# SWEEM Model Implementation

This file is used to illustrate the preprocessing, training, and evaluation 
stages of our model. Comments and more information will be provided per section.

## Preprocessing

Here we load in the data and establish our train-test split. We also set up dataloaders
for us to be able to properly use the data within our training loop.

In [4]:
%load_ext autoreload
%autoreload 2

import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Running on", device)

Running on cpu


In [6]:
data = pd.read_csv('./Data/OmicsData/data.csv')
# data = pd.concat([data.iloc[:, :5541], data.iloc[:, -2:]], axis=1)

# Split the data into train and validation sets.
# Train test validaiton split is 80 10 10
train_data, test_data, train_labels, test_labels = train_test_split(
    data.iloc[:, 1:-2], data.iloc[:, -2:], test_size=0.2, random_state=42)

test_data, validation_data, test_labels, validation_labels = train_test_split(
    test_data, test_labels, test_size=0.5, random_state=42)

print(data.shape)
print(train_data.shape)
print(test_data.shape)
print(validation_data.shape)

# Number of genes in rna: 5540
# Number of genes in scna: 5507
# Number of genes in methy: 4846
# Total number of genes: 15893
# Total number of samples in the final dataset: 475

#rna
print(train_data.columns[0])
print(train_data.columns[5539])
#scna
print(train_data.columns[5540])
print(train_data.columns[11046])
#methy
print(train_data.columns[11047])
print(train_data.columns[15892])

#labels
print(train_labels.columns)

(475, 15896)
(380, 15893)
(47, 15893)
(48, 15893)
ZYX_rna
A2M_rna
ZYX_scna
A2M_scna
ZYX_methy
A2M_methy
Index(['OS_DAYS', 'OS_EVENT'], dtype='object')


In [7]:
# Create Tensor datasets
train_dataset = TensorDataset(torch.tensor(train_data.values, dtype=torch.float32), torch.tensor(train_labels.values, dtype=torch.float32))
val_dataset   = TensorDataset(torch.tensor(validation_data.values, dtype=torch.float32), torch.tensor(validation_labels.values, dtype=torch.float32))
test_dataset  = TensorDataset(torch.tensor(test_data.values, dtype=torch.float32), torch.tensor(test_labels.values, dtype=torch.float32))

# Create DataLoader objects
batch_size = 32
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False)
test_dataloader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False)

## Define Self-Attention Model

In [5]:
class LinearRegression(nn.Module):
    def __init__(self, input_size):
        super(LinearRegression, self).__init__()
        self.linear1 = nn.Linear(input_size, 256)
        self.linear2 = nn.Linear(256, 1)
        self.relu = nn.ReLU()
        self.identity = nn.Identity()

    def forward(self, x):
        out = self.linear1(x)
        out = self.relu(out)
        out = self.linear2(out)
        # out = torch.sigmoid(out)
        return out

model = LinearRegression(15893)
print(model)

LinearRegression(
  (linear1): Linear(in_features=15893, out_features=256, bias=True)
  (linear2): Linear(in_features=256, out_features=1, bias=True)
  (relu): ReLU()
  (identity): Identity()
)


## Train Loop with Alt Dataloader

In [8]:
### Training Loop
from model import SelfAttentionModel
from loss import temp_loss, neg_par_log_likelihood
num_epochs = 1
epoch_train_losses = []
epoch_val_losses   = []

model = SelfAttentionModel(5540, 5507, 4846)
model.to(device)

criterion = temp_loss
optimizer = optim.Adam(model.parameters(), lr=0.01)

