# Reconstruction of cellpainting image using ML

This notebook aims at showing basic models and training and generation for Cell-Painting data.

In [None]:
import torchvision
import torch
import json
import os
import sys
import importlib
import shutil
import numpy as np 
import lightning as L

from lightning.pytorch.utilities.model_summary import ModelSummary
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
from torchvision.transforms import v2

plt.rcParams["savefig.bbox"] = 'tight'

# Adding the package in loading path
sys.path.extend(["../"])

import gencellpainting as gc
from gencellpainting.utils.dataset import WGANCriticDataset, CellPaintingDatasetInMemory
from gencellpainting.model import *


### Paths

We define the paths to the dataset and to store the trained models, the training tensorboard logs as well as the image for evaluations.

In [None]:
# Path of the data to update if you want to run this script
PATH_DATASET = "/mnt/c/Users/alexi/Documents/data/images/cellpainting/cpg0016-jump/data/jump_64px_uint8.pt"

PATH_ROOT = os.path.abspath("../..")

# Path of optimized parameters
PATH_OPTIM = os.path.join(PATH_ROOT,"data","optim")

# Path of the output of the model
PATH_OUTPUT = os.path.abspath(os.path.join(PATH_ROOT,"output"))
PATH_MODELS = os.path.join(PATH_OUTPUT,"models")
PATH_TSB_LOGS = os.path.join(PATH_OUTPUT,"tensorboard_logs")


### Parameters

High-levels parameters, which will be used during training:

In [None]:
BATCH_SIZE = 64
MAX_EPOCHS = 100
TEST_FRACTION = 0.2

### Loading the data

The data as already been preprocessed into tensor of dimension N, C, H, W where:
* __N__ is the number of example
* __C__ is the rnumber of channel: in this case 5
* __H__ is the heigth of the image in this case after resizing 64.
* __W__ is the wifdth of the image in this case after resizing 64.
This tensor is directly passed to the dataset we constructed. It could also be used directly with a _TensorDataset_.

In [None]:
ds = CellPaintingDatasetInMemory(tensor=torch.load(PATH_DATASET))

We save the image dimension to use as parameters when training the networks.

In [None]:
image = ds[1000]
C,H,W = image.shape

We can define some utility functions to visualize the images

In [None]:
# Visualization fucntion take from  https://docs.pytorch.org/vision/main/auto_examples/others/plot_visualization_utils.html#sphx-glr-auto-examples-others-plot-visualization-utils-py
def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

def plot_cellpainting_image(image, nrow=3):
    imgs = list(torch.split(image,1,dim=0))
    print([x.shape for x in imgs])
    grid = torchvision.utils.make_grid(imgs,nrow = nrow)
    show(grid)

plot_cellpainting_image(image)

In order to diversify the amount of image, we add a set of transformations

In [None]:
transforms = v2.Compose([
    v2.RandomHorizontalFlip(p = 0.5),
    v2.RandomVerticalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True) # Tensor values [0, 255] -> [0.0, 1.0]
])
# We pass it directly to the datasets transfroms for convenience.
ds.transform = transforms

We can split the dataset into a train and test set

In [None]:
ds_train, ds_test = torch.utils.data.random_split(ds, [1-TEST_FRACTION,TEST_FRACTION])

No in order to perform the learning we need to create batches using a _DataLoader_.

In [None]:
dl_train = torch.utils.data.DataLoader(ds_train,batch_size=BATCH_SIZE,shuffle=True,num_workers=4)
dl_test = torch.utils.data.DataLoader(ds_test,batch_size=BATCH_SIZE,shuffle=False,num_workers=4)

### Logging

We log the training using tensorboard, storing some diagnosis metrics as well as sampling some images at each trainig steps. You can run tensorboard `tensorboard --logdir {PATH_TSB_LOGS}` using the path defined at the beginning of this notebook.

In [None]:
# We erase the logs if they exists
if os.path.exists(PATH_TSB_LOGS):
    shutil.rmtree(PATH_TSB_LOGS)
print("Logging for tensorboard in '{}'".format(PATH_TSB_LOGS))


# Training

We first retrieve the best parameters sets after the grid search for the VAE and WGAN-GP.

In [None]:
def load_best_params(model_name):
    OPTIM_NAME = model_name + "_hpar_optim"
    json_path = os.path.join(PATH_OPTIM,OPTIM_NAME,"best_parameters_"+OPTIM_NAME+".json")
    with open(json_path,"r") as f:
        params = json.load(f)
    return params

params_VAE = load_best_params("VAE")
params_WGANGP = load_best_params("WGANGP")



## $\beta$-VAE 
We first train the $\beta$-VAE using the optimized parameters. We update the monitoring parameters to output images during training at every epoch.

In [None]:
params_VAE["epoch_monitoring_interval"] = 1
params_VAE["n_images_monitoring"] = 6
params_VAE["latent_dim"] = int(params_VAE["latent_dim"])
params_VAE["network_capacity"] = int(params_VAE["network_capacity"])
print(params_VAE)

In [None]:
model_vae = VAE(**params_VAE)

A good practice is to evaluate if:

1. The model can actually run.
2. We don t have a

In [None]:
ModelSummary(model_vae)

We configure the trainer, with an early stopping case

In [None]:
NAME_MODEL = "VAE"
plogs = os.path.join(PATH_TSB_LOGS,NAME_MODEL)
# Deleting the model log folder if it already exists for clarity
if os.path.isdir(plogs):
    shutil.rmtree(plogs)
