In [11]:
import argparse
import sys
from pathlib import Path
import itertools
import json
import time

import numpy as np

from mmml.cli.base import (
    load_model_parameters,
    resolve_checkpoint_paths,
    setup_ase_imports,
    setup_mmml_imports,
)

In [18]:
from dataclasses import dataclass, asdict
from types import SimpleNamespace
from typing import Optional, Dict, Any
from pathlib import Path

@dataclass
class OptArgs:
    # Required
    dataset: Path
    pdbfile: Path
    checkpoint: Path
    n_monomers: int
    n_atoms_monomer: int

    # Optimization controls (defaults match parser)
    ml_cutoff_grid: str = "1.5,2.0,2.5,3.0"
    mm_switch_on_grid: str = "4.0,5.0,6.0,7.0"
    mm_cutoff_grid: str = "0.5,1.0,1.5,2.0"
    energy_weight: float = 1.0
    force_weight: float = 1.0
    max_frames: int = 200
    out: Optional[Path] = None
    out_npz: Optional[Path] = None
    validate: bool = False

    # MD simulation arguments (defaults match parser)
    energy_catch: float = 0.5
    cell: Optional[float] = None
    ml_cutoff: float = 2.0
    mm_switch_on: float = 5.0
    mm_cutoff: float = 1.0
    include_mm: bool = False
    skip_ml_dimers: bool = False
    debug: bool = False

    def as_dict(self) -> Dict[str, Any]:
        return asdict(self)

    def as_namespace(self) -> SimpleNamespace:
        return SimpleNamespace(**self.as_dict())

args = OptArgs(
    dataset=Path("/pchem-data/meuwly/boittier/home/fixed-acetone-only_MP2_21000.npz"),
    pdbfile=Path("/pchem-data/meuwly/boittier/home/acetone.3/pdb/init-packmol.pdb"),
    checkpoint=Path("/pchem-data/meuwly/boittier/home/acetone.2/ACO-b4f39bb9-8ca7-485e-bf51-2e5236e51b56"),
    n_monomers=2,
    n_atoms_monomer=10,
    ml_cutoff=0.1,
    mm_switch_on=6.0,
    mm_cutoff=5.0,
    include_mm=True,
    skip_ml_dimers=False,
    debug=False,
)
ns = args.as_namespace()

ns
        

namespace(dataset=PosixPath('/pchem-data/meuwly/boittier/home/fixed-acetone-only_MP2_21000.npz'),
          pdbfile=PosixPath('/pchem-data/meuwly/boittier/home/acetone.3/pdb/init-packmol.pdb'),
          checkpoint=PosixPath('/pchem-data/meuwly/boittier/home/acetone.2/ACO-b4f39bb9-8ca7-485e-bf51-2e5236e51b56'),
          n_monomers=2,
          n_atoms_monomer=10,
          ml_cutoff_grid='1.5,2.0,2.5,3.0',
          mm_switch_on_grid='4.0,5.0,6.0,7.0',
          mm_cutoff_grid='0.5,1.0,1.5,2.0',
          energy_weight=1.0,
          force_weight=1.0,
          max_frames=200,
          out=None,
          out_npz=None,
          validate=False,
          energy_catch=0.5,
          cell=None,
          ml_cutoff=0.1,
          mm_switch_on=6.0,
          mm_cutoff=5.0,
          include_mm=True,
          skip_ml_dimers=False,
          debug=False)

In [19]:
args = ns

In [21]:
base_ckpt_dir, epoch_dir = resolve_checkpoint_paths(args.checkpoint)
# Setup imports
Atoms = setup_ase_imports()
CutoffParameters, ev2kcalmol, setup_calculator, get_ase_calc = setup_mmml_imports()

# Additional imports for this demo
try:
    import pycharmm
    import ase
    import ase.calculators.calculator as ase_calc
    import ase.io as ase_io
    from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary, ZeroRotation
    from ase.md.verlet import VelocityVerlet
    import ase.optimize as ase_opt
    import matplotlib.pyplot as plt
    import py3Dmol
    from mmml.pycharmmInterface.import_pycharmm import coor
    from mmml.pycharmmInterface.setupBox import setup_box_generic
    import pandas as pd
    from mmml.pycharmmInterface.import_pycharmm import minimize
    import jax_md
    # JAX-MD imports
    from jax_md import space, smap, energy, quantity, simulate, partition, units
    from ase.units import _amu

    import jax.numpy as jnp

    import jax, e3x
    from jax import jit, grad, lax, ops, random
    import jax.numpy as jnp
    from ase.io import Trajectory
