! uv pip install jax[cuda12] --force-reinstall
# Imports

In [1]:
import mmml
import ase
import os
from pathlib import Path
# Set environment variables
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".45"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"


# from jax import config
# config.update('jax_enable_x64', True)

# Check JAX configuration
import jax
devices = jax.local_devices()
print(devices)
print(jax.default_backend())
print(jax.devices())


[CudaDevice(id=0)]
gpu
[CudaDevice(id=0)]


In [2]:
from mmml.physnetjax import *

In [3]:
from mmml.physnetjax.physnetjax.calc.helper_mlp import get_ase_calc
reshifted = False

In [4]:
from mmml.physnetjax.physnetjax.models import model as model
from mmml.physnetjax.physnetjax.models.model import EF
from mmml.physnetjax.physnetjax.training.training import train_model


# Data

In [5]:
from mmml.physnetjax.physnetjax.data.data import prepare_datasets
from mmml.physnetjax.physnetjax.data.batches import prepare_batches_jit

In [6]:
data_key, train_key = jax.random.split(jax.random.PRNGKey(42), 2)
BATCHSIZE = 128

In [7]:
# Initialize random key for data loading
if 'data_key' not in globals():
    data_key = jax.random.PRNGKey(42)


data_file = "/pchem-data/meuwly/boittier/home/mmml/mmml/data/fixed-acetone-only_MP2_21000.npz"

print(f"Loading data from: {data_file}")

# Prepare datasets
train_data, valid_data = prepare_datasets(
    data_key, 
    10500,  # num_train
    10500,  # num_valid
    [data_file], 
    natoms=20
)


Loading data from: /pchem-data/meuwly/boittier/home/mmml/mmml/data/fixed-acetone-only_MP2_21000.npz
dataR (21000, 20, 3)
dataE [-81.79712432 -81.48244884 -81.38548297 -81.44645775 -81.74704898
 -81.67295344 -81.32876002 -81.82201676 -81.8124061  -81.80508929]
dataE [-81.79712432 -81.48244884 -81.38548297 -81.44645775 -81.74704898
 -81.67295344 -81.32876002 -81.82201676 -81.8124061  -81.80508929]
D (21000, 3)
Q 1 (21000,) 21000
Q (21000,)


In [8]:
files = [data_file]
train_size = 20000 
valid_size = 1000
NATOMSMAX = 20

train_data, valid_data = prepare_datasets(data_key, train_size, valid_size, files, natoms=NATOMSMAX)

dataR (21000, 20, 3)
dataE [-81.79712432 -81.48244884 -81.38548297 -81.44645775 -81.74704898
 -81.67295344 -81.32876002 -81.82201676 -81.8124061  -81.80508929]
dataE [-81.79712432 -81.48244884 -81.38548297 -81.44645775 -81.74704898
 -81.67295344 -81.32876002 -81.82201676 -81.8124061  -81.80508929]
D (21000, 3)
Q 1 (21000,) 21000
Q (21000,)


In [9]:
import openqdc

In [10]:
train_data, valid_data = prepare_datasets(data_key, train_size, valid_size, files, natoms=NATOMSMAX)



valid_batches = prepare_batches_jit(data_key, valid_data, BATCHSIZE, num_atoms = NATOMSMAX)

dataR (21000, 20, 3)
dataE [-81.79712432 -81.48244884 -81.38548297 -81.44645775 -81.74704898
 -81.67295344 -81.32876002 -81.82201676 -81.8124061  -81.80508929]
dataE [-81.79712432 -81.48244884 -81.38548297 -81.44645775 -81.74704898
 -81.67295344 -81.32876002 -81.82201676 -81.8124061  -81.80508929]
D (21000, 3)
Q 1 (21000,) 21000
Q (21000,)


In [11]:
atom_energies = {1: -13.717939590030356 ,
6: -1029.831662730747 ,
7: -1485.40806126101 ,
8: -2042.7920344362644 ,
16: -10831.264715514206 ,}

