In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torchvision.transforms import ToTensor

from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.kl import kl_divergence
import torchvision
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset, DataLoader

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import random
from random import sample
import seaborn as sns

In [4]:
!rm -rf runs/tcga*

In [15]:
tcga_tybalt_file_location = 'data/pancan_scaled_zeroone_rnaseq.tsv.gz'

In [16]:
rnaseq_df = pd.read_table(tcga_tybalt_file_location)
rnaseq_df.drop(columns=rnaseq_df.columns[0], axis=1,  inplace=True)
rnaseq_df = rnaseq_df.dropna()
print(rnaseq_df.shape)
rnaseq_df.head(2)

(10459, 5000)


Unnamed: 0,RPS4Y1,XIST,KRT5,AGR2,CEACAM5,KRT6A,KRT14,CEACAM6,DDX3Y,KDM5D,...,FAM129A,C8orf48,CDK5R1,FAM81A,C13orf18,GDPD3,SMAGP,C2orf85,POU5F1B,CHST2
0,0.678296,0.28991,0.03423,0.0,0.0,0.084731,0.031863,0.037709,0.746797,0.687833,...,0.44061,0.428782,0.732819,0.63434,0.580662,0.294313,0.458134,0.478219,0.168263,0.638497
1,0.200633,0.654917,0.181993,0.0,0.0,0.100606,0.050011,0.092586,0.103725,0.140642,...,0.620658,0.363207,0.592269,0.602755,0.610192,0.374569,0.72242,0.271356,0.160465,0.60256


In [17]:
test_set_percent = 0.2
rnaseq_df_test = rnaseq_df.sample(frac=test_set_percent)
rnaseq_df_train = rnaseq_df.drop(rnaseq_df_test.index)

In [18]:
# Define custom dataset class
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.tensor(self.data.iloc[idx].values, dtype=torch.float32)


In [19]:
train_dataset = CustomDataset(rnaseq_df_train)
test_dataset = CustomDataset(rnaseq_df_test)

In [20]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=50, shuffle=True)
validation_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=50, shuffle=False)

In [21]:
class VAE(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: list, z_dim):
        super(VAE, self).__init__()
        
        self.z_dim = z_dim
        
        self.encoder_layers = nn.ModuleList([nn.Linear(input_dim, hidden_dim[0])])
        self.decoder_layers = nn.ModuleList([nn.Linear(hidden_dim[0], input_dim)])
                
        if len(hidden_dim)>1:
            for i in range(len(hidden_dim)-1):
                self.encoder_layers.append(nn.Linear(hidden_dim[i], hidden_dim[i+1]))
                self.decoder_layers.insert(0, nn.Linear(hidden_dim[i+1], hidden_dim[i]))
                
        self.encoder_layers.append(nn.Linear(hidden_dim[-1], 2 * z_dim))
        self.batchnorm = nn.BatchNorm1d(z_dim)
        self.decoder_layers.insert(0, nn.Linear(z_dim, hidden_dim[-1]))

        
    def encoder(self, x):
        for idx, layer in enumerate(self.encoder_layers):
            x = layer(x)
            if idx < len(self.encoder_layers) - 1:
                # x = F.dropout(x, 0.01)
                x = F.relu(x)
                #x = nn.BatchNorm1d(x)
        return x[...,:self.z_dim], x[...,self.z_dim:] # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        # std = torch.abs(log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
        
    def decoder(self, z):
        for idx, layer in enumerate(self.decoder_layers):
            z = layer(z)
            if idx < len(self.decoder_layers) - 1:
                # x = F.dropout(x, 0.01)
                z = F.relu(z)
        return torch.sigmoid(z) 
    
    def forward(self, x):
        mu, log_var = self.encoder(x.view(-1, input_dim))
        mu = self.batchnorm(mu)
        log_var = self.batchnorm(log_var)
    #    z = self.sampling(mu, log_var)
        latent = MultivariateNormal(loc = mu, 
                                    scale_tril=torch.diag_embed(torch.exp(0.5*log_var)))
        z = latent.rsample()
           
    #    return self.decoder(z), mu, log_var
        return self.decoder(z), latent

    @staticmethod
    def loss_function(recon_x, x, mu, log_var):
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, input_dim), reduction='sum')
        KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        return BCE + KLD
    
    @staticmethod
    def loss_function_dist(recon_x, x, latent, input_dim):
        prior = MultivariateNormal(loc = torch.zeros(latent.mean.shape[1]),
                                   scale_tril=torch.eye(latent.mean.shape[1]))
        
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, input_dim), reduction='sum')
        KLD = torch.sum(kl_divergence(latent, prior))
        return BCE + KLD

VAE model

