In [None]:
import torch
from tqdm import tqdm
from nequip.utils.config import Config
from nequip.data import dataset_from_config, DataLoader, AtomicData
from litraj.data import download_dataset

In [None]:
!mkdir data

In [None]:
download_dataset('BVEL13k', 'data') # download the dataset to ./data folder

In [26]:
config = Config.from_file('../configs/allegro_BVEL13k_E3D.yaml')
config['dataset_file_name'] = './data/BVEL13k/BVEL13k_train.xyz'
config['validation_dataset_file_name'] = './data/BVEL13k/BVEL13k_val.xyz'

In [27]:
config['wandb'] = False              # set to True to use wandb
config['wandb_entity'] = 'your_entity'
config['wandb_project'] = 'your_project'

In [None]:
config.save('config.yaml')

In [None]:
!nequip-train config.yaml

In [None]:
# deploy

In [None]:
!nequip-deploy build --train-dir allegro_BVEL13k_e3d_rmax_7_lr_005_reproduce best_model.pt

In [None]:
model_path = 'best_model.pt'
model = torch.jit.load(model_path)

In [None]:
config['dataset_file_name'] = './data/BVEL13k/BVEL13k_test.xyz'
dataset = dataset_from_config(config, prefix="dataset")

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

dataloader = DataLoader(
        dataset=dataset,
        shuffle=False,
        batch_size=1)

preds = torch.tensor([])
targets = torch.tensor([])
nsites = torch.tensor([]) # we were training on E_3D_x_nsites

with torch.no_grad():
    model.eval()
    for batch in tqdm(dataloader):
        batch.to(device)
        data = AtomicData.to_AtomicDataDict(batch)
        result = model(data)
        preds = torch.cat((preds, result['total_energy'].detach().cpu()))
        targets = torch.cat((targets, data['total_energy'].detach().cpu()))
        nsites = torch.cat((nsites, torch.tensor([data['pos'].shape[0]])))
        
targets = (targets.squeeze() / nsites).numpy() # convert to E_3D
preds = (preds.squeeze() / nsites).numpy()     # convert to E_3D

In [None]:
import matplotlib.pyplot as plt

plt.figure(dpi = 200, figsize = (3, 3))
plt.plot(targets, targets, color = 'k', linewidth = 0.75, zorder = -1)
plt.scatter(targets, preds, s = 10, alpha = 0.5)
plt.xlabel('$E_a^{3D}$(BVEL), eV')
plt.ylabel('$E_a^{3D}$(Allegro), eV')
plt.tight_layout()

In [None]:
from litraj.metrics import get_metrics

get_metrics(targets, preds)