In [9]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import obonet
import random
import torch
import math
from Bio import SeqIO
import Bio.PDB
import urllib.request
import py3Dmol
import pylab
import pickle as pickle
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import GATConv
from torch_geometric.nn import GATv2Conv
from torch_geometric.nn import GENConv
from torch_geometric.nn.models import MLP
from torch_geometric.data import Data
from torch_geometric.nn.pool import SAGPooling
from torch_geometric.nn.aggr import MeanAggregation
import matplotlib.pyplot as plt
import os
from Bio import PDB
from rdkit import Chem

In [205]:
class CFG:
    pdbfiles: str = "/home/paul/BioHack/pdbind-refined-set/"
    AA_mol2_files: str = "/home/paul/BioHack/AA_mol2/"

In [213]:
def get_atom_symbol(atomic_number):
    return Chem.PeriodicTable.GetElementSymbol(Chem.GetPeriodicTable(), atomic_number)

def remove_hetatm(input_pdb_file, output_pdb_file):
    # Open the input PDB file for reading and the output PDB file for writing
    with open(input_pdb_file, 'r') as infile, open(output_pdb_file, 'w') as outfile:
        for line in infile:
            # Check if the line starts with 'HETATM' (non-protein atoms)
            if line.startswith('HETATM'):
                continue  # Skip this line (HETATM record)
            # Write all other lines to the output file
            outfile.write(line)
            
def get_atom_types_from_sdf(sdf_file):
    supplier = Chem.SDMolSupplier(sdf_file)
    atom_types = set()

    for mol in supplier:
        if mol is not None:
            atoms = mol.GetAtoms()
            atom_types.update([atom.GetSymbol() for atom in atoms])

    return sorted(list(atom_types))

def get_atom_types_from_mol2_split(mol2_file):
    atom_types = set()

    with open(mol2_file, 'r') as mol2:
        reading_atoms = False
        for line in mol2:
            if line.strip() == '@<TRIPOS>ATOM':
                reading_atoms = True
                continue
            elif line.strip() == '@<TRIPOS>BOND':
                break

            if reading_atoms:
                parts = line.split()
                if len(parts) >= 5:
                    atom_type = parts[5]
                    atom_types.add(atom_type)
    
    atom_types_split = set()
    for atom in atom_types:
        atom_types_split.add(str(atom).split('.')[0])
        

    return sorted(list(atom_types_split))

def get_atom_types_from_mol2(mol2_file):
    atom_types = set()

    with open(mol2_file, 'r') as mol2:
        reading_atoms = False
        for line in mol2:
            if line.strip() == '@<TRIPOS>ATOM':
                reading_atoms = True
                continue
            elif line.strip() == '@<TRIPOS>BOND':
                break

            if reading_atoms:
                parts = line.split()
                if len(parts) >= 5:
                    atom_type = parts[5]
                    atom_types.add(atom_type)

    return sorted(list(atom_types))

def get_atom_list_from_mol2_split(mol2_file):
    atoms = []
    with open(mol2_file, 'r') as mol2:
        reading_atoms = False
        for line in mol2:
            if line.strip() == '@<TRIPOS>ATOM':
                reading_atoms = True
                continue
            elif line.strip() == '@<TRIPOS>BOND':
                break

            if reading_atoms:
                parts = line.split()
                if len(parts) >= 5:
                    atom_type = parts[5]
                    atoms.append(atom_type)
    
    atom_list = []
    for atom in atoms:
        atom_list.append(str(atom).split('.')[0])
        

    return atom_list

def get_atom_list_from_mol2(mol2_file):
    atoms = []
    with open(mol2_file, 'r') as mol2:
        reading_atoms = False
        for line in mol2:
            if line.strip() == '@<TRIPOS>ATOM':
                reading_atoms = True
                continue
            elif line.strip() == '@<TRIPOS>BOND':
                break

            if reading_atoms:
                parts = line.split()
                if len(parts) >= 5:
                    atom_type = parts[5]
                    atoms.append(atom_type)

    return atoms

