In [None]:
from dgllife.utils import ScaffoldSplitter,RandomSplitter
import pandas as pd
from sklearn.model_selection import train_test_split
from typing import Tuple, Dict, List, Set, Union
from collections import defaultdict
from tqdm import tqdm
from rdkit import Chem
from rdkit.Chem.Scaffolds import MurckoScaffold
from random import Random


In [None]:
#random split
def train_test_validation_split(df):
    train_data, rest_data = train_test_split(df, test_size=0.2)
    test_data, validation_data = train_test_split(rest_data, test_size=0.5)
    return train_data.reset_index(drop=True), validation_data.reset_index(drop=True), test_data.reset_index(drop=True)

In [None]:
#scaffold split
def generate_scaffold(mol: Union[str, Chem.Mol], include_chirality: bool = False) -> str:
    
    mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol
    scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=mol, includeChirality=include_chirality)

    return scaffold

    
def scaffold_to_smiles(mols: Union[List[str], List[Chem.Mol]],
                       use_indices: bool = True) -> Dict[str, Union[Set[str], Set[int]]]:

    scaffolds = defaultdict(set)
    for i, mol in tqdm(enumerate(mols), total=len(mols)):
        scaffold = generate_scaffold(mol)
        if use_indices:
            scaffolds[scaffold].add(i)
        else:
            scaffolds[scaffold].add(mol)


    return scaffolds

def scaffold_split(data: pd.DataFrame,
                   sizes: Tuple[float, float, float] = (0.8, 0.1, 0.1),
                   balanced: bool = True,
                   seed: int = 2020,
                   ) :

    assert sum(sizes) == 1

   
    # Split
    train_size, val_size, test_size = sizes[0] * len(data), sizes[1] * len(data), sizes[2] * len(data)
    train, val, test = [], [], []
    train_scaffold_count, val_scaffold_count, test_scaffold_count = 0, 0, 0

    # Map from scaffold to index in the data
    scaffold_to_indices = scaffold_to_smiles(data['smiles'].to_list(), use_indices=True)

    # Seed randomness
    random = Random(seed)

    if balanced:  # Put stuff that's bigger than half the val/test size into train, rest just order randomly
        index_sets = list(scaffold_to_indices.values())
        big_index_sets = []
        small_index_sets = []
        for index_set in index_sets:
            if len(index_set) > val_size / 2 or len(index_set) > test_size / 2:
                big_index_sets.append(index_set)
            else:
                small_index_sets.append(index_set)
        random.seed(seed)
        random.shuffle(big_index_sets)
        random.shuffle(small_index_sets)
        index_sets = big_index_sets + small_index_sets
    else:  # Sort from largest to smallest scaffold sets
        index_sets = sorted(list(scaffold_to_indices.values()),
                            key=lambda index_set: len(index_set),
                            reverse=True)

    for index_set in index_sets:
        if len(train) + len(index_set) <= train_size:
            train += index_set
            train_scaffold_count += 1
        elif len(val) + len(index_set) <= val_size:
            val += index_set
            val_scaffold_count += 1
        else:
            test += index_set
            test_scaffold_count += 1
    print(f'Total scaffolds = {len(scaffold_to_indices):,} | '
                     f'train scaffolds = {train_scaffold_count:,} | '
                     f'val scaffolds = {val_scaffold_count:,} | '
                     f'test scaffolds = {test_scaffold_count:,}')
    # Map from indices to data
    train = [data.values[i] for i in train]
    val = [data.values[i] for i in val]
    test = [data.values[i] for i in test]
    train1 = pd.DataFrame(train, columns=['smiles','basic_pka'])
    val1 = pd.DataFrame(val, columns=['smiles','basic_pka'])
    test1 = pd.DataFrame(test, columns=['smiles','basic_pka'])
    print(f'train.shape: {train1.shape}, valid.shape: {val1.shape}, test.shape: {test1.shape}')
    
    assert train1.shape[0]+val1.shape[0]+test1.shape[0]==data.shape[0]
    return train1, val1, test1

In [None]:
logD=pd.read_csv("processed_chembl_29_logD(M-data).csv")
logp=pd.read_csv("logp.csv")
RT=pd.read_csv("RT.csv")

In [None]:
train_a,valid_a,test_a=scaffold_split(logD)
train_b,valid_b,test_b=scaffold_split(logp)