In [9]:
import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
from collections import Counter

In [10]:

##########################################
# Due to hardware limitations, part of the data is extracted from the complete data set and divided into 24 parts.
# Each part corresponds to a modulation, with 1200*26=31200 pieces of data
# Therefore, the current data set size is 748800*1024*2
##########################################
def split_data(X_data, Y_data, train_frac=0.8, valid_frac=0.1):
    n_examples = X_data.shape[0]
    n_train = int(n_examples * train_frac)
    n_valid = int(n_examples * valid_frac)

    indices = np.arange(n_examples)
    np.random.shuffle(indices)

    train_idx = indices[:n_train]
    valid_idx = indices[n_train:n_train + n_valid]
    test_idx = indices[n_train + n_valid:]

    X_train, Y_train = X_data[train_idx], Y_data[train_idx]
    X_valid, Y_valid = X_data[valid_idx], Y_data[valid_idx]
    X_test, Y_test = X_data[test_idx], Y_data[test_idx]

    return X_train, Y_train, X_valid, Y_valid, X_test, Y_test

# Create an empty list to stack data
X_train, Y_train, X_valid, Y_valid, X_test, Y_test = [], [], [], [], [], []

for i in range(24):

    filename = f'F:\sourceCode\AMC-2\ExtractDataset\part{i}.h5'
    print(filename)
    with h5py.File(filename, 'r') as f:
        X_data = f['X'][:]
        Y_data = f['Y'][:]

    X_tr, Y_tr, X_val, Y_val, X_te, Y_te = split_data(X_data, Y_data)
    if i == 0:
        X_train, Y_train = X_tr, Y_tr
        X_valid, Y_valid = X_val, Y_val
        X_test, Y_test = X_te, Y_te
    else:
        X_train = np.vstack((X_train, X_tr))
        Y_train = np.vstack((Y_train, Y_tr))
        X_valid = np.vstack((X_valid, X_val))
        Y_valid = np.vstack((Y_valid, Y_val))
        X_test = np.vstack((X_test, X_te))
        Y_test = np.vstack((Y_test, Y_te))

print('shape of X_train：',X_train.shape)
print('shape of Y_train：',Y_train.shape)
print('shape of X_test：',X_test.shape)
print('shape of Y_test：',Y_test.shape)

# for i in range(0,17): #17个数据集文件
#     ########打开文件#######
#     filename = f'F:\ISEP_Learning_Document\Semester3\End-of-track Project\dataset\data_raw\Dataset\part{i}.h5'
#     print(filename)
#     f = h5py.File(filename,'r')
#     ########读取数据#######
#     X_data = f['X'][:]
#     Y_data = f['Y'][:]
#     Z_data = f['Z'][:]
#     f.close()
#     #########分割训练集和测试集#########
#     #每读取到一个数据文件就直接分割为训练集和测试集，防止爆内存
#     n_examples = X_data.shape[0]
#     n_train = int(n_examples * 0.7)   #70%训练样本
#     train_idx = np.random.choice(range(0,n_examples), size=n_train, replace=False)#随机选取训练样本下标
#     test_idx = list(set(range(0,n_examples))-set(train_idx))        #测试样本下标
#     if i == 0:
#         X_train = X_data[train_idx]
#         Y_train = Y_data[train_idx]
#         Z_train = Z_data[train_idx]
#         X_test = X_data[test_idx]
#         Y_test = Y_data[test_idx]
#         Z_test = Z_data[test_idx]
#     else:
#         X_train = np.vstack((X_train, X_data[train_idx]))
#         Y_train = np.vstack((Y_train, Y_data[train_idx]))
#         Z_train = np.vstack((Z_train, Z_data[train_idx]))
#         X_test = np.vstack((X_test, X_data[test_idx]))
#         Y_test = np.vstack((Y_test, Y_data[test_idx]))
#         Z_test = np.vstack((Z_test, Z_data[test_idx]))
# print('训练集X维度：',X_train.shape)
# print('训练集Y维度：',Y_train.shape)
# print('训练集Z维度：',Z_train.shape)
# print('测试集X维度：',X_test.shape)
# print('测试集Y维度：',Y_test.shape)
# print('测试集Z维度：',Z_test.shape)

