# SWEEM Model Implementation

This file is used to illustrate the preprocessing, training, and evaluation 
stages of our model. Comments and more information will be provided per section.

## Preprocessing

Here we load in the data and establish our train-test split. We also set up dataloaders
for us to be able to properly use the data within our training loop.

In [1]:
%load_ext autoreload
%autoreload 2

import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Running on", device)

Running on cuda:0


In [14]:
data = pd.read_csv('./Data/OmicsData/data.csv')
# data = pd.concat([data.iloc[:, :5541], data.iloc[:, -2:]], axis=1)

# Separate to make sure that there's an even distribution of 1s and 0s in train and test
data_ones = data[data.iloc[:, -1] == 1]
data_zeros = data[data.iloc[:, -1] == 0]

# Split the data into train and validation sets.
# Train test split is 80 20
train_data_ones, test_data_ones, train_labels_ones, test_labels_ones = train_test_split(
    data_ones.iloc[:, 1:-2], data_ones.iloc[:, -2:], test_size=0.2, random_state=42)
train_data_zeros, test_data_zeros, train_labels_zeros, test_labels_zeros = train_test_split(
    data_zeros.iloc[:, 1:-2], data_zeros.iloc[:, -2:], test_size=0.2, random_state=42)

# Concatenate in the end to make train and test
train_data = pd.concat((train_data_ones, train_data_zeros))
train_labels = pd.concat((train_labels_ones, train_labels_zeros))
test_data = pd.concat((test_data_ones, test_data_zeros)) 
test_labels = pd.concat((test_labels_ones, test_labels_zeros))

# 352 0's
# 123 1's

# Number of genes in rna: 5540
# Number of genes in scna: 5507
# Number of genes in methy: 4846
# Total number of genes: 15893
# Total number of samples in the final dataset: 475

#rna
print(train_data.columns[0])
print(train_data.columns[5539])
#scna
print(train_data.columns[5540])
print(train_data.columns[11046])
#methy
print(train_data.columns[11047])
print(train_data.columns[15892])

#labels
print(train_labels.columns)

# Create Tensor datasets
train_dataset = TensorDataset(torch.tensor(train_data.values, dtype=torch.float32), torch.tensor(train_labels.values, dtype=torch.float32))
#val_dataset   = TensorDataset(torch.tensor(validation_data.values, dtype=torch.float32), torch.tensor(validation_labels.values, dtype=torch.float32))
test_dataset  = TensorDataset(torch.tensor(test_data.values, dtype=torch.float32), torch.tensor(test_labels.values, dtype=torch.float32))

# Create DataLoader objects
batch_size = 16
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
#val_dataloader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False)
test_dataloader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False)

print("batch size: ", batch_size)

ZYX_rna
A2M_rna
ZYX_scna
A2M_scna
ZYX_methy
A2M_methy
Index(['OS_DAYS', 'OS_EVENT'], dtype='object')
batch size:  16


## Train Loop

In [17]:
### Training Loop
from model import SWEEM
from loss import temp_loss
num_epochs = 5000
epoch_train_losses = []
epoch_val_losses   = []

model = SWEEM(rna_dim = 5540, 
              scna_dim = 5507, 
              methy_dim = 4846, 
              use_rna = True,
              use_scna = True,
              use_methy = False,
              hidden_dim = 128, 
              self_att = True, 
              cross_att = False,
              device=device)
model.to(device)

# criterion = temp_loss
# binary cross entropy loss
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

