In [1]:
from __future__ import annotations

import itertools
import math
import random
from dataclasses import dataclass
from typing import Dict, List, Sequence, Tuple

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
import os
from dnn import GenericMLP, collate_dense_batch

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GroupShuffleSplit, GroupKFold, cross_validate, GridSearchCV, train_test_split

In [3]:
df = pd.read_csv("../data/MIMIC-ED/event_level_training_data.csv")
df.columns

Index(['stay_id', 'temperature', 'heartrate', 'resprate', 'o2sat', 'sbp',
       'dbp', 'pain', 'rhythm_flag', 'is_white', 'is_black', 'is_asian',
       'is_hispanic', 'is_other_race', 'gender_F', 'gender_M',
       'arrival_transport_AMBULANCE', 'arrival_transport_HELICOPTER',
       'arrival_transport_OTHER', 'arrival_transport_UNKNOWN',
       'arrival_transport_WALK IN', 'time_since_adm', 'gsn_16599.0',
       'gsn_43952.0', 'gsn_4490.0', 'gsn_66419.0', 'gsn_61716.0', 'is_sepsis'],
      dtype='object')

In [4]:
# count nans
df.isna().sum()

stay_id                         0
temperature                     0
heartrate                       0
resprate                        0
o2sat                           0
sbp                             0
dbp                             0
pain                            0
rhythm_flag                     0
is_white                        0
is_black                        0
is_asian                        0
is_hispanic                     0
is_other_race                   0
gender_F                        0
gender_M                        0
arrival_transport_AMBULANCE     0
arrival_transport_HELICOPTER    0
arrival_transport_OTHER         0
arrival_transport_UNKNOWN       0
arrival_transport_WALK IN       0
time_since_adm                  0
gsn_16599.0                     0
gsn_43952.0                     0
gsn_4490.0                      0
gsn_66419.0                     0
gsn_61716.0                     0
is_sepsis                       0
dtype: int64

In [5]:
# do gssplit with this small df
gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
y = df["is_sepsis"]
X = df.drop(columns=["is_sepsis"])
train_idx, test_idx = next(gss.split(X, y, groups=df["stay_id"]))
x_train, x_test = X.iloc[train_idx], X.iloc[test_idx]
y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]
# print column types in X
print(x_train.dtypes)


stay_id                           int64
temperature                     float64
heartrate                       float64
resprate                        float64
o2sat                           float64
sbp                             float64
dbp                             float64
pain                            float64
rhythm_flag                       int64
is_white                          int64
is_black                          int64
is_asian                          int64
is_hispanic                       int64
is_other_race                     int64
gender_F                          int64
gender_M                          int64
arrival_transport_AMBULANCE       int64
arrival_transport_HELICOPTER      int64
arrival_transport_OTHER           int64
arrival_transport_UNKNOWN         int64
arrival_transport_WALK IN         int64
time_since_adm                  float64
gsn_16599.0                       int64
gsn_43952.0                       int64
gsn_4490.0                        int64


In [11]:
# ============================
# Grid Search for GenericMLP (CUDA 12.1 optimized) + Save Best
# ============================
import os, math, random, itertools, json, datetime
from dataclasses import dataclass, asdict
from typing import Dict, List, Tuple, Sequence, NamedTuple, Optional
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

# -----------------------------
# Utilities
# -----------------------------
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # ✅ CUDA 12.1/Ampere+ performance knobs
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
    try:
        torch.backends.cuda.matmul.allow_tf32 = True  # TF32 for matmuls (A100/RTX30+)
        torch.backends.cudnn.allow_tf32 = True
    except Exception:
        pass

    try:
        # Prefer HFMA/TF32 kernels when applicable
        torch.set_float32_matmul_precision("high")
    except Exception:
        pass


