In [None]:
# LTEV_LateFusion_Sweep.py
# One script (single run produces all artifacts):
#   - Preload all frames once per SNR (shared)
#   - BASELINE RAW IQ: train ONE single-frame model per fold (independent of m),
#       then compute test-frame logits once, and late-fusion Acc(m) for all m in M_LIST.
#   - XFR: for each m, build XFR blocks by m, train one model per fold (row-level),
#       compute test row-logits once, evaluate S in S_LIST with rule S>m => all.
#   - Output per SNR:
#       * long results CSV (fold-level)
#       * Table1 (m@99) CSV + LaTeX
#       * Table2 (fixed q) CSV + LaTeX
#       * Curve data CSV + Plot PNG
#       * Raw artifacts: per-fold logits npz, splits, block definitions

import os
import glob
import json
import csv
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
from datetime import datetime
from tqdm import tqdm
import h5py

# =========================
# Global config (only edit paths / SNR_LIST if needed)
# =========================
DATA_PATH = "E:/rf_datasets/"   # *.mat folder

# PHY
FS = 5e6
FC = 5.9e9
V_KMH = 120
APPLY_DOPPLER = True
APPLY_AWGN = True

# Sweep
SNR_LIST = [20, -5]  # e.g. list(range(20, -45, -5))
M_LIST = [1, 4, 8, 16, 32, 64, 128, 256]
S_LIST = [1, 4, 8, 16, 32, "all"]
Q_LIST = [1, 4, 8, 16, 32]
ACC_TARGET = 99.0

# Selection modes
ROW_SELECT_MODE = "linspace"   # for S rows out of L (=288)
BASELINE_BLOCK_SHUFFLE = True  # baseline: shuffle frames within each TX before grouping into blocks
BASELINE_BLOCK_SEED_PER_M = 777  # keep blocks deterministic across folds

# Safety caps (XFR dataset can explode for small m)
MAX_BLOCKS_PER_CLASS = 1200     # None to disable (not recommended for XFR small m)
MAX_TOTAL_BLOCKS = 12000        # None to disable

# Train
SEED = 42
N_SPLITS = 5
BATCH_SIZE = 64
NUM_EPOCHS = 300
LR = 1e-4
WEIGHT_DECAY = 1e-3
IN_PLANES = 64
DROPOUT = 0.5
PATIENCE = 5

# Early stopping metric
# baseline: val loss (single-frame) (independent of m)
# xfr: fused val acc at S="all"
XFR_EARLY_STOP_S = "all"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# =========================
# Repro
# =========================
def set_seed(sd=42):
    np.random.seed(sd)
    torch.manual_seed(sd)
    torch.cuda.manual_seed_all(sd)

set_seed(SEED)


# =========================
# PHY helpers
# =========================
def compute_doppler_shift(v_kmh, fc_hz):
    c = 3e8
    v_ms = v_kmh / 3.6
    return (v_ms / c) * fc_hz

def apply_doppler_shift(signal_c, fd_hz, fs_hz):
    t = np.arange(signal_c.shape[-1], dtype=np.float64) / fs_hz
    phase = np.exp(1j * 2 * np.pi * fd_hz * t)
    return signal_c * phase

def add_awgn(signal_c, snr_db):
    p = np.mean(np.abs(signal_c) ** 2)
    n_p = p / (10 ** (snr_db / 10))
    noise = np.sqrt(n_p / 2) * (np.random.randn(*signal_c.shape) + 1j * np.random.randn(*signal_c.shape))
    return signal_c + noise

def power_normalize(signal_c, eps=1e-12):
    return signal_c / (np.sqrt(np.mean(np.abs(signal_c) ** 2)) + eps)


# =========================
# HDF5 helpers
# =========================
def read_tx_id_str(rfDataset):
    txID_uint16 = rfDataset["txID"][:].flatten()
    return "".join(chr(int(c)) for c in txID_uint16 if int(c) != 0)

def read_dmrs_complex(rfDataset):
    dmrs = rfDataset["dmrs"]
    if isinstance(dmrs, h5py.Group):
        real = dmrs["real"][:]
        imag = dmrs["imag"][:]
        return real + 1j * imag
    arr = dmrs[:]
    if hasattr(arr, "dtype") and arr.dtype.fields is not None and ("real" in arr.dtype.fields) and ("imag" in arr.dtype.fields):
        return arr["real"] + 1j * arr["imag"]
    return arr


