In [None]:
import sys

import hydra
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from hydra import compose, initialize
from lightning import Trainer
from omegaconf import DictConfig, OmegaConf
from rootutils import setup_root
from sklearn.metrics import confusion_matrix
from torchmetrics import AUROC, Accuracy, F1Score

# setup_root("../", indicator=".project-root", pythonpath=True)


sys.path.append("/home/rajeeva/Project/yeast_growth_pred/")
sys.path.append("/home/rajeeva/Project/yeast_growth_pred/code/")

from main_code.data import CancerKFoldModule
from main_code.lightning_model import NetMultiViewLightning
from main_code.utils.plot_utils import predict, set_theme

set_theme()
pd.set_eng_float_format(accuracy=2)

In [None]:
def plot_confusion_matrix(
    y, y_pred, title="Confusion matrix", ax=None, cmap=plt.cm.Blues, titlesize=18
):
    cm = confusion_matrix(y, y_pred)

    if ax is not None:
        plot = sns.heatmap(cm, annot=True, cmap=cmap, fmt="d", ax=ax, cbar=False)
    else:
        plot = sns.heatmap(cm, annot=True, cmap=cmap, fmt="d", cbar=False)

    ax.set_title(title, fontsize=titlesize)
    plt.xlabel("Predicted")
    plt.ylabel("True")

    return plot

In [None]:
# intializing Hydra configs


overrides = [
    # "data\=cancer_data"
    "train_data.ckpt=/storage/bt20d204/runs/multiview/Bloom2013_09_01_2024/Bloom2013-1-epoch\=56_step\=6783.ckpt",  # noqa
    # "test_data.name=Bloom2013",
    # "test_data.data_path=${paths.data_dir}/bloom2013_clf_3_pubchem.feather",
    "device=cpu",
]


with initialize(version_base=None, config_path="../configs/"):
    cfg = compose(config_name="eval", overrides=overrides)

print(OmegaConf.to_yaml(cfg))

In [None]:
# Intitializing the model

model = NetMultiViewLightning.load_from_checkpoint(
    "../../runs/cancer/PRISM_1Q94_15_01_2024/PRISM_1Q94-0-epoch=97_step=644840.ckpt",
    map_location=cfg.device,
)

model.eval()

In [None]:
datamodule = CancerKFoldModule(path="../../data/cancer/PRISM_19Q4/")
datamodule.setup(stage="test")

In [None]:
trainer = Trainer(accelerator="cpu")

datamodule.predict_dataloader = torch.utils.data.DataLoader(
    datamodule.test_dataset,
    batch_size=256,
    num_workers=1,
    shuffle=False,
    drop_last=False,
    pin_memory=True,
)


def predict_loader():
    for batch in datamodule.predict_dataloader:
        yield batch[0]


# next(predict_loader())

In [None]:
predictions = trainer.predict(model, predict_loader())

In [None]:
predictions

In [None]:
train = cfg.train_data.name
test = cfg.test_data.name
filename = cfg.test_data.data_path

In [None]:
data = pd.read_feather(filename)
strain = data.drop(columns=["Strain"], inplace=True)
variation_columns = [col for col in data.columns if col.startswith("Y")]
data.head()

In [None]:
# Defining metrics
acc = Accuracy(task="binary").to(cfg.device)
auc = AUROC(task="binary").to(cfg.device)
f1 = F1Score(task="binary").to(cfg.device)

In [None]:
results_df = pd.DataFrame(columns=["Compound", "Accuracy", "AUC", "F1"])
fig, axes = plt.subplots(5, 8, figsize=(16, 12), sharex=True, sharey=True, tight_layout=True)

pred_dict = dict()

for i, (key, ax) in enumerate(zip(data.Condition.unique(), axes.flatten())):
    X = (
        torch.tensor(
            data.loc[data["Condition"] == key].drop(columns=["Phenotype", "Condition"]).values
        )
        .float()
        .to(cfg.device)
    )
    y = torch.tensor(data.loc[data["Condition"] == key]["Phenotype"].values).float().to(cfg.device)

    y_pred = torch.sigmoid(model(X)).reshape(-1).detach()

    acc_score = acc(y_pred, y)
    auc_score = auc(y_pred, y)
    f1_score = f1(y_pred, y)

    print(f"{key}", end=": ")
    print(f"Accuracy: {acc_score:.2f} | AUC: {auc_score:.2f} | F1: {f1_score:.2f}")

    pred_dict[key] = y_pred.numpy()
    y_pred = np.rint(y_pred.numpy())

    plot_confusion_matrix(y, y_pred, title=key, ax=ax, titlesize=10)

    results_df.loc[i] = [key, acc_score, auc_score, f1_score]

fig.supxlabel("Predicted")
fig.supylabel("True")
plt.suptitle(f"{train} | {test}")

In [None]:
results_df[["F1", "AUC", "Accuracy"]] = results_df[["F1", "AUC", "Accuracy"]].astype(float)
results_df["Compound"] = results_df["Compound"].astype(str)

In [None]:
print(
    results_df[["F1", "AUC", "Accuracy"]].apply(
        lambda x: str(x.mean()) + " ± " + str(x.std()), axis=0
    )
)

In [None]:
bloom_chemicals = [
    "diamide",
    "formamide",
    "MgCl2",
    "CuSO4",
    "etoh",
    "CoCl2",
    "trehalose",
    "xylose",
    "raffinose",
    "lactate",
    "neomycin",
    "MnSO4",
    "zeocin",
]

bloom_2013_chemicals = np.array(
    [
        "berbamine",
        "cocl2",
        "diamide",
        "ethanol",
        "formamide",
        "hydroxyurea",
        "lactate",
        "lactose",
        "menadione",
        "mgcl2",
        "indoleacetic_acid",
        "neomycin",
        "raffinose",
        "trehalose",
        "xylose",
        "zeocin",
    ],
    dtype="<U19",
)


yticklabels = results_df.set_index("Compound").sort_values("AUC").index
text_coloring = [
    "white" if condition in bloom_chemicals else "yellow" for condition in yticklabels
]

In [None]:
fig2, ax = plt.subplots(figsize=(6, 12))

sns.heatmap(
    results_df.set_index("Compound").sort_values("AUC"),
    cmap="viridis",
    annot=True,
    fmt=".2f",
    vmin=0,
    vmax=1,
)
plt.title(f"{train} Model on {test}", size=18)

# plt.savefig(f"../baselines/mutation_only/{train} on {test}.png", dpi=300, transparent=True, bbox_inches='tight')

for text in ax.get_yticklabels():
    if text.get_text().lower() in bloom_2013_chemicals:
        text.set(backgroundcolor="yellow")
    else:
        text.set(backgroundcolor="white")