# VAE and VSC for Cell Images

In [None]:
import torch
from torch import cuda
from torch import nn, optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils import data
from torchvision.utils import save_image
import torchvision
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from PIL import Image
import time

In [None]:
data_dir = '../../../Data/'
results_dir = '../../../results/epoch100beta5e-1/'
batch_size = 128

torch.manual_seed(22)
device = torch.device("cuda" if cuda.is_available() else "cpu")
print(device)

The following code is a utility to load and split the data given the path where the data is stored, pytorch keeps the folder names as labes as if it was a classification task, but they can just be ignored when using the data loaders. Since the task at hand is from a very specific domain (biology and cell images) the normalization values that are often used for computer vision tasks cannot be used. The mean and std used to normalize the dataset was pre computed from a small sample of the dataset and might not be accurate, but it seems to work well right now

In [None]:
def split_data(data_dir, n_split=0.2, batch_size=256):
    
    pin_memory = cuda.is_available()
    workers = 0 if cuda.is_available() else 4
    
    # Create training and validation datasets
    image_dataset = datasets.ImageFolder(data_dir, transform=transforms.Compose([
        transforms.ToTensor(),
        #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        transforms.Normalize((0.0302, 0.0660, 0.0518), (0.0633, 0.0974, 0.0766))
    ]))
    # Create training and validation dataloaders
    total = len(image_dataset)
    n_test = int(total * n_split)
    n_train = total - n_test 
    train_set, test_set = data.random_split(image_dataset, (n_train, n_test))

    n_val = int(len(train_set) * n_split)
    n_train = len(train_set) - n_val
    train_set, val_set = data.random_split(train_set, (n_train, n_val))  
    
    print('Train split: ', len(train_set))
    print('Val split: ', len(val_set))
    print('Test split: ', len(test_set))
    
    train_loader = data.DataLoader(
      train_set,
      batch_size=batch_size,
      num_workers = workers,
      shuffle=True,
      pin_memory=pin_memory
    )
    val_loader = data.DataLoader(
      val_set,
      batch_size=batch_size,
      num_workers = workers,
      shuffle=True,
      pin_memory=pin_memory
    )
    test_loader = data.DataLoader(
      test_set,
      batch_size=batch_size,
      num_workers = workers,
      shuffle=True,
      pin_memory=pin_memory
    )
    return train_loader, val_loader, test_loader

Helper function to view tensors as a plot

In [None]:
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.0302, 0.0660, 0.0518])
    std = np.array([0.0633, 0.0974, 0.0766])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.figure(figsize=(15,15))
    plt.imshow(inp)
    if title is not None:
        plt.title(title)

## Variational AutoEncoder

The previous work this is based on had an specific VAE architecture implemented in TensorFlow, the first step was to migrate that NN to Pytorch keeping the same layers and the same operations that were done in the encoder and autoencoder, as well as to keep the same loss fuction

In [None]:
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim

        channels = 'placeholder'
        # Encoder
        self.encoder_conv1 = self.getConvolutionLayer(3, 128)
        self.encoder_conv2 = self.getConvolutionLayer(128, 64)
        self.encoder_conv3 = self.getConvolutionLayer(64, 32)
        
        self.flatten = nn.Flatten()

        self.encoder_fc1 = nn.Linear(4608, self.latent_dim)
        self.encoder_fc2 = nn.Linear(4608, self.latent_dim)
        
        # Decoder
        self.decoder_fc1 = nn.Sequential(
            nn.Linear(self.latent_dim, 4608),
            nn.ReLU()
        )
        # Reshape to 32x12x12
        self.decoder_upsampler1 = nn.Upsample(scale_factor=(2,2) ,mode='nearest')
        
        self.decoder_deconv1 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.Upsample(scale_factor=(2,2) ,mode='nearest')
        )
        # 48x48x64
        self.decoder_deconv2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.Upsample(scale_factor=(2,2) ,mode='nearest')
        )

        self.decoder_conv1 = nn.Conv2d(in_channels=128, out_channels=3, kernel_size=3, stride=1, padding=1)
        # 96x96x128
        

    def getConvolutionLayer(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )


    def encode(self, x):
        x = self.encoder_conv1(x)
        x = self.encoder_conv2(x)
        x = self.encoder_conv3(x)
        
        x = self.flatten(x)
        mu = self.encoder_fc1(x)
        sigma = self.encoder_fc2(x)

        return mu, sigma

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        # Keeps shape, samples from normal dist with mean 0 and variance 1
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        z = self.decoder_fc1(z)
        z = self.decoder_upsampler1(z.view(-1, 32, 12, 12))
        z = self.decoder_deconv1(z)
        z = self.decoder_deconv2(z)
        recon = self.decoder_conv1(z)        
        return recon

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

