In [None]:
from __future__ import annotations

from pathlib import Path
from typing import Tuple, Dict, Any, Callable

import numpy as np
import torch
from numpy.typing import ArrayLike

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    roc_auc_score,
    average_precision_score,
    precision_recall_curve,
)

from pyod.models.ecod import ECOD

from usflows import Flow

# ---------------------------------------------------------------------------
# Global configuration
# ---------------------------------------------------------------------------

RANDOM_STATE: int = 42
DATA_DIR: Path = Path("./nf4ad/data/adbench")  # assumes ./data contains the Classical .npz files


# ---------------------------------------------------------------------------
# Dataset resolution & loading (unchanged)
# ---------------------------------------------------------------------------

def resolve_npz_path(dataset_name: str, data_dir: Path = DATA_DIR) -> Path:
    dataset_name = dataset_name.strip()
    data_dir = Path(data_dir)

    if not data_dir.is_dir():
        raise FileNotFoundError(
            f"Data directory {data_dir.resolve()} does not exist. "
            "Make sure you created it and copied the ADBench .npz files into it."
        )

    candidate = data_dir / dataset_name
    if candidate.is_file():
        return candidate

    if not dataset_name.endswith(".npz"):
        candidate_with_ext = data_dir / f"{dataset_name}.npz"
        if candidate_with_ext.is_file():
            return candidate_with_ext

    candidates = []

    if dataset_name.isdigit():
        prefix = f"{dataset_name}_"
        candidates = [p for p in data_dir.glob("*.npz") if p.name.startswith(prefix)]
    else:
        norm = dataset_name.lower()
        for p in data_dir.glob("*.npz"):
            stem = p.stem.lower()
            if stem == norm:
                candidates.append(p)
                continue
            if "_" in stem:
                _, suffix = stem.split("_", 1)
                if suffix == norm:
                    candidates.append(p)

    if not candidates:
        raise FileNotFoundError(
            f"Could not match dataset name '{dataset_name}' to any .npz file in "
            f"{data_dir.resolve()}."
        )

    if len(candidates) > 1:
        names = ", ".join(sorted(p.name for p in candidates))
        raise RuntimeError(
            f"Dataset name '{dataset_name}' is ambiguous; it matches multiple files: "
            f"{names}. Please specify a more precise name."
        )

    return candidates[0]


def load_classical_dataset(
    dataset_name: str,
    data_dir: Path = DATA_DIR,
) -> Tuple[np.ndarray, np.ndarray]:
    npz_path = resolve_npz_path(dataset_name, data_dir)
    npz = np.load(npz_path, allow_pickle=True)

    X = npz["X"]
    y = npz["y"].astype(int)

    print(
        f"Loaded {npz_path.name}: X.shape={X.shape}, "
        f"y.shape={y.shape}, anomaly_ratio={y.mean():.4f}"
    )

    return X, y


# ---------------------------------------------------------------------------
# Metrics (unchanged)
# ---------------------------------------------------------------------------


def evaluate_anomaly_scores(
    y_true: ArrayLike,
    scores: ArrayLike,
) -> Dict[str, float]:
    y_true = np.asarray(y_true).astype(int).ravel()
    scores = np.asarray(scores, dtype=float).ravel()

    if y_true.shape[0] != scores.shape[0]:
        raise ValueError(
            f"y_true and scores must have the same length, "
            f"got {y_true.shape[0]} and {scores.shape[0]}."
        )

    if np.unique(y_true).size < 2:
        raise ValueError(
            "y_true must contain both normal (0) and anomalous (1) labels."
        )

    metrics: Dict[str, float] = {}
    metrics["auc_roc"] = float(roc_auc_score(y_true, scores))
    metrics["auc_pr"] = float(average_precision_score(y_true, scores))

    precision, recall, thresholds = precision_recall_curve(y_true, scores)
    f1 = 2 * precision * recall / (precision + recall + 1e-12)
    best_idx = int(np.argmax(f1))
    metrics["best_f1"] = float(f1[best_idx])

    if thresholds.size > 0 and best_idx < thresholds.size:
        metrics["best_f1_threshold"] = float(thresholds[best_idx])
    else:
        metrics["best_f1_threshold"] = float("nan")

    return metrics


# ---------------------------------------------------------------------------
# Your method stub (unchanged)
# ---------------------------------------------------------------------------

def create_flow_prior(latent_dim, device: torch.device):
    """Helper to create flow prior with specified latent dimension."""
    from nf4ad.flows import NonUSFlow
    import pyro.distributions as dist
    import torch.nn as nn
    
    base_dist = dist.Normal(
        torch.zeros(latent_dim).to(device),
        torch.ones(latent_dim).to(device)
    )
    
    # Simple MLP conditioner for testing
    class SimpleConditioner(nn.Module):
        def __init__(self, in_dim, out_dim):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(in_dim, 128),
                nn.ReLU(),
                nn.Linear(128, out_dim),
            )
        
        def forward(self, x):
            return self.net(x)
    
    # from U import 
    
    flow = NonUSFlow(
        in_dims=[latent_dim],
        device=device,
        coupling_blocks=3,
        base_distribution=base_dist,
        prior_scale=1.0,
        affine_conjugation=True,
        conditioner_cls=SimpleConditioner,
        conditioner_args={
            'in_dim': latent_dim,
            'out_dim': latent_dim * 2,  # For affine coupling: scale + shift
        },
        nonlinearity=nn.ReLU(),
    )
    
    return flow


