# Seq PDB map try

In [1]:
import pandas as pd

df_davis = pd.read_csv('./davis_seq_pdb.csv')
df_kiba = pd.read_csv('./kiba_seq_pdb.csv')

In [9]:
df_davis[df_davis['target_pdb_id'] == 'TOFOLD']['target_sequence'].to_list()[0]

'QLTPTNSLKRGGAHHRRCEVALLGCGAVLAATGLGFDLLEAGKCQLLPLEEPEPPAREEKKRREGLFQRSSRPRRSTSPPSRKLFKKEEPMLLLGDPSASLTLLSLSSISECNSTRSLLRSDSDEIVVYEMPVSPVEAPPLSPCTHNPLVNVRVERFKRDPNQSLTPTHVTLTTPSQPSSHRRTPSDGALKPETLLASRSPSSNGLSPSPGAGMLKTPSPSRDPGEFPRLPDPNVVFPPTPRRWNTQQDSTLERPKTLEFLPRPRPSANRQRLDPWWFVSPSHARSTSPANSSSTETPSNLDSCFASSSSTVEERPGLPALLPFQAGPLPPTERTLLDLDAEGQSQDSTVPLCRAELNTHRPAPYEIQQEFWS'

In [2]:
for pdb_id in df_davis['target_pdb_id'].unique():
    if pdb_id[: 4].lower() in ["6g5i", "5t0j", "5wve"]:
        print(pdb_id)

6G5I_36


In [3]:
for pdb_id in df_kiba['target_pdb_id'].unique():
    if pdb_id[: 4].lower() in ["6g5i", "5t0j", "5wve"]:
        print(pdb_id)

In [4]:
len(df_davis['target_pdb_id'].unique())

298

In [6]:
len(set(pdb_id[: 4].lower() for pdb_id in df_davis['target_pdb_id'].unique()))

298

In [37]:
import glob

for fn in glob.glob('./data/**/*.csv', recursive=True):
    df = pd.read_csv(fn)
    df = df.rename(columns={col: col.strip() for col in df.columns})
    df.to_csv(fn, index=False)

# Prot. -> LLM (FusionDTA)

In [1]:
import pandas as pd
import numpy as np
import os
import json,pickle
from collections import OrderedDict
import re
import csv
import esm
import torch
from rdkit import Chem
import pypdb

In [2]:
def generate_protein_pretraining_representation(dataset_name, prots):
    data_dict = {}
    prots_tuple = [(str(i), prots[i][:1022]) for i in range(len(prots))]
    model, alphabet = esm.pretrained.esm2_t36_3B_UR50D()
    batch_converter = alphabet.get_batch_converter()
    i = 0
    batch = 1
    
    while (batch*i) < len(prots):
        print('converting protein batch: '+ str(i))
        if (i + batch) < len(prots):
            pt = prots_tuple[batch*i:batch*(i+1)]
        else:
            pt = prots_tuple[batch*i:]
        
        batch_labels, batch_strs, batch_tokens = batch_converter(pt)
        #model = model.cuda()
        #batch_tokens = batch_tokens.cuda()
        
        with torch.no_grad():
            results = model(batch_tokens, repr_layers=[36], return_contacts=True)
        token_representations = results["representations"][36].numpy()
        data_dict[i] = token_representations
        i += 1
    np.savez(dataset_name + '.npz', dict=data_dict)

In [3]:
seq = 'MSPLNQSAEGLPQEASNRSLNATETSEAWDPRTLQALKISLAVVLSVITLATVLSNAFVLTTILLTRKLHTPANYLIGSLATTDLLVSILVMPISIAYTITHTWNFGQILCDIWLSSDITCCTASILHLCVIALDRYWAITDALEYSKRRTAGHAATMIAIVWAISICISIPPLFWRQAKAQEEMSDCLVNTSQISYTIYSTCGAFYIPSVLLIILYGRIYRAARNRILNPPSLYGKRFTTAHLITGSAGSSLCSLNSSLHEGHSHSAGSPLFFNHVKIKLADSALERKRISAARERKATKILGIILGAFIICWLPFFVVSLVLPICRDSCWIHPALFDFFTWLGYLNSLINPIIYTVFNEEFRQAFQKIVPFRKAS'
q = pypdb.Query(seq,
          query_type="sequence", 
          return_type="polymer_entity")
None is None

True

In [4]:
def query_pdb_id(seq):
    q = pypdb.Query(seq,
              query_type="sequence", 
              return_type="polymer_entity")
    res = q.search()
    
    if res is not None:
        return res['result_set'][0]['identifier']
    else:
        return 'TOFOLD'

In [5]:
from tqdm import tqdm

