In [1]:
import os
import pickle
import torch
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem

import torch.nn.functional as F
from torch_geometric.data import Data

from Get_Mol_features import get_mol_features, remove_hydrogen

### 1. Load data

In [2]:
# atom features (9): [atomic_num, chirality, degree, formal_charge, numH, num_radical_e, hybridization, is_in_aromatic, is_in_ring]
# edge features (4): [bond_type, bond_stereo, is_conjugated, bond_direction]

In [3]:
zinc_input_df = pd.read_csv(f"../../input_data/compound/zinc_combined_apr_8_2019.csv.gz", 
                sep=',', compression='gzip',dtype='str')

zinc_smiles_list = list(zinc_input_df["smiles"])
zinc_id_list = list(zinc_input_df["zinc_id"])
print(f"Load {len(zinc_id_list)} compound")

Load 2000000 compound


### 2. Get molecular graph features

In [4]:
data_smiles_list, data_list = list(), list()
atom_slices, edge_slices = [0], [0]
total_eigvecs, total_eigvals = list(), list()
all_atom_features, all_edge_features = list(), list()
edge_indices = list()
total_n_atoms = list()

total_atoms, total_edges = 0, 0
avg_degree = 0

In [5]:
count = 0
for i in range(len(zinc_smiles_list)):
    if i % 10000 == 0:
        print(i, len(data_smiles_list), len(zinc_smiles_list))
    s = zinc_smiles_list[i]
    # Each example contains a single species
    try:
        mol = AllChem.MolFromSmiles(s)
        if mol != None: # Ignore invalid mol objects
            n_atoms = len(mol.GetAtoms())

            atom_features_list, edge_index, edge_features, n_edges = get_mol_features(mol)

            all_atom_features.append(torch.tensor(atom_features_list, dtype=torch.long))

            if atom_features_list == 0:
                print("Please remove the SDF files below and preprocess once again.")
                print(sdf_path)
                break

            avg_degree += (n_edges / 2) / n_atoms
            edge_indices.append(edge_index)
            all_edge_features.append(edge_features)

            total_edges += n_edges
            total_atoms += n_atoms
            total_n_atoms.append(n_atoms)

            edge_slices.append(total_edges)
            atom_slices.append(total_atoms)

            id = int(zinc_id_list[i].split('ZINC')[1].lstrip('0'))
            data_list.append(id)
            data_smiles_list.append(s) 
    except:
        pass 

0 0 2000000
10000 10000 2000000
20000 20000 2000000
30000 30000 2000000
40000 40000 2000000
50000 50000 2000000
60000 60000 2000000
70000 70000 2000000
80000 80000 2000000
90000 90000 2000000
100000 100000 2000000
110000 110000 2000000
120000 120000 2000000
130000 130000 2000000
140000 140000 2000000
150000 150000 2000000
160000 160000 2000000
170000 170000 2000000
180000 180000 2000000
190000 190000 2000000
200000 200000 2000000
210000 210000 2000000
220000 220000 2000000
230000 230000 2000000
240000 240000 2000000
250000 250000 2000000
260000 260000 2000000
270000 270000 2000000
280000 280000 2000000
290000 290000 2000000
300000 300000 2000000
310000 310000 2000000
320000 320000 2000000
330000 330000 2000000
340000 340000 2000000
350000 350000 2000000
360000 360000 2000000
370000 370000 2000000
380000 380000 2000000
390000 390000 2000000
400000 400000 2000000
410000 410000 2000000
420000 420000 2000000
430000 430000 2000000
440000 440000 2000000
450000 450000 2000000
460000 460000 20

In [6]:
data_smiles_series = pd.Series(data_smiles_list)
data_smiles_series.to_csv(f"../../input_data/compound/smilesnonH.csv", 
            index = False, header=False) 

In [7]:
data_dict = {'mol_ids':data_list,
             'n_atoms':torch.tensor(total_n_atoms, dtype=torch.long),
             'atom_slices':torch.tensor(atom_slices, dtype=torch.long),
             'edge_slices':torch.tensor(edge_slices, dtype=torch.long),
             'edge_indices':torch.cat(edge_indices, dim=1),
             'atom_features':torch.cat(all_atom_features, dim=0),
             'edge_features':torch.cat(all_edge_features, dim=0),
             'avg_degree':avg_degree / len(data_list)}

In [8]:
torch.save(data_dict, f"../../input_data/compound/MgraphDatanonH.pt")

In [9]:
print(f"Molecular graphs: {len(data_list)}")

Molecular graphs: 2000000
