# 0) Modules & Functions

In [6]:
import os
import pandas as pd
from pathlib import Path
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem, MACCSkeys
from sklearn.ensemble import RandomForestRegressor
from xgboost import XGBRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import numpy as np
from tqdm import tqdm
import joblib
import subprocess
import time

RANDOM_STATES = [42, 1337, 29121997]

# === DESCRIPTOR FUNCTIONS ===
def smiles_to_ecfp(smiles, radius=2, nBits=2048):
    mol = Chem.MolFromSmiles(smiles)
    if mol:
        return np.array(AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits))
    return np.zeros(nBits)

def smiles_to_avalon(smiles, nBits=512):
    mol = Chem.MolFromSmiles(smiles)
    if mol:
        return np.array(DataStructs.BitVectToNumPyArray(MACCSkeys.GenMACCSKeys(mol)))
    return np.zeros(nBits)

# === STL PROCESS ===
def run_stl_predictions(data_root, combinations, descriptor_fn, model_fn, model_name, desc_name, output_root="../data/prediction"):
    for combo in combinations:
        for train_source in ["noaug", "aug"]:
            stl_train_dir = f"{data_root}/{train_source}/{combo}/STL"
            stl_test_dir = f"{data_root}/test/{combo}/STL"
            pred_dir = f"{output_root}/{train_source}/{combo}/STL/{model_name}_{desc_name}"
            Path(pred_dir).mkdir(parents=True, exist_ok=True)

            for file in tqdm(os.listdir(stl_train_dir), desc=f"{combo} - {train_source} - {model_name} + {desc_name}"):
                if not file.endswith(".parquet"): continue
                endpoint = os.path.splitext(file)[0]

                train_df = pd.read_parquet(os.path.join(stl_train_dir, file))
                test_df = pd.read_parquet(os.path.join(stl_test_dir, file))

                X_train = np.stack(train_df["SMILES"].map(descriptor_fn))
                y_train = train_df["Y"].values
                X_test = np.stack(test_df["SMILES"].map(descriptor_fn))

                preds_df = test_df[["InChIKey", "SMILES", "AUG", "Y", "STD"]].copy()

                for i, seed in enumerate(RANDOM_STATES):
                    model = model_fn(seed)
                    start_train = time.time()
                    model.fit(X_train, y_train)
                    end_train = time.time()

                    start_test = time.time()
                    preds = model.predict(X_test)
                    end_test = time.time()

                    preds_df[f"pY_rep{i+1}"] = preds
                    preds_df[f"pY_train_time_rep{i+1}"] = end_train - start_train
                    preds_df[f"pY_test_time_rep{i+1}"] = end_test - start_test

                preds_df.to_parquet(os.path.join(pred_dir, file), index=False)
                print(f"Predictions saved to {os.path.join(pred_dir, file)}")

# === MODEL FACTORIES ===
def get_rf(seed):
    return RandomForestRegressor(n_estimators=200, random_state=seed, n_jobs=-1)

def get_xgb(seed):
    return XGBRegressor(n_estimators=200, random_state=seed, n_jobs=-1, verbosity=0)



# 1) Prepare the data

In [7]:
# === ENTRYPOINTS ===
run_stl_predictions("../data", ["BCDEFGHIJ", "ACDEFGHIJ", "ABDEFGHIJ", "ABCEFGHIJ","ABCDFGHIJ", "ABCDEGHIJ", "ABCDEFHIJ", "ABCDEFGIJ", "ABCDEFGHJ","ABCDEFGHI"], smiles_to_ecfp, get_rf, "RF", "ECFP")
# run_stl_predictions("../data", ["AB", "AC", "BC"], smiles_to_avalon, get_xgb, "XGB", "AVALON")
# run_mtl_predictions("../data", ["AB", "AC", "BC"], "../models/chemprop_mtl")


BCDEFGHIJ - noaug - RF + ECFP:   0%|          | 0/1 [00:00<?, ?it/s]

BCDEFGHIJ - noaug - RF + ECFP: 100%|██████████| 1/1 [00:06<00:00,  6.26s/it]


Predictions saved to ../data/prediction/noaug/BCDEFGHIJ/STL/RF_ECFP/oneADMET_LR-STL---pIC$_{50}$ TGFR1 (HUMAN).parquet


