## Outline

This notebook is meant to detail setting up the MM/ML simulations 

## fitting the LJs terms

In [1]:
import mmml
import ase
import os
from pathlib import Path
import argparse
import sys
import numpy as np
import jax
import jax.numpy as jnp

# Set environment variables
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".99"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Check JAX configuration
devices = jax.local_devices()
print(devices)
print(jax.default_backend())
print(jax.devices())


[CudaDevice(id=0)]
gpu
[CudaDevice(id=0)]


# Setup: Mock CLI Arguments (following run_sim.py structure)

This cell creates a mock args object that mimics the CLI arguments from `run_sim.py`.
This allows the notebook to follow the same structure as the script.

In [2]:
# Import required modules (following run_sim.py structure)
from mmml.cli.base import (
    load_model_parameters,
    resolve_checkpoint_paths,
    setup_ase_imports,
    setup_mmml_imports,
)
from mmml.pycharmmInterface import import_pycharmm
import pycharmm
import pycharmm.ic as ic
import pycharmm.psf as psf
import pycharmm.energy as energy
from mmml.pycharmmInterface.mmml_calculator import setup_calculator, CutoffParameters
from mmml.physnetjax.physnetjax.data.data import prepare_datasets
from mmml.physnetjax.physnetjax.data.batches import prepare_batches_jit
from mmml.pycharmmInterface.setupBox import setup_box_generic
from mmml.pycharmmInterface import setupRes, setupBox
from mmml.pycharmmInterface.import_pycharmm import reset_block, coor
from mmml.pycharmmInterface.pycharmmCommands import CLEAR_CHARMM

# Setup ASE imports
Atoms = setup_ase_imports()
CutoffParameters, ev2kcalmol, setup_calculator, get_ase_calc = setup_mmml_imports()

# Additional imports for simulation
import ase.io as ase_io
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary, ZeroRotation
from ase.md.verlet import VelocityVerlet
import ase.optimize as ase_opt

/scicore/home/meuwly/boitti0000/mmml/mmml/data/top_all36_cgenff.rtf
/scicore/home/meuwly/boitti0000/mmml/mmml/data/par_all36_cgenff.prm
CHARMM_HOME /scicore/home/meuwly/boitti0000/mmml/setup/charmm
CHARMM_LIB_DIR /scicore/home/meuwly/boitti0000/mmml/setup/charmm
  
 CHARMM>     BLOCK
 Block structure initialized with   3 blocks.
 All atoms have been assigned to block 1.
 All interaction coefficients have been set to unity.
  Setting number of block exclusions nblock_excldPairs=0
  
  BLOCK>            CALL 1 SELE ALL END
 SELRPN>      0 atoms have been selected out of      0
 The selected atoms have been reassigned to block   1
  
  BLOCK>              COEFF 1 1 1.0
  
  BLOCK>            END
 Matrix of Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of BOND Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.00000   1.00000   1.00000
 Matrix of ANGLE Interaction Coefficients
 
    1.00000
    1.00000   1.00000
    1.0

In [3]:
# ========================================================================
# MOCK CLI ARGUMENTS (spoofing run_sim.py CLI)
# ========================================================================
# Create a mock args object that mimics the CLI arguments from run_sim.py
# This allows the notebook to follow the same structure as the script

class MockArgs:
    """Mock CLI arguments following run_sim.py structure"""
    def __init__(self):
        # Paths
        self.pdbfile = None  # Will be created from valid_data if needed
        self.checkpoint = Path(RESTART) if 'RESTART' in globals() else None
        
        # System parameters
        self.n_monomers = 2
        self.n_atoms_monomer = 10
        self.atoms_per_monomer = 10  # Alias for compatibility
        
        # Calculator parameters
        self.ml_cutoff = 2.0
        self.mm_switch_on = 4.0
        self.mm_cutoff = 1.0
        self.include_mm = True
        self.skip_ml_dimers = False
        self.debug = False
        
        # MD simulation parameters
        self.temperature = 210.0
        self.timestep = 0.1
        self.nsteps_jaxmd = 100_000
        self.nsteps_ase = 10000
        self.ensemble = "nvt"
        self.heating_interval = 500
        self.write_interval = 100
        self.energy_catch = 0.5
        
        # Output
        self.output_prefix = "md_simulation"
        self.cell = None  # No PBC by default
        
        # Validation
        self.validate = False

