In [1]:
import os
import sys
import random
import multiprocessing as mp
# combine imports for read
from ase.io import read
import torch
from torch import compile as torch_compile
from torch.func import vjp, vmap
from torch.nn import MSELoss
from torch.optim.swa_utils import AveragedModel
from torch.utils.data import DataLoader
from torch_geometric.data import Batch
from torch_geometric.nn.pool import radius_graph
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
import numpy as np
from loaders.loaders import AtomsToGraphs, collate_fn_e2gnn, move_batch, AtomsDataset
from sklearn.metrics import mean_absolute_error, mean_squared_error
# Add E2GNN source path
sys.path.append(os.path.join(os.getcwd(), 'E2GNN'))
from E2GNN import E2GNN

cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled.


In [2]:
def evaluate_full(loader, mdl, device, hess_rows=256):
    mdl.eval()
    e_true, e_pred = [], []
    f_true, f_pred = [], []
    h_true, h_pred = [], []

    for batch in tqdm(loader):
        for data in batch.to_data_list():
            single = Batch.from_data_list([data])
            single = move_batch(single, device, torch.float)
            single.pos = single.pos.detach().clone().requires_grad_(True)

            # energy & force
            e_out, f_out = mdl(single)
            e_true.append(single.y.item())
            e_pred.append(e_out.item())

            f_flat = f_out.view(-1)
            f_true.extend(single.force.view(-1).cpu().tolist())
            f_pred.extend(f_flat.cpu().tolist())

            # prepare ground-truth Hessian
            n    = f_flat.numel()
            H_gt = single.hessian.view(n, n)

            # sample some rows
            idx = torch.randperm(n, device=device)[:hess_rows]
            for i in idx:
                # pick out f_i
                go     = torch.zeros_like(f_flat)
                go[i]  = 1.0

                # ∂f_i/∂x  → this equals –H_row_i
                g      = torch.autograd.grad(
                           f_flat, single.pos,
                           grad_outputs=go,
                           retain_graph=True
                         )[0]
                g_flat = g.view(-1)

                # flip sign so pred_row = +H_row
                pred_row = -g_flat     # shape (n,)
                true_row = H_gt[i]     # shape (n,)

                # append full row into your flat lists
                h_pred.extend(pred_row.cpu().tolist())
                h_true.extend(true_row.cpu().tolist())

    return e_true, e_pred, f_true, f_pred, h_true, h_pred

def combine_xyz_files(paths):
    atoms = []
    for p in paths:
        atoms.extend(read(p, ":"))
    return atoms

def plot_corr(x, y, title, xlabel, ylabel, filename=None):
    # convert to numpy
    x = np.asarray(x)
    y = np.asarray(y)

    # compute identical limits with 5% padding
    vmin = min(x.min(), y.min())
    vmax = max(x.max(), y.max())
    margin = (vmax - vmin) * 0.05
    lims = (vmin - margin, vmax + margin)

    # compute metrics
    mae = mean_absolute_error(y, x)
    mse = mean_squared_error(y, x)

    # start plot
    plt.figure(figsize=(6,6))
    plt.scatter(x, y, alpha=0.3, s=5)
    # perfect correlation line
    plt.plot(lims, lims, linestyle='--', linewidth=1, color='black')

    # axes, labels, title
    plt.xlim(lims)
    plt.ylim(lims)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)

    # metrics textbox
    textstr = f"MAE = {mae:.3e}\nMSE = {mse:.3e}"
    props = dict(boxstyle='round', facecolor='white', alpha=0.7)
    plt.gca().text(
        0.05, 0.95, textstr,
        transform=plt.gca().transAxes,
        fontsize=10, verticalalignment='top',
        bbox=props
    )

    plt.tight_layout()
    if filename:
        plt.savefig(filename)
        plt.close()
    else:
        plt.show()

def combine_target_dicts(paths):
    combined = {"energy": [], "forces": [], "hessian": []}
    for p in paths:
        d = torch.load(p)
        for k in combined:
            combined[k].extend(d[k])
    for k in combined:
        combined[k] = [x.cpu() for x in combined[k]]
    return combined




In [4]:
device = "cuda:0"

ema_model = torch.load("/home/alyssenko/c51_project/e2gnn_student_supervised_HESSIAN.model", map_location=device)
XYZ_TEST = [
    "/home/alyssenko/c51_project/BOTNet-datasets/dataset_3BPA/test_600K.xyz",
]

PT_TEST = [
    "/home/alyssenko/c51_project/BOTNet-datasets/dataset_3BPA/precomputed_training_data_test_600K.pt",
]

XYZ_TRAIN = [
    "BOTNet-datasets/dataset_3BPA/train_300K.xyz",
    "BOTNet-datasets/dataset_3BPA/train_mixedT.xyz",
    "BOTNet-datasets/dataset_3BPA/test_dih.xyz",
]

