-
Notifications
You must be signed in to change notification settings - Fork 4
/
EEGNet.py
52 lines (46 loc) · 1.88 KB
/
EEGNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import torch.nn as nn
class EEGNetModel(nn.Module): # EEGNET-8,2
def __init__(self, chans=22, classes=4, time_points=1001, temp_kernel=32,
f1=16, f2=32, d=2, pk1=8, pk2=16, dropout_rate=0.5, max_norm1=1, max_norm2=0.25):
super(EEGNetModel, self).__init__()
# Calculating FC input features
linear_size = (time_points//(pk1*pk2))*f2
# Temporal Filters
self.block1 = nn.Sequential(
nn.Conv2d(1, f1, (1, temp_kernel), padding='same', bias=False),
nn.BatchNorm2d(f1),
)
# Spatial Filters
self.block2 = nn.Sequential(
nn.Conv2d(f1, d * f1, (chans, 1), groups=f1, bias=False), # Depthwise Conv
nn.BatchNorm2d(d * f1),
nn.ELU(),
nn.AvgPool2d((1, pk1)),
nn.Dropout(dropout_rate)
)
self.block3 = nn.Sequential(
nn.Conv2d(d * f1, f2, (1, 16), groups=f2, bias=False, padding='same'), # Separable Conv
nn.Conv2d(f2, f2, kernel_size=1, bias=False), # Pointwise Conv
nn.BatchNorm2d(f2),
nn.ELU(),
nn.AvgPool2d((1, pk2)),
nn.Dropout(dropout_rate)
)
self.flatten = nn.Flatten()
self.fc = nn.Linear(linear_size, classes)
# Apply max_norm constraint to the depthwise layer in block2
self._apply_max_norm(self.block2[0], max_norm1)
# Apply max_norm constraint to the linear layer
self._apply_max_norm(self.fc, max_norm2)
def _apply_max_norm(self, layer, max_norm):
for name, param in layer.named_parameters():
if 'weight' in name:
param.data = torch.renorm(param.data, p=2, dim=0, maxnorm=max_norm)
def forward(self, x):
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.flatten(x)
x = self.fc(x)
return x