In [11]:
import torch
import random
import numpy as np
from itertools import compress
from rdkit.Chem.Scaffolds import MurckoScaffold
from collections import defaultdict
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem, DataStructs
from rdkit.ML.Cluster import Butina
from tqdm import tqdm

In [23]:
def generate_scaffold(smiles, include_chirality=False):
    """
    Obtain Bemis-Murcko scaffold from smiles
    :param smiles:
    :param include_chirality:
    :return: smiles of scaffold
    """
    scaffold = MurckoScaffold.MurckoScaffoldSmiles(
        smiles=smiles, includeChirality=include_chirality)
    return scaffold


# # test generate_scaffold
# s = 'Cc1cc(Oc2nccc(CCC)c2)ccc1'
# scaffold = generate_scaffold(s)
# assert scaffold == 'c1ccc(Oc2ccccn2)cc1'

def scaffold_split(dataset, smiles_list, task_idx=None, null_value=0,
                   frac_train=0.8, frac_valid=0.1, frac_test=0.1, ):
    np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.0)

    if task_idx != None:
        # filter based on null values in task_idx
        # get task array
        y_task = np.array([data.y[task_idx].item() for data in dataset])
        # boolean array that correspond to non null values
        non_null = y_task != null_value
        smiles_list = list(compress(enumerate(smiles_list), non_null))
    else:
        non_null = np.ones(len(dataset)) == 1
        smiles_list = list(compress(enumerate(smiles_list), non_null))

    scaffolds_list = []
    for i, smiles in smiles_list:
        scaffold = generate_scaffold(smiles, include_chirality=True)
        scaffolds_list.append(scaffold)

    dataset['scaffold'] = scaffolds_list
    fps = dataset['scaffold'].apply(lambda x: AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(x), 2, nBits=1024))
    print("fingerprint completed")
    fps = fps.tolist()
    dists = []
    nfps = len(fps)
    for i in tqdm(range(1, nfps), desc="Calculating similarities"):
        sims = DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i])
        dists.extend([1 - x for x in sims])
    print("start clutering")
    clusters = Butina.ClusterData(dists, nfps, 0.4, isDistData=True)
    cluster_labels = np.zeros(nfps, dtype=int)
    for cluster_id, cluster in enumerate(clusters):
        for idx in cluster:
            cluster_labels[idx] = cluster_id + 1
    dataset['cluster'] = cluster_labels
    print("cluter completed")
    
    # create dict of the form {cluster1: [idx1, idx....]}
    all_clusters = {}
    for i, label in enumerate(cluster_labels):
        if label not in all_clusters:
            all_clusters[label] = [i]
        else:
            all_clusters[label].append(i)

    # sort from largest to smallest sets
    all_clusters = {key: sorted(value) for key, value in all_clusters.items()}
    all_clusters_sets = [
        clusters_set for (clusters, clusters_set) in sorted(
            all_clusters.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True)
    ]

    # get train, valid test indices
    train_cutoff = frac_train * len(cluster_labels)
    valid_cutoff = (frac_train + frac_valid) * len(cluster_labels)
    train_idx, valid_idx, test_idx = [], [], []
    for clusters_set in all_clusters_sets:
        if len(train_idx) + len(clusters_set) > train_cutoff:
            if len(train_idx) + len(valid_idx) + len(clusters_set) > valid_cutoff:
                test_idx.extend(clusters_set)
            else:
                valid_idx.extend(clusters_set)
        else:
            train_idx.extend(clusters_set)

    assert len(set(train_idx).intersection(set(valid_idx))) == 0
    assert len(set(test_idx).intersection(set(valid_idx))) == 0
    
    train_dataset = dataset.iloc[train_idx]
    valid_dataset = dataset.iloc[valid_idx]
    test_dataset = dataset.iloc[test_idx]
    return train_dataset,valid_dataset,test_dataset

In [25]:
df_train = pd.read_csv('data/uspto50k/raw/uspto50k_train_application.csv',index_col=False)
df_val = pd.read_csv('data/uspto50k/raw/uspto50k_val_application.csv',index_col=False)
df_test = pd.read_csv('data/uspto50k/raw/uspto50k_test_application.csv',index_col=False)
dataset = pd.concat([df_train, df_val,df_test], ignore_index=True)
smiles_list = dataset['reactants>reagents>production'].apply(lambda x: x.split('>')[2])
print(len(dataset))
train_dataset,valid_dataset,test_dataset = scaffold_split(dataset, smiles_list, task_idx=None, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1)
train_dataset.to_csv('data/uspto50k/raw/uspto50k_train_scaffold.csv', index=False)
valid_dataset.to_csv('data/uspto50k/raw/uspto50k_val_scaffold.csv', index=False)
test_dataset.to_csv('data/uspto50k/raw/uspto50k_test_scaffold.csv', index=False)

50016
fingerprint completed


Calculating similarities: 100%|██████████| 50015/50015 [11:49<00:00, 70.46it/s]  


start clutering
cluter completed