In [12]:
XXX = train_data["E"] / (train_data["Z"].sum(axis=1))
XXX .flatten().flatten().mean()

np.float64(-1.3306833009327976)

## Save Checkpoint as JSON (for portability)

After training, you can save checkpoints as JSON files for easy loading without requiring orbax or pickle. This is useful for sharing models or loading in different environments.


In [13]:
# ========================================================================
# SAVE CHECKPOINT AS JSON (no orbax/pickle required for loading)
# ========================================================================
# This function converts JAX parameters to JSON-serializable format
# and saves them along with model configuration

def save_checkpoint_as_json(params, model, save_dir, epoch=None, best_loss=None):
    """
    Save model checkpoint as JSON files for portability.
    
    This saves:
    - params.json: Model parameters (converted from JAX arrays to lists)
    - model_config.json: Model configuration
    
    Args:
        params: Model parameters (JAX PyTree)
        model: Model instance
        save_dir: Directory to save checkpoint files
        epoch: Optional epoch number
        best_loss: Optional best loss value
    """
    import json
    from pathlib import Path
    
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    
    # Helper function to convert JAX arrays to JSON-serializable format
    def jax_to_json(obj):
        """Recursively convert JAX arrays to lists for JSON serialization."""
        if isinstance(obj, dict):
            return {k: jax_to_json(v) for k, v in obj.items()}
        elif isinstance(obj, (list, tuple)):
            return [jax_to_json(item) for item in obj]
        elif hasattr(obj, '__array__'):  # JAX/NumPy arrays
            return jnp.asarray(obj).tolist()
        elif isinstance(obj, (int, float, str, bool, type(None))):
            return obj
        else:
            # Try to convert to array if possible
            try:
                return jnp.asarray(obj).tolist()
            except:
                return str(obj)
    
    # Save parameters as JSON
    params_json = jax_to_json(params)
    params_path = save_dir / "params.json"
    
    print(f"Saving parameters to: {params_path}")
    with open(params_path, 'w') as f:
        json.dump(params_json, f, indent=2)
    print(f"  ✓ Saved {params_path}")
    
    # Extract and save model configuration
    model_config = {}
    
    # Try to extract model attributes
    if hasattr(model, 'return_attributes'):
        # PhysNet EF models
        model_config = model.return_attributes()
    elif hasattr(model, '__dict__'):
        # Try to get config from model attributes
        config_attrs = [
            'features', 'cutoff', 'max_degree', 'num_iterations',
            'num_basis_functions', 'max_atomic_number', 'n_res',
            'zbl', 'efa', 'charges', 'natoms', 'total_charge'
        ]
        for attr in config_attrs:
            if hasattr(model, attr):
                value = getattr(model, attr)
                # Convert JAX arrays to Python types
                if hasattr(value, '__array__'):
                    value = float(jnp.asarray(value).item()) if value.size == 1 else jax_to_json(value)
                model_config[attr] = value
    else:
        print("Warning: Could not extract model configuration")
    
    # Save model config as JSON
    config_path = save_dir / "model_config.json"
    print(f"Saving model config to: {config_path}")
    with open(config_path, 'w') as f:
        json.dump(model_config, f, indent=2)
    print(f"  ✓ Saved {config_path}")
    
    # Optionally save metadata
    if epoch is not None or best_loss is not None:
        metadata = {}
        if epoch is not None:
            metadata['epoch'] = int(epoch)
        if best_loss is not None:
            # Convert JAX array to float if needed
            if hasattr(best_loss, '__array__'):
                metadata['best_loss'] = float(jnp.asarray(best_loss).item())
            else:
                metadata['best_loss'] = float(best_loss)
        
        metadata_path = save_dir / "metadata.json"
        print(f"Saving metadata to: {metadata_path}")
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        print(f"  ✓ Saved {metadata_path}")
    
    print(f"\n✓ Checkpoint saved as JSON in: {save_dir}")
    print(f"  Files created:")
    print(f"    - {params_path.name}")
    print(f"    - {config_path.name}")
    if epoch is not None or best_loss is not None:
        print(f"    - metadata.json")
    print(f"\n  These files can be loaded without orbax or pickle!")
    
    return save_dir


