## Setup

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path
import logging
import itertools
import warnings

import pandas as pd
import numpy as np
from tqdm.auto import tqdm

import matplotlib.pyplot as plt
import seaborn as sns

from statsmodels.stats.multicomp import pairwise_tukeyhsd
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics import r2_score, mean_absolute_error

import useful_rdkit_utils as uru
from rdkit import Chem
from rdkit.Chem import rdFingerprintGenerator
from rdkit.DataStructs import BulkTanimotoSimilarity

from admet.model.lgbm_wrapper import LGBMMorganCountWrapper

In [3]:
%matplotlib inline

In [4]:
# setup tqdm
tqdm.pandas()

In [5]:
# setup logging
level = logging.DEBUG
logger = logging.getLogger(__name__)
if logger.hasHandlers():
    logger.handlers.clear()

formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.setLevel(level)

logger.info("Imports successful.")

2025-11-15 20:17:02,831 - __main__ - INFO - Imports successful.


## Load Data

In [6]:
# Data input and output directories
base_data_dir = Path().cwd().parents[0] / "assets/dataset/eda/data/set"
output_dir = base_data_dir.parents[2] / "splits"
output_dir.mkdir(parents=True, exist_ok=True)

if not base_data_dir.exists():
    raise FileNotFoundError(f"Data directory not found at {base_data_dir}")

logger.info(f"Output directory set to {output_dir}")
logger.info(f"Input data directory found at {base_data_dir}")
for dataset_dir in base_data_dir.iterdir():
    logger.info(f"Dataset name: {dataset_dir.name}")

2025-11-15 20:17:02,858 - __main__ - INFO - Output directory set to /media/aglisman/Linux_Overflow/home/aglisman/VSCodeProjects/OpenADMET-ExpansionRx-Blind-Challenge/assets/dataset/splits
2025-11-15 20:17:02,859 - __main__ - INFO - Input data directory found at /media/aglisman/Linux_Overflow/home/aglisman/VSCodeProjects/OpenADMET-ExpansionRx-Blind-Challenge/assets/dataset/eda/data/set
2025-11-15 20:17:02,859 - __main__ - INFO - Dataset name: cleaned_combined_datasets_low_quality_summary_table.csv
2025-11-15 20:17:02,859 - __main__ - INFO - Dataset name: cleaned_combined_datasets_medium_quality_summary_table.csv
2025-11-15 20:17:02,860 - __main__ - INFO - Dataset name: cleaned_combined_datasets_medium_quality.csv
2025-11-15 20:17:02,860 - __main__ - INFO - Dataset name: cleaned_combined_datasets_low_medium_high_quality.csv
2025-11-15 20:17:02,860 - __main__ - INFO - Dataset name: cleaned_combined_datasets_high_quality.csv
2025-11-15 20:17:02,860 - __main__ - INFO - Dataset name: cleaned

In [7]:
fpgen = rdFingerprintGenerator.GetMorganGenerator()

In [11]:
# Load input datasets
datasets = {
    "high": pd.read_csv(base_data_dir / "cleaned_combined_datasets_high_quality.csv"),
    "medium": pd.read_csv(
        base_data_dir / "cleaned_combined_datasets_medium_high_quality.csv", low_memory=False
    ),
    "low": pd.read_csv(
        base_data_dir / "cleaned_combined_datasets_low_medium_high_quality.csv", low_memory=False
    ),
}

for name, df in datasets.items():
    logger.info(f"Dataset: {name}, shape: {df.shape}")
    logger.info(f"Columns: {df.columns.tolist()}")
    logger.info(f"Unique Dataset Constituents: {df['Dataset'].unique()}")

2025-11-15 20:20:36,071 - __main__ - INFO - Dataset: high, shape: (5326, 12)
2025-11-15 20:20:36,072 - __main__ - INFO - Columns: ['Molecule Name', 'SMILES', 'Dataset', 'LogD', 'KSOL', 'HLM CLint', 'MLM CLint', 'Caco-2 Permeability Papp A>B', 'Caco-2 Permeability Efflux', 'MPPB', 'MBPB', 'MGMB']
2025-11-15 20:20:36,080 - __main__ - INFO - Unique Dataset Constituents: ['expansionrx']
2025-11-15 20:20:36,081 - __main__ - INFO - Dataset: medium, shape: (94708, 12)
2025-11-15 20:20:36,081 - __main__ - INFO - Columns: ['Molecule Name', 'SMILES', 'Dataset', 'LogD', 'KSOL', 'HLM CLint', 'MLM CLint', 'Caco-2 Permeability Papp A>B', 'Caco-2 Permeability Efflux', 'MPPB', 'MBPB', 'MGMB']
2025-11-15 20:20:36,083 - __main__ - INFO - Unique Dataset Constituents: ['expansionrx' 'kermt_public']
2025-11-15 20:20:36,084 - __main__ - INFO - Dataset: low, shape: (116527, 12)
2025-11-15 20:20:36,084 - __main__ - INFO - Columns: ['Molecule Name', 'SMILES', 'Dataset', 'LogD', 'KSOL', 'HLM CLint', 'MLM CLint'

In [10]:
n_folds = 5
n_splits = 5

split_list = [
    "random_cluster",
    "scaffold_cluster",
    "butina_cluster",
    # "umap_cluster",
]
split_dict = {
    "random_cluster": uru.get_random_clusters,
    "scaffold_cluster": uru.get_bemis_murcko_clusters,
    "butina_cluster": uru.get_butina_clusters,
    # "umap_cluster": uru.get_umap_clusters,
}

In [None]:
# FIXME: update the clusters and splitting to be stratified based on dataset

In [None]:
result_list = []
for split in split_list:
    for i in tqdm(range(0, n_folds), desc=split):
        cluster_list = split_dict[split](size_df.SMILES)
        group_kfold_shuffle = uru.GroupKFoldShuffle(n_splits=n_splits, shuffle=True)
        for train, test in group_kfold_shuffle.split(np.stack(size_df.fp), size_df.logS, cluster_list):
            result_list.append([split, len(test)])
result_df = pd.DataFrame(result_list, columns=["split", "num_test"])
result_df.to_csv(output_dir / "dataset_split_sizes.csv", index=False)

In [None]:
sns.set_style("whitegrid")
ax = sns.boxplot(x="split", y="num_test", data=result_df)
ax.set_xlabel("Dataset Splitting Strategy")
ax.set_ylabel("Test Set Size")