# SWEEM Model Implementation

This file is used to illustrate the preprocessing, training, and evaluation 
stages of our model. Comments and more information will be provided per section.

## Preprocessing

Here we load in the data and establish our train-test split. We also set up dataloaders
for us to be able to properly use the data within our training loop.

In [13]:
%load_ext autoreload
%autoreload 2

import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import checkpoint

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Running on", device)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Running on cpu


In [14]:
data = pd.read_csv('./Data/OmicsData/data.csv')
# data = pd.concat([data.iloc[:, :5541], data.iloc[:, -2:]], axis=1)

# Separate to make sure that there's an even distribution of 1s and 0s in train and test
data_ones = data[data.iloc[:, -1] == 1]
data_zeros = data[data.iloc[:, -1] == 0]

# Split the data into train and validation sets.
# Train test split is 80 20
train_data_ones, test_data_ones, train_labels_ones, test_labels_ones = train_test_split(
    data_ones.iloc[:, 1:-2], data_ones.iloc[:, -2:], test_size=0.2, random_state=42)
train_data_zeros, test_data_zeros, train_labels_zeros, test_labels_zeros = train_test_split(
    data_zeros.iloc[:, 1:-2], data_zeros.iloc[:, -2:], test_size=0.2, random_state=42)

# Concatenate in the end to make train and test
train_data = pd.concat((train_data_ones, train_data_zeros))
train_labels = pd.concat((train_labels_ones, train_labels_zeros))
test_data = pd.concat((test_data_ones, test_data_zeros)) 
test_labels = pd.concat((test_labels_ones, test_labels_zeros))

# 352 0's
# 123 1's

# Number of genes in rna: 5540
# Number of genes in scna: 5507
# Number of genes in methy: 4846
# Total number of genes: 15893
# Total number of samples in the final dataset: 475

#rna
print(train_data.columns[0])
print(train_data.columns[5539])
#scna
print(train_data.columns[5540])
print(train_data.columns[11046])
#methy
print(train_data.columns[11047])
print(train_data.columns[15892])

#labels
print(train_labels.columns)

# Create Tensor datasets
train_dataset = TensorDataset(torch.tensor(train_data.values, dtype=torch.float32), torch.tensor(train_labels.values, dtype=torch.float32))
#val_dataset   = TensorDataset(torch.tensor(validation_data.values, dtype=torch.float32), torch.tensor(validation_labels.values, dtype=torch.float32))
test_dataset  = TensorDataset(torch.tensor(test_data.values, dtype=torch.float32), torch.tensor(test_labels.values, dtype=torch.float32))

# Create DataLoader objects
batch_size = 16
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
#val_dataloader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False)
test_dataloader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=True)

print("batch size: ", batch_size)

ZYX_rna
A2M_rna
ZYX_scna
A2M_scna
ZYX_methy
A2M_methy
Index(['OS_DAYS', 'OS_EVENT'], dtype='object')
batch size:  16


## Train Loop

In [15]:
settings = {
    "model": {
        "rna_dim": 5540,
        "scna_dim": 5507,
        "methy_dim": 4846,
        "use_rna": True,
        "use_scna": True,
        "use_methy": False,
        "hidden_dim": 64,
        "self_att": False,
        "cross_att": False,
        "device": device
    },
    "train": {
        "lr": 0.0001,
        "l2": 1e-5,
        "epochs": 5,
        "epoch_mod": 1
    }
}

In [16]:
### Training Loop
from model import SWEEM
from loss import temp_loss, neg_par_log_likelihood

epoch_train_losses = []
epoch_val_losses   = []

model = SWEEM(**settings["model"])
model.to(device)

# criterion = temp_loss
# binary cross entropy loss
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=settings["train"]["lr"], weight_decay=settings["train"]["l2"])

