In [1]:
from rdkit import Chem
import torch
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

In [2]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
print(device)
in_path = 'example.sdf'

aimnet2 = torch.jit.load('/home/jack/aimnet2_lab/models/aimnet2_wb97m-d3_ens.jpt', map_location=device)
aimnet2_0 = torch.jit.load('/home/jack/aimnet2_lab/models/aimnet2_wb97m-d3_0.jpt', map_location=device)

cuda:1


In [3]:
def sdf2aimnet_input(sdf: str, device=torch.device('cpu')) -> dict:
    """Converts sdf to aimnet input, assuming the sdf has only 1 conformer."""
    mol = next(Chem.SDMolSupplier(sdf, removeHs=False))
    conf = mol.GetConformer()
    coord = torch.tensor(conf.GetPositions(), device=device).unsqueeze(0)
    numbers = torch.tensor([atom.GetAtomicNum() for atom in mol.GetAtoms()], device=device).unsqueeze(0)
    charge = torch.tensor([Chem.GetFormalCharge(mol)], device=device, dtype=torch.float)
    return dict(coord=coord, numbers=numbers, charge=charge)


dct = sdf2aimnet_input(in_path, device=device)
dct['coord'].requires_grad_(True)
aimnet2_out = aimnet2(dct)
print('aimnet2 energy: ', aimnet2_out['energy'])  # there is no gradient for energy

aimnet2 energy:  tensor([-10518.2732], device='cuda:1', dtype=torch.float64)


In [4]:
for key, val in aimnet2_out.items():
    print(key, val.shape)

coord torch.Size([1, 20, 3])
numbers torch.Size([1, 20])
charge torch.Size([1])
charges torch.Size([1, 20])
charges_std torch.Size([1, 20])
energy torch.Size([1])
energy_std torch.Size([1])
forces torch.Size([1, 20, 3])
forces_std torch.Size([1, 20, 3])


In [5]:
print('aimnet2 energy: ', aimnet2_out['energy'])

aimnet2 energy:  tensor([-10518.2732], device='cuda:1', dtype=torch.float64)


## aimnet2 single model

In [6]:
aimnet2_0out = aimnet2_0(dct)
print('aimnet2_0 energy: ', aimnet2_0out['energy'])  # there is gradient function for energy

aimnet2_0 energy:  tensor([-10518.2655], device='cuda:1', dtype=torch.float64,
       grad_fn=<AddBackward0>)


In [7]:
forces = -torch.autograd.grad(aimnet2_0out['energy'], dct['coord'], create_graph=True)[0]
print(forces)

tensor([[[ 0.0198,  0.1783, -0.0657],
         [ 0.0356, -0.2943,  0.0706],
         [-0.0442,  0.0613,  0.1326],
         [ 0.1627,  0.1406,  0.0440],
         [-0.0146, -0.1142, -0.1864],
         [-0.1602,  0.0206, -0.0834],
         [-0.3099, -0.3564, -0.0733],
         [ 0.0133, -0.1086,  0.0467],
         [-0.1510,  0.0309, -0.0501],
         [-0.1301,  0.0278, -0.1938],
         [ 0.1394,  0.1691, -0.0113],
         [ 0.1099, -0.1648, -0.1143],
         [-0.0429, -0.0140, -0.2180],
         [-0.1936,  0.0702,  0.0784],
         [ 0.0217, -0.1661,  0.0584],
         [ 0.1824,  0.0953,  0.2214],
         [-0.1041,  0.1576,  0.1365],
         [ 0.0342,  0.0051,  0.1738],
         [ 0.1627, -0.0603, -0.0591],
         [ 0.2689,  0.3220,  0.0929]]], device='cuda:1', dtype=torch.float64,
       grad_fn=<NegBackward0>)


## ANI-2xt

In [8]:
from Auto3D.batch_opt.ANI2xt_no_rep import ANI2xt

In [9]:
ani2xt = ANI2xt(device=device).double()

  self_energies = torch.tensor(self_energies, dtype=torch.double)


