In [14]:
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.model_selection import KFold

from data_loader import load_data, get_balanced_data, normalize_features

import wandb
wandb.init(project='gene')

data = load_data()
X_balanced, y_balanced = get_balanced_data(data)
print('got balanced data')
for key in X_balanced:
    print(key, 'length:', len(X_balanced[key]))


  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)


Processed mcg data
Processed genebody data
Processed atac data
Processed hic data
zero: 11118, non-zero: 3467
got balanced data
feature_type length: 5200


In [38]:
HIDDEN_DIM = 16
NUM_LAYERS = 2
NUM_HEADS = 1
DROPOUT = 0.2
LR = 0.001
OUTPUT_DIM = 3  # number of classes (-1, 0, 1)
NUM_EPOCHS = 200
BATCH_SIZE = 32
FEATURE_TYPES = ['mcg', 'genebody', 'atac', 'hic']

# HIDDEN_DIM = 64
# NUM_LAYERS = 1
# NUM_HEADS = 1
# DROPOUT = 0.3
# LR = 1e-4
# OUTPUT_DIM = 3  # number of classes (-1, 0, 1)
# NUM_EPOCHS = 200
# BATCH_SIZE = 16


class NHeadTransformerModel(nn.Module):
    def __init__(self, head_input_dims, hidden_dim, output_dim, num_layers=2, num_heads=1, dropout=0.1):
        super(NHeadTransformerModel, self).__init__()
        self.embeddings = nn.ModuleList()
        for i, head_input_dim in enumerate(head_input_dims):
            self.embeddings.append(nn.Linear(head_input_dim, hidden_dim))
        
        # TODO: may need to use tanh in attention instead of softmax
        encoder_layers = nn.TransformerEncoderLayer(hidden_dim, num_heads, dim_feedforward=hidden_dim*2, dropout=dropout, batch_first=True, norm_first=True)
        self.transformers = nn.ModuleList()
        for _ in range(len(head_input_dims)):
            self.transformers.append(nn.TransformerEncoder(encoder_layers, num_layers))
        
        self.classifier = nn.Linear(hidden_dim * len(head_input_dims), output_dim)
        self.hidden_dim = hidden_dim

    def forward(self, inputs, masks):
        embeddings = [self.embeddings[i](inputs[i]) for i in range(len(inputs))]
        
        transformed = [self.transformers[i](embeddings[i], src_key_padding_mask=~masks[i].bool()) for i in range(len(inputs) - 1)]
        
        # Global average pooling
        pooled = [x.mean(dim=1) for x in transformed]
        
        # Concatenate MCG and ATAC embeddings
        combined_x = torch.cat(pooled, dim=1)
        
        output = self.classifier(combined_x)
        return output

class CombinedGeneDataset(Dataset):
    def __init__(self, labels, inputs):
        self.inputs = inputs
        self.labels = labels
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        inputs = [torch.FloatTensor(self.inputs[i][idx]) for i in range(len(self.inputs))]
        label = torch.LongTensor([self.labels[idx] + 1])  # Add 1 to shift labels to 0, 1, 2
        masks = [torch.ones(len(inputs[i])) for i in range(len(inputs))]
        return inputs + [label] + masks

def combined_collate_fn(batch):
    batch.sort(key=lambda x: len(x[0]), reverse=True)
    data_types = list(zip(*batch))
    
    labels = data_types[len(FEATURE_TYPES)]
    data_types = data_types[:len(FEATURE_TYPES)]

    # inputs:
    #  [[mcg seq1, seq2], [atac seq1, seq2], [hic seq1, seq2]]
    
    lengths = [[len(seq) for seq in data_type] for data_type in data_types]
    max_lens = [max(ls) for ls in lengths]
    
    padded_seqs = [torch.zeros(len(data_type), max_len, data_type[0].size(1)) 
                   for data_type, max_len in zip(data_types, max_lens)]
    padded_masks = [torch.zeros(len(data_type), max_len) 
                    for data_type, max_len in zip(data_types, max_lens)]
    
    for i in range(len(batch)):
        for j, (seq, length) in enumerate(zip(data_types, lengths)):
            padded_seqs[j][i, :length[i]] = seq[i]
            padded_masks[j][i, :length[i]] = 1
    
    return padded_seqs + [torch.cat(labels)] + padded_masks

    

