In [15]:
from joblib import load
import pandas as pd
import numpy as np
import os
from  data_utilities import *
import cv2  # OpenCV 用于调整图像大小和颜色处理
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import gc  # 引入垃圾回收模块
from tqdm.auto import tqdm  # 自动适配环境 导入tqdm进度条库
from collections import defaultdict

dataset_name = 'ManySig'
dataset_path='../ManySig.pkl/'

compact_dataset = load_compact_pkl_dataset(dataset_path,dataset_name)

print("数据集发射机数量：",len(compact_dataset['tx_list']),"具体为：",compact_dataset['tx_list'])
print("数据集接收机数量：",len(compact_dataset['rx_list']),"具体为：",compact_dataset['rx_list'])
print("数据集采集天数：",len(compact_dataset['capture_date_list']),"具体为：",compact_dataset['capture_date_list'])


tx_list = compact_dataset['tx_list']
rx_list = compact_dataset['rx_list']
equalized = 0
capture_date_list = compact_dataset['capture_date_list']


n_tx = len(tx_list)
n_rx = len(rx_list)
print(n_tx,n_rx)


train_dates = ['2021_03_01', '2021_03_08', '2021_03_15']  # 设定你想用的训练日期
# X_train, y_train, X_test, y_test = preprocess_dataset_cross_IQ_blocks(
#     compact_dataset, tx_list, rx_list, 
#     train_dates=train_dates, 
#     max_sig=None,  # 或者 1000
#     equalized=0,
#     block_size=250
# )
X_train, y_train, X_test, y_test = preprocess_dataset_cross_IQ_blocks_date_interleaved(
    compact_dataset, tx_list, train_dates=train_dates,
    max_sig=None, equalized=equalized, block_size=240, y=80
)
print("X_train shape:", X_train.shape)  # (num_blocks*block_size, block_size, 2)
print("y_train shape:", y_train.shape)
print("X_test  shape:", X_test.shape)  # (num_blocks*block_size, block_size, 2)
print("y_test  shape:", y_test.shape)


数据集发射机数量： 6 具体为： ['14-10', '14-7', '20-15', '20-19', '6-15', '8-20']
数据集接收机数量： 12 具体为： ['1-1', '1-19', '14-7', '18-2', '19-2', '2-1', '2-19', '20-1', '3-19', '7-14', '7-7', '8-8']
数据集采集天数： 4 具体为： ['2021_03_01', '2021_03_08', '2021_03_15', '2021_03_23']
6 12
X_train shape: (230400, 240, 2)
y_train shape: (230400,)
X_test  shape: (76800, 240, 2)
y_test  shape: (76800,)


In [None]:
import numpy as np

# === 参数设置 ===
SNR_dB = 20            # 信噪比
fs = 20e6             # 采样率 (Hz)
fc = 2.4e9            # 载波频率 (Hz)
v = 120               # 速度 (m/s)

# === 多普勒频移计算 ===
def compute_doppler_shift(v, fc):
    c = 3e8  # 光速
    return (v / c) * fc

fd = compute_doppler_shift(v, fc)
print(f"[INFO] 多普勒频移 fd = {fd:.2f} Hz")

# === 多普勒变换 ===
def add_doppler_shift(signal, fd, fs):
    num_samples = signal.shape[-1]
    t = np.arange(num_samples) / fs
    doppler_phase = np.exp(1j * 2 * np.pi * fd * t)
    return signal * doppler_phase

# === 加噪声 + 多普勒 的主流程 ===
def preprocess_iq_data(data_real_imag, snr_db, fd, fs):
    # Step 1: 转为复数 IQ，shape: (N, T, 2) → (N, T)
    data_complex = data_real_imag[..., 0] + 1j * data_real_imag[..., 1]

    processed = []
    for sig in data_complex:
        # Step 2: 添加 AWGN 噪声
        signal_std = np.std(sig)
        noise_std = signal_std / (10 ** (snr_db / 20))
        noise = np.random.normal(0, noise_std, sig.shape) + 1j * np.random.normal(0, noise_std, sig.shape)
        noisy = sig + noise

        # Step 3: 添加多普勒频移
        shifted = add_doppler_shift(noisy, fd, fs)

        processed.append(shifted)

    processed = np.array(processed)  # shape: (N, T), complex

    # Step 4: 转回 [I, Q] 实数格式
    processed_real_imag = np.stack([processed.real, processed.imag], axis=-1)  # shape: (N, T, 2)

    return processed_real_imag

X_train_processed = preprocess_iq_data(X_train, snr_db=SNR_dB, fd=fd, fs=fs)
X_test_processed  = preprocess_iq_data(X_test,  snr_db=SNR_dB, fd=fd, fs=fs)

# 查看处理前后前10个点
print("原始信号 I 分量：", X_train[0, :10, 0])
print("处理后信号 I 分量：", X_train_processed[0, :10, 0])


[INFO] 多普勒频移 fd = 960.00 Hz
原始信号 I 分量： [-0.01974548  0.00335704 -0.00375376  0.00189214  0.00845358  0.00012207
 -0.00817892  0.00210577  0.00466931 -0.00378427]
处理后信号 I 分量： [-0.02210168  0.00381852 -0.00721617 -0.00231823  0.01310801  0.00945945
 -0.00852235  0.00570762  0.00971376 -0.0092444 ]


In [17]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix
import seaborn as sns
from torch.utils.data import DataLoader, TensorDataset, Subset
from datetime import datetime
from tqdm import tqdm
from sklearn.model_selection import KFold
from torch.nn import TransformerEncoder, TransformerEncoderLayer

