In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Define the model class
class DualTaskClassifier(nn.Module):
    def __init__(self, num_mutation_classes, num_cancer_classes, input_channels=5):
        super(DualTaskClassifier, self).__init__()
        
        self.input_channels = input_channels
        
        self.conv1 = nn.Conv1d(input_channels, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
        
        self.pool = nn.MaxPool1d(2)
        
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(256)

        self.dropout = nn.Dropout(0.3)

        self.attention = nn.MultiheadAttention(embed_dim=256, num_heads=4, batch_first=True)
        
        self.fc_mutation1 = nn.Linear(256, 128)
        self.fc_mutation2 = nn.Linear(128, num_mutation_classes)
        
        self.fc_cancer1 = nn.Linear(256, 128)
        self.fc_cancer2 = nn.Linear(128, num_cancer_classes)

    def forward(self, x):
        batch_size, _, seq_len = x.shape

        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool(x)

        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x)

        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool(x)

        x_attn = x.permute(0, 2, 1)
        attn_output, _ = self.attention(x_attn, x_attn, x_attn)
        x = x + attn_output.permute(0, 2, 1)

        x = torch.mean(x, dim=2)
        
        shared_feat = self.dropout(x)
        
        mutation_feat = F.relu(self.fc_mutation1(shared_feat))
        mutation_feat = self.dropout(mutation_feat)
        mutation_out = self.fc_mutation2(mutation_feat)
        
        cancer_feat = F.relu(self.fc_cancer1(shared_feat))
        cancer_feat = self.dropout(cancer_feat)
        cancer_out = self.fc_cancer2(cancer_feat)
        
        return mutation_out, cancer_out

# Instantiate the model (update these numbers if needed)
num_mutation_classes = 3  # Example: change based on your dataset
num_cancer_classes = 3    # Example: change based on your dataset
input_channels = 17        # Or 17 if using extra features

model = DualTaskClassifier(num_mutation_classes, num_cancer_classes, input_channels)

# Load weights
model.load_state_dict(torch.load("cancer_and_mutation_type_classifier.pth", map_location='cpu'))

# Print architecture
print("Model Architecture:\n")
print(model)


Model Architecture:

DualTaskClassifier(
  (conv1): Conv1d(17, 64, kernel_size=(3,), stride=(1,), padding=(1,))
  (conv2): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))
  (conv3): Conv1d(128, 256, kernel_size=(3,), stride=(1,), padding=(1,))
  (pool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout): Dropout(p=0.3, inplace=False)
  (attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
  )
  (fc_mutation1): Linear(in_features=256, out_features=128, bias=True)
  (fc_mutation2): Linear(in_features=128, out_features=3, bias=True)
  (fc_cancer1): Linear(in_features=256, out_features=128, bias=True)
  (f