## Load Data

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from dipeptides.data import XYZData

cutoff = 4.0
edata = XYZData(batch_size=4,cutoff=cutoff)
atomic_nums = []
for batch in edata.val_dataloader():
    break
batch

  warn(


Batch(atomic_numbers=[187], batch=[187], bec=[187, 9], cell=[12, 3], dipole=[4, 3], edge_index=[2, 3078], energy=[4], force=[187, 3], hirsh_charges=[187], hirsh_dipole=[12], hirsh_quadrupole=[12, 3], mbi_charges=[187], mbi_dipole=[12], mbi_quadrupole=[12, 3], mul_charges=[187], mul_dipole=[12], mul_quadrupole=[12, 3], positions=[187, 3], pred_bec=[187, 9], pred_charges=[187], pred_dipole=[12], pred_quadrupole=[12, 3], ptr=[5], quadrupole=[4, 3, 3], shifts=[3078, 3], unit_shifts=[3078, 3])

## Load Model

In [3]:
#Losses
import torch
from cace.tasks import GetLoss
e_loss = GetLoss(
    target_name="energy",
    predict_name='pred_energy',
    loss_fn=torch.nn.MSELoss(),
    loss_weight=1,
)
f_loss = GetLoss(
    target_name="force",
    predict_name='pred_force',
    loss_fn=torch.nn.MSELoss(),
    loss_weight=1000,
)
losses = [e_loss,f_loss]

#Metrics
from cace.tools import Metrics
e_metric = Metrics(
            target_name="energy",
            predict_name='pred_energy',
            name='e',
            metric_keys=["rmse"],
            per_atom=True,
        )
f_metric = Metrics(
            target_name="force",
            predict_name='pred_force',
            metric_keys=["rmse"],
            name='f',
        )
metrics = [e_metric,f_metric]

In [4]:
from cace.tasks import LightningData, LightningTrainingTask
from dipeptides.model import make_cace_lr

#LR model
sr_model = make_cace_lr(cutoff=cutoff,lr=False)
sr_model.cuda()
sr_model(batch.cuda())
sr_task = LightningTrainingTask(sr_model,losses=losses,metrics=metrics,
                             logs_directory="model_runs/lightning_logs",name="test",
                             scheduler_args={'mode': 'min', 'factor': 0.8, 'patience': 10},
                             optimizer_args={'lr': 0.01},
                            )
chkpt = "models/sr-model.ckpt"
sr_task.load(chkpt)
sr_task.model.cuda();

Loading model from models/sr-model.ckpt ...
Loading successful!


In [5]:
from cace.tasks import LightningData, LightningTrainingTask
from dipeptides.model import make_cace_lr

#LR model
lr_model = make_cace_lr(cutoff=cutoff,lr=True)
lr_model.cuda()
lr_model(batch.cuda())
lr_task = LightningTrainingTask(lr_model,losses=losses,metrics=metrics,
                             logs_directory="model_runs/lightning_logs",name="test",
                             scheduler_args={'mode': 'min', 'factor': 0.8, 'patience': 10},
                             optimizer_args={'lr': 0.01},
                            )
chkpt = "models/lr-model.ckpt"
lr_task.load(chkpt)
lr_task.model.cuda();

Loading model from models/lr-model.ckpt ...
Loading successful!


## Energy & Force Errors

In [6]:
#NOTE:
#Confirmed to work with https://github.com/dking072/cace.git
#Hash ad93b84a298cc2ef280adda64fdef2e3c31f8ece
#Future adjustments to Ewald seem to break the long range 

In [7]:
import pandas as pd
df = pd.DataFrame()

batch_all = {}
sr_all = {}
lr_all = {}
for batch in edata.test_dataloader():
    batch.cuda()
    lr_out = lr_task.model.model.forward(batch,training=False)
    sr_out = sr_task.model.model.forward(batch,training=False)
    for k in ["pred_energy","pred_force"]:
        if k not in sr_all:
            sr_all[k] = []
        if k not in lr_all:
            lr_all[k] = []
        lr_all[k].append(lr_out[k])
        sr_all[k].append(sr_out[k])
    for k in ["energy","force"]:
        if k not in batch_all:
            batch_all[k] = []
        batch_all[k].append(batch[k])

for k in ["pred_energy"]:
    lr_all[k] = torch.hstack(lr_all[k])
    sr_all[k] = torch.hstack(sr_all[k])
for k in ["pred_force"]:
    lr_all[k] = torch.vstack(lr_all[k])
    sr_all[k] = torch.vstack(sr_all[k])
for k in ["energy"]:
    batch_all[k] = torch.hstack(batch_all[k])
for k in ["force"]:
    batch_all[k] = torch.vstack(batch_all[k])

edata_all = edata = XYZData(batch_size=100000,cutoff=cutoff)
for batch in edata_all.test_dataloader():
    batch_all["batch"] = batch["batch"].cuda()
    break

import pandas as pd
df = pd.DataFrame()
df.loc["E","CACE-LR test"] = e_metric(lr_all,batch_all)["rmse"].item()
df.loc["F","CACE-LR test"] = f_metric(lr_all,batch_all)["rmse"].item()
df.loc["E","CACE-SR test"] = e_metric(sr_all,batch_all)["rmse"].item()
df.loc["F","CACE-SR test"] = f_metric(sr_all,batch_all)["rmse"].item()
torch.cuda.empty_cache()
df

Unnamed: 0,CACE-LR test,CACE-SR test
E,0.001878,0.002354
F,0.061126,0.07243


In [6]:
import pandas as pd
df = pd.DataFrame()

batch_all = {}
sr_all = {}
lr_all = {}
for batch in edata.val_dataloader():
    batch.cuda()
    lr_out = lr_task.model(batch)
    sr_out = sr_task.model(batch)
    for k in ["pred_energy","pred_force"]:
        if k not in sr_all:
            sr_all[k] = []
        if k not in lr_all:
            lr_all[k] = []
        lr_all[k].append(lr_out[k])
        sr_all[k].append(sr_out[k])
    for k in ["energy","force"]:
        if k not in batch_all:
            batch_all[k] = []
        batch_all[k].append(batch[k])

for k in ["pred_energy"]:
    lr_all[k] = torch.hstack(lr_all[k])
    sr_all[k] = torch.hstack(sr_all[k])
for k in ["pred_force"]:
    lr_all[k] = torch.vstack(lr_all[k])
    sr_all[k] = torch.vstack(sr_all[k])
for k in ["energy"]:
    batch_all[k] = torch.hstack(batch_all[k])
for k in ["force"]:
    batch_all[k] = torch.vstack(batch_all[k])

edata_all = edata = XYZData(batch_size=100000,cutoff=cutoff)
for batch in edata_all.val_dataloader():
    batch_all["batch"] = batch["batch"].cuda()
    break

import pandas as pd
df = pd.DataFrame()
df.loc["E","CACE-LR val"] = e_metric(lr_all,batch_all)["rmse"].item()
df.loc["F","CACE-LR val"] = f_metric(lr_all,batch_all)["rmse"].item()
df.loc["E","CACE-SR val"] = e_metric(sr_all,batch_all)["rmse"].item()
df.loc["F","CACE-SR val"] = f_metric(sr_all,batch_all)["rmse"].item()
torch.cuda.empty_cache()
df

Unnamed: 0,CACE-LR val,CACE-SR val
E,0.001289,0.001969
F,0.053149,0.058815