# =========================
# Preload all files once per SNR (shared by baseline + XFR)
# Returns:
#   tx_list, label_to_idx
#   tx_to_seqs: {tx: [arr_file1(N,L,2), ...]}
#   baseline_global_frames: X_all (Ntot,L,2), y_all (Ntot,)
#   tx_to_global_indices: {tx: np.array(indices in X_all)}
# =========================
def preload_by_snr(mat_folder, snr_db):
    mat_files = glob.glob(os.path.join(mat_folder, "*.mat"))
    if len(mat_files) == 0:
        raise RuntimeError(f"No .mat in {mat_folder}")

    fd = compute_doppler_shift(V_KMH, FC)

    # first pass: group by tx
    tx_to_files = {}
    for fp in tqdm(mat_files, desc=f"[SNR{snr_db}] Scan txID"):
        with h5py.File(fp, "r") as f:
            rfDataset = f["rfDataset"]
            tx = read_tx_id_str(rfDataset)
        tx_to_files.setdefault(tx, []).append(fp)

    tx_list = sorted(tx_to_files.keys())
    label_to_idx = {tx: i for i, tx in enumerate(tx_list)}
    num_classes = len(tx_list)
    print(f"[SNR{snr_db}] classes={num_classes} | label_to_idx={label_to_idx}")

    tx_to_seqs = {}
    X_all_list = []
    y_all_list = []
    tx_to_global_indices = {}

    global_cursor = 0
    seg_len_ref = None

    for tx in tx_list:
        files = sorted(tx_to_files[tx])
        seqs = []
        tx_global_indices = []

        for fp in tqdm(files, desc=f"[SNR{snr_db}] Load TX={tx}", leave=False):
            with h5py.File(fp, "r") as f:
                rfDataset = f["rfDataset"]
                dmrs_complex = read_dmrs_complex(rfDataset)  # (N,L) complex

            if dmrs_complex.ndim != 2:
                raise RuntimeError(f"dmrs shape error: {dmrs_complex.shape} | {fp}")
            N, L = dmrs_complex.shape
            if seg_len_ref is None:
                seg_len_ref = L
            else:
                if L != seg_len_ref:
                    raise RuntimeError(f"seg_len mismatch: {seg_len_ref} vs {L} in {fp}")

            processed = np.empty((N, L, 2), dtype=np.float32)
            for i in range(N):
                sig = dmrs_complex[i, :]
                sig = power_normalize(sig)
                if APPLY_DOPPLER:
                    sig = apply_doppler_shift(sig, fd, FS)
                if APPLY_AWGN:
                    sig = add_awgn(sig, snr_db)

                processed[i, :, 0] = sig.real.astype(np.float32)
                processed[i, :, 1] = sig.imag.astype(np.float32)

            seqs.append(processed)

            # baseline global frames
            X_all_list.append(processed)
            y_all_list.append(np.full((N,), label_to_idx[tx], dtype=np.int64))

            # record global indices for this chunk
            idx = np.arange(global_cursor, global_cursor + N, dtype=np.int64)
            tx_global_indices.append(idx)
            global_cursor += N

        tx_to_seqs[tx] = seqs
        tx_to_global_indices[tx] = np.concatenate(tx_global_indices, axis=0)

    X_all = np.concatenate(X_all_list, axis=0).astype(np.float32)   # (Ntot,L,2)
    y_all = np.concatenate(y_all_list, axis=0).astype(np.int64)

    print(f"[SNR{snr_db}] Global frames: X={X_all.shape} y={y_all.shape} | seg_len={seg_len_ref}")
    return tx_list, label_to_idx, tx_to_seqs, X_all, y_all, tx_to_global_indices, seg_len_ref


# =========================
# Utility: pick S rows (for XFR row-fusion)
# =========================
def pick_rows(L, S, mode="linspace"):
    if S == "all" or S is None or (isinstance(S, int) and S >= L):
        return None
    S = int(S)
    if mode == "linspace":
        idx = np.round(np.linspace(0, L - 1, num=S)).astype(int)
        idx = np.unique(np.clip(idx, 0, L - 1))
        if idx.size < S:
            missing = S - idx.size
            pool = [i for i in range(L) if i not in set(idx)]
            idx = np.concatenate([idx, np.array(pool[:missing], dtype=int)])
        idx = np.sort(idx)[:S]
        return idx
    raise ValueError(f"Unknown mode={mode}")


# =========================
# XFR block builder by m
# =========================
def build_xfr_blocks_by_m(tx_list, label_to_idx, tx_to_seqs, m):
    X_blocks = []
    y_blocks = []

    for tx in tx_list:
        seqs = tx_to_seqs[tx]
        num_files = len(seqs)
        pointers = [0] * num_files
        blocks_tx = []

        b = 0
        while True:
            start_file = b % num_files
            frames = []
            ok = True
            for t in range(m):
                fi = (start_file + t) % num_files
                pi = pointers[fi]
                if pi >= seqs[fi].shape[0]:
                    ok = False
                    break
                frames.append(seqs[fi][pi])  # (L,2)
                pointers[fi] += 1
            if not ok:
                break
            big = np.stack(frames, axis=0)            # (m,L,2)
            big = np.transpose(big, (1, 0, 2))        # (L,m,2)
            blocks_tx.append(big)
            b += 1

        if len(blocks_tx) == 0:
            continue

        X_tx = np.stack(blocks_tx, axis=0).astype(np.float32)  # (Btx,L,m,2)
        y_tx = np.full((X_tx.shape[0],), label_to_idx[tx], dtype=np.int64)

        X_blocks.append(X_tx)
        y_blocks.append(y_tx)

    X_blocks = np.concatenate(X_blocks, axis=0)
    y_blocks = np.concatenate(y_blocks, axis=0)
    return X_blocks, y_blocks


# =========================
# Baseline block builder by m (on a given pool, not on global)
# pool positions are [0..Npool-1], grouped within each TX
# Returns: blocks_pos (Nblk,m) positions into pool, y_blocks (Nblk,)
# =========================
def build_baseline_blocks_from_pool(y_pool, num_classes, m, shuffle=True, seed=0):
    rng = np.random.default_rng(seed)
    blocks = []
    labels = []
    for c in range(num_classes):
        pos = np.where(y_pool == c)[0].astype(np.int64)
        if pos.size == 0:
            continue
        if shuffle:
            rng.shuffle(pos)
        nblk = pos.size // m
        if nblk <= 0:
            continue
        pos = pos[: nblk * m].reshape(nblk, m)
        blocks.append(pos)
        labels.append(np.full((nblk,), c, dtype=np.int64))
    if len(blocks) == 0:
        return None, None
    return np.concatenate(blocks, axis=0), np.concatenate(labels, axis=0)


