In [5]:
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.model_selection import KFold

from data_loader import load_data, get_balanced_data, normalize_features

import wandb
wandb.init(project='gene')

data = load_data()
X_balanced, y_balanced = get_balanced_data(data)
print(len(X_balanced['mcg']), len(X_balanced['atac']))


    gene   2mo   9mo  18mo  9mo-2mo  18mo-9mo   9mo/2mo  18mo/9mo  old-young  \
0  Rgs20  0.65  0.60  0.90    -0.05      0.30  0.923077  1.500000       0.25   
1  Sulf1  0.36  0.52  0.56     0.16      0.04  1.444444  1.076923       0.20   
2  Sulf1  0.43  0.59  0.64     0.16      0.05  1.372093  1.084746       0.21   
3   Eya1  0.68  0.62  0.47    -0.06     -0.15  0.911765  0.758065      -0.21   
4   Eya1  0.61  0.37  0.45    -0.24      0.08  0.606557  1.216216      -0.16   

   old/young  distance  
0   1.384615  151241.0  
1   1.555556  121205.0  
2   1.488372  170142.0  
3   0.691176  137980.0  
4   0.737705  138254.0  
      gene  log2(old/young)  distance
0    Itgb5         1.521456   32221.0
1   Begain         1.320315   37592.0
2   BEGAIN         1.320315   36795.0
3     Eya4         1.588159  247781.0
4  Gm27247         1.018367   97190.0
[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], [[0,

In [6]:
HIDDEN_DIM = 16
NUM_LAYERS = 2
NUM_HEADS = 1
DROPOUT = 0.2
LR = 0.001
OUTPUT_DIM = 3  # number of classes (-1, 0, 1)
NUM_EPOCHS = 200


class TwoHeadTransformerModel(nn.Module):
    def __init__(self, mcg_input_dim, atac_input_dim, hidden_dim, output_dim, num_layers=2, num_heads=1, dropout=0.1):
        super(TwoHeadTransformerModel, self).__init__()
        self.mcg_embedding = nn.Linear(mcg_input_dim, hidden_dim)
        self.atac_embedding = nn.Linear(atac_input_dim, hidden_dim)
        
        encoder_layers = nn.TransformerEncoderLayer(hidden_dim, num_heads, dim_feedforward=hidden_dim*2, dropout=dropout, batch_first=True)
        self.mcg_transformer = nn.TransformerEncoder(encoder_layers, num_layers)
        self.atac_transformer = nn.TransformerEncoder(encoder_layers, num_layers)
        
        self.classifier = nn.Linear(hidden_dim * 2, output_dim)
        self.hidden_dim = hidden_dim

    def forward(self, mcg_x, mcg_mask, atac_x, atac_mask):
        mcg_x = self.mcg_embedding(mcg_x)
        atac_x = self.atac_embedding(atac_x)
        
        mcg_x = self.mcg_transformer(mcg_x, src_key_padding_mask=~mcg_mask.bool())
        atac_x = self.atac_transformer(atac_x, src_key_padding_mask=~atac_mask.bool())
        
        # Global average pooling
        mcg_x = mcg_x.mean(dim=1)
        atac_x = atac_x.mean(dim=1)
        
        # Concatenate MCG and ATAC embeddings
        combined_x = torch.cat((mcg_x, atac_x), dim=1)
        
        output = self.classifier(combined_x)
        return output

class CombinedGeneDataset(Dataset):
    def __init__(self, mcg_data, atac_data, labels):
        self.mcg_data = mcg_data
        self.atac_data = atac_data
        self.labels = labels
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        mcg_gene_data = torch.FloatTensor(self.mcg_data[idx])
        atac_gene_data = torch.FloatTensor(self.atac_data[idx])
        label = torch.LongTensor([self.labels[idx] + 1])  # Add 1 to shift labels to 0, 1, 2
        mcg_mask = torch.ones(len(mcg_gene_data))
        atac_mask = torch.ones(len(atac_gene_data))
        return mcg_gene_data, atac_gene_data, label, mcg_mask, atac_mask

def combined_collate_fn(batch):
    batch.sort(key=lambda x: len(x[0]), reverse=True)
    mcg_sequences, atac_sequences, labels, mcg_masks, atac_masks = zip(*batch)
    
    mcg_lengths = [len(seq) for seq in mcg_sequences]
    atac_lengths = [len(seq) for seq in atac_sequences]
    mcg_max_len = max(mcg_lengths)
    atac_max_len = max(atac_lengths)
    
    padded_mcg_seqs = torch.zeros(len(mcg_sequences), mcg_max_len, mcg_sequences[0].size(1))
    padded_atac_seqs = torch.zeros(len(atac_sequences), atac_max_len, atac_sequences[0].size(1))
    padded_mcg_masks = torch.zeros(len(mcg_sequences), mcg_max_len)
    padded_atac_masks = torch.zeros(len(atac_sequences), atac_max_len)
    
    for i, (mcg_seq, atac_seq, mcg_length, atac_length) in enumerate(zip(mcg_sequences, atac_sequences, mcg_lengths, atac_lengths)):
        padded_mcg_seqs[i, :mcg_length] = mcg_seq
        padded_atac_seqs[i, :atac_length] = atac_seq
        padded_mcg_masks[i, :mcg_length] = 1
        padded_atac_masks[i, :atac_length] = 1
    
    return padded_mcg_seqs, padded_atac_seqs, torch.cat(labels), padded_mcg_masks, padded_atac_masks

def train_combined_model(X_train_mcg, X_train_atac, y_train, X_test_mcg, X_test_atac, y_test, exp_name):
    wandb.init(project='gene', group=exp_name)
    mcg_input_dim = len(X_train_mcg[0][0])
    atac_input_dim = len(X_train_atac[0][0])
    
    train_dataset = CombinedGeneDataset(X_train_mcg, X_train_atac, y_train)
    test_dataset = CombinedGeneDataset(X_test_mcg, X_test_atac, y_test)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=combined_collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=combined_collate_fn)

    model = TwoHeadTransformerModel(mcg_input_dim, atac_input_dim, HIDDEN_DIM, OUTPUT_DIM, num_layers=NUM_LAYERS, num_heads=NUM_HEADS, dropout=DROPOUT)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LR)

    lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=0.0001)

    for epoch in tqdm(range(NUM_EPOCHS)):
        model.train()
        total_loss = 0
        train_correct = 0
        train_total = 0
        for mcg_x, atac_x, batch_y, mcg_mask, atac_mask in train_loader:
            optimizer.zero_grad()
            outputs = model(mcg_x, mcg_mask, atac_x, atac_mask)
            loss = criterion(outputs, batch_y.squeeze())
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            train_correct += (outputs.argmax(dim=1) == batch_y.squeeze()).sum().item()
            train_total += batch_y.size(0)
        lr_scheduler.step()

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for mcg_x, atac_x, batch_y, mcg_mask, atac_mask in test_loader:
                outputs = model(mcg_x, mcg_mask, atac_x, atac_mask)
                _, predicted = torch.max(outputs.data, 1)
                total += batch_y.size(0)
                correct += (predicted == batch_y.squeeze()).sum().item()
        
        accuracy = correct / total
        # print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Train Loss: {total_loss/len(train_loader):.4f}, Train Accuracy: {train_correct/train_total:.4f}, Test Accuracy: {accuracy:.4f}')
        wandb.log({'epoch': epoch, 'LR': optimizer.param_groups[0]['lr'], 'train_loss': total_loss/len(train_loader), 'train_accuracy': train_correct/train_total, 'test_accuracy': accuracy})

    model.eval()
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for mcg_x, atac_x, batch_y, mcg_mask, atac_mask in test_loader:
            outputs = model(mcg_x, mcg_mask, atac_x, atac_mask)
            _, predicted = torch.max(outputs.data, 1)
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(batch_y.cpu().numpy())

    final_accuracy = sum(np.array(all_predictions) == np.array(all_labels).squeeze()) / len(all_labels)
    print(f'Final Test Accuracy: {final_accuracy:.4f}')
    return final_accuracy