def get_bond_types_from_mol2(mol2_file):
    bond_types = set()

    with open(mol2_file, 'r') as mol2:
        reading_bonds = False
        for line in mol2:
            if line.strip() == '@<TRIPOS>BOND':
                reading_bonds = True
                continue
            elif reading_bonds and line.strip().startswith('@<TRIPOS>'):
                break

            if reading_bonds:
                parts = line.split()
                if len(parts) >= 4:
                    bond_type = parts[3]
                    bond_types.add(bond_type)

    return sorted(list(bond_types))

def read_mol2_bonds(mol2_file):
    bonds = []
    bond_types = []

    with open(mol2_file, 'r') as mol2:
        reading_bonds = False
        for line in mol2:
            if line.strip() == '@<TRIPOS>BOND':
                reading_bonds = True
                continue
            elif reading_bonds and line.strip().startswith('@<TRIPOS>'):
                break

            if reading_bonds:
                parts = line.split()
                if len(parts) >= 4:
                    atom1_index = int(parts[1])
                    atom2_index = int(parts[2])
                    bond_type = parts[3]
                    bonds.append((atom1_index, atom2_index))
                    bond_types.append(bond_type)

    return bonds, bond_types

def calc_residue_dist(residue_one, residue_two) :
    """Returns the C-alpha distance between two residues"""
    diff_vector  = residue_one["CA"].coord - residue_two["CA"].coord
    return np.sqrt(np.sum(diff_vector * diff_vector))

def calc_dist_matrix(chain_one, chain_two) :
    """Returns a matrix of C-alpha distances between two chains"""
    answer = np.zeros((len(chain_one), len(chain_two)), float)
    for row, residue_one in enumerate(chain_one) :
        for col, residue_two in enumerate(chain_two) :
            answer[row, col] = calc_residue_dist(residue_one, residue_two)
    return answer

def calc_contact_map(uniID,map_distance):
    pdb_code = uniID
    pdb_filename = uniID+"_pocket_clean.pdb"
    structure = Bio.PDB.PDBParser(QUIET = True).get_structure(pdb_code, (CFG.pdbfiles +'/'+pdb_code+'/'+pdb_filename))
    model = structure[0]
    flag1 = 0
    flag2 = 0
    idx = 0
    index = []
    chain_info = []
    
    for chain1 in model:
        for resi in chain1:
            index.append(idx)
            idx += 1
            chain_info.append([chain1.id,resi.id])
        for chain2 in model:
            if flag1 == 0:
                dist_matrix = calc_dist_matrix(model[chain1.id], model[chain2.id])
            else:
                new_matrix = calc_dist_matrix(model[chain1.id], model[chain2.id])
                dist_matrix = np.hstack((dist_matrix,new_matrix))
            flag1 += 1
        flag1 = 0
        if flag2 == 0:
            top_matrix = dist_matrix
        else:
            top_matrix = np.vstack((top_matrix,dist_matrix))
        flag2 += 1
    
    contact_map = top_matrix < map_distance
    return contact_map, index, chain_info


def one_hot_encode_single_res(res):
    allowed = set("GAVCPLIMWFKRHSTYNQDEUO")
    if not set(res).issubset(allowed):
        invalid = set(res) - allowed
        raise ValueError(f"Sequence has broken AA: {invalid}")
        
    AA_dict = {'GLY':torch.Tensor([1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]),
               'ALA':torch.Tensor([0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]),
               'VAL':torch.Tensor([0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]),
               'CYS':torch.Tensor([0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]),
               'PRO':torch.Tensor([0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]),
               'LEU':torch.Tensor([0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]),
               'ILE':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]),
               'MET':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]),
               'TRP':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]),
               'PHE':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]),
               'LYS':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]),
               'ARG':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]),
               'HIS':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]),
               'SER':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0]),
               'THR':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0]),
               'TYR':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0]),
               'ASN':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0]),
               'GLN':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0]),
               'ASP':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0]),
               'GLU':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0]),}
    return AA_dict[res]