def train_combined_model(X_trains, y_train, X_tests, y_test, exp_name, fold_idx):
    wandb.init(project='gene', group=exp_name, name=f'fold-{fold_idx}')
    
    train_dataset = CombinedGeneDataset(y_train, X_trains)
    test_dataset = CombinedGeneDataset(y_test, X_tests)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=combined_collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=combined_collate_fn)

    model = NHeadTransformerModel(head_input_dims=[len(x[0]) for x in X_trains], hidden_dim=HIDDEN_DIM, output_dim=OUTPUT_DIM, num_layers=NUM_LAYERS, num_heads=NUM_HEADS, dropout=DROPOUT)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LR)

    lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=0.0001)
    # Create the OneCycleLR scheduler
    # lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LR, total_steps=NUM_EPOCHS,
    #                           pct_start=0.8, anneal_strategy='cos',
    #                           cycle_momentum=False, div_factor=5.0,
    #                           final_div_factor=10.0)

    for epoch in tqdm(range(NUM_EPOCHS)):
        model.train()
        total_loss = 0
        train_correct = 0
        train_total = 0
        for inputs in train_loader:
            optimizer.zero_grad()
            data = inputs[:, :len(FEATURE_TYPES)-1]
            labels = inputs[len(FEATURE_TYPES)]
            masks = inputs[len(FEATURE_TYPES) + 1:]
            outputs = model(data, masks)
            loss = criterion(outputs, labels.squeeze())
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            train_correct += (outputs.argmax(dim=1) == labels.squeeze()).sum().item()
            train_total += labels.size(0)
        lr_scheduler.step()

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs in test_loader:
                data = inputs[:, :len(FEATURE_TYPES)-1]
                labels = inputs[len(FEATURE_TYPES)]
                masks = inputs[len(FEATURE_TYPES) + 1:]
                outputs = model(data, masks)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels.squeeze()).sum().item()
        
        accuracy = correct / total
        # print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Train Loss: {total_loss/len(train_loader):.4f}, Train Accuracy: {train_correct/train_total:.4f}, Test Accuracy: {accuracy:.4f}')
        wandb.log({'epoch': epoch, 'LR': optimizer.param_groups[0]['lr'], 'train_loss': total_loss/len(train_loader), 'train_accuracy': train_correct/train_total, 'test_accuracy': accuracy})

    model.eval()
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for inputs in test_loader:
            data = inputs[:, :len(FEATURE_TYPES)-1]
            labels = inputs[len(FEATURE_TYPES)]
            masks = inputs[len(FEATURE_TYPES) + 1:]
            outputs = model(data, masks)
            _, predicted = torch.max(outputs.data, 1)
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    final_accuracy = sum(np.array(all_predictions) == np.array(all_labels).squeeze()) / len(all_labels)
    print(f'Final Test Accuracy: {final_accuracy:.4f}')
    return final_accuracy

In [39]:
import time
exp_name = f'two-head-{time.strftime("%Y%m%d-%H%M%S")}'

wandb.config.update({
    'hidden_dim': HIDDEN_DIM,
    'num_layers': NUM_LAYERS,
    'num_heads': NUM_HEADS,
    'dropout': DROPOUT,
    'lr': LR,
    'output_dim': OUTPUT_DIM,
    'num_epochs': NUM_EPOCHS,
    'batch_size': BATCH_SIZE
})


kf = KFold(n_splits=5, shuffle=True, random_state=25)
accuracies = []
first_key = list(X_balanced.keys())[0]
for i, (train_index, test_index) in enumerate(kf.split(X_balanced[first_key])):
    X_trains = {key: [X_balanced[key][i] for i in train_index] for key in X_balanced}
    y_train = [y_balanced[i] for i in train_index]
    X_tests = {key: [X_balanced[key][i] for i in test_index] for key in X_balanced}
    y_test = [y_balanced[i] for i in test_index]
    
    X_trains_normalized_dict, X_tests_normalized_dict = {}, {}
    for key in X_balanced:
        X_trains_normalized_dict[key], X_tests_normalized_dict[key] = normalize_features(X_trains[key], X_tests[key])
    X_trains_normalized = [X_trains_normalized_dict[key] for key in X_trains_normalized_dict]
    X_tests_normalized = [X_tests_normalized_dict[key] for key in X_tests_normalized_dict]
    
    accuracies.append(train_combined_model(X_trains_normalized, y_train, X_tests_normalized, y_test, exp_name=exp_name, fold_idx=i))

print(f'Mean Accuracy: {np.mean(accuracies):.4f} ± {np.std(accuracies):.4f}')

  0%|          | 0/200 [00:00<?, ?it/s]


IndexError: list index out of range