In [None]:
# ==========================================
# LTE-V XFR Contiguous-m: m sweep @ SNR=20dB
# - No K-fold
# - Train/Test split by blocks (disjoint blocks)
# - Within train blocks: randomly split 20% as val (stratified by block labels)
# ==========================================

import os
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from tqdm import tqdm
import h5py

# ================= 参数设置 =================
data_path = "D:/users/zhongyuan/rf_datasets"  # 数据文件夹
fs = 5e6
fc = 5.9e9
v = 120                 # km/h（用于你的人为 Doppler 注入）
apply_doppler = True
apply_awgn = True

SNR_DB = 20            # 固定 20dB
test_size = 0.25       # block-level train/test split
val_size_in_train = 0.20  # 从训练 block 中划 20% 做验证（block-level）
random_state = 42

# m sweep（按你的需要改）
M_LIST = list(range(8, 577, 16))

# 模型超参数
batch_size = 64
num_epochs = 300
learning_rate = 1e-4
weight_decay = 1e-3
in_planes = 64
dropout = 0.5
patience = 5

# ================= 多普勒和AWGN处理函数 =================
def compute_doppler_shift(v_kmh, fc_hz):
    c = 3e8
    v_mps = v_kmh / 3.6
    return (v_mps / c) * fc_hz

def apply_doppler_shift(signal, fd, fs):
    t = np.arange(signal.shape[-1]) / fs
    doppler_phase = np.exp(1j * 2 * np.pi * fd * t)
    return signal * doppler_phase

def add_awgn(signal, snr_db):
    signal_power = np.mean(np.abs(signal)**2)
    noise_power = signal_power / (10**(snr_db/10))
    noise_real = np.random.randn(*signal.shape)
    noise_imag = np.random.randn(*signal.shape)
    noise = np.sqrt(noise_power/2) * (noise_real + 1j*noise_imag)
    return signal + noise

# ================= 数据加载（按单文件生成 contiguous blocks，并翻转 block） =================
def load_and_preprocess_blocks_singlefile_contiguous(
    mat_folder,
    m=288,
    apply_doppler=False,
    target_velocity=30,
    apply_awgn=False,
    snr_db=20,
    fs=5e6,
    fc=5.9e9
):
    """
    输出：
      X_blocks: (num_blocks, sample_len, m, 2)
      y_blocks: (num_blocks,)
      label_to_idx: dict(tx_id -> class index)

    约束：
      - 每个 block 只来自单个 .mat 文件（不跨文件拼接）
      - contiguous: indices = start + [0,1,2,...,m-1]
      - stride = m（非滑窗、非重叠 block）
    """
    mat_files = sorted(glob.glob(os.path.join(mat_folder, "*.mat")))
    print(f"[INFO] 共找到 {len(mat_files)} 个 .mat 文件")
    fd = compute_doppler_shift(target_velocity, fc)
    print(f"[INFO] Doppler 设置：v={target_velocity} km/h, fd={fd:.2f} Hz, apply_doppler={apply_doppler}")
    print(f"[INFO] XFR contiguous 设置：m={m}")

    y_files = []
    X_files = []

    for file in tqdm(mat_files, desc="读取并预处理文件"):
        with h5py.File(file, "r") as f:
            rfDataset = f["rfDataset"]
            dmrs_struct = rfDataset["dmrs"][:]
            dmrs_complex = dmrs_struct["real"] + 1j * dmrs_struct["imag"]  # (num_frames, sample_len)

            txID_uint16 = rfDataset["txID"][:].flatten()
            tx_id = "".join(chr(c) for c in txID_uint16 if c != 0)

        processed_signals = []
        for i in range(dmrs_complex.shape[0]):
            sig = dmrs_complex[i, :]

            # step1: 功率归一化
            sig = sig / (np.sqrt(np.mean(np.abs(sig)**2)) + 1e-12)

            # step2: Doppler（只改相位）
            if apply_doppler:
                sig = apply_doppler_shift(sig, fd, fs)

            # step3: AWGN
            if apply_awgn:
                sig = add_awgn(sig, snr_db)

            iq = np.stack((sig.real, sig.imag), axis=-1)  # (sample_len, 2)
            processed_signals.append(iq)

        processed_signals = np.array(processed_signals, dtype=np.float32)  # (num_frames, sample_len, 2)
        X_files.append(processed_signals)
        y_files.append(tx_id)

    label_list = sorted(list(set(y_files)))
    label_to_idx = {label: i for i, label in enumerate(label_list)}

    X_blocks_list = []
    y_blocks_list = []

    span = m
    stride = m  # 非滑窗，非重叠

    for Xf, tx_id in tqdm(list(zip(X_files, y_files)), desc="生成 contiguous blocks"):
        num_frames, sample_len, _ = Xf.shape
        if num_frames < span:
            continue

        starts = range(0, num_frames - span + 1, stride)
        for start in starts:
            idx = start + np.arange(m)  # contiguous

            # (m, sample_len, 2)
            block_raw = Xf[idx, :, :]

            # flip to (sample_len, m, 2)
            block = np.transpose(block_raw, (1, 0, 2))

            X_blocks_list.append(block)
            y_blocks_list.append(label_to_idx[tx_id])

    if len(X_blocks_list) == 0:
        raise RuntimeError("没有生成任何 block。请检查 m 是否过大，或文件帧数不足。")

    X_blocks = np.stack(X_blocks_list, axis=0)  # (num_blocks, sample_len, m, 2)
    y_blocks = np.array(y_blocks_list, dtype=np.int64)

    print(f"[INFO] 生成 block 数: {X_blocks.shape[0]}")
    print(f"[INFO] block 形状: (sample_len={X_blocks.shape[1]}, m={X_blocks.shape[2]}, 2)")
    print(f"[INFO] 类别数: {len(label_to_idx)}")
    return X_blocks, y_blocks, label_to_idx