ESM = True
datasets = ['davis','kiba']
for dataset in datasets:
    fpath = 'data/' + dataset + '/raw/'
    train_valid_folds = json.load(open(fpath + "folds/train_fold_setting1.txt"))
    test_fold = json.load(open(fpath + "folds/test_fold_setting1.txt"))
    valid_ids = [0, 1, 2, 3, 4]
    valid_folds = [train_valid_folds[vid] for vid in valid_ids]
    train_folds = []
    for valid_id in valid_ids:
        temp = []
        for idx in range(5):
            if idx != valid_id:
                temp += train_valid_folds[idx]
        train_folds.append(temp)
    
    ligands = json.load(open(fpath + "ligands_can.txt"), object_pairs_hook=OrderedDict)
    proteins = json.load(open(fpath + "proteins.txt"), object_pairs_hook=OrderedDict)
    affinity = pickle.load(open(fpath + "Y","rb"), encoding='latin1')
    drugs = []
    drug_smiles = []
    prot_seqs = []
    prot_pdb_ids = []
    for d in tqdm(ligands.keys()):
        #lg = ligands[d]
        lg = Chem.MolToSmiles(Chem.MolFromSmiles(ligands[d]),isomericSmiles=True)
        drugs.append(lg)
        drug_smiles.append(ligands[d])
    for t in tqdm(proteins.keys()):
        prot_seqs.append(proteins[t])
        prot_pdb_ids.append(query_pdb_id(proteins[t]))
    if dataset == 'davis':
        affinity = [-np.log10(y/1e9) for y in affinity]

    # protein pretraing presentation
    if ESM:
        generate_protein_pretraining_representation(dataset, prot_seqs)

    affinity = np.asarray(affinity)
    opts = ['train','valid']

    print('generating test data')
    rows, cols = np.where(np.isnan(affinity) == False)  
    test_rows, test_cols = rows[test_fold], cols[test_fold]
    with open('data/' + dataset + '/' + dataset  + '_test.csv', 'w') as f:
        f.write('compound_iso_smiles, target_sequence, target_pdb_id, affinity, protein_id, drug_id\n')
        for pair_ind in range(len(test_rows)):
            ls = []
            ls += [ drugs[test_rows[pair_ind]] ]
            ls += [ prot_seqs[test_cols[pair_ind]] ]
            ls += [ prot_pdb_ids[test_cols[pair_ind]] ]
            ls += [ affinity[test_rows[pair_ind], test_cols[pair_ind]] ]
            ls += [ test_cols[pair_ind] ]
            ls += [ test_rows[pair_ind] ]
            f.write(','.join(map(str, ls)) + '\n')

    for i in range(5):
        train_fold = train_folds[i]
        valid_fold = valid_folds[i]
        for opt in opts:
            rows, cols = np.where(np.isnan(affinity) == False)  
            if opt == 'train':
                rows, cols = rows[train_fold], cols[train_fold]
                
                #generating cold data
                with open('data/' + dataset + '_cold' + '.csv', 'w') as f:
                    f.write('compound_iso_smiles, target_sequence, affinity, protein_id, drug_id\n')
                    for pair_ind in range(len(rows)):
                        ls = []
                        ls += [ drugs[rows[pair_ind]]  ]
                        ls += [ prots[cols[pair_ind]]  ]
                        ls += [ affinity[rows[pair_ind],cols[pair_ind]]  ]
                        ls += [ cols[pair_ind] ]
                        ls += [ rows[pair_ind] ]
                        f.write(','.join(map(str,ls)) + '\n') 
            elif opt == 'valid':
                rows, cols = rows[valid_fold], cols[valid_fold]
                
                #generating cold data
                with open('data/' + dataset + '_cold' + '.csv', 'a') as f:
                    for pair_ind in range(len(rows)):
                        ls = []
                        ls += [ drugs[rows[pair_ind]]  ]
                        ls += [ prots[cols[pair_ind]]  ]
                        ls += [ affinity[rows[pair_ind],cols[pair_ind]]  ]
                        ls += [ cols[pair_ind] ]
                        ls += [ rows[pair_ind] ]
                        f.write(','.join(map(str,ls)) + '\n') 
                      
            #5-fold data
            print('generating 5-fold data')
            with open('data/' + dataset + '/' + dataset + '_' + opt + '_fold_' + str(i) + '.csv', 'w') as f:
                f.write('compound_iso_smiles, target_sequence, target_pdb_id, affinity, protein_id, drug_id\n')
                for pair_ind in range(len(rows)):
                    ls = []
                    ls += [ drugs[rows[pair_ind]] ]
                    ls += [ prot_seqs[cols[pair_ind]] ]
                    ls += [ prot_pdb_ids[cols[pair_ind]] ]
                    ls += [ affinity[rows[pair_ind], cols[pair_ind]] ]
                    ls += [ cols[pair_ind] ]
                    ls += [ rows[pair_ind] ]
                    f.write(','.join(map(str,ls)) + '\n')

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 68/68 [00:00<00:00, 2862.29it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 442/442 [26:31<00:00,  3.60s/it]


ModuleNotFoundError: No module named 'fused_layer_norm_cuda'

# Prot. -> PDB file -> Hybrid Graphs (AttnSiteDTI)

In [1]:
import os
import pickle
from collections import OrderedDict
import random
import glob

import pandas as pd
from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer
import dgl
import numpy as np
import torch
from rdkit import Chem
from rdkit.Chem.rdmolops import GetAdjacencyMatrix
import networkx as nx
from Bio.PDB import *
import deepchem
import pickle

2023-02-02 15:04:47.591148: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-02 15:04:48.298870: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-02-02 15:04:52.212567: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /home/luozc/CCMpred/lib:/home/luozc/miniconda3/lib:/usr/lib/x86_64-linux-gnu/:/home/luo

In [2]:
def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set))
    # print list((map(lambda s: x == s, allowable_set)))
    return list(map(lambda s: x == s, allowable_set))


def one_of_k_encoding_unk(x, allowable_set):
    """Maps inputs not in the allowable set to the last element."""
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))


def atom_feature(atom):
    return np.array(one_of_k_encoding_unk(atom.GetSymbol(),
                                          ['C', 'N', 'O', 'S', 'F', 'P', 'Cl', 'Br', 'B', 'H']) +
                    one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6, 7, 8]) +
                    one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4]) +
                    one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5]) +
                    [atom.GetIsAromatic()])  # (10, 9, 5, 6, 1) --> total 31


def get_atom_feature(m):
    H = []
    for i in range(len(m)):
        H.append(atom_feature(m[i][0]))
    H = np.array(H)

    return H

pk = deepchem.dock.ConvexHullPocketFinder()
def process_protein(pdb_file):
    m = Chem.MolFromPDBFile(pdb_file)
    am = GetAdjacencyMatrix(m) # n2 x n2
    pockets = pk.find_pockets(pdb_file)
    n2 = m.GetNumAtoms()
    c2 = m.GetConformers()[0] # Tid
    d2 = np.array(c2.GetPositions()) # n2 x 3
    binding_parts = []
    not_in_binding = [i for i in range(0, n2)]
    constructed_graphs = []
    for bound_box in pockets:
        x_min = bound_box.x_range[0]
        x_max = bound_box.x_range[1]
        y_min = bound_box.y_range[0]
        y_max = bound_box.y_range[1]
        z_min = bound_box.z_range[0]
        z_max = bound_box.z_range[1]
        binding_parts_atoms = []
        idxs = []
        for idx, atom_cord in enumerate(d2):
            if x_min < atom_cord[0] < x_max and y_min < atom_cord[1] < y_max and z_min < atom_cord[2] < z_max:
                binding_parts_atoms.append((m.GetAtoms()[idx], atom_cord))
                idxs.append(idx)
                if idx in not_in_binding:
                    not_in_binding.remove(idx)

        ami = am[np.array(idxs)[:, None], np.array(idxs)] # len(idxs) x len(idxs)
        H = get_atom_feature(binding_parts_atoms)
        g = nx.convert_matrix.from_numpy_matrix(ami)
        graph = dgl.from_networkx(g)
        graph.ndata['h'] = torch.Tensor(H)
        graph = dgl.add_self_loop(graph)
        constructed_graphs.append(graph)
        binding_parts.append(binding_parts_atoms)

    constructed_graphs = dgl.batch(constructed_graphs)

    return binding_parts, not_in_binding, constructed_graphs

In [3]:
random.seed(42)

node_featurizer = CanonicalAtomFeaturizer(atom_data_field='h') # Tid

zero = np.eye(2)[1] # [0, 1]
one = np.eye(2)[0] # [1, 0]

In [None]:
df = pd.read_csv("humanSeqPdb") # 1787 x 2
print(len(df['pdb_id'].unique())) # 1662

with open("human_data.txt", 'r') as fp:
    train_raw = fp.read()

constructed_graphs = ""
raw_data = train_raw.split("\n") # 6729 list
random.shuffle(raw_data)
raw_data_train = raw_data[0: int(len(raw_data)*0.8)] # 5383
raw_data_valid = raw_data[int(len(raw_data)*0.8): int(len(raw_data)*0.9)] # 673
raw_data_test = raw_data[int(len(raw_data)*0.9): int(len(raw_data))] # 673
del raw_data

p_graphs = {}