# 假设 SNR_dB 和 fd 已经定义
SNR_dB = globals().get('SNR_dB', 'no')
fd = globals().get('fd', 'no')

# === 模型与训练参数设置 ===
raw_input_dim = 2         # 每个时间步是 I/Q 两个值
model_dim = 128           # Transformer 模型内部维度
num_heads = 4
num_layers = 3
num_classes = len(np.unique(y_train))  # 或 len(tx_list)
dropout = 0.1
batch_size = 512
num_epochs = 100
learning_rate = 1e-4
patience = 5

# === 创建保存目录 ===
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
script_name = "wisig_cross"
folder_name = f"{timestamp}_{script_name}_SNR{SNR_dB}dB_fd{fd}_classes_{num_classes}_Transformer"
save_folder = os.path.join(os.getcwd(), "training_results", folder_name)
os.makedirs(save_folder, exist_ok=True)

results_file = os.path.join(save_folder, "results.txt")
with open(results_file, "w") as f:
    f.write(f"=== Experiment Summary ===\n")
    f.write(f"Timestamp: {timestamp}\n")
    f.write(f"Total Classes: {num_classes}\n")
    f.write(f"SNR: {SNR_dB} dB\n")
    f.write(f"fd (Doppler shift): {fd} Hz\n")
    f.write(f"equalized: {equalized} \n")
    
# === 模型定义 ===
class SignalTransformer(nn.Module):
    def __init__(self, raw_input_dim, model_dim, num_heads, num_layers, num_classes, dropout=0.1):
        super(SignalTransformer, self).__init__()
        self.embedding = nn.Linear(raw_input_dim, model_dim)
        encoder_layer = TransformerEncoderLayer(d_model=model_dim, nhead=num_heads, dropout=dropout, batch_first=True)
        self.encoder = TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(model_dim, num_classes)

    def forward(self, x):
        x = self.embedding(x)
        x = self.encoder(x)
        x = x[:, -1, :]
        x = self.fc(x)
        return x


# === 假设 X_train, y_train, X_test, y_test 都已定义并 shape 为 (N, L, 2) ===
# 若还未定义，可自行加载并 reshape
X_test_tensor = torch.tensor(X_test_processed, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.long)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

train_dataset = TensorDataset(torch.tensor(X_train_processed, dtype=torch.float32),
                               torch.tensor(y_train, dtype=torch.long))

# === K折交叉验证训练 ===
n_splits = 5
kfold = KFold(n_splits=n_splits, shuffle=True, random_state=42)
fold_results = []
test_results = []

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def compute_grad_norm(model):
    total_norm = 0.0
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    return total_norm ** 0.5

def moving_average(x, w=5):
    return np.convolve(x, np.ones(w), 'valid') / w

avg_grad_norms_per_fold = []

