# This script creates the graphs from the PDB and Mol2 files.
![GraphVisulization](images/graph_construction.png)<br>

In [1]:
import numpy as np
import pandas as pd
import random
import torch
import math
from Bio import SeqIO
import Bio.PDB
import pickle as pickle
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
import matplotlib.pyplot as plt
import os
from Bio import PDB
from rdkit import Chem
import blosum as bl
import pymol
from pymol import cmd
from graph_utils import *

In [2]:
class CFG:
    pdbfiles: str = "/home/paul/Desktop/BioHack-Project-Walkthrough/pdbind-refined-set/"
    AA_mol2_files: str = "/home/paul/Desktop/BioHack-Project-Walkthrough/AA_mol2/"
    home: str = '/home/paul/Desktop/BioHack-Project-Walkthrough/BioHack/'

# The general strategy is to create a graph from the binding pocket PDB, then create a graph from the Mol2 file and finally join the two graphs together.    

## Protein Graph
Each node in the protein graph coresponds to an single residue. The node feratures are amino acid embeddings computed in the previous script. Two nodes are connected if the alpha carbon of each corresponding amino acid is within 12 angstoms of one another. The edge features are calculated from the pairwise distance between the N, CA, and C atoms of each amino acid resulting in 9 total distances. These distances are used to calculate 12 gaussian radial basis functions for each atom pair resulting in 108 (9x12) features. This vector is then concatenated with a one hot encoded vector containing bond type information. None of the nodes are considered bonded in the protein but this vector will become relevant once this graph is merged with the small molecule graph
## Ligand Graph
Each node in the ligand graph corresponds to a single atom in the molecule. The nodes are connected if the two atoms are within 12 angstoms of each other. The edge features are calculated from the pairwise distance between two atoms. Too match the dimensonality of the protein edge features, this distance is duplicated 9 times and used to calculated the gaussian radial basis funtions. This vector is then concatenated with a one hot encoded vector containing bond type information.  
**n.b. This is how the graphs for computing amino acid embeddings were constructed**
## Merging the two graphs
Merging the graphs simply involved including connections between amino acids and ligand atoms. Again, two nodes are connected if they are within 12 angstroms of one another. Pairwise distances are calculated from the amino acid backbone atoms and the ligand atoms resulting in 3 distances. These distances are duplicated three times and used to calculate the gaussian radial basis functions and the bond type information is concatenated onto the resulting vector.  
Additionally, a "Master" node is added that is connected to all nodes in the graph. This node is included to facilitate the flow of information throuout the graph.

In [12]:
def uniID2graph(uniID,map_distance, norm_map_distance = 12.0):
    atom_name = 'CA'
    atom_names = ['N','CA','C']
    node_feature = []
    edge_index = []
    edge_attr = []
    coord = []
    y = []
    contact_map, index, chain_info = calc_contact_map(uniID,map_distance)
    pdb_code = uniID
    pdb_filename = uniID+"_pocket_clean.pdb"
    structure = Bio.PDB.PDBParser(QUIET = True).get_structure(pdb_code, (CFG.pdbfiles +'/'+pdb_code+'/'+pdb_filename))
    model = structure[0]
    
    for i in index:
        node_feature.append(AA_embeddings[model[chain_info[i][0]][chain_info[i][1]].get_resname()])
        coord.append([model[chain_info[i][0]][chain_info[i][1]][name].coord for name in atom_names])
        y.append(OHE_dict[model[chain_info[i][0]][chain_info[i][1]].get_resname()])
        for j in index:
            if contact_map[i,j] == 1 and j > i:
                edge_index.append([i,j])
                d = []
                for name1 in atom_names:
                    for name2 in atom_names:
                        diff_vector = model[chain_info[i][0]][chain_info[i][1]][name1].coord - model[chain_info[j][0]][chain_info[j][1]][name2].coord
                        dist = (np.sqrt(np.sum(diff_vector * diff_vector)))
                        for l in range(12):
                            d.append(np.exp((-1.0*(dist - 2.0*(l + 0.5))**2.0)/norm_map_distance)) 
                bond_type = bond_type_dict['nc']
                edge_attr.append(np.hstack((d,bond_type)))
    
    
    prot_coord = [torch.Tensor(i[1]) for i in coord]
    new_prot_coord = torch.stack(prot_coord)
    
    edge_index = np.array(edge_index)
    edge_index = edge_index.transpose()
    edge_index = torch.Tensor(edge_index)
    edge_index = edge_index.to(torch.int64)
    edge_attr = torch.Tensor(np.array(edge_attr))
    node_feature = torch.stack(node_feature)
    y = torch.stack(y)
    graph = Data(x = node_feature, edge_index = edge_index,edge_attr = edge_attr, pos = new_prot_coord)
    graph.y = y
    return graph, coord

