In [1]:
import copy
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torchsummary import summary
from torch.utils.data import DataLoader, TensorDataset
from DL_lab2_EEGclassification_model import EEGNet, DeepConvNet
from dataloader import read_bci_data

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def test(model):
    test_acc = 0
    model.eval()
    for i, (data, label) in enumerate(test_dataloader):
        with torch.no_grad():  
            data = data.to(device, dtype = torch.float)
            label = label.to(device, dtype = torch.long)
            predicts = model(data)  
            test_acc += predicts.max(dim=1)[1].eq(label).sum().item()  

    test_acc = 100.0 * test_acc / len(test_dataloader.dataset)
    print(f'Test acc: {test_acc:.3f}')
    
if __name__ == '__main__':
    batch_size = 64
    
    _, _, test_data, test_label = read_bci_data()
    dataset = TensorDataset(torch.Tensor(test_data), torch.Tensor(test_label))
    test_dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = False)
    
    summary(EEGNet().to(device), (1, 2, 750))
    summary(DeepConvNet().to(device), (1, 2, 750))
    for model_name in ['EEGNet', 'DeepConvNet']:
        for activation in ['ReLU', 'ELU', 'LeakyReLU']:
            if model_name == 'EEGNet':
                if activation == 'ReLU':
                    model = EEGNet(activation = nn.ReLU())
                elif activation == 'ELU':
                    model = EEGNet(activation = nn.ELU())
                else:
                    model = EEGNet(activation = nn.LeakyReLU())
            else:
                if activation == 'ReLU':
                    model = DeepConvNet(activation = nn.ReLU())
                elif activation == 'ELU':
                    model = DeepConvNet(activation = nn.ELU())
                else:
                    model = DeepConvNet(activation = nn.LeakyReLU())

            print(f'{model_name} {activation}:')
            model.load_state_dict(torch.load(f'./{model_name}_{activation}.pt'))
            model = model.to(device)
            test(model)
            print('----------------------------')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 2, 750]             816
       BatchNorm2d-2           [-1, 16, 2, 750]              32
            Conv2d-3           [-1, 32, 1, 750]              64
       BatchNorm2d-4           [-1, 32, 1, 750]              64
              ReLU-5           [-1, 32, 1, 750]               0
              ReLU-6           [-1, 32, 1, 750]               0
         AvgPool2d-7           [-1, 32, 1, 187]               0
           Dropout-8           [-1, 32, 1, 187]               0
            Conv2d-9           [-1, 32, 1, 187]          15,360
      BatchNorm2d-10           [-1, 32, 1, 187]              64
             ReLU-11           [-1, 32, 1, 187]               0
             ReLU-12           [-1, 32, 1, 187]               0
        AvgPool2d-13            [-1, 32, 1, 23]               0
          Dropout-14            [-1, 32