for fold, (train_idx, val_idx) in enumerate(kfold.split(train_dataset)):
    print(f"\n====== Fold {fold+1}/{n_splits} ======")

    train_subset = Subset(train_dataset, train_idx)
    val_subset = Subset(train_dataset, val_idx)

    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, drop_last=True)

    model = SignalTransformer(raw_input_dim, model_dim, num_heads, num_layers, num_classes, dropout).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []
    grad_norms = []

    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(num_epochs):
        model.train()
        running_train_loss, correct_train, total_train = 0.0, 0, 0
        batch_grad_norms = []

        with tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch") as tepoch:
            for inputs, labels in tepoch:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()

                grad_norm = compute_grad_norm(model)
                batch_grad_norms.append(grad_norm)

                optimizer.step()

                running_train_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total_train += labels.size(0)
                correct_train += (predicted == labels).sum().item()

                tepoch.set_postfix(loss=running_train_loss / (len(train_loader)),
                                   accuracy=100 * correct_train / total_train,
                                   grad_norm=grad_norm)

        epoch_train_loss = running_train_loss / len(train_loader)
        train_losses.append(epoch_train_loss)
        train_accuracies.append(100 * correct_train / total_train)
        avg_grad_norm = np.mean(batch_grad_norms)
        grad_norms.append(avg_grad_norm)

        print(f"Epoch {epoch+1} Average Gradient Norm: {avg_grad_norm:.4f}")

        # === 验证 ===
        model.eval()
        running_val_loss, correct_val, total_val = 0.0, 0, 0

        with torch.no_grad():
            for val_inputs, val_labels in val_loader:
                val_inputs = val_inputs.to(device)
                val_labels = val_labels.to(device)

                val_outputs = model(val_inputs)
                val_loss = criterion(val_outputs, val_labels)
                running_val_loss += val_loss.item()
                _, val_predicted = torch.max(val_outputs, 1)
                total_val += val_labels.size(0)
                correct_val += (val_predicted == val_labels).sum().item()

        epoch_val_loss = running_val_loss / len(val_loader)
        val_losses.append(epoch_val_loss)
        val_accuracies.append(100 * correct_val / total_val)

        with open(results_file, "a") as f:
            f.write(f"Epoch {epoch+1} | Train Acc: {train_accuracies[-1]:.2f}% | Val Acc: {val_accuracies[-1]:.2f}%\n")

        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print("Early stopping")
            break

        scheduler.step()

    fold_results.append(max(val_accuracies))
    avg_grad_norms_per_fold.append(grad_norms)

    # === 绘制 loss 曲线 ===
    plt.figure()
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.plot(moving_average(train_losses), label='Train Loss (Smooth)', linestyle='--')
    plt.plot(moving_average(val_losses), label='Val Loss (Smooth)', linestyle='--')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Fold {fold+1} Loss Curve')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(save_folder, f"fold_{fold+1}_loss_curve.png"))
    plt.close()

    # === 绘制 Gradient Norm 曲线 ===
    plt.figure()
    plt.plot(grad_norms, label='Gradient Norm')
    plt.xlabel('Epoch')
    plt.ylabel('Gradient Norm')
    plt.title(f'Fold {fold+1} Gradient Norm')
    plt.grid(True)
    plt.legend()
    plt.savefig(os.path.join(save_folder, f"fold_{fold+1}_grad_norm.png"))
    plt.close()

    # === 测试集评估 ===
    model.eval()
    test_preds, test_true = [], []

    with torch.no_grad():
        for test_inputs, test_labels in test_loader:
            test_inputs = test_inputs.to(device)
            test_labels = test_labels.to(device)

            test_outputs = model(test_inputs)
            _, predicted = torch.max(test_outputs, 1)
            test_preds.extend(predicted.cpu().numpy())
            test_true.extend(test_labels.cpu().numpy())

    test_preds = np.array(test_preds)
    test_true = np.array(test_true)
    test_accuracy = 100.0 * np.sum(test_preds == test_true) / len(test_true)
    test_results.append(test_accuracy)

    with open(results_file, "a") as f:
        f.write(f"Fold {fold+1} Test Accuracy: {test_accuracy:.2f}%\n")

    cm = confusion_matrix(test_true, test_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title(f'Test Confusion Matrix Fold {fold+1}')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.savefig(os.path.join(save_folder, f"fold_{fold+1}_test_confusion_matrix.png"))
    plt.close()

# === 总结结果 ===
avg_val = np.mean(fold_results)
avg_test = np.mean(test_results)

with open(results_file, "a") as f:
    f.write("\n=== Summary ===\n")
    for i in range(n_splits):
        f.write(f"Fold {i+1}: Val Acc = {fold_results[i]:.2f}%, Test Acc = {test_results[i]:.2f}%\n")
    f.write(f"\nAverage Validation Accuracy: {avg_val:.2f}%\n")
    f.write(f"Average Test Accuracy: {avg_test:.2f}%\n")

print("\n=== Final Summary ===")
for i in range(n_splits):
    print(f"Fold {i+1}: Val = {fold_results[i]:.2f}%, Test = {test_results[i]:.2f}%")
print(f"Average Val Accuracy: {avg_val:.2f}%")
print(f"Average Test Accuracy: {avg_test:.2f}%")


Using device: cuda



Epoch 1/100: 100%|██████████| 360/360 [02:15<00:00,  2.65batch/s, accuracy=33.7, grad_norm=8.78, loss=1.46]  


Epoch 1 Average Gradient Norm: 3.1726


Epoch 2/100: 100%|██████████| 360/360 [01:04<00:00,  5.58batch/s, accuracy=63.8, grad_norm=6.76, loss=0.838] 


Epoch 2 Average Gradient Norm: 5.0834


Epoch 3/100: 100%|██████████| 360/360 [01:01<00:00,  5.89batch/s, accuracy=65.7, grad_norm=2.94, loss=0.801] 


Epoch 3 Average Gradient Norm: 4.6465


Epoch 4/100: 100%|██████████| 360/360 [01:00<00:00,  5.91batch/s, accuracy=69.8, grad_norm=9.39, loss=0.733] 


Epoch 4 Average Gradient Norm: 6.4487


Epoch 5/100: 100%|██████████| 360/360 [01:01<00:00,  5.89batch/s, accuracy=74, grad_norm=4.85, loss=0.645]  


Epoch 5 Average Gradient Norm: 8.9564


Epoch 6/100: 100%|██████████| 360/360 [01:01<00:00,  5.88batch/s, accuracy=76.6, grad_norm=5.52, loss=0.59] 


Epoch 6 Average Gradient Norm: 9.3079


Epoch 7/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=77.6, grad_norm=13.1, loss=0.56] 


Epoch 7 Average Gradient Norm: 9.4740


Epoch 8/100: 100%|██████████| 360/360 [01:02<00:00,  5.78batch/s, accuracy=79, grad_norm=2.99, loss=0.528]  


Epoch 8 Average Gradient Norm: 7.3845


Epoch 9/100: 100%|██████████| 360/360 [01:03<00:00,  5.71batch/s, accuracy=79.8, grad_norm=8.38, loss=0.508]


Epoch 9 Average Gradient Norm: 6.6548


Epoch 10/100: 100%|██████████| 360/360 [01:02<00:00,  5.72batch/s, accuracy=80.2, grad_norm=4.16, loss=0.497]


Epoch 10 Average Gradient Norm: 6.3942


Epoch 11/100: 100%|██████████| 360/360 [01:02<00:00,  5.73batch/s, accuracy=80.7, grad_norm=6.02, loss=0.482]


Epoch 11 Average Gradient Norm: 5.0739


Epoch 12/100: 100%|██████████| 360/360 [01:01<00:00,  5.85batch/s, accuracy=81, grad_norm=4.86, loss=0.477]  


Epoch 12 Average Gradient Norm: 5.0714


Epoch 13/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=81, grad_norm=5.55, loss=0.475]  


Epoch 13 Average Gradient Norm: 5.2830


Epoch 14/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=81.2, grad_norm=5, loss=0.469]   


Epoch 14 Average Gradient Norm: 5.0283


Epoch 15/100: 100%|██████████| 360/360 [01:01<00:00,  5.86batch/s, accuracy=81.2, grad_norm=2.32, loss=0.47] 


Epoch 15 Average Gradient Norm: 5.3567


Epoch 16/100: 100%|██████████| 360/360 [01:01<00:00,  5.85batch/s, accuracy=81.4, grad_norm=2.44, loss=0.464]


