## Implementation of standard VAE and trainning procedure

The implementation inspired by: https://github.com/jmtomczak/intro_dgm/blob/main/vaes/vae_example.ipynb and https://github.com/DeepLearningDTU/02456-deep-learning-with-PyTorch/blob/master/7_Unsupervised/7.2-EXE-variational-autoencoder.ipynb

In [1]:
import os

import re
import random
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from sklearn.datasets import load_digits
from sklearn import datasets
from torch.utils.data import Dataset, DataLoader, RandomSampler
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms

from torchsummary import summary

### Dataset
The dataset contain 68x68 images of single cells treated with different compounds. For each of the utilized compounds there is an associated mechanism of action (moa), which describes how the compound it affecting the cell. There are 12 different classes of moa.

In [2]:
start_time = time.time()
metadata = pd.read_csv('/Users/mikkelrasmussen/mnt/deep_learning_project/data/metadata.csv', engine="pyarrow")
print("pd.read_csv wiht pyarrow took %s seconds" % (time.time() - start_time))

pd.read_csv wiht pyarrow took 775.9014341831207 seconds


In [3]:
DMSO_indx = metadata.index[metadata['moa'] == 'DMSO']
DMSO_drop_indices = np.random.choice(DMSO_indx, size=260360, replace=False)

metadata_subsampled = metadata.drop(DMSO_drop_indices).reset_index()

In [4]:
metadata_subsampled.groupby("moa").size().reset_index(name='counts').sort_values(by="counts", ascending=False)

Unnamed: 0,moa,counts
10,Microtubule stabilizers,89157
1,Aurora kinase inhibitors,16810
4,DNA damage,16582
3,DMSO,16000
9,Microtubule destabilizers,15178
7,Epithelial,14955
6,Eg5 inhibitors,12525
8,Kinase inhibitors,11622
12,Protein synthesis,9715
0,Actin disruptors,7491


In [5]:
# Map from class name to class index
classes = {index: name for name, index in enumerate(metadata["moa"].unique())}
classes_inv = {v: k for k, v in classes.items()}
classes

{'DMSO': 0,
 'Microtubule stabilizers': 1,
 'Eg5 inhibitors': 2,
 'Epithelial': 3,
 'Actin disruptors': 4,
 'Microtubule destabilizers': 5,
 'Aurora kinase inhibitors': 6,
 'Protein degradation': 7,
 'DNA replication': 8,
 'DNA damage': 9,
 'Protein synthesis': 10,
 'Kinase inhibitors': 11,
 'Cholesterol-lowering': 12}

In [52]:
class SingleCellDataset(torch.utils.data.Dataset):
    def __init__(self, annotation_file, images_folder, class_map, 
                 mode='train', transform = None):
        self.df = annotation_file
        self.images_folder = images_folder
        self.transform = transform
        self.class2index = class_map
            

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        filename = self.df.loc[index, "Single_Cell_Image_Name"]
        label = self.class2index[self.df.loc[index, "moa"]]
        subfolder = re.search("(.*)_", filename).group(1)
        image = np.load(os.path.join(self.images_folder, subfolder, filename))
        if self.transform is not None:
            image = self.transform(image.astype(np.float32))
        return image, label

In [37]:
class Digits(Dataset):
    """Scikit-Learn Digits dataset."""

    def __init__(self, mode='train', transforms=None):
        digits = load_digits()
        if mode == 'train':
            self.data = digits.data[:1000].astype(np.float32)
        elif mode == 'val':
            self.data = digits.data[1000:1350].astype(np.float32)
        else:
            self.data = digits.data[1350:].astype(np.float32)

        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        if self.transforms:
            sample = self.transforms(sample)
        return sample

### Code for the standard Variational Autoencoder (VAE)

#### Functions for probability distributions

In [230]:
PI = torch.from_numpy(np.asarray(np.pi))
EPS = 1.0e-5