In [22]:
 def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(train_loader):
        
        # Every data instance 
        data = data.to(DEVICE)

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        recon_batch, latent = vae(data)

        # Compute the loss and its gradients
        loss = VAE.loss_function_dist(recon_batch, data, latent, input_dim)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 100 == 99:
            last_loss = running_loss / 100.0 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(train_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

In [23]:
# build model
input_dim=rnaseq_df.shape[1]
#vae = VAE(input_dim=input_dim, hidden_dim=[100,100], z_dim=100)
vae = VAE(input_dim=input_dim, hidden_dim=[100], z_dim=100)
# if torch.backends.mps.is_available():
#     DEVICE = 'mps'
# else:
#train_loader = torch.utils.data.DataLoader(dataset=torch.Tensor(torch.randn(30, 5000)), batch_size=100, shuffle=True)

DEVICE = 'cpu'
    
vae.to(DEVICE)

optimizer = optim.Adam(vae.parameters(), lr=0.0005)


# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/tcga_trainer_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 20

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    vae.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)

    # We don't need gradients on to do reporting
    vae.train(False)

    running_vloss = 0.0
    for i, vdata in enumerate(validation_loader):
        vinputs = vdata
        voutputs, latent = vae(vinputs)
        
        vloss = VAE.loss_function_dist(voutputs, vinputs, latent, input_dim)
        running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        #torch.save(vae.state_dict(), model_path)

    epoch_number += 1

EPOCH 1:
  batch 100 loss: 159118.78109375
LOSS train 159118.78109375 valid 207546.734375
EPOCH 2:
  batch 100 loss: 143969.975
LOSS train 143969.975 valid 187245.34375
EPOCH 3:
  batch 100 loss: 141760.40515625
LOSS train 141760.40515625 valid 172713.28125
EPOCH 4:
  batch 100 loss: 140533.98828125
LOSS train 140533.98828125 valid 173668.71875
EPOCH 5:
  batch 100 loss: 140028.98203125
LOSS train 140028.98203125 valid 164026.125
EPOCH 6:
  batch 100 loss: 139635.4678125
LOSS train 139635.4678125 valid 161267.953125
EPOCH 7:
  batch 100 loss: 139351.01046875
LOSS train 139351.01046875 valid 160423.046875
EPOCH 8:
  batch 100 loss: 138997.33484375
LOSS train 138997.33484375 valid 160612.171875
EPOCH 9:
  batch 100 loss: 138725.40625
LOSS train 138725.40625 valid 157223.515625
EPOCH 10:
  batch 100 loss: 138685.10375
LOSS train 138685.10375 valid 155908.421875
EPOCH 11:
  batch 100 loss: 138485.31078125
LOSS train 138485.31078125 valid 154812.625
EPOCH 12:
  batch 100 loss: 138628.094218

Shuffle the colums of data

In [24]:
torch.save(vae.state_dict(), 'vae_weights.pth')

BioBomb data

In [25]:
rnaseq_df_test.to_csv('data/rnaseq_df_test.csv')

In [26]:
rnaseq_df_test

Unnamed: 0,RPS4Y1,XIST,KRT5,AGR2,CEACAM5,KRT6A,KRT14,CEACAM6,DDX3Y,KDM5D,...,FAM129A,C8orf48,CDK5R1,FAM81A,C13orf18,GDPD3,SMAGP,C2orf85,POU5F1B,CHST2
1980,0.000000,0.676411,0.538535,0.829060,0.371617,0.076709,0.493035,0.495491,0.000000,0.000000,...,0.641401,0.258891,0.520563,0.400908,0.437579,0.560402,0.483713,0.054940,0.333915,0.313550
5638,0.000000,0.601335,0.791538,0.580519,0.711329,0.771928,0.557838,0.656916,0.000000,0.000000,...,0.608653,0.157189,0.526314,0.651298,0.422983,0.648942,0.718139,0.030541,0.219439,0.447640
8240,0.000000,0.801651,0.515508,0.526579,0.119657,0.212524,0.086728,0.253842,0.000000,0.000000,...,0.278018,0.160775,0.496794,0.384108,0.434125,0.476615,0.634136,0.000000,0.433654,0.308404
3555,0.767083,0.132404,0.000000,0.461636,0.000000,0.000000,0.000000,0.046623,0.774512,0.736268,...,0.359250,0.657670,0.340666,0.708832,0.358783,0.328927,0.500310,0.000000,0.224678,0.514864
3936,0.022867,0.731514,0.140167,0.749021,0.750316,0.031328,0.000000,0.716501,0.000000,0.000000,...,0.851758,0.450602,0.528139,0.605051,0.601097,0.446412,0.684836,0.074063,0.170475,0.413623
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8798,0.529371,0.278337,0.084118,0.626482,0.848247,0.239914,0.029636,0.871150,0.565451,0.569982,...,0.407657,0.252476,0.523529,0.654260,0.860068,0.551171,0.780504,0.000000,0.575375,0.269600
8383,0.000000,0.689733,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.586353,0.041414,0.496126,0.915666,0.337744,0.167310,0.481890,0.000000,0.371185,0.467613
3652,0.067000,0.797838,0.230361,0.000000,0.057428,0.142460,0.103106,0.024850,0.033195,0.033241,...,0.587226,0.556678,0.428475,0.360708,0.344114,0.503687,0.616717,0.200213,0.582584,0.630867
9487,0.706362,0.155848,0.100019,0.000000,0.000000,0.000000,0.021514,0.000000,0.742599,0.759649,...,0.223681,0.380915,0.811927,0.718884,0.443742,0.379867,0.292060,0.771193,0.227811,0.516097
