In [None]:
"""
Sim→Real Counter Predictor — AMIGA electronics (Gain 0 measured S21)
--------------------------------------------------------------------

Goal
====
Train on (muon waveform x(t), simulated S21_sim(f)) → simulated binary counts.
At test time, replace S21_sim(f) with measured S21_real_gain0(f) (VNA+Pitaya stitched)
and predict the expected counts as if the muon passed through the real electronics.

What you need to provide
=========================
- A dataset of simulated examples with:
  * waveform: float32 array of shape [T] (uniformly sampled x(t))
  * s21_sim: float32 array of shape [F] (amplitude in dB or linear, see CONFIG)
  * target_count: float32 scalar (or int) — number of ones in the binary counter for that waveform.
- One measured S21 curve for gain=0 (stitched VNA+Pitaya) of shape [F_meas].
  It must be resampled to the same frequency grid as s21_sim (use the helper below).

File/dir expectations (you can change paths):
- data/sim_train.npz  with keys: 'waveforms' [N,T], 's21' [N,F], 'counts' [N]
- data/sim_val.npz    with same keys for validation
- data/s21_real_gain0.npy with shape [F]

Key design choices
==================
- CNN1D encodes waveforms; MLP encodes S21; features are concatenated → MLP head predicts counts.
- Optional frequency mask for f>40 MHz to attenuate unreliable regions in the S21 encoder input.
- Strong normalizations: waveform standardization; S21 either dB normalized to peak=0 or linear gain.
- Domain-robustness: light random perturbations of S21 during training (shift, tilt, noise) to avoid overfitting
  to a single "perfect" simulated curve.

Run
===
python sim2real_counter_predictor.py --train
python sim2real_counter_predictor.py --predict-real

This is a skeleton: adjust shapes, paths, and hyperparams to your data.
"""
from __future__ import annotations
import argparse
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# ------------------------------
# Config
# ------------------------------
@dataclass
class CONFIG:
    # data paths
    train_npz: str = "data/sim_train.npz"
    val_npz: str = "data/sim_val.npz"
    s21_real_path: str = "data/s21_real_gain0.npy"

    # sampling and masking
    freq_mask_cut_mhz: float = 40.0   # attenuate > this (MHz) in S21 encoder input
    freq_grid_mhz: Optional[np.ndarray] = None  # if provided, used to build mask; else no mask
    mask_attenuation: float = 0.4     # multiply S21 features above cut by this factor

    # model
    waveform_len: int = 2048
    s21_len: int = 1024
    wf_channels: int = 1

    # training
    batch_size: int = 64
    lr: float = 1e-3
    weight_decay: float = 1e-4
    epochs: int = 50
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    # data augmentation on S21 (domain randomization light)
    aug_shift_db: float = 0.5   # uniform shift in dB
    aug_tilt_db: float = 0.8    # linear tilt across band (dB span)
    aug_noise_db: float = 0.2   # additive gaussian noise (std, dB)

    # normalization
    s21_in_db: bool = True      # if True, inputs are in dB (recommend); else linear amplitude
    s21_peak_to_0db: bool = True


def set_seed(seed: int = 1234):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


# ------------------------------
# Utilities
# ------------------------------

def resample_to_reference(x: np.ndarray, f_src: np.ndarray, f_ref: np.ndarray) -> np.ndarray:
    """Resample 1D array x defined on f_src to f_ref via linear interpolation.
    Assumes both f_src and f_ref are sorted ascending.
    """
    return np.interp(f_ref, f_src, x)


def build_freq_mask_hz(freq_grid_hz: Optional[np.ndarray], low_hz: float, high_hz: float, attenuation: float) -> Optional[np.ndarray]:
    if freq_grid_hz is None:
        return None
    m = np.ones_like(freq_grid_hz, dtype=np.float32)
    m[(freq_grid_hz < low_hz) | (freq_grid_hz > high_hz)] = attenuation
    return m


def normalize_spectrum_db(spec_db: np.ndarray, peak_to_0db: bool = True) -> np.ndarray:
    x = spec_db.astype(np.float32).copy()
    if peak_to_0db:
        peak = np.max(x)
        x = x - peak
    return x