Epoch 16 Average Gradient Norm: 4.9312


Epoch 17/100: 100%|██████████| 360/360 [01:01<00:00,  5.86batch/s, accuracy=81.5, grad_norm=6.11, loss=0.461]


Epoch 17 Average Gradient Norm: 4.8539


Epoch 18/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=81.5, grad_norm=5.42, loss=0.46] 


Epoch 18 Average Gradient Norm: 5.0332


Epoch 19/100: 100%|██████████| 360/360 [01:01<00:00,  5.86batch/s, accuracy=81.7, grad_norm=3.66, loss=0.456]


Epoch 19 Average Gradient Norm: 4.6622


Epoch 20/100: 100%|██████████| 360/360 [01:01<00:00,  5.86batch/s, accuracy=81.7, grad_norm=3.56, loss=0.456]


Epoch 20 Average Gradient Norm: 4.9492
Early stopping



Epoch 1/100: 100%|██████████| 360/360 [01:01<00:00,  5.88batch/s, accuracy=32.3, grad_norm=5.08, loss=1.49]  


Epoch 1 Average Gradient Norm: 3.6075


Epoch 2/100: 100%|██████████| 360/360 [01:01<00:00,  5.88batch/s, accuracy=62.4, grad_norm=1.8, loss=0.868]  


Epoch 2 Average Gradient Norm: 5.3404


Epoch 3/100: 100%|██████████| 360/360 [01:01<00:00,  5.86batch/s, accuracy=65.8, grad_norm=6.88, loss=0.804] 


Epoch 3 Average Gradient Norm: 4.5333


Epoch 4/100: 100%|██████████| 360/360 [01:01<00:00,  5.89batch/s, accuracy=71, grad_norm=7.58, loss=0.714]  


Epoch 4 Average Gradient Norm: 7.2836


Epoch 5/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=74.7, grad_norm=13, loss=0.634]  


Epoch 5 Average Gradient Norm: 8.1954


Epoch 6/100: 100%|██████████| 360/360 [01:01<00:00,  5.88batch/s, accuracy=76.6, grad_norm=1.82, loss=0.591]


Epoch 6 Average Gradient Norm: 6.6987


Epoch 7/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=77.9, grad_norm=3.32, loss=0.559]


Epoch 7 Average Gradient Norm: 5.8393


Epoch 8/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=78.6, grad_norm=5.05, loss=0.538]


Epoch 8 Average Gradient Norm: 5.2598


Epoch 9/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=79.2, grad_norm=8.6, loss=0.523] 


Epoch 9 Average Gradient Norm: 5.1200


Epoch 10/100: 100%|██████████| 360/360 [01:01<00:00,  5.86batch/s, accuracy=79.5, grad_norm=5.09, loss=0.512]


Epoch 10 Average Gradient Norm: 5.1248


Epoch 11/100: 100%|██████████| 360/360 [01:01<00:00,  5.86batch/s, accuracy=80.3, grad_norm=3.7, loss=0.492] 


Epoch 11 Average Gradient Norm: 3.7544


Epoch 12/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=80.4, grad_norm=2.34, loss=0.488]


Epoch 12 Average Gradient Norm: 4.2888


Epoch 13/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=80.5, grad_norm=4.66, loss=0.484]


Epoch 13 Average Gradient Norm: 4.1782


Epoch 14/100: 100%|██████████| 360/360 [01:01<00:00,  5.86batch/s, accuracy=80.7, grad_norm=5.1, loss=0.479] 


Epoch 14 Average Gradient Norm: 4.2048


Epoch 15/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=80.8, grad_norm=4.59, loss=0.475]


Epoch 15 Average Gradient Norm: 4.1077


Epoch 16/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=81, grad_norm=8.07, loss=0.472]  


Epoch 16 Average Gradient Norm: 4.3222


Epoch 17/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=81.1, grad_norm=3.21, loss=0.468]


Epoch 17 Average Gradient Norm: 4.3429


Epoch 18/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=81.3, grad_norm=3.38, loss=0.464]


Epoch 18 Average Gradient Norm: 4.2839


Epoch 19/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=81.4, grad_norm=8.23, loss=0.461]


Epoch 19 Average Gradient Norm: 4.2260


Epoch 20/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=81.3, grad_norm=2.72, loss=0.461]


Epoch 20 Average Gradient Norm: 4.5753


Epoch 21/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=81.7, grad_norm=4.91, loss=0.451]


Epoch 21 Average Gradient Norm: 3.5526


Epoch 22/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=81.8, grad_norm=2.43, loss=0.449]


Epoch 22 Average Gradient Norm: 3.6525


Epoch 23/100: 100%|██████████| 360/360 [01:01<00:00,  5.85batch/s, accuracy=81.9, grad_norm=3.6, loss=0.446] 


Epoch 23 Average Gradient Norm: 3.5045


Epoch 24/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=81.8, grad_norm=4.88, loss=0.448]


Epoch 24 Average Gradient Norm: 4.0015


Epoch 25/100: 100%|██████████| 360/360 [01:01<00:00,  5.86batch/s, accuracy=82, grad_norm=3.54, loss=0.444]  


Epoch 25 Average Gradient Norm: 3.7514


Epoch 26/100: 100%|██████████| 360/360 [01:01<00:00,  5.88batch/s, accuracy=82, grad_norm=5.05, loss=0.443]  


Epoch 26 Average Gradient Norm: 3.9080


Epoch 27/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=82, grad_norm=2.09, loss=0.444]  


Epoch 27 Average Gradient Norm: 4.1594


Epoch 28/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=82.1, grad_norm=6.98, loss=0.441]


Epoch 28 Average Gradient Norm: 3.8266