# Create mock args object
args = MockArgs()

# Override with notebook-specific values if needed
if 'ATOMS_PER_MONOMER' in globals():
    args.n_atoms_monomer = ATOMS_PER_MONOMER
    args.atoms_per_monomer = ATOMS_PER_MONOMER
if 'N_MONOMERS' in globals():
    args.n_monomers = N_MONOMERS

print(f"Mock args created:")
print(f"  n_monomers: {args.n_monomers}")
print(f"  n_atoms_monomer: {args.n_atoms_monomer}")
print(f"  ml_cutoff: {args.ml_cutoff}")
print(f"  mm_switch_on: {args.mm_switch_on}")
print(f"  mm_cutoff: {args.mm_cutoff}")

Mock args created:
  n_monomers: 2
  n_atoms_monomer: 10
  ml_cutoff: 2.0
  mm_switch_on: 4.0
  mm_cutoff: 1.0


In [19]:
from mmml.physnetjax.physnetjax.restart.restart import get_last, get_params_model, get_params_model_with_ase
uid = "test-84aa02d9-e329-46c4-b12c-f55e6c9a2f94"
SCICORE = Path('/scicore/home/meuwly/boitti0000/')
RESTART=str(SCICORE / "ckpts" / f"{uid}")
args.checkpoint = Path(RESTART)

In [20]:
# System parameters (can be overridden by args)
ATOMS_PER_MONOMER = args.n_atoms_monomer
N_MONOMERS = args.n_monomers

# Load Data and Prepare Batches (following run_sim.py structure)

This cell loads the validation data and prepares batches that will be used to initialize simulations.
Note: The residue numbers in the PDB/PSF may need to be adjusted based on the actual system.

In [21]:
# ========================================================================
# LOAD DATA AND PREPARE BATCHES (following run_sim.py structure)
# ========================================================================

# Initialize random key for data loading
if 'data_key' not in globals():
    data_key = jax.random.PRNGKey(42)

# Load datasets (assuming SCICORE and data file path are defined)
# Note: Adjust data file path as needed
if 'SCICORE' in globals():
    data_file = SCICORE / "mmml/mmml/data/fixed-acetone-only_MP2_21000.npz"
else:
    # Fallback: adjust path as needed
    data_file = Path("/scicore/home/meuwly/boitti0000/mmml/mmml/data/fixed-acetone-only_MP2_21000.npz")

print(f"Loading data from: {data_file}")

# Prepare datasets
train_data, valid_data = prepare_datasets(
    data_key, 
    10500,  # num_train
    10500,  # num_valid
    [data_file], 
    natoms=ATOMS_PER_MONOMER * N_MONOMERS
)

# Prepare batches for validation data (used to initialize simulations)
valid_batches = prepare_batches_jit(data_key, valid_data, 1, num_atoms=ATOMS_PER_MONOMER * N_MONOMERS)
train_batches = prepare_batches_jit(data_key, train_data, 1, num_atoms=ATOMS_PER_MONOMER * N_MONOMERS)

print(f"Loaded {len(valid_data['R'])} validation samples")
print(f"Prepared {len(valid_batches)} validation batches")
print(f"Each batch contains {len(valid_batches[0]['R'])} atoms")

Loading data from: /scicore/home/meuwly/boitti0000/mmml/mmml/data/fixed-acetone-only_MP2_21000.npz
dataR (21000, 20, 3)
dataE [-81.79712432 -81.48244884 -81.38548297 -81.44645775 -81.74704898
 -81.67295344 -81.32876002 -81.82201676 -81.8124061  -81.80508929]
dataE [-81.79712432 -81.48244884 -81.38548297 -81.44645775 -81.74704898
 -81.67295344 -81.32876002 -81.82201676 -81.8124061  -81.80508929]
D (21000, 3)
Q 1 (21000,) 21000
Q (21000,)
Loaded 10500 validation samples
Prepared 10500 validation batches
Each batch contains 20 atoms


In [22]:
# Additional utility imports (if needed)
from ase.visualize.plot import plot_atoms

In [None]:
# Additional PyCHARMM imports (already imported in cell 3, but kept for reference)
from mmml.pycharmmInterface import setupRes, setupBox

