From 969e23fd6e29e3efc991a51e53e3188cca6037db Mon Sep 17 00:00:00 2001 From: "bencrulis@gmail.com" Date: Thu, 25 May 2023 13:37:51 +0200 Subject: [PATCH] Added a variable to control whether to retain the graph or not in the automatic backward pass --- avalanche/training/templates/base_sgd.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/avalanche/training/templates/base_sgd.py b/avalanche/training/templates/base_sgd.py index 35e21a49a..fb598d9cf 100644 --- a/avalanche/training/templates/base_sgd.py +++ b/avalanche/training/templates/base_sgd.py @@ -114,6 +114,9 @@ def __init__( ) """ Eval mini-batch size. """ + self.retain_graph: bool = False + """ Retain graph when calling loss.backward(). """ + if evaluator is None: evaluator = EvaluationPlugin() elif callable(evaluator): @@ -220,7 +223,7 @@ def training_epoch(self, **kwargs): def backward(self): """Run the backward pass.""" - self.loss.backward() + self.loss.backward(retain_graph=self.retain_graph) def optimizer_step(self): """Execute the optimizer step (weights update)."""