# MD from scratch

# BIOE/CME 209 Homework

In [None]:
from pathlib import Path

import jax
import jax.numpy as jnp
import mdtraj as md
import numpy as np
import py3Dmol

# OpenMM is used only to add hydrogens (preprocessing step)
from openmm.app import ForceField, Modeller, PDBFile

from mdfs.energy import BondedSet, DSFParams, NonbondedSet
from mdfs.params import (
    assign_types_and_charges_from_templates,
    lj_to_jax,
    load_ffxml,
    prepare_bonded_arrays_for_energy,
    prepare_nonbonded_scaling,
)
from mdfs.simulate import run, simulate_langevin

# Import your project modules
from mdfs.topology import load_topology_mdtraj


## Input paths and utilities

In [None]:
pdb_in = Path("../assets/poly_A.pdb")  # AlphaFold PDB without hydrogens
ffxml_path = Path("../src/mdfs/ffxml/amber19/protein.ff19SB.xml")

print("Using FFXML:", ffxml_path.resolve())

# Temporary output for protonated PDB
pdb_with_h = Path("./poly_A_withH.pdb")


## Add hydrogens with OpenMM Modeller (protein pH 7)

In [None]:
# Load PDB with OpenMM
pdb = PDBFile(str(pdb_in))

# Create modeller and add hydrogens using the same Amber19 FF family
# (A small protein-only FF is enough for protonation templates)
modeller = Modeller(pdb.topology, pdb.positions)

# Use OpenMM's Amber19 force field bundle for hydrogen placement templates.
# (The presence of a water model is not required for adding hydrogens here.)
amber19 = ForceField("amber19-all.xml")  # provides templates for addHydrogens

modeller.addHydrogens(forcefield=amber19, pH=7.0)

# Write out the protonated PDB
with open(pdb_with_h, "w") as f:
    PDBFile.writeFile(modeller.topology, modeller.positions, f)
print("Wrote:", pdb_with_h.resolve())


## Load positions/topology via MDTraj

In [None]:
traj0 = md.load(str(pdb_with_h))
top0 = traj0.topology
R0_nm = traj0.xyz[0].astype(np.float64)  # (N,3), in nm (MDTraj native units)
N = R0_nm.shape[0]

print("Atoms:", N, "| Residues:", top0.n_residues, "| Chains:", top0.n_chains)


In [None]:
# ----------------------------------------------------------
# 4) Visualize the starting structure (py3Dmol, ball&stick)
# ----------------------------------------------------------
view = py3Dmol.view(width=600, height=420)
view.addModel(open(str(pdb_with_h)).read(), "pdb")
view.setStyle({"cartoon": {"color": "spectrum"}})
view.addStyle({"atom": {"elem": "H"}}, {"stick": {}})  # emphasize hydrogens
view.addStyle({"atom": {"elem": ["C", "N", "O", "S"]}}, {"stick": {}})
view.zoomTo()
view.show()


In [None]:
# ----------------------------------------------------------
# 5) Build graph/topology arrays for bonded terms (MDTraj)
# ----------------------------------------------------------

top_arrays = load_topology_mdtraj(str(pdb_with_h))
# top_arrays: positions (nm), bonds, angles, dihedrals, onefour, atom/res names, box_lengths (or None)
assert top_arrays.positions.shape == R0_nm.shape

# If no unit cell in PDB, make a loose orthorhombic box: bounding box + margin
if top_arrays.box_lengths is None:
    lo = R0_nm.min(axis=0)
    hi = R0_nm.max(axis=0)
    extents = hi - lo
    margin = 1.0  # nm
    box = np.maximum(extents + 2 * margin, 3.0)  # at least 3 nm to be safe
else:
    box = top_arrays.box_lengths

print("Box (nm):", box)


In [None]:
# -------------------------------------------------------------
# 6) Load Amber19 ff19SB parameters and map types/charges
# -------------------------------------------------------------
ff = load_ffxml(str(ffxml_path))
types_np, charges_np = assign_types_and_charges_from_templates(
    res_names=top_arrays.res_names,
    atom_names=top_arrays.atom_names,
    ff=ff,
)

# Sanity: ensure most atoms got typed (types==-1 means unmatched)
typed_fraction = (types_np >= 0).mean()
print(f"Typed atoms: {typed_fraction * 100:.1f}%")
if typed_fraction < 0.95:
    print(
        "WARNING: Some atoms not matched to residue templates. "
        "They will fall back to element-based types (if you coded that) and zero charge."
    )


In [None]:
# ------------------------------------------------------------------------
# 7) Build per-instance bonded arrays and 1–4/exclusion matrices
# ------------------------------------------------------------------------
(bond_k, bond_r0), (ang_k, ang_th0), (tor_n, tor_kn, tor_delta, tor_mask) = (
    prepare_bonded_arrays_for_energy(
        bonds=top_arrays.bonds,
        angles=top_arrays.angles,
        dihedrals=top_arrays.dihedrals,
        types=types_np,
        ff=ff,
        torsion_tmax=4,
    )
)

ex_mat, s14_vdw_mat, s14_elec_mat = prepare_nonbonded_scaling(
    n_atoms=N,
    bonds=top_arrays.bonds,
    dihedrals=top_arrays.dihedrals,
    ff=ff,
)

