-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_fbmsnet.py
73 lines (59 loc) · 2.08 KB
/
main_fbmsnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import numpy
import mne
from model.networks import FBMSNet
from get_data_data_filter import eegmat,stew,nback,eegmat_test,stew_test,nback_test,eegmat_val,stew_val,nback_val
from sklearn.metrics import accuracy_score,f1_score
torch.manual_seed(42)
numpy.random.seed(42)
eegmat=torch.utils.data.DataLoader(eegmat,batch_size=10,shuffle=True)
eegmat_val=torch.utils.data.DataLoader(eegmat_val,batch_size=10,shuffle=True)
eegmat_test=torch.utils.data.DataLoader(eegmat_test,batch_size=10,shuffle=True)
stew=torch.utils.data.DataLoader(stew,batch_size=10,shuffle=True)
stew_val=torch.utils.data.DataLoader(stew_val,batch_size=10,shuffle=True)
stew_test=torch.utils.data.DataLoader(stew_test,batch_size=10,shuffle=True)
nback=torch.utils.data.DataLoader(nback,batch_size=10,shuffle=True)
nback_val=torch.utils.data.DataLoader(nback_val,batch_size=10,shuffle=True)
nback_test=torch.utils.data.DataLoader(nback_test,batch_size=10,shuffle=True)
model=FBMSNet(nChan=10,nTime=2560,nClass=2)
crt=torch.nn.CrossEntropyLoss()
optim=torch.optim.Adam(model.parameters(),lr=0.001)
for epoch in range(2):
model.train()
losses=[]
for idx,(d,l,_) in enumerate(eegmat):
optim.zero_grad()
d=d.to(dtype=torch.float)
p,_=model(d)
loss=crt(p,l)
loss.backward()
optim.step()
losses.append(loss.item())
losses=sum(losses)
print(epoch,losses)
model.eval()
with torch.no_grad():
pps=[]
ls=[]
for d,l,_ in stew_val:
d=d.to(dtype=torch.float)
p,_=model(d)
pp=p.argmax(axis=-1)
pps.extend(pp)
ls.extend(l)
acc=accuracy_score(ls,pps)
f1=f1_score(ls,pps,average="macro")
print(acc,'\t',f1)
model.eval()
with torch.no_grad():
pps=[]
ls=[]
for d,l,_ in nback:
d=d.to(dtype=torch.float)
p,_=model(d)
pp=p.argmax(axis=-1)
pps.extend(pp)
ls.extend(l)
acc=accuracy_score(ls,pps)
f1=f1_score(ls,pps,average="macro")
print(acc,'\t',f1)