This notebook requires the TinyMol dataset, which is downloaded by the following cell.

In [None]:
!python ../../scripts/download_tinymol_dataset.py

In [None]:
import pandas as pd
import numpy as np
from uncertainties import ufloat_fromstr
import uncertainties.unumpy as unp

import matplotlib.pyplot as plt
from matplotlib.legend_handler import HandlerTuple
import matplotlib.image as mpimg
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

from oneqmc.analysis.plot import set_defaults
from oneqmc.analysis import colours

set_defaults()

tab20b = plt.get_cmap("tab20b").colors

colors = {
    "psiformer_singlepoint": colours.YELLOW,
    "orbformer_scratch": colours.TEAL,
    "orbformer_finetune_tm_512k": colours.PURPLE,
    "orbformer_finetune_oc_400k": colours.DARKTEAL,
    "scherbela_chem-pretrain": tab20b[4 + 8],
    "scherbela_hf-pretrain": tab20b[5 + 8],
    "gao_scratch": tab20b[6 + 12],
    "gao_finetune": tab20b[7 + 12],
    "orbformer_finetune_joint_tm_512k": colours.MIDTEAL,
    "orbformer_scratch_joint": colours.GREY,
}

TEXTWIDTH = 7.08
plt.rcParams["font.size"] = 10

In [None]:
def fit(x, y):
    a, b = np.polyfit(np.log(x), np.log(y), 1)
    x_out = np.logspace(np.log10(x[0]), np.log10(x[-1]), 100)
    return x_out, np.exp(a * np.log(x_out) + b)

In [None]:
save_figures = False

# Data

In [None]:
# load csvs

to_datadir = "../../experiment_results/03_tinymol"

df = pd.read_csv(
    f"{to_datadir}/processed_data.csv", converters={"E": ufloat_fromstr}
).set_index(["Ansatz", "System", "Fine-tune step", "Geometry"])
df_gao = pd.read_csv(
    f"{to_datadir}/references/tinymol_gao.csv", converters={"E": ufloat_fromstr}
).set_index(["Ansatz", "System", "Fine-tune step", "Geometry"])
df_schebela = pd.read_csv(
    f"{to_datadir}/references/tinymol_deeperwin.csv", delimiter=";"
).set_index(
    [
        "molecule",
        "method",
        "n_pretrain_variational",
        "epoch",
        "reuse_from",
        "geom_comment",
    ]
)
df_schebela = df_schebela[
    df_schebela.index.get_level_values(5).str.contains("rot_dist_5")
]
df_schebela = df_schebela[
    np.logical_or(
        df_schebela.experiment == "2023-03-01_gao_shared_TinyMol",
        df_schebela.experiment == "2023-03-06_tinymol_v10_ablation_n_pretrain",
    )
]
df_ref_cc = pd.read_csv(
    f"{to_datadir}/references/tinymol_deeperwin.csv", delimiter=";"
).set_index(["molecule", "experiment", "geom_comment"])

In [None]:
systems_id = ["CNH", "C2H4", "COH2"]
systems_ood = ["C3H4", "CN2H2", "CNOH", "CO2"]

finetune_steps = np.array([0, 250, 1000, 4000, 8000, 16000, 32000])
finetune_steps_str = ["0", "25", "100", "400", "800", "1600", "3200"]

experiments = [
    "psiformer_singlepoint",
    "orbformer_finetune_tm_512k",
    "orbformer_finetune_oc_400k",
    "orbformer_scratch",
]

# Combined plot

In [None]:
def rel_error(energies, ref):
    energies_rel = energies - energies.mean(-1, keepdims=True)
    reference_rel = ref - ref.mean(-1, keepdims=True)
    rel_err = abs(energies_rel - reference_rel[..., None, :])
    return rel_err.mean(-1)

In [None]:
experiments_scaling = [
    "orbformer_finetune_tm_32k",
    "orbformer_finetune_tm_64k",
    "orbformer_finetune_tm_128k",
    "orbformer_finetune_tm_256k",
    "orbformer_finetune_tm_512k",
    "orbformer_finetune_oc_40k",
    "orbformer_finetune_oc_100k",
    "orbformer_finetune_oc_200k",
    "orbformer_finetune_oc_400k",
]