BCDEFGHIJ - aug - RF + ECFP: 100%|██████████| 1/1 [00:11<00:00, 11.24s/it]


Predictions saved to ../data/prediction/aug/BCDEFGHIJ/STL/RF_ECFP/oneADMET_LR-STL---pIC$_{50}$ TGFR1 (HUMAN).parquet


ACDEFGHIJ - noaug - RF + ECFP:  50%|█████     | 1/2 [00:06<00:06,  6.17s/it]

Predictions saved to ../data/prediction/noaug/ACDEFGHIJ/STL/RF_ECFP/oneADMET_LR-STL---pIC$_{50}$ TGFR1 (HUMAN).parquet


ACDEFGHIJ - noaug - RF + ECFP: 100%|██████████| 2/2 [00:06<00:00,  3.42s/it]


Predictions saved to ../data/prediction/noaug/ACDEFGHIJ/STL/RF_ECFP/oneADMET_LR-STL---pK$_{i}$ CXCR3 (HUMAN).parquet


ACDEFGHIJ - aug - RF + ECFP:  50%|█████     | 1/2 [00:12<00:12, 12.27s/it]

Predictions saved to ../data/prediction/aug/ACDEFGHIJ/STL/RF_ECFP/oneADMET_LR-STL---pIC$_{50}$ TGFR1 (HUMAN).parquet


ACDEFGHIJ - aug - RF + ECFP: 100%|██████████| 2/2 [00:13<00:00,  6.59s/it]


Predictions saved to ../data/prediction/aug/ACDEFGHIJ/STL/RF_ECFP/oneADMET_LR-STL---pK$_{i}$ CXCR3 (HUMAN).parquet


ABDEFGHIJ - noaug - RF + ECFP:  50%|█████     | 1/2 [00:06<00:06,  6.36s/it]

Predictions saved to ../data/prediction/noaug/ABDEFGHIJ/STL/RF_ECFP/oneADMET_LR-STL---pIC$_{50}$ TGFR1 (HUMAN).parquet


ABDEFGHIJ - noaug - RF + ECFP: 100%|██████████| 2/2 [00:07<00:00,  3.52s/it]


Predictions saved to ../data/prediction/noaug/ABDEFGHIJ/STL/RF_ECFP/oneADMET_LR-STL---pK$_{i}$ CXCR3 (HUMAN).parquet


ABDEFGHIJ - aug - RF + ECFP:  50%|█████     | 1/2 [00:11<00:11, 11.66s/it]

Predictions saved to ../data/prediction/aug/ABDEFGHIJ/STL/RF_ECFP/oneADMET_LR-STL---pIC$_{50}$ TGFR1 (HUMAN).parquet


ABDEFGHIJ - aug - RF + ECFP: 100%|██████████| 2/2 [00:12<00:00,  6.26s/it]


Predictions saved to ../data/prediction/aug/ABDEFGHIJ/STL/RF_ECFP/oneADMET_LR-STL---pK$_{i}$ CXCR3 (HUMAN).parquet


ABCEFGHIJ - noaug - RF + ECFP:  50%|█████     | 1/2 [00:06<00:06,  6.13s/it]

Predictions saved to ../data/prediction/noaug/ABCEFGHIJ/STL/RF_ECFP/oneADMET_LR-STL---pIC$_{50}$ TGFR1 (HUMAN).parquet


ABCEFGHIJ - noaug - RF + ECFP: 100%|██████████| 2/2 [00:06<00:00,  3.40s/it]


Predictions saved to ../data/prediction/noaug/ABCEFGHIJ/STL/RF_ECFP/oneADMET_LR-STL---pK$_{i}$ CXCR3 (HUMAN).parquet


ABCEFGHIJ - aug - RF + ECFP:  50%|█████     | 1/2 [00:12<00:12, 12.43s/it]

Predictions saved to ../data/prediction/aug/ABCEFGHIJ/STL/RF_ECFP/oneADMET_LR-STL---pIC$_{50}$ TGFR1 (HUMAN).parquet


ABCEFGHIJ - aug - RF + ECFP: 100%|██████████| 2/2 [00:13<00:00,  6.68s/it]