F:\sourceCode\AMC-2\ExtractDataset\part0.h5
F:\sourceCode\AMC-2\ExtractDataset\part1.h5
F:\sourceCode\AMC-2\ExtractDataset\part2.h5
F:\sourceCode\AMC-2\ExtractDataset\part3.h5
F:\sourceCode\AMC-2\ExtractDataset\part4.h5
F:\sourceCode\AMC-2\ExtractDataset\part5.h5
F:\sourceCode\AMC-2\ExtractDataset\part6.h5
F:\sourceCode\AMC-2\ExtractDataset\part7.h5
F:\sourceCode\AMC-2\ExtractDataset\part8.h5
F:\sourceCode\AMC-2\ExtractDataset\part9.h5
F:\sourceCode\AMC-2\ExtractDataset\part10.h5
F:\sourceCode\AMC-2\ExtractDataset\part11.h5
F:\sourceCode\AMC-2\ExtractDataset\part12.h5
F:\sourceCode\AMC-2\ExtractDataset\part13.h5
F:\sourceCode\AMC-2\ExtractDataset\part14.h5
F:\sourceCode\AMC-2\ExtractDataset\part15.h5
F:\sourceCode\AMC-2\ExtractDataset\part16.h5
F:\sourceCode\AMC-2\ExtractDataset\part17.h5
F:\sourceCode\AMC-2\ExtractDataset\part18.h5
F:\sourceCode\AMC-2\ExtractDataset\part19.h5
F:\sourceCode\AMC-2\ExtractDataset\part20.h5
F:\sourceCode\AMC-2\ExtractDataset\part21.h5
F:\sourceCode\AMC-2\

In [11]:
classes = classes = [ 'OOK','4ASK','8ASK','BPSK', 'QPSK','8PSK','16PSK','32PSK','16APSK', '32APSK','64APSK','128APSK',
        '16QAM', '32QAM','64QAM','128QAM','256QAM','AM-SSB-WC','AM-SSB-SC','AM-DSB-WC',
        'AM-DSB-SC','FM', 'GMSK','OQPSK']

import numpy as np

def check_distribution(labels, classes):
    distribution = {class_name: np.sum(labels[:, i]) for i, class_name in enumerate(classes)}
    return distribution


# Check the distribution in the training set
train_distribution = check_distribution(Y_train, classes)
print("Modulated signal distribution in the training set：")
print(train_distribution)

# Check the distribution in the test set
test_distribution = check_distribution(Y_test, classes)
print("\nModulated signal distribution in test set：")
print(test_distribution)

Modulated signal distribution in the training set：
{'OOK': 24960, '4ASK': 24960, '8ASK': 24960, 'BPSK': 24960, 'QPSK': 24960, '8PSK': 24960, '16PSK': 24960, '32PSK': 24960, '16APSK': 24960, '32APSK': 24960, '64APSK': 24960, '128APSK': 24960, '16QAM': 24960, '32QAM': 24960, '64QAM': 24960, '128QAM': 24960, '256QAM': 24960, 'AM-SSB-WC': 24960, 'AM-SSB-SC': 24960, 'AM-DSB-WC': 24960, 'AM-DSB-SC': 24960, 'FM': 24960, 'GMSK': 24960, 'OQPSK': 24960}

Modulated signal distribution in test set：
{'OOK': 3120, '4ASK': 3120, '8ASK': 3120, 'BPSK': 3120, 'QPSK': 3120, '8PSK': 3120, '16PSK': 3120, '32PSK': 3120, '16APSK': 3120, '32APSK': 3120, '64APSK': 3120, '128APSK': 3120, '16QAM': 3120, '32QAM': 3120, '64QAM': 3120, '128QAM': 3120, '256QAM': 3120, 'AM-SSB-WC': 3120, 'AM-SSB-SC': 3120, 'AM-DSB-WC': 3120, 'AM-DSB-SC': 3120, 'FM': 3120, 'GMSK': 3120, 'OQPSK': 3120}


