In [None]:
import glob
import json
import os
from collections import OrderedDict

import albumentations as aug
import colorcet as cc
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import tqdm
import umap
from captum.attr import LRP, GuidedBackprop, IntegratedGradients, FeatureAblation
from captum.attr import visualization as viz
from sklearn.decomposition import NMF, PCA
from sklearn.manifold import TSNE, trustworthiness
from sklearn.metrics import classification_report
from torch import nn
from torch.utils.data import DataLoader
from torchvision.utils import make_grid, save_image

import models as models
from data.dataset import MOADataset, get_normalization_stats, dmso_normalization
from models.resnet import ResNetLRP
from utils.cka.cka_dataloader import CKA

import torch
import torch.nn.functional as F

from PIL import Image

import os
import json
import numpy as np

import torch.nn.functional as F
import cv2

from kornia.filters import gaussian_blur2d

In [None]:
exp_folders = sorted(glob.glob("exps/specs_non_grit_based/*bf*"))[:5]
exp_cat = "bf_10cls_basic_aug_dmsonorm_750e_sgd/ResNet_resnet50"
bf_folders = [os.path.join(exp, exp_cat) for exp in exp_folders]

# exp_folders = sorted(glob.glob("/proj/haste_berzelius/exps/specs_non_grit_based/*fl*"))[:5]
# exp_cat = "fl_10cls_basic_aug_dmso_norm_750e_sgd/ResNet_resnet50"
# bf_folders = [os.path.join(exp, exp_cat) for exp in exp_folders]


In [None]:
valid_transforms = aug.Compose([aug.Resize(1080, 1080, p=1.0), aug.CenterCrop(1024, 1024, p=1.0)])

site_conversion = pd.DataFrame(
    {"bf_sites": ["s1", "s2", "s3", "s4", "s5"], "f_sites": ["s2", "s4", "s5", "s6", "s8"]}
)

moa_dict = {
    "moa": [
        "Aurora kinase inhibitor",
        "tubulin polymerization inhibitor",
        "JAK inhibitor",
        "protein synthesis inhibitor",
        "HDAC inhibitor",
        "topoisomerase inhibitor",
        "PARP inhibitor",
        "ATPase inhibitor",
        "retinoid receptor agonist",
        "HSP inhibitor",
        "dmso",
    ],
    "MoA": [
        "AuroraK-i",
        "Tub.Pol.-i",
        "JAK-i",
        "Prot.Synth.-i",
        "HDAC-i",
        "Topo.-i",
        "PARP-i",
        "ATPase-i",
        "Ret.Rec.Ag",
        "HSP-i",
        "DMSO",
    ],
}

In [None]:


for bf_folder in bf_folders:
    config_bf = json.load(open(os.path.join(bf_folder, "config_exp.json")))

    config_bf["data"]["data_folder"] = "datasets/specs"
    if not "mean_mode" in config_bf["data"]:
        config_bf["data"]["mean_mode"] = "mean"
    if not "modality" in config_bf["data"]:
        config_bf["data"]["modality"] = "bf"
    image_folder = os.path.join(bf_folder, "test_images")

    moas = np.sort(config_bf["data"]["moas"])
    moas_short = [moa_dict["MoA"][moa_dict["moa"].index(moa)] for moa in moas]

    moa_folders = []
    for moa in moas_short:
        os.makedirs(os.path.join(image_folder, moa), exist_ok=True)
        moa_folders.append(os.path.join(image_folder, moa))

    bf_df = (
        pd.read_csv(config_bf["data"]["test_csv_path"])
        .sort_values(["plate", "well", "site"])
        .reset_index(drop=True)
    )

    model = getattr(models, config_bf["model"]["type"])(**config_bf["model"]["args"])
    model = model.cuda()  # nn.DataParallel(model.cuda())

    model_checkpoint = torch.load(os.path.join(bf_folder, config_bf["test"]["model_path"]))
    new_state_dict = OrderedDict()
    for k, v in model_checkpoint["model_state_dict"].items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
    model.eval()

    best_train_epoch = model_checkpoint["epoch"]
    best_train_accuracy = model_checkpoint["epoch_accuracy"]
    print(f"Loading model with Epoch:{best_train_epoch},Accuracy:{best_train_accuracy}")

    bf_dataset = MOADataset(
        root=config_bf["data"]["data_folder"],
        csv_file=bf_df,
        normalize=config_bf["data"]["normalization"],
        dmso_stats_path=config_bf["data"]["dmso_stats_path"],
        moas=config_bf["data"]["moas"],
        geo_transform=valid_transforms,
        bg_correct=config_bf["data"]["bg_correct"],
        modality=config_bf["data"]["modality"],
        mean_mode=config_bf["data"]["mean_mode"],
    )

    stats_df = pd.read_csv(config_bf["data"]["dmso_stats_path"])

    gbp = GuidedBackprop(model)
    feature_mask = torch.tensor([[[0, 0, 1, 1], [0, 0, 1, 1], [2, 2, 3, 3], [2, 2, 3, 3]]])

    for i in tqdm.tqdm(range(0, len(bf_dataset))):
        bf_sample = bf_dataset.__getitem__(i)
        input = bf_sample[0].unsqueeze(0).cuda()
        target = bf_sample[1]
        plate = bf_sample[2]
        site = bf_sample[3]
        compound = bf_sample[4]
        well = bf_sample[5]

        target_comps = [
            "CBK308276",
            "CBK308126",
            "CBK288281",
            "CBK309507",
            "CBK288327",
            "CBK290766",
            "CBK309655",
        ]

        if compound not in target_comps:

            output = model(input)
            output = F.softmax(output, dim=1)
            prediction_score, pred_label_idx = torch.topk(output, 1)
            pred_label_idx.squeeze_()

            x = torch.max(input, dim=1)[0]
            bf_input_vis = ((x - x.min()) / (x.max() - x.min())).squeeze(0).cpu().detach().numpy()

            attribution_gbp = gbp.attribute(input, target=pred_label_idx)
            attribution_gbp = viz._normalize_image_attr(
                np.transpose(attribution_gbp.squeeze().cpu().detach().numpy(), (1, 2, 0)),
                sign="positive",
            )


            cv2.imwrite(
                os.path.join(
                    moa_folders[target],
                    f"{plate}_{compound}_{well}_{site}_{moas_short[pred_label_idx]}.png",
                ),
                (bf_input_vis * 255).astype("uint8"),
            )


            cv2.imwrite(
                os.path.join(
                    moa_folders[target],
                    f"{plate}_{compound}_{well}_{site}_gbp_{moas_short[pred_label_idx]}.png",
                ),
                (attribution_gbp * 255).astype("uint8"),
            )