In [10]:
print(next(ani2xt.parameters()).device)
print(dct['numbers'].device)
print(dct['coord'].device)
print(dct['numbers'].shape)
print(dct['coord'].shape)


cuda:1
cuda:1
cuda:1
torch.Size([1, 20])
torch.Size([1, 20, 3])


In [11]:
periodict2idx = {1:0, 6:1, 7:2, 8:3, 9:4, 16:5, 17:6}
numbers2 = torch.tensor([periodict2idx[num.item()] for num in dct['numbers'].squeeze()], device=device).unsqueeze(0)
print(numbers2)
print(numbers2.shape)
print(numbers2.dtype)
e, f= ani2xt(numbers2, dct['coord'])
print(e, f)

tensor([[1, 1, 1, 1, 1, 1, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
       device='cuda:1')
torch.Size([1, 20])
torch.int64
tensor([-10506.7458], device='cuda:1', dtype=torch.float64,
       grad_fn=<MulBackward0>) tensor([[[-0.1086, -0.5575,  0.0870],
         [ 0.1492,  0.4527, -0.0644],
         [ 0.0527,  0.0769,  0.1976],
         [-0.0455, -0.1433,  0.0164],
         [ 0.1827,  0.3086, -0.2481],
         [-0.0053,  0.5933, -0.1470],
         [-0.1012, -0.1057, -0.0098],
         [-0.0458, -0.4597,  0.0937],
         [-0.1637, -0.0676, -0.0271],
         [-0.1362,  0.0115, -0.1234],
         [ 0.1414,  0.2044, -0.0117],
         [ 0.0519, -0.1948, -0.0617],
         [-0.0190,  0.0291, -0.1953],
         [-0.1545,  0.1042,  0.0672],
         [ 0.0356, -0.1654,  0.0644],
         [ 0.0809, -0.0095,  0.1417],
         [ 0.0052,  0.1215,  0.1229],
         [ 0.0556, -0.0132,  0.2087],
         [ 0.2004, -0.0839, -0.0542],
         [-0.1756, -0.1015, -0.0569]]], device='cuda:1', dty

## ANI-2x

In [12]:
import torchani
from Auto3D.utils import hartree2ev

In [13]:
ani2x = torchani.models.ANI2x(periodic_table_index=True).to(device).double()

In [14]:
print(dct['numbers'].shape)
print(dct['coord'].shape)

torch.Size([1, 20])
torch.Size([1, 20, 3])


In [15]:
e = ani2x((dct['numbers'], dct['coord'])).energies
e = e * hartree2ev
g = torch.autograd.grad([e.sum()], dct['coord'])[0]
f = -g
print(e, f)

tensor([-10508.4711], device='cuda:1', dtype=torch.float64,
       grad_fn=<MulBackward0>) tensor([[[-2.7626e-03, -4.0848e-01,  1.0773e-01],
         [ 5.9649e-04,  6.7412e-01, -1.8385e-01],
         [ 5.3409e-02,  3.1565e-02,  2.1073e-01],
         [ 3.1580e-02, -6.9823e-02,  3.6084e-02],
         [ 6.2624e-02,  2.5490e-01, -2.7018e-01],
         [-4.8090e-02,  2.2278e-01, -7.1346e-02],
         [ 3.6675e-01, -1.4568e-01,  2.1477e-01],
         [-3.4433e-01, -5.4557e-01, -3.2470e-02],
         [-4.1047e-02, -5.8516e-02,  3.0438e-02],
         [-4.6297e-03, -1.5265e-02, -4.9141e-03],
         [ 3.9882e-02,  8.3096e-02, -4.4136e-02],
         [-1.6983e-02, -3.8818e-02, -2.6022e-02],
         [-2.2146e-02,  1.4923e-02, -9.1110e-03],
         [-1.4549e-02,  1.3608e-02, -1.9153e-02],
         [ 1.5301e-02, -2.0725e-02,  1.3282e-02],
         [ 1.1635e-02,  5.9222e-02, -6.2080e-03],
         [ 8.8300e-02, -1.3558e-04,  9.4194e-02],
         [-1.1661e-02, -3.1744e-02,  1.0882e-02],
         