def augment_spec_db(x_db: np.ndarray, cfg: CONFIG) -> np.ndarray:
    """Apply simple domain randomization to spectra in dB: shift, tilt, and noise."""
    x = x_db.copy()
    N = x.shape[-1]
    # global shift
    shift = np.random.uniform(-cfg.aug_shift_db, cfg.aug_shift_db)
    x += shift
    # linear tilt across band
    tilt_span = np.random.uniform(-cfg.aug_tilt_db, cfg.aug_tilt_db)
    lin = np.linspace(-0.5, 0.5, N, dtype=np.float32)
    x += tilt_span * lin
    # gaussian noise
    noise = np.random.normal(0.0, cfg.aug_noise_db, size=N).astype(np.float32)
    x += noise
    return x


def standardize_waveform(wf: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    m = wf.mean()
    s = wf.std()
    return ((wf - m) / (s + eps)).astype(np.float32)

# ------------------------------
# Dataset
# ------------------------------
class SimDataset(Dataset):
    def __init__(self, npz_path: str, cfg: CONFIG, freq_mask: Optional[np.ndarray] = None, train: bool = True):
        data = np.load(npz_path)
        self.waveforms = data['waveforms'].astype(np.float32)  # [N,T]
        self.s21 = data['s21'].astype(np.float32)              # [N,F]
        self.counts = data['counts'].astype(np.float32)        # [N]
        self.cfg = cfg
        self.freq_mask = freq_mask.astype(np.float32) if freq_mask is not None else None
        self.train = train
        assert self.waveforms.shape[1] == cfg.waveform_len, "Mismatch waveform_len"
        assert self.s21.shape[1] == cfg.s21_len, "Mismatch s21_len"

    def __len__(self):
        return self.waveforms.shape[0]

    def __getitem__(self, idx):
        wf = self.waveforms[idx]
        s21 = self.s21[idx]
        y = self.counts[idx]
        # normalize waveform
        wf = standardize_waveform(wf)
        # s21 in dB normalization
        if self.cfg.s21_in_db:
            s21 = normalize_s21_db(s21, peak_to_0db=self.cfg.s21_peak_to_0db)
        # data augmentation on s21 during training
        if self.train:
            s21 = augment_s21_db(s21, self.cfg)
        # apply frequency mask attenuation (feature-level downweighting)
        if self.freq_mask is not None:
            s21 = s21 * self.freq_mask
        # torch tensors
        wf = torch.from_numpy(wf)[None, :]  # [1,T]
        s21 = torch.from_numpy(s21)         # [F]
        y = torch.tensor(y, dtype=torch.float32)
        return wf, s21, y


# ------------------------------
# Model
# ------------------------------
class WaveformEncoder(nn.Module):
    def __init__(self, in_ch: int = 1, hidden: int = 64, t: int = 2048):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(in_ch, 32, kernel_size=9, stride=2, padding=4),
            nn.ReLU(),
            nn.Conv1d(32, 64, kernel_size=7, stride=2, padding=3),
            nn.ReLU(),
            nn.Conv1d(64, 96, kernel_size=5, stride=2, padding=2),
            nn.ReLU(),
            nn.Conv1d(96, 128, kernel_size=5, stride=2, padding=2),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
        )
        self.proj = nn.Linear(128, hidden)

    def forward(self, x):  # x: [B,1,T]
        h = self.net(x)    # [B,128,1]
        h = h.squeeze(-1)  # [B,128]
        h = self.proj(h)   # [B,H]
        return h


class S21Encoder(nn.Module):
    def __init__(self, f_len: int = 1024, hidden: int = 64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(f_len, 256), nn.ReLU(),
            nn.Linear(256, 128), nn.ReLU(),
            nn.Linear(128, hidden)
        )

    def forward(self, s21):  # s21: [B,F]
        return self.net(s21)


class Predictor(nn.Module):
    def __init__(self, wf_hidden: int = 64, s_hidden: int = 64):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(wf_hidden + s_hidden, 128), nn.ReLU(),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, wf_feat, s_feat):
        x = torch.cat([wf_feat, s_feat], dim=-1)
        y = self.mlp(x).squeeze(-1)
        return y


class Sim2RealModel(nn.Module):
    def __init__(self, cfg: CONFIG):
        super().__init__()
        self.wf_enc = WaveformEncoder(in_ch=cfg.wf_channels, hidden=64, t=cfg.waveform_len)
        self.s_enc = S21Encoder(f_len=cfg.s21_len, hidden=64)
        self.head = Predictor(wf_hidden=64, s_hidden=64)

    def forward(self, wf, s21):  # wf: [B,1,T]; s21: [B,F]
        wf_feat = self.wf_enc(wf)
        s_feat = self.s_enc(s21)
        y = self.head(wf_feat, s_feat)
        return y


# ------------------------------
# Training / Eval
# ------------------------------

