In [None]:
from torch import cuda

print(cuda.is_available())
print(cuda.get_device_name())
print(cuda.current_device())

_device = 'cuda' if cuda.is_available() else 'cpu'
print(f"available device {_device}")

In [None]:
"""
Create simple multitask learning architecture with three task.
1. Promoter detection.
2. Splice-site detection.
3. poly-A detection.
"""
from torch import nn

class MTModel(nn.Module):
    """
    Core architecture. This architecture consists of input layer, shared parameters, and heads for each of multi-tasks.
    """
    def __init__(self, shared_parameters, promoter_head, splice_site_head, polya_head):
        super().__init__()
        self.shared_layer = shared_parameters
        self.promoter_layer = promoter_head
        self.splice_site_layer = splice_site_head
        self.polya_layer = polya_head

    def forward(self, x):
        x = self.shared_layer(x)
        x1 = self.promoter_layer(x)
        x2 = self.splice_site_layer(x)
        x3 = self.polya_layer(x)
        return (x1, x2, x3)

class SharedParameter(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.stack = nn.Sequential(
            nn.Linear(512, 512, device=device),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.stack(x)
        return x

class PromoterHead(nn.Module):
    """
    Network configuration can be found in DeePromoter (Oubounyt et. al., 2019).
    """
    def __init__(self, device="cpu"):
        super().__init__()
        self.stack = nn.Sequential(
            nn.Linear(512, out_features=128, device=device),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, out_features=1, device=device),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.stack(x)
        return x

class SpliceSiteHead(nn.Module):
    """
    Network configuration can be found in Splice2Deep (Albaradei et. al., 2020).
    """
    def __init__(self, device="cpu"):
        super().__init__()
        self.stack = nn.Sequential(
            nn.Linear(512, out_features=2, device=device),
            nn.Softmax()
        )

    def forward(self, x):
        x = self.stack(x)
        return x


class PolyAHead(nn.Module):
    """
    Network configuration can be found in DeeReCT-PolyA (Xia et. al., 2018).
    """
    def __init__(self, device='cpu'):
        super().__init__()
        self.stack = nn.Sequential(
            nn.Linear(512, 64, device=device),
            nn.ReLU(),
            nn.Linear(64, 2, device=device),
            nn.Softmax()
        )

    def forward(self, x):
        x = self.stack(x)
        return x


polya_head = PolyAHead(_device)
promoter_head = PromoterHead(_device)
splice_head = SpliceSiteHead(_device)
shared_parameter = SharedParameter(_device)

model = MTModel(shared_parameters=shared_parameter, promoter_head=promoter_head, polya_head=polya_head, splice_site_head=splice_head).to(_device)
print(model)

