# jaxmm Quick Start

Core API in 5 minutes. Extract force field parameters from OpenMM once,
then evaluate energy, compute forces, and batch-process configurations
in pure JAX.

In [1]:
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import jax
jax.config.update("jax_enable_x64", True)  # float64 required
import jax.numpy as jnp
import jax.random as random

from openmm import unit
from openmmtools import testsystems

import jaxmm
from jaxmm import FEMTOSECOND



## Extract parameters

OpenMM builds the molecule and assigns force field parameters.
`extract_params` pulls everything into frozen dataclasses of JAX arrays.
After this, OpenMM is no longer needed.

In [2]:
# Build alanine dipeptide in vacuum (22 atoms, 66 DOF)
aldp = testsystems.AlanineDipeptideVacuum(constraints=None)

# One-time extraction (uses OpenMM)
params = jaxmm.extract_params(aldp.system)
pos = jnp.array(aldp.positions.value_in_unit(unit.nanometer), dtype=jnp.float64)

print(f"Atoms: {params.n_atoms}")
print(f"Bonds: {params.bonds.atom_i.shape[0]}")
print(f"Angles: {params.angles.atom_i.shape[0]}")
print(f"Torsions: {params.torsions.atom_i.shape[0]}")
print(f"GBSA: {'yes' if params.gbsa is not None else 'no'}")

Atoms: 22
Bonds: 21
Angles: 36
Torsions: 52
GBSA: no


## Evaluate energy

All energy functions have the same signature: `(positions, params) -> scalar`.
`energy_components` returns a dict of per-term contributions.

In [3]:
# Total energy
energy = jaxmm.total_energy(pos, params)
print(f"Total energy: {float(energy):.4f} kJ/mol")

# Per-term decomposition
components = jaxmm.energy_components(pos, params)
for name, val in components.items():
    print(f"  {name:>10s}: {float(val):10.4f} kJ/mol")

# Log Boltzmann factor: -E / (kB * T)
lp = jaxmm.log_boltzmann(pos, params, temperature=300.0)
print(f"\nlog p(x) at 300K: {float(lp):.4f}")

Total energy: -88.0886 kJ/mol
       bonds:     0.0862 kJ/mol
      angles:     1.5144 kJ/mol
    torsions:     8.0563 kJ/mol
   nonbonded:   -97.7455 kJ/mol

log p(x) at 300K: 35.3154


## Forces via jax.grad

Forces are the negative gradient of energy w.r.t. positions.
Because jaxmm is pure JAX, this is automatic.

In [4]:
# Forces = -dE/dx
forces = -jax.grad(jaxmm.total_energy)(pos, params)

print(f"Forces shape: {forces.shape}")
print(f"Max force magnitude: {float(jnp.linalg.norm(forces, axis=-1).max()):.2f} kJ/mol/nm")

Forces shape: (22, 3)
Max force magnitude: 867.10 kJ/mol/nm


## Batch evaluation with vmap

`jax.vmap` vectorizes energy evaluation across a batch of configurations.
Combined with `jax.jit`, this gives large speedups over sequential evaluation.

In [5]:
# Generate a batch of configurations via short Langevin MD
result = jax.jit(
    jaxmm.langevin_baoab, static_argnames=("n_steps", "save_every")
)(
    pos, jnp.zeros_like(pos), params,
    dt=1.0 * FEMTOSECOND, temperature=300.0, friction=1.0,
    n_steps=5000, save_every=100, key=random.key(0),
)
batch = result.trajectory_positions  # (50, 22, 3)

# Vectorized energy evaluation
batch_energy = jax.jit(jax.vmap(jaxmm.total_energy, in_axes=(0, None)))
energies = batch_energy(batch, params)

print(f"Batch shape: {batch.shape}")
print(f"Energies: {energies.shape}")
print(f"Range: [{float(energies.min()):.1f}, {float(energies.max()):.1f}] kJ/mol")

Batch shape: (50, 22, 3)
Energies: (50,)
Range: [-101.6, -25.1] kJ/mol


## Minimize

L-BFGS energy minimization, pure JAX.

In [6]:
pos_min = jaxmm.minimize_energy(pos, params)
e_min = jaxmm.total_energy(pos_min, params)
print(f"Energy before: {float(jaxmm.total_energy(pos, params)):.2f} kJ/mol")
print(f"Energy after:  {float(e_min):.2f} kJ/mol")

Energy before: -88.09 kJ/mol
Energy after:  -118.45 kJ/mol


## Visualize the molecule

Interactive 3D view at the minimized geometry. Requires `py3Dmol`.

In [None]:
from jaxmm.notebook import show_structure

view = show_structure(pos_min, aldp.topology, width=600, height=400)
view.show()

## Save and load parameters

Serialize to `.npz` (no pickle). After saving, OpenMM is not needed to
reload and use the parameters.

In [None]:
jaxmm.save_params(params, "aldp_vacuum_params.npz")
params_loaded = jaxmm.load_params("aldp_vacuum_params.npz")

# Verify roundtrip
e_original = jaxmm.total_energy(pos, params)
e_loaded = jaxmm.total_energy(pos, params_loaded)
print(f"Energy match: {float(abs(e_original - e_loaded)):.1e} kJ/mol difference")

## Summary

| Step | Function | Needs OpenMM? |
|------|----------|---------------|
| Extract params | `jaxmm.extract_params(system)` | Yes (one-time) |
| Save params | `jaxmm.save_params(params, path)` | No |
| Load params | `jaxmm.load_params(path)` | No |
| Energy | `jaxmm.total_energy(pos, params)` | No |
| Forces | `-jax.grad(jaxmm.total_energy)(pos, params)` | No |
| Batch | `jax.vmap(jaxmm.total_energy, in_axes=(0, None))` | No |
| Minimize | `jaxmm.minimize_energy(pos, params)` | No |
| Log Boltzmann | `jaxmm.log_boltzmann(pos, params, T)` | No |
| MD | `jaxmm.langevin_baoab(...)` / `jaxmm.verlet(...)` | No |