From what I've seen MSE and KLD and somewhat incompatible depending on how the scores are aggregated. It either has to be sum or mean, choosing different aggregation techniques results in the difference between scores being too different and the NN will end up optimizing the one that has the bigger impact. Current solution is B-Vae where a B parameter is added to the KLD as to control how much importance it has in the loss function; KLD and MSE are aggregated by sum.

In [None]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar, epoch_n=False):
    
    # mse = F.mse_loss(recon_x, x, reduction='mean')
     
    mse = torch.mean(torch.sum((x - recon_x).pow(2), dim=(1,2,3)))
    
    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    
    #kld = torch.mean(-0.5 * torch.sum(1 + torch.log(1e-10 + sigma.pow(2)) - mu.pow(2) - sigma.pow(2)))
    #kld = torch.mean(-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), axis=1))
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) * beta
    
    loss = mse + kld
    return loss, mse, kld

Helpers for the model training and testing loop 

In [None]:
def train(epoch, train_loader):
    model.train()
    train_loss = 0
    train_mse = 0
    train_kld = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss, mse, kld = loss_function(recon_batch, data, mu, logvar, epoch)
        
        loss.backward()
        optimizer.step()
        
        current_batch_size = len(data)
        train_loss += loss.item() * current_batch_size
        train_mse += mse.item() * current_batch_size
        train_kld += kld.item()

  
        if batch_idx % (int(len(train_loader) / 4)) == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item()))

    datapoints = len(train_loader.dataset)
    avg_loss = train_loss/datapoints
    avg_mse = train_mse/datapoints
    avg_kld = train_kld/(beta * len(train_loader))

    print('====> Epoch: {} Average loss: {:.8f}'.format(epoch, avg_loss))
    print('*** Avg MSE: {:.4f}'.format(avg_mse))
    print('*** Avg KLD: {:.8f}'.format(avg_kld * beta))
    return avg_loss, avg_mse, avg_kld

In [None]:
def test(epoch, test_loader):
    model.eval()
    test_loss = 0
    test_mse = 0
    test_kld = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            loss, mse, kld = loss_function(recon_batch, data, mu, logvar, epoch)

            current_batch_size = data.size(0)
            test_loss += loss.item() * current_batch_size
            test_mse += mse.item() * current_batch_size
            test_kld += kld.item()

            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n], recon_batch[:n]]).cpu()
                comparison = torchvision.utils.make_grid(comparison)
                imshow(comparison)
                plt.savefig(results_dir + 'reconstruction_' + str(epoch) + '.png')
                plt.close()
                #save_image(comparison, results_dir + 'reconstruction_' + str(epoch) + '.tif')

    datapoints = len(test_loader.dataset)
    test_loss /= datapoints
    test_mse /= datapoints
    test_kld /= (beta * len(test_loader))
    print('====> Test set loss: {:.8f}'.format(test_loss))
    print('*** Avg MSE: {:.8f}'.format(test_mse))
    print('*** Avg KLD: {:.8f}'.format(test_kld))
    return test_loss, test_mse, test_kld

In [None]:
def get_time_in_hours(seconds):
    hours = seconds // 3600
    remaining_time = seconds % 3600
    minutes = remaining_time // 60
    seconds = remaining_time % 60
    
    return hours, minutes, seconds

### Training
The training of the migrated VAE starts here

In [None]:
train_data, val_data, test_data = split_data(data_dir=data_dir, batch_size=batch_size)

In [None]:
model = VAE(256).to(device)
model

In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)
train_trace = {
    'loss': [],
    'mse': [],
    'kld': []
}

val_trace = {
    'loss': [],
    'mse': [],
    'kld': []
}
epochs = 100
beta = 0.5
since = time.time()
for epoch in range(1, epochs + 1):
        loss, mse, kld = train(epoch, train_data)
        train_trace['loss'].append(loss)
        train_trace['mse'].append(mse)
        train_trace['kld'].append(kld)

        loss, mse, kld = test(epoch, val_data)
        val_trace['loss'].append(loss)
        val_trace['mse'].append(mse)
        val_trace['kld'].append(kld)
        
        with torch.no_grad():
            sample = torch.randn(4, 256).to(device)
            sample = model.decode(sample).cpu()
            sample = torchvision.utils.make_grid(sample)
            imshow(sample)
            plt.savefig(results_dir + 'sample_' + str(epoch) + '.png')
            plt.close()
            #save_image(sample, results_dir + 'sample_' + str(epoch) + '.tif')
        
        epoch_time = time.time() - since
        e_hours, e_minutes, e_seconds = get_time_in_hours(epoch_time)
        print('Time elapsed {:.0f}h {:.0f}m {:.0f}s'.format(e_hours, e_minutes, e_seconds))

