# SWEEM Model Implementation

This file is used to illustrate the preprocessing, training, and evaluation 
stages of our model. Comments and more information will be provided per section.

## Preprocessing

Here we load in the data and establish our train-test split. We also set up dataloaders
for us to be able to properly use the data within our training loop.

In [12]:
%load_ext autoreload
%autoreload 2

import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Running on", device)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Running on cpu


In [56]:
data = pd.read_csv('./Data/OmicsData/data.csv')
# data = pd.concat([data.iloc[:, :5541], data.iloc[:, -2:]], axis=1)

# Separate to make sure that there's an even distribution of 1s and 0s in train and test
data_ones = data[data.iloc[:, -1] == 1]
data_zeros = data[data.iloc[:, -1] == 0]

# Split the data into train and validation sets.
# Train test split is 80 20
train_data_ones, test_data_ones, train_labels_ones, test_labels_ones = train_test_split(
    data_ones.iloc[:, 1:-2], data_ones.iloc[:, -2:], test_size=0.2, random_state=42)
train_data_zeros, test_data_zeros, train_labels_zeros, test_labels_zeros = train_test_split(
    data_zeros.iloc[:, 1:-2], data_zeros.iloc[:, -2:], test_size=0.2, random_state=42)

# Concatenate in the end to make train and test
train_data = pd.concat((train_data_ones, train_data_zeros))
train_labels = pd.concat((train_labels_ones, train_labels_zeros))
test_data = pd.concat((test_data_ones, test_data_zeros)) 
test_labels = pd.concat((test_labels_ones, test_labels_zeros))

# 352 0's
# 123 1's

# Number of genes in rna: 5540
# Number of genes in scna: 5507
# Number of genes in methy: 4846
# Total number of genes: 15893
# Total number of samples in the final dataset: 475

#rna
print(train_data.columns[0])
print(train_data.columns[5539])
#scna
print(train_data.columns[5540])
print(train_data.columns[11046])
#methy
print(train_data.columns[11047])
print(train_data.columns[15892])

#labels
print(train_labels.columns)

# Create Tensor datasets
train_dataset = TensorDataset(torch.tensor(train_data.values, dtype=torch.float32), torch.tensor(train_labels.values, dtype=torch.float32))
#val_dataset   = TensorDataset(torch.tensor(validation_data.values, dtype=torch.float32), torch.tensor(validation_labels.values, dtype=torch.float32))
test_dataset  = TensorDataset(torch.tensor(test_data.values, dtype=torch.float32), torch.tensor(test_labels.values, dtype=torch.float32))

# Create DataLoader objects
batch_size = 16
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
#val_dataloader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False)
test_dataloader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=True)

print("batch size: ", batch_size)

ZYX_rna
A2M_rna
ZYX_scna
A2M_scna
ZYX_methy
A2M_methy
Index(['OS_DAYS', 'OS_EVENT'], dtype='object')
batch size:  16


## Train Loop

In [70]:
settings = {
    "model": {
        "rna_dim": 5540,
        "scna_dim": 5507,
        "methy_dim": 4846,
        "use_rna": True,
        "use_scna": True,
        "use_methy": False,
        "hidden_dim": 32,
        "self_att": False,
        "cross_att": False,
        "device": device
    },
    "train": {
        "lr": 0.0001,
        "epochs": 101,
        "epoch_mod": 10
    }
}

In [71]:
### Training Loop
from model import SWEEM
from loss import temp_loss, neg_par_log_likelihood

epoch_train_losses = []
epoch_val_losses   = []

model = SWEEM(**settings["model"])
model.to(device)

# criterion = temp_loss
# binary cross entropy loss
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=settings["train"]["lr"])

