In [1]:
import numpy as np
import pandas as pd
import random
import torch
import math
from Bio import SeqIO
import Bio.PDB
import pickle as pickle
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GENConv
from torch_geometric.nn.models import MLP
from torch_geometric.data import Data
from torch_geometric.nn.aggr import MeanAggregation
import matplotlib.pyplot as plt
import os
from Bio import PDB
from rdkit import Chem
import blosum as bl

In [2]:
class CFG:
    pdbfiles: str = "/home/paul/Desktop/BioHack-Project-Walkthrough/pdbind-refined-set/"
    AA_mol2_files: str = "/home/paul/Desktop/BioHack-Project-Walkthrough/AA_mol2/"

In [5]:
with open('atom2emb.pkl', 'rb') as f:
    atom2emb = pickle.load(f)
    
with open('bond_type_dict.pkl', 'rb') as f:
    bond_type_dict = pickle.load(f)
    
one_letter_to_three_letter_dict = {'G':'gly',
                                   'A':'ala',
                                   'V':'val',
                                   'C':'cys',
                                   'P':'pro',
                                   'L':'leu',
                                   'I':'ile',
                                   'M':'met',
                                   'W':'trp',
                                   'F':'phe',
                                   'K':'lys',
                                   'R':'arg',
                                   'H':'his',
                                   'S':'ser',
                                   'T':'thr',
                                   'Y':'tyr',
                                   'N':'asn',
                                   'Q':'gln',
                                   'D':'asp',
                                   'E':'glu'
    
}

def BLOSUM_encode_single(seq,AA_dict):
    allowed = set("gavcplimwfkrhstynqdeuogavcplimwfkrhstynqde")
    if not set(seq).issubset(allowed):
        invalid = set(seq) - allowed
        raise ValueError(f"Sequence has broken AA: {invalid}")
    vec = AA_dict[seq]
    return vec

matrix = bl.BLOSUM(62)
allowed_AA = "GAVCPLIMWFKRHSTYNQDE"
BLOSUM_dict_three_letter = {}
for i in allowed_AA:
    vec = []
    for j in allowed_AA:
        vec.append(matrix[i][j])
    BLOSUM_dict_three_letter.update({one_letter_to_three_letter_dict[i]:torch.Tensor(vec)})
    
def read_mol2_bonds_and_atoms(mol2_file):
    bonds = []
    bond_types = []
    atom_types = {}
    atom_coordinates = {}

    with open(mol2_file, 'r') as mol2:
        reading_bonds = False
        reading_atoms = False
        for line in mol2:
            if line.strip() == '@<TRIPOS>BOND':
                reading_bonds = True
                continue
            elif line.strip() == '@<TRIPOS>ATOM':
                reading_atoms = True
                continue
            elif line.strip().startswith('@<TRIPOS>SUBSTRUCTURE'):
                break
            elif reading_bonds and line.strip().startswith('@<TRIPOS>'):
                reading_bonds = False
            elif reading_atoms and line.strip().startswith('@<TRIPOS>'):
                reading_atoms = False


            if reading_bonds:
                parts = line.split()
                if len(parts) >= 4:
                    atom1_index = int(parts[1])
                    atom2_index = int(parts[2])
                    bond_type = parts[3]
                    bonds.append((atom1_index, atom2_index))
                    bond_types.append(bond_type)

            if reading_atoms:
                parts = line.split()
                if len(parts) >= 6:
                    atom_index = int(parts[0])
                    atom_type = parts[5]
                    x, y, z = float(parts[2]), float(parts[3]), float(parts[4])
                    atom_types[atom_index] = atom_type.split('.')[0]
                    atom_coordinates[atom_index] = (x, y, z)

    return bonds, bond_types, atom_types, atom_coordinates  

