In [11]:
import pandas as pd
import numpy as np
import deepchem as dc
from deepchem.splits.splitters import generate_scaffold

In [2]:
data = pd.read_pickle("./ft_data.pkl")

In [3]:
scaffolds = [generate_scaffold(i) for i in data['smiles']]

In [4]:
scaffolds_dict = {}
for ind, scaffold in enumerate(scaffolds):
    if scaffold not in scaffolds_dict:
        scaffolds_dict[scaffold] = [ind]
    else:
        scaffolds_dict[scaffold].append(ind)

In [5]:
scaffolds_dict = {key: sorted(value) for key, value in scaffolds_dict.items()}
scaffold_sets = [scaffold_set for (scaffold, scaffold_set) in sorted(
                scaffolds_dict.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True)
                ]

In [6]:
frac_train, frac_valid, frac_test = 0.6, 0.2, 0.2
train_cutoff = frac_train * len(data)
valid_cutoff = (frac_train + frac_valid) * len(data)
train_inds, valid_inds, test_inds = [], [], []

In [7]:
for scaffold_set in scaffold_sets:
    if len(train_inds) + len(scaffold_set) > train_cutoff:
        if len(train_inds) + len(valid_inds) + len(scaffold_set) > valid_cutoff:
            test_inds += scaffold_set
        else:
            valid_inds += scaffold_set
    else:
        train_inds += scaffold_set

In [8]:
len(train_inds), len(valid_inds), len(test_inds)

(5808, 1936, 1936)

In [9]:
train_data = data.loc[train_inds + valid_inds]
test_data = data.loc[test_inds]

In [10]:
train_data.to_pickle("./ft_data_train.pkl")
test_data.to_pickle("./ft_data_test.pkl")