# EEG SSVEP Test

## 一、Dataset类创建

In [71]:
from typing import Any, Callable, Dict, List, Optional, Tuple
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import numpy as np
import scipy

class Benchmark(Dataset):
    
    classes = {
        
    }
    
    stim_event_freq = [8., 8.2, 8.4, 8.6, 8.8, 9., 9.2, 9.4, 9.6, 9.8, 10., 10.2, 10.4, 10.6,
                       10.8, 11., 11.2, 11.4, 11.6, 11.8, 12., 12.2, 12.4, 12.6, 12.8, 13., 13.2, 13.4,
                       13.6, 13.8, 14., 14.2, 14.4, 14.6, 14.8, 15., 15.2, 15.4, 15.6, 15.8]
    
    def __init__(
        self,
        root: str = '',
        train: bool = True,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
    ) -> None:
        super(Dataset).__init__()
        self.root = root
        self.train = train
        self.transform = transform
        self.target_transform = target_transform
        self.data, self.pre_data, self.label = self.load_data()
    
    def load_data(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        channels = [53, 54, 55, 57, 58, 59, 61, 62, 63]
        channels = [i - 1 for i in channels]
        
        if self.train:
            # train data
            data = np.zeros((200*35, len(channels), 1375))
            pre_data = np.zeros((200*35, len(channels), 125))
            label = np.zeros(200*35, dtype=int)
        else:
            # test data
            data = np.zeros((40*35, len(channels), 1375))
            pre_data = np.zeros((40*35, len(channels), 125))
            label = np.zeros(40*35, dtype=int)
            
        for sub_num in range(1, 36):
            f = scipy.io.loadmat(self.root + f"/S{sub_num}.mat")
            print(f"mat{sub_num}文件大小: {f['data'].shape}")
            for block in range(6):
                for target in range(40):
                    if self.train and block!=5:
                        data[(sub_num - 1) * 200 + block * 40 + target] = f["data"][channels, 125:, target, block]
                        pre_data[(sub_num - 1) * 200 + block * 40 + target] = f["data"][channels, :125, target, block]
                        label[(sub_num - 1) * 200 + block * 40 + target] = int(target + 1)
                    elif not self.train and block==5:
                        data[(sub_num - 1) * 40 + target] = f["data"][channels, 125:, target, block]
                        pre_data[(sub_num - 1) * 40 + target] = f["data"][channels, :125, target, block]
                        label[(sub_num - 1) * 40 + target] = int(target + 1)
        return data, pre_data, label
    
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, index) -> Tuple[Any, Any]:
        eeg, target = self.data[index], self.label[index]
        
        if self.transform is not None:
            eeg = self.transform(eeg)

        if self.target_transform is not None:
            target = self.target_transform(target)
        
        return eeg, target


In [72]:
# torch.set_default_dtype(torch.float64)
train_data = Benchmark("E:\Datasets\BCI\SSVEP\Benchmark", train = True, 
                       transform = transforms.Compose([
                           transforms.ToTensor(),
                       ]))
test_data = Benchmark("E:\Datasets\BCI\SSVEP\Benchmark", train = False, 
                       transform = transforms.Compose([
                           transforms.ToTensor(),
                       ]))

