In [None]:
import numpy as np
from scipy.interpolate import CubicSpline

import matplotlib as mpl
import matplotlib.font_manager as fm
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.gridspec as gridspec
import matplotlib.patches as patches

from matplotlib.offsetbox import OffsetImage, AnnotationBbox

In [None]:
# On Linux: Manually download segoeui font and load and register here

from pathlib import Path

font_path = f'{Path.home()}/.local/share/fonts/SEGOEUI.TTF'
segoe_prop = fm.FontProperties(fname=font_path)

mpl.rcParams['font.family'] = segoe_prop.get_name()
fm.fontManager.addfont(font_path)

In [None]:
# Matplotlib settings

mpl.rcParams["text.usetex"] = False
mpl.rcParams['mathtext.fontset'] = 'stix'
mpl.rcParams['mathtext.cal'] = 'stix'
mpl.rcParams['mathtext.rm'] = 'stix'
mpl.rcParams['mathtext.it'] = 'stix:italic'
mpl.rcParams['mathtext.bf'] = 'stix:bold'
mpl.rcParams["font.size"] = 10

%config InlineBackend.print_figure_kwargs={'bbox_inches': None}

model_color = "#dff2f4"
array_color = "#e7dff6"
dataset_color = "#fff1ee"
result_color = "#ffd566"

In [None]:
TEXTWIDTH = 7.08

fig = plt.figure(figsize=(TEXTWIDTH, 0.5 * TEXTWIDTH))
spec = gridspec.GridSpec(ncols=2, nrows=1, figure=fig)
ax1 = fig.add_subplot(spec[0, 0])
ax2 = fig.add_subplot(spec[0, 1])

lw = 0.5
lw_bold = 2 * lw
lw_box = lw


def annotate(
    ax,
    text,
    xy,
    xytext=None,
    ha='center',
    va='center',
    arrowprops={},
    arrow=False,
    **kwargs,
):
    ap = {
        'color': 'k',
        'arrowstyle': '-|>, head_width=0.15, head_length=0.3',
        'linewidth': lw,
    }
    arrowprops = {**ap, **arrowprops} if (arrow or arrowprops) else None
    ax.annotate(
        text,
        xy,
        xytext,
        xycoords='figure fraction',
        textcoords='figure fraction',
        ha=ha,
        va=va,
        arrowprops=arrowprops,
        **kwargs,
    )


ax1.axis('off')
ax2.axis('off')

# SCAFFOLDING
for l in [
    ((0.0015, -0.1), (0.0015, 1.1)),
    ((0.5, -0.1), (0.5, 1.1)),
    ((0.26, 0.44), (0.502, 0.44)),
    ((0.05, 0.44), (0.23, 0.44)),
    ((-0.1, 0.723), (0.35, 0.723)),
    ((0.6, 0.66), (1.1, 0.66)),
    ((0.498, 0.39), (0.66, 0.39)),
]:
    annotate(ax1, None, *l, arrowprops={'color': 'lightgrey', 'arrowstyle': '-'})


# SUBFIGURE PRETRAIN
annotate(ax1, 'Chemical Pretraining', (0.25, 0.95), fontsize=10)

# DATASET CONSTRUCTION

arrimg = mpimg.imread(f"atoms.png")
imagebox = OffsetImage(arrimg, zoom=0.09)
ab = AnnotationBbox(
    imagebox,
    (0.1, 0.77),
    frameon=False,
    xycoords='figure fraction',
    box_alignment=(0.5, 0.0),
)
ax1.add_artist(ab)


arrimg = mpimg.imread("NH3_mol.png")
imagebox = OffsetImage(arrimg, zoom=0.07)
ab = AnnotationBbox(imagebox, (0.25, 0.8), frameon=False, xycoords='figure fraction')
ax1.add_artist(ab)

for i, action in enumerate(['Bending', 'Stretching', 'Breaking']):
    arrimg = mpimg.imread(f"{action.lower()}.png")
    imagebox = OffsetImage(arrimg, zoom=0.045)
    ab = AnnotationBbox(
        imagebox,
        (0.35 + 0.055 * i, 0.78),
        frameon=False,
        xycoords='figure fraction',
        box_alignment=(0.5, 0.0),
    )
    ax1.add_artist(ab)
    annotate(ax1, action, (0.35 + 0.055 * i, 0.755), fontsize=6)

annotate(ax1, None, (0.2, 0.8), (0.17, 0.8), arrow=True)
annotate(ax1, None, (0.32, 0.8), (0.29, 0.8), arrow=True)

annotate(ax1, 'Light Atom Species', (0.1, 0.89), fontsize=8)
annotate(ax1, 'Assemble Molecules', (0.25, 0.89), fontsize=8)
annotate(ax1, 'Distort Geometries', (0.4, 0.89), fontsize=8)