AA_dictionary = {'GLY':torch.Tensor([1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]),
            'ALA':torch.Tensor([0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]),
            'VAL':torch.Tensor([0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]),
            'CYS':torch.Tensor([0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]),
            'PRO':torch.Tensor([0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]),
            'LEU':torch.Tensor([0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]),
            'ILE':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]),
            'MET':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]),
            'TRP':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]),
            'PHE':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]),
            'LYS':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]),
            'ARG':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]),
            'HIS':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]),
            'SER':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0]),
            'THR':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0]),
            'TYR':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0]),
            'ASN':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0]),
            'GLN':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0]),
            'ASP':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0]),
            'GLU':torch.Tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0]),}

def uniID2graph(uniID,map_distance):
    atom_name = 'CA'
    node_feature = []
    edge_index = []
    edge_attr = []
    coord = []
    contact_map, index, chain_info = calc_contact_map(uniID,map_distance)
    pdb_code = uniID
    pdb_filename = uniID+"_pocket_clean.pdb"
    structure = Bio.PDB.PDBParser(QUIET = True).get_structure(pdb_code, (CFG.pdbfiles +'/'+pdb_code+'/'+pdb_filename))
    model = structure[0]
    
    for i in index:
        node_feature.append(one_hot_encode_single_res(model[chain_info[i][0]][chain_info[i][1]].get_resname()))
        coord.append(model[chain_info[i][0]][chain_info[i][1]]['CA'].coord)
        for j in index:
            if contact_map[i,j] == 1:
                edge_index.append([i,j])
                diff_vector = model[chain_info[i][0]][chain_info[i][1]]['CA'].coord - model[chain_info[j][0]][chain_info[j][1]]['CA'].coord
                dist = (np.sqrt(np.sum(diff_vector * diff_vector))/map_distance)
                bond_type = bond_type_dict['nc']
                edge_attr.append(np.hstack((dist,bond_type)))
                            
    edge_index = np.array(edge_index)
    edge_index = edge_index.transpose()
    edge_index = torch.Tensor(edge_index)
    edge_index = edge_index.to(torch.int64)
    edge_attr = torch.Tensor(edge_attr)
    node_feature = torch.stack(node_feature)
    graph = Data(x = node_feature, edge_index = edge_index,edge_attr = edge_attr)
    return graph, coord

def read_mol2_bonds_and_atoms(mol2_file):
    bonds = []
    bond_types = []
    atom_types = {}
    atom_coordinates = {}

    with open(mol2_file, 'r') as mol2:
        reading_bonds = False
        reading_atoms = False
        for line in mol2:
            if line.strip() == '@<TRIPOS>BOND':
                reading_bonds = True
                continue
            elif line.strip() == '@<TRIPOS>ATOM':
                reading_atoms = True
                continue
            elif line.strip().startswith('@<TRIPOS>SUBSTRUCTURE'):
                break
            elif reading_bonds and line.strip().startswith('@<TRIPOS>'):
                reading_bonds = False
            elif reading_atoms and line.strip().startswith('@<TRIPOS>'):
                reading_atoms = False


            if reading_bonds:
                parts = line.split()
                if len(parts) >= 4:
                    atom1_index = int(parts[1])
                    atom2_index = int(parts[2])
                    bond_type = parts[3]
                    bonds.append((atom1_index, atom2_index))
                    bond_types.append(bond_type)

            if reading_atoms:
                parts = line.split()
                if len(parts) >= 6:
                    atom_index = int(parts[0])
                    atom_type = parts[5]
                    x, y, z = float(parts[2]), float(parts[3]), float(parts[4])
                    atom_types[atom_index] = atom_type.split('.')[0]
                    atom_coordinates[atom_index] = (x, y, z)

    return bonds, bond_types, atom_types, atom_coordinates

