In [1]:
# 下载文件并命名为 Jamming_Classifier.zip
# !wget -O Jamming_Classifier.zip "https://zenodo.org/records/3783969/files/Jamming_Classifier.zip?download=1"

# 解压下载的 ZIP 文件
# !unzip Jamming_Classifier.zip

In [2]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 图像预处理：灰度图 → 256×256 → tensor (float32, 0~1)
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # 强制灰度
    transforms.Resize((256, 256)),                # 缩放到统一大小
    transforms.ToTensor(),                        # 转换为 float32 & 归一化 [0,1]
])

# 设置路径（你原本是调反了，training 应该是训练集）
train_dir = 'Dataset/Jamming_Classifier/Image_training_database'
test_dir = 'Dataset/Jamming_Classifier/Image_testing_database'

# 构建 PyTorch Dataset（仅记录路径 & 标签，不加载数据）
train_dataset = datasets.ImageFolder(root=train_dir, transform=transform)
test_dataset = datasets.ImageFolder(root=test_dir, transform=transform)

# 构建 DataLoader：每次只加载一小批，避免内存溢出
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)

# 获取类别名（ImageFolder 会自动按文件夹排序生成类别标签）
classes = train_dataset.classes
print("类别标签顺序:", classes)

类别标签顺序: ['DME', 'NB', 'NoJam', 'SingleAM', 'SingleChirp', 'SingleFM']


In [3]:
import os
# 统计每个类别的图像数量（保持与 ImageFolder 加载路径一致）
def count_images(folder_path, classes, dataset_name):
    print(f"\n{dataset_name} 圖片數量統計：")
    for label in classes:
        label_dir = os.path.join(folder_path, label)
        if not os.path.exists(label_dir):
            print(f"類別 {label} 不存在，跳過")
            continue
        count = len([file for file in os.listdir(label_dir) if file.endswith('.bmp')])
        print(f"{label}: {count} 張")

# 使用之前定义的路径和类名
count_images(train_dir, classes, '訓練集')
count_images(test_dir, classes, '測試集')


訓練集 圖片數量統計：
DME: 10000 張
NB: 10000 張
NoJam: 10000 張
SingleAM: 10000 張
SingleChirp: 10000 張
SingleFM: 10000 張

測試集 圖片數量統計：
DME: 10000 張
NB: 10000 張
NoJam: 10000 張
SingleAM: 10000 張
SingleChirp: 10000 張
SingleFM: 10000 張


In [4]:
import torch
import torch.nn as nn
from torchvision import models

# 使用预训练的 MobileNetV2 模型
base_model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT)

# 修改第一层卷积层（将输入通道数从 3 改为 1）
old_conv = base_model.features[0][0]
new_conv = nn.Conv2d(
    in_channels=1,
    out_channels=old_conv.out_channels,
    kernel_size=old_conv.kernel_size,
    stride=old_conv.stride,
    padding=old_conv.padding,
    bias=old_conv.bias is not None
)

# 平均 RGB 权重 → 适配灰度
with torch.no_grad():
    new_conv.weight = nn.Parameter(old_conv.weight.mean(dim=1, keepdim=True))

base_model.features[0][0] = new_conv

# 获取类别数（来自 ImageFolder）
num_classes = len(train_dataset.classes)

# 替换分类头
base_model.classifier = nn.Linear(base_model.classifier[1].in_features, num_classes)

# 冻结特征层参数（可选）
for param in base_model.features.parameters():
    param.requires_grad = False

# 设置 BN 层为 eval 模式，防止 BatchNorm 统计被更新
for m in base_model.features.modules():
    if isinstance(m, nn.BatchNorm2d):
        m.eval()

# 显示模型结构确认
print(base_model)

MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=

In [5]:
import torch.nn as nn
import torch.optim as optim

# 损失函数：多分类交叉熵
criterion = nn.CrossEntropyLoss()

# 优化器：只优化分类器层（因为前面特征层已经冻结）
optimizer = optim.Adam(base_model.classifier.parameters(), lr=0.001)

# 设置设备（自动选择 GPU/CPU）
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model.to(device)

MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=

In [6]:
from torch.utils.data import random_split, DataLoader
from torchvision import datasets

# 原始 ImageFolder 的完整训练集
train_dataset = datasets.ImageFolder(root=train_dir, transform=transform)

# 计算 80/20 分割
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_subset, val_subset = random_split(
    train_dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

# 构建 DataLoader（分批加载）
train_loader = DataLoader(train_subset, batch_size=32, shuffle=True, num_workers=0)
val_loader = DataLoader(val_subset, batch_size=32, shuffle=False, num_workers=0)

# 测试集 DataLoader（不变）
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)

# 类别标签（来自 ImageFolder 自动识别的 class 文件夹名）
classes = train_dataset.classes
print("类别顺序:", classes)

