In [1]:
!nvidia-smi

Fri Jun 14 06:41:58 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-PCI...  On   | 00000000:C3:00.0 Off |                    0 |
| N/A   35C    P0    35W / 250W |      2MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
from atomdiff.models.prior import LitScoreNet
from atomdiff.datasets import *

In [3]:
data_module = StructureMPDataModule(
    data_dir='data/mp-20-reduced/',
    cutoff=3.0,
    train_prior=True,
    k=0.8,
    train_size=0.9,
    scale_y=1.0,
    dup=1,
    batch_size=1,
    num_workers=4
)

# score_net = LitScoreNet(
#     num_species=82, 
#     num_convs=5, 
#     dim=200, 
#     ema_decay=0.9999, 
#     learn_rate=1e-3
# )

score_net = LitScoreNet.load_from_checkpoint(
    '/global/homes/s/shuyijia/playground/atom-diff/training_logs/mp-2-reduced/version_0/checkpoints/epoch=12-step=400000.ckpt'
)

In [4]:
print(score_net)

LitScoreNet(
  (model): ScoreModel(
    (encoder): Encoder(
      (embed_node): Sequential(
        (0): MLP(dims=[82, 200, 200], act=SiLU())
        (1): LayerNorm((200,), eps=1e-05, elementwise_affine=True)
      )
      (embed_edge): Sequential(
        (0): MLP(dims=[4, 200, 200], act=SiLU())
        (1): LayerNorm((200,), eps=1e-05, elementwise_affine=True)
      )
      (embed_time): Sequential(
        (0): GaussianRandomFourierFeatures()
        (1): MLP(dims=[200, 200, 200], act=SiLU())
        (2): LayerNorm((200,), eps=1e-05, elementwise_affine=True)
      )
    )
    (processor): Processor(
      (convs): ModuleList(
        (0-4): 5 x MeshGraphNetsConv(
          node_dim=200, edge_dim=200
          (edge_processor): EdgeProcessor(
            (edge_mlp): Sequential(
              (0): MLP(dims=[600, 200, 200, 200], act=SiLU())
              (1): LayerNorm((200,), eps=1e-05, elementwise_affine=True)
            )
          )
          (node_processor): NodeProcessor(
     

In [5]:
from atomdiff.utils.pbc import periodic_radius_graph

data_module.setup()
diffuser = data_module.train_set.diffuser

def prior_score(z, pos, cell, t, cutoff=data_module.cutoff):
    edge_index, edge_vec = periodic_radius_graph(pos, cutoff, cell)
    edge_len  = edge_vec.norm(dim=-1, keepdim=True)
    edge_attr = torch.hstack([edge_vec, edge_len])
    return score_net.ema_model(z, edge_index, edge_attr, t, diffuser.sigma(t))

def denoise_by_sde(z, pos, cell, score_fn, ts=torch.linspace(0.999, 0.001, 128+1)):
    ts = ts.to(pos.device).view(-1, 1)
    pos_traj = [pos.clone()]
    f, g, g2 = diffuser.f, diffuser.g, diffuser.g2
    for i, t in enumerate(ts[1:]):
        dt = ts[i+1] - ts[i]
        eps = dt.abs().sqrt() * torch.randn_like(pos)
        score = score_fn(z, pos, cell, t)
        disp = (f(t)*pos - g2(t)*score)*dt + g(t)*eps
        pos += disp
        pos_traj.append(pos.clone())
    return torch.stack(pos_traj)

2059
Creating torch_geometric.data.Data objects...


100%|██████████| 2059/2059 [00:00<00:00, 2686.30it/s]


In [6]:
train_loader = data_module.train_dataloader()

In [24]:
import ase.io

device = 'cuda'
path = "./data/mp-20-reduced/raw_train/mp-1210001.cif"

atoms = ase.io.read(path)

atomic_numbers = atoms.numbers
positions = torch.tensor(atoms.positions, dtype=torch.float)
cell = torch.tensor(atoms.cell.array, dtype=torch.float)

z = train_loader.dataset.atom_encoder.transform(atomic_numbers.reshape(-1, 1))
z = torch.tensor(z, dtype=torch.float)

positions = positions.to(device)
cell = cell.to(device)
z = z.to(device)

In [26]:
import random
import ase
from tqdm.notebook import trange

# Setup
NUM_GEN = 1       # number of generations/cells

score_net.ema_model.to(device)

priors = []
with torch.no_grad():
    for i in trange(NUM_GEN):
        
        # Start with a random structure
        pos = torch.rand(positions.shape).to(device) @ cell

        # Denoise
        pos_traj = denoise_by_sde(z, pos, cell, prior_score, ts=torch.linspace(0.999, 0.001, 64+1))
        pos = pos_traj[-1]
        
        # Convert generation to ASE Atoms
        atoms = ase.Atoms(numbers=atomic_numbers, positions=pos.detach().cpu().numpy(), cell=cell.cpu().numpy(), pbc=[True]*3)
        atoms.wrap()
        priors.append(atoms)

  0%|          | 0/1 [00:00<?, ?it/s]

In [31]:
for i, _pos in enumerate(pos_traj):
    atoms = ase.Atoms(numbers=atomic_numbers, positions=_pos.detach().cpu().numpy(), cell=cell.cpu().numpy(), pbc=[True]*3)
    ase.io.write(f'./outputs/mp-20-reduced/{i}.cif', atoms)