In [12]:
from torch.utils.data import TensorDataset, DataLoader
import torch

# Convert Y_train, Y_valid and Y_test from one-hot encoding to category index
Y_train_indices = np.argmax(Y_train, axis=1)
Y_valid_indices = np.argmax(Y_valid, axis=1)
Y_test_indices = np.argmax(Y_test, axis=1)

# Convert data to PyTorch Tensor
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
Y_train_tensor = torch.tensor(Y_train_indices, dtype=torch.long)
X_valid_tensor = torch.tensor(X_valid, dtype=torch.float32)
Y_valid_tensor = torch.tensor(Y_valid_indices, dtype=torch.long)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
Y_test_tensor = torch.tensor(Y_test_indices, dtype=torch.long)

# DataLoader
train_dataset = TensorDataset(X_train_tensor, Y_train_tensor)
valid_dataset = TensorDataset(X_valid_tensor, Y_valid_tensor)
test_dataset = TensorDataset(X_test_tensor, Y_test_tensor)

train_loader = DataLoader(dataset=train_dataset, batch_size=256, shuffle=True)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=256, shuffle=True)

In [13]:
from collections import Counter

def check_distribution(loader):
    counter = Counter()
    for _, targets in loader:
        counter.update(targets.tolist())
    return counter

# Check the distribution of training and test sets
train_distribution = check_distribution(train_loader)
test_distribution = check_distribution(test_loader)

print("Modulated signal distribution in the training set:", train_distribution)
print("Modulation signal distribution in the test set:", test_distribution)

Modulated signal distribution in the training set: Counter({16: 24960, 5: 24960, 13: 24960, 18: 24960, 14: 24960, 9: 24960, 11: 24960, 21: 24960, 12: 24960, 4: 24960, 6: 24960, 7: 24960, 10: 24960, 3: 24960, 15: 24960, 19: 24960, 8: 24960, 22: 24960, 2: 24960, 1: 24960, 0: 24960, 23: 24960, 17: 24960, 20: 24960})
Modulation signal distribution in the test set: Counter({16: 3120, 21: 3120, 15: 3120, 12: 3120, 8: 3120, 4: 3120, 23: 3120, 0: 3120, 17: 3120, 10: 3120, 6: 3120, 5: 3120, 9: 3120, 1: 3120, 18: 3120, 14: 3120, 20: 3120, 3: 3120, 19: 3120, 22: 3120, 11: 3120, 7: 3120, 13: 3120, 2: 3120})