def log_categorical(x, p, num_classes=12, reduction=None, dim=None):
    x_one_hot = F.one_hot(x.long(), num_classes=-1)
    log_p = x_one_hot * torch.log(torch.clamp(p, EPS, 1.0 - EPS))
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p


def log_bernoulli(x, p, reduction=None, dim=None):
    pp = torch.clamp(p, EPS, 1.0 - EPS)
    log_p = x * torch.log(pp) + (1.0 - x) * torch.log(1.0 - pp)
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p
    

def log_normal_diag(x, mu, log_var, reduction=None, dim=None):
    D = x.shape[1]
    log_p = -0.5 * D * torch.log(2.0 * PI) - 0.5 * log_var - 0.5 * torch.exp(-log_var) * (x - mu)**2
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p


def log_standard_normal(x, reduction=None, dim=None):
    D = x.shape[1]
    log_p = -0.5 * D * torch.log(2.0 * PI) - 0.5 * x**2
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p
    

#### Encoder class

In [231]:
class Encoder(nn.Module):
    def __init__(self, encoder_net):
        super(Encoder, self).__init__()
        
        # The init of the encoder network
        self.encoder = encoder_net
    
    # The reparameterization trick for Gaussians
    @staticmethod
    def reparameterization(mu, log_var):
        # The formula is the following:
        # z = mu + std * epsilon
        
        # First, we need to get std from log-variance
        std = torch.exp(0.5*log_var)
        
        # Second, we sample epsilon from Normal(0, 1)
        eps = torch.randn_like(std)
        
        # Finally, z is calculated
        z = mu + std * eps
        
        return z
    
    # This function implements the output of the encoder network 
    # (i.e., parameters of a Gaussian)
    def encode(self, x):
        # First, we calculate the output of the encoder network of size 2M.
        h_e = self.encoder(x)
        
        # Second, we must divide the output to the mean and log-variance.
        mu_e, log_var_e = torch.chunk(h_e, 2, dim=1)
        
        return mu_e, log_var_e
    
    # Sampling procedure.
    def sample(self, x=None, mu_e=None, log_var_e=None):
        # If we do not provide a mean an a log-variance, we must first calculate it:
        if (mu_e is None) and (log_var_e is None):
            mu_e, log_var_e = self.encode(x)
        
        # Or the final sample
        else:
            # Otherwise, we can simply apply the reparameterization trick!
            if (mu_e is None) and (log_var_e is None):
                raise ValueError('mu and log-var cannot be None!')
        
        z = self.reparameterization(mu_e, log_var_e)
        return z
    
    # This function calculates the log-probability that is later used for calculating the ELBO.
    def log_prob(self, x=None, mu_e=None, log_var_e=None, z=None):
        # If we provide x alone, then we can calculate a corresponding sample.
        if x is not None:
            mu_e, log_var_e = self.encode(x)
            z = self.sample(mu_e=mu_e, log_var_e=log_var_e)
        else:
        # Otherwise, we should provide mu, log-var and z!
            if (mu_e is None) or (log_var_e is None) or (z is None):
                raise ValueError('mu, log-var and z cannot be None!')
                
        return log_normal_diag(z, mu_e, log_var_e)
    
    # PyTorch forward pass: it is either log-probability (by default) or sampling.
    def forward(self, x, type='log_prob'):
        assert type in ['encode', 'log_prob'], 'Type could be either encode or log_prob'
        if type == 'log_prob':
            return self.log_prob(x)
        else:
            return self.sample(x)            

#### Decoder class

