# Score dynamics

This notebook will guide you through (conditional) diffusion model training and inference for iteratively generating future molecular configurations based on an initial condition.

In [1]:
import torch
import numpy as np

## Data storage format

Each molecular trajecory data is stored as an HDF5 file with at least the `pos` *dataset* array of shape [T, N, 3], where T is the number of timesteps, and N is the number of particles. Additionally, the `pos` dataset has the following attributes:
- `symbols`: atomic symbols
- `numbers`: atomic numbers
- `timestep`: timestep size (in picosecond) of the MD trajectory

For example, you can retrieve the above information using the following code block.

A toy-size version of the alanine dipeptide trajectory is stored in this repo. Remember to install `h5py` with `pip install h5py`.

In [2]:
import h5py

f = h5py.File('./data/ala-dipep-minidataset.hdf5')
pos = f['pos']
print(pos.shape)
print(pos.attrs.keys())
print(pos.attrs['symbols'])
print(pos.attrs['numbers'])
print(pos.attrs['timestep'])

(1000, 22, 3)
<KeysViewHDF5 ['numbers', 'symbols', 'timestep']>
['CH' 'HH' 'HH' 'HH' 'C' 'O' 'N' 'H' 'CA' 'HA' 'CB' 'HB' 'HB' 'HB' 'C' 'O'
 'N' 'H' 'CH' 'HH' 'HH' 'HH']
[6 1 1 1 6 8 7 1 6 1 6 1 1 1 6 8 7 1 6 1 1 1]
10.0


## Training

### Create lightning datamodule and module instances

A lot of details are hidden away for your convenience using PyTorch Lightning. They include:
- How the data is loaded, processed, batched, and so on.
- Model definition.
- How the model is trained, which optimizer, etc.

These implementations can be found at `./lit/` if you are interested.

In [3]:
from lit.datamodules import MolTrajDataModule
from lit.modules import LitNoiseNet

from glob import glob

file_list = ['./data/ala-dipep-minidataset.hdf5']

# Datamodule
datamodule = MolTrajDataModule(
    file_list   = file_list,
    interval    = 1,    # number of timesteps for producing displacements, i.e., displacement = pos[n+interval] - pos[n] 
    scale       = 2.0,  # scaling factor to be multiplied to atomic coordinates (this affects quality of the trained diffusion model)
    batch_size  = 128,
    num_workers = 4,
)

# New model instance
# noise_net = LitNoiseNet(
#     num_species = 10,
#     num_convs   = 5,
#     dim         = 200,
#     out_dim     = 3,
#     cutoff      = 4.0 * 2.0,  # cutoff radius multiplied by scaling factor
#     ema_decay   = 0.9999,
#     learn_rate  = 1e-4,
# )

# Load model weights from a saved checkpoint
noise_net = LitNoiseNet.load_from_checkpoint(
    './lit_logs/ala-dipep-r4-S2-nvt-rand-tsize-240k/version_2/checkpoints/epoch=3404-step=800000.ckpt'
)

Lightning automatically upgraded your loaded checkpoint from v1.6.5 to v2.2.4. To apply the upgrade to your files permanently, run `python -m lightning.pytorch.utilities.upgrade_checkpoint lit_logs/ala-dipep-r4-S2-nvt-rand-tsize-240k/version_2/checkpoints/epoch=3404-step=800000.ckpt`


### Start training session

Skip this section if you loaded a pre-trained model ready for inference.

In [None]:
import lightning as L
from lightning.pytorch.callbacks import TQDMProgressBar
from lightning.pytorch.loggers import TensorBoardLogger

trainer = L.Trainer(
    max_steps = 600_000,
    logger    = TensorBoardLogger(save_dir='./lit_logs/', name='ala-dipep-test'),
    callbacks = [TQDMProgressBar(refresh_rate=10)],
)

trainer.fit(
    noise_net, datamodule,
    # ckpt_path = './lit_logs/...',  # specify this if you are resuming from a previous training session
)

## Roll-out from a given initial configuration

### Helper functions

In [4]:
from ase.neighborlist import primitive_neighbor_list

def ase_radius_graph(pos, cutoff, numbers, cell=np.diag([1., 1., 1.]), pbc=[False, False, False]):
    i, j = primitive_neighbor_list('ij', positions=pos, cell=cell, cutoff=cutoff, pbc=pbc, numbers=numbers)
    i = torch.tensor(i, dtype=torch.long)
    j = torch.tensor(j, dtype=torch.long)
    edge_index = torch.stack([i, j])
    return edge_index

def ala_dipep_chirality(pos):
    """
    Ad-hoc chirality check for alanine dipeptide.
    Not applicable to other systems.
    """
    HA_CA = pos[8]  - pos[9]
    CA_CB = pos[10] - pos[8]
    CA_N  = pos[6]  - pos[8]
    return torch.linalg.cross(HA_CA, CA_CB).dot(CA_N)