save_set = []
i = 1
# for item in raw_data_train:
#     print(i)
#     i += 1
#     try:
#         a = item.split()
#         smile = a[0]
#         sequence = a[1]
#         pdb_code = df.loc[df["sequence"] == sequence]["pdb_id"].item()[:-1] # Tid: -1 means omit the info of chain
#         if pdb_code != "6g5i" and pdb_code != "5t0j" and pdb_code != "5wve":
#             if pdb_code not in p_graphs.keys():
#                 pdbl = PDBList()
#                 pdbl.retrieve_pdb_file(
#                     pdb_code, pdir='./pdbs/', overwrite=True, file_format="pdb"
#                 )
#                 # Rename file to .pdb from .ent
#                 os.rename(
#                     './pdbs/' + "pdb" + pdb_code + ".ent", './pdbs/' + pdb_code + ".pdb"
#                 )
#                 # Assert file has been downloaded
#                 assert any(pdb_code in s for s in os.listdir('./pdbs/'))
#                 #print(f"Downloaded PDB file for: {pdb_code}")
#                 _, _, constructed_graphs = process_protein(f"./pdbs/{pdb_code}.pdb")

#                 p_graphs[pdb_code] = constructed_graphs
#             else:
#                 constructed_graphs = p_graphs[pdb_code]

#             g = smiles_to_bigraph(smile, node_featurizer=node_featurizer)
#             g = dgl.add_self_loop(g)
#             if a[2] == "1":
#                 save_set.append(((constructed_graphs, g), one))
#             else:
#                 save_set.append(((constructed_graphs, g), zero)) # FIXME
#     except Exception as e:
#         print(e)
#         continue


# with open(f'human_part_train.pkl', 'wb') as f:
#     pickle.dump(save_set, f)

In [13]:
p_graphs = {}

save_set = []
item = 'CC[C@@]1(C[C@@H]2C3=CC(=C(C=C3CCN2C[C@H]1CC(C)C)OC)OC)O MSPLNQSAEGLPQEASNRSLNATETSEAWDPRTLQALKISLAVVLSVITLATVLSNAFVLTTILLTRKLHTPANYLIGSLATTDLLVSILVMPISIAYTITHTWNFGQILCDIWLSSDITCCTASILHLCVIALDRYWAITDALEYSKRRTAGHAATMIAIVWAISICISIPPLFWRQAKAQEEMSDCLVNTSQISYTIYSTCGAFYIPSVLLIILYGRIYRAARNRILNPPSLYGKRFTTAHLITGSAGSSLCSLNSSLHEGHSHSAGSPLFFNHVKIKLADSALERKRISAARERKATKILGIILGAFIICWLPFFVVSLVLPICRDSCWIHPALFDFFTWLGYLNSLINPIIYTVFNEEFRQAFQKIVPFRKAS 0'

a = item.split()
smile = a[0]
sequence = a[1]
pdb_code = '4iarA'[: 4]
if pdb_code != "6g5i" and pdb_code != "5t0j" and pdb_code != "5wve":
    if pdb_code not in p_graphs.keys():
        pdbl = PDBList()
        pdbl.retrieve_pdb_file(
            pdb_code, pdir='./pdbs/', overwrite=True, file_format="pdb"
        )
        # Rename file to .pdb from .ent
        os.rename(
            './pdbs/' + "pdb" + pdb_code + ".ent", './pdbs/' + pdb_code + ".pdb"
        )
        # Assert file has been downloaded
        assert any(pdb_code in s for s in os.listdir('./pdbs/'))
        #print(f"Downloaded PDB file for: {pdb_code}")
        _, _, constructed_graphs = process_protein(f"./pdbs/{pdb_code}.pdb")

        p_graphs[pdb_code] = constructed_graphs
    else:
        constructed_graphs = p_graphs[pdb_code]

    g = smiles_to_bigraph(smile, node_featurizer=node_featurizer)
    g = dgl.add_self_loop(g)
    if a[2] == "1":
        save_set.append(((constructed_graphs, g), one))
    else:
        save_set.append(((constructed_graphs, g), zero))

Downloading PDB structure '4iar'...


In [14]:
p_graphs

{'4iar': Graph(num_nodes=14086, num_edges=41200,
       ndata_schemes={'h': Scheme(shape=(31,), dtype=torch.float32)}
       edata_schemes={})}

In [15]:
save_set

[((Graph(num_nodes=14086, num_edges=41200,
         ndata_schemes={'h': Scheme(shape=(31,), dtype=torch.float32)}
         edata_schemes={}),
   Graph(num_nodes=25, num_edges=79,
         ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
         edata_schemes={})),
  array([0., 1.]))]

In [47]:
pdb_ids = ['6T29_1', '6TLU_1', '6QX9_43']

In [49]:
from Bio.PDB import *

for pdb_id in pdb_ids:
    pdb_code = pdb_id[: 4].lower()
    pdbl = PDBList()
    pdbl.retrieve_pdb_file(
        pdb_code, pdir='./pdbs/', overwrite=False, file_format="mmCif"
    )
    # # Rename file to .pdb from .ent
    # os.rename(
    #     './pdbs/' + "pdb" + pdb_code + ".ent", './pdbs/' + pdb_code + ".pdb"
    # )

Downloading PDB structure '6t29'...
Downloading PDB structure '6tlu'...
Downloading PDB structure '6qx9'...


In [12]:
for fn in os.listdir('./pdbs'):
    if fn.endswith('.cif'):
        print('./pdbs/' + fn)

./pdbs/6t29.cif
./pdbs/6tlu.cif
./pdbs/AF-P80192-F1-model_v4.cif
./pdbs/6qx9.cif


In [13]:
!python cif2pdb.py ./pdbs/AF-P80192-F1-model_v4.cif ./pdbs/AF-P80192-F1-model_v4.pdb

In [14]:
!python cif2pdb.py ./pdbs/6t29.cif ./pdbs/6t29.pdb

In [15]:
!python cif2pdb.py ./pdbs/6tlu.cif ./pdbs/6tlu.pdb

In [None]:
!python cif2pdb.py ./pdbs/6qx9.cif ./pdbs/6qx9.pdb

ERROR: Too many chains to represent in PDB format


In [59]:
for fn in os.listdir('./pdbs'):
    if any(pdb_id[: 4].lower() in fn for pdb_id in pdb_ids):
        print('./pdbs/' + fn)

NameError: name 'pdb_ids' is not defined

In [61]:
from Bio.PDB.MMCIFParser import MMCIFParser

ciffile = './pdbs/6qx9.cif'
strucid = ciffile[:4] if len(ciffile)>4 else "1xxx"
parser = MMCIFParser()
structure = parser.get_structure(strucid, ciffile)
list(structure.get_chains())