for epoch in range(num_epochs):
    epoch_train_loss = 0
    epoch_val_loss   = 0
    
    ## Training
    #print(f"Epoch {epoch + 1} training:")
    progress_bar = tqdm(range(len(train_dataloader)))

    model.train()
    for i, (batchX, batchY) in enumerate(train_dataloader):        
        batchX = batchX.to(device)
        rna = batchX[:, :5540]
        scna = batchX[:, 5540:11047]
        methy = batchX[:, 11047:]
        time = batchY[:,0].reshape(-1, 1).to(device)
        event = batchY[:,1].reshape(-1, 1).to(device)

        outputs = model(event, rna=rna, scna=scna, methy=methy)
        
        # loss = criterion(outputs, time, event)
        loss = criterion(outputs, event)
        # print(f"batch {i+1} loss: ", loss)
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
        
        epoch_train_loss += loss.item()
        progress_bar.update(1)
        
    ## Validation
    #print(f"Epoch {epoch + 1} validation:")
    progress_bar = tqdm(range(len(test_dataloader)))
    
    model.eval()
    with torch.no_grad():
        for i, (batchX, batchY) in enumerate(test_dataloader):
            batchX = batchX.to(device)
            rna = batchX[:, :5540]
            scna = batchX[:, 5540:11047]
            methy = batchX[:, 11047:]
            time = batchY[:,0].reshape(-1, 1).to(device)
            event = batchY[:,1].reshape(-1, 1).to(device)
            outputs = model(event, rna=rna, scna=scna, methy=methy)


            #outputs = model(rna, scna, methy, event)
            # loss = criterion(outputs, time, event)
            loss = criterion(outputs, event)
            epoch_val_loss += loss.item()
            progress_bar.update(1)

    # Save and print losses
    epoch_train_loss /= len(train_dataloader)
    epoch_val_loss /= len(test_dataloader)
    epoch_train_losses.append(epoch_train_loss)
    epoch_val_losses.append(epoch_val_loss)
    if epoch % 20 == 0:
        print(f"Epoch {epoch + 1} training loss: {epoch_train_loss}")
        print(f"Epoch {epoch + 1} validation loss: {epoch_val_loss}")

 88%|████████▊ | 21/24 [00:14<00:02,  1.48it/s]
100%|██████████| 24/24 [00:01<00:00, 12.89it/s]


Epoch 1 training loss: 24.483292003472645
Epoch 1 validation loss: 26.041666666666668


100%|██████████| 6/6 [00:00<00:00, 161.82it/s]
100%|██████████| 24/24 [00:01<00:00, 13.38it/s]
100%|██████████| 6/6 [00:00<00:00, 160.46it/s]
100%|██████████| 24/24 [00:01<00:00, 13.18it/s]
100%|██████████| 6/6 [00:00<00:00, 146.04it/s]
100%|██████████| 24/24 [00:01<00:00, 13.21it/s]
100%|██████████| 6/6 [00:00<00:00, 164.39it/s]
100%|██████████| 24/24 [00:01<00:00, 13.31it/s]
100%|██████████| 6/6 [00:00<00:00, 132.18it/s]
100%|██████████| 24/24 [00:01<00:00, 12.54it/s]
100%|██████████| 6/6 [00:00<00:00, 140.84it/s]
100%|██████████| 24/24 [00:01<00:00, 13.03it/s]
100%|██████████| 6/6 [00:00<00:00, 157.16it/s]
100%|██████████| 24/24 [00:01<00:00, 13.36it/s]
100%|██████████| 6/6 [00:00<00:00, 161.79it/s]
100%|██████████| 24/24 [00:01<00:00, 13.33it/s]
100%|██████████| 6/6 [00:00<00:00, 142.18it/s]
100%|██████████| 24/24 [00:01<00:00, 12.09it/s]
100%|██████████| 6/6 [00:00<00:00, 153.19it/s]
100%|██████████| 24/24 [00:01<00:00, 12.77it/s]
100%|██████████| 6/6 [00:00<00:00, 165.21it/s]
100

Epoch 21 training loss: 25.87594699859619
Epoch 21 validation loss: 26.041666666666668