@dataclass
class TrainConfig:
    batch_size: int = 2048
    max_epochs: int = 30
    patience: int = 4
    grad_clip: float = 1.0
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    loss_type: str = "bce"          # "bce" or "mse"
    val_split: float = 0.2          # used only if we fall back to (X, y)
    num_workers: int = max(4, (os.cpu_count() or 8) // 2)
    prefetch_factor: int = 2
    amp: bool = True
    compile: bool = True            # ✅ will try torch.compile on CUDA
    use_fused_adamw: bool = True    # ✅ use fused AdamW when available


@dataclass
class GridResult:
    val_loss: float
    metrics: Dict[str, float]
    config: Dict[str, object]
    state_dict: Dict[str, torch.Tensor]  # ✅ keep best weights for this config


# -----------------------------
# Dataset (Fast path)
# -----------------------------
class FastTabularDataset(Dataset):
    """Stores full feature/target tensors (CPU) and slices by index."""
    def __init__(self, features_2d: torch.Tensor, targets_1d: torch.Tensor):
        assert features_2d.ndim == 2, "features must be 2D (N, D)"
        assert targets_1d.ndim == 1, "targets must be 1D (N,)"
        assert features_2d.size(0) == targets_1d.size(0)
        self.x = features_2d.contiguous()
        self.y = targets_1d.contiguous()

    def __len__(self) -> int:
        return self.x.size(0)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.x[idx], self.y[idx]


class Batch(NamedTuple):
    x: torch.Tensor  # (B, D)
    y: torch.Tensor  # (B, 1)

def collate_dense_batch(samples: Sequence[Tuple[torch.Tensor, torch.Tensor]]) -> Batch:
    xs, ys = zip(*samples)
    x = torch.stack(xs, dim=0)
    y = torch.as_tensor(ys, dtype=torch.float32).view(-1, 1)
    return Batch(x=x, y=y)


# -----------------------------
# Model
# -----------------------------
_ACTS: Dict[str, nn.Module] = {
    "relu": nn.ReLU(),
    "gelu": nn.GELU(approximate="tanh"),
    "silu": nn.SiLU()
}

class GenericMLP(nn.Module):
    def __init__(
        self,
        input_size: int,
        hidden_sizes: Tuple[int, ...] = (256, 128),
        dropout: float = 0.0,
        activation: str = "relu",
        batch_norm: bool = False,
        final_dropout: float = 0.0,
    ):
        super().__init__()
        act = _ACTS[activation]
        layers: List[nn.Module] = []
        prev = input_size
        for h in hidden_sizes:
            layers.append(nn.Linear(prev, h, bias=not batch_norm))
            if batch_norm:
                layers.append(nn.BatchNorm1d(h))
            layers.append(act)
            if dropout > 0:
                layers.append(nn.Dropout(dropout))
            prev = h
        self.backbone = nn.Sequential(*layers)
        self.head = nn.Linear(prev, 1)
        self.final_dropout = nn.Dropout(final_dropout) if final_dropout > 0 else nn.Identity()

    def forward(self, x: torch.Tensor, return_logits: bool = True) -> torch.Tensor:
        z = self.backbone(x)
        z = self.final_dropout(z)
        logits = self.head(z)
        if return_logits:
            return logits
        return torch.sigmoid(logits)


# -----------------------------
# Data loading (uses your pre-split, if present)
# -----------------------------
def make_fast_dataset_from_xy(X_np: np.ndarray, y_np: np.ndarray) -> FastTabularDataset:
    f = torch.tensor(X_np, dtype=torch.float32)
    t = torch.tensor(y_np, dtype=torch.float32)
    return FastTabularDataset(f, t)

def _maybe_use_presplit() -> Optional[FastTabularDataset]:
    """
    If x_train/y_train are defined in the notebook (as you mentioned),
    use them; otherwise return None and we'll fall back to full X,y.
    """
    g = globals()
    if all(k in g for k in ("x_train", "y_train")):
        X_np = g["x_train"].to_numpy() if hasattr(g["x_train"], "to_numpy") else np.asarray(g["x_train"])
        y_np = g["y_train"].to_numpy() if hasattr(g["y_train"], "to_numpy") else np.asarray(g["y_train"])
        return make_fast_dataset_from_xy(X_np, y_np)
    return None

def _fallback_full_dataset() -> FastTabularDataset:
    # Assumes pandas DataFrames/Series X, y exist (your original code).
    g = globals()
    X = g["X"]; y = g["y"]
    features_2d = torch.tensor(X.drop(columns=["stay_id"]).to_numpy(), dtype=torch.float32)
    targets_1d = torch.tensor(y.to_numpy(), dtype=torch.float32)
    return FastTabularDataset(features_2d, targets_1d)


# -----------------------------
# Loaders
# -----------------------------
def make_loaders(ds: Dataset, cfg: TrainConfig) -> Tuple[DataLoader, DataLoader]:
    # When pre-split is used, ds is already "train only"; we'll still carve a val split from it.
    n_val = max(1, int(len(ds) * cfg.val_split))
    n_train = len(ds) - n_val
    train_ds, val_ds = random_split(ds, [n_train, n_val], generator=torch.Generator().manual_seed(0))

    pin = (cfg.device == "cuda")
    nw = cfg.num_workers if len(ds) > 1024 else 0
    pw = bool(nw > 0)

    train_loader = DataLoader(
        train_ds,
        batch_size=cfg.batch_size,
        shuffle=True,
        collate_fn=collate_dense_batch,
        num_workers=nw,
        pin_memory=pin,
        persistent_workers=pw,
        prefetch_factor=cfg.prefetch_factor if nw > 0 else None,
        drop_last=False,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=cfg.batch_size,
        shuffle=False,
        collate_fn=collate_dense_batch,
        num_workers=nw,
        pin_memory=pin,
        persistent_workers=pw,
        prefetch_factor=cfg.prefetch_factor if nw > 0 else None,
        drop_last=False,
    )
    return train_loader, val_loader


# -----------------------------
# Training / Evaluation
# -----------------------------
def _best_amp_dtype() -> torch.dtype:
    if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
        return torch.bfloat16
    return torch.float16

def _make_optimizer(params, cfg: TrainConfig, lr: float, weight_decay: float):
    # ✅ Prefer fused AdamW on CUDA if available (PyTorch 2.0+)
    use_fused = (cfg.use_fused_adamw and torch.cuda.is_available())
    try:
        if use_fused:
            return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay, fused=True)
    except TypeError:
        pass
    return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)

