In [None]:
# --- Per-channel AnomalyTransformer forecasting anomaly detection (unsupervised) ---
# Goal: train ONE model per chan_id on train/ .npy (normal-only),
#       predict future column 0 using history of columns 1..F-1,
#       then detect anomalies in test via high prediction error + association discrepancy.

import os
import ast
import json
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from typing import Tuple, Optional

# -----------------------------
# Config (mirrors your notebook exactly)
# -----------------------------
INPUT_DIR = Path("DataSet")
LABEL_FILE = INPUT_DIR / "labeled_anomalies.csv"
TRAIN_DIR = INPUT_DIR / "data/data/train"
TEST_DIR = INPUT_DIR / "data/data/test"

OUT_MODELS_DIR = Path("models") / "anomaly_transformer"
OUT_RESULTS_DIR = Path("results") / "anomaly_transformer"
OUT_MODELS_DIR.mkdir(parents=True, exist_ok=True)
OUT_RESULTS_DIR.mkdir(parents=True, exist_ok=True)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

# Forecasting window params (Telemanom-style)
L_S = 100               # history length
N_PRED = 1              # forecast horizon (start with 1)
STRIDE = 1              # step size in window generation

# Training params (per channel)
BATCH_SIZE = 256
EPOCHS = 50             # Transformers converge faster than CNNs
LR = 1e-4
DROPOUT = 0.2
VAL_FRAC = 0.2

# Thresholding (from validation error distribution)
THRESH_Q = 0.995        # 99.5th percentile

# Scaling mode: "none" (Telemanom data already scaled) or "standard"
SCALE_MODE = "none"

# For quick testing: set to int (e.g., 5). Use None for all channels.
MAX_CHANNELS = None

print(f"DEVICE: {DEVICE}")
print(f"Window params: {{'L_S': {L_S}, 'N_PRED': {N_PRED}, 'STRIDE': {STRIDE}}}")
print(f"Train params: {{'BATCH_SIZE': {BATCH_SIZE}, 'EPOCHS': {EPOCHS}, 'LR': {LR}, 'VAL_FRAC': {VAL_FRAC}}}")


# -----------------------------
# Data utilities (identical to your notebook)
# -----------------------------
def load_chan_train_test(chan_id: str):
    """Load (train, test) arrays for a single channel."""
    train_path = TRAIN_DIR / f"{chan_id}.npy"
    test_path = TEST_DIR / f"{chan_id}.npy"

    if not train_path.exists() or not test_path.exists():
        return None, None, None

    x_train = np.load(train_path).astype(np.float32)
    x_test = np.load(test_path).astype(np.float32)

    if x_train.ndim != 2 or x_test.ndim != 2:
        return None, None, None
    if x_train.shape[1] != x_test.shape[1]:
        return None, None, None

    if SCALE_MODE == "standard":
        scaler = StandardScaler()
        x_train = scaler.fit_transform(x_train).astype(np.float32)
        x_test = scaler.transform(x_test).astype(np.float32)
        return x_train, x_test, scaler

    return x_train, x_test, None


def build_forecast_windows(x: np.ndarray, l_s: int, n_pred: int = 1, stride: int = 1, use_other_features_only: bool = True):
    """Telemanom-style forecasting windows."""
    assert x.ndim == 2
    T, F = x.shape
    if use_other_features_only:
        if F < 2:
            return None, None, None
        x_in = x[:, 1:]
    else:
        x_in = x

    max_i = T - l_s - n_pred
    if max_i <= 0:
        return None, None, None

    X_list = []
    y_list = []
    for i in range(0, max_i, stride):
        X_list.append(x_in[i : i + l_s])
        y_list.append(x[i + l_s : i + l_s + n_pred, 0])

    X = np.stack(X_list).astype(np.float32)
    y = np.stack(y_list).astype(np.float32)
    t0 = l_s
    return X, y, t0


def anomaly_vector_from_sequences(T: int, anomaly_sequences) -> np.ndarray:
    """Build point-wise 0/1 anomaly vector from CSV sequences."""
    y = np.zeros(T, dtype=np.int64)
    for start, end in anomaly_sequences:
        start = max(0, int(start))
        end = min(T, int(end))
        if start < end:
            y[start:end] = 1
    return y


