# Playground

In [None]:
%reload_ext autoreload
%autoreload 2

import sys
sys.path.append("./code")

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset

import logging
logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.DEBUG, datefmt="%I:%M:%S")

from dataset import *
from plots import *
from models import *
from utils import *

fix_seed()

## Metadata Summary

In [None]:
metadata = load_metadata()

In [None]:
metadata.head()

In [None]:
metadata.shape

In [None]:
plot_MOA_distribution(metadata)

In [None]:
plot_treatment_heatmap(metadata)

## Tiny - Subset of Dataset

In [None]:
tiny_metadata = metadata[metadata["Multi_Cell_Image_Name"] == "Week10_200907_F02_s1_w14631241C-4FA2-4BC9-8693-D7D268CAEE82"]
#tiny_metadata = metadata[metadata["Multi_Cell_Image_Name"] == "B02_s1_w16F89C55C-7808-4136-82E4-E066F8E3CB10"]

tiny_images = load_images_from_metadata(tiny_metadata)
print(f"{tiny_metadata.shape[0]} images")

In [None]:
plot_treatment_heatmap(tiny_metadata)

## Stratified - Subset of Dataset

In [None]:
blacklist = [("Eg5 inhibitors", 0.1), ("Microtubule destabilizers", 0.3), ("Cholesterol-lowering", 6.0)]
stratified_metadata = stratify_metadata(metadata, 60, blacklist=blacklist)

stratified_images = load_images_from_metadata(stratified_metadata)
print(f"{stratified_images.shape[0]} images")

In [None]:
plot_treatment_heatmap(stratified_metadata)

## Model - Training

In [None]:
train_images = tiny_images
train_metadata = tiny_metadata

train_images = normalize_channel_wise(train_images)

In [None]:
img1 = train_images[0]

In [None]:
plot_image(img1)

In [None]:
plot_channels(img1)

In [None]:
train_images = normalized_to_zscore(train_images)
train_images = view_cropped_images(train_images)

In [None]:
train_conditional_diffusion_model(train_metadata, train_images, epochs=20, epoch_sample_times=2, batch_size=2)

## Model - Load pretrained

In [None]:
unet = UNet()
unet.load_state_dict(torch.load("./models/DDPM_Unconditional/ckpt.pt"))
diffusion = Diffusion()
sampled_images = diffusion.sample(unet, n=1)

## Results

In [None]:
epoch_images, epochs = load_epoch_images("./results/DDPM_Conditional/")
plot_epoch_sample_series(epoch_images, epochs)

In [None]:
plot_image(epoch_images[-1][0])