In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torchvision.transforms import ToTensor

from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.kl import kl_divergence
import torchvision
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset, DataLoader

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import random
from random import sample
import seaborn as sns

In [2]:
tcga_tybalt_file_location = 'data/pancan_scaled_zeroone_rnaseq.tsv.gz'

In [3]:
rnaseq_df = pd.read_table(tcga_tybalt_file_location)
rnaseq_df.drop(columns=rnaseq_df.columns[0], axis=1,  inplace=True)
rnaseq_df = rnaseq_df.dropna()
print(rnaseq_df.shape)
rnaseq_df.head(2)

(10459, 5000)


Unnamed: 0,RPS4Y1,XIST,KRT5,AGR2,CEACAM5,KRT6A,KRT14,CEACAM6,DDX3Y,KDM5D,...,FAM129A,C8orf48,CDK5R1,FAM81A,C13orf18,GDPD3,SMAGP,C2orf85,POU5F1B,CHST2
0,0.678296,0.28991,0.03423,0.0,0.0,0.084731,0.031863,0.037709,0.746797,0.687833,...,0.44061,0.428782,0.732819,0.63434,0.580662,0.294313,0.458134,0.478219,0.168263,0.638497
1,0.200633,0.654917,0.181993,0.0,0.0,0.100606,0.050011,0.092586,0.103725,0.140642,...,0.620658,0.363207,0.592269,0.602755,0.610192,0.374569,0.72242,0.271356,0.160465,0.60256


In [4]:
test_set_percent = 0.1
rnaseq_df_test = rnaseq_df.sample(frac=test_set_percent)
rnaseq_df_train = rnaseq_df.drop(rnaseq_df_test.index)

In [5]:
# Define custom dataset class
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.tensor(self.data.iloc[idx].values, dtype=torch.float32)


In [6]:
train_dataset = CustomDataset(rnaseq_df_train)
test_dataset = CustomDataset(rnaseq_df_test)

In [7]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
validation_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=32, shuffle=True)

In [8]:
class VAE(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: list, z_dim):
        super(VAE, self).__init__()
        
        self.z_dim = z_dim
        
        self.encoder_layers = nn.ModuleList([nn.Linear(input_dim, hidden_dim[0])])
        self.decoder_layers = nn.ModuleList([nn.Linear(hidden_dim[0], input_dim)])
                
        if len(hidden_dim)>1:
            for i in range(len(hidden_dim)-1):
                self.encoder_layers.append(nn.Linear(hidden_dim[i], hidden_dim[i+1]))
                self.decoder_layers.insert(0, nn.Linear(hidden_dim[i+1], hidden_dim[i]))
                
        self.encoder_layers.append(nn.Linear(hidden_dim[-1], 2 * z_dim))
        self.batchnorm = nn.BatchNorm1d(z_dim)
        self.decoder_layers.insert(0, nn.Linear(z_dim, hidden_dim[-1]))

        
    def encoder(self, x):
        for idx, layer in enumerate(self.encoder_layers):
            x = layer(x)
            if idx < len(self.encoder_layers) - 1:
                # x = F.dropout(x, 0.01)
                x = F.relu(x)
                #x = nn.BatchNorm1d(x)
        return x[...,:self.z_dim], x[...,self.z_dim:] # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        # std = torch.abs(log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
        
    def decoder(self, z):
        for idx, layer in enumerate(self.decoder_layers):
            z = layer(z)
            if idx < len(self.decoder_layers) - 1:
                # x = F.dropout(x, 0.01)
                z = F.relu(z)
        return torch.sigmoid(z) 
    
    def forward(self, x):
        mu, log_var = self.encoder(x.view(-1, input_dim))
        mu = self.batchnorm(mu)
        log_var = self.batchnorm(log_var)
    #    z = self.sampling(mu, log_var)
        latent = MultivariateNormal(loc = mu, 
                                    scale_tril=torch.diag_embed(torch.exp(0.5*log_var)))
        z = latent.rsample()
           
    #    return self.decoder(z), mu, log_var
        return self.decoder(z), latent

    @staticmethod
    def loss_function(recon_x, x, mu, log_var):
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, input_dim), reduction='sum')
        KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        return BCE + KLD
    
    @staticmethod
    def loss_function_dist(recon_x, x, latent, input_dim):
        prior = MultivariateNormal(loc = torch.zeros(latent.mean.shape[1]),
                                   scale_tril=torch.eye(latent.mean.shape[1]))
        
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, input_dim), reduction='sum')
        KLD = torch.sum(kl_divergence(latent, prior))
        return BCE + KLD

VAE model

In [11]:
 def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(train_loader):
        
        # Every data instance 
        data = data.to(DEVICE)

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        recon_batch, latent = vae(data)

        # Compute the loss and its gradients
        loss = VAE.loss_function_dist(recon_batch, data, latent, input_dim)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 100 == 99:
            last_loss = running_loss / 100.0 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(train_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

In [12]:
# build model
input_dim=rnaseq_df.shape[1]
#vae = VAE(input_dim=input_dim, hidden_dim=[100,100], z_dim=100)
vae = VAE(input_dim=input_dim, hidden_dim=[100], z_dim=100)
# if torch.backends.mps.is_available():
#     DEVICE = 'mps'
# else:
#train_loader = torch.utils.data.DataLoader(dataset=torch.Tensor(torch.randn(30, 5000)), batch_size=100, shuffle=True)

DEVICE = 'cpu'
    
vae.to(DEVICE)

optimizer = optim.Adam(vae.parameters(), lr=0.0005)


# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/tcga_trainer_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 20

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    vae.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)

    # We don't need gradients on to do reporting
    vae.train(False)

    running_vloss = 0.0
    for i, vdata in enumerate(validation_loader):
        vinputs = vdata
        voutputs, latent = vae(vinputs)
        
        vloss = VAE.loss_function_dist(voutputs, vinputs, latent, input_dim)
        running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        #torch.save(vae.state_dict(), model_path)

    epoch_number += 1