100%|██████████| 6/6 [00:00<00:00, 144.85it/s]
100%|██████████| 24/24 [00:01<00:00, 13.20it/s]
100%|██████████| 6/6 [00:00<00:00, 157.03it/s]
100%|██████████| 24/24 [00:01<00:00, 13.34it/s]
100%|██████████| 6/6 [00:00<00:00, 153.51it/s]
100%|██████████| 24/24 [00:01<00:00, 13.35it/s]
100%|██████████| 6/6 [00:00<00:00, 153.25it/s]
100%|██████████| 24/24 [00:01<00:00, 13.21it/s]
100%|██████████| 6/6 [00:00<00:00, 161.32it/s]
100%|██████████| 24/24 [00:01<00:00, 13.33it/s]
100%|██████████| 6/6 [00:00<00:00, 153.08it/s]
100%|██████████| 24/24 [00:01<00:00, 13.37it/s]
100%|██████████| 6/6 [00:00<00:00, 157.56it/s]
100%|██████████| 24/24 [00:01<00:00, 13.08it/s]
100%|██████████| 6/6 [00:00<00:00, 156.73it/s]
100%|██████████| 24/24 [00:01<00:00, 13.31it/s]
100%|██████████| 6/6 [00:00<00:00, 157.57it/s]
100%|██████████| 24/24 [00:01<00:00, 13.32it/s]
100%|██████████| 6/6 [00:00<00:00, 152.82it/s]
100%|██████████| 24/24 [00:01<00:00, 13.05it/s]
100%|██████████| 6/6 [00:00<00:00, 148.80it/s]
100

Epoch 41 training loss: 25.757575750350952
Epoch 41 validation loss: 26.041666666666668


100%|██████████| 6/6 [00:00<00:00, 153.45it/s]
100%|██████████| 24/24 [00:01<00:00, 13.02it/s]
100%|██████████| 6/6 [00:00<00:00, 151.58it/s]
100%|██████████| 24/24 [00:01<00:00, 12.59it/s]
100%|██████████| 6/6 [00:00<00:00, 121.56it/s]
100%|██████████| 24/24 [00:01<00:00, 12.09it/s]
100%|██████████| 6/6 [00:00<00:00, 153.40it/s]
100%|██████████| 24/24 [00:01<00:00, 12.37it/s]
100%|██████████| 6/6 [00:00<00:00, 148.19it/s]
100%|██████████| 24/24 [00:01<00:00, 12.19it/s]
100%|██████████| 6/6 [00:00<00:00, 123.74it/s]
100%|██████████| 24/24 [00:01<00:00, 12.48it/s]
100%|██████████| 6/6 [00:00<00:00, 134.35it/s]
100%|██████████| 24/24 [00:01<00:00, 12.17it/s]
100%|██████████| 6/6 [00:00<00:00, 129.72it/s]
100%|██████████| 24/24 [00:02<00:00, 11.92it/s]
100%|██████████| 6/6 [00:00<00:00, 118.93it/s]
100%|██████████| 24/24 [00:02<00:00, 11.93it/s]
100%|██████████| 6/6 [00:00<00:00, 119.24it/s]
100%|██████████| 24/24 [00:02<00:00, 11.89it/s]
100%|██████████| 6/6 [00:00<00:00, 124.30it/s]
100

Epoch 61 training loss: 25.87594699859619
Epoch 61 validation loss: 26.041666666666668


100%|██████████| 6/6 [00:00<00:00, 150.00it/s]
100%|██████████| 24/24 [00:01<00:00, 13.39it/s]
100%|██████████| 6/6 [00:00<00:00, 144.88it/s]
100%|██████████| 24/24 [00:01<00:00, 13.40it/s]
100%|██████████| 6/6 [00:00<00:00, 148.76it/s]
100%|██████████| 24/24 [00:01<00:00, 13.32it/s]
100%|██████████| 6/6 [00:00<00:00, 141.56it/s]
100%|██████████| 24/24 [00:01<00:00, 13.25it/s]
100%|██████████| 6/6 [00:00<00:00, 138.67it/s]
100%|██████████| 24/24 [00:01<00:00, 13.24it/s]
100%|██████████| 6/6 [00:00<00:00, 157.20it/s]
100%|██████████| 24/24 [00:01<00:00, 13.35it/s]
100%|██████████| 6/6 [00:00<00:00, 149.97it/s]
100%|██████████| 24/24 [00:01<00:00, 13.32it/s]
100%|██████████| 6/6 [00:00<00:00, 162.16it/s]
100%|██████████| 24/24 [00:01<00:00, 13.35it/s]
100%|██████████| 6/6 [00:00<00:00, 160.59it/s]
100%|██████████| 24/24 [00:01<00:00, 13.38it/s]
100%|██████████| 6/6 [00:00<00:00, 160.97it/s]
100%|██████████| 24/24 [00:01<00:00, 13.33it/s]
100%|██████████| 6/6 [00:00<00:00, 152.91it/s]
100

