Skip to content

Commit

Permalink
Merge pull request #3 from sliwy/test_fix
Browse files Browse the repository at this point in the history
restored previous tests behavior
  • Loading branch information
PierreGtch committed Sep 9, 2023
2 parents 6514091 + 8348d9a commit b35dbdd
Showing 1 changed file with 38 additions and 20 deletions.
58 changes: 38 additions & 20 deletions test/unit_tests/test_eegneuralnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __getitem__(self, item):
return torch.rand(3, 10), item % 4


class MockModule(EEGModuleMixin, torch.nn.Module):
class MockModule1(EEGModuleMixin, torch.nn.Module):
def __init__(
self,
preds,
Expand All @@ -46,7 +46,30 @@ def __init__(
input_window_seconds=input_window_seconds,
sfreq=sfreq,
)
# self.preds = to_tensor(preds, device='cpu')
self.preds = to_tensor(preds, device='cpu')
self.final_layer = torch.nn.Conv1d(self.n_chans, self.n_outputs, self.n_times)

def forward(self, x):
return self.preds

class MockModule2(EEGModuleMixin, torch.nn.Module):
def __init__(
self,
n_outputs=None,
n_chans=None,
chs_info=None,
n_times=None,
input_window_seconds=None,
sfreq=None,
):
super().__init__(
n_outputs=n_outputs,
n_chans=n_chans,
chs_info=chs_info,
n_times=n_times,
input_window_seconds=input_window_seconds,
sfreq=sfreq,
)
self.final_layer = torch.nn.Conv1d(self.n_chans, self.n_outputs, self.n_times)

def forward(self, x):
Expand Down Expand Up @@ -138,15 +161,16 @@ def test_trialwise_predict_and_predict_proba(eegneuralnet_cls):
[0.125, 0.875],
[1., 0.],
[0.8, 0.2],
[0.8, 0.2],
[0.9, 0.1],
]
)
eegneuralnet = eegneuralnet_cls(
MockModule,
MockModule1,
module__preds=preds,
module__n_outputs=2,
module__n_chans=3,
module__n_times=3,
module__n_times=10,
optimizer=optim.Adam,
batch_size=32
)
Expand All @@ -158,7 +182,7 @@ def test_trialwise_predict_and_predict_proba(eegneuralnet_cls):

def test_cropped_predict_and_predict_proba(eegneuralnet_cls, preds):
eegneuralnet = eegneuralnet_cls(
MockModule,
MockModule1,
module__preds=preds,
module__n_outputs=4,
module__n_chans=3,
Expand All @@ -180,7 +204,7 @@ def test_cropped_predict_and_predict_proba(eegneuralnet_cls, preds):

def test_cropped_predict_and_predict_proba_not_aggregate_predictions(eegneuralnet_cls, preds):
eegneuralnet = eegneuralnet_cls(
MockModule,
MockModule1,
module__preds=preds,
module__n_outputs=4,
module__n_chans=3,
Expand All @@ -200,7 +224,7 @@ def test_cropped_predict_and_predict_proba_not_aggregate_predictions(eegneuralne

def test_predict_trials(eegneuralnet_cls, preds):
eegneuralnet = eegneuralnet_cls(
MockModule,
MockModule1,
module__preds=preds,
module__n_outputs=4,
module__n_chans=3,
Expand All @@ -219,7 +243,7 @@ def test_predict_trials(eegneuralnet_cls, preds):

def test_clonable(eegneuralnet_cls, preds):
eegneuralnet = eegneuralnet_cls(
MockModule,
MockModule1,
module__preds=preds,
module__n_outputs=4,
module__n_chans=3,
Expand All @@ -241,8 +265,7 @@ def test_clonable(eegneuralnet_cls, preds):
def test_set_signal_params_numpy(eegneuralnet_cls, preds, Xy):
X, y = Xy
net = eegneuralnet_cls(
MockModule,
module__preds=preds,
MockModule2,
cropped=False,
optimizer=optim.Adam,
batch_size=32,
Expand All @@ -262,9 +285,8 @@ def test_set_signal_params_epochs(eegneuralnet_cls, preds):
def test_set_signal_params_torch_ds(eegneuralnet_cls, preds):
n_outputs = (1 if eegneuralnet_cls == EEGRegressor else 4)
net = eegneuralnet_cls(
MockModule,
MockModule2,
module__n_outputs=n_outputs,
module__preds=preds,
cropped=False,
optimizer=optim.Adam,
batch_size=32,
Expand All @@ -280,8 +302,7 @@ def test_set_signal_params_torch_ds(eegneuralnet_cls, preds):
def test_set_signal_params_windows_ds_metadata(eegneuralnet_cls, preds, windows_dataset_metadata):
n_outputs = (1 if eegneuralnet_cls == EEGRegressor else 4)
net = eegneuralnet_cls(
MockModule,
module__preds=preds,
MockModule2,
cropped=False,
optimizer=optim.Adam,
batch_size=32,
Expand All @@ -297,8 +318,7 @@ def test_set_signal_params_windows_ds_metadata(eegneuralnet_cls, preds, windows_
def test_set_signal_params_windows_ds_channels(eegneuralnet_cls, preds, windows_dataset_channels):
n_outputs = (1 if eegneuralnet_cls == EEGRegressor else 4)
net = eegneuralnet_cls(
MockModule,
module__preds=preds,
MockModule2,
module__n_outputs=n_outputs,
cropped=False,
optimizer=optim.Adam,
Expand All @@ -315,8 +335,7 @@ def test_set_signal_params_windows_ds_channels(eegneuralnet_cls, preds, windows_
def test_set_signal_params_concat_ds_metadata(eegneuralnet_cls, preds, concat_dataset_metadata):
n_outputs = (1 if eegneuralnet_cls == EEGRegressor else 4)
net = eegneuralnet_cls(
MockModule,
module__preds=preds,
MockModule2,
cropped=False,
optimizer=optim.Adam,
batch_size=32,
Expand All @@ -332,8 +351,7 @@ def test_set_signal_params_concat_ds_metadata(eegneuralnet_cls, preds, concat_da
def test_set_signal_params_concat_ds_channels(eegneuralnet_cls, preds, concat_dataset_channels):
n_outputs = (1 if eegneuralnet_cls == EEGRegressor else 4)
net = eegneuralnet_cls(
MockModule,
module__preds=preds,
MockModule2,
module__n_outputs=n_outputs,
cropped=False,
optimizer=optim.Adam,
Expand Down

0 comments on commit b35dbdd

Please sign in to comment.