diff --git a/celldetection/data/misc.py b/celldetection/data/misc.py index 9888ba6..ccec098 100644 --- a/celldetection/data/misc.py +++ b/celldetection/data/misc.py @@ -16,7 +16,7 @@ def transpose_spatial(inputs: np.ndarray, inputs_channels_last=True, spatial_dim else: # e.g. (0, 2, 3, 1) b = list(range(inputs.ndim - spatial_dims, inputs.ndim)) # spatial dims - c = list(range(has_batch, inputs.ndim - spatial_dims - (1 - has_batch))) # n channels + c = list(range(has_batch, inputs.ndim - spatial_dims)) # n channels return np.transpose(inputs, a + b + c)