In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import rcParams

import numpy as np
from uncertainties import ufloat_fromstr
from uncertainties.unumpy import nominal_values, std_devs

import matplotlib.ticker as ticker
from matplotlib.legend_handler import HandlerTuple
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

from oneqmc.analysis import HARTREE_TO_KCAL, colours
from oneqmc.analysis.plot import set_defaults, get_cyclic_cmap

set_defaults()

In [None]:
save_figures = False

In [None]:
def cmap_idx(i):
    return (i + 3.9) / 8

In [None]:
orbformer_from_scratch_color = "#1d2429"
orbformer_beta_color = "#4e938d"
orbformer_alpha_color = "#74baae"
colormaps = {
    "M06-2X": get_cyclic_cmap([colours.BRAND_BLUE, colours.BRAND_LIGHT_BLUE]),
    "B2-PLYP": get_cyclic_cmap([colours.BRAND_MAGENTA, colours.BRAND_LIGHT_MAGENTA]),
    "CASPT2(4,4)": get_cyclic_cmap([colours.BRAND_RED, colours.BRAND_LIGHT_RED]),
    "CASPT2(6,6)": get_cyclic_cmap([colours.BRAND_ORANGE, colours.BRAND_LIGHT_ORANGE]),
    "MRAQCC/int-sp": get_cyclic_cmap(
        [colours.BRAND_YELLOW, colours.BRAND_LIGHT_YELLOW]
    ),
    "MRAQCC(6,6)/int-sp": get_cyclic_cmap(
        [colours.BRAND_YELLOW, colours.BRAND_LIGHT_YELLOW]
    ),
    "CCSD(T)": get_cyclic_cmap([colours.BRAND_RED, colours.BRAND_LIGHT_RED]),
    "MRAQCC/complete": get_cyclic_cmap(
        [colours.BRAND_GREEN, colours.BRAND_LIGHT_GREEN]
    ),
    "MRAQCC(6,6)/complete": get_cyclic_cmap(
        [colours.BRAND_GREEN, colours.BRAND_LIGHT_GREEN]
    ),
    "Orbformer scratch": ["#539992", "#7AD4C5", "#73BDB1"],
    "Orbformer LAC": ["#1C909F", "#246c75", "#2E5459", "#2d494d"],
    "Orbformer from scratch": ["#539992", "#7AD4C5", "#73BDB1"],
    "Orbformer finetune LAC": ["#1C909F", "#246c75", "#2E5459", "#2d494d"],
}

In [None]:
data_dir = f"../../experiment_results/01_diels-alder"
orbformer_results = pd.read_csv(
    f"{data_dir}/processed_data.csv",
    converters={"energy": ufloat_fromstr},
).set_index(["experiment", "step", "molecule"])

In [None]:
orbformer_key = "beta_all_geoms_merged_bs1024"
orbformer_iters = [150000, 180000, 200000, 230000, 260000, 300000]

# Combined barplot

In [None]:
def make_orbformer_result(pretrain, step, end_state="ts_concerted"):
    return (
        HARTREE_TO_KCAL
        * (
            orbformer_results.loc[pretrain, step, end_state]["energy"]
            - orbformer_results.loc[pretrain, step, "ethene"]["energy"]
            - orbformer_results.loc[pretrain, step, "trans_butadiene"]["energy"]
        ).nominal_value
    )

In [None]:
concerted_activation_energies = {
    "Ref": 25.8,
    "M06-2X": {
        "6-31+G(d)": 21.07,
        "6-311+G(d)": 20.25,
        "6-311+G(2d,1p)": 22.08,
        "6-311++G(2d,1p)": 22.07,
    },
    "B2-PLYP": {
        "6-31+G(d)": 24.77,
        "6-311+G(d)": 26.76,
        "6-311+G(2d,1p)": 24.28,
        "6-311++G(2d,1p)": 24.24,
    },
    "CASPT2(4,4)": {  # convert from kJ/mol
        "cc-pVDZ": 0.239 * 84.0,
        "cc-pVTZ": 0.239 * 77.7,
        "cc-pVQZ": 0.239 * 76.8,
        "cc-pV5Z": 0.239 * 76.4,
    },
    "CASPT2(6,6)": {  # convert from kJ/mol
        "cc-pVDZ": 0.239 * 96.1,
        "cc-pVTZ": 0.239 * 90.9,
        "cc-pVQZ": 0.239 * 90.3,
        "cc-pV5Z": 0.239 * 90.0,
    },
    "MRAQCC(6,6)/int-sp": {  # convert from kJ/mol
        "6-31G(d)": 0.239 * 107.7,
        "6-31G(d,p)": 0.239 * 105.7,
        "6-311G(d,p)": 0.239 * 101.5,
        "6-311G(2d,1p)": 0.239 * 99.2,
    },
    "MRAQCC(6,6)/complete": {
        "6-31G(d)": 24.22,
        "6-31G(d,p)": 23.71,
        "6-311G(2d,1p)": 22.2,
    },
    "Orbformer from scratch": {
        "128k": make_orbformer_result("from_scratch", 130_000),
        "200k": make_orbformer_result("from_scratch", 200_000),
    },
    "Orbformer finetune LAC": {
        "128k": make_orbformer_result("beta", 128_000),
        "200k": make_orbformer_result("beta", 200_000),
    },
}

