## Outline

This notebook is meant to detail setting up the MM/ML simulations 

## fitting the LJs terms

In [1]:
import mmml
import ase
import os
from pathlib import Path
import argparse
import sys
import numpy as np
import jax
import jax.numpy as jnp

# Set environment variables
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".99"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Check JAX configuration
devices = jax.local_devices()
print(devices)
print(jax.default_backend())
print(jax.devices())


[CudaDevice(id=0)]
gpu
[CudaDevice(id=0)]


# Setup: Mock CLI Arguments (following run_sim.py structure)

This cell creates a mock args object that mimics the CLI arguments from `run_sim.py`.
This allows the notebook to follow the same structure as the script.

In [2]:
# Import required modules (following run_sim.py structure)
from mmml.cli.base import (
    load_model_parameters,
    resolve_checkpoint_paths,
    setup_ase_imports,
    setup_mmml_imports,
)
from mmml.pycharmmInterface import import_pycharmm
import pycharmm
import pycharmm.ic as ic
import pycharmm.psf as psf
import pycharmm.energy as energy
from mmml.pycharmmInterface.mmml_calculator import setup_calculator, CutoffParameters
from mmml.physnetjax.physnetjax.data.data import prepare_datasets
from mmml.physnetjax.physnetjax.data.batches import prepare_batches_jit
from mmml.pycharmmInterface.setupBox import setup_box_generic
from mmml.pycharmmInterface import setupRes, setupBox
from mmml.pycharmmInterface.import_pycharmm import reset_block, coor
from mmml.pycharmmInterface.pycharmmCommands import CLEAR_CHARMM

# Setup ASE imports
Atoms = setup_ase_imports()
CutoffParameters, ev2kcalmol, setup_calculator, get_ase_calc = setup_mmml_imports()

# Additional imports for simulation
import ase.io as ase_io
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary, ZeroRotation
from ase.md.verlet import VelocityVerlet
import ase.optimize as ase_opt

/scicore/home/meuwly/boitti0000/mmml/mmml/data/top_all36_cgenff.rtf
/scicore/home/meuwly/boitti0000/mmml/mmml/data/par_all36_cgenff.prm
CHARMM_HOME /scicore/home/meuwly/boitti0000/mmml/setup/charmm
CHARMM_LIB_DIR /scicore/home/meuwly/boitti0000/mmml/setup/charmm
  
 CHARMM>     BLOCK
 Block structure initialized with   3 blocks.
 All atoms have been assigned to block 1.
 All interaction coefficients have been set to unity.
  Setting number of block exclusions nblock_excldPairs=0
  
  BLOCK>            CALL 1 SELE ALL END
 SELRPN>      0 atoms have been selected out of      0
 The selected atoms have been reassigned to block   1
  
  BLOCK>              COEFF 1 1 1.0
  
  BLOCK>            END
 Matrix of Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of BOND Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of ANGLE Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.0

In [3]:
# ========================================================================
# MOCK CLI ARGUMENTS (spoofing run_sim.py CLI)
# ========================================================================
# Create a mock args object that mimics the CLI arguments from run_sim.py
# This allows the notebook to follow the same structure as the script

class MockArgs:
    """Mock CLI arguments following run_sim.py structure"""
    def __init__(self):
        # Paths
        self.pdbfile = None  # Will be created from valid_data if needed
        self.checkpoint = Path(RESTART) if 'RESTART' in globals() else None
        
        # System parameters
        self.n_monomers = 2
        self.n_atoms_monomer = 10
        self.atoms_per_monomer = 10  # Alias for compatibility
        
        # Calculator parameters
        self.ml_cutoff = 2.0
        self.mm_switch_on = 4.0
        self.mm_cutoff = 1.0
        self.include_mm = True
        self.skip_ml_dimers = False
        self.debug = False
        
        # MD simulation parameters
        self.temperature = 210.0
        self.timestep = 0.1
        self.nsteps_jaxmd = 100_000
        self.nsteps_ase = 10000
        self.ensemble = "nvt"
        self.heating_interval = 500
        self.write_interval = 100
        self.energy_catch = 0.5
        
        # Output
        self.output_prefix = "md_simulation"
        self.cell = None  # No PBC by default
        
        # Validation
        self.validate = False

# Create mock args object
args = MockArgs()

# Override with notebook-specific values if needed
if 'ATOMS_PER_MONOMER' in globals():
    args.n_atoms_monomer = ATOMS_PER_MONOMER
    args.atoms_per_monomer = ATOMS_PER_MONOMER
if 'N_MONOMERS' in globals():
    args.n_monomers = N_MONOMERS

print(f"Mock args created:")
print(f"  n_monomers: {args.n_monomers}")
print(f"  n_atoms_monomer: {args.n_atoms_monomer}")
print(f"  ml_cutoff: {args.ml_cutoff}")
print(f"  mm_switch_on: {args.mm_switch_on}")
print(f"  mm_cutoff: {args.mm_cutoff}")

Mock args created:
  n_monomers: 2
  n_atoms_monomer: 10
  ml_cutoff: 2.0
  mm_switch_on: 4.0
  mm_cutoff: 1.0


In [4]:
# System parameters (can be overridden by args)
ATOMS_PER_MONOMER = args.n_atoms_monomer
N_MONOMERS = args.n_monomers

# Load Data and Prepare Batches (following run_sim.py structure)

This cell loads the validation data and prepares batches that will be used to initialize simulations.
Note: The residue numbers in the PDB/PSF may need to be adjusted based on the actual system.

In [5]:
# ========================================================================
# LOAD DATA AND PREPARE BATCHES (following run_sim.py structure)
# ========================================================================

# Initialize random key for data loading
if 'data_key' not in globals():
    data_key = jax.random.PRNGKey(42)

# Load datasets (assuming SCICORE and data file path are defined)
# Note: Adjust data file path as needed
if 'SCICORE' in globals():
    data_file = SCICORE / "mmml/mmml/data/fixed-acetone-only_MP2_21000.npz"
else:
    # Fallback: adjust path as needed
    data_file = Path("/scicore/home/meuwly/boitti0000/mmml/mmml/data/fixed-acetone-only_MP2_21000.npz")

print(f"Loading data from: {data_file}")

# Prepare datasets
train_data, valid_data = prepare_datasets(
    data_key, 
    10500,  # num_train
    10500,  # num_valid
    [data_file], 
    natoms=ATOMS_PER_MONOMER * N_MONOMERS
)

# Prepare batches for validation data (used to initialize simulations)
valid_batches = prepare_batches_jit(data_key, valid_data, 1, num_atoms=ATOMS_PER_MONOMER * N_MONOMERS)
train_batches = prepare_batches_jit(data_key, train_data, 1, num_atoms=ATOMS_PER_MONOMER * N_MONOMERS)

print(f"Loaded {len(valid_data['R'])} validation samples")
print(f"Prepared {len(valid_batches)} validation batches")
print(f"Each batch contains {len(valid_batches[0]['R'])} atoms")

Loading data from: /scicore/home/meuwly/boitti0000/mmml/mmml/data/fixed-acetone-only_MP2_21000.npz
dataR (21000, 20, 3)
dataE [-81.79712432 -81.48244884 -81.38548297 -81.44645775 -81.74704898
 -81.67295344 -81.32876002 -81.82201676 -81.8124061  -81.80508929]
dataE [-81.79712432 -81.48244884 -81.38548297 -81.44645775 -81.74704898
 -81.67295344 -81.32876002 -81.82201676 -81.8124061  -81.80508929]
D (21000, 3)
Q 1 (21000,) 21000
Q (21000,)
Loaded 10500 validation samples
Prepared 10500 validation batches
Each batch contains 20 atoms


In [6]:
# Additional utility imports (if needed)
from ase.visualize.plot import plot_atoms

In [7]:
# Additional PyCHARMM imports (already imported in cell 3, but kept for reference)
from mmml.pycharmmInterface import setupRes, setupBox

