In [None]:
import pandas as pd
pd.set_option('display.max_columns', None)

from warnings import filterwarnings
filterwarnings('ignore')

In [None]:
df = pd.read_csv("local_summary.csv")
df = df[~df["Experiment"].isna()].reset_index()

In [None]:
num_higher_reps = 3
num_lower_reps = 1

y_feature = "log_affinity_data"
metric_names = {
    "r2": "r2",
    "root_mean_squared_error": "rmse",
    "pearson::correlation": "pearson",
}

split_names = {
    "validation": [f"fold_{higher_idx:02}_{lower_idx:02}" for higher_idx in range(num_higher_reps) for lower_idx in range(num_lower_reps)],
    "test": [f"fold_{higher_idx:02}" for higher_idx in range(num_higher_reps)],
}
metric_agg = {}
metric_cols = []

for split_name, split_fold_names in split_names.items():
    for metric in metric_names:
        set_met_list = []
        for fold_name in split_fold_names:
            col = f"{fold_name}.{split_name}.{y_feature}.{metric}"
            metric_cols.append(col)
            set_met_list.append(col)
        metric_agg[f"{split_name}::{metric_names[metric]}"] = set_met_list
metric_names["pearson"] = metric_names["pearson::correlation"]

In [None]:
for k, v in metric_agg.items():
    df[f"{k}::mean"] = df[v].mean(1)
    df[f"{k}::std"] = df[v].std(1)

In [None]:
import numpy as np

df["filtering"] = [eval(x)[0]["value"] for x in df["filtering"]]
df["num_train_points"] = df[
    [f"{col}.train.{y_feature}.predictions" for col in split_names["test"]]
].apply(lambda x: np.array([len(eval(xx)) for xx in x]).mean(), axis=1)

df["num_test_points"] = df[
    [f"{col}.test.{y_feature}.predictions" for col in split_names["test"]]
].apply(lambda x: np.array([len(eval(xx)) for xx in x]).mean(), axis=1)

In [None]:
pred_cols = [f"{col}.test.{y_feature}.predictions" for col in split_names["test"]]
ref_cols = [f"{col}.test.{y_feature}.references" for col in split_names["test"]]
uniprot_cols = [f"{col}.test.{y_feature}.uniprot_id" for col in split_names["test"]]

In [None]:
col_map = {
    "rev": "rev",
    "dataset.dvc_config.rev": "dataset",
    "filtering": "uniprot_id",
    "higher_split.presets.columns": "higher_split",
    "featurisation.featurisation_name": "featurisation_name",
    'train.model_config.name': "model_tag",
    "num_train_points": "num_train_points",
    "num_test_points": "num_test_points",
}

df_agg = df[
    list(col_map.keys()) + 
    [f"{metric}::mean" for metric in metric_agg.keys()] + 
    [f"{metric}::std" for metric in metric_agg.keys()] + 
    pred_cols + ref_cols + uniprot_cols
].rename(columns=col_map)

df_agg["higher_split"] = ["_".join(eval(x)[0].split("_")[:3]) for x in df_agg["higher_split"]]
df_agg["higher_split"] = df_agg["higher_split"].map({k: f"{'_'.join(k.split('_')[:2])}_{int(k.split('_')[2]):02}" for k in set(df_agg["higher_split"])})

In [None]:
df_agg.to_csv("local_aggregated_summary.csv")

In [None]:
df_agg = pd.read_csv("local_aggregated_summary.csv")

In [None]:
df_models = df_agg[df_agg["featurisation_name"] != "mw"]

df_best = df_models.loc[
    df_models.groupby(["uniprot_id", "higher_split"])["validation::r2::mean"].idxmax().dropna()
]
df_best["model_type"] = "ligand_only_2d"
df_best.to_csv("../../notebooks/pdbbind_results/data/strat_local_ligand_only_2d_results.csv")

In [None]:
import molflux.metrics
import numpy as np
met_suite = molflux.metrics.load_suite("regression")
overall_metrics = {}
for group, df_g in df_best.groupby("higher_split"):
    sub_metrics = {
        "r2": [],
        "rmse": [],
        "pearson": [],
    }
    for col in split_names["test"]:
        preds = [eval(x) for x in df_g[f"{col}.test.{y_feature}.predictions"]]
        refs = [eval(x) for x in df_g[f"{col}.test.{y_feature}.references"]]
        preds = [xs for x in preds for xs in x]
        refs = [xs for x in refs for xs in x]
        met_vals = met_suite.compute(predictions=preds, references=refs)
        sub_metrics["r2"].append(met_vals["r2"])
        sub_metrics["rmse"].append(met_vals["root_mean_squared_error"])
        sub_metrics["pearson"].append(met_vals["pearson::correlation"])
    
    agg_mets = {}
    for k, v in sub_metrics.items():
        agg_mets[f"test::{k}::mean"] = np.array(v).mean()
        agg_mets[f"test::{k}::std"] = np.array(v).std()
    overall_metrics[group] = agg_mets

df_overall = pd.DataFrame.from_dict(overall_metrics).T.reset_index().rename(columns={"index": "higher_split"})
df_overall["model_type"] = "ligand_only_2d"

df_overall.to_csv("../../notebooks/pdbbind_results/data/overall_local_ligand_only_2d_results.csv")

In [None]:
df_baseline = df_agg[df_agg["featurisation_name"] == "mw"]

df_best_mw = df_baseline.loc[
    df_baseline.groupby(["uniprot_id", "higher_split"])["validation::r2::mean"].idxmax().dropna()
]
df_best_mw["model_type"] = "local_mw"
df_best_mw.to_csv("../../notebooks/pdbbind_results/data/strat_local_mw_results.csv")

In [None]:
import molflux.metrics
import numpy as np
met_suite = molflux.metrics.load_suite("regression")
overall_metrics = {}
for group, df_g in df_best_mw.groupby("higher_split"):
    sub_metrics = {
        "r2": [],
        "rmse": [],
        "pearson": [],
    }
    for col in split_names["test"]:
        preds = [eval(x) for x in df_g[f"{col}.test.{y_feature}.predictions"]]
        refs = [eval(x) for x in df_g[f"{col}.test.{y_feature}.references"]]
        preds = [xs for x in preds for xs in x]
        refs = [xs for x in refs for xs in x]
        met_vals = met_suite.compute(predictions=preds, references=refs)
        sub_metrics["r2"].append(met_vals["r2"])
        sub_metrics["rmse"].append(met_vals["root_mean_squared_error"])
        sub_metrics["pearson"].append(met_vals["pearson::correlation"])
    
    agg_mets = {}
    for k, v in sub_metrics.items():
        agg_mets[f"test::{k}::mean"] = np.array(v).mean()
        agg_mets[f"test::{k}::std"] = np.array(v).std()
    overall_metrics[group] = agg_mets

df_overall = pd.DataFrame.from_dict(overall_metrics).T.reset_index().rename(columns={"index": "higher_split"})
df_overall["model_type"] = "local_mw"

df_overall.to_csv("../../notebooks/pdbbind_results/data/overall_local_mw_results.csv")