Skip to content

Commit

Permalink
Merge pull request #834 from hamedhemati/fixes
Browse files Browse the repository at this point in the history
Add force_data_batch_size to ReplayPlugin for manual assignment of data-mem ratio
  • Loading branch information
AntonioCarta committed Jan 7, 2022
2 parents 6eb8c81 + 19a5ffd commit 8abc551
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 13 deletions.
34 changes: 27 additions & 7 deletions avalanche/benchmarks/utils/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class GroupBalancedDataLoader:
def __init__(self, datasets: Sequence[AvalancheDataset],
oversample_small_groups: bool = False,
collate_mbatches=_default_collate_mbatches_fn,
batch_size: int = 32,
**kwargs):
""" Data loader that balances data from multiple datasets.
Expand All @@ -119,6 +120,8 @@ def __init__(self, datasets: Sequence[AvalancheDataset],
:param collate_mbatches: function that given a sequence of mini-batches
(one for each task) combines them into a single mini-batch. Used to
combine the mini-batches obtained separately from each task.
:param batch_size: the size of the batch. It must be greater than or
equal to the number of groups.
:param kwargs: data loader arguments used to instantiate the loader for
each group separately. See pytorch :class:`DataLoader`.
"""
Expand All @@ -127,8 +130,19 @@ def __init__(self, datasets: Sequence[AvalancheDataset],
self.oversample_small_groups = oversample_small_groups
self.collate_mbatches = collate_mbatches

# check if batch_size is larger than or equal to the number of datasets
assert batch_size >= len(datasets)

# divide the batch between all datasets in the group
ds_batch_size = batch_size // len(datasets)
remaining = batch_size % len(datasets)

for data in self.datasets:
self.dataloaders.append(DataLoader(data, **kwargs))
bs = ds_batch_size
if remaining > 0:
bs += 1
remaining -= 1
self.dataloaders.append(DataLoader(data, batch_size=bs, **kwargs))
self.max_len = max([len(d) for d in self.dataloaders])

def __iter__(self):
Expand Down Expand Up @@ -249,7 +263,9 @@ def __init__(self, data: AvalancheDataset, memory: AvalancheDataset = None,
combine the mini-batches obtained separately from each task.
:param batch_size: the size of the batch. It must be greater than or
equal to the number of tasks.
:param ratio_data_mem: How many of the samples should be from
:param force_data_batch_size: How many of the samples should be from the
current `data`. If None, it will equally divide each batch between
samples from all seen tasks in the current `data` and `memory`.
:param kwargs: data loader arguments used to instantiate the loader for
each task separately. See pytorch :class:`DataLoader`.
"""
Expand All @@ -265,19 +281,23 @@ def __init__(self, data: AvalancheDataset, memory: AvalancheDataset = None,
assert force_data_batch_size <= batch_size, \
"Forced batch size of data must be <= entire batch size"

mem_batch_size = batch_size - force_data_batch_size
remaining_example = 0
remaining_example_data = 0

mem_keys = len(self.memory.task_set)
mem_batch_size = batch_size - force_data_batch_size
mem_batch_size_k = mem_batch_size // mem_keys
remaining_example_mem = mem_batch_size % mem_keys

assert mem_batch_size >= mem_keys, \
"Batch size must be greator or equal " \
"to the number of tasks in the memory."

self.loader_data, _ = self._create_dataloaders(
data, force_data_batch_size,
remaining_example, **kwargs)
remaining_example_data, **kwargs)
self.loader_memory, _ = self._create_dataloaders(
memory, mem_batch_size,
remaining_example, **kwargs)
memory, mem_batch_size_k,
remaining_example_mem, **kwargs)
else:
num_keys = len(self.data.task_set) + len(self.memory.task_set)
assert batch_size >= num_keys, \
Expand Down
14 changes: 9 additions & 5 deletions avalanche/training/plugins/replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,19 @@ class ReplayPlugin(StrategyPlugin):
The :mem_size: attribute controls the total number of patterns to be stored
in the external memory.
:param storage_policy: The policy that controls how to add new exemplars
in memory
:param force_data_batch_size: How many of the samples should be from the
current `data`. If None, it will equally divide each batch between
samples from all seen tasks in the current `data` and `memory`.
"""

def __init__(self, mem_size: int = 200,
storage_policy: Optional["ExemplarsBuffer"] = None):
"""
:param storage_policy: The policy that controls how to add new exemplars
in memory
"""
storage_policy: Optional["ExemplarsBuffer"] = None,
force_data_batch_size: int = None):
super().__init__()
self.mem_size = mem_size
self.force_data_batch_size = force_data_batch_size

if storage_policy is not None: # Use other storage policy
self.storage_policy = storage_policy
Expand Down Expand Up @@ -68,6 +71,7 @@ def before_training_exp(self, strategy: "BaseStrategy",
oversample_small_tasks=True,
num_workers=num_workers,
batch_size=strategy.train_mb_size,
force_data_batch_size=self.force_data_batch_size,
shuffle=shuffle)

def after_training_exp(self, strategy: "BaseStrategy", **kwargs):
Expand Down
Binary file modified tests/target_metrics/mt.pickle
Binary file not shown.
Binary file modified tests/target_metrics/sit.pickle
Binary file not shown.
Binary file modified tests/target_metrics/tpp.pickle
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def setUpClass(cls) -> None:
collect_all=True) # collect all metrics (set to True by default)
cl_strategy = BaseStrategy(
model, SGD(model.parameters(), lr=0.001, momentum=0.9),
CrossEntropyLoss(), train_mb_size=2, train_epochs=2,
CrossEntropyLoss(), train_mb_size=4, train_epochs=2,
eval_mb_size=2, device=DEVICE,
evaluator=eval_plugin, eval_every=1)
for i, experience in enumerate(benchmark.train_stream):
Expand Down

0 comments on commit 8abc551

Please sign in to comment.