Skip to content

Commit

Permalink
prediction fix
Browse files Browse the repository at this point in the history
  • Loading branch information
GilesStrong committed Nov 30, 2020
1 parent c5e931f commit 9a45395
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
## Fixes

- `Model` now creates `cb_savepath` is it didn't already exist
- Bug in `PredHandler` where predictions were kept on device leading to increased memory usage

## Changes

Expand Down
2 changes: 1 addition & 1 deletion lumin/nn/callbacks/pred_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ def on_pred_begin(self) -> None: self.preds = []
def on_pred_end(self) -> None: self.preds = torch.cat(self.preds)
def get_preds(self) -> np.ndarray: return self.preds
def on_forwards_end(self) -> None:
if self.model.fit_params.state == 'test': self.preds.append(self.model.fit_params.y_pred)
if self.model.fit_params.state == 'test': self.preds.append(self.model.fit_params.y_pred.cpu().detach())

0 comments on commit 9a45395

Please sign in to comment.