def your_method_scores(
    X_train: np.ndarray,
    y_train: np.ndarray,
    X_test: np.ndarray,
) -> np.ndarray:
    """
    Template wrapper for flow anomaly detection method.

    Implement s.t.:
        1) Fits your model on (X_train, y_train) in the appropriate way
           (for unsupervised methods you may ignore y_train).
        2) Returns a 1D array of anomaly scores for X_test, where higher
           scores mean "more anomalous".

    Returns
    -------
    scores : np.ndarray of shape (n_test,)
        Anomaly scores for X_test.
    """
    # Get number of features
    n_features = X_train.shape[1]
    
    from nf4ad.adbench_wrapper import ADBenchVAEFlowTabular
    
    vaeflow = ADBenchVAEFlowTabular(
        flow_prior=create_flow_prior(n_features, device=torch.device('cuda')),
        n_features=5,
        latent_dim=n_features,
    )
    vaeflow.fit(X_train, y_train)
    return vaeflow.predict_score(X_test)


# ---------------------------------------------------------------------------
# PyOD baseline (ECOD)
# ---------------------------------------------------------------------------

def run_pyod_ecod_baseline(
    X_train: np.ndarray,
    y_train: np.ndarray,
    X_test: np.ndarray,
) -> np.ndarray:
    """
    Run ECOD (Empirical Cumulative Distribution based Outlier Detection)
    from PyOD as an unsupervised baseline.

    This is one of the unsupervised methods ADBench includes via PyOD,
    but here we call it directly through PyOD's modern API.
    """
    clf = ECOD()          # y is ignored in unsupervised PyOD models
    clf.fit(X_train)      # fit on train set
    scores = clf.decision_function(X_test)  # higher = more anomalous

    return np.asarray(scores, dtype=float).ravel()


# ---------------------------------------------------------------------------
# High-level helper to run everything on one dataset (slightly tweaked)
# ---------------------------------------------------------------------------

def run_single_dataset_example(
    dataset_name: str = "cardio",
    data_dir: Path = DATA_DIR,
    use_baseline: bool = True,
    method_fn: Callable[[np.ndarray, np.ndarray, np.ndarray], np.ndarray] | None = None,
) -> Dict[str, Any]:
    """
    Pipeline:
      - load dataset from .npz
      - split into train/test
      - standardize features
      - run either ECOD baseline or your custom method
      - compute metrics
    """
    X, y = load_classical_dataset(dataset_name, data_dir=data_dir)

    X_train, X_test, y_train, y_test = train_test_split(
        X,
        y,
        test_size=0.5,
        random_state=RANDOM_STATE,
        stratify=y,
    )

    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    if use_baseline:
        scores = run_pyod_ecod_baseline(X_train_scaled, y_train, X_test_scaled)
        model_used = "PyOD-ECOD"
    else:
        if method_fn is None:
            raise ValueError(
                "use_baseline=False but no method_fn was provided. "
                "Pass your own method wrapper, e.g. method_fn=your_method_scores."
            )
        scores = method_fn(X_train_scaled, y_train, X_test_scaled)
        model_used = getattr(method_fn, "__name__", "custom_method")

    metrics = evaluate_anomaly_scores(y_test, scores)

    result: Dict[str, Any] = {
        "dataset": dataset_name,
        "n_train": int(X_train.shape[0]),
        "n_test": int(X_test.shape[0]),
        "model": model_used,
        "metrics": metrics,
    }

    print(
        f"\nResults on dataset='{dataset_name}' using model='{model_used}':\n"
        f"  AUC-ROC : {metrics['auc_roc']:.4f}\n"
        f"  AUC-PR  : {metrics['auc_pr']:.4f}\n"
        f"  best F1 : {metrics['best_f1']:.4f} "
        f"(at score threshold ≈ {metrics['best_f1_threshold']:.4f})"
    )

    return result


# ---------------------------------------------------------------------------
# Example call
# ---------------------------------------------------------------------------

example_result = run_single_dataset_example(
    dataset_name="cardio",   # or "6", "6_cardio", "6_cardio.npz"
    use_baseline=False,       # use ECOD baseline
    method_fn=your_method_scores,
)

example_result


Loaded 6_cardio.npz: X.shape=(1831, 21), y.shape=(1831,), anomaly_ratio=0.0961
Training VAEFlow on 915 samples...
Input shape: torch.Size([915, 1, 3, 3])
Device: cuda
Training completed. Final loss: 200566927941790381113344.0000

Results on dataset='cardio' using model='your_method_scores':
  AUC-ROC : 0.5125
  AUC-PR  : 0.1025
  best F1 : 0.1820 (at score threshold ≈ 209428289828075601920.0000)


{'dataset': 'cardio',
 'n_train': 915,
 'n_test': 916,
 'model': 'your_method_scores',
 'metrics': {'auc_roc': 0.5125027448397014,
  'auc_pr': 0.10251017910372012,
  'best_f1': 0.18201754385947477,
  'best_f1_threshold': 2.094282898280756e+20}}