-  reference for review Hessian matrix: https://machinelearningmastery.com/a-gentle-introduction-to-hessian-matrices/
- reference for review Jacobian matrix: https://machinelearningmastery.com/a-gentle-introduction-to-the-jacobian/

In [1]:
from rdkit import Chem
import torch
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
in_path = 'example.sdf'
# there is no gradient information from the ensemble model
aimnet2 = torch.jit.load('/home/jack/aimnet2_lab/models/aimnet2_wb97m-d3_0.jpt', map_location=device)

def sdf2aimnet_input(sdf: str, device=torch.device('cpu')) -> dict:
    """Converts sdf to aimnet input, assuming the sdf has only 1 conformer."""
    mol = next(Chem.SDMolSupplier(sdf, removeHs=False))
    conf = mol.GetConformer()
    coord = torch.tensor(conf.GetPositions(), device=device).unsqueeze(0)
    numbers = torch.tensor([atom.GetAtomicNum() for atom in mol.GetAtoms()], device=device).unsqueeze(0)
    charge = torch.tensor([Chem.GetFormalCharge(mol)], device=device, dtype=torch.float)
    return dict(coord=coord, numbers=numbers, charge=charge)


dct = sdf2aimnet_input(in_path, device=device)
dct['coord'].requires_grad_(True)
aimnet2_out = aimnet2(dct)
print('aimnet2 energy: ', aimnet2_out['energy'])  #note that there is no gradient function for energy

aimnet2 energy:  tensor([-10518.2655], device='cuda:0', dtype=torch.float64,
       grad_fn=<AddBackward0>)


In [2]:
forces = - torch.autograd.grad(aimnet2_out['energy'], dct['coord'], create_graph=True)[0]
print(forces.shape)

torch.Size([1, 20, 3])


In [3]:
# using pytorch function for calculating the hessian
def func(coord, numbers=dct['numbers'], charge=dct['charge']):
    dct = dict(coord=coord, numbers=numbers, charge=charge)
    return aimnet2(dct)['energy']

hess2 = torch.autograd.functional.hessian(func, dct['coord'])
print(hess2)

tensor([[[[[[ 5.2781e+01, -9.5706e-01,  1.9930e+01],
            [-1.1365e+01,  6.8590e+00, -3.9351e+00],
            [-6.9205e-01, -1.6098e+00,  4.5698e-01],
            ...,
            [-4.3666e-02,  4.6905e-03,  4.8027e-02],
            [ 1.3112e-02, -5.1483e-02, -5.0566e-02],
            [ 1.7759e-01,  6.1218e-01, -2.7228e-02]]],


          [[[-9.5705e-01,  7.0478e+01, -1.5258e+01],
            [ 5.9971e+00, -4.0473e+01,  1.1000e+01],
            [-1.3390e+00, -1.2029e+00,  3.9974e-01],
            ...,
            [-1.8787e-02, -1.9064e-01,  1.0621e-01],
            [ 2.0141e-02, -1.8988e-01,  1.8768e-02],
            [-1.6291e-01, -3.3018e-01, -2.3343e-02]]],


          [[[ 1.9930e+01, -1.5258e+01,  2.5407e+01],
            [-3.7458e+00,  1.1760e+01, -1.1204e+01],
            [ 1.7912e+00,  2.1679e+00,  5.1628e-01],
            ...,
            [-3.3753e-02,  4.9814e-03,  1.9353e-02],
            [-1.8409e-03,  1.0843e-02, -4.5274e-02],
            [ 8.7504e-02,  2.8650e-01, -

In [4]:
print(hess2.shape)

torch.Size([1, 20, 3, 1, 20, 3])


In [5]:
hess3 = hess2.view(20, 3, 20, 3)
print(hess2)

tensor([[[[[[ 5.2781e+01, -9.5706e-01,  1.9930e+01],
            [-1.1365e+01,  6.8590e+00, -3.9351e+00],
            [-6.9205e-01, -1.6098e+00,  4.5698e-01],
            ...,
            [-4.3666e-02,  4.6905e-03,  4.8027e-02],
            [ 1.3112e-02, -5.1483e-02, -5.0566e-02],
            [ 1.7759e-01,  6.1218e-01, -2.7228e-02]]],


          [[[-9.5705e-01,  7.0478e+01, -1.5258e+01],
            [ 5.9971e+00, -4.0473e+01,  1.1000e+01],
            [-1.3390e+00, -1.2029e+00,  3.9974e-01],
            ...,
            [-1.8787e-02, -1.9064e-01,  1.0621e-01],
            [ 2.0141e-02, -1.8988e-01,  1.8768e-02],
            [-1.6291e-01, -3.3018e-01, -2.3343e-02]]],


          [[[ 1.9930e+01, -1.5258e+01,  2.5407e+01],
            [-3.7458e+00,  1.1760e+01, -1.1204e+01],
            [ 1.7912e+00,  2.1679e+00,  5.1628e-01],
            ...,
            [-3.3753e-02,  4.9814e-03,  1.9353e-02],
            [-1.8409e-03,  1.0843e-02, -4.5274e-02],
            [ 8.7504e-02,  2.8650e-01, -

this Hessian can be passed to `VibrationsData` object for calculating the vibration energy.
- https://wiki.fysik.dtu.dk/ase/ase/vibrations/modes.html#ase.vibrations.VibrationsData

In [8]:
# roman's script
hess = -torch.stack([torch.autograd.grad(f, dct['coord'], retain_graph=True)[0].flatten() for f in forces.flatten().unbind()])
hess = hess[:-3, :-3]
hess = hess.view(len(dct['numbers']), 3, len(dct['numbers']), 3).detach().cpu().numpy()
print(hess.shape)

RuntimeError: shape '[1, 3, 1, 3]' is invalid for input of size 3249