In [None]:
reaction_energies = {
    "Ref": 25.8 - 68.5,  # kcal/mol
    "M06-2X": {
        "6-31+G(d)": -39.67,
        "6-311+G(d)": -39.56,
        "6-311+G(2d,1p)": -36.92,
        "6-311++G(2d,1p)": -36.25,
    },
    "B2-PLYP": {
        "6-31+G(d)": -31.82,
        "6-311+G(d)": -32.70,
        "6-311+G(2d,1p)": -29.87,
        "6-311++G(2d,1p)": -29.88,
    },
    "MRAQCC(6,6)/int-sp": {
        "6-31G(d)": -43.8,
        "6-31G(d,p)": -43.85,
        "6-311G(d,p)": -43.14,
        "6-311G(2d,1p)": -41.03,
    },
    "MRAQCC(6,6)/complete": {
        "6-31G(d)": -42.46,
        "6-31G(d,p)": -42.52,
        "6-311G(2d,1p)": -39.7,
    },
    "Orbformer from scratch": {
        "128k": make_orbformer_result("from_scratch", 130_000, end_state="cyclohexene"),
        "200k": make_orbformer_result("from_scratch", 200_000, end_state="cyclohexene"),
    },
    "Orbformer finetune LAC": {
        "128k": make_orbformer_result("beta", 128_000, end_state="cyclohexene"),
        "200k": make_orbformer_result("beta", 200_000, end_state="cyclohexene"),
    },
}

In [None]:
tsf_anti_energies = {
    "M06-2X": {
        "6-31+G(d)": 35.05,
        "6-311+G(d)": 34.54,
        "6-311+G(2d,1p)": 36.35,
        "6-311++G(2d,1p)": 36.34,
    },
    "B2-PLYP": {
        "6-31+G(d)": 39.15,
        "6-311+G(d)": 38.98,
        "6-311+G(2d,1p)": 40.10,
        "6-311++G(2d,1p)": 40.07,
    },
    "CASPT2(4,4)": {
        "cc-pVDZ": 0.239 * 127.5,
        "cc-pVTZ": 0.239 * 129.0,
        "cc-pVQZ": 0.239 * 129.3,
        "cc-pV5Z": 0.239 * 129.2,
    },
    "CASPT2(6,6)": {
        "cc-pVDZ": 0.239 * 124.7,
        "cc-pVTZ": 0.239 * 125.7,
        "cc-pVQZ": 0.239 * 125.9,
        "cc-pV5Z": 0.239 * 125.8,
    },
    "MRAQCC(6,6)/int-sp": {  # convert from kJ/mol
        "6-31G(d)": 32.05,
        "6-31G(d,p)": 31.77,
    },
    "Orbformer from scratch": {
        "128k": make_orbformer_result("from_scratch", 130_000, end_state="tsf_anti"),
        "200k": make_orbformer_result("from_scratch", 200_000, end_state="tsf_anti"),
    },
    "Orbformer finetune LAC": {
        "128k": make_orbformer_result("beta", 128_000, end_state="tsf_anti"),
        "200k": make_orbformer_result("beta", 200_000, end_state="tsf_anti"),
    },
}

