## Imports

In [1]:
from tqdm import tqdm
import time
import pickle

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import numpy as np
import pandas as pd

from sklearn.model_selection import KFold

from data_loader import load_data, get_balanced_data, normalize_features

import seaborn as sns
import matplotlib.pyplot as plt
# import wandb
# wandb.init(project='gene')

  from .autonotebook import tqdm as notebook_tqdm


## Session Parameters

In [2]:
!python -V

Python 3.10.15


In [3]:
print(torch.__version__), print(np.__version__)
print(torch.cuda.is_available())
print(torch.version.cuda)

1.12.1
1.26.4
True
11.3


In [4]:
print(torch.cuda.current_device(), torch.cuda.device_count(), torch.cuda.get_device_name(0))
device = torch.device("cuda:0")
print(device)

0 1 NVIDIA A40
cuda:0


## Loading Data

In [5]:
data = load_data(y_val = "DEG", ct="Oligo_NN")
X_balanced, y_balanced = get_balanced_data(data, method='balanced', y_val="DEG")
FEATURE_TYPES = ['mcg', 'atac', 'hic', 'genebody']
for k, v in X_balanced.items():
    print(k, len(v))

Processed mcg data
Processed genebody data
Processed atac data
Processed hic data
zero: 4115, down: 874, up: 557
mcg 1671
genebody 1671
atac 1671
hic 1671


## Model Functions

In [25]:
HIDDEN_DIM = 64
NUM_LAYERS = 4
NUM_HEADS = 8
DROPOUT = 0.1
LR = 0.001
OUTPUT_DIM = 3  # number of classes (-1, 0, 1)
NUM_EPOCHS = 50
BATCH_SIZE = 32

class ThreeHeadTransformerModel(nn.Module):
    def __init__(self, input_1_dim, input_2_dim, input_3_dim, hidden_dim, output_dim, num_layers=2, num_heads=1, dropout=0.1):
        super(ThreeHeadTransformerModel, self).__init__()
        # For the tokens
        self.cls_1_embedding = nn.Parameter(torch.randn(1, 1, hidden_dim))
        self.cls_2_embedding = nn.Parameter(torch.randn(1, 1, hidden_dim))
        self.cls_3_embedding = nn.Parameter(torch.randn(1, 1, hidden_dim))
        self.combined_cls_embedding = nn.Parameter(torch.randn(1, 1, hidden_dim))
        # embedding layer of the inputs
        self.input_1_embedding = nn.Linear(input_1_dim, hidden_dim)
        self.input_2_embedding = nn.Linear(input_2_dim, hidden_dim)
        self.input_3_embedding = nn.Linear(input_3_dim, hidden_dim)

        # Transformer layers
        # TODO: may need to use tanh in attention instead of softmax
        encoder_layers = nn.TransformerEncoderLayer(hidden_dim, num_heads, dim_feedforward=hidden_dim*4, dropout=dropout, batch_first=True, norm_first=True)
        self.transformer_1 = nn.TransformerEncoder(encoder_layers, num_layers)
        self.transformer_2 = nn.TransformerEncoder(encoder_layers, num_layers)
        self.transformer_3 = nn.TransformerEncoder(encoder_layers, num_layers)
        
        # Use self-attention mechanism to combine the outputs of the four heads
        self.linear = nn.Linear(hidden_dim * 3, hidden_dim)
        self.classifier = nn.Linear(hidden_dim, output_dim)

    def forward(self, _x1, _x2, _x3, x1_mask, x2_mask, x3_mask):
        _x1 = self.input_1_embedding(_x1)
        _x2 = self.input_2_embedding(_x2)
        _x3 = self.input_3_embedding(_x3)

        # Add cls embedding, so that we can take the first output as the representation of the sequence
        _cls_1_embedding = self.cls_1_embedding.repeat(_x1.size(0), 1, 1)
        _cls_2_embedding = self.cls_2_embedding.repeat(_x2.size(0), 1, 1)
        _cls_3_embedding = self.cls_3_embedding.repeat(_x3.size(0), 1, 1)

        _z1 = torch.cat((_cls_1_embedding, _x1), dim=1)
        _z2 = torch.cat((_cls_2_embedding, _x2), dim=1)
        _z3 = torch.cat((_cls_3_embedding, _x3), dim=1)
        
        _mask_1 = torch.cat((torch.ones(x1_mask.size(0), 1).to(device), x1_mask), dim=1)
        _mask_2 = torch.cat((torch.ones(x2_mask.size(0), 1).to(device), x2_mask), dim=1)
        _mask_3 = torch.cat((torch.ones(x3_mask.size(0), 1).to(device), x3_mask), dim=1)
        
        _z1 = self.transformer_1(_z1, src_key_padding_mask=~_mask_1.bool())
        _z2 = self.transformer_2(_z2, src_key_padding_mask=~_mask_2.bool())
        _z3 = self.transformer_3(_z3, src_key_padding_mask=~_mask_3.bool())
    
        
        # Pooling, we should not use average pooling since the sequence length is also important
        _z1 = _z1[:, 0, :]
        _z2 = _z2[:, 0, :]
        _z3 = _z3[:, 0, :]
        
        # Combine the output of three heads into a sequence
        _z = torch.concat((_z1, _z2, _z3), dim=1)
        _z = self.linear(_z)
        _z = _z.squeeze(-1) # (batch_size, hidden_dim, 1)
        # Final classifier
        output = self.classifier(_z)
        return output

