In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import seaborn as sns
import torch
from torchmetrics import MeanSquaredError, R2Score

from src.data import KFoldEncodeModule
from src.utils.plot_utils import predict, set_theme

set_theme()
pd.set_eng_float_format(accuracy=2)

In [None]:
from src.lightning_model import Netlightning

model = Netlightning.load_from_checkpoint(
    "../runs/regression_bloom/bloom2015_reg/bloom2015_reg-2epoch=36-step=58460.ckpt",
    loss_function="mse",
)
model.eval()

In [None]:
train = "Bloom2015"
test = "Bloom2013"
filename = "../data/regression_data/bloom2013_regression.feather"

In [None]:
data = pd.read_feather(filename)
data.drop(columns=["Strain"], inplace=True)
variation_columns = [col for col in data.columns if col.startswith("Y")]
data.head()

In [None]:
from sklearn.model_selection import train_test_split

_, test_data = train_test_split(data, test_size=0.2, random_state=42)

In [None]:
results_df = pd.DataFrame(columns=["Compound", "MSE", "R-squared"])

# fig, axes = plt.subplots(5, 8, figsize=(16,12), sharex=True, sharey=True, tight_layout=True)

pred_dict = dict()
mse = MeanSquaredError()
r2 = R2Score()

for i, key in enumerate(test_data.Condition.unique()):
    condition_test = test_data[test_data["Condition"] == key]
    X = torch.tensor(condition_test.drop(["Condition", "Phenotype"], axis=1).values).float()
    y = torch.tensor(condition_test["Phenotype"].values).float().unsqueeze(1)

    y_pred = predict(model, X).to("cpu")

    mse_val = mse(y_pred, y).detach().item()
    r2_val = r2(y_pred, y).detach().item()

    print(f"{key}", end=": ")
    print(f"MSE:  {mse_val:.2f}", " | ", f"R2: , {r2_val:.2f}")

    pred_dict[key] = y_pred.reshape(-1).detach().numpy()

    results_df.loc[i] = [key, mse_val, r2_val]

In [None]:
results_df["Compound"] = results_df["Compound"].astype(str)
print(
    results_df[["MSE", "R-squared"]].apply(lambda x: str(x.mean()) + " ± " + str(x.std()), axis=0)
)

In [None]:
fig2, ax = plt.subplots(1, 2, figsize=(8, 10), sharey=True, tight_layout=True)

sns.heatmap(
    results_df.set_index("Compound").drop(["MSE"], axis=1).sort_values("R-squared"),
    cmap="viridis",
    annot=True,
    fmt=".2f",
    ax=ax[0],
)
sns.heatmap(
    results_df.set_index("Compound").sort_values("R-squared").drop(["R-squared"], axis=1),
    cmap="viridis",
    annot=True,
    fmt=".2f",
    ax=ax[1],
)

bloom_2013_chemicals = np.array(
    [
        "berbamine",
        "cocl2",
        "diamide",
        "ethanol",
        "formamide",
        "hydroxyurea",
        "lactate",
        "lactose",
        "menadione",
        "mgcl2",
        "indoleacetic_acid",
        "neomycin",
        "raffinose",
        "trehalose",
        "xylose",
        "zeocin",
    ],
    dtype="<U19",
)

for text in ax[0].get_yticklabels():
    if text.get_text().lower() in bloom_2013_chemicals:
        text.set(backgroundcolor="yellow")
    else:
        text.set(backgroundcolor="white")
plt.suptitle(f"{train} Model on {test} - Regression", size=18)

In [None]:
fig, ax = plt.subplots(figsize=(8, 6), tight_layout=True)

condition = "CoCl2"
sns.scatterplot(
    x=test_data[test_data["Condition"] == condition]["Phenotype"], y=pred_dict[condition]
)
ax.set_ylabel("Predicted")
ax.set_xlabel("True")

In [None]:
sns.histplot(x=data[data["Condition"] == condition]["Phenotype"], bins=50, color="cornflowerblue")
plt.title("Histogram of phenotype values")

In [None]:
sns.histplot(
    x=test_data[test_data["Condition"] == condition]["Phenotype"], bins=50, color="cornflowerblue"
)
plt.title("Histogram of phenotype values - Test Set")