In [12]:
from collections import defaultdict
import json
from pathlib import Path

import numpy as np
import pandas as pd
from sklearn.metrics import r2_score, mean_absolute_error

In [6]:
TARGET_COLUMNS = ['LogHLM', 'LogMLM', 'LogD', 'LogKSOL', 'LogMDR1-MDCKII']

In [5]:
pred_dir = '../output/asap/rnd_splits/'
models = ['chemprop', 'roberta']
preds = defaultdict(list)
for split_idx in range(5):
    for model_name in models:
        pred_path = Path(pred_dir) / model_name / 'run_0' / f'split_{split_idx}' / 'predictions.csv'
        preds[f'split_{split_idx}'].append(pd.read_csv(pred_path))

In [10]:
# copypaste from https://github.com/asapdiscovery/asap-polaris-blind-challenge-examples/blob/a613051bac57060f686d9993e201ecaa15e51009/evaulation.py
# with a log-transform fix according to this issue https://github.com/asapdiscovery/asap-polaris-blind-challenge-examples/issues/14

from collections import defaultdict
from typing import Tuple

def mask_nan(y_true, y_pred):
    mask = ~np.isnan(y_true)
    y_true = np.array(y_true)[mask]
    y_pred = np.array(y_pred)[mask]
    return y_true, y_pred

def eval_admet(preds: dict[str, list], refs: dict[str, list]) -> Tuple[dict[str, float], np.ndarray]:
    """
    Eval ADMET targets with MAE for pre-log10 transformed targets (LogD) and MALE  (MAE on log10 transformed dataset) on non-transformed data

    This provides a "relative" error metric that will not be as sensitive to the large outliers with huge errors. This is sometimes known as MALE.

    Parameters
    ----------
    preds : dict[str, list]
        Dictionary of predicted ADMET values.
    refs : dict[str, list]
        Dictionary of reference ADMET values.

    Returns
    -------
    dict[str, float]
        Returns a dictonary of summary statistics
    """
    keys = {
        "MLM",
        "HLM",
        "KSOL",
        "LogD",
        "MDR1-MDCKII",
    }
    # will be treated as is
    logscale_endpts = {"LogD"}

    collect = defaultdict(dict)

    for k in keys:
        if k not in preds.keys() or k not in refs.keys():
            raise ValueError("required key not present")

        ref, pred = mask_nan(refs[k], preds[k])

        if k in logscale_endpts:
            # already log10scaled
            mae = mean_absolute_error(ref, pred)
            r2 = r2_score(ref, pred)
        else:
            # clip to a detection limit
            # epsilon = 1e-8
            # pred = np.clip(pred, a_min=epsilon, a_max=None)
            # ref = np.clip(ref, a_min=epsilon, a_max=None)

            # transform both log10scale
            pred_log10s = np.log10(pred + 1.)
            ref_log10s = np.log10(ref + 1.)

            # compute MALE and R2 in log space
            mae = mean_absolute_error(ref_log10s, pred_log10s)
            r2 = r2_score(ref_log10s, pred_log10s)

        collect[k]["mean_absolute_error"] = mae
        collect[k]["r2"] = r2

    # compute macro average MAE
    macro_mae = np.mean([collect[k]["mean_absolute_error"] for k in keys])
    collect["aggregated"]["macro_mean_absolute_error"] = macro_mae

    # compute macro average R2
    macro_r2 = np.mean([collect[k]["r2"] for k in keys])
    collect["aggregated"]["macro_r2"] = macro_r2

    return collect

In [11]:
def extract_preds(preds: pd.DataFrame):
    preds_dict = {}
    for t in TARGET_COLUMNS:
        if t in ["LogHLM", "LogMLM", "LogKSOL", "LogMDR1-MDCKII"]:
            # transform back to non-log scale
            preds_dict[t[3:]] = np.power(10, preds.iloc[:, preds.columns.get_loc(f"pred_{t}")].values) - 1.
        else:
            preds_dict[t] = preds.iloc[:, preds.columns.get_loc(f"pred_{t}")].values
    
    return preds_dict

def extract_refs(refs: pd.DataFrame):
    refs_dict = {}
    for t in TARGET_COLUMNS:
        if t in ["LogHLM", "LogMLM", "LogKSOL", "LogMDR1-MDCKII"]:
            refs_dict[t[3:]] = refs.iloc[:, refs.columns.get_loc(t[3:])].values
        else:
            refs_dict[t] = refs.iloc[:, refs.columns.get_loc(t)].values
    
    return refs_dict

In [13]:
# average predictions across splits
for split_idx, split_preds in preds.items():
    for t in TARGET_COLUMNS:
        pred = split_preds[0].copy()
        pred[t] = (split_preds[0][t] + split_preds[1][t]) / 2

    train_preds = extract_preds(pred[pred["split"] == "train"])
    train_refs = extract_refs(pred[pred["split"] == "train"])
    val_preds = extract_preds(pred[pred["split"] == "val"])
    val_refs = extract_refs(pred[pred["split"] == "val"])

    metrics = eval_admet(val_preds, val_refs)
    print("\nVal metrics:")
    print(json.dumps(metrics, indent=2))


Val metrics:
{
  "LogD": {
    "mean_absolute_error": 0.46723858925373135,
    "r2": 0.7299866884942934
  },
  "MLM": {
    "mean_absolute_error": 0.40572426009085244,
    "r2": 0.33521354172613305
  },
  "MDR1-MDCKII": {
    "mean_absolute_error": 0.18577788307780504,
    "r2": 0.5648184088917596
  },
  "KSOL": {
    "mean_absolute_error": 0.2654457301867883,
    "r2": 0.5565578358026182
  },
  "HLM": {
    "mean_absolute_error": 0.27365844977556714,
    "r2": 0.2730263573939373
  },
  "aggregated": {
    "macro_mean_absolute_error": 0.31956898247694887,
    "macro_r2": 0.49192056646174825
  }
}

Val metrics:
{
  "LogD": {
    "mean_absolute_error": 0.39726614218571427,
    "r2": 0.8057381337306473
  },
  "MLM": {
    "mean_absolute_error": 0.3000724625310541,
    "r2": 0.6877236304349188
  },
  "MDR1-MDCKII": {
    "mean_absolute_error": 0.1942623224901187,
    "r2": 0.4221671113762564
  },
  "KSOL": {
    "mean_absolute_error": 0.3053094021621843,
    "r2": 0.5046082298758539
  },
