This notebook shows how to easily create a SAKE model using the provided code

In [1]:
import torch
import sys
sys.path.append(".")  # Set pathway here
import SAKE_model_v1 as SAKE



Run code to create SAKE layers

In [2]:
embed_in, embed_out,\
sake_block, energy_NN = SAKE.create_SAKE_layers(in_node_nf=1, 
                                                hidden_nf=32, 
                                                out_node_nf=32, 
                                                act_fn=torch.nn.CELU(alpha=2.0), 
                                                energy_act_fn=torch.nn.CELU(alpha=2.0), 
                                                n_layers=4,
                                                n_heads=4,
                                                cutoff=1,
                                                kernel_size = 18,
                                                embed_type = 'gaff',
                                                normalize = False
                                               )

Run code to create model

In [3]:
model = SAKE.SAKE_modular(embedding_in=embed_in,
                          embedding_out=embed_out,
                          sake_conv=sake_block,
                          energy_network=energy_NN,
                          device='cpu', 
                          cutoff=1,
                          max_num_neighbors=1000
                         )

In [4]:
model

SAKE_modular(
  (embedding_in): Embedding(97, 32)
  (embedding_out): Linear(in_features=32, out_features=32, bias=True)
  (sake_conv): ModuleList(
    (0-3): 4 x SAKELayer(
      (edge_mlp): Sequential(
        (0): Linear(in_features=97, out_features=32, bias=True)
        (1): CELU(alpha=2.0)
        (2): Linear(in_features=32, out_features=32, bias=True)
        (3): CELU(alpha=2.0)
      )
      (node_mlp): Sequential(
        (0): Linear(in_features=96, out_features=32, bias=True)
        (1): CELU(alpha=2.0)
        (2): Linear(in_features=32, out_features=32, bias=True)
        (3): CELU(alpha=2.0)
      )
      (spatial_att_mlp): Linear(in_features=32, out_features=4, bias=True)
      (semantic_att_mlp): Sequential(
        (0): Linear(in_features=32, out_features=4, bias=True)
        (1): CELU(alpha=2.0)
        (2): Linear(in_features=4, out_features=1, bias=True)
      )
      (rbf): Sequential(
        (0): expnorm_smearing()
        (1): Linear(in_features=18, out_feature

Create dummy molecule to input to model

Isopropyl Sulfide

In [5]:
# Set coordinates (units of angstroms)
coords = torch.Tensor([[-1.96430671,  0.31051004, -1.19996035],
                       [-0.73622882, -0.72752088, -1.01535392],
                       [-0.31321597, -1.34827518, -2.37928247],
                       [ 0.66871333, -0.02617523, -0.19206515],
                       [ 0.5306921 , -0.4202556 ,  1.58075047],
                       [-0.70359355,  0.3118152 ,  2.04930782],
                       [ 1.73722041,  0.19472905,  2.38563561],
                       [-2.5447526 , -0.15219232, -2.0220902 ],
                       [-1.59456003,  1.33184624, -1.40155053],
                       [-2.54677987,  0.50790024, -0.19426368],
                       [-1.09588408, -1.29260612, -0.15630379],
                       [ 0.71891737, -1.72426915, -2.38179493],
                       [-1.09203029, -2.00252628, -2.73346186],
                       [-0.32819545, -0.57277346, -3.17117214],
                       [ 0.64011461, -1.53141069,  1.76843905],
                       [-1.12419271,  0.97337341,  1.3624804 ],
                       [-0.48255327,  0.92818391,  2.98026752],
                       [-1.46087456, -0.40343699,  2.31385899],
                       [ 2.58227634, -0.45088139,  2.34737015],
                       [ 1.82486057,  1.24805963,  2.03895712],
                       [ 1.6380769 ,  0.33527586,  3.53223348]]) * 0.1 # convert to nm

# Set atomtypes (GAFF)
atomtypes = ['c3', 'c3', 'c3', 'ss', 'c3', 'c3', 'c3', 'hc', 'hc', 'hc', 'h1',
             'hc', 'hc', 'hc', 'h1', 'hc', 'hc', 'hc', 'hc', 'hc', 'hc']

# Create embedding for atomtypes (arbitrary embedding used here)
embedding = torch.LongTensor([0, 0, 0, 1, 0, 0, 0, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2])

# Create batch tensor
batch = torch.zeros_like(embedding)

Get energy and force prediction from SAKE

In [9]:
# Set coords to require_grad
coords.requires_grad_(True)

# Compute energies
U_pred = model(embedding, coords, batch)

# Set grad outputs
grad_outputs = [torch.ones_like(U_pred)]

# Compute forces
F_pred = -torch.autograd.grad([U_pred],
                              [coords],
                              grad_outputs=grad_outputs,
                              create_graph=True,
                              retain_graph=True
                             )[0]



Print outputs

In [11]:
print("Energy: {:.3f} kJ/mol".format(U_pred))
print("\nForces (in kJ/mol/nm):")
print(F_pred.detach())

Energy: 1.200 kJ/mol

Forces (in kJ/mol/nm):
tensor([[-0.0535,  0.0115, -0.0432],
        [-0.0242, -0.0388, -0.0592],
        [-0.0056, -0.0434, -0.0890],
        [ 0.0566,  0.0148, -0.0275],
        [ 0.0344, -0.0249,  0.0618],
        [-0.0175,  0.0082,  0.0614],
        [ 0.0699,  0.0130,  0.0801],
        [-0.0775, -0.0040, -0.0857],
        [-0.0461,  0.0396, -0.0485],
        [-0.0680,  0.0206, -0.0179],
        [-0.0142, -0.0107, -0.0024],
        [ 0.0215, -0.0518, -0.1024],
        [-0.0434, -0.0635, -0.1259],
        [-0.0198, -0.0153, -0.1329],
        [ 0.0055, -0.0134,  0.0188],
        [-0.0328,  0.0424,  0.0358],
        [ 0.0030,  0.0409,  0.0951],
        [-0.0386, -0.0005,  0.0613],
        [ 0.1095, -0.0027,  0.0976],
        [ 0.0647,  0.0499,  0.0794],
        [ 0.0762,  0.0279,  0.1435]])
