In [1]:
import os
import json
import ase
from glob import glob
from ase import Atoms
from tqdm import tqdm
import numpy as np
from collections import defaultdict

import torch

from atomdiff.models.prior import LitScoreNet
from atomdiff.datasets import *

from graphite.nn import periodic_radius_graph

In [2]:
data_module = StructureDataModule(
    data_dir='./data/dummy/',
    cutoff=5.0,
    train_prior=True,
    k=0.8,
    train_size=0.9,
    scale_y=1.0,
    dup=1,
    batch_size=1,
    num_workers=4
)

score_net = LitScoreNet.load_from_checkpoint(
    './training_logs/stem-vasp-dup128/version_4/checkpoints/epoch=999-step=1250000.ckpt'
)

data_module.setup()
diffuser = data_module.train_set.diffuser

Getting PyG Data objects...


100%|██████████| 1/1 [00:00<00:00, 48.72it/s]
100%|██████████| 1/1 [00:00<00:00, 982.27it/s]


In [3]:
def prior_score(z, pos, cell, t, cutoff=data_module.cutoff):
    edge_index, edge_vec = periodic_radius_graph(pos, cutoff, cell)
    edge_len  = edge_vec.norm(dim=-1, keepdim=True)
    edge_attr = torch.hstack([edge_vec, edge_len])
    return score_net.ema_model(z, edge_index, edge_attr, t, diffuser.sigma(t))

def denoise_by_sde(z, pos, cell, score_fn, ts=torch.linspace(0.999, 0.001, 128+1)):
    ts = ts.to(pos.device).view(-1, 1)
    pos_traj = [pos.clone()]
    f, g, g2 = diffuser.f, diffuser.g, diffuser.g2
    for i, t in enumerate(ts[1:]):
        dt = ts[i+1] - ts[i]
        eps = dt.abs().sqrt() * torch.randn_like(pos)
        score = score_fn(z, pos, cell, t)
        disp = (f(t)*pos - g2(t)*score)*dt + g(t)*eps
        pos += disp
        pos_traj.append(pos.clone())
    return torch.stack(pos_traj)

### Get all validation structures

In [4]:
path = "./data/stem/vasp/"
cifs = glob(path + "*.vasp")

In [7]:
device = 'cuda'
score_net.ema_model.to(device)

save_path = "outputs/stem_vasp_ts_512/"

if not os.path.exists(save_path):
    os.makedirs(save_path)

for cif in tqdm(cifs, disable=False):
    sname = Path(cif).stem
    atoms = ase.io.read(cif)

    pos = torch.tensor(atoms.positions, dtype=torch.float, device=device)
    cell = torch.tensor(atoms.cell.array, dtype=torch.float, device=device)
    atomic_numbers = atoms.numbers

    # random positions
    random_pos = torch.rand_like(pos, device=device) @ cell
    random_pos_clone = random_pos.clone().detach()
    
    # initialize z 
    z = torch.tensor(atoms.numbers, device=device)
    z = data_module.train_set.atom_encoder(z)
    z = z.float()

    # denoise
    pos_traj = denoise_by_sde(z, random_pos, cell, prior_score, ts=torch.linspace(0.999, 0.001, 512+1))
    final_pos = pos_traj[-1]

    # save traj
    denoise_traj = [
        ase.Atoms(
            numbers=atomic_numbers,
            positions=each.detach().cpu().numpy(),
            cell=cell.cpu().numpy(),
            pbc=[True] * 3
        )
        for each in pos_traj
    ]

    ase.io.write(f"{save_path}{sname}.extxyz", denoise_traj)


    # save final structure
    atoms = ase.Atoms(
        numbers=atomic_numbers,
        positions=final_pos.detach().cpu().numpy(),
        cell=cell.cpu().numpy(),
        pbc=[True] * 3
    )

    ase.io.write(f"{save_path}{sname}_final.cif", atoms)

    # save_initial_structure
    atoms = ase.Atoms(
        numbers=atomic_numbers,
        positions=random_pos_clone.cpu().numpy(),
        cell=cell.cpu().numpy(),
        pbc=[True] * 3
    )

    ase.io.write(f"{save_path}{sname}_initial.cif", atoms)

100%|██████████| 625/625 [17:26<00:00,  1.67s/it]


In [None]:
!nvidia-smi

In [None]:
# Get detailed memory statistics
stats = torch.cuda.memory_stats()

# Print detailed memory statistics
for key, value in stats.items():
    print(f"{key}: {value}")

In [None]:
cifs[:5]