In [1]:
import torch
import numpy as np
from torch_cluster import radius_graph
from torch_scatter import scatter_add, scatter
import mdtraj
import time
import math

  from .autonotebook import tqdm as notebook_tqdm


Note, use root environment (base) for this notebook since pip doesn't appear to be installing PyTorch in a separate environment

In [2]:
#assert True == False

Check to ensure Apple Silicon GPU is available

In [2]:
# Set user
user = 'airasj'
#user = 'mitadm'

# Set pathway
chig_path = '/Users/{}/OneDrive/MIT_Research/1st_Project/Small_Protein_MD/Datasets_MD/chignolin/MD/umbrella/CHARMM/c36_gbsa_hbond_constraint'.format(user)

# Load chignolin pdb
top = mdtraj.load('{}/parm_structs/chignolin_folded_avg_struct.pdb'.format(chig_path)).topology

# Load chignolin trajectory
umb_traj = torch.Tensor(mdtraj.load_netcdf('{}/chig_c36_gbn2_hbond_constraint_umb_100thou.nc'.format(chig_path), top).xyz)[0]
umb_traj_2 = torch.cat([umb_traj, umb_traj])

# Load umbrella energies
umb_energies = torch.Tensor(np.load('{}/energies/chig_100thou_Ugbsa_umb.npy'.format(chig_path)))

# Load atomic numbers (will use them as features) and partial charges
umb_charges = torch.Tensor(np.load('/Users/{}/OneDrive/MIT_Research/1st_Project/Small_Protein_MD/Datasets_MD/chignolin/MD/explicit/chignolin_ptCharges_c36.npy'.format(user))).view(-1, 1)
umb_elements = torch.Tensor(np.load('/Users/{}/OneDrive/MIT_Research/1st_Project/Small_Protein_MD/Datasets_MD/chignolin/MD/explicit/chignolin_elementsN_c36.npy'.format(user))).view(-1, 1)

Get atom types from topology

In [3]:
umb_types = np.array([a.name for a in top.atoms])

# Set all pairs
all_types = np.array(['C', 'CA', 'CB', 'CD', 'CD1', 'CD2', 'CE', 'CE1', 'CE2', 'CE3', 
                      'CG', 'CG1', 'CG2', 'CH2', 'CZ', 'CZ2', 'CZ3', 'H', 'H1', 'H2', 
                      'H3', 'HA', 'HA2', 'HA3', 'HB', 'HB1', 'HB2', 'HB3', 'HD1', 'HD11', 
                      'HD12', 'HD13', 'HD2', 'HD21', 'HD22', 'HD23', 'HD3', 'HE', 'HE1', 'HE2',
                      'HE21', 'HE22', 'HE3', 'HG', 'HG1', 'HG11', 'HG12', 'HG13', 'HG2', 'HG21',
                      'HG22', 'HG23', 'HG3', 'HH', 'HH11', 'HH12', 'HH2', 'HH21', 'HH22', 'HZ',
                      'HZ1', 'HZ2', 'HZ3', 'N', 'ND1', 'ND2', 'NE', 'NE1', 'NE2', 'NH1', 
                      'NH2', 'NZ', 'O', 'OD1', 'OD2', 'OE1', 'OE2', 'OG', 'OG1', 'OH', 
                      'OXT', 'SD', 'SG'])

# Embed the types
from sklearn.preprocessing import LabelEncoder
embed = LabelEncoder()
embed.fit(all_types)
print(embed.transform(np.array(['CA'])))
umb_embedding = torch.LongTensor(embed.transform(umb_types))
print(umb_embedding)

[1]
tensor([63, 17, 19, 20,  1, 21,  2, 27, 26, 10,  4, 28,  7, 38, 14, 79, 53,  5,
        32,  8, 39,  0, 72, 63, 17,  1, 21,  2, 27, 26, 10,  4, 28,  7, 38, 14,
        79, 53,  5, 32,  8, 39,  0, 72, 63, 17,  1, 21,  2, 27, 26, 10, 73, 74,
         0, 72, 63,  3, 36, 32,  1, 21,  2, 27, 26, 10, 52, 48,  0, 72, 63, 17,
         1, 21,  2, 27, 26, 10, 52, 48,  3, 75, 76,  0, 72, 63, 17,  1, 21,  2,
        24, 78, 44, 12, 49, 50, 51,  0, 72, 63, 17,  1, 23, 22,  0, 72, 63, 17,
         1, 21,  2, 24, 78, 44, 12, 49, 50, 51,  0, 72, 63, 17,  1, 21,  2, 27,
        26, 10,  4, 28, 67, 38,  8,  5,  9, 42, 16, 62, 15, 61, 13, 56,  0, 72,
         0, 72, 80, 63, 17,  1, 21,  2, 27, 26, 10,  4, 28,  7, 38, 14, 79, 53,
         5, 32,  8, 39])


