Skip to content

Commit

Permalink
Merge pull request #88 from ANNUBS/fix_build_trials_inputs
Browse files Browse the repository at this point in the history
fix: values assignment for the input signals
  • Loading branch information
gcroci2 committed May 10, 2024
2 parents ecefec2 + 0c75ab4 commit 0dc7c68
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
6 changes: 3 additions & 3 deletions annubes/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,10 +470,10 @@ def _build_trials_inputs(self) -> NDArray[np.float64]:
(len(self._time[n]), self._n_inputs),
dtype=np.float32,
)
for idx, _ in enumerate(self._modalities):
value = self._rng.choice(self.stim_intensities, 1) if self._modality_seq[n] != "X" else 0
for idx, mod in enumerate(self._modalities):
if self._modality_seq[n] != "X" and self._modality_seq[n] == mod:
x[n][self._phases[n]["input"], idx] = self._rng.choice(self.stim_intensities, 1)
x[n][self._phases[n]["fix_time"], idx] = self.fix_intensity
x[n][self._phases[n]["input"], idx] = value
x[n][self._phases[n]["input"], self._n_inputs - 1] = 1 # start cue
# add noise
x[n] += noise_factor * self._rng.normal(loc=0, scale=1, size=x[n].shape)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,3 +550,19 @@ def test_plot_trials(task: Task):
== f"Number of plots requested ({n_plots}) exceeds number of trials ({ntrials}). Will plot all trials."
)
assert warning.category == UserWarning


def test_intensity_trials():
task = Task(name=NAME, session=SESSION, stim_intensities=STIM_INTENSITIES, scaling=SCALING)
trials = task.generate_trials(ntrials=NTRIALS)
high_val = 0.6 # for a signal to be considered high
low_val = 0.3 # for a signal to be considered low
for n in range(NTRIALS):
for idx, mod in enumerate(task._modalities):
assert (
(trials["inputs"][n][task._phases[n]["input"], idx] > high_val).all()
if trials["modality_seq"][n] == mod # check if the signal is high if the modality is the current one
else (
trials["inputs"][n][task._phases[n]["input"], idx] < low_val
).all() # check if the signal is low otherwise
)

0 comments on commit 0dc7c68

Please sign in to comment.