In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torchvision.transforms import ToTensor

from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.kl import kl_divergence
import torchvision
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset, DataLoader

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import random
from random import sample
import seaborn as sns

In [6]:
!rm -rf runs/tcga*

zsh:1: no matches found: runs/tcga*


In [7]:
tcga_tybalt_file_location = 'data/pancan_scaled_zeroone_rnaseq.tsv.gz'

In [9]:
rnaseq_df = pd.read_table(tcga_tybalt_file_location)
rnaseq_df.drop(columns=rnaseq_df.columns[0], axis=1,  inplace=True)
rnaseq_df = rnaseq_df.dropna()
print(rnaseq_df.shape)
rnaseq_df.head(2)

(10459, 5000)


Unnamed: 0,RPS4Y1,XIST,KRT5,AGR2,CEACAM5,KRT6A,KRT14,CEACAM6,DDX3Y,KDM5D,...,FAM129A,C8orf48,CDK5R1,FAM81A,C13orf18,GDPD3,SMAGP,C2orf85,POU5F1B,CHST2
0,0.678296,0.28991,0.03423,0.0,0.0,0.084731,0.031863,0.037709,0.746797,0.687833,...,0.44061,0.428782,0.732819,0.63434,0.580662,0.294313,0.458134,0.478219,0.168263,0.638497
1,0.200633,0.654917,0.181993,0.0,0.0,0.100606,0.050011,0.092586,0.103725,0.140642,...,0.620658,0.363207,0.592269,0.602755,0.610192,0.374569,0.72242,0.271356,0.160465,0.60256


In [10]:
test_set_percent = 0.2
rnaseq_df_test = rnaseq_df.sample(frac=test_set_percent)
rnaseq_df_train = rnaseq_df.drop(rnaseq_df_test.index)

In [27]:
# Define custom dataset class
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.tensor(self.data.iloc[idx].values, dtype=torch.float32)


In [28]:
train_dataset = CustomDataset(rnaseq_df_train)
test_dataset = CustomDataset(rnaseq_df_test)

In [29]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=50, shuffle=True)
validation_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=50, shuffle=False)

In [30]:
class VAE(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: list, z_dim):
        super(VAE, self).__init__()
        
        self.z_dim = z_dim
        
        self.encoder_layers = nn.ModuleList([nn.Linear(input_dim, hidden_dim[0])])
        self.decoder_layers = nn.ModuleList([nn.Linear(hidden_dim[0], input_dim)])
                
        if len(hidden_dim)>1:
            for i in range(len(hidden_dim)-1):
                self.encoder_layers.append(nn.Linear(hidden_dim[i], hidden_dim[i+1]))
                self.decoder_layers.insert(0, nn.Linear(hidden_dim[i+1], hidden_dim[i]))
                
        self.encoder_layers.append(nn.Linear(hidden_dim[-1], 2 * z_dim))
        self.batchnorm = nn.BatchNorm1d(z_dim)
        self.decoder_layers.insert(0, nn.Linear(z_dim, hidden_dim[-1]))

        
    def encoder(self, x):
        for idx, layer in enumerate(self.encoder_layers):
            x = layer(x)
            if idx < len(self.encoder_layers) - 1:
                # x = F.dropout(x, 0.01)
                x = F.relu(x)
                #x = nn.BatchNorm1d(x)
        return x[...,:self.z_dim], x[...,self.z_dim:] # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        # std = torch.abs(log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
        
    def decoder(self, z):
        for idx, layer in enumerate(self.decoder_layers):
            z = layer(z)
            if idx < len(self.decoder_layers) - 1:
                # x = F.dropout(x, 0.01)
                z = F.relu(z)
        return torch.sigmoid(z) 
    
    def forward(self, x):
        mu, log_var = self.encoder(x.view(-1, input_dim))
        mu = self.batchnorm(mu)
        log_var = self.batchnorm(log_var)
    #    z = self.sampling(mu, log_var)
        latent = MultivariateNormal(loc = mu, 
                                    scale_tril=torch.diag_embed(torch.exp(0.5*log_var)))
        z = latent.rsample()
           
    #    return self.decoder(z), mu, log_var
        return self.decoder(z), latent

    @staticmethod
    def loss_function(recon_x, x, mu, log_var):
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, input_dim), reduction='sum')
        KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        return BCE + KLD
    
    @staticmethod
    def loss_function_dist(recon_x, x, latent, input_dim):
        prior = MultivariateNormal(loc = torch.zeros(latent.mean.shape[1]),
                                   scale_tril=torch.eye(latent.mean.shape[1]))
        
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, input_dim), reduction='sum')
        KLD = torch.sum(kl_divergence(latent, prior))
        return BCE + KLD

VAE model