In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Define residual stacking
# class ResidualStack(nn.Module):
#     def __init__(self, input_channels, output_channels, kernel_size, pool_size):
#         super(ResidualStack, self).__init__()
#         self.conv1x1 = nn.Conv2d(input_channels, output_channels, kernel_size=(1, 1), stride=1, padding='same')
#         self.conv2 = nn.Conv2d(output_channels, 32, kernel_size=kernel_size, stride=1, padding='same')
#         self.conv3 = nn.Conv2d(32, output_channels, kernel_size=kernel_size, stride=1, padding='same')
#         self.conv4 = nn.Conv2d(output_channels, 32, kernel_size=kernel_size, stride=1, padding='same')
#         self.conv5 = nn.Conv2d(32, output_channels, kernel_size=kernel_size, stride=1, padding='same')
#         self.pool = nn.MaxPool2d(pool_size)
#
#     def forward(self, x):
#         x = self.conv1x1(x)
#
#         # Residual Unit 1
#         shortcut = x
#         x = F.relu(self.conv2(x))
#         x = self.conv3(x)
#         x += shortcut
#         x = F.relu(x)
#
#         # Residual Unit 2
#         shortcut = x
#         x = F.relu(self.conv4(x))
#         x = self.conv5(x)
#         x += shortcut
#         x = F.relu(x)
#
#         x = self.pool(x)
#         return x
#
#
# # Define the complete model
# class ModulationClassificationModel(nn.Module):
#     def __init__(self, num_classes):
#         super(ModulationClassificationModel, self).__init__()
#         self.res_stack0 = ResidualStack(1, 32, kernel_size=(3, 2), pool_size=(2, 2))
#         self.res_stack1 = ResidualStack(32, 32, kernel_size=(3, 1), pool_size=(2, 1))
#         self.res_stack2 = ResidualStack(32, 32, kernel_size=(3, 1), pool_size=(2, 1))
#         self.res_stack3 = ResidualStack(32, 32, kernel_size=(3, 1), pool_size=(2, 1))
#         self.res_stack4 = ResidualStack(32, 32, kernel_size=(3, 1), pool_size=(2, 1))
#         self.res_stack5 = ResidualStack(32, 32, kernel_size=(3, 1), pool_size=(2, 1))
#         self.flatten = nn.Flatten()
#         self.fc1 = nn.Linear(32 * 16 * 1, 128)  # Adjust the input size based on the output of the last ResidualStack
#         self.alpha_dropout = nn.AlphaDropout(0.2)
#         self.fc2 = nn.Linear(128, num_classes)
#
#     def forward(self, x):
#         x = x.unsqueeze(1)  # Add a channel dimension
#         x = self.res_stack0(x)
#         x = self.res_stack1(x)
#         x = self.res_stack2(x)
#         x = self.res_stack3(x)
#         x = self.res_stack4(x)
#         x = self.res_stack5(x)
#         x = self.flatten(x)
#         x = F.selu(self.fc1(x))
#         x = self.alpha_dropout(x)
#         x = self.fc2(x)
#         return x
import torch
import torch.nn as nn
import torch.nn.functional as F


# 定义残差单元
class ResidualUnit(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size):
        super(ResidualUnit, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=kernel_size, stride=1, padding='same')
        self.conv2 = nn.Conv2d(32, output_channels, kernel_size=kernel_size, stride=1, padding='same')

    def forward(self, x):
        shortcut = x
        x = F.relu(self.conv1(x))
        x = self.conv2(x)
        x += shortcut
        return F.relu(x)

# 定义残差堆叠
class ResidualStack(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size, pool_size):
        super(ResidualStack, self).__init__()
        self.conv1x1 = nn.Conv2d(input_channels, output_channels, kernel_size=(1, 1), stride=1, padding='same')
        self.res_unit1 = ResidualUnit(output_channels, output_channels, kernel_size)
        self.res_unit2 = ResidualUnit(output_channels, output_channels, kernel_size)
        self.pool = nn.MaxPool2d(pool_size)

    def forward(self, x):
        x = self.conv1x1(x)
        x = self.res_unit1(x)
        x = self.res_unit2(x)
        x = self.pool(x)
        return x

# 定义完整模型
class ModulationClassificationModel(nn.Module):
    def __init__(self, num_classes):
        super(ModulationClassificationModel, self).__init__()
        self.reshape = nn.Unflatten(1, (1, 1024, 2))
        self.res_stack0 = ResidualStack(1, 32, kernel_size=(3, 2), pool_size=(2, 2))
        self.res_stack1 = ResidualStack(32, 32, kernel_size=(3, 1), pool_size=(2, 1))
        self.res_stack2 = ResidualStack(32, 32, kernel_size=(3, 1), pool_size=(2, 1))
        self.res_stack3 = ResidualStack(32, 32, kernel_size=(3, 1), pool_size=(2, 1))
        self.res_stack4 = ResidualStack(32, 32, kernel_size=(3, 1), pool_size=(2, 1))
        self.res_stack5 = ResidualStack(32, 32, kernel_size=(3, 1), pool_size=(2, 1))
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(32 * 16 * 1, 128)  # 根据池化后的特征图尺寸调整
        self.alpha_dropout = nn.AlphaDropout(0.3)
        self.fc2 = nn.Linear(128, num_classes)
        self.res_stack0 = ResidualStack(1, 32, kernel_size=(3, 2), pool_size=(2, 2))

    def forward(self, x):
        x = x.unsqueeze(1)  # 增加一个维度以匹配卷积层的输入要求
        x = self.res_stack0(x)
        x = self.res_stack1(x)
        x = self.res_stack2(x)
        x = self.res_stack3(x)
        x = self.res_stack4(x)
        x = self.res_stack5(x)
        x = self.flatten(x)
        x = F.selu(self.fc1(x))
        x = self.alpha_dropout(x)
        x = self.fc2(x)
        return x