In [4]:
# Create 2 frame sets
coords_2 = torch.Tensor(mdtraj.load_netcdf('{}/chig_c36_gbn2_hbond_constraint_umb_100thou.nc'.format(chig_path), top).xyz)[:2]

coords_2_list = []
for coord in coords_2:
    coords_2_list.append(coord)
coords_2 = torch.cat(coords_2_list)

batch_2_list = []
for b in range(0,2):
    batch_2 = torch.zeros(len(umb_charges), dtype=torch.long)+b
    batch_2_list.append(batch_2)
    
batch_2 = torch.cat(batch_2_list)

#charge_pairs_2 = torch.cat([charge_pairs, charge_pairs])
umb_elements_2 = torch.cat([umb_elements, umb_elements])
umb_embedding_2 = torch.cat([umb_embedding, umb_embedding])

In [5]:
# Set batch tensor
batch = torch.cat([torch.zeros_like(umb_elements)+i for i in range(1)]).squeeze().type(torch.LongTensor)
print(batch)

# Compute edges for test frame
edges = radius_graph(umb_traj, r=1, max_num_neighbors=100000)
edges_2 = radius_graph(coords_2, r=1, batch=batch_2, max_num_neighbors=100000)
edges_3 = radius_graph(coords_2[166:], r=1, batch=batch, max_num_neighbors=100000)
print(edges)


# Compute charge pairs
#charge_pairs = umb_charges[edges[0]]*umb_charges[edges[1]]

#umb_elements = umb_elements.to(mps_device)
#coords = umb_traj.to(mps_device)
#edges = edges.to(mps_device)
#charge_pairs = charge_pairs.to(mps_device)
#batch = torch.zeros(len(coords), dtype=torch.long).to(mps_device)
#batch = torch.zeros(len(umb_traj), dtype=torch.long)

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([[  3,   4,   5,  ..., 121,  30,  39],
        [  0,   0,   0,  ..., 165, 165, 165]])


Test function for extracting edge batches, this can be used relatively quickly for normalization

### Test code from SAKE.py

In [6]:
import sys
import matplotlib.pyplot as plt
from collections import OrderedDict
import re
import Schake_CA_model_v1 as Schake

In [8]:
assert True == False

AssertionError: 

Build the Schake model

In [14]:
sake_modular = Schake.create_Schake(hidden_channels = 32, 
                                    num_layers = 4, 
                                    kernel_size = 18,
                                    cosine_offset = 0.5,
                                    sake_low_cut = 0, 
                                    sake_high_cut = 0.5, 
                                    schnet_low_cut = 0.5, 
                                    schnet_high_cut = 2.5, 
                                    schnet_act = torch.nn.CELU(alpha=2.0), 
                                    sake_act = torch.nn.CELU(alpha=2.0), 
                                    out_act = torch.nn.CELU(alpha=2.0),
                                    schnet_sel = 1,
                                    #schnet_sel = None,
                                    num_heads = 4, 
                                    embed_type = 'names', 
                                    num_out_layers = 3,
                                    max_num_neigh = 10000,
                                    normalize = False,
                                    device = 'cpu')

# Define function to edit a state dict
def edit_stateDict(state_dict):
        
    # Define new dictionary
    mod_dict = OrderedDict()
        
    # Define pattern to strip
    pattern = re.compile('module.')
        
    # Loop through loaded state dict
    for key, value in state_dict.items():
        if re.search('module.', key):
            mod_dict[re.sub(pattern, '', key)] = value
        else:
            mod_dict = state_dict

    return mod_dict

# Set statedict path
stateDict_path = '/Users/airasj/OneDrive/MIT_Research/2nd_Project/Datasets/SwissProt/train_models/Schake/testing/energies_only/res400_8GPUs_4node_varBatch_32w_1nm/stateDicts/Schake_af2SP_epoch140.pt'

# Load state dict for the model (parameters obviously won't be perfect matches for testing)
stateDict = torch.load(stateDict_path, map_location = torch.device('cpu'))
        
# Modify the stateDict with correct names
mod_stateDict = edit_stateDict(stateDict)
        
# Apply stateDict
sake_modular.load_state_dict(mod_stateDict)

<All keys matched successfully>

Test modular SAKE model

