In [922]:
from tqdm import tqdm
from collections import defaultdict
from itertools import product

import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F 
from torch.nn import Linear, Parameter, GRUCell, Conv1d
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing, GCNConv
from torch_geometric.utils import add_self_loops, degree, smiles

from rdkit import Chem

In [3]:
dataset_dir_path = "/hdd3/dti_databank/preprocessed/dataset_220622"
complex_metadata = "complex_metadata_pdb_2020_general.csv"
ligand_metadata = "ligand_metadata_pdb_2020_general.csv"
protein_metadata  = "protein_metadata_pdb_2020_general.csv"

---
### Raw data load

In [4]:
ba_measure = 'KIKD' # IC50, nan

comp_meta_df = pd.read_csv(f'{dataset_dir_path}/{complex_metadata}')
comp_meta_df = comp_meta_df[comp_meta_df.ba_measure==ba_measure]
lig_meta_df = pd.read_csv(f'{dataset_dir_path}/{ligand_metadata}')
prot_meta_df = pd.read_csv(f'{dataset_dir_path}/{protein_metadata}')

---
### Make dataset for affinity

In [6]:
comp_lig_df = pd.merge(left=comp_meta_df, right=lig_meta_df, how='inner', on='ligand_id')
comp_lig_prot_df = pd.merge(left=comp_lig_df, right=prot_meta_df, how='inner', on='protein_id')
c_p_df = comp_lig_prot_df[['ba_value','smiles','fasta']]
c_p_df.tail()