In [14]:
model = EF(
    features=128,
    max_degree = 1,
    num_basis_functions=64,
    num_iterations = 3,
    n_res=3,
    cutoff = 12.0,
    max_atomic_number = 40,
    zbl=True,
    efa=False,
    charges=True,
    debug=False
)
model

EF(
    # attributes
    features = 128
    max_degree = 1
    num_iterations = 3
    num_basis_functions = 64
    cutoff = 12.0
    max_atomic_number = 40
    charges = True
    natoms = 60
    total_charge = 0
    n_res = 3
    zbl = True
    debug = False
    efa = False
    use_energy_bias = True
)

## Training

In [15]:
do_training = True
if do_training:
    # uid = "test-84aa02d9-e329-46c4-b12c-f55e6c9a2f94"
    uid = "pyhsnetacetone-d38b2d5c-b24d-432b-83b4-801ff726dbde"
    uid = "eq_acetone-dc858977-288b-447a-a877-7801923bac47"
    # SCICORE = Path('/scicore/home/meuwly/boitti0000/ckpts')
    SCICORE = Path("/pchem-data/meuwly/boittier/home/ckpts")
    RESTART=str(SCICORE / f"{uid}")
    params_out = train_model(
        train_key,
        model,
        train_data,
        valid_data, 
        num_epochs = 5000,
        learning_rate=0.001,
        batch_size=BATCHSIZE,
        num_atoms=NATOMSMAX,
        energy_weight=1,
        restart=RESTART,
        conversion={'energy': 1, 'forces': 1},
        print_freq=1,
        name='eq_acetone',
        best=False,
        optimizer=None,
        transform=None,
        schedule_fn="constant",
        objective='valid_loss',
        ckpt_dir=SCICORE,
        log_tb=False,
        batch_method="default",
        batch_args_dict=None,
        data_keys=('R', 'Z', 'F', "N", 'E', 'D', 'batch_segments'),
        
    )

Using default (fat) batching method


Extra Validation Info:
Z: Array[1000, 20] i32 n=20000 (78Kb) x∈[0, 8] μ=3.038 σ=2.769 cpu:0
R: Array[1000, 20, 3] n=60000 (0.2Mb) x∈[-17.661, 18.309] μ=-1.848 σ=7.122 cpu:0
E: Array[1000, 1] 3.9Kb x∈[-82.149, -40.481] μ=-77.528 σ=12.309 cpu:0
N: Array[1000, 1] i32 3.9Kb x∈[10, 20] μ=18.990 σ=3.013 cpu:0
F: Array[1000, 20, 3] n=60000 (0.2Mb) x∈[-4.446, 4.514] μ=3.099e-10 σ=0.959 cpu:0
D: Array[1000, 3] n=3000 (12Kb) x∈[-0.419, 0.412] μ=-0.014 σ=0.167 cpu:0


ERROR:asyncio:Exception in callback Task.__step()
handle: <Handle Task.__step()>
Traceback (most recent call last):
  File "/pchem-data/meuwly/boittier/home/.conda/envs/mmml-gpu/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
RuntimeError: cannot enter context: <_contextvars.Context object at 0x1486a58ead80> is already entered
ERROR:asyncio:Exception in callback Task.__step()
handle: <Handle Task.__step()>
Traceback (most recent call last):
  File "/pchem-data/meuwly/boittier/home/.conda/envs/mmml-gpu/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
RuntimeError: cannot enter context: <_contextvars.Context object at 0x1486a58ead80> is already entered
ERROR:asyncio:Exception in callback Task.__step()
handle: <Handle Task.__step()>
Traceback (most recent call last):
  File "/pchem-data/meuwly/boittier/home/.conda/envs/mmml-gpu/lib/python3.12/asyncio/events.py", line 88, in _run
  

dict_keys(['opt_state', 'params', 'step'])