annotate(ax1, 'H', (0.056, 0.812), fontsize=6)
annotate(ax1, 'Li', (0.085, 0.812), fontsize=6)
annotate(ax1, 'B', (0.1135, 0.812), fontsize=6)
annotate(ax1, 'C', (0.1437, 0.812), fontsize=6)
annotate(ax1, 'N', (0.07, 0.75), fontsize=6)
annotate(ax1, 'O', (0.1, 0.75), fontsize=6)
annotate(ax1, 'F', (0.129, 0.75), fontsize=6)

# LAC DATASET

(x, y) = (0, 0)
annotate(ax1, None, (0.39 + x, 0.69 + y), (0.39, 0.74), arrow=True)
annotate(ax1, 'LAC Dataset', (0.31 + x, 0.69 + y), fontsize=8)
empty_box = patches.FancyBboxPatch(
    (0.58 + x / 2, 0.465 + y),
    0.48,
    0.2,
    facecolor=dataset_color,
    edgecolor='k',
    linewidth=lw_box,
    zorder=0,
    clip_on=False,
    boxstyle="round,pad=0.0,rounding_size=0.02",
)
ax1.add_patch(empty_box)
annotate(ax1, '• 88 distinct molecules', (0.275 + x, 0.64 + y), ha="left", fontsize=6)
annotate(ax1, '• 22350 structures', (0.275 + x, 0.61 + y), ha="left", fontsize=6)
annotate(ax1, '• up to 22 electrons ', (0.275 + x, 0.58 + y), ha="left", fontsize=6)
annotate(
    ax1, '• organic & inorganic chemistry', (0.275 + x, 0.55 + y), ha="left", fontsize=6
)
annotate(
    ax1, '• non-equilibrium & reactions', (0.275 + x, 0.52 + y), ha="left", fontsize=6
)
annotate(
    ax1, '• 45% multi-reference character', (0.275 + x, 0.49 + y), ha="left", fontsize=6
)

annotate(
    ax1,
    None,
    (0.225 + x, 0.57 + y),
    (0.255 + x, 0.57 + y),
    arrowprops={"arrowstyle": "-"},
)

# AUGMENTATION

(x, y) = (-0.11, -0.01)

annotate(ax1, 'Augmentation', (0.24 + x, 0.68 + y), fontsize=8)

# Rotation
arrimg = mpimg.imread("C2H4_mol.png")
imagebox = OffsetImage(arrimg, zoom=0.05)
ab = AnnotationBbox(
    imagebox, (0.205 + x, 0.58 + y), frameon=False, xycoords='figure fraction'
)
ax1.add_artist(ab)

for a in [
    ((0.21 + x, 0.52 + y), (0.17 + x, 0.61 + y)),
    ((0.20 + x, 0.64 + y), (0.24 + x, 0.55 + y)),
]:
    annotate(
        ax1,
        None,
        *a,
        arrowprops={
            'connectionstyle': "angle,angleA=90,angleB=0,rad=18",
            "arrowstyle": "-|>, head_width=0.08, head_length=0.2",
        },
    )

annotate(ax1, 'Rotation', (0.205 + x, 0.49 + y), fontsize=7)

# Fuzz
ab = AnnotationBbox(
    imagebox, (0.28 + x, 0.58 + y), frameon=False, xycoords='figure fraction'
)
ax1.add_artist(ab)
for a in [
    ((0.27 + x, 0.64 + y), (0.288 + x, 0.61 + y)),
    ((0.265 + x, 0.62 + y), (0.256 + x, 0.58 + y)),
    ((0.25 + x, 0.56 + y), (0.273 + x, 0.545 + y)),
    ((0.32 + x, 0.57 + y), (0.2976 + x, 0.5845 + y)),
    ((0.3 + x, 0.62 + y), (0.287 + x, 0.587 + y)),
    ((0.295 + x, 0.555 + y), (0.268 + x, 0.576 + y)),
]:
    annotate(
        ax1,
        None,
        *a,
        arrowprops={"arrowstyle": "-|>, head_width=0.05, head_length=0.1"},
    )

annotate(ax1, 'Fuzz', (0.28 + x, 0.49 + y), fontsize=7)

annotate(
    ax1,
    None,
    (0.03, 0.425),
    (0.16 + x, 0.58 + y),
    arrowprops={'connectionstyle': "angle,angleA=0,angleB=90,rad=8"},
)

# PRETRAINING
x, y = (0, -0.25)

# Arrow to stacked electron positions r
annotate(
    ax1,
    None,
    (0.055 + x, 0.51 + y),
    (0.03 + x, 0.55 + y),
    arrowprops={'connectionstyle': "angle,angleA=90,angleB=0,rad=8"},
)