In [None]:
# ========================================================================
# LOAD MODEL AND SETUP CALCULATOR (following run_sim.py structure)
# ========================================================================

# Resolve checkpoint paths
if args.checkpoint is not None:
    base_ckpt_dir, epoch_dir = resolve_checkpoint_paths(args.checkpoint)
    print(f"Checkpoint base dir: {base_ckpt_dir}")
    print(f"Checkpoint epoch dir: {epoch_dir}")
else:
    # Fallback if RESTART is defined
    if 'RESTART' in globals():
        base_ckpt_dir = Path(RESTART)
        epoch_dir = base_ckpt_dir
    else:
        raise ValueError("Checkpoint path must be provided via args.checkpoint or RESTART variable")

# Load model parameters
natoms = ATOMS_PER_MONOMER * N_MONOMERS
params, model = load_model_parameters(epoch_dir, natoms)
model.natoms = natoms
print(f"Model loaded: {model}")

# Setup calculator factory (following run_sim.py)
calculator_factory = setup_calculator(
    ATOMS_PER_MONOMER=args.n_atoms_monomer,
    N_MONOMERS=args.n_monomers,
    ml_cutoff_distance=args.ml_cutoff,
    mm_switch_on=args.mm_switch_on,
    mm_cutoff=args.mm_cutoff,
    doML=True,
    doMM=args.include_mm,
    doML_dimer=not args.skip_ml_dimers,
    debug=args.debug,
    model_restart_path=base_ckpt_dir,
    MAX_ATOMS_PER_SYSTEM=natoms,
    ml_energy_conversion_factor=1,
    ml_force_conversion_factor=1,
    cell=args.cell,
)

# Create cutoff parameters
CUTOFF_PARAMS = CutoffParameters(
    ml_cutoff=args.ml_cutoff,
    mm_switch_on=args.mm_switch_on,
    mm_cutoff=args.mm_cutoff,
)
print(f"Cutoff parameters: {CUTOFF_PARAMS}")

# Initialize Simulations from valid_data Batches

This section initializes simulations using positions and atomic numbers from `valid_data` batches.
Each batch can be used to create an ASE Atoms object and run a simulation.

In [None]:
# ========================================================================
# SETUP Pycharmm SYSTEM FIRST (required before MM contributions)
# ========================================================================
# IMPORTANT: PyCHARMM system must be initialized BEFORE creating calculators
# that use MM contributions, otherwise charges won't be available
#
# This generates residues in PyCHARMM and builds the structure.
# The atom ordering from PyCHARMM will be used to reorder valid_data batch atoms.

# Clear CHARMM state
CLEAR_CHARMM()
reset_block()

# Generate residues in PyCHARMM
# For N_MONOMERS=2, we generate "ACO ACO" (two acetone molecules)
# Adjust the residue string based on N_MONOMERS and your system
residue_string = " ".join(["ACO"] * N_MONOMERS)
print(f"Generating {N_MONOMERS} residues: {residue_string}")

try:
    # Generate residues (this creates the PSF structure)
    setupRes.generate_residue(residue_string)
    print("Residues generated successfully")
    
    # Build the structure using internal coordinates
    ic.build()
    print("Structure built using internal coordinates")
    
    # Show coordinates
    coor.show()
    
    # Get PyCHARMM atom ordering information
    # This will be used to reorder valid_data batch atoms
    pycharmm_atypes = np.array(psf.get_atype())[:N_MONOMERS * ATOMS_PER_MONOMER]
    pycharmm_resids = np.array(psf.get_res())[:N_MONOMERS * ATOMS_PER_MONOMER]
    pycharmm_iac = np.array(psf.get_iac())[:N_MONOMERS * ATOMS_PER_MONOMER]
    
    print(f"PyCHARMM atom types: {pycharmm_atypes}")
    print(f"PyCHARMM residue IDs: {pycharmm_resids}")
    print(f"PyCHARMM has {len(pycharmm_atypes)} atoms")
    
    # View PyCHARMM state
    mmml.pycharmmInterface.import_pycharmm.view_pycharmm_state()
    
except Exception as e:
    print(f"Warning: Could not initialize PyCHARMM system: {e}")
    print("You may need to adjust residue names/numbers")
    print("MM contributions will be disabled if PyCHARMM is not initialized")
    if args.include_mm:
        print("Setting include_mm=False since PyCHARMM initialization failed")
        args.include_mm = False
    pycharmm_atypes = None
    pycharmm_resids = None
    pycharmm_iac = None