# =========================
# Caps: enforce balanced session count for feasibility (XFR)
# =========================
def cap_blocks_balanced(X_blocks, y_blocks, num_classes):
    rng = np.random.default_rng(SEED + 999)

    kept = []
    for c in range(num_classes):
        idx = np.where(y_blocks == c)[0]
        if idx.size == 0:
            continue
        rng.shuffle(idx)
        if MAX_BLOCKS_PER_CLASS is not None:
            idx = idx[: min(idx.size, MAX_BLOCKS_PER_CLASS)]
        kept.append(idx)

    if len(kept) == 0:
        raise RuntimeError("No blocks after per-class capping.")

    kept = np.concatenate(kept, axis=0)
    rng.shuffle(kept)

    if MAX_TOTAL_BLOCKS is not None and kept.size > MAX_TOTAL_BLOCKS:
        kept = kept[:MAX_TOTAL_BLOCKS]

    return X_blocks[kept], y_blocks[kept], kept


# =========================
# Model: ResNet18 1D (supports variable length via AdaptiveAvgPool)
#   - baseline input: (L,2) with L=288
#   - XFR row input: (m,2) with variable m
# =========================
class BasicBlock1D(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1, downsample=None, dropout=0.0):
        super().__init__()
        self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm1d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(p=dropout)
        self.conv2 = nn.Conv1d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(planes)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.dropout(out)
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return self.relu(out)