for epoch in range(settings["train"]["epochs"]):
    epoch_train_loss = 0
    epoch_val_loss   = 0
    
    ## Training
    # print(f"Epoch {epoch + 1} training:")
    # progress_bar = tqdm(range(len(train_dataloader)))

    model.train()
    for i, (batchX, batchY) in enumerate(train_dataloader):        
        batchX = batchX.to(device)
        rna = batchX[:, :5540]
        scna = batchX[:, 5540:11047]
        methy = batchX[:, 11047:]
        time = batchY[:,0].reshape(-1, 1).to(device)
        event = batchY[:,1].reshape(-1, 1).to(device)

        outputs = model(event, rna=rna, scna=scna, methy=methy)
        
        # loss = criterion(outputs, time, event)
        loss = criterion(outputs, event)
        # print(f"batch {i+1} loss: ", loss)
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
        
        epoch_train_loss += loss.item()
        # progress_bar.update(1)
        
    ## Validation
    # print(f"Epoch {epoch + 1} validation:")
    # progress_bar = tqdm(range(len(test_dataloader)))
    
    model.eval()
    with torch.no_grad():
        for i, (batchX, batchY) in enumerate(test_dataloader):
            batchX = batchX.to(device)
            rna = batchX[:, :5540]
            scna = batchX[:, 5540:11047]
            methy = batchX[:, 11047:]
            time = batchY[:,0].reshape(-1, 1).to(device)
            event = batchY[:,1].reshape(-1, 1).to(device)
            outputs = model(event, rna=rna, scna=scna, methy=methy)

            #outputs = model(rna, scna, methy, event)
            # loss = criterion(outputs, time, event)
            loss = criterion(outputs, event)
            epoch_val_loss += loss.item()
            # progress_bar.update(1)

    # Save and print losses
    epoch_train_loss /= len(train_dataloader)
    epoch_val_loss /= len(test_dataloader)
    epoch_train_losses.append(epoch_train_loss)
    epoch_val_losses.append(epoch_val_loss)
    if epoch % settings["train"]["epoch_mod"] == 0:
        print(f"Epoch {epoch + 1} training loss: {epoch_train_loss}")
        print(f"Epoch {epoch + 1} validation loss: {epoch_val_loss}")

Epoch 1 training loss: 1.091178297996521
Epoch 1 validation loss: 0.6686240285634995
Epoch 11 training loss: 0.6109214077393214
Epoch 11 validation loss: 0.5951964457829794
Epoch 21 training loss: 0.5825897653897604
Epoch 21 validation loss: 0.6006471465031306
Epoch 31 training loss: 0.5573716225723425
Epoch 31 validation loss: 0.5436550875504812
Epoch 41 training loss: 0.5563496152559916
Epoch 41 validation loss: 0.5546911706527075
Epoch 51 training loss: 0.5519858623544375
Epoch 51 validation loss: 0.5704788416624069
Epoch 61 training loss: 0.5693352917830149
Epoch 61 validation loss: 0.5341236641009649
Epoch 71 training loss: 0.5607726437350115
Epoch 71 validation loss: 0.4939448932806651
Epoch 81 training loss: 0.5264037922024727
Epoch 81 validation loss: 0.5111846625804901
Epoch 91 training loss: 0.5330311780174574
Epoch 91 validation loss: 0.589444657166799
Epoch 101 training loss: 0.541891548782587
Epoch 101 validation loss: 0.5184443593025208


In [72]:
### Sanity Check
model.eval()
with torch.no_grad():
    for (batchX, batchY) in test_dataloader:
        batchX = batchX.to(device)
        rna = batchX[:, :5540]
        scna = batchX[:, 5540:11047]
        methy = batchX[:, 11047:]
        time = batchY[:,0].reshape(-1, 1).to(device)
        event = batchY[:,1].reshape(-1, 1).to(device)
        outputs = model(event, rna=rna, scna=scna, methy=methy)
        
        # concat torch tensors
        table = torch.cat((time, event, outputs), 1)
        
        # print row by row
        print("Sanity Check:")
        print("time, event, predicted")
        for row in table:
            print(row.tolist())
        break

