In [1]:
import torch
import numpy as np
import sys
import Schake_model_v2 as Schake
import mdtraj

Check if cuda available

In [2]:
cuda_avail = torch.cuda.is_available()
device = ['cuda' if cuda_avail else 'cpu'][0]

Create energy model

In [3]:
# Set energy model
E_model = Schake.create_Schake(hidden_channels = 32, 
                                  num_layers = 2, 
                                  kernel_size = 18, 
                                  neighbor_embed = 'resid_bb3',
                                  sake_low_cut = 0.25, 
                                  sake_high_cut = 1, 
                                  schnet_low_cut = 1, 
                                  schnet_high_cut = 2.5, 
                                  schnet_act = torch.nn.CELU(2), 
                                  sake_act = torch.nn.CELU(2),
                                  out_act = torch.nn.Tanh(), 
                                  max_num_neigh = 10000, 
                                  schnet_sel = 1, 
                                  trainable_sake_kernel = False, 
                                  trainable_schnet_kernel = False, 
                                  num_heads = 4, 
                                  embed_type = 'elements', 
                                  num_out_layers = 3, 
                                  device = device,
                                  single_pro = True,
                                  return_logits = False,
                                  energy_func='ms')

# Load state dict
state_dict = torch.load('Schake_trained_weights.pt',
                        map_location=torch.device(device)
                       )

# Edit state dict
new_dict = Schake._SP_state_dict(state_dict)
E_model.load_state_dict(new_dict)
E_model