In [None]:
def plot_loss(train_data, val_data, epochs, xlabel, ylabel, title):
    plt.figure(figsize=(10,10))
    
    min_train = int(min(train_data))
    max_train = int(max(train_data))
    min_val = int(min(val_data))
    max_val = int(max(val_data))
    
    low_bound = min(min_train, min_val)
    up_bound = max(max_train, max_val)
    
    plt.plot(train_data, label='train')
    plt.plot(val_data, label='test')
    plt.xticks(np.arange(0, epochs+1, int(epochs/20)))
    plt.yticks(np.arange(low_bound, up_bound*1.01, int((1.1*up_bound - 1.1*low_bound) / 20)))
    plt.ylabel(ylabel)
    plt.xlabel(xlabel)
    plt.title(title)
    plt.legend()
    plt.grid()
    plt.show()

In [None]:
plot_loss(train_trace['loss'], val_trace['loss'], epochs, 'Epochs', "Loss (MSE + KLD)", "Average Loss for Beta " + str(beta))

In [None]:
plot_loss(train_trace['mse'], val_trace['mse'], epochs, 'Epochs', "MSE", "Average MSE for Beta " + str(beta))

In [None]:
plot_loss(train_trace['kld'], val_trace['kld'], epochs, 'Epochs', "KLD", "Average KLD for Beta " + str(beta))

### Experimental Cells
Please ignore

In [None]:
img, _ = next(iter(val_data))
print(img.shape)
print(torch.max(img[0]))
print(torch.max(img[0][0]))
print(torch.max(img[0][1]))
print(torch.max(img[0][2]))

In [None]:
x = img[0:100].to(device)
recon, mu, logvar = model.forward(x)

In [None]:
img[1:2] + img[2:3]*0.25

In [None]:
interpolated = None
base_a = None
base_b = None
with torch.no_grad():
    cell_a = img[48:49].to(device)
    cell_b = img[50:51].to(device)

    mu_a, sigma_a = model.encode(cell_a)
    mu_b, sigma_b = model.encode(cell_b)

    z_a = model.reparameterize(mu_a, sigma_a)
    z_b = model.reparameterize(mu_b, sigma_b)

    z_diff = z_b - z_a
    recon_a = model.decode(z_a)
    recon_25 = model.decode(z_a + (z_diff * 0.25))
    recon_50 = model.decode(z_a + (z_diff * 0.50))
    recon_75 = model.decode(z_a + (z_diff * 0.75))
    recon_b = model.decode(z_b)
    interpolated = torch.cat((recon_a, recon_25, recon_50, recon_75, recon_b), dim=0).cpu()
    base_a = cell_a.cpu()
    base_b = cell_b.cpu()

In [None]:
imshow(torchvision.utils.make_grid(torch.cat((base_a, base_b), dim=0)))

In [None]:
imshow(torchvision.utils.make_grid(interpolated))
plt.show()

In [None]:
loss, mse, kld = loss_function(recon, x, mu, logvar)
kld

In [None]:
torch.mean((-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)))

In [None]:
torch.sum(-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1))

In [None]:
-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

In [None]:
mse

In [None]:
# MSE
torch.mean(torch.mean((x - recon).pow(2), dim=(1,2,3)))

In [None]:
# Mean of sum of squared errors
torch.mean(torch.sum((x - recon).pow(2), dim=(1,2,3)))

In [None]:
torch.sum((x - recon).pow(2), dim=(1,2,3))

In [None]:
recon

In [None]:
F.mse_loss(recon, x, reduction='sum')

# Extra

The cell below calculates the MEAN and STD of the data set so it can be normalized properly

In [None]:
image_dataset = datasets.ImageFolder(data_dir, transform=transforms.Compose([
        transforms.ToTensor(),
        #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]))

loader = data.DataLoader(
    image_dataset,
    batch_size=128,
    num_workers=0,
    pin_memory=True,
    shuffle=False
)


mean = 0.
std = 0.
nb_samples = 0.
for dat, _ in loader:
    batch_samples = dat.size(0)
    dat = dat.view(batch_samples, dat.size(1), -1)
    mean += dat.mean(2).sum(0)
    std += dat.std(2).sum(0)
    nb_samples += batch_samples

mean /= nb_samples
std /= nb_samples
print("mean: ", mean)
print("std: ", std)