# 0) Modules & Functions

In [11]:
import pandas as pd 
from tqdm import tqdm
import os
import re
from rdkit import Chem
from rdkit.Chem import inchi


# Compute InChIKey from SMILES
def smiles_to_inchikey(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    try:
        return inchi.MolToInchiKey(mol)
    except:
        return None
    
def subscript(text):
    return f"$_{{{text}}}$"

def _replace_subscript(name):
    name = name.replace('Caco-2-2', 'Caco-2')
    name = name.replace('(Rat)', '(RAT)')
    name = name.replace('(Human,', '(HUMAN,')
    name = name.replace('(Rat,', '(RAT,')
    name = name.replace('(po, Rat,', '(po, RAT,')
    name = name.replace('(Mouse,', '(MOUSE,')
    name = name.replace('(Mouse)', '(MOUSE)')
    name = name.replace('(Rat)', '(RAT)')
    name = name.replace('(Dog,', '(DOG,')
    name = name.replace('Prolif (Cell)', 'Cell-prolif')
    name = name.replace('Half Life Microsome', 'logHalf-life$_{microsomal}$')
    name = name.replace('Half Life Plasma', 'logHalf-life$_{plasma}$')
    name = name.replace('Clearance Microsomal', 'logClearance$_{microsomal}$')
    name = name.replace('Clearance Total', 'logClearance$_{total}$')
    name = name.replace('Clearance Total', 'logClearance$_{total}$')
    name = name.replace('Clearance Renal', 'logClearance$_{renal}$')
    name = name.replace(', Cell)', ')')
    name = name.replace('Growth Tumor', 'Tumor-growth')
    name = name.replace('logClearance$_{total}$ iv (RAT,', 'logClearance$_{total}$ (iv, RAT,')
    name = name.replace('Stability Microsomal', 'Stability$_{microsomal}$')
    name = name.replace('pVD', 'logVD')
    return name

def clean_endpoint_name(name):
    name = name.replace('.csv', '')

    # Replace raw units with readable forms (keep lowercase)
    name = name.replace('pmL%min%kg', 'µL·min⁻¹·kg⁻¹')
    name = name.replace('pmg%kg', 'mg·kg⁻¹')
    name = name.replace('pHours', 'h')
    name = name.replace('pL%kg', 'L·kg⁻¹')
    name = name.replace('mol%kg', 'mol·kg⁻¹')

    # Remove prefixes
    name = re.sub(r'^(PUBLIC|BindingDB)___', '', name)

    # Replace separators and clean up
    name = name.replace('-', ' ').replace('_', ' ').replace('Public', '')
    name = re.sub(r'\s+', ' ', name).strip()

    # Remove "Biochemical"
    name = re.sub(r'\bBiochemical\b', '', name)

    # Fix substrings for cell types, models
    name = re.sub(r'\bCaco2\b', 'Caco-2', name)
    name = re.sub(r'\bCaco\b', 'Caco-2', name)
    name = re.sub(r'MDCK LE\b', 'MDCK-LE', name)
    name = re.sub(r'MDCK MDR1\b', 'MDCK-MDR1', name)
    name = re.sub(r'MCDK MDR1\b', 'MDCK-MDR1', name)

    # Apply LaTeX-style subscripts
    name = re.sub(r'\bp(IC|GI|LD|TD)50\b', lambda m: f"p{m.group(1)}{subscript('50')}", name)
    name = re.sub(r'\bpVDss\b', lambda _: f"pVD{subscript('ss')}", name)
    name = re.sub(r'\bpKi\b', lambda _: f"pK{subscript('i')}", name)
    name = re.sub(r'\bpKd\b', lambda _: f"pK{subscript('d')}", name)

    # logX formatting with special subscripts
    name = re.sub(r'\bLogD74\b', lambda _: f"logD{subscript('7-4')}", name)
    name = re.sub(r'\bLogPapp\b', lambda _: f"logP{subscript('app')}", name)
    name = re.sub(r'\bLogPow\b', lambda _: f"logP{subscript('ow')}", name)
    name = re.sub(r'\bLogS Water\b', lambda _: f"logS{subscript('water')}", name)
    name = re.sub(r'\bLogS Apparent\b', lambda _: f"logS{subscript('buffer')}", name)
    name = re.sub(r'\bLog([A-Z])', r'log\1', name)


    # Tokenize and collect metadata
    tokens = name.split()
    main = []
    meta = []

    # Reference species, sources, tissues, etc.
    meta_whitelist = {
        'MOUSE', 'RAT', 'HUMAN', 'DOG', 'HAMSTER', 'CELL',
        'PO', 'CAVPO', 'TETCF', 'CANEN', 'MYCTU', 'SARS', 'SARS2',
        'MDCK-MDR1', 'MDCK-LE', 'Caco-2', 'PAMPA', 'BOVIN', 'HV1BR', 
        'HV1H2', 'PIG', 'BACAN', 'TRYCR', 'CANAL', 'ECOLX',  'ECOLI', 
        'MESAU', 'HELPY', 'HCMVA', 'CHICK', "HCMVA", 'RICCO', 'MACFA', 
        'PLAF7','SPIOL', 'HV1N5', 'HV1H2', 'HV1BR', 'HV1B1', 'HCVJ4', 
        'HCVCO', 'HCVBK', 'HCV77', 'SHEEP', 'I34A1', 'CARPA','MELGA', 
        'INBLE', 'I77AB', 'I75A5', 'I34A1', 'I33A0', 'CLOPF', 'HELPJ', 
        'YEAST', 'PHOPY', 'PSEAE', 'HHV1S', "PORG3", 'MYCSM', 'STAAU',
        'HBVD3','ENTFA','HV1H2','PNECA', 'TOXGO', 'PLAFK', 'LEIMA', 'STAAU',
        'STAAM', 'TRYCR', 'STAAE', 'RABIT', 'AGABI', 'BACFG', 'BACSU', 'CLOBO', 'ELEEL', 'ENTCL', 'HORSE', 'KLEPN', 'LACCA', 'SERMA', 'STAAR'
    }

    unit_patterns = ['µL·min⁻¹·kg⁻¹', 'mg·kg⁻¹', 'L·kg⁻¹', 'mol·kg⁻¹', 'h']

    for token in tokens:
        token_upper = token.upper()
        if token in unit_patterns or token in meta_whitelist or token_upper in meta_whitelist:
            meta.append(token)
        elif token_upper in meta_whitelist:
            meta.append(token_upper)
        elif token.isupper() and len(token) <= 6:  # likely gene or protein name
            main.append(token)
        else:
            main.append(token)

    # Remove duplicates from meta, preserve order
    seen = set()
    meta_clean = []
    for m in meta:
        if m not in seen:
            seen.add(m)
            meta_clean.append(m)


    if meta_clean:
        return _replace_subscript(f"{' '.join(main)} ({', '.join(meta_clean)})")
    else:
        return _replace_subscript(' '.join(main))

# 1) Data loading

In [2]:
df_full = pd.read_parquet('../data/exp/oneADMET.parquet')

# 2) Filter by dataset size

In [3]:
# Assuming df_full is your DataFrame
min_non_nan = 50
max_non_nan = 10000

valid_cols = [

    col for col in tqdm(df_full.columns)
    if min_non_nan <= df_full[col].notna().sum() <= max_non_nan
]

# If you want to exclude 'SMILES' and 'Set' columns from the check:
valid_cols = [
    col for col in tqdm(df_full.columns)
    if col not in ['SMILES', 'Set'] and min_non_nan <= df_full[col].notna().sum() <= max_non_nan
]

to_remove = [
    "PUBLIC___pIC50_CYP1A2_HUMAN_Biochemical_Public.csv",
    "PUBLIC___pIC50_CYP2C19_HUMAN_Biochemical_Public.csv",
    "PUBLIC___pIC50_CYP2C9_HUMAN_Biochemical_Public.csv",
    "PUBLIC___pIC50_CYP2D6_HUMAN_Biochemical_Public.csv",
    "PUBLIC___pIC50_CYP3A4_HUMAN_Biochemical_Public.csv",
    "pLD50_Unknown_Public.csv",
    "PUBLIC___LogHydrationFreeEnergy_Public.csv",
    "PPB_Human_Public.csv",
    "BindingDB___pIC50_CYP19A_HUMAN_Biochemical.csv",
    "BindingDB___pIC50_CYP1A1_HUMAN_Biochemical.csv",
    "BindingDB___pIC50_CYP1A2_HUMAN_Biochemical.csv",
    "BindingDB___pIC50_CYP1B1_HUMAN_Biochemical.csv",
    "BindingDB___pIC50_CYP2B6_HUMAN_Biochemical.csv",
    "BindingDB___pIC50_CYP4F2_HUMAN_Biochemical.csv"
]

filtered_cols = [
    col for col in tqdm(df_full.columns)
    if col not in ['SMILES', 'Set']
    and min_non_nan <= df_full[col].notna().sum() <= max_non_nan
    and col not in to_remove
]

# Keep only SMILES, Set, and the filtered columns
df_filtered = df_full[["SMILES", "Set"] + filtered_cols]

# Drop rows where all values in filtered_cols are NaN
df_filtered = df_filtered.dropna(subset=filtered_cols, how='all')

100%|██████████| 1535/1535 [00:00<00:00, 1671.53it/s]
100%|██████████| 1535/1535 [00:00<00:00, 3150.73it/s]
100%|██████████| 1535/1535 [00:00<00:00, 3127.59it/s]


# 3) Keep large enough test

In [4]:
set_out = []
for col in tqdm(filtered_cols):
    # Keep only SMILES, Set, and current column
    df_filtered = df_full[["SMILES", "Set", col]].dropna()

    # Calculate % of each set
    set_counts = df_filtered["Set"].value_counts(normalize=True)

    # Print if any set has more than 10%
    for set_name, frac in set_counts.items():
        if frac < 0.1 and set_name == "test":
            set_out.append(set_name)

    # Save the filtered dataframe
    # df_filtered.to_parquet(f'../data/exp/MTL/oneADMET_{col}.parquet', index=False)

100%|██████████| 1343/1343 [00:36<00:00, 37.21it/s]


# 4) Filter out columns

In [5]:
to_remove_all = to_remove + set_out

filtered_cols_full = [
    col for col in tqdm(df_full.columns)
    if col not in ['SMILES', 'Set']
    and min_non_nan <= df_full[col].notna().sum() <= max_non_nan
    and col not in to_remove_all
]

# Keep only SMILES, Set, and the filtered columns
df_filtered_MTL = df_full[["SMILES", "Set"] + filtered_cols_full]


100%|██████████| 1535/1535 [00:00<00:00, 2221.45it/s]


# 5) Rename the columns

In [6]:
# Usage
cleaned_names = [clean_endpoint_name(name) for name in filtered_cols_full]
endpoint_name_mapping = dict(zip(filtered_cols_full, cleaned_names))

for old_name, new_name in tqdm(endpoint_name_mapping.items()):
    # Optionally, you can rename the columns in the DataFrame
    df_filtered_MTL.rename(columns={old_name: new_name}, inplace=True)

df_filtered_MTL_cols = df_filtered_MTL.columns.tolist()[2:]

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_filtered_MTL.rename(columns={old_name: new_name}, inplace=True)
100%|██████████| 1343/1343 [00:00<00:00, 6571.93it/s]


# 6) Prepare MTL data

In [12]:
# Drop rows where all values in filtered_cols are NaN
try:
    os.mkdir('../data/exp/MTL')
except FileExistsError:
    pass

df_filtered_MTL = df_filtered_MTL.dropna(subset=df_filtered_MTL_cols, how='all')

df_filtered_MTL['InChIKey'] = df_filtered_MTL['SMILES'].apply(smiles_to_inchikey)

# Insert InChIKey as first column
cols = ['InChIKey'] + [col for col in df_filtered_MTL.columns if col != 'InChIKey']
df_filtered_MTL = df_filtered_MTL[cols]

df_filtered_MTL.to_parquet('../data/exp/MTL/oneADMET_LR-MTL.parquet', index=False)

# 7) Prepare STL data

In [13]:
# Prepare STL data
try:
    os.mkdir('../data/exp/STL')
except FileExistsError:
    pass
for col in tqdm(df_filtered_MTL_cols):
    # Keep only SMILES, Set, and current column
    df_filtered_STL = df_filtered_MTL[['InChIKey', "SMILES", "Set", col]].dropna()
    df_filtered_STL.columns = ['InChIKey', "SMILES", "SET", "Y"]

    # Calculate % of each set
    set_counts = df_filtered_STL["SET"].value_counts(normalize=True)

    if len(df_filtered_STL.drop_duplicates("InChIKey")) != len(df_filtered_STL):
        print(f"Duplicate InChIKey found in {col}")

    # Save the filtered dataframe
    df_filtered_STL.to_parquet(f'../data/exp/STL/oneADMET_LR-STL---{col}.parquet', index=False)



100%|██████████| 1343/1343 [00:52<00:00, 25.49it/s]