In [16]:
 def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(train_loader):
        
        # Every data instance 
        data = data.to(DEVICE)

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        recon_batch, latent = vae(data)

        # Compute the loss and its gradients
        loss = VAE.loss_function_dist(recon_batch, data, latent, input_dim)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 100 == 99:
            last_loss = running_loss / 100.0 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(train_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

In [17]:
# build model
input_dim=rnaseq_df.shape[1]
#vae = VAE(input_dim=input_dim, hidden_dim=[100,100], z_dim=100)
vae = VAE(input_dim=input_dim, hidden_dim=[100], z_dim=100)
# if torch.backends.mps.is_available():
#     DEVICE = 'mps'
# else:
#train_loader = torch.utils.data.DataLoader(dataset=torch.Tensor(torch.randn(30, 5000)), batch_size=100, shuffle=True)

DEVICE = 'cpu'
    
vae.to(DEVICE)

optimizer = optim.Adam(vae.parameters(), lr=0.0005)


# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/tcga_trainer_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 20

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    vae.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)

    # We don't need gradients on to do reporting
    vae.train(False)

    running_vloss = 0.0
    for i, vdata in enumerate(validation_loader):
        vinputs = vdata
        voutputs, latent = vae(vinputs)
        
        vloss = VAE.loss_function_dist(voutputs, vinputs, latent, input_dim)
        running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        #torch.save(vae.state_dict(), model_path)

    epoch_number += 1

EPOCH 1:
  batch 100 loss: 159190.97359375
LOSS train 159190.97359375 valid 186473.984375
EPOCH 2:
  batch 100 loss: 144129.82484375
LOSS train 144129.82484375 valid 158527.640625
EPOCH 3:
  batch 100 loss: 141895.46359375
LOSS train 141895.46359375 valid 151911.84375
EPOCH 4:
  batch 100 loss: 140772.69609375
LOSS train 140772.69609375 valid 151951.03125
EPOCH 5:
  batch 100 loss: 140114.9878125
LOSS train 140114.9878125 valid 148374.75
EPOCH 6:
  batch 100 loss: 139427.70390625
LOSS train 139427.70390625 valid 148817.921875
EPOCH 7:
  batch 100 loss: 139188.25890625
LOSS train 139188.25890625 valid 147931.75
EPOCH 8:
  batch 100 loss: 138941.07171875
LOSS train 138941.07171875 valid 147887.828125
EPOCH 9:
  batch 100 loss: 138843.95921875
LOSS train 138843.95921875 valid 146943.671875
EPOCH 10:
  batch 100 loss: 138710.41921875
LOSS train 138710.41921875 valid 147469.421875
EPOCH 11:
  batch 100 loss: 138632.77
LOSS train 138632.77 valid 148356.59375
EPOCH 12:
  batch 100 loss: 13853

Shuffle the colums of data

In [18]:
torch.save(vae.state_dict(), 'vae_weights.pth')

BioBomb data

In [20]:
rnaseq_df_test.to_csv('data/rnaseq_df_test.csv')

In [21]:
rnaseq_df_test

Unnamed: 0,RPS4Y1,XIST,KRT5,AGR2,CEACAM5,KRT6A,KRT14,CEACAM6,DDX3Y,KDM5D,...,FAM129A,C8orf48,CDK5R1,FAM81A,C13orf18,GDPD3,SMAGP,C2orf85,POU5F1B,CHST2
2018,0.000000,0.711395,0.079027,0.000000,0.000000,0.023197,0.121969,0.062543,0.000000,0.000000,...,0.572671,0.512329,0.461674,0.491376,0.501528,0.495456,0.644358,0.113569,0.429427,0.582995
10150,0.781910,0.031101,0.684937,0.700038,0.230084,0.254548,0.416508,0.361120,0.781960,0.786848,...,0.659345,0.208477,0.458558,0.505048,0.454214,0.501694,0.586376,0.263308,0.332541,0.388225
7324,0.788988,0.079403,0.457493,0.680465,0.366980,0.285862,0.310623,0.352631,0.769542,0.737494,...,0.293865,0.247636,0.407056,0.375633,0.160404,0.778409,0.753785,0.043088,0.535991,0.555555
2055,0.747995,0.046092,0.130644,0.177930,0.129670,0.115420,0.039233,0.157857,0.729496,0.689956,...,0.479202,0.082720,0.480683,0.638266,0.464715,0.302999,0.599158,0.169778,0.289458,0.360866
3477,0.000000,0.798650,0.712625,0.560407,0.061522,0.166504,0.664300,0.331431,0.000000,0.000000,...,0.799483,0.487555,0.380651,0.409978,0.531780,0.534171,0.609271,0.146475,0.193681,0.412169
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7975,0.736170,0.152171,0.476311,0.702972,0.768507,0.343426,0.281088,0.758171,0.760010,0.740522,...,0.728246,0.450068,0.435667,0.488870,0.465533,0.509826,0.706824,0.165557,0.337093,0.427780
8056,0.038897,0.759848,0.311547,0.646421,0.384474,0.253147,0.215385,0.431012,0.077446,0.045757,...,0.654467,0.365170,0.405802,0.517247,0.563040,0.391213,0.665276,0.427613,0.196777,0.596736
8791,0.530575,0.151500,0.332931,0.694337,0.931342,0.262074,0.335299,0.671591,0.751476,0.763546,...,0.367071,0.262613,0.361214,0.661306,0.897802,0.482602,0.739770,0.044569,0.468041,0.341742
1265,0.764453,0.174147,0.824906,0.453406,0.608007,0.842768,0.765355,0.648007,0.779927,0.729218,...,0.478843,0.438191,0.567193,0.274560,0.363217,0.512191,0.693135,0.110666,0.242758,0.675748