In [16]:
# Get predictions
with torch.no_grad():
    start = time.time()
    #out = sake_modular(umb_elements.type(torch.LongTensor).squeeze(), umb_traj, batch=batch)
    out = sake_modular(umb_embedding, umb_traj, batch=batch)
    end = time.time()
    print('------------------------')
    print('Time elapsed: {:.3f} sec'.format(end-start))
    print(out)
    
    start = time.time()
    #out_2 = sake_modular(umb_elements_2.type(torch.LongTensor).squeeze(), coords_2, batch=batch_2)
    out_2 = sake_modular(umb_embedding_2, coords_2, batch=batch_2)
    end = time.time()
    print('------------------------')
    print('Time elapsed: {:.3f} sec'.format(end-start))
    print(out_2)
    
    start = time.time()
    #out_3 = sake_modular(umb_elements.type(torch.LongTensor).squeeze(), coords_2[166:], batch=batch)
    out_3 = sake_modular(umb_embedding, coords_2[166:], batch=batch)
    end = time.time()
    print('------------------------')
    print('Time elapsed: {:.3f} sec'.format(end-start))
    print(out_3)

------------------------
Time elapsed: 0.062 sec
tensor(-2448.3525)
------------------------
Time elapsed: 0.061 sec
tensor([-2448.3525, -4832.5327])
------------------------
Time elapsed: 0.036 sec
tensor(-4832.5327)


#### Examine outputs

In [None]:
out[2]

In [None]:
umb_embedding[out[2]]

In [None]:
umb_embedding[25]

In [None]:
print(umb_types[101])
print(umb_types[0])

In [None]:
import matplotlib.pyplot as plt
plt.plot(out[2][100])

In [None]:
print(out[1].min().sqrt())
print(out[1].max().sqrt())

In [None]:
print(out[1].shape)

In [None]:
print(out[2].min())
print(out[2].max())

In [None]:
print(out[2].shape)

Examine RBFs

In [None]:
out[3][0]

In [None]:
out[4][0]

In [None]:
assert True == False

In [None]:
#np.save('rbfs/rbf_expnorm_sqrt_diff_end.npy', out[4].numpy(), allow_pickle=False)

In [None]:
np.save('rbfs/SAKE_rbf_expnorm_sqrt_1nm.npy', out[3].numpy(), allow_pickle=False)

In [None]:
out[0]

#### Note on RBF and cosine cutoff functions

- Cosine cutoff
    - SAKE -> the cosine cutoff starts at 0, ends at fixed point of $2 r_\mathrm{cut}^\mathrm{SAKE}$. B/c the SAKE cutoff is set to 0.5 nm here, the minimum value for the cosine cutoff is 0.5. B/c the cutoff now needs to closely align with the offsetted cutoff for the SAKE layer, we simply input $r$ rather than $r^2$ for simplicity. The function is as follows:
    $\frac{1}{2} \left( \cos \left[ \frac{\pi r}{2 r_\mathrm{cut}^\mathrm{SAKE}} \right] + 1 \right)$
    - SchNet -> B/c the cutoff distance for SchNet is so much larger now, can't use the same cutoff function as from SAKE. As such, the function has been modified as follows: 
    $ \frac{1}{4} \left( \cos \left[ \frac{\pi \left( r - \mathrm{offset} \right)}{r_\mathrm{cut}^\mathrm{SchNet} - \mathrm{offset}} \right] + 1 \right)$.
    This modification provides more coverage of the entire range of alpha carbons that SchNet can process.

- RBF function
    - SAKE -> like with the original Schake, the expnorm function is used with a start of 0 nm, and a fixed endpoint of $2 \left( r_\mathrm{cut}^\mathrm{SAKE} \right)^2$ (despite the fact that this point will never be reached). $r^2$ is inputted to the model rather than $r$, as in the original formulation. Note that setitng the endpoint to $2 \left( r_\mathrm{cut}^\mathrm{SAKE} \right)^2$ ensures that the RBF will behave consistently with the original Schake
    - SchNet -> like with the original Schake, the expnorm function is used with a start of 0 nm (despite the fact that this 0 nm point will never be reached) and an endpoint of $r^\mathrm{SchNet}_\mathrm{cut}$. Like the original Schake, $r$ is inputted to the function.

In [None]:
r = 1
0.5*(torch.cos(torch.tensor(math.pi)*r)+1)

#### Compute number of pairs for benchmark dataset

Goal here is to examine the total number of atom pairs produced from the benchmark dataset so that we can try and determine a good cutoff distance to use for the model