for epoch in range(settings["train"]["epochs"]):
    epoch_train_loss = 0
    epoch_val_loss   = 0
    
    ## Training
    # print(f"Epoch {epoch + 1} training:")
    # progress_bar = tqdm(range(len(train_dataloader)))

    model.train()
    for i, (batchX, batchY) in enumerate(train_dataloader):        
        batchX = batchX.to(device)
        rna = batchX[:, :5540]
        scna = batchX[:, 5540:11047]
        methy = batchX[:, 11047:]
        time = batchY[:,0].reshape(-1, 1).to(device)
        event = batchY[:,1].reshape(-1, 1).to(device)

        outputs = model(event, rna=rna, scna=scna, methy=methy)
        
        # loss = criterion(outputs, time, event)
        loss = criterion(outputs, event)
        # print(f"batch {i+1} loss: ", loss)
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
        
        epoch_train_loss += loss.item()
        # progress_bar.update(1)
        
    ## Validation
    # print(f"Epoch {epoch + 1} validation:")
    # progress_bar = tqdm(range(len(test_dataloader)))
    
    model.eval()
    with torch.no_grad():
        for i, (batchX, batchY) in enumerate(test_dataloader):
            batchX = batchX.to(device)
            rna = batchX[:, :5540]
            scna = batchX[:, 5540:11047]
            methy = batchX[:, 11047:]
            time = batchY[:,0].reshape(-1, 1).to(device)
            event = batchY[:,1].reshape(-1, 1).to(device)
            outputs = model(event, rna=rna, scna=scna, methy=methy)

            #outputs = model(rna, scna, methy, event)
            # loss = criterion(outputs, time, event)
            loss = criterion(outputs, event)
            epoch_val_loss += loss.item()
            # progress_bar.update(1)

    # Save and print losses
    epoch_train_loss /= len(train_dataloader)
    epoch_val_loss /= len(test_dataloader)
    epoch_train_losses.append(epoch_train_loss)
    epoch_val_losses.append(epoch_val_loss)
    if epoch % settings["train"]["epoch_mod"] == 0:
        print(f"Epoch {epoch + 1} training loss: {epoch_train_loss}")
        print(f"Epoch {epoch + 1} validation loss: {epoch_val_loss}")

Epoch 1 training loss: 1.4706139067808788
Epoch 1 validation loss: 1.409365137418111
Epoch 2 training loss: 0.835314237823089
Epoch 2 validation loss: 0.7055389583110809
Epoch 3 training loss: 0.6451539024710655
Epoch 3 validation loss: 0.6701457848151525
Epoch 4 training loss: 0.6479893177747726
Epoch 4 validation loss: 0.6169155240058899
Epoch 5 training loss: 0.6121218030651411
Epoch 5 validation loss: 0.5708580911159515


In [17]:
### Sanity Check
model.eval()
with torch.no_grad():
    for (batchX, batchY) in test_dataloader:
        batchX = batchX.to(device)
        rna = batchX[:, :5540].to(device)
        scna = batchX[:, 5540:11047].to(device)
        methy = batchX[:, 11047:].to(device)
        time = batchY[:,0].reshape(-1, 1).to(device)
        event = batchY[:,1].reshape(-1, 1).to(device)
        outputs = model(event, rna=rna, scna=scna, methy=methy)
        
        # concat torch tensors
        table = torch.cat((time, event, outputs), 1)
        
        # print row by row
        print("Sanity Check:")
        print("time, event, predicted")
        for row in table:
            print(row.tolist())
        break

Sanity Check:
time, event, predicted
[442.0, 0.0, 0.1938619166612625]
[292.0, 0.0, 0.18911351263523102]
[564.0, 0.0, 0.2952427566051483]
[722.0, 1.0, 0.5404471755027771]
[241.0, 1.0, 0.4509541690349579]
[326.0, 0.0, 0.24414752423763275]
[434.0, 0.0, 0.32922402024269104]
[63.0, 0.0, 0.09088050574064255]
[1585.0, 1.0, 0.4302815794944763]
[544.0, 0.0, 0.33130818605422974]
[585.0, 0.0, 0.31077396869659424]
[245.0, 1.0, 0.32782426476478577]
[609.0, 0.0, 0.3032451868057251]
[2860.0, 0.0, 0.40283891558647156]
[946.0, 0.0, 0.3657281994819641]
[162.0, 0.0, 0.30680325627326965]


In [21]:
import checkpoint

# save and load for training

checkpoint.save("./sweem.model", model, settings, optimizer, epoch_train_losses, epoch_val_losses, inference=False)

model, settings, optimizer, epoch_train_losses, epoch_val_losses = checkpoint.load("./sweem.model", SWEEM, optim.Adam, inference=False)

print("settings: ", settings)

settings:  {'model': {'rna_dim': 5540, 'scna_dim': 5507, 'methy_dim': 4846, 'use_rna': True, 'use_scna': True, 'use_methy': False, 'hidden_dim': 64, 'self_att': False, 'cross_att': False, 'device': device(type='cpu')}, 'train': {'lr': 0.0001, 'l2': 1e-05, 'epochs': 5, 'epoch_mod': 1}}


In [23]:
# save and load for inference

checkpoint.save("./sweem_inf.model", model, settings, inference=True)

model, settings = checkpoint.load("./sweem_inf.model", SWEEM, inference=True)

print("settings: ", settings)

settings:  {'model': {'rna_dim': 5540, 'scna_dim': 5507, 'methy_dim': 4846, 'use_rna': True, 'use_scna': True, 'use_methy': False, 'hidden_dim': 64, 'self_att': False, 'cross_att': False, 'device': device(type='cpu')}, 'train': {'lr': 0.0001, 'l2': 1e-05, 'epochs': 5, 'epoch_mod': 1}}