In [232]:
class Decoder(nn.Module):
    def __init__(self, decoder_net, distribution='categorical', num_vals=None):
        super(Decoder, self).__init__()
        
        # The decoder network.
        self.decoder = decoder_net
        # The distribution used for the decoder (it is categorical by default)
        self.distribution = distribution
        # The number of possible values. This is important for the categorical distribution.
        self.num_vals = num_vals
    
    # This function calculates parameters of the likelihood function p(x|z)
    def decode(self, z):
        # First, we apply the decoder network.
        h_d = self.decoder(z)
        
        # We will mainly use the categorical distribution.
        if self.distribution == 'categorical':
            # We save the shapes: batch size
            b = h_d.shape[0]
            # and the dimensionality of x.
            d = h_d.shape[1]//self.num_vals
            # Then we reshape to (Batch size, Dimensionality, Number of Values)
            h_d = h_d.view(b, d, self.num_vals)
            # To get probabilities, we apply softmax
            mu_d = torch.softmax(h_d, 2)
            return [mu_d]
        
        # ... however, we also present the Bernoulli distribution.
        elif self.distribution == 'bernoulli':
            # In the Bernoulli case, we have x_d \in {0, 1}.
            # Therefore, it is enough to output a single probability,
            # because p(x_d=1|z) = \theta and p(x_d=0|z) = 1 - \theta
            mu_d = torch.sigmoid(h_d)
            return [mu_d]
        else:
            raise ValueError('Distribution must be either: categorical or bernoulli')
        
    # This function implements sampling from the decoder
    def sample(self, z):
        outs = self.decode(z)
        
        if self.distribution == 'categorical':
            # We take the output of the decoder
            mu_d = outs[0]
            # and save shapes (we will need that for reshaping).
            b = mu_d.shape[0]
            m = mu_d.shape[1]
            # Here we use reshaping
            mu_d = mu_d.view(b, -1, self.num_vals)
            p = mu_d.view(-1, self.num_vals)
            # Eventually, we sample from the categorical (the built-in PyTorch function)
            x_new = torch.multinomial(p, num_samples=1).view(b, m)
            return x_new
            
        elif self.distribution == 'bernoulli':
            # In the case of Bernoulli, we do not need any reshaping
            mu_d = outs[0]
            # and we can use the built-in PyTorch sampler!
            x_new = torch.bernoulli(mu_d)
            return x_new
    
        else: 
            raise ValueError('Distribution must be either: categorical or bernoulli')
            
    # This function calculates the conditional log-likelihood function.
    def log_prob(self, x, z):
        outs = self.decode(z)
        
        if self.distribution == 'categorical':
            mu_d = outs[0]
            log_p = log_categorical(x, mu_d, num_classes=self.num_vals, 
                                    reduction='sum', dim=-1).sum(-1)
            
        elif self.distribution == 'bernoulli':
            mu_d = outs[0]
            log_p = log_bernoulli(x, mu_d, reduction='sum', dim=-1)
            
        else:
            raise ValueError('Distribution must be either: categorical or bernoulli')
        
        return log_p
    
    # The forward pass is either a log-prob or a sample.
    def forward(self, z, x=None, type='log_prob'):
        assert type in ['decoder', 'log_prob'], 'Type could be either decode or log_prob'
        if type == 'log_prob':
            return self.log_prob(x, z)
        else:
            return self.sample(x)

#### Prior class

In [233]:
# The current implementation of the prior is very simple, namely, it is a standard Gaussian.
# We could have used a built-in PyTorch distribution. However, we did not do that for two reasons:
#    (i): It is important to think of the prior as a crucial component in VAEs
#    (ii): We can implement a learnable prior (e.g. a flow-based prior, VampPrior, a mixture of distributions)

class Prior(nn.Module):
    def __init__(self, L):
        super(Prior, self).__init__()
        self.L = L
    
    def sample(self, batch_size):
        z = torch.randn((batch_size, self.L))
        return z
    
    def log_prob(self, z):
        return log_standard_normal(z)

#### Full VAE class