In [None]:
data_scaling_tm = {}
data_scaling_oc = {}
data_scaling_scratch = {}
for i, s in enumerate([systems_id, systems_ood]):
    data_scaling_tm[i] = []
    data_scaling_oc[i] = []
    data_scaling_scratch[i] = []
    E_CC = (
        df_ref_cc.swaplevel(0, 1)
        .loc["CCSD(T)_CBS"]
        .loc[s]
        .sort_index()
        .E.to_numpy()
        .reshape(-1, 10)
    )

    pretrain_steps_tm = [
        int(exp.split("_")[-1][:-1]) * 1000
        for exp in experiments_scaling
        if "tm" in exp
    ]
    for ansatz in [exp for exp in experiments_scaling if "tm" in exp]:
        MARE = rel_error(
            df.loc[ansatz].loc[s].sort_index().E.to_numpy().reshape(len(s), -1, 10),
            E_CC,
        )
        data_scaling_tm[i].append(MARE)

    pretrain_steps_oc = [
        int(exp.split("_")[-1][:-1]) * 1000 * 2
        for exp in experiments_scaling
        if "oc" in exp
    ]
    for ansatz in [exp for exp in experiments_scaling if "oc" in exp]:
        MARE = rel_error(
            df.loc[ansatz].loc[s].sort_index().E.to_numpy().reshape(len(s), -1, 10),
            E_CC,
        )
        data_scaling_oc[i].append(MARE)

    pretrain_steps_scratch = [0]
    for ansatz in ["orbformer_scratch"]:
        MARE = rel_error(
            df.loc[ansatz].loc[s].sort_index().E.to_numpy().reshape(len(s), -1, 10),
            E_CC,
        )
        data_scaling_scratch[i].append(MARE)

In [None]:
fig, ((axs1, axs2), (axs3, axs4)) = plt.subplots(
    figsize=(TEXTWIDTH+0.45, 4), nrows=2, ncols=2
)
x_offset = 100
eps = 0.0  # used to avoid error bar collisions
for i, (ax, systems) in enumerate(zip((axs1, axs3), (systems_id, systems_ood))):
    # CCSD(T) reference
    E_CC = (
        df_ref_cc.swaplevel(0, 1)
        .loc["CCSD(T)_CBS"]
        .loc[systems]
        .sort_index()
        .E.to_numpy()
        .reshape(-1, 10)
    )

    # baselines
    for ansatz in [
        "scherbela_chem-pretrain",
        "scherbela_hf-pretrain",
        "gao_scratch",
        "gao_finetune",
    ]:
        if ansatz == "scherbela_chem-pretrain":
            df_run = df_schebela.loc[
                systems, "reuseshared", 500000.0, :, "500kshared_tinymol_v10"
            ]
            l = 7
        elif ansatz == "scherbela_hf-pretrain":
            df_run = df_schebela.loc[systems, "shared", 0, finetune_steps, :]
            l = 7
        elif ansatz in ["gao_finetune", "gao_scratch"]:
            df_run = df_gao.loc[ansatz].loc[systems]
            l = 6
        n = df_run.sort_index().E.to_numpy().reshape(-1, l, 10)
        MARE = 627.5095 * rel_error(unp.nominal_values(n), E_CC)
        ax.errorbar(
            np.exp(eps) * (finetune_steps[(7 - l) :] + x_offset),
            unp.nominal_values(MARE).mean(0),
            unp.nominal_values(MARE).std(0),
            c=colors[ansatz],
            label=ansatz,
            marker=".",
            ls="",
            markersize=10,
        )
        eps += 0.005
        x, y = fit(finetune_steps[-6:], unp.nominal_values(MARE).mean(0)[-6:])
        ax.plot(x + x_offset, y, c=colors[ansatz])

    # orbformer runs
    for ansatz in experiments:
        MARE = 627.5095 * rel_error(
            df.loc[ansatz]
            .loc[systems]
            .sort_index()
            .E.to_numpy()
            .reshape(len(systems), -1, 10),
            E_CC,
        )
        ax.errorbar(
            np.exp(eps)
            * (
                finetune_steps[-MARE.shape[1] :]
                * (10 if "singlepoint" in ansatz else 1)
                * ((1 / 4 if i else 1 / 3) if "joint" in ansatz else 1)
                + x_offset
            ),
            unp.nominal_values(MARE).mean(0),
            unp.nominal_values(MARE).std(0),
            label=ansatz.replace("_", " "),
            c=colors[ansatz],
            marker=".",
            ls="",
            markersize=10,
        )
        eps += 0.005
        x, y = fit(
            finetune_steps[-6:] * (10 if "singlepoint" in ansatz else 1),
            unp.nominal_values(MARE[:, -6:]).mean(0),
        )
        ax.plot(x + x_offset, y, c=colors[ansatz])

    # plot formatting
    ax.set_xlim(200, 600000)
    ax.set_yscale("log")
    ax.set_xscale("log")
    ax.set_xticks(
        np.array([320, 3200, 32000, 320000]) + x_offset,
        [32, 320, 3200, "32k"],
    )
    ax.minorticks_off()
    if not i:
        ax.set_ylabel("MARE wrt. CCSD(T)/CBS (kcal/mol)")
        ax.yaxis.set_label_coords(-0.15, 0)
    else:
        ax.set_xlabel("Scratch training or finetuning\n steps per structure")
    ax.annotate(
        ["in dist", "out of dist"][i],
        (0.025, 0.05),
        xycoords="axes fraction",
        fontsize=10,
    )