# ================= 1D ResNet18（增加 dropout 和 in_planes） =================
class BasicBlock1D(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1, downsample=None, dropout=0.0):
        super().__init__()
        self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm1d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(p=dropout)
        self.conv2 = nn.Conv1d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(planes)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.dropout(out)
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return self.relu(out)

class ResNet18_1D(nn.Module):
    def __init__(self, num_classes=10, in_planes=64, dropout=0.0):
        super().__init__()
        self.in_planes = in_planes
        self.conv1 = nn.Conv1d(2, in_planes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm1d(in_planes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(64, 2, stride=1, dropout=dropout)
        self.layer2 = self._make_layer(128, 2, stride=2, dropout=dropout)
        self.layer3 = self._make_layer(256, 2, stride=2, dropout=dropout)
        self.layer4 = self._make_layer(512, 2, stride=2, dropout=dropout)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, planes, blocks, stride, dropout):
        downsample = None
        if stride != 1 or self.in_planes != planes:
            downsample = nn.Sequential(
                nn.Conv1d(self.in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(planes)
            )
        layers = [BasicBlock1D(self.in_planes, planes, stride, downsample, dropout)]
        self.in_planes = planes
        for _ in range(1, blocks):
            layers.append(BasicBlock1D(self.in_planes, planes, dropout=dropout))
        return nn.Sequential(*layers)

    def forward(self, x):
        # x: (B, m, 2)
        x = x.permute(0, 2, 1)  # (B,2,m)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x).squeeze(-1)
        return self.fc(x)

# ================= 辅助函数 =================
def evaluate_model(model, dataloader, device, num_classes):
    model.eval()
    correct, total = 0, 0
    all_labels, all_preds = [], []
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
    acc = 100.0 * correct / max(1, total)
    cm = confusion_matrix(all_labels, all_preds, labels=range(num_classes))
    return acc, cm

def plot_confusion_matrix(cm, save_path=None):
    plt.figure(figsize=(8,6))
    plt.rcParams["font.sans-serif"] = ["SimHei"]
    plt.rcParams["axes.unicode_minus"] = False
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
    plt.title("Confusion Matrix")
    plt.ylabel("Reference")
    plt.xlabel("Predicted")
    if save_path:
        plt.savefig(save_path)
    plt.close()

def train_one_setting(m, save_folder):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\n\n================== m={m}, SNR={SNR_DB} dB ==================\n")
    print(f"[INFO] 使用设备: {device}")

    # 1) 生成 blocks
    X_blocks, y_blocks, label_to_idx = load_and_preprocess_blocks_singlefile_contiguous(
        data_path,
        m=m,
        apply_doppler=apply_doppler,
        target_velocity=v,
        apply_awgn=apply_awgn,
        snr_db=SNR_DB,
        fs=fs,
        fc=fc
    )

    num_blocks = X_blocks.shape[0]
    print(f"[INFO] 总 block 数: {num_blocks}")

    # 2) block-level train/test split
    block_idx = np.arange(num_blocks)
    train_block_idx, test_block_idx = train_test_split(
        block_idx,
        test_size=test_size,
        stratify=y_blocks,
        random_state=random_state
    )

    # 3) 从 train blocks 再切 20% 做验证（仍按 block stratify）
    train_block_idx2, val_block_idx = train_test_split(
        train_block_idx,
        test_size=val_size_in_train,
        stratify=y_blocks[train_block_idx],
        random_state=random_state
    )

    print(f"[INFO] 训练 block 数: {len(train_block_idx2)}, 验证 block 数: {len(val_block_idx)}, 测试 block 数: {len(test_block_idx)}")

    X_train_blocks = X_blocks[train_block_idx2]
    y_train_blocks = y_blocks[train_block_idx2]
    X_val_blocks   = X_blocks[val_block_idx]
    y_val_blocks   = y_blocks[val_block_idx]
    X_test_blocks  = X_blocks[test_block_idx]
    y_test_blocks  = y_blocks[test_block_idx]

    # 4) 展开 blocks 为样本：每 block -> sample_len 个样本
    # block shape: (B, sample_len, m, 2) -> samples: (B*sample_len, m, 2)
    def flatten_blocks(Xb, yb):
        X = Xb.reshape(-1, Xb.shape[2], Xb.shape[3])    # (-1, m, 2)
        y = np.repeat(yb, Xb.shape[1])                 # repeat sample_len
        return X, y

    X_train, y_train = flatten_blocks(X_train_blocks, y_train_blocks)
    X_val,   y_val   = flatten_blocks(X_val_blocks,   y_val_blocks)
    X_test,  y_test  = flatten_blocks(X_test_blocks,  y_test_blocks)

    train_loader = DataLoader(
        TensorDataset(torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.long)),
        batch_size=batch_size, shuffle=True
    )
    val_loader = DataLoader(
        TensorDataset(torch.tensor(X_val, dtype=torch.float32), torch.tensor(y_val, dtype=torch.long)),
        batch_size=batch_size, shuffle=False
    )
    test_loader = DataLoader(
        TensorDataset(torch.tensor(X_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.long)),
        batch_size=batch_size, shuffle=False
    )

    # 5) 训练
    model = ResNet18_1D(num_classes=len(label_to_idx), in_planes=in_planes, dropout=dropout).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

    best_val_acc = 0.0
    best_model_wts = None
    patience_counter = 0

    train_losses, val_losses = [], []

    results_file = os.path.join(save_folder, "results.txt")
    with open(results_file, "a", encoding="utf-8") as f:
        f.write(f"\n==== m={m}, SNR={SNR_DB} dB ====\n")

    for epoch in range(num_epochs):
        model.train()
        running_loss, correct_train, total_train = 0.0, 0, 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            preds = torch.argmax(outputs, dim=1)
            total_train += labels.size(0)
            correct_train += (preds == labels).sum().item()

        train_loss = running_loss / max(1, len(train_loader))
        train_acc = 100.0 * correct_train / max(1, total_train)
        train_losses.append(train_loss)

        val_acc, _ = evaluate_model(model, val_loader, device, len(label_to_idx))
        # 这里 val_loss 不再逐 batch 计算（要的话可以加），为了简洁只记录 train_loss + val_acc
        log_msg = f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.2f}%, Val Acc={val_acc:.2f}%"
        print(log_msg)
        with open(results_file, "a", encoding="utf-8") as f:
            f.write(log_msg + "\n")

        if val_acc > best_val_acc + 0.01:
            best_val_acc = val_acc
            patience_counter = 0
            best_model_wts = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        else:
            patience_counter += 1
            if patience_counter >= patience:
                msg = f"Early stop: {patience} epochs no val improvement."
                print(msg)
                with open(results_file, "a", encoding="utf-8") as f:
                    f.write(msg + "\n")
                break

        scheduler.step()

    # 6) 测试
    model.load_state_dict(best_model_wts)
    test_acc, test_cm = evaluate_model(model, test_loader, device, len(label_to_idx))
    print(f"[RESULT] m={m}, best_val={best_val_acc:.2f}%, test_acc={test_acc:.2f}%")

    plot_confusion_matrix(test_cm, save_path=os.path.join(save_folder, f"confusion_matrix_test_m{m}.png"))
    torch.save(best_model_wts, os.path.join(save_folder, f"best_model_m{m}.pth"))

    with open(results_file, "a", encoding="utf-8") as f:
        f.write(f"[RESULT] m={m}, best_val={best_val_acc:.4f}, test_acc={test_acc:.4f}\n")

    return best_val_acc, test_acc

