-
Notifications
You must be signed in to change notification settings - Fork 0
/
eegnet_linear_probing.py
83 lines (70 loc) · 2.86 KB
/
eegnet_linear_probing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from pathlib import Path
import torch.nn
import yaml
import numpy as np
from torch.nn import MSELoss
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import FunctionTransformer
import moabb
from moabb.datasets import Schirrmeister2017
from moabb.evaluations import WithinSessionEvaluation, CrossSubjectEvaluation
from moabb.paradigms import MotorImagery
from moabb.analysis import Results
from models import EEGNetv4
from skorch_frozen import FrozenNeuralNetTransformer
moabb.set_log_level("info")
# Load condig
config_file = Path(__file__).parent / 'config.yaml'
local_config_file = Path(__file__).parent / 'local_config.yaml'
with config_file.open('r') as f:
config = yaml.safe_load(f)
with local_config_file.open('r') as f:
local_config = yaml.safe_load(f)
suffix = local_config['evaluation_params']['base']['suffix']
n_classes = config['paradigm_params']['base']['n_classes']
channels = config['paradigm_params']['base']['channels']
resample = config['paradigm_params']['base']['resample']
t0, t1 = Schirrmeister2017().interval
# Dataset
dataset = Schirrmeister2017()
datasets = [dataset]
paradigm = MotorImagery(
**config['paradigm_params']['base'],
**config['paradigm_params']['single_band'],
)
# Prepare checkpoint directories
results_param_names = ['suffix', 'overwrite', 'hdf5_path', 'additional_columns']
results_params = {k: local_config['evaluation_params']['base'][k] for k in results_param_names if
k in local_config['evaluation_params']['base']}
fake_results = Results(CrossSubjectEvaluation, MotorImagery, **results_params)
checkpoints_root_dir = Path(fake_results.filepath).parent
del fake_results
checkpoints_dict = {}
for subject in dataset.subject_list:
path = checkpoints_root_dir / str(subject)
files = list(path.glob('*.ckpt'))
if len(files) != 1:
raise ValueError(f'Multiple or no checkpoint file(s) present at {path}')
checkpoints_dict[subject] = str(files[0])
# Create pipeline
pipelines = {}
pipelines["EEGNet+LP"] = make_pipeline(
FunctionTransformer(func=np.float32, inverse_func=np.float64),
FrozenNeuralNetTransformer(EEGNetv4.load_from_checkpoint(str(list(checkpoints_dict.values())[0])).embedding, criterion=MSELoss),
LogisticRegression(),
)
def pre_fit_function(pipeline, dataset, subject):
path = checkpoints_dict[subject]
print(f'Loading checkpoint for subject {subject} from {path}')
pipeline[1].initialize().module.load_state_dict(EEGNetv4.load_from_checkpoint(path).embedding.state_dict())
# Evaluation
evaluation = WithinSessionEvaluation(
paradigm=paradigm, datasets=datasets,
pre_fit_function=pre_fit_function,
**config['evaluation_params']['base'],
**config['evaluation_params']['within_session'],
**local_config['evaluation_params']['base'],
)
results = evaluation.process(pipelines)
print(results)