def molecule2graph_AA(filename,map_distance):
    node_feature = []
    edge_index = []
    edge_attr = []
    mol2_file = CFG.pdbfiles+filename+'/'+filename+'_ligand.mol2'
    bonds, bond_types, atom_types, atom_coordinates = read_mol2_bonds_and_atoms(mol2_file)
    for atom in atom_types:
        node_feature.append(torch.zeros(20))
        #node_feature.append(torch.Tensor(atom2emb[atom_types[atom]]))
    for i in range(len(bonds)):
        bond = bonds[i]
        edge_index.append([bond[0] - 1,bond[1] - 1])
        coord1 = np.array(atom_coordinates[bond[0]])
        coord2 = np.array(atom_coordinates[bond[1]])
        dist = [np.sqrt(np.sum((coord1 - coord2)*(coord1 - coord2)))/map_distance]
        bond_type = bond_type_dict[bond_types[i]]
        edge_attr.append(np.hstack((dist,bond_type)))
    
    edge_index = np.array(edge_index)
    edge_index = edge_index.transpose()
    edge_index = torch.Tensor(edge_index)
    edge_index = edge_index.to(torch.int64)
    edge_attr = torch.Tensor(edge_attr)
    node_feature = torch.stack(node_feature)
    graph = Data(x = node_feature, edge_index = edge_index,edge_attr = edge_attr)
    
    return graph, atom_coordinates

def id2fullgraph(filename, map_distance):
    prot_graph, prot_coord = uniID2graph(filename,map_distance)
    mol_graph, mol_coord = molecule2graph(filename,map_distance)
    mol_coord = [mol_coord[i] for i in mol_coord]
    node_features = torch.cat((prot_graph.x,mol_graph.x),dim = 0)
    update_edge_index = mol_graph.edge_index + prot.x.size()[0]
    edge_index = torch.cat((prot_graph.edge_index,update_edge_index), dim = 1)
    edge_attr = torch.cat((prot_graph.edge_attr,mol_graph.edge_attr), dim = 0)
    
    new_edge_index = []
    new_edge_attr = []
    for i in range(len(mol_coord)):
        for j in range(len(prot_coord)):
            dist_vec = mol_coord[i] - prot_coord[j]
            dist = np.sqrt(np.sum(dist_vec*dist_vec))/map_distance
            if dist < 1.0:
                new_edge_index.append([j,i + len(prot_coord)])
                new_edge_attr.append((np.hstack(([dist],bond_type_dict['nc']))))
                
    new_edge_index = np.array(new_edge_index)
    new_edge_index = new_edge_index.transpose()
    new_edge_index = torch.Tensor(new_edge_index)
    new_edge_index = new_edge_index.to(torch.int64)
    new_edge_attr = torch.Tensor(new_edge_attr)
    
    edge_index = torch.cat((edge_index,new_edge_index), dim = 1)
    edge_attr = torch.cat((edge_attr,new_edge_attr), dim = 0)
    
    graph = Data(x = node_features, edge_index = edge_index,edge_attr = edge_attr)
    
    return graph