Schake_modular_Zs_SP(
  (pseudo_energy): Schake_modular_Zs(
    (embedding_in): ModuleList(
      (0): Embedding(20, 16)
      (1): Embedding(64, 16)
    )
    (embedding_out): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    (sake_rbf_func): expnorm_smearing()
    (schnet_rbf_func): expnorm_smearing()
    (sake_layers): ModuleList(
      (0-1): 2 x SAKELayer(
        (edge_mlp): Sequential(
          (0): Linear(in_features=97, out_features=32, bias=True)
          (1): CELU(alpha=2)
          (2): Linear(in_features=32, out_features=32, bias=True)
          (3): CELU(alpha=2)
        )
        (node_mlp): Sequential(
          (0): Linear(in_features=96, out_features=32, bias=True)
          (1): CELU(alpha=2)
          (2): Linear(in_features=32, out_features=32, bias=True)
          (3): CELU(alpha=2)
        )
        (spatial_att_mlp): Linear(in_features=32, out_features=4, bias=True)
        (semantic_att_mlp): Sequential(
          (0): Linear(in_features=32, out_featur

Load data for target protein

In [4]:
# Load model inputs
pdb_name = 'md_structs/Home.pdb'
inputs = Schake._load_mol(pdb_name, flt_grp='bb3')

# Extract important data
names = torch.LongTensor(inputs[0])
aas = torch.LongTensor(inputs[1])
coords = torch.Tensor(inputs[2])
batch = torch.LongTensor(inputs[3])
ss8_ohv = torch.Tensor(inputs[4])
ca_idxs = torch.LongTensor(inputs[5])

# Move to GPU if available
if cuda_avail:
    names = names.cuda()
    aas = aas.cuda()
    coords = coords.cuda()
    batch = batch.cuda()
    ss8_ohv = ss8_ohv.cuda()
    ca_idxs = ca_idxs.cuda()
    
# Load info into the energy model
temp = 1 # set in OpenMM instead
E_model.load_mol_info(names, aas, ss8_ohv, batch, ca_idxs, temp = temp)

In [5]:
E_model.kBT

tensor(0.0083)

Extract PDB coords

In [6]:
pdb = mdtraj.load_pdb(pdb_name)
pdb_coord = torch.Tensor(pdb.xyz[0])

if cuda_avail:
    pdb_coord = pdb_coord.cuda()

Get output from model

In [7]:
with torch.no_grad():
    E = E_model(pdb_coord)
    print(E)

tensor(0.3918)


If applicable, print probabilities

In [8]:
try:
    print(E[1].softmax(-1)*ss8_ohv)
except:
    print('Logits not returned.')

Logits not returned.


Test force output

In [9]:
pdb_coord.requires_grad_(True)
E = E_model(pdb_coord)

# Set grad outputs
grad_outputs = [torch.ones_like(E)]

# Compute forces
F = -torch.autograd.grad([E],
                         [pdb_coord],
                         grad_outputs=grad_outputs,
                         create_graph=False,
                         retain_graph=False
                        )[0]

In [10]:
pdb_coord[ca_idxs]

tensor([[ 0.8529,  1.3238,  0.9905],
        [ 0.8518,  1.2151,  0.8892],
        [ 0.7571,  1.2472,  0.7727],
        [ 0.7596,  1.1685,  0.6618],
        [ 0.6633,  1.1828,  0.5535],
        [ 0.5379,  1.0997,  0.5671],
        [ 0.5551,  0.9709,  0.6030],
        [ 0.4500,  0.8758,  0.6425],
        [ 0.4385,  0.8579,  0.7944],
        [ 0.3206,  0.8220,  0.8422],
        [ 0.2953,  0.8195,  0.9854],
        [ 0.3467,  0.6952,  1.0476],
        [ 0.4208,  0.7063,  1.1541],
        [ 0.4605,  0.5895,  1.2307],
        [ 0.4607,  0.6220,  1.3797],
        [ 0.3414,  0.6567,  1.4326],
        [ 0.3193,  0.7177,  1.5673],
        [ 0.3847,  0.6668,  1.6935],
        [ 0.3808,  0.5296,  1.7115],
        [ 0.4206,  0.4695,  1.8373],
        [ 0.5685,  0.4374,  1.8243],
        [ 0.6278,  0.4485,  1.7039],
        [ 0.7782,  0.4183,  1.6877],
        [ 0.8485,  0.5538,  1.6909],
        [ 0.7843,  0.6610,  1.6337],
        [ 0.8302,  0.7997,  1.6296],
        [ 0.8857,  0.8536,  1.7583],
 

In [11]:
F[ca_idxs]

tensor([[ 1.4575e-01,  1.8658e-01,  5.3599e-01],
        [-3.2425e-01, -8.5129e-02, -9.5306e-01],
        [ 1.0503e-01,  4.8757e-02,  5.1711e-02],
        [-3.4708e-01, -6.0206e-01, -4.8584e-01],
        [ 7.0381e-01,  4.0217e-01,  5.1811e-01],
        [-4.0688e-02, -2.1268e-01, -3.9972e-02],
        [ 2.0603e-01,  9.6857e-02, -7.5215e-02],
        [ 1.8231e-01,  4.4876e-01, -3.0594e-01],
        [-1.1652e-01, -1.0415e-01,  1.9226e-01],
        [-4.5313e-02,  4.2904e-02, -5.2270e-02],
        [-5.1524e-01, -1.0201e-01,  1.5420e-01],
        [ 1.6108e-01,  5.7711e-02, -7.5600e-02],
        [ 1.1591e-02,  5.0106e-02,  1.7541e-01],
        [ 1.0269e-01, -3.8685e-01,  3.5890e-01],
        [ 4.7330e-02,  6.0918e-02,  4.6103e-01],
        [ 3.6434e-01, -7.2732e-02,  2.7346e-01],
        [ 4.2195e-02, -1.0380e-01, -1.7019e-02],
        [ 3.8959e-01,  6.2036e-02,  2.9065e-02],
        [ 2.9972e-02,  2.0848e-01, -9.5175e-03],
        [ 1.2507e-01, -1.5330e-01, -3.4507e-01],
        [ 3.7053e-01

Save the traced model

In [13]:
traced_model = torch.jit.trace(E_model, (pdb_coord))

Check output

In [14]:
traced_model(pdb_coord)

tensor(0.3918, grad_fn=<MulBackward0>)

If desired, save traced model

In [15]:
torch.jit.save(traced_model, 'Home_model_{}K.pt'.format(temp))