In [None]:
!pip install nglview ipywidgets --quiet

In [4]:
import torch
from omegaconf import OmegaConf
from nablaDFT.optimization import PYGAseInterface

In [59]:
molecule_path = "./moses_7570.xyz"
ckpt_path = "../checkpoints/GemNet-OC/GemNet-OC_100k.ckpt"
workdir = "./optimize"
model_cfg = OmegaConf.create(
    {
        "model": OmegaConf.load("../config/model/gemnet-oc.yaml")
    }
)

In [60]:
optimizer = PYGAseInterface(
    molecule_path=molecule_path,
    working_dir=workdir,
    config=model_cfg,
    ckpt_path=ckpt_path,
    energy_key="energy",
    force_key="forces",
    energy_unit="eV",
    position_unit="Ang",
    device="cuda:0",
    dtype=torch.float32,
)

INFO:nablaDFT.optimization.pyg_ase_interface:Loading model from ../checkpoints/GemNet-OC/GemNet-OC_100k.ckpt
/mnt/2tb/ber/miniconda3/envs/nablaDFT/lib/python3.9/site-packages/pytorch_lightning/utilities/parsing.py:198: Attribute 'metric' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['metric'])`.
INFO:root:Restore model weights from ../checkpoints/GemNet-OC/GemNet-OC_100k.ckpt


In [61]:
optimizer.optimize(fmax=1e-4, steps=100)

                Step[ FC]     Time          Energy          fmax
BFGSLineSearch:    0[  0] 14:00:28       -5.322232        0.1936
BFGSLineSearch:    1[  3] 14:00:29       -5.359284        0.1254
BFGSLineSearch:    2[  8] 14:00:30       -5.364477        0.0777
BFGSLineSearch:    3[ 42] 14:00:39       -5.365033        0.0655
BFGSLineSearch:    4[ 64] 14:00:45       -5.365070        0.0558
BFGSLineSearch:    5[ 80] 14:00:49       -5.365069        0.0558
BFGSLineSearch:    6[110] 14:00:57       -5.365223        0.0424
BFGSLineSearch:    7[126] 14:01:01       -5.365287        0.0307
BFGSLineSearch:    8[148] 14:01:06       -5.365572        0.0285
BFGSLineSearch:    9[176] 14:01:14       -5.365817        0.0263
BFGSLineSearch:   10[195] 14:01:18       -5.366064        0.0279
BFGSLineSearch:   11[213] 14:01:23       -5.366118        0.0395
BFGSLineSearch:   12[216] 14:01:23       -5.367243        0.0452
BFGSLineSearch:   13[218] 14:01:24       -5.372262        0.0454
BFGSLineSearch:   14[220]

INFO:nablaDFT.optimization.pyg_ase_interface:Save molecule in optimization


## Visualize

In [65]:
from ase.io.trajectory import Trajectory
from nglview import show_asetraj

In [66]:
traj = Trajectory("./optimize/optimization.traj")
show_asetraj(traj)

NGLWidget(max_frame=100)