# 0) Modules & Functions

In [1]:
import pandas as pd 
from tqdm import tqdm
import os
import re
from rdkit import Chem
from rdkit.Chem import inchi
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from scipy.stats import wasserstein_distance

# Compute InChIKey from SMILES
def smiles_to_inchikey(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    try:
        return inchi.MolToInchiKey(mol)
    except:
        return None
    
def subscript(text):
    return f"$_{{{text}}}$"

def _replace_subscript(name):
    name = name.replace('Caco-2-2', 'Caco-2')
    name = name.replace('(Rat)', '(RAT)')
    name = name.replace('(Human,', '(HUMAN,')
    name = name.replace('(Rat,', '(RAT,')
    name = name.replace('(po, Rat,', '(po, RAT,')
    name = name.replace('(Mouse,', '(MOUSE,')
    name = name.replace('(Mouse)', '(MOUSE)')
    name = name.replace('(Rat)', '(RAT)')
    name = name.replace('(Dog,', '(DOG,')
    name = name.replace('Prolif (Cell)', 'Cell-prolif')
    name = name.replace('Half Life Microsome', 'logHalf-life$_{microsomal}$')
    name = name.replace('Half Life Plasma', 'logHalf-life$_{plasma}$')
    name = name.replace('Clearance Microsomal', 'logClearance$_{microsomal}$')
    name = name.replace('Clearance Total', 'logClearance$_{total}$')
    name = name.replace('Clearance Total', 'logClearance$_{total}$')
    name = name.replace('Clearance Renal', 'logClearance$_{renal}$')
    name = name.replace(', Cell)', ')')
    name = name.replace('Growth Tumor', 'Tumor-growth')
    name = name.replace('logClearance$_{total}$ iv (RAT,', 'logClearance$_{total}$ (iv, RAT,')
    name = name.replace('Stability Microsomal', 'Stability$_{microsomal}$')
    name = name.replace('pVD', 'logVD')
    return name

def clean_endpoint_name(name):
    name = name.replace('.csv', '')

    # Replace raw units with readable forms (keep lowercase)
    name = name.replace('pmL%min%kg', 'µL·min⁻¹·kg⁻¹')
    name = name.replace('pmg%kg', 'mg·kg⁻¹')
    name = name.replace('pHours', 'h')
    name = name.replace('pL%kg', 'L·kg⁻¹')
    name = name.replace('mol%kg', 'mol·kg⁻¹')

    # Remove prefixes
    name = re.sub(r'^(PUBLIC|BindingDB)___', '', name)

    # Replace separators and clean up
    name = name.replace('-', ' ').replace('_', ' ').replace('Public', '')
    name = re.sub(r'\s+', ' ', name).strip()

    # Remove "Biochemical"
    name = re.sub(r'\bBiochemical\b', '', name)

    # Fix substrings for cell types, models
    name = re.sub(r'\bCaco2\b', 'Caco-2', name)
    name = re.sub(r'\bCaco\b', 'Caco-2', name)
    name = re.sub(r'MDCK LE\b', 'MDCK-LE', name)
    name = re.sub(r'MDCK MDR1\b', 'MDCK-MDR1', name)
    name = re.sub(r'MCDK MDR1\b', 'MDCK-MDR1', name)

    # Apply LaTeX-style subscripts
    name = re.sub(r'\bp(IC|GI|LD|TD)50\b', lambda m: f"p{m.group(1)}{subscript('50')}", name)
    name = re.sub(r'\bpVDss\b', lambda _: f"pVD{subscript('ss')}", name)
    name = re.sub(r'\bpKi\b', lambda _: f"pK{subscript('i')}", name)
    name = re.sub(r'\bpKd\b', lambda _: f"pK{subscript('d')}", name)

    # logX formatting with special subscripts
    name = re.sub(r'\bLogD74\b', lambda _: f"logD{subscript('7-4')}", name)
    name = re.sub(r'\bLogPapp\b', lambda _: f"logP{subscript('app')}", name)
    name = re.sub(r'\bLogPow\b', lambda _: f"logP{subscript('ow')}", name)
    name = re.sub(r'\bLogS Water\b', lambda _: f"logS{subscript('water')}", name)
    name = re.sub(r'\bLogS Apparent\b', lambda _: f"logS{subscript('buffer')}", name)
    name = re.sub(r'\bLog([A-Z])', r'log\1', name)


    # Tokenize and collect metadata
    tokens = name.split()
    main = []
    meta = []

    # Reference species, sources, tissues, etc.
    meta_whitelist = {
        'MOUSE', 'RAT', 'HUMAN', 'DOG', 'HAMSTER', 'CELL',
        'PO', 'CAVPO', 'TETCF', 'CANEN', 'MYCTU', 'SARS', 'SARS2',
        'MDCK-MDR1', 'MDCK-LE', 'Caco-2', 'PAMPA', 'BOVIN', 'HV1BR', 
        'HV1H2', 'PIG', 'BACAN', 'TRYCR', 'CANAL', 'ECOLX',  'ECOLI', 
        'MESAU', 'HELPY', 'HCMVA', 'CHICK', "HCMVA", 'RICCO', 'MACFA', 
        'PLAF7','SPIOL', 'HV1N5', 'HV1H2', 'HV1BR', 'HV1B1', 'HCVJ4', 
        'HCVCO', 'HCVBK', 'HCV77', 'SHEEP', 'I34A1', 'CARPA','MELGA', 
        'INBLE', 'I77AB', 'I75A5', 'I34A1', 'I33A0', 'CLOPF', 'HELPJ', 
        'YEAST', 'PHOPY', 'PSEAE', 'HHV1S', "PORG3", 'MYCSM', 'STAAU',
        'HBVD3','ENTFA','HV1H2','PNECA', 'TOXGO', 'PLAFK', 'LEIMA', 'STAAU',
        'STAAM', 'TRYCR', 'STAAE', 'RABIT', 'AGABI', 'BACFG', 'BACSU', 'CLOBO', 'ELEEL', 'ENTCL', 'HORSE', 'KLEPN', 'LACCA', 'SERMA', 'STAAR'
    }

    unit_patterns = ['µL·min⁻¹·kg⁻¹', 'mg·kg⁻¹', 'L·kg⁻¹', 'mol·kg⁻¹', 'h']

    for token in tokens:
        token_upper = token.upper()
        if token in unit_patterns or token in meta_whitelist or token_upper in meta_whitelist:
            meta.append(token)
        elif token_upper in meta_whitelist:
            meta.append(token_upper)
        elif token.isupper() and len(token) <= 6:  # likely gene or protein name
            main.append(token)
        else:
            main.append(token)

    # Remove duplicates from meta, preserve order
    seen = set()
    meta_clean = []
    for m in meta:
        if m not in seen:
            seen.add(m)
            meta_clean.append(m)


    if meta_clean:
        return _replace_subscript(f"{' '.join(main)} ({', '.join(meta_clean)})")
    else:
        return _replace_subscript(' '.join(main))
    

def stratified_split_with_distribution_analysis(df, n_bins=5, min_data_points=5000):
    df = df.copy()
    
    # Step 1: Filter columns with enough data
    target_cols = df.drop(columns=['SMILES']).columns[df.drop(columns=['SMILES']).notna().sum() > min_data_points]
    df_targets = df[target_cols]

    # Step 2: Impute missing values
    df_imputed = df_targets.fillna(df_targets.median())

    # Step 3: Quantize
    df_binned = df_imputed.apply(lambda col: pd.qcut(col, q=n_bins, duplicates='drop', labels=False), axis=0)

    # Step 4: Stratified split
    X = df['SMILES'].values
    Y = df_binned.values
    mskf = MultilabelStratifiedKFold(n_splits=3, shuffle=False)
    sets = ['A', 'B', 'C']
    split_indices = []
    for _, test_index in mskf.split(X, Y):
        split_indices.append(test_index)
    
    # Step 5: Assign set labels
    df['Set'] = ''
    for i, indices in enumerate(split_indices):
        df.iloc[indices, df.columns.get_loc('Set')] = sets[i]

    # Step 6: Print percentage per set
    set_counts = df['Set'].value_counts(normalize=True) * 100

    return df

# 1) Data loading

In [2]:
df_full = pd.read_parquet('../data/exp/oneADMET.parquet')

# 2) Filter by dataset size

In [3]:
# Assuming df_full is your DataFrame
min_non_nan = 50
max_non_nan = 10000

valid_cols = [

    col for col in tqdm(df_full.columns)
    if min_non_nan <= df_full[col].notna().sum() <= max_non_nan
]

# If you want to exclude 'SMILES' and 'Set' columns from the check:
valid_cols = [
    col for col in tqdm(df_full.columns)
    if col not in ['SMILES', 'Set'] and min_non_nan <= df_full[col].notna().sum() <= max_non_nan
]

to_remove = [
    "PUBLIC___pIC50_CYP1A2_HUMAN_Biochemical_Public.csv",
    "PUBLIC___pIC50_CYP2C19_HUMAN_Biochemical_Public.csv",
    "PUBLIC___pIC50_CYP2C9_HUMAN_Biochemical_Public.csv",
    "PUBLIC___pIC50_CYP2D6_HUMAN_Biochemical_Public.csv",
    "PUBLIC___pIC50_CYP3A4_HUMAN_Biochemical_Public.csv",
    "pLD50_Unknown_Public.csv",
    "PUBLIC___LogHydrationFreeEnergy_Public.csv",
    "PPB_Human_Public.csv",
    "BindingDB___pIC50_CYP19A_HUMAN_Biochemical.csv",
    "BindingDB___pIC50_CYP1A1_HUMAN_Biochemical.csv",
    "BindingDB___pIC50_CYP1A2_HUMAN_Biochemical.csv",
    "BindingDB___pIC50_CYP1B1_HUMAN_Biochemical.csv",
    "BindingDB___pIC50_CYP2B6_HUMAN_Biochemical.csv",
    "BindingDB___pIC50_CYP4F2_HUMAN_Biochemical.csv"
]

filtered_cols = [
    col for col in tqdm(df_full.columns)
    if col not in ['SMILES', 'Set']
    and min_non_nan <= df_full[col].notna().sum() <= max_non_nan
    and col not in to_remove
]

# Keep only SMILES, Set, and the filtered columns
df_filtered = df_full[["SMILES"] + filtered_cols]

# Drop rows where all values in filtered_cols are NaN
df_filtered = df_filtered.dropna(subset=filtered_cols, how='all')

100%|██████████| 1535/1535 [00:01<00:00, 1185.55it/s]
100%|██████████| 1535/1535 [00:00<00:00, 2847.96it/s]
100%|██████████| 1535/1535 [00:00<00:00, 2851.34it/s]


# 2.5) Assign set

In [None]:
df_split = stratified_split_with_distribution_analysis(df_filtered)

# 3) Keep large enough test

In [None]:
set_out = []
for col in tqdm(filtered_cols):
    # Keep only SMILES, Set, and current column
    df_unique = df_split[["SMILES", "Set", col]].dropna()

    # Calculate % of each set
    set_counts = df_unique["Set"].value_counts(normalize=True)

    # Print if any set has more than 10%
    for set_name, frac in set_counts.items():
        if frac < 0.25:
            set_out.append(col)
            print(f" {col} {frac:.2%} in {set_name} set")


  4%|▍         | 56/1343 [00:01<00:35, 36.19it/s]

 BindingDB___pIC50_ABL2_HUMAN_Biochemical.csv 24.68% in A set
 BindingDB___pIC50_ACACA_RAT_Biochemical.csv 24.41% in B set


  7%|▋         | 92/1343 [00:02<00:34, 36.73it/s]

 BindingDB___pIC50_ADA2A_HUMAN_Biochemical.csv 24.80% in A set
 BindingDB___pIC50_ADRB3_HUMAN_Biochemical.csv 20.00% in A set


 10%|█         | 136/1343 [00:03<00:33, 36.49it/s]

 BindingDB___pIC50_ARBK1_BOVIN_Biochemical.csv 21.13% in C set


 18%|█▊        | 244/1343 [00:06<00:32, 33.41it/s]

 BindingDB___pIC50_CP2B1_RAT_Biochemical.csv 22.22% in B set


 20%|█▉        | 264/1343 [00:07<00:32, 33.35it/s]

 BindingDB___pIC50_DDB1_HUMAN_Biochemical.csv 24.74% in C set
 BindingDB___pIC50_DEF_STAAM_Biochemical.csv 20.59% in C set


 21%|██        | 276/1343 [00:07<00:31, 33.57it/s]

 BindingDB___pIC50_DHSO_HUMAN_Biochemical.csv 23.29% in C set
 BindingDB___pIC50_DNLI1_HUMAN_Biochemical.csv 23.08% in B set
 BindingDB___pIC50_DNLI1_HUMAN_Biochemical.csv 21.15% in C set


 23%|██▎       | 304/1343 [00:08<00:29, 34.84it/s]

 BindingDB___pIC50_ECE1_RAT_Biochemical.csv 20.25% in B set


 28%|██▊       | 376/1343 [00:10<00:27, 35.57it/s]

 BindingDB___pIC50_GBRA1_HUMAN_Biochemical.csv 21.74% in B set
 BindingDB___pIC50_GCKR_HUMAN_Biochemical.csv 24.24% in C set


 29%|██▉       | 392/1343 [00:11<00:27, 34.94it/s]

 BindingDB___pIC50_GRIK1_HUMAN_Biochemical.csv 18.64% in C set


 35%|███▌      | 476/1343 [00:13<00:24, 34.88it/s]

 BindingDB___pIC50_KDM4D_HUMAN_Biochemical.csv 24.18% in A set


 37%|███▋      | 496/1343 [00:14<00:23, 35.33it/s]

 BindingDB___pIC50_KS6A1_HUMAN_Biochemical.csv 20.91% in A set


 38%|███▊      | 512/1343 [00:14<00:24, 34.58it/s]

 BindingDB___pIC50_LMBL1_HUMAN_Biochemical.csv 23.60% in B set


 39%|███▊      | 520/1343 [00:14<00:24, 34.29it/s]

 BindingDB___pIC50_LST8_HUMAN_Biochemical.csv 22.12% in C set


 41%|████      | 548/1343 [00:15<00:22, 34.69it/s]

 BindingDB___pIC50_MDR1B_MOUSE_Biochemical.csv 18.57% in B set
 BindingDB___pIC50_MEP1B_HUMAN_Biochemical.csv 16.88% in C set


 44%|████▍     | 592/1343 [00:16<00:21, 34.74it/s]

 BindingDB___pIC50_MYLK_CHICK_Biochemical.csv 19.19% in A set


 47%|████▋     | 632/1343 [00:17<00:19, 35.67it/s]

 BindingDB___pIC50_NRP1_HUMAN_Biochemical.csv 21.35% in A set


 48%|████▊     | 648/1343 [00:18<00:19, 35.75it/s]

 BindingDB___pIC50_OXER1_HUMAN_Biochemical.csv 24.53% in C set


 52%|█████▏    | 700/1343 [00:19<00:17, 37.62it/s]

 BindingDB___pIC50_PDK4_RAT_Biochemical.csv 24.56% in A set


 54%|█████▍    | 724/1343 [00:20<00:16, 37.72it/s]

 BindingDB___pIC50_PK3CB_RAT_Biochemical.csv 20.37% in C set
 BindingDB___pIC50_PK3CG_MOUSE_Biochemical.csv 24.73% in B set


 55%|█████▍    | 736/1343 [00:20<00:16, 37.77it/s]

 BindingDB___pIC50_PMYT1_HUMAN_Biochemical.csv 18.18% in C set


 59%|█████▉    | 792/1343 [00:22<00:14, 37.87it/s]

 BindingDB___pIC50_RPOB_ECOLI_Biochemical.csv 20.37% in C set


 61%|██████    | 816/1343 [00:22<00:14, 36.29it/s]

 BindingDB___pIC50_SC6A9_RAT_Biochemical.csv 24.32% in A set


 63%|██████▎   | 852/1343 [00:23<00:13, 36.49it/s]

 BindingDB___pIC50_SRC_CHICK_Biochemical.csv 23.93% in B set


 64%|██████▍   | 860/1343 [00:24<00:13, 35.47it/s]

 BindingDB___pIC50_STAT6_HUMAN_Biochemical.csv 22.22% in A set
 BindingDB___pIC50_SUMO1_HUMAN_Biochemical.csv 13.56% in B set


 65%|██████▌   | 876/1343 [00:24<00:12, 35.98it/s]

 BindingDB___pIC50_THA_HUMAN_Biochemical.csv 24.29% in A set


 68%|██████▊   | 916/1343 [00:25<00:12, 35.19it/s]

 BindingDB___pIC50_UT2_RAT_Biochemical.csv 18.87% in A set


 69%|██████▉   | 932/1343 [00:26<00:11, 35.29it/s]

 BindingDB___pIC50_WNT3A_MOUSE_Biochemical.csv 23.94% in B set
 BindingDB___pKd_ABL1_HUMAN_Biochemical.csv 24.00% in C set
 BindingDB___pKd_ACES_HUMAN_Biochemical.csv 22.78% in C set
 BindingDB___pKd_ACM1_HUMAN_Biochemical.csv 24.53% in C set


 71%|███████   | 952/1343 [00:26<00:10, 36.18it/s]

 BindingDB___pKd_FAK1_HUMAN_Biochemical.csv 24.53% in A set


 71%|███████▏  | 960/1343 [00:26<00:10, 35.36it/s]

 BindingDB___pKd_LEG7_HUMAN_Biochemical.csv 17.74% in B set
 BindingDB___pKd_LEG9_HUMAN_Biochemical.csv 21.79% in B set
 BindingDB___pKd_MET_HUMAN_Biochemical.csv 24.07% in C set
 BindingDB___pKd_P3C2A_HUMAN_Biochemical.csv 24.68% in C set


 72%|███████▏  | 968/1343 [00:27<00:10, 35.75it/s]

 BindingDB___pKd_P3C2B_HUMAN_Biochemical.csv 23.53% in C set
 BindingDB___pKd_RASK_HUMAN_Biochemical.csv 24.86% in A set


 73%|███████▎  | 980/1343 [00:27<00:10, 36.24it/s]

 BindingDB___pKd_S1PR2_HUMAN_Biochemical.csv 22.41% in B set
 BindingDB___pKd_SHBG_HUMAN_Biochemical.csv 24.10% in B set


 77%|███████▋  | 1036/1343 [00:29<00:08, 36.36it/s]

 BindingDB___pKi_ADRB1_RAT_Biochemical.csv 23.01% in B set
 BindingDB___pKi_ANDR_RAT_Biochemical.csv 23.21% in A set


 79%|███████▊  | 1056/1343 [00:29<00:07, 36.25it/s]

 BindingDB___pKi_BLAT_ECOLX_Biochemical.csv 18.92% in B set


 80%|████████  | 1080/1343 [00:30<00:07, 36.58it/s]

 BindingDB___pKi_CAN2_HUMAN_Biochemical.csv 24.82% in C set
 BindingDB___pKi_CBPA1_HUMAN_Biochemical.csv 23.73% in B set


 85%|████████▌ | 1144/1343 [00:32<00:05, 35.72it/s]

 BindingDB___pKi_GBA1_HUMAN_Biochemical.csv 24.00% in A set


 88%|████████▊ | 1176/1343 [00:32<00:04, 36.37it/s]

 BindingDB___pKi_I23O1_HUMAN_Biochemical.csv 22.39% in C set


 89%|████████▉ | 1192/1343 [00:33<00:04, 36.99it/s]

 BindingDB___pKi_LGUL_HUMAN_Biochemical.csv 22.34% in B set


 93%|█████████▎| 1252/1343 [00:35<00:02, 35.17it/s]

 BindingDB___pKi_P4K2A_HUMAN_Biochemical.csv 24.00% in A set


 94%|█████████▍| 1260/1343 [00:35<00:02, 34.10it/s]

 BindingDB___pKi_PDE5A_HUMAN_Biochemical.csv 23.96% in C set


 96%|█████████▌| 1284/1343 [00:36<00:01, 33.14it/s]

 BindingDB___pKi_PYGM_HUMAN_Biochemical.csv 20.00% in A set


100%|██████████| 1343/1343 [00:37<00:00, 35.46it/s]

 BindingDB___pKi_VACHT_TETCF_Biochemical.csv 24.07% in B set





# 4) Filter out columns

In [None]:
to_remove_all = to_remove + set_out

filtered_cols_full = [
    col for col in tqdm(df_split.columns)
    if col not in ['SMILES', 'Set']
    and min_non_nan <= df_split[col].notna().sum() <= max_non_nan
    and col not in to_remove_all
    and col not in set_out
]

# Keep only SMILES, Set, and the filtered columns
df_filtered_MTL = df_split[["SMILES", "Set"] + filtered_cols_full]


100%|██████████| 1345/1345 [00:00<00:00, 2599.28it/s]


In [None]:
df_filtered_MTL

Unnamed: 0,SMILES,Set,Clearance-Microsomal_Mouse-pmL%min%kg_Public.csv,Clearance-Microsomal_Rat-pmL%min%kg_Public.csv,Clearance-Renal_pmL%min%kg_Public.csv,Clearance-Total_iv-Rat-pmL%min%kg_Public.csv,Half-Life_Human-Microsome-pHours_Public.csv,Half-Life_Human-Plasma-pHours_Public.csv,Half-Life_Rat-Microsome-pHours_Public.csv,Half-Life_Rat-Plasma-pHours_Public.csv,...,BindingDB___pKi_V1AR_RAT_Biochemical.csv,BindingDB___pKi_V1BR_HUMAN_Biochemical.csv,BindingDB___pKi_VACHT_HUMAN_Biochemical.csv,BindingDB___pKi_VACHT_RAT_Biochemical.csv,BindingDB___pKi_VDR_HUMAN_Biochemical.csv,BindingDB___pKi_VMAT2_RAT_Biochemical.csv,BindingDB___pKi_WDR5_HUMAN_Biochemical.csv,pIC50_Cell-Prolif_Public.csv,PUBLIC___pIC50_hERG_HAMSTER_Cell_Public.csv,PUBLIC___pKi_hERG_HAMSTER_Cell_Public.csv
5,CC(NC(=O)C(Cc1ccccc1)CP(O)(=O)C([NH3+])Cc1cccc...,B,,,,,,,,,...,,,,,,,,,,
6,COc1ccc2cc(C(=O)NCCc3ccc(N)cc3)c(=O)[nH]c2c1OCCCC,B,,,,,,,,,...,,,,,,,,,,
8,NC(=O)c1[n][n](CC(=O)N2C(CC3CC23)C(=O)NCC2CC2(...,B,,,,,,,,,...,,,,,,,,,,
9,CN(C)c1[n]c2c(cc(c(=O)[n]2C)-c2c(Cl)cccc2Cl)c[n]1,B,,,,,,,,,...,,,,,,,,,,
12,Cc1o[n]cc1C(=O)Nc1cc(F)c(cc1)-[n]1[n]c(cc1C1CC...,B,,,,,,,,,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
213183,Cc1ccc(C[n]2[n]cc3[n]c([nH]c23)-c2ccc(cc2)N2CC...,C,,,,,,,,,...,,,,,,,,,,
213187,CC1(C)C2CCC1(c1[n][n]c(cc12)-c1c(F)cccc1F)c1c[...,C,,,,,,,,,...,,,,,,,,,,
213188,O=C(Nc1c[n]2[n]c(ccc2[n]1)-c1c([n]c2cccc[n]21)...,A,,,,,,,,,...,,,,,,,,,,
213189,CC1(C)CC(c2ccc(Cl)cc2)=C(CN2CCN(CC2)c2ccc(c(c2...,B,,,,,,,,,...,,,,,,,,,,


# 5) Rename the columns

In [None]:
# Usage
cleaned_names = [clean_endpoint_name(name) for name in filtered_cols_full]
endpoint_name_mapping = dict(zip(filtered_cols_full, cleaned_names))

for old_name, new_name in tqdm(endpoint_name_mapping.items()):
    # Optionally, you can rename the columns in the DataFrame
    df_filtered_MTL.rename(columns={old_name: new_name}, inplace=True)

df_filtered_MTL_cols = df_filtered_MTL.columns.tolist()[2:]

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_filtered_MTL.rename(columns={old_name: new_name}, inplace=True)
100%|██████████| 1284/1284 [00:00<00:00, 6530.20it/s]


# 6) Prepare MTL data

In [None]:
# Drop rows where all values in filtered_cols are NaN
try:
    os.mkdir('../data/exp/MTL')
except FileExistsError:
    pass

df_filtered_MTL = df_filtered_MTL.dropna(subset=df_filtered_MTL_cols, how='all')

df_filtered_MTL['InChIKey'] = df_filtered_MTL['SMILES'].apply(smiles_to_inchikey)

# Insert InChIKey as first column
cols = ['InChIKey'] + [col for col in df_filtered_MTL.columns if col != 'InChIKey']
df_filtered_MTL = df_filtered_MTL[cols]

df_filtered_MTL.to_parquet('../data/exp/MTL/oneADMET_LR-MTL.parquet', index=False)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_filtered_MTL['InChIKey'] = df_filtered_MTL['SMILES'].apply(smiles_to_inchikey)


# 7) Prepare STL data

In [None]:
# Prepare STL data
try:
    os.mkdir('../data/exp/STL')
except FileExistsError:
    pass
for col in tqdm(df_filtered_MTL_cols):
    # Keep only SMILES, Set, and current column
    df_filtered_STL = df_filtered_MTL[['InChIKey', "SMILES", "Set", col]].dropna()
    df_filtered_STL.columns = ['InChIKey', "SMILES", "SET", "Y"]

    # Calculate % of each set
    set_counts = df_filtered_STL["SET"].value_counts(normalize=True)

    if len(df_filtered_STL.drop_duplicates("InChIKey")) != len(df_filtered_STL):
        print(f"Duplicate InChIKey found in {col}")

    # Save the filtered dataframe
    df_filtered_STL.to_parquet(f'../data/exp/STL/oneADMET_LR-STL---{col}.parquet', index=False)



100%|██████████| 1284/1284 [00:53<00:00, 24.18it/s]
