# JAX-MD Test with MMML Calculator

Demonstrates using the MMML `spherical_cutoff_calculator` as a JAX-MD energy function for:
1. **FIRE minimization** – energy minimization
2. **NVE simulation** – short microcanonical MD run

Requires: `jax`, `jax_md`, `e3x`, `pycharmm`, and a checkpoint. Set `CHECKPOINT` path below.

In [None]:
import os
from pathlib import Path

import numpy as np
import jax
import jax.numpy as jnp
from jax import jit
import jax_md
from jax_md import space, simulate, quantity

from mmml.pycharmmInterface.mmml_calculator import setup_calculator
from mmml.pycharmmInterface.cutoffs import CutoffParameters

In [None]:
# Config: checkpoint path, system size
ckpt_env = os.environ.get("MMML_CKPT")
ckpt = Path(ckpt_env) if ckpt_env else Path("mmml/physnetjax/ckpts")
assert ckpt.exists(), f"Checkpoint not found: {ckpt}"

n_monomers = 2
n_atoms_monomer = 10
n_atoms = n_monomers * n_atoms_monomer
atomic_numbers = np.array([6] * n_atoms)  # carbon

In [None]:
factory = setup_calculator(
    ATOMS_PER_MONOMER=n_atoms_monomer,
    N_MONOMERS=n_monomers,
    doML=True,
    doMM=False,
    model_restart_path=ckpt,
    MAX_ATOMS_PER_SYSTEM=n_atoms,
)

cutoff_params = CutoffParameters()
calc, spherical_cutoff_calculator = factory(
    atomic_numbers=atomic_numbers,
    atomic_positions=np.zeros((n_atoms, 3)),
    n_monomers=n_monomers,
    cutoff_params=cutoff_params,
)

key = jax.random.PRNGKey(42)
R0 = jnp.asarray(
    jax.random.uniform(key, (n_atoms, 3), minval=0.0, maxval=10.0),
    dtype=jnp.float32,
)

## FIRE minimization

In [None]:
@jit
def jax_md_energy_fn(position, **kwargs):
    out = spherical_cutoff_calculator(
        positions=position,
        atomic_numbers=jnp.array(atomic_numbers),
        n_monomers=n_monomers,
        cutoff_params=cutoff_params,
    )
    return out.energy.reshape(-1)[0]

displacement, shift = space.free()
init_fn, step_fn = jax_md.minimize.fire_descent(
    jax_md_energy_fn, shift, dt_start=0.001, dt_max=0.001
)
step_fn = jit(step_fn)

E0 = float(jax_md_energy_fn(R0))
print(f"Initial energy: {E0:.6f}")

state = init_fn(R0)
n_steps = 20
for i in range(n_steps):
    state = step_fn(state)

E_final = float(jax_md_energy_fn(state.position))
print(f"Final energy:   {E_final:.6f}")
print(f"Energy change:  {E_final - E0:.6f}")

## NVE simulation

In [None]:
@jit
def energy_fn(position, **kwargs):
    out = spherical_cutoff_calculator(
        positions=position,
        atomic_numbers=jnp.array(atomic_numbers),
        n_monomers=n_monomers,
        cutoff_params=cutoff_params,
    )
    return out.energy.reshape(-1)[0]

masses = jnp.ones((n_atoms,), dtype=jnp.float32) * 12.0  # carbon amu
dt = 1e-3
kT = 0.001

init_fn_nve, apply_fn_nve = simulate.nve(energy_fn, shift, dt)
apply_fn_nve = jit(apply_fn_nve)

key, vel_key = jax.random.split(key)
state_nve = init_fn_nve(vel_key, state.position, kT, mass=masses)

n_md_steps = 50
energies = []
for step in range(n_md_steps):
    state_nve = apply_fn_nve(state_nve)
    Ep = float(energy_fn(state_nve.position))
    Ek = float(quantity.kinetic_energy(momentum=state_nve.momentum, mass=state_nve.mass))
    energies.append((Ep, Ek))

Ep_vals, Ek_vals = zip(*energies)
Etot = np.array(Ep_vals) + np.array(Ek_vals)
print(f"Potential energy: {Ep_vals[0]:.6f} -> {Ep_vals[-1]:.6f}")
print(f"Total energy std: {np.std(Etot):.2e}")

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 1, figsize=(8, 4))
ax.plot(Ep_vals, label="Potential")
ax.plot(Ek_vals, label="Kinetic")
ax.plot(Etot, label="Total")
ax.set_xlabel("Step")
ax.set_ylabel("Energy")
ax.legend()
plt.tight_layout()
plt.show()