Restoring from /pchem-data/meuwly/boittier/home/ckpts/eq_acetone-dc858977-288b-447a-a877-7801923bac47/epoch-145
Restored keys: dict_keys(['best_loss', 'ema_params', 'epoch', 'lr_eff', 'model', 'model_attributes', 'objectives', 'opt_state', 'params', 'transform_state'])
Training resumed from step 145, best_loss Array gpu:0 6.380e+07




Output()


KeyboardInterrupt



### Example: Save loaded checkpoint as JSON

After loading a checkpoint (e.g., from orbax), you can save it as JSON for easier loading later:


In [None]:
from mmml.physnetjax.physnetjax.restart.restart import get_last, get_params_model, get_params_model_with_ase

In [None]:
uid = "pyhsnetacetone-d38b2d5c-b24d-432b-83b4-801ff726dbde"
uid = "eq_acetone-46a8cd1d-880c-427b-8c3f-c206c3b75a19"
SCICORE = Path('/scicore/home/meuwly/boitti0000/')
SCICORE = Path("/pchem-data/meuwly/boittier/home/ckpts")
RESTART=str(SCICORE / f"{uid}")
RESTART

# │ /pchem-data/meuwly/boittier/home/ckpts/eq_acetone-a114f328-a678-4999-904a-ef8ff78a7eb1/epoch-9 │ 20:46:28  │     


# │ /pchem-data/meuwly/boittier/home/ckpts/eq_acetone-472e0e91-11d3-4fc9-b773-ab5a68b4cc42/epoch-10 │ 20:43:48  │    


In [None]:
last= get_last(RESTART)
last 

In [None]:
params, model, everything = get_params_model(last, return_everything=True)
params, model, everything

In [None]:
# Example: Save the loaded checkpoint as JSON
# This converts the orbax checkpoint to JSON format for portability
import jax.numpy as jnp
# Save the loaded params and model as JSON
json_save_dir = last / "json_checkpoint"
save_checkpoint_as_json(
    params=params,
    model=model,
    save_dir=json_save_dir,
    epoch=everything.get('epoch', None) if 'everything' in locals() else None,
    best_loss=everything.get('best_loss', None) if 'everything' in locals() else None
)

# Now you can load this checkpoint later using:
# from 3-sim.ipynb: load_model_parameters_json(json_save_dir, natoms=NATOMSMAX)
json_save_dir


In [None]:
list(json_save_dir.glob("*"))

In [None]:
N = valid_batches[0]["N"][0]
R = valid_batches[0]["R"][:N]
Z = valid_batches[0]["Z"][:N]
atoms=ase.Atoms(Z , R)
atoms

In [None]:
get_params_model_with_ase?

# Validation

In [None]:
from mmml.physnetjax.physnetjax.analysis.analysis import  *
model.natoms = 20
output = mmml.physnetjax.physnetjax.analysis.analysis.eval(valid_batches, model, params, batch_size=BATCHSIZE)
Es, Eeles, predEs, Fs, predFs, Ds, predDs, charges, outputs = output


In [None]:
ase_kcalmol = ase.units.kcal/ase.units.mol
1/ase_kcalmol

In [None]:
bin_count_edges = plt.hist(Es, bins=20)
bin_count_edges = plt.hist(predEs, bins=2000)
bins = [bin_count_edges[1][0]]
for i, _ in enumerate(bin_count_edges[0]):
    if _ == 0 and bin_count_edges[0][i-1] != 0:
        bins.append(bin_count_edges[1][i-1])
        plt.axvline(bins[-1])
bins.append(bin_count_edges[1][-1])
bins
    

In [None]:
Es