Epoch 81 training loss: 25.757575750350952
Epoch 81 validation loss: 26.041666666666668


100%|██████████| 6/6 [00:00<00:00, 156.75it/s]
100%|██████████| 24/24 [00:01<00:00, 13.22it/s]
100%|██████████| 6/6 [00:00<00:00, 152.97it/s]
100%|██████████| 24/24 [00:01<00:00, 13.12it/s]
100%|██████████| 6/6 [00:00<00:00, 141.64it/s]
100%|██████████| 24/24 [00:01<00:00, 13.25it/s]
100%|██████████| 6/6 [00:00<00:00, 149.64it/s]
100%|██████████| 24/24 [00:01<00:00, 13.11it/s]
100%|██████████| 6/6 [00:00<00:00, 133.34it/s]
100%|██████████| 24/24 [00:01<00:00, 13.23it/s]
100%|██████████| 6/6 [00:00<00:00, 130.43it/s]
100%|██████████| 24/24 [00:01<00:00, 13.17it/s]
100%|██████████| 6/6 [00:00<00:00, 157.58it/s]
100%|██████████| 24/24 [00:01<00:00, 13.32it/s]
100%|██████████| 6/6 [00:00<00:00, 148.77it/s]
100%|██████████| 24/24 [00:01<00:00, 13.31it/s]
100%|██████████| 6/6 [00:00<00:00, 137.69it/s]
100%|██████████| 24/24 [00:01<00:00, 13.17it/s]
100%|██████████| 6/6 [00:00<00:00, 135.23it/s]
100%|██████████| 24/24 [00:01<00:00, 13.08it/s]
100%|██████████| 6/6 [00:00<00:00, 136.35it/s]
100

Epoch 101 training loss: 25.520833333333332
Epoch 101 validation loss: 26.041666666666668


100%|██████████| 6/6 [00:00<00:00, 153.10it/s]
100%|██████████| 24/24 [00:01<00:00, 13.26it/s]
100%|██████████| 6/6 [00:00<00:00, 156.72it/s]
100%|██████████| 24/24 [00:01<00:00, 13.39it/s]
100%|██████████| 6/6 [00:00<00:00, 152.66it/s]
100%|██████████| 24/24 [00:01<00:00, 13.37it/s]
100%|██████████| 6/6 [00:00<00:00, 149.22it/s]
100%|██████████| 24/24 [00:01<00:00, 13.26it/s]
100%|██████████| 6/6 [00:00<00:00, 141.59it/s]
100%|██████████| 24/24 [00:01<00:00, 13.35it/s]
100%|██████████| 6/6 [00:00<00:00, 156.59it/s]
100%|██████████| 24/24 [00:01<00:00, 13.29it/s]
100%|██████████| 6/6 [00:00<00:00, 152.09it/s]
100%|██████████| 24/24 [00:01<00:00, 13.28it/s]
100%|██████████| 6/6 [00:00<00:00, 142.16it/s]
100%|██████████| 24/24 [00:01<00:00, 13.19it/s]
100%|██████████| 6/6 [00:00<00:00, 139.53it/s]
100%|██████████| 24/24 [00:01<00:00, 13.20it/s]
100%|██████████| 6/6 [00:00<00:00, 148.82it/s]
100%|██████████| 24/24 [00:01<00:00, 13.23it/s]
100%|██████████| 6/6 [00:00<00:00, 132.84it/s]
100

Epoch 121 training loss: 25.87594699859619
Epoch 121 validation loss: 26.041666666666668