def has_bad_structure(data, scale=1.0):
    """
    This function is ad-hoc for alanine dipeptides and alkanes.
    Not applicable to other systems.
    """
    s = scale

    # Check C-H bond lengths
    # i, j = data.CH_index; bnd_len = (data.pos[j]-data.pos[i]).norm(dim=1)
    # if torch.bitwise_or(bnd_len < 1.02*s, bnd_len > 1.16*s).any(): print(f'Step {n}: bad C-H bond length'); return True

    # Check C-C bond lengths
    # i, j = data.CC_index; bnd_len = (data.pos[j]-data.pos[i]).norm(dim=1)
    # if torch.bitwise_or(bnd_len < 1.38*s, bnd_len > 1.70*s).any(): print(f'Step {n}: bad C-C bond length'); return True

    # Check X-H bond lengths (X is any atom species)
    i, j = data.XH_index; bnd_len = (data.pos[j]-data.pos[i]).norm(dim=1)
    if torch.bitwise_or(bnd_len < 0.9*s, bnd_len > 1.16*s).any(): print(f'Step {n}: bad X-H bond length'); return True

    # Check CNO bond lengths (any bonds between {C, N, O})
    i, j = data.CNO_index; bnd_len = (data.pos[j]-data.pos[i]).norm(dim=1)
    if torch.bitwise_or(bnd_len < 1.1*s, bnd_len > 1.70*s).any(): print(f'Step {n}: bad CNO bond length'); return True

    # Check chirality of alanine dipeptide
    if ala_dipep_chirality(data.pos) < 0.0: print(f'Step {n} flipped chirality'); return True

    return False

### Sample an initial structure

In [5]:
datamodule.setup()
data = datamodule.dataset[0]

### Keep track of certain bonds (for validation)

In [6]:
SCALE = 2.0
numbers = pos.attrs['numbers']

data.XH_index  = ase_radius_graph(data.pos.numpy(), cutoff={('H', 'C'): 1.74*SCALE, ('H', 'N'): 1.65*SCALE,  ('H', 'O'): 1.632*SCALE}, numbers=numbers)
data.CNO_index = ase_radius_graph(data.pos.numpy(), cutoff={('C', 'C'): 2.04*SCALE, ('C', 'N'): 1.95*SCALE,  ('C', 'O'): 1.932*SCALE}, numbers=numbers)
# data.CH_index  = ase_radius_graph(data.pos.numpy(), cutoff={('H', 'C'): 1.74*SCALE}, numbers=numbers)
# data.CC_index  = ase_radius_graph(data.pos.numpy(), cutoff={('C', 'C'): 2.04*SCALE}, numbers=numbers)

### Prepare for rollout

In [7]:
NUM_STEPS = 1_000
rollout_pos = []
bad_pos = []
n = 0

diffuser = noise_net.diffuser
molecular_graph = noise_net._molecular_graph
data = data.to('cuda')

# Torchscript optimization
model = torch.jit.script(noise_net.ema_model.module).to('cuda')

### Rollout!

In [8]:
import functools
from tqdm.notebook import tqdm

with torch.no_grad():
    with tqdm(total=NUM_STEPS, desc='Rollout') as pbar:
        pbar.n = n; pbar.refresh()
        while n <= NUM_STEPS:
            # Update graph data
            data = molecular_graph(data, cutoff=4.0*SCALE)

            # Catch bad structures
            if has_bad_structure(data, scale=SCALE):
                bad_pos.append(data.pos.cpu().numpy())
                del rollout_pos[-8:]
                n -= 8
                pbar.n = n; pbar.refresh()
                data.pos = rollout_pos[-1].clone()
                continue
            else:                
                rollout_pos.append(data.pos.clone())
                n += 1
                pbar.update(1)

            # Generate displacements
            x_T = torch.randn_like(data.pos)
            noise_model = functools.partial(model, x_atm=data.z, bnd_index=data.edge_index, x_bnd=data.edge_attr)
            xs = diffuser.reverse_denoise(x_T, noise_model, diffuser.solver3, M=20)

            # Apply displacements
            dx = xs[-1]
            data.pos += dx

Rollout:   0%|          | 0/1000 [00:00<?, ?it/s]

### Save the roll-out

In [9]:
import ase, ase.io

rollout_traj = [
    ase.Atoms(numbers=numbers, positions=pos/SCALE, cell=np.diag([30., 30., 30]))
    for pos in torch.stack(rollout_pos).cpu().numpy()
]
for atoms in rollout_traj: atoms.center()
ase.io.write('./ala-dipep-rollout.extxyz', rollout_traj)