In [None]:
for i in range(len(bins)-1):
    monomers_idx = (Es < bins[1+i]) * (Es > bins[i])
    if np.sum(monomers_idx) != 0:
        # dimers_idx = Es < bins[1]
        print(bins)
        ax = plt.gca()
        plot(Es[monomers_idx]/ase_kcalmol, predEs[monomers_idx]/ase_kcalmol, ax, units="kcal/mol", _property="", kde=False, s=10, diag=True)
        # plt.xlim(bins[i], bins[i+1])
        # plt.ylim(bins[i], bins[i+1])
        plt.show()
    plt.show()
    # ax = plt.gca()
    # plot(Es[dimers_idx]/ase_kcalmol, predEs[dimers_idx]/ase_kcalmol, ax, units="kcal/mol", _property="", kde=False, s=1, diag=True)

In [None]:
ax = plt.gca()
plot(Es/ase_kcalmol, predEs/ase_kcalmol, ax, units="kcal/mol", _property="", kde=True, s=10, diag=True)

In [None]:
ax = plt.gca()
plot(Fs/ase_kcalmol, predFs/ase_kcalmol, ax, units="kcal/mol", _property="", kde=True, s=1, diag=True)

In [None]:
ax = plt.gca()
plot(Ds/ase_kcalmol, predDs/ase_kcalmol, ax, units="e $\AA$", _property="", kde=True, s=1, diag=True)

# Calculator

In [None]:
# !conda install pint
from mmml.pycharmmInterface import import_pycharmm
import pycharmm
from mmml.pycharmmInterface.mmml_calculator import setup_calculator, CutoffParameters

In [None]:
?CutoffParameters

In [None]:
ATOMS_PER_MONOMER = 10
N_MONOMERS = 2

In [None]:
?setup_calculator

In [None]:
calculator_factory = setup_calculator(
    ATOMS_PER_MONOMER,
    N_MONOMERS,
    ml_cutoff_distance  = 0.01,
    mm_switch_on = 8.0,
    mm_cutoff  = 5.0,
    doML = True,
    doMM  = True,
    doML_dimer  = True,
    debug  = False,
    ep_scale = None,
    sig_scale = None,
    model_restart_path = RESTART,
    MAX_ATOMS_PER_SYSTEM = 20,
)

In [None]:
from ase.visualize.plot import plot_atoms
from mmml.pycharmmInterface import import_pycharmm
from mmml.pycharmmInterface.import_pycharmm import  *

In [None]:
from mmml.pycharmmInterface import setupRes, setupBox
from mmml.pycharmmInterface.import_pycharmm import reset_block, reset_block_no_internal
from mmml.pycharmmInterface.pycharmmCommands import CLEAR_CHARMM

In [None]:
CLEAR_CHARMM()
reset_block()
reset_block_no_internal()
reset_block()
reset_block()
reset_block_no_internal()
reset_block()

In [None]:
train_data, valid_data = prepare_datasets(data_key, 10500, 10500, [ "/pchem-data/meuwly/boittier/home/mmml/mmml/data/fixed-acetone-only_MP2_21000.npz"], natoms=20)
valid_batches = prepare_batches_jit(data_key, valid_data, 1, num_atoms = 20)
train_batches = prepare_batches_jit(data_key, train_data, 1, num_atoms = 20)

In [None]:
setupBox.initialize_psf("ACO", 2, 30, None)

In [None]:
mmml.pycharmmInterface.import_pycharmm.view_pycharmm_state()

In [None]:
energy.show()

In [None]:
R = valid_batches[0]["R"]
Z = valid_batches[0]["Z"]
R,Z

In [None]:
ase_atoms = ase.Atoms(Z, R)

In [None]:
FACTOR = 1/(ase.units.kcal/ase.units.mol)
calcs = calculator_factory(atomic_numbers=Z , atomic_positions=R , n_monomers=2,
                           # energy_conversion_factor=FACTOR, force_conversion_factor=FACTOR, 
                           debug=False
                          )

In [None]:
ase_atoms.calc = calcs[0]

In [None]:
ase_atoms.get_potential_energy()

In [None]:
calculator_factory?

In [None]:
pycharmm_quiet()
pycharmm_verbose()

In [None]:
fix_idxs = np.array(range(20))
_fix_idxs = np.array(range(20))
fix_idxs[0] = _fix_idxs[3]
fix_idxs[3] = _fix_idxs[0]
fix_idxs[10] = _fix_idxs[13]
fix_idxs[13] = _fix_idxs[10]
batch_index = 0

