# Playground

In [1]:
from typing import List, Set, Dict, Tuple, Optional, Any
from collections import defaultdict

import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

import math 
import torch
from torch import nn, Tensor
from torch.nn.functional import softplus, relu
from torch.distributions import Distribution, Normal
from torch.utils.data import DataLoader, Dataset

from gmfpp.utils.data_preparation import *
from gmfpp.utils.data_transformers import *
from gmfpp.utils.plotting import *

from gmfpp.models.ReparameterizedDiagonalGaussian import *
from gmfpp.models.CytoVariationalAutoencoder import *
from gmfpp.models.VariationalAutoencoder import *
from gmfpp.models.ConvVariationalAutoencoder import *
from gmfpp.models.VariationalInference import *

%matplotlib inline

In [2]:
def constant_seed(seed: int = 0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)

In [3]:
constant_seed()

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


## Load data

In [5]:
metadata_all = read_metadata("./data/all/metadata.csv")
mapping = get_MOA_mappings(metadata_all)

In [6]:
metadata = read_metadata("./data/tiny/metadata.csv")

In [7]:
metadata = shuffle_metadata(metadata)
metadata_train_all, metadata_test = split_metadata(metadata, split_fraction = .90)
metadata_train, metadata_validation = split_metadata(metadata_train_all, split_fraction = .90)

In [8]:
relative_path = get_relative_image_paths(metadata)
image_paths = ["./data/tiny/" + path for path in relative_path]
images = load_images(image_paths)

In [9]:
len(images)

259

## Normalize data

In [10]:
images = prepare_raw_images(images)
normalize_channels_inplace(images)
print(images.shape)

torch.Size([259, 3, 68, 68])


In [11]:
channel_first = view_channel_dim_first(images)
for i in range(channel_first.shape[0]):
    channel = channel_first[i]
    print("channel {} interval: [{:.2f}; {:.2f}]".format(i, torch.min(channel), torch.max(channel)))

channel 0 interval: [0.02; 1.00]
channel 1 interval: [0.04; 1.00]
channel 2 interval: [0.05; 1.00]


## Prepare datasets

In [12]:
class SingleCellDataset(Dataset):
    
    def __init__(self, metadata: pd.DataFrame, images: torch.Tensor, label_to_id: Dict[str, int]):
        self.metadata = metadata
        self.label_to_id = label_to_id
        self.images = images
        
    def __len__(self):
        return self.metadata.shape[0]

    def __getitem__(self, idx):
        row = self.metadata.iloc[idx]
        
        image = self.images[idx]
        
        label_name = row["moa"]
        label = self.label_to_id[label_name]
        
        return image, label

In [15]:
train_set = SingleCellDataset(metadata_train, images, mapping)
validation_set = SingleCellDataset(metadata_validation, images, mapping)
test_set = SingleCellDataset(metadata_test, images, mapping)

## VAE

In [16]:
image_shape = np.array([3, 68, 68])
latent_features = 256

vae = CytoVariationalAutoencoder(image_shape, latent_features)
#vae = VariationalAutoencoder(image_shape, latent_features)
vae = vae.to(device)

beta = 1.
vi = VariationalInference(beta=beta)


num_epochs = 10
batch_size = 32

learning_rate = 1e-3
weight_decay = 10e-4

optimizer = torch.optim.Adam(vae.parameters(), lr=learning_rate, weight_decay=weight_decay)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)

In [17]:
training_data = defaultdict(list)
validation_data = defaultdict(list)

for epoch in range(num_epochs):
    print(f"epoch: {epoch}/{num_epochs}")    

    training_epoch_data = defaultdict(list)
    vae.train()

    for x, y in train_loader:
        x = x.to(device)
        
        # perform a forward pass through the model and compute the ELBO
        loss, diagnostics, outputs = vi(vae, x)

        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(vae.parameters(), 10_000)
        optimizer.step()

        # gather data for the current batch
        for k, v in diagnostics.items():
            training_epoch_data[k] += [v.mean().item()]

    print("training | elbo: {:2f}, log_px: {:.2f}, kl: {:.2f}:".format(np.mean(training_epoch_data["elbo"]), np.mean(training_epoch_data["log_px"]), np.mean(training_epoch_data["kl"])))

    # gather data for the full epoch
    for k, v in training_epoch_data.items():
        training_data[k] += [np.mean(training_epoch_data[k])]

    # Evaluate on a single batch, do not propagate gradients
    with torch.no_grad():
        vae.eval()

        # Just load a single batch from the test loader
        '''x, y = next(iter(test_loader))'''
        x = x.to(device)

        # perform a forward pass through the model and compute the ELBO
        loss, diagnostics, outputs = vi(vae, x)

        # gather data for the validation step
        for k, v in diagnostics.items():
            validation_data[k] += [v.mean().item()]

    print("validation | elbo: {:2f}, log_px: {:.2f}, kl: {:.2f}:".format(np.mean(validation_data["elbo"]), np.mean(validation_data["log_px"]), np.mean(validation_data["kl"])))    