Epoch 29/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=82.1, grad_norm=4.29, loss=0.44] 


Epoch 29 Average Gradient Norm: 3.8796


Epoch 30/100: 100%|██████████| 360/360 [01:01<00:00,  5.88batch/s, accuracy=82.2, grad_norm=2.03, loss=0.439]


Epoch 30 Average Gradient Norm: 3.6685
Early stopping



Epoch 1/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=23, grad_norm=21.1, loss=1.68]    


Epoch 1 Average Gradient Norm: 2.1807


Epoch 2/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=62.7, grad_norm=3.06, loss=0.858] 


Epoch 2 Average Gradient Norm: 7.2680


Epoch 3/100: 100%|██████████| 360/360 [01:01<00:00,  5.86batch/s, accuracy=65.9, grad_norm=3.06, loss=0.792] 


Epoch 3 Average Gradient Norm: 6.0985


Epoch 4/100: 100%|██████████| 360/360 [01:01<00:00,  5.86batch/s, accuracy=70.3, grad_norm=5.96, loss=0.725]


Epoch 4 Average Gradient Norm: 7.9116


Epoch 5/100: 100%|██████████| 360/360 [01:01<00:00,  5.86batch/s, accuracy=74.1, grad_norm=17.5, loss=0.64] 


Epoch 5 Average Gradient Norm: 12.4617


Epoch 6/100: 100%|██████████| 360/360 [01:01<00:00,  5.86batch/s, accuracy=77, grad_norm=3.91, loss=0.576]  


Epoch 6 Average Gradient Norm: 14.5772


Epoch 7/100: 100%|██████████| 360/360 [01:01<00:00,  5.86batch/s, accuracy=78.7, grad_norm=10.1, loss=0.534]


Epoch 7 Average Gradient Norm: 14.1127


Epoch 8/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=79.5, grad_norm=10.8, loss=0.514]


Epoch 8 Average Gradient Norm: 13.7790


Epoch 9/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=80.1, grad_norm=10.7, loss=0.498]


Epoch 9 Average Gradient Norm: 12.4503


Epoch 10/100: 100%|██████████| 360/360 [01:01<00:00,  5.86batch/s, accuracy=80.4, grad_norm=15.1, loss=0.491]


Epoch 10 Average Gradient Norm: 11.9379


Epoch 11/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=80.9, grad_norm=5.86, loss=0.475]


Epoch 11 Average Gradient Norm: 9.9021


Epoch 12/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=81.1, grad_norm=9.66, loss=0.471]


Epoch 12 Average Gradient Norm: 10.0615


Epoch 13/100: 100%|██████████| 360/360 [01:01<00:00,  5.88batch/s, accuracy=81, grad_norm=19.5, loss=0.474]  


Epoch 13 Average Gradient Norm: 11.8106


Epoch 14/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=81.3, grad_norm=3.71, loss=0.466]


Epoch 14 Average Gradient Norm: 10.3996


Epoch 15/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=81.3, grad_norm=5.12, loss=0.465]


Epoch 15 Average Gradient Norm: 10.5235


Epoch 16/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=81.3, grad_norm=5.41, loss=0.464]


Epoch 16 Average Gradient Norm: 10.7149


Epoch 17/100: 100%|██████████| 360/360 [01:01<00:00,  5.88batch/s, accuracy=81.6, grad_norm=20.6, loss=0.459]


Epoch 17 Average Gradient Norm: 9.7803


Epoch 18/100: 100%|██████████| 360/360 [01:01<00:00,  5.88batch/s, accuracy=81.7, grad_norm=7.74, loss=0.457]


Epoch 18 Average Gradient Norm: 9.8797


Epoch 19/100: 100%|██████████| 360/360 [01:01<00:00,  5.86batch/s, accuracy=81.5, grad_norm=3.22, loss=0.458]


Epoch 19 Average Gradient Norm: 10.7773
Early stopping



Epoch 1/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=42, grad_norm=6.7, loss=1.29]     


Epoch 1 Average Gradient Norm: 4.0090


Epoch 2/100: 100%|██████████| 360/360 [01:01<00:00,  5.86batch/s, accuracy=64.3, grad_norm=7.01, loss=0.826] 


Epoch 2 Average Gradient Norm: 4.3219


Epoch 3/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=65.5, grad_norm=2.05, loss=0.802] 


Epoch 3 Average Gradient Norm: 3.8389


Epoch 4/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=70, grad_norm=2.15, loss=0.733]   


Epoch 4 Average Gradient Norm: 4.4958


Epoch 5/100: 100%|██████████| 360/360 [01:01<00:00,  5.87batch/s, accuracy=75, grad_norm=3.46, loss=0.635]  


Epoch 5 Average Gradient Norm: 6.2155


Epoch 6/100: 100%|██████████| 360/360 [01:01<00:00,  5.86batch/s, accuracy=77.1, grad_norm=2.38, loss=0.587]


Epoch 6 Average Gradient Norm: 5.7836


Epoch 7/100: 100%|██████████| 360/360 [01:00<00:00,  5.90batch/s, accuracy=78, grad_norm=5.48, loss=0.565]  


Epoch 7 Average Gradient Norm: 5.4548


Epoch 8/100: 100%|██████████| 360/360 [01:00<00:00,  5.91batch/s, accuracy=78.5, grad_norm=3.04, loss=0.553]


Epoch 8 Average Gradient Norm: 5.1269


Epoch 9/100: 100%|██████████| 360/360 [01:01<00:00,  5.90batch/s, accuracy=78.7, grad_norm=5.12, loss=0.544]


Epoch 9 Average Gradient Norm: 4.8361