In [None]:
test_R = valid_batches[batch_index]["R"][:20][fix_idxs]

In [None]:
test_Z = valid_batches[batch_index]["Z"][:20]
test_Z, test_Z[fix_idxs]

In [None]:
model.natoms = 20
model

In [None]:
ref_physnet_atoms = ase.Atoms(test_Z[fix_idxs], test_R)
ref_physnet_atoms.calc = get_ase_calc(params, model, ref_physnet_atoms, {"energy": 1, "forces": 1}, ['energy', 'forces'])

In [None]:
ref_physnet_atoms.get_potential_energy()

In [None]:
ref_physnet_F = ref_physnet_atoms.get_forces().flatten()

In [None]:
ref_e = valid_batches[batch_index]["E"][0][0]
ref_e

In [None]:
ref_f = valid_batches[batch_index]["F"][:20]
refF = ref_f.flatten()

In [None]:
ase_atoms.set_positions(test_R)

In [None]:
ase_atoms.set_atomic_numbers(ref_physnet_atoms.get_atomic_numbers())

In [None]:
ase_atoms.get_potential_energy()

In [None]:
mF = np.array(ase_atoms.get_forces()).flatten()

In [None]:
ks = [
 'dH',
 'energy',
 'forces',
 'internal_E',
 'internal_F',
 'ml_2b_E',
 'ml_2b_F',
 'mm_E',
 'mm_F']
di = {}
for k in ks:
    di[k] = dict(ase_atoms.calc.results)["out"].__getattribute__(k).flatten()

In [None]:
dir(dict(ase_atoms.calc.results)["out"])
di

In [None]:
view_atoms(ref_physnet_atoms)

In [None]:
plt.scatter(refF, ref_physnet_F)

In [None]:
plt.scatter(mF,refF)

In [None]:
plt.scatter(refF,di["mm_F"])

In [None]:
plt.scatter(mF,di["mm_F"])

In [None]:
np.array(dict(ase_atoms.calc.results)["out"].internal_E)*FACTOR + dict(ase_atoms.calc.results)["out"].mm_E

In [None]:
np.array(dict(ase_atoms.calc.results)["out"].internal_F)

In [None]:
view_atoms(ase_atoms)

In [None]:
setupRes.generate_residue("ACO ACO")
ic.build()
coor.show()

In [None]:
test_R

In [None]:
xyz = pd.DataFrame(test_R, columns=["x", "y", "z"])
coor.set_positions(xyz)

In [None]:
# energy.show()

In [None]:
# coor.show()

In [None]:
nbonds = """!#########################################
! Bonded/Non-bonded Options & Constraints
!#########################################

! Non-bonding parameters
nbonds atom cutnb 14.0  ctofnb 12.0 ctonnb 10.0 -
vswitch NBXMOD 3 -
inbfrq -1 imgfrq -1
"""
pycharmm.lingo.charmm_script(nbonds)

In [None]:
pycharmm_quiet()
energy.show()

In [None]:
energy.get_term_by_name("VDW")

In [None]:
energy.get_term_by_name("ELEC")

In [None]:
mmml.pycharmmInterface.import_pycharmm.view_pycharmm_state()

In [None]:
pycharmm_verbose()
energy.show()

In [None]:
CLEAR_CHARMM()

In [None]:
atoms = setupRes.main("ACO")

In [None]:
atoms = setupRes.generate_coordinates()
_ = setupRes.coor.get_positions()
atoms.set_positions(_)
reset_block()
reset_block_no_internal()
reset_block()

In [None]:
mmml.pycharmmInterface.import_pycharmm.view_pycharmm_state()

In [None]:
atoms = setupRes.generate_coordinates()
_ = setupRes.coor.get_positions()
atoms.set_positions(_)
reset_block()
reset_block_no_internal()
reset_block()


In [None]:
mmml.pycharmmInterface.import_pycharmm.view_pycharmm_state()

