Skip to content

Commit

Permalink
Fix window bounds check in event_triggered_traces function
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed May 1, 2024
1 parent 0ea3cc8 commit 48ef857
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions bnpm/timeSeries.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def event_triggered_traces(
idx_triggers_clean = torch.as_tensor(idx_triggers[~torch.isnan(idx_triggers)], dtype=torch.long) ## remove nans from idx_triggers and convert to torch.long

windows = torch.stack([xAxis + i for i in torch.as_tensor(idx_triggers_clean, dtype=torch.long)], dim=0) ## make windows. shape = (n_triggers, len_window)
win_toInclude = (torch.any(windows<0, dim=1)==0) * (torch.any(windows>arr.shape[dim], dim=1)==0) ## boolean array of windows that are within the bounds of the length of 'dim'
win_toInclude = (torch.any(windows<0, dim=1)==0) * (torch.any(windows>=arr.shape[dim], dim=1)==0) ## boolean array of windows that are within the bounds of the length of 'dim'
n_win_excluded = torch.sum(win_toInclude==False) ## number of windows excluded due to window bounds. Only used for printing currently
windows = windows[win_toInclude] ## windows after pruning out windows that are out of bounds
n_windows = windows.shape[0] ## number of windows. Only used for printing currently
Expand All @@ -436,7 +436,7 @@ def event_triggered_traces(
shape = list(arr.shape) ## original shape
dims_perm = [dim] + list(range(dim)) + list(range(dim+1, len(shape))) ## new dims for indexing. put dim at dim 0
arr_perm = arr.permute(*dims_perm) ## permute to put 'dim' to dim 0
arr_idx = arr_perm.index_select(0, windows.reshape(-1)) ## index out windows along dim 0
arr_idx = arr_perm.index_select(dim=0, index=windows.reshape(-1)) ## index out windows along dim 0
rs = list(arr_perm.shape[1:]) + [n_windows, win_bounds[1]-win_bounds[0]] ## new shape for unflattening. 'dim' will be moved to dim -1, then reshaped to n_windows x len_window
arr_idx_rs = arr_idx.permute(*(list(range(1, arr_idx.ndim)) + [0])).reshape(*rs) ## permute to put current 'dim' (currently dim 0) to end, then reshape

Expand Down

0 comments on commit 48ef857

Please sign in to comment.