diff --git a/animaloc/train/trainers.py b/animaloc/train/trainers.py index e99b471..a46699c 100644 --- a/animaloc/train/trainers.py +++ b/animaloc/train/trainers.py @@ -524,7 +524,7 @@ def _train( wandb.log(loss_dict) self.losses = sum(loss for loss in loss_dict.values()) - batches_losses.append(self.losses) + batches_losses.append(self.losses.detach()) loss_dict_reduced = reduce_dict(loss_dict) losses_reduced = sum(loss for loss in loss_dict_reduced.values())