In [None]:
energy.show()

In [None]:
energy.get_term_by_name("VDW")

In [None]:
pycharmm.lingo.get_energy_value("ENER")

# Example: packmol for a dimer system

In [None]:
run_packmol(20, 30)

In [None]:
def CLEAR_CHARMM():
    s = """DELETE ATOM SELE ALL END"""
    pycharmm.lingo.charmm_script(s)
    s = """DELETE PSF SELE ALL END"""
    pycharmm.lingo.charmm_script(s)

CLEAR_CHARMM()

In [None]:
reset_block()
reset_block_no_internal()
reset_block()
mmml.pycharmmInterface.import_pycharmm.view_pycharmm_state()

In [None]:
# pycharmm_verbose()

In [None]:
# pycharmm.lingo.charmm_script(nbonds)

# #equivalent CHARMM scripting command: minimize abnr nstep 1000 tole 1e-3 tolgr 1e-3
# minimize.run_abnr(nstep=1000, tolenr=1e-1, tolgrd=1e-1)
# #equivalent CHARMM scripting command: energy
# energy.show()

In [None]:
mmml.pycharmmInterface.import_pycharmm.view_pycharmm_state()

In [None]:
ase_atoms.set_positions(coor.get_positions())
view_atoms(ase_atoms)

In [None]:
cs

In [None]:
ase_atoms.get_potential_energy()

In [None]:
ase_atoms.calc.results #["out"]

In [None]:
ase_atoms.calc.results["out"].mm_E

In [None]:
# Run structure optimization with BFGS.
_ = ase_opt.BFGS(ase_atoms).run(fmax=0.001)

In [None]:
view_atoms(ase_atoms)

In [None]:
# Parameters.
temperature = 10
timestep_fs = 0.1
num_steps = 300

# Draw initial momenta.
MaxwellBoltzmannDistribution(ase_atoms, temperature_K=temperature)
Stationary(ase_atoms)  # Remove center of mass translation.
ZeroRotation(ase_atoms)  # Remove rotations.

# Initialize Velocity Verlet integrator.
integrator = VelocityVerlet(ase_atoms, timestep=timestep_fs*ase.units.fs)

# Run molecular dynamics.
frames = np.zeros((num_steps, len(ase_atoms), 3))
potential_energy = np.zeros((num_steps,))
kinetic_energy = np.zeros((num_steps,))
total_energy = np.zeros((num_steps,))
for i in range(num_steps):
  # Run 1 time step.
  integrator.run(1)
  # Save current frame and keep track of energies.
  frames[i] = ase_atoms.get_positions()
  potential_energy[i] = ase_atoms.get_potential_energy()
  kinetic_energy[i] = ase_atoms.get_kinetic_energy()
  total_energy[i] = ase_atoms.get_total_energy()
  # Occasionally print progress.
  if i % 100 == 0:
    print(f"step {i:5d} epot {potential_energy[i]: 5.3f} ekin {kinetic_energy[i]: 5.3f} etot {total_energy[i]: 5.3f}")

In [None]:
# # Visualize the structure with py3Dmol.
# view = py3Dmol.view()
# xyz = io.StringIO()
# ase_io.write(xyz, ase_atoms, format='xyz')
# view.addModel(xyz.getvalue(), 'xyz')
# view.setStyle({'stick': {'radius': 0.15}, 'sphere': {'scale': 0.25}})
# view.show()
# view.getModel().setCoordinates(frames[::100], 'array')
# view.animate({'loop': 'forward', 'interval': 0.1})
# view.show() 

In [None]:
%matplotlib inline
plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
plt.xlabel('time [fs]')
plt.ylabel('energy [eV]')
time = np.arange(num_steps) * timestep_fs
plt.plot(time, potential_energy, label='potential energy')
plt.plot(time, kinetic_energy, label='kinetic energy')
plt.plot(time, total_energy, label='total energy')
plt.legend()
plt.grid()

In [None]:
view_atoms(ase_atoms)