100%|██████████| 6/6 [00:00<00:00, 152.53it/s]
100%|██████████| 24/24 [00:01<00:00, 13.28it/s]
100%|██████████| 6/6 [00:00<00:00, 157.93it/s]
100%|██████████| 24/24 [00:01<00:00, 13.37it/s]
100%|██████████| 6/6 [00:00<00:00, 157.17it/s]
100%|██████████| 24/24 [00:01<00:00, 13.24it/s]
100%|██████████| 6/6 [00:00<00:00, 152.87it/s]
100%|██████████| 24/24 [00:01<00:00, 13.31it/s]
100%|██████████| 6/6 [00:00<00:00, 156.26it/s]
100%|██████████| 24/24 [00:01<00:00, 13.31it/s]
100%|██████████| 6/6 [00:00<00:00, 156.75it/s]
100%|██████████| 24/24 [00:01<00:00, 13.23it/s]
100%|██████████| 6/6 [00:00<00:00, 143.24it/s]
100%|██████████| 24/24 [00:01<00:00, 13.28it/s]
100%|██████████| 6/6 [00:00<00:00, 157.14it/s]
100%|██████████| 24/24 [00:01<00:00, 13.32it/s]
100%|██████████| 6/6 [00:00<00:00, 144.86it/s]
100%|██████████| 24/24 [00:01<00:00, 13.30it/s]
100%|██████████| 6/6 [00:00<00:00, 134.05it/s]
100%|██████████| 24/24 [00:01<00:00, 13.26it/s]
100%|██████████| 6/6 [00:00<00:00, 141.10it/s]
100

Epoch 141 training loss: 25.757575750350952
Epoch 141 validation loss: 26.041666666666668


100%|██████████| 6/6 [00:00<00:00, 156.72it/s]
100%|██████████| 24/24 [00:01<00:00, 13.33it/s]
100%|██████████| 6/6 [00:00<00:00, 141.73it/s]
100%|██████████| 24/24 [00:01<00:00, 13.17it/s]
100%|██████████| 6/6 [00:00<00:00, 121.12it/s]
100%|██████████| 24/24 [00:01<00:00, 13.28it/s]
100%|██████████| 6/6 [00:00<00:00, 146.01it/s]
100%|██████████| 24/24 [00:01<00:00, 13.28it/s]
100%|██████████| 6/6 [00:00<00:00, 149.65it/s]
100%|██████████| 24/24 [00:01<00:00, 13.24it/s]
100%|██████████| 6/6 [00:00<00:00, 135.48it/s]
100%|██████████| 24/24 [00:01<00:00, 13.34it/s]
100%|██████████| 6/6 [00:00<00:00, 156.42it/s]
100%|██████████| 24/24 [00:01<00:00, 13.31it/s]
100%|██████████| 6/6 [00:00<00:00, 157.19it/s]
100%|██████████| 24/24 [00:01<00:00, 13.29it/s]
100%|██████████| 6/6 [00:00<00:00, 145.38it/s]
100%|██████████| 24/24 [00:01<00:00, 13.30it/s]
100%|██████████| 6/6 [00:00<00:00, 152.57it/s]
100%|██████████| 24/24 [00:01<00:00, 13.33it/s]
100%|██████████| 6/6 [00:00<00:00, 153.00it/s]
100

Epoch 161 training loss: 25.87594699859619
Epoch 161 validation loss: 26.041666666666668


100%|██████████| 6/6 [00:00<00:00, 135.36it/s]
100%|██████████| 24/24 [00:01<00:00, 13.15it/s]
100%|██████████| 6/6 [00:00<00:00, 156.73it/s]
100%|██████████| 24/24 [00:01<00:00, 13.22it/s]
100%|██████████| 6/6 [00:00<00:00, 145.11it/s]
100%|██████████| 24/24 [00:01<00:00, 13.27it/s]
100%|██████████| 6/6 [00:00<00:00, 149.00it/s]
100%|██████████| 24/24 [00:01<00:00, 13.22it/s]
100%|██████████| 6/6 [00:00<00:00, 148.68it/s]
100%|██████████| 24/24 [00:01<00:00, 13.26it/s]
100%|██████████| 6/6 [00:00<00:00, 153.48it/s]
100%|██████████| 24/24 [00:01<00:00, 13.27it/s]
100%|██████████| 6/6 [00:00<00:00, 157.09it/s]
100%|██████████| 24/24 [00:01<00:00, 13.22it/s]
100%|██████████| 6/6 [00:00<00:00, 145.17it/s]
100%|██████████| 24/24 [00:01<00:00, 13.28it/s]
100%|██████████| 6/6 [00:00<00:00, 148.70it/s]
100%|██████████| 24/24 [00:01<00:00, 13.31it/s]
100%|██████████| 6/6 [00:00<00:00, 112.49it/s]
100%|██████████| 24/24 [00:01<00:00, 13.26it/s]
100%|██████████| 6/6 [00:00<00:00, 140.99it/s]
100

