In [1]:
import pandas as pd
from pathlib import Path
from sklearn.model_selection import KFold
from loguru import logger
from tqdm.auto import tqdm
import random
from rdkit.Chem.Scaffolds import MurckoScaffold
from typing import List, Tuple
from typing import Optional, List

In [2]:
raw_data_dir = Path("../datasets/raw")

data_consolidation = pd.read_csv(raw_data_dir / "Dataset_Consolidation_canonicalized.csv")
data_cyanine = pd.read_csv(raw_data_dir / "Dataset_Cyanine_canonicalized.csv")
data_xanthene = pd.read_csv(raw_data_dir / "Dataset_Xanthene_canonicalized.csv")

In [3]:
def drop_duplicates(df: pd.DataFrame, columns: List[str]) -> pd.DataFrame:
    logger.info(f"before dropping duplicates: {df.shape[0]} rows")
    df = df.drop_duplicates(subset=columns)
    df.reset_index(drop=True, inplace=True)
    logger.info(f"after dropping duplicates: {df.shape[0]} rows")
    return df

data_consolidation = drop_duplicates(data_consolidation, ["smiles", "solvent"])
data_cyanine = drop_duplicates(data_cyanine, ["smiles", "solvent"])
data_xanthene = drop_duplicates(data_xanthene, ["smiles", "solvent"])

