In [2]:
%load_ext autoreload
%autoreload 2
import sys
import h5py 
import pandas as pd
import numpy as np
import torch 
sys.path.insert(0, "../examples")
sys.path.insert(0, "data/components/")
from QMmodel import GNN_QM
from MDmodel import GNN_MD
from data.components.transformQM import GNNTransformQM
from data.components.transformMD import GNNTransformMD
from data.processing.inference_QM import main




## Creation H5 file from a ligand pdbid

We want to run inference on a new structure from PDB. It is either possible to provide a already downloaded fileName or to just give the pdbid and it will be downloaded automatically. (If you run the script directly in the terminal just give the keywords in the promt)

In [2]:
class Args:
  pdbid = "vww"
  fileName = None
  datasetOutName = 'inference_for_qm.hdf5'
args=Args()

In [3]:
main(args)

reading vww.sdf


## Prediction of Ionization potential and Hardness by our model

We load the created h5 file and store the elements and coordinates in a dataframe.

In [4]:
qmh5_file = "inference_for_qm.hdf5"
qm_H5File = h5py.File(qmh5_file)

In [5]:
column_names = ["x", "y", "z", "element"]
atoms = pd.DataFrame(columns = column_names)

prop = qm_H5File["vww"]["atom_properties"]["atom_properties_values"]
atoms["x"] = prop[:,0].astype(np.float32)
atoms["y"] = prop[:,1].astype(np.float32)
atoms["z"] = prop[:,2].astype(np.float32)
        
atoms["element"] = np.array([element for element in qm_H5File['vww']['atom_properties']['atoms_names'][:]])


In [6]:
item = {
    "atoms" : atoms,
    "labels": 0,
    "bonds": None, 
    "id": "vww"
}

transform = GNNTransformQM()
data_item = transform(item)

We run inference using cpu.

In [8]:
model = GNN_QM(data_item.num_features, 64)
cpt = torch.load("../examples/logs/QM_latest/best_weights_rep0.pt", map_location=torch.device('cpu'))["model_state_dict"]
model.load_state_dict(cpt)
model.eval()

GNN_QM(
  (lin0): Linear(in_features=25, out_features=64, bias=True)
  (conv): NNConv(64, 64, aggr=mean, nn=Sequential(
    (0): Linear(in_features=1, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=4096, bias=True)
  ))
  (gru): GRU(64, 64)
  (set2set): Set2Set(64, 128)
  (lin1): Linear(in_features=128, out_features=64, bias=True)
  (lin2): Linear(in_features=64, out_features=2, bias=True)
)

In [9]:
# predict with the model
y_hat = model(data_item)

In [10]:
y_hat

tensor([ 0.0480, -0.0375], grad_fn=<ViewBackward0>)

## Creating H5 file for a protein-ligand complex

Similar to the ligand case we download a pdb file, convert it to amber format and store it in an h5 file. For this step you need to have installed ambertools so you might have to switch the conda env.

In [1]:
from data.processing.pdb_to_h5 import main

In [4]:
mdh5_file = "inference_for_md.hdf5"

In [3]:
class Args:
  pdbid = "11GS"
  fileName = None
  mapPath = "data/processing/Maps/"
  mask = "!@H=" # no Hydrogens, see https://amberhub.chpc.utah.edu/atom-mask-selection-syntax/
  datasetOutName = mdh5_file
args=Args()

In [4]:
main(args)

11GS/11GS.pdb was created. Please always use this file for inspection because the coordinates might get translated during amber file generation and thus might vary from the input pdb file.
The following trajectory was created: pytraj.TrajectoryIterator, 1 frames: 
Size: 0.000146 (GB)
<Topology: 6534 atoms, 416 residues, 2 mols, non-PBC>
           
molecule begin atom index [0, 1631, 3262] [1631, 1631]


## Prediction of adaptability by our model

In [5]:
# switch to misato env if not running from container

md_H5File = h5py.File(mdh5_file)

column_names = ["x", "y", "z", "element"]
atoms_protein = pd.DataFrame(columns = column_names)
cutoff = md_H5File["11GS"]["molecules_begin_atom_index"][:][-1] # cutoff defines protein atoms