def molecule2graph_AA(filename,map_distance, norm_map_distance = 12.0):
    node_feature = []
    edge_index = []
    edge_attr = []
    mol2_file = CFG.AA_mol2_files+filename
    bonds, bond_types, atom_types, atom_coordinates = read_mol2_bonds_and_atoms(mol2_file)
    for atom in atom_types:
        node_feature.append(torch.Tensor(atom2emb[atom_types[atom]]))
    

    for atom1 in range(1, len(atom_types)+1):
        for atom2 in range(atom1 + 1, len(atom_types)+1):
            bonded_flag = 0
            for i, bond in enumerate(bonds):
                if (atom1 in bond) and (atom2 in bond):
                    edge_index.append([bond[0] - 1,bond[1] - 1])
                    coord1 = np.array(atom_coordinates[bond[0]])
                    coord2 = np.array(atom_coordinates[bond[1]])
                    dist = math.dist(coord1, coord2)
                    d = []
                    for l in range(12):
                        d.append(np.exp((-1.0*(dist - 2.0*(l + 0.5))**2.0)/norm_map_distance))
                    bond_type = bond_type_dict[bond_types[i]]
                    edge_attr.append(np.hstack((d,d,d,d,d,d,d,d,d,bond_type)))
                    bonded_flag = 1
                
            if bonded_flag == 0:
                coord1 = np.array(atom_coordinates[atom1])
                coord2 = np.array(atom_coordinates[atom2])
                dist = math.dist(coord1, coord2)
                if dist < map_distance:
                    edge_index.append([atom1 - 1,atom2 - 1])
                    d = []
                    for l in range(12):
                        d.append(np.exp((-1.0*(dist - 2.0*(l + 0.5))**2.0)/norm_map_distance))
                    bond_type = bond_type_dict['nc']
                    edge_attr.append(np.hstack((d,d,d,d,d,d,d,d,d,bond_type)))

    
    edge_index = np.array(edge_index)
    edge_index = edge_index.transpose()
    edge_index = torch.Tensor(edge_index)
    edge_index = edge_index.to(torch.int64)
    edge_attr = torch.Tensor(np.array(edge_attr))
    node_feature = torch.stack(node_feature)
    
    #Master_node
    new_edge_index = []
    new_edge_attr = []
    node_features = torch.cat((node_feature,torch.zeros(len(atom2emb['N'])).unsqueeze(0)),dim = 0)
    
    for i in range(len(node_features) - 1):
        new_edge_index.append([i,int(len(node_features)-1)])
        bond_type = bond_type_dict['nc']
        new_edge_attr.append(np.hstack((np.zeros(9*len(d)),bond_type)))
    
    new_edge_index = np.array(new_edge_index)
    new_edge_index = new_edge_index.transpose()
    new_edge_index = torch.Tensor(new_edge_index)
    new_edge_index = new_edge_index.to(torch.int64)
    new_edge_attr = torch.Tensor(np.array(new_edge_attr))    
    
    edge_index = torch.cat((edge_index,new_edge_index), dim = 1)
    edge_attr = torch.cat((edge_attr,new_edge_attr), dim = 0)
    
    graph = Data(x = node_feature, edge_index = edge_index,edge_attr = edge_attr)#, pos = new_mol_coords)
    graph.label = filename.split('.')[0]
    return graph

In [6]:
AA_graphs = []
for filename in os.listdir(CFG.AA_mol2_files):
    AA_graphs.append(molecule2graph_AA(filename,12.0))

In [7]:
print(random.choice(AA_graphs))

Data(x=[21, 133], edge_index=[2, 231], edge_attr=[231, 114], label='tyr')


In [8]:
softmax = nn.Softmax(dim = 0)
for i, graph in enumerate(AA_graphs):
    AA_graphs[i].y = softmax(BLOSUM_encode_single(graph.label,BLOSUM_dict_three_letter))

