# SWEEM Model Implementation

This file is used to illustrate the preprocessing, training, and evaluation 
stages of our model. Comments and more information will be provided per section.

## Preprocessing

Here we load in the data and establish our train-test split. We also set up dataloaders
for us to be able to properly use the data within our training loop.

In [156]:
%load_ext autoreload
%autoreload 2

import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader, TensorDataset

data = pd.read_csv('./Data/OmicsData/data.csv')
data = pd.concat([data.iloc[:, :5541], data.iloc[:, -2:]], axis=1)

# Split the data into train and validation sets.
# Train test validaiton split is 80 10 10
train_data, test_data, train_labels, test_labels = train_test_split(
    data.iloc[:, 1:-2], data.iloc[:, -2:], test_size=0.2, random_state=42)

test_data, validation_data, test_labels, validation_labels = train_test_split(
    test_data, test_labels, test_size=0.5, random_state=42)

print(data.shape)
print(train_data.shape)
print(test_data.shape)
print(validation_data.shape)

# Number of genes in rna: 5540
# Number of genes in scna: 5507
# Number of genes in methy: 4846
# Total number of samples in the final dataset: 475

#rna
print(train_data.columns[0])
print(train_data.columns[5539])
# #scna
# print(train_data.columns[5540])
# print(train_data.columns[11046])
# #methy
# print(train_data.columns[11047])
# print(train_data.columns[15892])

#labels
print(train_labels.columns)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
(475, 5543)
(380, 5540)
(47, 5540)
(48, 5540)
ZYX_rna
A2M_rna
Index(['OS_DAYS', 'OS_EVENT'], dtype='object')


In [157]:
# Create Tensor datasets
train_dataset = TensorDataset(torch.tensor(train_data.values), torch.tensor(train_labels.values))
val_dataset   = TensorDataset(torch.tensor(validation_data.values), torch.tensor(validation_labels.values))
test_dataset  = TensorDataset(torch.tensor(test_data.values), torch.tensor(test_labels.values))

# Create DataLoader objects
batch_size = 380
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False)
test_dataloader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False)

## Define Self-Attention Model

In [159]:
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm

class LinearRegression(nn.Module):
    def __init__(self, input_size):
        super(LinearRegression, self).__init__()
        self.linear1 = nn.Linear(input_size, 256)
        self.linear2 = nn.Linear(256, 1)
        self.linearOut = nn.Linear(2, 1)
        self.relu = nn.ReLU()
        self.identity = nn.Identity()

    def forward(self, x, x2):
        out = self.linear1(x)
        out = self.relu(out)
        out = self.linear2(out)
        out = self.relu(out)
        out = torch.cat((out, x2), 1)
        out = self.linearOut(out)
        # out = torch.sigmoid(out)
        return out

## Train Loop with Alt Dataloader

In [172]:
### Training Loop
from model import SelfAttentionModel
from loss import temp_loss, neg_par_log_likelihood
num_epochs = 50
epoch_train_losses = []
epoch_val_losses   = []

# model = LinearRegression(5540)
model = SelfAttentionModel(5540, 1024, 32)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Running on", device)
model.to(device)

criterion = neg_par_log_likelihood
optimizer = optim.Adam(model.parameters(), lr=0.01)