[<Chain id=1>,
 <Chain id=6>,
 <Chain id=5O>,
 <Chain id=B4>,
 <Chain id=13>,
 <Chain id=4B>,
 <Chain id=5e>,
 <Chain id=I>,
 <Chain id=1K>,
 <Chain id=4C>,
 <Chain id=41>,
 <Chain id=R>,
 <Chain id=1f>,
 <Chain id=4e>,
 <Chain id=66>,
 <Chain id=X>,
 <Chain id=22>,
 <Chain id=5>,
 <Chain id=67>,
 <Chain id=62>,
 <Chain id=2B>,
 <Chain id=53>,
 <Chain id=A2>,
 <Chain id=B2>,
 <Chain id=2f>,
 <Chain id=5C>,
 <Chain id=5X>,
 <Chain id=11>,
 <Chain id=42>,
 <Chain id=1b>,
 <Chain id=B5>,
 <Chain id=1A>,
 <Chain id=S>,
 <Chain id=5f>,
 <Chain id=5J>,
 <Chain id=51>,
 <Chain id=4D>,
 <Chain id=63>,
 <Chain id=2b>,
 <Chain id=2>,
 <Chain id=4f>,
 <Chain id=B3>,
 <Chain id=1g>,
 <Chain id=23>,
 <Chain id=5b>,
 <Chain id=68>,
 <Chain id=43>,
 <Chain id=1e>,
 <Chain id=5A>,
 <Chain id=A3>,
 <Chain id=U>,
 <Chain id=2g>,
 <Chain id=5D>,
 <Chain id=52>,
 <Chain id=12>,
 <Chain id=64>,
 <Chain id=2e>,
 <Chain id=BP>,
 <Chain id=1C>,
 <Chain id=5g>,
 <Chain id=K>,
 <Chain id=4b>,
 <Chain id=4A>,
 <

In [62]:
len(list(structure.get_chains()))

71

In [67]:
with open('./davis_1_site_msg.pkl', 'rb') as f:
    msg = pickle.load(f)
len(msg), msg

(8,
 ['7 PDB id: 3MY0_1 ',
  "27 PDB id: 7MFE_1 Python argument types in\n    rdkit.Chem.rdmolops.GetAdjacencyMatrix(NoneType)\ndid not match C++ signature:\n    GetAdjacencyMatrix(RDKit::ROMol {lvalue} mol, bool useBO=False, int emptyVal=0, bool force=False, char const* prefix='')",
  "34 PDB id: 6W4P_1 Python argument types in\n    rdkit.Chem.rdmolops.GetAdjacencyMatrix(NoneType)\ndid not match C++ signature:\n    GetAdjacencyMatrix(RDKit::ROMol {lvalue} mol, bool useBO=False, int emptyVal=0, bool force=False, char const* prefix='')",
  "46 PDB id: 1UA2_1 Python argument types in\n    rdkit.Chem.rdmolops.GetAdjacencyMatrix(NoneType)\ndid not match C++ signature:\n    GetAdjacencyMatrix(RDKit::ROMol {lvalue} mol, bool useBO=False, int emptyVal=0, bool force=False, char const* prefix='')",
  "67 PDB id: 5OOI_1 Python argument types in\n    rdkit.Chem.rdmolops.GetAdjacencyMatrix(NoneType)\ndid not match C++ signature:\n    GetAdjacencyMatrix(RDKit::ROMol {lvalue} mol, bool useBO=False, 

In [66]:
with open('./davis_2_site_msg.pkl', 'rb') as f:
    msg = pickle.load(f)
len(msg), msg

(9,
 ["8 PDB id: 5T89_2 Python argument types in\n    rdkit.Chem.rdmolops.GetAdjacencyMatrix(NoneType)\ndid not match C++ signature:\n    GetAdjacencyMatrix(RDKit::ROMol {lvalue} mol, bool useBO=False, int emptyVal=0, bool force=False, char const* prefix='')",
  "9 PDB id: 7QDP_2 Python argument types in\n    rdkit.Chem.rdmolops.GetAdjacencyMatrix(NoneType)\ndid not match C++ signature:\n    GetAdjacencyMatrix(RDKit::ROMol {lvalue} mol, bool useBO=False, int emptyVal=0, bool force=False, char const* prefix='')",
  "34 PDB id: 4Z32_1 Python argument types in\n    rdkit.Chem.rdmolops.GetAdjacencyMatrix(NoneType)\ndid not match C++ signature:\n    GetAdjacencyMatrix(RDKit::ROMol {lvalue} mol, bool useBO=False, int emptyVal=0, bool force=False, char const* prefix='')",
  "38 PDB id: 5NCL_1 Python argument types in\n    rdkit.Chem.rdmolops.GetAdjacencyMatrix(NoneType)\ndid not match C++ signature:\n    GetAdjacencyMatrix(RDKit::ROMol {lvalue} mol, bool useBO=False, int emptyVal=0, bool forc

In [3]:
df = df_davis.copy()

In [4]:
import os
import json
from collections import OrderedDict
import re
import csv
import pypdb
from tqdm import tqdm
import glob

import pandas as pd
import dgl
import numpy as np
import torch
from rdkit import Chem
from rdkit.Chem.rdmolops import GetAdjacencyMatrix
import networkx as nx
from Bio.PDB import *
import deepchem
import pickle
import time

def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set))
    # print list((map(lambda s: x == s, allowable_set)))
    return list(map(lambda s: x == s, allowable_set))


def one_of_k_encoding_unk(x, allowable_set):
    """Maps inputs not in the allowable set to the last element."""
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))


def atom_feature(atom):
    return np.array(one_of_k_encoding_unk(atom.GetSymbol(),
                                          ['C', 'N', 'O', 'S', 'F', 'P', 'Cl', 'Br', 'B', 'H']) +
                    one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6, 7, 8]) +
                    one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4]) +
                    one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5]) +
                    [atom.GetIsAromatic()])  # (10, 9, 5, 6, 1) --> total 31


def get_atom_feature(m):
    H = []
    for i in range(len(m)):
        H.append(atom_feature(m[i][0]))
    H = np.array(H)

    return H

pk = deepchem.dock.ConvexHullPocketFinder()
def process_protein(pdb_file):
    m = Chem.MolFromPDBFile(pdb_file)
    if m is None:
        return None, None, None
    am = GetAdjacencyMatrix(m) # n2 x n2
    pockets = pk.find_pockets(pdb_file)
    n2 = m.GetNumAtoms()
    c2 = m.GetConformers()[0] # Tid
    d2 = np.array(c2.GetPositions()) # n2 x 3
    binding_parts = []
    not_in_binding = [i for i in range(0, n2)]
    constructed_graphs = []
    for bound_box in pockets:
        x_min = bound_box.x_range[0]
        x_max = bound_box.x_range[1]
        y_min = bound_box.y_range[0]
        y_max = bound_box.y_range[1]
        z_min = bound_box.z_range[0]
        z_max = bound_box.z_range[1]
        binding_parts_atoms = []
        idxs = []
        for idx, atom_cord in enumerate(d2):
            if x_min < atom_cord[0] < x_max and y_min < atom_cord[1] < y_max and z_min < atom_cord[2] < z_max:
                binding_parts_atoms.append((m.GetAtoms()[idx], atom_cord))
                idxs.append(idx)
                if idx in not_in_binding:
                    not_in_binding.remove(idx)

        ami = am[np.array(idxs)[:, None], np.array(idxs)] # len(idxs) x len(idxs)
        H = get_atom_feature(binding_parts_atoms)
        g = nx.convert_matrix.from_numpy_matrix(ami)
        graph = dgl.from_networkx(g)
        graph.ndata['h'] = torch.Tensor(H)
        graph = dgl.add_self_loop(graph)
        constructed_graphs.append(graph)
        binding_parts.append(binding_parts_atoms)

    constructed_graphs = dgl.batch(constructed_graphs)

    return binding_parts, not_in_binding, constructed_graphs

