In [1]:
#!/usr/bin/env python3
import argparse
import numpy as nps
from ase import Atoms
from ase.io import read, write
from ase.optimize import *
import torch
from spo.spookynet import SpookyNet
"""
Spookynet: Learning force fields with electronic degrees of freedom and nonlocal effects.
OT Unke, S Chmiela, M Gastegger, KT Schütt, HE Sauceda, KR Müller
Nat. Commun. 12(1), 2021, 1-14, (2021).
"""


'\nSpookynet: Learning force fields with electronic degrees of freedom and nonlocal effects.\nOT Unke, S Chmiela, M Gastegger, KT Schütt, HE Sauceda, KR Müller\nNat. Commun. 12(1), 2021, 1-14, (2021).\n'

In [2]:
import torch
import pickle

# Load the saved DataLoader
filename = 'dataloader.pkl'
with open(filename, 'rb') as f:
    loaded_dataloader = pickle.load(f)

In [3]:
spookyNet_model = 'prueba.pth'
charge = 1
magmom = 0 
nbeads = 10
use_gpu = True

In [4]:
torch.cuda.get_device_name(0)

'NVIDIA GeForce RTX 3060'

In [5]:
# load SpookyNet model
try:
    model = SpookyNet(load_from=spookyNet_model)
    model.to(torch.float32)
    model.eval()
    print(
        " @ForceField: SpookyNet model " + spookyNet_model + " loaded"
    )
except ValueError:
    raise ValueError(
        "ERROR: Reading SpookyNet model " + spookyNet_model + " file failed."
    )

if use_gpu and not torch.cuda.is_available():
    print(
        " @ForceField: No GPU available: Evaluation on GPU was requested"
        + " but no GPU device is available. Falling back to CPU."
    )

if use_gpu and torch.cuda.is_available():
    model.cuda()
    device = 'cuda:0'
    
else:
    device = 'cpu'
print(device)

 @ForceField: SpookyNet model prueba.pth loaded
cuda:0


In [6]:
enegias = []

In [7]:
j = 0
for batch in loaded_dataloader:
    # Process the batch as needed
    datos = batch  # Unpack the batch into inputs and targets
    nxyz = datos['nxyz']
    xyz = nxyz[:, 1:]
    xyz = xyz.numpy()
    len_atoms = nxyz.shape[0]
    nu_atoms = nxyz[:, 0]

    nu = datos['num_atoms']
    # generate idx lists for finding neighboring atoms
    idx = torch.arange(len_atoms, dtype=torch.int64, device=device)
    idx_i = idx.view(-1, 1).expand(-1, len_atoms).reshape(-1)
    idx_j = idx.view(1, -1).expand(len_atoms, -1).reshape(-1)

    idx_i = datos['mol_nbrs'][:, 0].to(device)
    idx_j = datos['mol_nbrs'][:, 1].to(device)

    # exclude self-interactions
    idx_i, idx_j = idx_i[idx_i != idx_j], idx_j[idx_i != idx_j]

    batch_seg = torch.cat([torch.ones(int(num_atoms)) * i for i, num_atoms in
                               enumerate(nu)]).to(torch.long).to(device)
    # create input dictionary
    inputs2 = {
        "Z": torch.tensor(nu_atoms, dtype=torch.int64, device=device),
        "Q": torch.full((nbeads,), charge, dtype=torch.float32, device=device),
        "S": torch.full((nbeads,), magmom, dtype=torch.float32, device=device),
        "idx_i": torch.cat([idx_i + i * len_atoms for i in range(1)]),
        "idx_j": torch.cat([idx_j + i * len_atoms for i in range(1)]),
        "batch_seg": batch_seg,
        "num_batch": 10
    }
    R = torch.tensor(xyz,
                    dtype=torch.float32,
                    device=device)
    R.requires_grad = True
    inputs2["R"] = R

    outputs = model.energy_and_forces(**inputs2)
    energy = outputs[0].detach().cpu().numpy()
    forces = outputs[1].detach().cpu().numpy().reshape((1, -1, 3))
    enegias.append(energy)
    print('energy : ',energy, 'forces : ', forces)


  "Z": torch.tensor(nu_atoms, dtype=torch.int64, device=device),


torch.Size([170]) torch.Size([10]) torch.Size([10]) torch.Size([170, 3])
energy :  [-0.0841921   0.08180042  0.01822178 -0.01216138 -0.02671169  0.05273353
 -0.05380144 -0.02578509 -0.07193723  0.06553777] forces :  [[[ 2.22766146e-01  3.08526188e-01  3.72679234e-01]
  [ 9.88022983e-02  2.18968853e-01 -8.25225264e-02]
  [ 3.55656818e-03  1.91873506e-01  4.96921539e-02]
  [-1.29688099e-01  1.00383125e-01 -6.09104782e-02]
  [-2.34301805e-01  5.86121231e-02  4.47650850e-02]
  [ 1.50685370e-01 -2.29977831e-01 -3.62922326e-02]
  [-2.54732788e-01 -1.33506417e-01  7.89312720e-02]
  [-3.46375048e-01 -8.12287629e-02 -2.80671179e-01]
  [ 1.61254644e-01 -3.99302170e-02  1.19033769e-01]
  [-2.01691300e-01 -3.17947268e-01  3.21236163e-01]
  [ 2.23876700e-01  1.71944097e-01 -5.76051056e-01]
  [-2.13019475e-01  8.27216133e-02 -3.00275415e-01]
  [ 4.45081353e-01  6.51166812e-02  6.84519708e-02]
  [ 1.09674528e-01  8.81208926e-02  7.01455623e-02]
  [-9.59624425e-02 -1.08761480e-02 -2.63513476e-01]
  [ 

KeyboardInterrupt: 