In [189]:
import tqdm
from collections import defaultdict

import numpy as np
import pandas as pd

import torch
from rdkit import Chem
import dgl

In [190]:
dataset_dir_path = "/hdd3/dti_databank/preprocessed/dataset_220622"
complex_metadata = "complex_metadata_pdb_2020_general.csv"
ligand_metadata = "ligand_metadata_pdb_2020_general.csv"
protein_metadata  = "protein_metadata_pdb_2020_general.csv"

---
### Raw data load

In [191]:
ba_measure = 'KIKD' # IC50, nan

comp_meta_df = pd.read_csv(f'{dataset_dir_path}/{complex_metadata}')
comp_meta_df = comp_meta_df[comp_meta_df.ba_measure==ba_measure]
lig_meta_df = pd.read_csv(f'{dataset_dir_path}/{ligand_metadata}')
prot_meta_df = pd.read_csv(f'{dataset_dir_path}/{protein_metadata}')

---
### Make dataset for affinity

In [192]:
comp_lig_df = pd.merge(left=comp_meta_df, right=lig_meta_df, how='inner', on='ligand_id')
comp_lig_prot_df = pd.merge(left=comp_lig_df, right=prot_meta_df, how='inner', on='protein_id')
c_p_df = comp_lig_prot_df[['ba_value','smiles','fasta']]
c_p_df.head()

Unnamed: 0,ba_value,smiles,fasta
0,0.4,CC(=O)NC(CCC(=O)O)C(=O)O,GFSATRSTVIQLLNNISTKREVEQYLKYFTSVSQQQFAVIKVGGAI...
1,0.49,CC(=O)N1CCCC(C)C1,MVNPTVFFDIAVDGEPLGRVSFELFADKVPKTAENFRALSTGEKGF...
2,1.6,CCOC(=O)C(=O)N1CCCCC1,MVNPTVFFDIAVDGEPLGRVSFELFADKVPKTAENFRALSTGEKGF...
3,3.58,CCCN(Cc1ccc(N)cc1)C(=O)NCC(=O)OCC,MVNPTVFFDIAVDGEPLGRVSFELFADKVPKTAENFRALSTGEKGF...
4,3.74,CCOC(=O)CNC(=O)N(Cc1ccc(N)cc1)Cc1nnn(C)n1,MVNPTVFFDIAVDGEPLGRVSFELFADKVPKTAENFRALSTGEKGF...


---
### 1) Smiles to graph

In [193]:
# node len : 63개
# Degree : [0,1,2,3,4,5] - 6개
# ExplicitValence : [1,2,3,4,5,6] - 6개
# ImplicitValence : [0,1,2,3,4,5] - 6개
# Aromatic : [0 or 1] - 1개
# 총 length : 82개 ()

ELEM_LIST = [
    'C', 'N', 'O', 'S', 'F', 
    'Si', 'P', 'Cl', 'Br', 'Mg', 
    'Na', 'Ca', 'Fe', 'As', 'Al', 
    'I', 'B', 'V', 'K', 'Tl', 
    'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 
    'Co', 'Se', 'Ti', 'Zn', 'H', 
    'Li', 'Ge', 'Cu', 'Au', 'Ni', 
    'Cd', 'In', 'Mn', 'Zr', 'Cr', 
    'Pt', 'Hg', 'Pb', 'W', 'Ru', 
    'Nb', 'Re', 'Te', 'Rh', 'Tc', 
    'Ba', 'Bi', 'Hf', 'Mo', 'U', 
    'Sm', 'Os', 'Ir', 'Ce', 'Gd',
    'Ga','Cs', 'unknown'
]


In [194]:
# For get_atom_feature
def one_of_k_encoding(x, vocab:list) -> list:
	if x not in vocab:
		x = vocab[-1]
	return list(map(lambda s: int(x==s), vocab))

In [195]:
# For get_molecular_graph
def get_atom_feature(atom) -> list:
    atom_feature =  one_of_k_encoding(atom.GetSymbol(), ELEM_LIST)
    atom_feature += one_of_k_encoding(atom.GetDegree(), [1,2,3,4,5,6])
    atom_feature += one_of_k_encoding(atom.GetExplicitValence(), [1,2,3,4,5,6])
    atom_feature += one_of_k_encoding(atom.GetImplicitValence(), [0,1,2,3,4,5])
    atom_feature += [atom.GetIsAromatic()]
    
    return atom_feature

In [196]:
# For get_molecular_graph
def get_bond_feature(bond) -> list:
	bt = bond.GetBondType()
	bond_feature = [
		bt == Chem.rdchem.BondType.SINGLE,
		bt == Chem.rdchem.BondType.DOUBLE,
		bt == Chem.rdchem.BondType.TRIPLE,
		bt == Chem.rdchem.BondType.AROMATIC,
		bond.GetIsConjugated(),
		bond.IsInRing()
	]
	return bond_feature