Epoch 181 training loss: 25.87594699859619
Epoch 181 validation loss: 26.041666666666668


100%|██████████| 6/6 [00:00<00:00, 135.57it/s]
100%|██████████| 24/24 [00:01<00:00, 13.11it/s]
100%|██████████| 6/6 [00:00<00:00, 152.66it/s]
100%|██████████| 24/24 [00:01<00:00, 13.28it/s]
100%|██████████| 6/6 [00:00<00:00, 139.12it/s]
100%|██████████| 24/24 [00:01<00:00, 13.20it/s]
100%|██████████| 6/6 [00:00<00:00, 129.94it/s]
100%|██████████| 24/24 [00:01<00:00, 13.24it/s]
100%|██████████| 6/6 [00:00<00:00, 135.59it/s]
100%|██████████| 24/24 [00:01<00:00, 13.29it/s]
100%|██████████| 6/6 [00:00<00:00, 152.72it/s]
100%|██████████| 24/24 [00:01<00:00, 13.20it/s]
100%|██████████| 6/6 [00:00<00:00, 149.31it/s]
100%|██████████| 24/24 [00:01<00:00, 13.30it/s]
100%|██████████| 6/6 [00:00<00:00, 157.15it/s]
100%|██████████| 24/24 [00:01<00:00, 13.25it/s]
100%|██████████| 6/6 [00:00<00:00, 149.13it/s]
100%|██████████| 24/24 [00:01<00:00, 13.20it/s]
100%|██████████| 6/6 [00:00<00:00, 141.63it/s]
100%|██████████| 24/24 [00:01<00:00, 13.23it/s]
100%|██████████| 6/6 [00:00<00:00, 149.22it/s]
100

Epoch 201 training loss: 25.87594699859619
Epoch 201 validation loss: 26.041666666666668


100%|██████████| 6/6 [00:00<00:00, 153.14it/s]
100%|██████████| 24/24 [00:01<00:00, 13.27it/s]
100%|██████████| 6/6 [00:00<00:00, 148.60it/s]
100%|██████████| 24/24 [00:01<00:00, 13.23it/s]
100%|██████████| 6/6 [00:00<00:00, 144.92it/s]
100%|██████████| 24/24 [00:01<00:00, 13.26it/s]
100%|██████████| 6/6 [00:00<00:00, 152.66it/s]
100%|██████████| 24/24 [00:01<00:00, 13.28it/s]
100%|██████████| 6/6 [00:00<00:00, 157.11it/s]
100%|██████████| 24/24 [00:01<00:00, 13.23it/s]
100%|██████████| 6/6 [00:00<00:00, 148.84it/s]
100%|██████████| 24/24 [00:01<00:00, 13.28it/s]
100%|██████████| 6/6 [00:00<00:00, 157.09it/s]
100%|██████████| 24/24 [00:01<00:00, 13.27it/s]
100%|██████████| 6/6 [00:00<00:00, 148.65it/s]
100%|██████████| 24/24 [00:01<00:00, 13.17it/s]
100%|██████████| 6/6 [00:00<00:00, 152.38it/s]
100%|██████████| 24/24 [00:01<00:00, 13.28it/s]
100%|██████████| 6/6 [00:00<00:00, 142.75it/s]
100%|██████████| 24/24 [00:01<00:00, 13.29it/s]
100%|██████████| 6/6 [00:00<00:00, 156.57it/s]
100

Epoch 221 training loss: 25.757575750350952
Epoch 221 validation loss: 26.041666666666668