2023-02-06 09:04:34.143489: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-06 09:04:34.867119: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
  np.bool8: (False, True),
2023-02-06 09:04:38.752193: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /home/luozc/CCMpred/lib:/home/luozc/miniconda3/lib:/usr/lib/

In [5]:
df['target_pdb_id'].unique()[239: 244]

array(['6GL7_3', '6ZXF_35', '6G51_23', '6G5I_36', '4NEU_1'], dtype=object)

In [10]:
with open('./davis_site_msg.pkl', 'rb') as f:
    msg = pickle.load(f)
len(msg), msg

(3,
 ['7 PDB id: 3MY0_1 ',
  '176 PDB id: 5ZCS_1 ',
  '236 PDB id: 6QX9_43 Bad input file ./pdbs/6qx9.pdb'])

In [11]:
with open('./kiba_site_msg.pkl', 'rb') as f:
    msg = pickle.load(f)
len(msg), msg

(1, ['81 PDB id: 5ZCS_1 '])

In [22]:
dataset = 'davis'
df_davis[df_davis['target_pdb_id'] == 'TOFOLD']['target_sequence'].unique()

array(['QLTPTNSLKRGGAHHRRCEVALLGCGAVLAATGLGFDLLEAGKCQLLPLEEPEPPAREEKKRREGLFQRSSRPRRSTSPPSRKLFKKEEPMLLLGDPSASLTLLSLSSISECNSTRSLLRSDSDEIVVYEMPVSPVEAPPLSPCTHNPLVNVRVERFKRDPNQSLTPTHVTLTTPSQPSSHRRTPSDGALKPETLLASRSPSSNGLSPSPGAGMLKTPSPSRDPGEFPRLPDPNVVFPPTPRRWNTQQDSTLERPKTLEFLPRPRPSANRQRLDPWWFVSPSHARSTSPANSSSTETPSNLDSCFASSSSTVEERPGLPALLPFQAGPLPPTERTLLDLDAEGQSQDSTVPLCRAELNTHRPAPYEIQQEFWS'],
      dtype=object)

In [23]:
prot_seq_AF_dict = {'QLTPTNSLKRGGAHHRRCEVALLGCGAVLAATGLGFDLLEAGKCQLLPLEEPEPPAREEKKRREGLFQRSSRPRRSTSPPSRKLFKKEEPMLLLGDPSASLTLLSLSSISECNSTRSLLRSDSDEIVVYEMPVSPVEAPPLSPCTHNPLVNVRVERFKRDPNQSLTPTHVTLTTPSQPSSHRRTPSDGALKPETLLASRSPSSNGLSPSPGAGMLKTPSPSRDPGEFPRLPDPNVVFPPTPRRWNTQQDSTLERPKTLEFLPRPRPSANRQRLDPWWFVSPSHARSTSPANSSSTETPSNLDSCFASSSSTVEERPGLPALLPFQAGPLPPTERTLLDLDAEGQSQDSTVPLCRAELNTHRPAPYEIQQEFWS':'AF-P80192-F1-model_v4'}

with open(dataset + '_seq_AF.csv', 'w') as f:
    f.write('target_sequence,target_AF_id\n')
    for seq, AF_id in prot_seq_AF_dict.items():
        ls = []
        ls += [ seq ]
        ls += [ AF_id ]
        f.write(','.join(map(str, ls)) + '\n')

In [25]:
df = pd.read_csv(dataset + '_seq_AF.csv')

with open(f'{dataset}_site.pkl', 'rb') as f:
    p_graphs = pickle.load(f)
print(len(p_graphs), type(p_graphs))

for AF_id in tqdm(df['target_AF_id'].unique()):
    try:
        AF_code = AF_id
        if AF_code not in p_graphs.keys():
            # Assert file has been downloaded
            assert any(AF_code in s for s in os.listdir('./pdbs/'))
            #print(f"Downloaded PDB file for: {AF_code}")
            _, _, constructed_graphs = process_protein(f"./pdbs/{AF_code}.pdb")

            p_graphs[AF_code] = constructed_graphs
    except Exception as e:
        print(f'{time.strftime("%Y-%m-%d|%H:%M:%S", time.localtime())}: AF_id={AF_id}')
        continue

with open(f'{dataset}_site.pkl', 'wb') as f:
    pickle.dump(p_graphs, f)

292 <class 'dict'>


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:04<00:00,  4.73s/it]


In [26]:
with open(f'{dataset}_site.pkl', 'rb') as f:
    p_graphs = pickle.load(f)

In [28]:
cnt = 0
for val in p_graphs.values():
    if val is None:
        cnt += 1
cnt

23

# Filter tasks without Hybrid Graphs

In [64]:
def get_prot_key(row):
    pdb_id = seq2pdb[seq2pdb['target_sequence'] == row['target_sequence']]['target_pdb_id'].to_list()[0]
    if pdb_id == 'TOFOLD':
        return seq2AF[seq2AF['target_sequence'] == row['target_sequence']]['target_AF_id'].to_list()[0]
    else:
        return pdb_id[: 4].lower()

def add_prot_key_col(df):
    new_df = df.copy()
    new_df['prot_key'] = df.apply(get_prot_key, axis=1)
    return new_df

def filter_by_prot_key(df):
    new_df = df.copy()
    return new_df[new_df.apply(lambda row: row['prot_key'] in prot_key2prot_graph.keys(), axis=1)]

In [115]:
test_df.columns

Index(['compound_iso_smiles', 'target_sequence', 'target_pdb_id', 'affinity',
       'protein_id', 'drug_id'],
      dtype='object')

In [66]:
datasets = ['davis', 'kiba']
folds = '01234'

for dataset in datasets:
    # test df
    test_df = pd.read_csv(f'data/{dataset}/{dataset}_test.csv')
    
    # filter no Hybrid graphs task
    if dataset == 'davis':
        seq2AF = pd.read_csv(f'{dataset}_seq_AF.csv')
    seq2pdb = pd.read_csv(f'{dataset}_seq_pdb.csv')
    
    with open(f'{dataset}_site.pkl', 'rb') as fp:
        prot_key2prot_graph = pickle.load(fp)
    prot_key2prot_graph = {key: val for key, val in prot_key2prot_graph.items() if val is not None}
    
    filter_by_prot_key(add_prot_key_col(test_df)).to_csv(f'data/{dataset}/{dataset}_test_clr.csv', index=False)
    
    for fold in tqdm(folds):
        # train & val df
        train_df = pd.read_csv(f'data/{dataset}/{dataset}_train_fold_{fold}.csv')
        valid_df = pd.read_csv(f'data/{dataset}/{dataset}_valid_fold_{fold}.csv')

        filter_by_prot_key(add_prot_key_col(train_df)).to_csv(f'data/{dataset}/{dataset}_train_fold_{fold}_clr.csv', index=False)
        filter_by_prot_key(add_prot_key_col(valid_df)).to_csv(f'data/{dataset}/{dataset}_valid_fold_{fold}_clr.csv', index=False)
        
        print(dataset, f'fold {fold} clear!')

 20%|███████████████████████                                                                                            | 1/5 [00:09<00:39,  9.79s/it]

