In [None]:
import molflux.datasets
import pandas as pd
import tempfile
from dvc.api import DVCFileSystem

In [None]:
df = pd.read_csv("summary.csv")
df = df[~df["Experiment"].isna()].reset_index()

dataset = molflux.datasets..load_dataset_from_store("../pdb_processing/data/dataset_processed.parquet")
pdb_code_to_uniprot = {}
for p in dataset:
    pdb_code_to_uniprot[p["pdb_code"]] = p["uniprot_id"]

In [None]:
import molflux.metrics
from tqdm.auto import tqdm
tqdm.pandas()

met_suite = molflux.metrics.load_suite("regression")

def add_metrics(row):
    fs = DVCFileSystem(
        "git@github.com:Exscientia/low-sim-pdbbind.git",
        rev=row["rev"],
        subrepos=True
    )
    with tempfile.NamedTemporaryFile(suffix=".csv") as tmpfile:
        fs.download(
            "pipelines/durant_models/data/results/trained_model_test.csv",
            tmpfile.name,
        )
        df_results = pd.read_csv(str(tmpfile.name))

    df_results["uniprot_id"] = df_results["key"].map(pdb_code_to_uniprot)
    
    overall = met_suite.compute(
        references=df_results["pk"].values,
        predictions=df_results["pred"].values,
    )
    
    row["test::r2"] = overall["r2"]
    row["test::pearson"] = overall["pearson::correlation"]
    row["test::rmse"] = overall["root_mean_squared_error"]

    for group, df_g in df_results.groupby("uniprot_id"):
        group_mets = met_suite.compute(
            references=df_g["pk"].values,
            predictions=df_g["pred"].values,
        )
        
        row[f"{group}::r2"] = group_mets["r2"]
        row[f"{group}::pearson"] = group_mets["pearson::correlation"]
        row[f"{group}::rmse"] = group_mets["root_mean_squared_error"]
    return row

In [None]:
cols_before = set(df.columns)
df = df.progress_apply(add_metrics, axis=1)
new_cols = set(df.columns) - cols_before

In [None]:
col_map = {
    "rev": "rev",
    "model_repo": "model_tag",
    "higher_split.presets.column": "higher_split",
}

df_agg = df[
    list(col_map.keys()) + 
    list(new_cols)
].rename(columns=col_map)

df_agg["fold"] = df_agg["higher_split"].str.slice(-6)
df_agg["higher_split"] = df_agg["higher_split"].str.slice(0, -7)
df_agg["higher_split"] = df_agg["higher_split"].map({k: f"{'_'.join(k.split('_')[:2])}_{int(k.split('_')[2]):02}" for k in set(df_agg["higher_split"])})

df_agg = df_agg.groupby(["model_tag", "higher_split"]).agg({
    **{k: ["mean", "std"] for k in list(new_cols)},
    **{"rev": "first"}
}).reset_index()
df_agg.columns = ['::'.join(col).strip() if col[1] else col[0] for col in df_agg.columns.values]

In [None]:
df_agg.to_csv("aggregated_summary.csv")

In [None]:
df_agg = pd.read_csv("aggregated_summary.csv")

In [None]:
df_agg["model_type"] = df_agg["model_tag"]
df_agg.to_csv("../../notebooks/pdbbind_results/data/overall_global_durant_results.csv")

In [None]:
uniprot_ids = [
    'O60885',
    'P00734',
    'P00760',
    'P00918',
    'P07900',
    'P24941',
    'P56817',
    'Q9H2K2',
]

In [None]:
list_dfs = []
for uniprot in uniprot_ids:
    met_df = df_agg[[
            "higher_split", 
            f"{uniprot}::pearson::mean",
            f"{uniprot}::r2::mean",
            f"{uniprot}::rmse::mean",
            f"{uniprot}::pearson::std",
            f"{uniprot}::r2::std",
            f"{uniprot}::rmse::std",
            "model_tag",
        ]].rename(columns={
            f"{uniprot}::pearson::mean": "test::pearson::mean",
            f"{uniprot}::r2::mean": "test::r2::mean",
            f"{uniprot}::rmse::mean": "test::rmse::mean",
            f"{uniprot}::pearson::std": "test::pearson::std",
            f"{uniprot}::r2::std": "test::r2::std",
            f"{uniprot}::rmse::std": "test::rmse::std",
        })
    met_df["uniprot_id"] = uniprot
    list_dfs.append(met_df)
df_strat_best = pd.concat(list_dfs)
df_strat_best["model_type"] = df_strat_best["model_tag"]
df_strat_best.to_csv("../../notebooks/pdbbind_results/data/strat_global_durant_results.csv")