# SWEEM Model Implementation

This file is used to illustrate the preprocessing, training, and evaluation 
stages of our model. Comments and more information will be provided per section.

## Preprocessing

Here we load in the data and establish our train-test split. We also set up dataloaders
for us to be able to properly use the data within our training loop.

In [1]:
%load_ext autoreload
%autoreload 2

import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Running on", device)

Running on cpu


In [8]:
data = pd.read_csv('./Data/OmicsData/data.csv')
# data = pd.concat([data.iloc[:, :5541], data.iloc[:, -2:]], axis=1)

# Separate to make sure that there's an even distribution of 1s and 0s in train and test
data_ones = data[data.iloc[:, -1] == 1]
data_zeros = data[data.iloc[:, -1] == 0]

# Split the data into train and validation sets.
# Train test split is 80 20
train_data_ones, test_data_ones, train_labels_ones, test_labels_ones = train_test_split(
    data_ones.iloc[:, 1:-2], data_ones.iloc[:, -2:], test_size=0.2, random_state=42)
train_data_zeros, test_data_zeros, train_labels_zeros, test_labels_zeros = train_test_split(
    data_zeros.iloc[:, 1:-2], data_zeros.iloc[:, -2:], test_size=0.2, random_state=42)

# Concatenate in the end to make train and test
train_data = pd.concat((train_data_ones, train_data_zeros))
train_labels = pd.concat((train_labels_ones, train_labels_zeros))
test_data = pd.concat((test_data_ones, test_data_zeros)) 
test_labels = pd.concat((test_labels_ones, test_labels_zeros))

# 352 0's
# 123 1's

# Number of genes in rna: 5540
# Number of genes in scna: 5507
# Number of genes in methy: 4846
# Total number of genes: 15893
# Total number of samples in the final dataset: 475

#rna
print(train_data.columns[0])
print(train_data.columns[5539])
#scna
print(train_data.columns[5540])
print(train_data.columns[11046])
#methy
print(train_data.columns[11047])
print(train_data.columns[15892])

#labels
print(train_labels.columns)

# Create Tensor datasets
train_dataset = TensorDataset(torch.tensor(train_data.values, dtype=torch.float32), torch.tensor(train_labels.values, dtype=torch.float32))
#val_dataset   = TensorDataset(torch.tensor(validation_data.values, dtype=torch.float32), torch.tensor(validation_labels.values, dtype=torch.float32))
test_dataset  = TensorDataset(torch.tensor(test_data.values, dtype=torch.float32), torch.tensor(test_labels.values, dtype=torch.float32))

# Create DataLoader objects
batch_size = 32
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
#val_dataloader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False)
test_dataloader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False)

print("batch size: ", batch_size)

ZYX_rna
A2M_rna
ZYX_scna
A2M_scna
ZYX_methy
A2M_methy
Index(['OS_DAYS', 'OS_EVENT'], dtype='object')


## Train Loop

In [9]:
### Training Loop
from model import SWEEM
from loss import temp_loss
num_epochs = 2
epoch_train_losses = []
epoch_val_losses   = []

model = SWEEM(rna_dim = 5540, 
              scna_dim = 5507, 
              methy_dim = 4846, 
              hidden_dim = 128, 
              self_att = False, 
              cross_att = False)
model.to(device)

# criterion = temp_loss
# binary cross entropy loss
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.005)