In [None]:
def combined_bar():
    fig, ax = plt.subplots(1, 2, figsize=(10, 7), sharey=True)
    inter_group_gap = 0.5
    cum_shift = 0
    positions = []
    low_level_labels = []
    high_level_positions = [1.0]
    high_level_labels = []

    for outer_k, inner in concerted_activation_energies.items():
        if outer_k == "Ref":
            continue
        for i_inner, inner_k in enumerate(inner.keys()):
            positions.append(-(cum_shift + i_inner))
            low_level_labels.append(inner_k)
            if outer_k.startswith("Orbformer"):
                c = colormaps[outer_k][i_inner]
            else:
                c = colormaps[outer_k](cmap_idx(i_inner))
            ax[0].barh(
                positions[-1],
                abs(inner[inner_k] - concerted_activation_energies["Ref"]),
                color=c,
            )
            if outer_k in reaction_energies:
                ax[1].barh(
                    positions[-1],
                    abs(reaction_energies[outer_k][inner_k] - reaction_energies["Ref"]),
                    color=c,
                )
            else:
                ax[1].barh(positions[-1], 6.5, color="#ededeb", height=1.5)
        high_level_positions.append(positions[-1] - 0.5 * inter_group_gap)
        high_level_labels.append(outer_k)
        cum_shift += len(inner) + inter_group_gap

    ax[0].axvline(1.0, linestyle=":", color="k", zorder=-100)
    ax[1].axvline(1.0, linestyle=":", color="k", zorder=-100)

    # Paint in molecule images
    mol_xoffset = -0.1
    mol_yoffset = 0.0

    ##############################
    # ax[0]
    ##############################
    # ethene
    ethene_img = plt.imread(
        f"{data_dir}/molecule_images/ethene_66_6311Gdp.png",
        format="png",
    )
    imagebox = AnnotationBbox(
        OffsetImage(ethene_img, zoom=0.135),
        (0.9 + mol_xoffset, 0.50 + mol_yoffset),
        xycoords="axes fraction",
        bboxprops={"linewidth": 0},
    )
    ax[0].add_artist(imagebox)

    # butadiene
    butadiene_img = plt.imread(
        f"{data_dir}/molecule_images/trans_butadiene_66_6311Gdp.png",
        format="png",
    )
    imagebox = AnnotationBbox(
        OffsetImage(butadiene_img, zoom=0.135),
        (0.9 + mol_xoffset, 0.23 + mol_yoffset),
        xybox=(0.9 + mol_xoffset, 0.38 + mol_yoffset),
        arrowprops=dict(arrowstyle="->"),
        xycoords="axes fraction",
        bboxprops={"linewidth": 0},
    )
    ax[0].add_artist(imagebox)

    # ts_concerted
    ts_concerted_img = plt.imread(
        f"{data_dir}/molecule_images/ts_concerted_biradical_66_6311Gdp.png",
        format="png",
    )
    imagebox = AnnotationBbox(
        OffsetImage(ts_concerted_img, zoom=0.15),
        (0.9 + mol_xoffset, 0.12 + mol_yoffset),
        xycoords="axes fraction",
        bboxprops={"linewidth": 0},
    )
    ax[0].add_artist(imagebox)

    ##############################
    # ax[1]
    ##############################
    mol_xoffset = -0.05
    mol_yoffset = 0.0
    # ethene
    ethene_img = plt.imread(
        f"{data_dir}/molecule_images/ethene_66_6311Gdp.png",
        format="png",
    )
    imagebox = AnnotationBbox(
        OffsetImage(ethene_img, zoom=0.135),
        (0.8 + mol_xoffset, 0.58 + mol_yoffset),
        xycoords="axes fraction",
        bboxprops={"linewidth": 0},
    )
    ax[1].add_artist(imagebox)

    # butadiene
    butadiene_img = plt.imread(
        f"{data_dir}/molecule_images/trans_butadiene_66_6311Gdp.png",
        format="png",
    )
    imagebox = AnnotationBbox(
        OffsetImage(butadiene_img, zoom=0.135),
        (0.8 + mol_xoffset, 0.28 + mol_yoffset),
        xybox=(0.8 + mol_xoffset, 0.44 + mol_yoffset),
        arrowprops=dict(arrowstyle="->"),
        xycoords="axes fraction",
        bboxprops={"linewidth": 0},
    )
    ax[1].add_artist(imagebox)

    # cyclohexene
    cyclohexene_img = plt.imread(
        f"{data_dir}/molecule_images/cyclohexane_66_6311Gdp.png",
        format="png",
    )
    imagebox = AnnotationBbox(
        OffsetImage(cyclohexene_img, zoom=0.135),
        (0.8 + mol_xoffset, 0.15 + mol_yoffset),
        xycoords="axes fraction",
        bboxprops={"linewidth": 0},
    )
    ax[1].add_artist(imagebox)

    # Set up the axes
    ax[0].set_ylim(min(positions) - 1, max(positions) + 1)
    ax[0].set_xlabel("Activation energy error (kcal/mol)")
    ax[1].set_xlabel("Reaction energy error (kcal/mol)")
    ax[0].set_yticks(positions, low_level_labels, fontsize=10)

    ax2 = ax[0].twinx()
    ax2.set_ylim(min(positions) - 1, max(positions) + 1)
    # ax.set_xscale('log')

    ax2.spines["left"].set_position(("axes", -0.32))
    ax2.spines["left"].set_linewidth(0)
    ax2.tick_params("both", length=0, width=0, which="minor")
    ax2.tick_params("both", length=0, width=0, direction="in", which="major")
    ax2.yaxis.set_ticks_position("left")
    ax2.yaxis.set_label_position("left")
    ax[1].annotate("Data not available", (0.55, -13), fontsize=12)

    ax2.set_yticks(np.array(high_level_positions) - 0.5)
    high_level_half_points = [
        (start + stop) / 2 - 0.5
        for start, stop in zip(high_level_positions[:-1], high_level_positions[1:])
    ]
    ax2.yaxis.set_major_formatter(ticker.NullFormatter())
    ax2.yaxis.set_minor_locator(ticker.FixedLocator(high_level_half_points))
    ax2.yaxis.set_minor_formatter(ticker.FixedFormatter(high_level_labels))
    ax2.tick_params(axis="y", which="minor", labelsize=10)

    fig.align_ylabels()
    plt.subplots_adjust(wspace=0)
    plt.tight_layout()