# Setup PyCHARMM System (REQUIRED before MM contributions)

**IMPORTANT**: The PyCHARMM system must be initialized BEFORE creating calculators that use MM contributions. 

This cell:
1. Generates residues using `setupRes.generate_residue()` (e.g., "ACO ACO" for two acetone molecules)
2. Builds the structure using `ic.build()`
3. Gets the atom ordering from PyCHARMM

**Note on atom reordering**: The atoms from `valid_data` batches may need to be reordered to match PyCHARMM's atom ordering. 
The `reorder_atoms_to_match_pycharmm()` function handles this, but you may need to customize it based on your system.

- Residue names (e.g., "ACO" for acetone) must match your system
- The number of residues should match `N_MONOMERS`
- If PyCHARMM initialization fails, MM contributions will be automatically disabled

In [None]:
# ========================================================================
# ATOM REORDERING FUNCTION
# ========================================================================
# PyCHARMM has a specific atom ordering based on residue and atom type.
# The valid_data batch atoms need to be reordered to match PyCHARMM's ordering.
# This function tries different orderings and selects the one that minimizes
# CHARMM internal energy.

def reorder_atoms_to_match_pycharmm(R, Z, pycharmm_atypes, pycharmm_resids):
    """
    Reorder atoms from valid_data batch to match PyCHARMM's atom ordering.
    
    This function tries different atom orderings and selects the one that
    minimizes CHARMM internal energy (INTE term).
    
    Args:
        R: Positions from valid_data batch (n_atoms, 3)
        Z: Atomic numbers from valid_data batch (n_atoms,)
        pycharmm_atypes: Atom types from PyCHARMM PSF
        pycharmm_resids: Residue IDs from PyCHARMM PSF
    
    Returns:
        R_reordered: Reordered positions matching PyCHARMM ordering
        Z_reordered: Reordered atomic numbers matching PyCHARMM ordering
        reorder_indices: Indices used for reordering
    """
    import pandas as pd
    
    n_atoms = len(R)
    
    print("  Reordering atoms to match PyCHARMM ordering...")
    print(f"  Original R shape: {R.shape}, Z shape: {Z.shape}")
    
    # Start with identity mapping
    base_indices = np.arange(n_atoms)
    
    # Generate candidate reorderings to try
    # Start with the identity (no reordering)
    candidate_orderings = [base_indices.copy()]
    
    # Add common swap patterns (based on user's example)
    # Swap indices 0 ↔ 3
    swap_1 = base_indices.copy()
    swap_1[0] = base_indices[3]
    swap_1[3] = base_indices[0]
    candidate_orderings.append(swap_1)
    
    # Swap indices 10 ↔ 13
    swap_2 = base_indices.copy()
    swap_2[10] = base_indices[13]
    swap_2[13] = base_indices[10]
    candidate_orderings.append(swap_2)
    
    # Combined swap: 0↔3 and 10↔13
    swap_combined = base_indices.copy()
    swap_combined[0] = base_indices[3]
    swap_combined[3] = base_indices[0]
    swap_combined[10] = base_indices[13]
    swap_combined[13] = base_indices[10]
    candidate_orderings.append(swap_combined)
    
    # Try additional swaps within each monomer if needed
    # For each monomer, try swapping first and last atoms
    for monomer_idx in range(N_MONOMERS):
        start_idx = monomer_idx * ATOMS_PER_MONOMER
        end_idx = (monomer_idx + 1) * ATOMS_PER_MONOMER
        if end_idx <= n_atoms:
            swap_monomer = base_indices.copy()
            if start_idx < n_atoms and end_idx - 1 < n_atoms:
                swap_monomer[start_idx] = base_indices[end_idx - 1]
                swap_monomer[end_idx - 1] = base_indices[start_idx]
                candidate_orderings.append(swap_monomer)
    
    print(f"  Trying {len(candidate_orderings)} different atom orderings...")
    
    # Evaluate each ordering by computing CHARMM internal energy
    best_energy = float('inf')
    best_indices = base_indices
    best_R = R
    best_Z = Z
    
    for i, reorder_indices in enumerate(candidate_orderings):
        try:
            # Apply reordering
            R_test = R[reorder_indices]
            Z_test = Z[reorder_indices]
            
            # Set positions in PyCHARMM
            xyz = pd.DataFrame(R_test, columns=["x", "y", "z"])
            coor.set_positions(xyz)
            
            # Compute energy
            energy.get_energy()
            inte_energy = energy.get_term_by_name("INTE")
            
            print(f"    Ordering {i+1}/{len(candidate_orderings)}: INTE = {inte_energy:.6f} kcal/mol")
            
            # Keep track of best (lowest energy) ordering
            if inte_energy < best_energy:
                best_energy = inte_energy
                best_indices = reorder_indices
                best_R = R_test
                best_Z = Z_test
                
        except Exception as e:
            print(f"    Ordering {i+1} failed: {e}")
            continue
    
    print(f"  Best ordering found: INTE = {best_energy:.6f} kcal/mol")
    print(f"  Reorder indices: {best_indices}")
    
    # Set final positions in PyCHARMM
    xyz = pd.DataFrame(best_R, columns=["x", "y", "z"])
    coor.set_positions(xyz)
    
    return best_R, best_Z, best_indices