for epoch in range(num_epochs):
    epoch_train_loss = 0
    epoch_val_loss   = 0
    
    ## Training
    print(f"Epoch {epoch + 1} training:")
    progress_bar = tqdm(range(len(train_dataloader)))

    model.train()
    for i, (batchX, batchY) in enumerate(train_dataloader):        
        batchX = batchX.to(device)
        rna = batchX[:, :5540]
        scna = batchX[:, 5540:11047]
        methy = batchX[:, 11047:]
        time = batchY[:,0].reshape(-1, 1).to(device)
        event = batchY[:,1].reshape(-1, 1).to(device)
        
        outputs = model(rna, scna, methy, event)
        
        # loss = criterion(outputs, time, event)
        loss = criterion(outputs, event)
        print(f"batch {i+1} loss: ", loss)
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
        
        epoch_train_loss += loss.item()
        progress_bar.update(1)
        
    ## Validation
    print(f"Epoch {epoch + 1} validation:")
    progress_bar = tqdm(range(len(test_dataloader)))
    
    model.eval()
    with torch.no_grad():
        for i, (batchX, batchY) in enumerate(test_dataloader):
            batchX = batchX.to(device)
            rna = batchX[:, :5540]
            scna = batchX[:, 5540:11047]
            methy = batchX[:, 11047:]
            time = batchY[:,0].reshape(-1, 1).to(device)
            event = batchY[:,1].reshape(-1, 1).to(device)
            outputs = model(rna, scna, methy, event)
            # loss = criterion(outputs, time, event)
            loss = criterion(outputs, event)
            epoch_val_loss += loss.item()
            progress_bar.update(1)

    # Save and print losses
    epoch_train_loss /= len(train_dataloader)
    epoch_val_loss /= len(test_dataloader)
    epoch_train_losses.append(epoch_train_loss)
    epoch_val_losses.append(epoch_val_loss)
    print(f"Epoch {epoch + 1} training loss: {epoch_train_loss}")
    print(f"Epoch {epoch + 1} validation loss: {epoch_val_loss}")

Epoch 1 training:


100%|██████████| 1/1 [01:26<00:00, 86.80s/it]


loss:  tensor(0.7814, grad_fn=<BinaryCrossEntropyBackward0>)




loss:  tensor(25., grad_fn=<BinaryCrossEntropyBackward0>)




loss:  tensor(31.2500, grad_fn=<BinaryCrossEntropyBackward0>)




loss:  tensor(21.8750, grad_fn=<BinaryCrossEntropyBackward0>)




loss:  tensor(25., grad_fn=<BinaryCrossEntropyBackward0>)




loss:  tensor(21.8750, grad_fn=<BinaryCrossEntropyBackward0>)




loss:  tensor(12.5000, grad_fn=<BinaryCrossEntropyBackward0>)




loss:  tensor(34.3750, grad_fn=<BinaryCrossEntropyBackward0>)




loss:  tensor(12.5000, grad_fn=<BinaryCrossEntropyBackward0>)




loss:  tensor(40.6250, grad_fn=<BinaryCrossEntropyBackward0>)




loss:  tensor(28.1250, grad_fn=<BinaryCrossEntropyBackward0>)


KeyboardInterrupt: 

In [9]:
### Sanity Check
model.eval()
with torch.no_grad():
    for (batchX, batchY) in test_dataloader:
        batchX = batchX.to(device)
        rna = batchX[:, :5540]
        scna = batchX[:, 5540:11047]
        methy = batchX[:, 11047:]
        time = batchY[:,0].reshape(-1, 1).to(device)
        event = batchY[:,1].reshape(-1, 1).to(device)
        outputs = model(rna, scna, methy, event)
        print(f"times: {time}")
        print(f"events: {event}")
        print(f"predictions: {outputs}")
        break

times: tensor([[ 706.],
        [ 772.],
        [ 184.],
        [ 455.],
        [ 407.],
        [ 544.],
        [1585.],
        [ 411.],
        [ 956.],
        [1354.],
        [ 835.],
        [2381.],
        [ 648.],
        [1257.],
        [ 748.],
        [6423.],
        [1401.],
        [ 442.],
        [ 993.],
        [2107.],
        [ 900.],
        [  96.],
        [ 491.],
        [5546.],
        [ 629.],
        [ 547.],
        [ 122.],
        [3725.],
        [  62.],
        [ 605.],
        [1220.],
        [1262.]])
events: tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [1.],
        [1.],
      