# POS-EGNN 

## Setup

In [1]:
# Uncomment to install notebook-only dependencies
# !pip install nglview ipywidgets

In [2]:
import nglview as nv
import torch
from ase import units
from ase.io import read
from ase.md.langevin import Langevin

from posegnn.calculator import PosEGNNCalculator



In [3]:
device = "cpu"
torch.set_float32_matmul_precision("high")

## Feature Extraction

In [4]:
# Please download checkpoint from https://huggingface.co/ibm-research/materials.pos-egnn
calculator = PosEGNNCalculator("pos-egnn.v1-6M.ckpt", device=device, compute_stress=False)
atoms = read("inputs/3BPA.xyz", index=0)
data = calculator._build_data(atoms)
embeddings = calculator.model.forward(data.to(device))["embedding_0"][:,:,:, 1].squeeze(2)

embeddings.shape

torch.Size([27, 256])

## Inference

In [5]:
out = calculator.model.compute_properties(data, compute_stress=False)
energy = out["total_energy"]
force = out["force"]

In [6]:
energy, force

(tensor([-175.0519], grad_fn=<AddBackward0>),
 tensor([[ 0.3428, -0.4197,  0.7246],
         [-0.8685, -0.1219, -2.3050],
         [ 0.2631,  0.0661,  0.8548],
         [-0.2307,  0.0230, -0.5161],
         [-0.4390,  2.7678, -0.7030],
         [ 0.0393, -0.5039,  1.0452],
         [ 0.3763, -2.2708, -0.7662],
         [ 0.2588, -1.6086, -0.0870],
         [-0.0932, -0.2467, -0.4807],
         [ 0.0185,  1.0018,  2.1512],
         [-0.4606,  1.3631, -0.3847],
         [ 0.3861, -0.3217,  0.6269],
         [-0.2910,  0.2251, -0.2673],
         [ 1.3340, -1.7278, -0.0881],
         [-0.9644,  1.1447,  1.0665],
         [-0.7468,  0.5678,  0.0310],
         [ 0.4204,  0.7406, -0.6954],
         [-0.2565,  0.2528,  0.2541],
         [ 2.0052, -0.3826, -0.2691],
         [-0.0074,  0.4379, -0.2768],
         [ 0.6456, -0.5602, -0.1124],
         [-0.0060, -1.0381,  0.2364],
         [-0.0415,  0.0296, -0.0675],
         [-0.8607,  0.0030,  0.0678],
         [-0.0546,  0.0535, -0.0684],
    

In [7]:
atoms.calc = calculator
atoms.get_total_energy(), atoms.get_forces()

(array([-175.05188], dtype=float32),
 array([[ 0.34280193, -0.4196788 ,  0.72462475],
        [-0.8685478 , -0.1218636 , -2.3050241 ],
        [ 0.26306948,  0.06607056,  0.8547611 ],
        [-0.23073699,  0.0230464 , -0.51613975],
        [-0.4390195 ,  2.7678292 , -0.70297706],
        [ 0.03933173, -0.50390184,  1.0451794 ],
        [ 0.3762841 , -2.2708364 , -0.7662431 ],
        [ 0.2588454 , -1.6086005 , -0.08700079],
        [-0.0931953 , -0.24666795, -0.48069412],
        [ 0.0184921 ,  1.0017664 ,  2.1512074 ],
        [-0.46055824,  1.3630679 , -0.3847049 ],
        [ 0.38605824, -0.32170495,  0.6269282 ],
        [-0.29103592,  0.22509198, -0.26729935],
        [ 1.3340424 , -1.7278178 , -0.08812346],
        [-0.9644211 ,  1.1447096 ,  1.0665404 ],
        [-0.7467996 ,  0.56782794,  0.03098067],
        [ 0.42040414,  0.74056077, -0.69537413],
        [-0.25654244,  0.2528266 ,  0.25414667],
        [ 2.0051587 , -0.382573  , -0.26911476],
        [-0.0074314 ,  0.4378659

## Molecular Dynamics Simulation

In [8]:
dyn = Langevin(atoms=atoms, friction=0.005, temperature_K=310, timestep=0.5 * units.fs)

def write_frame():
    dyn.atoms.write("output.xyz", append=True)

dyn.attach(write_frame, interval=1)
dyn.run(100)

True

In [9]:
traj = read('output.xyz', index=slice(None))
view = nv.show_asetraj(traj)
display(view)

NGLWidget(max_frame=100)