def train(cfg: CONFIG):
    set_seed()

    # optional frequency mask (requires cfg.freq_grid_mhz)
    freq_mask = build_freq_mask(cfg.freq_grid_mhz, cfg.freq_mask_cut_mhz, cfg.mask_attenuation)

    ds_tr = SimDataset(cfg.train_npz, cfg, freq_mask=freq_mask, train=True)
    ds_va = SimDataset(cfg.val_npz, cfg, freq_mask=freq_mask, train=False)

    dl_tr = DataLoader(ds_tr, batch_size=cfg.batch_size, shuffle=True, num_workers=2, drop_last=True)
    dl_va = DataLoader(ds_va, batch_size=cfg.batch_size, shuffle=False, num_workers=2)

    model = Sim2RealModel(cfg).to(cfg.device)
    optim = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode='min', factor=0.5, patience=5)
    loss_fn = nn.MSELoss()

    best_val = float('inf')
    best_path = Path("artifacts/best.pt")
    best_path.parent.mkdir(parents=True, exist_ok=True)

    for epoch in range(1, cfg.epochs + 1):
        model.train()
        tr_loss = 0.0
        for wf, s21, y in dl_tr:
            wf = wf.to(cfg.device)
            s21 = s21.to(cfg.device)
            y = y.to(cfg.device)
            optim.zero_grad()
            yhat = model(wf, s21)
            loss = loss_fn(yhat, y)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            optim.step()
            tr_loss += loss.item() * y.size(0)
        tr_loss /= len(ds_tr)

        # validation
        model.eval()
        va_loss = 0.0
        with torch.no_grad():
            for wf, s21, y in dl_va:
                wf = wf.to(cfg.device)
                s21 = s21.to(cfg.device)
                y = y.to(cfg.device)
                yhat = model(wf, s21)
                loss = loss_fn(yhat, y)
                va_loss += loss.item() * y.size(0)
        va_loss /= len(ds_va)
        sched.step(va_loss)

        print(f"Epoch {epoch:03d} | train {tr_loss:.4f} | val {va_loss:.4f}")
        if va_loss < best_val:
            best_val = va_loss
            torch.save({"model": model.state_dict(), "cfg": cfg.__dict__}, best_path)
            print(f"  ↳ Saved best to {best_path} (val {best_val:.4f})")


@torch.no_grad()
def predict_with_real(cfg: CONFIG):
    # load best
    ckpt = torch.load("artifacts/best.pt", map_location=cfg.device)
    model = Sim2RealModel(cfg).to(cfg.device)
    model.load_state_dict(ckpt["model"]) 
    model.eval()

    # load data for prediction
    real_s21 = np.load(cfg.s21_real_path).astype(np.float32)  # shape [F]
    if cfg.s21_in_db:
        real_s21 = normalize_s21_db(real_s21, peak_to_0db=cfg.s21_peak_to_0db)

    # apply same frequency mask attenuation
    freq_mask = build_freq_mask(cfg.freq_grid_mhz, cfg.freq_mask_cut_mhz, cfg.mask_attenuation)
    if freq_mask is not None:
        real_s21 = real_s21 * freq_mask

    real_s21_t = torch.from_numpy(real_s21)[None, :].to(cfg.device)  # [1,F]

    # Example: predict counts for each waveform in validation set with S21_real
    data = np.load(cfg.val_npz)
    waveforms = data['waveforms'].astype(np.float32)
    wf_ds = []
    for i in range(waveforms.shape[0]):
        wf = standardize_waveform(waveforms[i])
        wf_ds.append(torch.from_numpy(wf)[None, :])  # [1,T]
    wf_batch = torch.stack(wf_ds, dim=0).to(cfg.device)  # [N,1,T]
    s21_batch = real_s21_t.repeat(wf_batch.size(0), 1)   # [N,F]

    yhat = model(wf_batch, s21_batch).cpu().numpy()      # predicted counts under real S21
    out_path = Path("artifacts/pred_counts_real_gain0.npy")
    out_path.parent.mkdir(parents=True, exist_ok=True)
    np.save(out_path, yhat)
    print(f"Saved predictions under real S21 to {out_path}")


# ------------------------------
# Main
# ------------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--train", action="store_true", help="Train the model on simulated data")
    parser.add_argument("--predict-real", action="store_true", help="Predict counts for val set using real S21")
    args = parser.parse_args()

    cfg = CONFIG()

    if args.train:
        train(cfg)
    if args.predict_real:
        predict_with_real(cfg)