for epoch in range(num_epochs):
    epoch_train_loss = 0
    epoch_val_loss   = 0
    print(f"Epoch {epoch + 1} training:")
    progress_bar = tqdm(range(len(train_dataloader)))

    ## Training
    model.train()
    for (batchX, batchY) in train_dataloader:        
        batchX = batchX.to(device)
        rna = batchX[:, :5540]
        scna = batchX[:, 5540:11047]
        methy = batchX[:, 11047:]
        time = batchY[:,0].reshape(-1, 1).to(device)
        event = batchY[:,1].reshape(-1, 1).to(device)
        
        outputs = model(rna, scna, methy, event)
        
        loss = criterion(outputs, time, event)
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
        
        epoch_train_loss += loss.item()
        progress_bar.update(1)
        
        ## Validation
        model.eval()
        with torch.no_grad():
            for (batchX, batchY) in val_dataloader:
                batchX = batchX.to(device)
                rna = batchX[:, :5540]
                scna = batchX[:, 5540:11047]
                methy = batchX[:, 11047:]
                time = batchY[:,0].reshape(-1, 1).to(device)
                event = batchY[:,1].reshape(-1, 1).to(device)
                outputs = model(rna, scna, methy, event)
                loss = criterion(outputs, time, event)
                epoch_val_loss += loss.item()

        # Save and print losses
        epoch_train_loss /= len(train_dataloader)
        epoch_val_loss /= len(val_dataloader)
        epoch_train_losses.append(epoch_train_loss)
        epoch_val_losses.append(epoch_val_loss)
        print(f"Epoch {epoch + 1} training loss: {epoch_train_loss}")
        print(f"Epoch {epoch + 1} validation loss: {epoch_val_loss}")

Epoch 1 training:


  8%|▊         | 1/12 [00:06<01:12,  6.61s/it]

Epoch 1 training loss: 0.0212319220105807
Epoch 1 validation loss: 0.359375


 17%|█▋        | 2/12 [00:11<00:56,  5.65s/it]

Epoch 1 training loss: 0.027810993500881728
Epoch 1 validation loss: 0.5390625


 25%|██▌       | 3/12 [00:17<00:50,  5.57s/it]

Epoch 1 training loss: 0.025755082791740144
Epoch 1 validation loss: 0.62890625


 33%|███▎      | 4/12 [00:22<00:45,  5.69s/it]

Epoch 1 training loss: 0.015167090232645014
Epoch 1 validation loss: 0.673828125


 42%|████▏     | 5/12 [00:28<00:38,  5.55s/it]

Epoch 1 training loss: 0.024701424186053752
Epoch 1 validation loss: 0.6962890625


 50%|█████     | 6/12 [00:33<00:32,  5.48s/it]

Epoch 1 training loss: 0.020287618682171146
Epoch 1 validation loss: 0.70751953125


 58%|█████▊    | 7/12 [00:38<00:27,  5.44s/it]

Epoch 1 training loss: 0.02512813489018093
Epoch 1 validation loss: 0.713134765625


 67%|██████▋   | 8/12 [00:44<00:21,  5.38s/it]

Epoch 1 training loss: 0.022927344574181746
Epoch 1 validation loss: 0.7159423828125


 75%|███████▌  | 9/12 [00:49<00:16,  5.45s/it]

Epoch 1 training loss: 0.014931445381181812
Epoch 1 validation loss: 0.71734619140625


 83%|████████▎ | 10/12 [00:55<00:10,  5.41s/it]

Epoch 1 training loss: 0.01947345378176515
Epoch 1 validation loss: 0.718048095703125


 92%|█████████▏| 11/12 [01:00<00:05,  5.31s/it]

Epoch 1 training loss: 0.02245612114848043
Epoch 1 validation loss: 0.7183990478515625


100%|██████████| 12/12 [01:05<00:00,  5.42s/it]

Epoch 1 training loss: 0.02865705735996477
Epoch 1 validation loss: 0.7185745239257812


In [9]:
### Sanity Check
model.eval()
with torch.no_grad():
    for (batchX, batchY) in val_dataloader:
        batchX = batchX.to(device)
        rna = batchX[:, :5540]
        scna = batchX[:, 5540:11047]
        methy = batchX[:, 11047:]
        time = batchY[:,0].reshape(-1, 1).to(device)
        event = batchY[:,1].reshape(-1, 1).to(device)
        outputs = model(rna, scna, methy, event)
        print(f"times: {time}")
        print(f"events: {event}")
        print(f"predictions: {outputs}")
        break

times: tensor([[ 706.],
        [ 772.],
        [ 184.],
        [ 455.],
        [ 407.],
        [ 544.],
        [1585.],
        [ 411.],
        [ 956.],
        [1354.],
        [ 835.],
        [2381.],
        [ 648.],
        [1257.],
        [ 748.],
        [6423.],
        [1401.],
        [ 442.],
        [ 993.],
        [2107.],
        [ 900.],
        [  96.],
        [ 491.],
        [5546.],
        [ 629.],
        [ 547.],
        [ 122.],
        [3725.],
        [  62.],
        [ 605.],
        [1220.],
        [1262.]])
events: tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [1.],
        [1.],
      