In [9]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.node_feature_size = 133
        self.node_feature_hidden_size = 133
        self.node_feature_size_out = 20
        self.conv1 = GENConv(self.node_feature_size,self.node_feature_hidden_size,aggr = 'mean',edge_dim = 114, num_layer = 2,norm = 'layer')
        self.conv2 = GENConv(self.node_feature_hidden_size,self.node_feature_hidden_size,aggr = 'mean',edge_dim = 114,num_layer = 2,norm = 'layer')
        self.conv3 = GENConv(self.node_feature_hidden_size,self.node_feature_hidden_size,aggr = 'mean',edge_dim = 114,num_layer = 2,norm = 'layer')
        self.linear1 = nn.Linear(self.node_feature_hidden_size,self.node_feature_size_out)
        self.ReLu = nn.ReLU()
        self.softmax = torch.nn.Softmax(dim=1)

    def forward(self,graph):
        x, edge_index, edge_attr = graph.x, graph.edge_index, graph.edge_attr
        x1 = self.conv1(x, edge_index, edge_attr)
        x1 = self.ReLu(x1)
        x1 = self.conv2(x1, edge_index, edge_attr)
        x1 = self.ReLu(x1)
        x1 = self.conv3(x1, edge_index, edge_attr)
        x1 = x1[-1]
        x1 = torch.tanh(x1)
        x1 = self.linear1(x1)
        return x1
    
    def encode(self,graph):
        x, edge_index, edge_attr = graph.x,graph.edge_index,graph.edge_attr
        x1 = self.conv1(x, edge_index,edge_attr)
        x1 = self.ReLu(x1)
        x1 = self.conv2(x1, edge_index,edge_attr)
        x1 = self.ReLu(x1)
        x1 = self.conv3(x1, edge_index,edge_attr)
        x1 = x1[-1]
        x1 = torch.tanh(x1)
        return x1
    
    def decode(self,encoding):
        x1 = self.linear1(encoding)
        return x1

In [10]:
from torch_geometric.loader import DataLoader
train_dl = DataLoader(AA_graphs,batch_size = 1, shuffle = True)

In [11]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Net()
model.to(DEVICE) # put on GPU

# Define a loss function (e.g., Mean Squared Error) and an optimizer (e.g., Adam)
criterion = loss_func = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5000,8000], gamma=0.1)
#optimizer = torch.optim.SGD(model.parameters(),lr = 5e-6)

# Training loop
num_epochs = 10000  # Adjust the number of epochs as needed
losses = []
lowest = 0.01

for epoch in range(num_epochs):
    total_loss = 0.0
    val_loss = 0.0
    
        
    for batch in train_dl:
        model.train()
        inputs = batch[0].to(DEVICE)
        print(inputs)

        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        
        # Compute the loss
        loss = criterion(outputs, inputs.y)
        
        # Backpropagation and optimization
        loss.backward()
        optimizer.step()
        
        inputs= inputs.to('cpu')
        
        total_loss += loss.item()
            
    for batch in train_dl:
        with torch.no_grad():
            model.eval()
            inputs = batch[0].to(DEVICE)
        
            # Forward pass
            outputs = model(inputs)
        
            # Compute the loss
            loss = criterion(outputs, inputs.y)
        
            inputs= inputs.to('cpu')
            
            val_loss += loss.item()
            
    # Print the average loss for this epoch
    avg_loss = total_loss / len(train_dl)
    val_avg_loss = val_loss / len(train_dl)
    
    if lowest > val_avg_loss:
        torch.save(model.state_dict(), 'AA_encoder_11172023.pt')
        lowest = val_avg_loss
    
    print(f'Epoch [{epoch+1}/{num_epochs}] Loss: {avg_loss:.4f} Val Loss: {val_avg_loss:.4f}')
    
print('Training complete')


Data(x=[20, 133], edge_index=[2, 210], edge_attr=[210, 114], label='phe', y=[20])


