In [1]:
import torch
import parmed as pmd
from openmm import app, unit, Context, VerletIntegrator
from openmm.app import AmberPrmtopFile
from openmm.openmm import HarmonicAngleForce, PeriodicTorsionForce, NonbondedForce
from parmed.openmm import energy_decomposition
import copy
import system.system as sys
import system.units as units
import system.box as box
from utils import * 

# --- Load topology and coordinates ---
parm = pmd.load_file("alanine-dipeptide.prmtop", xyz="alanine-dipeptide.pdb")
system = parm.createSystem(nonbondedMethod=app.NoCutoff, constraints=None)
nonbonded = next(f for f in system.getForces() if isinstance(f, NonbondedForce))

# --- Energy from OpenMM ---
integrator = VerletIntegrator(1 * unit.femtoseconds)
ctx = Context(system, integrator)
ctx.setPositions(parm.positions)
tot_E = ctx.getState(getEnergy=True).getPotentialEnergy()
terms = energy_decomposition(parm, ctx, nrg=unit.kilojoule_per_mole)

openmm_breakdown = {
    'bond': terms['bond'],
    'angle': terms['angle'],
    'dihedral': terms['dihedral'],
    'non_bonded': terms['nonbonded'],
    'total': tot_E.value_in_unit(unit.kilojoule_per_mole),
}

# --- Setup SPARK system ---
device = torch.device("cuda")
dtype = torch.float32
to_kjmol = 4.184 # convert here using emprical factor cuz its easier
top, node_features, mass, energy_dict = build_top_and_features("alanine-dipeptide.prmtop")
pos = torch.tensor(parm.coordinates, dtype=dtype, device=device).unsqueeze(0)
b = box.Box([1000,1000,1000], ["s", "s", "s"]) # open free expansion box, l=1000 doesnt matter 
u = units.UnitSystem.akma()
mom = 0.5 * torch.randn_like(pos)
S = sys.System(pos, mom, mass, top, b, energy_dict, u, node_features)
S.compile_force_fn()

spark_breakdown = {k: 0.0 for k in ['bond', 'angle', 'dihedral', 'non_bonded']}
for key, val in S.potential_energy_split().items():
    e = val.item() * to_kjmol
    if key.startswith('bond'):
        spark_breakdown['bond'] += e
    elif key.startswith('angle'):
        spark_breakdown['angle'] += e
    elif key.startswith('dih'):
        spark_breakdown['dihedral'] += e
    elif key in ['LJ', 'LJ_Fudge', 'coulomb', 'coulomb_Fudge']:
        spark_breakdown['non_bonded'] += e
spark_breakdown['total'] = S.potential_energy().cpu().item() * to_kjmol

# --- Energy breakdown comparison ---
header_fmt = "{:<14} {:>20} {:>20} {:>20}"
row_fmt = "{:<14} {:20.8f} {:20.8f} {:20.8f}"
print(header_fmt.format("Term", "OpenMM (kJ/mol)", "SPARK (kJ/mol)", "Δ (O - S)"))
print("-" * 74)
for key in ['bond', 'angle', 'dihedral', 'non_bonded']:
    o, s = openmm_breakdown[key], spark_breakdown[key]
    print(row_fmt.format(key, o, s, o - s))
print("-" * 74)
openmm_sum = sum(openmm_breakdown[k] for k in ['bond', 'angle', 'dihedral', 'non_bonded'])
spark_sum = sum(spark_breakdown[k] for k in ['bond', 'angle', 'dihedral', 'non_bonded'])
print(row_fmt.format("sum(parts)", openmm_sum, spark_sum, openmm_sum - spark_sum))
print(row_fmt.format("reported total", openmm_breakdown['total'], spark_breakdown['total'],
                     openmm_breakdown['total'] - spark_breakdown['total']))