for epoch in range(num_epochs):
    epoch_train_loss = 0
    epoch_val_loss   = 0
    print(f"Epoch {epoch + 1} training:")
    progress_bar = tqdm(range(len(train_dataloader)))

    ## Training
    model.train()
    for (batchX, batchY) in train_dataloader:   
        batchX = batchX.to(torch.float32).to(device)
        time = batchY[:,0].to(torch.float32).to(device).reshape(-1, 1)
        event = batchY[:,1].to(torch.float32).to(device).reshape(-1, 1)
        # print("batchX", batchX[:, :5])
        # print("time", time)
        # print("event", event)
        
        outputs = model(batchX, event)
        # print("predictions", outputs)
        
        # print(outputs.shape, time.shape, event.shape)
        loss = criterion(outputs, time, event)
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
        
        epoch_train_loss += loss.item()
        progress_bar.update(1)
    
    ## Validation
    model.eval()
    with torch.no_grad():
        for (batchX, batchY) in val_dataloader:
            batchX = batchX.to(torch.float32).to(device)
            time = batchY[:,0].to(torch.float32).to(device).reshape(-1, 1)
            event = batchY[:,1].to(torch.float32).to(device).reshape(-1, 1)
            outputs = model(batchX, event)
            loss = criterion(outputs, time, event)
            epoch_val_loss += loss.item()

    # Save and print losses
    epoch_train_loss /= len(train_dataloader)
    epoch_val_loss /= len(val_dataloader)
    epoch_train_losses.append(epoch_train_loss)
    epoch_val_losses.append(epoch_val_loss)
    print(f"Epoch {epoch + 1} training loss: {epoch_train_loss}")
    print(f"Epoch {epoch + 1} validation loss: {epoch_val_loss}")

Running on cpu
Epoch 1 training:


100%|██████████| 1/1 [00:06<00:00,  6.67s/it]


Epoch 1 training loss: 4.920126438140869
Epoch 1 validation loss: 3.359114646911621
Epoch 2 training:


100%|██████████| 1/1 [00:00<00:00, 29.20it/s]


Epoch 2 training loss: 4.933938503265381
Epoch 2 validation loss: 3.3591151237487793
Epoch 3 training:


100%|██████████| 1/1 [00:00<00:00, 33.35it/s]


Epoch 3 training loss: 4.949088096618652
Epoch 3 validation loss: 3.3591151237487793
Epoch 4 training:


100%|██████████| 1/1 [00:00<00:00, 38.59it/s]


Epoch 4 training loss: 4.940771579742432
Epoch 4 validation loss: 3.3591151237487793
Epoch 5 training:


100%|██████████| 1/1 [00:00<00:00, 41.60it/s]


Epoch 5 training loss: 4.915983200073242
Epoch 5 validation loss: 3.3591151237487793
Epoch 6 training:


100%|██████████| 1/1 [00:00<00:00, 40.66it/s]


Epoch 6 training loss: 4.951780796051025
Epoch 6 validation loss: 3.3591151237487793
Epoch 7 training:


100%|██████████| 1/1 [00:00<00:00, 39.45it/s]


Epoch 7 training loss: 4.994030475616455
Epoch 7 validation loss: 3.3591151237487793
Epoch 8 training:


100%|██████████| 1/1 [00:00<00:00, 37.89it/s]


Epoch 8 training loss: 4.953352928161621
Epoch 8 validation loss: 3.3591151237487793
Epoch 9 training:


100%|██████████| 1/1 [00:00<00:00, 40.22it/s]


Epoch 9 training loss: 5.0112385749816895
Epoch 9 validation loss: 3.3591151237487793
Epoch 10 training:


100%|██████████| 1/1 [00:00<00:00, 40.50it/s]


Epoch 10 training loss: 4.948269367218018
Epoch 10 validation loss: 3.3591151237487793
Epoch 11 training:


100%|██████████| 1/1 [00:00<00:00, 39.35it/s]


Epoch 11 training loss: 4.883542060852051
Epoch 11 validation loss: 3.3591151237487793
Epoch 12 training:


100%|██████████| 1/1 [00:00<00:00, 40.72it/s]


Epoch 12 training loss: 4.995030879974365
Epoch 12 validation loss: 3.3591151237487793
Epoch 13 training:


100%|██████████| 1/1 [00:00<00:00, 39.44it/s]


Epoch 13 training loss: 4.875766277313232
Epoch 13 validation loss: 3.3591151237487793
Epoch 14 training:


100%|██████████| 1/1 [00:00<00:00, 40.03it/s]