# Arrow from r to Orbformer
annotate(ax1, None, (0.13 + x, 0.51 + y), (0.105 + x, 0.51 + y), arrow=True)

# Arrow from molecular configuration M to Orbformer
annotate(
    ax1,
    None,
    (0.13 + x, 0.455 + y),
    (0.03 + x, 0.625 + y),
    arrowprops={'connectionstyle': "angle,angleA=90,angleB=0,rad=8"},
)

# Plot M and r boxes
empty_box = patches.FancyBboxPatch(
    (0.035 + x / 2, 0.62 + y),
    0.06,
    0.05,
    facecolor=array_color,
    edgecolor='k',
    linewidth=lw_box,
    clip_on=False,
    boxstyle="round,pad=0.0,rounding_size=0.01",
)
ax1.add_patch(empty_box)
annotate(ax1, r'$\mathbf{M}$', (0.03 + x, 0.64 + y), fontsize=9)
annotate(
    ax1,
    r"$\{\mathbf{x}\}_i$",
    (0.08 + x, 0.508 + y),
    bbox=dict(
        boxstyle="round,pad=0.3", linewidth=lw_box, edgecolor="k", facecolor=array_color
    ),
    fontsize=9,
    color="black",
)

# Orbfrormer box
annotate(
    ax1,
    "Orbformer\nWavefunction",
    (0.2 + x, 0.48 + y),
    bbox=dict(
        boxstyle="round,pad=0.8", linewidth=lw_box, edgecolor="k", facecolor=model_color
    ),
    fontsize=9,
    color="black",
)

# MCMC arrows and text
annotate(
    ax1,
    None,
    (0.3 + x, 0.56 + y),
    (0.27 + x, 0.475 + y),
    arrowprops={'connectionstyle': "angle,angleA=0,angleB=90,rad=8"},
)
ax1.annotate(
    r"$\rho_\mathbf{M}(\mathbf{x}) = \left|\Psi(\mathbf{x}\mid\mathbf{M})\right|^2$",
    xycoords='figure fraction',
    textcoords='figure fraction',
    xy=(0.19 + x, 0.58 + y),
    fontsize=9,
    color="black",
)
annotate(
    ax1,
    None,
    (0.08 + x, 0.54 + y),
    (0.18 + x, 0.59 + y),
    arrowprops={'connectionstyle': "angle,angleA=0,angleB=90,rad=8"},
)
ax1.annotate(
    r"ULA/MALA",
    xycoords='figure fraction',
    textcoords='figure fraction',
    xy=(0.1 + x, 0.605 + y),
    fontsize=7,
    color="black",
)

# SGD arrows and text
annotate(
    ax1,
    None,
    (0.3 + x, 0.39 + y),
    (0.27 + x, 0.475 + y),
    arrowprops={'connectionstyle': "angle,angleA=0,angleB=90,rad=8"},
)
annotate(
    ax1,
    r"$\mathcal{E}_{\mathbf{M}}"
    r" \approx{\sum}_i~\,\frac{H_\mathbf{M}\Psi(\mathbf{x}_i\mid\mathbf{M})}{\Psi(\mathbf{x}_i\mid\mathbf{M})}$",
    (0.3 + x, 0.349 + y),
    fontsize=10,
    color="black",
)
annotate(
    ax1,
    None,
    (0.2 + x, 0.41 + y),
    (0.22 + x, 0.35 + y),
    arrowprops={'connectionstyle': "angle,angleA=0,angleB=90,rad=8"},
)
annotate(
    ax1,
    r"$\nabla_\boldsymbol{\theta}~\mathcal{E}_\mathbf{M} \rightarrow"
    r" \boldsymbol{\theta}'$",
    xy=(0.14 + x, 0.37 + y),
    fontsize=9,
    color="black",
)

# Final parameters
annotate(
    ax1,
    None,
    (x + 0.205, y + 0.285),
    (x + 0.14, y + 0.34),
    fontsize=7,
    arrowprops={'connectionstyle': "angle,angleA=90,angleB=0,rad=8"},
)
annotate(
    ax1, r"$\boldsymbol{\theta}^*$", (x + 0.22, y + 0.285), fontsize=10, color="black"
)
annotate(
    ax1,
    ", Foundation Model Parameters",
    (x + 0.34, y + 0.285),
    fontsize=8,
    color="black",
)

# Training progress subfigure
axp = ax1.inset_axes((0.772 + x, 0.44 + y, 0.3, 0.13))
x_plot = np.linspace(0, 1, 50)
y_plot = (
    20 - 15 * x_plot + np.random.normal(0, 1, 50) + np.random.normal(0, 5, 5).repeat(10)
)
axp.plot(x_plot, y_plot, c=result_color)
for i in range(1, 6):
    axp.axvline(i * 0.2, color='k', lw=lw)