In [235]:
def read_mol2_bonds_and_atoms_AA(mol2_file):
    bonds = []
    bond_types = []
    atom_names = {}
    atom_types = {}
    atom_coordinates = {}

    with open(mol2_file, 'r') as mol2:
        reading_bonds = False
        reading_atoms = False
        for line in mol2:
            if line.strip() == '@<TRIPOS>BOND':
                reading_bonds = True
                continue
            elif line.strip() == '@<TRIPOS>ATOM':
                reading_atoms = True
                continue
            elif line.strip().startswith('@<TRIPOS>SUBSTRUCTURE'):
                break
            elif reading_bonds and line.strip().startswith('@<TRIPOS>'):
                reading_bonds = False
            elif reading_atoms and line.strip().startswith('@<TRIPOS>'):
                reading_atoms = False


            if reading_bonds:
                parts = line.split()
                if len(parts) >= 4:
                    atom1_index = int(parts[1])
                    atom2_index = int(parts[2])
                    bond_type = parts[3]
                    bonds.append((atom1_index, atom2_index))
                    bond_types.append(bond_type)

            if reading_atoms:
                parts = line.split()
                if len(parts) >= 6:
                    atom_index = int(parts[0])
                    atom_type = parts[5]
                    x, y, z = float(parts[2]), float(parts[3]), float(parts[4])
                    atom_types[atom_index] = atom_type.split('.')[0]
                    atom_names[atom_index] = parts[1]
                    atom_coordinates[atom_index] = (x, y, z)

    return bonds, bond_types, atom_types, atom_coordinates, atom_names

def molecule2graph_AA(filename,map_distance):
    node_feature = []
    edge_index = []
    edge_attr = []
    mol2_file = filename
    bonds, bond_types, atom_types, atom_coordinates = read_mol2_bonds_and_atoms(mol2_file)
    for atom in atom_types:
        #node_feature.append(torch.zeros(20))
        node_feature.append(torch.Tensor(atom2emb[atom_types[atom]]))
    for i in range(len(bonds)):
        bond = bonds[i]
        edge_index.append([bond[0] - 1,bond[1] - 1])
        coord1 = np.array(atom_coordinates[bond[0]])
        coord2 = np.array(atom_coordinates[bond[1]])
        dist = [np.sqrt(np.sum((coord1 - coord2)*(coord1 - coord2)))/map_distance]
        bond_type = bond_type_dict[bond_types[i]]
        edge_attr.append(np.hstack((dist,bond_type)))
    
    #Master_node
    node_feature.append(torch.zeros(len(atom2emb['N'])))
    
    for i in range(len(node_feature) - 1):
        edge_index.append([i,int(len(node_feature)-1)])
        bond_type = bond_type_dict['1']
        edge_attr.append(np.hstack((1.0,bond_type)))
    
    edge_index = np.array(edge_index)
    edge_index = edge_index.transpose()
    edge_index = torch.Tensor(edge_index)
    edge_index = edge_index.to(torch.int64)
    edge_attr = torch.Tensor(edge_attr)
    node_feature = torch.stack(node_feature)
    graph = Data(x = node_feature, edge_index = edge_index,edge_attr = edge_attr)
    
    return graph, atom_coordinates

In [236]:
graph_list = []
for filename in os.listdir(CFG.AA_mol2_files):
    mol2_file = CFG.AA_mol2_files + filename
    graph = molecule2graph_AA(mol2_file,12.0)[0]
    graph_list.append([filename.split('.')[0],graph])

In [237]:
torch.save(graph_list,'AA_graphs.pt')

In [211]:
backbone = ['N','C','O']
for atom in atom_names:
    if atom_names[atom] in backbone:
        print('yes')
    else:
        print('no')

yes
no
yes
yes
no
no
no
no
no
no
no
no
no
no
no
no
no
no
no


In [6]:
with open('ele2emb.pkl', 'rb') as f:
    ele2emb = pickle.load(f)

In [7]:
atom2emb = {}
for i in ele2emb:
    atom2emb.update({get_atom_symbol(i): ele2emb[i]})

In [11]:
with open('atom2emb.pkl', 'wb') as f:
    pickle.dump(atom2emb, f)

In [15]:
with open('bond_type_dict.pkl', 'rb') as f:
    bond_type_dict = pickle.load(f)

In [16]:
bond_type_dict['1']

array([1., 0., 0., 0., 0., 0.])

In [38]:
for filename in os.listdir(CFG.pdbfiles):
    mol2_file = CFG.pdbfiles+filename+'/'+filename+'_ligand.mol2'
    bonds, bond_types, atom_types, coordinates = read_mol2_bonds_and_atoms(mol2_file)
    print(atom_types)
    print(bonds)
    print(bond_types)
    print(coordinates)
    break