Epoch 10/100: 100%|██████████| 360/360 [01:00<00:00,  5.91batch/s, accuracy=79, grad_norm=3.65, loss=0.536]  


Epoch 10 Average Gradient Norm: 4.7221


Epoch 11/100: 100%|██████████| 360/360 [01:00<00:00,  5.94batch/s, accuracy=79.6, grad_norm=6.08, loss=0.519] 


Epoch 11 Average Gradient Norm: 3.3964


Epoch 12/100: 100%|██████████| 360/360 [01:00<00:00,  5.91batch/s, accuracy=79.7, grad_norm=5.9, loss=0.518] 


Epoch 12 Average Gradient Norm: 3.8809


Epoch 13/100: 100%|██████████| 360/360 [01:01<00:00,  5.88batch/s, accuracy=79.6, grad_norm=2.91, loss=0.517]


Epoch 13 Average Gradient Norm: 4.2335


Epoch 14/100: 100%|██████████| 360/360 [01:00<00:00,  5.91batch/s, accuracy=79.9, grad_norm=8.22, loss=0.51] 


Epoch 14 Average Gradient Norm: 3.7839


Epoch 15/100: 100%|██████████| 360/360 [01:01<00:00,  5.90batch/s, accuracy=79.8, grad_norm=1.73, loss=0.511] 


Epoch 15 Average Gradient Norm: 4.1752


Epoch 16/100: 100%|██████████| 360/360 [01:00<00:00,  5.93batch/s, accuracy=80.1, grad_norm=2.89, loss=0.504]


Epoch 16 Average Gradient Norm: 3.6539


Epoch 17/100: 100%|██████████| 360/360 [01:00<00:00,  5.90batch/s, accuracy=80.1, grad_norm=5.55, loss=0.503]


Epoch 17 Average Gradient Norm: 4.2206


Epoch 18/100: 100%|██████████| 360/360 [01:00<00:00,  5.90batch/s, accuracy=80.2, grad_norm=3.7, loss=0.501] 


Epoch 18 Average Gradient Norm: 4.1979


Epoch 19/100: 100%|██████████| 360/360 [01:00<00:00,  5.94batch/s, accuracy=80.2, grad_norm=6, loss=0.497]   


Epoch 19 Average Gradient Norm: 3.9395


Epoch 20/100: 100%|██████████| 360/360 [01:00<00:00,  5.91batch/s, accuracy=80.4, grad_norm=2.58, loss=0.491]


Epoch 20 Average Gradient Norm: 3.8175


Epoch 21/100: 100%|██████████| 360/360 [01:00<00:00,  5.91batch/s, accuracy=80.6, grad_norm=2.33, loss=0.485]


Epoch 21 Average Gradient Norm: 3.4592


Epoch 22/100: 100%|██████████| 360/360 [01:00<00:00,  5.92batch/s, accuracy=80.7, grad_norm=1.77, loss=0.481]


Epoch 22 Average Gradient Norm: 3.4394


Epoch 23/100: 100%|██████████| 360/360 [01:01<00:00,  5.89batch/s, accuracy=80.8, grad_norm=4.36, loss=0.479]


Epoch 23 Average Gradient Norm: 3.4708


Epoch 24/100: 100%|██████████| 360/360 [01:01<00:00,  5.90batch/s, accuracy=80.8, grad_norm=6.46, loss=0.476]


Epoch 24 Average Gradient Norm: 3.7732


Epoch 25/100: 100%|██████████| 360/360 [01:00<00:00,  5.91batch/s, accuracy=81, grad_norm=3.4, loss=0.474]   


Epoch 25 Average Gradient Norm: 3.7466


Epoch 26/100: 100%|██████████| 360/360 [01:00<00:00,  5.93batch/s, accuracy=81.1, grad_norm=3.13, loss=0.47] 


Epoch 26 Average Gradient Norm: 3.6787


Epoch 27/100: 100%|██████████| 360/360 [01:00<00:00,  5.91batch/s, accuracy=81.2, grad_norm=4.45, loss=0.467]


Epoch 27 Average Gradient Norm: 3.8317


Epoch 28/100: 100%|██████████| 360/360 [01:00<00:00,  5.91batch/s, accuracy=81.2, grad_norm=2.95, loss=0.465]


Epoch 28 Average Gradient Norm: 3.8851


Epoch 29/100: 100%|██████████| 360/360 [01:01<00:00,  5.90batch/s, accuracy=81.3, grad_norm=3.59, loss=0.464]


Epoch 29 Average Gradient Norm: 3.9489


Epoch 30/100: 100%|██████████| 360/360 [01:00<00:00,  5.90batch/s, accuracy=81.4, grad_norm=5.01, loss=0.46] 


Epoch 30 Average Gradient Norm: 3.6074


Epoch 31/100: 100%|██████████| 360/360 [01:00<00:00,  5.93batch/s, accuracy=81.6, grad_norm=1.76, loss=0.457]


Epoch 31 Average Gradient Norm: 3.6390


Epoch 32/100: 100%|██████████| 360/360 [01:00<00:00,  5.91batch/s, accuracy=81.6, grad_norm=5.5, loss=0.456] 


Epoch 32 Average Gradient Norm: 3.6307


Epoch 33/100: 100%|██████████| 360/360 [01:00<00:00,  5.91batch/s, accuracy=81.6, grad_norm=1.83, loss=0.455]


Epoch 33 Average Gradient Norm: 3.5581


Epoch 34/100: 100%|██████████| 360/360 [01:00<00:00,  5.93batch/s, accuracy=81.6, grad_norm=2.76, loss=0.455]


Epoch 34 Average Gradient Norm: 3.5129


