# Determinant penalty ablation study

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from oneqmc.analysis.plot import set_defaults
from oneqmc.analysis import colours

set_defaults()

In [None]:
t = np.arange(0, 200, 0.025)

In [None]:
smooth_energy = np.load(
    "../../experiment_results/08_si_ablations/det_penalty/energy.npz"
)

In [None]:
plt.figure(figsize=(7.5, 5.25))
plt.plot(
    t,
    smooth_energy["standard_psiformer"],
    label="Psiformer",
    c=colours.TEAL,
)
plt.plot(
    t, smooth_energy["nopretrain_psiformer"], label="Psiformer no HF", c=colours.YELLOW
)
plt.plot(
    t,
    smooth_energy["onedet_psiformer"],
    label="Psiformer 1 determinant",
    linestyle="dashed",
    c=colours.PURPLE,
)
plt.plot(
    t,
    smooth_energy["onedet_nopretrain_psiformer"],
    label="Psiformer 1 determinant no HF",
    linestyle="dashed",
    c=colours.ORANGE,
)
plt.plot(t, smooth_energy["orbformer"], label="Orbformer no HF")
plt.plot(
    t,
    smooth_energy["orbformer_no_pen"],
    label="Orbformer no HF, no penalty,\nold initialization",
    c="tab:green",
)
plt.ylim([-40.5148, -40.5132])
plt.legend()
plt.xlabel("Iteration (thousands)")
plt.ylabel("Smoothed energy (Ha)")
plt.tight_layout()
plt.show()


# Langevin sampling ablation study

In [None]:
collated_stats = np.load(
    "../../experiment_results/08_si_ablations/langevin/mala.npz"
)
collated_stats2 = np.load(
    "../../experiment_results/08_si_ablations/langevin/ula_mala.npz"
)
time = np.arange(len(collated_stats["sampling/tau"]))

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(nrows=3, sharex=True, figsize=(5, 7))

ax1.plot(time, collated_stats["sampling/tau"], label="MALA", c=colours.BRAND_TEAL)
ax1.plot(time, collated_stats2["sampling/tau"], label="ULA+MALA", c=colours.ORANGE)
ax1.axvline(150, c=colours.ORANGE, linestyle="--")
ax1.legend()
ax1.set_ylabel("Step size $\\tau^{(t)}$")

ax3.plot(
    time,
    collated_stats["sampling/log_psi"].mean(-1),
    label="MALA",
    c=colours.BRAND_TEAL,
)
ax3.plot(
    time,
    collated_stats2["sampling/log_psi"].mean(-1),
    label="ULA+MALA",
    c=colours.ORANGE,
)
ax3.axvline(150, c=colours.ORANGE, linestyle="--")
ax3.set_ylabel("Mean $\\log | \\Psi_\\theta|$")

ax2.plot(
    time, collated_stats["sampling/pdists"].mean(-1), label="MALA", c=colours.BRAND_TEAL
)
ax2.plot(
    time,
    collated_stats2["sampling/pdists"].mean(-1),
    label="ULA+MALA",
    c=colours.ORANGE,
)
ax2.axvline(150, c=colours.ORANGE, linestyle="--")
ax2.set_ylabel("Mean elec-elec distance")

ax3.set_xlabel("Iteration $t$")

plt.tight_layout()
plt.subplots_adjust(wspace=0, hspace=0.07)

# Iteration timing ablation study

### Configurations investigated

1. Psiformer no flash attn, loop laplacian float32
1. Orbformer no flash attn, loop laplacian float32
1. Orbformer no flash attn, loop laplacian tf32
1. Orbformer no flash attn, forward laplacian tf32
1. Orbformer flash attn, forward laplacian tf32

These configurations can be run with the below commands, respectively.

1. `python scripts/transferable.py -a psiformer-new -d SYSTEM --electron-batch-size 256 -n 100 --max-eq-steps 10 --metric-logger-period 25 --metric-logger tb h5 --det-penalty-weight 0.0 --no-balance-grad --no-flash-attn --laplacian loop --jax-matmul-precision highest`
2. `python scripts/transferable.py -a orbformer-se -d SYSTEM --electron-batch-size 256 -n 100 --max-eq-steps 10 --metric-logger-period 25 --metric-logger tb h5 --det-penalty-weight 0.0 --no-balance-grad --no-flash-attn --laplacian loop --jax-matmul-precision highest`
3. `python scripts/transferable.py -a orbformer-se -d SYSTEM --electron-batch-size 256 -n 100 --max-eq-steps 10 --metric-logger-period 25 --metric-logger tb h5 --det-penalty-weight 0.0 --no-balance-grad --no-flash-attn --laplacian loop --jax-matmul-precision high`
4. `python scripts/transferable.py -a orbformer-se -d SYSTEM --electron-batch-size 256 -n 100 --max-eq-steps 10 --metric-logger-period 25 --metric-logger tb h5 --det-penalty-weight 0.0 --no-balance-grad --no-flash-attn --laplacian forward --jax-matmul-precision high`
5. `python scripts/transferable.py -a orbformer-se -d SYSTEM --electron-batch-size 256 -n 100 --max-eq-steps 10 --metric-logger-period 25 --metric-logger tb h5 --det-penalty-weight 0.0 --no-balance-grad --flash-attn --laplacian forward --jax-matmul-precision high`

In [None]:
names = {
    "mep_2": r"Ethane ($E=18$)",
    "mep_418": r"Formamide ($E=24$)",
    "mep_548": r"Propanol ($E=34$)",
    "mep_1327": r"2-Aminopropan-2-ol ($E=42$)",
    "mep_1106": r"L-Alanine ($E=48$)",
}

data = {
    "mep_2": np.array([47.0, 85.0, 78.0, 51.0, 47.0]),
    "mep_418": np.array([64.0, 112.0, 81.0, 57.0, 54.0]),
    "mep_548": np.array([141.0, 232.0, 224.0, 111.0, 101.0]),
    "mep_1327": np.array([211.0, 330.0, 309.0, 150.0, 134.0]),
    "mep_1106": np.array([259.0, 392.0, 257.0, 159.0, 153.0]),
}

In [None]:
categorical = [
    colours.BRAND_TEAL,
    colours.YELLOW,
    colours.BRAND_PURPLE,
    colours.GREY,
    colours.BRAND_RED,
]

plt.figure(figsize=(10, 5.4))
for i, (k, v) in enumerate(data.items()):
    plt.barh(
        np.arange(len(v) - 1, 0, -1) + 0.17 * (2 - i),
        v[1:] / v[0] - 1,
        height=0.14,
        left=1,
        label=names[k],
        color=categorical[i],
    )

plt.xscale("log")
plt.minorticks_off()
plt.xticks([0.5, 0.7, 1, 1.4, 2], [0.5, 0.7, 1, 1.4, 2])
plt.axvline(1, c="k")
plt.yticks(
    [4, 3, 2, 1],
    ["Naive Orbformer", "+ TF32", "+ Forward Laplacian", "+ Flash Attention"],
)
plt.xlabel("Iteration timing relative to naive Psiformer")
plt.xscale("log")
plt.legend(loc="lower right")
plt.tight_layout()