In [None]:
import pathlib

import pandas as pd
import polaris as po

from matplotlib import pyplot as plt
from polaris.utils.types import TargetType

In [None]:
RNG_SEEDS = (42, 117, 709, 1701, 9001)

POLARIS_BENCHMARKS = (
    # "polaris/pkis2-ret-wt-cls-v2",
    # "polaris/pkis2-ret-wt-reg-v2",
    "polaris/pkis2-kit-wt-cls-v2",
    # "polaris/pkis2-kit-wt-reg-v2",
    "polaris/pkis2-egfr-wt-reg-v2",
    "polaris/adme-fang-solu-1",
    "polaris/adme-fang-rppb-1",
    # "polaris/adme-fang-hppb-1",
    # "polaris/adme-fang-perm-1",
    # "polaris/adme-fang-rclint-1",
    # "polaris/adme-fang-hclint-1",
    # "tdcommons/lipophilicity-astrazeneca",
    # "tdcommons/ppbr-az",
    "tdcommons/clearance-hepatocyte-az",
    # "tdcommons/cyp2d6-substrate-carbonmangels",
    # "tdcommons/half-life-obach",
    # "tdcommons/cyp2c9-substrate-carbonmangels",
    # "tdcommons/clearance-microsome-az",
    "tdcommons/dili",
    "tdcommons/bioavailability-ma",
    # "tdcommons/vdss-lombardo",
    # "tdcommons/cyp3a4-substrate-carbonmangels",
    "tdcommons/pgp-broccatelli",
    "tdcommons/caco2-wang",
    "tdcommons/herg",
    "tdcommons/bbb-martins",
    # "tdcommons/ames",
    "tdcommons/ld50-zhu",
)

In [None]:
regression_benchmarks = {}
classification_benchmarks = {}
for bm_id in POLARIS_BENCHMARKS:
    bm = po.load_benchmark(bm_id)
    target_cols = list(bm.target_cols)
    task_type = bm.target_types[target_cols[0]]
    if task_type == TargetType.REGRESSION:
        regression_benchmarks[bm_id] = bm
    else:
        classification_benchmarks[bm_id] = bm
benchmarks = {**regression_benchmarks, **classification_benchmarks}

In [None]:
CWD = pathlib.Path.cwd()
RESULTS_FILES = {
    "Default": CWD / "default.csv",
    "Delayed": CWD / "delayed.csv",
    "Frozen": CWD / "frozen.csv",
    "Dropout": CWD / "dropout.csv",
}

In [None]:
statistics = {}
for case, file in RESULTS_FILES.items():
    results_df = pd.read_csv(file)
    case_stats = {}
    for benchmark in benchmarks:
        name = "".join(c if c.isalnum() else "_" for c in benchmark)
        for split in ["train", "val"]:
            df = results_df[
                [f"{name}_seed{seed}_{split}" for seed in RNG_SEEDS]
            ].copy()
            case_stats[f"{name}_{split}_mean"] = df.mean(axis=1)
            case_stats[f"{name}_{split}_std"] = df.std(axis=1)
    statistics[case] = pd.DataFrame(case_stats)

Learning curves for each model and case

In [None]:
num_benchmarks = len(benchmarks)
ncols = 3
nrows = (num_benchmarks + ncols - 1) // ncols
fig, axes = plt.subplots(
    nrows, ncols, figsize=(16, 5 * nrows)
)
axes = axes.flatten()

for bm_index, bm_id in enumerate(benchmarks):
    benchmark = benchmarks[bm_id]
    target_cols = list(benchmark.target_cols)
    task_type = benchmark.target_types[target_cols[0]]
    train, test = benchmark.get_train_test_split()
    train_set_size = len(train)

    prefix = "".join(c if c.isalnum() else "_" for c in bm_id)
    ax = axes[bm_index]
    legend_colors = {}
    for i, case in enumerate(statistics.keys()):
        case_stats = statistics[case]
        mean_cols = {split: f"{prefix}_{split}_mean" for split in ["train", "val"]}
        epochs = case_stats.index
        for split, ls in [("val", "-"), ("train", "--")]:
            mean_col = mean_cols[split]
            color = f"C{i}"
            label = f"{case}, {'training' if split == 'train' else 'validation'}"
            legend_colors[label] = color
            ax.plot(
                epochs,
                case_stats[mean_col],
                color=color,
                linestyle=ls,
                label=label,
                linewidth=3.0,
            )
    ax.set_title(f"{bm_id}\n{task_type.value} ({train_set_size})", fontsize=18)
    if bm_index >= num_benchmarks - 3:
        ax.set_xlabel("Epoch", fontsize=18)
    if bm_index % 3 == 0:
        ax.set_ylabel("Scaled Loss", fontsize=18)
    ax.set_ylim(bottom=0, top=1)
    ax.tick_params(axis="both", which="major", width=2.0, length=8)
    ax.tick_params(axis="both", which="minor", width=1.5, length=4)

for extra_ax in axes[num_benchmarks:]:
    extra_ax.axis("off")

handles, labels = ax.get_legend_handles_labels()
legend = fig.legend(
    ncol=4,
    frameon=False,
    handles=handles,
    labels=labels,
    loc="upper center",
    bbox_to_anchor=(0.5, 0),
    fontsize=20,
)

fig.tight_layout()

for ext in ["png", "eps"]:
    fig.savefig(f"fine_tuning.{ext}", dpi=600, bbox_inches="tight")

plt.show()