Epoch 35/100: 100%|██████████| 360/360 [01:00<00:00,  5.91batch/s, accuracy=81.7, grad_norm=5.72, loss=0.453]


Epoch 35 Average Gradient Norm: 3.5390


Epoch 36/100: 100%|██████████| 360/360 [01:00<00:00,  5.94batch/s, accuracy=81.8, grad_norm=5.68, loss=0.452]


Epoch 36 Average Gradient Norm: 3.5634


Epoch 37/100: 100%|██████████| 360/360 [01:00<00:00,  5.93batch/s, accuracy=81.7, grad_norm=3.12, loss=0.451]


Epoch 37 Average Gradient Norm: 3.6075
Early stopping



Epoch 1/100: 100%|██████████| 360/360 [01:00<00:00,  5.90batch/s, accuracy=34.2, grad_norm=2.94, loss=1.45]  


Epoch 1 Average Gradient Norm: 3.7128


Epoch 2/100: 100%|██████████| 360/360 [01:00<00:00,  5.91batch/s, accuracy=64.1, grad_norm=6.37, loss=0.831] 


Epoch 2 Average Gradient Norm: 5.3468


Epoch 3/100: 100%|██████████| 360/360 [01:00<00:00,  5.90batch/s, accuracy=65.5, grad_norm=5.49, loss=0.799] 


Epoch 3 Average Gradient Norm: 4.2639


Epoch 4/100: 100%|██████████| 360/360 [01:00<00:00,  5.93batch/s, accuracy=69.8, grad_norm=6.52, loss=0.729] 


Epoch 4 Average Gradient Norm: 5.5938


Epoch 5/100: 100%|██████████| 360/360 [01:00<00:00,  5.91batch/s, accuracy=75.5, grad_norm=1.74, loss=0.619]


Epoch 5 Average Gradient Norm: 7.0092


Epoch 6/100: 100%|██████████| 360/360 [01:00<00:00,  5.95batch/s, accuracy=77.4, grad_norm=5.51, loss=0.574]


Epoch 6 Average Gradient Norm: 6.2066


Epoch 7/100: 100%|██████████| 360/360 [01:01<00:00,  5.90batch/s, accuracy=78.4, grad_norm=4.6, loss=0.546] 


Epoch 7 Average Gradient Norm: 5.1756


Epoch 8/100: 100%|██████████| 360/360 [01:00<00:00,  5.94batch/s, accuracy=79, grad_norm=6.61, loss=0.527]  


Epoch 8 Average Gradient Norm: 5.0349


Epoch 9/100: 100%|██████████| 360/360 [01:00<00:00,  5.90batch/s, accuracy=79.4, grad_norm=4.88, loss=0.517]


Epoch 9 Average Gradient Norm: 5.3049


Epoch 10/100: 100%|██████████| 360/360 [01:00<00:00,  5.92batch/s, accuracy=80.1, grad_norm=8.47, loss=0.499]


Epoch 10 Average Gradient Norm: 4.4294


Epoch 11/100: 100%|██████████| 360/360 [01:00<00:00,  5.93batch/s, accuracy=80.5, grad_norm=2.78, loss=0.485]


Epoch 11 Average Gradient Norm: 3.7312


Epoch 12/100: 100%|██████████| 360/360 [01:00<00:00,  5.93batch/s, accuracy=80.7, grad_norm=3.79, loss=0.481]


Epoch 12 Average Gradient Norm: 4.0609


Epoch 13/100: 100%|██████████| 360/360 [01:00<00:00,  5.90batch/s, accuracy=80.9, grad_norm=5.09, loss=0.474]


Epoch 13 Average Gradient Norm: 3.7780


Epoch 14/100: 100%|██████████| 360/360 [01:00<00:00,  5.93batch/s, accuracy=80.9, grad_norm=2.08, loss=0.473]


Epoch 14 Average Gradient Norm: 4.1929


Epoch 15/100: 100%|██████████| 360/360 [01:00<00:00,  5.90batch/s, accuracy=81.1, grad_norm=6.91, loss=0.469]


Epoch 15 Average Gradient Norm: 4.1456


Epoch 16/100: 100%|██████████| 360/360 [01:00<00:00,  5.91batch/s, accuracy=81.2, grad_norm=4.89, loss=0.464]


Epoch 16 Average Gradient Norm: 3.9048


Epoch 17/100: 100%|██████████| 360/360 [01:00<00:00,  5.92batch/s, accuracy=81.3, grad_norm=4.22, loss=0.462]


Epoch 17 Average Gradient Norm: 4.1462


Epoch 18/100: 100%|██████████| 360/360 [01:00<00:00,  5.95batch/s, accuracy=81.4, grad_norm=5.57, loss=0.46] 


Epoch 18 Average Gradient Norm: 4.2861


Epoch 19/100: 100%|██████████| 360/360 [01:00<00:00,  5.92batch/s, accuracy=81.6, grad_norm=5.56, loss=0.456]


Epoch 19 Average Gradient Norm: 3.9045


Epoch 20/100: 100%|██████████| 360/360 [01:01<00:00,  5.90batch/s, accuracy=81.5, grad_norm=3.02, loss=0.456]


Epoch 20 Average Gradient Norm: 4.1155
Early stopping

=== Final Summary ===
Fold 1: Val = 81.62%, Test = 78.82%
Fold 2: Val = 83.15%, Test = 80.81%
Fold 3: Val = 81.51%, Test = 76.85%
Fold 4: Val = 82.15%, Test = 79.67%
Fold 5: Val = 82.13%, Test = 79.73%
Average Val Accuracy: 82.11%
Average Test Accuracy: 79.18%


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix
import seaborn as sns
from torch.utils.data import DataLoader, TensorDataset, random_split
from datetime import datetime
from tqdm import tqdm
import random
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from sklearn.model_selection import train_test_split