Predictions saved to ../data/prediction/aug/ABCEFGHIJ/STL/RF_ECFP/oneADMET_LR-STL---pK$_{i}$ CXCR3 (HUMAN).parquet


ABCDFGHIJ - noaug - RF + ECFP:  50%|█████     | 1/2 [00:06<00:06,  6.30s/it]

Predictions saved to ../data/prediction/noaug/ABCDFGHIJ/STL/RF_ECFP/oneADMET_LR-STL---pIC$_{50}$ TGFR1 (HUMAN).parquet


ABCDFGHIJ - noaug - RF + ECFP: 100%|██████████| 2/2 [00:06<00:00,  3.49s/it]


Predictions saved to ../data/prediction/noaug/ABCDFGHIJ/STL/RF_ECFP/oneADMET_LR-STL---pK$_{i}$ CXCR3 (HUMAN).parquet


ABCDFGHIJ - aug - RF + ECFP:  50%|█████     | 1/2 [00:12<00:12, 12.83s/it]

Predictions saved to ../data/prediction/aug/ABCDFGHIJ/STL/RF_ECFP/oneADMET_LR-STL---pIC$_{50}$ TGFR1 (HUMAN).parquet


ABCDFGHIJ - aug - RF + ECFP: 100%|██████████| 2/2 [00:13<00:00,  6.84s/it]


Predictions saved to ../data/prediction/aug/ABCDFGHIJ/STL/RF_ECFP/oneADMET_LR-STL---pK$_{i}$ CXCR3 (HUMAN).parquet


ABCDEGHIJ - noaug - RF + ECFP:  50%|█████     | 1/2 [00:06<00:06,  6.17s/it]

Predictions saved to ../data/prediction/noaug/ABCDEGHIJ/STL/RF_ECFP/oneADMET_LR-STL---pIC$_{50}$ TGFR1 (HUMAN).parquet


ABCDEGHIJ - noaug - RF + ECFP: 100%|██████████| 2/2 [00:06<00:00,  3.42s/it]


Predictions saved to ../data/prediction/noaug/ABCDEGHIJ/STL/RF_ECFP/oneADMET_LR-STL---pK$_{i}$ CXCR3 (HUMAN).parquet


ABCDEGHIJ - aug - RF + ECFP:  50%|█████     | 1/2 [00:11<00:11, 11.88s/it]

Predictions saved to ../data/prediction/aug/ABCDEGHIJ/STL/RF_ECFP/oneADMET_LR-STL---pIC$_{50}$ TGFR1 (HUMAN).parquet


ABCDEGHIJ - aug - RF + ECFP: 100%|██████████| 2/2 [00:12<00:00,  6.34s/it]


Predictions saved to ../data/prediction/aug/ABCDEGHIJ/STL/RF_ECFP/oneADMET_LR-STL---pK$_{i}$ CXCR3 (HUMAN).parquet


ABCDEFHIJ - noaug - RF + ECFP:  50%|█████     | 1/2 [00:06<00:06,  6.26s/it]

Predictions saved to ../data/prediction/noaug/ABCDEFHIJ/STL/RF_ECFP/oneADMET_LR-STL---pIC$_{50}$ TGFR1 (HUMAN).parquet


ABCDEFHIJ - noaug - RF + ECFP: 100%|██████████| 2/2 [00:06<00:00,  3.47s/it]


Predictions saved to ../data/prediction/noaug/ABCDEFHIJ/STL/RF_ECFP/oneADMET_LR-STL---pK$_{i}$ CXCR3 (HUMAN).parquet


ABCDEFHIJ - aug - RF + ECFP:  50%|█████     | 1/2 [00:13<00:13, 13.07s/it]

Predictions saved to ../data/prediction/aug/ABCDEFHIJ/STL/RF_ECFP/oneADMET_LR-STL---pIC$_{50}$ TGFR1 (HUMAN).parquet


ABCDEFHIJ - aug - RF + ECFP: 100%|██████████| 2/2 [00:13<00:00,  7.00s/it]


Predictions saved to ../data/prediction/aug/ABCDEFHIJ/STL/RF_ECFP/oneADMET_LR-STL---pK$_{i}$ CXCR3 (HUMAN).parquet