davis fold 0 clear!


 40%|██████████████████████████████████████████████                                                                     | 2/5 [00:19<00:29,  9.85s/it]

davis fold 1 clear!


 60%|█████████████████████████████████████████████████████████████████████                                              | 3/5 [00:29<00:19,  9.87s/it]

davis fold 2 clear!


 80%|████████████████████████████████████████████████████████████████████████████████████████████                       | 4/5 [00:39<00:09,  9.91s/it]

davis fold 3 clear!


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:49<00:00,  9.88s/it]

davis fold 4 clear!



 20%|███████████████████████                                                                                            | 1/5 [00:35<02:23, 35.75s/it]

kiba fold 0 clear!


 40%|██████████████████████████████████████████████                                                                     | 2/5 [01:12<01:48, 36.06s/it]

kiba fold 1 clear!


 60%|█████████████████████████████████████████████████████████████████████                                              | 3/5 [01:47<01:11, 35.98s/it]

kiba fold 2 clear!


 80%|████████████████████████████████████████████████████████████████████████████████████████████                       | 4/5 [02:23<00:35, 36.00s/it]

kiba fold 3 clear!


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [02:59<00:00, 36.00s/it]

kiba fold 4 clear!





In [67]:
datasets = ['davis', 'kiba']

for dataset in datasets:    
    # filter no Hybrid graphs task
    if dataset == 'davis':
        seq2AF = pd.read_csv(f'{dataset}_seq_AF.csv')
    seq2pdb = pd.read_csv(f'{dataset}_seq_pdb.csv')
    
    with open(f'{dataset}_site.pkl', 'rb') as fp:
        prot_key2prot_graph = pickle.load(fp)
    prot_key2prot_graph = {key: val for key, val in prot_key2prot_graph.items() if val is not None}
    
    # cold df
    cold_df = pd.read_csv(f'data/{dataset}_cold.csv')

    filter_by_prot_key(add_prot_key_col(cold_df)).to_csv(f'data/{dataset}_cold_clr.csv', index=False)
    
    print(dataset, f'cold clear!')

davis cold clear!
kiba cold clear!


In [96]:
def atom_features(atom):
    # 44 + 11 + 11 + 11 + 1 + 3 + 1
    return np.array(one_of_k_encoding_unk(atom.GetSymbol(),
                                          ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'As',
                                           'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se',
                                           'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr',
                                           'Pt', 'Hg', 'Pb', 'X']) +
                    one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
                    one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
                    one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
                    [atom.GetIsAromatic()]
                    +
                    one_of_k_encoding_unk(atom.GetFormalCharge() , [-1,0,1]) +
                    [atom.IsInRing()]
)


# one ont encoding
def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        # print(x)
        raise Exception('input {0} not in allowable set{1}:'.format(x, allowable_set))
    return list(map(lambda s: x == s, allowable_set))


def one_of_k_encoding_unk(x, allowable_set):
    '''Maps inputs not in the allowable set to the last element.'''
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

In [97]:
# Drug smiles -> Graph, call func-smile_to_graph, exist for sure
def smile_to_graph(smile):
    mol = Chem.MolFromSmiles(smile)
    if mol is None:
        return None
    c_size = mol.GetNumAtoms()

    features = []
    for atom in mol.GetAtoms():
        feature = atom_features(atom)
        features.append(feature / sum(feature))

    edges = []
    for bond in mol.GetBonds():
        edges.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
    g = nx.Graph(edges).to_directed()
    edge_index = []
    mol_adj = np.zeros((c_size, c_size))
    for e1, e2 in g.edges:
        mol_adj[e1, e2] = 1
        # edge_index.append([e1, e2])
    mol_adj += np.matrix(np.eye(mol_adj.shape[0]))
    index_row, index_col = np.where(mol_adj >= 0.5)
    for i, j in zip(index_row, index_col):
        edge_index.append([i, j])
    return c_size, features, edge_index

# Prot seq -> LLM, use pid2llm dict, exist for sure
pid2llm = np.load(f'{dataset}.npz',allow_pickle=True)['dict'][()]
def pid_to_llm(pid):
    return torch.from_numpy(pid2llm[pid]).squeeze()

# Prot seq -> Hybrid graphs, seq2[AF | pdb_id[: 4]], then use site.pkl, may not exist(no [pdb_id] or None)
with open(f'{dataset}_site.pkl', 'rb') as fp:
    prot_key2prot_graph = pickle.load(fp)
def prot_key_to_graph(prot_key):
    return prot_key2prot_graph[prot_key]

In [118]:
with open(f'{dataset}_site.pkl', 'rb') as fp:
    prot_key2prot_graph = pickle.load(fp)

In [119]:
with open(f'{dataset}_site.pkl', 'rb') as fp:
    prot_key2prot_graph = pickle.load(fp)

In [98]:
dataset = 'davis'
fold = '0'

train_df = pd.read_csv(f'data/{dataset}/{dataset}_train_fold_{fold}_clr.csv')

In [99]:
train_df_sub = train_df[: 10]

In [114]:
for idx, row in train_df_sub.iterrows():
    cis, pid, key = row.compound_iso_smiles, row.protein_id, row.prot_key
    mol_g = smile_to_graph(cis)
    prot_llm = pid_to_llm(pid)
    prot_g = prot_key_to_graph(key)
    
    print(torch.Tensor(mol_g[1]).shape, torch.Tensor(mol_g[2]).shape, prot_llm.shape)
    print(prot_g)

torch.Size([19, 82]) torch.Size([61, 2]) torch.Size([849, 2560])
Graph(num_nodes=12701, num_edges=36429,
      ndata_schemes={'h': Scheme(shape=(31,), dtype=torch.float32)}
      edata_schemes={})
torch.Size([35, 82]) torch.Size([119, 2]) torch.Size([774, 2560])
Graph(num_nodes=16369, num_edges=45211,
      ndata_schemes={'h': Scheme(shape=(31,), dtype=torch.float32)}
      edata_schemes={})
torch.Size([34, 82]) torch.Size([108, 2]) torch.Size([495, 2560])
Graph(num_nodes=15888, num_edges=42868,
      ndata_schemes={'h': Scheme(shape=(31,), dtype=torch.float32)}
      edata_schemes={})
torch.Size([37, 82]) torch.Size([117, 2]) torch.Size([526, 2560])
Graph(num_nodes=19148, num_edges=54038,
      ndata_schemes={'h': Scheme(shape=(31,), dtype=torch.float32)}
      edata_schemes={})
torch.Size([36, 82]) torch.Size([116, 2]) torch.Size([589, 2560])
Graph(num_nodes=41359, num_edges=119859,
      ndata_schemes={'h': Scheme(shape=(31,), dtype=torch.float32)}
      edata_schemes={})
torch.Size

In [128]:
prot_g

Graph(num_nodes=24118, num_edges=66726,
      ndata_schemes={'h': Scheme(shape=(31,), dtype=torch.float32)}
      edata_schemes={})

In [130]:
(prot_g.ndata['h']).shape

torch.Size([24118, 31])