mat1文件大小: (64, 1500, 40, 6)
mat2文件大小: (64, 1500, 40, 6)
mat3文件大小: (64, 1500, 40, 6)
mat4文件大小: (64, 1500, 40, 6)
mat5文件大小: (64, 1500, 40, 6)
mat6文件大小: (64, 1500, 40, 6)
mat7文件大小: (64, 1500, 40, 6)
mat8文件大小: (64, 1500, 40, 6)
mat9文件大小: (64, 1500, 40, 6)
mat10文件大小: (64, 1500, 40, 6)
mat11文件大小: (64, 1500, 40, 6)
mat12文件大小: (64, 1500, 40, 6)
mat13文件大小: (64, 1500, 40, 6)
mat14文件大小: (64, 1500, 40, 6)
mat15文件大小: (64, 1500, 40, 6)
mat16文件大小: (64, 1500, 40, 6)
mat17文件大小: (64, 1500, 40, 6)
mat18文件大小: (64, 1500, 40, 6)
mat19文件大小: (64, 1500, 40, 6)
mat20文件大小: (64, 1500, 40, 6)
mat21文件大小: (64, 1500, 40, 6)
mat22文件大小: (64, 1500, 40, 6)
mat23文件大小: (64, 1500, 40, 6)
mat24文件大小: (64, 1500, 40, 6)
mat25文件大小: (64, 1500, 40, 6)
mat26文件大小: (64, 1500, 40, 6)
mat27文件大小: (64, 1500, 40, 6)
mat28文件大小: (64, 1500, 40, 6)
mat29文件大小: (64, 1500, 40, 6)
mat30文件大小: (64, 1500, 40, 6)
mat31文件大小: (64, 1500, 40, 6)
mat32文件大小: (64, 1500, 40, 6)
mat33文件大小: (64, 1500, 40, 6)
mat34文件大小: (64, 1500, 40, 6)
mat35文件大小: (64, 1500, 4

In [73]:
print("train_data:\n", len(train_data))
print(f"shape: {train_data[0][0].shape}, type: {type(train_data[0][0])}")

print("test_data:\n", len(test_data))
print(f"shape: {test_data[0][0].shape}, type: {type(test_data[0][0])}")

print(train_data[2][1])

train_data:
 7000
shape: torch.Size([1, 9, 1375]), type: <class 'torch.Tensor'>
test_data:
 1400
shape: torch.Size([1, 9, 1375]), type: <class 'torch.Tensor'>
3


## 二、EEGNet构建

In [74]:
from torchsummary import summary
import time

device = torch.device('cuda' if not torch.cuda.is_available() else 'cpu')

class DepthwiseSeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, depth_multiplier=1):
        super(DepthwiseSeparableConv2d, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels * depth_multiplier, kernel_size=kernel_size,
                                   stride=(1, 1), padding=(0, 0 if kernel_size[0]>kernel_size[1] else max(kernel_size)//2), groups=in_channels, bias=False)
        self.pointwise = nn.Conv2d(in_channels * depth_multiplier, out_channels, kernel_size=(1, 1),
                                   stride=(1, 1), padding=(0, 0), bias=False)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x


class EEGNet(nn.Module):
    def __init__(self, nb_classes, Chans=64, Samples=128, dropoutRate=0.5, kernLength=64,
                 F1=8, D=2, F2=16, norm_rate=0.25, dropoutType='Dropout'):
        super(EEGNet, self).__init__()

        if dropoutType == 'SpatialDropout2D':
            self.dropoutType = nn.Dropout2d
        elif dropoutType == 'Dropout':
            self.dropoutType = nn.Dropout
        else:
            raise ValueError('dropoutType must be one of SpatialDropout2D '
                             'or Dropout, passed as a string.')

        self.block1 = nn.Sequential(
            nn.Conv2d(1, F1, (1, kernLength), padding=(0, kernLength//2), bias=False),
            nn.BatchNorm2d(F1),
            DepthwiseSeparableConv2d(F1, F1, kernel_size=(Chans, 1), depth_multiplier=D),
            nn.BatchNorm2d(F1),
            nn.ELU(),
            nn.AvgPool2d((1, 4)),
            self.dropoutType(dropoutRate)
        )

        self.block2 = nn.Sequential(
            DepthwiseSeparableConv2d(F1, F2, kernel_size=(1, 16), depth_multiplier=1),
            nn.BatchNorm2d(F2),
            nn.ELU(),
            nn.AvgPool2d((1, 8)),
            self.dropoutType(dropoutRate)
        )

        self.block3 = nn.Sequential(
            nn.Flatten(),
            nn.Linear(F2*int(np.floor((np.floor((Samples+1)/4)+1)/8)), nb_classes),
            nn.Softmax(dim=1)
        )

    def forward(self, input):
        x = self.block1(input)
        x = self.block2(x)
        x = self.block3(x)
        return x

learning_rate = 1e-3
nb_classes = 40
Chans = 9
Samples = 1375
model = EEGNet(nb_classes, Chans, Samples)
model = model.to(device)
# torch.set_default_dtype(torch.float64)
print(model.state_dict()['block1.0.weight'].dtype)
print(device)
# print(summary(model, input_size=(1, 9, 1375), device="cuda"))


torch.float64
cpu


In [81]:
from torch.utils.tensorboard import SummaryWriter
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# tensorboard 记录训练结果
# writer = SummaryWriter("./logs_train")

# dataloader
train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
test_loader = DataLoader(test_data, batch_size=16, shuffle=False)

# 损失函数
criterion = nn.CrossEntropyLoss()

# 优化器
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, foreach=False)

# Training Loop
start_time = time.time()
total_train_step = 0
num_epochs = 200
for epoch in range(num_epochs):
    model.train()
    for inputs, labels in train_loader:
        labels = labels - 1
        inputs, labels = inputs.to(device), labels.to(device)
        # 优化器清除梯度
        optimizer.zero_grad()
        outputs = model(inputs)
        # 交叉熵计算损失
        loss = criterion(outputs, labels.long())
        # 优化器优化模型
        loss.backward()
        optimizer.step()
        # 误差分析
        total_train_step += 1
        if total_train_step % 100 == 0:
            end_time = time.time()
            print(f"{end_time-start_time}  训练次数：{total_train_step}, Loss：{loss}")
            # writer.add_scalar("train_loss", loss.item(), total_train_step)
            
# 测试
model.eval()
with torch.no_grad():
    total_correct = 0        
    total_samples = 0
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        values, predicted = torch.max(outputs, dim=1)
        total_correct += (predicted == labels).sum().item()
        total_samples += len(labels)
    
    accuracy = total_correct / total_samples
    print(f"Test Accuracy: {accuracy:.4f}")

    torch.save(model, f".\Weights/eeg_gpu.pth")
    # torch.save(light.state_dict(), f"Weights/light_{epoch}.pth")
    print("模型已保存")
# writer.close()

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [30]:
# dataloader = DataLoader(train_data, batch_size=16, shuffle=True)
# for i, (eegs, labels) in enumerate(dataloader):
#     print(f"{i}:", eegs.shape)

0: torch.Size([16, 1, 9, 1375])
1: torch.Size([16, 1, 9, 1375])
2: torch.Size([16, 1, 9, 1375])
3: torch.Size([16, 1, 9, 1375])
4: torch.Size([16, 1, 9, 1375])
5: torch.Size([16, 1, 9, 1375])
6: torch.Size([16, 1, 9, 1375])
7: torch.Size([16, 1, 9, 1375])
8: torch.Size([16, 1, 9, 1375])
9: torch.Size([16, 1, 9, 1375])
10: torch.Size([16, 1, 9, 1375])
11: torch.Size([16, 1, 9, 1375])
12: torch.Size([16, 1, 9, 1375])
13: torch.Size([16, 1, 9, 1375])
14: torch.Size([16, 1, 9, 1375])
15: torch.Size([16, 1, 9, 1375])
16: torch.Size([16, 1, 9, 1375])
17: torch.Size([16, 1, 9, 1375])
18: torch.Size([16, 1, 9, 1375])
19: torch.Size([16, 1, 9, 1375])
20: torch.Size([16, 1, 9, 1375])
21: torch.Size([16, 1, 9, 1375])
22: torch.Size([16, 1, 9, 1375])
23: torch.Size([16, 1, 9, 1375])
24: torch.Size([16, 1, 9, 1375])
25: torch.Size([16, 1, 9, 1375])
26: torch.Size([16, 1, 9, 1375])
27: torch.Size([16, 1, 9, 1375])
28: torch.Size([16, 1, 9, 1375])
29: torch.Size([16, 1, 9, 1375])
30: torch.Size([16, 

[ 8.   8.2  8.4  8.6  8.8  9.   9.2  9.4  9.6  9.8 10.  10.2 10.4 10.6
 10.8 11.  11.2 11.4 11.6 11.8 12.  12.2 12.4 12.6 12.8 13.  13.2 13.4
 13.6 13.8 14.  14.2 14.4 14.6 14.8 15.  15.2 15.4 15.6 15.8]