Sanity Check:
time, event, predicted
[492.0, 1.0, 0.49946901202201843]
[1058.0, 0.0, 0.1540626883506775]
[1585.0, 1.0, 0.23833434283733368]
[257.0, 0.0, 0.43128618597984314]
[2433.0, 1.0, 0.38814905285835266]
[1191.0, 0.0, 0.2459455281496048]
[2772.0, 0.0, 0.3042075037956238]
[1183.0, 1.0, 0.17927546799182892]
[522.0, 0.0, 0.23103226721286774]
[533.0, 0.0, 0.38528895378112793]
[1458.0, 0.0, 0.2566658854484558]
[55.0, 0.0, 0.3051906228065491]
[3253.0, 0.0, 0.49946901202201843]
[411.0, 0.0, 0.37086620926856995]
[585.0, 0.0, 0.2274782508611679]
[544.0, 0.0, 0.3775646686553955]


In [73]:
import checkpoint

checkpoint.save("./sweem.model", model, optimizer, epoch_train_losses, epoch_val_losses, settings)

model, optimizer, epoch_train_losses, epoch_val_losses, settings = checkpoint.load("./sweem.model", SWEEM, optim.Adam)

In [74]:
### Sanity Check
model.eval()
with torch.no_grad():
    for (batchX, batchY) in test_dataloader:
        batchX = batchX.to(device)
        rna = batchX[:, :5540]
        scna = batchX[:, 5540:11047]
        methy = batchX[:, 11047:]
        time = batchY[:,0].reshape(-1, 1).to(device)
        event = batchY[:,1].reshape(-1, 1).to(device)
        outputs = model(event, rna=rna, scna=scna, methy=methy)
        
        # concat torch tensors
        table = torch.cat((time, event, outputs), 1)
        
        # print row by row
        print("Sanity Check:")
        print("time, event, predicted")
        for row in table:
            print(row.tolist())
        break

Sanity Check:
time, event, predicted
[835.0, 0.0, 0.2668880820274353]
[585.0, 0.0, 0.22324496507644653]
[492.0, 1.0, 0.49946901202201843]
[351.0, 1.0, 0.39260634779930115]
[992.0, 0.0, 0.16521213948726654]
[1481.0, 1.0, 0.11523571610450745]
[326.0, 0.0, 0.49946901202201843]
[1242.0, 1.0, 0.3857501745223999]
[1183.0, 1.0, 0.18767809867858887]
[442.0, 0.0, 0.12709879875183105]
[1137.0, 1.0, 0.21330054104328156]
[1393.0, 0.0, 0.49946901202201843]
[153.0, 0.0, 0.1625303030014038]
[566.0, 0.0, 0.45260533690452576]
[908.0, 0.0, 0.49946901202201843]
[442.0, 0.0, 0.49946901202201843]


In [75]:
print(epoch_train_losses)

[1.091178297996521, 0.6804453569153944, 0.6231756955385208, 0.6233249207337698, 0.6118835993111134, 0.6371393452088038, 0.6210088841617107, 0.6015973823765913, 0.5887632754941782, 0.5704758378366629, 0.6109214077393214, 0.5836750157177448, 0.6033178592721621, 0.6143522250155607, 0.5988648335138956, 0.5756115627785524, 0.6025844464699427, 0.5780049984653791, 0.6006980662544569, 0.5883169881999493, 0.5825897653897604, 0.5896185760696729, 0.5902309653659662, 0.6154125332832336, 0.5851350935796896, 0.5549411637087663, 0.5669780758519968, 0.5781946939726671, 0.5811946851511797, 0.5735298444827398, 0.5573716225723425, 0.5893578144411246, 0.5751541579763094, 0.5706254243850708, 0.5452771161993345, 0.5425012335181236, 0.5652928203344345, 0.5566865851481756, 0.5741078394154707, 0.5471719801425934, 0.5563496152559916, 0.5660253005723158, 0.5536101547380289, 0.5570531040430069, 0.5383208480974039, 0.5439110932250818, 0.5770820217827956, 0.5656026514867941, 0.551738460858663, 0.5632294379174709, 0