In [234]:
class VAE(nn.Module):
    def __init__(self, encoder_net, decoder_net, num_vals=256, L=16, likelihood_type='categorical'):
        super(VAE, self).__init__()
        
        self.encoder = Encoder(encoder_net=encoder_net)
        self.decoder = Decoder(distribution=likelihood_type,
                              decoder_net=decoder_net, num_vals=num_vals)
        self.prior = Prior(L=L)
        self.num_vals = num_vals
        self.likelihood_type = likelihood_type
        
    def forward(self, x, reduction='avg'):
        # Encoder
        mu_e, log_var_e = self.encoder.encode(x)
        z = self.encoder.sample(mu_e=mu_e, log_var_e=log_var_e)
        
        # ELBO        
        RE = self.decoder.log_prob(x, z)
        KL = (self.prior.log_prob(z) - self.encoder.log_prob(mu_e=mu_e, log_var_e=log_var_e, z=z)).sum(-1)
        
        if reduction == 'sum':
            return -(RE + KL).sum()
        elif reduction == 'avg':
            return -(RE + KL).mean()
        else:
            raise ValueError('reduction must be either: sum or avg')
    
    def sample(self, batch_size=64):
        z = self.prior.sample(batch_size=batch_size)
        return self.decoder.sample(z)

#### Auxiliary functions for training, evaluation and plotting

In [235]:
def evaluation(test_loader, name=None, model_best=None, epoch=None,
               device='cpu'):
    # EVALUATION
    if model_best is None:
        # load best performing model
        model_best = torch.load(name + '.model')

    model_best.eval()
    loss = 0.
    N = 0.
    for indx_batch, (test_batch, test_target) in enumerate(test_loader):
        test_batch = test_batch.to(device)
        
        loss_t = model_best.forward(test_batch, reduction='sum')
        loss = loss + loss_t.item()
        N = N + test_batch.shape[0]
    loss = loss / N

    if epoch is None:
        print(f'FINAL LOSS: nll={loss}')
    else:
        print(f'Epoch: {epoch}, val nll={loss}')

    return loss

def samples_real(name, test_loader):
    # REAL-------
    num_x = 4
    num_y = 4
    x = next(iter(test_loader)).detach().numpy()

    fig, ax = plt.subplots(num_x, num_y)
    for i, ax in enumerate(ax.flatten()):
        plottable_image = np.reshape(x[i], (8, 8))
        ax.imshow(plottable_image, cmap='gray')
        ax.axis('off')

    plt.savefig(name+'_real_images.pdf', bbox_inches='tight')
    plt.close()
    

def samples_generated(name, data_loader, extra_name=''):
    x = next(iter(data_loader)).detach().numpy()

    # GENERATIONS-------
    model_best = torch.load(name + '.model')
    model_best.eval()

    num_x = 4
    num_y = 4
    x = model_best.sample(num_x * num_y)
    x = x.detach().numpy()

    fig, ax = plt.subplots(num_x, num_y)
    for i, ax in enumerate(ax.flatten()):
        plottable_image = np.reshape(x[i], (8, 8))
        ax.imshow(plottable_image, cmap='gray')
        ax.axis('off')

    plt.savefig(name + '_generated_images' + extra_name + '.pdf', bbox_inches='tight')
    plt.close()
    

def plot_curve(name, nll_val):
    plt.plot(np.arange(len(nll_val)), nll_val, linewidth='3')
    plt.xlabel('epochs')
    plt.ylabel('nll')
    plt.savefig(name + '_nll_val_curve.pdf', bbox_inches='tight')
    plt.close()

