diff --git a/avalanche/benchmarks/utils/data_loader.py b/avalanche/benchmarks/utils/data_loader.py index 3a421696c..9339b5d71 100644 --- a/avalanche/benchmarks/utils/data_loader.py +++ b/avalanche/benchmarks/utils/data_loader.py @@ -102,6 +102,7 @@ class GroupBalancedDataLoader: def __init__(self, datasets: Sequence[AvalancheDataset], oversample_small_groups: bool = False, collate_mbatches=_default_collate_mbatches_fn, + batch_size: int = 32, **kwargs): """ Data loader that balances data from multiple datasets. @@ -119,6 +120,8 @@ def __init__(self, datasets: Sequence[AvalancheDataset], :param collate_mbatches: function that given a sequence of mini-batches (one for each task) combines them into a single mini-batch. Used to combine the mini-batches obtained separately from each task. + :param batch_size: the size of the batch. It must be greater than or + equal to the number of groups. :param kwargs: data loader arguments used to instantiate the loader for each group separately. See pytorch :class:`DataLoader`. """ @@ -127,8 +130,19 @@ def __init__(self, datasets: Sequence[AvalancheDataset], self.oversample_small_groups = oversample_small_groups self.collate_mbatches = collate_mbatches + # check if batch_size is larger than or equal to the number of datasets + assert batch_size >= len(datasets) + + # divide the batch between all datasets in the group + ds_batch_size = batch_size // len(datasets) + remaining = batch_size % len(datasets) + for data in self.datasets: - self.dataloaders.append(DataLoader(data, **kwargs)) + bs = ds_batch_size + if remaining > 0: + bs += 1 + remaining -= 1 + self.dataloaders.append(DataLoader(data, batch_size=bs, **kwargs)) self.max_len = max([len(d) for d in self.dataloaders]) def __iter__(self): @@ -249,7 +263,9 @@ def __init__(self, data: AvalancheDataset, memory: AvalancheDataset = None, combine the mini-batches obtained separately from each task. :param batch_size: the size of the batch. It must be greater than or equal to the number of tasks. - :param ratio_data_mem: How many of the samples should be from + :param force_data_batch_size: How many of the samples should be from the + current `data`. If None, it will equally divide each batch between + samples from all seen tasks in the current `data` and `memory`. :param kwargs: data loader arguments used to instantiate the loader for each task separately. See pytorch :class:`DataLoader`. """ @@ -265,19 +281,23 @@ def __init__(self, data: AvalancheDataset, memory: AvalancheDataset = None, assert force_data_batch_size <= batch_size, \ "Forced batch size of data must be <= entire batch size" - mem_batch_size = batch_size - force_data_batch_size - remaining_example = 0 + remaining_example_data = 0 + mem_keys = len(self.memory.task_set) + mem_batch_size = batch_size - force_data_batch_size + mem_batch_size_k = mem_batch_size // mem_keys + remaining_example_mem = mem_batch_size % mem_keys + assert mem_batch_size >= mem_keys, \ "Batch size must be greator or equal " \ "to the number of tasks in the memory." self.loader_data, _ = self._create_dataloaders( data, force_data_batch_size, - remaining_example, **kwargs) + remaining_example_data, **kwargs) self.loader_memory, _ = self._create_dataloaders( - memory, mem_batch_size, - remaining_example, **kwargs) + memory, mem_batch_size_k, + remaining_example_mem, **kwargs) else: num_keys = len(self.data.task_set) + len(self.memory.task_set) assert batch_size >= num_keys, \ diff --git a/avalanche/training/plugins/replay.py b/avalanche/training/plugins/replay.py index 168543e79..a6fad6d26 100644 --- a/avalanche/training/plugins/replay.py +++ b/avalanche/training/plugins/replay.py @@ -28,16 +28,19 @@ class ReplayPlugin(StrategyPlugin): The :mem_size: attribute controls the total number of patterns to be stored in the external memory. + :param storage_policy: The policy that controls how to add new exemplars + in memory + :param force_data_batch_size: How many of the samples should be from the + current `data`. If None, it will equally divide each batch between + samples from all seen tasks in the current `data` and `memory`. """ def __init__(self, mem_size: int = 200, - storage_policy: Optional["ExemplarsBuffer"] = None): - """ - :param storage_policy: The policy that controls how to add new exemplars - in memory - """ + storage_policy: Optional["ExemplarsBuffer"] = None, + force_data_batch_size: int = None): super().__init__() self.mem_size = mem_size + self.force_data_batch_size = force_data_batch_size if storage_policy is not None: # Use other storage policy self.storage_policy = storage_policy @@ -68,6 +71,7 @@ def before_training_exp(self, strategy: "BaseStrategy", oversample_small_tasks=True, num_workers=num_workers, batch_size=strategy.train_mb_size, + force_data_batch_size=self.force_data_batch_size, shuffle=shuffle) def after_training_exp(self, strategy: "BaseStrategy", **kwargs): diff --git a/tests/target_metrics/mt.pickle b/tests/target_metrics/mt.pickle index 162ae207d..53f02c071 100644 Binary files a/tests/target_metrics/mt.pickle and b/tests/target_metrics/mt.pickle differ diff --git a/tests/target_metrics/sit.pickle b/tests/target_metrics/sit.pickle index 2167514a9..9c13709c3 100644 Binary files a/tests/target_metrics/sit.pickle and b/tests/target_metrics/sit.pickle differ diff --git a/tests/target_metrics/tpp.pickle b/tests/target_metrics/tpp.pickle index 89df94540..b02386526 100644 Binary files a/tests/target_metrics/tpp.pickle and b/tests/target_metrics/tpp.pickle differ diff --git a/tests/test_metrics.py b/tests/test_metrics.py index fc9fd22c6..a560983d0 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -542,7 +542,7 @@ def setUpClass(cls) -> None: collect_all=True) # collect all metrics (set to True by default) cl_strategy = BaseStrategy( model, SGD(model.parameters(), lr=0.001, momentum=0.9), - CrossEntropyLoss(), train_mb_size=2, train_epochs=2, + CrossEntropyLoss(), train_mb_size=4, train_epochs=2, eval_mb_size=2, device=DEVICE, evaluator=eval_plugin, eval_every=1) for i, experience in enumerate(benchmark.train_stream):