ABCDEFGIJ - noaug - RF + ECFP:  50%|█████     | 1/2 [00:06<00:06,  6.19s/it]

Predictions saved to ../data/prediction/noaug/ABCDEFGIJ/STL/RF_ECFP/oneADMET_LR-STL---pIC$_{50}$ TGFR1 (HUMAN).parquet


ABCDEFGIJ - noaug - RF + ECFP: 100%|██████████| 2/2 [00:06<00:00,  3.45s/it]


Predictions saved to ../data/prediction/noaug/ABCDEFGIJ/STL/RF_ECFP/oneADMET_LR-STL---pK$_{i}$ CXCR3 (HUMAN).parquet


ABCDEFGIJ - aug - RF + ECFP:  50%|█████     | 1/2 [00:12<00:12, 12.25s/it]

Predictions saved to ../data/prediction/aug/ABCDEFGIJ/STL/RF_ECFP/oneADMET_LR-STL---pIC$_{50}$ TGFR1 (HUMAN).parquet


ABCDEFGIJ - aug - RF + ECFP: 100%|██████████| 2/2 [00:13<00:00,  6.52s/it]


Predictions saved to ../data/prediction/aug/ABCDEFGIJ/STL/RF_ECFP/oneADMET_LR-STL---pK$_{i}$ CXCR3 (HUMAN).parquet


ABCDEFGHJ - noaug - RF + ECFP:  50%|█████     | 1/2 [00:06<00:06,  6.05s/it]

Predictions saved to ../data/prediction/noaug/ABCDEFGHJ/STL/RF_ECFP/oneADMET_LR-STL---pIC$_{50}$ TGFR1 (HUMAN).parquet


ABCDEFGHJ - noaug - RF + ECFP: 100%|██████████| 2/2 [00:06<00:00,  3.36s/it]


Predictions saved to ../data/prediction/noaug/ABCDEFGHJ/STL/RF_ECFP/oneADMET_LR-STL---pK$_{i}$ CXCR3 (HUMAN).parquet


ABCDEFGHJ - aug - RF + ECFP:  50%|█████     | 1/2 [00:11<00:11, 11.18s/it]

Predictions saved to ../data/prediction/aug/ABCDEFGHJ/STL/RF_ECFP/oneADMET_LR-STL---pIC$_{50}$ TGFR1 (HUMAN).parquet


ABCDEFGHJ - aug - RF + ECFP: 100%|██████████| 2/2 [00:12<00:00,  6.00s/it]


Predictions saved to ../data/prediction/aug/ABCDEFGHJ/STL/RF_ECFP/oneADMET_LR-STL---pK$_{i}$ CXCR3 (HUMAN).parquet


ABCDEFGHI - noaug - RF + ECFP:  33%|███▎      | 1/3 [00:06<00:12,  6.14s/it]

Predictions saved to ../data/prediction/noaug/ABCDEFGHI/STL/RF_ECFP/oneADMET_LR-STL---pIC$_{50}$ TGFR1 (HUMAN).parquet


ABCDEFGHI - noaug - RF + ECFP:  67%|██████▋   | 2/3 [00:06<00:02,  2.93s/it]

Predictions saved to ../data/prediction/noaug/ABCDEFGHI/STL/RF_ECFP/oneADMET_LR-STL---pK$_{i}$ CXCR3 (HUMAN).parquet


ABCDEFGHI - noaug - RF + ECFP: 100%|██████████| 3/3 [00:07<00:00,  2.52s/it]


Predictions saved to ../data/prediction/noaug/ABCDEFGHI/STL/RF_ECFP/oneADMET_LR-STL---pIC$_{50}$ AMPN (HUMAN).parquet


ABCDEFGHI - aug - RF + ECFP:  33%|███▎      | 1/3 [00:11<00:23, 11.74s/it]

Predictions saved to ../data/prediction/aug/ABCDEFGHI/STL/RF_ECFP/oneADMET_LR-STL---pIC$_{50}$ TGFR1 (HUMAN).parquet


ABCDEFGHI - aug - RF + ECFP:  67%|██████▋   | 2/3 [00:12<00:05,  5.27s/it]