atoms_protein["x"] = md_H5File["11GS"]["atoms_coordinates_ref"][:][:cutoff, 0]
atoms_protein["y"] = md_H5File["11GS"]["atoms_coordinates_ref"][:][:cutoff, 1]
atoms_protein["z"] = md_H5File["11GS"]["atoms_coordinates_ref"][:][:cutoff, 2]

atoms_protein["element"] = md_H5File["11GS"]["atoms_element"][:][:cutoff]  

item = {}
item["scores"] = 0
item["id"] = "11GS"
item["atoms_protein"] = atoms_protein

transform = GNNTransformMD()
data_item = transform(item)



In [6]:
 md_H5File["11GS"]["molecules_begin_atom_index"][:]

array([   0, 1631, 3262])

In [7]:
import torch 
model = GNN_MD(data_item.num_features, 64)

cpt = torch.load("../examples/logs/MD_latest/best_weights_rep0.pt", map_location=torch.device('cpu'))["model_state_dict"]

model.load_state_dict(cpt)

model.eval()

GNN_MD(
  (conv1): GCNConv(11, 64)
  (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): GCNConv(64, 128)
  (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): GCNConv(128, 256)
  (bn3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): GCNConv(256, 256)
  (bn4): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv5): GCNConv(256, 512)
  (bn5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=512, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=1, bias=True)
)

In [8]:
model(data_item).shape

torch.Size([3262])

In [9]:
model(data_item)

tensor([4.2869, 4.5388, 3.8849,  ..., 1.5878, 1.5421, 1.8159],
       grad_fn=<ViewBackward0>)

## Visualization of Adaptability

In [1]:
# switch to ambertools env if not running from container
import nglview as nv
import pytraj as pt
import os




In [2]:
def add_opacity_to_spheres(num_spheres, opacity):
    for i in range(num_spheres):
        view.update_representation(component=view.n_components-i, opacity=opacity)
              
def show_ada_spheres(indices, prediction, color, radiusFactor):
    for i in range(len(indices)):
        pred_mask = '@'+str(indices[i]+1)
        sh_value = prediction[indices[i]]
        x,y,z = traj[pred_mask].xyz[:,:,:][0][0]
        view.shape.add_sphere([x, y, z], color, prediction[indices[i]]/radiusFactor)

def get_entries(struct, f):
    atoms_coordinates_ref = f.get(struct+'/'+'trajectory_coordinates')[0]
    atoms_element = f.get(struct+'/'+'atoms_element') 
    atoms_ada = f.get(struct+'/'+'atoms_feature_adaptability')
    return atoms_coordinates_ref, atoms_element

def get_entries_ada(struct, f):
    atoms_coordinates_ref = f.get(struct+'/'+'atoms_coordinates_ref')
    atoms_element = f.get(struct+'/'+'atoms_element') 
    atoms_ada = f.get(struct+'/'+'atoms_feature_adaptability')
    return atoms_coordinates_ref, atoms_element, atoms_ada

##
## Not needed
##
        
def show_ada_r_spheres_color_conversion(indices, ada_type, color, radiusFactor, index_conversion):
    for i in indices:
        pred_mask = '@'+str(i+1)
        sh_value = prediction[ada_type][index_conversion[i]]
        x,y,z = traj[pred_mask].xyz[:,:,:][0][0]
        view.shape.add_sphere([x, y, z], color, prediction[sh_type][index_conversion[i]]/radiusFactor)
        
def get_index_conversion(atoms_element, atoms_coordinates_ref):
    index_conversion = {}
    noh_indices = np.where(atoms_element[:]!=1)[0] # change if not hydrogen
    #h_indices = np.where(atoms_element[:]=1)[0]
    equivalent_noh_index = 0
    for all_atom_index in range(np.shape(atoms_coordinates_ref)[0]):
        if all_atom_index in noh_indices:
            index_conversion[equivalent_noh_index]=all_atom_index
            equivalent_noh_index +=1
    return index_conversion

def get_values_from_atomIndices(prediction, indices, ada_type, index_conversion):
    values = []
    for index in indices:
        values.append(prediction[ada_type][index_conversion[index]])
    return values

