In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torchvision.transforms import ToTensor

from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.kl import kl_divergence
import torchvision
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler

In [2]:
tcga_train_file_location = 'data/train_tcga_expression_matrix_processed.tsv.gz'
target_train_file_location = 'data/train_target_expression_matrix_processed.tsv.gz'
gtex_train_file_location = 'data/train_gtex_expression_matrix_processed.tsv.gz'

In [None]:
tcga_df = pd.read_table(tcga_train_file_location)
target_df = pd.read_table(target_train_file_location)
gtex_df = pd.read_table(gtex_train_file_location)

In [None]:
tcga_df.head(2)

In [None]:
target_df.head(2)

In [None]:
gtex_df.head(2)

In [None]:
scaler = StandardScaler()

In [None]:
gtex_df_sort = gtex_df[list(gtex_df.columns)]
gtex_df_sort = gtex_df_sort.drop(columns='sample_id')
gtex_df_sort = gtex_df_sort.dropna()


In [None]:
gtex_df_sort = scaler.fit_transform(gtex_df_sort)

In [None]:
train_loader = torch.utils.data.DataLoader(dataset=torch.Tensor(gtex_df_sort), batch_size=100, shuffle=True)

In [None]:
gtex_df_sort.shape

VAE model

In [None]:
class VAE(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: list, z_dim):
        super(VAE, self).__init__()
        
        self.z_dim = z_dim
        
        self.encoder_layers = nn.ModuleList([nn.Linear(input_dim, hidden_dim[0])])
        self.decoder_layers = nn.ModuleList([nn.Linear(hidden_dim[0], input_dim)])
                
        if len(hidden_dim)>1:
            for i in range(len(hidden_dim)-1):
                self.encoder_layers.append(nn.Linear(hidden_dim[i], hidden_dim[i+1]))
                self.decoder_layers.insert(0, nn.Linear(hidden_dim[i+1], hidden_dim[i]))
                
        self.encoder_layers.append(nn.Linear(hidden_dim[-1], 2 * z_dim))
        self.decoder_layers.insert(0, nn.Linear(z_dim, hidden_dim[-1]))

        
    def encoder(self, x):
        for idx, layer in enumerate(self.encoder_layers):
            x = layer(x)
            if idx < len(self.encoder_layers) - 1:
                # x = F.dropout(x, 0.01)
                x = F.relu(x)
                #x = nn.BatchNorm1d(x)
        return x[...,:self.z_dim], x[...,self.z_dim:] # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        # std = torch.abs(log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
        
    def decoder(self, z):
        for idx, layer in enumerate(self.decoder_layers):
            z = layer(z)
            if idx < len(self.decoder_layers) - 1:
                # x = F.dropout(x, 0.01)
                z = F.relu(z)
        return torch.sigmoid(z) 
    
    def forward(self, x):
        mu, log_var = self.encoder(x.view(-1, input_dim))
        
        
    #    z = self.sampling(mu, log_var)
        latent = MultivariateNormal(loc = mu, 
                                    scale_tril=torch.diag_embed(torch.exp(0.5*log_var)))
        z = latent.rsample()
           
    #    return self.decoder(z), mu, log_var
        return self.decoder(z), latent

    @staticmethod
    def loss_function(recon_x, x, mu, log_var):
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, input_dim), reduction='sum')
        KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        return BCE + KLD
    
    @staticmethod
    def loss_function_dist(recon_x, x, latent, input_dim):
        prior = MultivariateNormal(loc = torch.zeros(latent.mean.shape[1]),
                                   scale_tril=torch.eye(latent.mean.shape[1]))
        
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, input_dim), reduction='sum')
        KLD = torch.sum(kl_divergence(latent, prior))
        return BCE + input_dim*KLD

In [None]:
# build model
input_dim=gtex_df_sort.shape[1]
#%%
print(input_dim)
vae = VAE(input_dim=input_dim, hidden_dim=[512,512], z_dim=100)
# if torch.backends.mps.is_available():
#     DEVICE = 'mps'
# else:
#train_loader = torch.utils.data.DataLoader(dataset=torch.Tensor(torch.randn(30, 5000)), batch_size=100, shuffle=True)

DEVICE = 'cpu'
    
vae.to(DEVICE)

optimizer = optim.Adam(vae.parameters())
avg_loss = []
def train(epoch, input_dim):
    vae.train()
    train_loss = 0
    for batch_idx, data in enumerate(train_loader):
        
        
        data = data.to(DEVICE)
        optimizer.zero_grad()
        
        #recon_batch, mu, log_var = vae(data, unmask_id = None)
        #recon_batch, mu, log_var = vae(data)
        #loss = VAE.loss_function(recon_batch, data, mu, log_var)
        recon_batch, latent = vae(data)
        loss = VAE.loss_function_dist(recon_batch, data, latent, input_dim)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item() / len(data)))
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, abs(train_loss) / len(train_loader.dataset)))
    avg_loss.append(abs(train_loss) / len(train_loader.dataset))

In [None]:
vae(torch.rand(5,gtex_df_sort.shape[1]))

In [None]:
#%%
def test():
    vae.eval()
    test_loss= 0
    with torch.no_grad():
        for data in test_loader:
            data = data.to(DEVICE)
            #recon, mu, log_var = vae(data)
            recon, latent = vae(data)
            
            # sum up batch loss
            #test_loss += VAE.loss_function(recon, data, mu, log_var).item()
            test_loss += VAE.loss_function_dist(recon, data, latent).item()
        
    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(abs(test_loss)))

In [None]:
#%%
for epoch in range(1, 5):
    train(epoch, input_dim)
    #test()

In [None]:
import matplotlib.pyplot as plt
plt.plot(avg_loss)

In [None]:
summary(vae)