In [13]:
def molecule2graph(filename,map_distance, norm_map_distance = 12.0):
    node_feature = []
    edge_index = []
    edge_attr = []
    y = []
    mol2_file = CFG.pdbfiles+filename+'/'+filename+'_ligand.mol2'
    bonds, bond_types, atom_types, atom_coordinates = read_mol2_bonds_and_atoms(mol2_file)
    for atom in atom_types:
        node_feature.append(torch.Tensor(atom2emb[atom_types[atom]]))
        y.append(torch.zeros(20))
    

    for atom1 in range(1, len(atom_types)+1):
        for atom2 in range(atom1 + 1, len(atom_types)+1):
            bonded_flag = 0
            for i, bond in enumerate(bonds):
                if (atom1 in bond) and (atom2 in bond):
                    edge_index.append([bond[0] - 1,bond[1] - 1])
                    coord1 = np.array(atom_coordinates[bond[0]])
                    coord2 = np.array(atom_coordinates[bond[1]])
                    dist = math.dist(coord1, coord2)
                    #dist = np.sqrt(np.sum((coord1 - coord2)*(coord1 - coord2)))
                    d = []
                    for l in range(12):
                        d.append(np.exp((-1.0*(dist - 2.0*(l + 0.5))**2.0)/norm_map_distance))
                    bond_type = bond_type_dict[bond_types[i]]
                    edge_attr.append(np.hstack((d,d,d,d,d,d,d,d,d,bond_type)))
                    bonded_flag = 1
                
            if bonded_flag == 0:
                coord1 = np.array(atom_coordinates[atom1])
                coord2 = np.array(atom_coordinates[atom2])
                dist = math.dist(coord1, coord2)
                #dist = np.sqrt(np.sum((coord1 - coord2)*(coord1 - coord2)))
                if dist < map_distance:
                    edge_index.append([atom1 - 1,atom2 - 1])
                    d = []
                    for l in range(12):
                        d.append(np.exp((-1.0*(dist - 2.0*(l + 0.5))**2.0)/norm_map_distance))
                    bond_type = bond_type_dict['nc']
                    edge_attr.append(np.hstack((d,d,d,d,d,d,d,d,d,bond_type)))

    
    edge_index = np.array(edge_index)
    edge_index = edge_index.transpose()
    edge_index = torch.Tensor(edge_index)
    edge_index = edge_index.to(torch.int64)
    edge_attr = torch.Tensor(edge_attr)
    node_feature = torch.stack(node_feature)
    y = torch.stack(y)
    graph = Data(x = node_feature, edge_index = edge_index,edge_attr = edge_attr)#, pos = new_mol_coords)
    graph.y = y
    return graph, atom_coordinates

In [14]:
def id2fullgraph(filename, map_distance, norm_map_distance = 12.0):
    prot_graph, prot_coord = uniID2graph(filename,map_distance)
    prot_graph = prot_graph.to('cpu')
    mol_graph, mol_coord = molecule2graph(filename,map_distance)
    mol_graph = mol_graph.to('cpu')
    mol_coord = [mol_coord[i] for i in mol_coord]
    node_features = torch.cat((prot_graph.x,mol_graph.x),dim = 0)
    y = torch.cat((prot_graph.y,mol_graph.y),dim = 0)
    update_edge_index = mol_graph.edge_index + prot_graph.x.size()[0]
    edge_index = torch.cat((prot_graph.edge_index,update_edge_index), dim = 1)
    edge_attr = torch.cat((prot_graph.edge_attr,mol_graph.edge_attr), dim = 0)
    
    new_edge_index = []
    new_edge_attr = []
    for i in range(len(mol_coord)):
        for j in range(len(prot_coord)):
            d = []
            for k in range(len(prot_coord[j])):
                dist_vec = mol_coord[i] - prot_coord[j][k]
                dist = np.sqrt(np.sum(dist_vec*dist_vec))
                if k == 1:
                    d_check = dist
                for l in range(12):
                    d.append(np.exp((-1.0*(dist - 2.0*(l + 0.5))**2.0)/norm_map_distance))
            if (d_check/map_distance) < 1.0:
                new_edge_index.append([j,i + len(prot_coord)])
                new_edge_attr.append((np.hstack((d,d,d,bond_type_dict['nc']))))
    
    #Master_node
    node_features = torch.cat((node_features,torch.zeros(len(atom2emb['N'])).unsqueeze(0)),dim = 0)
    y = torch.cat((y,torch.zeros(20).unsqueeze(0)),dim = 0)
    
    for i in range(len(node_features) - 1):
        new_edge_index.append([i,int(len(node_features)-1)])
        bond_type = bond_type_dict['nc']
        new_edge_attr.append(np.hstack((np.zeros(3*len(d)),bond_type)))
    
    new_edge_index = np.array(new_edge_index)
    new_edge_index = new_edge_index.transpose()
    new_edge_index = torch.Tensor(new_edge_index)
    new_edge_index = new_edge_index.to(torch.int64)
    new_edge_attr = torch.Tensor(new_edge_attr)
    
    #Include Ca or atom coordinates
    prot_coord = [torch.Tensor(i[1]) for i in prot_coord]
    new_prot_coord = torch.stack(prot_coord)
    mol_coord = [torch.Tensor(i) for i in mol_coord]
    new_mol_coord = torch.stack(mol_coord)
    coords = torch.vstack((new_prot_coord, new_mol_coord))
    
    edge_index = torch.cat((edge_index,new_edge_index), dim = 1)
    edge_attr = torch.cat((edge_attr,new_edge_attr), dim = 0)
    
    graph = Data(x = node_features, edge_index = edge_index,edge_attr = edge_attr, pos = coords)
    graph.y = y
    
    return graph