In [26]:
# Dataset Object
class ThreeGeneDataset(Dataset):
    def __init__(self, _data1, _data2, _data3, labels):
        self.data1 = _data1
        self.data2 = _data2
        self.data3 = _data3
        self.labels = labels
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        data1 = torch.FloatTensor(self.data1[idx]).to(device)
        data2 = torch.FloatTensor(self.data2[idx]).to(device)
        data3 = torch.FloatTensor(self.data3[idx]).to(device)
        
        label = torch.LongTensor([self.labels[idx] + 1]).to(device)  # Add 1 to shift labels to 0, 1, 2
        
        mask1 = torch.ones(len(data1)).to(device)
        mask2 = torch.ones(len(data2)).to(device)
        mask3 = torch.ones(len(data3)).to(device)
        
        return data1, data2, data3, label, mask1, mask2, mask3

## handling the batches 
def comb_collate_fn(batch):
    batch.sort(key=lambda x: len(x[0]), reverse=True)
    d1_sequences, d2_sequences, d3_sequences, labels, d1_masks, d2_masks, d3_masks = zip(*batch)
    
    d1_lengths = [len(seq) for seq in d1_sequences]
    d2_lengths = [len(seq) for seq in d2_sequences]
    d3_lengths = [len(seq) for seq in d3_sequences]
    d1_max_len = max(d1_lengths)
    d2_max_len = max(d2_lengths)
    d3_max_len = max(d3_lengths)
    
    d1_padded_seqs = torch.zeros(len(d1_sequences), d1_max_len, d1_sequences[0].size(1))
    d2_padded_seqs = torch.zeros(len(d2_sequences), d2_max_len, d2_sequences[0].size(1))
    d3_padded_seqs = torch.zeros(len(d3_sequences), d3_max_len, d3_sequences[0].size(1))
    
    d1_padded_masks = torch.zeros(len(d1_sequences), d1_max_len)
    d2_padded_masks = torch.zeros(len(d2_sequences), d2_max_len)
    d3_padded_masks = torch.zeros(len(d3_sequences), d3_max_len)
    
    for i, (d1_seq, d2_seq, d3_seq, d1_length, d2_length, d3_length) in enumerate(zip(d1_sequences, d2_sequences, d3_sequences, d1_lengths, d2_lengths, d3_lengths)):
        d1_padded_seqs[i, :d1_length] = d1_seq
        d2_padded_seqs[i, :d2_length] = d2_seq
        d3_padded_seqs[i, :d3_length] = d3_seq
        
        d1_padded_masks[i, :d1_length] = 1 
        d2_padded_masks[i, :d2_length] = 1 
        d3_padded_masks[i, :d3_length] = 1 
        
    return d1_padded_seqs.to(device), d2_padded_seqs.to(device), d3_padded_seqs.to(device), torch.cat(labels), d1_padded_masks.to(device), d2_padded_masks.to(device), d3_padded_masks.to(device)