except ModuleNotFoundError as exc:
    sys.exit(f"Required modules not available: {exc}")

pdbfilename = str(args.pdbfile)

# Setup box and load PDB
setup_box_generic(pdbfilename, side_length=1000)
pdb_ase_atoms = ase_io.read(pdbfilename)

print(f"Loaded PDB file: {pdb_ase_atoms}")
print("Note: for testing the dimer calculator, the pdb file should contain a dimer, and" \
 "the atom types should be consistent with the dimer calculator.")

# ========================================================================
# MASS SETUP FOR JAX-MD SIMULATION
# ========================================================================
# JAX-MD requires proper mass arrays for temperature calculation and dynamics

# Get atomic masses from ASE (in atomic mass units)
raw_masses = pdb_ase_atoms.get_masses()
print(f"Raw masses from ASE: {raw_masses}")
Si_mass = jnp.array(raw_masses)  # Use ASE masses directly (in amu)
Si_mass_sum = Si_mass.sum()
print(f"Si_mass (ASE masses in amu): {Si_mass}")
print(f"Si_mass sum: {Si_mass_sum}")

# Expand mass array to match momentum dimensions for JAX-MD broadcasting
# Momentum has shape (n_atoms, 3), so mass must also have shape (n_atoms, 3)
Si_mass_expanded = jnp.repeat(Si_mass[:, None], 3, axis=1)  # Shape: (20, 3)
print(f"Si_mass_expanded shape: {Si_mass_expanded.shape}")
print(f"Si_mass_expanded sample: {Si_mass_expanded[0]}")
print(f"PyCHARMM coordinates: {coor.get_positions()}")
print(f"Ase coordinates: {pdb_ase_atoms.get_positions()}")
print(f"{coor.get_positions() == pdb_ase_atoms.get_positions()}")

print(coor.get_positions())

# Load model parameters
natoms = len(pdb_ase_atoms)
n_monomers = args.n_monomers
n_atoms_monomer = args.n_atoms_monomer
assert n_atoms_monomer * n_monomers == natoms, "n_atoms_monomer * n_monomers != natoms"
params, model = load_model_parameters(epoch_dir, natoms)
model.natoms = natoms
print(f"Model loaded: {model}")

# Get atomic numbers and positions
Z, R = pdb_ase_atoms.get_atomic_numbers(), pdb_ase_atoms.get_positions()

# Setup calculator factory
calculator_factory = setup_calculator(
    ATOMS_PER_MONOMER=args.n_atoms_monomer,
    N_MONOMERS=args.n_monomers,
    ml_cutoff_distance=args.ml_cutoff,
    mm_switch_on=args.mm_switch_on,
    mm_cutoff=args.mm_cutoff,
    doML=True,
    doMM=args.include_mm,
    doML_dimer=not args.skip_ml_dimers,
    debug=args.debug,
    model_restart_path=base_ckpt_dir,
    MAX_ATOMS_PER_SYSTEM=natoms,
    ml_energy_conversion_factor=1,
    ml_force_conversion_factor=1,
    cell=args.cell,
)


CUTOFF_PARAMS = CutoffParameters(
        ml_cutoff=args.ml_cutoff,
        mm_switch_on=args.mm_switch_on,
        mm_cutoff=args.mm_cutoff,
    )

print(f"Cutoff parameters: {CUTOFF_PARAMS}")

# Create hybrid calculator
hybrid_calc, _ = calculator_factory(
    atomic_numbers=Z,
    atomic_positions=R,
    n_monomers=args.n_monomers,
    cutoff_params=CUTOFF_PARAMS,
    doML=True,
    doMM=args.include_mm,
    doML_dimer=not args.skip_ml_dimers,
    backprop=True,
    debug=args.debug,
    energy_conversion_factor=1,
    force_conversion_factor=1,
    # do_pbc_map=args.cell is not None,
    # pbc_map=calculator_factory.pbc_map if hasattr(calculator_factory, 'pbc_map') else None,
)