class ResNet18_1D(nn.Module):
    def __init__(self, num_classes, in_planes=64, dropout=0.0):
        super().__init__()
        self.in_planes = in_planes
        self.conv1 = nn.Conv1d(2, in_planes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm1d(in_planes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(64, 2, stride=1, dropout=dropout)
        self.layer2 = self._make_layer(128, 2, stride=2, dropout=dropout)
        self.layer3 = self._make_layer(256, 2, stride=2, dropout=dropout)
        self.layer4 = self._make_layer(512, 2, stride=2, dropout=dropout)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, planes, blocks, stride, dropout):
        downsample = None
        if stride != 1 or self.in_planes != planes:
            downsample = nn.Sequential(
                nn.Conv1d(self.in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(planes)
            )
        layers = [BasicBlock1D(self.in_planes, planes, stride, downsample, dropout)]
        self.in_planes = planes
        for _ in range(1, blocks):
            layers.append(BasicBlock1D(self.in_planes, planes, dropout=dropout))
        return nn.Sequential(*layers)

    def forward(self, x):
        # x: (B, T, 2) -> (B,2,T)
        x = x.permute(0, 2, 1)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x).squeeze(-1)
        return self.fc(x)


# =========================
# Datasets
# =========================
class XFRRowDataset(Dataset):
    # X_blocks: torch.FloatTensor (Nblk,L,m,2) on CPU
    # y_blocks: torch.LongTensor (Nblk,)
    def __init__(self, X_blocks, y_blocks, block_indices, L):
        self.X = X_blocks
        self.y = y_blocks
        self.block_indices = np.array(block_indices, dtype=np.int64)
        self.L = int(L)

    def __len__(self):
        return self.block_indices.size * self.L

    def __getitem__(self, idx):
        b = idx // self.L
        r = idx % self.L
        blk = int(self.block_indices[b])
        x = self.X[blk, r]  # (m,2)
        y = self.y[blk]
        return x, y

class FrameIndexDataset(Dataset):
    # X_all: torch.FloatTensor (Ntot,L,2), y_all: torch.LongTensor (Ntot,)
    def __init__(self, X_all, y_all, frame_indices):
        self.X = X_all
        self.y = y_all
        self.idx = np.array(frame_indices, dtype=np.int64)

    def __len__(self):
        return self.idx.size

    def __getitem__(self, i):
        j = int(self.idx[i])
        return self.X[j], self.y[j]


# =========================
# Inference helpers
# =========================
@torch.no_grad()
def infer_logits_for_indices(model, X_all_cpu, frame_indices, num_classes, batch_size=512):
    model.eval()
    idx = np.array(frame_indices, dtype=np.int64)
    out = np.zeros((idx.size, num_classes), dtype=np.float32)
    for s in range(0, idx.size, batch_size):
        e = min(idx.size, s + batch_size)
        xb = X_all_cpu[idx[s:e]].to(DEVICE)  # (B,L,2)
        logits = model(xb).detach().cpu().numpy().astype(np.float32)
        out[s:e] = logits
    return out

@torch.no_grad()
def infer_xfr_row_logits_all(model, X_blk, num_classes, batch_blocks=8):
    # X_blk: torch.FloatTensor (Nblk,L,m,2) on CPU
    model.eval()
    N, L, m, C = X_blk.shape
    out = np.zeros((N, L, num_classes), dtype=np.float32)

    for s in range(0, N, batch_blocks):
        e = min(N, s + batch_blocks)
        xb = X_blk[s:e].to(DEVICE)                 # (B,L,m,2)
        B = xb.shape[0]
        flat = xb.reshape(B * L, m, C)             # (B*L,m,2)
        logits = model(flat).reshape(B, L, num_classes)
        out[s:e] = logits.detach().cpu().numpy().astype(np.float32)

    return out

def xfr_fused_acc_from_logits(logits, y_true, m, S):
    # rule: if S > m => use all rows
    if S != "all" and isinstance(S, int) and S > m:
        S_eff = "all"
    else:
        S_eff = S

    N, L, K = logits.shape
    row_idx = pick_rows(L, S_eff, mode=ROW_SELECT_MODE)
    if row_idx is None:
        fused = logits.mean(axis=1)
    else:
        fused = logits[:, row_idx, :].mean(axis=1)

    y_pred = np.argmax(fused, axis=1)
    acc = 100.0 * (y_pred == y_true).mean()
    cm = confusion_matrix(y_true, y_pred, labels=list(range(K)))
    return acc, cm, str(S_eff)

def baseline_fused_acc_from_frame_logits(frame_logits, blocks_pos, y_blocks, num_classes):
    # frame_logits: (Npool,K), blocks_pos: (Nblk,m) positions into pool
    fused = frame_logits[blocks_pos].mean(axis=1)  # (Nblk,K)
    pred = np.argmax(fused, axis=1)
    acc = 100.0 * (pred == y_blocks).mean()
    cm = confusion_matrix(y_blocks, pred, labels=list(range(num_classes)))
    return float(acc), cm


# =========================
# BASELINE: train once per fold (per SNR), evaluate ALL m on same test pool
# =========================
def train_baseline_shared_all_m(snr_db, baseline_dir, X_all_np, y_all_np, num_classes):
    os.makedirs(baseline_dir, exist_ok=True)

    X_all = torch.from_numpy(np.ascontiguousarray(X_all_np)).float()  # CPU
    y_all = torch.from_numpy(np.ascontiguousarray(y_all_np)).long()   # CPU

    # Frame-level split (fixed for all m)
    idx_all = np.arange(y_all_np.shape[0], dtype=np.int64)
    trval_idx, test_idx = train_test_split(idx_all, test_size=0.25, random_state=SEED, stratify=y_all_np)

    y_test_pool = y_all_np[test_idx].astype(np.int64)

    np.savez_compressed(os.path.join(baseline_dir, "split_frames.npz"),
                        trval_idx=trval_idx, test_idx=test_idx, y_test_pool=y_test_pool)

    # Pre-build deterministic test blocks for each m (reused across folds)
    blocks_dir = os.path.join(baseline_dir, "test_blocks")
    os.makedirs(blocks_dir, exist_ok=True)
    test_blocks_by_m = {}
    for m in M_LIST:
        blocks_pos, y_blk = build_baseline_blocks_from_pool(
            y_pool=y_test_pool,
            num_classes=num_classes,
            m=m,
            shuffle=BASELINE_BLOCK_SHUFFLE,
            seed=SEED + BASELINE_BLOCK_SEED_PER_M + m
        )
        if blocks_pos is None:
            raise RuntimeError(f"[BASE] No test blocks for m={m} (too few samples in some class?)")
        test_blocks_by_m[m] = (blocks_pos, y_blk)
        np.savez_compressed(os.path.join(blocks_dir, f"blocks_m{m:03d}.npz"),
                            blocks_pos=blocks_pos.astype(np.int32),
                            y_blocks=y_blk.astype(np.int64),
                            m=np.int32(m))

    # CV on train_val frames
    skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
    y_trval = y_all_np[trval_idx]

    fold_rows = []

    for fold, (tr_i, va_i) in enumerate(skf.split(trval_idx, y_trval), start=1):
        fold_dir = os.path.join(baseline_dir, f"fold{fold}")
        os.makedirs(fold_dir, exist_ok=True)

        train_idx = trval_idx[tr_i]
        val_idx   = trval_idx[va_i]

        train_ds = FrameIndexDataset(X_all, y_all, train_idx)
        val_ds   = FrameIndexDataset(X_all, y_all, val_idx)
        train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
        val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

        model = ResNet18_1D(num_classes=num_classes, in_planes=IN_PLANES, dropout=DROPOUT).to(DEVICE)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

        best_val_loss = float("inf")
        best_state = None
        patience_cnt = 0

        log_csv = os.path.join(fold_dir, "train_log.csv")
        with open(log_csv, "w", newline="", encoding="utf-8") as f:
            w = csv.writer(f)
            w.writerow(["epoch", "train_loss", "train_acc", "val_loss", "val_acc"])

        for epoch in range(1, NUM_EPOCHS + 1):
            # train
            model.train()
            tr_loss = 0.0
            tr_cor, tr_tot = 0, 0
            for xb, yb in train_loader:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                optimizer.zero_grad()
                logits = model(xb)
                loss = criterion(logits, yb)
                loss.backward()
                optimizer.step()
                tr_loss += loss.item()
                pred = torch.argmax(logits, dim=1)
                tr_cor += (pred == yb).sum().item()
                tr_tot += yb.size(0)
            tr_loss /= max(len(train_loader), 1)
            tr_acc = 100.0 * tr_cor / max(tr_tot, 1)

            # val
            model.eval()
            va_loss = 0.0
            va_cor, va_tot = 0, 0
            with torch.no_grad():
                for xb, yb in val_loader:
                    xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                    logits = model(xb)
                    loss = criterion(logits, yb)
                    va_loss += loss.item()
                    pred = torch.argmax(logits, dim=1)
                    va_cor += (pred == yb).sum().item()
                    va_tot += yb.size(0)
            va_loss /= max(len(val_loader), 1)
            va_acc = 100.0 * va_cor / max(va_tot, 1)

            print(f"[BASE][SNR{snr_db}] Fold{fold} Ep{epoch:03d} | "
                  f"TrainAcc {tr_acc:.2f}% | ValAcc {va_acc:.2f}% | ValLoss {va_loss:.4f}")

            with open(log_csv, "a", newline="", encoding="utf-8") as f:
                w = csv.writer(f)
                w.writerow([epoch, tr_loss, tr_acc, va_loss, va_acc])

            # early stop by val loss (independent of m)
            if va_loss < best_val_loss - 1e-6:
                best_val_loss = va_loss
                best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
                patience_cnt = 0
            else:
                patience_cnt += 1
                if patience_cnt >= PATIENCE:
                    break

            scheduler.step()

        if best_state is not None:
            model.load_state_dict(best_state)

        torch.save(model.state_dict(), os.path.join(fold_dir, "best_model.pth"))

        # Compute TEST frame logits ONCE for this fold
        test_logits = infer_logits_for_indices(model, X_all, test_idx, num_classes, batch_size=512)  # (Ntest,K)

        np.savez_compressed(os.path.join(fold_dir, "test_frame_logits.npz"),
                            logits=test_logits.astype(np.float16),
                            y_true=y_test_pool.astype(np.int64),
                            test_idx=test_idx.astype(np.int64),
                            snr_db=np.int32(snr_db),
                            K=np.int32(num_classes))

        # Evaluate all m using prebuilt blocks (fast, fair)
        for m in M_LIST:
            blocks_pos, y_blk = test_blocks_by_m[m]
            acc, cm = baseline_fused_acc_from_frame_logits(test_logits, blocks_pos, y_blk, num_classes)
            np.save(os.path.join(fold_dir, f"cm_test_m{m:03d}.npy"), cm)

            fold_rows.append({
                "method": "BASELINE_RAWIQ",
                "snr_db": int(snr_db),
                "m": int(m),
                "S": "m",
                "S_eff": "m",
                "fold": int(fold),
                "acc": float(acc),
                "val_best": float(best_val_loss)
            })

    return fold_rows


# =========================
# XFR: train one m, test all S
# =========================
def train_xfr_one_m(snr_db, m, out_dir, tx_list, label_to_idx, tx_to_seqs):
    num_classes = len(label_to_idx)

    # 1) build blocks by m
    X_np, y_np = build_xfr_blocks_by_m(tx_list, label_to_idx, tx_to_seqs, m)
    X_np, y_np, kept = cap_blocks_balanced(X_np, y_np, num_classes)

    Nblk, L, mm, C = X_np.shape
    assert mm == m
    print(f"[XFR][SNR{snr_db}] m={m} | blocks={Nblk} | X={X_np.shape}")

    with open(os.path.join(out_dir, "blocks_meta.json"), "w", encoding="utf-8") as f:
        json.dump({"snr_db": snr_db, "m": m, "blocks": int(Nblk), "L": int(L), "kept_idx_len": int(len(kept))}, f, indent=2)

    X = torch.from_numpy(np.ascontiguousarray(X_np)).float()   # CPU
    y = torch.from_numpy(np.ascontiguousarray(y_np)).long()    # CPU

    # 2) block split
    idx_all = np.arange(Nblk)
    trval_idx, test_idx = train_test_split(idx_all, test_size=0.25, random_state=SEED, stratify=y_np)

    np.savez_compressed(os.path.join(out_dir, "split_indices.npz"),
                        trval_idx=trval_idx, test_idx=test_idx, y_blocks=y_np)

    X_test_blk = X[test_idx]
    y_test_blk = y[test_idx].numpy()

    # 3) CV on trval blocks
    skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
    y_trval = y_np[trval_idx]

    fold_rows = []
    for fold, (tr_i, va_i) in enumerate(skf.split(trval_idx, y_trval), start=1):
        fold_dir = os.path.join(out_dir, f"fold{fold}")
        os.makedirs(fold_dir, exist_ok=True)

        tr_blocks = trval_idx[tr_i]
        va_blocks = trval_idx[va_i]

        train_ds = XFRRowDataset(X, y, tr_blocks, L)
        val_ds = XFRRowDataset(X, y, va_blocks, L)

        train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
        val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

        model = ResNet18_1D(num_classes=num_classes, in_planes=IN_PLANES, dropout=DROPOUT).to(DEVICE)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

        best_val_fused = -1.0
        best_state = None
        patience_cnt = 0

        log_csv = os.path.join(fold_dir, "train_log.csv")
        with open(log_csv, "w", newline="", encoding="utf-8") as f:
            w = csv.writer(f)
            w.writerow(["epoch", "train_loss", "train_row_acc", "val_loss", "val_row_acc", "val_fused_acc_Sall"])

        X_va_blk = X[va_blocks]
        y_va_blk = y[va_blocks].numpy()

        for epoch in range(1, NUM_EPOCHS + 1):
            model.train()
            tr_loss = 0.0
            tr_cor, tr_tot = 0, 0
            for xb, yb in train_loader:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                optimizer.zero_grad()
                logits = model(xb)
                loss = criterion(logits, yb)
                loss.backward()
                optimizer.step()
                tr_loss += loss.item()
                pred = torch.argmax(logits, dim=1)
                tr_cor += (pred == yb).sum().item()
                tr_tot += yb.size(0)
            tr_loss /= max(len(train_loader), 1)
            tr_acc = 100.0 * tr_cor / max(tr_tot, 1)

            # val row
            model.eval()
            va_loss = 0.0
            va_cor, va_tot = 0, 0
            with torch.no_grad():
                for xb, yb in val_loader:
                    xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                    logits = model(xb)
                    loss = criterion(logits, yb)
                    va_loss += loss.item()
                    pred = torch.argmax(logits, dim=1)
                    va_cor += (pred == yb).sum().item()
                    va_tot += yb.size(0)
            va_loss /= max(len(val_loader), 1)
            va_row_acc = 100.0 * va_cor / max(va_tot, 1)

            # val fused (S=all)
            va_logits = infer_xfr_row_logits_all(model, X_va_blk, num_classes, batch_blocks=8)
            va_acc_fused, _, _ = xfr_fused_acc_from_logits(va_logits, y_va_blk, m, XFR_EARLY_STOP_S)

            print(f"[XFR][SNR{snr_db}] m={m:3d} Fold{fold} Ep{epoch:03d} | "
                  f"TrainRowAcc {tr_acc:.2f}% | ValRowAcc {va_row_acc:.2f}% | ValFusedAcc(S=all) {va_acc_fused:.2f}%")

            with open(log_csv, "a", newline="", encoding="utf-8") as f:
                w = csv.writer(f)
                w.writerow([epoch, tr_loss, tr_acc, va_loss, va_row_acc, va_acc_fused])

            if va_acc_fused > best_val_fused + 1e-6:
                best_val_fused = va_acc_fused
                best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
                patience_cnt = 0
            else:
                patience_cnt += 1
                if patience_cnt >= PATIENCE:
                    break

            scheduler.step()

        if best_state is not None:
            model.load_state_dict(best_state)

        # ---- TEST: compute logits once, reuse for all S ----
        test_logits = infer_xfr_row_logits_all(model, X_test_blk, num_classes, batch_blocks=8)  # (Ntest,L,K)

        np.savez_compressed(os.path.join(fold_dir, "test_row_logits.npz"),
                            logits=test_logits.astype(np.float16),
                            y_true=y_test_blk.astype(np.int64),
                            snr_db=np.int32(snr_db),
                            m=np.int32(m),
                            L=np.int32(L),
                            K=np.int32(num_classes))

        for S in S_LIST:
            acc, cm, S_eff = xfr_fused_acc_from_logits(test_logits, y_test_blk, m, S)
            np.save(os.path.join(fold_dir, f"cm_test_S{S}.npy"), cm)
            fold_rows.append({
                "method": "XFR",
                "snr_db": int(snr_db),
                "m": int(m),
                "S": str(S),
                "S_eff": str(S_eff),
                "fold": int(fold),
                "acc": float(acc),
                "val_best": float(best_val_fused)
            })

        torch.save(model.state_dict(), os.path.join(fold_dir, "best_model.pth"))

    return fold_rows


# =========================
# Postprocess: aggregate, make tables + plot
# =========================
def aggregate_results_to_csv(rows, csv_path):
    keys = list(rows[0].keys())
    with open(csv_path, "w", newline="", encoding="utf-8") as f:
        w = csv.DictWriter(f, fieldnames=keys)
        w.writeheader()
        w.writerows(rows)

def aggregate_mean_std(rows):
    from collections import defaultdict
    agg = defaultdict(list)
    for r in rows:
        key = (r["method"], int(r["snr_db"]), int(r["m"]), str(r.get("S_eff", r["S"])))
        agg[key].append(float(r["acc"]))
    out = {}
    for k, vals in agg.items():
        out[k] = (float(np.mean(vals)), float(np.std(vals)))
    return out

def find_m_at_target(curve_points, target=99.0):
    curve_points = sorted(curve_points, key=lambda x: x[0])
    for m, a in curve_points:
        if a >= target:
            return int(m)
    return None

def write_latex_table(path, latex_str):
    with open(path, "w", encoding="utf-8") as f:
        f.write(latex_str)

def make_tables_and_plot(out_root, snr_db, rows, num_classes):
    agg = aggregate_mean_std(rows)

    # baseline curve: (m -> mean,std)
    baseline_curve = []
    for m in M_LIST:
        key = ("BASELINE_RAWIQ", snr_db, m, "m")
        if key in agg:
            baseline_curve.append((m, agg[key][0], agg[key][1]))

    # XFR curves (nominal S, but lookup uses S_eff rule when S>m)
    xfr_curves = {str(S): [] for S in S_LIST}
    for m in M_LIST:
        for S in S_LIST:
            if S != "all" and isinstance(S, int) and S > m:
                S_eff = "all"
            else:
                S_eff = str(S)
            key = ("XFR", snr_db, m, S_eff)
            if key in agg:
                mean, std = agg[key]
                xfr_curves[str(S)].append((m, mean, std, S_eff))

    # ---- Table 1: m@99 ----
    baseline_m99 = find_m_at_target([(m, mean) for (m, mean, std) in baseline_curve], target=ACC_TARGET)

    xfr_m99 = {}
    for S in S_LIST:
        pts = [(m, mean) for (m, mean, std, S_eff) in xfr_curves[str(S)]]
        xfr_m99[str(S)] = find_m_at_target(pts, target=ACC_TARGET)

    table1_csv = os.path.join(out_root, f"SNR{snr_db}dB_table1_m_at_{int(ACC_TARGET)}.csv")
    with open(table1_csv, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["Method", "Budget(q)", f"m@{int(ACC_TARGET)}"])
        w.writerow(["RAW IQ late fusion (baseline)", "q=m", baseline_m99 if baseline_m99 is not None else "NA"])
        for S in S_LIST:
            m99 = xfr_m99[str(S)]
            if m99 is None:
                q_show = str(S)
            else:
                if S != "all" and isinstance(S, int) and S > m99:
                    q_show = "all"
                else:
                    q_show = str(S)
            w.writerow([f"XFR row-fusion (S={S})", f"q={q_show}", m99 if m99 is not None else "NA"])

    latex1 = []
    latex1.append(r"\begin{table}[t]")
    latex1.append(r"\centering")
    latex1.append(r"\caption{Collection budget to reach %d\%% accuracy ($m@%d$). Baseline uses $q=m$ inferences; XFR uses $q=S$ (if $S>m$, we use all rows).}" % (int(ACC_TARGET), int(ACC_TARGET)))
    latex1.append(r"\label{tab:m_at_99}")
    latex1.append(r"\begin{tabular}{lcc}")
    latex1.append(r"\hline")
    latex1.append(r"Method & Budget $q$ & $m@%d$ \\" % int(ACC_TARGET))
    latex1.append(r"\hline")
    latex1.append(r"RAW IQ late fusion & $q=m$ & %s \\" % (str(baseline_m99) if baseline_m99 is not None else "N/A"))
    for S in S_LIST:
        m99 = xfr_m99[str(S)]
        if m99 is None:
            q_show = str(S)
        else:
            if S != "all" and isinstance(S, int) and S > m99:
                q_show = "all"
            else:
                q_show = str(S)
        latex1.append(r"XFR row-fusion ($S=%s$) & $q=%s$ & %s \\" %
                      (str(S), str(q_show), (str(m99) if m99 is not None else "N/A")))
    latex1.append(r"\hline")
    latex1.append(r"\end{tabular}")
    latex1.append(r"\end{table}")
    write_latex_table(os.path.join(out_root, f"SNR{snr_db}dB_table1.tex"), "\n".join(latex1))

    # ---- Table 2: fixed inference budget q ----
    table2_csv = os.path.join(out_root, f"SNR{snr_db}dB_table2_fixed_q.csv")
    with open(table2_csv, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["q", "Baseline Acc (m=q)", "XFR Acc (m=q, S=q)", "Delta(XFR-Baseline)"])
        for q in Q_LIST:
            kb = ("BASELINE_RAWIQ", snr_db, q, "m")
            b = agg[kb][0] if kb in agg else None
            kx = ("XFR", snr_db, q, str(q))
            x = agg[kx][0] if kx in agg else None
            if b is None or x is None:
                w.writerow([q, "NA", "NA", "NA"])
            else:
                w.writerow([q, f"{b:.2f}", f"{x:.2f}", f"{(x-b):.2f}"])

    latex2 = []
    latex2.append(r"\begin{table}[t]")
    latex2.append(r"\centering")
    latex2.append(r"\caption{Fair comparison under the same inference budget $q$ (number of model calls). Baseline uses $m=q$ late fusion. XFR uses $m=q$ and $S=q$ row fusion.}")
    latex2.append(r"\label{tab:fixed_q}")
    latex2.append(r"\begin{tabular}{cccc}")
    latex2.append(r"\hline")
    latex2.append(r"$q$ & Baseline Acc (\%) & XFR Acc (\%) & $\Delta$ \\")
    latex2.append(r"\hline")
    for q in Q_LIST:
        kb = ("BASELINE_RAWIQ", snr_db, q, "m")
        kx = ("XFR", snr_db, q, str(q))
        if kb in agg and kx in agg:
            b = agg[kb][0]
            x = agg[kx][0]
            latex2.append(r"%d & %.2f & %.2f & %.2f \\" % (q, b, x, x - b))
        else:
            latex2.append(r"%d & N/A & N/A & N/A \\" % q)
    latex2.append(r"\hline")
    latex2.append(r"\end{tabular}")
    latex2.append(r"\end{table}")
    write_latex_table(os.path.join(out_root, f"SNR{snr_db}dB_table2.tex"), "\n".join(latex2))

    # ---- Curve data CSV ----
    curve_csv = os.path.join(out_root, f"SNR{snr_db}dB_curve_data.csv")
    with open(curve_csv, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["method", "S", "m", "acc_mean", "acc_std"])
        for (m, mean, std) in baseline_curve:
            w.writerow(["BASELINE_RAWIQ", "m", m, f"{mean:.4f}", f"{std:.4f}"])
        for S in S_LIST:
            for (m, mean, std, S_eff) in xfr_curves[str(S)]:
                w.writerow(["XFR", f"{S}(eff={S_eff})", m, f"{mean:.4f}", f"{std:.4f}"])

    # ---- Plot ----
    plt.figure(figsize=(8, 5))
    if len(baseline_curve) > 0:
        ms = [x[0] for x in baseline_curve]
        accs = [x[1] for x in baseline_curve]
        plt.plot(ms, accs, marker="o", label="RAW IQ baseline (late fusion)")

    for S in S_LIST:
        pts = xfr_curves[str(S)]
        if len(pts) == 0:
            continue
        ms = [p[0] for p in pts]
        accs = [p[1] for p in pts]
        plt.plot(ms, accs, marker="o", label=f"XFR (S={S})")

    plt.axhline(ACC_TARGET, linestyle="--")
    plt.xlabel("m (frames per session)")
    plt.ylabel("Accuracy (%)")
    plt.title(f"LTE-V: Accuracy vs m (SNR={snr_db} dB)")
    plt.grid(True)
    plt.legend()

    fig_path = os.path.join(out_root, f"SNR{snr_db}dB_acc_vs_m.png")
    plt.tight_layout()
    plt.savefig(fig_path, dpi=200)
    plt.close()

    m99_json = {
        "snr_db": int(snr_db),
        "acc_target": float(ACC_TARGET),
        "num_classes": int(num_classes),
        "baseline_m_at_target": baseline_m99,
        "xfr_m_at_target": xfr_m99
    }
    with open(os.path.join(out_root, f"SNR{snr_db}dB_m_at_{int(ACC_TARGET)}.json"), "w", encoding="utf-8") as f:
        json.dump(m99_json, f, ensure_ascii=False, indent=2)

    print(f"[POST][SNR{snr_db}] Table1: {table1_csv}")
    print(f"[POST][SNR{snr_db}] Table2: {table2_csv}")
    print(f"[POST][SNR{snr_db}] Curve data: {curve_csv}")
    print(f"[POST][SNR{snr_db}] Plot: {fig_path}")


# =========================
# Main
# =========================
def main():
    ts = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    out_root = os.path.join(os.getcwd(), "training_results", f"{ts}_LTEV_LateFusion_Sweep")
    os.makedirs(out_root, exist_ok=True)

    run_cfg = {
        "DATA_PATH": DATA_PATH,
        "SNR_LIST": SNR_LIST,
        "M_LIST": M_LIST,
        "S_LIST": S_LIST,
        "Q_LIST": Q_LIST,
        "ACC_TARGET": ACC_TARGET,
        "MAX_BLOCKS_PER_CLASS": MAX_BLOCKS_PER_CLASS,
        "MAX_TOTAL_BLOCKS": MAX_TOTAL_BLOCKS,
        "PHY": {"FS": FS, "FC": FC, "V_KMH": V_KMH, "APPLY_DOPPLER": APPLY_DOPPLER, "APPLY_AWGN": APPLY_AWGN},
        "TRAIN": {"SEED": SEED, "N_SPLITS": N_SPLITS, "BATCH_SIZE": BATCH_SIZE, "NUM_EPOCHS": NUM_EPOCHS,
                  "LR": LR, "WEIGHT_DECAY": WEIGHT_DECAY, "IN_PLANES": IN_PLANES, "DROPOUT": DROPOUT, "PATIENCE": PATIENCE},
        "RULE": {"S_gt_m_to_all": True},
        "BASELINE": {"train_once_per_snr": True, "blocks_built_on_test_pool": True},
        "DEVICE": str(DEVICE)
    }
    with open(os.path.join(out_root, "run_config.json"), "w", encoding="utf-8") as f:
        json.dump(run_cfg, f, ensure_ascii=False, indent=2)

    for snr_db in SNR_LIST:
        snr_dir = os.path.join(out_root, f"SNR{snr_db}dB")
        os.makedirs(snr_dir, exist_ok=True)

        # preload once per SNR
        tx_list, label_to_idx, tx_to_seqs, X_all, y_all, tx_to_global_indices, seg_len = preload_by_snr(DATA_PATH, snr_db)
        num_classes = len(label_to_idx)

        with open(os.path.join(snr_dir, "preload_meta.json"), "w", encoding="utf-8") as f:
            json.dump({"snr_db": int(snr_db), "num_classes": int(num_classes), "seg_len": int(seg_len), "label_to_idx": label_to_idx},
                      f, ensure_ascii=False, indent=2)

        all_rows = []

        # ---- BASELINE: train once per SNR, evaluate all m ----
        baseline_dir = os.path.join(snr_dir, "BASELINE_RAWIQ_SHARED")
        rows_base = train_baseline_shared_all_m(snr_db, baseline_dir, X_all, y_all, num_classes)
        all_rows.extend(rows_base)

        # ---- XFR: train per m ----
        xfr_root = os.path.join(snr_dir, "XFR")
        os.makedirs(xfr_root, exist_ok=True)
        for m in M_LIST:
            m_dir = os.path.join(xfr_root, f"m{m:03d}")
            os.makedirs(m_dir, exist_ok=True)
            rows_xfr = train_xfr_one_m(snr_db, m, m_dir, tx_list, label_to_idx, tx_to_seqs)
            all_rows.extend(rows_xfr)

        # save long results
        long_csv = os.path.join(snr_dir, f"SNR{snr_db}dB_results_long.csv")
        aggregate_results_to_csv(all_rows, long_csv)
        print(f"[SAVE][SNR{snr_db}] long results: {long_csv}")

        # postprocess (tables + plot)
        make_tables_and_plot(snr_dir, int(snr_db), all_rows, num_classes)

    print(f"[DONE] All outputs in: {out_root}")


if __name__ == "__main__":
    print(f"[INFO] device={DEVICE}")
    main()