# ================= 主程序：不同 m 的准确率曲线 =================
if __name__ == "__main__":
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    folder_name = f"{timestamp}_LTEV_XFR_contiguous_m_sweep_SNR{SNR_DB}dB_fd{int(compute_doppler_shift(v, fc))}"
    save_root = os.path.join(os.getcwd(), "training_results", folder_name)
    os.makedirs(save_root, exist_ok=True)

    m_vals, best_vals, test_accs = [], [], []

    for m in M_LIST:
        save_folder = os.path.join(save_root, f"m_{m}")
        os.makedirs(save_folder, exist_ok=True)

        best_val_acc, test_acc = train_one_setting(m, save_folder)
        m_vals.append(m)
        best_vals.append(best_val_acc)
        test_accs.append(test_acc)

    # 保存并绘制 m vs acc
    np.save(os.path.join(save_root, "m_vals.npy"), np.array(m_vals))
    np.save(os.path.join(save_root, "best_val_accs.npy"), np.array(best_vals))
    np.save(os.path.join(save_root, "test_accs.npy"), np.array(test_accs))

    plt.figure(figsize=(8,5))
    plt.plot(m_vals, test_accs, marker="o", linestyle="-")
    plt.xlabel("m (XFR length)")
    plt.ylabel("Test Accuracy (%)")
    plt.title(f"Contiguous XFR: m vs Test Accuracy @ SNR={SNR_DB} dB")
    plt.grid(True)
    fig_path = os.path.join(save_root, "m_vs_testacc.png")
    plt.savefig(fig_path)
    plt.show()
    print(f"[INFO] Saved curve: {fig_path}")

    # 打印汇总
    print("\n===== Summary (SNR=20dB) =====")
    for m, a in zip(m_vals, test_accs):
        print(f"m={m:>4d} -> Test Acc={a:.2f}%")