# 假设 SNR_dB, fd, equalized 已定义
SNR_dB = globals().get('SNR_dB', 'no')
fd = globals().get('fd', 'no')
equalized = globals().get('equalized', 'no')

# 假设 X_train_processed, y_train, X_test_processed, y_test 已定义
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# === 模型定义 ===
class SignalTransformer(nn.Module):
    def __init__(self, raw_input_dim, model_dim, num_heads, num_layers, num_classes, dropout=0.1):
        super(SignalTransformer, self).__init__()
        self.embedding = nn.Linear(raw_input_dim, model_dim)
        encoder_layer = TransformerEncoderLayer(
            d_model=model_dim, nhead=num_heads, dropout=dropout, batch_first=True
        )
        self.encoder = TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(model_dim, num_classes)

    def forward(self, x):
        x = self.embedding(x)
        x = self.encoder(x)
        x = x[:, -1, :]
        x = self.fc(x)
        return x

# === 数据准备 ===
X_train_tensor = torch.tensor(X_train_processed, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
X_test_tensor = torch.tensor(X_test_processed, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.long)

num_classes = len(np.unique(y_train))

# 划分训练集 / 验证集
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

# === 参数空间 ===
param_space = {
    "model_dim": [128, 256, 512],
    "num_heads": [2, 4, 8],
    "num_layers": [1, 2, 3],
    "dropout": [0.1, 0.3, 0.5],
    "learning_rate": [1e-3, 5e-4, 1e-4],
    "batch_size": [128, 256, 512]
}
num_search = 100  # 随机搜索次数
patience = 5
raw_input_dim = 2
num_epochs = 300

results_summary = []
best_config = None
best_val_acc = 0

# 计算梯度范数
def compute_grad_norm(model):
    total_norm = 0.0
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    return total_norm ** 0.5

# 平滑曲线
def moving_average(x, w=5):
    return np.convolve(x, np.ones(w), 'valid') / w

# === 随机搜索 ===
for search_idx in range(num_search):
    config = {k: random.choice(v) for k, v in param_space.items()}
    print(f"\n=== Random Search {search_idx+1}/{num_search} ===")
    print(f"Params: {config}")

    # 创建保存目录
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    script_name = "wisig_cross_random"
    folder_name = f"{timestamp}_{script_name}_SNR{SNR_dB}"
    save_folder = os.path.join("search_results", folder_name)
    os.makedirs(save_folder, exist_ok=True)
    results_file = os.path.join(save_folder, "results.txt")

    with open(results_file, "w") as f:
        f.write(f"=== Hyperparameters ===\n{config}\n")

    # DataLoader
    train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False, drop_last=True)
    test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=False)

    # 模型 & 优化器
    model = SignalTransformer(raw_input_dim, config["model_dim"], config["num_heads"],
                              config["num_layers"], num_classes, config["dropout"]).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"], weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []
    grad_norms = []

    best_val = 0
    patience_counter = 0
    best_model_wts = None

    for epoch in range(num_epochs):
    # 训练
        model.train()
        running_loss, correct_train, total_train = 0.0, 0, 0
        batch_grad_norms = []

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            grad_norm = compute_grad_norm(model)
            batch_grad_norms.append(grad_norm)
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

        train_losses.append(running_loss / len(train_loader))
        train_accuracies.append(100 * correct_train / total_train)
        grad_norms.append(np.mean(batch_grad_norms))

        # 验证
        model.eval()
        correct_val, total_val = 0, 0
        val_loss_sum = 0.0
        with torch.no_grad():
            for val_inputs, val_labels in val_loader:
                val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
                val_outputs = model(val_inputs)
                val_loss = criterion(val_outputs, val_labels)
                val_loss_sum += val_loss.item()
                _, val_pred = torch.max(val_outputs, 1)
                total_val += val_labels.size(0)
                correct_val += (val_pred == val_labels).sum().item()

        val_acc = 100 * correct_val / total_val
        val_losses.append(val_loss_sum / len(val_loader))
        val_accuracies.append(val_acc)

        # 早停
        if val_acc > best_val:
            best_val = val_acc
            patience_counter = 0
            best_model_wts = model.state_dict()
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

        scheduler.step()


    # 恢复最佳权重
    if best_model_wts:
        model.load_state_dict(best_model_wts)

    # 测试集
    model.eval()
    test_preds, test_true = [], []
    with torch.no_grad():
        for test_inputs, test_labels in test_loader:
            test_inputs, test_labels = test_inputs.to(device), test_labels.to(device)
            test_outputs = model(test_inputs)
            _, predicted = torch.max(test_outputs, 1)
            test_preds.extend(predicted.cpu().numpy())
            test_true.extend(test_labels.cpu().numpy())

    test_acc = 100 * np.sum(np.array(test_preds) == np.array(test_true)) / len(test_true)
    with open(results_file, "a") as f:
        f.write(f"Params: {config} \n Val Acc: {val_acc:.2f}% | Test Acc: {test_acc:.2f}%\n")

    # 控制台即时输出
    print(f"[Result] Config {search_idx+1}/{num_search} - Val Acc: {val_acc:.2f}%, Test Acc: {test_acc:.2f}%")

    # 记录结果
    results_summary.append((config, best_val, test_acc))
    if best_val > best_val_acc:
        best_val_acc = best_val
        best_config = (config, test_acc)


# === 最佳结果 ===
print("\n=== Best Hyperparameters ===")
print(best_config)