In [7]:
import time
exp_name = f'two-head-{time.strftime("%Y%m%d-%H%M%S")}'

wandb.config.update({
    'hidden_dim': HIDDEN_DIM,
    'num_layers': NUM_LAYERS,
    'num_heads': NUM_HEADS,
    'dropout': DROPOUT,
    'lr': LR,
    'output_dim': OUTPUT_DIM,
    'num_epochs': NUM_EPOCHS
})


kf = KFold(n_splits=5, shuffle=True, random_state=25)
accuracies = []
for train_index, test_index in kf.split(X_balanced['mcg']):
    X_train_mcg, X_test_mcg = [X_balanced['mcg'][i] for i in train_index], [X_balanced['mcg'][i] for i in test_index]
    X_train_atac, X_test_atac = [X_balanced['atac'][i] for i in train_index], [X_balanced['atac'][i] for i in test_index]
    y_train, y_test = [y_balanced[i] for i in train_index], [y_balanced[i] for i in test_index]
    
    X_train_mcg_normalized, X_test_mcg_normalized = normalize_features(X_train_mcg, X_test_mcg)
    X_train_atac_normalized, X_test_atac_normalized = normalize_features(X_train_atac, X_test_atac)
    
    accuracies.append(train_combined_model(X_train_mcg_normalized, X_train_atac_normalized, y_train, 
                                           X_test_mcg_normalized, X_test_atac_normalized, y_test, exp_name=exp_name))

print(f'Mean Accuracy: {np.mean(accuracies):.4f} ± {np.std(accuracies):.4f}')

 63%|██████▎   | 126/200 [01:34<00:55,  1.33it/s]


KeyboardInterrupt: 