epoch: 0/10
training | elbo: -14701.419271, log_px: -13283.28, kl: 1418.14:
validation | elbo: -13014.904297, log_px: -12988.78, kl: 26.12:
epoch: 1/10
training | elbo: -13492.100098, log_px: -12836.09, kl: 656.01:
validation | elbo: -13081.771973, log_px: -12945.89, kl: 135.89:
epoch: 2/10
training | elbo: -13293.366699, log_px: -12798.09, kl: 495.27:
validation | elbo: -13272.919271, log_px: -12909.20, kl: 363.72:
epoch: 3/10
training | elbo: -13179.118978, log_px: -12775.27, kl: 403.85:
validation | elbo: -13558.588379, log_px: -12877.05, kl: 681.54:
epoch: 4/10
training | elbo: -13127.640951, log_px: -12765.65, kl: 361.99:
validation | elbo: -13673.250195, log_px: -12860.11, kl: 813.14:
epoch: 5/10
training | elbo: -13095.499512, log_px: -12756.63, kl: 338.87:
validation | elbo: -13699.078776, log_px: -12843.98, kl: 855.10:
epoch: 6/10
training | elbo: -13059.161458, log_px: -12750.91, kl: 308.25:
validation | elbo: -13660.979074, log_px: -12832.04, kl: 828.94:
epoch: 7/10
training

In [None]:
plt.plot(training_data["elbo"])

## Compare reconstruction and original image

In [None]:
x = train_set[0]

In [None]:
plot_image(x)

In [None]:
vae.eval() # because of batch normalization
outputs = vae(x[None,:,:,:])
px = outputs["px"]

x_reconstruction = px.sample()
x_reconstruction = x_reconstruction[0]
plot_image_channels(x_reconstruction)

In [None]:
x_reconstruction = px.sample()
x_reconstruction = x_reconstruction[0]
plot_image(clip_image_to_zero_one(x_reconstruction))

In [None]:
# @TODO cleanup. Used to images from cluster quickly
#x_reconstruction = torch.tensor(np.array(load_images(["./dump/images/x0_reconstruction.npy"])[0], dtype=np.float32))
#x = torch.tensor(np.array(load_images(["./dump/images/x0.npy"])[0], dtype=np.float32))

In [None]:
plot_image_channels(x)

In [None]:
plot_image_channels(x_reconstruction)

## Downstream Classification

In [18]:
class NeuralNetwork(nn.Module):
    
    def __init__(self, n_classes: int = 13):
        super(NeuralNetwork, self).__init__()
        
        self.net = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, n_classes))

    def forward(self, x):
        logits = self.net(x)
        return logits

    
N_classes = len(mapping)
classifier = NeuralNetwork(N_classes).to(device)
print(classifier)

NeuralNetwork(
  (net): Sequential(
    (0): Linear(in_features=256, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=13, bias=True)
  )
)


In [19]:
class SingleCellDataset(Dataset):
    
    def __init__(self, metadata: pd.DataFrame, images: torch.Tensor, label_to_id: Dict[str, int]):
        self.metadata = metadata
        self.label_to_id = label_to_id
        self.images = images
        
    def __len__(self):
        return self.metadata.shape[0]

    def __getitem__(self, idx):
        row = self.metadata.iloc[idx]
        
        image = self.images[idx]
        
        label_name = row["moa"]
        label = self.label_to_id[label_name]
        
        return image, label

In [20]:
# VAE
image_shape = np.array([3, 68, 68])
latent_features = 256
vae = CytoVariationalAutoencoder(image_shape, latent_features) # @TODO: load trained parameters
vae.eval()

# Classifier
N_classes = len(mapping)
classifier = NeuralNetwork(N_classes).to(device)

# Training
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(classifier.parameters(), lr=1e-3)

num_epochs = 5
batch_size = 16

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)

In [21]:
for epoch in range(num_epochs):
    print(f"epoch: {epoch}/{num_epochs}")    

    training_epoch_data = defaultdict(list)
    vae.train()
    
    train_epoch_loss = []

    for x, y in train_loader:
        x = x.to(device)
        
        outputs = vae(x)
        z = outputs["z"]
        
        pred = classifier(z)
        loss = loss_fn(pred, y)
        
        train_epoch_loss.append(loss.item())

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
            
    print("training | loss: {:.2f}".format(np.mean(train_epoch_loss)))

epoch: 0/5
training | loss: 2.80
epoch: 1/5
training | loss: 2.61
epoch: 2/5
training | loss: 2.43
epoch: 3/5
training | loss: 2.24
epoch: 4/5
training | loss: 2.07


## @TODO
- Look at latent representation
    - How does changing one latent variable change the image reconstruction?
    - How similiar are images in the latent space (cosine-simularity)