axp.set_xlabel('Training Iteration', fontsize=7)
axp.set_ylabel('Variance', fontsize=7)
axp.set_xticks([], [])
axp.set_yticks([])
axp.set_zorder(0)
axp.spines['top'].set_visible(False)
axp.spines['right'].set_visible(False)

# Arrows from training back to M
for a in [
    ((0.3 + x, 0.64 + y), (0.245 + x, 0.5)),
    ((0.293 + x, 0.64 + y), (0.4811 + x, 0.562 + y)),
    ((0.427 + x, 0.64 + y), (0.4563 + x, 0.562 + y)),
    ((0.402 + x, 0.64 + y), (0.4315 + x, 0.562 + y)),
    ((0.377 + x, 0.64 + y), (0.4067 + x, 0.562 + y)),
    ((0.351 + x, 0.64 + y), (0.3819 + x, 0.562 + y)),
]:
    annotate(
        ax1,
        None,
        *a,
        arrowprops={
            'connectionstyle': "angle,angleA=90,angleB=0,rad=8",
            'arrowstyle': '-',
        },
    )
annotate(
    ax1,
    None,
    (0.215, 0.57),
    (0.245 + x, 0.485),
    arrowprops={'connectionstyle': "angle,angleA=90,angleB=0,rad=8"},
)
annotate(ax1, r"draw molecule", (0.34 + x, 0.658 + y), fontsize=7, color="black")

# SUBFIGURE FINETUNE

annotate(ax2, 'Transferable Finetuning', (0.75, 0.95), fontsize=10)
annotate(ax2, 'Targeted Chemical Process', (0.62, 0.89), fontsize=8)

# Plot molecules of chemical process
for i, j in enumerate([0, 4, 9, 14, 19]):
    arrimg = mpimg.imread(f"bbmep_1106_mols/mol_{(str(0) + str(j))[-2:]}.png")
    imagebox = OffsetImage(arrimg, zoom=0.06)
    ab = AnnotationBbox(
        imagebox, (0.58 + i * 0.08, 0.8), frameon=False, xycoords='figure fraction'
    )
    ax2.add_artist(ab)
    annotate(
        ax2,
        None,
        (0.55, 0.7009 - i * 0.00047),
        (0.58 + i * 0.08, 0.74),
        arrowprops={
            'arrowstyle': '-',
            'connectionstyle': "angle,angleA=90,angleB=0,rad=8",
        },
    )

# Arrow to M block
annotate(
    ax2,
    None,
    (0.53, 0.63),
    (0.558, 0.7),
    arrowprops={
        'linewidth': lw_bold,
        'connectionstyle': "angle,angleA=0,angleB=-90,rad=8",
    },
)


# FINETUNE
x, y = (0.5, -0.01)

# Arrow to stacked electron positions r
annotate(
    ax2,
    None,
    (0.08 + x, 0.51 + y),
    (0.03 + x, 0.55 + y),
    arrowprops={
        'linewidth': lw_bold,
        'connectionstyle': "angle,angleA=90,angleB=0,rad=8",
    },
)
# Arrow from r to Orbformer
annotate(
    ax2,
    None,
    (0.185 + x, 0.51 + y),
    (0.14 + x, 0.51 + y),
    arrowprops={'linewidth': lw_bold},
)

# Arrow from molecular configuration M to Orbformer
annotate(
    ax2,
    None,
    (0.185 + x, 0.44 + y),
    (0.03 + x, 0.575 + y),
    arrowprops={
        'linewidth': lw_bold,
        'connectionstyle': "angle,angleA=90,angleB=0,rad=8",
    },
)

# Plot M and r boxes
for i in range(5):
    empty_box = patches.FancyBboxPatch(
        (-0.05 - i * 0.005, 0.56 + i * 0.005),
        0.06,
        0.05,
        facecolor=array_color,
        edgecolor='k',
        linewidth=lw_box,
        clip_on=False,
        boxstyle="round,pad=0.0,rounding_size=0.01",
    )
    ax2.add_patch(empty_box)
    annotate(
        ax2,
        r"$\{\mathbf{x}\}_i$",
        (0.115 + x - i * 0.0025, 0.495 + y + i * 0.005),
        bbox=dict(
            boxstyle="round,pad=0.3",
            linewidth=lw_box,
            edgecolor="k",
            facecolor=array_color,
        ),
        fontsize=9,
        color="black",
    )

annotate(ax2, r'$\mathbf{M}$', (0.0275 + x, 0.61 + y), fontsize=9)