print(f"Hybrid calculator created: {hybrid_calc}")
atoms = pdb_ase_atoms




  
 CHARMM>     DELETE ATOM SELE ALL END
 SELRPN>     20 atoms have been selected out of     20

 Message from MAPIC: Atom numbers are changed.

 Message from MAPIC:          2 residues deleted.

 Message from MAPIC:          1 segments deleted.
 DELTIC:        18 bonds deleted
 DELTIC:        30 angles deleted
 DELTIC:        24 dihedrals deleted
 DELTIC:         2 improper dihedrals deleted
 DELTIC:         2 acceptors deleted
 PSFSUM> PSF modified: NONBOND lists and IMAGE atoms cleared.
 PSFSUM> Summary of the structure file counters :
         Number of segments      =        0   Number of residues   =        0
         Number of atoms         =        0   Number of groups     =        0
         Number of bonds         =        0   Number of angles     =        0
         Number of dihedrals     =        0   Number of impropers  =        0
         Number of cross-terms   =        0   Number of autogens   =        0
         Number of HB acceptors  =        0   Number of HB donors



dict_keys(['opt_state', 'params', 'step'])


Model loaded: EF(
    # attributes
    features = 24
    max_degree = 1
    num_iterations = 2
    num_basis_functions = 12
    cutoff = 5.0
    max_atomic_number = 28
    charges = False
    natoms = 20
    total_charge = 0
    n_res = 2
    zbl = False
    debug = False
    efa = False
)
0 1
unique_res_ids [0, 1]
len(dimer_perms) 1
dict_keys(['opt_state', 'params', 'step'])


Cutoff parameters: <mmml.pycharmmInterface.mmml_calculator.CutoffParameters object at 0x149e60af0f20>
Hybrid calculator created: <mmml.pycharmmInterface.mmml_calculator.setup_calculator.<locals>.AseDimerCalculator object at 0x149dee519bb0>


In [33]:
hybrid_calc.cutoff_params.ml_cutoff, hybrid_calc.cutoff_params.mm_switch_on, hybrid_calc.cutoff_params.mm_cutoff, 

(0.1, 6.0, 5.0)

In [37]:
if args.cell is not None:
    print("Setting cell")
    from ase.cell import Cell
    print("Creating cell")
    cell = Cell.fromcellpar([float(args.cell), float(args.cell), float(args.cell), 90., 90., 90.])
    atoms.set_cell(cell)
    # Enable periodic boundary conditions
    atoms.set_pbc(True)
    print(f"Cell: {cell}")
    print(f"PBC enabled: {atoms.pbc}")
    print(f"Cell shape: {cell.shape}")
    print(f"Cell type: {type(cell)}")
    print(f"Cell dtype: {cell.dtype}")
    print(f"Cell size: {cell.size}")
    print(f"Cell dtype: {cell.dtype}")
    print(f"Cell ndim: {cell.ndim}")
    print(f"Cell dtype: {cell.dtype}")
else:
    cell = None
    print("No cell provided")

print(f"ASE atoms: {atoms}")
atoms.calc = hybrid_calc
# Get initial energy and forces
hybrid_energy = float(atoms.get_potential_energy())
hybrid_forces = np.asarray(atoms.get_forces())
print(f"Initial energy: {hybrid_energy:.6f} eV")
print(f"Initial forces: {hybrid_forces}")


# load dataset
dataset = np.load(args.dataset)
print(f"Dataset: {dataset}")
print(f"Dataset keys: {dataset.keys()}")
R_all = dataset["R"]  # (n_frames, natoms, 3)
Z_ds = dataset.get("Z", Z)
if Z_ds.ndim > 1:
    Z_ds = np.array(Z_ds[0]).astype(int)
print(f"R shape: {R_all.shape}")
E_all = dataset.get("E", None)
F_all = dataset.get("F", None)
has_E = E_all is not None and np.size(E_all) > 0
has_F = F_all is not None and np.size(F_all) > 0
n_frames = R_all.shape[0]
if args.max_frames is not None and args.max_frames > 0:
    n_eval = min(n_frames, args.max_frames)