/opt/conda/conda-bld/pytorch_1682343995622/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [62,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch_1682343995622/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [63,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch_1682343995622/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [64,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch_1682343995622/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [65,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch_1682343995622/work/aten/src/ATen/native/cuda/ScatterGat

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
model.load_state_dict(torch.load('AA_encoder_11172023.pt'))
graph = random.choice(data)
pred = model.forward(graph.to(DEVICE))
print(pred)
print(graph.y)

In [None]:
embedding = model.encode(graph.to(DEVICE))

In [None]:
print(embedding)

In [None]:
with open('atom2emb.pkl', 'rb') as f:
    atom2emb = pickle.load(f)

In [None]:
atom2emb['Fe']

In [None]:
AA_embeddings = {}
for graph in graph_list:
    pred = model.encode(graph[1].to(DEVICE))
    AA_embeddings.update({upper2lower[graph[0]]:pred})

In [None]:
for i in AA_embeddings:
    print(i)
    print(AA_embeddings[i])

In [None]:
with open('AA_embeddings_11172023.pkl', 'wb') as f:
    pickle.dump(AA_embeddings, f)

In [None]:
def get_atom_symbol(atomic_number):
    return Chem.PeriodicTable.GetElementSymbol(Chem.GetPeriodicTable(), atomic_number)

def remove_hetatm(input_pdb_file, output_pdb_file):
    # Open the input PDB file for reading and the output PDB file for writing
    with open(input_pdb_file, 'r') as infile, open(output_pdb_file, 'w') as outfile:
        for line in infile:
            # Check if the line starts with 'HETATM' (non-protein atoms)
            if line.startswith('HETATM'):
                continue  # Skip this line (HETATM record)
            # Write all other lines to the output file
            outfile.write(line)
            
def get_atom_types_from_sdf(sdf_file):
    supplier = Chem.SDMolSupplier(sdf_file)
    atom_types = set()

    for mol in supplier:
        if mol is not None:
            atoms = mol.GetAtoms()
            atom_types.update([atom.GetSymbol() for atom in atoms])

    return sorted(list(atom_types))

def get_atom_types_from_mol2_split(mol2_file):
    atom_types = set()

    with open(mol2_file, 'r') as mol2:
        reading_atoms = False
        for line in mol2:
            if line.strip() == '@<TRIPOS>ATOM':
                reading_atoms = True
                continue
            elif line.strip() == '@<TRIPOS>BOND':
                break

            if reading_atoms:
                parts = line.split()
                if len(parts) >= 5:
                    atom_type = parts[5]
                    atom_types.add(atom_type)
    
    atom_types_split = set()
    for atom in atom_types:
        atom_types_split.add(str(atom).split('.')[0])
        

    return sorted(list(atom_types_split))

def get_atom_types_from_mol2(mol2_file):
    atom_types = set()

    with open(mol2_file, 'r') as mol2:
        reading_atoms = False
        for line in mol2:
            if line.strip() == '@<TRIPOS>ATOM':
                reading_atoms = True
                continue
            elif line.strip() == '@<TRIPOS>BOND':
                break

            if reading_atoms:
                parts = line.split()
                if len(parts) >= 5:
                    atom_type = parts[5]
                    atom_types.add(atom_type)

    return sorted(list(atom_types))

def get_atom_list_from_mol2_split(mol2_file):
    atoms = []
    with open(mol2_file, 'r') as mol2:
        reading_atoms = False
        for line in mol2:
            if line.strip() == '@<TRIPOS>ATOM':
                reading_atoms = True
                continue
            elif line.strip() == '@<TRIPOS>BOND':
                break

            if reading_atoms:
                parts = line.split()
                if len(parts) >= 5:
                    atom_type = parts[5]
                    atoms.append(atom_type)
    
    atom_list = []
    for atom in atoms:
        atom_list.append(str(atom).split('.')[0])
        

    return atom_list

def get_atom_list_from_mol2(mol2_file):
    atoms = []
    with open(mol2_file, 'r') as mol2:
        reading_atoms = False
        for line in mol2:
            if line.strip() == '@<TRIPOS>ATOM':
                reading_atoms = True
                continue
            elif line.strip() == '@<TRIPOS>BOND':
                break

            if reading_atoms:
                parts = line.split()
                if len(parts) >= 5:
                    atom_type = parts[5]
                    atoms.append(atom_type)

    return atoms

def get_bond_types_from_mol2(mol2_file):
    bond_types = set()

    with open(mol2_file, 'r') as mol2:
        reading_bonds = False
        for line in mol2:
            if line.strip() == '@<TRIPOS>BOND':
                reading_bonds = True
                continue
            elif reading_bonds and line.strip().startswith('@<TRIPOS>'):
                break

            if reading_bonds:
                parts = line.split()
                if len(parts) >= 4:
                    bond_type = parts[3]
                    bond_types.add(bond_type)

    return sorted(list(bond_types))

def read_mol2_bonds(mol2_file):
    bonds = []
    bond_types = []

    with open(mol2_file, 'r') as mol2:
        reading_bonds = False
        for line in mol2:
            if line.strip() == '@<TRIPOS>BOND':
                reading_bonds = True
                continue
            elif reading_bonds and line.strip().startswith('@<TRIPOS>'):
                break

            if reading_bonds:
                parts = line.split()
                if len(parts) >= 4:
                    atom1_index = int(parts[1])
                    atom2_index = int(parts[2])
                    bond_type = parts[3]
                    bonds.append((atom1_index, atom2_index))
                    bond_types.append(bond_type)

    return bonds, bond_types

def calc_residue_dist(residue_one, residue_two) :
    """Returns the C-alpha distance between two residues"""
    diff_vector  = residue_one["CA"].coord - residue_two["CA"].coord
    return np.sqrt(np.sum(diff_vector * diff_vector))

def calc_dist_matrix(chain_one, chain_two) :
    """Returns a matrix of C-alpha distances between two chains"""
    answer = np.zeros((len(chain_one), len(chain_two)), float)
    for row, residue_one in enumerate(chain_one) :
        for col, residue_two in enumerate(chain_two) :
            answer[row, col] = calc_residue_dist(residue_one, residue_two)
    return answer

def calc_contact_map(uniID,map_distance):
    pdb_code = uniID
    pdb_filename = uniID+"_pocket_clean.pdb"
    structure = Bio.PDB.PDBParser(QUIET = True).get_structure(pdb_code, (CFG.pdbfiles +'/'+pdb_code+'/'+pdb_filename))
    model = structure[0]
    flag1 = 0
    flag2 = 0
    idx = 0
    index = []
    chain_info = []
    
    for chain1 in model:
        for resi in chain1:
            index.append(idx)
            idx += 1
            chain_info.append([chain1.id,resi.id])
        for chain2 in model:
            if flag1 == 0:
                dist_matrix = calc_dist_matrix(model[chain1.id], model[chain2.id])
            else:
                new_matrix = calc_dist_matrix(model[chain1.id], model[chain2.id])
                dist_matrix = np.hstack((dist_matrix,new_matrix))
            flag1 += 1
        flag1 = 0
        if flag2 == 0:
            top_matrix = dist_matrix
        else:
            top_matrix = np.vstack((top_matrix,dist_matrix))
        flag2 += 1
    
    contact_map = top_matrix < map_distance
    return contact_map, index, chain_info

one_letter_to_three_letter_dict = {'G':'gly',
                                   'A':'ala',
                                   'V':'val',
                                   'C':'cys',
                                   'P':'pro',
                                   'L':'leu',
                                   'I':'ile',
                                   'M':'met',
                                   'W':'trp',
                                   'F':'phe',
                                   'K':'lys',
                                   'R':'arg',
                                   'H':'his',
                                   'S':'ser',
                                   'T':'thr',
                                   'Y':'tyr',
                                   'N':'asn',
                                   'Q':'gln',
                                   'D':'asp',
                                   'E':'glu'
    
}

def BLOSUM_encode_single(seq,AA_dict):
    allowed = set("gavcplimwfkrhstynqdeuogavcplimwfkrhstynqde")
    if not set(seq).issubset(allowed):
        invalid = set(seq) - allowed
        raise ValueError(f"Sequence has broken AA: {invalid}")
    vec = AA_dict[seq]
    return vec

matrix = bl.BLOSUM(62)
allowed_AA = "GAVCPLIMWFKRHSTYNQDE"
BLOSUM_dict_three_letter = {}
for i in allowed_AA:
    vec = []
    for j in allowed_AA:
        vec.append(matrix[i][j])
    BLOSUM_dict_three_letter.update({one_letter_to_three_letter_dict[i]:torch.Tensor(vec)})

def uniID2graph(uniID,map_distance):
    atom_name = 'CA'
    node_feature = []
    edge_index = []
    edge_attr = []
    coord = []
    contact_map, index, chain_info = calc_contact_map(uniID,map_distance)
    pdb_code = uniID
    pdb_filename = uniID+"_pocket_clean.pdb"
    structure = Bio.PDB.PDBParser(QUIET = True).get_structure(pdb_code, (CFG.pdbfiles +'/'+pdb_code+'/'+pdb_filename))
    model = structure[0]
    
    for i in index:
        node_feature.append(one_hot_encode_single_res(model[chain_info[i][0]][chain_info[i][1]].get_resname()))
        coord.append(model[chain_info[i][0]][chain_info[i][1]]['CA'].coord)
        for j in index:
            if contact_map[i,j] == 1:
                edge_index.append([i,j])
                diff_vector = model[chain_info[i][0]][chain_info[i][1]]['CA'].coord - model[chain_info[j][0]][chain_info[j][1]]['CA'].coord
                dist = (np.sqrt(np.sum(diff_vector * diff_vector))/map_distance)
                bond_type = bond_type_dict['nc']
                edge_attr.append(np.hstack((dist,bond_type)))
                            
    edge_index = np.array(edge_index)
    edge_index = edge_index.transpose()
    edge_index = torch.Tensor(edge_index)
    edge_index = edge_index.to(torch.int64)
    edge_attr = torch.Tensor(edge_attr)
    node_feature = torch.stack(node_feature)
    graph = Data(x = node_feature, edge_index = edge_index,edge_attr = edge_attr)
    return graph, coord

def read_mol2_bonds_and_atoms(mol2_file):
    bonds = []
    bond_types = []
    atom_types = {}
    atom_coordinates = {}

    with open(mol2_file, 'r') as mol2:
        reading_bonds = False
        reading_atoms = False
        for line in mol2:
            if line.strip() == '@<TRIPOS>BOND':
                reading_bonds = True
                continue
            elif line.strip() == '@<TRIPOS>ATOM':
                reading_atoms = True
                continue
            elif line.strip().startswith('@<TRIPOS>SUBSTRUCTURE'):
                break
            elif reading_bonds and line.strip().startswith('@<TRIPOS>'):
                reading_bonds = False
            elif reading_atoms and line.strip().startswith('@<TRIPOS>'):
                reading_atoms = False


            if reading_bonds:
                parts = line.split()
                if len(parts) >= 4:
                    atom1_index = int(parts[1])
                    atom2_index = int(parts[2])
                    bond_type = parts[3]
                    bonds.append((atom1_index, atom2_index))
                    bond_types.append(bond_type)

            if reading_atoms:
                parts = line.split()
                if len(parts) >= 6:
                    atom_index = int(parts[0])
                    atom_type = parts[5]
                    x, y, z = float(parts[2]), float(parts[3]), float(parts[4])
                    atom_types[atom_index] = atom_type.split('.')[0]
                    atom_coordinates[atom_index] = (x, y, z)

    return bonds, bond_types, atom_types, atom_coordinates

def molecule2graph_AA(filename,map_distance):
    node_feature = []
    edge_index = []
    edge_attr = []
    mol2_file = CFG.pdbfiles+filename+'/'+filename+'_ligand.mol2'
    bonds, bond_types, atom_types, atom_coordinates = read_mol2_bonds_and_atoms(mol2_file)
    for atom in atom_types:
        node_feature.append(torch.zeros(20))
        #node_feature.append(torch.Tensor(atom2emb[atom_types[atom]]))
    for i in range(len(bonds)):
        bond = bonds[i]
        edge_index.append([bond[0] - 1,bond[1] - 1])
        coord1 = np.array(atom_coordinates[bond[0]])
        coord2 = np.array(atom_coordinates[bond[1]])
        dist = [np.sqrt(np.sum((coord1 - coord2)*(coord1 - coord2)))/map_distance]
        bond_type = bond_type_dict[bond_types[i]]
        edge_attr.append(np.hstack((dist,bond_type)))
    
    edge_index = np.array(edge_index)
    edge_index = edge_index.transpose()
    edge_index = torch.Tensor(edge_index)
    edge_index = edge_index.to(torch.int64)
    edge_attr = torch.Tensor(edge_attr)
    node_feature = torch.stack(node_feature)
    graph = Data(x = node_feature, edge_index = edge_index,edge_attr = edge_attr)
    
    return graph, atom_coordinates

def id2fullgraph(filename, map_distance):
    prot_graph, prot_coord = uniID2graph(filename,map_distance)
    mol_graph, mol_coord = molecule2graph(filename,map_distance)
    mol_coord = [mol_coord[i] for i in mol_coord]
    node_features = torch.cat((prot_graph.x,mol_graph.x),dim = 0)
    update_edge_index = mol_graph.edge_index + prot.x.size()[0]
    edge_index = torch.cat((prot_graph.edge_index,update_edge_index), dim = 1)
    edge_attr = torch.cat((prot_graph.edge_attr,mol_graph.edge_attr), dim = 0)
    
    new_edge_index = []
    new_edge_attr = []
    for i in range(len(mol_coord)):
        for j in range(len(prot_coord)):
            dist_vec = mol_coord[i] - prot_coord[j]
            dist = np.sqrt(np.sum(dist_vec*dist_vec))/map_distance
            if dist < 1.0:
                new_edge_index.append([j,i + len(prot_coord)])
                new_edge_attr.append((np.hstack(([dist],bond_type_dict['nc']))))
                
    new_edge_index = np.array(new_edge_index)
    new_edge_index = new_edge_index.transpose()
    new_edge_index = torch.Tensor(new_edge_index)
    new_edge_index = new_edge_index.to(torch.int64)
    new_edge_attr = torch.Tensor(new_edge_attr)
    
    edge_index = torch.cat((edge_index,new_edge_index), dim = 1)
    edge_attr = torch.cat((edge_attr,new_edge_attr), dim = 0)
    
    graph = Data(x = node_features, edge_index = edge_index,edge_attr = edge_attr)
    
    return graph

def molecule2graph_AA(filename,map_distance):
    node_feature = []
    edge_index = []
    edge_attr = []
    mol2_file = filename
    bonds, bond_types, atom_types, atom_coordinates = read_mol2_bonds_and_atoms(mol2_file)
    for atom in atom_types:
        #node_feature.append(torch.zeros(20))
        node_feature.append(torch.Tensor(atom2emb[atom_types[atom]]))
    for i in range(len(bonds)):
        bond = bonds[i]
        edge_index.append([bond[0] - 1,bond[1] - 1])
        coord1 = np.array(atom_coordinates[bond[0]])
        coord2 = np.array(atom_coordinates[bond[1]])
        dist = [np.sqrt(np.sum((coord1 - coord2)*(coord1 - coord2)))/map_distance]
        bond_type = bond_type_dict[bond_types[i]]
        edge_attr.append(np.hstack((dist,bond_type)))
    
    #Master_node
    node_feature.append(torch.zeros(len(atom2emb['N'])))
    
    for i in range(len(node_feature) - 1):
        edge_index.append([i,int(len(node_feature)-1)])
        bond_type = bond_type_dict['1']
        edge_attr.append(np.hstack((1.0,bond_type)))
    
    edge_index = np.array(edge_index)
    edge_index = edge_index.transpose()
    edge_index = torch.Tensor(edge_index)
    edge_index = edge_index.to(torch.int64)
    edge_attr = torch.Tensor(edge_attr)
    node_feature = torch.stack(node_feature)
    graph = Data(x = node_feature, edge_index = edge_index,edge_attr = edge_attr)
    
    return graph, atom_coordinates

upper2lower = {
    "ala": "ALA",
    "arg": "ARG",
    "asn": "ASN",
    "asp": "ASP",
    "cys": "CYS",
    "gln": "GLN",
    "glu": "GLU",
    "gly": "GLY",
    "his": "HIS",
    "ile": "ILE",
    "leu": "LEU",
    "lys": "LYS",
    "met": "MET",
    "phe": "PHE",
    "pro": "PRO",
    "ser": "SER",
    "thr": "THR",
    "trp": "TRP",
    "tyr": "TYR",
    "val": "VAL",
}