In [None]:
!pip install rdkit

Collecting rdkit
  Downloading rdkit-2025.9.4-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (4.2 kB)
Downloading rdkit-2025.9.4-cp312-cp312-manylinux_2_28_x86_64.whl (36.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m36.6/36.6 MB[0m [31m76.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rdkit
Successfully installed rdkit-2025.9.4


In [None]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m38.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.7.0


In [None]:
!pip install --quiet optuna


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/413.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m413.9/413.9 kB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

%%writefile attentivefp_reviewer_compliant.py
"""
Reviewer-compliant AttentiveFP training script (Colab-friendly, notebook-independent)

Implements reviewer-requested robustness & reproducibility:
✅ train/val/test split (no test leakage)
✅ random OR Murcko scaffold split option
✅ Optuna hyperparameter tuning on VAL only
✅ early stopping on VAL loss
✅ repeated seeds + mean±std aggregation
✅ y-randomization sanity check (shuffle TRAIN labels only)
✅ imbalance sensitivity experiment via negative subsampling (inactive_ratio)
✅ class-weighted BCE loss via pos_weight
✅ saves split indices + per-seed metrics to disk

Expected CSV columns (default):
- SMILES
- PUBCHEM_ACTIVITY_OUTCOME  (Active/Inactive)
"""

import os
import json
import math
import random
import argparse
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim

from torch_geometric.nn import AttentiveFP
from torch_geometric.loader import DataLoader
from torch_geometric.utils import from_smiles

import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"  # or ":16:8"
torch.use_deterministic_algorithms(True)

from sklearn.metrics import (
    roc_auc_score,
    average_precision_score,
    f1_score,
    cohen_kappa_score,
    precision_score,
    recall_score,
    accuracy_score,
)

# RDKit scaffold split
from rdkit import Chem
from rdkit.Chem.Scaffolds import MurckoScaffold

import optuna


# -----------------------------
# Reproducibility
# -----------------------------
def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# -----------------------------
# Config
# -----------------------------
@dataclass
class RunConfig:
    csv_path: str
    smiles_col: str = "SMILES"
    label_col: str = "PUBCHEM_ACTIVITY_OUTCOME"
    active_value: str = "Active"

    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    # Splits
    split_mode: str = "random"  # random | scaffold
    test_frac: float = 0.20
    val_frac: float = 0.10  # fraction of full dataset
    seeds: List[int] = None

    # Training
    batch_size: int = 64
    max_epochs: int = 75
    patience: int = 15

    # Optuna
    do_optuna: bool = False
    optuna_trials: int = 20
    trial_max_epochs: int = 25

    # Sanity / robustness
    y_randomization: bool = False

    # Imbalance experiment: keep all actives, subsample inactives to approx 1:k
    inactive_ratio: Optional[int] = None

    save_dir: str = "runs_attentivefp"


# -----------------------------
# Data utilities
# -----------------------------
def to_binary_label(x: str, active_value: str) -> int:
    return 1 if str(x).strip() == active_value else 0


def murcko_scaffold(smiles: str) -> str:
    """Return Murcko scaffold SMILES; empty string if invalid."""
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return ""
        scaf = MurckoScaffold.GetScaffoldForMol(mol)
        if scaf is None:
            return ""
        return Chem.MolToSmiles(scaf, isomericSmiles=False)
    except Exception:
        return ""


def random_stratified_split(
    labels: List[int],
    test_frac: float,
    val_frac: float,
    seed: int,
) -> Tuple[List[int], List[int], List[int]]:
    """
    Stratified split by class, returns train_idx, val_idx, test_idx over indices [0..N-1].
    val_frac and test_frac are fractions of full dataset.
    """
    rng = np.random.default_rng(seed)
    labels_arr = np.asarray(labels, dtype=int)

    pos_idx = np.where(labels_arr == 1)[0].tolist()
    neg_idx = np.where(labels_arr == 0)[0].tolist()
    rng.shuffle(pos_idx)
    rng.shuffle(neg_idx)

    def split_class(arr: List[int]) -> Tuple[List[int], List[int], List[int]]:
        n = len(arr)
        n_test = int(round(test_frac * n))
        n_val = int(round(val_frac * n))
        test = arr[:n_test]
        val = arr[n_test : n_test + n_val]
        train = arr[n_test + n_val :]
        return train, val, test

    pos_train, pos_val, pos_test = split_class(pos_idx)
    neg_train, neg_val, neg_test = split_class(neg_idx)

    train_idx = pos_train + neg_train
    val_idx = pos_val + neg_val
    test_idx = pos_test + neg_test

    rng.shuffle(train_idx)
    rng.shuffle(val_idx)
    rng.shuffle(test_idx)
    return train_idx, val_idx, test_idx


def stratified_group_split(
    groups: List[str],
    labels: List[int],
    test_frac: float,
    val_frac: float,
    seed: int,
) -> Tuple[List[int], List[int], List[int]]:
    """
    Greedy scaffold-group split approximating stratification.

    groups: scaffold id per sample (len N)
    labels: 0/1 per sample
    test_frac/val_frac: fractions of full dataset
    """
    rng = np.random.default_rng(seed)
    N = len(labels)
    labels_arr = np.asarray(labels, dtype=int)

    # group -> indices
    group_to_idx: Dict[str, List[int]] = {}
    for i, g in enumerate(groups):
        group_to_idx.setdefault(g, []).append(i)

    # shuffle groups
    all_groups = list(group_to_idx.keys())
    rng.shuffle(all_groups)

    # desired sizes
    n_test_target = int(round(test_frac * N))
    n_val_target = int(round(val_frac * N))

    test_idx, val_idx, train_idx = [], [], []

    # Greedy fill test, then val, rest train
    for g in all_groups:
        idxs = group_to_idx[g]
        # assign to test if still room, else to val if room, else train
        if len(test_idx) + len(idxs) <= n_test_target:
            test_idx.extend(idxs)
        elif len(val_idx) + len(idxs) <= n_val_target:
            val_idx.extend(idxs)
        else:
            train_idx.extend(idxs)

    # If we underfilled test/val (can happen), move from train
    def move_from_train(dst: List[int], n_target: int):
        nonlocal train_idx
        while len(dst) < n_target and len(train_idx) > 0:
            dst.append(train_idx.pop())

    move_from_train(test_idx, n_test_target)
    move_from_train(val_idx, n_val_target)

    rng.shuffle(train_idx)
    rng.shuffle(val_idx)
    rng.shuffle(test_idx)
    return train_idx, val_idx, test_idx


def apply_inactive_ratio_sampling(
    df: pd.DataFrame,
    labels: List[int],
    inactive_ratio: int,
    seed: int,
) -> pd.DataFrame:
    """
    Keep all actives; subsample inactives so that #inactive ≈ #active * inactive_ratio.
    NOTE: This is NOT oversampling. It's negative subsampling for an imbalance-stress experiment.
    """
    rng = np.random.default_rng(seed)
    labels_arr = np.asarray(labels, dtype=int)

    act_idx = np.where(labels_arr == 1)[0]
    inact_idx = np.where(labels_arr == 0)[0]

    n_act = len(act_idx)
    n_inact_target = min(len(inact_idx), n_act * inactive_ratio)

    sampled_inact = rng.choice(inact_idx, size=n_inact_target, replace=False)
    keep_idx = np.concatenate([act_idx, sampled_inact])
    rng.shuffle(keep_idx)

    return df.iloc[keep_idx].reset_index(drop=True)


def build_graphs_from_smiles(
    smiles_list: List[str],
    labels: List[int],
) -> Tuple[List[torch.Tensor], List[int], List[int]]:
    """
    Returns:
      graphs: list of torch_geometric Data objects
      kept_indices: indices kept (valid SMILES converted)
      skipped_indices: indices skipped (invalid SMILES)
    """
    graphs = []
    kept = []
    skipped = []

    for i, smi in enumerate(smiles_list):
        try:
            g = from_smiles(smi)
            if g is None or g.x is None or g.edge_index is None:
                skipped.append(i)
                continue

            # ensure floats
            g.x = g.x.float()
            if getattr(g, "edge_attr", None) is not None:
                g.edge_attr = g.edge_attr.float()

            # label tensor [1,1]
            g.y = torch.tensor([labels[i]], dtype=torch.float32).view(-1, 1)

            graphs.append(g)
            kept.append(i)
        except Exception:
            skipped.append(i)

    return graphs, kept, skipped


# -----------------------------
# Model + training
# -----------------------------
class AttentiveFPBinary(nn.Module):
    def __init__(
        self,
        in_channels: int,
        edge_dim: int,
        hidden_channels: int,
        num_layers: int,
        num_timesteps: int,
        dropout: float,
    ):
        super().__init__()
        self.model = AttentiveFP(
            in_channels=in_channels,
            hidden_channels=hidden_channels,
            out_channels=1,  # logits
            edge_dim=edge_dim,
            num_layers=num_layers,
            num_timesteps=num_timesteps,
            dropout=dropout,
        )

    def forward(self, data):
        return self.model(data.x, data.edge_index, data.edge_attr, data.batch)


@torch.no_grad()
def evaluate(model, loader, device) -> Dict[str, float]:
    model.eval()
    y_true, y_prob = [], []

    for batch in loader:
        batch = batch.to(device)
        logits = model(batch).view(-1)
        probs = torch.sigmoid(logits)

        y_true.extend(batch.y.view(-1).detach().cpu().numpy().tolist())
        y_prob.extend(probs.detach().cpu().numpy().tolist())

    y_true_arr = np.asarray(y_true, dtype=np.float32)
    y_prob_arr = np.asarray(y_prob, dtype=np.float32)

    y_pred = (y_prob_arr >= 0.5).astype(int)

    metrics: Dict[str, float] = {}
    metrics["accuracy"] = float(accuracy_score(y_true_arr, y_pred))

    # Only compute AUCs if both classes present
    if len(np.unique(y_true_arr)) == 2:
        metrics["roc_auc"] = float(roc_auc_score(y_true_arr, y_prob_arr))
        metrics["pr_auc"] = float(average_precision_score(y_true_arr, y_prob_arr))
    else:
        metrics["roc_auc"] = float("nan")
        metrics["pr_auc"] = float("nan")

    metrics["precision"] = float(precision_score(y_true_arr, y_pred, zero_division=0))
    metrics["recall"] = float(recall_score(y_true_arr, y_pred, zero_division=0))
    metrics["f1"] = float(f1_score(y_true_arr, y_pred, zero_division=0))
    metrics["kappa"] = float(cohen_kappa_score(y_true_arr, y_pred))
    return metrics


def compute_pos_weight(train_labels: List[int], device: str) -> torch.Tensor:
    n_pos = int(np.sum(train_labels))
    n_neg = int(len(train_labels) - n_pos)
    w = (n_neg / max(1, n_pos)) if n_pos > 0 else 1.0
    return torch.tensor([w], dtype=torch.float32, device=device)


def train_with_early_stopping(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    device: str,
    lr: float,
    weight_decay: float,
    max_epochs: int,
    patience: int,
    pos_weight: torch.Tensor,
) -> Tuple[nn.Module, Dict[str, float]]:
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    best_val = float("inf")
    best_state = None
    bad = 0

    for epoch in range(1, max_epochs + 1):
        model.train()
        total_loss, n_batches = 0.0, 0

        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            logits = model(batch).view(-1)
            y = batch.y.view(-1)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

            total_loss += float(loss.item())
            n_batches += 1

        # validation loss
        model.eval()
        val_loss, val_batches = 0.0, 0
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                logits = model(batch).view(-1)
                y = batch.y.view(-1)
                loss = criterion(logits, y)
                val_loss += float(loss.item())
                val_batches += 1

        val_loss = val_loss / max(1, val_batches)

        if val_loss < best_val - 1e-5:
            best_val = val_loss
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1
            if bad >= patience:
                break

    if best_state is not None:
        model.load_state_dict(best_state)

    val_metrics = evaluate(model, val_loader, device)
    val_metrics["val_loss_best"] = float(best_val)
    return model, val_metrics


# -----------------------------
# Optuna tuning (VAL only)
# -----------------------------
def optuna_objective_factory(
    graphs: List,
    train_idx: List[int],
    val_idx: List[int],
    cfg: RunConfig,
    in_channels: int,
    edge_dim: int,
    train_labels: List[int],
):
    device = cfg.device

    train_set = [graphs[i] for i in train_idx]
    val_set = [graphs[i] for i in val_idx]
    train_loader = DataLoader(train_set, batch_size=cfg.batch_size, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=cfg.batch_size, shuffle=False)

    pos_weight = compute_pos_weight(train_labels, device=device)

    def objective(trial: optuna.Trial) -> float:
        hidden = trial.suggest_int("hidden_channels", 64, 256)
        num_layers = trial.suggest_int("num_layers", 2, 5)
        num_timesteps = trial.suggest_int("num_timesteps", 2, 6)
        dropout = trial.suggest_float("dropout", 0.0, 0.6)
        lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
        wd = trial.suggest_float("weight_decay", 1e-7, 1e-3, log=True)

        model = AttentiveFPBinary(
            in_channels=in_channels,
            edge_dim=edge_dim,
            hidden_channels=hidden,
            num_layers=num_layers,
            num_timesteps=num_timesteps,
            dropout=dropout,
        )

        model, _ = train_with_early_stopping(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            device=device,
            lr=lr,
            weight_decay=wd,
            max_epochs=cfg.trial_max_epochs,
            patience=max(5, cfg.patience // 2),
            pos_weight=pos_weight,
        )

        val_metrics = evaluate(model, val_loader, device)
        score = val_metrics.get("roc_auc", float("nan"))
        if np.isnan(score):
            score = val_metrics.get("accuracy", 0.0)
        return float(score)

    return objective


# -----------------------------
# Run one seed
# -----------------------------
def run_one_seed(cfg: RunConfig, seed: int) -> Dict[str, object]:
    set_seed(seed)
    os.makedirs(cfg.save_dir, exist_ok=True)
    os.makedirs(os.path.join(cfg.save_dir, "splits"), exist_ok=True)
    os.makedirs(os.path.join(cfg.save_dir, "metrics"), exist_ok=True)

    df = pd.read_csv(cfg.csv_path)

    # labels from raw df (before invalid SMILES filtering)
    raw_labels = [to_binary_label(x, cfg.active_value) for x in df[cfg.label_col].tolist()]

    # Optional imbalance sampling (keeps all actives; samples inactives)
    if cfg.inactive_ratio is not None and cfg.inactive_ratio > 0:
        df = apply_inactive_ratio_sampling(df, raw_labels, cfg.inactive_ratio, seed)
        raw_labels = [to_binary_label(x, cfg.active_value) for x in df[cfg.label_col].tolist()]

    smiles = df[cfg.smiles_col].astype(str).tolist()

    # Build graphs (skip invalid SMILES)
    graphs, kept, skipped = build_graphs_from_smiles(smiles, raw_labels)
    labels_kept = [raw_labels[i] for i in kept]
    smiles_kept = [smiles[i] for i in kept]

    if len(graphs) < 50:
        raise RuntimeError(f"Too few valid graphs after SMILES parsing: {len(graphs)}")

    # Infer input dims
    in_channels = int(graphs[0].x.size(-1))
    edge_dim = int(graphs[0].edge_attr.size(-1)) if getattr(graphs[0], "edge_attr", None) is not None else 0

    # Split indices on the kept list (0..len(graphs)-1)
    if cfg.split_mode == "scaffold":
        scaffolds = [murcko_scaffold(smi) for smi in smiles_kept]
        scaffolds = [s if s else f"NOSCAF_{i}" for i, s in enumerate(scaffolds)]
        train_idx, val_idx, test_idx = stratified_group_split(
            groups=scaffolds,
            labels=labels_kept,
            test_frac=cfg.test_frac,
            val_frac=cfg.val_frac,
            seed=seed,
        )
    else:
        train_idx, val_idx, test_idx = random_stratified_split(
            labels=labels_kept,
            test_frac=cfg.test_frac,
            val_frac=cfg.val_frac,
            seed=seed,
        )

    # y-randomization only on TRAIN labels
    if cfg.y_randomization:
        rng = np.random.default_rng(seed)
        y_train = np.array([labels_kept[i] for i in train_idx], dtype=int)
        rng.shuffle(y_train)
        for j, idx in enumerate(train_idx):
            graphs[idx].y = torch.tensor([int(y_train[j])], dtype=torch.float32).view(-1, 1)

    # Dataloaders
    train_set = [graphs[i] for i in train_idx]
    val_set = [graphs[i] for i in val_idx]
    test_set = [graphs[i] for i in test_idx]

    train_loader = DataLoader(train_set, batch_size=cfg.batch_size, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=cfg.batch_size, shuffle=False)
    test_loader = DataLoader(test_set, batch_size=cfg.batch_size, shuffle=False)

    # Default params (will be overridden by Optuna best)
    best_params = dict(
        hidden_channels=128,
        num_layers=3,
        num_timesteps=3,
        dropout=0.2,
        lr=1e-3,
        weight_decay=1e-5,
    )

    # Optuna tuning on VAL only
    if cfg.do_optuna:
        train_labels = [labels_kept[i] for i in train_idx]
        objective = optuna_objective_factory(
            graphs=graphs,
            train_idx=train_idx,
            val_idx=val_idx,
            cfg=cfg,
            in_channels=in_channels,
            edge_dim=edge_dim,
            train_labels=train_labels,
        )
        sampler = optuna.samplers.TPESampler(seed=seed)
        study = optuna.create_study(direction="maximize", sampler=sampler)
        study.optimize(objective, n_trials=cfg.optuna_trials)

        best_params.update(study.best_params)

        # Some params are not in Optuna (lr, weight_decay always are; keep safe fallback)
        if "lr" not in best_params:
            best_params["lr"] = 1e-3
        if "weight_decay" not in best_params:
            best_params["weight_decay"] = 1e-5

    # Train final model with best params
    model = AttentiveFPBinary(
        in_channels=in_channels,
        edge_dim=edge_dim,
        hidden_channels=int(best_params["hidden_channels"]),
        num_layers=int(best_params["num_layers"]),
        num_timesteps=int(best_params["num_timesteps"]),
        dropout=float(best_params["dropout"]),
    )

    train_labels = [labels_kept[i] for i in train_idx]
    pos_weight = compute_pos_weight(train_labels, device=cfg.device)

    model, val_metrics = train_with_early_stopping(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        device=cfg.device,
        lr=float(best_params["lr"]),
        weight_decay=float(best_params["weight_decay"]),
        max_epochs=cfg.max_epochs,
        patience=cfg.patience,
        pos_weight=pos_weight,
    )

    test_metrics = evaluate(model, test_loader, cfg.device)

    # Save splits
    split_payload = {
        "seed": seed,
        "csv_path": cfg.csv_path,
        "split_mode": cfg.split_mode,
        "test_frac": cfg.test_frac,
        "val_frac": cfg.val_frac,
        "train_idx": train_idx,
        "val_idx": val_idx,
        "test_idx": test_idx,
        "n_graphs": len(graphs),
        "skipped_smiles_count": len(skipped),
        "in_channels": in_channels,
        "edge_dim": edge_dim,
        "y_randomization": cfg.y_randomization,
        "inactive_ratio": cfg.inactive_ratio,
    }
    split_path = os.path.join(cfg.save_dir, "splits", f"splits_seed{seed}.json")
    with open(split_path, "w", encoding="utf-8") as f:
        json.dump(split_payload, f, indent=2)

    # Save metrics + params
    result = {
        "seed": seed,
        "best_params": best_params,
        "val_metrics": val_metrics,
        "test_metrics": test_metrics,
        "split_path": split_path,
    }
    metrics_path = os.path.join(cfg.save_dir, "metrics", f"metrics_seed{seed}.json")
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(result, f, indent=2)

    return result


def aggregate_results(results: List[Dict[str, object]]) -> Dict[str, Dict[str, float]]:
    keys = ["accuracy", "roc_auc", "pr_auc", "precision", "recall", "f1", "kappa"]
    data = {k: [] for k in keys}
    for r in results:
        tm = r["test_metrics"]
        for k in keys:
            data[k].append(tm.get(k, float("nan")))

    summary = {}
    for k in keys:
        arr = np.asarray(data[k], dtype=np.float64)
        summary[k] = {
            "mean": float(np.nanmean(arr)),
            "std": float(np.nanstd(arr, ddof=1)) if np.sum(~np.isnan(arr)) > 1 else float("nan"),
        }
    return {"test_summary": summary}


# -----------------------------
# CLI
# -----------------------------
def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--csv_path", type=str, required=True, help="Path to AID*.csv file with SMILES + labels")
    p.add_argument("--split_mode", type=str, default="random", choices=["random", "scaffold"])
    p.add_argument("--seeds", type=str, default="0,1,2,3,4", help="Comma-separated seeds")
    p.add_argument("--save_dir", type=str, default="runs_attentivefp")

    p.add_argument("--batch_size", type=int, default=64)
    p.add_argument("--max_epochs", type=int, default=75)
    p.add_argument("--patience", type=int, default=15)

    p.add_argument("--do_optuna", action="store_true", help="Enable Optuna tuning (VAL only)")
    p.add_argument("--optuna_trials", type=int, default=20)
    p.add_argument("--trial_max_epochs", type=int, default=25)

    p.add_argument("--split_test_frac", type=float, default=0.20)
    p.add_argument("--split_val_frac", type=float, default=0.10)

    p.add_argument("--y_randomization", action="store_true", help="Shuffle TRAIN labels for sanity check")
    p.add_argument("--inactive_ratio", type=int, default=None, help="If set, sample inactives to approx 1:k ratio")

    return p.parse_args()


def main():
    args = parse_args()

    cfg = RunConfig(
        csv_path=args.csv_path,
        split_mode=args.split_mode,
        test_frac=args.split_test_frac,
        val_frac=args.split_val_frac,
        seeds=[int(x.strip()) for x in args.seeds.split(",") if x.strip() != ""],
        save_dir=args.save_dir,
        batch_size=args.batch_size,
        max_epochs=args.max_epochs,
        patience=args.patience,
        do_optuna=args.do_optuna,
        optuna_trials=args.optuna_trials,
        trial_max_epochs=args.trial_max_epochs,
        y_randomization=args.y_randomization,
        inactive_ratio=args.inactive_ratio,
    )

    os.makedirs(cfg.save_dir, exist_ok=True)

    print(f"Device: {cfg.device}")
    print(f"CSV: {cfg.csv_path}")
    print(f"Split: {cfg.split_mode} | test_frac={cfg.test_frac} val_frac={cfg.val_frac}")
    print(f"Seeds: {cfg.seeds}")
    print(f"Optuna: {cfg.do_optuna} | trials={cfg.optuna_trials}")
    print(f"y_randomization: {cfg.y_randomization} | inactive_ratio: {cfg.inactive_ratio}")

    all_results = []
    for seed in cfg.seeds:
        print(f"\n=== Running seed {seed} ===")
        res = run_one_seed(cfg, seed)
        print("Test metrics:", res["test_metrics"])
        all_results.append(res)

    summary = aggregate_results(all_results)
    summary_path = os.path.join(cfg.save_dir, "summary_mean_std.json")
    with open(summary_path, "w", encoding="utf-8") as f:
        json.dump(summary, f, indent=2)

    print("\n=== Mean±Std (TEST) ===")
    for k, v in summary["test_summary"].items():
        print(f"{k:>10}: {v['mean']:.4f} ± {v['std']:.4f}")
    print(f"\nSaved summary: {summary_path}")


if __name__ == "__main__":
    main()


Writing attentivefp_reviewer_compliant.py