In [None]:
combined_bar()
if save_figures:
    plt.savefig("barplot-reaction-activation.pdf", dpi=300)
plt.show()

In [None]:
def tsf_anti_bar():
    fig, ax = plt.subplots(figsize=(7, 7))
    inter_group_gap = 0.5
    cum_shift = 0
    positions = []
    low_level_labels = []
    high_level_positions = [1.0]
    high_level_labels = []

    for outer_k, inner in tsf_anti_energies.items():
        for i_inner, inner_k in enumerate(inner.keys()):
            positions.append(-(cum_shift + i_inner))
            low_level_labels.append(inner_k)
            if outer_k.startswith("Orbformer"):
                c = colormaps[outer_k][i_inner]
            else:
                c = colormaps[outer_k](cmap_idx(i_inner))
            ax.barh(
                positions[-1],
                inner[inner_k],
                color=c,
            )
        high_level_positions.append(positions[-1] - 0.5 * inter_group_gap)
        high_level_labels.append(outer_k)
        cum_shift += len(inner) + inter_group_gap

    # Set up the axes
    ax.set_ylim(min(positions) - 1, max(positions) + 1)
    ax.set_xlabel("Stepwise activation energy (kcal/mol)")
    ax.set_yticks(positions, low_level_labels, fontsize=10)

    ax2 = ax.twinx()
    ax2.set_ylim(min(positions) - 1, max(positions) + 1)
    # ax.set_xscale('log')

    ax2.spines["left"].set_position(("axes", -0.32))
    ax2.spines["left"].set_linewidth(0)
    ax2.tick_params("both", length=0, width=0, which="minor")
    ax2.tick_params("both", length=0, width=0, direction="in", which="major")
    ax2.yaxis.set_ticks_position("left")
    ax2.yaxis.set_label_position("left")

    ax2.set_yticks(np.array(high_level_positions) - 0.5)
    high_level_half_points = [
        (start + stop) / 2 - 0.5
        for start, stop in zip(high_level_positions[:-1], high_level_positions[1:])
    ]
    ax2.yaxis.set_major_formatter(ticker.NullFormatter())
    ax2.yaxis.set_minor_locator(ticker.FixedLocator(high_level_half_points))
    ax2.yaxis.set_minor_formatter(ticker.FixedFormatter(high_level_labels))
    ax2.tick_params(axis="y", which="minor", labelsize=10)

    plt.tight_layout()

In [None]:
tsf_anti_bar()
if save_figures:
    plt.savefig("barplot-stepwise.pdf")
plt.show()