In [27]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False
        
def train_3head_model(X_train_1, X_train_2, X_train_3, y_train, X_test_1, X_test_2, X_test_3, y_test, exp_name, fold_idx):
    #wandb.init(project='gene', group=exp_name, name=f'fold-{fold_idx}')
    input_1_dim = len(X_train_1[0][0])
    input_2_dim = len(X_train_2[0][0])
    input_3_dim = len(X_train_3[0][0])
    
    train_dataset = ThreeGeneDataset(X_train_1, X_train_2, X_train_3, y_train)
    test_dataset = ThreeGeneDataset(X_test_1, X_test_2, X_test_3, y_test)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=comb_collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=comb_collate_fn)

    model = ThreeHeadTransformerModel(input_1_dim, input_2_dim, input_3_dim, HIDDEN_DIM, OUTPUT_DIM, num_layers=NUM_LAYERS, num_heads=NUM_HEADS, dropout=DROPOUT).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LR)

    lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=0.0001)
    # Create the OneCycleLR scheduler
    # lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LR, total_steps=NUM_EPOCHS,
    #                           pct_start=0.8, anneal_strategy='cos',
    #                           cycle_momentum=False, div_factor=5.0,
    #                           final_div_factor=10.0)

    train_accuracies = []
    test_accuracies = []
    train_losses = []
    test_losses = []
    early_stopper = EarlyStopper(patience=5, min_delta=0.03)
    for epoch in range(NUM_EPOCHS):
        model.train()
        total_loss = 0
        train_correct = 0
        train_total = 0
        for x1, x2, x3, batch_y, mask1, mask2, mask3 in train_loader:
            optimizer.zero_grad()
            outputs = model(x1, x2, x3, mask1, mask2, mask3)
            loss = criterion(outputs, batch_y.squeeze())
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            train_correct += (outputs.argmax(dim=1) == batch_y.squeeze()).sum().item()
            train_total += batch_y.size(0)
        lr_scheduler.step()

        model.eval()
        correct = 0
        total = 0
        test_loss = 0
        with torch.no_grad():
            for x1, x2, x3, batch_y, mask1, mask2, mask3 in test_loader:
                outputs = model(x1, x2, x3, mask1, mask2, mask3)
                _, predicted = torch.max(outputs.data, 1)
                test_loss += criterion(outputs, batch_y.squeeze()).item()
                total += batch_y.size(0)
                correct += (predicted == batch_y.squeeze()).sum().item()
        
        accuracy = correct / total
        train_accuracies.append(train_correct/train_total)
        train_losses.append(total_loss/len(train_loader))
        test_accuracies.append(accuracy)
        test_losses.append(test_loss / len(test_loader))
        print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Train Loss: {total_loss/len(train_loader):.4f}, Train Accuracy: {train_correct/train_total:.4f}, Test Loss: {test_loss/len(test_loader):.4f}, Test Accuracy: {accuracy:.4f}')
        #wandb.log({'epoch': epoch, 'LR': optimizer.param_groups[0]['lr'], 'train_loss': total_loss/len(train_loader), 'train_accuracy': train_correct/train_total, 'test_accuracy': accuracy})
        if early_stopper.early_stop( test_loss / len(test_loader) ): 
            print("Early Stop")
            break

    model.eval()
    all_predictions = []
    all_labels = []
    with torch.no_grad():
        for x1, x2, x3, batch_y, mask1, mask2, mask3 in test_loader:
            outputs = model(x1, x2, x3, mask1, mask2, mask3)
            _, predicted = torch.max(outputs.detach().cpu().data, 1)
            all_predictions.extend(predicted.numpy())
            all_labels.extend(batch_y.detach().cpu().numpy())

    # final_accuracy = sum(np.array(all_predictions) == np.array(all_labels).squeeze()) / len(all_labels)
    # print(f'Final Test Accuracy: {final_accuracy:.4f}')
    return (all_predictions, all_labels), (train_accuracies, train_losses, test_accuracies, test_losses), model

In [28]:
kf = KFold(n_splits=5, shuffle=True, random_state=25)

In [29]:
exp_name = f'three-head-{time.strftime("%Y%m%d-%H%M%S")}'