print("Bonds/Angles/Dihedrals:", len(bond_k), len(ang_k), tor_n.shape[0])


In [None]:
# ------------------------------------------------------------------------
# 8) Assemble BondedSet / NonbondedSet (JAX arrays) for energy functions
# ------------------------------------------------------------------------

# Bonded
bonded = BondedSet(
    bonds=jnp.asarray(top_arrays.bonds),
    k_r=jnp.asarray(bond_k),
    r0=jnp.asarray(bond_r0),
    angles=jnp.asarray(top_arrays.angles),
    k_theta=jnp.asarray(ang_k),
    theta0=jnp.asarray(ang_th0),
    dihs=jnp.asarray(top_arrays.dihedrals),
    n=jnp.asarray(tor_n),
    k_n=jnp.asarray(tor_kn),
    delta=jnp.asarray(tor_delta),
    active_mask=jnp.asarray(tor_mask),
)

# Nonbonded template (pairs are injected at run-time by simulate)
lj_params = lj_to_jax(ff.lj)  # ensure JAX dtypes
nonbonded = NonbondedSet(
    pairs=jnp.zeros((0, 2), dtype=jnp.int32),
    types=jnp.asarray(types_np, dtype=jnp.int32),
    q=jnp.asarray(charges_np),
    lj_params=lj_params,
    r_cut_lj=1.2,  # nm (typical short-range cutoff)
    dsf=DSFParams(alpha=3.0, r_cut=1.2, k_e=1.0),  # DSF electrostatics
    scale14_vdw=jnp.asarray(s14_vdw_mat),  # (N,N) matrices used inside energy with pairs
    scale14_elec=jnp.asarray(s14_elec_mat),
    exclude_mask=jnp.asarray(ex_mat, dtype=bool),
    shift_lj=True,  # energy-shifted LJ at cutoff (optional)
)

# Initial positions/velocities (JAX)
R0 = jnp.asarray(R0_nm)
V0 = jnp.zeros_like(R0)
box_jax = jnp.asarray(box)  # (3,)


In [None]:
# ---------------------------------------------------
# 9) Build a Langevin simulation and run a short traj
# ---------------------------------------------------
dt = 0.002  # ps
gamma = 1.0  # 1/ps
temperature = 300.0
mass = 12.0  # amu, simple scalar mass per atom for teaching

r_cut_neighbor = 1.2  # nm
skin = 0.2  # nm

init_fn, step_fn = simulate_langevin(
    R0=R0,
    V0=V0,
    box=box_jax,
    bonded=bonded,
    nonbonded=nonbonded,
    dt=dt,
    mass=mass,
    gamma=gamma,
    temperature=temperature,
    r_cut_neighbor=r_cut_neighbor,
    skin=skin,
    kB=1.0,
)

# Collect a short trajectory for visualization/analysis
n_steps = 2000
report_every = 50

frames = []


def callback(step, state):
    # state.integ.R is (N,3) nm
    frames.append(np.array(state.integ.R))


state = init_fn(jax.random.PRNGKey(0))
state = run(
    init_fn,
    step_fn,
    n_steps=n_steps,
    key=jax.random.PRNGKey(1),
    state=state,
    report_interval=report_every,
    callback=callback,
)

frames = np.stack(frames, axis=0)  # (T, N, 3) in nm
frames.shape


In [None]:
# ---------------------------------------------------
# 10) Visualize the trajectory with py3Dmol
# ---------------------------------------------------
# Convert to an MDTraj Trajectory for convenience (topology from traj0)
traj = md.Trajectory(frames, traj0.topology)  # frames in nm (MDTraj-native)

# Write a small multi-model PDB string for py3Dmol (only a few frames)
sel = np.linspace(0, len(traj) - 1, num=min(30, len(traj)), dtype=int)
traj_show = traj.slice(sel)

pdb_str = traj_show.to_pdb()
view = py3Dmol.view(width=700, height=480)
view.addModel(pdb_str, "pdb")
view.setStyle({"cartoon": {"color": "spectrum"}})
view.addStyle({"atom": {"elem": "H"}}, {"stick": {}})
view.zoomTo()
view.animate({"loop": "forward", "reps": 1, "step": 1, "interval": 60})
view.show()


In [None]:
# ---------------------------------------------------
# 11) Simple analysis with MDTraj: backbone RMSD
# ---------------------------------------------------
# Align to the first reported frame and compute RMSD (Cα selection)
traj_ca = traj.atom_slice(traj.topology.select("name CA"))
traj_ca.superpose(traj_ca, 0)  # align on first frame

rmsd = md.rmsd(traj_ca, traj_ca, 0)  # nm
import matplotlib.pyplot as plt

plt.figure()
plt.plot(np.arange(len(rmsd)) * report_every * dt, rmsd)
plt.xlabel("Time (ps)")
plt.ylabel("Cα RMSD (nm)")
plt.title("Backbone RMSD vs Time")
plt.show()


In [None]:
# ---------------------------------------------------
# 12) Save the short trajectory (optional)
# ---------------------------------------------------
xtc_out = Path("./poly_A_short.xtc")
traj.save_xtc(str(xtc_out))
print("Saved trajectory:", xtc_out.resolve())