In [15]:
count = 0
graph_list_mn = []
for filename in os.listdir(CFG.pdbfiles):
    graph = id2fullgraph(filename, 12.0)
    graph.label = filename
    graph_list_mn.append(graph)
    count += 1
    if count % 1000 == 0:
        print(count)

1000
2000
3000
4000
5000


## Determinging "Designable Indicies"
The task of the graph neural network is to predict the identity of the amino acids that are deemed "in contact" with the small molecule. Our criteria for "in contact" is if any pare of the residue is within 5 angstroms of any part of the ligand. Here we determine which amino acids should be the prediction target

In [16]:
def createMask(filename,graph,num_masked):
    pdb_code = filename
    pdb_filename = pdb_code+"_pocket_clean.pdb"
    structure = Bio.PDB.PDBParser(QUIET = True).get_structure(pdb_code, (CFG.pdbfiles +'/'+pdb_code+'/'+pdb_filename))
    model = structure[0]
    count = 0
    for chain in model:
        for resi in chain:
            count += 1
    protein_mask = [False]*(count - num_masked)
    true_mask = [True] * num_masked
    protein_mask = np.hstack((protein_mask,true_mask))
    random.shuffle(protein_mask)
    molecule_mask = [False]*(graph.x.size()[0] - count)
    
    return np.hstack((protein_mask,molecule_mask))

def design_indicies(filename):
    pdb_filename = filename +'_pocket_clean.pdb'
    mol2_filename = filename+ '_ligand.mol2'
    mol2 = filename+ '_ligand'
    cmd.delete('all')
    cmd.load(CFG.pdbfiles +'/'+filename+'/'+pdb_filename)
    cmd.load(CFG.pdbfiles +'/'+filename+'/'+mol2_filename)
    #cmd.create('interacting_resi', 'br. all within 5 of organic')
    cmd.create('interacting_resi', '(all within 5 of '+ mol2+') and not '+mol2)
    chain_list = []
    resi_list = []
    cmd.iterate('interacting_resi', 'chain_list.append(chain)', space={'chain_list':chain_list,'resi_list':resi_list})
    cmd.iterate('interacting_resi', 'resi_list.append(resi)', space={'chain_list':chain_list,'resi_list':resi_list})
    
    chain_resi = []
    previous = [None,None]
    for i,resi in enumerate(resi_list):
        if chain_list[i]:
            test = [chain_list[i],resi_list[i]]
            if test != previous:
                chain_resi.append(test)
                previous = test
                
    structure = Bio.PDB.PDBParser(QUIET = True).get_structure(filename, (CFG.pdbfiles +'/'+filename+'/'+pdb_filename))
    model = structure[0]
    count = 0
    designable_indicies = []
    for chain in model:
        for resi in chain:
            test = [str(chain.id),str(resi.id[1])]
            if test in chain_resi:
                designable_indicies.append(count)
            count += 1
    
    return designable_indicies

In [17]:
count = 0
for i, entry in enumerate(graph_list_mn):
    indicies = design_indicies(entry.label)
    graph_list_mn[i].designable_indicies = indicies
    count += 1
    if count % 1000 == 0:
        print(count)

1000
2000
3000
4000
5000


In [18]:
torch.save(graph_list_mn,'binding_pocket_graphs.pt')