[32m2025-04-05 17:19:56.883[0m | [1mINFO    [0m | [36m__main__[0m:[36mdrop_duplicates[0m:[36m2[0m - [1mbefore dropping duplicates: 36750 rows[0m


[32m2025-04-05 17:19:56.909[0m | [1mINFO    [0m | [36m__main__[0m:[36mdrop_duplicates[0m:[36m5[0m - [1mafter dropping duplicates: 36735 rows[0m
[32m2025-04-05 17:19:56.911[0m | [1mINFO    [0m | [36m__main__[0m:[36mdrop_duplicates[0m:[36m2[0m - [1mbefore dropping duplicates: 1496 rows[0m
[32m2025-04-05 17:19:56.915[0m | [1mINFO    [0m | [36m__main__[0m:[36mdrop_duplicates[0m:[36m5[0m - [1mafter dropping duplicates: 1496 rows[0m
[32m2025-04-05 17:19:56.916[0m | [1mINFO    [0m | [36m__main__[0m:[36mdrop_duplicates[0m:[36m2[0m - [1mbefore dropping duplicates: 1152 rows[0m
[32m2025-04-05 17:19:56.918[0m | [1mINFO    [0m | [36m__main__[0m:[36mdrop_duplicates[0m:[36m5[0m - [1mafter dropping duplicates: 1146 rows[0m


In [4]:
def random_split(
    df: pd.DataFrame, save_dir: Path, name: str, random_state: int = 42
) -> pd.DataFrame:
    kf = KFold(n_splits=5, shuffle=True, random_state=random_state)
    _save_dir = save_dir / "random"
    _save_dir.mkdir(parents=True, exist_ok=True)
    for fold, (train_index, valid_index) in enumerate(kf.split(df)):
        _df = df.copy()
        _df.loc[valid_index, "split"] = "valid"  # 这一折作为 valid
        _df.loc[train_index, "split"] = "train"  # 剩余的数据作为 train
        df_test = _df[_df["split"] == "valid"].copy()
        df_test["split"] = "test"
        _df = pd.concat([_df, df_test], ignore_index=True)
        _df.to_csv(_save_dir / f"{name}_fold{fold}.csv", index=False)

        n_total = len(_df)
        n_test = len(_df[_df["split"] == "test"])
        n_valid = len(_df[_df["split"] == "valid"])
        n_train = len(_df[_df["split"] == "train"])

        logger.info(
            f"length of {name}_fold{fold}: {n_total}; length of train: {n_train}; length of valid: {n_valid}; length of test: {n_test}"
        )

In [5]:
random_split(data_consolidation, raw_data_dir, "consolidation")
random_split(data_cyanine, raw_data_dir, "cyanine")
random_split(data_xanthene, raw_data_dir, "xanthene")

[32m2025-04-05 17:19:57.252[0m | [1mINFO    [0m | [36m__main__[0m:[36mrandom_split[0m:[36m21[0m - [1mlength of consolidation_fold0: 44082; length of train: 29388; length of valid: 7347; length of test: 7347[0m
[32m2025-04-05 17:19:57.495[0m | [1mINFO    [0m | [36m__main__[0m:[36mrandom_split[0m:[36m21[0m - [1mlength of consolidation_fold1: 44082; length of train: 29388; length of valid: 7347; length of test: 7347[0m
[32m2025-04-05 17:19:57.700[0m | [1mINFO    [0m | [36m__main__[0m:[36mrandom_split[0m:[36m21[0m - [1mlength of consolidation_fold2: 44082; length of train: 29388; length of valid: 7347; length of test: 7347[0m
[32m2025-04-05 17:19:57.866[0m | [1mINFO    [0m | [36m__main__[0m:[36mrandom_split[0m:[36m21[0m - [1mlength of consolidation_fold3: 44082; length of train: 29388; length of valid: 7347; length of test: 7347[0m
[32m2025-04-05 17:19:58.043[0m | [1mINFO    [0m | [36m__main__[0m:[36mrandom_split[0m:[36m21[0m - [1m

In [6]:
def generate_scaffold(smiles, include_chirality=False):
    """
    Obtain Bemis-Murcko scaffold from smiles
    :param smiles:
    :param include_chirality:
    :return: smiles of scaffold
    """
    scaffold = MurckoScaffold.MurckoScaffoldSmiles(
        smiles=smiles, includeChirality=include_chirality
    )
    return scaffold

In [7]:
def scaffold_split(
    smiles_list: List[str],
    k: int = 5,  # Number of folds
    balanced: bool = True,
    seed: int = 42,
) -> List[Tuple[List[int], List[int]]]:

    # Generate scaffold-based dictionary
    all_scaffolds = {}
    scaffolds = []
    for i, smiles in enumerate(tqdm(smiles_list)):
        try:
            scaffold = generate_scaffold(smiles, include_chirality=True)
            scaffolds.append(scaffold)
        except Exception:
            logger.warning(f"Error generating scaffold for {smiles}")
            continue
        if scaffold not in all_scaffolds:
            all_scaffolds[scaffold] = [i]
        else:
            all_scaffolds[scaffold].append(i)

    # Group scaffolds into a list of index sets
    scaffold_sets = list(all_scaffolds.values())

    # If balancing is enabled, shuffle larger scaffold sets differently
    if balanced:
        random.seed(seed)
        random.shuffle(scaffold_sets)

    # Create the KFold splits
    kf = KFold(n_splits=k, shuffle=True, random_state=seed)
    folds = []

    # We are going to assign each scaffold to one of the k folds
    for train_idx, val_idx in kf.split(scaffold_sets):
        train_fold = []
        val_fold = []

        # Collect indices for train and validation folds based on scaffolds
        for idx in train_idx:
            train_fold.extend(scaffold_sets[idx])
        for idx in val_idx:
            val_fold.extend(scaffold_sets[idx])

        folds.append((train_fold, val_fold))

    return folds, scaffolds

In [10]:
def scaffold_split_df(
    df: pd.DataFrame,
    name: str,
    k: int = 5,
    balanced: bool = True,
    seed: int = 42,
    save_dir: Path = raw_data_dir,
) -> List[Tuple[List[int], List[int]]]:
    smiles_list = df["smiles"].tolist()
    folds, scaffolds = scaffold_split(smiles_list, k=k, balanced=balanced, seed=seed)
    df["scaffold"] = scaffolds

    _save_dir = save_dir / "scaffold"
    _save_dir.mkdir(parents=True, exist_ok=True)

    n_scaffolds = df["scaffold"].nunique()
    n_smiles_unique = df["smiles"].nunique()
    logger.info(f"number of scaffolds: {n_scaffolds}")
    logger.info(f"number of smiles: {n_smiles_unique}")

    for fold, (train_idx, val_idx) in enumerate(folds):
        _df = df.copy()
        _df.loc[val_idx, "split"] = "valid"
        _df.loc[train_idx, "split"] = "train"
        df_test = _df[_df["split"] == "valid"].copy()
        df_test["split"] = "test"
        _df = pd.concat([_df, df_test], ignore_index=True)
        _df.to_csv(_save_dir / f"{name}_fold{fold}.csv", index=False)

        n_total = len(_df)
        n_test = len(_df[_df["split"] == "test"])
        n_valid = len(_df[_df["split"] == "valid"])
        n_train = len(_df[_df["split"] == "train"])

        logger.info(
            f"length of {name}_fold{fold}: {n_total}; length of train: {n_train}; length of valid: {n_valid}; length of test: {n_test}"
        )

In [11]:
scaffold_split_df(data_consolidation, "consolidation")

  0%|          | 0/36735 [00:00<?, ?it/s]

[32m2025-04-05 17:20:55.832[0m | [1mINFO    [0m | [36m__main__[0m:[36mscaffold_split_df[0m:[36m18[0m - [1mnumber of scaffolds: 9984[0m
[32m2025-04-05 17:20:55.834[0m | [1mINFO    [0m | [36m__main__[0m:[36mscaffold_split_df[0m:[36m19[0m - [1mnumber of smiles: 25128[0m
[32m2025-04-05 17:20:56.045[0m | [1mINFO    [0m | [36m__main__[0m:[36mscaffold_split_df[0m:[36m35[0m - [1mlength of consolidation_fold0: 43183; length of train: 30287; length of valid: 6448; length of test: 6448[0m
[32m2025-04-05 17:20:56.238[0m | [1mINFO    [0m | [36m__main__[0m:[36mscaffold_split_df[0m:[36m35[0m - [1mlength of consolidation_fold1: 43568; length of train: 29902; length of valid: 6833; length of test: 6833[0m
[32m2025-04-05 17:20:56.433[0m | [1mINFO    [0m | [36m__main__[0m:[36mscaffold_split_df[0m:[36m35[0m - [1mlength of consolidation_fold2: 44247; length of train: 29223; length of valid: 7512; length of test: 7512[0m
[32m2025-04-05 17:20:56.62

In [12]:
scaffold_split_df(data_cyanine, "cyanine")

  0%|          | 0/1496 [00:00<?, ?it/s]

[32m2025-04-05 17:21:06.484[0m | [1mINFO    [0m | [36m__main__[0m:[36mscaffold_split_df[0m:[36m18[0m - [1mnumber of scaffolds: 385[0m
[32m2025-04-05 17:21:06.486[0m | [1mINFO    [0m | [36m__main__[0m:[36mscaffold_split_df[0m:[36m19[0m - [1mnumber of smiles: 792[0m
[32m2025-04-05 17:21:06.516[0m | [1mINFO    [0m | [36m__main__[0m:[36mscaffold_split_df[0m:[36m35[0m - [1mlength of cyanine_fold0: 1767; length of train: 1225; length of valid: 271; length of test: 271[0m
[32m2025-04-05 17:21:06.545[0m | [1mINFO    [0m | [36m__main__[0m:[36mscaffold_split_df[0m:[36m35[0m - [1mlength of cyanine_fold1: 1839; length of train: 1153; length of valid: 343; length of test: 343[0m
[32m2025-04-05 17:21:06.572[0m | [1mINFO    [0m | [36m__main__[0m:[36mscaffold_split_df[0m:[36m35[0m - [1mlength of cyanine_fold2: 1724; length of train: 1268; length of valid: 228; length of test: 228[0m
[32m2025-04-05 17:21:06.600[0m | [1mINFO    [0m | [36m_

In [13]:
scaffold_split_df(data_xanthene, "xanthene")

  0%|          | 0/1146 [00:00<?, ?it/s]

[32m2025-04-05 17:21:08.798[0m | [1mINFO    [0m | [36m__main__[0m:[36mscaffold_split_df[0m:[36m18[0m - [1mnumber of scaffolds: 278[0m
[32m2025-04-05 17:21:08.799[0m | [1mINFO    [0m | [36m__main__[0m:[36mscaffold_split_df[0m:[36m19[0m - [1mnumber of smiles: 704[0m
[32m2025-04-05 17:21:08.822[0m | [1mINFO    [0m | [36m__main__[0m:[36mscaffold_split_df[0m:[36m35[0m - [1mlength of xanthene_fold0: 1377; length of train: 915; length of valid: 231; length of test: 231[0m
[32m2025-04-05 17:21:08.845[0m | [1mINFO    [0m | [36m__main__[0m:[36mscaffold_split_df[0m:[36m35[0m - [1mlength of xanthene_fold1: 1328; length of train: 964; length of valid: 182; length of test: 182[0m
[32m2025-04-05 17:21:08.868[0m | [1mINFO    [0m | [36m__main__[0m:[36mscaffold_split_df[0m:[36m35[0m - [1mlength of xanthene_fold2: 1307; length of train: 985; length of valid: 161; length of test: 161[0m
[32m2025-04-05 17:21:08.893[0m | [1mINFO    [0m | [36m_