In [197]:
def get_molecular_graph(smi: str) -> dgl.graph:
    
    graph = dgl.DGLGraph()
    mol = Chem.MolFromSmiles(smi)
    
    # graph : num_nodes
    atoms_mol = mol.GetAtoms()
    num_atoms = len(atoms_mol)
    graph.add_nodes(num_atoms)

    # graph : ndata_chemes
    atom_feature_ls = [get_atom_feature(atom) for atom in atoms_mol]
    atom_feature_ts = torch.tensor(atom_feature_ls, dtype=torch.float64)
    graph.ndata['h'] = atom_feature_ts

    # graph : num_edges + edata_schemes
    bonds_mol = mol.GetBonds()
    bond_feature_ls= []
    for bond in bonds_mol:
        bond_feature = get_bond_feature(bond)
        src = bond.GetBeginAtom().GetIdx()
        dst = bond.GetEndAtom().GetIdx()
        graph.add_edges(src, dst)
        bond_feature_ls.append(bond_feature)
        graph.add_edges(dst, src)
        bond_feature_ls.append(bond_feature)

    bond_feature_ts = torch.tensor(bond_feature_ls, dtype=torch.float64)
    graph.edata['e_ij'] = bond_feature_ts

    return graph

In [198]:
get_molecular_graph(c_p_df.smiles[0])



Graph(num_nodes=13, num_edges=24,
      ndata_schemes={'h': Scheme(shape=(82,), dtype=torch.float64)}
      edata_schemes={'e_ij': Scheme(shape=(6,), dtype=torch.float64)})

---
### 2) Protein to sequence

In [199]:
aa_list = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']

In [200]:
word_dict = defaultdict(lambda: len(word_dict))

sequence = c_p_df.fasta[0]
sequence = sequence.upper()
word_ls = [i for i in sequence]
output = [word_dict['X'] if word not in aa_list else word_dict[word] for word in word_ls]
np.array(output,np.int32)

array([ 0,  1,  2,  3,  4,  5,  2,  4,  6,  7,  8,  9,  9, 10, 10,  7,  2,
        4, 11,  5, 12,  6, 12,  8, 13,  9, 11, 13,  1,  4,  2,  6,  2,  8,
        8,  8,  1,  3,  6,  7, 11,  6,  0,  0,  3,  7,  7,  2, 14, 10,  9,
       15, 12,  9,  3,  2, 16,  9,  3,  1,  9, 13, 15,  6,  0,  9, 13, 17,
        7,  6,  9, 15,  0,  4,  0, 17,  8,  6, 10,  0,  5,  9, 12,  3,  8,
        0,  7, 12, 17, 14, 13,  7, 14,  0,  7,  5,  7,  4, 14, 12, 15,  4,
       18,  3,  6,  6,  5, 11, 16,  1,  9, 12,  8, 10,  9, 11,  9,  6,  4,
        3,  9, 12,  8,  9,  0,  6,  5,  3,  5, 17,  7,  4,  2,  0,  6,  1,
        4,  3, 14, 13,  9, 14, 11, 14, 11, 13, 11,  9,  6,  0, 10,  7, 11,
        2,  6,  4, 11, 12, 17,  7, 12,  3,  2,  7, 11,  3,  0,  3,  9, 17,
        7,  9,  4,  2,  9,  3, 12,  4,  3,  2,  0,  8, 18,  9, 10,  6, 10,
        3, 14,  6,  3,  3,  0, 12,  9,  3,  5,  6,  1, 12, 17,  9, 11,  7,
        6, 13,  9, 10, 12, 11,  0,  0,  7,  7, 10,  0,  2,  4,  0, 12, 11,
        7,  2, 18,  7, 10

In [201]:
def load_blosum62(blosum_path = '/hdd3/seungheun/dti_study/blosum62.txt') -> dict:
    with open(blosum_path, 'r') as fr:
        blosum_dict = {}
        for line in fr:
            if line.startswith(' '):
                continue
            parsed = line.strip('\n').split()
            blosum_dict[parsed[0]] = np.array(parsed[1:]).astype(float)
    
    return blosum_dict

In [202]:
blosum_dict = load_blosum62()

---
### 3) collate_function

In [203]:
def my_collate_fn(batch) -> (dgl.graph, torch.tensor):
    graph_ls, ba_value_ls = [], []
    
    for item in batch:
        smiles = item[0]
        ba_value = item[1]
        graph_ls.append(get_molecular_graph(smiles))
        ba_value_ls.append(ba_value)
    
    graph_ls = dgl.batch(graph_ls)
    ba_value_ls = torch.tensor(ba_value_ls, dtype=torch.float64)
    
    return graph_ls, ba_value_ls

---
### 3. dataset class

In [204]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, splitted_set):
        self._smi_ls = splitted_set['smiles']
        self._prot_ls = splitted_set['fasta']
        self._ba_value_ls = splitted_set['ba_value']
    
    def __len__(self):
        return len(self.splitted_set)
    
    def __getitem__(self, idx):
        c_graph = get_molecular_graph(self._smi_ls[idx])
        y = self._ba_value_ls[idx]
        return c_graph, y

---
### Anything
---

In [205]:
def my_collate_fn_heun(batch)-> (dgl.graph, torch.tensor):
    graph_ls, ba_value_ls = [], []

    for _, row in batch.iterrows():
        smiles = row[0]
        ba_value = row[1]
        graph_ls.append(get_molecular_graph(smiles))
        ba_value_ls.append(ba_value)
    graph_graph = dgl.batch(graph_ls)
    ba_value_ts = torch.tensor(ba_value_ls, dtype=torch.float64)

    return graph_graph, ba_value_ts

batch = c_p_df[['smiles','ba_value']][0:3]
my_collate_fn_heun(batch)



(Graph(num_nodes=36, num_edges=70,
       ndata_schemes={'h': Scheme(shape=(82,), dtype=torch.float64)}
       edata_schemes={'e_ij': Scheme(shape=(6,), dtype=torch.float64)}),
 tensor([0.4000, 0.4900, 1.6000], dtype=torch.float64))