def contiguous_sequences_from_flags(flags: np.ndarray, offset: int = 0):
    """Convert boolean flags into [start,end) sequences."""
    flags = np.asarray(flags).astype(bool)
    seqs = []
    in_run = False
    run_start = 0
    for i, f in enumerate(flags):
        if f and not in_run:
            in_run = True
            run_start = i
        elif (not f) and in_run:
            in_run = False
            seqs.append([run_start + offset, i + offset])
    if in_run:
        seqs.append([run_start + offset, len(flags) + offset])
    return seqs


def precision_recall_f1(y_true: np.ndarray, y_pred: np.ndarray):
    y_true = y_true.astype(np.int64)
    y_pred = y_pred.astype(np.int64)
    tp = int(np.sum((y_true == 1) & (y_pred == 1)))
    fp = int(np.sum((y_true == 0) & (y_pred == 1)))
    fn = int(np.sum((y_true == 1) & (y_pred == 0)))

    prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = (2 * prec * rec) / (prec + rec) if (prec + rec) > 0 else 0.0
    return dict(tp=tp, fp=fp, fn=fn, precision=prec, recall=rec, f1=f1)


def point_adjust_predictions(y_true_full: np.ndarray, y_pred_full: np.ndarray) -> np.ndarray:
    """Telemanom/NAB-style point adjustment."""
    y_true_full = y_true_full.astype(np.int64)
    y_pred_full = y_pred_full.astype(np.int64)
    y_adj = y_pred_full.copy()

    in_seg = False
    seg_start = 0
    for i in range(len(y_true_full) + 1):
        cur = y_true_full[i] if i < len(y_true_full) else 0
        if cur == 1 and not in_seg:
            in_seg = True
            seg_start = i
        elif cur == 0 and in_seg:
            in_seg = False
            seg_end = i
            if np.any(y_pred_full[seg_start:seg_end] == 1):
                y_adj[seg_start:seg_end] = 1

    return y_adj


def precision_recall_f1_point_adjusted(y_true_full: np.ndarray, y_pred_full: np.ndarray):
    y_adj = point_adjust_predictions(y_true_full, y_pred_full)
    return precision_recall_f1(y_true_full, y_adj)


# -----------------------------
# AnomalyTransformer Core (ICLR 2022) - FIXED DIMENSION HANDLING
# -----------------------------
class AnomalyAttention(nn.Module):
    """Key innovation: learns series-wise and prior association to compute discrepancy"""
    def __init__(self, d_model: int, n_heads: int = 8):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        
        # Q, K, V projections
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        
        # Learnable prior for association discrepancy
        self.u = nn.Parameter(torch.randn(n_heads, 1, 1) * 0.1)  # Small init for stability
        
        self.out_proj = nn.Linear(d_model, d_model)
        
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        B, L, D = x.shape
        
        # Project queries, keys, values
        Q = self.q_proj(x).view(B, L, self.n_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(x).view(B, L, self.n_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(x).view(B, L, self.n_heads, self.head_dim).transpose(1, 2)
        
        # Series-wise association (normalized attention)
        attn_series = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.head_dim)
        attn_series = F.softmax(attn_series, dim=-1)
        
        # Prior association (learnable Gaussian prior)
        positions = torch.arange(L, device=x.device).float()
        prior = torch.abs(positions.unsqueeze(0) - positions.unsqueeze(1))
        prior = prior.unsqueeze(0).unsqueeze(0).expand(B, self.n_heads, L, L)
        attn_prior = torch.exp(-prior ** 2 / (2 * (self.u ** 2 + 1e-6)))
        attn_prior = attn_prior / (attn_prior.sum(dim=-1, keepdim=True) + 1e-6)
        
        # Association discrepancy (key innovation)
        discrepancy = torch.abs(attn_series - attn_prior).mean(dim=(1, 2))
        
        # Weighted value aggregation
        out = torch.matmul(attn_series, V).transpose(1, 2).contiguous().view(B, L, D)
        out = self.out_proj(out)
        
        return out, discrepancy


class AnomalyTransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        self.attention = AnomalyAttention(d_model, n_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )
        
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        attn_out, discrepancy = self.attention(x)
        x = self.norm1(x + attn_out)
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)
        return x, discrepancy