def train_and_validate(
    model: GenericMLP,
    train_loader: DataLoader,
    val_loader: DataLoader,
    cfg: TrainConfig,
    optim_kwargs: Dict,
) -> Tuple[float, Dict[str, float], Dict[str, torch.Tensor]]:
    device = torch.device(cfg.device)
    model.to(device)

    # Loss
    if cfg.loss_type == "bce":
        criterion = nn.BCEWithLogitsLoss()
        use_logits = True
    elif cfg.loss_type == "mse":
        criterion = nn.MSELoss()
        use_logits = False
    else:
        raise ValueError("loss_type must be 'bce' or 'mse'")

    optimizer = _make_optimizer(model.parameters(), cfg, lr=optim_kwargs["lr"], weight_decay=optim_kwargs["weight_decay"])
    scaler = torch.cuda.amp.GradScaler(enabled=(cfg.amp and device.type == "cuda"))
    amp_dtype = _best_amp_dtype()

    # ✅ Optional compile (PyTorch 2.x)
    if cfg.compile and hasattr(torch, "compile") and device.type == "cuda":
        try:
            model = torch.compile(model, mode="reduce-overhead", fullgraph=False)
        except Exception:
            pass

    best_val_loss = math.inf
    best_metrics: Dict[str, float] = {}
    best_state = None
    epochs_no_improve = 0

    for epoch in range(cfg.max_epochs):
        model.train()
        train_loss_sum = 0.0
        n_train = 0

        for batch in train_loader:
            x, y = batch.x.to(device, non_blocking=True), batch.y.to(device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)
            if cfg.amp and device.type == "cuda":
                with torch.cuda.amp.autocast(dtype=amp_dtype):
                    if use_logits:
                        logits = model(x, return_logits=True)
                        loss = criterion(logits, y)
                    else:
                        probs = model(x, return_logits=False)
                        loss = criterion(probs, y)
                scaler.scale(loss).backward()
                if cfg.grad_clip is not None:
                    scaler.unscale_(optimizer)
                    nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
                scaler.step(optimizer)
                scaler.update()
            else:
                if use_logits:
                    logits = model(x, return_logits=True)
                    loss = criterion(logits, y)
                else:
                    probs = model(x, return_logits=False)
                    loss = criterion(probs, y)
                loss.backward()
                if cfg.grad_clip is not None:
                    nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
                optimizer.step()

            bs = x.size(0)
            train_loss_sum += float(loss) * bs
            n_train += bs

        train_loss = train_loss_sum / max(1, n_train)

        # ---- Validation (fast path) ----
        model.eval()
        val_loss_sum = 0.0
        n_val = 0
        correct = 0

        with torch.inference_mode():
            ctx = torch.cuda.amp.autocast(dtype=amp_dtype) if (cfg.amp and device.type == "cuda") else torch.no_grad()
            with ctx:
                for batch in val_loader:
                    x, y = batch.x.to(device, non_blocking=True), batch.y.to(device, non_blocking=True)
                    if use_logits:
                        logits = model(x, return_logits=True)
                        loss = criterion(logits, y)
                        probs = torch.sigmoid(logits)
                    else:
                        probs = model(x, return_logits=False)
                        loss = criterion(probs, y)
                    bs = x.size(0)
                    val_loss_sum += float(loss) * bs
                    n_val += bs
                    preds = (probs >= 0.5).to(y.dtype)
                    correct += (preds == (y >= 0.5)).sum().item()

        val_loss = val_loss_sum / max(1, n_val)
        val_acc = correct / max(1, n_val)

        # ---- Early stopping ----
        if val_loss < best_val_loss - 1e-6:
            best_val_loss = val_loss
            best_metrics = {"val_loss": val_loss, "val_acc": val_acc, "train_loss": train_loss}
            epochs_no_improve = 0
            # ✅ keep raw (uncompiled) weights on CPU
            state_cpu = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            best_state = state_cpu
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= cfg.patience:
                break

    if best_state is None:
        raise RuntimeError("Training completed without capturing a best state.")
    return best_val_loss, best_metrics, best_state