# ========================================================================
# INITIALIZE SIMULATIONS FROM VALID_DATA BATCHES
# ========================================================================
# Following run_sim.py structure, we'll initialize simulations using valid_data batches
# NOTE: PyCHARMM system should be set up in the previous cell before calling this

def initialize_simulation_from_batch(batch_idx=0):
    """
    Initialize a simulation from a valid_data batch.
    
    Args:
        batch_idx: Index of the batch to use (default: 0)
    
    Returns:
        atoms: ASE Atoms object initialized from the batch
        hybrid_calc: Hybrid calculator for the system
    """
    # Get positions and atomic numbers from batch
    R = valid_batches[batch_idx]["R"]
    Z = valid_batches[batch_idx]["Z"]
    
    # Extract the first configuration from the batch
    # Note: batches may contain multiple configurations
    if R.ndim == 3:
        # Batch shape: (batch_size, n_atoms, 3)
        R = R[0]
        Z = Z[0]
    elif R.ndim == 2:
        # Already flattened: (n_atoms, 3)
        pass
    else:
        raise ValueError(f"Unexpected R shape: {R.shape}")
    
    # Ensure we have the right number of atoms
    n_atoms_expected = ATOMS_PER_MONOMER * N_MONOMERS
    if len(R) != n_atoms_expected:
        print(f"Warning: Expected {n_atoms_expected} atoms, got {len(R)}")
        R = R[:n_atoms_expected]
        Z = Z[:n_atoms_expected]
    
    print(f"Initializing simulation from batch {batch_idx}")
    print(f"  Positions shape: {R.shape}")
    print(f"  Atomic numbers shape: {Z.shape}")
    print(f"  Number of atoms: {len(R)}")
    
    # Reorder atoms to match PyCHARMM ordering if PyCHARMM is initialized
    if args.include_mm and 'pycharmm_atypes' in globals() and pycharmm_atypes is not None:
        R, Z, reorder_indices = reorder_atoms_to_match_pycharmm(
            R, Z, pycharmm_atypes, pycharmm_resids
        )
        print(f"  Atoms reordered to match PyCHARMM ordering")
    else:
        print(f"  No reordering applied (MM disabled or PyCHARMM not initialized)")
    
    # Create ASE Atoms object
    atoms = ase.Atoms(Z, R)
    
    # Sync positions with PyCHARMM if MM is enabled
    # This ensures PyCHARMM coordinates match the batch positions
    if args.include_mm:
        try:
            import pandas as pd
            xyz = pd.DataFrame(R, columns=["x", "y", "z"])
            coor.set_positions(xyz)
            print("  Synced positions with PyCHARMM")
        except Exception as e:
            print(f"  Warning: Could not sync positions with PyCHARMM: {e}")
    
    # Create hybrid calculator (following run_sim.py)
    # Note: MM contributions require PyCHARMM to be initialized first
    hybrid_calc, _ = calculator_factory(
        atomic_numbers=Z,
        atomic_positions=R,
        n_monomers=args.n_monomers,
        cutoff_params=CUTOFF_PARAMS,
        doML=True,
        doMM=args.include_mm,
        doML_dimer=not args.skip_ml_dimers,
        backprop=True,
        debug=args.debug,
        energy_conversion_factor=1,
        force_conversion_factor=1,
    )
    
    atoms.calc = hybrid_calc
    
    # Get initial energy and forces
    try:
        hybrid_energy = float(atoms.get_potential_energy())
        hybrid_forces = np.asarray(atoms.get_forces())
        print(f"Initial energy: {hybrid_energy:.6f} eV")
        print(f"Initial forces shape: {hybrid_forces.shape}")
        print(f"Max force: {np.abs(hybrid_forces).max():.6f} eV/Å")
    except Exception as e:
        print(f"Warning: Could not compute initial energy/forces: {e}")
        print("This may be due to PyCHARMM not being properly initialized")
        print("or atom ordering mismatch. Check the reordering function.")
        raise
    
    return atoms, hybrid_calc

