This notebook shows how to easily create a SAKE model using the provided code

In [1]:
import torch
import sys
sys.path.append(".")  # Set pathway here
import Schake_model_v1 as Schake



Run code to create Schake model

In [2]:
model = Schake.create_Schake(hidden_channels = 32, 
                             num_layers = 4, 
                             kernel_size = 18,
                             cosine_offset = 0.5,
                             sake_low_cut = 0, 
                             sake_high_cut = 0.5, 
                             schnet_low_cut = 0.5, 
                             schnet_high_cut = 2.5, 
                             schnet_act = torch.nn.CELU(alpha=2.0), 
                             sake_act = torch.nn.CELU(alpha=2.0), 
                             out_act = torch.nn.CELU(alpha=2.0),
                             schnet_sel = 1,      # Selects atomtype matching embedding 1 
                             #schnet_sel = None,  # None avoids filtering
                             num_heads = 4, 
                             embed_type = 'names', 
                             num_out_layers = 3,
                             max_num_neigh = 10000,
                             normalize = False,
                             device = 'cpu')

In [3]:
model

Schake_modular(
  (embedding_in): Embedding(83, 32)
  (embedding_out): Linear(in_features=32, out_features=32, bias=True)
  (sake_rbf_func): expnorm_smearing()
  (schnet_rbf_func): expnorm_smearing()
  (sake_layers): ModuleList(
    (0-3): 4 x SAKELayer(
      (edge_mlp): Sequential(
        (0): Linear(in_features=97, out_features=32, bias=True)
        (1): CELU(alpha=2.0)
        (2): Linear(in_features=32, out_features=32, bias=True)
        (3): CELU(alpha=2.0)
      )
      (node_mlp): Sequential(
        (0): Linear(in_features=96, out_features=32, bias=True)
        (1): CELU(alpha=2.0)
        (2): Linear(in_features=32, out_features=32, bias=True)
        (3): CELU(alpha=2.0)
      )
      (spatial_att_mlp): Linear(in_features=32, out_features=4, bias=True)
      (semantic_att_mlp): Sequential(
        (0): Linear(in_features=32, out_features=4, bias=True)
        (1): CELU(alpha=2.0)
        (2): Linear(in_features=4, out_features=1, bias=True)
      )
      (rbf_model): Lin

Create dummy molecule to input to model

Isopropyl Sulfide

In [4]:
# Set coordinates (units of angstroms)
coords = torch.Tensor([[-1.96430671,  0.31051004, -1.19996035],
                       [-0.73622882, -0.72752088, -1.01535392],
                       [-0.31321597, -1.34827518, -2.37928247],
                       [ 0.66871333, -0.02617523, -0.19206515],
                       [ 0.5306921 , -0.4202556 ,  1.58075047],
                       [-0.70359355,  0.3118152 ,  2.04930782],
                       [ 1.73722041,  0.19472905,  2.38563561],
                       [-2.5447526 , -0.15219232, -2.0220902 ],
                       [-1.59456003,  1.33184624, -1.40155053],
                       [-2.54677987,  0.50790024, -0.19426368],
                       [-1.09588408, -1.29260612, -0.15630379],
                       [ 0.71891737, -1.72426915, -2.38179493],
                       [-1.09203029, -2.00252628, -2.73346186],
                       [-0.32819545, -0.57277346, -3.17117214],
                       [ 0.64011461, -1.53141069,  1.76843905],
                       [-1.12419271,  0.97337341,  1.3624804 ],
                       [-0.48255327,  0.92818391,  2.98026752],
                       [-1.46087456, -0.40343699,  2.31385899],
                       [ 2.58227634, -0.45088139,  2.34737015],
                       [ 1.82486057,  1.24805963,  2.03895712],
                       [ 1.6380769 ,  0.33527586,  3.53223348]]) * 0.1 # convert to nm

# Set atomtypes (GAFF)
atomtypes = ['c3', 'c3', 'c3', 'ss', 'c3', 'c3', 'c3', 'hc', 'hc', 'hc', 'h1',
             'hc', 'hc', 'hc', 'h1', 'hc', 'hc', 'hc', 'hc', 'hc', 'hc']

# Create embedding for atomtypes (arbitrary embedding used here)
embedding = torch.LongTensor([0, 0, 0, 1, 0, 0, 0, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2])

# Create batch tensor
batch = torch.zeros_like(embedding)

Get energy and force prediction from SAKE

In [5]:
# Set coords to require_grad
coords.requires_grad_(True)

# Compute energies
U_pred = model(embedding, coords, batch)

# Set grad outputs
grad_outputs = [torch.ones_like(U_pred)]

# Compute forces
F_pred = -torch.autograd.grad([U_pred],
                              [coords],
                              grad_outputs=grad_outputs,
                              create_graph=True,
                              retain_graph=True
                             )[0]



Print outputs

In [6]:
print("Energy: {:.3f} kJ/mol".format(U_pred))
print("\nForces (in kJ/mol/nm):")
print(F_pred.detach())

Energy: 3.917 kJ/mol

Forces (in kJ/mol/nm):
tensor([[-4.1480e-04, -1.5626e-03, -6.9167e-03],
        [ 2.7790e-03,  8.2153e-03,  7.2570e-03],
        [-7.6717e-03,  4.6120e-03, -5.8667e-03],
        [-1.9928e-02, -8.3908e-03,  1.2441e-02],
        [-3.6314e-03,  5.3461e-03, -6.5350e-03],
        [ 8.7607e-03,  3.2686e-05, -7.3707e-03],
        [-2.5678e-03,  1.9912e-03,  7.2989e-03],
        [ 9.7536e-03, -1.4699e-05,  1.0006e-03],
        [ 3.4913e-04, -8.1564e-03,  4.3534e-04],
        [ 1.1763e-02, -1.4022e-03, -1.4344e-03],
        [-1.2447e-02, -1.2305e-02, -3.0892e-03],
        [-6.5812e-03,  5.3779e-03, -4.5154e-03],
        [ 2.0958e-03,  1.2365e-02,  6.5001e-03],
        [-6.5587e-03,  7.1730e-04,  3.8638e-03],
        [ 1.2421e-02, -3.8652e-03,  1.8774e-02],
        [ 1.0793e-02, -1.4696e-02, -1.1709e-02],
        [ 5.9829e-03,  4.3622e-04, -1.2868e-03],
        [ 1.2558e-02,  7.4629e-03, -1.0459e-02],
        [-6.3799e-03,  5.7675e-03,  3.8518e-03],
        [-5.3823e-03, -2