Unnamed: 0,ba_value,smiles,fasta
9131,12.0,CC1CN(c2ccncc2NC(=O)c2ccc(F)c(-c3c(F)cccc3F)n2...,EPLESQYQVGPLLGSGGFGSVYSGIRVSDNLPVAIKHVEKDRISDW...
9132,12.39,Nc1ccc2c(c1)c(-c1ccccc1)[n+](CCCCCCc1cnnn1CCNc...,EGREDPQLLVRVRGGQLRGIRLKAPGGPVSAFLGIPFAEPPVGSRR...
9133,15.0,Nc1ccc2c(c1)c(-c1ccccc1)[n+](CCCCCCc1cnnn1CCNc...,EGREDPQLLVRVRGGQLRGIRLKAPGGPVSAFLGIPFAEPPVGSRR...
9134,12.72,O=C(O)C(O)(COP(=O)(O)O)C(O)C(O)COP(=O)(O)O,ASVGFKAGVKDYKLTYYTPEYETLDTDILAAFRVSPQPGVPPEEAG...
9135,13.82,C[N+](C)(C)c1cccc(C(O)(O)C(F)(F)F)c1,SELLVNTKSGKVMGTRVPVLSSHISAFLGIPFAEPPVGNMRFRRPE...


---
# The Graph Convolution Module
---

### 1) Smiles to graph

In [7]:
# node len : 63개
# Degree : [0,1,2,3,4,5] - 6개
# ExplicitValence : [1,2,3,4,5,6] - 6개
# ImplicitValence : [0,1,2,3,4,5] - 6개
# Aromatic : [0 or 1] - 1개
# 총 length : 82개 ()

ELEM_LIST = [
    'C', 'N', 'O', 'S', 'F', 
    'Si', 'P', 'Cl', 'Br', 'Mg', 
    'Na', 'Ca', 'Fe', 'As', 'Al', 
    'I', 'B', 'V', 'K', 'Tl', 
    'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 
    'Co', 'Se', 'Ti', 'Zn', 'H', 
    'Li', 'Ge', 'Cu', 'Au', 'Ni', 
    'Cd', 'In', 'Mn', 'Zr', 'Cr', 
    'Pt', 'Hg', 'Pb', 'W', 'Ru', 
    'Nb', 'Re', 'Te', 'Rh', 'Tc', 
    'Ba', 'Bi', 'Hf', 'Mo', 'U', 
    'Sm', 'Os', 'Ir', 'Ce', 'Gd',
    'Ga','Cs', 'unknown'
]

In [8]:
# For get_atom_feature
def one_of_k_encoding(x, vocab:list) -> list:
	if x not in vocab:
		x = vocab[-1]
	return list(map(lambda s: int(x==s), vocab))

In [432]:
# For get_molecular_graph
def get_atom_feature(atom) -> list:
    atom_feature =  one_of_k_encoding(atom.GetSymbol(), ELEM_LIST)
    atom_feature += one_of_k_encoding(atom.GetDegree(), [1,2,3,4,5,6])
    atom_feature += one_of_k_encoding(atom.GetExplicitValence(), [1,2,3,4,5,6])
    atom_feature += one_of_k_encoding(atom.GetImplicitValence(), [0,1,2,3,4,5])
    atom_feature += [int(atom.GetIsAromatic())]
    
    return [atom_feature]

In [10]:
def get_bond_idx(bond) -> list:
    begin = bond.GetBeginAtomIdx()
    end = bond.GetEndAtomIdx()
    bond_idx = [[begin, end],[end, begin]]
    
    return bond_idx

In [455]:
# For get_molecular_graph
def get_bond_feature(bond) -> list:
    bt = bond.GetBondType()
    bond_feature = [
        bt == Chem.rdchem.BondType.SINGLE,
        bt == Chem.rdchem.BondType.DOUBLE,
        bt == Chem.rdchem.BondType.TRIPLE,
        bt == Chem.rdchem.BondType.AROMATIC,
        bond.GetIsConjugated(),
        bond.IsInRing()
    ]
    return [bond_feature,bond_feature]
    

In [456]:
def get_molecular_graph(smi: str) -> Data:
    
    mol = Chem.MolFromSmiles(smi)

    # graph : get x
    atom_feature_ls = [get_atom_feature(atom) for atom in mol.GetAtoms()]
    atom_feature_ts = torch.tensor(atom_feature_ls, dtype=torch.float32).view(-1,82)

    # graph : get edge_index, edge_attr
    bond_idx_ls, bond_feature_ls = [], []
    for bond in mol.GetBonds():
        bond_idx_ls += get_bond_idx(bond)
        bond_feature_ls += get_bond_feature(bond)
    bond_idx_ts = torch.tensor(bond_idx_ls, dtype=torch.long).t().view(2,-1)
    bond_feature_ts = torch.tensor(bond_feature_ls, dtype=torch.float32).view(-1,6)
    
    return Data(x=atom_feature_ts, edge_index=bond_idx_ts, edge_attr=bond_feature_ts, smiles=smi)

In [457]:
graph = get_molecular_graph(c_p_df.smiles[0])
graph

Data(x=[13, 82], edge_index=[2, 24], edge_attr=[24, 6], smiles='CC(=O)NC(CCC(=O)O)C(=O)O')

### 2)Graph Convolution Module

In [840]:
# Start
print('x')
print(graph.x.shape)
print()

# Equation 1
print('v_i_0')
w_init = Linear(82,128)
v_i_0 = F.leaky_relu(w_init(graph.x))
print(v_i_0.shape)
print()

# Equation 2
print('t_i_1')
w_gather = Linear(134,128)
row,col = graph.edge_index[0], graph.edge_index[1]
sca = scatter_add(src=graph.edge_attr.T, index=col, dim_size=graph.num_nodes)
print(sca.shape)
v_i_0_sca = torch.cat((v_i_0, sca.T),1)
print(v_i_0_sca.shape)
t_i_1 = F.leaky_relu(w_gather(v_i_0_sca))
print(t_i_1.shape)
print()

# Equation 3
print('u_i_1')
w_update = Linear(256,128)
v_i_0_t_i_1 = torch.cat((v_i_0, t_i_1), axis=1)
u_i_0 = F.leaky_relu(w_update(v_i_0_t_i_1))
print(u_i_0.shape)
print()

# Equation 4
print('s__0')
s__0 = torch.sum(v_i_0, axis=0, keepdim=True)
print(s__0.shape)
print()

# Equation 5
print('u_s_0')
w_super = Linear(128,128)
u_s_0 = F.tanh(w_super(s__0))
print(u_s_0.shape)
print()

x
torch.Size([13, 82])

v_i_0
torch.Size([13, 128])

t_i_1
torch.Size([6, 13])
torch.Size([13, 134])
torch.Size([13, 128])

u_i_1
torch.Size([13, 128])

s__0
torch.Size([1, 128])

u_s_0
torch.Size([1, 128])





In [795]:
############ Step 1 ############

# Equation 6
print('u_s_to_v_0')
u_s_to_v_0 = F.tanh(w_s_to_v(s__0))
print(u_s_to_v_0.shape)
print()

# Equation 7
k = 2
w_s_to_v = Linear(128,128)
w_v_to_s = Linear(128*k,128)
a_v_ls = []

for idx in range(1,k+1):
    w_vatt_0_0 = Linear(128,128)
    w_satt_0_0 = Linear(128,128)
    w_att_0_0 = Linear(128,1)
    
    # Equation 9
    print(f'b_v_i_1_0 : k = {idx}')
    b_v_i_1_0 = F.tanh(w_vatt_0_0(v_i_0)) * F.tanh(w_satt_0_0(s__0))
    print(b_v_i_1_0.shape)
    print()

    # Equation 8
    print(f'a_v_i_1_0 : K = {idx}')
    a_v_i_1_0 = F.softmax(w_att_0_0(b_v_i_1_0))
    print(a_v_i_1_0.shape)
    print()

    a_v_j_0 = torch.sum(a_v_i_1_0*v_i_0, axis=0, keepdim=True)
    a_v_ls += a_v_j_0

a_v_j_k_0 = torch.cat(a_v_ls,dim=0).view(1,-1)

print('u_v_to_s_0')
u_v_to_s_0 = F.tanh(w_v_to_s(a_v_j_k_0))
print(u_v_to_s_0.shape)

u_s_to_v_0
torch.Size([1, 128])

b_v_i_1_0 : k = 1
torch.Size([13, 128])

a_v_i_1_0 : K = 1
torch.Size([13, 1])

b_v_i_1_0 : k = 2
torch.Size([13, 128])

a_v_i_1_0 : K = 2
torch.Size([13, 1])

u_v_to_s_0
torch.Size([1, 128])


  a_v_i_1_0 = F.softmax(w_att_0_0(b_v_i_1_0))


In [818]:
############ Step 2 ############

# Equation 10
print('g_v_to_s_0')
w_gate11 = Linear(128,128)
w_gate12 = Linear(128,128)
g_v_to_s_0 = F.sigmoid(w_gate11(u_v_to_s_0) + w_gate12(u_s_0))
print(g_v_to_s_0.shape)
print()

# Equation 11
print('t_v_to_s_0')
t_v_to_s_0 = (1-g_v_to_s_0)*u_v_to_s_0 + g_v_to_s_0*u_s_0
print(t_v_to_s_0.shape)
print()

# Equation 12
print('g_s_to_i_0')
w_gate21 = Linear(128,128)
w_gate22 = Linear(128,128)
g_s_to_i_0 = F.sigmoid(w_gate21(u_i_0) + w_gate22(u_s_to_v_0))
print(g_s_to_i_0.shape)
print()

# Equation 13
print('t_s_to_i_0')
t_s_to_i_0 = (1-g_s_to_i_0)*u_i_0 + g_s_to_i_0*u_s_to_v_0
print(t_s_to_i_0.shape)

g_v_to_s_0
torch.Size([1, 128])

t_v_to_s_0
torch.Size([1, 128])

g_s_to_i_0
torch.Size([13, 128])

t_s_to_i_0
torch.Size([13, 128])


In [839]:
############ Step 3 ############

# Equation 14
print('v_i_1')
gru_v = GRUCell(128,128)
v_i_1 = gru_v(v_i_0, t_s_to_i_0)
print(v_i_1.shape)
print()

# Equation 15
print('s__1')
gru_s = GRUCell(128,128)
s__1 = gru_s(s__0,t_v_to_s_0)
print(s__1.shape)

v_i_1
torch.Size([13, 128])

s__1
torch.Size([1, 128])


---
# The CNN Module
---

### 1) Fasta to sequence

In [1082]:
def load_blosum62(blosum_path = '/hdd3/seungheun/dti_study/blosum62.txt') -> dict:
    blosum_dict = {}
    with open(blosum_path, 'r') as fr:
        for line in fr:
            if line.startswith(' '):
                continue
            parsed = line.strip('\n').split()
            blosum_dict[parsed[0]] = np.array(parsed[1:]).astype(np.float32)
    return blosum_dict

In [1087]:
blosum_dict = load_blosum62()
fasta = c_p_df.fasta[0]
init_feature = torch.tensor([blosum_dict[r] for r in fasta]).T
init_feature

tensor([[ 0., -2.,  1.,  ..., -1.,  0., -2.],
        [-3., -2., -1.,  ..., -3., -3., -2.],
        [-1., -3.,  0.,  ..., -2., -1., -3.],
        ...,
        [-3., -1., -2.,  ..., -3., -3., -1.],
        [-2.,  1., -3.,  ..., -3., -2.,  2.],
        [-3.,  3., -2.,  ..., -2., -3.,  7.]])

### 2) 1-D Covolution layers

In [1116]:
kernel_size = 7 # or 5
padding = int((kernel_size-1)/2)
conv1d = Conv1d(20,128,kernel_size=kernel_size,padding=padding,dtype=torch.float32)
r_j_1 = F.leaky_relu(conv1d(init_feature)).T
r_j_1.shape

torch.Size([290, 128])

---
# The Pairwise Interaction Prediction Module
---

In [1117]:
# (Equation 16)
print("p_i_j")
w_atom = Linear(128,128)
w_residue = Linear(128,128)
p_i_j = F.sigmoid(torch.mm(F.leaky_relu(w_atom(v_i_1)), F.leaky_relu(w_residue(r_j_1)).T))
print(p_i_j.shape)

p_i_j
torch.Size([13, 290])


---
# The Affinity Prediction Module
---

In [1133]:
h_2 = 128 # or 64

# (Equation 17)
print('h_v_i')
w_v = Linear(128, h_2)
h_v_i = F.leaky_relu(w_v(v_i_1))
print(h_v_i.shape)
print()

# (Equation 18)
print('h_s')
w_r = Linear(128, h_2)
h_s = F.leaky_relu(w_r(s__1))
print(h_s.shape)
print()

# (Equation 19)
print('h_r_j')
w_s = Linear(128, h_2)
h_r_j = F.leaky_relu(w_s(r_j_1))
print(h_r_j.shape)
print()



# (Equation 20)
print('h_c_0')
h_c_0 = torch.sum(h_v_i,axis=0,keepdim=True)/h_v_i.size(0)
print(h_c_0.shape)
print()

# (Equation 21)
print('h_p_0')
h_p_0 = torch.sum(h_r_j,axis=0,keepdim=True)/h_r_j.size(0)
print(h_p_0.shape)
print()

# (Equation 22)
print('m_0')
m_0 = h_c_0 * h_p_0
print(m_0.shape)
print()



#

h_v_i
torch.Size([13, 128])

h_s
torch.Size([1, 128])

h_r_j
torch.Size([290, 128])

h_c_0
torch.Size([1, 128])

h_p_0
torch.Size([1, 128])

m_0
torch.Size([1, 128])