# Initialize first simulation from batch 0
atoms, hybrid_calc = initialize_simulation_from_batch(batch_idx=0)

# Initialize Multiple Simulations from valid_data Batches

This cell demonstrates how to initialize multiple simulations from different batches.
Each simulation can be run independently.

In [None]:
# ========================================================================
# INITIALIZE MULTIPLE SIMULATIONS FROM VALID_DATA BATCHES
# ========================================================================
# Following run_sim.py structure, we can initialize multiple simulations

def initialize_multiple_simulations(n_simulations=5):
    """
    Initialize multiple simulations from different valid_data batches.
    
    Args:
        n_simulations: Number of simulations to initialize (default: 5)
    
    Returns:
        List of (atoms, hybrid_calc) tuples
    """
    simulations = []
    n_batches = len(valid_batches)
    
    for i in range(min(n_simulations, n_batches)):
        try:
            atoms, calc = initialize_simulation_from_batch(batch_idx=i)
            simulations.append((atoms, calc))
            print(f"Successfully initialized simulation {i+1}/{n_simulations}")
        except Exception as e:
            print(f"Warning: Failed to initialize simulation from batch {i}: {e}")
            continue
    
    return simulations

# Initialize multiple simulations
# Adjust n_simulations as needed
simulations = initialize_multiple_simulations(n_simulations=5)
print(f"\nInitialized {len(simulations)} simulations from valid_data batches")

# Example: Run a Simple Energy Calculation

This demonstrates how to use the initialized simulations.

In [None]:
# ========================================================================
# EXAMPLE: RUN ENERGY CALCULATIONS
# ========================================================================

# Example: Calculate energy for the first simulation
if len(simulations) > 0:
    atoms_example, calc_example = simulations[0]
    energy = atoms_example.get_potential_energy()
    forces = atoms_example.get_forces()
    print(f"Example simulation energy: {energy:.6f} eV")
    print(f"Example simulation forces shape: {forces.shape}")
    print(f"Max force magnitude: {np.abs(forces).max():.6f} eV/Å")
else:
    print("No simulations initialized. Check batch data and system parameters.")

In [None]:
# Next Steps: Running MD Simulations

To run MD simulations following `run_sim.py`, you can:
1. Use the `minimize_structure` function from run_sim.py
2. Use the `run_ase_md` function for ASE-based MD
3. Use JAX-MD for more advanced simulations

See `run_sim.py` for complete MD simulation setup.

In [None]:
# ========================================================================
# HELPER FUNCTIONS (from run_sim.py)
# ========================================================================
# These functions can be copied from run_sim.py for running MD simulations

