! uv pip install jax[cuda12] --force-reinstall
# Imports

In [1]:
import mmml
import ase
import os
from pathlib import Path
# Set environment variables
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".99"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


# from jax import config
# config.update('jax_enable_x64', True)

# Check JAX configuration
import jax
devices = jax.local_devices()
print(devices)
print(jax.default_backend())
print(jax.devices())


[CudaDevice(id=0)]
gpu
[CudaDevice(id=0)]


In [2]:
import os, jax, jax.numpy as jnp
print("LD_LIBRARY_PATH:", os.environ.get("LD_LIBRARY_PATH"))
x = jnp.ones((1024,1024), dtype=jnp.float32)
y = jnp.ones((1024,1024), dtype=jnp.float32)
print((x@y).block_until_ready().shape)

LD_LIBRARY_PATH: None
(1024, 1024)


In [3]:
from mmml.physnetjax import *

In [4]:
from mmml.physnetjax.physnetjax.calc.helper_mlp import get_ase_calc


In [5]:
?get_ase_calc

[31mSignature:[39m
get_ase_calc(
    params,
    model,
    ase_mol,
    conversion={[33m'energy'[39m: [32m1[39m, [33m'forces'[39m: [32m1[39m, [33m'dipole'[39m: [32m1[39m},
    implemented_properties=[[33m'energy'[39m, [33m'forces'[39m, [33m'dipole'[39m],
)
[31mDocstring:[39m
Ase calculator implementation for physnetjax model

Args:
params: params of the physnetjax model
model: physnetjax model
ase_mol: ase molecule
conversion: conversion factor for the energy, forces, and dipole
implemented_properties: implemented properties for the ase calculator

Returns:
Ase calculator implementation for physnetjax model
[31mFile:[39m      ~/mmml/mmml/physnetjax/physnetjax/calc/helper_mlp.py
[31mType:[39m      function

In [6]:
from mmml.physnetjax.physnetjax.models import model as model
from mmml.physnetjax.physnetjax.models.model import EF
from mmml.physnetjax.physnetjax.training.training import train_model


# Data

In [7]:
from mmml.physnetjax.physnetjax.data.data import prepare_datasets
from mmml.physnetjax.physnetjax.data.batches import prepare_batches_jit

In [8]:
?prepare_datasets

[31mSignature:[39m
prepare_datasets(
    key,
    train_size=[32m0[39m,
    valid_size=[32m0[39m,
    files=[38;5;28;01mNone[39;00m,
    clean=[38;5;28;01mFalse[39;00m,
    esp_mask=[38;5;28;01mFalse[39;00m,
    clip_esp=[38;5;28;01mFalse[39;00m,
    natoms=[32m60[39m,
    verbose=[38;5;28;01mFalse[39;00m,
    subtract_atom_energies=[38;5;28;01mFalse[39;00m,
    subtract_mean=[38;5;28;01mFalse[39;00m,
)
[31mDocstring:[39m
Prepare datasets for training and validation.

Args:
    key: Random key for dataset shuffling.
    num_train (int): Number of training samples.
    num_valid (int): Number of validation samples.
    filename (str or list): Filename(s) to load datasets from.

Returns:
    tuple: A tuple containing train_data and valid_data dictionaries.
[31mFile:[39m      ~/mmml/mmml/physnetjax/physnetjax/data/data.py
[31mType:[39m      function

In [9]:
jax.random.key(1)

Array((), dtype=key<fry>) overlaying:
Array[2] u32 μ=0.500 σ=0.500 gpu:0 [0, 1]

In [10]:
data_key, train_key = jax.random.split(jax.random.PRNGKey(42), 2)
BATCHSIZE = 16

In [11]:
data_paths = list(Path("/scicore/home/meuwly/boitti0000/").glob("b*npz"))
data_paths

[Path('/scicore/home/meuwly/boitti0000/beta-diketones_71208.npz')]

In [12]:
import numpy as np
np.load(data_paths[0])["Z"]

array([[6, 6, 1, ..., 0, 0, 0],
       [6, 6, 1, ..., 1, 0, 0],
       [6, 6, 6, ..., 0, 0, 0],
       ...,
       [8, 6, 6, ..., 0, 0, 0],
       [6, 6, 6, ..., 0, 0, 0],
       [8, 6, 6, ..., 0, 0, 0]], shape=(71208, 17))

In [13]:
71208 * .2

14241.6

In [14]:
BATCHSIZE

16

In [15]:
files = data_paths
train_size = 59000 
valid_size = 1400
NATOMSMAX = 17

train_data, valid_data = prepare_datasets(data_key, train_size, valid_size, files, natoms=NATOMSMAX)

dataR (71208, 17, 3)
dataE [-54.02564528 -60.37702056 -34.56810181 -34.74576462 -41.01765859
 -40.48231955 -53.06176803 -35.04941783 -27.50888271 -40.22177633]
dataE [-54.02564528 -60.37702056 -34.56810181 -34.74576462 -41.01765859
 -40.48231955 -53.06176803 -35.04941783 -27.50888271 -40.22177633]
D (71208, 3)
Q 1 (71208,) 71208
Q (71208,)


In [16]:
?EF

[31mInit signature:[39m
EF(
    features: int = [32m32[39m,
    max_degree: int = [32m3[39m,
    num_iterations: int = [32m2[39m,
    num_basis_functions: int = [32m16[39m,
    cutoff: float = [32m6.0[39m,
    max_atomic_number: int = [32m118[39m,
    charges: bool = [38;5;28;01mFalse[39;00m,
    natoms: int = [32m60[39m,
    total_charge: float = [32m0[39m,
    n_res: int = [32m3[39m,
    zbl: bool = [38;5;28;01mTrue[39;00m,
    debug: Union[bool, List[str]] = [38;5;28;01mFalse[39;00m,
    efa: bool = [38;5;28;01mFalse[39;00m,
    use_energy_bias: bool = [38;5;28;01mTrue[39;00m,
    parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object at [32m0x14881447bf20[39m>,
    name: Optional[str] = [38;5;28;01mNone[39;00m,
) -> [38;5;28;01mNone[39;00m
[31mDocstring:[39m     
Energy and Forces Neural Network Model.

A neural network model that predicts molecular energies an

In [17]:
valid_batches = prepare_batches_jit(data_key, valid_data, BATCHSIZE, num_atoms = NATOMSMAX)

In [18]:
model = EF(
    features=32,
    max_degree = 1,
    num_basis_functions=32,
    num_iterations = 2,
    n_res=4,
    cutoff = 8.0,
    max_atomic_number = 40,
    zbl=False,
    efa=False,
    charges=True,
)
model

EF(
    # attributes
    features = 32
    max_degree = 1
    num_iterations = 2
    num_basis_functions = 32
    cutoff = 8.0
    max_atomic_number = 40
    charges = True
    natoms = 60
    total_charge = 0
    n_res = 4
    zbl = False
    debug = False
    efa = False
    use_energy_bias = True
)

## Training

In [None]:
uid = "test-84aa02d9-e329-46c4-b12c-f55e6c9a2f94"
SCICORE = Path('/scicore/home/meuwly/boitti0000/ckpts')
params_out = train_model(
    train_key,
    model,
    train_data,
    valid_data, 
    num_epochs = 6300,
    learning_rate=0.0005,
    batch_size=BATCHSIZE,
    num_atoms=NATOMSMAX,
    energy_weight=100,
    restart=str(SCICORE / f"{uid}"),
    conversion={'energy': 1, 'forces': 1},
    print_freq=1,
    name='test',
    best=False,
    optimizer=None,
    transform=None,
    schedule_fn="constant",
    objective='valid_loss',
    ckpt_dir=SCICORE,
    log_tb=False,
    batch_method="default",
    batch_args_dict=None,
    data_keys=('R', 'Z', 'F', "N", 'E', 'D', 'batch_segments'),
    
)

Using default (fat) batching method


Extra Validation Info:
Z: Array[1400, 17] i32 n=23800 (93Kb) x∈[0, 8] μ=2.022 σ=2.707 cpu:0
R: Array[1400, 17, 3] n=71400 (0.3Mb) x∈[-3.429, 3.382] μ=0.004 σ=0.894 cpu:0
E: Array[1400, 1] 5.5Kb x∈[-67.615, -9.490] μ=-42.597 σ=13.520 cpu:0
N: Array[1400, 1] i32 5.5Kb x∈[3, 17] μ=10.344 σ=3.098 cpu:0
F: Array[1400, 17, 3] n=71400 (0.3Mb) x∈[-11.194, 11.759] μ=-2.838e-10 σ=1.317 cpu:0
D: Array[1400, 3] n=4200 (16Kb) x∈[-1.076, 0.854] μ=-0.010 σ=0.272 cpu:0


ERROR:asyncio:Exception in callback Task.__step()
handle: <Handle Task.__step()>
Traceback (most recent call last):
  File "/scicore/home/meuwly/boitti0000/.conda/envs/mmml-gpu/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
RuntimeError: cannot enter context: <_contextvars.Context object at 0x14949e83bb40> is already entered
ERROR:asyncio:Exception in callback Task.__step()
handle: <Handle Task.__step()>
Traceback (most recent call last):
  File "/scicore/home/meuwly/boitti0000/.conda/envs/mmml-gpu/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
RuntimeError: cannot enter context: <_contextvars.Context object at 0x14949e83bb40> is already entered
ERROR:asyncio:Exception in callback Task.__step()
handle: <Handle Task.__step()>
Traceback (most recent call last):
  File "/scicore/home/meuwly/boitti0000/.conda/envs/mmml-gpu/lib/python3.12/asyncio/events.py", line 88, in _run
    s

dict_keys(['opt_state', 'params', 'step'])


ERROR:asyncio:Task was destroyed but it is pending!
task: <Task pending name='Task-168' coro=<_async_in_context.<locals>.run_in_context() running at /scicore/home/meuwly/boitti0000/.local/lib/python3.12/site-packages/ipykernel/utils.py:60> wait_for=<Task pending name='Task-2' coro=<Kernel.shell_main() running at /scicore/home/meuwly/boitti0000/.local/lib/python3.12/site-packages/ipykernel/kernelbase.py:590> cb=[Task.task_wakeup()]> cb=[ZMQStream._run_callback.<locals>._log_error() at /scicore/home/meuwly/boitti0000/.local/lib/python3.12/site-packages/zmq/eventloop/zmqstream.py:563]>
ERROR:asyncio:Task was destroyed but it is pending!
task: <Task pending name='Task-2' coro=<Kernel.shell_main() running at /scicore/home/meuwly/boitti0000/.local/lib/python3.12/site-packages/ipykernel/kernelbase.py:590> cb=[Task.task_wakeup()]>
ERROR:asyncio:Task was destroyed but it is pending!
task: <Task pending name='Task-25' coro=<Kernel.shell_main() running at /scicore/home/meuwly/boitti0000/.local/li

Restoring from /scicore/home/meuwly/boitti0000/ckpts/test-84aa02d9-e329-46c4-b12c-f55e6c9a2f94/epoch-5209
Restored keys: dict_keys(['best_loss', 'ema_params', 'epoch', 'lr_eff', 'model', 'model_attributes', 'objectives', 'opt_state', 'params', 'transform_state'])
Training resumed from step 5209, best_loss Array gpu:0 146.001




Output()

ERROR:asyncio:Exception in callback Task.__step()
handle: <Handle Task.__step()>
Traceback (most recent call last):
  File "/scicore/home/meuwly/boitti0000/.conda/envs/mmml-gpu/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
RuntimeError: cannot enter context: <_contextvars.Context object at 0x14949e83bb40> is already entered
ERROR:asyncio:Task was destroyed but it is pending!
task: <Task pending name='Task-2842' coro=<_async_in_context.<locals>.run_in_context() done, defined at /scicore/home/meuwly/boitti0000/.local/lib/python3.12/site-packages/ipykernel/utils.py:57> wait_for=<Task pending name='Task-2843' coro=<Kernel.shell_main() running at /scicore/home/meuwly/boitti0000/.local/lib/python3.12/site-packages/ipykernel/kernelbase.py:590> cb=[Task.__wakeup()]> cb=[ZMQStream._run_callback.<locals>._log_error() at /scicore/home/meuwly/boitti0000/.local/lib/python3.12/site-packages/zmq/eventloop/zmqstream.py:563]>
ERROR:asyncio:Task w

# Validation

In [None]:
from mmml.physnetjax.physnetjax.analysis.analysis import  *

output = mmml.physnetjax.physnetjax.analysis.analysis.eval(valid_batches, model, params_out, batch_size=BATCHSIZE)
Es, Eeles, predEs, Fs, predFs, Ds, predDs, charges, outputs = output
monomers_idx = Es > -45
dimers_idx = Es < -50

In [None]:
ax = plt.gca()
plot(Es[monomers_idx], predEs[monomers_idx], ax, units="eV", _property="", kde=False, s=1, diag=True)
plt.xlim(-42, -40)
plt.ylim(-42, -40)
plt.show()
ax = plt.gca()
plot(Es[dimers_idx], predEs[dimers_idx], ax, units="eV", _property="", kde=True, s=1, diag=True)

In [None]:
ax = plt.gca()
plot(Fs, predFs, ax, units="kcal/mol", _property="", kde=True, s=1, diag=True)

# Calculator

In [None]:
from mmml.pycharmmInterface.setupBox import *
import pycharmm
from mmml.pycharmmInterface.mmml_calculator import setup_calculator, CutoffParameters

In [None]:
?CutoffParameters

In [None]:
ATOMS_PER_MONOMER = 10
N_MONOMERS = 2

In [None]:
?setup_calculator

In [None]:
calculator_factory = setup_calculator(
    ATOMS_PER_MONOMER,
    N_MONOMERS,
    ml_cutoff_distance  = 2.0,
    mm_switch_on = 4.0,
    mm_cutoff  = 1.0,
    doML = True,
    doMM  = True,
    doML_dimer  = True,
    debug  = False,
    ep_scale = None,
    sig_scale = None,
    model_restart_path = "/home/ericb/mmml/mmml/physnetjax/ckpts/test-9af0d71b-4140-4d4b-83e3-ce07c652d048",
    MAX_ATOMS_PER_SYSTEM = 20,
)

In [None]:
from ase.visualize.plot import plot_atoms
from mmml.pycharmmInterface import import_pycharmm
from mmml.pycharmmInterface.import_pycharmm import  *

In [None]:
from mmml.pycharmmInterface import setupRes, setupBox
from mmml.pycharmmInterface.import_pycharmm import reset_block, reset_block_no_internal
from mmml.pycharmmInterface.pycharmmCommands import CLEAR_CHARMM

In [None]:
CLEAR_CHARMM()
reset_block()
reset_block_no_internal()
reset_block()
reset_block()
reset_block_no_internal()
reset_block()

In [None]:
initialize_psf("ACO", 2, 30, None)

In [None]:
# mmml.pycharmmInterface.import_pycharmm.view_pycharmm_state()

In [None]:
energy.show()

In [None]:
ase_atoms = ase_from_pycharmm_state()
R = ase_atoms.get_positions()
Z = ase_atoms.get_atomic_numbers()
R,Z

In [None]:
FACTOR = 1/(ase.units.kcal/ase.units.mol)
calcs = calculator_factory(atomic_numbers=Z, atomic_positions=R, n_monomers=2,
                           energy_conversion_factor=FACTOR, force_conversion_factor=FACTOR, 
                           debug=False
                          )

In [None]:
ase_atoms.calc = calcs[0]

In [None]:
ase_atoms.get_potential_energy()

In [None]:
pycharmm_quiet()
pycharmm_verbose()

In [None]:
fix_idxs = np.array(range(20))
_fix_idxs = np.array(range(20))
fix_idxs[0] = _fix_idxs[3]
fix_idxs[3] = _fix_idxs[0]
fix_idxs[10] = _fix_idxs[13]
fix_idxs[13] = _fix_idxs[10]
batch_index = 0

In [None]:
test_R = valid_batches[batch_index]["R"][:20][fix_idxs]

In [None]:
test_Z = valid_batches[batch_index]["Z"][:20]
test_Z, test_Z[fix_idxs]

In [None]:
model.natoms = 20
model

In [None]:
ref_physnet_atoms = ase.Atoms(test_Z[fix_idxs], test_R)
ref_physnet_atoms.calc = get_ase_calc(params_out, model, ref_physnet_atoms, {"energy": 1, "forces": 1}, ['energy', 'forces'])

In [None]:
ref_physnet_atoms.get_potential_energy()

In [None]:
ref_physnet_F = ref_physnet_atoms.get_forces().flatten()

In [None]:
ref_e = valid_batches[batch_index]["E"][0][0]
ref_e

In [None]:
ref_f = valid_batches[batch_index]["F"][:20]
refF = ref_f.flatten()

In [None]:
ase_atoms.set_positions(test_R)

In [None]:
ase_atoms.get_atomic_numbers()

In [None]:
ase_atoms.get_potential_energy()

In [None]:
mF = np.array(ase_atoms.get_forces()).flatten()

In [None]:
ks = [
 'dH',
 'energy',
 'forces',
 'internal_E',
 'internal_F',
 'ml_2b_E',
 'ml_2b_F',
 'mm_E',
 'mm_F']
di = {}
for k in ks:
    di[k] = dict(ase_atoms.calc.results)["out"].__getattribute__(k).flatten()

In [None]:
dir(dict(ase_atoms.calc.results)["out"])
di

In [None]:
view_atoms(ref_physnet_atoms)

In [None]:
plt.scatter(refF, ref_physnet_F)

In [None]:
plt.scatter(mF,refF)

In [None]:
plt.scatter(refF,di["mm_F"])

In [None]:
plt.scatter(mF,di["mm_F"])

In [None]:
np.array(dict(ase_atoms.calc.results)["out"].internal_E)*FACTOR + dict(ase_atoms.calc.results)["out"].mm_E

In [None]:
np.array(dict(ase_atoms.calc.results)["out"].internal_F)

In [None]:
view_atoms(ase_atoms)

In [None]:
setupRes.generate_residue("ACO ACO")
ic.build()
coor.show()

In [None]:
test_R

In [None]:
xyz = pd.DataFrame(test_R, columns=["x", "y", "z"])
coor.set_positions(xyz)

In [None]:
# energy.show()

In [None]:
# coor.show()

In [None]:
nbonds = """!#########################################
! Bonded/Non-bonded Options & Constraints
!#########################################

! Non-bonding parameters
nbonds atom cutnb 14.0  ctofnb 12.0 ctonnb 10.0 -
vswitch NBXMOD 3 -
inbfrq -1 imgfrq -1
"""
pycharmm.lingo.charmm_script(nbonds)

In [None]:
pycharmm_quiet()
energy.show()

In [None]:
energy.get_term_by_name("VDW")

In [None]:
energy.get_term_by_name("ELEC")

In [None]:
mmml.pycharmmInterface.import_pycharmm.view_pycharmm_state()

In [None]:
pycharmm_verbose()
energy.show()

In [None]:
CLEAR_CHARMM()

In [None]:
atoms = setupRes.main("ACO")

In [None]:
atoms = setupRes.generate_coordinates()
_ = setupRes.coor.get_positions()
atoms.set_positions(_)
reset_block()
reset_block_no_internal()
reset_block()

In [None]:
mmml.pycharmmInterface.import_pycharmm.view_pycharmm_state()

In [None]:
atoms = setupRes.generate_coordinates()
_ = setupRes.coor.get_positions()
atoms.set_positions(_)
reset_block()
reset_block_no_internal()
reset_block()


In [None]:
mmml.pycharmmInterface.import_pycharmm.view_pycharmm_state()

In [None]:
energy.show()

In [None]:
energy.get_term_by_name("VDW")

In [None]:
pycharmm.lingo.get_energy_value("ENER")

# Example: packmol for a dimer system

In [None]:
run_packmol(20, 30)

In [None]:
def CLEAR_CHARMM():
    s = """DELETE ATOM SELE ALL END"""
    pycharmm.lingo.charmm_script(s)
    s = """DELETE PSF SELE ALL END"""
    pycharmm.lingo.charmm_script(s)

CLEAR_CHARMM()

In [None]:
reset_block()
reset_block_no_internal()
reset_block()
mmml.pycharmmInterface.import_pycharmm.view_pycharmm_state()

In [None]:
# pycharmm_verbose()

In [None]:
# pycharmm.lingo.charmm_script(nbonds)

# #equivalent CHARMM scripting command: minimize abnr nstep 1000 tole 1e-3 tolgr 1e-3
# minimize.run_abnr(nstep=1000, tolenr=1e-1, tolgrd=1e-1)
# #equivalent CHARMM scripting command: energy
# energy.show()

In [None]:
mmml.pycharmmInterface.import_pycharmm.view_pycharmm_state()

In [None]:
ase_atoms.set_positions(coor.get_positions())
view_atoms(ase_atoms)

In [None]:
cs

In [None]:
ase_atoms.get_potential_energy()

In [None]:
ase_atoms.calc.results #["out"]

In [None]:
ase_atoms.calc.results["out"].mm_E

In [None]:
# Run structure optimization with BFGS.
_ = ase_opt.BFGS(ase_atoms).run(fmax=0.001)

In [None]:
view_atoms(ase_atoms)

In [None]:
# Parameters.
temperature = 10
timestep_fs = 0.1
num_steps = 300

# Draw initial momenta.
MaxwellBoltzmannDistribution(ase_atoms, temperature_K=temperature)
Stationary(ase_atoms)  # Remove center of mass translation.
ZeroRotation(ase_atoms)  # Remove rotations.

# Initialize Velocity Verlet integrator.
integrator = VelocityVerlet(ase_atoms, timestep=timestep_fs*ase.units.fs)

# Run molecular dynamics.
frames = np.zeros((num_steps, len(ase_atoms), 3))
potential_energy = np.zeros((num_steps,))
kinetic_energy = np.zeros((num_steps,))
total_energy = np.zeros((num_steps,))
for i in range(num_steps):
  # Run 1 time step.
  integrator.run(1)
  # Save current frame and keep track of energies.
  frames[i] = ase_atoms.get_positions()
  potential_energy[i] = ase_atoms.get_potential_energy()
  kinetic_energy[i] = ase_atoms.get_kinetic_energy()
  total_energy[i] = ase_atoms.get_total_energy()
  # Occasionally print progress.
  if i % 100 == 0:
    print(f"step {i:5d} epot {potential_energy[i]: 5.3f} ekin {kinetic_energy[i]: 5.3f} etot {total_energy[i]: 5.3f}")

In [None]:
# # Visualize the structure with py3Dmol.
# view = py3Dmol.view()
# xyz = io.StringIO()
# ase_io.write(xyz, ase_atoms, format='xyz')
# view.addModel(xyz.getvalue(), 'xyz')
# view.setStyle({'stick': {'radius': 0.15}, 'sphere': {'scale': 0.25}})
# view.show()
# view.getModel().setCoordinates(frames[::100], 'array')
# view.animate({'loop': 'forward', 'interval': 0.1})
# view.show() 

In [None]:
%matplotlib inline
plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
plt.xlabel('time [fs]')
plt.ylabel('energy [eV]')
time = np.arange(num_steps) * timestep_fs
plt.plot(time, potential_energy, label='potential energy')
plt.plot(time, kinetic_energy, label='kinetic energy')
plt.plot(time, total_energy, label='total energy')
plt.legend()
plt.grid()

In [None]:
view_atoms(ase_atoms)