elif args.max_frames == -1:
    n_eval = n_frames
else:
    n_eval = n_frames

print("arranging frames by CoM")
# arrange frames by center of mass distances between the two monomers
com_distances = []
count_non_dimer = 0
for i in range(len(R_all)):
    # Calculate COM for each monomer
    com1 = R_all[i][:args.n_atoms_monomer].mean(axis=0)  # First monomer
    com2 = R_all[i][args.n_atoms_monomer:].mean(axis=0)  # Second monomer
    # Distance between monomer COMs
    if dataset["N"][i] != args.n_atoms_monomer*2:
        count_non_dimer += 1
    com_distances.append(np.linalg.norm(com1 - com2))
com_distances = np.array(com_distances)
# sort by com_distances and then remove the non-dimer frames
frame_indices = np.argsort(com_distances)[:-count_non_dimer][::(len(com_distances)-count_non_dimer)//n_eval]
print(f"Evaluating {n_eval} frames (out of {n_frames}). E available: {has_E}, F available: {has_F}")

# Utility to parse grids
def _parse_grid(s: str) -> list[float]:
    return [float(x) for x in s.split(",") if x.strip() != ""]

ml_grid = _parse_grid(args.ml_cutoff_grid)
mm_on_grid = _parse_grid(args.mm_switch_on_grid)
mm_cut_grid = _parse_grid(args.mm_cutoff_grid)
print(f"Grid sizes -> ml:{len(ml_grid)} mm_on:{len(mm_on_grid)} mm_cut:{len(mm_cut_grid)}")


No cell provided
ASE atoms: Atoms(symbols='OC3H6OC3H6', pbc=False, atomtypes=..., bfactor=..., occupancy=..., residuenames=..., residuenumbers=..., calculator=AseDimerCalculator(...))
Initial energy: -73.381851 eV
Initial forces: [[ 3.6557930e+00  1.0154342e+01 -7.0488071e-01]
 [-1.0048665e+00 -3.9248860e+00  7.3255281e+00]
 [-4.1517158e+00 -4.4469900e+00 -4.7821269e+00]
 [-9.8661852e-01 -2.8884964e+00  1.5761465e-01]
 [-5.9345388e-01  3.7028468e-01 -9.0858644e-01]
 [-2.3831147e-01 -2.9437959e-02 -3.5048112e-02]
 [ 1.3743185e+00  7.2051257e-01 -6.2255359e-01]
 [ 1.0860609e+00  8.6109841e-01 -3.9593297e-01]
 [ 6.2033278e-01 -3.4821886e-01 -8.3028078e-03]
 [ 2.3846097e-01 -4.6820870e-01 -2.5710434e-02]
 [-8.3674831e+00 -9.9758892e+00  2.2805994e+00]
 [-1.4439833e-01  7.4412069e+00  1.5081589e-01]
 [ 6.0849085e+00  7.1063310e-01 -8.6344406e-02]
 [ 1.4873812e+00  1.8369488e+00 -3.4228727e-01]
 [ 4.8169678e-01  6.3304883e-01  5.3847957e-01]
 [ 7.5287640e-01 -4.8852590e-01 -1.4228216e+00]
 [

In [38]:

# Objective evaluation for a given cutoff triple
def evaluate_objective(atoms, ml_cutoff: float, mm_switch_on: float, mm_cutoff: float) -> dict:
    # atoms = atoms.copy()
    local_params = CutoffParameters(
        ml_cutoff=ml_cutoff,
        mm_switch_on=mm_switch_on,
        mm_cutoff=mm_cutoff,
    )
    # Rebuild calculator with new cutoffs
    hybrid_calc, _ = calculator_factory(
        atomic_numbers=Z_ds,
        atomic_positions=R_all[0],
        n_monomers=args.n_monomers,
        cutoff_params=local_params,
        doML=True,
        doMM=args.include_mm,
        doML_dimer=not args.skip_ml_dimers,
        backprop=True,
        debug=args.debug,
        energy_conversion_factor=1,
        force_conversion_factor=1,
    )
    print(
        hybrid_calc.cutoff_params.ml_cutoff,
        hybrid_calc.cutoff_params.mm_switch_on, 
        hybrid_calc.cutoff_params.mm_cutoff
    )
    atoms.calc = hybrid_calc
    se_e = 0.0
    se_f = 0.0
    n_e = 0
    n_f = 0
    for i in frame_indices:
        atoms.positions = R_all[i]
        pred_E = float(atoms.get_potential_energy())
        pred_F = np.asarray(atoms.get_forces())
        if has_E:
            ref_E = float(E_all[i])
            se_e += (pred_E - ref_E) ** 2
            n_e += 1
        if has_F:
            ref_F = np.asarray(F_all[i])
            se_f += float(np.mean((pred_F - ref_F) ** 2))
            n_f += 1
    mse_e = (se_e / max(n_e, 1)) if has_E else 0.0
    mse_f = (se_f / max(n_f, 1)) if has_F else 0.0
    obj = args.energy_weight * mse_e + args.force_weight * mse_f
    _out_dict = {
        "ml_cutoff": ml_cutoff,
        "mm_switch_on": mm_switch_on,
        "mm_cutoff": mm_cutoff,
        "mse_energy": mse_e,
        "mse_forces": mse_f,
        "objective": obj,
    }
    print(f"Objective: {_out_dict}")
    return _out_dict

In [40]:
atoms

Atoms(symbols='OC3H6OC3H6', pbc=False, atomtypes=..., bfactor=..., occupancy=..., residuenames=..., residuenumbers=..., calculator=AseDimerCalculator(...))

In [None]:


# Grid search
start = time.time()
best = None
results = []
for ml_c, mm_on, mm_c in itertools.product(ml_grid, mm_on_grid, mm_cut_grid):
    res = evaluate_objective(atoms, ml_c, mm_on, mm_c)
    results.append(res)
    if best is None or res["objective"] < best["objective"]:
        best = res
    print(
        f"ml={ml_c:.3f} mm_on={mm_on:.3f} mm_cut={mm_c:.3f} -> obj={res['objective']:.6e} (E={res['mse_energy']:.6e}, F={res['mse_forces']:.6e})"
    )
elapsed = time.time() - start
print(f"Grid search completed in {elapsed:.1f}s over {len(results)} combos.")
print(f"Best: {best}")

if args.out is not None:
    payload = {
        "best": best,
        "results": results,
        "n_eval_frames": int(n_eval),
        "energy_weight": args.energy_weight,
        "force_weight": args.force_weight,
    }
    args.out.parent.mkdir(parents=True, exist_ok=True)
    with open(args.out, "w") as f:
        json.dump(payload, f, indent=2)
    print(f"Saved results to {args.out}")

if args.out_npz is not None:
    # Save detailed results as NPZ
    npz_data = {
        "ml_cutoffs": np.array([r["ml_cutoff"] for r in results]),
        "mm_switch_ons": np.array([r["mm_switch_on"] for r in results]),
        "mm_cutoffs": np.array([r["mm_cutoff"] for r in results]),
        "mse_energies": np.array([r["mse_energy"] for r in results]),
        "mse_forces": np.array([r["mse_forces"] for r in results]),
        "objectives": np.array([r["objective"] for r in results]),
        "best_ml_cutoff": best["ml_cutoff"],
        "best_mm_switch_on": best["mm_switch_on"],
        "best_mm_cutoff": best["mm_cutoff"],
        "best_mse_energy": best["mse_energy"],
        "best_mse_forces": best["mse_forces"],
        "best_objective": best["objective"],
        "n_eval_frames": n_eval,
        "energy_weight": args.energy_weight,
        "force_weight": args.force_weight,
    }
    args.out_npz.parent.mkdir(parents=True, exist_ok=True)
    np.savez(args.out_npz, **npz_data)
    print(f"Saved detailed results to {args.out_npz}")

1.5 4.0 0.5
  
 CHARMM>     BLOCK
  
  BLOCK>            CALL 1 SELE ALL END
 SELRPN>     20 atoms have been selected out of     20
 The selected atoms have been reassigned to block   1
  
  BLOCK>              COEFF 1 1 1.0
  
  BLOCK>            END
 Matrix of Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of BOND Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of ANGLE Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of DIHE Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of CROSS Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of ELEC Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of VDW Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000

In [None]:
results