handles = ax.get_legend_handles_labels()[0]
handles_grouped = (
    handles[7][0],
    handles[6][0],
    handles[5][0],
    handles[4][0],
    handles[0][0],
    handles[1][0],
    handles[2][0],
    handles[3][0],
)
labels = (
    "Orbformer - scratch",
    "Orbformer - finetune (LAC)",
    "Orbformer - finetune (TinyMol)",
    "Psiformer - scratch (single-point)",
    "Scherbela - finetune (TinyMol)",
    "Scherbela - finetune (Hartree Fock)",
    "Gao - scratch",
    "Gao - finetune (TinyMol)",
)
fig.legend(
    loc="center",
    handles=handles_grouped,
    labels=labels,
    handler_map={tuple: HandlerTuple(ndivide=None)},
    ncol=3,
    bbox_to_anchor=(0.515, 0.972),
    columnspacing=0.2,
    handletextpad=0,
)

# -------------------------------------------------------------------------------------------
x_offset = 20000
c = (
    colors["orbformer_finetune_tm_512k"],
    colors["orbformer_finetune_oc_400k"],
    colors["orbformer_scratch"],
)
subset = [3]
for i, (ax, systems) in enumerate(zip((axs2, axs4), [systems_id, systems_ood])):
    data_scaling_scratch_i = np.stack([d[:, -6:] for d in data_scaling_scratch[i]])
    data_scaling_tm_i = np.concatenate(
        (data_scaling_scratch_i, np.stack([d[:, -6:] for d in data_scaling_tm[i]])), 0
    )
    data_scaling_oc_i = np.concatenate(
        (data_scaling_scratch_i, np.stack([d[:, -6:] for d in data_scaling_oc[i]])), 0
    )
    for k, (pretrain_steps, data_scaling_i) in enumerate(
        (
            ([0] + pretrain_steps_tm, data_scaling_tm_i),
            ([0] + pretrain_steps_oc, data_scaling_oc_i),
            # (pretrain_steps_scratch, data_scaling_scratch_i),
        )
    ):
        for j in subset:
            ax.errorbar(
                np.array(pretrain_steps[:]) + x_offset,
                627.5095 * unp.nominal_values(data_scaling_i)[:, :, j].mean(1),
                627.5095 * unp.nominal_values(data_scaling_i)[:, :, j].std(1),
                label=f"ft-{finetune_steps[-6:][j]}",
                c=c[k],
                alpha=0.3 + 0.7 * j / max(subset),
                marker=".",
                markersize=10,
            )
        ax.set_xscale("log")
        ax.set_xlim(x_offset * 0.9, 1.1 * 800000 + x_offset)
        ax.set_yscale("log")
        ax.set_yticks([0.2, 0.5, 1, 2, 5, 10, 20, 50], [0.2, 0.5, 1, 2, 5, 10, 20, 50])
        ax.minorticks_off()
    ax.annotate(
        ["in dist", "out of dist"][i],
        (0.025, 0.05),
        xycoords="axes fraction",
        fontsize=10,
    )
    if not i:
        ax.set_xticks(
            [],
        )
    else:
        ax.set_xticks(
            [
                x_offset,
                32000 + x_offset,
                128000 + x_offset,
                256000 + x_offset,
                512000 + x_offset,
            ],
            ["0", "32k", "128k", "256k", "512k"],
        )
        ax.set_xlabel("Chemical pretraining steps\n(800 finetune steps per structure)")
