# Playground

In [1]:
%reload_ext autoreload
%autoreload 2

import sys
sys.path.append("./src")

In [2]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch import nn
from torch.utils.data import DataLoader

from dataset import *
from plots import *
from models import *

## Metadata Summary

In [None]:
metadata = load_metadata()

In [None]:
metadata.head()

In [None]:
metadata.shape

In [None]:
plot_MOA_distribution(metadata)

In [None]:
plot_treatment_heatmap(metadata)

## Tiny - Subset of Dataset

In [None]:
tiny_metadata = metadata[metadata["Multi_Cell_Image_Name"] == "Week10_200907_F02_s1_w14631241C-4FA2-4BC9-8693-D7D268CAEE82"]
#tiny_metadata = metadata[metadata["Multi_Cell_Image_Name"] == "B02_s1_w16F89C55C-7808-4136-82E4-E066F8E3CB10"]

In [None]:
images = load_images_from_metadata(tiny_metadata)
images = normalize_channel_wise(images)

print(f"{images.shape[0]} images")

In [None]:
plot_treatment_heatmap(tiny_metadata)

In [None]:
img1 = images[0]

In [None]:
plot_image(img1)

In [None]:
plot_channels(img1)

## Stratified - Subset of Dataset

In [None]:
blacklist = [("Eg5 inhibitors", 0.1), ("Microtubule destabilizers", 0.3), ("Cholesterol-lowering", 6.0)]
stratified = stratify_metadata(metadata, 60, blacklist=blacklist)

In [None]:
images = load_images_from_metadata(stratified)
images = normalize_channel_wise(images)

print(f"{images.shape[0]} images")

In [None]:
plot_treatment_heatmap(stratified)

## Model - Training

In [None]:
images = normalized_to_zscore(images)

In [None]:
print(f"{images.shape[0]} images")

In [None]:
cropped_images = view_cropped_images(images)
train(cropped_images, epochs=20, epoch_sample_times=2)

## Model - Load pretrained

In [None]:
unet = UNet()
unet.load_state_dict(torch.load("./models/DDPM_Unconditional/ckpt.pt"))
diffusion = Diffusion()
sampled_images = diffusion.sample(unet, n=1)

## Results

In [None]:
epoch_images, epochs = load_epoch_images()
plot_epoch_sample_series(epoch_images, epochs)

In [None]:
plot_image(epoch_images[-1][0])