In [1]:
from collections import defaultdict
import json
from pathlib import Path
import sys

import pandas as pd

sys.path.insert(0, '../agenticadmet')
from eval import extract_preds, extract_refs, eval_admet

In [2]:
TARGET_COLUMNS = ['LogHLM', 'LogMLM', 'LogD', 'LogKSOL', 'LogMDR1-MDCKII']

In [3]:
pred_dir = '../output/asap/rnd_splits/'
models = ['chemprop', 'roberta']
preds = defaultdict(list)
for split_idx in range(5):
    for model_name in models:
        pred_path = Path(pred_dir) / model_name / 'run_0' / f'split_{split_idx}' / 'predictions.csv'
        preds[f'split_{split_idx}'].append(pd.read_csv(pred_path))

In [4]:
# average predictions across splits
for split_idx, split_preds in preds.items():
    for t in TARGET_COLUMNS:
        pred = split_preds[0].copy()
        pred[f"pred_{t}"] = (split_preds[0][f"pred_{t}"] + split_preds[1][f"pred_{t}"]) / 2

    train_preds = extract_preds(pred[pred["split"] == "train"])
    train_refs = extract_refs(pred[pred["split"] == "train"])
    val_preds = extract_preds(pred[pred["split"] == "val"])
    val_refs = extract_refs(pred[pred["split"] == "val"])

    metrics = eval_admet(val_preds, val_refs)
    print("\nVal metrics:")
    print(json.dumps(metrics, indent=2))


Val metrics:
{
  "HLM": {
    "mean_absolute_error": 0.35427863199510806,
    "r2": 0.2574464720345797
  },
  "KSOL": {
    "mean_absolute_error": 0.3865705027615289,
    "r2": 0.34739706183218844
  },
  "LogD": {
    "mean_absolute_error": 0.4518723898032787,
    "r2": 0.7818159279952948
  },
  "MLM": {
    "mean_absolute_error": 0.34520219464093466,
    "r2": 0.5266217719645775
  },
  "MDR1-MDCKII": {
    "mean_absolute_error": 0.22995053353849323,
    "r2": 0.3926929280893212
  },
  "aggregated": {
    "macro_mean_absolute_error": 0.35357485054786875,
    "macro_r2": 0.4611948323831923
  }
}

Val metrics:
{
  "HLM": {
    "mean_absolute_error": 0.3642703922592749,
    "r2": 0.3375760363807351
  },
  "KSOL": {
    "mean_absolute_error": 0.35402169402384404,
    "r2": 0.378212608923861
  },
  "LogD": {
    "mean_absolute_error": 0.5059657347727273,
    "r2": 0.6638158763836648
  },
  "MLM": {
    "mean_absolute_error": 0.442124383882978,
    "r2": 0.31202641557971655
  },
  "MDR1-MDC