def convert_indices(indices, index_conversion):
    values = []
    for index in indices:
        values.append(index_conversion[index])
    return values


### OUT

We need to load the h5 file with hydrogens and the h5 file with the hydrogens stripped (noh) after processing so that we assign the correct atom indices for the pdb file that we want to visualize. 

In [14]:
f_ada = h5py.File('../data/MD/h5_files/tiny_md.hdf5', 'r')
f_ada_noh = h5py.File(mdh5_file, 'r')

In [15]:
print(np.shape(f_ada['11GS']['trajectory_coordinates'][0]))
print(f_ada_noh['11GS']['atoms_coordinates_ref'])

(6600, 3)
<HDF5 dataset "atoms_coordinates_ref": shape (3262, 3), type "<f8">


In [16]:
struct = '11GS'
atoms_coordinates_ref, atoms_element = get_entries(struct, f_ada)
atoms_coordinates_ref_noH, atoms_element_noH, feature_ada_noh = get_entries_ada(struct, f_ada_noh)
index_conversion = get_index_conversion(atoms_element, atoms_coordinates_ref)
inverse_index_conversion= {value:key for key,value in index_conversion.items()}

In [None]:
# we use the MISATO colors
cm = nv.color.ColormakerRegistry
cm.add_scheme_func('dgreenC','''
 this.atomColor = function (atom) {
     if (atom.element == "C") {
       return 0x60a854 // C
     } else if (atom.element == "H") {
       return 0xecf0f1
     } else if (atom.element == "S") {
       return 0xf1c40f
     } else if (atom.element == "N") {
       return 0x2980b9
     } else if (atom.element == "O") {
       return 0xFF0D0D
     }
 }
''')
cm.add_scheme_func('lgreenC','''
 this.atomColor = function (atom) {
     if (atom.element == "C") {
       return 0x89e876 // C
     } else if (atom.element == "H") {
       return 0xecf0f1
     } else if (atom.element == "S") {
       return 0xf1c40f
     } else if (atom.element == "N") {
       return 0x2980b9
     } else if (atom.element == "O") {
       return 0xFF0D0D
     }
 }
''')

### OUT

In [17]:
prediction = pd.DataFrame(model(data_item).detach().numpy(), columns = ['prediction'])

In [18]:
f_ada_noh['11GS'].keys()


<KeysViewHDF5 ['atoms_coordinates_ref', 'atoms_element', 'atoms_number', 'atoms_residue', 'atoms_type', 'molecules_begin_atom_index']>

In [19]:
prediction = pd.DataFrame(model(data_item).detach().numpy(), columns = ['prediction'])

In [20]:
prediction

Unnamed: 0,prediction
0,4.251047
1,4.606583
2,4.346603
3,3.720964
4,4.035377
...,...
3257,2.260487
3258,2.674389
3259,1.613312
3260,1.525985


In [21]:
md_H5File['11GS']['atoms_coordinates_ref']

<HDF5 dataset "atoms_coordinates_ref": shape (3262, 3), type "<f8">

In [36]:
struct = '11GS'
traj = pt.load(struct+'/'+struct+'.pdb')
view = nv.show_pytraj(traj)

In [37]:
view

NGLWidget()

Check from here:

In [32]:
residue_indices1 = list(traj.top.atom_indices(':27@C=,N=,O=,S='))
residue_indices2 = list(traj.top.atom_indices(':58@C=,N=,O=,S='))
residue_indices3 = list(traj.top.atom_indices(':107@C=,N=,O=,S='))
residue_indices4 = list(traj.top.atom_indices(':100@C=,N=,O=,S='))
residue_indices = residue_indices1+residue_indices2+residue_indices3+residue_indices4
converted_indices = convert_indices(residue_indices, inverse_index_conversion)

In [34]:
show_ada_spheres(residue_indices, prediction['prediction'], (0,1,0), 1.5)

In [None]:
view.render_image(trim=True, factor=12)

In [35]:
add_opacity_to_spheres(view.n_components-1, 0.5)

In [None]:
view.render_image(trim=True, factor=12)

In [None]:
view.download_image()

In [None]:
show_sh_r_spheres_color_conversion(residue_indices, 'target', yellow, 2.5, inverse_index_conversion)

In [65]:
view