# Orbfrormer box
annotate(
    ax2,
    r"Pretrained" + "\nOrbformer\nWavefunction",
    (0.255 + x, 0.48 + y),
    bbox=dict(
        boxstyle="round,pad=0.8", linewidth=lw_box, edgecolor="k", facecolor=model_color
    ),
    fontsize=9,
    color="black",
)

# MCMC arrows and text
annotate(
    ax2,
    None,
    (0.4 + x, 0.58 + y),
    (0.33 + x, 0.48 + y),
    arrowprops={
        'linewidth': lw_bold,
        'connectionstyle': "angle,angleA=0,angleB=90,rad=8",
    },
)
annotate(
    ax2,
    r"$\rho_\mathbf{M}(\mathbf{x}) = \left|\Psi(\mathbf{x}\mid\mathbf{M})\right|^2$",
    (0.41 + x, 0.615 + y),
    fontsize=9,
    color="black",
)
annotate(
    ax2,
    None,
    (0.107 + x, 0.55 + y),
    (0.33 + x, 0.615 + y),
    arrowprops={
        'linewidth': lw_bold,
        'connectionstyle': "angle,angleA=0,angleB=90,rad=8",
    },
)
annotate(ax2, r"MALA", (0.22 + x, 0.635 + y), fontsize=7, color="black")

# SGD arrows and text
annotate(
    ax2,
    None,
    (0.4 + x, 0.32 + y),
    (0.33 + x, 0.48 + y),
    arrowprops={
        'linewidth': lw_bold,
        'connectionstyle': "angle,angleA=0,angleB=90,rad=8",
    },
)
annotate(
    ax2,
    r"$\mathcal{E}_{\mathbf{M}}"
    r" \approx{\sum}_{i}~\,\frac{H_\mathbf{M}\Psi(\mathbf{x}_i\mid\mathbf{M})}{\Psi(\mathbf{x}_i\mid\mathbf{M})}$",
    (0.4 + x, 0.27 + y),
    fontsize=10,
    color="black",
)
annotate(
    ax2,
    None,
    (0.25 + x, 0.39 + y),
    (0.32 + x, 0.3 - 0.02 + y),
    arrowprops={
        'linewidth': lw_bold,
        'connectionstyle': "angle,angleA=0,angleB=90,rad=8",
    },
)
annotate(
    ax2,
    r"$\nabla_\boldsymbol{\theta}~{\sum}_{\mathbf{M}}~\,\mathcal{E}_\mathbf{M}"
    r" \rightarrow\boldsymbol{\theta}'$",
    (0.18 + x, 0.33 + y),
    fontsize=9,
    color="black",
)

# Arrows to PES datapoints
annotate(
    ax2,
    None,
    (0.06 + x, 0.08),
    (0.32 + x, 0.29 - 0.02 + y),
    arrowprops={'connectionstyle': "angle,angleA=0,angleB=-90,rad=20"},
)
annotate(
    ax2,
    None,
    (0.125 + x, 0.08),
    (0.32 + x, 0.286 - 0.02 + y),
    arrowprops={'connectionstyle': "angle,angleA=0,angleB=-90,rad=16"},
)
annotate(
    ax2,
    None,
    (0.207 + x, 0.08),
    (0.32 + x, 0.282 - 0.02 + y),
    arrowprops={'connectionstyle': "angle,angleA=0,angleB=-90,rad=12"},
)
annotate(
    ax2,
    None,
    (0.288 + x, 0.14),
    (0.32 + x, 0.278 - 0.02 + y),
    arrowprops={'connectionstyle': "angle,angleA=0,angleB=-90,rad=8"},
)
annotate(
    ax2,
    None,
    (0.37 + x, 0.17),
    (0.37 + x, 0.255 - 0.02 + y),
    arrowprops={'connectionstyle': "angle,angleA=0,angleB=-90,rad=0"},
)

axp = ax2.inset_axes((-0.002, 0.07 + y, 0.75, 0.16))
x = np.arange(0, 20, 1)
y = np.array([
    -323.76290,
    -323.77085,
    -323.76719,
    -323.76160,
    -323.76803,
    -323.77344,
    -323.75211,
    -323.75864,
    -323.76477,
    -323.76331,
    -323.75992,
    -323.76476,
    -323.75962,
    -323.76217,
    -323.66576,
    -323.60473,
    -323.52371,
    -323.60698,
    -323.60705,
    -323.61342,
])
y_err = np.array([
    22e-5,
    26e-5,
    29e-5,
    28e-5,
    33e-5,
    10e-5,
    29e-5,
    3e-5,
    25e-5,
    23e-5,
    11e-5,
    15e-5,
    11e-5,
    5e-5,
    19e-5,
    8e-5,
    5e-5,
    9e-5,
    35e-5,
    9e-5,
])
idxs = [0, 4, 9, 14, 19]
xnew = np.linspace(0, 19, num=1000)
ynew = np.interp(xnew, x, y)
ynew = CubicSpline(x, y)(xnew)
axp.plot(xnew, ynew, ls=':', c='grey', lw=1)
axp.errorbar(x[idxs], y[idxs], y_err[idxs], marker='.', ms=7, ls='', c=result_color)
axp.set_xlabel('Reaction Coordinate', fontsize=7)
axp.set_ylabel('Energy', fontsize=7)
axp.set_xticks([])
axp.set_yticks([])
axp.spines['top'].set_visible(False)
axp.spines['right'].set_visible(False)
axp.set_zorder(0)

fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
fig.savefig('method.pdf', dpi=600)

In [None]:
nuc_color = "#dff2f4"
elec_color = 'lightyellow'
array_color = 'lightgrey'

fig = plt.figure(figsize=(0.5 * TEXTWIDTH, 0.5 * TEXTWIDTH))
spec = gridspec.GridSpec(ncols=1, nrows=1, figure=fig)
ax1 = fig.add_subplot(spec[0, 0])

ax1.axis('off')
ax1.annotate(
    'Orbformer Wavefunction',
    (0.25 * 2, 0.95),
    xycoords='figure fraction',
    ha="center",
    va="center",
    fontsize=10,
)

elecs = np.array([[0.46, 0.91], [0.39, 0.73], [0.52, 0.86], [0.33, 0.96]])
ax1.scatter(*elecs[:2].T, c='r', s=5)
ax1.scatter(*elecs[2:].T, c='b', s=5)
ax1.set_ylim(0, 1)
ax1.set_xlim(0, 1)

annotate(
    ax1,
    None,
    (0.6, 0.8),
    (0.495, 0.85),
    arrowprops={
        'connectionstyle': "angle,angleA=0,angleB=90,rad=5",
        'arrowstyle': '-',
        'color': 'lightgray',
    },
)
annotate(
    ax1,
    None,
    (0.6, 0.76),
    (0.485, 0.72),
    arrowprops={
        'connectionstyle': "angle,angleA=0,angleB=90,rad=5",
        'arrowstyle': '-',
        'color': 'lightgray',
    },
)
annotate(
    ax1,
    None,
    (0.8, 0.78),
    (0.6, 0.82),
    arrowprops={
        'connectionstyle': "angle,angleA=90,angleB=0,rad=5",
        'arrowstyle': '-',
        'color': 'lightgray',
    },
)
annotate(
    ax1,
    None,
    (0.8, 0.77),
    (0.6, 0.73),
    arrowprops={
        'connectionstyle': "angle,angleA=90,angleB=0,rad=5",
        'arrowstyle': '-',
        'color': 'lightgray',
    },
)
annotate(
    ax1,
    None,
    (0.19, 0.79),
    (0.3, 0.82),
    arrowprops={
        'connectionstyle': "angle,angleA=90,angleB=0,rad=5",
        'arrowstyle': '-',
        'color': 'lightgray',
    },
)
annotate(
    ax1,
    None,
    (0.3, 0.8),
    (0.37, 0.85),
    arrowprops={
        'connectionstyle': "angle,angleA=0,angleB=90,rad=5",
        'arrowstyle': '-',
        'color': 'lightgray',
    },
)
annotate(
    ax1,
    None,
    (0.19, 0.78),
    (0.35, 0.805),
    arrowprops={
        'connectionstyle': "angle,angleA=90,angleB=0,rad=5",
        'arrowstyle': '-',
        'color': 'lightgray',
    },
)
annotate(
    ax1,
    None,
    (0.35, 0.79),
    (0.47, 0.81),
    arrowprops={
        'connectionstyle': "angle,angleA=0,angleB=90,rad=5",
        'arrowstyle': '-',
        'color': 'lightgray',
    },
)
annotate(
    ax1,
    None,
    (0.19, 0.77),
    (0.52, 0.77),
    arrowprops={'arrowstyle': '-', 'color': 'lightgray'},
)
annotate(
    ax1,
    None,
    (0.19, 0.76),
    (0.32, 0.7),
    arrowprops={
        'connectionstyle': "angle,angleA=90,angleB=0,rad=5",
        'arrowstyle': '-',
        'color': 'lightgray',
    },
)
annotate(
    ax1,
    None,
    (0.32, 0.73),
    (0.42, 0.67),
    arrowprops={
        'connectionstyle': "angle,angleA=0,angleB=90,rad=5",
        'arrowstyle': '-',
        'color': 'lightgray',
    },
)

# ________________________________________________

