In [None]:
"""
Simple Jupyter-friendly Sim→Real counter predictor (minimal, opinionated)
---------------------------------------------------------------------

Objetivo
========
Entrenar con (waveform_muon, gain0_sim) -> counts (clase 0..20).
En test reemplazar gain0_sim por combined_real_gain0 y predecir los counts esperados.

Principios de diseño (simplificados a pedido):
- Espectros en dB, normalizados a pico = 0 dB (siempre).
- Enmascaramiento DURO fuera de [1e3, 1e9] Hz (se multiplican por 0).
- Sin augmentations ni randomizaciones.
- No hay resampling en el código: **asumo que gain0_sim y combined_real_gain0 ya están en la MISMA grilla**
  y que esa grilla está guardada en `data/freq_grid_hz.npy` (Hz). Si no la tenés, creala
  con la misma longitud que tus espectros. Esto evita complejidad en el script.
- Salida: clasificación con 21 clases (0..20).
- Código pensado para ejecutarse en un notebook: funciones que devuelven historiales y permiten
  graficar losses, comparar distribuciones y visualizar la matriz de confusión.

Archivos esperados (poner en data/):
- sim_train.npz  -> keys: 'waveforms' (N,T), 'gain0_sim' (N,F, dB), 'counts' (N) integers 0..20
- sim_val.npz    -> same structure (validación)
- combined_real_gain0.npy -> shape (F,) en dB (misma grilla que gain0_sim)
- freq_grid_hz.npy -> shape (F,) con frecuencias en Hz (misma malla para los espectros)

Salida producida:
- artifacts/best_simple.pt                 (modelo guardado)
- artifacts/pred_counts_real_gain0.npy     (predicciones sobre val usando combined_real_gain0)
- functions devuelven arrays para plotear en notebook

Instrucciones de uso (notebook):
1) Cargá este script como módulo o pegalo en una celda.
2) Llamá a `train_model(...)` para entrenar y recibir `history`.
3) Usá `plot_losses(history)` y `plot_confusion(y_true, y_pred)` para visualizar.
4) Ejecutá `predict_with_real(...)` para obtener predicciones usando `combined_real_gain0`.

"""

from __future__ import annotations
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import matplotlib.pyplot as plt
from typing import Tuple, Dict, Any

# -----------------------------
# Minimal user-editable variables
# -----------------------------
TRAIN_NPZ = 'data/sim_train.npz'
VAL_NPZ = 'data/sim_val.npz'
REAL_SPEC_NPY = 'data/combined_real_gain0.npy'
FREQ_GRID_NPY = 'data/freq_grid_hz.npy'   # required: same length as spec
ARTIFACTS_DIR = Path('artifacts')
ARTIFACTS_DIR.mkdir(exist_ok=True)

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
EPOCHS = 30
BATCH_SIZE = 64
LR = 1e-3

# Frequency bounds (hard mask)
LOW_HZ = 1e3
HIGH_HZ = 1e9

# Fixed problem dimensions (will be inferred from data on load)
NUM_CLASSES = 21  # classes 0..20

# -----------------------------
# Utility functions
# -----------------------------

def load_freq_grid():
    p = Path(FREQ_GRID_NPY)
    if not p.exists():
        raise FileNotFoundError(f"Freq grid file not found: {p}. Create a freq_grid_hz.npy matching your spectra")
    return np.load(p)


def build_hard_mask(freq_grid_hz: np.ndarray, low_hz: float = LOW_HZ, high_hz: float = HIGH_HZ) -> np.ndarray:
    mask = np.ones_like(freq_grid_hz, dtype=np.float32)
    mask[(freq_grid_hz < low_hz) | (freq_grid_hz > high_hz)] = 0.0
    return mask