In [236]:
def training(name, max_patience, num_epochs, model, optimizer, training_loader, val_loader):
    nll_val = []
    best_nll = 1000.
    patience = 0
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f">> Using device: {device}")

    # move the model to the device
    model = model.to(device)

    # Main loop
    for e in range(num_epochs):
        # TRAINING
        model.train()
        for indx_batch, (batch, target) in enumerate(training_loader):
            batch = batch.to(device)
            
            if hasattr(model, 'dequantization'):
                if model.dequantization:
                    batch = batch + torch.rand(batch.shape)
                    
            loss = model.forward(batch)

            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()

        # Validation
        loss_val = evaluation(val_loader, model_best=model, epoch=e, device=device)
        nll_val.append(loss_val)  # save for plotting

        if e == 0:
            print('saved!')
            torch.save(model, name + '.model')
            best_nll = loss_val
        else:
            if loss_val < best_nll:
                print('saved!')
                torch.save(model, name + '.model')
                best_nll = loss_val
                patience = 0

                samples_generated(name, val_loader, extra_name="_epoch_" + str(e))
            else:
                patience = patience + 1

        if patience > max_patience:
            break

    nll_val = np.asarray(nll_val)

    return nll_val

#### Initialize dataloaders

In [237]:
batch_size = 10
train_data = Digits(mode='train')
val_data = Digits(mode='val')
test_data = Digits(mode='test')

training_loader_digit = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader_digit = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_loader_digit = DataLoader(test_data, batch_size=batch_size, shuffle=False)

result_dir = 'results/'
if not(os.path.exists(result_dir)):
    os.mkdir(result_dir)
name = 'vae'


images_folder = "/Users/mikkelrasmussen/mnt/deep_learning_project/data/singh_cp_pipeline_singlecell_images"
train_transforms = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Lambda(lambda x: torch.flatten(x)),
     transforms.Lambda(lambda x: x/x.max())]
)
train_set = SingleCellDataset(images_folder=images_folder, 
                              annotation_file=metadata_subsampled, 
                              transform=train_transforms,
                              class_map=classes)

#sampler = RandomSampler(train_set, replacement=False, num_samples=100)
#train_dataloader = DataLoader(train_set, sampler=sampler, 
#                              batch_size=batch_size, drop_last=True)

# Define the size of the train, validation and test datasets
data_prct = 0.01
train_prct = 0.8

data_amount = int(len(metadata_subsampled) * data_prct)
train_size = int(train_prct * data_amount)
val_size = (data_amount - train_size) // 2
test_size = (data_amount - train_size) // 2

indicies = torch.randperm(len(metadata_subsampled))
train_indices = indicies[:train_size]
val_indicies = indicies[train_size:train_size+val_size]
test_indicies = indicies[train_size+val_size:train_size+val_size+test_size]

# Checking there are not overlapping incdicies
print(sum(np.isin(train_indices.numpy() , [val_indicies.numpy(), test_indicies.numpy()])))
print(sum(np.isin(val_indicies.numpy() , [train_indices.numpy(), test_indicies.numpy()])))
print(sum(np.isin(test_indicies.numpy() , [train_indices.numpy(), val_indicies.numpy()])))

training_set = torch.utils.data.Subset(train_set, train_indices.tolist())
val_set = torch.utils.data.Subset(train_set, val_indicies.tolist())
testing_set = torch.utils.data.Subset(train_set, test_indicies.tolist())