In [27]:
import sys
sys.path.append('/Users/airasj/OneDrive/MIT_Research/2nd_Project/Datasets/SwissProt/python')
import dataset_prep

# Load dataset
bench_dataset = torch.load('/Users/airasj/OneDrive/MIT_Research/2nd_Project/Figures/benchmark_figures/SPAF_dataset_bench_big.pt')

# Loop through dataset, get num of atoms
num_atoms, entries = [], []
for data in bench_dataset:
    num_atoms.append(data[-2].item())
    entries.append(data[-1].item())
    
# Conver to numpy array
num_atoms = np.array(num_atoms)
entries = np.array(entries)

Loop through the benchmarking dataset to get number of pairs

In [30]:
def coord2radial(edge_index, coord):
    row, col = edge_index
    coord_diff = coord[row] - coord[col]
    radial = torch.sum(coord_diff**2, 1).unsqueeze(1)

    return radial, coord_diff

# SAKE cutoff
sake_cut = 0.5  # -> this is 0.5 nm
schnet_cut = 2.5
#h_sel = 1
h_sel = None

sake_pairs, schnet_pairs = [], []
other_schnet_pairs = []
for data in bench_dataset:
# Generate adjacency lists
    edges = radius_graph(data[0], 
                         r=schnet_cut, # This can be different depending on which model
                         batch=torch.zeros(data[-2], dtype=torch.long), 
                         max_num_neighbors=10000
                        )

    # Compute pairwise distances and edge vectors
    radial, coord_diff = coord2radial(edges, data[0])

    # Compute distances (sqrt of radial)
    dist = torch.sqrt(radial)

    # Filter edges, coord_diff, radial, dist, rbf based on individual cutoffs
    sake_mask = torch.where((dist < sake_cut) & (dist >= 0))[0]
    schnet_mask = torch.where((dist >= sake_cut) & (dist <= schnet_cut))[0]

    # Reshape the edges, extract only necessary edges for each model
    sake_edges = edges.T[sake_mask].T
    schnet_edges = edges.T[schnet_mask].T

    # Extract radial, coord_diff for SAKE pairs only
    sake_radial, sake_coord_diff = radial[sake_mask], coord_diff[sake_mask]

    # Extract distance for SchNet pairs only
    schnet_dist = dist[schnet_mask]
    
    if h_sel != None:
        # For SchNet pairs, create adjacency list in terms of species
        h_schnet_edges = data[-3][schnet_edges]

        # Filter SchNet edges to only include atom type of interest
        h_mask = torch.where(h_schnet_edges[0] == h_sel)[0]
        schnet_edges = schnet_edges.T[h_mask].T
    
    # Append number of pairs to list
    sake_pairs.append(sake_edges.shape[-1])
    schnet_pairs.append(schnet_edges.shape[-1])
    # Check to see number of pairs when other adjacency list used
    other_schnet_pairs.append(torch.where(h_schnet_edges[1] == 1)[0].shape[-1])

In [31]:
start = 510
stop = start + 10
print(np.array(schnet_pairs)[start:stop])
print(np.array(other_schnet_pairs)[start:stop])

[27864150 29276034  5740080 34216242 33934474 27630380 33760678 17072840
 32863462 20193182]
[1368804 1368804 1368804 1368804 1368804 1368804 1368804 1368804 1368804
 1368804]


In [13]:
np.array_equal(np.array(schnet_pairs), np.array(other_schnet_pairs))

True

As structures get bigger, there starts to be a discrepancy between the adjacency lists. This isn't a major problem, but it is interesting

Save the number of pairs

In [14]:
np.save('schake_pair_analysis/schake_sake_pairs.npy', np.array(sake_pairs))
np.save('schake_pair_analysis/schake_schnet_pairs.npy', np.array(schnet_pairs))

In [33]:
np.save('schake_pair_analysis/num_atoms.npy', num_atoms)

#### Get info to create figure displaying how Schake-C$\alpha$ works

Get list of pairs for A1W0R3

In [12]:
torch.zeros(bench_dataset[293][-2], dtype=torch.long)

tensor([0, 0, 0,  ..., 0, 0, 0])

In [13]:
with torch.no_grad():
    start = time.time()
    out = sake_modular(bench_dataset[293][-3], bench_dataset[293][0], batch=torch.zeros(bench_dataset[293][-2], dtype=torch.long))
    end = time.time()
    print('------------------------')
    print('Time elapsed: {:.3f} sec'.format(end-start))
    print(out[0])

------------------------
Time elapsed: 3.468 sec
tensor(49464496.)


#### Create pair fraction figure

Load the pairs