Epoch 14 training loss: 5.055960655212402
Epoch 14 validation loss: 3.3591151237487793
Epoch 15 training:


100%|██████████| 1/1 [00:00<00:00, 41.50it/s]


Epoch 15 training loss: 4.979366779327393
Epoch 15 validation loss: 3.3591151237487793
Epoch 16 training:


100%|██████████| 1/1 [00:00<00:00, 42.17it/s]


Epoch 16 training loss: 4.950570106506348
Epoch 16 validation loss: 3.3591151237487793
Epoch 17 training:


100%|██████████| 1/1 [00:00<00:00, 39.10it/s]


Epoch 17 training loss: 5.0021772384643555
Epoch 17 validation loss: 3.3591151237487793
Epoch 18 training:


100%|██████████| 1/1 [00:00<00:00, 37.28it/s]


Epoch 18 training loss: 4.838803768157959
Epoch 18 validation loss: 3.3591151237487793
Epoch 19 training:


100%|██████████| 1/1 [00:00<00:00, 34.86it/s]


Epoch 19 training loss: 5.147669792175293
Epoch 19 validation loss: 3.3591151237487793
Epoch 20 training:


100%|██████████| 1/1 [00:00<00:00, 24.55it/s]


Epoch 20 training loss: 4.973489761352539
Epoch 20 validation loss: 3.3591151237487793
Epoch 21 training:


100%|██████████| 1/1 [00:00<00:00, 29.07it/s]


Epoch 21 training loss: 4.9539642333984375
Epoch 21 validation loss: 3.3591151237487793
Epoch 22 training:


100%|██████████| 1/1 [00:00<00:00, 38.55it/s]


Epoch 22 training loss: 4.961251258850098
Epoch 22 validation loss: 3.3591151237487793
Epoch 23 training:


100%|██████████| 1/1 [00:00<00:00, 40.69it/s]


Epoch 23 training loss: 5.072635173797607
Epoch 23 validation loss: 3.3591151237487793
Epoch 24 training:


100%|██████████| 1/1 [00:00<00:00, 41.23it/s]


Epoch 24 training loss: 5.056114196777344
Epoch 24 validation loss: 3.3591151237487793
Epoch 25 training:


100%|██████████| 1/1 [00:00<00:00, 42.29it/s]


Epoch 25 training loss: 5.067052364349365
Epoch 25 validation loss: 3.3591151237487793
Epoch 26 training:


100%|██████████| 1/1 [00:00<00:00, 41.73it/s]


Epoch 26 training loss: 5.052810192108154
Epoch 26 validation loss: 3.3591151237487793
Epoch 27 training:


100%|██████████| 1/1 [00:00<00:00, 41.58it/s]


Epoch 27 training loss: 4.979764938354492
Epoch 27 validation loss: 3.3591151237487793
Epoch 28 training:


100%|██████████| 1/1 [00:00<00:00, 42.07it/s]


Epoch 28 training loss: 4.893612384796143
Epoch 28 validation loss: 3.3591151237487793
Epoch 29 training:


100%|██████████| 1/1 [00:00<00:00, 42.95it/s]


Epoch 29 training loss: 4.827571392059326
Epoch 29 validation loss: 3.3591151237487793
Epoch 30 training:


100%|██████████| 1/1 [00:00<00:00, 43.46it/s]


Epoch 30 training loss: 4.93640661239624
Epoch 30 validation loss: 3.3591151237487793
Epoch 31 training:


100%|██████████| 1/1 [00:00<00:00, 40.82it/s]


Epoch 31 training loss: 4.974845886230469
Epoch 31 validation loss: 3.3591151237487793
Epoch 32 training:


100%|██████████| 1/1 [00:00<00:00, 42.63it/s]


Epoch 32 training loss: 4.917346000671387
Epoch 32 validation loss: 3.3591151237487793
Epoch 33 training:


100%|██████████| 1/1 [00:00<00:00, 40.87it/s]