num_classes = len(classes)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ModulationClassificationModel(num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.9)
criterion = nn.CrossEntropyLoss()


In [15]:
from tqdm import tqdm


def train_model(model, criterion, optimizer, train_loader, valid_loader, device, num_epochs=100):
    best_valid_loss = float('inf')

    for epoch in range(num_epochs):

        model.train()
        train_loss = 0
        train_correct = 0
        total = 0

        train_bar = tqdm(train_loader, desc=f'Training Epoch {epoch + 1}/{num_epochs}')
        for data, target in train_bar:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, preds = torch.max(output, 1)
            train_correct += preds.eq(target).sum().item()
            total += target.size(0)

            train_bar.set_postfix(loss=train_loss/total, accuracy=100.*train_correct/total)


        model.eval()
        valid_loss = 0
        valid_correct = 0
        valid_total = 0
        valid_bar = tqdm(valid_loader, desc=f'Validating Epoch {epoch + 1}/{num_epochs}')
        with torch.no_grad():
            for data, target in valid_bar:
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss = criterion(output, target)
                valid_loss += loss.item()
                _, preds = torch.max(output, 1)
                valid_correct += preds.eq(target).sum().item()
                valid_total += target.size(0)
                valid_bar.set_postfix(loss=valid_loss/valid_total, accuracy=100.*valid_correct/valid_total)

        avg_valid_loss = valid_loss / len(valid_loader)
        valid_accuracy = 100. * valid_correct / valid_total
        print(f'Epoch {epoch + 1}/{num_epochs} - Valid Loss: {avg_valid_loss:.4f}, Valid Accuracy: {valid_accuracy:.2f}%')


        lr_scheduler.step()


        if avg_valid_loss < best_valid_loss:
            best_valid_loss = avg_valid_loss
            torch.save(model.state_dict(), 'best_model.pth')

train_model(model, criterion, optimizer, train_loader, valid_loader, device, num_epochs=100)


Training Epoch 1/100: 100%|██████████| 2340/2340 [02:47<00:00, 13.96it/s, accuracy=28.4, loss=0.00805]
Validating Epoch 1/100: 100%|██████████| 293/293 [00:12<00:00, 22.73it/s, accuracy=38.1, loss=0.00698]


Epoch 1/100 - Valid Loss: 1.7836, Valid Accuracy: 38.13%


Training Epoch 2/100: 100%|██████████| 2340/2340 [02:50<00:00, 13.75it/s, accuracy=38.9, loss=0.00681]
Validating Epoch 2/100: 100%|██████████| 293/293 [00:09<00:00, 29.95it/s, accuracy=40.7, loss=0.00679]


Epoch 2/100 - Valid Loss: 1.7344, Valid Accuracy: 40.74%


Training Epoch 3/100: 100%|██████████| 2340/2340 [02:50<00:00, 13.69it/s, accuracy=42.4, loss=0.00642]
Validating Epoch 3/100: 100%|██████████| 293/293 [00:08<00:00, 32.61it/s, accuracy=41.2, loss=0.00669]


Epoch 3/100 - Valid Loss: 1.7109, Valid Accuracy: 41.24%


Training Epoch 4/100: 100%|██████████| 2340/2340 [02:51<00:00, 13.61it/s, accuracy=44.9, loss=0.00619]
Validating Epoch 4/100: 100%|██████████| 293/293 [00:08<00:00, 33.46it/s, accuracy=46.8, loss=0.00615]


Epoch 4/100 - Valid Loss: 1.5710, Valid Accuracy: 46.82%


Training Epoch 5/100: 100%|██████████| 2340/2340 [02:51<00:00, 13.68it/s, accuracy=47.3, loss=0.00596]
Validating Epoch 5/100: 100%|██████████| 293/293 [00:08<00:00, 32.74it/s, accuracy=47.6, loss=0.00605]