accuracies = []
for i, (train_index, test_index) in enumerate(kf.split(X_balanced['mcg'])):
    X_train_mcg, X_test_mcg = [X_balanced['mcg'][i] for i in train_index], [X_balanced['mcg'][i] for i in test_index]
    X_train_atac, X_test_atac = [X_balanced['atac'][i] for i in train_index], [X_balanced['atac'][i] for i in test_index]
    X_train_hic, X_test_hic = [X_balanced['hic'][i] for i in train_index], [X_balanced['hic'][i] for i in test_index]
    # X_train_genebody, X_test_genebody = [X_balanced['genebody'][i] for i in train_index], [X_balanced['genebody'][i] for i in test_index]
    
    y_train, y_test = [y_balanced[i] for i in train_index], [y_balanced[i] for i in test_index]
    
    X_train_mcg_normalized, X_test_mcg_normalized = normalize_features(X_train_mcg, X_test_mcg)
    X_train_atac_normalized, X_test_atac_normalized = normalize_features(X_train_atac, X_test_atac)
    X_train_hic_normalized, X_test_hic_normalized = normalize_features(X_train_hic, X_test_hic)
        # X_train_genebody_normalized, X_test_genebody_normalized = normalize_features(X_train_genebody, X_test_genebody)
    
    elem_stats, epoch_stats, model = train_3head_model(X_train_mcg_normalized, X_train_atac_normalized, X_train_hic_normalized, y_train,
                                                          X_test_mcg_normalized, X_test_atac_normalized, X_test_hic_normalized, y_test, exp_name=exp_name, fold_idx=i)

    df_elem = pd.DataFrame({'preds' : elem_stats[0], 'labels' : elem_stats[1]})
    df_epoch = pd.DataFrame({'train_acc' : epoch_stats[0], 'train_loss' : epoch_stats[1], 
                             'test_acc' : epoch_stats[2], 'test_loss' : epoch_stats[3]})
    df_elem.to_csv(f"../3head_results/no_genebody/res_bal_elem_{i}.csv")
    df_epoch.to_csv(f"../3head_results/no_genebody/res_bal_epoch_{i}.csv")
    
    torch.save(model, f"../3head_results/no_genebody/model_{i}.pt")

    del model
    del X_train_mcg, X_train_mcg_normalized
    del X_train_atac, X_train_atac_normalized
    del X_train_hic, X_train_hic_normalized
    # del X_train_genebody, X_train_genebody_normalized