tb_logger = L.pytorch.loggers.TensorBoardLogger(save_dir = PATH_TSB_LOGS, name=NAME_MODEL)


In [None]:
trainer = L.Trainer(max_epochs=100, accelerator="gpu", devices=1, logger=tb_logger)

In [None]:
trainer.fit(model_vae, train_dataloaders=dl_train, val_dataloaders=dl_test)

We save the model after training

In [None]:
trainer.save_checkpoint(os.path.join(PATH_MODELS,NAME_MODEL))

In [None]:
b1 = ds_test[10]
b1 = b1[None,:,:,:]
b1.shape

In [None]:
y1 = model_vae.decoder(model_vae.encoder(b1).sample())

In [None]:
plot_cellpainting_image(b1.squeeze())

In [None]:
plot_cellpainting_image(y1.squeeze())

### Wasserstein GAN

The Wasserstein GAN are a more stable version of GAN, and to be frank I am just curious of their performance. We first have to modify the dataset. A wasserstein GAN need _N_ samples to train the generator and _N x C_ sample to train the _C_ critics.

In [None]:
NCRITICS = 5
ds_W = WGANCriticDataset(ds,ncritic=NCRITICS)
ds_train_W, ds_test_W = torch.utils.data.random_split(ds_W, [1-TEST_FRACTION,TEST_FRACTION])

In order for the dataloader to know how to stitch a batch together we need to provide a specific `collate_fn` argument

In [None]:
# This function will return a tuple with 2 elements
# 1. The images to process with the generator (B x C x H x W)
# 2. The images to process with the discriminator for learning ( (B x NCRITIC) x C x H x W) 
def collate_wgan_batch(batch):
    gen_imgs,disc_imgs = zip(*batch)
    gen_imgs = torch.stack(gen_imgs)
    disc_imgs = torch.stack([y for subbatch in disc_imgs for y in subbatch ])
    return gen_imgs, disc_imgs

In [None]:
dl_train_W = torch.utils.data.DataLoader(ds_train_W,batch_size=32,shuffle=True, collate_fn=collate_wgan_batch, num_workers=4)
dl_test_W = torch.utils.data.DataLoader(ds_test_W,batch_size=32,shuffle=True, collate_fn=collate_wgan_batch, num_workers=4)

We can now test the model

In [None]:
import gencellpainting.model.WGAN as WGAN
importlib.reload(WGAN)

In [None]:
params_WGANGP["epoch_monitoring_interval"] = 1
params_WGANGP["n_images_monitoring"] = 3
params_WGANGP["noise_dim"] = int(params_WGANGP["noise_dim"])
params_WGANGP["network_capacity"] = int(params_WGANGP["network_capacity"])
print(params_VAE)

In [None]:
wgan = WGAN.WGAN_GP(**params_WGANGP)

In [None]:
ModelSummary(wgan)

In [None]:
tb_logger = L.pytorch.loggers.TensorBoardLogger(save_dir = PATH_TSB_LOGS, name="WGAN_GP")
trainer_wg = L.Trainer(max_epochs=100,accelerator="gpu",devices=1, logger=tb_logger)

In [None]:
trainer_wg.fit(wgan,dl_train_W)

In [None]:
trainer_wg.save_checkpoint(os.path.join(PATH_MODELS,"WGANGP"))

### Diffusion based model

This section presents the diffusion training process

In [None]:
import gencellpainting.model.net.UNETdiffusion as UND
importlib.reload(UND)
import gencellpainting.model.diffusion as DIF
importlib.reload(DIF)


In [None]:
TIME_CHANNELS = 62
NETWORK_CAPACITY = 32
NSTEPS = 200
NLAYERS = 3

We can now create the diffusion process using the UNET created

In [None]:
diff_unet = UND.UNetDiffusionV2(ds.n_channels, ds.n_channels, time_channels = TIME_CHANNELS, network_capacity=NETWORK_CAPACITY, nlayers=NLAYERS)

diffusion = DIF.DiffusionProcess(1,time_dim=TIME_CHANNELS,nsteps=NSTEPS,\
                                 model=diff_unet, include_time_emb=True)

This version of diffusion is suited to generate images with values in [-1,1], our current dataset have value between [0,1]. We can create a new dataloader to rescale the image

In [None]:
# Gaussian scaling
ds.tensor.float().mean()

In [None]:
def collate_diff_batch(batch):
    gen_imgs = batch
    gen_imgs = torch.stack(gen_imgs)
    gen_imgs = gen_imgs * 2. - 1.
    return gen_imgs

In [None]:
dl_train_diff = torch.utils.data.DataLoader(ds_train,batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_diff_batch, num_workers=4)


In [None]:
ModelSummary(diffusion)

We need to train the model

In [None]:

NAME_MODEL = "DIFFUNET"
plogs = os.path.join(PATH_TSB_LOGS,NAME_MODEL)
if os.path.isdir(plogs):
    shutil.rmtree(plogs)
tb_logger = L.pytorch.loggers.TensorBoardLogger(save_dir = PATH_TSB_LOGS, name=NAME_MODEL)

trainer_diff = L.Trainer(max_epochs=100,accelerator="gpu",devices=1, logger=tb_logger)

In [None]:
trainer_diff.fit(diffusion,dl_train_diff)

In [None]:
trainer_diff.save_checkpoint(os.path.join(PATH_MODELS,NAME_MODEL))