## Explore ADMET dataset distributions (ChEMBL curated)

This notebook loads curated parquet splits from `data/admet/`:

- `herg_{train,cal,test}.parquet`
- `logs_{train,cal,test}.parquet`
- `logp_{train,cal,test}.parquet`

It plots:

- Class balance for hERG
- Feature distributions using ChEMBL-provided properties (`full_mwt`, `alogp`, `psa`)

It also writes summary statistics to `data/admet/summary.json`.


In [None]:
from __future__ import annotations

import json
from pathlib import Path

import pandas as pd


ADMET_DIR = Path("data/admet")
OUT_SUMMARY = ADMET_DIR / "summary.json"


def load_split(endpoint: str, split: str) -> pd.DataFrame:
    path = ADMET_DIR / f"{endpoint}_{split}.parquet"
    if not path.exists():
        return pd.DataFrame()
    return pd.read_parquet(path)


def dataset_summary(df: pd.DataFrame) -> dict:
    if df.empty:
        return {"rows": 0}
    out = {
        "rows": int(len(df)),
        "n_compounds": int(df["molecule_chembl_id"].nunique()) if "molecule_chembl_id" in df.columns else None,
        "missing_smiles": int(df["canonical_smiles"].isna().sum()) if "canonical_smiles" in df.columns else None,
    }
    for col in ("full_mwt", "alogp", "psa"):
        if col in df.columns:
            s = pd.to_numeric(df[col], errors="coerce")
            out[f"{col}_missing"] = int(s.isna().sum())
            out[f"{col}_mean"] = float(s.mean()) if s.notna().any() else None
    if "y" in df.columns:
        if set(df["y"].dropna().unique().tolist()).issubset({0, 1}):
            vc = df["y"].value_counts().to_dict()
            out["label_counts"] = {str(k): int(v) for k, v in vc.items()}
        else:
            s = pd.to_numeric(df["y"], errors="coerce")
            out["y_mean"] = float(s.mean()) if s.notna().any() else None
            out["y_std"] = float(s.std()) if s.notna().any() else None
    return out


summary = {
    "generated_at": pd.Timestamp.utcnow().isoformat(),
    "admet_dir": str(ADMET_DIR),
    "endpoints": {},
}

for endpoint in ("herg", "logs", "logp"):
    summary["endpoints"][endpoint] = {}
    for split in ("train", "cal", "test"):
        df = load_split(endpoint, split)
        summary["endpoints"][endpoint][split] = dataset_summary(df)

OUT_SUMMARY.parent.mkdir(parents=True, exist_ok=True)
OUT_SUMMARY.write_text(json.dumps(summary, indent=2, sort_keys=True))

summary


In [None]:
from __future__ import annotations

import matplotlib.pyplot as plt
import pandas as pd


def plot_herg_balance() -> None:
    rows = []
    for split in ("train", "cal", "test"):
        df = load_split("herg", split)
        if df.empty or "y" not in df.columns:
            continue
        vc = df["y"].value_counts().to_dict()
        rows.append({"split": split, "0": int(vc.get(0, 0)), "1": int(vc.get(1, 0))})

    if not rows:
        print("No hERG splits found.")
        return

    t = pd.DataFrame(rows).set_index("split")
    ax = t.plot(kind="bar", stacked=True, figsize=(6, 3))
    ax.set_title("hERG class balance")
    ax.set_ylabel("rows")
    plt.tight_layout()
    plt.show()


plot_herg_balance()

# Basic feature distributions for each endpoint
for endpoint in ("herg", "logs", "logp"):
    df = pd.concat([load_split(endpoint, s) for s in ("train", "cal", "test")], ignore_index=True)
    if df.empty:
        continue

    fig, axes = plt.subplots(1, 3, figsize=(12, 3))
    for ax, (col, title) in zip(
        axes,
        [("full_mwt", "MW"), ("alogp", "LogP (ChEMBL)"), ("psa", "TPSA")],
    ):
        if col not in df.columns:
            ax.axis("off")
            continue
        s = pd.to_numeric(df[col], errors="coerce").dropna()
        ax.hist(s, bins=40)
        ax.set_title(f"{endpoint}: {title}")
    plt.tight_layout()
    plt.show()
