In [None]:
import pandas as pd
pd.set_option('display.max_columns', None)

from warnings import filterwarnings
filterwarnings('ignore')

In [None]:
df = pd.read_csv("summary.csv")
df = df[~df["Experiment"].isna()].reset_index()

In [None]:
num_higher_reps = 3
num_lower_reps = 1

y_feature = "log_affinity_data"
metric_names = {
    "r2": "r2",
    "root_mean_squared_error": "rmse",
    "pearson::correlation": "pearson",
}

split_names = {
    "validation": [f"fold_{higher_idx:02}_{lower_idx:02}" for higher_idx in range(num_higher_reps) for lower_idx in range(num_lower_reps)],
    "test": [f"fold_{higher_idx:02}" for higher_idx in range(num_higher_reps)],
}
metric_agg = {}
metric_cols = []

for split_name, split_fold_names in split_names.items():
    for metric in metric_names:
        set_met_list = []
        for fold_name in split_fold_names:
            col = f"{fold_name}.{split_name}.{y_feature}.{metric}"
            metric_cols.append(col)
            set_met_list.append(col)
        metric_agg[f"{split_name}::{metric_names[metric]}"] = set_met_list
metric_names["pearson"] = metric_names["pearson::correlation"]

In [None]:
for k, v in metric_agg.items():
    df[f"{k}::mean"] = df[v].mean(1)
    df[f"{k}::std"] = df[v].std(1)

In [None]:
from tqdm.auto import tqdm
import numpy as np
tqdm.pandas()
import molflux.datasets
import molflux.metrics

met_suite = molflux.metrics.load_suite("regression")

def add_stratified_metrics(row, strat_name="uniprot_id"):

    all_metrics_dict = {}
    
    for col in split_names["test"]:
        metrics_dict = {}
        
        refs = eval(row[f"{col}.test.{y_feature}.references"])
        preds = eval(row[f"{col}.test.{y_feature}.predictions"])
        strats = eval(row[f"{col}.test.{y_feature}.{strat_name}"])

        df_tmp = pd.DataFrame({"strats": strats, "refs": refs, "preds": preds})
        
        for group, df_group in df_tmp.groupby("strats"):
            mets = met_suite.compute(
                references=df_group["refs"].tolist(),
                predictions=df_group["preds"].tolist(),
            )
            metrics_dict[group] = {}
            metrics_dict[group]["r2"] = mets["r2"]
            metrics_dict[group]["rmse"] = mets["root_mean_squared_error"]
            metrics_dict[group]["pearson"] = mets["pearson::correlation"]

        all_metrics_dict[col] = metrics_dict

    for group in all_metrics_dict[col].keys():
        for met in all_metrics_dict[col][group].keys():
            met_array = np.array([all_metrics_dict[fold][group][met] for fold in all_metrics_dict.keys()])
            row[f"{group}::{met}::mean"] = met_array.mean()
            row[f"{group}::{met}::std"] = met_array.std()
    
    return row

In [None]:
cols_before = set(df.columns)
df = df.progress_apply(add_stratified_metrics, axis=1)
strat_cols = set(df.columns) - cols_before

In [None]:
col_map = {
    "rev": "rev",
    "dataset.dvc_config.rev": "dataset",
    "higher_split.presets.columns": "higher_split",
    "featurisation.which_hydrogens": "which_hydrogens",
    "featurisation.featurisation_name": "featurisation_name",
    'train.model_config.config.tag': "model_tag",
    "train.model_config.config.pooling_head": "multi_graph",
    # 'train.model_config.config.jitter': "jitter",
    # 'train.model_config.config.y_graph_scalars_loss_config.name': "loss_func",
}

df_agg = df[
    list(col_map.keys()) + 
    [f"{metric}::mean" for metric in metric_agg.keys()] + 
    [f"{metric}::std" for metric in metric_agg.keys()] +
    list(strat_cols)
].rename(columns=col_map)

df_agg["pre_trained"] = ~df_agg["model_tag"].str.contains("plain")

df_agg["model_tag"] = df_agg["model_tag"].str.replace("multi_graph_", "")
df_agg["model_tag"] = df_agg["model_tag"].map(
    {
        "qm_egnn_model_two_stage": "EGNN_QM",
        "diffusion_egnn_model_two_stage": "EGNN_DIFF",
        "plain_egnn_model": "EGNN",
        # "qm_egnn_model_two_stage_contact_map": "EGNN_QM_CM",
        # "diffusion_egnn_model_two_stage_contact_map": "EGNN_DIFF_CM",
        # "plain_egnn_model_contact_map": "EGNN_CM",
    }
)

df_agg["multi_graph"] = df_agg["multi_graph"].map({"InvariantLigandPocketPoolingHead": "multi"})
df_agg["multi_graph"] = df_agg["multi_graph"].fillna("single")

df_agg["higher_split"] = ["_".join(eval(x)[0].split("_")[:3]) for x in df_agg["higher_split"]]
df_agg["higher_split"] = df_agg["higher_split"].map({k: f"{'_'.join(k.split('_')[:2])}_{int(k.split('_')[2]):02}" for k in set(df_agg["higher_split"])})
# df_agg["jitter"] = df_agg["jitter"].fillna(0.0)

In [None]:
df_agg

In [None]:
df_agg.to_csv("aggregated_summary.csv")

In [None]:
df_agg = pd.read_csv("aggregated_summary.csv")

In [None]:
df_best = df_agg.loc[df_agg.groupby(["model_tag", "multi_graph", "higher_split"])["validation::r2::mean"].idxmax()]
df_best["model_type"] = df_best["model_tag"] + "_" + df_best["multi_graph"]
df_best.to_csv("../../notebooks/pdbbind_results/data/overall_ligand_pocket_3d_results.csv")

In [None]:
uniprot_ids = [
    'O60885',
    'P00734',
    'P00760',
    'P00918',
    'P07900',
    'P24941',
    'P56817',
    'Q9H2K2',
]

In [None]:
list_dfs = []
for uniprot in uniprot_ids:
    met_df = df_best[[
            "higher_split", 
            f"{uniprot}::pearson::mean",
            f"{uniprot}::r2::mean",
            f"{uniprot}::rmse::mean",
            f"{uniprot}::pearson::std",
            f"{uniprot}::r2::std",
            f"{uniprot}::rmse::std",
            "featurisation_name", 
            "model_tag",
            "which_hydrogens",
            "multi_graph",
        ]].rename(columns={
            f"{uniprot}::pearson::mean": "test::pearson::mean",
            f"{uniprot}::r2::mean": "test::r2::mean",
            f"{uniprot}::rmse::mean": "test::rmse::mean",
            f"{uniprot}::pearson::std": "test::pearson::std",
            f"{uniprot}::r2::std": "test::r2::std",
            f"{uniprot}::rmse::std": "test::rmse::std",
        })
    met_df["uniprot_id"] = uniprot
    list_dfs.append(met_df)
df_strat_best = pd.concat(list_dfs)
df_strat_best["model_type"] = df_strat_best["model_tag"] + "_" + df_strat_best["multi_graph"]
df_strat_best.to_csv("../../notebooks/pdbbind_results/data/strat_ligand_pocket_3d_results.csv")