In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pytraj as pt

import jax.numpy as jnp

from moldex.descriptors import MMElectrostaticPotential
from moldex.pytraj import data_for_elecpot

We load a trajectory using pytraj.

We do not load it in memory in order to mimic a real case scenario.

In [3]:
traj = pt.iterload('data/cla_meoh.nc', top='data/cla_meoh.parm7')

traj

pytraj.TrajectoryIterator, 10 frames: 
Size: 0.005856 (GB)
<Topology: 26201 atoms, 4345 residues, 4345 mols, PBC with box type = octahedral>
           

In order to instantiate the `MMElectrostaticPotential` class we need several ingredients: the indices of the atoms in the QM and MM parts, the charges of the MM part, and an array that indicates the ID of the first atom of each residue (i.e., where the residue starts), with a final number equal to the number of MM atoms. 

This would be boring to write at hand, so we have written a helper to speed things up

In [4]:
qm_indices, mm_indices, mm_charges, residues_array = data_for_elecpot(top=traj.top, qm_mask=':CLA')

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


At this point we can instantiate the descriptor. We also need to provide a cutoff: MM *residues* within a cutoff distance from the QM part are included in the calculation of the descriptor.

In [5]:
elecpot = MMElectrostaticPotential(
    qm_indices=qm_indices,
    mm_indices=mm_indices,
    mm_charges=mm_charges,
    residues_array=residues_array,
    cutoff=10.0, # Angstrom
)

The descriptor is then computed along the trajectory as follows:

In [6]:
desc = []
for frame in pt.iterframe(traj):
    p = elecpot.compute(frame.xyz)
    desc.append(p)
    
desc = jnp.asarray(desc)

NOTE: this descriptor uses custom C code interfaced to JAX. We don't have (yet?) the corresponding CUDA code, so we can only run it on CPU. If you are running on CPU, wrap the call to the descriptor (`compute`) in this context manager to run that portion of code on CPU, e.g.:

```python
with jax.default_device(jax.devices("cpu")[0]):
    desc = []
    for frame in pt.iterframe(traj):
        p = elecpot.compute(frame.xyz)
        desc.append(p)
        
    desc = jnp.asarray(desc)
```

In [7]:
desc

Array([[-0.12089559, -0.03347979, -0.02932582, ..., -0.06089249,
        -0.02984913, -0.04853807],
       [-0.13022247, -0.03419513, -0.00462489, ...,  0.01664686,
        -0.00511678,  0.01491921],
       [-0.16212597,  0.00224037, -0.05976309, ..., -0.01647541,
        -0.03034896, -0.01643274],
       ...,
       [-0.09233882,  0.02434437, -0.00114283, ...,  0.01889373,
         0.00490391,  0.00378224],
       [-0.16683061, -0.01922448, -0.00091998, ..., -0.02384871,
        -0.02185269, -0.03518623],
       [-0.13999091,  0.04262209,  0.00030935, ...,  0.01385698,
        -0.0386499 , -0.01524869]], dtype=float32)

You can visualize the cut in a PDB file using another helper.

The ML and environment parts are distinguished by their residue name: ML and ENV, respectively.

In [9]:
from moldex.pytraj import visualize_cut_pdb

visualize_cut_pdb(elecpot, frame.xyz, traj.top)

In [10]:
! rm ml_env_cut.pdb