PT_TRAIN = [
    "BOTNet-datasets/dataset_3BPA/precomputed_training_data_train_300K.pt",
    "BOTNet-datasets/dataset_3BPA/precomputed_training_data_train_mixedT.pt",
    "BOTNet-datasets/dataset_3BPA/precomputed_training_data_test_dih.pt",
]


train_atoms = combine_xyz_files(XYZ_TRAIN)
train_tgt   = combine_target_dicts(PT_TRAIN)
all_h = torch.cat([h.flatten() for h in train_tgt['hessian']])
hessian_std = all_h.std()
train_ds = AtomsDataset(train_atoms, train_tgt, device, hessian_scale=hessian_std,plot_hist=False)
train_loader = DataLoader(train_ds, batch_size=512, shuffle=True, collate_fn=collate_fn_e2gnn)#,num_workers=8) 
mean_e, std_e, h_scale = train_ds.energy_mean, train_ds.energy_std, train_ds.hessian_scale

test_atoms = combine_xyz_files(XYZ_TEST)
test_tgt   = combine_target_dicts(PT_TEST)
test_ds   = AtomsDataset(test_atoms, test_tgt, device,
                            energy_mean=mean_e, energy_std=std_e, hessian_scale=h_scale)
test_loader = DataLoader(test_ds, batch_size=4, shuffle=False, collate_fn=collate_fn_e2gnn)


In [7]:
print(len(test_atoms),len(train_atoms))
print(mean_e,std_e)
print(sum(p.numel() for p in ema_model.parameters() if p.requires_grad))

2138 8047
tensor(-17678.3105) tensor(0.6023)
887745


In [5]:
test_e_true, test_e_pred, \
test_f_true, test_f_pred, \
test_h_true, test_h_pred = evaluate_full(
    test_loader,
    ema_model,
    device
)

100%|██████████| 535/535 [09:16<00:00,  1.04s/it]


In [6]:
plot_corr(test_e_true, test_e_pred, "Energy Correlation", "True Energy", "Predicted Energy", "correlation_plot_energy.png")
plot_corr(test_f_true, test_f_pred, "Force Correlation", "True Force", "Predicted Force", "correlation_plot_force.png")
plot_corr(test_h_true, test_h_pred, "Hessian Correlation", "True Hessian", "Predicted Hessian", "correlation_plot_hessian.png")

In [13]:

def plot_corr_unscaled(
    x_norm, y_norm,
    title, xlabel, ylabel,
    scale, mean=0.0,
    filename=None
):
    x_norm = np.asarray(x_norm, dtype=float)
    y_norm = np.asarray(y_norm, dtype=float)
    scale  = float(scale)   
    mean   = float(mean)

    x = x_norm * scale + mean
    y = y_norm * scale + mean


    vmin, vmax = min(x.min(), y.min()), max(x.max(), y.max())
    pad        = (vmax - vmin) * 0.05
    lims       = (vmin - pad, vmax + pad)


    mae = mean_absolute_error(y, x)
    mse = mean_squared_error(y, x)


    plt.figure(figsize=(6,6))
    plt.scatter(x, y, alpha=0.3, s=5)
    plt.plot(lims, lims, '--', linewidth=1, color='black')  

    plt.xlim(lims); plt.ylim(lims)
    plt.xlabel(xlabel); plt.ylabel(ylabel)
    plt.title(title)


    textstr = f"MAE = {mae:.3e}\nMSE = {mse:.3e}"
    props   = dict(boxstyle='round', facecolor='white', alpha=0.7)
    plt.gca().text(
        0.05, 0.95, textstr,
        transform=plt.gca().transAxes,
        fontsize=10, verticalalignment='top',
        bbox=props
    )

    plt.tight_layout()
    if filename:
        plt.savefig(filename)
        plt.close()
    else:
        plt.show()

In [14]:
plot_corr_unscaled(
    x_norm   = test_e_pred,
    y_norm   = test_e_true,
    title    = "Energy Correlation (unscaled)",
    xlabel   = "True Energy (eV)",
    ylabel   = "Predicted Energy (eV)",
    scale    = std_e.detach().numpy(),      
    mean     = mean_e.detach().numpy(),     
    filename = "corr_energy_unscaled.png"
)

plot_corr_unscaled(
    x_norm   = test_f_pred,
    y_norm   = test_f_true,
    title    = "Force Correlation (unscaled)",
    xlabel   = "True Force (eV/A)",
    ylabel   = "Predicted Force (eV/A)",
    scale    = std_e.detach().numpy(),      
    mean     = 0.0,        
    filename = "corr_force_unscaled.png"
)


plot_corr_unscaled(
    x_norm   = test_h_pred,
    y_norm   = test_h_true,
    title    = "Hessian Correlation (unscaled)",
    xlabel   = "True Hessian (eV/A^2)",
    ylabel   = "Predicted Hessian (eV/A^2)",
    scale    = h_scale.detach().numpy(),    
    mean     = 0.0,       
    filename = "corr_hessian_unscaled.png"
)