Epoch [1/50], Train Loss: 1.1726, Train Accuracy: 0.4349, Test Loss: 0.9585, Test Accuracy: 0.5493
Epoch [2/50], Train Loss: 0.9203, Train Accuracy: 0.5524, Test Loss: 0.8185, Test Accuracy: 0.6299
Epoch [3/50], Train Loss: 0.9036, Train Accuracy: 0.5554, Test Loss: 0.8606, Test Accuracy: 0.5731
Epoch [4/50], Train Loss: 0.8940, Train Accuracy: 0.5831, Test Loss: 0.8710, Test Accuracy: 0.5552
Epoch [5/50], Train Loss: 0.8720, Train Accuracy: 0.5778, Test Loss: 0.7916, Test Accuracy: 0.6149
Epoch [6/50], Train Loss: 0.8283, Train Accuracy: 0.5981, Test Loss: 0.7500, Test Accuracy: 0.6299
Epoch [7/50], Train Loss: 0.8079, Train Accuracy: 0.6078, Test Loss: 0.7825, Test Accuracy: 0.6358
Epoch [8/50], Train Loss: 0.8116, Train Accuracy: 0.6355, Test Loss: 0.7448, Test Accuracy: 0.6299
Epoch [9/50], Train Loss: 0.7586, Train Accuracy: 0.6475, Test Loss: 0.7782, Test Accuracy: 0.6269
Epoch [10/50], Train Loss: 0.7661, Train Accuracy: 0.6295, Test Loss: 0.7280, Test Accuracy: 0.6657
Epoch [11

In [30]:
exp_name = f'four-head-{time.strftime("%Y%m%d-%H%M%S")}'

kf = KFold(n_splits=5, shuffle=True, random_state=25)
accuracies = []
for i, (train_index, test_index) in enumerate(kf.split(X_balanced['mcg'])):
    X_train_mcg, X_test_mcg = [X_balanced['mcg'][i] for i in train_index], [X_balanced['mcg'][i] for i in test_index]
    X_train_atac, X_test_atac = [X_balanced['atac'][i] for i in train_index], [X_balanced['atac'][i] for i in test_index]
    # X_train_hic, X_test_hic = [X_balanced['hic'][i] for i in train_index], [X_balanced['hic'][i] for i in test_index]
    X_train_genebody, X_test_genebody = [X_balanced['genebody'][i] for i in train_index], [X_balanced['genebody'][i] for i in test_index]
    
    y_train, y_test = [y_balanced[i] for i in train_index], [y_balanced[i] for i in test_index]
    
    X_train_mcg_normalized, X_test_mcg_normalized = normalize_features(X_train_mcg, X_test_mcg)
    X_train_atac_normalized, X_test_atac_normalized = normalize_features(X_train_atac, X_test_atac)
    # X_train_hic_normalized, X_test_hic_normalized = normalize_features(X_train_hic, X_test_hic)
    X_train_genebody_normalized, X_test_genebody_normalized = normalize_features(X_train_genebody, X_test_genebody)
    
    elem_stats, epoch_stats, model = train_3head_model(X_train_mcg_normalized, X_train_atac_normalized, X_train_genebody_normalized, y_train,
                                                          X_test_mcg_normalized, X_test_atac_normalized, X_test_genebody_normalized, y_test, exp_name=exp_name, fold_idx=i)

    df_elem = pd.DataFrame({'preds' : elem_stats[0], 'labels' : elem_stats[1]})
    df_epoch = pd.DataFrame({'train_acc' : epoch_stats[0], 'train_loss' : epoch_stats[1], 
                             'test_acc' : epoch_stats[2], 'test_loss' : epoch_stats[3]})
    df_elem.to_csv(f"../3head_results/no_hic/res_bal_elem_{i}.csv")
    df_epoch.to_csv(f"../3head_results/no_hic/res_bal_epoch_{i}.csv")
    
    torch.save(model, f"../3head_results/no_hic/model_{i}.pt")

    del model
    del X_train_mcg, X_train_mcg_normalized
    del X_train_atac, X_train_atac_normalized
    # del X_train_hic, X_train_hic_normalized
    del X_train_genebody, X_train_genebody_normalized


Epoch [1/50], Train Loss: 1.1335, Train Accuracy: 0.4454, Test Loss: 0.9158, Test Accuracy: 0.5224
Epoch [2/50], Train Loss: 0.9212, Train Accuracy: 0.5651, Test Loss: 0.8443, Test Accuracy: 0.5761
Epoch [3/50], Train Loss: 0.8582, Train Accuracy: 0.5981, Test Loss: 0.8771, Test Accuracy: 0.5851
Epoch [4/50], Train Loss: 0.8262, Train Accuracy: 0.6153, Test Loss: 0.8560, Test Accuracy: 0.5791
Epoch [5/50], Train Loss: 0.8605, Train Accuracy: 0.5808, Test Loss: 0.7910, Test Accuracy: 0.5851
Epoch [6/50], Train Loss: 0.8084, Train Accuracy: 0.6160, Test Loss: 0.7218, Test Accuracy: 0.6567
Epoch [7/50], Train Loss: 0.7853, Train Accuracy: 0.6347, Test Loss: 0.7857, Test Accuracy: 0.6418
Epoch [8/50], Train Loss: 0.7799, Train Accuracy: 0.6280, Test Loss: 0.7412, Test Accuracy: 0.6567
Epoch [9/50], Train Loss: 0.7568, Train Accuracy: 0.6527, Test Loss: 0.7070, Test Accuracy: 0.6657
Epoch [10/50], Train Loss: 0.7645, Train Accuracy: 0.6602, Test Loss: 0.7508, Test Accuracy: 0.6478
Epoch [11

In [31]:
exp_name = f'four-head-{time.strftime("%Y%m%d-%H%M%S")}'

kf = KFold(n_splits=5, shuffle=True, random_state=25)
accuracies = []
for i, (train_index, test_index) in enumerate(kf.split(X_balanced['mcg'])):
    X_train_mcg, X_test_mcg = [X_balanced['mcg'][i] for i in train_index], [X_balanced['mcg'][i] for i in test_index]
    # X_train_atac, X_test_atac = [X_balanced['atac'][i] for i in train_index], [X_balanced['atac'][i] for i in test_index]
    X_train_hic, X_test_hic = [X_balanced['hic'][i] for i in train_index], [X_balanced['hic'][i] for i in test_index]
    X_train_genebody, X_test_genebody = [X_balanced['genebody'][i] for i in train_index], [X_balanced['genebody'][i] for i in test_index]
    
    y_train, y_test = [y_balanced[i] for i in train_index], [y_balanced[i] for i in test_index]
    
    X_train_mcg_normalized, X_test_mcg_normalized = normalize_features(X_train_mcg, X_test_mcg)
    # X_train_atac_normalized, X_test_atac_normalized = normalize_features(X_train_atac, X_test_atac)
    X_train_hic_normalized, X_test_hic_normalized = normalize_features(X_train_hic, X_test_hic)
    X_train_genebody_normalized, X_test_genebody_normalized = normalize_features(X_train_genebody, X_test_genebody)
    
    elem_stats, epoch_stats, model = train_3head_model(X_train_mcg_normalized, X_train_hic_normalized, X_train_genebody_normalized, y_train,
                                                          X_test_mcg_normalized, X_test_hic_normalized, X_test_genebody_normalized, y_test, exp_name=exp_name, fold_idx=i)

    df_elem = pd.DataFrame({'preds' : elem_stats[0], 'labels' : elem_stats[1]})
    df_epoch = pd.DataFrame({'train_acc' : epoch_stats[0], 'train_loss' : epoch_stats[1], 
                             'test_acc' : epoch_stats[2], 'test_loss' : epoch_stats[3]})
    df_elem.to_csv(f"../3head_results/no_atac/res_bal_elem_{i}.csv")
    df_epoch.to_csv(f"../3head_results/no_atac/res_bal_epoch_{i}.csv")

    torch.save(model, f"../3head_results/no_atac/model_{i}.pt")

    del model
    del X_train_mcg, X_train_mcg_normalized
    # del X_train_atac, X_train_atac_normalized
    del X_train_hic, X_train_hic_normalized
    del X_train_genebody, X_train_genebody_normalized


Epoch [1/50], Train Loss: 1.1677, Train Accuracy: 0.3930, Test Loss: 1.0101, Test Accuracy: 0.5045
Epoch [2/50], Train Loss: 0.9932, Train Accuracy: 0.5000, Test Loss: 0.9518, Test Accuracy: 0.5284
Epoch [3/50], Train Loss: 0.9316, Train Accuracy: 0.5561, Test Loss: 0.8054, Test Accuracy: 0.5731
Epoch [4/50], Train Loss: 0.8823, Train Accuracy: 0.5831, Test Loss: 0.8013, Test Accuracy: 0.6000
Epoch [5/50], Train Loss: 0.9083, Train Accuracy: 0.5666, Test Loss: 0.8515, Test Accuracy: 0.5582
Epoch [6/50], Train Loss: 0.8673, Train Accuracy: 0.5868, Test Loss: 0.8990, Test Accuracy: 0.5343
Epoch [7/50], Train Loss: 0.8497, Train Accuracy: 0.6085, Test Loss: 0.8179, Test Accuracy: 0.6000
Epoch [8/50], Train Loss: 0.8568, Train Accuracy: 0.5891, Test Loss: 0.7975, Test Accuracy: 0.6119
Epoch [9/50], Train Loss: 0.8304, Train Accuracy: 0.6145, Test Loss: 0.8120, Test Accuracy: 0.5940
Epoch [10/50], Train Loss: 0.8430, Train Accuracy: 0.6175, Test Loss: 0.7811, Test Accuracy: 0.6149
Epoch [11

In [33]:
exp_name = f'four-head-{time.strftime("%Y%m%d-%H%M%S")}'

kf = KFold(n_splits=5, shuffle=True, random_state=25)
accuracies = []
for i, (train_index, test_index) in enumerate(kf.split(X_balanced['mcg'])):
    # X_train_mcg, X_test_mcg = [X_balanced['mcg'][i] for i in train_index], [X_balanced['mcg'][i] for i in test_index]
    X_train_atac, X_test_atac = [X_balanced['atac'][i] for i in train_index], [X_balanced['atac'][i] for i in test_index]
    X_train_hic, X_test_hic = [X_balanced['hic'][i] for i in train_index], [X_balanced['hic'][i] for i in test_index]
    X_train_genebody, X_test_genebody = [X_balanced['genebody'][i] for i in train_index], [X_balanced['genebody'][i] for i in test_index]
    
    y_train, y_test = [y_balanced[i] for i in train_index], [y_balanced[i] for i in test_index]
    
    # X_train_mcg_normalized, X_test_mcg_normalized = normalize_features(X_train_mcg, X_test_mcg)
    X_train_atac_normalized, X_test_atac_normalized = normalize_features(X_train_atac, X_test_atac)
    X_train_hic_normalized, X_test_hic_normalized = normalize_features(X_train_hic, X_test_hic)
    X_train_genebody_normalized, X_test_genebody_normalized = normalize_features(X_train_genebody, X_test_genebody)
    
    elem_stats, epoch_stats, model = train_3head_model(X_train_atac_normalized, X_train_hic_normalized, X_train_genebody_normalized, y_train,
                                                          X_test_atac_normalized, X_test_hic_normalized, X_test_genebody_normalized, y_test, exp_name=exp_name, fold_idx=i)

    df_elem = pd.DataFrame({'preds' : elem_stats[0], 'labels' : elem_stats[1]})
    df_epoch = pd.DataFrame({'train_acc' : epoch_stats[0], 'train_loss' : epoch_stats[1], 
                             'test_acc' : epoch_stats[2], 'test_loss' : epoch_stats[3]})
    df_elem.to_csv(f"../3head_results/no_mcg/res_bal_elem_{i}.csv")
    df_epoch.to_csv(f"../3head_results/no_mcg/res_bal_epoch_{i}.csv")
    
    torch.save(model, f"../3head_results/no_mcg/model_{i}.pt")

    del model
    # del X_train_mcg, X_train_mcg_normalized
    del X_train_atac, X_train_atac_normalized
    del X_train_hic, X_train_hic_normalized
    del X_train_genebody, X_train_genebody_normalized


Epoch [1/50], Train Loss: 1.0667, Train Accuracy: 0.4888, Test Loss: 0.8518, Test Accuracy: 0.5582
Epoch [2/50], Train Loss: 0.8723, Train Accuracy: 0.5711, Test Loss: 0.8087, Test Accuracy: 0.6060
Epoch [3/50], Train Loss: 0.8499, Train Accuracy: 0.5928, Test Loss: 0.8824, Test Accuracy: 0.5970
Epoch [4/50], Train Loss: 0.8593, Train Accuracy: 0.5913, Test Loss: 0.9975, Test Accuracy: 0.5403
Epoch [5/50], Train Loss: 0.8411, Train Accuracy: 0.5936, Test Loss: 0.7638, Test Accuracy: 0.6179
Epoch [6/50], Train Loss: 0.8020, Train Accuracy: 0.6243, Test Loss: 0.8321, Test Accuracy: 0.6149
Epoch [7/50], Train Loss: 0.8098, Train Accuracy: 0.6100, Test Loss: 0.8056, Test Accuracy: 0.6090
Epoch [8/50], Train Loss: 0.7888, Train Accuracy: 0.6317, Test Loss: 0.7667, Test Accuracy: 0.6060
Epoch [9/50], Train Loss: 0.7836, Train Accuracy: 0.6422, Test Loss: 0.7457, Test Accuracy: 0.6299
Epoch [10/50], Train Loss: 0.7571, Train Accuracy: 0.6362, Test Loss: 0.7447, Test Accuracy: 0.6299
Epoch [11

## Clear GPU

In [34]:
# del model
# del optimizer

import gc
gc.collect()

torch.cuda.empty_cache()

In [35]:
print(f"Allocated Memory: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
print(f"Cached Memory: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

Allocated Memory: 0.00 MB
Cached Memory: 0.00 MB
