In [None]:
"""
ImmunoFoundationModel - JupyterLab single-file version

Research-grade Python code for multimodal modeling of women's reproductive health,
focused on reproductive aging / ovarian reserve and endometriosis risk.

This script integrates:
- Data schemas and preprocessing
- Multi-modal PyTorch Dataset
- Modality-specific encoders and ImmunoFoundationModel
- Multitask training engine + optional masked-feature pretraining
- Integrated gradients interpretability
- Example configs and a synthetic test pipeline

IMPORTANT:
- No biological constants or thresholds are hard-coded.
- Wherever clinical cut points are needed, insert them manually as `# TODO` with real values.
"""

from __future__ import annotations

import argparse
import json
import logging
import os
import pickle
import random
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
import yaml
from captum.attr import IntegratedGradients
from sklearn import metrics
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset

# ---------------------------------------------------------------------------
# Utilities
# ---------------------------------------------------------------------------


def set_seed(seed: int) -> None:
    """Set random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def load_config(path: Path) -> Dict[str, Any]:
    """Load YAML configuration from disk."""
    with open(path, "r", encoding="utf-8") as f:
        return yaml.safe_load(f)


def configure_logging(level: int = logging.INFO) -> None:
    """Configure basic logging."""
    logging.basicConfig(
        level=level,
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    )


logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Data schemas
# ---------------------------------------------------------------------------


@dataclass
class ModalitySchema:
    """Schema describing a data modality.

    Attributes:
        name: Name of the modality (e.g., "methylation", "immune").
        id_column: Identifier column shared across modalities.
        feature_columns: List of feature column names used for the modality.
        categorical_columns: Optional list of categorical columns (for clinical data).
        column_mapping: Mapping from raw column names to canonical internal names.
    """

    name: str
    id_column: str = "patient_id"
    feature_columns: List[str] = field(default_factory=list)
    categorical_columns: Optional[List[str]] = None
    column_mapping: Dict[str, str] = field(default_factory=dict)


@dataclass
class TaskSchema:
    """Schema describing supervised tasks.

    Attributes correspond to column names in the merged dataframe.
    """

    reproductive_age: Optional[str] = None
    oocyte_yield: Optional[str] = None
    endometriosis_presence: Optional[str] = None
    endometriosis_stage: Optional[str] = None


@dataclass
class DatasetSchema:
    """Container for dataset schemas across modalities and tasks."""

    modalities: Dict[str, ModalitySchema]
    tasks: TaskSchema


def create_schema_from_config(config: Dict[str, Any]) -> DatasetSchema:
    """Create DatasetSchema from YAML-style config dict."""
    modalities: Dict[str, ModalitySchema] = {}
    for name, modality_cfg in config["data"]["modalities"].items():
        modalities[name] = ModalitySchema(
            name=name,
            id_column=modality_cfg.get("id_column", "patient_id"),
            feature_columns=modality_cfg.get("feature_columns", []),
            categorical_columns=modality_cfg.get("categorical_columns"),
            column_mapping=modality_cfg.get("column_mapping", {}),
        )
    tasks_cfg = config["data"].get("tasks", {})
    tasks = TaskSchema(
        reproductive_age=tasks_cfg.get("reproductive_age"),
        oocyte_yield=tasks_cfg.get("oocyte_yield"),
        endometriosis_presence=tasks_cfg.get("endometriosis_presence"),
        endometriosis_stage=tasks_cfg.get("endometriosis_stage"),
    )
    return DatasetSchema(modalities=modalities, tasks=tasks)


# ---------------------------------------------------------------------------
# Preprocessing
# ---------------------------------------------------------------------------


class DataPreprocessor:
    """Handle preprocessing across multiple modalities.

    The preprocessor fits scalers/encoders on the training set only and reuses
    them for validation/test or inference. Preprocessing artifacts are saved
    to disk so the exact transformations can be applied later.
    """

    def __init__(self, schema: DatasetSchema, artifacts_dir: Path):
        self.schema = schema
        self.artifacts_dir = Path(artifacts_dir)
        self.artifacts_dir.mkdir(parents=True, exist_ok=True)
        self.transformers: Dict[str, Pipeline] = {}
        self.feature_names_: Dict[str, List[str]] = {}

    @staticmethod
    def _load_table(path: Path) -> pd.DataFrame:
        if not path.exists():
            raise FileNotFoundError(f"Expected data file at {path}")
        if path.suffix.lower() in {".parquet", ".pq"}:
            return pd.read_parquet(path)
        return pd.read_csv(path)

    def load_modality(
        self,
        modality_name: str,
        path: Path,
        extra_columns: Optional[List[str]] = None,
    ) -> pd.DataFrame:
        """Load a single modality table, apply column mapping, and trim columns."""
        modality: ModalitySchema = self.schema.modalities[modality_name]
        df = self._load_table(path)

        # Apply column mapping if provided
        if modality.column_mapping:
            df = df.rename(columns=modality.column_mapping)

        extras = extra_columns or []
        retained_extras = [col for col in extras if col in df.columns]

        if modality.feature_columns:
            expected = set(modality.feature_columns + [modality.id_column])
            missing = expected - set(df.columns)
            if missing:
                raise ValueError(f"Missing columns for {modality_name}: {missing}")
            selected_cols = [modality.id_column] + modality.feature_columns + retained_extras
            # Deduplicate while preserving order
            unique_cols = list(dict.fromkeys(selected_cols))
            df = df[unique_cols]
        else:
            # If no feature list provided, treat all non-id columns as features
            df = df.dropna(axis=1, how="all")
            cols = [c for c in df.columns if c != modality.id_column]
            modality.feature_columns = cols

        return df

    def merge_modalities(self, modality_frames: Dict[str, pd.DataFrame]) -> pd.DataFrame:
        """Inner-join all modalities on their id column; drop rows with any NaNs."""
        base: Optional[pd.DataFrame] = None
        for name, frame in modality_frames.items():
            id_col = self.schema.modalities[name].id_column
            if base is None:
                base = frame
            else:
                base = base.merge(frame, on=id_col, how="inner")
        if base is None:
            raise ValueError("No modalities provided for merge")
        base = base.dropna(axis=0, how="any")
        logger.info("Merged modalities shape: %s", base.shape)
        return base

    def _build_transformer(self, modality_name: str) -> Pipeline:
        modality = self.schema.modalities[modality_name]
        categorical = modality.categorical_columns or []
        continuous = [c for c in modality.feature_columns if c not in categorical]
        transformers = []
        if continuous:
            transformers.append(("continuous", StandardScaler(), continuous))
        if categorical:
            transformers.append(
                (
                    "categorical",
                    OneHotEncoder(handle_unknown="ignore"),
                    categorical,
                )
            )
        if not transformers:
            raise ValueError(f"No transformers configured for {modality_name}")
        return ColumnTransformer(transformers)

    def fit(self, df: pd.DataFrame) -> None:
        """Fit transformers on feature columns for each modality."""
        for modality_name, modality in self.schema.modalities.items():
            transformer = self._build_transformer(modality_name)
            transformer.fit(df[modality.feature_columns])
            self.transformers[modality_name] = transformer

            feature_names: List[str] = []
            for name, trans, cols in transformer.transformers_:
                if name == "categorical" and hasattr(trans, "get_feature_names_out"):
                    feature_names.extend(trans.get_feature_names_out(cols).tolist())
                elif isinstance(cols, list):
                    feature_names.extend(cols)
            self.feature_names_[modality_name] = feature_names
            logger.info(
                "Fitted transformer for %s with %d features",
                modality_name,
                len(feature_names),
            )

    def transform(self, df: pd.DataFrame) -> Dict[str, np.ndarray]:
        """Transform the dataframe into numeric arrays by modality."""
        transformed: Dict[str, np.ndarray] = {}
        for modality_name, transformer in self.transformers.items():
            modality = self.schema.modalities[modality_name]
            transformed[modality_name] = transformer.transform(df[modality.feature_columns])
        return transformed

    def fit_transform(self, df: pd.DataFrame) -> Dict[str, np.ndarray]:
        """Convenience method: fit on df, then transform df."""
        self.fit(df)
        return self.transform(df)

    def split(
        self,
        df: pd.DataFrame,
        test_size: float,
        val_size: float,
        random_state: int,
    ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
        """Randomly split the merged dataframe into train/val/test sets."""
        train_df, test_df = train_test_split(
            df,
            test_size=test_size,
            random_state=random_state,
            shuffle=True,
        )
        relative_val = val_size / max(1e-8, (1 - test_size))
        train_df, val_df = train_test_split(
            train_df,
            test_size=relative_val,
            random_state=random_state,
            shuffle=True,
        )
        return train_df, val_df, test_df

    def save_artifacts(self) -> None:
        """Save transformers and feature names to disk."""
        for modality_name, transformer in self.transformers.items():
            path = self.artifacts_dir / f"{modality_name}_transformer.pkl"
            with open(path, "wb") as f:
                pickle.dump(transformer, f)
        meta = {"feature_names": self.feature_names_}
        with open(self.artifacts_dir / "preprocess_meta.json", "w", encoding="utf-8") as f:
            json.dump(meta, f, indent=2)
        logger.info("Saved preprocessing artifacts to %s", self.artifacts_dir)

    def load_artifacts(self) -> None:
        """Load transformers and feature names from disk."""
        feature_meta_path = self.artifacts_dir / "preprocess_meta.json"
        with open(feature_meta_path, "r", encoding="utf-8") as f:
            meta = json.load(f)
        self.feature_names_ = meta.get("feature_names", {})
        for modality_name in self.schema.modalities:
            path = self.artifacts_dir / f"{modality_name}_transformer.pkl"
            with open(path, "rb") as f:
                self.transformers[modality_name] = pickle.load(f)
        logger.info("Loaded preprocessing artifacts from %s", self.artifacts_dir)


# ---------------------------------------------------------------------------
# Dataset and collator
# ---------------------------------------------------------------------------


class MultiModalDataset(Dataset):
    """PyTorch Dataset for multi-modal tabular data.

    Args:
        inputs: Dict mapping modality names to numpy arrays of shape (n_samples, n_features).
        targets: Dict mapping task names to numpy arrays of targets.
    """

    def __init__(self, inputs: Dict[str, np.ndarray], targets: Dict[str, np.ndarray]):
        self.inputs = inputs
        self.targets = targets
        self.length = next(iter(inputs.values())).shape[0]

    def __len__(self) -> int:
        return self.length

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        batch: Dict[str, torch.Tensor] = {
            modality: torch.as_tensor(values[idx]).float()
            for modality, values in self.inputs.items()
        }
        for task, target in self.targets.items():
            batch[task] = torch.as_tensor(target[idx])
        return batch


class BatchCollator:
    """Collate function to keep modalities grouped when batching."""

    def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        collated: Dict[str, torch.Tensor] = {}
        for key in batch[0].keys():
            collated[key] = torch.stack([item[key] for item in batch], dim=0)
        return collated


# ---------------------------------------------------------------------------
# Models: encoders and foundation model
# ---------------------------------------------------------------------------


def build_mlp(input_dim: int, hidden_dims: Tuple[int, ...], dropout: float = 0.1) -> nn.Sequential:
    """Build a simple MLP with BatchNorm and Dropout."""
    layers: List[nn.Module] = []
    prev_dim = input_dim
    for hidden_dim in hidden_dims:
        layers.extend(
            [
                nn.Linear(prev_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout),
            ]
        )
        prev_dim = hidden_dim
    return nn.Sequential(*layers)


class MethylationEncoder(nn.Module):
    """Encoder for DNA methylation features using a simple MLP."""

    def __init__(self, input_dim: int, hidden_dims: Tuple[int, ...]):
        super().__init__()
        self.net = build_mlp(input_dim, hidden_dims)

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
        return self.net(x)


class ImmuneEncoder(nn.Module):
    """Encoder for immune markers and cytokines."""

    def __init__(self, input_dim: int, hidden_dims: Tuple[int, ...]):
        super().__init__()
        self.net = build_mlp(input_dim, hidden_dims)

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
        return self.net(x)


class MitoEncoder(nn.Module):
    """Encoder for mitochondrial features."""

    def __init__(self, input_dim: int, hidden_dims: Tuple[int, ...]):
        super().__init__()
        self.net = build_mlp(input_dim, hidden_dims)

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
        return self.net(x)


class ClinicalEncoder(nn.Module):
    """Encoder for clinical and lifestyle covariates."""

    def __init__(self, input_dim: int, hidden_dims: Tuple[int, ...]):
        super().__init__()
        self.net = build_mlp(input_dim, hidden_dims)

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
        return self.net(x)


def encoder_dispatch(modality: str, input_dim: int, hidden_dims: Tuple[int, ...]) -> nn.Module:
    """Return appropriate encoder module for a given modality name."""
    mapping: Dict[str, nn.Module] = {
        "methylation": MethylationEncoder(input_dim, hidden_dims),
        "immune": ImmuneEncoder(input_dim, hidden_dims),
        "mitochondrial": MitoEncoder(input_dim, hidden_dims),
        "clinical": ClinicalEncoder(input_dim, hidden_dims),
    }
    if modality not in mapping:
        raise ValueError(f"Unsupported modality {modality}")
    return mapping[modality]


class ImmunoFoundationModel(nn.Module):
    """Multimodal foundation model with modality-specific encoders and multitask heads.

    Outputs:
        - reproductive_age: continuous regression
        - oocyte_yield: continuous regression
        - endometriosis_presence: binary classification (logits)
        - endometriosis_stage: optional continuous/ordinal regression (if enabled)
    """

    def __init__(
        self,
        input_dims: Dict[str, int],
        encoder_hidden: Tuple[int, ...] = (128, 64),
        backbone_hidden: Tuple[int, ...] = (128, 64),
        dropout: float = 0.1,
        enable_stage: bool = True,
    ):
        super().__init__()
        self.encoders = nn.ModuleDict()
        latent_dims = 0
        for modality, input_dim in input_dims.items():
            encoder = encoder_dispatch(modality, input_dim, encoder_hidden)
            self.encoders[modality] = encoder
            latent_dims += encoder_hidden[-1] if encoder_hidden else input_dim

        self.backbone = build_mlp(latent_dims, backbone_hidden, dropout=dropout)
        backbone_out = backbone_hidden[-1] if backbone_hidden else latent_dims

        self.heads = nn.ModuleDict(
            {
                "reproductive_age": nn.Linear(backbone_out, 1),
                "oocyte_yield": nn.Linear(backbone_out, 1),
                "endometriosis_presence": nn.Linear(backbone_out, 1),
            }
        )
        if enable_stage:
            self.heads["endometriosis_stage"] = nn.Linear(backbone_out, 1)

    def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:  # type: ignore[override]
        latent: List[torch.Tensor] = []
        for modality, encoder in self.encoders.items():
            latent.append(encoder(inputs[modality]))
        fused = torch.cat(latent, dim=1)
        shared = self.backbone(fused)
        outputs: Dict[str, torch.Tensor] = {}
        for task, head in self.heads.items():
            outputs[task] = head(shared)
        return outputs


class MaskedFeatureModel(nn.Module):
    """Self-supervised module for masked feature reconstruction."""

    def __init__(self, foundation: ImmunoFoundationModel, input_dims: Dict[str, int]):
        super().__init__()
        self.foundation = foundation
        self.reconstructors = nn.ModuleDict()
        for mod, dim in input_dims.items():
            encoder = foundation.encoders[mod]
            latent_dim: Optional[int] = None
            # Infer latent dimension from the last Linear layer in encoder.net
            for layer in reversed(encoder.net):  # type: ignore[attr-defined]
                if isinstance(layer, nn.Linear):
                    latent_dim = layer.out_features
                    break
            if latent_dim is None:
                raise ValueError(f"Unable to infer latent dimension for modality {mod}")
            self.reconstructors[mod] = nn.Linear(latent_dim, dim)

    def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:  # type: ignore[override]
        latents = {mod: encoder(inputs[mod]) for mod, encoder in self.foundation.encoders.items()}
        reconstructions = {mod: self.reconstructors[mod](latent) for mod, latent in latents.items()}
        return reconstructions


# ---------------------------------------------------------------------------
# Metrics
# ---------------------------------------------------------------------------


def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
    """Return MAE, RMSE, and R^2 for regression."""
    mae = metrics.mean_absolute_error(y_true, y_pred)
    rmse = metrics.mean_squared_error(y_true, y_pred, squared=False)
    r2 = metrics.r2_score(y_true, y_pred)
    return {"mae": mae, "rmse": rmse, "r2": r2}


def classification_metrics(y_true: np.ndarray, y_pred_logits: np.ndarray) -> Dict[str, float]:
    """Return AUROC, AUPRC, accuracy, sensitivity, and specificity for binary classification."""
    y_prob = 1 / (1 + np.exp(-y_pred_logits))
    if len(np.unique(y_true)) > 1:
        auroc = metrics.roc_auc_score(y_true, y_prob)
    else:
        auroc = float("nan")
    precision, recall, _ = metrics.precision_recall_curve(y_true, y_prob)
    auprc = metrics.auc(recall, precision)
    preds = (y_prob >= 0.5).astype(int)
    acc = metrics.accuracy_score(y_true, preds)
    tn, fp, fn, tp = metrics.confusion_matrix(y_true, preds, labels=[0, 1]).ravel()
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else float("nan")
    specificity = tn / (tn + fp) if (tn + fp) > 0 else float("nan")
    return {
        "auroc": auroc,
        "auprc": auprc,
        "accuracy": acc,
        "sensitivity": sensitivity,
        "specificity": specificity,
    }


# ---------------------------------------------------------------------------
# Training engine
# ---------------------------------------------------------------------------


@dataclass
class TrainingConfig:
    epochs: int = 5
    lr: float = 1e-3
    weight_decay: float = 1e-4
    batch_size: int = 32
    device: str = "cpu"
    grad_clip: float = 1.0
    patience: int = 3
    modality_dropout: float = 0.0  # placeholder if you want to expand


class Trainer:
    """Trainer handling multitask supervised learning."""

    def __init__(self, model: ImmunoFoundationModel, config: TrainingConfig):
        self.model = model
        self.config = config
        self.device = torch.device(config.device)
        self.model.to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
        self.lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode="min",
            patience=2,
            factor=0.5,
        )
        self.loss_fns: Dict[str, nn.Module] = {
            "reproductive_age": nn.MSELoss(),
            "oocyte_yield": nn.MSELoss(),
            "endometriosis_presence": nn.BCEWithLogitsLoss(),
            "endometriosis_stage": nn.MSELoss(),
        }

    def _compute_loss(self, outputs: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor]) -> torch.Tensor:
        losses: List[torch.Tensor] = []
        for task, pred in outputs.items():
            if task not in batch:
                continue
            target = batch[task].float().to(self.device)
            pred = pred.view_as(target)
            losses.append(self.loss_fns[task](pred, target))
        if not losses:
            raise ValueError("No losses computed; check task configuration")
        return torch.stack(losses).mean()

    def train_epoch(self, loader: DataLoader) -> float:
        self.model.train()
        running_loss = 0.0
        for batch in loader:
            inputs = {k: v.to(self.device) for k, v in batch.items() if k not in self.loss_fns}
            targets = {k: v for k, v in batch.items() if k in self.loss_fns}
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self._compute_loss(outputs, targets)
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)
            self.optimizer.step()
            running_loss += loss.item()
        return running_loss / len(loader)

    @torch.no_grad()
    def evaluate(self, loader: DataLoader) -> Dict[str, float]:
        self.model.eval()
        all_outputs: Dict[str, List[np.ndarray]] = {task: [] for task in self.loss_fns}
        all_targets: Dict[str, List[np.ndarray]] = {task: [] for task in self.loss_fns}
        total_loss = 0.0
        batches = 0
        for batch in loader:
            inputs = {k: v.to(self.device) for k, v in batch.items() if k not in self.loss_fns}
            targets = {k: v for k, v in batch.items() if k in self.loss_fns}
            outputs = self.model(inputs)
            loss = self._compute_loss(outputs, targets)
            total_loss += loss.item()
            batches += 1
            for task, pred in outputs.items():
                if task not in targets:
                    continue
                all_outputs[task].append(pred.cpu().numpy())
                all_targets[task].append(targets[task].cpu().numpy())
        metrics_report: Dict[str, float] = {"loss": total_loss / max(1, batches)}
        for task, preds in all_outputs.items():
            if not preds or not all_targets[task]:
                continue
            y_true = np.concatenate(all_targets[task])
            y_pred = np.concatenate(preds)
            if task in {"reproductive_age", "oocyte_yield", "endometriosis_stage"}:
                task_metrics = regression_metrics(y_true, y_pred)
            else:
                task_metrics = classification_metrics(y_true, y_pred)
            for k, v in task_metrics.items():
                metrics_report[f"{task}_{k}"] = v
        return metrics_report

    def fit(self, train_loader: DataLoader, val_loader: DataLoader, output_dir: Path) -> Dict[str, float]:
        best_val = float("inf")
        patience_counter = 0
        best_state: Optional[Dict[str, torch.Tensor]] = None
        output_dir.mkdir(parents=True, exist_ok=True)
        val_metrics: Dict[str, float] = {}
        for epoch in range(self.config.epochs):
            train_loss = self.train_epoch(train_loader)
            val_metrics = self.evaluate(val_loader)
            self.lr_scheduler.step(val_metrics["loss"])
            logger.info(
                "Epoch %d train_loss=%.4f val_loss=%.4f",
                epoch + 1,
                train_loss,
                val_metrics["loss"],
            )
            if val_metrics["loss"] < best_val:
                best_val = val_metrics["loss"]
                best_state = {k: v.cpu() for k, v in self.model.state_dict().items()}
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= self.config.patience:
                    logger.info("Early stopping at epoch %d", epoch + 1)
                    break
        if best_state:
            torch.save(best_state, output_dir / "best_model.pt")
        return val_metrics


def pretrain_masked_reconstruction(
    model: ImmunoFoundationModel,
    loader: DataLoader,
    input_dims: Dict[str, int],
    device: str,
    epochs: int,
    mask_prob: float = 0.15,
) -> None:
    """Simple masked feature reconstruction pretraining."""
    mask_module = MaskedFeatureModel(model, input_dims).to(device)
    optimizer = optim.Adam(mask_module.parameters(), lr=1e-3)
    loss_fn = nn.MSELoss()
    for epoch in range(epochs):
        epoch_loss = 0.0
        for batch in loader:
            inputs = {k: v.to(device) for k, v in batch.items() if k not in model.heads}
            masked_inputs: Dict[str, torch.Tensor] = {}
            targets: Dict[str, torch.Tensor] = {}
            for modality, tensor in inputs.items():
                mask = torch.rand_like(tensor) < mask_prob
                masked = tensor.clone()
                masked[mask] = 0.0
                masked_inputs[modality] = masked
                targets[modality] = tensor
            optimizer.zero_grad()
            reconstructions = mask_module(masked_inputs)
            losses = [loss_fn(reconstructions[m], targets[m]) for m in reconstructions]
            loss = torch.stack(losses).mean()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        logger.info("Pretrain epoch %d loss=%.4f", epoch + 1, epoch_loss / len(loader))


# ---------------------------------------------------------------------------
# Interpretability with Integrated Gradients
# ---------------------------------------------------------------------------


def compute_integrated_gradients(
    model: torch.nn.Module,
    inputs: Dict[str, torch.Tensor],
    target_head: str,
    baseline: float = 0.0,
    n_steps: int = 32,
) -> Dict[str, torch.Tensor]:
    """Compute integrated gradients for each modality.

    Args:
        model: Trained ImmunoFoundationModel.
        inputs: Dictionary of modality -> tensor with batch dimension.
        target_head: Task head name to explain.
        baseline: Baseline value for IG.
        n_steps: Number of interpolation steps.
    """
    model.eval()
    inputs_clone = {k: v.clone().requires_grad_(True) for k, v in inputs.items()}

    def forward_func(*tensors: torch.Tensor) -> torch.Tensor:
        packed = {name: tensor for name, tensor in zip(inputs_clone.keys(), tensors)}
        outputs = model(packed)
        if target_head not in outputs:
            raise ValueError(f"Task head {target_head} not found in model outputs")
        return outputs[target_head]

    ig = IntegratedGradients(forward_func)
    tensor_inputs = tuple(inputs_clone.values())
    baselines = tuple(torch.full_like(t, baseline) for t in tensor_inputs)
    ig_values = ig.attribute(tensor_inputs, baselines=baselines, n_steps=n_steps)
    attributions: Dict[str, torch.Tensor] = {}
    for name, attr in zip(inputs_clone.keys(), ig_values):
        attributions[name] = attr.detach()
    return attributions


def summarize_feature_importance(attributions: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    """Aggregate absolute attributions per feature for global ranking."""
    summary = {modality: attr.abs().mean(dim=0) for modality, attr in attributions.items()}
    return summary


# ---------------------------------------------------------------------------
# High-level helpers for training/inference (CLI-style but callable in notebook)
# ---------------------------------------------------------------------------


def prepare_targets(config: Dict[str, Any], df: pd.DataFrame) -> Dict[str, np.ndarray]:
    """Extract target arrays from dataframe based on config['data']['tasks']."""
    targets: Dict[str, np.ndarray] = {}
    tasks = config["data"].get("tasks", {})
    for task_name, column in tasks.items():
        if column is None:
            continue
        if column not in df.columns:
            raise ValueError(f"Expected target column {column} for task {task_name}")
        targets[task_name] = df[column].to_numpy()
    return targets


def build_dataloaders(
    config: Dict[str, Any],
    preprocessor: DataPreprocessor,
    train_df: pd.DataFrame,
    val_df: pd.DataFrame,
    test_df: pd.DataFrame,
) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """Build train/val/test dataloaders from preprocessed dataframes."""
    train_inputs = preprocessor.transform(train_df)
    val_inputs = preprocessor.transform(val_df)
    test_inputs = preprocessor.transform(test_df)

    targets_train = prepare_targets(config, train_df)
    targets_val = prepare_targets(config, val_df)
    targets_test = prepare_targets(config, test_df)

    train_ds = MultiModalDataset(train_inputs, targets_train)
    val_ds = MultiModalDataset(val_inputs, targets_val)
    test_ds = MultiModalDataset(test_inputs, targets_test)

    batch_size = config["training"].get("batch_size", 32)
    drop_last = config["training"].get("drop_last", False)
    collator = BatchCollator()
    return (
        DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collator, drop_last=drop_last),
        DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=collator, drop_last=False),
        DataLoader(test_ds, batch_size=batch_size, shuffle=False, collate_fn=collator, drop_last=False),
    )


def run_training(config_path: Path) -> None:
    """End-to-end training using a YAML config file (can be called from a notebook)."""
    config = load_config(config_path)
    configure_logging()
    set_seed(config.get("seed", 42))
    artifacts_dir = Path(config.get("artifacts_dir", "artifacts"))
    artifacts_dir.mkdir(parents=True, exist_ok=True)

    schema = create_schema_from_config(config)
    preprocessor = DataPreprocessor(schema, artifacts_dir / "preprocess")

    modality_frames: Dict[str, pd.DataFrame] = {}
    task_columns = [col for col in config["data"].get("tasks", {}).values() if col]
    for name, modality_cfg in config["data"]["modalities"].items():
        modality_frames[name] = preprocessor.load_modality(
            name,
            Path(modality_cfg["path"]),
            extra_columns=task_columns,
        )
    merged = preprocessor.merge_modalities(modality_frames)

    train_df, val_df, test_df = preprocessor.split(
        merged,
        test_size=config["data"].get("test_size", 0.2),
        val_size=config["data"].get("val_size", 0.1),
        random_state=config.get("seed", 42),
    )

    preprocessor.fit(train_df)
    preprocessor.save_artifacts()

    train_loader, val_loader, test_loader = build_dataloaders(config, preprocessor, train_df, val_df, test_df)

    input_dims = {
        name: train_loader.dataset.inputs[name].shape[1]  # type: ignore[arg-type]
        for name in train_loader.dataset.inputs  # type: ignore[attr-defined]
    }
    model = ImmunoFoundationModel(
        input_dims=input_dims,
        encoder_hidden=tuple(config["model"].get("encoder_hidden", [128, 64])),
        backbone_hidden=tuple(config["model"].get("backbone_hidden", [128, 64])),
        dropout=config["model"].get("dropout", 0.1),
        enable_stage=config["data"].get("tasks", {}).get("endometriosis_stage") is not None,
    )

    training_cfg = TrainingConfig(
        epochs=config["training"].get("epochs", 5),
        lr=config["training"].get("lr", 1e-3),
        weight_decay=config["training"].get("weight_decay", 1e-4),
        batch_size=config["training"].get("batch_size", 32),
        device=config["training"].get("device", "cpu"),
        grad_clip=config["training"].get("grad_clip", 1.0),
        patience=config["training"].get("patience", 3),
        modality_dropout=config["training"].get("modality_dropout", 0.0),
    )

    trainer = Trainer(model, training_cfg)

    if config["training"].get("pretrain_masked", False):
        pretrain_masked_reconstruction(
            model,
            train_loader,
            input_dims=input_dims,
            device=training_cfg.device,
            epochs=config["training"].get("pretrain_epochs", 3),
            mask_prob=config["training"].get("mask_prob", 0.15),
        )

    val_metrics = trainer.fit(train_loader, val_loader, artifacts_dir / "checkpoints")
    test_metrics = trainer.evaluate(test_loader)

    with open(artifacts_dir / "metrics.json", "w", encoding="utf-8") as f:
        json.dump({"val": val_metrics, "test": test_metrics}, f, indent=2)
    logger.info("Saved metrics to %s", artifacts_dir / "metrics.json")


def run_inference(config_path: Path, input_path: Path, model_checkpoint: Path) -> Dict[str, List[float]]:
    """Run inference on new data using existing preprocessing artifacts and checkpoint."""
    config = load_config(config_path)
    configure_logging()
    schema = create_schema_from_config(config)
    preprocessor = DataPreprocessor(schema, Path(config.get("artifacts_dir", "artifacts")) / "preprocess")
    preprocessor.load_artifacts()

    modality_frames: Dict[str, pd.DataFrame] = {}
    for name, modality_cfg in config["data"]["modalities"].items():
        # Assume files share the same filename as in training config, but under input_path
        original_path = Path(modality_cfg["path"])
        filename = original_path.name
        modality_frames[name] = preprocessor.load_modality(name, input_path / filename)

    merged = preprocessor.merge_modalities(modality_frames)
    inputs = preprocessor.transform(merged)

    input_dims = {name: arr.shape[1] for name, arr in inputs.items()}
    model = ImmunoFoundationModel(
        input_dims=input_dims,
        encoder_hidden=tuple(config["model"].get("encoder_hidden", [128, 64])),
        backbone_hidden=tuple(config["model"].get("backbone_hidden", [128, 64])),
        dropout=config["model"].get("dropout", 0.1),
        enable_stage=config["data"].get("tasks", {}).get("endometriosis_stage") is not None,
    )
    state = torch.load(model_checkpoint, map_location="cpu")
    model.load_state_dict(state)
    model.eval()

    tensors = {k: torch.tensor(v).float() for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(tensors)
    results = {k: np.squeeze(v.numpy()).tolist() for k, v in outputs.items()}
    print(json.dumps(results, indent=2))
    return results


# ---------------------------------------------------------------------------
# Example configs as Python dicts (you can save them to YAML if you want)
# ---------------------------------------------------------------------------

example_endometriosis_config: Dict[str, Any] = {
    "seed": 7,
    "artifacts_dir": "artifacts/endometriosis",
    "data": {
        "test_size": 0.2,
        "val_size": 0.1,
        "modalities": {
            "clinical": {
                "path": "data/patients.csv",
                "id_column": "patient_id",
                "feature_columns": ["age_years", "bmi", "amh_ng_ml", "afc"],
                "categorical_columns": ["ethnicity", "smoking_status", "parity"],
                "column_mapping": {},
            },
            "methylation": {
                "path": "data/methylation.parquet",
                "id_column": "patient_id",
                "feature_columns": [],  # TODO: fill with real CpG/region features
            },
            "immune": {
                "path": "data/immune.csv",
                "id_column": "patient_id",
                "feature_columns": [],  # TODO: fill with real immune markers
            },
            "mitochondrial": {
                "path": "data/mitochondrial.csv",
                "id_column": "patient_id",
                "feature_columns": [],  # TODO: fill with mitochondrial scores
            },
        },
        "tasks": {
            "reproductive_age": None,
            "oocyte_yield": None,
            "endometriosis_presence": "clinical_endometriosis",
            "endometriosis_stage": "endometriosis_stage",  # TODO: ensure clinical staging schema
        },
    },
    "model": {
        "encoder_hidden": [64, 32],
        "backbone_hidden": [64, 32],
        "dropout": 0.15,
    },
    "training": {
        "epochs": 8,
        "lr": 0.0005,
        "weight_decay": 0.0001,
        "batch_size": 16,
        "device": "cpu",
        "grad_clip": 1.0,
        "patience": 3,
        "pretrain_masked": False,
    },
}

example_reproductive_aging_config: Dict[str, Any] = {
    "seed": 42,
    "artifacts_dir": "artifacts/reproductive_aging",
    "data": {
        "test_size": 0.2,
        "val_size": 0.1,
        "modalities": {
            "clinical": {
                "path": "data/patients.csv",
                "id_column": "patient_id",
                "feature_columns": [
                    "age_years",
                    "bmi",
                    "amh_ng_ml",
                    "afc",
                    "oocyte_yield",
                ],
                "categorical_columns": [
                    "ethnicity",
                    "smoking_status",
                    "parity",
                    "cycle_phase",
                ],
                "column_mapping": {},
            },
            "methylation": {
                "path": "data/methylation.parquet",
                "id_column": "patient_id",
                "feature_columns": [],  # TODO: CpG or region aggregate columns
            },
            "immune": {
                "path": "data/immune.csv",
                "id_column": "patient_id",
                "feature_columns": [],  # TODO: immune markers and cytokines
            },
            "mitochondrial": {
                "path": "data/mitochondrial.csv",
                "id_column": "patient_id",
                "feature_columns": [],  # TODO: mitochondrial scores
            },
        },
        "tasks": {
            "reproductive_age": "reproductive_age_label",  # TODO: replace with real column
            "oocyte_yield": "oocyte_yield",
            "endometriosis_presence": "clinical_endometriosis",
            "endometriosis_stage": None,
        },
    },
    "model": {
        "encoder_hidden": [128, 64],
        "backbone_hidden": [128, 64],
        "dropout": 0.2,
    },
    "training": {
        "epochs": 10,
        "lr": 0.001,
        "weight_decay": 0.0001,
        "batch_size": 32,
        "device": "cpu",
        "grad_clip": 1.0,
        "patience": 3,
        "pretrain_masked": True,
        "pretrain_epochs": 2,
        "mask_prob": 0.15,
    },
}


def save_config_to_yaml(config: Dict[str, Any], path: Path) -> None:
    """Utility to save one of the example configs to a YAML file."""
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        yaml.safe_dump(config, f)


# ---------------------------------------------------------------------------
# Synthetic pipeline demo (non-biological, just to validate code paths)
# ---------------------------------------------------------------------------


def run_synthetic_pipeline_demo(base_dir: Path = Path("synthetic_demo")) -> None:
    """Create tiny synthetic tables, run preprocessing + one epoch of training.

    This uses non-biological placeholder data to validate shapes and code paths.
    """
    base_dir.mkdir(parents=True, exist_ok=True)
    data_dir = base_dir / "data"
    data_dir.mkdir(exist_ok=True)

    # Synthetic tables
    patient_ids = [f"p{i}" for i in range(12)]
    clinical = pd.DataFrame(
        {
            "patient_id": patient_ids,
            "age_years": np.linspace(30, 40, len(patient_ids)),
            "bmi": np.linspace(20, 28, len(patient_ids)),
            "amh_ng_ml": np.linspace(1, 3, len(patient_ids)),
            "afc": np.linspace(10, 20, len(patient_ids)),
            "oocyte_yield": np.linspace(5, 15, len(patient_ids)),
            "clinical_endometriosis": np.random.randint(0, 2, len(patient_ids)),
            "repro_age_target": np.linspace(32, 45, len(patient_ids)),
        }
    )
    methylation = pd.DataFrame(
        {
            "patient_id": patient_ids,
            "meth_1": np.random.rand(len(patient_ids)),
            "meth_2": np.random.rand(len(patient_ids)),
        }
    )
    immune = pd.DataFrame(
        {
            "patient_id": patient_ids,
            "immune_1": np.random.rand(len(patient_ids)),
            "immune_2": np.random.rand(len(patient_ids)),
        }
    )
    mitochondrial = pd.DataFrame(
        {
            "patient_id": patient_ids,
            "mito_1": np.random.rand(len(patient_ids)),
        }
    )

    paths = {
        "clinical": data_dir / "patients.csv",
        "methylation": data_dir / "methylation.csv",
        "immune": data_dir / "immune.csv",
        "mitochondrial": data_dir / "mitochondrial.csv",
    }
    clinical.to_csv(paths["clinical"], index=False)
    methylation.to_csv(paths["methylation"], index=False)
    immune.to_csv(paths["immune"], index=False)
    mitochondrial.to_csv(paths["mitochondrial"], index=False)

    config = {
        "seed": 123,
        "artifacts_dir": str(base_dir / "artifacts"),
        "data": {
            "test_size": 0.2,
            "val_size": 0.1,
            "modalities": {
                "clinical": {
                    "path": str(paths["clinical"]),
                    "id_column": "patient_id",
                    "feature_columns": [
                        "age_years",
                        "bmi",
                        "amh_ng_ml",
                        "afc",
                        "oocyte_yield",
                    ],
                    "categorical_columns": [],
                    "column_mapping": {},
                },
                "methylation": {
                    "path": str(paths["methylation"]),
                    "id_column": "patient_id",
                    "feature_columns": ["meth_1", "meth_2"],
                },
                "immune": {
                    "path": str(paths["immune"]),
                    "id_column": "patient_id",
                    "feature_columns": ["immune_1", "immune_2"],
                },
                "mitochondrial": {
                    "path": str(paths["mitochondrial"]),
                    "id_column": "patient_id",
                    "feature_columns": ["mito_1"],
                },
            },
            "tasks": {
                "reproductive_age": "repro_age_target",
                "oocyte_yield": "oocyte_yield",
                "endometriosis_presence": "clinical_endometriosis",
                "endometriosis_stage": None,
            },
        },
        "model": {
            "encoder_hidden": [32, 16],
            "backbone_hidden": [32, 16],
            "dropout": 0.1,
        },
        "training": {
            "epochs": 1,
            "lr": 1e-3,
            "batch_size": 4,
            "device": "cpu",
            "patience": 2,
        },
    }

    schema = create_schema_from_config(config)
    preprocessor = DataPreprocessor(schema, Path(config["artifacts_dir"]) / "preprocess")

    task_columns = [col for col in config["data"]["tasks"].values() if col]
    modality_frames = {
        name: preprocessor.load_modality(name, Path(mod_cfg["path"]), extra_columns=task_columns)
        for name, mod_cfg in config["data"]["modalities"].items()
    }
    merged = preprocessor.merge_modalities(modality_frames)
    train_df, val_df, test_df = preprocessor.split(merged, test_size=0.2, val_size=0.1, random_state=123)
    preprocessor.fit(train_df)

    train_inputs = preprocessor.transform(train_df)
    targets = {
        task: train_df[col].to_numpy()
        for task, col in config["data"]["tasks"].items()
        if col is not None
    }

    dataset = MultiModalDataset(train_inputs, targets)
    loader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=BatchCollator(), drop_last=True)

    input_dims = {name: arr.shape[1] for name, arr in train_inputs.items()}
    model = ImmunoFoundationModel(
        input_dims=input_dims,
        encoder_hidden=(16, 8),
        backbone_hidden=(16, 8),
        dropout=0.1,
        enable_stage=False,
    )
    trainer = Trainer(model, TrainingConfig(epochs=1, batch_size=2))

    initial_outputs = model(
        {
            k: torch.tensor(v[:2]).float()
            for k, v in train_inputs.items()
        }
    )
    assert set(initial_outputs.keys()) >= {"reproductive_age", "oocyte_yield", "endometriosis_presence"}
    for tensor in initial_outputs.values():
        assert tensor.shape[0] == 2

    train_loss = trainer.train_epoch(loader)
    print(f"Synthetic demo training loss: {train_loss:.4f}")


# ---------------------------------------------------------------------------
# Optional CLI entry point (works in a .py script; harmless in notebook)
# ---------------------------------------------------------------------------


def main() -> None:
    parser = argparse.ArgumentParser(description="Train or run inference with the ImmunoFoundationModel")
    parser.add_argument("--config", type=Path, required=False, help="Path to YAML configuration file")
    parser.add_argument("--mode", choices=["train", "inference"], default="train")
    parser.add_argument("--input_dir", type=Path, help="Directory with modality files for inference")
    parser.add_argument("--checkpoint", type=Path, help="Model checkpoint for inference")
    parser.add_argument("--synthetic_demo", action="store_true", help="Run synthetic pipeline demo instead")
    args = parser.parse_args()

    if args.synthetic_demo:
        run_synthetic_pipeline_demo()
        return

    if args.config is None:
        raise ValueError("Please provide --config or use --synthetic_demo")

    if args.mode == "train":
        run_training(args.config)
    else:
        if args.input_dir is None or args.checkpoint is None:
            raise ValueError("Inference mode requires --input_dir and --checkpoint")
        run_inference(args.config, args.input_dir, args.checkpoint)


if __name__ == "__main__":
    # In a notebook this block will not be executed unless you explicitly run
    # %run this_script.py from the terminal. Safe to leave as-is.
    main()