training_loader = torch.utils.data.DataLoader(training_set, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(testing_set, batch_size=batch_size, shuffle=True)

print(len(training_loader.dataset))
print(len(val_loader.dataset))
print(len(test_loader.dataset))

0
0
0
1824
228
228


#### Hyperparameters

In [243]:
D = 64   # input dimension
L = 16  # number of latents
M = 256  # the number of neurons in scale (s) and translation (t) nets

lr = 1e-3 # learning rate
num_epochs = 1000 # max. number of epochs
max_patience = 20 # an early stopping is used, if training doesn't improve for longer than 20 epochs, it is stopped

#### Initialize VAE

In [246]:
likelihood_type = 'categorical'

if likelihood_type == 'categorical':
    num_vals = 17
elif likelihood_type == 'bernoulli':
    num_vals = 1

encoder = nn.Sequential(nn.Linear(D, M), nn.LeakyReLU(),
                        nn.Linear(M, M), nn.LeakyReLU(),
                        nn.Linear(M, 2 * L))

decoder = nn.Sequential(nn.Linear(L, M), nn.LeakyReLU(),
                        nn.Linear(M, M), nn.LeakyReLU(),
                        nn.Linear(M, num_vals * D))

prior = torch.distributions.MultivariateNormal(torch.zeros(L), torch.eye(L))
model = VAE(encoder_net=encoder, decoder_net=decoder, num_vals=num_vals, L=L, likelihood_type=likelihood_type)

# Print the summary (like in Keras)
print("ENCODER:\n")
print(summary(encoder, input_size=(1, D, D)))
print("\nDECODER:\n")
print(summary(decoder, input_size=(1, L, L)))

ENCODER:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1           [-1, 1, 64, 256]          16,640
         LeakyReLU-2           [-1, 1, 64, 256]               0
            Linear-3           [-1, 1, 64, 256]          65,792
         LeakyReLU-4           [-1, 1, 64, 256]               0
            Linear-5            [-1, 1, 64, 32]           8,224
Total params: 90,656
Trainable params: 90,656
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.02
Forward/backward pass size (MB): 0.52
Params size (MB): 0.35
Estimated Total Size (MB): 0.88
----------------------------------------------------------------
None

DECODER:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1           [-1, 1, 16, 256]           4,352
         LeakyReL

In [240]:
# OPTIMIZER
optimizer = torch.optim.Adamax([p for p in model.parameters() if p.requires_grad == True], lr=lr)

In [241]:
# Training procedure
nll_val = training(name=result_dir + name, 
                   max_patience=max_patience, 
                   num_epochs=1, 
                   model=model, 
                   optimizer=optimizer,
                   training_loader=training_loader, 
                   val_loader=val_loader)

>> Using device: cpu


KeyboardInterrupt: 

In [242]:
test_loss = evaluation(name=result_dir + name, test_loader=test_loader)
f = open(result_dir + name + '_test_loss.txt', "w")
f.write(str(test_loss))
f.close()

samples_real(result_dir + name, test_loader)

plot_curve(result_dir + name, nll_val)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (10x13872 and 64x256)

In [35]:
indx_batch, batch = next(iter(enumerate(training_loader)))

In [26]:
batch

tensor([[ 0.,  1., 15.,  ..., 16., 16.,  2.],
        [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
        [ 0.,  0.,  2.,  ..., 16.,  8.,  0.],
        ...,
        [ 0.,  0.,  7.,  ..., 14.,  6.,  0.],
        [ 0.,  0.,  0.,  ..., 15.,  8.,  0.],
        [ 0.,  0.,  0.,  ...,  6.,  0.,  0.]])

In [90]:
batch.shape

torch.Size([64, 64])

In [41]:
batch_new = next(iter(train_dataloader))

In [44]:
batch_new

[tensor([[  704.,   720.,   688.,  ...,  7584.,  8112.,  8608.],
         [  816.,   848.,   784.,  ...,  7328.,  8208., 10288.],
         [  544.,   480.,   528.,  ...,  1232.,  1264.,  1312.],
         ...,
         [  672.,   640.,   624.,  ...,  5616.,  5328.,  4976.],
         [ 2048.,  2112.,  2320.,  ...,  2496.,  2496.,  2320.],
         [  544.,   592.,   624.,  ...,  8496., 10272., 11872.]]),
 tensor([ 4,  1,  1,  0,  1, 11,  1,  9,  1,  1,  8,  1,  1,  5, 10,  1,  6,  0,
          2, 12,  3,  1, 10,  1,  2,  5, 10,  3,  3,  3,  3,  2, 11,  5,  1,  9,
          3,  2,  1,  4, 11,  1,  2,  1,  1, 12,  3,  1,  2,  7, 12,  1,  1,  6,
          0,  1,  1,  1, 12,  9,  9,  5,  1,  4])]