In [12]:
# ========================================================================
# LOAD MODEL AND SETUP CALCULATOR (following run_sim.py structure)
# ========================================================================
uid = "test-84aa02d9-e329-46c4-b12c-f55e6c9a2f94"
SCICORE = Path('/scicore/home/meuwly/boitti0000/')
RESTART=str(SCICORE / "ckpts" / f"{uid}" / "epoch-5450"/ "json_checkpoint")
RESTART
# ========================================================================
# JSON-BASED CHECKPOINT LOADER (no orbax/pickle required)
# ========================================================================
def load_model_parameters_json(epoch_dir, natoms, use_orbax=False):
    """
    Load model parameters from checkpoint using JSON (no orbax/pickle required).
    
    This function tries to load checkpoints from JSON files first, then falls back
    to pickle if needed. JSON is preferred for portability.
    
    Args:
        epoch_dir: Path to checkpoint epoch directory
        natoms: Number of atoms
        use_orbax: If True, try orbax first (default: False)
    
    Returns:
        params, model: Model parameters and model instance
    """
    from mmml.physnetjax.physnetjax.models.model import EF
    import json
    import pickle
    
    epoch_dir = Path(epoch_dir)
    
    # Try orbax first if requested
    if use_orbax:
        try:
            from mmml.physnetjax.physnetjax.restart.restart import get_params_model
            params, model = get_params_model(str(epoch_dir), natoms=natoms)
            if model is not None:
                print("✓ Loaded checkpoint using orbax")
                return params, model
        except Exception as e:
            print(f"Warning: orbax loading failed: {e}")
            print("Falling back to JSON/pickle-based loading...")
    
    # Helper function to convert JSON-serialized arrays back to JAX arrays
    def json_to_jax(obj):
        """Recursively convert JSON lists to JAX arrays."""
        if isinstance(obj, dict):
            return {k: json_to_jax(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            # Check if it's a nested list (array)
            if len(obj) > 0 and isinstance(obj[0], (list, int, float)):
                arr = jnp.array(obj)
                return arr
            else:
                return [json_to_jax(item) for item in obj]
        elif isinstance(obj, (int, float)):
            return obj
        else:
            return obj
    
    # Try JSON-based loading first (preferred)
    json_candidates = [
        epoch_dir / "params.json",
        epoch_dir / "best_params.json",
        epoch_dir / "checkpoint.json",
        epoch_dir / "final_params.json",
    ]
    
    params = None
    params_source = None
    
    # Try JSON files first
    for json_path in json_candidates:
        if json_path.exists():
            print(f"Loading parameters from JSON: {json_path}")
            try:
                with open(json_path, 'r') as f:
                    checkpoint_data = json.load(f)
                
                # Extract params
                if isinstance(checkpoint_data, dict):
                    params_data = checkpoint_data.get('params') or checkpoint_data.get('ema_params') or checkpoint_data
                else:
                    params_data = checkpoint_data
                
                # Convert JSON arrays back to JAX arrays
                params = json_to_jax(params_data)
                params_source = "json"
                break
            except Exception as e:
                print(f"  Failed to load from {json_path}: {e}")
                continue
    
    # Fall back to pickle if JSON not found
    if params is None:
        pickle_candidates = [
            epoch_dir / "params.pkl",
            epoch_dir / "best_params.pkl",
            epoch_dir / "checkpoint.pkl",
            epoch_dir / "final_params.pkl",
        ]
        
        for pkl_path in pickle_candidates:
            if pkl_path.exists():
                print(f"Loading parameters from pickle: {pkl_path}")
                with open(pkl_path, 'rb') as f:
                    checkpoint_data = pickle.load(f)
                
                # Extract params
                if isinstance(checkpoint_data, dict):
                    params = checkpoint_data.get('params') or checkpoint_data.get('ema_params') or checkpoint_data
                else:
                    params = checkpoint_data
                params_source = "pickle"
                break
    
    if params is None:
        all_candidates = [str(p) for p in json_candidates + [
            epoch_dir / "params.pkl",
            epoch_dir / "best_params.pkl",
            epoch_dir / "checkpoint.pkl",
            epoch_dir / "final_params.pkl",
        ]]
        raise FileNotFoundError(
            f"Could not find parameters in {epoch_dir}.\n"
            f"Tried JSON: {[str(p) for p in json_candidates]}\n"
            f"Tried pickle: {[str(p) for p in pickle_candidates if p.exists()]}\n"
            f"Please ensure checkpoint files exist."
        )
    
    # Load model config (prefer JSON)
    config_candidates = [
        epoch_dir / "model_config.json",
        epoch_dir.parent / "model_config.json",
        epoch_dir / "model_config.pkl",
        epoch_dir.parent / "model_config.pkl",
    ]
    
    model_kwargs = {}
    for config_path in config_candidates:
        if config_path.exists():
            print(f"Loading model config from: {config_path}")
            try:
                if config_path.suffix == '.json':
                    with open(config_path, 'r') as f:
                        model_kwargs = json.load(f)
                else:
                    with open(config_path, 'rb') as f:
                        model_kwargs = pickle.load(f)
                break
            except Exception as e:
                print(f"  Warning: Failed to load config from {config_path}: {e}")
                continue
    
    # If no config found, try to extract from checkpoint directory structure
    if not model_kwargs:
        print("Warning: No model config found, using defaults")
        # Try to infer from directory name or use defaults
        model_kwargs = {
            'features': 64,
            'cutoff': 8.0,
            'max_degree': 2,
            'num_iterations': 3,
        }
    
    # Set natoms
    model_kwargs['natoms'] = natoms
    
    # Create model
    model = EF(**model_kwargs)
    model.natoms = natoms
    
    print(f"✓ Loaded checkpoint using {params_source} (no orbax required)")
    print(f"  Model: {model}")
    
    return params, model

# Resolve checkpoint paths
if args.checkpoint is not None:
    base_ckpt_dir, epoch_dir = resolve_checkpoint_paths(args.checkpoint)
    print(f"Checkpoint base dir: {base_ckpt_dir}")
    print(f"Checkpoint epoch dir: {epoch_dir}")
else:
    # Fallback if RESTART is defined
    if 'RESTART' in globals():
        base_ckpt_dir = Path(RESTART)
        epoch_dir = base_ckpt_dir
    else:
        raise ValueError("Checkpoint path must be provided via args.checkpoint or RESTART variable")

# Load model parameters (using JSON-based loader to avoid orbax/pickle requirement)
natoms = ATOMS_PER_MONOMER * N_MONOMERS

# Try JSON-based loading first (preferred, no orbax/pickle required)
try:
    params, model = load_model_parameters_json(epoch_dir, natoms, use_orbax=False)
    print(f"Model loaded using JSON/pickle: {model}")
except Exception as e:
    print(f"JSON/pickle-based loading failed: {e}")
    print("Trying orbax-based loading (requires GPU environment)...")
    try:
        params, model = load_model_parameters(epoch_dir, natoms)
        model.natoms = natoms
        print(f"Model loaded using orbax: {model}")
    except Exception as e2:
        raise RuntimeError(
            f"Failed to load model with all methods:\n"
            f"  JSON/pickle: {e}\n"
            f"  Orbax: {e2}\n"
            f"Make sure checkpoint files exist in {epoch_dir}\n"
            f"Preferred format: JSON files (params.json, model_config.json)"
        )

# Setup calculator factory (following run_sim.py)
calculator_factory = setup_calculator(
    ATOMS_PER_MONOMER=args.n_atoms_monomer,
    N_MONOMERS=args.n_monomers,
    ml_cutoff_distance=args.ml_cutoff,
    mm_switch_on=args.mm_switch_on,
    mm_cutoff=args.mm_cutoff,
    doML=True,
    doMM=args.include_mm,
    doML_dimer=not args.skip_ml_dimers,
    debug=args.debug,
    model_restart_path=base_ckpt_dir,
    MAX_ATOMS_PER_SYSTEM=natoms,
    ml_energy_conversion_factor=1,
    ml_force_conversion_factor=1,
    cell=args.cell,
)

# Create cutoff parameters
CUTOFF_PARAMS = CutoffParameters(
    ml_cutoff=args.ml_cutoff,
    mm_switch_on=args.mm_switch_on,
    mm_cutoff=args.mm_cutoff,
)
print(f"Cutoff parameters: {CUTOFF_PARAMS}")

Loading parameters from JSON: /scicore/home/meuwly/boitti0000/ckpts/test-84aa02d9-e329-46c4-b12c-f55e6c9a2f94/epoch-5450/json_checkpoint/params.json
Loading model config from: /scicore/home/meuwly/boitti0000/ckpts/test-84aa02d9-e329-46c4-b12c-f55e6c9a2f94/epoch-5450/json_checkpoint/model_config.json
✓ Loaded checkpoint using json (no orbax required)
  Model: EF(
    # attributes
    features = 32
    max_degree = 1
    num_iterations = 2
    num_basis_functions = 32
    cutoff = 8.0
    max_atomic_number = 40
    charges = True
    natoms = 20
    total_charge = 0
    n_res = 4
    zbl = False
    debug = False
    efa = False
    use_energy_bias = True
)
Model loaded using JSON/pickle: EF(
    # attributes
    features = 32
    max_degree = 1
    num_iterations = 2
    num_basis_functions = 32
    cutoff = 8.0
    max_atomic_number = 40
    charges = True
    natoms = 20
    total_charge = 0
    n_res = 4
    zbl = False
    debug = False
    efa = False
    use_energy_bias = True
)
[

## Fit Lennard-Jones Parameters to Training Data

Before running simulations, we can optimize the LJ parameters (epsilon and sigma scaling factors) to better match the training dataset. This fits only the MM part of the hybrid potential.


In [15]:
# ========================================================================
# FIT LENNARD-JONES PARAMETERS TO TRAINING DATA (JAX-native approach)
# ========================================================================
# This creates a JAX-differentiable loss function for optimizing ep_scale and sig_scale
# The loss is computed using JAX, making it fully differentiable

def extract_lj_parameters_from_calculator():
    """
    Extract base LJ parameters and indices from the calculator setup.
    
    This should be called once after calculator_factory is created to extract
    the base parameters that will be scaled during optimization.
    
    Returns:
        dict with keys:
            atc_epsilons: Base epsilon values for each atom type
            atc_rmins: Base rmin values for each atom type
            atc_qs: Charges for each atom type
            at_codes: Atom type codes for each atom in the system
            pair_idx_atom_atom: Pair indices for atom-atom interactions
    """
    import pycharmm.param as param
    from mmml.pycharmmInterface.import_pycharmm import psf
    from itertools import product
    
    # Get atom type codes
    atc = param.get_atc()
    at_codes = np.array(psf.get_iac())[:N_MONOMERS * ATOMS_PER_MONOMER]
    
    # Load CGENFF parameters (this should match what's in setup_calculator)
    from mmml.pycharmmInterface.import_pycharmm import (
        CGENFF_RTF, CGENFF_PRM, read, settings
    )
    from mmml.pycharmmInterface.import_pycharmm import reset_block
    reset_block()
    read.rtf(CGENFF_RTF)
    bl = settings.set_bomb_level(-2)
    wl = settings.set_warn_level(-2)
    read.prm(CGENFF_PRM)
    settings.set_bomb_level(bl)
    settings.set_warn_level(wl)
    
    # Extract parameters from parameter file
    cgenff_rtf = open(CGENFF_RTF).readlines()
    atc = param.get_atc()
    cgenff_params_dict_q = {}
    for _ in cgenff_rtf:
        if _.startswith("ATOM"):
            _, atomname, at, q = _.split()[:4]
            try:
                cgenff_params_dict_q[at] = float(q)
            except:
                cgenff_params_dict_q[at] = float(q.split("!")[0])
    
    cgenff_params_dict = {}
    for p in open(CGENFF_PRM).readlines():
        if len(p) > 5 and len(p.split()) > 4 and p.split()[1] == "0.0" and p[0] != "!":
            res, _, ep, sig = p.split()[:4]
            cgenff_params_dict[res] = (float(ep), float(sig))
    
    # Extract base parameters
    atc_epsilons = np.array([cgenff_params_dict.get(_, (0.0, 0.0))[0] for _ in atc])
    atc_rmins = np.array([cgenff_params_dict.get(_, (0.0, 0.0))[1] for _ in atc])
    atc_qs = np.array([cgenff_params_dict_q.get(_, 0.0) for _ in atc])
    
    # Compute pair indices (matching the calculator setup)
    from mmml.pycharmmInterface.mmml_calculator import dimer_permutations
    pair_idxs_product = np.array([(a,b) for a,b in list(product(np.arange(ATOMS_PER_MONOMER), repeat=2))])
    dimer_perms = np.array(dimer_permutations(N_MONOMERS))
    pair_idxs_np = dimer_perms * ATOMS_PER_MONOMER
    pair_idx_atom_atom = pair_idxs_np[:, None, :] + pair_idxs_product[None,...]
    pair_idx_atom_atom = pair_idx_atom_atom.reshape(-1, 2)
    
    return {
        "atc_epsilons": atc_epsilons,
        "atc_rmins": atc_rmins,
        "atc_qs": atc_qs,
        "at_codes": at_codes,
        "pair_idx_atom_atom": pair_idx_atom_atom,
    }


def create_lj_fitting_factory(
    base_calculator_factory,
    atc_epsilons,
    atc_rmins,
    atc_qs,
    at_codes,
    pair_idx_atom_atom,
    cutoff_params,
):
    """
    Create a factory function that computes hybrid energy/forces with differentiable LJ parameters.
    
    This factory accepts ep_scale and sig_scale as JAX arrays and returns a JAX-differentiable
    function for computing energies and forces.
    
    Args:
        base_calculator_factory: The base calculator factory (from setup_calculator)
        atc_epsilons: Base epsilon values for each atom type (JAX array)
        atc_rmins: Base rmin values for each atom type (JAX array)
        atc_qs: Charges for each atom type (JAX array)
        at_codes: Atom type codes for each atom in the system
        pair_idx_atom_atom: Pair indices for atom-atom interactions
        cutoff_params: Cutoff parameters
    
    Returns:
        compute_energy_forces: Function that takes (R, Z, ep_scale, sig_scale) and returns (E, F)
    """
    def compute_energy_forces(R, Z, ep_scale, sig_scale):
        """
        Compute hybrid MM/ML energy and forces with differentiable LJ parameters.
        
        Args:
            R: Positions (n_atoms, 3)
            Z: Atomic numbers (n_atoms,)
            ep_scale: Epsilon scaling factors (n_atom_types,) - JAX array
            sig_scale: Sigma scaling factors (n_atom_types,) - JAX array
        
        Returns:
            E: Total energy (scalar)
            F: Forces (n_atoms, 3)
        """
        # Compute scaled LJ parameters (fully JAX-differentiable)
        at_ep = -1 * jnp.abs(jnp.array(atc_epsilons)) * ep_scale
        at_rm = jnp.array(atc_rmins) * sig_scale
        
        # Get parameters for each atom in the system
        rmins_per_system = jnp.take(at_rm, at_codes)
        epsilons_per_system = jnp.take(at_ep, at_codes)
        q_per_system = jnp.take(jnp.array(atc_qs), at_codes)
        
        # Compute pair parameters
        rm_a = jnp.take(rmins_per_system, pair_idx_atom_atom[:, 0])
        rm_b = jnp.take(rmins_per_system, pair_idx_atom_atom[:, 1])
        ep_a = jnp.take(epsilons_per_system, pair_idx_atom_atom[:, 0])
        ep_b = jnp.take(epsilons_per_system, pair_idx_atom_atom[:, 1])
        q_a = jnp.take(q_per_system, pair_idx_atom_atom[:, 0])
        q_b = jnp.take(q_per_system, pair_idx_atom_atom[:, 1])
        
        pair_rm = (rm_a + rm_b)
        pair_ep = (ep_a * ep_b) ** 0.5
        pair_qq = q_a * q_b
        
        # Compute distances
        displacements = R[pair_idx_atom_atom[:, 0]] - R[pair_idx_atom_atom[:, 1]]
        distances = jnp.linalg.norm(displacements, axis=1)
        
        # Lennard-Jones potential
        def lennard_jones(r, sig, ep):
            r6 = (sig / r) ** 6
            return ep * (r6 ** 2 - 2 * r6)
        
        # Coulomb potential
        coulombs_constant = 3.32063711e2
        coulomb_epsilon = 1e-10
        def coulomb(r, qq, constant=coulombs_constant, eps=coulomb_epsilon):
            r_safe = jnp.maximum(r, eps)
            return -constant * qq / r_safe
        
        # Compute MM energies
        vdw_energies = lennard_jones(distances, pair_rm, pair_ep)
        coulomb_energies = coulomb(distances, pair_qq)
        mm_pair_energies = vdw_energies + coulomb_energies
        
        # Apply switching function (simplified - you may want to use the full switching)
        # For now, just sum the energies
        mm_energy = mm_pair_energies.sum()
        
        # Compute MM forces using JAX grad
        def mm_energy_fn(R_pos):
            # Recompute with new positions
            disp = R_pos[pair_idx_atom_atom[:, 0]] - R_pos[pair_idx_atom_atom[:, 1]]
            dist = jnp.linalg.norm(disp, axis=1)
            vdw = lennard_jones(dist, pair_rm, pair_ep)
            coul = coulomb(dist, pair_qq)
            return (vdw + coul).sum()
        
        mm_forces = -jax.grad(mm_energy_fn)(R)
        
        # Get ML contributions directly from the calculator's internal function
        # This avoids going through ASE and keeps everything in JAX
        # We need to access the spherical_cutoff_calculator from the factory
        # For now, compute ML using the base calculator but convert to JAX arrays
        try:
            # Try to get ML contributions directly as JAX
            # This requires accessing the internal calculator function
            # For now, we'll compute it through the calculator but convert to JAX
            calc, _ = base_calculator_factory(
                atomic_numbers=Z,
                atomic_positions=R,
                n_monomers=args.n_monomers,
                cutoff_params=cutoff_params,
                doML=True,
                doMM=False,  # We compute MM separately with our parameters
                doML_dimer=not args.skip_ml_dimers,
                backprop=True,
                debug=False,
                energy_conversion_factor=1,
                force_conversion_factor=1,
            )
            
            # Get ML energy and forces as JAX arrays
            # The calculator should return JAX-compatible arrays
            atoms = ase.Atoms(Z, R)
            atoms.calc = calc
            ml_energy = jnp.array(atoms.get_potential_energy())
            ml_forces = jnp.array(atoms.get_forces())
        except Exception as e:
            # Fallback: set ML to zero if there's an error
            # This allows MM-only fitting
            ml_energy = jnp.array(0.0)
            ml_forces = jnp.zeros_like(R)
        
        # Combine MM and ML (all JAX arrays, fully differentiable)
        total_energy = ml_energy + mm_energy
        total_forces = ml_forces + mm_forces
        
        return total_energy, total_forces
    
    return compute_energy_forces


def fit_lj_parameters_to_training_data_jax(
    train_batches,
    base_calculator_factory,
    atc_epsilons,
    atc_rmins,
    atc_qs,
    at_codes,
    pair_idx_atom_atom,
    initial_ep_scale=None,
    initial_sig_scale=None,
    n_samples=None,
    energy_weight=1.0,
    force_weight=1.0,
    learning_rate=0.01,
    n_iterations=100,
    verbose=True
):
    """
    Fit LJ parameters (ep_scale, sig_scale) to training data using JAX optimization.
    
    This is a fully JAX-native approach that:
    - Creates a differentiable loss function
    - Uses JAX gradients for optimization
    - Uses optax for optimization
    
    Args:
        train_batches: List of training batches (from prepare_batches_jit)
        base_calculator_factory: Base calculator factory (from setup_calculator)
        atc_epsilons: Base epsilon values for each atom type (numpy array)
        atc_rmins: Base rmin values for each atom type (numpy array)
        atc_qs: Charges for each atom type (numpy array)
        at_codes: Atom type codes for each atom in the system (numpy array)
        pair_idx_atom_atom: Pair indices for atom-atom interactions (numpy array)
        initial_ep_scale: Initial epsilon scaling factors (array, defaults to ones)
        initial_sig_scale: Initial sigma scaling factors (array, defaults to ones)
        n_samples: Number of training samples to use (if None, uses all)
        energy_weight: Weight for energy loss term
        force_weight: Weight for force loss term
        learning_rate: Learning rate for optimization
        n_iterations: Number of optimization iterations
        verbose: Print progress
    
    Returns:
        optimized_ep_scale: Optimized epsilon scaling factors (JAX array)
        optimized_sig_scale: Optimized sigma scaling factors (JAX array)
        loss_history: History of loss values during optimization
    """
    import optax
    
    # Convert inputs to JAX arrays
    atc_epsilons_jax = jnp.array(atc_epsilons)
    atc_rmins_jax = jnp.array(atc_rmins)
    atc_qs_jax = jnp.array(atc_qs)
    at_codes_jax = jnp.array(at_codes)
    pair_idx_atom_atom_jax = jnp.array(pair_idx_atom_atom)
    
    n_atom_types = len(atc_epsilons)
    
    # Initialize scaling factors as JAX arrays
    if initial_ep_scale is None:
        initial_ep_scale = jnp.ones(n_atom_types)
    else:
        initial_ep_scale = jnp.array(initial_ep_scale)
    
    if initial_sig_scale is None:
        initial_sig_scale = jnp.ones(n_atom_types)
    else:
        initial_sig_scale = jnp.array(initial_sig_scale)
    
    # Create the differentiable factory
    compute_energy_forces = create_lj_fitting_factory(
        base_calculator_factory,
        atc_epsilons_jax,
        atc_rmins_jax,
        atc_qs_jax,
        at_codes_jax,
        pair_idx_atom_atom_jax,
        CUTOFF_PARAMS,
    )
    
    # Select training samples
    if n_samples is None:
        n_samples = len(train_batches)
    n_samples = min(n_samples, len(train_batches))
    selected_batches = train_batches[:n_samples]
    
    if verbose:
        print(f"Fitting LJ parameters using {n_samples} training samples")
        print(f"  Number of atom types: {n_atom_types}")
        print(f"  Initial ep_scale: {initial_ep_scale}")
        print(f"  Initial sig_scale: {initial_sig_scale}")
        print(f"  Energy weight: {energy_weight}, Force weight: {force_weight}")
        print(f"  Learning rate: {learning_rate}, Iterations: {n_iterations}")
    
    # Prepare training data
    training_data = []
    for batch in selected_batches:
        R = batch["R"]
        Z = batch["Z"]
        E_ref = batch.get("E", None)
        F_ref = batch.get("F", None)
        
        # Handle batch dimension
        if R.ndim == 3:
            for config_idx in range(R.shape[0]):
                training_data.append({
                    "R": jnp.array(R[config_idx]),
                    "Z": jnp.array(Z[config_idx]),
                    "E_ref": jnp.array(E_ref[config_idx]) if E_ref is not None else None,
                    "F_ref": jnp.array(F_ref[config_idx]) if F_ref is not None else None,
                })
        else:
            training_data.append({
                "R": jnp.array(R),
                "Z": jnp.array(Z),
                "E_ref": jnp.array(E_ref) if E_ref is not None else None,
                "F_ref": jnp.array(F_ref) if F_ref is not None else None,
            })
    
    # Define loss function
    def loss_fn(ep_scale, sig_scale):
        """JAX-differentiable loss function."""
        total_energy_error = 0.0
        total_force_error = 0.0
        n_configs = 0
        
        for data in training_data:
            try:
                E_pred, F_pred = compute_energy_forces(
                    data["R"],
                    data["Z"],
                    ep_scale,
                    sig_scale
                )
                
                if data["E_ref"] is not None:
                    energy_error = (E_pred - data["E_ref"]) ** 2
                    total_energy_error += energy_error
                
                if data["F_ref"] is not None:
                    force_error = jnp.mean((F_pred - data["F_ref"]) ** 2)
                    total_force_error += force_error
                
                n_configs += 1
            except Exception as e:
                if verbose:
                    print(f"  Warning: Error in loss computation: {e}")
                continue
        
        if n_configs == 0:
            return jnp.inf
        
        avg_energy_error = total_energy_error / n_configs
        avg_force_error = total_force_error / n_configs
        
        loss = energy_weight * avg_energy_error + force_weight * avg_force_error
        return loss
    
    # Initialize parameters
    params = {
        "ep_scale": initial_ep_scale,
        "sig_scale": initial_sig_scale,
    }
    
    # Create optimizer
    optimizer = optax.adam(learning_rate=learning_rate)
    opt_state = optimizer.init(params)
    
    # Optimization loop
    loss_history = []
    best_loss = jnp.inf
    best_params = params
    
    if verbose:
        print(f"\nStarting JAX optimization...")
    
    for iteration in range(n_iterations):
        # Compute loss and gradients
        loss, grads = jax.value_and_grad(
            lambda p: loss_fn(p["ep_scale"], p["sig_scale"])
        )(params)
        
        # Update parameters
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        
        # Clip parameters to reasonable bounds
        params = {
            "ep_scale": jnp.clip(params["ep_scale"], 0.1, 10.0),
            "sig_scale": jnp.clip(params["sig_scale"], 0.1, 10.0),
        }
        
        loss_val = float(loss)
        loss_history.append(loss_val)
        
        if loss_val < best_loss:
            best_loss = loss_val
            best_params = params.copy()
        
        if verbose and (iteration % 10 == 0 or iteration == n_iterations - 1):
            print(f"  Iteration {iteration:4d}: Loss = {loss_val:.6f}")
            print(f"    ep_scale: {params['ep_scale']}")
            print(f"    sig_scale: {params['sig_scale']}")
    
    if verbose:
        print(f"\n✓ Optimization complete!")
        print(f"  Final loss: {best_loss:.6f}")
        print(f"  Optimized ep_scale: {best_params['ep_scale']}")
        print(f"  Optimized sig_scale: {best_params['sig_scale']}")
    
    return best_params["ep_scale"], best_params["sig_scale"], loss_history
    
    # Get number of atom types from PyCHARMM if not provided
    if n_atom_types is None:
        try:
            atc = param.get_atc()
            n_atom_types = len(atc)
            if verbose:
                print(f"Inferred {n_atom_types} atom types from PyCHARMM")
        except Exception as e:
            raise ValueError(f"Could not determine n_atom_types. Please provide it explicitly. Error: {e}")
    
    # Initialize scaling factors
    if initial_ep_scale is None:
        initial_ep_scale = np.ones(n_atom_types)
    if initial_sig_scale is None:
        initial_sig_scale = np.ones(n_atom_types)
    
    # Flatten parameters for optimization
    x0 = np.concatenate([initial_ep_scale, initial_sig_scale])
    n_params = len(x0)
    
    # Select training samples
    if n_samples is None:
        n_samples = len(train_batches)
    n_samples = min(n_samples, len(train_batches))
    selected_batches = train_batches[:n_samples]
    
    if verbose:
        print(f"Fitting LJ parameters using {n_samples} training samples")
        print(f"  Number of atom types: {n_atom_types}")
        print(f"  Initial ep_scale: {initial_ep_scale}")
        print(f"  Initial sig_scale: {initial_sig_scale}")
        print(f"  Energy weight: {energy_weight}, Force weight: {force_weight}")
    
    def loss_function(x):
        """Compute loss for given LJ scaling parameters."""
        # Unpack parameters
        ep_scale = x[:n_atom_types]
        sig_scale = x[n_atom_types:]
        
        total_energy_error = 0.0
        total_force_error = 0.0
        n_configs = 0
        
        for batch_idx, batch in enumerate(selected_batches):
            try:
                # Get reference data
                R = batch["R"]
                Z = batch["Z"]
                E_ref = batch.get("E", None)
                F_ref = batch.get("F", None)
                
                # Handle batch dimension
                if R.ndim == 3:
                    # Multiple configurations in batch
                    for config_idx in range(R.shape[0]):
                        R_config = R[config_idx]
                        Z_config = Z[config_idx]
                        
                        # Recreate calculator factory with current LJ parameters
                        # Note: This is expensive but necessary since ep_scale/sig_scale are set at factory creation
                        calc_factory_with_lj = setup_calculator(
                            ATOMS_PER_MONOMER=args.n_atoms_monomer,
                            N_MONOMERS=args.n_monomers,
                            ml_cutoff_distance=args.ml_cutoff,
                            mm_switch_on=args.mm_switch_on,
                            mm_cutoff=args.mm_cutoff,
                            doML=True,
                            doMM=args.include_mm,
                            doML_dimer=not args.skip_ml_dimers,
                            debug=False,
                            model_restart_path=base_ckpt_dir,
                            MAX_ATOMS_PER_SYSTEM=len(Z_config),
                            ml_energy_conversion_factor=1,
                            ml_force_conversion_factor=1,
                            cell=args.cell,
                            ep_scale=ep_scale,
                            sig_scale=sig_scale,
                        )
                        
                        # Create calculator with current LJ parameters
                        calc, _ = calc_factory_with_lj(
                            atomic_numbers=Z_config,
                            atomic_positions=R_config,
                            n_monomers=args.n_monomers,
                            cutoff_params=CUTOFF_PARAMS,
                            doML=True,
                            doMM=args.include_mm,
                            doML_dimer=not args.skip_ml_dimers,
                            backprop=True,
                            debug=False,
                            energy_conversion_factor=1,
                            force_conversion_factor=1,
                        )
                        
                        # Create ASE atoms object
                        atoms = ase.Atoms(Z_config, R_config)
                        atoms.calc = calc
                        
                        # Compute predicted energy and forces
                        E_pred = atoms.get_potential_energy()
                        F_pred = atoms.get_forces()
                        
                        # Compute errors
                        if E_ref is not None:
                            E_ref_config = E_ref[config_idx] if E_ref.ndim > 0 else E_ref
                            energy_error = (E_pred - E_ref_config) ** 2
                            total_energy_error += energy_error
                        
                        if F_ref is not None:
                            F_ref_config = F_ref[config_idx] if F_ref.ndim == 3 else F_ref
                            force_error = np.mean((F_pred - F_ref_config) ** 2)
                            total_force_error += force_error
                        
                        n_configs += 1
                else:
                    # Single configuration
                    # Recreate calculator factory with current LJ parameters
                    calc_factory_with_lj = setup_calculator(
                        ATOMS_PER_MONOMER=args.n_atoms_monomer,
                        N_MONOMERS=args.n_monomers,
                        ml_cutoff_distance=args.ml_cutoff,
                        mm_switch_on=args.mm_switch_on,
                        mm_cutoff=args.mm_cutoff,
                        doML=True,
                        doMM=args.include_mm,
                        doML_dimer=not args.skip_ml_dimers,
                        debug=False,
                        model_restart_path=base_ckpt_dir,
                        MAX_ATOMS_PER_SYSTEM=len(Z),
                        ml_energy_conversion_factor=1,
                        ml_force_conversion_factor=1,
                        cell=args.cell,
                        ep_scale=ep_scale,
                        sig_scale=sig_scale,
                    )
                    
                    # Create calculator with current LJ parameters
                    calc, _ = calc_factory_with_lj(
                        atomic_numbers=Z,
                        atomic_positions=R,
                        n_monomers=args.n_monomers,
                        cutoff_params=CUTOFF_PARAMS,
                        doML=True,
                        doMM=args.include_mm,
                        doML_dimer=not args.skip_ml_dimers,
                        backprop=True,
                        debug=False,
                        energy_conversion_factor=1,
                        force_conversion_factor=1,
                    )
                    
                    # Create ASE atoms object
                    atoms = ase.Atoms(Z, R)
                    atoms.calc = calc
                    
                    # Compute predicted energy and forces
                    E_pred = atoms.get_potential_energy()
                    F_pred = atoms.get_forces()
                    
                    # Compute errors
                    if E_ref is not None:
                        energy_error = (E_pred - E_ref) ** 2
                        total_energy_error += energy_error
                    
                    if F_ref is not None:
                        force_error = np.mean((F_pred - F_ref) ** 2)
                        total_force_error += force_error
                    
                    n_configs += 1
                    
            except Exception as e:
                if verbose:
                    print(f"  Warning: Error processing batch {batch_idx}: {e}")
                continue
        
        if n_configs == 0:
            raise ValueError("No valid configurations processed")
        
        # Normalize by number of configurations
        avg_energy_error = total_energy_error / n_configs
        avg_force_error = total_force_error / n_configs
        
        # Combined loss
        loss = energy_weight * avg_energy_error + force_weight * avg_force_error
        
        if verbose:
            print(f"  Loss: {loss:.6f} (E: {avg_energy_error:.6f}, F: {avg_force_error:.6f})")
        
        return loss
    
    # Set up bounds (keep scaling factors positive and reasonable)
    bounds = [(0.1, 10.0)] * n_params  # Allow 0.1x to 10x scaling
    
    # Optimize
    if verbose:
        print(f"\nStarting optimization with method: {method}")
        print(f"  Bounds: {bounds[0]} (applied to all {n_params} parameters)")
    
    result = minimize(
        loss_function,
        x0=x0,
        method=method,
        bounds=bounds,
        options={'maxiter': maxiter, 'disp': verbose}
    )
    
    # Extract optimized parameters
    optimized_ep_scale = result.x[:n_atom_types]
    optimized_sig_scale = result.x[n_atom_types:]
    
    if verbose:
        print(f"\n✓ Optimization complete!")
        print(f"  Final loss: {result.fun:.6f}")
        print(f"  Optimized ep_scale: {optimized_ep_scale}")
        print(f"  Optimized sig_scale: {optimized_sig_scale}")
        print(f"  Success: {result.success}")
        if result.message:
            print(f"  Message: {result.message}")
    
    return result, optimized_ep_scale, optimized_sig_scale

# Example usage (uncomment to run):

# Step 1: Extract base LJ parameters (do this once, after calculator_factory is created)
lj_params = extract_lj_parameters_from_calculator()
lj_params


  
 CHARMM>     BLOCK
  
  BLOCK>            CALL 1 SELE ALL END
 SELRPN>     20 atoms have been selected out of     20
 The selected atoms have been reassigned to block   1
  
  BLOCK>              COEFF 1 1 1.0
  
  BLOCK>            END
 Matrix of Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of BOND Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of ANGLE Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of DIHE Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of CROSS Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of ELEC Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of VDW Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 

{'atc_epsilons': array([ 0.    ,  0.    , -0.045 , -0.035 , -0.024 , -0.031 , -0.026 ,
        -0.028 , -0.03  , -0.028 , -0.028 , -0.04  , -0.046 , -0.046 ,
        -0.1   , -0.046 , -0.046 , -0.009 , -0.01  , -0.012 , -0.03  ,
        -0.046 , -0.046 , -0.03  , -0.046 , -0.046 , -0.03  , -0.167 ,
        -0.1032, -0.18  , -0.068 , -0.064 , -0.068 , -0.068 , -0.068 ,
        -0.068 , -0.064 , -0.11  , -0.11  , -0.11  , -0.098 , -0.07  ,
        -0.06  , -0.09  , -0.07  , -0.058 , -0.05  , -0.02  , -0.02  ,
        -0.05  , -0.068 , -0.068 , -0.068 , -0.068 , -0.07  , -0.09  ,
        -0.1   , -0.04  , -0.07  , -0.07  , -0.099 , -0.067 , -0.099 ,
        -0.032 , -0.02  , -0.032 , -0.042 , -0.031 , -0.056 , -0.06  ,
        -0.11  , -0.055 , -0.078 , -0.077 , -0.07  , -0.078 , -0.08  ,
        -0.056 , -0.065 , -0.036 , -0.036 , -0.06  , -0.035 , -0.059 ,
        -0.032 , -0.18  , -0.2   , -0.2   , -0.2   , -0.2   , -0.2   ,
        -0.2   , -0.2   , -0.2   , -0.2   , -0.2   , -0.2   ,

In [16]:

# Step 2: Fit LJ parameters using JAX optimization
opt_ep_scale, opt_sig_scale, loss_history = fit_lj_parameters_to_training_data_jax(
    train_batches=train_batches,
    base_calculator_factory=calculator_factory,
    atc_epsilons=lj_params["atc_epsilons"],
    atc_rmins=lj_params["atc_rmins"],
    atc_qs=lj_params["atc_qs"],
    at_codes=lj_params["at_codes"],
    pair_idx_atom_atom=lj_params["pair_idx_atom_atom"],
    n_samples=20,  # Use first 20 batches for fitting
    energy_weight=1.0,
    force_weight=1.0,
    learning_rate=0.01,
    n_iterations=100,
    verbose=True
)



Fitting LJ parameters using 20 training samples
  Number of atom types: 163
  Initial ep_scale: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
  Initial sig_scale: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1.

TypeError: Gradient only defined for scalar-output functions. Output had shape: (1, 1).

In [None]:
# Step 3: Use optimized parameters in subsequent calculations
# Recreate calculator_factory with optimized parameters:
calculator_factory_optimized = setup_calculator(
    ATOMS_PER_MONOMER=args.n_atoms_monomer,
    N_MONOMERS=args.n_monomers,
    ml_cutoff_distance=args.ml_cutoff,
    mm_switch_on=args.mm_switch_on,
    mm_cutoff=args.mm_cutoff,
    doML=True,
    doMM=args.include_mm,
    doML_dimer=not args.skip_ml_dimers,
    debug=args.debug,
    model_restart_path=base_ckpt_dir,
    MAX_ATOMS_PER_SYSTEM=natoms,
    ml_energy_conversion_factor=1,
    ml_force_conversion_factor=1,
    cell=args.cell,
    ep_scale=np.array(opt_ep_scale),  # Convert JAX array to numpy
    sig_scale=np.array(opt_sig_scale),  # Convert JAX array to numpy
)

# Initialize Simulations from valid_data Batches

This section initializes simulations using positions and atomic numbers from `valid_data` batches.
Each batch can be used to create an ASE Atoms object and run a simulation.

In [14]:
# ========================================================================
# SETUP Pycharmm SYSTEM FIRST (required before MM contributions)
# ========================================================================
# IMPORTANT: PyCHARMM system must be initialized BEFORE creating calculators
# that use MM contributions, otherwise charges won't be available
#
# This generates residues in PyCHARMM and builds the structure.
# The atom ordering from PyCHARMM will be used to reorder valid_data batch atoms.

# Clear CHARMM state
CLEAR_CHARMM()
reset_block()

# Generate residues in PyCHARMM
# For N_MONOMERS=2, we generate "ACO ACO" (two acetone molecules)
# Adjust the residue string based on N_MONOMERS and your system
residue_string = " ".join(["ACO"] * N_MONOMERS)
print(f"Generating {N_MONOMERS} residues: {residue_string}")

try:
    # Generate residues (this creates the PSF structure)
    setupRes.generate_residue(residue_string)
    print("Residues generated successfully")
    
    # Build the structure using internal coordinates
    ic.build()
    print("Structure built using internal coordinates")
    
    # Show coordinates
    coor.show()
    
    # Get PyCHARMM atom ordering information
    # This will be used to reorder valid_data batch atoms
    pycharmm_atypes = np.array(psf.get_atype())[:N_MONOMERS * ATOMS_PER_MONOMER]
    pycharmm_resids = np.array(psf.get_res())[:N_MONOMERS * ATOMS_PER_MONOMER]
    pycharmm_iac = np.array(psf.get_iac())[:N_MONOMERS * ATOMS_PER_MONOMER]
    
    print(f"PyCHARMM atom types: {pycharmm_atypes}")
    print(f"PyCHARMM residue IDs: {pycharmm_resids}")
    print(f"PyCHARMM has {len(pycharmm_atypes)} atoms")
    
    # View PyCHARMM state
    mmml.pycharmmInterface.import_pycharmm.view_pycharmm_state()
    
except Exception as e:
    print(f"Warning: Could not initialize PyCHARMM system: {e}")
    print("You may need to adjust residue names/numbers")
    print("MM contributions will be disabled if PyCHARMM is not initialized")
    if args.include_mm:
        print("Setting include_mm=False since PyCHARMM initialization failed")
        args.include_mm = False
    pycharmm_atypes = None
    pycharmm_resids = None
    pycharmm_iac = None

  
 CHARMM>     DELETE ATOM SELE ALL END
  
  
 CHARMM>     DELETE PSF SELE ALL END
  
  
 CHARMM>     BLOCK
  
  BLOCK>            CALL 1 SELE ALL END
Generating 2 residues: ACO ACO
***** Generating residue *****
 SELRPN>      0 atoms have been selected out of      0
 The selected atoms have been reassigned to block   1
  
  BLOCK>              COEFF 1 1 1.0
  
  BLOCK>            END
 Matrix of Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of BOND Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of ANGLE Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of DIHE Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of CROSS Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of ELEC Interaction Coefficients
 
    1.00000
    1.00000   

# Setup PyCHARMM System (REQUIRED before MM contributions)

**IMPORTANT**: The PyCHARMM system must be initialized BEFORE creating calculators that use MM contributions. 

This cell:
1. Generates residues using `setupRes.generate_residue()` (e.g., "ACO ACO" for two acetone molecules)
2. Builds the structure using `ic.build()`
3. Gets the atom ordering from PyCHARMM

**Note on atom reordering**: The atoms from `valid_data` batches may need to be reordered to match PyCHARMM's atom ordering. 
The `reorder_atoms_to_match_pycharmm()` function handles this, but you may need to customize it based on your system.

- Residue names (e.g., "ACO" for acetone) must match your system
- The number of residues should match `N_MONOMERS`
- If PyCHARMM initialization fails, MM contributions will be automatically disabled

In [None]:
# ========================================================================
# ATOM REORDERING FUNCTION
# ========================================================================
# PyCHARMM has a specific atom ordering based on residue and atom type.
# The valid_data batch atoms need to be reordered to match PyCHARMM's ordering.
# This function tries different orderings and selects the one that minimizes
# CHARMM internal energy.

def reorder_atoms_to_match_pycharmm(R, Z, pycharmm_atypes, pycharmm_resids):
    """
    Reorder atoms from valid_data batch to match PyCHARMM's atom ordering.
    
    This function tries different atom orderings and selects the one that
    minimizes CHARMM internal energy (INTE term).
    
    Args:
        R: Positions from valid_data batch (n_atoms, 3)
        Z: Atomic numbers from valid_data batch (n_atoms,)
        pycharmm_atypes: Atom types from PyCHARMM PSF
        pycharmm_resids: Residue IDs from PyCHARMM PSF
    
    Returns:
        R_reordered: Reordered positions matching PyCHARMM ordering
        Z_reordered: Reordered atomic numbers matching PyCHARMM ordering
        reorder_indices: Indices used for reordering
    """
    import pandas as pd
    
    n_atoms = len(R)
    
    print("  Reordering atoms to match PyCHARMM ordering...")
    print(f"  Original R shape: {R.shape}, Z shape: {Z.shape}")
    
    # Start with identity mapping
    base_indices = np.arange(n_atoms)
    
    # Generate candidate reorderings to try
    # Start with the identity (no reordering)
    candidate_orderings = [base_indices.copy()]
    
    # Add common swap patterns (based on user's example)
    # Swap indices 0 ↔ 3
    swap_1 = base_indices.copy()
    swap_1[0] = base_indices[3]
    swap_1[3] = base_indices[0]
    candidate_orderings.append(swap_1)
    
    # Swap indices 10 ↔ 13
    swap_2 = base_indices.copy()
    swap_2[10] = base_indices[13]
    swap_2[13] = base_indices[10]
    candidate_orderings.append(swap_2)
    
    # Combined swap: 0↔3 and 10↔13
    swap_combined = base_indices.copy()
    swap_combined[0] = base_indices[3]
    swap_combined[3] = base_indices[0]
    swap_combined[10] = base_indices[13]
    swap_combined[13] = base_indices[10]
    candidate_orderings.append(swap_combined)
    
    # Try additional swaps within each monomer if needed
    # For each monomer, try swapping first and last atoms
    for monomer_idx in range(N_MONOMERS):
        start_idx = monomer_idx * ATOMS_PER_MONOMER
        end_idx = (monomer_idx + 1) * ATOMS_PER_MONOMER
        if end_idx <= n_atoms:
            swap_monomer = base_indices.copy()
            if start_idx < n_atoms and end_idx - 1 < n_atoms:
                swap_monomer[start_idx] = base_indices[end_idx - 1]
                swap_monomer[end_idx - 1] = base_indices[start_idx]
                candidate_orderings.append(swap_monomer)
    
    print(f"  Trying {len(candidate_orderings)} different atom orderings...")
    
    # Evaluate each ordering by computing CHARMM internal energy
    best_energy = float('inf')
    best_indices = base_indices
    best_R = R
    best_Z = Z
    
    for i, reorder_indices in enumerate(candidate_orderings):
        try:
            # Apply reordering
            R_test = R[reorder_indices]
            Z_test = Z[reorder_indices]
            
            # Validate reordered arrays
            if R_test.shape != R.shape or Z_test.shape != Z.shape:
                print(f"    Ordering {i+1} failed: shape mismatch after reordering")
                continue
            
            # Check for NaN/Inf in positions
            if np.any(~np.isfinite(R_test)):
                print(f"    Ordering {i+1} failed: NaN/Inf in positions")
                continue
            
            # Set positions in PyCHARMM
            xyz = pd.DataFrame(R_test, columns=["x", "y", "z"])
            coor.set_positions(xyz)
            
            # Compute energy with error handling
            try:
                energy.get_energy()
                inte_energy = energy.get_term_by_name("INTE")
                
                # Check if energy is valid
                if not np.isfinite(inte_energy):
                    print(f"    Ordering {i+1} failed: invalid energy (NaN/Inf)")
                    continue
                
                print(f"    Ordering {i+1}/{len(candidate_orderings)}: INTE = {inte_energy:.6f} kcal/mol")
                
                # Keep track of best (lowest energy) ordering
                if inte_energy < best_energy:
                    best_energy = inte_energy
                    best_indices = reorder_indices
                    best_R = R_test
                    best_Z = Z_test
                    
            except Exception as e:
                print(f"    Ordering {i+1} energy calculation failed: {e}")
                import traceback
                traceback.print_exc()
                continue
                
        except Exception as e:
            print(f"    Ordering {i+1} failed: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    # Validate that we found a valid ordering
    if best_energy == float('inf'):
        raise RuntimeError(
            "Failed to find valid atom ordering. All orderings produced invalid energies. "
            "This may indicate:\n"
            "1. PyCHARMM is not properly initialized\n"
            "2. Atom positions are invalid (NaN/Inf)\n"
            "3. Atom types/charges mismatch between batch and PyCHARMM\n"
            "4. PyCHARMM energy calculation is failing"
        )
    
    print(f"  Best ordering found: INTE = {best_energy:.6f} kcal/mol")
    print(f"  Reorder indices: {best_indices}")
    
    # Validate final arrays
    if np.any(~np.isfinite(best_R)):
        raise RuntimeError("Final reordered positions contain NaN/Inf values")
    
    # Set final positions in PyCHARMM
    try:
        xyz = pd.DataFrame(best_R, columns=["x", "y", "z"])
        coor.set_positions(xyz)
        print("  Final positions set in PyCHARMM")
    except Exception as e:
        print(f"  Warning: Could not set final positions in PyCHARMM: {e}")
        raise
    
    return best_R, best_Z, best_indices


# ========================================================================
# INITIALIZE SIMULATIONS FROM VALID_DATA BATCHES
# ========================================================================
# Following run_sim.py structure, we'll initialize simulations using valid_data batches
# NOTE: PyCHARMM system should be set up in the previous cell before calling this

def initialize_simulation_from_batch(batch_idx=0):
    """
    Initialize a simulation from a valid_data batch.
    
    Args:
        batch_idx: Index of the batch to use (default: 0)
    
    Returns:
        atoms: ASE Atoms object initialized from the batch
        hybrid_calc: Hybrid calculator for the system
    """
    # Get positions and atomic numbers from batch
    R = valid_batches[batch_idx]["R"]
    Z = valid_batches[batch_idx]["Z"]
    
    # Extract the first configuration from the batch
    # Note: batches may contain multiple configurations
    if R.ndim == 3:
        # Batch shape: (batch_size, n_atoms, 3)
        R = R[0]
        Z = Z[0]
    elif R.ndim == 2:
        # Already flattened: (n_atoms, 3)
        pass
    else:
        raise ValueError(f"Unexpected R shape: {R.shape}")
    
    # Ensure we have the right number of atoms
    n_atoms_expected = ATOMS_PER_MONOMER * N_MONOMERS
    if len(R) != n_atoms_expected:
        print(f"Warning: Expected {n_atoms_expected} atoms, got {len(R)}")
        R = R[:n_atoms_expected]
        Z = Z[:n_atoms_expected]
    
    print(f"Initializing simulation from batch {batch_idx}")
    print(f"  Positions shape: {R.shape}")
    print(f"  Atomic numbers shape: {Z.shape}")
    print(f"  Number of atoms: {len(R)}")
    
    # Reorder atoms to match PyCHARMM ordering if PyCHARMM is initialized
    if args.include_mm and 'pycharmm_atypes' in globals() and pycharmm_atypes is not None:
        R, Z, reorder_indices = reorder_atoms_to_match_pycharmm(
            R, Z, pycharmm_atypes, pycharmm_resids
        )
        print(f"  Atoms reordered to match PyCHARMM ordering")
    else:
        print(f"  No reordering applied (MM disabled or PyCHARMM not initialized)")
    
    # Create ASE Atoms object
    atoms = ase.Atoms(Z, R)
    
    # Sync positions with PyCHARMM if MM is enabled
    # This ensures PyCHARMM coordinates match the batch positions
    if args.include_mm:
        try:
            import pandas as pd
            xyz = pd.DataFrame(R, columns=["x", "y", "z"])
            coor.set_positions(xyz)
            print("  Synced positions with PyCHARMM")
        except Exception as e:
            print(f"  Warning: Could not sync positions with PyCHARMM: {e}")
    
    # Create hybrid calculator (following run_sim.py)
    # Note: MM contributions require PyCHARMM to be initialized first
    hybrid_calc, _ = calculator_factory(
        atomic_numbers=Z,
        atomic_positions=R,
        n_monomers=args.n_monomers,
        cutoff_params=CUTOFF_PARAMS,
        doML=True,
        doMM=args.include_mm,
        doML_dimer=not args.skip_ml_dimers,
        backprop=True,
        debug=args.debug,
        energy_conversion_factor=1,
        force_conversion_factor=1,
    )
    
    atoms.calc = hybrid_calc
    
    # Get initial energy and forces
    try:
        hybrid_energy = float(atoms.get_potential_energy())
        hybrid_forces = np.asarray(atoms.get_forces())
        print(f"Initial energy: {hybrid_energy:.6f} eV")
        print(f"Initial forces shape: {hybrid_forces.shape}")
        print(f"Max force: {np.abs(hybrid_forces).max():.6f} eV/Å")
    except Exception as e:
        print(f"Warning: Could not compute initial energy/forces: {e}")
        print("This may be due to PyCHARMM not being properly initialized")
        print("or atom ordering mismatch. Check the reordering function.")
        raise
    
    return atoms, hybrid_calc

# Initialize first simulation from batch 0
atoms, hybrid_calc = initialize_simulation_from_batch(batch_idx=0)

# Initialize Multiple Simulations from valid_data Batches

This cell demonstrates how to initialize multiple simulations from different batches.
Each simulation can be run independently.

In [None]:
# ========================================================================
# INITIALIZE MULTIPLE SIMULATIONS FROM VALID_DATA BATCHES
# ========================================================================
# Following run_sim.py structure, we can initialize multiple simulations

def initialize_multiple_simulations(n_simulations=5):
    """
    Initialize multiple simulations from different valid_data batches.
    
    Args:
        n_simulations: Number of simulations to initialize (default: 5)
    
    Returns:
        List of (atoms, hybrid_calc) tuples
    """
    simulations = []
    n_batches = len(valid_batches)
    
    for i in range(min(n_simulations, n_batches)):
        try:
            atoms, calc = initialize_simulation_from_batch(batch_idx=i)
            simulations.append((atoms, calc))
            print(f"Successfully initialized simulation {i+1}/{n_simulations}")
        except Exception as e:
            print(f"Warning: Failed to initialize simulation from batch {i}: {e}")
            continue
    
    return simulations

# Initialize multiple simulations
# Adjust n_simulations as needed
simulations = initialize_multiple_simulations(n_simulations=5)
print(f"\nInitialized {len(simulations)} simulations from valid_data batches")

# Example: Run a Simple Energy Calculation

This demonstrates how to use the initialized simulations.

In [None]:
# ========================================================================
# EXAMPLE: RUN ENERGY CALCULATIONS
# ========================================================================

# Example: Calculate energy for the first simulation
if len(simulations) > 0:
    atoms_example, calc_example = simulations[0]
    energy = atoms_example.get_potential_energy()
    forces = atoms_example.get_forces()
    print(f"Example simulation energy: {energy:.6f} eV")
    print(f"Example simulation forces shape: {forces.shape}")
    print(f"Max force magnitude: {np.abs(forces).max():.6f} eV/Å")
else:
    print("No simulations initialized. Check batch data and system parameters.")

  
 CHARMM>     BLOCK
  
  BLOCK>            CALL 1 SELE ALL END
 SELRPN>    500 atoms have been selected out of    500
 The selected atoms have been reassigned to block   1
  
  BLOCK>              COEFF 1 1 1.0
  
  BLOCK>            END
 Matrix of Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of BOND Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of ANGLE Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of DIHE Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of CROSS Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of ELEC Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of VDW Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 

Array gpu:0 -2.325e+04

In [None]:
# Next Steps: Running MD Simulations

To run MD simulations following `run_sim.py`, you can:
1. Use the `minimize_structure` function from run_sim.py
2. Use the `run_ase_md` function for ASE-based MD
3. Use JAX-MD for more advanced simulations

See `run_sim.py` for complete MD simulation setup.

[31mSignature:[39m
calculator_factory(
    atomic_numbers: [33m'Array'[39m,
    atomic_positions: [33m'Array'[39m,
    n_monomers: [33m'int'[39m,
    cutoff_params: [33m'CutoffParameters'[39m = [38;5;28;01mNone[39;00m,
    doML: [33m'bool'[39m = [38;5;28;01mTrue[39;00m,
    doMM: [33m'bool'[39m = [38;5;28;01mTrue[39;00m,
    doML_dimer: [33m'bool'[39m = [38;5;28;01mTrue[39;00m,
    backprop: [33m'bool'[39m = [38;5;28;01mFalse[39;00m,
    debug: [33m'bool'[39m = [38;5;28;01mFalse[39;00m,
    energy_conversion_factor: [33m'float'[39m = [32m1.0[39m,
    force_conversion_factor: [33m'float'[39m = [32m1.0[39m,
    verbose: [33m'bool'[39m = [38;5;28;01mNone[39;00m,
) -> [33m'Tuple[AseDimerCalculator, Callable]'[39m
[31mDocstring:[39m
Factory function to create calculator instances

Args:
    verbose: If True, store full ModelOutput breakdown in results.
             If None, defaults to debug value.
[31mFile:[39m      ~/mmml/mmml/pycharmmInte

In [None]:
# ========================================================================
# HELPER FUNCTIONS (from run_sim.py)
# ========================================================================
# These functions can be copied from run_sim.py for running MD simulations

def minimize_structure(atoms, run_index=0, nsteps=60, fmax=0.0006, charmm=False, calculator=None):
    """
    Minimize structure using BFGS optimizer (from run_sim.py)
    
    Args:
        atoms: ASE Atoms object (must have calculator set, or provide calculator)
        run_index: Index for trajectory file naming
        nsteps: Maximum number of optimization steps
        fmax: Force convergence criterion
        charmm: If True, run CHARMM minimization first
        calculator: Optional calculator to set if atoms doesn't have one
    """
    # Ensure calculator is set
    if atoms.calc is None:
        if calculator is not None:
            atoms.calc = calculator
        else:
            # Try to create calculator from atoms
            Z = atoms.get_atomic_numbers()
            R = atoms.get_positions()
            try:
                calc, _ = calculator_factory(
                    atomic_numbers=Z,
                    atomic_positions=R,
                    n_monomers=args.n_monomers,
                    cutoff_params=CUTOFF_PARAMS,
                    doML=True,
                    doMM=args.include_mm,
                    doML_dimer=not args.skip_ml_dimers,
                    backprop=True,
                    debug=args.debug,
                    energy_conversion_factor=1,
                    force_conversion_factor=1,
                )
                atoms.calc = calc
                print("  Created calculator for minimization")
            except Exception as e:
                raise RuntimeError(f"Cannot minimize: atoms has no calculator and cannot create one: {e}")
    
    if charmm:
        pycharmm.minimize.run_abnr(nstep=1000, tolenr=1e-6, tolgrd=1e-6)
        pycharmm.lingo.charmm_script("ENER")
        pycharmm.energy.show()
        atoms.set_positions(coor.get_positions())

    traj = ase_io.Trajectory(f'bfgs_{run_index}_{args.output_prefix}_minimized.traj', 'w')
    print("Minimizing structure with hybrid calculator")
    print(f"Running BFGS for {nsteps} steps")
    print(f"Running BFGS with fmax: {fmax}")
    _ = ase_opt.BFGS(atoms, trajectory=traj).run(fmax=fmax, steps=nsteps)
    # Sync with PyCHARMM
    import pandas as pd
    xyz = pd.DataFrame(atoms.get_positions(), columns=["x", "y", "z"])
    coor.set_positions(xyz)
    traj.write(atoms)
    traj.close()
    return atoms

# Example: Minimize the first simulation
if len(simulations) > 0:
    # Get atoms and calculator from the simulation
    atoms_to_minimize, calc_to_minimize = simulations[0]
    # Create a copy but preserve the calculator
    atoms_to_minimize = atoms_to_minimize.copy()
    atoms_to_minimize.calc = calc_to_minimize  # Ensure calculator is set
    print("Running minimization...")
    print("Note: Calculator is preserved from the initialized simulation")
    # Uncomment to run minimization:
    # atoms_minimized = minimize_structure(atoms_to_minimize, run_index=0, nsteps=100, fmax=0.0006)


# Notes on Residue Numbers and Atom Ordering

When setting up PyCHARMM simulations:

**Residue Setup:**
- Use `setupRes.generate_residue("ACO ACO")` to generate residues (for 2 acetone molecules)
- Use `ic.build()` to build the structure
- The number of residues should match `N_MONOMERS`

**Atom Ordering:**
- PyCHARMM has a specific atom ordering based on residue and atom type
- The `valid_data` batch atoms **must be reordered** to match PyCHARMM's ordering
- The `reorder_atoms_to_match_pycharmm()` function tries different orderings and selects the one that **minimizes CHARMM internal energy** (`energy.get_term_by_name("INTE")`)
- Common swaps tested: indices 0↔3, 10↔13, and combinations
- The function automatically finds the best ordering by energy minimization

**To customize reordering:**
1. Add more swap patterns to the `candidate_orderings` list in `reorder_atoms_to_match_pycharmm()`
2. The function will automatically test all candidates and select the one with minimum INTE energy
3. Example swaps: `fix_idxs[0] = _fix_idxs[3]; fix_idxs[3] = _fix_idxs[0]` (swap 0↔3)
4. The energy-based selection ensures the correct ordering is found automatically


In [None]:
# ========================================================================
# SUMMARY
# ========================================================================
print("=" * 60)
print("Simulation Setup Complete")
print("=" * 60)
print(f"Number of simulations initialized: {len(simulations)}")
print(f"Number of atoms per simulation: {ATOMS_PER_MONOMER * N_MONOMERS}")
print(f"Number of monomers: {N_MONOMERS}")
print(f"Atoms per monomer: {ATOMS_PER_MONOMER}")
print(f"ML cutoff: {args.ml_cutoff} Å")
print(f"MM switch on: {args.mm_switch_on} Å")
print(f"MM cutoff: {args.mm_cutoff} Å")
print(f"Valid data batches available: {len(valid_batches)}")
print("=" * 60)
print("\nTo run MD simulations, use the helper functions or refer to run_sim.py")
print("Note: Residue numbers may need adjustment based on your system")


In [None]:
energy.show()


 NONBOND OPTION FLAGS: 
     ELEC     VDW      ATOMs    CDIElec  FSHIft   VATOm    VFSWIt  
     BYGRoup  NOEXtnd  NOEWald 
 CUTNB  = 14.000 CTEXNB =999.000 CTONNB = 10.000 CTOFNB = 12.000
 CGONNB =  0.000 CGOFNB = 10.000
 WMIN   =  1.500 WRNMXD =  0.500 E14FAC =  1.000 EPS    =  1.000
 NBXMOD =      5
 There are        0 atom  pairs and        0 atom  exclusions.
 There are        0 group pairs and        0 group exclusions.
 <MAKINB> with mode   5 found   1200 exclusions and    600 interactions(1-4)
 <MAKGRP> found      0 group exclusions.
ENER ENR:  Eval#     ENERgy      Delta-E         GRMS
 ----------       ---------    ---------    ---------    ---------    ---------
ENER>        0     -0.00000      0.00000      0.00000
 ----------       ---------    ---------    ---------    ---------    ---------


In [None]:
R = valid_batches[0]["R"]
Z = valid_batches[0]["Z"]
R,Z