Skip to content

Commit

Permalink
Fix multiclass BS=1 handling
Browse files Browse the repository at this point in the history
  • Loading branch information
GilesStrong committed Nov 19, 2021
1 parent 54337ab commit 187eea9
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
- `bootstrap_stats` corrected computation of central 68% CI: was `np.percentile(np.abs(points), 68.2)` now `(np.percentile(points, 84.135)-np.percentile(points, 15.865))/2`
- Error when trying to initialise `SEBlock2d` or `SEBlock3d`
- Fixed ipython display import to only run if in notebook
- Bug in multiclass-classification with on a batch of 1 data-point caused by targets being squeezed 2 dimensions, rather than 1.

## Changes

Expand Down
4 changes: 2 additions & 2 deletions lumin/nn/data/batch_yielder.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __iter__(self) -> List[Tensor]:
if self.bulk_move:
inputs = to_device(Tensor(self.inputs))
if self.targets is not None:
if 'multiclass' in self.objective: targets = to_device(Tensor(self.targets).long().squeeze())
if 'multiclass' in self.objective: targets = to_device(Tensor(self.targets).long().squeeze(-1))
else: targets = to_device(Tensor(self.targets))
if self.weights is not None and self.use_weights: weights = to_device(Tensor(self.weights))
else: weights = None
Expand All @@ -77,7 +77,7 @@ def __iter__(self) -> List[Tensor]:
for i in range(0, upper, self.bs):
idxs = full_idxs[i:i+self.bs]
if self.targets is not None:
if 'multiclass' in self.objective: y = to_device(Tensor(self.targets[idxs]).long().squeeze())
if 'multiclass' in self.objective: y = to_device(Tensor(self.targets[idxs]).long().squeeze(-1))
else: y = to_device(Tensor(self.targets[idxs]))
else:
y = None
Expand Down

0 comments on commit 187eea9

Please sign in to comment.