# Prepare Data
Given an sdf-file with the QM9 molecules and a csv file with the respective labels,
this notebook filters, shuffles and partitions the data.

When the QM9 data set was created, a small part of the molecules did not retain their configuration during geometry optimization. A list of the IDs of these molecules is provided in uncharacterized-ids.txt and will be filtered out.

Additionally, due to some [erroneously parsed bonds in the sdf file](https://github.com/deepchem/deepchem/issues/1122#issuecomment-402662114), some molecules could not be imported with rdkit. They are removed as well.


In [1]:
import rdkit.Chem as Chem
import numpy as np
import os
from tqdm import tqdm_notebook as tqdm

def filter_sdf(raw_path, filter_path, destination_path):
    """Remove from sdf file: the molecules with ids listed in file filter_path and molecules rdkit cannot import"""
    with open(filter_path, 'r') as filter_file:
        lines = filter_file.readlines()
        ids_to_remove = [int(line) for line in lines]

    mols = Chem.SDMolSupplier(raw_path, removeHs=False)  # import molecules from sdf file
    num_mols = len(mols)
    print(num_mols, 'molecules in total')
    
    num_accepted = 0
    for i, mol in enumerate(tqdm(mols)):
        gdb_id = int(mols.GetItemText(i).split()[1])
        if gdb_id in ids_to_remove or mol is None:
            continue
        with open(destination_path, "a", newline='') as out_file:
            out_file.write(mols.GetItemText(i))
        num_accepted += 1
    print('removed {} mols'.format(num_mols - num_accepted))
    
filter_sdf('mols.sdf', 'uncharacterized-ids.txt', 'mols_filtered.sdf')

133656 molecules in total


HBox(children=(IntProgress(value=0, max=133656), HTML(value='')))


removed 5078 mols


In [2]:
# shuffle and partition
def create_partitions(mol_path, val_size=10000, test_size=10000, shuffle=True):
    mols = Chem.SDMolSupplier(mol_path, removeHs=False)  # import molecules from sdf file
    num_mols = len(mols)
    if shuffle:
        perm = np.random.permutation(num_mols)
    else:
        perm = np.arange(num_mols)
    
    train_size = num_mols - val_size - test_size
    assert train_size > 0 and val_size > 0 and test_size > 0
    
    train_indices = perm[:train_size]
    val_indices = perm[train_size:train_size+val_size]
    test_indices = perm[train_size+val_size:]
    
    out_dir = os.path.dirname(mol_path)
    
    with open(os.path.join(out_dir, 'training.sdf'), 'w', newline='') as f:
        for i in tqdm(train_indices):
            f.write(mols.GetItemText(int(i)))
        
    with open(os.path.join(out_dir, 'validation.sdf'), 'w', newline='') as f:
        for i in tqdm(val_indices):
            f.write(mols.GetItemText(int(i)))
        
    with open(os.path.join(out_dir, 'test.sdf'), 'w', newline='') as f:
        for i in tqdm(test_indices):
            f.write(mols.GetItemText(int(i)))

create_partitions('mols_filtered.sdf')

HBox(children=(IntProgress(value=0, max=108578), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10000), HTML(value='')))




In [3]:
# extract labels
def write_labels_for_sdf(mol_path, label_path):    
    with open(label_path, 'r') as label_file:
        header = label_file.readline()
        labels = label_file.readlines()
    
    destination_path = mol_path[:-4] + '_labels.csv'
    with open(destination_path, "w", newline='') as out_file:
        out_file.write(header)

    mols = Chem.SDMolSupplier(mol_path, removeHs=False)  # import molecules from sdf file  
    num_mols = len(mols)
    for i in tqdm(range(num_mols)):
        gdb_id = int(mols.GetItemText(i).split()[1])
        with open(destination_path, "a", newline='') as out_file:
            out_file.write(labels[gdb_id - 1])
            
write_labels_for_sdf('training.sdf', 'labels.csv')
write_labels_for_sdf('validation.sdf', 'labels.csv')
write_labels_for_sdf('test.sdf', 'labels.csv')

HBox(children=(IntProgress(value=0, max=108578), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10000), HTML(value='')))


