In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from torchvision import transforms

## 读入数据

In [3]:
from scipy.io import loadmat

dataset = loadmat("dataset/Subject_B_Train.mat")

In [4]:
signals = dataset['Signal']
flashing = dataset['Flashing']
labels = dataset['StimulusType'] # 0 when intensified row / column does not include target character.

signals.shape, labels.shape

((85, 7794, 64), (85, 7794))

In [5]:
class EEGDataset(Dataset):
    def __init__(self, features: torch.Tensor, labels: torch.Tensor):
        super().__init__()
        self.features = features
        self.labels = labels
    
    def __len__(self):
        return self.features.shape[0]

    def __getitem__(self, idx: int):
        return self.features[idx], self.labels[idx]


In [None]:
trans = transforms.ToTensor()


In [9]:
class MaxNormConv2d(nn.Conv2d):
    def __init__(self, *args, max_norm=1, **kwargs):
        super().__init__(*args, **kwargs)
        self.max_norm = max_norm
    
    def forward(self, x):
        self.weight.data = torch.renorm(
            self.weight.data, p=2, dim=0, maxnorm=self.max_norm
        )
        return super().forward(x)


In [10]:
class EEGNet(nn.Module):
    def __init__(self, num_channels):
        super(EEGNet, self).__init__()
        self.T = 120
    
        self.C = num_channels
        self.F1 = 16 # The number of channels after first conv2d
        self.D = 2 # depth multiplier, used in depthwise conv
        self.p = 0.5 # dropout rate
        self.F2 = self.F1 * self.D
        self.N = 2 # number of classes

        # Here input shape is (1, C, T)
        b1 = nn.Sequential(
            nn.Conv2d(1, self.F1, (1, 64), padding="same"), # 1D Convolution to time
            nn.BatchNorm2d(self.F1, affine=True, eps=1e-3),
            MaxNormConv2d(self.F1, self.F1 * self.D,
            (self.C, 1), max_norm=1, groups=self.F1), # Depthwise Conv
            nn.BatchNorm2d(self.F1 * self.D, affine=True, eps=1e-3),
            nn.ELU(),
            nn.AvgPool2d((1, 4)),
            nn.Dropout(self.p)
        )

        b2 = nn.Sequential(
            nn.Conv2d(self.F1 * self.D, self.F1 * self.D, (1, 16),
                groups = self.F1 * self.D, padding="same"), # Separable Conv 1 (depthwise conv)
            nn.Conv2d(self.F1 * self.D, self.F2, kernel_size=1), # Separable Conv 1 (pointwise conv)
            nn.BatchNorm2d(self.F2),
            nn.ELU(),
            nn.AvgPool2d((1, 8)),
            nn.Dropout(self.p),
            nn.Flatten()
        )

        self.net = nn.Sequential(b1, b2, nn.LazyLinear(self.N))
    
    def forward(self, x):
        return self.net(x)

In [11]:
net = EEGNet(num_channels=64)
y1 = net(torch.from_numpy(signals[0]).T.reshape(1, 1, 64, -1))

  return F.conv2d(input, weight, bias, self.stride,


In [12]:
torch.from_numpy(signals[0]).T.shape, y1.shape

(torch.Size([64, 7794]), torch.Size([1, 2]))

In [13]:
y1

tensor([[-0.7683, -0.6069]], grad_fn=<AddmmBackward0>)