# -----------------------------
# Grid search (+ save best model)
# -----------------------------
def _rebuild_model(input_size: int, cfg_dict: Dict[str, object]) -> GenericMLP:
    return GenericMLP(
        input_size=input_size,
        hidden_sizes=tuple(cfg_dict["hidden_sizes"]),
        dropout=float(cfg_dict["dropout"]),
        activation=str(cfg_dict["activation"]),
        batch_norm=bool(cfg_dict["batch_norm"]),
        final_dropout=float(cfg_dict["final_dropout"]),
    )

def _save_best_model(
    filepath: str,
    input_size: int,
    best_cfg: Dict[str, object],
    best_state: Dict[str, torch.Tensor],
    metrics: Dict[str, float],
    val_loss: float,
    train_cfg: TrainConfig,
):
    payload = {
        "model_class": "GenericMLP",
        "model_kwargs": {
            "input_size": int(input_size),
            "hidden_sizes": tuple(best_cfg["hidden_sizes"]),
            "dropout": float(best_cfg["dropout"]),
            "activation": str(best_cfg["activation"]),
            "batch_norm": bool(best_cfg["batch_norm"]),
            "final_dropout": float(best_cfg["final_dropout"]),
        },
        "state_dict": best_state,
        "grid_config": {k: (tuple(v) if isinstance(v, list) else v) for k, v in best_cfg.items()},
        "metrics": metrics,
        "val_loss": float(val_loss),
        "train_config": asdict(train_cfg),
        "created": datetime.datetime.now().isoformat(),
        "torch_version": torch.__version__,
        "cuda_version": torch.version.cuda if torch.cuda.is_available() else None,
        "device_saved": "cuda" if torch.cuda.is_available() else "cpu",
    }
    torch.save(payload, filepath)

    # ✅ Optional TorchScript export for deployment
    try:
        model = _rebuild_model(input_size, best_cfg)
        model.load_state_dict(best_state, strict=True)
        model.eval()
        example = torch.randn(1, input_size)
        scripted = torch.jit.trace(model, example)
        torch.jit.save(scripted, os.path.splitext(filepath)[0] + "_scripted.pt")
    except Exception:
        pass