In [137]:
vars(prot_g)

{'_graph': <dgl.heterograph_index.HeteroGraphIndex at 0x7f97f4c8a640>,
 '_canonical_etypes': [('_N', '_E', '_N')],
 '_batch_num_nodes': {'_N': tensor([ 174,  141,  860,  435,  457,  518,  324,  413,  283,  236,  144,  195,
           248,  461,  484,  596,  204,  352,  497,  577,  173,  161,  648,  263,
           631,  768,   96,   55,  146,  257, 1482,  516,  287,  242,  214,  421,
           335,  443,  176,  315,  252,   87,  106,  259,  372,  170,  493,  342,
           653,  573,  478,  255,  300,  427,  392,  302,   77,  223,  232,  500,
           106,  881,  313,  203,  257,  637])},
 '_batch_num_edges': {('_N',
   '_E',
   '_N'): tensor([ 484,  399, 2412, 1211, 1297, 1454,  922, 1115,  797,  598,  392,  531,
           678, 1255, 1330, 1640,  596,  998, 1369, 1605,  471,  439, 1814,  741,
          1787, 2150,  266,  147,  390,  725, 4202, 1454,  799,  676,  600, 1153,
           921, 1211,  486,  869,  688,  241,  288,  717, 1020,  482, 1337,  940,
          1799, 1515, 1276

In [138]:
vars(prot_g).keys()

dict_keys(['_graph', '_canonical_etypes', '_batch_num_nodes', '_batch_num_edges', '_ntypes', '_is_unibipartite', '_srctypes_invmap', '_dsttypes_invmap', '_etypes', '_etype2canonical', '_etypes_invmap', '_node_frames', '_edge_frames'])

In [139]:
prot_g.ndata

{'h': tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 1.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.]])}

In [141]:
from dgl.nn.pytorch.conv import GATConv, GraphConv, TAGConv, GINConv, APPNPConv
from dgl.nn import TWIRLSConv
from dgl.nn.pytorch.glob import MaxPooling, GlobalAttentionPooling
import torch
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F

In [142]:
prot_conv = nn.ModuleList()
for _ in range(5):
    prot_conv.append(TAGConv(31, 31, 2))

pooling_prot = nn.Linear(31, 1)
pool_prot = GlobalAttentionPooling(pooling_prot)

In [143]:
prot_conv

ModuleList(
  (0): TAGConv(
    (lin): Linear(in_features=93, out_features=31, bias=True)
  )
  (1): TAGConv(
    (lin): Linear(in_features=93, out_features=31, bias=True)
  )
  (2): TAGConv(
    (lin): Linear(in_features=93, out_features=31, bias=True)
  )
  (3): TAGConv(
    (lin): Linear(in_features=93, out_features=31, bias=True)
  )
  (4): TAGConv(
    (lin): Linear(in_features=93, out_features=31, bias=True)
  )
)

In [145]:
pool_prot

GlobalAttentionPooling(
  (gate_nn): Linear(in_features=31, out_features=1, bias=True)
)

In [148]:
# single graph

feat_prot = prot_g.ndata['h']
print(feat_prot.shape)
for module in prot_conv:
    feat_prot = F.relu(module(prot_g, feat_prot))
feat_prot.shape

torch.Size([24118, 31])


torch.Size([24118, 31])

In [150]:
# single graph

prot_repr = pool_prot(prot_g, feat_prot)
prot_repr.shape

torch.Size([66, 31])

In [152]:
prot_repr = prot_repr.view(1, -1, 31)
prot_repr.shape

torch.Size([1, 66, 31])

In [169]:
# batch graph

prot_gs = []
for idx, row in train_df_sub.iterrows():
    key = row.prot_key
    prot_g = prot_key_to_graph(key)
    prot_gs.append(prot_g)

In [170]:
prot_gs

[Graph(num_nodes=12701, num_edges=36429,
       ndata_schemes={'h': Scheme(shape=(31,), dtype=torch.float32)}
       edata_schemes={}),
 Graph(num_nodes=16369, num_edges=45211,
       ndata_schemes={'h': Scheme(shape=(31,), dtype=torch.float32)}
       edata_schemes={}),
 Graph(num_nodes=15888, num_edges=42868,
       ndata_schemes={'h': Scheme(shape=(31,), dtype=torch.float32)}
       edata_schemes={}),
 Graph(num_nodes=19148, num_edges=54038,
       ndata_schemes={'h': Scheme(shape=(31,), dtype=torch.float32)}
       edata_schemes={}),
 Graph(num_nodes=41359, num_edges=119859,
       ndata_schemes={'h': Scheme(shape=(31,), dtype=torch.float32)}
       edata_schemes={}),
 Graph(num_nodes=19823, num_edges=55623,
       ndata_schemes={'h': Scheme(shape=(31,), dtype=torch.float32)}
       edata_schemes={}),
 Graph(num_nodes=44729, num_edges=131149,
       ndata_schemes={'h': Scheme(shape=(31,), dtype=torch.float32)}
       edata_schemes={}),
 Graph(num_nodes=13191, num_edges=35285,
     

In [171]:
prot_convs = prot_conv

In [173]:
relu = nn.ReLU()

In [177]:
for prot_g in prot_gs:
    feat_prot = prot_g.ndata['h']

    for prot_conv in prot_convs:
        feat_prot = relu(prot_conv(prot_g, feat_prot))

    prot_repr = pool_prot(prot_g, feat_prot).view(1, -1, 31)
    prot_repr = F.pad(
        input=prot_repr, 
        pad=(0, 0, 0, 140 - prot_repr.size()[1]), 
        mode='constant', value=0)
    print(prot_repr.shape)

torch.Size([1, 140, 31])
torch.Size([1, 140, 31])
torch.Size([1, 140, 31])
torch.Size([1, 140, 31])
torch.Size([1, 140, 31])
torch.Size([1, 140, 31])
torch.Size([1, 140, 31])
torch.Size([1, 140, 31])
torch.Size([1, 140, 31])
torch.Size([1, 140, 31])


In [178]:
prot_g_bilstm = nn.LSTM(31, 31, num_layers=2, batch_first=True, bidirectional=True, dropout=0.2)

In [185]:
prot_reprs = []
for prot_g in prot_gs:
    feat_prot = prot_g.ndata['h']

    for prot_conv in prot_convs:
        feat_prot = relu(prot_conv(prot_g, feat_prot))

    prot_repr = pool_prot(prot_g, feat_prot).view(1, -1, 31)
    prot_repr = F.pad(
        input=prot_repr, 
        pad=(0, 0, 0, 140 - prot_repr.size()[1]), 
        mode='constant', value=0)
    prot_repr, _ = prot_g_bilstm(prot_repr)
    prot_reprs.append(prot_repr)
torch.concat(prot_reprs).permute(1, 0, 2).shape

torch.Size([140, 10, 62])

In [None]:
mol_gs = []
for idx, row in train_df_sub.iterrows():
    cis = row.compound_iso_smiles
    mol_g = smile_to_graph(cis)
    mol_gs.append(mol_g)

In [191]:
for idx, row in train_df_sub.iterrows():
    cis, pid, key = row.compound_iso_smiles, row.protein_id, row.prot_key
    prot_llm = pid_to_llm(pid)
    print(prot_llm.shape)

torch.Size([849, 2560])
torch.Size([774, 2560])
torch.Size([495, 2560])
torch.Size([526, 2560])
torch.Size([589, 2560])
torch.Size([1024, 2560])
torch.Size([747, 2560])
torch.Size([841, 2560])
torch.Size([768, 2560])
torch.Size([640, 2560])


In [186]:
from model.TGCA.tgt_guided_cross_attention_model import TargetGuidedCrossAttention

In [187]:
tgca = TargetGuidedCrossAttention(embed_dim=128, num_heads=1)

In [188]:
tgca

TargetGuidedCrossAttention(
  (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
)

In [192]:
llms = torch.ones(211, 1024, 512)

In [193]:
llms.shape

torch.Size([211, 1024, 512])

In [199]:
llms.flatten(2).shape

torch.Size([211, 1024, 512])

In [202]:
llms.transpose(-1, -2).shape

torch.Size([211, 512, 1024])

In [210]:
datasets = ['kiba', 'davis']

for dataset in datasets:
    pid2llm = np.load(f'{dataset}.npz',allow_pickle=True)['dict'][()]
    print(max([llm.shape[1] for llm in pid2llm.values()]))

1024
1024


# CM csv clr...

In [5]:
data_dir = '/data5/luozc/projects/DTA/GINCM-DTA/data/'
def filter_by_tgt_ori_key(df, dataset):
    new_df = df.copy()
    return new_df[new_df.apply(lambda row: valid_target(row.target_original_key, dataset, data_dir), axis=1)]

In [6]:
dataset = 'kiba'
df = pd.read_csv(f'data/{dataset}/{dataset}_test.csv')

In [None]:
filter_by_tgt_ori_key(df, dataset)

Unnamed: 0,compound_iso_smiles,target_sequence,target_pdb_id,target_original_key,affinity,protein_id,drug_id
0,Cc1ccccc1NC(=O)Nc1ccc(-c2coc3ncnc(N)c23)cc1,MPALARDGGQLPLLVVFSAMIFGTITNQDLPVIKCVLINHKNNDSS...,7QDP_2,P36888,14.400162,83,345
1,Cc1ccccc1-c1c(-c2ccc3[nH]nc(N)c3c2)nnn1Cc1ccccc1,MPHPRRYHSSERGSRGSYREHYRSRKHKRRRSRSWSSSSDRTRRRR...,6KHE_1,P49760,12.399998,104,495
2,CC(C)(C(N)=O)n1cc(-c2cnc(N)c3c(-c4ccc(NC(=O)Nc...,MSGRPRTTSFAESCKPVQQPSAFGSMKVSRDKDGSKVTTVVATPGQ...,1I09_1,P49841,11.400000,106,163
3,Fc1ccc(-c2ccc3nccn3n2)cn1,MGAIGLLWLLPLLLSTAAVGSGMGTGQRAGSPAAGPPLQPREPLSY...,7NX3_1,Q9UM73,10.100000,222,1372
4,NCCCOc1cc2c(c(-c3ccc(Nc4nc5ccccc5o4)cc3)c1)CNC2=O,MASSSGSKAEFIVGGKYKLVRKIGSGSFGDIYLAINITNGEEVAVK...,5FQD_3,P48729,12.699998,97,254
...,...,...,...,...,...,...,...
19704,CN(C)c1cc2sncc2cc1NC(=O)C(=O)O,MSELEEDFAKILMLKEERIKELEKRLSEKEEEIQELKRKLHKCQSV...,7LV3_1,Q13976,11.800001,159,467
19705,CN1CCC1COc1cncc(CCc2ccncc2)c1,MAPFLRIAFNSYELGSLQAEDEANQPFCAVKMKEALSTERGKTLVQ...,1XJD_1,Q05655,11.200000,141,967
19706,O=C(CO)N1CCC(c2[nH]nc(-c3ccc(Cl)cc3F)c2-c2ccnc...,MSSWIRWHGPAMARLWGFCWLVVGFWRAAFACPTSCKCSASRIWCS...,4ASZ_1,Q16620,11.400000,176,104
19707,NNc1cc(N2CCOCC2)nc(OCCc2ccccn2)n1,MATCIGEKIEDFKVGNLLGKGSFAGVYRAESIHTGLEVAIKMIDKK...,3COK_1,O00444,11.500000,4,968


In [9]:
datasets = ['davis', 'kiba']
folds = '01234'

for dataset in datasets:
    # test df
    test_df = pd.read_csv(f'data/{dataset}/{dataset}_test.csv')
    
    filter_by_tgt_ori_key(test_df, dataset).to_csv(f'data/{dataset}/{dataset}_test_cm.csv', index=False)
    
    for fold in tqdm(folds):
        # train & val df
        train_df = pd.read_csv(f'data/{dataset}/{dataset}_train_fold_{fold}.csv')
        valid_df = pd.read_csv(f'data/{dataset}/{dataset}_valid_fold_{fold}.csv')

        filter_by_tgt_ori_key(train_df, dataset).to_csv(f'data/{dataset}/{dataset}_train_fold_{fold}_cm.csv', index=False)
        filter_by_tgt_ori_key(valid_df, dataset).to_csv(f'data/{dataset}/{dataset}_valid_fold_{fold}_cm.csv', index=False)
        
        print(dataset, f'fold {fold} clear!')

 20%|███████████████████████                                                                                            | 1/5 [00:01<00:04,  1.01s/it]

davis fold 0 clear!


 40%|██████████████████████████████████████████████                                                                     | 2/5 [00:02<00:03,  1.01s/it]

davis fold 1 clear!


 60%|█████████████████████████████████████████████████████████████████████                                              | 3/5 [00:03<00:02,  1.01s/it]

davis fold 2 clear!


 80%|████████████████████████████████████████████████████████████████████████████████████████████                       | 4/5 [00:04<00:01,  1.01s/it]

davis fold 3 clear!


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:05<00:00,  1.02s/it]

davis fold 4 clear!



 20%|███████████████████████                                                                                            | 1/5 [00:03<00:15,  3.77s/it]

kiba fold 0 clear!


 40%|██████████████████████████████████████████████                                                                     | 2/5 [00:07<00:11,  3.78s/it]

kiba fold 1 clear!


 60%|█████████████████████████████████████████████████████████████████████                                              | 3/5 [00:11<00:07,  3.78s/it]

kiba fold 2 clear!


 80%|████████████████████████████████████████████████████████████████████████████████████████████                       | 4/5 [00:15<00:03,  3.79s/it]

kiba fold 3 clear!


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:18<00:00,  3.79s/it]

kiba fold 4 clear!





In [10]:
datasets = ['davis', 'kiba']

for dataset in datasets:
    # cold df
    cold_df = pd.read_csv(f'data/{dataset}_cold.csv')

    filter_by_tgt_ori_key(cold_df, dataset).to_csv(f'data/{dataset}_cold_cm.csv', index=False)
    
    print(dataset, f'cold clear!')

davis cold clear!
kiba cold clear!