# --- Extract and compare bonded terms ---
omm_bonds = {
    tuple(sorted((int(a), int(b))))
    for f in system.getForces() if hasattr(f, "getNumBonds")
    for i in range(f.getNumBonds())
    for a, b, *_ in [f.getBondParameters(i)]
}
omm_angles = {
    tuple(sorted((min(int(a), int(c)), int(b), max(int(a), int(c)))))
    for f in system.getForces() if isinstance(f, HarmonicAngleForce)
    for i in range(f.getNumAngles())
    for a, b, c, *_ in [f.getAngleParameters(i)]
}
omm_dihedrals = {
    (int(a), int(b), int(c), int(d))
    for f in system.getForces() if isinstance(f, PeriodicTorsionForce)
    for i in range(f.getNumTorsions())
    for a, b, c, d, *_ in [f.getTorsionParameters(i)]
}

bonded_top = {
    tuple(sorted((i, j)))
    for label, entries in top.get_arity(2).items()
    if label.startswith("bondtype_")
    for i, j in entries
}
angle_top = {
    tuple(sorted((min(i, k), j, max(i, k))))
    for label, entries in top.get_arity(3).items()
    if label.startswith("angletype_")
    for i, j, k in entries
}
dihedral_top = {
    (i, j, k, l)
    for label, entries in top.get_arity(4).items()
    if label.startswith("dihtype_")
    for i, j, k, l in entries
}

def compare(label, a, b):
    print(f"\nComparing {label}")
    print(f"✓ Matched: {len(a & b)}")
    print(f"x In OpenMM not SPARK: {sorted(a - b)}")
    print(f"! In SPARK not OpenMM: {sorted(b - a)}")

compare("bonds (1–2)", omm_bonds, bonded_top)
compare("angles (1–3)", omm_angles, angle_top)
compare("dihedrals (1–4)", omm_dihedrals, dihedral_top)

# --- Nonbonded exclusions and exceptions ---
excluded_pairs = set()
one_four_exceptions = set()
for i in range(nonbonded.getNumExceptions()):
    a, b, q, sigma, eps = nonbonded.getExceptionParameters(i)
    pair = tuple(sorted((int(a), int(b))))
    if q == 0.0 * unit.elementary_charge**2 and eps == 0.0 * unit.kilojoule_per_mole:
        excluded_pairs.add(pair)
    else:
        one_four_exceptions.add(pair)

lj_fudge_pairs = {tuple(sorted(p)) for p in top.get(2, 'LJ_Fudge')}
top_exclusions = bonded_top | {
    tuple(sorted((i, k)))
    for label, entries in top.get_arity(3).items()
    if label.startswith("angletype_")
    for i, j, k in entries
}

compare("1–4 exceptions vs LJ_Fudge", one_four_exceptions, lj_fudge_pairs)
compare("1–2/1–3 exclusions", excluded_pairs, top_exclusions)


╔═══════════════════════════════════════════════════╗
║                                                   ║
║  ██████╗   ██████╗    ██╗      ██████╗   ██╗  ██╗ ║
║ ██╔════╝  ██╔══██╗   ██╔██╗    ██╔══██╗  ██║ ██╔╝ ║
║ ╚█████╗   ██████╔╝  ██╔╝╚██╗   ██████╔╝  █████╔╝  ║
║  ╚═══██╗  ██╔═══╝  ██╔╝  ╚██╗  ██╔══██╗  ██╔═██╗  ║
║ ██████╔╝  ██║     ██╔╝    ╚██╗ ██║  ██║  ██║ ╚██╗ ║
║ ╚═════╝   ╚═╝     ╚═╝      ╚═╝ ╚═╝  ╚═╝  ╚═╝  ╚═╝ ║
║                                                   ║
║     Statistical Physics Autodiff Research Kit     ║
╚═══════════════════════════════════════════════════╝

          V(r)           ψ, φ              q
           │               │               │
           ○               ○               ○
         ╱ | ╲           ╱ | ╲           ╱ | ╲
        ○  ○  ○         ○  ○  ○         ○  ○  ○
         ╲ | ╱           ╲ | ╱           ╲ | ╱
           ○               ○               ○
           │               │               │
          g(r)             F         