Skip to content

Commit

Permalink
Make offline ER us total batch size in first update (#381)
Browse files Browse the repository at this point in the history
  • Loading branch information
lballes committed Aug 18, 2023
1 parent 73df5c2 commit 2ca8df2
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
9 changes: 7 additions & 2 deletions src/renate/updaters/experimental/offline_er.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def on_model_update_start(
self._num_points_current_task = len(train_dataset)

def train_dataloader(self) -> DataLoader:
train_loader = super().train_dataloader()
loaders = {"current_task": train_loader}
loaders = {}
if len(self._memory_buffer) > self._memory_batch_size:
loaders["current_task"] = super().train_dataloader()
loaders["memory"] = DataLoader(
dataset=self._memory_buffer,
batch_size=self._memory_batch_size,
Expand All @@ -83,6 +83,11 @@ def train_dataloader(self) -> DataLoader:
pin_memory=True,
collate_fn=self._train_collate_fn,
)
else:
batch_size = self._batch_size
self._batch_size += self._memory_batch_size
loaders["current_task"] = super().train_dataloader()
self._batch_size = batch_size
return CombinedLoader(loaders, mode="max_size_cycle")

def on_model_update_end(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions test/integration_tests/configs/suites/quick/offline-er.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
"dataset": "cifar10.json",
"backend": "local",
"job_name": "class-incremental-mlp-offline-er",
"expected_accuracy_linux": [[0.7319999933242798, 0.4699999988079071], [0.7515000104904175, 0.49300000071525574]],
"expected_accuracy_darwin": [[0.7300000190734863, 0.5350000262260437]]
"expected_accuracy_linux": [[0.6980000138282776, 0.546999990940094], [0.6514999866485596, 0.3725000023841858]],
"expected_accuracy_darwin": [[0.7315000295639038, 0.49000000953674316]]
}

0 comments on commit 2ca8df2

Please sign in to comment.