Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FFCV support #1439

Merged
merged 24 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
5de1db8
FFCV support draft
lrzpellegrini Jun 29, 2023
8a7f7ef
Fix typo
lrzpellegrini Jun 29, 2023
1873651
Fixed PEP8 issues
lrzpellegrini Jul 5, 2023
8ea1ff9
Merge remote-tracking branch 'upstream/master' into ffcv_support_pt2
lrzpellegrini Jul 5, 2023
3bdeed5
Fix merge issue. Fix minor issue in FFCV support.
lrzpellegrini Jul 5, 2023
db0a58a
Fix for Python 3.7
lrzpellegrini Jul 5, 2023
de33044
Better dataset traversal and transformations equality checks
lrzpellegrini Jul 7, 2023
328374f
Implement __eq__ in transforms. Add transforms unit tests.
lrzpellegrini Jul 7, 2023
f6e8b1c
Fixed linter issues
lrzpellegrini Jul 7, 2023
b127a08
Additional linter fix.
lrzpellegrini Jul 7, 2023
33f6a65
Implemented batch sampling
lrzpellegrini Jul 10, 2023
78375c5
made internal elements private. Better names. Improved doc.
lrzpellegrini Jul 12, 2023
a1fcd6a
Minor fix
lrzpellegrini Jul 12, 2023
fd6016e
FFCV: loading no longer requires indices. Add unit test.
lrzpellegrini Jul 12, 2023
7cc9313
Added example for custom RGB fields.
lrzpellegrini Jul 12, 2023
4523d4b
Add FFCV docstrings
lrzpellegrini Jul 31, 2023
a0fbe78
Revert default persistent_workers value to False.
lrzpellegrini Jul 31, 2023
4f3f62b
Additional docstrings for SmartModuleWrapper
lrzpellegrini Jul 31, 2023
6b6df54
Moved FFCV examples in an ad-hoc folder.
lrzpellegrini Jul 31, 2023
81764c7
Fix linter issue
lrzpellegrini Jul 31, 2023
7bc5b04
Added unit tests for transformations flattening
lrzpellegrini Jul 31, 2023
1438199
Remove use_ffcv flag from MultiDatasetDataLoader
lrzpellegrini Jul 31, 2023
eca63fa
Merge remote-tracking branch 'upstream/master' into ffcv_support_pt2
lrzpellegrini Jul 31, 2023
bdc690f
Fix merge problem
lrzpellegrini Jul 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
107 changes: 91 additions & 16 deletions avalanche/benchmarks/utils/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@

from avalanche.benchmarks.utils.data import AvalancheDataset
from avalanche.benchmarks.utils.data_attribute import DataAttribute
from avalanche.benchmarks.utils.ffcv_support.ffcv_components import (
HybridFfcvLoader,
has_ffcv_support,
)
from avalanche.distributed.distributed_helper import DistributedHelper