class AnomalyTransformerForecaster(nn.Module):
    """
    Forecasting variant: predicts next value of channel 0 using history of other channels
    """
    def __init__(
        self,
        input_dim: int,      # Features from OTHER channels (columns 1..F-1)
        d_model: int = 128,
        n_layers: int = 3,
        n_heads: int = 8,
        window_size: int = 100,
        n_pred: int = 1,
        dropout: float = 0.1
    ):
        super().__init__()
        self.window_size = window_size
        self.n_pred = n_pred
        
        # Input projection (handles variable input dimensions)
        self.input_proj = nn.Linear(input_dim, d_model)
        
        # Transformer encoder stack
        self.layers = nn.ModuleList([
            AnomalyTransformerBlock(d_model, n_heads, dropout)
            for _ in range(n_layers)
        ])
        
        # Forecasting head (predicts next n_pred values of target channel)
        self.forecast_head = nn.Linear(d_model, n_pred)
        
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: [batch, window_size, input_dim] - history of OTHER channels
        Returns:
            forecast: [batch, n_pred] - predicted values for target channel
            discrepancy: [batch, window_size] - anomaly scores per timestep
        """
        # Project to model dimension
        x_proj = self.input_proj(x)
        
        # Pass through transformer layers (accumulate discrepancies)
        total_discrepancy = 0
        for layer in self.layers:
            x_proj, disc = layer(x_proj)
            total_discrepancy += disc
        
        # Average discrepancies across layers
        avg_discrepancy = total_discrepancy / len(self.layers)
        
        # Forecast from last timestep representation
        context = x_proj.mean(dim=1)  # [B, D]
        forecast = self.forecast_head(context)  # [B, n_pred]
        
        return forecast, avg_discrepancy


# -----------------------------
# Dataset & Training Utilities (IDENTICAL structure to your CNN)
# -----------------------------
class ForecastWindowDataset(Dataset):
    def __init__(self, X: np.ndarray, y: np.ndarray):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


def split_train_val_sequential(X: np.ndarray, y: np.ndarray, val_frac: float = 0.2):
    n = len(X)
    n_val = max(int(n * val_frac), 1)
    n_train = n - n_val
    if n_train <= 0:
        return None
    X_tr, y_tr = X[:n_train], y[:n_train]
    X_va, y_va = X[n_train:], y[n_train:]
    return X_tr, y_tr, X_va, y_va


def train_transformer(model, train_loader, val_loader, epochs: int, lr: float, device: str, chan_id: str):
    model.to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    best_val_loss = float('inf')
    best_state = None
    
    for epoch in range(1, epochs + 1):
        model.train()
        tr_loss = 0.0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()
            pred, _ = model(xb)
            loss = criterion(pred, yb)  # Shape-safe: [B, n_pred] vs [B, n_pred]
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            tr_loss += loss.item() * xb.size(0)
        tr_loss /= len(train_loader.dataset)

        model.eval()
        va_loss = 0.0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(device), yb.to(device)
                pred, _ = model(xb)
                loss = criterion(pred, yb)
                va_loss += loss.item() * xb.size(0)
        va_loss /= len(val_loader.dataset)
        
        scheduler.step()

        if va_loss < best_val_loss:
            best_val_loss = va_loss
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}

        if epoch == 1 or epoch % max(1, epochs // 5) == 0 or epoch == epochs:
            print(f"  epoch {epoch: >3}/{epochs}  train_mse={tr_loss:.6f}  val_mse={va_loss:.6f}")

    if best_state is not None:
        model.load_state_dict(best_state)
    return model


@torch.no_grad()
def predict_anomaly_scores(model, loader, device: str, n_pred: int = 1):
    """FIXED: Handles n_pred=1 correctly without dimension errors"""
    model.eval()
    errors = []
    discrepancies = []
    
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device)
        pred, disc = model(xb)
        
        # FIXED: Compute per-sample MSE correctly for any n_pred
        # pred: [B, n_pred], yb: [B, n_pred] → error: [B]
        error = ((pred - yb) ** 2).mean(dim=1).cpu().numpy()  # Mean over prediction horizon
        
        errors.append(error)
        discrepancies.append(disc.mean(dim=1).cpu().numpy())  # Mean over time window
    
    errors = np.concatenate(errors)
    discrepancies = np.concatenate(discrepancies)
    
    # Combined anomaly score (geometric mean for balance)
    combined = np.sqrt(errors * discrepancies + 1e-8)
    return combined


# -----------------------------
# Per-Channel Training Pipeline (IDENTICAL to your CNN structure)
# -----------------------------


def save_channel_model(chan_id: str, model, threshold: float, config: dict, out_dir: Path = Path("streaming_models")):
    """Save model + metadata for streaming consumption"""
    out_dir.mkdir(parents=True, exist_ok=True)
    
    # Save PyTorch model
    model_path = out_dir / f"{chan_id}.pt"
    torch.save({
        'model_state_dict': model.state_dict(),
        'model_type': model.__class__.__name__,
        'input_dim': getattr(model, 'input_dim', getattr(model, 'features_in', None)),
        'window_size': getattr(model, 'window_size', config['L_S']),
        'n_pred': getattr(model, 'n_pred', config['N_PRED']),
    }, model_path)
    
    # Save metadata (threshold + config)
    meta_path = out_dir / f"{chan_id}_metadata.json"
    metadata = {
        'chan_id': chan_id,
        'threshold': float(threshold),
        'window_size': config['L_S'],
        'n_pred': config['N_PRED'],
        'stride': config['STRIDE'],
        'model_type': model.__class__.__name__,
        'input_features': getattr(model, 'input_dim', getattr(model, 'features_in', None)),
        'training_config': config
    }
    with open(meta_path, 'w') as f:
        json.dump(metadata, f, indent=2)
    
    print(f"✓ Saved {chan_id} model to {model_path} + metadata")
    return model_path, meta_path

def get_anomaly_sequences_for_chan(df: pd.DataFrame, chan_id: str):
    row = df[df["chan_id"] == chan_id]
    if len(row) == 0:
        return []
    s = row.iloc[0]["anomaly_sequences"]
    try:
        return ast.literal_eval(s)
    except Exception:
        return []


def run_one_channel(chan_id: str, verbose: bool = True):
    x_train, x_test, scaler = load_chan_train_test(chan_id)
    if x_train is None:
        if verbose:
            print(f"[{chan_id}] missing or invalid train/test")
        return None

    # Build windows (forecasting setup: predict col 0 from cols 1..F-1)
    X_tr_all, y_tr_all, t0_tr = build_forecast_windows(
        x_train, l_s=L_S, n_pred=N_PRED, stride=STRIDE, use_other_features_only=True
    )
    X_te, y_te, t0_te = build_forecast_windows(
        x_test, l_s=L_S, n_pred=N_PRED, stride=STRIDE, use_other_features_only=True
    )
    if X_tr_all is None or X_te is None:
        if verbose:
            print(f"[{chan_id}] not enough length for windows")
        return None

    split = split_train_val_sequential(X_tr_all, y_tr_all, val_frac=VAL_FRAC)
    if split is None:
        if verbose:
            print(f"[{chan_id}] not enough train windows to split")
        return None

    X_tr, y_tr, X_va, y_va = split

    train_ds = ForecastWindowDataset(X_tr, y_tr)
    val_ds = ForecastWindowDataset(X_va, y_va)
    test_ds = ForecastWindowDataset(X_te, y_te)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

    # Initialize AnomalyTransformer (input_dim = F-1 other channels)
    model = AnomalyTransformerForecaster(
        input_dim=X_tr.shape[2],
        d_model=128,
        n_layers=3,
        n_heads=8,
        window_size=L_S,
        n_pred=N_PRED,
        dropout=DROPOUT
    )
    
    if verbose:
        print(f"\n[{chan_id}] train_windows={len(train_ds)} val_windows={len(val_ds)} "
              f"test_windows={len(test_ds)} F_in={X_tr.shape[2]}")

    # Train model
    model = train_transformer(
        model, train_loader, val_loader,
        epochs=EPOCHS, lr=LR, device=DEVICE, chan_id=chan_id
    )

    # Compute anomaly scores on validation set for thresholding (FIXED call)
    val_scores = predict_anomaly_scores(model, val_loader, device=DEVICE, n_pred=N_PRED)
    thr = float(np.quantile(val_scores, THRESH_Q))
    
    # Compute scores on test set
    test_scores = predict_anomaly_scores(model, test_loader, device=DEVICE, n_pred=N_PRED)
    pred_flags = (test_scores > thr).astype(np.int64)

    # Align with ground truth
    anomaly_seqs = get_anomaly_sequences_for_chan(df_labels, chan_id)
    gt_full = anomaly_vector_from_sequences(T=x_test.shape[0], anomaly_sequences=anomaly_seqs)

    # Map window predictions to original time indices
    pred_indices = t0_te + np.arange(len(pred_flags)) * STRIDE
    pred_indices = pred_indices.astype(np.int64)
    valid = pred_indices < len(gt_full)
    pred_indices = pred_indices[valid]
    y_pred_pts = pred_flags[valid]
    y_true_pts = gt_full[pred_indices]

    # Point-adjusted metrics (standard in NASA anomaly detection papers)
    metrics = precision_recall_f1_point_adjusted(y_true_pts, y_pred_pts)

    # Save model
    save_path = OUT_MODELS_DIR / f"{chan_id}.pt"
    torch.save(
        {
            "chan_id": chan_id,
            "state_dict": {k: v.cpu() for k, v in model.state_dict().items()},
            "input_dim": int(X_tr.shape[2]),
            "l_s": int(L_S),
            "n_pred": int(N_PRED),
            "stride": int(STRIDE),
            "thr_q": float(THRESH_Q),
            "threshold": float(thr),
            "scaler_mean": None if scaler is None else scaler.mean_.astype(np.float32),
            "scaler_scale": None if scaler is None else scaler.scale_.astype(np.float32),
        },
        save_path,
    )

    out = {
        "chan_id": chan_id,
        "T_train": int(x_train.shape[0]),
        "T_test": int(x_test.shape[0]),
        "F": int(x_train.shape[1]),
        "F_in": int(X_tr.shape[2]),
        "n_train_windows": int(len(train_ds)),
        "n_val_windows": int(len(val_ds)),
        "n_test_windows": int(len(test_ds)),
        "threshold": float(thr),
        "tp": metrics["tp"],
        "fp": metrics["fp"],
        "fn": metrics["fn"],
        "precision": metrics["precision"],
        "recall": metrics["recall"],
        "f1": metrics["f1"],
        "pred_sequences": contiguous_sequences_from_flags(y_pred_pts.astype(bool), offset=int(pred_indices[0]) if len(pred_indices) else t0_te),
        "true_sequences": anomaly_seqs,
        "model_path": str(save_path),
    }
    
    if verbose:
        print(f"[{chan_id}] thr={thr:.6f}  P={metrics['precision']:.3f} R={metrics['recall']:.3f} F1={metrics['f1']:.3f}")
    return out


# -----------------------------
# Main Execution
# -----------------------------
if __name__ == "__main__":
    df_labels = pd.read_csv(LABEL_FILE)
    print(f"Loaded labels: {df_labels.shape}")
    print(df_labels.head(3))

    results = []
    chan_ids = df_labels["chan_id"].tolist()

    if MAX_CHANNELS is not None:
        chan_ids = chan_ids[:int(MAX_CHANNELS)]

    print(f"\nTotal channels to run: {len(chan_ids)}\n")

    for cid in chan_ids:
        r = run_one_channel(cid, verbose=True)
        if r is not None:
            results.append(r)

    res_df = pd.DataFrame(results)
    res_path = OUT_RESULTS_DIR / "per_channel_metrics.csv"
    res_df.to_csv(res_path, index=False)

    print(f"\nSaved metrics: {res_path}")
    print(res_df[["chan_id", "precision", "recall", "f1", "threshold"]].head(10))
    
    # Summary statistics
    if len(results) > 0:
        print("\n=== SUMMARY STATISTICS ===")
        print(f"Mean Precision: {res_df['precision'].mean():.4f}")
        print(f"Mean Recall:    {res_df['recall'].mean():.4f}")
        print(f"Mean F1:        {res_df['f1'].mean():.4f}")
        print(f"Channels with F1 > 0.5: {(res_df['f1'] > 0.5).sum()} / {len(res_df)}")

DEVICE: cuda
Window params: {'L_S': 100, 'N_PRED': 1, 'STRIDE': 1}
Train params: {'BATCH_SIZE': 256, 'EPOCHS': 50, 'LR': 0.0001, 'VAL_FRAC': 0.2}
Loaded labels: (82, 5)
  chan_id spacecraft                           anomaly_sequences  \
0     P-1       SMAP  [[2149, 2349], [4536, 4844], [3539, 3779]]   
1     S-1       SMAP                              [[5300, 5747]]   
2     E-1       SMAP                [[5000, 5030], [5610, 6086]]   

                                  class  num_values  
0  [contextual, contextual, contextual]        8505  
1                               [point]        7331  
2              [contextual, contextual]        8516  

Total channels to run: 82


[P-1] train_windows=2217 val_windows=554 test_windows=8404 F_in=24
  epoch   1/50  train_mse=0.451543  val_mse=0.193827
  epoch  10/50  train_mse=0.157447  val_mse=0.191873
  epoch  20/50  train_mse=0.154448  val_mse=0.190271
  epoch  30/50  train_mse=0.149726  val_mse=0.186310
  epoch  40/50  train_mse=0.145591