def normalize_spec_db_to_peak(spec_db: np.ndarray) -> np.ndarray:
    """Subtract max so peak is 0 dB. Works on 1D arrays or batch arrays (last axis = freq)."""
    spec_db = np.array(spec_db, dtype=np.float32)
    if spec_db.ndim == 1:
        return spec_db - np.max(spec_db)
    else:
        return spec_db - np.max(spec_db, axis=-1, keepdims=True)


def standardize_waveform(wf: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    wf = np.array(wf, dtype=np.float32)
    m = wf.mean()
    s = wf.std()
    return (wf - m) / (s + eps)

# -----------------------------
# Dataset
# -----------------------------
class SimpleSimDataset(Dataset):
    def __init__(self, npz_path: str, freq_mask: np.ndarray):
        data = np.load(npz_path)
        self.waveforms = data['waveforms'].astype(np.float32)  # [N,T]
        self.spec = data['gain0_sim'].astype(np.float32)      # [N,F] in dB
        self.counts = data['counts'].astype(np.int64)         # [N]
        assert self.waveforms.ndim == 2
        assert self.spec.ndim == 2
        assert self.waveforms.shape[0] == self.spec.shape[0] == self.counts.shape[0]
        self.freq_mask = freq_mask.astype(np.float32)

    def __len__(self):
        return self.waveforms.shape[0]

    def __getitem__(self, idx):
        wf = standardize_waveform(self.waveforms[idx])  # [T]
        spec = normalize_spec_db_to_peak(self.spec[idx]) # [F]
        spec = spec * self.freq_mask
        y = int(self.counts[idx])
        if y < 0:
            y = 0
        if y >= NUM_CLASSES:
            y = NUM_CLASSES - 1
        wf_t = torch.from_numpy(wf)[None, :].float()   # [1,T]
        spec_t = torch.from_numpy(spec).float()        # [F]
        return wf_t, spec_t, torch.tensor(y, dtype=torch.long)

# -----------------------------
# Model (small, simple)
# -----------------------------
class WaveformEncoder(nn.Module):
    def __init__(self, in_ch: int = 1, out_dim: int = 64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(in_ch, 32, kernel_size=9, stride=2, padding=4),
            nn.ReLU(),
            nn.Conv1d(32, 64, kernel_size=7, stride=2, padding=3),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(64, out_dim),
            nn.ReLU()
        )

    def forward(self, x):
        return self.net(x)

class SpecEncoder(nn.Module):
    def __init__(self, f_len: int, out_dim: int = 64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(f_len, 256),
            nn.ReLU(),
            nn.Linear(256, out_dim),
            nn.ReLU()
        )

    def forward(self, x):
        return self.net(x)

class SimpleClassifier(nn.Module):
    def __init__(self, spec_len: int, wf_out: int = 64, spec_out: int = 64, n_classes: int = NUM_CLASSES):
        super().__init__()
        self.wf_enc = WaveformEncoder(in_ch=1, out_dim=wf_out)
        self.spec_enc = SpecEncoder(f_len=spec_len, out_dim=spec_out)
        self.head = nn.Sequential(
            nn.Linear(wf_out + spec_out, 128),
            nn.ReLU(),
            nn.Linear(128, n_classes)
        )

    def forward(self, wf, spec):
        # wf: [B,1,T], spec: [B,F]
        wf_feat = self.wf_enc(wf)
        spec_feat = self.spec_enc(spec)
        x = torch.cat([wf_feat, spec_feat], dim=-1)
        logits = self.head(x)
        return logits

# -----------------------------
# Training / utilities for notebook
# -----------------------------

def train_model(train_npz: str = TRAIN_NPZ, val_npz: str = VAL_NPZ, epochs: int = EPOCHS, batch_size: int = BATCH_SIZE, lr: float = LR) -> Dict[str, Any]:
    # load freq grid and build mask
    freq_grid = load_freq_grid()
    mask = build_hard_mask(freq_grid, LOW_HZ, HIGH_HZ)

    # datasets and loaders
    ds_tr = SimpleSimDataset(train_npz, mask)
    ds_va = SimpleSimDataset(val_npz, mask)
    spec_len = ds_tr.spec.shape[1]

    dl_tr = DataLoader(ds_tr, batch_size=batch_size, shuffle=True)
    dl_va = DataLoader(ds_va, batch_size=batch_size, shuffle=False)

    model = SimpleClassifier(spec_len=spec_len).to(DEVICE)
    optim = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    history = {'train_loss': [], 'val_loss': []}
    best_val = float('inf')

    for ep in range(1, epochs + 1):
        model.train()
        running_loss = 0.0
        for wf, spec, y in dl_tr:
            wf = wf.to(DEVICE)
            spec = spec.to(DEVICE)
            y = y.to(DEVICE)
            logits = model(wf, spec)
            loss = loss_fn(logits, y)
            optim.zero_grad()
            loss.backward()
            optim.step()
            running_loss += loss.item() * y.size(0)
        train_loss = running_loss / len(ds_tr)

        # validation
        model.eval()
        running_val = 0.0
        with torch.no_grad():
            for wf, spec, y in dl_va:
                wf = wf.to(DEVICE)
                spec = spec.to(DEVICE)
                y = y.to(DEVICE)
                logits = model(wf, spec)
                loss = loss_fn(logits, y)
                running_val += loss.item() * y.size(0)
        val_loss = running_val / len(ds_va)

        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        print(f"Epoch {ep:03d} | train {train_loss:.4f} | val {val_loss:.4f}")

        if val_loss < best_val:
            best_val = val_loss
            torch.save({'model': model.state_dict(), 'spec_len': spec_len}, ARTIFACTS_DIR / 'best_simple.pt')

    return {'model_state': model.state_dict(), 'history': history, 'spec_len': spec_len}

# -----------------------------
# Prediction using combined_real_gain0
# -----------------------------
@torch.no_grad()
def predict_with_real(val_npz: str = VAL_NPZ, checkpoint: str = ARTIFACTS_DIR / 'best_simple.pt') -> Tuple[np.ndarray, np.ndarray]:
    ckpt = np.load(checkpoint, allow_pickle=True) if str(checkpoint).endswith('.npz') else None
    # load torch checkpoint
    import torch
    ck = torch.load(checkpoint, map_location=DEVICE)
    spec_len = ck.get('spec_len') if isinstance(ck, dict) else None

    # load freq grid and mask
    freq_grid = load_freq_grid()
    mask = build_hard_mask(freq_grid, LOW_HZ, HIGH_HZ)

    # load real spec
    real_spec = np.load(REAL_SPEC_NPY).astype(np.float32)
    real_spec = normalize_spec_db_to_peak(real_spec)
    real_spec = real_spec * mask

    # load validation waveforms
    data = np.load(val_npz)
    waveforms = data['waveforms'].astype(np.float32)
    counts_true = data['counts'].astype(np.int64)

    # build model and load weights
    spec_len_detected = real_spec.shape[0]
    model = SimpleClassifier(spec_len=spec_len_detected).to(DEVICE)
    model.load_state_dict(ck['model'])
    model.eval()

    # prepare batches
    N = waveforms.shape[0]
    batch_size = 128
    preds = np.zeros((N,), dtype=np.int64)
    probs = np.zeros((N, NUM_CLASSES), dtype=np.float32)

    for i in range(0, N, batch_size):
        chunk = waveforms[i:i+batch_size]
        wf_t = np.stack([standardize_waveform(w) for w in chunk], axis=0)
        wf_t = torch.from_numpy(wf_t)[:, None, :].float().to(DEVICE)
        spec_t = torch.from_numpy(np.tile(real_spec[None, :], (wf_t.size(0), 1))).float().to(DEVICE)
        with torch.no_grad():
            logits = model(wf_t, spec_t)
            p = torch.softmax(logits, dim=-1).cpu().numpy()
        probs[i:i+batch_size] = p
        preds[i:i+batch_size] = p.argmax(axis=-1)

    # save
    np.save(ARTIFACTS_DIR / 'pred_class_probs_real_gain0.npy', probs)
    np.save(ARTIFACTS_DIR / 'pred_counts_real_gain0.npy', preds)
    return preds, probs

# -----------------------------
# Plot helpers for notebook
# -----------------------------
def plot_losses(history: Dict[str, list]):
    plt.figure(figsize=(6,4))
    plt.plot(history['train_loss'], label='train')
    plt.plot(history['val_loss'], label='val')
    plt.yscale('log')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend()
    plt.grid(True)
    plt.show()


def plot_confusion(y_true: np.ndarray, y_pred: np.ndarray, normalize: bool = True, vmax: float = None):
    K = NUM_CLASSES
    cm = np.zeros((K,K), dtype=np.int64)
    for t,p in zip(y_true, y_pred):
        cm[int(t), int(p)] += 1
    if normalize:
        row_sums = cm.sum(axis=1, keepdims=True).astype(np.float32)
        cm_norm = np.divide(cm, row_sums, out=np.zeros_like(cm, dtype=float), where=row_sums!=0)
        m = cm_norm
    else:
        m = cm
    plt.figure(figsize=(8,6))
    plt.imshow(m, origin='lower', aspect='auto', vmax=vmax)
    plt.colorbar()
    plt.xlabel('predicted')
    plt.ylabel('true')
    plt.title('Confusion matrix (rows=true)')
    plt.show()


def plot_pred_vs_true_hist(y_true: np.ndarray, y_pred: np.ndarray):
    plt.figure(figsize=(8,4))
    plt.hist(y_true, bins=np.arange(NUM_CLASSES+1)-0.5, alpha=0.5, label='true')
    plt.hist(y_pred, bins=np.arange(NUM_CLASSES+1)-0.5, alpha=0.5, label='pred')
    plt.xlabel('counts')
    plt.ylabel('frequency')
    plt.legend()
    plt.show()

# End of script


In [None]:
"""
Simple Jupyter-friendly Sim→Real counter predictor (minimal, opinionated)
---------------------------------------------------------------------

Objetivo
========
Entrenar con (waveform_muon, gain0_sim) -> counts (clase 0..20).
En test reemplazar gain0_sim por combined_real_gain0 y predecir los counts esperados.

Principios de diseño (simplificados a pedido):
- Espectros en dB, normalizados a pico = 0 dB (siempre).
- Enmascaramiento DURO fuera de [1e3, 1e9] Hz (se multiplican por 0).
- Sin augmentations ni randomizaciones.
- No hay resampling en el código: **asumo que gain0_sim y combined_real_gain0 ya están en la MISMA grilla**
  y que esa grilla está guardada en `data/freq_grid_hz.npy` (Hz). Si no la tenés, creala
  con la misma longitud que tus espectros. Esto evita complejidad en el script.
- Salida: clasificación con 21 clases (0..20).
- Código pensado para ejecutarse en un notebook: funciones que devuelven historiales y permiten
  graficar losses, comparar distribuciones y visualizar la matriz de confusión.

Archivos esperados (poner en data/):
- sim_train.npz  -> keys: 'waveforms' (N,T), 'gain0_sim' (N,F, dB), 'counts' (N) integers 0..20
- sim_val.npz    -> same structure (validación)
- combined_real_gain0.npy -> shape (F,) en dB (misma grilla que gain0_sim)
- freq_grid_hz.npy -> shape (F,) con frecuencias en Hz (misma malla para los espectros)

Salida producida:
- artifacts/best_simple.pt                 (modelo guardado)
- artifacts/pred_counts_real_gain0.npy     (predicciones sobre val usando combined_real_gain0)
- functions devuelven arrays para plotear en notebook

Instrucciones de uso (notebook):
1) Cargá este script como módulo o pegalo en una celda.
2) Llamá a `train_model(...)` para entrenar y recibir `history`.
3) Usá `plot_losses(history)` y `plot_confusion(y_true, y_pred)` para visualizar.
4) Ejecutá `predict_with_real(...)` para obtener predicciones usando `combined_real_gain0`.

"""

from __future__ import annotations
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import matplotlib.pyplot as plt
from typing import Tuple, Dict, Any

# -----------------------------
# Minimal user-editable variables
# -----------------------------
TRAIN_NPZ = 'data/sim_train.npz'
VAL_NPZ = 'data/sim_val.npz'
REAL_SPEC_NPY = 'data/combined_real_gain0.npy'
FREQ_GRID_NPY = 'data/freq_grid_hz.npy'   # required: same length as spec
ARTIFACTS_DIR = Path('artifacts')
ARTIFACTS_DIR.mkdir(exist_ok=True)

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
EPOCHS = 30
BATCH_SIZE = 64
LR = 1e-3

# Frequency bounds (hard mask)
LOW_HZ = 1e3
HIGH_HZ = 1e9

# Fixed problem dimensions (will be inferred from data on load)
NUM_CLASSES = 21  # classes 0..20

# -----------------------------
# Utility functions
# -----------------------------

def load_freq_grid():
    p = Path(FREQ_GRID_NPY)
    if not p.exists():
        raise FileNotFoundError(f"Freq grid file not found: {p}. Create a freq_grid_hz.npy matching your spectra")
    return np.load(p)


def build_hard_mask(freq_grid_hz: np.ndarray, low_hz: float = LOW_HZ, high_hz: float = HIGH_HZ) -> np.ndarray:
    mask = np.ones_like(freq_grid_hz, dtype=np.float32)
    mask[(freq_grid_hz < low_hz) | (freq_grid_hz > high_hz)] = 0.0
    return mask


def normalize_spec_db_to_peak(spec_db: np.ndarray) -> np.ndarray:
    """Subtract max so peak is 0 dB. Works on 1D arrays or batch arrays (last axis = freq)."""
    spec_db = np.array(spec_db, dtype=np.float32)
    if spec_db.ndim == 1:
        return spec_db - np.max(spec_db)
    else:
        return spec_db - np.max(spec_db, axis=-1, keepdims=True)


def standardize_waveform(wf: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    wf = np.array(wf, dtype=np.float32)
    m = wf.mean()
    s = wf.std()
    return (wf - m) / (s + eps)

# -----------------------------
# Dataset
# -----------------------------
class SimpleSimDataset(Dataset):
    def __init__(self, npz_path: str, freq_mask: np.ndarray):
        data = np.load(npz_path)
        self.waveforms = data['waveforms'].astype(np.float32)  # [N,T]
        self.spec = data['gain0_sim'].astype(np.float32)      # [N,F] in dB
        self.counts = data['counts'].astype(np.int64)         # [N]
        assert self.waveforms.ndim == 2
        assert self.spec.ndim == 2
        assert self.waveforms.shape[0] == self.spec.shape[0] == self.counts.shape[0]
        self.freq_mask = freq_mask.astype(np.float32)

    def __len__(self):
        return self.waveforms.shape[0]

    def __getitem__(self, idx):
        wf = standardize_waveform(self.waveforms[idx])  # [T]
        spec = normalize_spec_db_to_peak(self.spec[idx]) # [F]
        spec = spec * self.freq_mask
        y = int(self.counts[idx])
        if y < 0:
            y = 0
        if y >= NUM_CLASSES:
            y = NUM_CLASSES - 1
        wf_t = torch.from_numpy(wf)[None, :].float()   # [1,T]
        spec_t = torch.from_numpy(spec).float()        # [F]
        return wf_t, spec_t, torch.tensor(y, dtype=torch.long)

# -----------------------------
# Model (small, simple)
# -----------------------------
class WaveformEncoder(nn.Module):
    def __init__(self, in_ch: int = 1, out_dim: int = 64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(in_ch, 32, kernel_size=9, stride=2, padding=4),
            nn.ReLU(),
            nn.Conv1d(32, 64, kernel_size=7, stride=2, padding=3),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(64, out_dim),
            nn.ReLU()
        )

    def forward(self, x):
        return self.net(x)

class SpecEncoder(nn.Module):
    def __init__(self, f_len: int, out_dim: int = 64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(f_len, 256),
            nn.ReLU(),
            nn.Linear(256, out_dim),
            nn.ReLU()
        )

    def forward(self, x):
        return self.net(x)

class SimpleClassifier(nn.Module):
    def __init__(self, spec_len: int, wf_out: int = 64, spec_out: int = 64, n_classes: int = NUM_CLASSES):
        super().__init__()
        self.wf_enc = WaveformEncoder(in_ch=1, out_dim=wf_out)
        self.spec_enc = SpecEncoder(f_len=spec_len, out_dim=spec_out)
        self.head = nn.Sequential(
            nn.Linear(wf_out + spec_out, 128),
            nn.ReLU(),
            nn.Linear(128, n_classes)
        )

    def forward(self, wf, spec):
        # wf: [B,1,T], spec: [B,F]
        wf_feat = self.wf_enc(wf)
        spec_feat = self.spec_enc(spec)
        x = torch.cat([wf_feat, spec_feat], dim=-1)
        logits = self.head(x)
        return logits

# -----------------------------
# Training / utilities for notebook
# -----------------------------

def train_model(train_npz: str = TRAIN_NPZ, val_npz: str = VAL_NPZ, epochs: int = EPOCHS, batch_size: int = BATCH_SIZE, lr: float = LR) -> Dict[str, Any]:
    # load freq grid and build mask
    freq_grid = load_freq_grid()
    mask = build_hard_mask(freq_grid, LOW_HZ, HIGH_HZ)

    # datasets and loaders
    ds_tr = SimpleSimDataset(train_npz, mask)
    ds_va = SimpleSimDataset(val_npz, mask)
    spec_len = ds_tr.spec.shape[1]

    dl_tr = DataLoader(ds_tr, batch_size=batch_size, shuffle=True)
    dl_va = DataLoader(ds_va, batch_size=batch_size, shuffle=False)

    model = SimpleClassifier(spec_len=spec_len).to(DEVICE)
    optim = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    history = {'train_loss': [], 'val_loss': []}
    best_val = float('inf')

    for ep in range(1, epochs + 1):
        model.train()
        running_loss = 0.0
        for wf, spec, y in dl_tr:
            wf = wf.to(DEVICE)
            spec = spec.to(DEVICE)
            y = y.to(DEVICE)
            logits = model(wf, spec)
            loss = loss_fn(logits, y)
            optim.zero_grad()
            loss.backward()
            optim.step()
            running_loss += loss.item() * y.size(0)
        train_loss = running_loss / len(ds_tr)

        # validation
        model.eval()
        running_val = 0.0
        with torch.no_grad():
            for wf, spec, y in dl_va:
                wf = wf.to(DEVICE)
                spec = spec.to(DEVICE)
                y = y.to(DEVICE)
                logits = model(wf, spec)
                loss = loss_fn(logits, y)
                running_val += loss.item() * y.size(0)
        val_loss = running_val / len(ds_va)

        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        print(f"Epoch {ep:03d} | train {train_loss:.4f} | val {val_loss:.4f}")

        if val_loss < best_val:
            best_val = val_loss
            torch.save({'model': model.state_dict(), 'spec_len': spec_len}, ARTIFACTS_DIR / 'best_simple.pt')

    return {'model_state': model.state_dict(), 'history': history, 'spec_len': spec_len}

# -----------------------------
# Prediction using combined_real_gain0
# -----------------------------
@torch.no_grad()
def predict_with_real(val_npz: str = VAL_NPZ, checkpoint: str = ARTIFACTS_DIR / 'best_simple.pt') -> Tuple[np.ndarray, np.ndarray]:
    ckpt = np.load(checkpoint, allow_pickle=True) if str(checkpoint).endswith('.npz') else None
    # load torch checkpoint
    import torch
    ck = torch.load(checkpoint, map_location=DEVICE)
    spec_len = ck.get('spec_len') if isinstance(ck, dict) else None

    # load freq grid and mask
    freq_grid = load_freq_grid()
    mask = build_hard_mask(freq_grid, LOW_HZ, HIGH_HZ)

    # load real spec
    real_spec = np.load(REAL_SPEC_NPY).astype(np.float32)
    real_spec = normalize_spec_db_to_peak(real_spec)
    real_spec = real_spec * mask

    # load validation waveforms
    data = np.load(val_npz)
    waveforms = data['waveforms'].astype(np.float32)
    counts_true = data['counts'].astype(np.int64)

    # build model and load weights
    spec_len_detected = real_spec.shape[0]
    model = SimpleClassifier(spec_len=spec_len_detected).to(DEVICE)
    model.load_state_dict(ck['model'])
    model.eval()

    # prepare batches
    N = waveforms.shape[0]
    batch_size = 128
    preds = np.zeros((N,), dtype=np.int64)
    probs = np.zeros((N, NUM_CLASSES), dtype=np.float32)

    for i in range(0, N, batch_size):
        chunk = waveforms[i:i+batch_size]
        wf_t = np.stack([standardize_waveform(w) for w in chunk], axis=0)
        wf_t = torch.from_numpy(wf_t)[:, None, :].float().to(DEVICE)
        spec_t = torch.from_numpy(np.tile(real_spec[None, :], (wf_t.size(0), 1))).float().to(DEVICE)
        with torch.no_grad():
            logits = model(wf_t, spec_t)
            p = torch.softmax(logits, dim=-1).cpu().numpy()
        probs[i:i+batch_size] = p
        preds[i:i+batch_size] = p.argmax(axis=-1)

    # save
    np.save(ARTIFACTS_DIR / 'pred_class_probs_real_gain0.npy', probs)
    np.save(ARTIFACTS_DIR / 'pred_counts_real_gain0.npy', preds)
    return preds, probs

# -----------------------------
# Plot helpers for notebook
# -----------------------------
def plot_losses(history: Dict[str, list]):
    plt.figure(figsize=(6,4))
    plt.plot(history['train_loss'], label='train')
    plt.plot(history['val_loss'], label='val')
    plt.yscale('log')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend()
    plt.grid(True)
    plt.show()


def plot_confusion(y_true: np.ndarray, y_pred: np.ndarray, normalize: bool = True, vmax: float = None):
    K = NUM_CLASSES
    cm = np.zeros((K,K), dtype=np.int64)
    for t,p in zip(y_true, y_pred):
        cm[int(t), int(p)] += 1
    if normalize:
        row_sums = cm.sum(axis=1, keepdims=True).astype(np.float32)
        cm_norm = np.divide(cm, row_sums, out=np.zeros_like(cm, dtype=float), where=row_sums!=0)
        m = cm_norm
    else:
        m = cm
    plt.figure(figsize=(8,6))
    plt.imshow(m, origin='lower', aspect='auto', vmax=vmax)
    plt.colorbar()
    plt.xlabel('predicted')
    plt.ylabel('true')
    plt.title('Confusion matrix (rows=true)')
    plt.show()


def plot_pred_vs_true_hist(y_true: np.ndarray, y_pred: np.ndarray):
    plt.figure(figsize=(8,4))
    plt.hist(y_true, bins=np.arange(NUM_CLASSES+1)-0.5, alpha=0.5, label='true')
    plt.hist(y_pred, bins=np.arange(NUM_CLASSES+1)-0.5, alpha=0.5, label='pred')
    plt.xlabel('counts')
    plt.ylabel('frequency')
    plt.legend()
    plt.show()

# End of script
