In [1]:
import torch
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
from Get_Mol_features import get_mol_features, remove_hydrogen

### 1. Load data

In [3]:
IC50_df = pd.read_csv("../../input_data/BindingDB/IC50_data.tsv", sep = "\t")
IC50_cid, IC50_iso_SMILES = IC50_df.loc[:, "CID"], IC50_df.loc[:, "SMILES_iso"]
print(f"[IC50] unique compounds: {len(np.unique(IC50_cid))}")

[IC50] unique compounds: 582841


In [4]:
Ki_df = pd.read_csv("../../input_data/BindingDB/Ki_data.tsv", sep = "\t")
Ki_cid, Ki_iso_SMILES = Ki_df.loc[:, "CID"], Ki_df.loc[:, "SMILES_iso"]
print(f"[Ki] unique compounds: {len(np.unique(Ki_cid))}")

[Ki] unique compounds: 183546


In [5]:
total_cid = np.unique(np.concatenate((Ki_cid, IC50_cid)))
print(f"[Total]: {len(total_cid)}")

[Total]: 737776


### 2. Get molecular graph

In [6]:
atom_slices, edge_slices = [0], [0]
all_atom_features, all_edge_features = list(), list()
edge_indices, total_n_atoms, id_list = list(), list(), list()
total_atoms, total_edges = 0, 0
avg_degree = 0

In [7]:
path = "../../preprocessing_BindingDB/data/sdf/"

In [None]:
for mol_idx, lig_code in tqdm(enumerate(total_cid)):
    if mol_idx % 500 == 0:
        print(mol_idx, len(total_cid))
    sdf_path = os.path.join(path, f"{lig_code}.sdf")
    suppl = Chem.SDMolSupplier(sdf_path, removeHs = True)

    mol = next(iter(suppl))

    n_atoms = len(mol.GetAtoms())

    atom_features_list, edge_index, edge_features, n_edges = get_mol_features(mol)

    all_atom_features.append(torch.tensor(atom_features_list, dtype=torch.long))

    if atom_features_list == 0:
        print("Please remove the SDF files below and preprocess once again.")
        print(sdf_path)
        break 

    avg_degree += (n_edges / 2) / n_atoms 
    edge_indices.append(edge_index)
    all_edge_features.append(edge_features)

    total_edges += n_edges
    total_atoms += n_atoms
    total_n_atoms.append(n_atoms)

    edge_slices.append(total_edges)
    atom_slices.append(total_atoms)

    id_list.append(lig_code)

In [None]:
data_dict = {'mol_ids':id_list,
             'n_atoms':torch.tensor(total_n_atoms, dtype=torch.long),
             'atom_slices':torch.tensor(atom_slices, dtype=torch.long),
             'edge_slices':torch.tensor(edge_slices, dtype=torch.long),
             'edge_indices':torch.cat(edge_indices, dim=1),
             'atom_features':torch.cat(all_atom_features, dim=0),
             'edge_features':torch.cat(all_edge_features, dim=0),
             'avg_degree':avg_degree / len(id_list)}

In [None]:
torch.save(data_dict, f"../../input_data/BindingDB/BindingDB_graph_data.pt")

In [None]:
print(f"Molecular graphs: {len(data_list)}")