# Graph network model denoiser for denoising thermal perturbation

This notebook will guide you through training and applying a graph network model for denoising thermal perturbation in crystal structures.

In [1]:
import numpy as np
import torch
import ase.io

## Training

### Create lightning datamodule and module instances

A lot of details are hidden away for your convenience using PyTorch Lightning. They include:
- How the data is loaded, processed, batched, and so on.
- Denoiser model definition.
- How the model is trained, which optimizer, etc.

These implementations can be found at `./lit/` if you are interested.

In [2]:
from lit.datamodules import PeriodicStructureDataModule
from lit.modules import LitEquivariantNoiseNet

# Read ideal structures
file_list = [
    './data/Cu/ideal-bcc.dump',
    './data/Cu/ideal-fcc.dump',
    './data/Cu/ideal-hcp.dump'
]

# Datamodule
datamodule = PeriodicStructureDataModule(
    file_list    = file_list, 
    large_cutoff = 5.0,  # a large cutoff to precompute long bonds/edges
    duplicate    = 256,
    batch_size   = 8,
    num_workers  = 4,
)

# New model instance
# noise_net = LitEquivariantNoiseNet(
#     num_species   = 1, 
#     node_dim      = 64, 
#     ff_dim        = 64, 
#     init_edge_dim = 1, 
#     edge_dim      = 64, 
#     num_heads     = 4, 
#     num_layers    = 4,
#     sigma_max     = 0.3,  # maximum level of noise to be applied to ideal crystals 
#     cutoff        = 3.2,  # a smaller cutoff for downselecting the precomputed edges
#     learn_rate    = 1e-4, 
# )

# Load model weights from a saved checkpoint
noise_net = LitEquivariantNoiseNet.load_from_checkpoint(
    './lit_logs/equivariant-noise-net/version_0/checkpoints/epoch=1149-step=100000.ckpt'
)

### Start training session

Skip this section is you loaded a pre-trained model ready for inference.

In [None]:
import lightning as L
from lightning.pytorch.callbacks import TQDMProgressBar
from lightning.pytorch.loggers import TensorBoardLogger

trainer = L.Trainer(
    max_steps = 100_000,
    logger    = TensorBoardLogger(save_dir='./lit_logs/', name='equivariant-noise-net'),
    callbacks = [TQDMProgressBar(refresh_rate=10)],
)

trainer.fit(
    noise_net, datamodule,
    # ckpt_path = './lit_logs/...',  # specify this if you are resuming from a previous training session
)

## Inference

### Denoise test data

In the case that the model was trained with Cu systems, it is still applicable to systems of other elements. You just have to scale the input data to match the atomic radius or the interatomic distance of the ideal Cu lattices. I used the following scaling factors:
- In-house data
    - Cu: 1.0
    - Ta: 0.82
- Benchmark data from Chung et al. (*"Data-centric framework for crystal structure identification in atomistic simulations using machine learning"*)
    - Al (FCC): 0.9
    - Ar (FCC): 0.65
    - Fe (BCC): 0.9
    - Li (BCC): 0.75
    - Mg (HCP): 0.75
    - Ti (HCP): 0.8

### Helper function for denoising

In [3]:
from sklearn.preprocessing import LabelEncoder
from tqdm.notebook import trange
from graphite.nn import periodic_radius_graph

@torch.no_grad()
def denoise_snapshot(atoms, model, scale=1.0, steps=8):
    # Convert to PyG format
    x = LabelEncoder().fit_transform(atoms.numbers)
    species = torch.tensor(x,                   dtype=torch.long)
    pos     = torch.tensor(atoms.positions,     dtype=torch.float)
    cell    = torch.tensor(atoms.cell.tolist(), dtype=torch.float)

    # Scale
    pos  *= scale
    cell *= scale
    
    # Denoising trajectory
    pos_traj = [atoms.positions]    
    for _ in trange(steps):
        edge_index, edge_vec = periodic_radius_graph(pos, r=3.2, cell=cell)
        edge_len = torch.linalg.norm(edge_vec, dim=1, keepdim=True)
        _, disp = model(species, edge_index, edge_attr=edge_len, edge_vec=edge_vec)
        pos -= disp
        pos_traj.append(pos.clone().numpy() / scale)
    
    return pos_traj

### Denoise a perturbed FCC snapshot

You can also apply the model to other structures. Just change `TEST_FNAME` to a different structure file.

In [6]:
from pathlib import Path
from scipy.spatial.transform import Rotation

TEST_FNAME = './data/Cu/3400K-fcc.dump'
path = Path(TEST_FNAME)

# Read perturbed data
noisy_atoms = ase.io.read(path)

# Apply rotation to test equivariance
R = Rotation.random().as_matrix()
noisy_atoms.positions @= R
noisy_atoms.cell      @= R

# Get denoising trajectory (positions only)
noise_net.eval(); noise_net.to('cpu')
pos_traj = denoise_snapshot(noisy_atoms, noise_net.model, scale=1.0, steps=8)

# Construct the denoising trajectory including other information (chemical symbols, cell dimenisons, etc.)
denoising_traj = [
    ase.Atoms(
        symbols   = noisy_atoms.get_chemical_symbols(),
        positions = pos,
        cell      = noisy_atoms.cell,
        pbc       = True
    )
    for pos in pos_traj
]

# Wrap the atoms into the simulaiton box
for atoms in denoising_traj: atoms.wrap()

# Save the denoising steps
ase.io.write(path.with_suffix('.denoised_.extxyz'), denoising_traj)

  0%|          | 0/8 [00:00<?, ?it/s]