Epoch 33 training loss: 4.956727981567383
Epoch 33 validation loss: 3.3591151237487793
Epoch 34 training:


100%|██████████| 1/1 [00:00<00:00, 41.26it/s]


Epoch 34 training loss: 4.774108409881592
Epoch 34 validation loss: 3.3591151237487793
Epoch 35 training:


100%|██████████| 1/1 [00:00<00:00, 38.53it/s]


Epoch 35 training loss: 5.012545585632324
Epoch 35 validation loss: 3.3591151237487793
Epoch 36 training:


100%|██████████| 1/1 [00:00<00:00, 37.08it/s]


Epoch 36 training loss: 4.890004634857178
Epoch 36 validation loss: 3.3591151237487793
Epoch 37 training:


100%|██████████| 1/1 [00:00<00:00, 35.51it/s]


Epoch 37 training loss: 4.958044528961182
Epoch 37 validation loss: 3.3591151237487793
Epoch 38 training:


100%|██████████| 1/1 [00:00<00:00, 36.46it/s]


Epoch 38 training loss: 5.106845378875732
Epoch 38 validation loss: 3.3591151237487793
Epoch 39 training:


100%|██████████| 1/1 [00:00<00:00, 37.81it/s]


Epoch 39 training loss: 4.907041072845459
Epoch 39 validation loss: 3.3591151237487793
Epoch 40 training:


100%|██████████| 1/1 [00:00<00:00, 36.98it/s]


Epoch 40 training loss: 4.830103397369385
Epoch 40 validation loss: 3.3591151237487793
Epoch 41 training:


100%|██████████| 1/1 [00:00<00:00, 36.44it/s]


Epoch 41 training loss: 4.83859920501709
Epoch 41 validation loss: 3.3591151237487793
Epoch 42 training:


100%|██████████| 1/1 [00:00<00:00, 37.50it/s]


Epoch 42 training loss: 4.986613750457764
Epoch 42 validation loss: 3.3591151237487793
Epoch 43 training:


100%|██████████| 1/1 [00:00<00:00, 36.75it/s]


Epoch 43 training loss: 4.766414642333984
Epoch 43 validation loss: 3.3591151237487793
Epoch 44 training:


100%|██████████| 1/1 [00:00<00:00, 38.95it/s]


Epoch 44 training loss: 4.950599670410156
Epoch 44 validation loss: 3.3591151237487793
Epoch 45 training:


100%|██████████| 1/1 [00:00<00:00, 36.50it/s]


Epoch 45 training loss: 4.901646137237549
Epoch 45 validation loss: 3.3591151237487793
Epoch 46 training:


100%|██████████| 1/1 [00:00<00:00, 38.38it/s]


Epoch 46 training loss: 4.875185489654541
Epoch 46 validation loss: 3.3591151237487793
Epoch 47 training:


100%|██████████| 1/1 [00:00<00:00, 26.74it/s]


Epoch 47 training loss: 4.9830522537231445
Epoch 47 validation loss: 3.3591151237487793
Epoch 48 training:


100%|██████████| 1/1 [00:00<00:00, 37.86it/s]


Epoch 48 training loss: 5.0894455909729
Epoch 48 validation loss: 3.3591151237487793
Epoch 49 training:


100%|██████████| 1/1 [00:00<00:00, 36.95it/s]


Epoch 49 training loss: 4.9651875495910645
Epoch 49 validation loss: 3.3591151237487793
Epoch 50 training:


100%|██████████| 1/1 [00:00<00:00, 33.33it/s]


Epoch 50 training loss: 5.0098795890808105
Epoch 50 validation loss: 3.3591151237487793


In [None]:
### Sanity Checking Outputs

# Check the outputs of the model on the test set
model.eval()
sample_number = 5
for test_images, test_labels in test_dataloader: 
    outputs = model(test_images.to(device).to(torch.float32))
    print("Intended Output: ", test_labels)
    print("Actual Output: ", outputs)
    break