100%|██████████| 6/6 [00:00<00:00, 131.67it/s]
100%|██████████| 24/24 [00:01<00:00, 12.94it/s]
100%|██████████| 6/6 [00:00<00:00, 141.26it/s]
100%|██████████| 24/24 [00:01<00:00, 12.99it/s]
100%|██████████| 6/6 [00:00<00:00, 142.06it/s]
100%|██████████| 24/24 [00:01<00:00, 13.04it/s]
100%|██████████| 6/6 [00:00<00:00, 144.50it/s]
100%|██████████| 24/24 [00:01<00:00, 12.91it/s]
100%|██████████| 6/6 [00:00<00:00, 132.74it/s]
100%|██████████| 24/24 [00:01<00:00, 12.93it/s]
100%|██████████| 6/6 [00:00<00:00, 145.31it/s]
100%|██████████| 24/24 [00:01<00:00, 13.01it/s]
100%|██████████| 6/6 [00:00<00:00, 124.26it/s]
100%|██████████| 24/24 [00:01<00:00, 12.91it/s]
100%|██████████| 6/6 [00:00<00:00, 129.23it/s]
100%|██████████| 24/24 [00:01<00:00, 12.98it/s]
100%|██████████| 6/6 [00:00<00:00, 127.84it/s]
100%|██████████| 24/24 [00:01<00:00, 13.04it/s]
100%|██████████| 6/6 [00:00<00:00, 132.69it/s]
100%|██████████| 24/24 [00:01<00:00, 13.14it/s]
100%|██████████| 6/6 [00:00<00:00, 139.10it/s]
100

Epoch 241 training loss: 25.757575750350952
Epoch 241 validation loss: 26.041666666666668


100%|██████████| 6/6 [00:00<00:00, 112.28it/s]
100%|██████████| 24/24 [00:01<00:00, 13.17it/s]
100%|██████████| 6/6 [00:00<00:00, 151.12it/s]
100%|██████████| 24/24 [00:01<00:00, 13.15it/s]
100%|██████████| 6/6 [00:00<00:00, 149.00it/s]
100%|██████████| 24/24 [00:01<00:00, 13.17it/s]
100%|██████████| 6/6 [00:00<00:00, 153.04it/s]
100%|██████████| 24/24 [00:01<00:00, 13.15it/s]
100%|██████████| 6/6 [00:00<00:00, 145.70it/s]
100%|██████████| 24/24 [00:01<00:00, 13.15it/s]
100%|██████████| 6/6 [00:00<00:00, 141.95it/s]
100%|██████████| 24/24 [00:01<00:00, 13.15it/s]
100%|██████████| 6/6 [00:00<00:00, 114.46it/s]
100%|██████████| 24/24 [00:01<00:00, 12.94it/s]
100%|██████████| 6/6 [00:00<00:00, 94.62it/s]
100%|██████████| 24/24 [00:01<00:00, 12.74it/s]
100%|██████████| 6/6 [00:00<00:00, 126.11it/s]
100%|██████████| 24/24 [00:01<00:00, 13.11it/s]
100%|██████████| 6/6 [00:00<00:00, 138.83it/s]
 33%|███▎      | 8/24 [00:00<00:01, 12.80it/s]

In [None]:
### Sanity Check
model.eval()
with torch.no_grad():
    for (batchX, batchY) in test_dataloader:
        batchX = batchX.to(device)
        rna = batchX[:, :5540]
        scna = batchX[:, 5540:11047]
        methy = batchX[:, 11047:]
        time = batchY[:,0].reshape(-1, 1).to(device)
        event = batchY[:,1].reshape(-1, 1).to(device)
        outputs = model(rna, scna, methy, event)
        print(f"times: {time}")
        print(f"events: {event}")
        print(f"predictions: {outputs}")
        break

times: tensor([[ 706.],
        [ 772.],
        [ 184.],
        [ 455.],
        [ 407.],
        [ 544.],
        [1585.],
        [ 411.],
        [ 956.],
        [1354.],
        [ 835.],
        [2381.],
        [ 648.],
        [1257.],
        [ 748.],
        [6423.],
        [1401.],
        [ 442.],
        [ 993.],
        [2107.],
        [ 900.],
        [  96.],
        [ 491.],
        [5546.],
        [ 629.],
        [ 547.],
        [ 122.],
        [3725.],
        [  62.],
        [ 605.],
        [1220.],
        [1262.]])
events: tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [1.],
        [1.],
      