handles, _ = ax.get_legend_handles_labels()
labels = [
    [
        "finetune 250",
        "finetune 1k",
        "finetune 4k",
        "finetune 8k",
        "finetune 16k",
        "finetune 32k",
    ][i]
    for i in subset
]
# -------------------------------------------------------------------------------------------
for i, ax in enumerate([axs1, axs2, axs3, axs4]):
    if i // 2 == 0:
        ax.set_ylim(0.05, 100)
    else:
        ax.set_ylim(0.15, 200)
    if i % 2 == 0:
        ax.set_yticks([0.2, 0.5, 1, 2, 5, 10, 20, 50], [0.2, 0.5, 1, 2, 5, 10, 20, 50])
    else:
        ax.set_yticks([])
    ax.axhspan(0, 5, color="grey", alpha=0.2, lw=0)
    ax.axhline(1, color="k", ls=":", zorder=1)

# Add molecule structures
# --------------------------------------------------------------------------------
arrimg = mpimg.imread("../../experiment_results/03_tinymol/molecule_images/C2H4.png")
imagebox = OffsetImage(arrimg, zoom=0.07)
ab = AnnotationBbox(imagebox, (140000, 25), frameon=False)
axs2.add_artist(ab)
arrimg = mpimg.imread("../../experiment_results/03_tinymol/molecule_images/CNH.png")
imagebox = OffsetImage(arrimg, zoom=0.07)
ab = AnnotationBbox(imagebox, (300000, 25), frameon=False)
axs2.add_artist(ab)
arrimg = mpimg.imread("../../experiment_results/03_tinymol/molecule_images/COH2.png")
imagebox = OffsetImage(arrimg, zoom=0.09)
ab = AnnotationBbox(imagebox, (600000, 25), frameon=False)
axs2.add_artist(ab)

arrimg = mpimg.imread("../../experiment_results/03_tinymol/molecule_images/C3H4.png")
imagebox = OffsetImage(arrimg, zoom=0.09)
ab = AnnotationBbox(imagebox, (50000, 25), frameon=False)
axs4.add_artist(ab)
arrimg = mpimg.imread("../../experiment_results/03_tinymol/molecule_images/CN2H2.png")
imagebox = OffsetImage(arrimg, zoom=0.08)
ab = AnnotationBbox(imagebox, (130000, 25), frameon=False)
axs4.add_artist(ab)
arrimg = mpimg.imread("../../experiment_results/03_tinymol/molecule_images/CNOH.png")
imagebox = OffsetImage(arrimg, zoom=0.1)
ab = AnnotationBbox(imagebox, (300000, 25), frameon=False)
axs4.add_artist(ab)
arrimg = mpimg.imread("../../experiment_results/03_tinymol/molecule_images/CO2.png")
imagebox = OffsetImage(arrimg, zoom=0.07)
ab = AnnotationBbox(imagebox, (600000, 25), frameon=False)
axs4.add_artist(ab)
axs1.annotate("(a)", (300000, 40))
axs3.annotate("(b)", (300000, 70))
axs2.annotate("(c)", (20100, 40))
axs4.annotate("(d)", (20100, 70))

# --------------------------------------------------------------------------------
fig.subplots_adjust(hspace=0.0, wspace=0, right=1)
if save_figures:
    plt.savefig("01_tinymol.pdf", bbox_inches="tight", dpi=600)