Predictions saved to ../data/prediction/aug/ABCDEFGHI/STL/RF_ECFP/oneADMET_LR-STL---pK$_{i}$ CXCR3 (HUMAN).parquet


ABCDEFGHI - aug - RF + ECFP: 100%|██████████| 3/3 [00:14<00:00,  4.75s/it]

Predictions saved to ../data/prediction/aug/ABCDEFGHI/STL/RF_ECFP/oneADMET_LR-STL---pIC$_{50}$ AMPN (HUMAN).parquet





In [None]:
import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, r2_score

def collect_perf_stats_all(base_dir_aug, base_dir_noaug, folds, reps=[1, 2, 3]):
    records = []

    # Collect all dataset filenames from one fold
    fold_dir = os.path.join(base_dir_aug, folds[0], "STL", "RF_ECFP")
    all_files = [f for f in os.listdir(fold_dir) if f.endswith(".parquet")]

    for dataset_name in all_files:
        for aug_type, base_dir in zip(["Aug", "No Aug"], [base_dir_aug, base_dir_noaug]):
            for d in folds:
                file_path = os.path.join(base_dir, d, "STL", "RF_ECFP", dataset_name)
                if not os.path.exists(file_path):
                    continue
                try:
                    df = pd.read_parquet(file_path)
                    for k in reps:
                        y_true = df["Y"].tolist()
                        y_pred = df[f"pY_rep{k}"].tolist()
                        rmse = mean_squared_error(y_true, y_pred, squared=False)
                        r2 = r2_score(y_true, y_pred)
                        records.append({
                            "Dataset": dataset_name,
                            "Fold": d,
                            "Augmentation": aug_type,
                            "RMSE": rmse,
                            "R2": r2,
                            "Rep": k
                        })
                except:
                    continue

    df_perf = pd.DataFrame(records)

    # R² plot
    plt.figure(figsize=(16, 8))
    sns.boxplot(x="Dataset", y="R2", hue="Augmentation", data=df_perf)
    plt.title("R² Distribution per Dataset (Aug vs No Aug)")
    plt.xticks(rotation=90)
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    # RMSE plot
    plt.figure(figsize=(16, 8))
    sns.boxplot(x="Dataset", y="RMSE", hue="Augmentation", data=df_perf)
    plt.title("RMSE Distribution per Dataset (Aug vs No Aug)")
    plt.xticks(rotation=90)
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    return df_perf

# Example usage
df_stats_all = collect_perf_stats_all(
    base_dir_aug="../data/prediction/aug",
    base_dir_noaug="../data/prediction/noaug",
    folds=["BCDEFGHIJ", "ACDEFGHIJ", "ABDEFGHIJ", "ABCEFGHIJ","ABCDFGHIJ", "ABCDEGHIJ", "ABCDEFHIJ", "ABCDEFGIJ", "ABCDEFGHJ","ABCDEFGHI"]
)

In [None]:
# === MTL PREDICTION (Chemprop) ===
def run_mtl_predictions(data_root, combinations, model_path, chemprop_predict_path="chemprop_predict.py", output_root="../data/prediction"):
    for combo in combinations:
        mtl_dir = f"{data_root}/aug/{combo}/MTL"
        test_dir = f"{data_root}/test/{combo}/MTL"
        pred_dir = f"{output_root}/aug/{combo}/MTL/CHEMPROP"
        Path(pred_dir).mkdir(parents=True, exist_ok=True)

        for file in tqdm(os.listdir(test_dir), desc=f"{combo} - MTL Chemprop"):
            if not file.endswith(".parquet"): continue
            test_path = os.path.join(test_dir, file)
            csv_path = test_path.replace(".parquet", ".csv")

            df = pd.read_parquet(test_path)
            df[["SMILES"]].to_csv(csv_path, index=False)

            output_csv = os.path.join(pred_dir, file.replace(".parquet", "_preds.csv"))

            start_train = time.time()
            subprocess.run([
                "python", chemprop_predict_path,
                "--test_path", csv_path,
                "--checkpoint_dir", model_path,
                "--preds_path", output_csv,
                "--no_cuda"
            ])
            end_train = time.time()

            print(f"[✓] Predicted {file} in {end_train - start_train:.2f} seconds")

            os.remove(csv_path)