def minimize_structure(atoms, run_index=0, nsteps=60, fmax=0.0006, charmm=False, calculator=None):
    """
    Minimize structure using BFGS optimizer (from run_sim.py)
    
    Args:
        atoms: ASE Atoms object (must have calculator set, or provide calculator)
        run_index: Index for trajectory file naming
        nsteps: Maximum number of optimization steps
        fmax: Force convergence criterion
        charmm: If True, run CHARMM minimization first
        calculator: Optional calculator to set if atoms doesn't have one
    """
    # Ensure calculator is set
    if atoms.calc is None:
        if calculator is not None:
            atoms.calc = calculator
        else:
            # Try to create calculator from atoms
            Z = atoms.get_atomic_numbers()
            R = atoms.get_positions()
            try:
                calc, _ = calculator_factory(
                    atomic_numbers=Z,
                    atomic_positions=R,
                    n_monomers=args.n_monomers,
                    cutoff_params=CUTOFF_PARAMS,
                    doML=True,
                    doMM=args.include_mm,
                    doML_dimer=not args.skip_ml_dimers,
                    backprop=True,
                    debug=args.debug,
                    energy_conversion_factor=1,
                    force_conversion_factor=1,
                )
                atoms.calc = calc
                print("  Created calculator for minimization")
            except Exception as e:
                raise RuntimeError(f"Cannot minimize: atoms has no calculator and cannot create one: {e}")
    
    if charmm:
        pycharmm.minimize.run_abnr(nstep=1000, tolenr=1e-6, tolgrd=1e-6)
        pycharmm.lingo.charmm_script("ENER")
        pycharmm.energy.show()
        atoms.set_positions(coor.get_positions())

    traj = ase_io.Trajectory(f'bfgs_{run_index}_{args.output_prefix}_minimized.traj', 'w')
    print("Minimizing structure with hybrid calculator")
    print(f"Running BFGS for {nsteps} steps")
    print(f"Running BFGS with fmax: {fmax}")
    _ = ase_opt.BFGS(atoms, trajectory=traj).run(fmax=fmax, steps=nsteps)
    # Sync with PyCHARMM
    import pandas as pd
    xyz = pd.DataFrame(atoms.get_positions(), columns=["x", "y", "z"])
    coor.set_positions(xyz)
    traj.write(atoms)
    traj.close()
    return atoms

# Example: Minimize the first simulation
if len(simulations) > 0:
    # Get atoms and calculator from the simulation
    atoms_to_minimize, calc_to_minimize = simulations[0]
    # Create a copy but preserve the calculator
    atoms_to_minimize = atoms_to_minimize.copy()
    atoms_to_minimize.calc = calc_to_minimize  # Ensure calculator is set
    print("Running minimization...")
    print("Note: Calculator is preserved from the initialized simulation")
    # Uncomment to run minimization:
    # atoms_minimized = minimize_structure(atoms_to_minimize, run_index=0, nsteps=100, fmax=0.0006)


# Notes on Residue Numbers and Atom Ordering

When setting up PyCHARMM simulations:

**Residue Setup:**
- Use `setupRes.generate_residue("ACO ACO")` to generate residues (for 2 acetone molecules)
- Use `ic.build()` to build the structure
- The number of residues should match `N_MONOMERS`

**Atom Ordering:**
- PyCHARMM has a specific atom ordering based on residue and atom type
- The `valid_data` batch atoms **must be reordered** to match PyCHARMM's ordering
- The `reorder_atoms_to_match_pycharmm()` function tries different orderings and selects the one that **minimizes CHARMM internal energy** (`energy.get_term_by_name("INTE")`)
- Common swaps tested: indices 0↔3, 10↔13, and combinations
- The function automatically finds the best ordering by energy minimization

**To customize reordering:**
1. Add more swap patterns to the `candidate_orderings` list in `reorder_atoms_to_match_pycharmm()`
2. The function will automatically test all candidates and select the one with minimum INTE energy
3. Example swaps: `fix_idxs[0] = _fix_idxs[3]; fix_idxs[3] = _fix_idxs[0]` (swap 0↔3)
4. The energy-based selection ensures the correct ordering is found automatically


In [None]:
# ========================================================================
# SUMMARY
# ========================================================================
print("=" * 60)
print("Simulation Setup Complete")
print("=" * 60)
print(f"Number of simulations initialized: {len(simulations)}")
print(f"Number of atoms per simulation: {ATOMS_PER_MONOMER * N_MONOMERS}")
print(f"Number of monomers: {N_MONOMERS}")
print(f"Atoms per monomer: {ATOMS_PER_MONOMER}")
print(f"ML cutoff: {args.ml_cutoff} Å")
print(f"MM switch on: {args.mm_switch_on} Å")
print(f"MM cutoff: {args.mm_cutoff} Å")
print(f"Valid data batches available: {len(valid_batches)}")
print("=" * 60)
print("\nTo run MD simulations, use the helper functions or refer to run_sim.py")
print("Note: Residue numbers may need adjustment based on your system")


In [None]:
energy.show()

In [None]:
R = valid_batches[0]["R"]
Z = valid_batches[0]["Z"]
R,Z