类别顺序: ['DME', 'NB', 'NoJam', 'SingleAM', 'SingleChirp', 'SingleFM']


In [7]:
epochs = 30
patience = 5
best_val_loss = float('inf')
best_state_dict = None
epochs_no_improve = 0

for epoch in range(1, epochs+1):
    base_model.train()  # 切换模型到训练模式
    train_loss_sum = 0.0
    train_correct = 0
    total_train = 0

    # 训练批次循环
    for X_batch, y_batch in train_loader:
        # 将数据加载到计算设备
        X_batch = X_batch.to(device)
        y_batch = y_batch.to(device)
        # 前向传播
        # X_batch = X_batch.repeat(1, 3, 1, 1)  # 从灰度 (B,1,256,256) → RGB (B,3,256,256)
        outputs = base_model(X_batch)
        loss = criterion(outputs, y_batch)
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 统计训练损失和准确度
        train_loss_sum += loss.item() * y_batch.size(0)
        _, pred_labels = torch.max(outputs, 1)
        train_correct += (pred_labels == y_batch).sum().item()
        total_train += y_batch.size(0)
    # 计算平均训练损失和准确率
    avg_train_loss = train_loss_sum / total_train
    train_accuracy = train_correct / total_train

    # 验证阶段
    base_model.eval()  # 切换模型到评估模式
    val_loss_sum = 0.0
    val_correct = 0
    total_val = 0
    # 在验证集上不需要计算梯度
    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device)
            outputs = base_model(X_batch)
            loss = criterion(outputs, y_batch)
            # 累计验证损失和准确度
            val_loss_sum += loss.item() * y_batch.size(0)
            _, pred_labels = torch.max(outputs, 1)
            val_correct += (pred_labels == y_batch).sum().item()
            total_val += y_batch.size(0)
    avg_val_loss = val_loss_sum / total_val
    val_accuracy = val_correct / total_val

    # 输出本轮训练的结果
    print(f"Epoch {epoch}/{epochs} - "
          f"loss: {avg_train_loss:.4f} - accuracy: {train_accuracy:.4f} - "
          f"val_loss: {avg_val_loss:.4f} - val_accuracy: {val_accuracy:.4f}")

    # Early Stopping 检查：若验证损失改善，则保存最佳模型权重
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_state_dict = base_model.state_dict()  # 保存当前最佳状态
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print("验证集 loss 多次没有改善，提前停止训练。")
            if best_state_dict is not None:
                base_model.load_state_dict(best_state_dict)  # 恢复最佳模型权重
            break

Epoch 1/30 - loss: 0.2401 - accuracy: 0.9341 - val_loss: 0.1139 - val_accuracy: 0.9600
Epoch 2/30 - loss: 0.1268 - accuracy: 0.9552 - val_loss: 0.0990 - val_accuracy: 0.9633
Epoch 3/30 - loss: 0.1096 - accuracy: 0.9613 - val_loss: 0.0845 - val_accuracy: 0.9702
Epoch 4/30 - loss: 0.0999 - accuracy: 0.9645 - val_loss: 0.0796 - val_accuracy: 0.9699
Epoch 5/30 - loss: 0.0925 - accuracy: 0.9666 - val_loss: 0.0765 - val_accuracy: 0.9712
Epoch 6/30 - loss: 0.0882 - accuracy: 0.9687 - val_loss: 0.0718 - val_accuracy: 0.9729
Epoch 7/30 - loss: 0.0859 - accuracy: 0.9688 - val_loss: 0.0742 - val_accuracy: 0.9724
Epoch 8/30 - loss: 0.0815 - accuracy: 0.9696 - val_loss: 0.0704 - val_accuracy: 0.9738
Epoch 9/30 - loss: 0.0790 - accuracy: 0.9714 - val_loss: 0.0682 - val_accuracy: 0.9731
Epoch 10/30 - loss: 0.0797 - accuracy: 0.9705 - val_loss: 0.0883 - val_accuracy: 0.9663
Epoch 11/30 - loss: 0.0776 - accuracy: 0.9715 - val_loss: 0.0793 - val_accuracy: 0.9695
Epoch 12/30 - loss: 0.0752 - accuracy: 0.

In [9]:
base_model.eval()  # 模型设为评估模式
test_correct = 0
total_test = 0

with torch.no_grad():
    for X_batch, y_batch in test_loader:
        X_batch = X_batch.to(device)
        y_batch = y_batch.to(device)
        # X_batch = X_batch.repeat(1, 3, 1, 1)  # 转换为 RGB 三通道
        outputs = base_model(X_batch)
        _, pred_labels = torch.max(outputs, 1)
        test_correct += (pred_labels == y_batch).sum().item()
        total_test += y_batch.size(0)

test_accuracy = test_correct / total_test
print(f"测试集上的准确率: {test_accuracy * 100:.2f}%")


测试集上的准确率: 85.10%