{1: 'C', 2: 'C', 3: 'C', 4: 'C', 5: 'C', 6: 'C', 7: 'C', 8: 'C', 9: 'N', 10: 'S', 11: 'O', 12: 'O', 13: 'O', 14: 'O', 15: 'C', 16: 'C', 17: 'C', 18: 'C', 19: 'C', 20: 'C', 21: 'C', 22: 'C', 23: 'N', 24: 'O', 25: 'O', 26: 'O', 27: 'O', 28: 'O', 29: 'H', 30: 'H', 31: 'H', 32: 'H', 33: 'H', 34: 'H', 35: 'H', 36: 'H', 37: 'H', 38: 'H', 39: 'H', 40: 'H', 41: 'H', 42: 'H', 43: 'H', 44: 'H', 45: 'H', 46: 'H', 47: 'H', 48: 'H', 49: 'H', 50: 'H', 51: 'H', 52: 'H', 53: 'H', 54: 'H'}
[(2, 1), (1, 10), (13, 1), (3, 2), (2, 9), (4, 3), (3, 11), (4, 5), (12, 4), (5, 6), (5, 13), (6, 14), (7, 8), (9, 7), (10, 7), (15, 12), (16, 15), (15, 26), (16, 17), (23, 16), (17, 18), (17, 24), (18, 19), (18, 25), (19, 20), (26, 19), (20, 27), (21, 22), (21, 23), (21, 28), (1, 29), (2, 30), (3, 31), (4, 32), (5, 33), (6, 34), (6, 35), (8, 36), (8, 37), (8, 38), (11, 39), (14, 40), (15, 41), (16, 42), (17, 43), (18, 44), (19, 45), (20, 46), (20, 47), (22, 48), (22, 49), (22, 50), (23, 51), (24, 52), (25, 53), (27,

In [174]:
id2fullgraph(filename,12.0)

Data(x=[111, 20], edge_index=[2, 2482], edge_attr=[2482, 7])

In [202]:
count = 0
for filename in os.listdir(CFG.pdbfiles):
    print(filename)
    id2fullgraph(filename, 12.0)
    count += 1
    if count % 1000 == 0:
        print(count)

2wlz
4isi
2wnj
4qxo
2vvs
5cap
1yqj
5mro
1w0z
4a6l
3dgo
4ab9
1pzi
5e89
3fee
2zkj
5ewk
4und
5ngz
6hsh
5gja
3mof
5j6m
1z9g
3k5v
3prs
1c4u
5wbm
5evd
4b73
1nl9
1eb2
6gji
1gj6
6udu
6gj8
4cfl
6ic2
1elb
1b40
1i9n
5za9
4je7
4k0o
6p84
1ajv
3gnw
6eq7
5upf
4flp
4io2
5d47
1vyg
6i8m
3ewj
3el1
3p8p
4r5a
5ipj
2j4g
4g0p
6gvz
6qgf
5w44
1hps
5wuk
4avi
4zzy
4cs9
4cjp
5lny
3dzt
2gj5
1oau
6ibk
4nku
2vyt
3dbu
6cjv
1g7f
6g6t
1mq6
5ahw
5zw6
3ttp
4rqv
3sut
6p3t
5e1s
2aj8
3su1
3wmc
4ks4
6hr2
4mhy
2g94
1i1e
2avq
6eeo
1ws4
5oha
3czv
5fs5
3jy0
5dit
4o9w
3hmo
5n84
3wtj
3ao4
1wn6
2j79
4ih3
6p89
4i9h
5zg3
1gyx
5dnu
4p6w
3s5y
3oil
6bm5
1bzj
3ov1
1njd
4hp0
2a8g
2gvv
5fcz
5v82
5d25
4ago
2rfh
4bt5
5cy9
2qbu
4gzx
4m2u
5t9u
1bgq
4k3h
6e5l
4zzd
1g7q
5sz7
2v00
4zip
2wer
5kej
5ufr
5ot9
1w5x
4pin
5g17
1v1j
2cf8
5u0w
4q1x
3uw5
5ma7
2vo4
6qlt
1msn
3mxe
1hpx
4p3h
4x8u
1o0f
1hk4
2f1g
4der
3hk1
6c85
5os2
4u70
4uma
1x39
5azf
5g2g
5vcw
6nxz
6p5o
1sw2
5d26
1z6e
3ucj
4cd5
2bok
2br1
1g3d
3fcq
1o5e
4cps
4q46
4q8y
4x5r
5btv
4uof
4xoc
4iwz


5tmn
3uod
4mrg
1xka
6fba
4rd3
2ymd
2yel
4u73
5k1f
3fuz
2w26
5epn
4djo
4ad6
4ufk
4b6s
5n25
3e5a
5i9z
3f5k
5kqx
2oxd
6i66
1ajq
3gv9
3exh
3i25
1k1i
1elc
3t85
4n9a
6b96
6cpw
5nw0
3cj4
3m3c
1o2n
2cet
3r5t
5od1
4psb
4ih5
5j8u
6fhq
3c2r
3aau
4gql
4m0y
4cpw
4ayq
5dfp
4u0w
5cp5
6p9e
1wdn
1kav
6cwh
4g8y
1g7g
1fzm
6d78
2xbx
2qi5
3mhl
6pve
1ciz
6gga
6mja
2euk
2tpi
4tmn
4q99
1d6w
4dfg
4umc
6nwl
3s77
4x48
1z4o
4l19
2wvt
1jao
1li2
2r2m
3jzj
3eax
2w8w
2w8j
1a4w
4twp
3zze
6qpl
4m0e
1sld
3pwk
6ftp
4az5
6g9i
5lud
1k22
5j7q
5hvs
2ypo
3g30
4djw
4kcx
1y3x
1t31
186l
6n5x
4avh
3b4f
4mn3
6md0
6ge7
3myg
3n76
3ps1
4fxp
1kjr
6ei5
5fsy
3gqz
3t0d
6fcj
4x6n
1k1y
5wbo
4i7j
4mc9
6std
4poh
6b4l
5n0e
2wzs
2cej
4lov
3q6z
1d3p
3qgw
4cmo
6gl9
5ekm
4ty7
5nz4
3nex
6ht1
6czb
3q6w
4gqq
5fls
5ld8
6hza
3iph
6hni
5byi
3v7x
1qb9
3pn1
1pa9
1bnq
2v88
4b5s
1l8g
2uxi
4zls
1a1e
3v5p
3vfa
3vhk
1c88
2h3e
3rux
5mmg
4jyc
4bs0
3kmc
3f78
4azi
2fxv
1b3g
3rdq
4h42
5os4
5llg
3brn
5ep7
6b97
3ddg
3tfp
2q8m
4q90
3hek
4djr
4jpx
3rt8
2j7b
2hkf
3ebi


6d15
3ipq
1pzp
3oku
1o0n
1s38
4deu
4bcn
3rr4
5yj8
4zw7
5eh7
3iw6
5vm0
1odj
1str
4cd0
2xb7
6g34
3uo4
3f19
2hu6
5dqf
5eis
1yej
5f08
5ey4
6r4k
4re2
4e7r
1o1s
1ctu
5ouh
5wgp
6nsv
3st5
2bal
1fjs
1e6s
4cpz
4m3p
4unp
5n2z
6f90
3f80
2yfx
1ii5
3lq2
3f5l
5qa8
4idn
1ql7
2ra6
4pvx
4i5c
2p53
1j14
6dh1
5eng
4zeb
2bt9
2b1g
4b6o
6eed
5m9w
4av5
4auj
3nb5
3hf8
4l51
4de0
3ouj
4pop
2sim
2j94
6o95
5l9l
4qsv
1egh
2xef
2xj1
1v16
2wr8
5sz0
5sz4
2o8h
1m83
1lyx
3k1j
6epy
1fzj
1m1b
4css
4wop
1h46
3gc4
2evl
6idg
5vh0
6ej2
1qy2
4djv
6f28
5g1z
3b26
3s71
1y6r
1d4i
2xn3
2xht
5hwv
3l4z
5hbn
6jbe
5er2
2xej
6np4
6cdo
2pql
5ka7
4qll
1k1n
3hkw
4x5q
5ymx
1sgu
4lm4
6ior
3ccz
3gy2
6cfc
3ip6
4nyf
4cu7
3bxe
6nv9
5jt9
6i63
3d7z
6hh3
3cz1
1mtr
6eux
1g98
1hi3
1lkk
2uz9
1k21
3ms9
1x8t
2boj
2uy0
3r6u
1pph
5o1f
1y20
6ajz
3cm2
2uwp
4u0f
4hym
4z84
3dd8
5dbm
2x0y
3hmp
1dhi
4rww
1wht
1utl
1tsy
6cjr
1qbs
1y1z
3p4v
3ryy
5d3n
1f0u
5u0z
5ivv
3b68
6hgi
5aoj
1lan
3n86
3wtn
6d1i
3n35
4poj
5d0c
1uv6
1bzy
2xp7
1xq0
5mod
5dyo
3suu
4h3g
5xo7
5zag


3u8k
2hzl
6chp
4llk
2qta
6hgr
4r5t
1hos
4x6m
5ueu
6h36
5sym
5dpx
1tni
3sjf
4e5w
6dpt
5meh
4a6c
6jay
3t70
6mla
3p5l
4ibb
4z1e
4dst
3igp
6mj4
6ays
2oc2
1wvj
3ng4
4ih6
5nlk
3c8a
5cst
1fzo
4cwp
1e4h
2a5c
2aoe
2qg2
6ixd
4n6z
6r8w
5fpk
6bhv
1q8u
3v5t
5nk6
4ax9
5gsa
3hzv
2r1y
2epn
5eij
2e1w
4ozj
4dew
1sqa
1fd0
3ljg
5ceq
5a2i
5g5v
5vcv
5om3
3bxg
5eq1
4ty6
4da5
6e7j
4j48
4mhz
1ik4
2ovv
5sz2
2v95
1n4h
3mhc
2jiw
4agl
6nw3
1h4w
5000
4ujb
1bnv
4hla
2pvl
3mss
3fwv
3tao
3lk8
5wex
4ban
2v8w
5d45
2pqc
2i4w
1c1u
4ca6
4bf6
2jdu
1jak
2xhm
5f2r
2wzm
5jhb
1ejn
5fho
4kzu
4qfn
4q4q
2uwo
5n1s
4m2r
6n79
4je8
4d8z
5mby
5op4
4xmr
2qbp
2y82
2cf9
6guh
4u6w
5d2r
1d4h
3gkz
1erb
5fnd
2xjx
3zps
3tf6
6f3b
6n7a
1b3l
3ekt
3fv2
1fl3
2fzk
4ibk
1ado
3cow
6g2c
4q9o
5e2k
1tx7
4gzp
4q83
4bcs
4lm1
4gih
5n93
4kyh
4y5d
4ogj
4pft
1e1v
4o61
3lmk
1d3d
1e3g
4k0y
1o35
6cn5
5ko5
4zei
1x8j
6d1g
5lwm
4qlk
1m0n
6c9v
2i6b
2j7d
2pow
5d0r
4y3y
2e2r
5o9r
3n7o
2wc3
5ih9
5oh9
1u1w
3c4h
1p57
2ha2
4u1b
5n31
5f2u
4uc5
3o5n
6qrc
1o7o
6d1a
5vij
1gfy