EPOCH 1:
  batch 100 loss: 102268.02375
  batch 200 loss: 94172.823671875
LOSS train 94172.823671875 valid 107453.6328125
EPOCH 2:
  batch 100 loss: 91474.396328125
  batch 200 loss: 90896.451015625
LOSS train 90896.451015625 valid 101955.5625
EPOCH 3:
  batch 100 loss: 90095.656953125
  batch 200 loss: 90095.94796875
LOSS train 90095.94796875 valid 97991.1875
EPOCH 4:
  batch 100 loss: 89608.863828125
  batch 200 loss: 89559.135703125
LOSS train 89559.135703125 valid 95764.1875
EPOCH 5:
  batch 100 loss: 89406.8259375
  batch 200 loss: 89227.239765625
LOSS train 89227.239765625 valid 94891.2578125
EPOCH 6:
  batch 100 loss: 89022.93625
  batch 200 loss: 89222.17609375
LOSS train 89222.17609375 valid 96044.8984375
EPOCH 7:
  batch 100 loss: 89043.665078125
  batch 200 loss: 88941.221640625
LOSS train 88941.221640625 valid 95405.5625
EPOCH 8:
  batch 100 loss: 88708.440546875
  batch 200 loss: 88863.93671875
LOSS train 88863.93671875 valid 95216.6796875
EPOCH 9:
  batch 100 loss: 88784.

Shuffle the colums of data

In [13]:
torch.save(vae.state_dict(), 'vae_weights.pth')

BioBomb data

In [None]:
tcga_biobomb_file_location = 'data/rescaled_5000_gtex_df_sort.tsv.gz'

In [None]:
tcga_df = pd.read_table(tcga_biobomb_file_location)
tcga_df.drop(columns=tcga_df.columns[0], axis=1,  inplace=True)
tcga_df = tcga_df.dropna()
print(tcga_df.shape)
tcga_df.head(2)

In [171]:
test_set_percent = 0.1
tcga_df_test = rnaseq_df.sample(frac=test_set_percent)
tcga_df_train = rnaseq_df.drop(tcga_df_test.index)

In [172]:
train_dataset_tcga = CustomDataset(tcga_df_train)
test_dataset_tcga = CustomDataset(tcga_df_test)

In [173]:
train_loader_tcga = torch.utils.data.DataLoader(dataset=train_dataset_tcga, batch_size=100, shuffle=True)
validation_loader_tcga = torch.utils.data.DataLoader(dataset=test_dataset_tcga, batch_size=100, shuffle=True)

In [174]:
# build model
input_dim=rnaseq_df.shape[1]

vae_tcga = VAE(input_dim=input_dim, hidden_dim=[100], z_dim=100)

# if torch.backends.mps.is_available():
#     DEVICE = 'mps'
# else:
#train_loader = torch.utils.data.DataLoader(dataset=torch.Tensor(torch.randn(30, 5000)), batch_size=100, shuffle=True)

DEVICE = 'cpu'
    
vae_tcga.to(DEVICE)

optimizer = optim.Adam(vae_tcga.parameters(), lr=0.0005)


# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/tcga_trainer_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 50

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    vae_tcga.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)

    # We don't need gradients on to do reporting
    vae_tcga.train(False)

    running_vloss = 0.0
    for i, vdata in enumerate(validation_loader):
        vinputs = vdata
        voutputs, latent = vae_tcga(vinputs)
        
        vloss = VAE.loss_function_dist(voutputs, vinputs, latent, input_dim)
        running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    avg_vloss = avg_vloss/100
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        #torch.save(vae.state_dict(), model_path)

    epoch_number += 1

EPOCH 1:
  batch 1 loss: 2978.0325
LOSS train 2978.0325 valid 3333.0634765625
EPOCH 2:
  batch 1 loss: 2968.221875
LOSS train 2968.221875 valid 3333.6884765625
EPOCH 3:
  batch 1 loss: 2990.526875
LOSS train 2990.526875 valid 3332.948974609375
EPOCH 4:
  batch 1 loss: 3003.7159375
LOSS train 3003.7159375 valid 3333.268798828125
EPOCH 5:
  batch 1 loss: 2992.02
LOSS train 2992.02 valid 3332.671142578125
EPOCH 6:
  batch 1 loss: 2973.1271875
LOSS train 2973.1271875 valid 3333.08740234375
EPOCH 7:
  batch 1 loss: 3001.1975
LOSS train 3001.1975 valid 3333.07958984375
EPOCH 8:
  batch 1 loss: 3001.526875
LOSS train 3001.526875 valid 3333.5986328125
EPOCH 9:
  batch 1 loss: 2986.7059375
LOSS train 2986.7059375 valid 3332.742919921875
EPOCH 10:
  batch 1 loss: 2984.953125
LOSS train 2984.953125 valid 3332.91748046875
EPOCH 11:
  batch 1 loss: 2986.73125
LOSS train 2986.73125 valid 3332.461669921875
EPOCH 12:
  batch 1 loss: 2983.706875
LOSS train 2983.706875 valid 3332.802734375
EPOCH 13:
  b

In [21]:
torch.save(vae.state_dict(), 'vae_weights.pth')