from torch.utils.data.sampler import Sampler, BatchSampler
Expand Down Expand Up @@ -93,7 +97,9 @@ def __init__(
will not contribute to the minibatch composition near the end of
the epoch.
:param distributed_sampling: If True, apply the PyTorch
:class:`DistributedSampler`. Defaults to False.
:class:`DistributedSampler`. Defaults to True.
Note: the distributed sampler is not applied if not running
a distributed training, even when True is passed.
:param never_ending: If True, this data loader will cycle indefinitely
by iterating over all datasets again and again and the epoch will
never end. In this case, the `termination_dataset` and
Expand Down Expand Up @@ -125,8 +131,10 @@ def __init__(
self.termination_dataset: int = termination_dataset
self.never_ending: bool = never_ending

self.loader_kwargs, self.ffcv_args = self._extract_ffcv_args(self.loader_kwargs)

# Only used if persistent_workers == True in loader kwargs
self._persistent_loader = None
self._persistent_loader: Optional[DataLoader] = None

if "collate_fn" not in self.loader_kwargs:
self.loader_kwargs["collate_fn"] = self.datasets[0].collate_fn
Expand Down Expand Up @@ -155,7 +163,7 @@ def __init__(
_make_data_loader(
data_subset,
distributed_sampling,
kwargs,
self.loader_kwargs,
subset_mb_size,
force_no_workers=True,
)[0]
Expand All @@ -165,7 +173,7 @@ def __init__(
_make_data_loader(
self.datasets[self.termination_dataset],
distributed_sampling,
kwargs,
self.loader_kwargs,
self.batch_sizes[self.termination_dataset],
force_no_workers=True,
)[0]
Expand Down Expand Up @@ -198,23 +206,68 @@ def _get_loader(self):
self.loader_kwargs,
)

overall_dataset = ConcatDataset(self.datasets)

multi_dataset_batch_sampler = MultiDatasetSampler(
overall_dataset.datasets,
self.datasets,
samplers,
termination_dataset_idx=self.termination_dataset,
oversample_small_datasets=self.oversample_small_datasets,
never_ending=self.never_ending,
)

loader = _make_data_loader_with_batched_sampler(
overall_dataset,
batch_sampler=multi_dataset_batch_sampler,
if has_ffcv_support(self.datasets):
loader = self._make_ffcv_loader(
self.datasets,
multi_dataset_batch_sampler,
)
else:
loader = self._make_pytorch_loader(
self.datasets,
multi_dataset_batch_sampler,
)

return loader

def _make_pytorch_loader(
self, datasets: List[AvalancheDataset], batch_sampler: Sampler[List[int]]
):
return _make_data_loader_with_batched_sampler(
ConcatDataset(datasets),
batch_sampler=batch_sampler,
data_loader_args=self.loader_kwargs,
)

return loader
def _make_ffcv_loader(
self, datasets: List[AvalancheDataset], batch_sampler: Sampler[List[int]]
):
ffcv_args = dict(self.ffcv_args)
device = ffcv_args.pop("device")
print_ffcv_summary = ffcv_args.pop("print_ffcv_summary")

persistent_workers = self.loader_kwargs.get("persistent_workers", False)

return HybridFfcvLoader(
dataset=AvalancheDataset(datasets),
batch_sampler=batch_sampler,
ffcv_loader_parameters=ffcv_args,
device=device,
persistent_workers=persistent_workers,
print_ffcv_summary=print_ffcv_summary,
)

def _extract_ffcv_args(self, loader_args):
loader_args = dict(loader_args)
ffcv_args: Dict[str, Any] = loader_args.pop("ffcv_args", dict())
ffcv_args.setdefault("device", None)
ffcv_args.setdefault("print_ffcv_summary", False)

for arg_name, arg_value in loader_args.items():
if arg_name in ffcv_args:
# Already specified in ffcv_args -> discard
continue

if arg_name in HybridFfcvLoader.VALID_FFCV_PARAMS:
ffcv_args[arg_name] = arg_value
return loader_args, ffcv_args

def __len__(self):
return self.n_iterations
Expand All @@ -238,6 +291,17 @@ def _create_samplers(
return samplers


class SingleDatasetDataLoader(MultiDatasetDataLoader):
"""
Replacement of PyTorch DataLoader that also supports
the additional loading mechanisms implemented in
:class:`MultiDatasetDataLoader`.
"""

def __init__(self, datasets: AvalancheDataset, batch_size: int = 1, **kwargs):
super().__init__([datasets], [batch_size], **kwargs)


class GroupBalancedDataLoader(MultiDatasetDataLoader):
"""Data loader that balances data from multiple datasets."""

Expand Down Expand Up @@ -265,7 +329,9 @@ def __init__(
:param batch_size: the size of the batch. It must be greater than or
equal to the number of groups.
:param distributed_sampling: If True, apply the PyTorch
:class:`DistributedSampler`. Defaults to False.
:class:`DistributedSampler`. Defaults to True.
Note: the distributed sampler is not applied if not running
a distributed training, even when True is passed.
:param kwargs: data loader arguments used to instantiate the loader for
each group separately. See pytorch :class:`DataLoader`.
"""
Expand Down Expand Up @@ -302,7 +368,7 @@ def __init__(
data: AvalancheDataset,
batch_size: int = 32,
oversample_small_groups: bool = False,
distributed_sampling: bool = True, # TODO: doc fix
distributed_sampling: bool = True,
**kwargs
):
"""Task-balanced data loader for Avalanche's datasets.
Expand All @@ -320,7 +386,9 @@ def __init__(
:param oversample_small_groups: whether smaller tasks should be
oversampled to match the largest one.
:param distributed_sampling: If True, apply the PyTorch
:class:`DistributedSampler`. Defaults to False.
:class:`DistributedSampler`. Defaults to True.
Note: the distributed sampler is not applied if not running
a distributed training, even when True is passed.
:param kwargs: data loader arguments used to instantiate the loader for
each task separately. See pytorch :class:`DataLoader`.
"""
Expand Down Expand Up @@ -374,7 +442,9 @@ def __init__(
final mini-batch, NOT the final mini-batch size. The final
mini-batches will be of size `len(datasets) * batch_size`.
:param distributed_sampling: If True, apply the PyTorch
:class:`DistributedSampler`. Defaults to False.
:class:`DistributedSampler`. Defaults to True.
Note: the distributed sampler is not applied if not running
a distributed training, even when True is passed.
:param kwargs: data loader arguments used to instantiate the loader for
each group separately. See pytorch :class:`DataLoader`.
"""
Expand Down Expand Up @@ -434,7 +504,9 @@ def __init__(
task-balanced, otherwise it creates a single data loader for the
buffer samples.
:param distributed_sampling: If True, apply the PyTorch
:class:`DistributedSampler`. Defaults to False.
:class:`DistributedSampler`. Defaults to True.
Note: the distributed sampler is not applied if not running
a distributed training, even when True is passed.
:param kwargs: data loader arguments used to instantiate the loader for
each task separately. See pytorch :class:`DataLoader`.
"""
Expand Down Expand Up @@ -700,6 +772,7 @@ def _make_data_loader(
force_no_workers: bool = False,
):
data_loader_args = data_loader_args.copy()
data_loader_args.pop("ffcv_args", None)

collate_from_data_or_kwargs(dataset, data_loader_args)

Expand Down Expand Up @@ -747,6 +820,8 @@ def _make_data_loader_with_batched_sampler(
data_loader_args.pop("sampler", False)
data_loader_args.pop("drop_last", False)

data_loader_args.pop("ffcv_args", None)

return DataLoader(dataset, batch_sampler=batch_sampler, **data_loader_args)


Expand Down