annotate(ax1, None, (0.13, 0.13), (0.13, 0.74), arrow=True)
annotate(
    ax1,
    None,
    (0.23, 0.68),
    (0.17, 0.74),
    arrowprops={'connectionstyle': "angle,angleA=90,angleB=0,rad=5", 'arrowstyle': '-'},
)
annotate(
    ax1,
    None,
    (0.25, 0.635),
    (0.21, 0.68),
    arrowprops={'connectionstyle': "angle,angleA=0,angleB=90,rad=5"},
)
annotate(
    ax1,
    None,
    (0.86, 0.68),
    (0.79, 0.74),
    arrowprops={'connectionstyle': "angle,angleA=90,angleB=0,rad=5", 'arrowstyle': '-'},
)
annotate(
    ax1,
    None,
    (0.88, 0.635),
    (0.84, 0.68),
    arrowprops={'connectionstyle': "angle,angleA=0,angleB=90,rad=5"},
)
annotate(
    ax1,
    None,
    (0.73, 0.68),
    (0.77, 0.74),
    arrowprops={'connectionstyle': "angle,angleA=90,angleB=0,rad=5", 'arrowstyle': '-'},
)
annotate(
    ax1,
    None,
    (0.7, 0.635),
    (0.75, 0.68),
    arrowprops={'connectionstyle': "angle,angleA=0,angleB=90,rad=5"},
)
annotate(
    ax1,
    None,
    (0.645, 0.7),
    (0.75, 0.74),
    arrowprops={'connectionstyle': "angle,angleA=90,angleB=0,rad=5", 'arrowstyle': '-'},
)
annotate(
    ax1,
    None,
    (0.61, 0.635),
    (0.66, 0.7),
    arrowprops={'connectionstyle': "angle,angleA=0,angleB=90,rad=5", 'arrowstyle': '-'},
)
annotate(
    ax1,
    None,
    (0.328, 0.6),
    (0.61, 0.65),
    arrowprops={'connectionstyle': "angle,angleA=90,angleB=0,rad=5"},
)
annotate(ax1, None, (0.328, 0.56), (0.65, 0.56), arrow=True)
annotate(ax1, None, (0.81, 0.58), (0.75, 0.58), arrow=True)
annotate(
    ax1,
    None,
    (0.324, 0.49),
    (0.25, 0.54),
    arrowprops={'connectionstyle': "angle,angleA=90,angleB=0,rad=5", 'arrowstyle': '-'},
)
annotate(
    ax1,
    None,
    (0.352, 0.49),
    (0.86, 0.54),
    arrowprops={'connectionstyle': "angle,angleA=90,angleB=0,rad=5", 'arrowstyle': '-'},
)
annotate(
    ax1,
    None,
    (0.634, 0.45),
    (0.15, 0.74),
    arrowprops={'connectionstyle': "angle,angleA=90,angleB=0,rad=5", 'arrowstyle': '-'},
)
annotate(
    ax1,
    None,
    (0.662, 0.45),
    (0.9, 0.54),
    arrowprops={'connectionstyle': "angle,angleA=90,angleB=0,rad=5", 'arrowstyle': '-'},
)
annotate(ax1, None, (0.338, 0.405), (0.338, 0.473), arrow=True)
annotate(ax1, None, (0.649, 0.379), (0.649, 0.435), arrow=True)
annotate(
    ax1,
    None,
    (0.34, 0.3),
    (0.482, 0.26),
    arrowprops={'connectionstyle': "angle,angleA=0,angleB=90,rad=5", 'arrowstyle': '-'},
)
annotate(
    ax1,
    None,
    (0.65, 0.3),
    (0.512, 0.26),
    arrowprops={'connectionstyle': "angle,angleA=0,angleB=90,rad=5", 'arrowstyle': '-'},
)
annotate(ax1, None, (0.497, 0.19), (0.497, 0.242), arrow=True)
annotate(ax1, None, (0.632, 0.15), (0.578, 0.15), arrowprops={'arrowstyle': '-'})
annotate(
    ax1,
    None,
    (0.663, 0.15),
    (0.719, 0.097),
    arrowprops={'connectionstyle': "angle,angleA=90,angleB=0,rad=5", 'arrowstyle': '-'},
)
annotate(ax1, None, (0.703, 0.085), (0.23, 0.085), arrowprops={'arrowstyle': '-'})
annotate(ax1, None, (0.81, 0.085), (0.732, 0.085), arrow=True)

circ1 = patches.Circle((0.275, 0.49), 0.02, lw=lw, ec='k', fill=False)
circ2 = patches.Circle((0.675, 0.44), 0.02, lw=lw, ec='k', fill=False)
circ3 = patches.Circle((0.675, 0.05), 0.02, lw=lw, ec='k', fill=False)
circ4 = patches.Circle((0.765, -0.035), 0.02, lw=lw, ec='k', fill=False)
circ5 = patches.Circle((0.48, 0.19), 0.02, lw=lw, ec='k', fill=False)
ax1.add_patch(circ1)
ax1.add_patch(circ2)
ax1.add_patch(circ3)
ax1.add_patch(circ4)
ax1.add_patch(circ5)
circ4.set_clip_box(None)

