## Implementation of standard VAE and trainning procedure

The implementation inspired by: https://github.com/jmtomczak/intro_dgm/blob/main/vaes/vae_example.ipynb and https://github.com/DeepLearningDTU/02456-deep-learning-with-PyTorch/blob/master/7_Unsupervised/7.2-EXE-variational-autoencoder.ipynb

In [5]:
import os

import re
import random
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from sklearn.datasets import load_digits
from sklearn import datasets
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F

from torchsummary import summary

In [3]:
!pip3 install torchsummary

Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1


### Dataset
The dataset contain 68x68 images of single cells treated with different compounds. For each of the utilized compounds there is an associated mechanism of action (moa), which describes how the compound it affecting the cell. There are 12 different classes of moa.

In [6]:
start_time = time.time()
metadata = pd.read_csv('../data/metadata.csv', engine="pyarrow")
print("pd.read_csv wiht pyarrow took %s seconds" % (time.time() - start_time))

pd.read_csv wiht pyarrow took 2185.364300966263 seconds


In [7]:
DMSO_indx = metadata.index[metadata['moa'] == 'DMSO']
DMSO_drop_indices = np.random.choice(DMSO_indx, size=260360, replace=False)

metadata_subsampled = metadata.drop(DMSO_drop_indices).reset_index()

In [8]:
metadata_subsampled.groupby("moa").size().reset_index(name='counts').sort_values(by="counts", ascending=False)

Unnamed: 0,moa,counts
10,Microtubule stabilizers,89157
1,Aurora kinase inhibitors,16810
4,DNA damage,16582
3,DMSO,16000
9,Microtubule destabilizers,15178
7,Epithelial,14955
6,Eg5 inhibitors,12525
8,Kinase inhibitors,11622
12,Protein synthesis,9715
0,Actin disruptors,7491


In [9]:
# Map from class name to class index
classes = {index: name for name, index in enumerate(metadata["moa"].unique())}
classes

{'DMSO': 0,
 'Microtubule stabilizers': 1,
 'Eg5 inhibitors': 2,
 'Epithelial': 3,
 'Actin disruptors': 4,
 'Microtubule destabilizers': 5,
 'Aurora kinase inhibitors': 6,
 'Protein degradation': 7,
 'DNA replication': 8,
 'DNA damage': 9,
 'Protein synthesis': 10,
 'Kinase inhibitors': 11,
 'Cholesterol-lowering': 12}

In [10]:
class SingleCellDataset(torch.utils.data.Dataset):
    def __init__(self, annotation_file, images_folder, class_map, transform = None):
        self.df = annotation_file
        self.images_folder = images_folder
        self.transform = transform
        self.class2index = class_map

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        filename = self.df.loc[index, "Single_Cell_Image_Name"]
        label = self.class2index[self.df.loc[index, "moa"]]
        subfolder = re.search("(.*)_", filename).group(1)
        image = np.load(os.path.join(self.images_folder, subfolder, filename))
        if self.transform is not None:
            image = self.transform(image.astype("int16"))
        return image, label

### Code for the standard Variational Autoencoder (VAE)

#### Functions for probability distributions

In [61]:
PI = torch.from_numpy(np.asarray(np.pi))
EPS = 1.0e-5

def log_categorical(x, p, num_classes=12, reduction=None, dim=None):
    x_one_hot = F.one_hot(x.long(), num_classes=num_classes)
    log_p = x_one_hot * torch.log(torch.clamp(p, EPS, 1.0 - EPS))
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p
    

def log_normal_diag(x, mu, log_var, reduction=None, dim=None):
    D = x.shape[1]
    log_p = -0.5 * D * torch.log(2.0 * PI) - 0.5 * log_var - 0.5 * torch.exp(-log_var) * (x - mu)**2
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p


def log_normal_standard(x, reduction=None, dim=None):
    D = x.shape[1]
    log_p = -0.5 * D * torch.log(2.0 * PI) - 0.5 * x**2
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

#### Encoder class

In [74]:
class Encoder(nn.Module):
    def __init__(self, encoder_net):
        super(Encoder, self).__init__()
        
        # The init of the encoder network
        self.encoder = encoder_net
    
    # The reparameterization trick for Gaussians
    @staticmethod
    def reparameterization(mu, log_var):
        # The formula is the following:
        # z = mu + std * epsilon
        
        # First, we need to get std from log-variance
        std = torch.exp(0.5*log_var)
        
        # Second, we sample epsilon from Normal(0, 1)
        eps = torch.randn_like(std)
        
        # Finally, z is calculated
        z = mu + std * eps
        
        return z
    
    # This function implements the output of the encoder network 
    # (i.e., parameters of a Gaussian)
    def encode(self, x):
        # First, we calculate the output of the encoder network of size 2M.
        h_e = self.encoder(x)
        
        # Second, we must divide the output to the mean and log-variance.
        mu_e, log_var_e = torch.chunk(h_e, 2, dim=1)
        
        return mu_e, log_var_e
    
    # Sampling procedure.
    def sample(self, x=None, mu_e=None, log_var_e=None):
        # If we do not provide a mean an a log-variance, we must first calculate it:
        if (mu_e is None) and (log_var_e is None):
            mu_e, log_var_e = self.encode(x)
        
        # Or the final sample
        else:
            # Otherwise, we can simply apply the reparameterization trick!
            if (mu_e is None) and (log_var_e is None):
                raise ValueError('mu and log-var cannot be None!')
        
        z = self.reparameterization(mu_e, log_var_e)
        return z
    
    # This function calculates the log-probability that is later used for calculating the ELBO.
    def log_prob(self, x=None, mu_e=None, log_var_e=None, z=None):
        # If we provide x alone, then we can calculate a corresponding sample.
        if x is not None:
            mu_e, log_var_e = self.encode(x)
            z = self.sample(mu_e=mu_e, log_var_e=log_var_e)
        else:
        # Otherwise, we should provide mu, log-var and z!
            if (mu_e is None) or (log_var_e is None) or (z is None):
                raise ValueError('mu, log-var and z cannot be None!')
                
        return log_normal_diag(z, mu_e, log_var_e)
    
    # PyTorch forward pass: it is either log-probability (by default) or sampling.
    def forward(self, x, type='log_prob'):
        assert type in ['encode', 'log_prob'], 'Type could be either encode or log_prob'
        if type == 'log_prob':
            return self.log_prob(x)
        else:
            return self.sample(x)            

#### Decoder class

In [None]:
class Decoder(nn.Module):
    def __init__(self, decoder_net, distribution='categorical', num_vals=None):
        super(Decoder, self).__init__()
        
        # The decoder network.
        self.decoder = decoder_net
        # The distribution used for the decoder (it is categorical by default)
        self.distribution = distribution
        # The number of possible values. This is important for the categorical distribution.
        self.num_vals = num_vals
        
    def decode(self, z):
        h_d = self.decoder(z)
        
        if self.distribution == 'categorical':
            b = h_d.shape[0]
            d = h_d.shape[1]
            h_d = h_d