In [1]:
from rdkit import Chem
from rdkit.Chem.Scaffolds import MurckoScaffold
from collections import defaultdict
import random
import pandas as pd


In [2]:
def get_scaffold(smiles: str) -> str:
    mol = Chem.MolFromSmiles(smiles)
    scaffold = MurckoScaffold.GetScaffoldForMol(mol)
    return Chem.MolToSmiles(scaffold, canonical=True)


In [9]:
def scaffold_split(
    df,
    smiles_col="smiles",
    frac_train=0.8,
    seed=42
):
    random.seed(seed)

    scaffold_groups = defaultdict(list)
    for idx, row in df.iterrows():
        scaffold = get_scaffold(row[smiles_col])
        scaffold_groups[scaffold].append(idx)

    scaffolds = list(scaffold_groups.keys())
    random.shuffle(scaffolds)

    train_idx, test_idx = [], []
    n_total = len(df)
    n_train = int(frac_train * n_total)

    for scaffold in scaffolds:
        if len(train_idx) < n_train:
            train_idx.extend(scaffold_groups[scaffold])
        else:
            test_idx.extend(scaffold_groups[scaffold])

    return df.loc[train_idx], df.loc[test_idx]


In [10]:
clean_df = pd.read_csv("training_data.csv")
print(clean_df.columns)

train_df, val_df = scaffold_split(clean_df)


Index(['ACTIVITY', 'smiles'], dtype='str')


In [11]:
print(train_df)
print(val_df)

       ACTIVITY                                             smiles
3974          0        COc1cc(NCC2CCNCC2)nc2c1nnn2-c1cccc(N(C)C)c1
6552          1  FC(F)(F)Oc1cccc(-n2nnc3cc(Cl)c(NCC4CCNCC4)nc32)c1
10066         1    COc1cc(NCC2CCNCC2)nc2c1nnn2-c1cccc(OC(F)(F)F)c1
7189          0               N=C(Nc1ccc2c(c1)CCCN2CCCNCCO)c1cccs1
7200          0                N=C(Nc1ccc2c(c1)CCCN2CCNCCO)c1cccs1
...         ...                                                ...
3282          1      CCN1CCNC(CN2CCN(C(=O)Nc3ccc(Cl)c(Cl)c3)CC2)C1
7327          1  Cc1cccc(C)c1OCC(=O)N[C@H](Cc1ccccc1)[C@H](O)C[...
7871          1  CCN(CC)Cc1cc(C=NN=C(N)CC(O)c2cc3c(F)cc(F)cc3c3...
4765          1      Fc1cc(-c2cc3sc(N4CCC(N5CCCCC5)CC4)nc3cn2)ccn1
6362          1     COc1cc(-c2cc3sc(N4CCC(N5CCCCC5)CC4)nc3cn2)ccn1

[9545 rows x 2 columns]
       ACTIVITY                                             smiles
9884          0  NC1(C(=O)N[C@@H](CCN2CCCCC2)c2ccc(Cl)cc2)CCN(c...
1335          0    Cc1nc2ccncc2c(=O)n