Epoch 5/100 - Valid Loss: 1.5456, Valid Accuracy: 47.61%


Training Epoch 6/100: 100%|██████████| 2340/2340 [02:48<00:00, 13.87it/s, accuracy=48.9, loss=0.00579]
Validating Epoch 6/100: 100%|██████████| 293/293 [00:08<00:00, 34.43it/s, accuracy=47.8, loss=0.00609]


Epoch 6/100 - Valid Loss: 1.5554, Valid Accuracy: 47.78%


Training Epoch 7/100: 100%|██████████| 2340/2340 [02:47<00:00, 13.93it/s, accuracy=50.4, loss=0.00563]
Validating Epoch 7/100: 100%|██████████| 293/293 [00:08<00:00, 34.04it/s, accuracy=50.3, loss=0.00573]


Epoch 7/100 - Valid Loss: 1.4637, Valid Accuracy: 50.33%


Training Epoch 8/100: 100%|██████████| 2340/2340 [02:47<00:00, 13.96it/s, accuracy=51.7, loss=0.00552]
Validating Epoch 8/100: 100%|██████████| 293/293 [00:08<00:00, 34.41it/s, accuracy=53, loss=0.00552]  


Epoch 8/100 - Valid Loss: 1.4112, Valid Accuracy: 53.00%


Training Epoch 9/100: 100%|██████████| 2340/2340 [02:48<00:00, 13.87it/s, accuracy=53.1, loss=0.00542]
Validating Epoch 9/100: 100%|██████████| 293/293 [00:08<00:00, 34.42it/s, accuracy=53.3, loss=0.00556]


Epoch 9/100 - Valid Loss: 1.4208, Valid Accuracy: 53.32%


Training Epoch 10/100: 100%|██████████| 2340/2340 [02:48<00:00, 13.87it/s, accuracy=54.1, loss=0.00532]
Validating Epoch 10/100: 100%|██████████| 293/293 [00:08<00:00, 34.13it/s, accuracy=54.8, loss=0.00553]


Epoch 10/100 - Valid Loss: 1.4121, Valid Accuracy: 54.79%


Training Epoch 11/100: 100%|██████████| 2340/2340 [02:50<00:00, 13.73it/s, accuracy=54.8, loss=0.00527]
Validating Epoch 11/100: 100%|██████████| 293/293 [00:08<00:00, 33.90it/s, accuracy=55.6, loss=0.00531]


Epoch 11/100 - Valid Loss: 1.3576, Valid Accuracy: 55.58%


Training Epoch 12/100: 100%|██████████| 2340/2340 [02:49<00:00, 13.84it/s, accuracy=55.1, loss=0.00523]
Validating Epoch 12/100: 100%|██████████| 293/293 [00:08<00:00, 33.77it/s, accuracy=55.2, loss=0.00535]


Epoch 12/100 - Valid Loss: 1.3666, Valid Accuracy: 55.19%


Training Epoch 13/100: 100%|██████████| 2340/2340 [02:51<00:00, 13.65it/s, accuracy=55.6, loss=0.00518]
Validating Epoch 13/100: 100%|██████████| 293/293 [00:08<00:00, 33.43it/s, accuracy=55.5, loss=0.00536]


Epoch 13/100 - Valid Loss: 1.3699, Valid Accuracy: 55.48%


Training Epoch 14/100:  51%|█████     | 1194/2340 [01:27<01:23, 13.66it/s, accuracy=55.9, loss=0.00515]


KeyboardInterrupt: 

In [None]:
# 测试模型
def test_model(model, test_loader, device):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            test_loss += loss.item()
            _, preds = torch.max(output, 1)
            correct += preds.eq(target).sum().item()
            total += target.size(0)

    avg_test_loss = test_loss / len(test_loader)
    test_accuracy = 100. * correct / total
    print(f'Test Loss: {avg_test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

# 加载最佳模型并进行测试
model.load_state_dict(torch.load('best_model.pth'))
test_model(model, test_loader, device)