In [1]:
import torch
from tqdm import tqdm
from Datasets_Test import test_dataloader
from model_seizure import EEGLightNet

In [2]:
def test_model(model, dataloader, device):
    """
    测试模型在给定测试集上的表现:
    参数:
    - model: 已经训练好的模型。
    - dataloader: 测试数据集的数据加载器。
    - device: 训练设备(cpu/cuda)。
    """
    # 切换模型到评估模式
    model.eval()
    model.to(device)
    
    correct_predictions = 0
    total_predictions = 0
    
    # 使用 tqdm 包裹 dataloader 来显示进度条
    progress_bar = tqdm(dataloader, desc='Testing', unit='batch')
    
    # 在不计算梯度的情况下进行前向传播
    with torch.no_grad():
        for inputs, labels in progress_bar:
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            
            total_predictions += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()

            # 更新进度条的信息，显示当前的准确率
            progress_bar.set_postfix({'accuracy': correct_predictions / total_predictions})
    
    accuracy = correct_predictions / total_predictions
    print(f"\nTest Accuracy: {accuracy:.4f}")

In [3]:
model = EEGLightNet()
model_path = "model_seizure_5(92.2).pth"
# 加载已保存的模型参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(model_path, weights_only=True, map_location=device))
model.eval()

EEGLightNet(
  (input_block): Sequential(
    (0): Conv1d(6, 16, kernel_size=(1,), stride=(1,))
    (1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (feature_extractor): Sequential(
    (0): LightweightMultiScaleConv(
      (conv3): Conv1d(16, 21, kernel_size=(3,), stride=(1,), padding=(1,))
      (conv5): Conv1d(16, 21, kernel_size=(5,), stride=(1,), padding=(2,))
      (conv7): Conv1d(16, 21, kernel_size=(7,), stride=(1,), padding=(3,))
      (attention): Sequential(
        (0): AdaptiveAvgPool1d(output_size=1)
        (1): Flatten(start_dim=1, end_dim=-1)
        (2): Linear(in_features=63, out_features=15, bias=True)
        (3): ReLU()
        (4): Linear(in_features=15, out_features=3, bias=True)
        (5): Softmax(dim=1)
      )
      (bn): BatchNorm1d(63, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): Mish()
    )
    (1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mod

In [4]:
test_model(model, test_dataloader, device)

Testing: 100%|██████████| 2474/2474 [00:14<00:00, 165.40batch/s, accuracy=0.922]


Test Accuracy: 0.9219