annotate(ax1, r"$\times$", (0.34, 0.484), fontsize=8, color="black")
annotate(ax1, r"$\times$", (0.65, 0.446), fontsize=8, color="black")
annotate(ax1, r"$\bullet$", (0.4975, 0.253), fontsize=8, color="black")
annotate(ax1, r"$\sum$", (0.649, 0.147), fontsize=4.5, color="black")
annotate(ax1, r"$\bullet$", (0.7185, 0.08), fontsize=8, color="black")

arrimg = mpimg.imread("H.png")
imagebox = OffsetImage(arrimg, zoom=0.2)
ab = AnnotationBbox(imagebox, (0.46, 0.72), frameon=False, xycoords='figure fraction')
ax1.add_artist(ab)

arrimg = mpimg.imread("Li.png")
imagebox = OffsetImage(arrimg, zoom=0.2)
ab = AnnotationBbox(imagebox, (0.46, 0.85), frameon=False, xycoords='figure fraction')
ax1.add_artist(ab)

annotate(
    ax1,
    "\n      ",
    (0.15, 0.775),
    bbox=dict(
        boxstyle="round,pad=0.8", linewidth=lw, edgecolor="k", facecolor=elec_color
    ),
    fontsize=7,
    color="black",
)
annotate(ax1, r"$[E,4]$", (0.15, 0.76), fontsize=6, color="black")
annotate(ax1, r"$\mathbf{x}$", (0.15, 0.79), fontsize=9, color="black")
annotate(
    ax1,
    "\n      ",
    (0.77, 0.775),
    bbox=dict(
        boxstyle="round,pad=0.8", linewidth=lw, edgecolor="k", facecolor=nuc_color
    ),
    fontsize=7,
    color="black",
)
annotate(ax1, r"$[N,4]$", (0.77, 0.76), fontsize=6, color="black")
annotate(ax1, r"$\mathbf{M}$", (0.77, 0.79), fontsize=9, color="black")

annotate(
    ax1,
    "Electron\nTransformer\n" + r"$[E,F_\text{rep}]$",
    (0.25, 0.58),
    bbox=dict(
        boxstyle="round,pad=0.8", linewidth=lw, edgecolor="k", facecolor=elec_color
    ),
    fontsize=6,
    color="black",
)
annotate(
    ax1,
    "Nuclei\nMPNN\n" + r"$[N,F_\text{nuc}]$",
    (0.7, 0.58),
    bbox=dict(
        boxstyle="round,pad=0.8", edgecolor="k", linewidth=lw, facecolor=nuc_color
    ),
    fontsize=6,
    color="black",
)
annotate(
    ax1,
    "Orbital\nGenerator\n" + r"$[O,N,F_\text{orb}]$",
    (0.88, 0.58),
    bbox=dict(
        boxstyle="round,pad=0.8", edgecolor="k", linewidth=lw, facecolor=nuc_color
    ),
    fontsize=6,
    color="black",
)
annotate(
    ax1,
    "Jastrow factor\n" + r"$[1]$",
    (0.15, 0.09),
    bbox=dict(
        boxstyle="round,pad=0.8", edgecolor="k", linewidth=lw, facecolor=array_color
    ),
    fontsize=6,
    color="black",
)
annotate(
    ax1,
    "Envelopes\n" + r"$[E,O,D]$",
    (0.65, 0.34),
    bbox=dict(
        boxstyle="round,pad=0.8", edgecolor="k", linewidth=lw, facecolor=dataset_color
    ),
    fontsize=6,
    color="black",
)
annotate(
    ax1,
    "Projected electron\n representations\n" + r"$[E,O,D]$",
    (0.35, 0.355),
    bbox=dict(
        boxstyle="round,pad=0.8", edgecolor="k", linewidth=lw, facecolor=dataset_color
    ),
    fontsize=6,
    color="black",
)
annotate(
    ax1,
    "Determinant\n" + r"$[D]$",
    (0.5, 0.15),
    bbox=dict(
        boxstyle="round,pad=0.8", edgecolor="k", linewidth=lw, facecolor=array_color
    ),
    fontsize=6,
    color="black",
)
annotate(
    ax1,
    r"$\Psi$",
    (0.85, 0.085),
    bbox=dict(
        boxstyle="round,pad=0.5", edgecolor="k", linewidth=lw, facecolor=array_color
    ),
    fontsize=11,
    color="black",
)

fig.savefig('architecture.pdf', dpi=600)