From 187eea9d33065d4c0ff84aa8bebaf64f244bf8ec Mon Sep 17 00:00:00 2001 From: GilesStrong Date: Fri, 19 Nov 2021 14:50:53 +0100 Subject: [PATCH] Fix multiclass BS=1 handling --- CHANGES.md | 1 + lumin/nn/data/batch_yielder.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index c009b1e..71170ae 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -28,6 +28,7 @@ - `bootstrap_stats` corrected computation of central 68% CI: was `np.percentile(np.abs(points), 68.2)` now `(np.percentile(points, 84.135)-np.percentile(points, 15.865))/2` - Error when trying to initialise `SEBlock2d` or `SEBlock3d` - Fixed ipython display import to only run if in notebook +- Bug in multiclass-classification with on a batch of 1 data-point caused by targets being squeezed 2 dimensions, rather than 1. ## Changes diff --git a/lumin/nn/data/batch_yielder.py b/lumin/nn/data/batch_yielder.py index 75a7fa8..c81d02a 100644 --- a/lumin/nn/data/batch_yielder.py +++ b/lumin/nn/data/batch_yielder.py @@ -59,7 +59,7 @@ def __iter__(self) -> List[Tensor]: if self.bulk_move: inputs = to_device(Tensor(self.inputs)) if self.targets is not None: - if 'multiclass' in self.objective: targets = to_device(Tensor(self.targets).long().squeeze()) + if 'multiclass' in self.objective: targets = to_device(Tensor(self.targets).long().squeeze(-1)) else: targets = to_device(Tensor(self.targets)) if self.weights is not None and self.use_weights: weights = to_device(Tensor(self.weights)) else: weights = None @@ -77,7 +77,7 @@ def __iter__(self) -> List[Tensor]: for i in range(0, upper, self.bs): idxs = full_idxs[i:i+self.bs] if self.targets is not None: - if 'multiclass' in self.objective: y = to_device(Tensor(self.targets[idxs]).long().squeeze()) + if 'multiclass' in self.objective: y = to_device(Tensor(self.targets[idxs]).long().squeeze(-1)) else: y = to_device(Tensor(self.targets[idxs])) else: y = None