def run_grid_search(save_path: str = "best_generic_mlp.pt"):
    print("Using device:", "cuda" if torch.cuda.is_available() else "cpu")
    seed_everything(7)

    # Prefer your pre-split x_train/y_train if present
    ds = _maybe_use_presplit()
    if ds is None:
        ds = _fallback_full_dataset()

    train_cfg = TrainConfig()
    train_loader, val_loader = make_loaders(ds, train_cfg)
    input_size = ds.x.size(-1)

    param_grid = {
        "hidden_sizes": [
            (128, 64),
            (256, 128),
            (256, 128, 64),
        ],
        "dropout": [0.0, 0.1, 0.25],
        "activation": ["relu", "gelu"],
        "batch_norm": [False, True],
        "final_dropout": [0.0, 0.1],
        "lr": [1e-3, 5e-4],
        "weight_decay": [0.0, 1e-4],
    }

    keys = list(param_grid.keys())
    grid_values = [param_grid[k] for k in keys]
    total_configs = math.prod(len(v) for v in grid_values)

    results: List[GridResult] = []
    best_overall: Optional[GridResult] = None  # ✅ track winner across grid

    print(f"Total configs: {total_configs}")

    for i, values in enumerate(itertools.product(*grid_values), 1):
        cfg_dict = dict(zip(keys, values))
        print(f"\n[{i}/{total_configs}] Config: {cfg_dict}")

        model = _rebuild_model(input_size, cfg_dict)

        best_val, metrics, best_state = train_and_validate(
            model,
            train_loader=train_loader,
            val_loader=val_loader,
            cfg=train_cfg,
            optim_kwargs={"lr": cfg_dict["lr"], "weight_decay": cfg_dict["weight_decay"]},
        )

        result = GridResult(val_loss=best_val, metrics=metrics, config=cfg_dict, state_dict=best_state)
        results.append(result)
        print(f"=> best_val_loss={best_val:.4f}, val_acc={metrics['val_acc']:.4f}")

        # ✅ Maintain global best
        if (best_overall is None) or (best_val < best_overall.val_loss - 1e-12):
            best_overall = result

    # Sort for leaderboard
    results.sort(key=lambda r: r.val_loss)

    print("\n===== Leaderboard (by val_loss) =====")
    for rank, result in enumerate(results[:10], 1):
        val_loss = result.val_loss
        metrics = result.metrics
        cfg_dict = result.config
        print(
            f"#{rank:>2} val_loss={val_loss:.4f} val_acc={metrics['val_acc']:.4f} "
            f"train_loss={metrics['train_loss']:.4f} | {cfg_dict}"
        )

    # ✅ Save the single best model across the grid
    assert best_overall is not None
    print(f"\nSaving best model to: {save_path}")
    _save_best_model(
        filepath=save_path,
        input_size=input_size,
        best_cfg=best_overall.config,
        best_state=best_overall.state_dict,
        metrics=best_overall.metrics,
        val_loss=best_overall.val_loss,
        train_cfg=train_cfg,
    )

    # Report best config/metrics
    print("\nBest Config:")
    for k, v in best_overall.config.items():
        print(f"  {k}: {v}")
    print("Best Metrics:")
    for k, v in best_overall.metrics.items():
        print(f"  {k}: {v:.4f}")
    print(f"Saved weights and metadata to: {save_path}")
    ts_path = os.path.splitext(save_path)[0] + "_scripted.pt"
    if os.path.exists(ts_path):
        print(f"(TorchScript export available at: {ts_path})")


# ============================
# (Optional) Quick single run
# ============================
if __name__ == "__main__":
    run_grid_search("best_generic_mlp.pt")


Using device: cuda
Total configs: 288

[1/288] Config: {'hidden_sizes': (128, 64), 'dropout': 0.0, 'activation': 'relu', 'batch_norm': False, 'final_dropout': 0.0, 'lr': 0.001, 'weight_decay': 0.0}


  scaler = torch.cuda.amp.GradScaler(enabled=(cfg.amp and device.type == "cuda"))
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/ykim3041/anaconda3/lib/python3.12/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ykim3041/anaconda3/lib/python3.12/multiprocessing/spawn.py", line 132, in _main
    self = reduction.pickle.load(from_parent)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: Can't get attribute 'FastTabularDataset' on <module '__main__' (<class '_frozen_importlib.BuiltinImporter'>)>


KeyboardInterrupt: 