In [1]:
import os
import torch
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
from Get_Mol_features import get_mol_features, remove_hydrogen

### 1. Load data

In [2]:
Training_df = pd.read_csv("../../input_data/PDB/BA/Training_BA_data.tsv", sep = "\t")
Training_lig_codes = Training_df.loc[:, "Ligand_Codes"].values
print(f"[PDBbind] unique compounds: {len(np.unique(Training_lig_codes))}")

[PDBbind] unique compounds: 10408


In [3]:
CASF2016_df = pd.read_csv("../../input_data/PDB/BA/CASF2016_BA_data.tsv", sep = "\t")
CASF2016_lig_codes = CASF2016_df.loc[:, "Ligand_Codes"].values
print(f"[CASF2016] unique compounds: {len(np.unique(CASF2016_lig_codes))}")

[CASF2016] unique compounds: 262


In [4]:
CASF2013_df = pd.read_csv("../../input_data/PDB/BA/CASF2013_BA_data.tsv", sep = "\t")
CASF2013_lig_codes = CASF2013_df.loc[:, "Ligand_Codes"].values
print(f"[CASF2013] unique compounds: {len(np.unique(CASF2013_lig_codes))}")

[CASF2013] unique compounds: 160


In [5]:
CSAR2014_df = pd.read_csv("../../input_data/PDB/BA/CSAR2014_BA_data.tsv", sep = "\t")
CSAR2014_lig_codes = CSAR2014_df.loc[:, "Ligand_Codes"].values
print(f"[CSAR2014] unique compounds: {len(np.unique(CSAR2014_lig_codes))}")

[CSAR2014] unique compounds: 46


In [6]:
CSAR2012_df = pd.read_csv("../../input_data/PDB/BA/CSAR2012_BA_data.tsv", sep = "\t")
CSAR2012_lig_codes = CSAR2012_df.loc[:, "Ligand_Codes"].values
print(f"[CSAR2012] unique compounds: {len(np.unique(CSAR2012_lig_codes))}")

[CSAR2012] unique compounds: 54


In [7]:
CSARset1_df = pd.read_csv("../../input_data/PDB/BA/CSARset1_BA_data.tsv", sep = "\t")
CSARset1_lig_codes = CSARset1_df.loc[:, "Ligand_Codes"].values
print(f"[CSARset1] unique compounds: {len(np.unique(CSARset1_lig_codes))}")

[CSARset1] unique compounds: 140


In [8]:
CSARset2_df = pd.read_csv("../../input_data/PDB/BA/CSARset2_BA_data.tsv", sep = "\t")
CSARset2_lig_codes = CSARset2_df.loc[:, "Ligand_Codes"].values
print(f"[CSARset2] unique compounds: {len(np.unique(CSARset2_lig_codes))}")

[CSARset2] unique compounds: 120


In [9]:
Astex_df = pd.read_csv("../../input_data/PDB/BA/Astex_BA_data.tsv", sep = "\t")
Astex_lig_codes = Astex_df.loc[:, "Ligand_Codes"].values
print(f"[Astex] unique compounds: {len(np.unique(Astex_lig_codes))}")

[Astex] unique compounds: 72


In [10]:
COACH420_df = pd.read_csv("../../input_data/PDB/BA/COACH420_IS_data.tsv", sep = "\t")
COACH420_lig_codes = COACH420_df.loc[:, "Ligand_Codes"].values
print(f"[COACH420] unique compounds: {len(np.unique(COACH420_lig_codes))}")

[COACH420] unique compounds: 232


In [11]:
HOLO4K_df = pd.read_csv("../../input_data/PDB/BA/HOLO4K_IS_data.tsv", sep = "\t")
HOLO4K_lig_codes = HOLO4K_df.loc[:, "Ligand_Codes"].values
print(f"[HOLO4K] unique compounds: {len(np.unique(HOLO4K_lig_codes))}")

[HOLO4K] unique compounds: 1476


In [12]:
total_lig_codes = np.unique(np.concatenate((Training_lig_codes, CASF2016_lig_codes, CASF2013_lig_codes, CSAR2014_lig_codes, CSAR2012_lig_codes, CSARset1_lig_codes, CSARset2_lig_codes, Astex_lig_codes, COACH420_lig_codes, HOLO4K_lig_codes)))
print(f"[Total] unique compounds: {len(total_lig_codes)}")
print()

[Total] unique compounds: 11995



### 2. Get molecular graph features

In [13]:
atom_slices, edge_slices = [0], [0]
all_atom_features, all_edge_features = list(), list()
edge_indices, total_n_atoms, id_list = list(), list(), list()

total_atoms, total_edges = 0, 0
avg_degree = 0

In [14]:
path = "../../Preprocessing_PDB/data/PDB/ligand"

In [15]:
for mol_idx, lig_code in enumerate(total_lig_codes):
    if mol_idx % 1000 == 0:
        print(mol_idx, len(total_lig_codes))
    sdf_path = os.path.join(path, f"{lig_code}_ideal.sdf")
    suppl = Chem.SDMolSupplier(sdf_path, removeHs = True)

    mol = next(iter(suppl))
        
    n_atoms = len(mol.GetAtoms()) 

    atom_features_list, edge_index, edge_features, n_edges = get_mol_features(mol)

    if atom_features_list == 0:
        print("Please remove the SDF files below and preprocess once again.")
        print(sdf_path)
        break 

    if lig_code in ["313", "M2T"]:
        n_atoms, atom_features_list, edge_features, edge_index, n_edges = remove_hydrogen(atom_features_list, edge_index, edge_features)

    all_atom_features.append(torch.tensor(atom_features_list, dtype=torch.long))

    avg_degree += (n_edges / 2) / n_atoms 
    edge_indices.append(edge_index)
    all_edge_features.append(edge_features)

    total_edges += n_edges
    total_atoms += n_atoms
    total_n_atoms.append(n_atoms)

    edge_slices.append(total_edges)
    atom_slices.append(total_atoms)

    id_list.append(lig_code)

0 11995
1000 11995




2000 11995
3000 11995
4000 11995
5000 11995
6000 11995
7000 11995




8000 11995
9000 11995




10000 11995
11000 11995


In [16]:
data_dict = {'mol_ids':id_list,
             'n_atoms':torch.tensor(total_n_atoms, dtype=torch.long),
             'atom_slices':torch.tensor(atom_slices, dtype=torch.long),
             'edge_slices':torch.tensor(edge_slices, dtype=torch.long),
             'edge_indices':torch.cat(edge_indices, dim=1),
             'atom_features':torch.cat(all_atom_features, dim=0),
             'edge_features':torch.cat(all_edge_features, dim=0),
             'avg_degree':avg_degree / len(id_list)}

In [17]:
torch.save(data_dict, f"../../input_data/PDB/BA/PDB_graph_data.pt")

In [18]:
print(f"Molecular graphs: {len(id_list)}")

Molecular graphs: 11995
