diff --git a/avalanche/benchmarks/utils/data_loader.py b/avalanche/benchmarks/utils/data_loader.py index 492fa7cdc..09b97fc6d 100644 --- a/avalanche/benchmarks/utils/data_loader.py +++ b/avalanche/benchmarks/utils/data_loader.py @@ -31,6 +31,10 @@ from avalanche.benchmarks.utils.data import AvalancheDataset from avalanche.benchmarks.utils.data_attribute import DataAttribute +from avalanche.benchmarks.utils.ffcv_support.ffcv_components import ( + HybridFfcvLoader, + has_ffcv_support, +) from avalanche.distributed.distributed_helper import DistributedHelper from torch.utils.data.sampler import Sampler, BatchSampler @@ -93,7 +97,9 @@ def __init__( will not contribute to the minibatch composition near the end of the epoch. :param distributed_sampling: If True, apply the PyTorch - :class:`DistributedSampler`. Defaults to False. + :class:`DistributedSampler`. Defaults to True. + Note: the distributed sampler is not applied if not running + a distributed training, even when True is passed. :param never_ending: If True, this data loader will cycle indefinitely by iterating over all datasets again and again and the epoch will never end. In this case, the `termination_dataset` and @@ -125,8 +131,10 @@ def __init__( self.termination_dataset: int = termination_dataset self.never_ending: bool = never_ending + self.loader_kwargs, self.ffcv_args = self._extract_ffcv_args(self.loader_kwargs) + # Only used if persistent_workers == True in loader kwargs - self._persistent_loader = None + self._persistent_loader: Optional[DataLoader] = None if "collate_fn" not in self.loader_kwargs: self.loader_kwargs["collate_fn"] = self.datasets[0].collate_fn @@ -155,7 +163,7 @@ def __init__( _make_data_loader( data_subset, distributed_sampling, - kwargs, + self.loader_kwargs, subset_mb_size, force_no_workers=True, )[0] @@ -165,7 +173,7 @@ def __init__( _make_data_loader( self.datasets[self.termination_dataset], distributed_sampling, - kwargs, + self.loader_kwargs, self.batch_sizes[self.termination_dataset], force_no_workers=True, )[0] @@ -198,23 +206,68 @@ def _get_loader(self): self.loader_kwargs, ) - overall_dataset = ConcatDataset(self.datasets) - multi_dataset_batch_sampler = MultiDatasetSampler( - overall_dataset.datasets, + self.datasets, samplers, termination_dataset_idx=self.termination_dataset, oversample_small_datasets=self.oversample_small_datasets, never_ending=self.never_ending, ) - loader = _make_data_loader_with_batched_sampler( - overall_dataset, - batch_sampler=multi_dataset_batch_sampler, + if has_ffcv_support(self.datasets): + loader = self._make_ffcv_loader( + self.datasets, + multi_dataset_batch_sampler, + ) + else: + loader = self._make_pytorch_loader( + self.datasets, + multi_dataset_batch_sampler, + ) + + return loader + + def _make_pytorch_loader( + self, datasets: List[AvalancheDataset], batch_sampler: Sampler[List[int]] + ): + return _make_data_loader_with_batched_sampler( + ConcatDataset(datasets), + batch_sampler=batch_sampler, data_loader_args=self.loader_kwargs, ) - return loader + def _make_ffcv_loader( + self, datasets: List[AvalancheDataset], batch_sampler: Sampler[List[int]] + ): + ffcv_args = dict(self.ffcv_args) + device = ffcv_args.pop("device") + print_ffcv_summary = ffcv_args.pop("print_ffcv_summary") + + persistent_workers = self.loader_kwargs.get("persistent_workers", False) + + return HybridFfcvLoader( + dataset=AvalancheDataset(datasets), + batch_sampler=batch_sampler, + ffcv_loader_parameters=ffcv_args, + device=device, + persistent_workers=persistent_workers, + print_ffcv_summary=print_ffcv_summary, + ) + + def _extract_ffcv_args(self, loader_args): + loader_args = dict(loader_args) + ffcv_args: Dict[str, Any] = loader_args.pop("ffcv_args", dict()) + ffcv_args.setdefault("device", None) + ffcv_args.setdefault("print_ffcv_summary", False) + + for arg_name, arg_value in loader_args.items(): + if arg_name in ffcv_args: + # Already specified in ffcv_args -> discard + continue + + if arg_name in HybridFfcvLoader.VALID_FFCV_PARAMS: + ffcv_args[arg_name] = arg_value + return loader_args, ffcv_args def __len__(self): return self.n_iterations @@ -238,6 +291,17 @@ def _create_samplers( return samplers +class SingleDatasetDataLoader(MultiDatasetDataLoader): + """ + Replacement of PyTorch DataLoader that also supports + the additional loading mechanisms implemented in + :class:`MultiDatasetDataLoader`. + """ + + def __init__(self, datasets: AvalancheDataset, batch_size: int = 1, **kwargs): + super().__init__([datasets], [batch_size], **kwargs) + + class GroupBalancedDataLoader(MultiDatasetDataLoader): """Data loader that balances data from multiple datasets.""" @@ -265,7 +329,9 @@ def __init__( :param batch_size: the size of the batch. It must be greater than or equal to the number of groups. :param distributed_sampling: If True, apply the PyTorch - :class:`DistributedSampler`. Defaults to False. + :class:`DistributedSampler`. Defaults to True. + Note: the distributed sampler is not applied if not running + a distributed training, even when True is passed. :param kwargs: data loader arguments used to instantiate the loader for each group separately. See pytorch :class:`DataLoader`. """ @@ -302,7 +368,7 @@ def __init__( data: AvalancheDataset, batch_size: int = 32, oversample_small_groups: bool = False, - distributed_sampling: bool = True, # TODO: doc fix + distributed_sampling: bool = True, **kwargs ): """Task-balanced data loader for Avalanche's datasets. @@ -320,7 +386,9 @@ def __init__( :param oversample_small_groups: whether smaller tasks should be oversampled to match the largest one. :param distributed_sampling: If True, apply the PyTorch - :class:`DistributedSampler`. Defaults to False. + :class:`DistributedSampler`. Defaults to True. + Note: the distributed sampler is not applied if not running + a distributed training, even when True is passed. :param kwargs: data loader arguments used to instantiate the loader for each task separately. See pytorch :class:`DataLoader`. """ @@ -374,7 +442,9 @@ def __init__( final mini-batch, NOT the final mini-batch size. The final mini-batches will be of size `len(datasets) * batch_size`. :param distributed_sampling: If True, apply the PyTorch - :class:`DistributedSampler`. Defaults to False. + :class:`DistributedSampler`. Defaults to True. + Note: the distributed sampler is not applied if not running + a distributed training, even when True is passed. :param kwargs: data loader arguments used to instantiate the loader for each group separately. See pytorch :class:`DataLoader`. """ @@ -434,7 +504,9 @@ def __init__( task-balanced, otherwise it creates a single data loader for the buffer samples. :param distributed_sampling: If True, apply the PyTorch - :class:`DistributedSampler`. Defaults to False. + :class:`DistributedSampler`. Defaults to True. + Note: the distributed sampler is not applied if not running + a distributed training, even when True is passed. :param kwargs: data loader arguments used to instantiate the loader for each task separately. See pytorch :class:`DataLoader`. """ @@ -700,6 +772,7 @@ def _make_data_loader( force_no_workers: bool = False, ): data_loader_args = data_loader_args.copy() + data_loader_args.pop("ffcv_args", None) collate_from_data_or_kwargs(dataset, data_loader_args) @@ -747,6 +820,8 @@ def _make_data_loader_with_batched_sampler( data_loader_args.pop("sampler", False) data_loader_args.pop("drop_last", False) + data_loader_args.pop("ffcv_args", None) + return DataLoader(dataset, batch_sampler=batch_sampler, **data_loader_args) diff --git a/avalanche/benchmarks/utils/dataset_traversal_utils.py b/avalanche/benchmarks/utils/dataset_traversal_utils.py new file mode 100644 index 000000000..d3f1f82fe --- /dev/null +++ b/avalanche/benchmarks/utils/dataset_traversal_utils.py @@ -0,0 +1,375 @@ +from collections import OrderedDict, defaultdict, deque +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) +from avalanche.benchmarks.scenarios.generic_scenario import CLScenario +from avalanche.benchmarks.utils.data import ( + _FlatDataWithTransform, + AvalancheDataset, +) +from avalanche.benchmarks.utils.dataset_definitions import IDataset +from avalanche.benchmarks.utils.dataset_utils import find_list_from_index +from avalanche.benchmarks.utils.flat_data import FlatData + +from torch.utils.data import Subset, ConcatDataset, Dataset + +from avalanche.benchmarks.utils.transform_groups import EmptyTransformGroups +from avalanche.benchmarks.utils.transforms import TupleTransform +from torchvision.datasets.vision import StandardTransform + + +def dataset_list_from_benchmark(benchmark: CLScenario) -> List[AvalancheDataset]: + """ + Traverse a benchmark and obtain the dataset of each experience. + + This will traverse all streams in alphabetical order. + + :param benchmark: The benchmark to traverse. + :return: The list of datasets. + """ + single_datasets = OrderedDict() + for stream_name in sorted(benchmark.streams.keys()): + stream = benchmark.streams[stream_name] + for experience in stream: + dataset: AvalancheDataset = experience.dataset + if dataset not in single_datasets: + single_datasets[dataset] = dataset + + return list(single_datasets.keys()) + + +def flat_datasets_from_benchmark( + benchmark: CLScenario, include_leaf_transforms: bool = True +): + """ + Obtain a list of flattened datasets from a benchmark. + + In practice, this function will traverse all the + datasets in the benchmark to find the leaf datasets. + A dataset can be traversed and flattened to (one or more) leaf + dataset(s) if all subset and dataset concatenations point to a + single leaf dataset and if transformations are the same across + all paths. + + Traversing the dataset means traversing :class:`AvalancheDataset` + as well as PyTorch :class:`Subset` and :class:`ConcatDataset` to + obtain the leaf datasets, the indices, and the transformations chain. + + Note: this means that datasets will be plain PyTorch datasets, + not :class:`AvalancheDataset` (Avalanche datasets are traversed). + + In common benchmarks, this returns one dataset for the train + and one dataset for test. + + :param benchmark: The benchmark to traverse. + :param include_leaf_transforms: If True, include the transformations + found in the leaf dataset in the transforms list. Defaults to True. + :return: The list of leaf datasets. Each element in the list is + a tuple `(dataset, indices, transforms)`. + """ + single_datasets = dataset_list_from_benchmark(benchmark) + leaves = leaf_datasets( + AvalancheDataset(single_datasets), + include_leaf_transforms=include_leaf_transforms, + ) + + result = [] + for dataset, indices_and_transforms in leaves.items(): + # Check that all transforms are the same + first_transform = indices_and_transforms[0][1] + same_transforms = all([first_transform == t for _, t in indices_and_transforms]) + + if not same_transforms: + for indices, transforms in indices_and_transforms: + result.append((dataset, indices, transforms)) + continue + + flat_indices = [i for i, _ in indices_and_transforms] + + result.append((dataset, flat_indices, first_transform)) + return result + + +T = TypeVar("T") +Y = TypeVar("Y") +TraverseT = Union[Dataset, AvalancheDataset, FlatData, IDataset] + + +def _traverse_supported_dataset_with_intermediate( + dataset: TraverseT, + values_selector: Callable[ + [TraverseT, Optional[List[int]], Optional[T]], Optional[List[Y]] + ], + intermediate_selector: Optional[Callable[[TraverseT, Optional[T]], T]] = None, + intermediate: Optional[T] = None, + indices: Optional[List[int]] = None, +) -> List[Y]: + """ + Traverse the given dataset by gathering required info. + + The given dataset is traversed by covering all sub-datasets + contained in PyTorch :class:`Subset` and :class`ConcatDataset` + as well as :class:`AvalancheDataset`. + + For each dataset, the `values_selector` will be called to gather + the required information. The values returned by the given selector + are then concatenated to create a final list of values. + + While traversing, the `intermediate_selector` (if provided) + will be called to create a chain of intermediate values, which + are passed to `values_selector`. + + :param dataset: The dataset to traverse. + :param values_selector: A function that, given the dataset + and the indices to consider (which may be None if the entire + dataset must be considered), returns a list of selected values. + :returns: The list of selected values. + """ + + if intermediate_selector is not None: + intermediate = intermediate_selector(dataset, intermediate) + + leaf_result: Optional[List[Y]] = values_selector(dataset, indices, intermediate) + + if leaf_result is not None: + if len(leaf_result) == 0: + raise RuntimeError("Empty result") + return leaf_result + + if isinstance(dataset, AvalancheDataset): + return list( + _traverse_supported_dataset_with_intermediate( + dataset._flat_data, + values_selector, + intermediate_selector=intermediate_selector, + indices=indices, + intermediate=intermediate, + ) + ) + + if isinstance(dataset, Subset): + if indices is None: + indices = [dataset.indices[x] for x in range(len(dataset))] + else: + indices = [dataset.indices[x] for x in indices] + + return list( + _traverse_supported_dataset_with_intermediate( + dataset.dataset, + values_selector, + intermediate_selector=intermediate_selector, + indices=indices, + intermediate=intermediate, + ) + ) + + if isinstance(dataset, FlatData) and dataset._indices is not None: + if indices is None: + indices = [dataset._indices[x] for x in range(len(dataset))] + else: + indices = [dataset._indices[x] for x in indices] + + if isinstance(dataset, (ConcatDataset, FlatData)): + result: List[Y] = [] + + concatenated_datasets: Sequence[TraverseT] + if isinstance(dataset, ConcatDataset): + concatenated_datasets = dataset.datasets + else: + concatenated_datasets = dataset._datasets + + if indices is None: + for c_dataset in concatenated_datasets: + result += list( + _traverse_supported_dataset_with_intermediate( + c_dataset, + values_selector, + intermediate_selector=intermediate_selector, + indices=indices, + intermediate=intermediate, + ) + ) + if len(result) == 0: + raise RuntimeError("Empty result") + return result + + datasets_to_indexes = defaultdict(list) + indexes_to_dataset = [] + datasets_len = [] + recursion_result = [] + + all_size = 0 + for c_dataset in concatenated_datasets: + len_dataset = len(c_dataset) + datasets_len.append(len_dataset) + all_size += len_dataset + + for subset_idx in indices: + dataset_idx, pattern_idx = find_list_from_index( + subset_idx, datasets_len, all_size + ) + datasets_to_indexes[dataset_idx].append(pattern_idx) + indexes_to_dataset.append(dataset_idx) + + for dataset_idx, c_dataset in enumerate(concatenated_datasets): + recursion_result.append( + deque( + _traverse_supported_dataset_with_intermediate( + c_dataset, + values_selector, + intermediate_selector=intermediate_selector, + indices=datasets_to_indexes[dataset_idx], + intermediate=intermediate, + ) + ) + ) + + result = [] + for idx in range(len(indices)): + dataset_idx = indexes_to_dataset[idx] + result.append(recursion_result[dataset_idx].popleft()) + + if len(result) == 0: + raise RuntimeError("Empty result") + return result + + raise ValueError("Error: can't find the needed data in the given dataset") + + +def _extract_transforms_from_standard_dataset(dataset): + if hasattr(dataset, "transforms"): + # Has torchvision >= v0.3.0 transforms + # Ignore transform and target_transform + transforms = getattr(dataset, "transforms") + if isinstance(transforms, StandardTransform): + if ( + transforms.transform is not None + or transforms.target_transform is not None + ): + return TupleTransform( + [transforms.transform, transforms.target_transform] + ) + elif hasattr(dataset, "transform") or hasattr(dataset, "target_transform"): + return TupleTransform( + [getattr(dataset, "transform"), getattr(dataset, "target_transform")] + ) + + return None + + +def leaf_datasets(dataset: TraverseT, include_leaf_transforms: bool = True): + """ + Obtains the leaf datasets of a Dataset. + + This is a low level utility. For most use cases, it is better to use + :func:`single_flat_dataset` or :func:`flat_datasets_from_benchmark`. + + :param dataset: The dataset to traverse. + :param include_leaf_transforms: If True, include the transformations + found in the leaf dataset in the transforms list. Defaults to True. + :return: A dictionary mapping each leaf dataset to a list of tuples. + Each tuple contains two elements: the index and the transformation + applied to that exemplar. + """ + + def leaf_selector(subset, indices, transforms): + if isinstance(subset, (AvalancheDataset, FlatData, Subset, ConcatDataset)): + # Returning None => continue traversing + return None + + if indices is None: + indices = range(len(subset)) + + if include_leaf_transforms: + leaf_transforms = _extract_transforms_from_standard_dataset(subset) + + if leaf_transforms is not None: + transforms = list(transforms) + [leaf_transforms] + + return [(subset, idx, transforms) for idx in indices] + + def transform_selector(subset, transforms): + if isinstance(subset, _FlatDataWithTransform): + if subset._frozen_transform_groups is not None and not isinstance( + subset._frozen_transform_groups, EmptyTransformGroups + ): + transforms = list(transforms) + [ + subset._frozen_transform_groups[ + subset._frozen_transform_groups.current_group + ] + ] + if subset._transform_groups is not None and not isinstance( + subset._transform_groups, EmptyTransformGroups + ): + transforms = list(transforms) + [ + subset._transform_groups[subset._transform_groups.current_group] + ] + + return transforms + + leaves = _traverse_supported_dataset_with_intermediate( + dataset, + leaf_selector, + intermediate_selector=transform_selector, + intermediate=[], + ) + + leaves_dict: Dict[Any, List[Tuple[int, Any]]] = defaultdict(list) + for leaf_dataset, idx, transform in leaves: + transform_reversed = list(reversed(transform)) + leaves_dict[leaf_dataset].append((idx, transform_reversed)) + + return leaves_dict + + +def single_flat_dataset(dataset, include_leaf_transforms: bool = True): + """ + Obtains the single leaf dataset of a Dataset. + + A dataset can be traversed and flattened to a single leaf dataset + if all subset and dataset concatenations point to a single leaf + dataset and if transformations are the same across all paths. + + :param dataset: The dataset to traverse. + :param include_leaf_transforms: If True, include the transformations + found in the leaf dataset in the transforms list. Defaults to True. + :return: A tuple containing three elements: the dataset, the list of + indices, and the list of transformations. If the dataset cannot + be flattened to a single dataset, None is returned. + """ + leaves_dict = leaf_datasets( + dataset, include_leaf_transforms=include_leaf_transforms + ) + if len(leaves_dict) != 1: + return None + + # Obtain the single dataset element + dataset = list(leaves_dict.keys())[0] + indices_and_transforms = list(leaves_dict.values())[0] + + # Check that all transforms are the same + first_transform = indices_and_transforms[0][1] + same_transforms = all([first_transform == t for _, t in indices_and_transforms]) + + if not same_transforms: + return None + + flat_indices = [i for i, _ in indices_and_transforms] + + return dataset, flat_indices, first_transform + + +__all__ = [ + "dataset_list_from_benchmark", + "flat_datasets_from_benchmark", + "leaf_datasets", + "single_flat_dataset", +] diff --git a/avalanche/benchmarks/utils/ffcv_support/__init__.py b/avalanche/benchmarks/utils/ffcv_support/__init__.py new file mode 100644 index 000000000..200a4857c --- /dev/null +++ b/avalanche/benchmarks/utils/ffcv_support/__init__.py @@ -0,0 +1 @@ +from .ffcv_components import * diff --git a/avalanche/benchmarks/utils/ffcv_support/ffcv_components.py b/avalanche/benchmarks/utils/ffcv_support/ffcv_components.py new file mode 100644 index 000000000..9a86d66f7 --- /dev/null +++ b/avalanche/benchmarks/utils/ffcv_support/ffcv_components.py @@ -0,0 +1,669 @@ +""" +Components used to enable the FFCV dataloading mechanisms. + +It is usually sufficient to call `enable_ffcv` on the given +benchmark to get started with the FFCV support. + +Please refer to the examples for more details. +""" + +from dataclasses import dataclass +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Optional, + Union, +) +from collections import OrderedDict +import warnings +import numpy as np + +import torch +from torch.utils.data.sampler import Sampler +from avalanche.benchmarks.scenarios.generic_scenario import CLScenario +from avalanche.benchmarks.utils.data import AvalancheDataset +from avalanche.benchmarks.utils.dataset_traversal_utils import ( + flat_datasets_from_benchmark, + single_flat_dataset, +) + +from avalanche.benchmarks.utils.utils import concat_datasets + +if TYPE_CHECKING: + from avalanche.benchmarks.utils.ffcv_support.ffcv_support_internals import ( + EncoderDef, + DecoderDef, + ) + + +FFCV_EXPERIMENTAL_WARNED = False + + +@dataclass +class FFCVInfo: + path: Path + encoder_dictionary: "EncoderDef" + decoder_dictionary: "DecoderDef" + decoder_includes_transformations: bool + device: torch.device + + +def enable_ffcv( + benchmark: CLScenario, + write_dir: Union[str, Path], + device: torch.device, + ffcv_parameters: Dict[str, Any], + force_overwrite: bool = False, + encoder_def: "EncoderDef" = None, + decoder_def: "DecoderDef" = None, + decoder_includes_transformations: Optional[bool] = None, + print_summary: bool = True, +) -> None: + """ + Enables the support for FFCV data loading for the given benchmark. + + Once the support is added, the strategies will create FFCV-based dataloaders + instead of the usual PyTorch-based ones. + + Please note that FFCV is an optional dependency whose installation process + is described in the official FFCV website. + + This function supposes that the benchmark is based on a few base datasets + (usually one for train and one for test). This is the case for Split-* benchmarks + and is also the usual case for the vast majority of benchmarks. The support for + "sparse" datasets such as CORe50 will be added in the near future. + + When this function is first called on a benchmark, the underlying datasets are + serialized on disk. If a `encoder_def` is given, that definition is used. Otherwise, + a definition is searched in the leaf dataset (`_ffcv_encoder` field, if available). + If such a definition is not found, it is created automatically. + Refer to the FFCV documentation for more details on the encoder pipeline. + + Please note that the serialized datasets are independent of the benchmark seed, + number of experiences, presence of task labels, etcetera. This means that the + same folder can be reused for the same benchmark type. + + The definition of the decoder pipeline is created later, if not + given using `decoder_def`. However, creating the decoder pipeline is a complex + task and not all field types and transformations are fully supported. Consider + passing an explicit `decoder_def` in case of unexpected outputs. If a decoder + definition is not passed explicitly, Avalanche will try to use the dataset + `_ffcv_decoder` field if available before attempting to create one automatically. + + See the `ffcv` examples for more info on how to tune the decoder definitions and for + examples of advanced use of the FFCV support. + + :param benchmark: The benchmark for which the support for FFCV loader should be enabled. + :param write_dir: Where the datasets should be serialized in FFCV format. + :param device: The device used for training. + :param ffcv_parameters: Parameters to be passed to FFCV writer and RGB fields. + :param force_overwrite: If True, serialized datasets already found in `write_dir` will be + overwritten. + :param encoder_def: The definition of the dataset fields. See the FFCV guide for more details. + :param decoder_def: The definition of the decoder pipeline. If not None, then + `decoder_includes_transformations` must be passed. + :param decoder_includes_transformations: If True, then Avalanche will treat `decoder_def` as + the complete pipeline, transformations included. If False, Avalanche will suppose that only + the decoder is passed for each field and transformations will be translated by Avalanche + from the torchvision ones. + :param print_summary: If True (default), will print some verbose info to stdout regaring the + datasets and pipelines. Once you have a complete working FFCV pipeline, you can consider + setting this to False. + """ + global FFCV_EXPERIMENTAL_WARNED + + if not FFCV_EXPERIMENTAL_WARNED: + warnings.warn("The support for FFCV is experimental. Use at your own risk!") + FFCV_EXPERIMENTAL_WARNED = True + + from ffcv.writer import DatasetWriter + from avalanche.benchmarks.utils.ffcv_support.ffcv_support_internals import ( + _make_ffcv_decoder, + _make_ffcv_encoder, + ) + + if decoder_def is not None: + if decoder_includes_transformations is None: + raise ValueError( + "When defining the decoder pipeline, " + "please specify `decoder_includes_transformations`" + ) + assert isinstance(decoder_includes_transformations, bool) + + if decoder_includes_transformations is None: + decoder_includes_transformations = False + + write_dir = Path(write_dir) + write_dir.mkdir(exist_ok=True, parents=True) + + flattened_datasets = flat_datasets_from_benchmark(benchmark) + + if print_summary: + print("FFCV will serialize", len(flattened_datasets), "datasets") + + for idx, (dataset, _, _) in enumerate(flattened_datasets): + if print_summary: + print("-" * 25, "Dataset", idx, "-" * 25) + + # Note: it is appropriate to serialize the dataset in its raw + # version (without transformations). Transformations will be + # applied at loading time. + with _SuppressTransformations(dataset): + dataset_ffcv_path = write_dir / f"dataset{idx}.beton" + + # Obtain the encoder dictionary + # The FFCV encoder is a ordered dictionary mapping each + # field (by name) to the field encoder. + # + # Example: + # { + # 'image': RGBImageField(), + # 'label: IntField() + # } + # + # Some fields (especcially the RGBImageField) accept + # some parameters that are here contained in ffcv_parameters. + encoder_dict = _make_ffcv_encoder(dataset, encoder_def, ffcv_parameters) + + if encoder_dict is None: + raise RuntimeError( + "Could not create the encoder pipeline for " "the given dataset" + ) + + if print_summary: + print("### Encoder ###") + for field_name, encoder_pipeline in encoder_dict.items(): + print(f'Field "{field_name}"') + print("\t", encoder_pipeline) + + # Obtain the decoder dictionary + # The FFCV decoder is a ordered dictionary mapping each + # field (by name) to the field pipeline. + # A field pipeline is made of a decoder followed by + # transformations. + # + # Example: + # { + # 'image': [ + # SimpleRGBImageDecoder(), + # RandomHorizontalFlip(), + # ToTensor(), + # ... + # ], + # 'label: [IntDecoder(), ToTensor(), Squeeze(), ...] + # } + # + # However, unless the user specified a full custom decoder + # pipeline, Avalanche will obtain only the decoder for each + # field. The transformations, which may vary, will be added by the + # data loader. + decoder_dict = _make_ffcv_decoder( + dataset, decoder_def, ffcv_parameters, encoder_dictionary=encoder_dict + ) + + if decoder_dict is None: + raise RuntimeError( + "Could not create the decoder pipeline " "for the given dataset" + ) + + if print_summary: + print("### Decoder ###") + for field_name, decoder_pipeline in decoder_dict.items(): + print(f'Field "{field_name}"') + for pipeline_element in decoder_pipeline: + print("\t", pipeline_element) + + if decoder_includes_transformations: + print("This pipeline already includes transformations") + else: + print("This pipeline does not include transformations") + + if force_overwrite or not dataset_ffcv_path.exists(): + if print_summary: + print("Serializing dataset to:", str(dataset_ffcv_path)) + + writer_kwarg_parameters = dict() + if "page_size" in ffcv_parameters: + writer_kwarg_parameters["page_size"] = ffcv_parameters["page_size"] + + if "num_workers" in ffcv_parameters: + writer_kwarg_parameters["num_workers"] = ffcv_parameters[ + "num_workers" + ] + + writer = DatasetWriter( + str(dataset_ffcv_path), + OrderedDict(encoder_dict), + **writer_kwarg_parameters, + ) + writer.from_indexed_dataset(dataset) + + if print_summary: + print("Dataset serialized successfully") + + # Set the FFCV file path and encoder/decoder dictionaries + # Those will be used later in the data loading process and may + # also be useful for debugging purposes + dataset.ffcv_info = FFCVInfo( + dataset_ffcv_path, + encoder_dict, + decoder_dict, + decoder_includes_transformations, + torch.device(device), + ) + + if print_summary: + print("-" * 61) + + +class _SuppressTransformations: + """ + Suppress the transformations of a dataset. + + This will act on the transformation fields. + + Note: there are no ways to suppress hard coded transformations + or transformations held in fields with custom names. + """ + + SUPPRESS_FIELDS = ["transform", "target_transform", "transforms"] + + def __init__(self, dataset): + self.dataset = dataset + self._held_out_transforms = dict() + + def __enter__(self): + self._held_out_transforms = dict() + for transform_field in _SuppressTransformations.SUPPRESS_FIELDS: + if hasattr(self.dataset, transform_field): + field_content = getattr(self.dataset, transform_field) + self._held_out_transforms[transform_field] = field_content + setattr(self.dataset, transform_field, None) + + def __exit__(self, *_): + for transform_field, field_content in self._held_out_transforms.items(): + setattr(self.dataset, transform_field, field_content) + self._held_out_transforms.clear() + + +class _GetItemDataset: + def __init__( + self, + dataset: AvalancheDataset, + reversed_indices: Dict[int, int], + collate_fn=None, + ): + self.dataset: AvalancheDataset = dataset + self.reversed_indices: Dict[int, int] = reversed_indices + + all_data_attributes = self.dataset._data_attributes.values() + self.get_item_data_attributes = list( + filter(lambda x: x.use_in_getitem, all_data_attributes) + ) + + self.collate_fn = ( + collate_fn if collate_fn is not None else self.dataset.collate_fn + ) + + if self.collate_fn is None: + raise RuntimeError("Undefined collate function") + + def __getitem__(self, indices): + elements_from_attributes = [] + for idx in indices: + reversed_idx = self.reversed_indices[int(idx)] + values = [] + for da in self.get_item_data_attributes: + values.append(da[reversed_idx]) + elements_from_attributes.append(tuple(values)) + + return tuple(self.collate_fn(elements_from_attributes)) + + +def has_ffcv_support(datasets: List[AvalancheDataset]): + """ + Checks if the support for FFCV was enabled for the given + dataset list. + + This will 1) check if all the given :class:`AvalancheDataset` + point to the same leaf dataset and 2) if the leaf dataset + has the proper FFCV info setted by the :func:`enable_ffcv` + function. + + :param dataset: The list of datasets. + :return: True if FFCV can be used to load the given datasets, + False otherwise. + """ + try: + flat_set = single_flat_dataset(concat_datasets(datasets)) + except Exception: + return False + + if flat_set is None: + return False + + leaf_dataset = flat_set[0] + + return hasattr(leaf_dataset, "ffcv_info") + + +class _MappedBatchsampler(Sampler[List[int]]): + """ + Internal utility to better support the `set_epoch` method in FFCV. + + This is a wrapper of a batch sampler that may be based on a PyTorch + :class:`DistributedSampler`. This allows passing the `set_epoch` + call to the underlying sampler. + """ + + def __init__(self, batch_sampler: Sampler[List[int]], indices): + self.batch_sampler = batch_sampler + self.indices = indices + + def __iter__(self): + for batch in self.batch_sampler: + batch_mapped = [self.indices[int(x)] for x in batch] + yield np.array(batch_mapped) + + def __len__(self): + return len(self.batch_sampler) + + def set_epoch(self, epoch: int): + if hasattr(self.batch_sampler, "set_epoch"): + self.batch_sampler.set_epoch(epoch) + else: + if hasattr(self.batch_sampler, "sampler"): + if hasattr(self.batch_sampler.sampler, "set_epoch"): + self.batch_sampler.sampler.set_epoch(epoch) + + +class HybridFfcvLoader: + """ + A dataloader used to load :class:`AvalancheDataset`s for which + the FFCV support was previously enabled by using :func:`enable_ffcv`. + + This is not a pure wrapper of a FFCV loader: this hybrid dataloader + is in charge of both creating the FFCV loader and merging + the Avalanche-specific info contained in the :class:`DataAttribute` + fields of the datasets (such as task labels). + """ + + ALREADY_COVERED_PARAMS = set( + ( + "fname", + "batch_size", + "order", + "distributed", + "seed", + "indices", + "pipelines", + ) + ) + + VALID_FFCV_PARAMS = set( + ( + "fname", + "batch_size", + "num_workers", + "os_cache", + "order", + "distributed", + "seed", + "indices", + "pipelines", + "custom_fields", + "drop_last", + "batches_ahead", + "recompile", + ) + ) + + def __init__( + self, + dataset: AvalancheDataset, + batch_sampler: Iterable[List[int]], + ffcv_loader_parameters: Dict[str, Any], + device: Optional[Union[str, torch.device]] = None, + persistent_workers: bool = False, + print_ffcv_summary: bool = True, + start_immediately: bool = False, + ): + """ + Creates an instance of the Avalanche-FFCV hybrid dataloader. + + :param dataset: The dataset to be loaded. + :param batch_sampler: The batch sampler to use. + :param ffcv_loader_parameters: The FFCV-specific parameters to pass to + the FFCV loader. Should not contain the elements such as `fname`, + `batch_size`, `order`, and all the parameters listed in the + `ALREADY_COVERED_PARAMS` class field, as they are already set by Avalanche. + :param device: The target device. + :param persistent_workers: If True, this loader will not re-create the FFCV loader + between epochs. Defaults to False. + :param print_ffcv_summary: If True, a summary of the decoder pipeline (and additional + useful info) will be printed. Defaults to True. + :param start_immediately: If True, the FFCV loader should be started immediately. + Defaults to False. + """ + from avalanche.benchmarks.utils.ffcv_support.ffcv_loader import _CustomLoader + + self.dataset: AvalancheDataset = dataset + self.batch_sampler = batch_sampler + self.ffcv_loader_parameters = ffcv_loader_parameters + self.persistent_workers: bool = persistent_workers + + for param_name in HybridFfcvLoader.ALREADY_COVERED_PARAMS: + if param_name in self.ffcv_loader_parameters: + warnings.warn( + f"`{param_name}` should not be passed to the ffcv loader!" + ) + + if print_ffcv_summary: + print("-" * 15, "HybridFfcvLoader summary", "-" * 15) + + ffcv_info = self._extract_ffcv_info( + dataset=self.dataset, device=device, print_summary=print_ffcv_summary + ) + + if print_ffcv_summary: + print("-" * 56) + + ( + self.ffcv_dataset_path, + self.ffcv_decoder_dictionary, + self.leaf_indices, + self.get_item_dataset, + self.device, + ) = ffcv_info + + self._persistent_loader: Optional["_CustomLoader"] = None + + if start_immediately: + # If persistent_workers is False, this loader will be + # used at first __iter__ and immediately set to None + self._persistent_loader = self._make_loader() + + @staticmethod + def _extract_ffcv_info( + dataset: AvalancheDataset, + device: Optional[Union[str, torch.device]] = None, + print_summary: bool = True, + ): + from avalanche.benchmarks.utils.ffcv_support.ffcv_transform_utils import ( + adapt_transforms, + check_transforms_consistency, + ) + + # Obtain the leaf dataset, the indices, + # and the transformations to apply + flat_set_def = single_flat_dataset(dataset) + if flat_set_def is None: + raise RuntimeError("The dataset cannot be traversed to the leaf dataset.") + + leaf_dataset, indices, transforms = flat_set_def + if print_summary: + print( + "The input AvalancheDataset is a subset of the leaf dataset", + leaf_dataset, + ) + print("The input dataset contains", len(indices), "elements") + print("The original chain of transformations is:") + for t in transforms: + print("\t", t) + print("Will try to translate those transformations to FFCV") + + ffcv_info: FFCVInfo = leaf_dataset.ffcv_info + + ffcv_dataset_path = ffcv_info.path + ffcv_decoder_dictionary = ffcv_info.decoder_dictionary + decoder_includes_transformations = ffcv_info.decoder_includes_transformations + + if device is None: + device = ffcv_info.device + device = torch.device(device) + + # Map the indices so that we know how leaf + # dataset indices are mapped in the AvalancheDataset + reversed_indices = dict() + for avl_idx, leaf_idx in enumerate(indices): + reversed_indices[leaf_idx] = avl_idx + + # We will use the GetItemDataset to get those Avalanche-specific + # dynamic fields that are not loaded by FFCV, such as the task label + get_item_dataset = _GetItemDataset(dataset, reversed_indices=reversed_indices) + + if print_summary: + if len(get_item_dataset.get_item_data_attributes) > 0: + print( + "The following data attributes are returned in " + "the example tuple:" + ) + for da in get_item_dataset.get_item_data_attributes: + print("\t", da.name) + else: + print("No data attributes are returned in the example tuple.") + + # Defensive copy + # Alas, FFCV Loader internally modifies it, so this is also + # needed when decoder_includes_transformations is True + ffcv_decoder_dictionary = OrderedDict(ffcv_decoder_dictionary) + + if not decoder_includes_transformations: + # Adapt the transformations (usually from torchvision) to FFCV. + # Most torchvision transformations cannot be mapped to FFCV ones, + # but they still work. + ffcv_decoder_dictionary_lst = list(ffcv_decoder_dictionary.values()) + + adapted_transforms = adapt_transforms( + transforms, ffcv_decoder_dictionary_lst, device=device + ) + + for i, field_name in enumerate(ffcv_decoder_dictionary.keys()): + ffcv_decoder_dictionary[field_name] = adapted_transforms[i] + + for field_name, field_decoder in ffcv_decoder_dictionary.items(): + if print_summary: + print(f'Checking pipeline for field "{field_name}"') + no_issues = check_transforms_consistency(field_decoder) + + if print_summary and no_issues: + print(f"No issues for this field") + + if print_summary: + print("### The final chain of transformations is: ###") + for field_name, field_transforms in ffcv_decoder_dictionary.items(): + print(f'Field "{field_name}":') + for t in field_transforms: + print("\t", t) + + return ( + ffcv_dataset_path, + ffcv_decoder_dictionary, + indices, + get_item_dataset, + device, + ) + + def _make_loader(self): + from ffcv.loader import OrderOption + from avalanche.benchmarks.utils.ffcv_support.ffcv_loader import _CustomLoader + + ffcv_dataset_path = self.ffcv_dataset_path + ffcv_decoder_dictionary = OrderedDict(self.ffcv_decoder_dictionary) + leaf_indices = list(self.leaf_indices) + + return _CustomLoader( + str(ffcv_dataset_path), + batch_size=len(leaf_indices) // len(self.batch_sampler), # Not used + indices=leaf_indices, + order=OrderOption.SEQUENTIAL, + pipelines=ffcv_decoder_dictionary, + batch_sampler=_MappedBatchsampler(self.batch_sampler, leaf_indices), + **self.ffcv_loader_parameters, + ) + + def __iter__(self): + from avalanche.benchmarks.utils.ffcv_support.ffcv_epoch_iterator import ( + _CustomEpochIterator, + ) + + get_item_dataset = self.get_item_dataset + + # Instantiate the FFCV loader + if self._persistent_loader is not None: + ffcv_loader = self._persistent_loader + + if not self.persistent_workers: + # Corner case: + # This may happen if start_immediately is True + # but persistent_workers is False + self._persistent_loader = None + else: + ffcv_loader = self._make_loader() + + if self.persistent_workers: + self._persistent_loader = ffcv_loader + + epoch_iterator: "_CustomEpochIterator" = iter(ffcv_loader) + + for indices, batch in epoch_iterator: + # Before returning the batch, obtain the custom Avalanche values + # and add it to the batch. + # Those are the values not found in the FFCV dataset + # (and not stored on disk!). + # + # A common element is the task label, which is usually returned + # as the third element. + # + # In practice, those fields are "data attributes" + # of the input AvalancheDataset whose `use_in_getitem` + # field is True. + # + # This means in practice: + # 1. obtain the `batch` from FFCV (usually is a tuple `x, y`). + # 2. obtain the Avalanche values such as `t` (or others). + # We do this through the `get_item_dataset`. + # 3. create an overall tuple `x, y, t, ...`. + + elements_from_attributes = get_item_dataset[indices] + + elements_from_attributes_device = [] + + for element in elements_from_attributes: + if isinstance(element, torch.Tensor): + element = element.to(self.device, non_blocking=True) + elements_from_attributes_device.append(element) + + overall_batch = tuple(batch) + tuple(elements_from_attributes_device) + + yield overall_batch + + def __len__(self): + return len(self.batch_sampler) + + +__all__ = ["enable_ffcv", "has_ffcv_support", "HybridFfcvLoader"] diff --git a/avalanche/benchmarks/utils/ffcv_support/ffcv_epoch_iterator.py b/avalanche/benchmarks/utils/ffcv_support/ffcv_epoch_iterator.py new file mode 100644 index 000000000..7c36b633f --- /dev/null +++ b/avalanche/benchmarks/utils/ffcv_support/ffcv_epoch_iterator.py @@ -0,0 +1,119 @@ +""" +Custom version of the FFCV epoch iterator. +""" +from threading import Thread, Event, Lock +from queue import Queue +from typing import List, Sequence, TYPE_CHECKING + +from ffcv.traversal_order.quasi_random import QuasiRandom +from ffcv.loader.epoch_iterator import ( + EpochIterator as FFCVEpochIterator, + QUASIRANDOM_ERROR_MSG, +) + +import torch + +if TYPE_CHECKING: + from avalanche.benchmarks.utils.ffcv_support.ffcv_loader import _CustomLoader + +IS_CUDA = torch.cuda.is_available() + + +class AtomicCounter: + """ + An atomic, thread-safe incrementing counter. + + Based on: + https://gist.github.com/benhoyt/8c8a8d62debe8e5aa5340373f9c509c7 + """ + + def __init__(self): + """Initialize a new atomic counter to 0.""" + self.value = 0 + self._lock = Lock() + + def increment(self): + """ + Atomically increment the counter by 1 and return the + previous value. + """ + with self._lock: + prev_value = self.value + self.value += 1 + return prev_value + + +class _QueueWithIndex(Queue): + """ + A Python Queue that also returns the index of the inserted element. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._counter = AtomicCounter() + + def _put(self, item): + item_index = self._counter.increment() + super()._put((item_index, item)) + + +class _CustomEpochIterator(FFCVEpochIterator, Thread): + def __init__(self, loader: "_CustomLoader", batches: Sequence[List[int]]): + Thread.__init__(self, daemon=True) + self.loader: "_CustomLoader" = loader + self.metadata = loader.reader.metadata + self.current_batch_slot = 0 + self.batches = batches + self.iter_ixes = iter(batches) + self.closed = False + self.output_queue = _QueueWithIndex(self.loader.batches_ahead) + self.terminate_event = Event() + self.memory_context = self.loader.memory_manager.schedule_epoch(batches) + + if IS_CUDA: + self.current_stream = torch.cuda.current_stream() + + try: + self.memory_context.__enter__() + except MemoryError as e: + if not isinstance(loader.traversal_order, QuasiRandom): + print(QUASIRANDOM_ERROR_MSG) + print("Full error below:") + + raise e + + self.storage_state = self.memory_context.state + + self.cuda_streams = [ + (torch.cuda.Stream() if IS_CUDA else None) + for _ in range(self.loader.batches_ahead + 2) + ] + + max_batch_size = max(map(len, batches), default=0) + + self.memory_allocations = self.loader.graph.allocate_memory( + max_batch_size, self.loader.batches_ahead + 2 + ) + + self.start() + + def __next__(self): + result = self.output_queue.get() + batch_index, result = result + + if result is None: + self.close() + raise StopIteration() + + slot, result = result + indices = list(self.batches[batch_index]) + + if IS_CUDA: + stream = self.cuda_streams[slot] + # We wait for the copy to be done + self.current_stream.wait_stream(stream) + + return indices, result + + +__all__ = ["_CustomEpochIterator"] diff --git a/avalanche/benchmarks/utils/ffcv_support/ffcv_loader.py b/avalanche/benchmarks/utils/ffcv_support/ffcv_loader.py new file mode 100644 index 000000000..8f2ae9f10 --- /dev/null +++ b/avalanche/benchmarks/utils/ffcv_support/ffcv_loader.py @@ -0,0 +1,178 @@ +""" +Custom version of the FFCV loader that accepts a batch sampler. +""" +from typing import Any, Callable, List, Mapping, Optional, Sequence, Type, Union +import warnings + +from ffcv.fields.base import Field + +import torch as ch + +from torch.utils.data.sampler import BatchSampler, Sampler + +from ffcv.loader.loader import ( + Loader as FFCVLoader, + OrderOption, + ORDER_TYPE, + DEFAULT_OS_CACHE, +) + +from ffcv.traversal_order.base import TraversalOrder +from ffcv.pipeline.operation import Operation +from ffcv.pipeline import Compiler + +from avalanche.benchmarks.utils.ffcv_support.ffcv_epoch_iterator import ( + _CustomEpochIterator, +) + + +class _TraversalOrderAsSampler(Sampler[int]): + def __init__(self, traversal_order: TraversalOrder): + self.traversal_order: TraversalOrder = traversal_order + self.current_epoch: int = 0 + + def __iter__(self): + yield from self.traversal_order.sample_order(self.current_epoch) + + def __len__(self): + return len(self.traversal_order.indices) + + def set_epoch(self, epoch: int): + self.current_epoch = epoch + + +class _CustomLoader(FFCVLoader): + """ + Customized FFCV loader class that can be used as a drop-in replacement + for standard (e.g. PyTorch) data loaders. + + Differently from the original FFCV loader, this version also accepts a batch sampler. + + Parameters + ---------- + fname: str + Full path to the location of the dataset (.beton file format). + batch_size : int + Batch size. + num_workers : int + Number of workers used for data loading. Consider using the actual number of cores instead of the number of threads if you only use JITed augmentations as they usually don't benefit from hyper-threading. + os_cache : bool + Leverages the operating for caching purposes. This is beneficial when there is enough memory to cache the dataset and/or when multiple processes on the same machine training using the same dataset. See https://docs.ffcv.io/performance_guide.html for more information. + order : Union[OrderOption, TraversalOrder] + Traversal order, one of: SEQEUNTIAL, RANDOM, QUASI_RANDOM, or a custom TraversalOrder + + QUASI_RANDOM is a random order that tries to be as uniform as possible while minimizing the amount of data read from the disk. Note that it is mostly useful when `os_cache=False`. Currently unavailable in distributed mode. + distributed : bool + For distributed training (multiple GPUs). Emulates the behavior of DistributedSampler from PyTorch. + seed : int + Random seed for batch ordering. + indices : Sequence[int] + Subset of dataset by filtering only some indices. + pipelines : Mapping[str, Sequence[Union[Operation, torch.nn.Module]] + Dictionary defining for each field the sequence of Decoders and transforms to apply. + Fileds with missing entries will use the default pipeline, which consists of the default decoder and `ToTensor()`, + but a field can also be disabled by explicitly by passing `None` as its pipeline. + custom_fields : Mapping[str, Field] + Dictonary informing the loader of the types associated to fields that are using a custom type. + drop_last : bool + Drop non-full batch in each iteration. + batches_ahead : int + Number of batches prepared in advance; balances latency and memory. + recompile : bool + Recompile every iteration. This is necessary if the implementation of some augmentations are expected to change during training. + batch_sampler : BatchSampler + If not None, will ignore `batch_size`, `indices`, `drop_last` and will use this sampler instead. + The batch sampler must be an iterable that outputs lists of int (the indices of examples to include in each batch). + When running in a distributed training setup, the BatchSampler should already wrap a DistributedSampler. + """ + + def __init__( + self, + fname: str, + batch_size: int, + num_workers: int = -1, + os_cache: bool = DEFAULT_OS_CACHE, + order: Union[ORDER_TYPE, TraversalOrder] = OrderOption.SEQUENTIAL, + distributed: bool = False, + seed: Optional[int] = None, # For ordering of samples + indices: Optional[Sequence[int]] = None, # For subset selection + pipelines: Mapping[str, Sequence[Union[Operation, ch.nn.Module]]] = {}, + custom_fields: Mapping[str, Type[Field]] = {}, + drop_last: bool = True, + batches_ahead: int = 3, + recompile: bool = False, # Recompile at every epoch + batch_sampler: Optional[Sampler[List[int]]] = None, + ): + # Set batch sampler to an empty list so that next_traversal_order() + # and __len__() work when running super().__init__(...) + self.batch_sampler: Sampler[List[int]] = [] + + super().__init__( + fname=fname, + batch_size=batch_size, + num_workers=num_workers, + os_cache=os_cache, + order=order, + distributed=distributed, + seed=seed, + indices=indices, + pipelines=pipelines, + custom_fields=custom_fields, + drop_last=drop_last, + batches_ahead=batches_ahead, + recompile=recompile, + ) + + self._args["batch_sampler"] = batch_sampler + + if batch_sampler is None: + batch_sampler = BatchSampler( + _TraversalOrderAsSampler(self.traversal_order), + batch_size=batch_size, + drop_last=drop_last, + ) + + self.batch_sampler = batch_sampler + + def next_traversal_order(self): + # Manage distributed sampler, which has to know the id of the current epoch + self._batch_sampler_set_epoch() + + return list(self.batch_sampler) + + def __iter__(self): + Compiler.set_num_threads(self.num_workers) + order = self.next_traversal_order() + self.next_epoch += 1 + + # Compile at the first epoch + if self.code is None or self.recompile: + self.generate_code() + + return _CustomEpochIterator(self, order) + + def filter(self, field_name: str, condition: Callable[[Any], bool]) -> "FFCVLoader": + if self._args["batch_sampler"] is not None: + warnings.warn( + "The original loader was created by passing a batch sampler. " + "The filtered loader will not inherit the sampler!" + ) + + return super().filter(field_name, condition) + + def __len__(self): + return len(self.batch_sampler) + + def _batch_sampler_set_epoch(self): + if hasattr(self.batch_sampler, "set_epoch"): + # Supports batch samplers with set_epoch method + self.batch_sampler.set_epoch(self.next_epoch) + else: + # Standard setup: the batch sampler wraps a TraversalOrder or + # a distributed sampler + if hasattr(self.batch_sampler, "sampler"): + if hasattr(self.batch_sampler.sampler, "set_epoch"): + self.batch_sampler.sampler.set_epoch(self.next_epoch) + + +__all__ = ["_CustomLoader"] diff --git a/avalanche/benchmarks/utils/ffcv_support/ffcv_support_internals.py b/avalanche/benchmarks/utils/ffcv_support/ffcv_support_internals.py new file mode 100644 index 000000000..e8ff6e098 --- /dev/null +++ b/avalanche/benchmarks/utils/ffcv_support/ffcv_support_internals.py @@ -0,0 +1,344 @@ +""" +Internal utils needed to enable the support for FFCV in Avalanche. +""" + +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Union, +) +from collections import OrderedDict +import numpy as np + +from torch import Tensor + +from PIL.Image import Image + +from ffcv.fields import TorchTensorField +from ffcv.fields.decoders import ( + IntDecoder, + FloatDecoder, + NDArrayDecoder, + SimpleRGBImageDecoder, +) + + +if TYPE_CHECKING: + from ffcv.fields import Field + from ffcv.pipeline.operation import Operation + + FFCVEncodeDef = OrderedDict[str, Field] + FFCVDecodeDef = OrderedDict[str, List[Operation]] + + FFCVParameters = Dict[str, Any] + EncoderDef = Optional[ + Union["FFCVEncodeDef", Callable[[FFCVParameters], "FFCVEncodeDef"]] + ] + DecoderDef = Optional[ + Union["FFCVDecodeDef", Callable[[FFCVParameters], "FFCVDecodeDef"]] + ] + + +def _image_encoder(ffcv_parameters: "FFCVParameters"): + """ + Create a :class:`RGBImageField` given additional + parameters passed by the user. Follows the FFCV defaults. + + :param ffcv_parameters: The additional parameters passed + to the :func:`enable_ffcv` function. + :return: A :class:`RGBImageField` instance. + """ + from ffcv.fields import RGBImageField + + return RGBImageField( + write_mode=ffcv_parameters.get("write_mode", "raw"), + max_resolution=ffcv_parameters.get("max_resolution", None), + smart_threshold=ffcv_parameters.get("smart_threshold", None), + jpeg_quality=ffcv_parameters.get("jpeg_quality", 90), + compress_probability=ffcv_parameters.get("compress_probability", 0.5), + ) + + +def _ffcv_infer_encoder(value, ffcv_parameters: "FFCVParameters") -> Optional["Field"]: + """ + Infers the field encoder definition from a given example. + + :param value: The example obtained from the dataset. + :param ffcv_parameters: The additional parameters passed + to the :func:`enable_ffcv` function. + :return: A :class:`Field` instance or None if it cannot be + inferred. + """ + from ffcv.fields import ( + IntField, + FloatField, + NDArrayField, + TorchTensorField, + ) + + if isinstance(value, int): + return IntField() + + if isinstance(value, float): + return FloatField() + + if isinstance(value, np.ndarray): + return NDArrayField(value.dtype, shape=value.shape) + + if isinstance(value, Tensor): + return TorchTensorField(value.dtype, shape=value.shape) + + if isinstance(value, Image): + return _image_encoder(ffcv_parameters) + + return None + + +def _ffcv_infer_decoder( + value, + ffcv_parameters: "FFCVParameters", + encoder: Optional["Field"] = None, + add_common_collate: bool = True, +) -> Optional[List["Operation"]]: + """ + Infers the field decoder definition from a given example. + + :param value: The example obtained from the dataset. + :param ffcv_parameters: The additional parameters passed + to the :func:`enable_ffcv` function. + :param encoder: If not None, will try to infer the decoder + definition from the field. + :param add_common_collate: If True, will apply a PyTorch-alike + collate to int and float fields so that they end up being + a flat PyTorch tensor instead of a list of int/float. + :return: The decoder pipeline as a list of :class:`Operation` + or None if the decoder pipeline cannot be inferred. + """ + from ffcv.transforms import ToTensor, Squeeze + + if encoder is not None: + if isinstance(encoder, TorchTensorField): + return [NDArrayDecoder(), ToTensor()] + + encoder_class = encoder.get_decoder_class() + pipeline: List["Operation"] = [encoder_class()] + if add_common_collate and encoder_class in [IntDecoder, FloatDecoder]: + pipeline.extend((ToTensor(), Squeeze())) + return pipeline + + if isinstance(value, int): + pipeline: List["Operation"] = [IntDecoder()] + + if add_common_collate: + pipeline.extend((ToTensor(), Squeeze())) + return pipeline + + if isinstance(value, float): + pipeline: List["Operation"] = [FloatDecoder()] + + if add_common_collate: + pipeline.extend((ToTensor(), Squeeze())) + return pipeline + + if isinstance(value, np.ndarray): + return [NDArrayDecoder()] + + if isinstance(value, Tensor): + return [NDArrayDecoder(), ToTensor()] + + if isinstance(value, Image): + return [SimpleRGBImageDecoder()] + + return None + + +def _check_dataset_ffcv_encoder(dataset) -> "EncoderDef": + """ + Returns the dataset-specific FFCV encoder definition, if available. + """ + encoder_fn_or_def = getattr(dataset, "_ffcv_encoder", None) + return encoder_fn_or_def + + +def _check_dataset_ffcv_decoder(dataset) -> "DecoderDef": + """ + Returns the dataset-specific FFCV decoder definition, if available. + """ + decoder_fn_or_def = getattr(dataset, "_ffcv_decoder", None) + return decoder_fn_or_def + + +def _encoder_infer_all( + dataset, ffcv_parameters: "FFCVParameters" +) -> Optional["FFCVEncodeDef"]: + """ + Infer the encoder pipeline from the dataset. + + :param dataset: The dataset to use. Must have at least + one example. + :param ffcv_parameters: The additional parameters passed + to the :func:`enable_ffcv` function. + :return: The encoder pipeline or None if it could not be inferred. + """ + dataset_item = dataset[0] + + types = [] + + # Try to infer the field type for each element + for item in dataset_item: + inferred_type = _ffcv_infer_encoder(item, ffcv_parameters) + + if inferred_type is None: + return None + + types.append(inferred_type) + + # Type inferred for all fields + # Let's apply a generic name and return the dictionary + result = OrderedDict() + for i, t in enumerate(types): + result[f"field_{i}"] = t + + return result + + +def _decoder_infer_all( + dataset, + ffcv_parameters: "FFCVParameters", + encoder_dictionary: Optional["FFCVEncodeDef"] = None, +) -> Optional["FFCVDecodeDef"]: + """ + Infer the decoder pipeline from the dataset. + + :param dataset: The dataset to use. Must have at least + one example. + :param ffcv_parameters: The additional parameters passed + to the :func:`enable_ffcv` function. + :param encoder_dictionary: If not None, will be used as a + basis to create the decoder pipeline. + :return: The decoder pipeline or None if it could not be inferred. + """ + dataset_item: Sequence[Any] = dataset[0] + + types: List[List["Operation"]] = [] + + encoder_hints: List[Optional["Field"]] = [] + field_names: List[str] + + if encoder_dictionary is None: + encoder_hints = [None] * len(dataset_item) + field_names = [f"field_{i}" for i in range(len(dataset_item))] + else: + if len(encoder_dictionary) != len(dataset_item): + raise ValueError("Wrong number of elements in encoder dictionary.") + + encoder_hints.extend(encoder_dictionary.values()) + field_names = list(encoder_dictionary.keys()) + + # Try to infer the field type for each element + for item, field_encoder in zip(dataset_item, encoder_hints): + inferred_type = _ffcv_infer_decoder( + item, ffcv_parameters, encoder=field_encoder + ) + + if inferred_type is None: + return None + + types.append(inferred_type) + + # Type inferred for all fields + # Let's apply the name and return the dictionary + result = OrderedDict() + for t, field_name in zip(types, field_names): + result[field_name] = t + + return result + + +def _make_ffcv_encoder( + dataset, user_encoder_def: "EncoderDef", ffcv_parameters: "FFCVParameters" +) -> Optional["FFCVEncodeDef"]: + """ + Infer the encoder pipeline from either a user definition, + the dataset `_ffcv_encoder` field, of from the examples format. + + :param dataset: The dataset to use. Must have at least + one example to attempt an inference for data format. + :param user_encoder_def: An optional user-given encoder definition. + Can be a dictionary or callable that accepts the ffcv parameters + and returns the encoder dictionary. + :param ffcv_parameters: The additional parameters passed + to the :func:`enable_ffcv` function. + :return: The encoder pipeline or None if it could not be inferred. + """ + encoder_def = None + + # Use the user-provided pipeline / pipeline factory + if user_encoder_def is not None: + encoder_def = user_encoder_def + if callable(encoder_def): + encoder_def = encoder_def(ffcv_parameters) + + # Check if the dataset has an explicit field/method + if encoder_def is None: + encoder_def = _check_dataset_ffcv_encoder(dataset) + if callable(encoder_def): + encoder_def = encoder_def(ffcv_parameters) + + # Try to infer the pipeline from the dataset + if encoder_def is None: + encoder_def = _encoder_infer_all(dataset, ffcv_parameters) + + return encoder_def + + +def _make_ffcv_decoder( + dataset, + user_decoder_def: "DecoderDef", + ffcv_parameters: "FFCVParameters", + encoder_dictionary: Optional["FFCVEncodeDef"], +) -> Optional["FFCVDecodeDef"]: + """ + Infer the decoder pipeline from either a user definition, + the dataset `_ffcv_decoder` field, of from the examples format. + + :param dataset: The dataset to use. Must have at least + one example to attempt an inference for data format. + :param user_decoder_def: An optional user-given decoder definition. + Can be a dictionary or callable that accepts the ffcv parameters + and returns the decoder dictionary. + :param ffcv_parameters: The additional parameters passed + to the :func:`enable_ffcv` function. + :param encoder_dictionary: If not None, will be used to infer + the decoders. + :return: The decoder pipeline or None if it could not be inferred. + """ + decode_def = None + + # Use the user-provided pipeline / pipeline factory + if user_decoder_def is not None: + decode_def = user_decoder_def + if callable(decode_def): + decode_def = decode_def(ffcv_parameters) + + # Check if the dataset has an explicit field/method + if decode_def is None: + decode_def = _check_dataset_ffcv_decoder(dataset) + if callable(decode_def): + decode_def = decode_def(ffcv_parameters) + + # Try to infer the pipeline from the dataset + if decode_def is None: + decode_def = _decoder_infer_all( + dataset, ffcv_parameters, encoder_dictionary=encoder_dictionary + ) + + return decode_def + + +__all__ = ["_make_ffcv_encoder", "_make_ffcv_decoder"] diff --git a/avalanche/benchmarks/utils/ffcv_support/ffcv_transform_utils.py b/avalanche/benchmarks/utils/ffcv_support/ffcv_transform_utils.py new file mode 100644 index 000000000..1c84949f3 --- /dev/null +++ b/avalanche/benchmarks/utils/ffcv_support/ffcv_transform_utils.py @@ -0,0 +1,595 @@ +""" +Utilities used to translate torchvision transformations to FFCV. +""" + +from typing import ( + Any, + Callable, + Dict, + List, + NamedTuple, + Optional, + Tuple, + Type, + Union, +) +from typing_extensions import Literal +import warnings +import numpy as np + +import torch + +from avalanche.benchmarks.utils.transforms import flat_transforms_recursive + +from torchvision.transforms import ToTensor as ToTensorTV +from torchvision.transforms import PILToTensor as PILToTensorTV +from torchvision.transforms import Normalize as NormalizeTV +from torchvision.transforms import ConvertImageDtype as ConvertTV +from torchvision.transforms import RandomResizedCrop as RandomResizedCropTV +from torchvision.transforms import RandomHorizontalFlip as RandomHorizontalFlipTV +from torchvision.transforms import RandomCrop as RandomCropTV +from torchvision.transforms import Lambda + +from ffcv.transforms import ToTensor as ToTensorFFCV +from ffcv.transforms import ToDevice as ToDeviceFFCV +from ffcv.transforms import ToTorchImage as ToTorchImageFFCV +from ffcv.transforms import NormalizeImage as NormalizeFFCV +from ffcv.transforms import Convert as ConvertFFCV +from ffcv.transforms import View as ViewFFCV +from ffcv.transforms import Squeeze as SqueezeFFCV +from ffcv.transforms import RandomResizedCrop as RandomResizedCropFFCV +from ffcv.transforms import RandomHorizontalFlip as RandomHorizontalFlipFFCV +from ffcv.transforms import RandomTranslate as RandomTranslateFFCV +from ffcv.transforms import Cutout as CutoutFFCV +from ffcv.transforms import ImageMixup as ImageMixupFFCV +from ffcv.transforms import LabelMixup as LabelMixupFFCV +from ffcv.transforms import MixupToOneHot as MixupToOneHotFFCV +from ffcv.transforms import Poison as PoisonFFCV +from ffcv.transforms import ReplaceLabel as ReplaceLabelFFCV +from ffcv.transforms import RandomBrightness as RandomBrightnessFFCV +from ffcv.transforms import RandomContrast as RandomContrastFFCV +from ffcv.transforms import RandomSaturation as RandomSaturationFFCV +from ffcv.transforms import ModuleWrapper +from ffcv.pipeline.operation import Operation +from ffcv.pipeline.state import State +from ffcv.pipeline.allocation_query import AllocationQuery + +from ffcv.fields.decoders import SimpleRGBImageDecoder, RandomResizedCropRGBImageDecoder +from dataclasses import replace + + +class CallableAdapter: + def __init__(self, callable_obj): + self.callable_obj = callable_obj + + def __repr__(self) -> str: + return f"CallableAdapter({self.callable_obj})" + + def __call__(self, batch): + result = [] + for element in batch: + result.append(self.callable_obj(element)) + + if isinstance(batch, np.ndarray): + return np.array(result) + elif isinstance(batch, torch.Tensor): + return torch.asarray(result) + else: + return result + + +class ScaleFrom_0_255_To_0_1(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + default_float_dtype = torch.get_default_dtype() + + return input.to(dtype=default_float_dtype).div(255) + + +class FFCVTransformRegistry(NamedTuple): + numpy_cpu: bool + pytorch_cpu: bool + pytorch_gpu: bool + + +FFCV_TRANSFORMS_DEFS: Dict[Type, FFCVTransformRegistry] = {} + + +def make_transform_defs(): + """ + Fills a series of definition obtained by the FFCV documentation, + source code, and manual attempts. + + These definitions are used to properly arrange the order of transformations. + In FFCV, not all transformations support NumPy, PyTorch CPU, and PyTorch GPU + inputs. The supported inputs are defined in dictionaries stored in + `FFCV_TRANSFORMS_DEFS`. + """ + global FFCV_TRANSFORMS_DEFS + + FFCV_TRANSFORMS_DEFS[ToDeviceFFCV] = FFCVTransformRegistry( + numpy_cpu=False, + pytorch_cpu=True, + pytorch_gpu=True, # GPU -> CPU, probably unused + ) + + FFCV_TRANSFORMS_DEFS[ToTorchImageFFCV] = FFCVTransformRegistry( + numpy_cpu=False, pytorch_cpu=True, pytorch_gpu=False + ) + + FFCV_TRANSFORMS_DEFS[NormalizeFFCV] = FFCVTransformRegistry( + numpy_cpu=True, pytorch_cpu=False, pytorch_gpu=True + ) + + FFCV_TRANSFORMS_DEFS[ConvertFFCV] = FFCVTransformRegistry( + numpy_cpu=False, pytorch_cpu=True, pytorch_gpu=True + ) + + # Note: for some reason, view == convert in FFCV + # View should not used to change the shape of the tensor (it does not work) + FFCV_TRANSFORMS_DEFS[ViewFFCV] = FFCV_TRANSFORMS_DEFS[ConvertFFCV] + + FFCV_TRANSFORMS_DEFS[SqueezeFFCV] = FFCVTransformRegistry( + numpy_cpu=False, pytorch_cpu=True, pytorch_gpu=True + ) + + FFCV_TRANSFORMS_DEFS[MixupToOneHotFFCV] = FFCVTransformRegistry( + numpy_cpu=False, pytorch_cpu=True, pytorch_gpu=True + ) + + FFCV_TRANSFORMS_DEFS[ModuleWrapper] = FFCVTransformRegistry( + numpy_cpu=False, pytorch_cpu=True, pytorch_gpu=True + ) + + FFCV_TRANSFORMS_DEFS[SmartModuleWrapper] = FFCVTransformRegistry( + numpy_cpu=True, pytorch_cpu=True, pytorch_gpu=True + ) + + numpy_only_types = [ + ToTensorFFCV, + RandomResizedCropFFCV, + RandomHorizontalFlipFFCV, + RandomTranslateFFCV, + CutoutFFCV, + ImageMixupFFCV, + LabelMixupFFCV, + PoisonFFCV, + ReplaceLabelFFCV, + RandomBrightnessFFCV, + RandomContrastFFCV, + RandomSaturationFFCV, + ] + + for t_type in numpy_only_types: + FFCV_TRANSFORMS_DEFS[t_type] = FFCVTransformRegistry( + numpy_cpu=True, pytorch_cpu=False, pytorch_gpu=False + ) + + +def adapt_transforms( + transforms_list, ffcv_decoder_list, device: Optional[torch.device] = None +): + """ + Adapt the list of torchvision transformations to FFCV. + + This will use hard-coded transformations that will usually + make sense for the vast majority of situations. However, + in some cases it makes sense to pass an explicit `decoder_def` + to the :func:`enable_ffcv` function. + + :param transforms_list: The list of transformations. May include + multi-param transformations. Avalanche will usually obtain this + list from the AvalancheDatasets. + :param ffcv_decoder_list: The list of FFCV decoders. + :param device: If passed, the FFCV `ToDevice` operation will be added. + :return: The transformations adapted for FFCV. + """ + result = [] + for field_idx, pipeline_head in enumerate(ffcv_decoder_list): + transforms = flat_transforms_recursive(transforms_list, field_idx) + transforms = pipeline_head + transforms + transforms = _apply_transforms_pre_optimization(transforms, device=device) + + field_transforms: List[Operation] = [] + for t in transforms: + if isinstance(t, Operation): + # Already an FFCV transform + field_transforms.append(t) + elif isinstance(t, PILToTensorTV): + field_transforms.append(ToTensorFFCV()) + field_transforms.append(ToTorchImageFFCV()) + elif isinstance(t, ToTensorTV): + field_transforms.append(ToTensorFFCV()) + field_transforms.append(ToTorchImageFFCV()) + field_transforms.append(ModuleWrapper(ScaleFrom_0_255_To_0_1())) + elif isinstance(t, ConvertTV): + field_transforms.append(ConvertFFCV(t.dtype)) + elif isinstance(t, RandomResizedCropTV): + field_transforms.append(RandomResizedCropFFCV(t.scale, t.ratio, t.size)) + elif isinstance(t, RandomHorizontalFlipTV): + field_transforms.append(RandomHorizontalFlipFFCV(t.p)) + elif isinstance(t, RandomCropTV): + field_transforms.append( + SmartModuleWrapper( + t, expected_out_type="as_previous", expected_shape=t.size + ) + ) + elif isinstance(t, torch.nn.Module): + field_transforms.append(SmartModuleWrapper(t)) + else: + # Last hope... + field_transforms.append(SmartModuleWrapper(CallableAdapter(t))) + field_transforms = _add_to_device_operation(field_transforms, device=device) + result.append(field_transforms) + return result + + +def _apply_transforms_pre_optimization( + transformations: List[Any], device: Optional[torch.device] = None +): + """ + Applies common pre-optimizations to the list of transformations. + + :param transformations: The list of transformations. + :param device: If passed, the FFCV `ToDevice` operation will be added. + :return: The transformations optimized for FFCV. + """ + if len(transformations) < 2: + # No optimizations to apply if there are less than 2 transformations + return transformations + + result = [transformations[0]] + + for t in transformations[1:]: + if ( + isinstance(t, NormalizeTV) + and isinstance(result[-1], ToTensorTV) + and device is not None + and device.type == "cuda" + ): + # Optimize ToTensor+Normalize combo + + # ToTensor from torchvision does the following: + # 1. PIL/NDArray -> Tensor + # 2. Shape (H x W x C) -> (C x H x W) + # 3. [0, 255] -> [0.0, 1.0] + # In FFCV, the fist two steps are implemented as separate + # transformations. The range change is not available in a + # standalone way, but it is applied when normalizing. + + # Note: we apply this optimization only when running on CUDA + # as the FFCV Normalize is currently bugged and + # does not work on CPU with PyTorch Tensor inputs. + # It *may* work with CPU+NDArray... + + result[-1] = ToTensorFFCV() + result.append(ToTorchImageFFCV()) + + dtype = torch.zeros(0, dtype=torch.get_default_dtype()).numpy().dtype + + mean = np.array(t.mean) * 255 + std = np.array(t.std) * 255 + result.append(NormalizeFFCV(mean, std, dtype)) + + elif isinstance(t, RandomResizedCropTV) and isinstance( + result[-1], SimpleRGBImageDecoder + ): + size = t.size + if isinstance(size, int): + size = [size, size] + elif len(size) == 1: + size = [size[0], size[0]] + result[-1] = RandomResizedCropRGBImageDecoder(size, t.scale, t.ratio) + else: + result.append(t) + + return result + + +def _add_to_device_operation(transformations, device: Optional[torch.device] = None): + """ + Given a list of FFCV transformations, insert the `ToDevice` operation in the most + appropriate place. + + The corrent position of the `ToDevice` operation in FFCV is very hard to infer. + Avalanche uses the `FFCV_TRANSFORMS_DEFS` dictionary to infer the correct position + based on the kind of input and outputs supported by the transformations. + + :param transformations: The list of transformations (modified in place). + :param device: If passed, the FFCV `ToDevice` operation will be added. + :return: The transformations with `ToDevice`. + """ + if device is None: + return transformations + + # Check if ToDevice is laready in the pipeline + for t in transformations: + if isinstance(t, ToDeviceFFCV): + # Already set + return transformations + + # All decoders (first operation in the pipeline) return NumPy arrays + is_numpy = True + is_cpu = True + + transformations = list(transformations) + inserted = False + for i, t in enumerate(transformations): + t_def = FFCV_TRANSFORMS_DEFS.get(type(t), None) + if t_def is None: + # Unknown operation + continue + + if is_numpy and not t_def.numpy_cpu: + # Unmanageable situation: the current input is a NumPy array + # but the transformation only supports PyTorch Tensor. + + # A warning is already raised by check_transforms_consistency, + # so it's not a big issue... + # Anyway, the pipeline is probably doomed to fail + break + elif not is_numpy: + if not (t_def.pytorch_cpu or t_def.pytorch_gpu): + # Unmanageable situation: the current input is a PyTorch Tensor + # but the transformation only supports NumPy arrays. + + # A warning is already raised by check_transforms_consistency + break + + if is_cpu and t_def.pytorch_gpu: + transformations.insert(i, ToDeviceFFCV(device=device)) + inserted = True + break + + elif (not is_cpu) and t_def.pytorch_cpu: + # From GPU to CPU is currently unsupported + # Maybe in the future we can try to manage this... + break + + if isinstance(t, ToTensorFFCV): + is_numpy = False + elif isinstance(t, ToDeviceFFCV): + is_cpu = t.device.type == "cpu" + + if not inserted: + transformations.append(ToDeviceFFCV(device)) + + return transformations + + +def check_transforms_consistency(transformations, warn_gpu_to_cpu: bool = True): + """ + Checks if the list of transformations has issues with input/output formats + and devices consistency: + + :param transformations: The list of transformations. + :param warn_gpu_to_cpu: Warn if ToDevice is used to move tensors from the + gpu to the cpu. + :return: True if the list of transformations has no obvious consistency issues. + """ + had_issues = False + + # All decoders (first operation in the pipeline) return NumPy arrays + is_numpy = True + is_cpu = True + + for t in transformations: + t_def = FFCV_TRANSFORMS_DEFS.get(type(t), None) + if t_def is None: + # Unknown operation + continue + + bad_usage_type = None + + if is_numpy and not t_def.numpy_cpu: + bad_usage_type = "NumPy arrays" + elif not is_numpy: + if is_cpu and not t_def.pytorch_cpu: + bad_usage_type = "CPU PyTorch Tensors" + elif (not is_cpu) and not t_def.pytorch_gpu: + bad_usage_type = "GPU PyTorch Tensors" + + if bad_usage_type is not None: + warnings.warn( + f"Transformation {type(t)} cannot be used on " + f"{bad_usage_type}.\n" + f"Its registered definition is: {t_def}.\n" + f"This may lead to issues with Numba..." + ) + had_issues = True + + if isinstance(t, ToTensorFFCV): + is_numpy = False + elif isinstance(t, ToDeviceFFCV): + if (not is_cpu) and t.device.type == "cpu": + if warn_gpu_to_cpu: + warnings.warn( + f"Moving a Tensor from GPU to CPU is quite unusual..." + ) + had_issues = True + + is_cpu = t.device.type == "cpu" + + return not had_issues + + +class SmartModuleWrapper(Operation): + """ + Transform using the given torch.nn.Module. + + This covers those transformations implemented as a torch module which + are not already translated from torchvision. + + This is a smarter version of the FFCV wrapper as it allows + having NumPy inputs and setting explicit shapes for input and outputs. + + Parameters + ---------- + module: torch.nn.Module + The module for transformation + """ + + def __init__( + self, + module: torch.nn.Module, + expected_out_type: Union[ + np.dtype, torch.dtype, Literal["as_previous"] + ] = "as_previous", + expected_shape: Union[Tuple[int, ...], Literal["as_previous"]] = "as_previous", + smart_reshape: bool = True, + ): + """ + Creates an instance of a SmartModuleWrapper. + + :param module: The module to use. The module must be able to process + the inputs in batches. + :param expected_out_type: The expected type of the output. Default to `as_previous`. + :param expected_shape: The expected shape of the output. Default to `as_previous`. + :param smart_reshape: If True, will try to compute the proper shape conversion + when the input is NumPy and the shape suggests that an image is being passed. + """ + super().__init__() + self.module = module + self.expected_out_type = expected_out_type + self.expected_shape = expected_shape + self.input_type = "numpy" + self.output_type = "numpy" + self.smart_reshape = smart_reshape + + def __repr__(self) -> str: + return f"SmartModuleWrapper({self.module})" + + def generate_code(self) -> Callable: + """ + Obtain the correct function for the given input and output + definitions. + + :return: The callable to be used as the transformation. + """ + + def convert_apply_convert_reshape(inp, _): + inp_as_tensor = torch.from_numpy(inp) + # N, H, W, C -> N, C, H, W + inp_as_tensor = inp_as_tensor.permute([0, 3, 1, 2]) + res = self.module(inp_as_tensor) + + # N, C, H, W -> N, H, W, C + res_as_np: np.ndarray = res.numpy() + return res_as_np.transpose((0, 2, 3, 1)) + + def convert_apply_reshape(inp, _): + inp_as_tensor = torch.from_numpy(inp) + # N, H, W, C -> N, C, H, W + inp_as_tensor = inp_as_tensor.permute([0, 3, 1, 2]) + + res = self.module(inp_as_tensor) + return res + + def apply_convert_reshape(inp, _): + res = self.module(inp) + + # N, C, H, W -> N, H, W, C + res_as_np: np.ndarray = res.numpy() + return res_as_np.transpose((0, 2, 3, 1)) + + def convert_apply_convert(inp, _): + inp_as_tensor = torch.from_numpy(inp) + res = self.module(inp_as_tensor) + return res.numpy() + + def convert_apply(inp, _): + inp_as_tensor = torch.from_numpy(inp) + res = self.module(inp_as_tensor) + return res + + def apply_convert(inp, _): + res = self.module(inp) + return res.numpy() + + def apply(inp, _): + device = inp.device + return self.module(inp).to(device, non_blocking=True) + + # (input_type, output_type, smart_reshape) -> func + func_table = { + ("numpy", "numpy", True): convert_apply_convert_reshape, + ("numpy", "torch", True): convert_apply_reshape, + ("torch", "numpy", True): apply_convert_reshape, + ("numpy", "numpy", False): convert_apply_convert, + ("numpy", "torch", False): convert_apply, + ("torch", "numpy", False): apply_convert, + ("torch", "torch", True): apply, + ("torch", "torch", False): apply, + } + + return func_table[(self.input_type, self.output_type, self.smart_reshape)] + + def declare_state_and_memory( + self, previous_state: State + ) -> Tuple[State, Optional[AllocationQuery]]: + if len(previous_state.shape) != 3: + self.smart_reshape = False + + self._fill_types(previous_state) + self._to_device(previous_state) + self._compute_smart_shape(previous_state) + + state_changes = dict() + if self.expected_out_type != "as_previous": + # Output type != input type + state_changes["dtype"] = self.expected_out_type + + state_changes["shape"] = self.expected_shape + + return replace(previous_state, jit_mode=False, **state_changes), None + + def _fill_types(self, previous_state: State): + if isinstance(previous_state.dtype, torch.dtype): + self.input_type = "torch" + else: + self.input_type = "numpy" + + if self.expected_out_type == "as_previous": + self.output_type = self.input_type + else: + if isinstance(self.expected_out_type, torch.dtype): + self.output_type = "torch" + else: + self.output_type = "numpy" + + def _to_device(self, previous_state: State): + if previous_state.device.type != "cpu": + if hasattr(self.module, "to"): + self.module = self.module.to(previous_state.device) + + def _compute_smart_shape(self, previous_state: State): + if self.smart_reshape: + if self.input_type == "numpy": + h, w, c = previous_state.shape + else: + c, h, w = previous_state.shape + + patch_shape = True + if self.expected_shape != "as_previous": + if ( + isinstance(self.expected_shape, int) + or len(self.expected_shape) == 1 + ): + h = self.expected_shape + w = self.expected_shape + elif len(self.expected_shape) == 2: + h, w = self.expected_shape + else: + # Completely user-managed + patch_shape = False + + if patch_shape: + if self.output_type == "numpy": + self.expected_shape = (h, w, c) + else: + self.expected_shape = (c, h, w) + + +make_transform_defs() diff --git a/avalanche/benchmarks/utils/flat_data.py b/avalanche/benchmarks/utils/flat_data.py index 63dae165d..f4c17abca 100644 --- a/avalanche/benchmarks/utils/flat_data.py +++ b/avalanche/benchmarks/utils/flat_data.py @@ -575,6 +575,8 @@ def _flatdata_repr(dataset, indent=0): """Return the string representation of the dataset. Shows the underlying dataset tree. """ + from avalanche.benchmarks.utils.data import _FlatDataWithTransform + if isinstance(dataset, FlatData): ss = dataset._indices is not None cc = len(dataset._datasets) != 1 @@ -584,6 +586,11 @@ def _flatdata_repr(dataset, indent=0): + f"{dataset.__class__.__name__} (len={len(dataset)},subset={ss}," f"cat={cc},cf={cf})\n" ) + if isinstance(dataset, _FlatDataWithTransform): + s = s[:-2] + ( + f",transform_groups={dataset._transform_groups}," + f"frozen_transform_groups={dataset._frozen_transform_groups})\n" + ) for dd in dataset._datasets: s += _flatdata_repr(dd, indent + 1) return s diff --git a/avalanche/benchmarks/utils/transforms.py b/avalanche/benchmarks/utils/transforms.py index 99b94fba6..e41536182 100644 --- a/avalanche/benchmarks/utils/transforms.py +++ b/avalanche/benchmarks/utils/transforms.py @@ -13,19 +13,49 @@ This module contains a bunch of utility classes to help define multi-argument transformations. """ +from abc import ABC, abstractmethod import warnings -from typing import Callable, Sequence +from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union from inspect import signature, Parameter +from torchvision.transforms import Compose -class MultiParamTransform: +class MultiParamTransform(ABC): """We need this class to be able to distinguish between a single argument transformation and multi-argument ones. Transformations are callable objects. """ + @abstractmethod def __call__(self, *args, **kwargs): + """ + Applies this transformations to the given inputs. + """ + pass + + @abstractmethod + def flat_transforms(self, position: int) -> List[Any]: + """ + Returns a flat list of transformations. + + A flat list of transformations is a list in which + all intermediate wrappers (such as torchvision Compose, + Avalanche MultiParamCompose, ...) are removed. + + The position parameter is used to control which transformations + are to be returned based on the position of the tranformed element. + Position 0 means transformations on the "x" value, + 1 means "target" (or y) transformations, and so on. + + Please note that transformations acting on multiple parameters + may be returned when appropriate. This is common for object + detection augmentations that transform x (image) and y (bounding boxes) + inputs at the same time. + + :position: The position of the tranformed element. + :return: A list of transformations for the given position. + """ pass @@ -49,8 +79,8 @@ class MultiParamCompose(MultiParamTransform): def __init__(self, transforms: Sequence[Callable]): # skip empty transforms transforms = list(filter(lambda x: x is not None, transforms)) - self.transforms = transforms - self.param_def = [] + self.transforms = list(transforms) + self.param_def: List[Tuple[int, int]] = [] self.max_params = -1 self.min_params = -1 @@ -63,7 +93,7 @@ def __init__(self, transforms: Sequence[Callable]): all_maxes = set([max_p for _, max_p in self.param_def]) if len(all_maxes) > 1: warnings.warn( - "Transformations define a different amount of parameters. " + "Transformations define a different number of parameters. " "This may lead to errors. This warning will only appear" "once.", ComposeMaxParamsWarning, @@ -75,6 +105,20 @@ def __init__(self, transforms: Sequence[Callable]): self.max_params = max(all_maxes) self.min_params = min([min_p for min_p, _ in self.param_def]) + def __eq__(self, other): + if self is other: + return True + + if not isinstance(other, MultiParamCompose): + return False + + return ( + self.transforms == other.transforms + and self.param_def == other.param_def + and self.min_params == other.min_params + and self.max_params == other.max_params + ) + def __call__(self, *args, force_tuple_output=False): if len(self.transforms) > 0: for transform, (min_par, max_par) in zip(self.transforms, self.param_def): @@ -97,6 +141,17 @@ def __repr__(self): def __str__(self): return self.__repr__() + def flat_transforms(self, position: int): + all_transforms = [] + + for transform, par_def in zip(self.transforms, self.param_def): + max_params = par_def[1] + + if position < max_params or max_params == -1: + all_transforms.append(transform) + + return flat_transforms_recursive(all_transforms, position) + class MultiParamTransformCallable(MultiParamTransform): """Generic multi-argument transformation.""" @@ -145,7 +200,7 @@ def _call_transform(transform_callable, _, max_par, *params): return params_list @staticmethod - def _detect_parameters(transform_callable): + def _detect_parameters(transform_callable) -> Tuple[int, int]: min_params = 0 max_params = 0 @@ -192,12 +247,30 @@ def _is_torchvision_transform(transform_callable): tc_module = tc_class.__module__ return "torchvision.transforms" in tc_module + def flat_transforms(self, position: int): + if position < self.max_params or self.max_params == -1: + return flat_transforms_recursive(self.transform, position) + return [] + + def __eq__(self, other): + if self is other: + return True + + if not isinstance(other, MultiParamTransformCallable): + return False + + return ( + self.transform == other.transform + and self.min_params == other.min_params + and self.max_params == other.max_params + ) + class TupleTransform(MultiParamTransform): """Multi-argument transformation represented as tuples.""" def __init__(self, transforms: Sequence[Callable]): - self.transforms = transforms + self.transforms = list(transforms) def __call__(self, *args): args_list = list(args) @@ -209,6 +282,60 @@ def __call__(self, *args): def __str__(self): return "TupleTransform({})".format(self.transforms) + def __repr__(self): + return "TupleTransform({})".format(self.transforms) + + def __eq__(self, other): + if self is other: + return True + + if not isinstance(other, TupleTransform): + return False + + return self.transforms == other.transforms + + def flat_transforms(self, position: int): + if position < len(self.transforms): + return flat_transforms_recursive(self.transforms[position], position) + return [] + + +def flat_transforms_recursive(transforms: Union[List, Any], position: int) -> List[Any]: + """ + Flattens a list of transformations. + + :param transforms: The list of transformations to flatten. + :param position: The position of the transformed element. + :return: A flat list of transformations. + """ + if not isinstance(transforms, Iterable): + transforms = [transforms] + + must_flat = True + while must_flat: + must_flat = False + flattened_list = [] + + for transform in transforms: + flat_strat = getattr(transform, "flat_transforms", None) + if callable(flat_strat): + flattened_list.extend(flat_strat(position)) + must_flat = True + elif isinstance(transform, Compose): + flattened_list.extend(transform.transforms) + must_flat = True + elif isinstance(transform, Sequence): + flattened_list.extend(transform) + must_flat = True + elif transform is None: + pass + else: + flattened_list.append(transform) + + transforms = flattened_list + + return transforms + class ComposeMaxParamsWarning(Warning): def __init__(self, message): @@ -224,4 +351,5 @@ def __init__(self, message): "MultiParamTransformCallable", "ComposeMaxParamsWarning", "TupleTransform", + "flat_transforms_recursive", ] diff --git a/avalanche/training/plugins/replay.py b/avalanche/training/plugins/replay.py index 832a26df6..96ed23c47 100644 --- a/avalanche/training/plugins/replay.py +++ b/avalanche/training/plugins/replay.py @@ -1,6 +1,8 @@ from typing import Optional, TYPE_CHECKING -from avalanche.benchmarks.utils import concat_classification_datasets +from packaging.version import parse +import torch + from avalanche.benchmarks.utils.data_loader import ReplayDataLoader from avalanche.training.plugins.strategy_plugin import SupervisedPlugin from avalanche.training.storage_policy import ( @@ -97,6 +99,18 @@ def before_training_exp( batch_size_mem = strategy.train_mb_size assert strategy.adapted_dataset is not None + + other_dataloader_args = dict() + + if "ffcv_args" in kwargs: + other_dataloader_args["ffcv_args"] = kwargs["ffcv_args"] + + if "persistent_workers" in kwargs: + if parse(torch.__version__) >= parse("1.7.0"): + other_dataloader_args["persistent_workers"] = kwargs[ + "persistent_workers" + ] + strategy.dataloader = ReplayDataLoader( strategy.adapted_dataset, self.storage_policy.buffer, @@ -107,6 +121,7 @@ def before_training_exp( num_workers=num_workers, shuffle=shuffle, drop_last=drop_last, + **other_dataloader_args ) def after_training_exp(self, strategy: "SupervisedTemplate", **kwargs): diff --git a/avalanche/training/supervised/ar1.py b/avalanche/training/supervised/ar1.py index 85a0955d2..3a2f52075 100644 --- a/avalanche/training/supervised/ar1.py +++ b/avalanche/training/supervised/ar1.py @@ -229,7 +229,9 @@ def _before_training_exp(self, **kwargs): ] self.cwr_plugin.reset_weights(self.cwr_plugin.cur_class) - def make_train_dataloader(self, num_workers=0, shuffle=True, **kwargs): + def make_train_dataloader( + self, num_workers=0, shuffle=True, persistent_workers=False, **kwargs + ): """ Called after the dataset instantiation. Initialize the data loader. @@ -273,6 +275,7 @@ def make_train_dataloader(self, num_workers=0, shuffle=True, **kwargs): batch_size=current_batch_mb_size, num_workers=num_workers, shuffle=shuffle, + persistent_workers=persistent_workers, **kwargs ) diff --git a/avalanche/training/templates/base.py b/avalanche/training/templates/base.py index 33b1a67e5..7cd34c6a6 100644 --- a/avalanche/training/templates/base.py +++ b/avalanche/training/templates/base.py @@ -1,7 +1,7 @@ import sys import warnings from collections import defaultdict -from typing import Generic, Iterable, Sequence, Optional, TypeVar, Union, List +from typing import Iterable, Sequence, Optional, TypeVar, Union, List import torch from torch.nn import Module diff --git a/avalanche/training/templates/base_sgd.py b/avalanche/training/templates/base_sgd.py index 8e40e2b5c..7c949fbf0 100644 --- a/avalanche/training/templates/base_sgd.py +++ b/avalanche/training/templates/base_sgd.py @@ -4,21 +4,22 @@ import torch from torch.nn import Module, CrossEntropyLoss from torch.optim import Optimizer -from torch.utils.data import DataLoader from torch import Tensor from avalanche.benchmarks import CLExperience, CLStream from avalanche.benchmarks.scenarios.generic_scenario import DatasetExperience from avalanche.benchmarks.utils.data import AvalancheDataset from avalanche.core import BasePlugin, BaseSGDPlugin -from avalanche.training.plugins import SupervisedPlugin, EvaluationPlugin +from avalanche.training.plugins import EvaluationPlugin from avalanche.training.plugins.clock import Clock from avalanche.training.plugins.evaluation import default_evaluator from avalanche.training.templates.base import BaseTemplate from avalanche.benchmarks.utils.data_loader import ( + SingleDatasetDataLoader, TaskBalancedDataLoader, collate_from_data_or_kwargs, ) + from avalanche.training.templates.strategy_mixin_protocol import SGDStrategyProtocol from avalanche.training.utils import trigger_plugins @@ -358,7 +359,10 @@ def _obtain_common_dataloader_parameters(self, **kwargs): other_dataloader_args = {} if "persistent_workers" in kwargs: - if parse(torch.__version__) >= parse("1.7.0"): + if ( + parse(torch.__version__) >= parse("1.7.0") + and kwargs.get("num_workers", 0) > 0 + ): other_dataloader_args["persistent_workers"] = kwargs[ "persistent_workers" ] @@ -395,8 +399,6 @@ def make_train_dataloader( assert self.adapted_dataset is not None - torch.utils.data.DataLoader - other_dataloader_args = self._obtain_common_dataloader_parameters( batch_size=self.train_mb_size, num_workers=num_workers, @@ -406,6 +408,9 @@ def make_train_dataloader( drop_last=drop_last, ) + if "ffcv_args" in kwargs: + other_dataloader_args["ffcv_args"] = kwargs["ffcv_args"] + self.dataloader = TaskBalancedDataLoader( self.adapted_dataset, oversample_small_groups=True, **other_dataloader_args ) @@ -441,7 +446,12 @@ def make_eval_dataloader( collate_from_data_or_kwargs(self.adapted_dataset, other_dataloader_args) - self.dataloader = DataLoader(self.adapted_dataset, **other_dataloader_args) + if "ffcv_args" in kwargs: + other_dataloader_args["ffcv_args"] = kwargs["ffcv_args"] + + self.dataloader = SingleDatasetDataLoader( + self.adapted_dataset, **other_dataloader_args + ) def eval_dataset_adaptation(self, **kwargs): """Initialize `self.adapted_dataset`.""" diff --git a/avalanche/training/templates/problem_type/supervised_problem.py b/avalanche/training/templates/problem_type/supervised_problem.py index 1b12fd539..fff259ee4 100644 --- a/avalanche/training/templates/problem_type/supervised_problem.py +++ b/avalanche/training/templates/problem_type/supervised_problem.py @@ -46,8 +46,10 @@ def _unpack_minibatch(self): if isinstance(mbatch, tuple): mbatch = list(mbatch) + self.mbatch = mbatch + for i in range(len(mbatch)): - self.mbatch[i] = mbatch[i].to(self.device) # type: ignore + mbatch[i] = mbatch[i].to(self.device) # type: ignore __all__ = ["SupervisedProblem"] diff --git a/examples/ffcv/README.md b/examples/ffcv/README.md new file mode 100644 index 000000000..f2f3ad641 --- /dev/null +++ b/examples/ffcv/README.md @@ -0,0 +1,13 @@ +# Avalanche-FFCV examples + +This folder contains some examples that can be used to get started with the [FFCV](https://ffcv.io/) data loading mechanism in Avalanche. + +Avalanche currently supports the FFCV data loading mechanism for virtually all benchmark types. However, automatic support is given only for **classification** and **regression** tasks due to the complex encoder/decoder definitions in FFCV. + +## Examples list + +- `ffcv_enable.py`: the main example, shows how to enable FFCV in Avalanche. +- `ffcv_enable_rgb_compress.py`: shows how to use the jpg/mixed image encoding. +- `ffcv_io_manual_test.py`: a template you can use to manually setup the decoder pipeline. +- `ffcv_try_speed.py`: a benchmarking script to compare FFCV to PyTorch. + diff --git a/examples/ffcv/ffcv_enable.py b/examples/ffcv/ffcv_enable.py new file mode 100644 index 000000000..385d6af99 --- /dev/null +++ b/examples/ffcv/ffcv_enable.py @@ -0,0 +1,147 @@ +""" +This example shows how to use FFCV data loading system in Avalanche. +""" + +import argparse +from datetime import datetime +import time + +import torch +import torch.optim.lr_scheduler +from torch.optim import Adam +from avalanche.benchmarks import SplitMNIST +from avalanche.benchmarks.classic.ccifar100 import SplitCIFAR100 +from avalanche.benchmarks.classic.ctiny_imagenet import SplitTinyImageNet +from avalanche.benchmarks.utils.ffcv_support import enable_ffcv +from avalanche.models import SimpleMLP +from avalanche.training.determinism.rng_manager import RNGManager +from avalanche.training.supervised import Naive +from avalanche.training.plugins import ReplayPlugin +from avalanche.evaluation.metrics import accuracy_metrics +from avalanche.logging import TensorboardLogger, InteractiveLogger +from avalanche.training.plugins import EvaluationPlugin + + +def main(cuda: int): + # --- CONFIG + device = torch.device(f"cuda:{cuda}" if torch.cuda.is_available() else "cpu") + RNGManager.set_random_seeds(1234) + + benchmark_type = "cifar100" + + # --- BENCHMARK CREATION + num_workers = 8 + if benchmark_type == "mnist": + input_size = 28 * 28 + num_workers = 4 + benchmark = SplitMNIST( + n_experiences=5, seed=42, class_ids_from_zero_from_first_exp=True + ) + elif benchmark_type == "cifar100": + benchmark = SplitCIFAR100(5, seed=1234, shuffle=True) + input_size = 32 * 32 * 3 + elif benchmark_type == "tinyimagenet": + benchmark = SplitTinyImageNet() + input_size = 64 * 64 * 3 + else: + raise RuntimeError("Unknown benchmark") + + # Enabling FFCV in Avalanche is as simple as calling `enable_ffcv`. + # This function will: + # - Prepare an encoder pipeline + # - Prepare a decoder pipeline (transformations) + # - Write the datasets (usually train and test) on disk + # - Enable FFCV in strategies + # + # Note that Avalanche will make some assumptions regarding the + # decoder (loader+transformations) part. If the decoder does not + # work as intended (bad outputs, exceptions, crashes), then + # it is better to use the `ffcv_io_manual_test.py` example to + # prepare a manual pipeline. + # + # Ad-hoc pipelines can be passed as the `encoder_def` + # and `decoder_def` parameters. + print("Enabling FFCV support...") + print("The may include writing the datasets in FFCV format. May take some time...") + enable_ffcv( + benchmark=benchmark, + write_dir=f"./ffcv_test_{benchmark_type}", + device=device, + ffcv_parameters=dict(num_workers=8), + ) + print("FFCV enabled!") + + # -------------------- THAT'S IT!! ------------------------------ + # The rest of the script is an usual Avalanche code. + # + # In certain situations, you may want to pass some custom + # parameters to the FFCV Loader. This can be achieved + # when calling `train()` and `eval()` (see the main loop). + # --------------------------------------------------------------- + + # MODEL CREATION + model = SimpleMLP(input_size=input_size, num_classes=benchmark.n_classes) + + # METRICS + eval_plugin = EvaluationPlugin( + accuracy_metrics(stream=True, experience=True), + loggers=[TensorboardLogger(f"tb_data/{datetime.now()}"), InteractiveLogger()], + ) + + # CREATE THE STRATEGY INSTANCE + replay_plugin = ReplayPlugin(mem_size=100, batch_size=125, batch_size_mem=25) + cl_strategy = Naive( + model, + Adam(model.parameters()), + train_mb_size=128, + train_epochs=4, + eval_mb_size=128, + device=device, + plugins=[replay_plugin], + evaluator=eval_plugin, + ) + + # TRAINING LOOP + # For FFCV, you can pass the Loader parameters using ffcv_args + # Notice that some parameters like shuffle, num_workers, ..., + # which are also found in the PyTorch DataLoader, can be passed + # to train() and eval() as usual: they will be passed to the FFCV + # Loader as they would be passed to the PyTorch Dataloader. + # + # In addition to the FFCV Loader parameters, you can pass the + # print_ffcv_summary flag (which is managed by Avalanche), + # which allows for printing the pipeline and the status of + # internal checks made by Avalanche. That flag is very useful + # when setting up an FFCV+Avalanche experiment. Once you are sure + # that the code works as intended, it is better to remove it as + # the logging may be quite verbose... + start_time = time.time() + for i, experience in enumerate(benchmark.train_stream): + cl_strategy.train( + experience, + shuffle=False, + persistent_workers=True, + num_workers=num_workers, + ffcv_args={"print_ffcv_summary": True, "batches_ahead": 2}, + ) + + cl_strategy.eval( + benchmark.test_stream[: i + 1], + shuffle=False, + num_workers=num_workers, + ffcv_args={"print_ffcv_summary": True, "batches_ahead": 2}, + ) + end_time = time.time() + print("Overall time:", end_time - start_time, "seconds") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--cuda", + type=int, + default=0, + help="Select zero-indexed cuda device. -1 to use CPU.", + ) + args = parser.parse_args() + main(args.cuda) diff --git a/examples/ffcv/ffcv_enable_rgb_compress.py b/examples/ffcv/ffcv_enable_rgb_compress.py new file mode 100644 index 000000000..858fd6d7a --- /dev/null +++ b/examples/ffcv/ffcv_enable_rgb_compress.py @@ -0,0 +1,157 @@ +""" +This example shows how to use FFCV data loading system in Avalanche +when compressing RGB images is required. + +FFCV allows for various tweaks to be used when manipulating images. +In particular, FFCV allows storing images as JPGs with custom +quality. In addition, the max side of the image and other custom +elements can be set. + +This tutorial will show how to set these parameters. +""" + +import argparse +from datetime import datetime +import time + +import torch +import torch.optim.lr_scheduler +from torch.optim import Adam +from avalanche.benchmarks import SplitMNIST +from avalanche.benchmarks.classic.ccifar100 import SplitCIFAR100 +from avalanche.benchmarks.classic.ctiny_imagenet import SplitTinyImageNet +from avalanche.benchmarks.utils.ffcv_support import enable_ffcv +from avalanche.models import SimpleMLP +from avalanche.training.determinism.rng_manager import RNGManager +from avalanche.training.supervised import Naive +from avalanche.training.plugins import ReplayPlugin +from avalanche.evaluation.metrics import accuracy_metrics +from avalanche.logging import TensorboardLogger, InteractiveLogger +from avalanche.training.plugins import EvaluationPlugin + + +def main(cuda: int): + device = torch.device(f"cuda:{cuda}" if torch.cuda.is_available() else "cpu") + RNGManager.set_random_seeds(1234) + + benchmark_type = "tinyimagenet" + + # --- BENCHMARK CREATION + num_workers = 8 + if benchmark_type == "mnist": + input_size = 28 * 28 + num_workers = 4 + benchmark = SplitMNIST( + n_experiences=5, seed=42, class_ids_from_zero_from_first_exp=True + ) + elif benchmark_type == "cifar100": + benchmark = SplitCIFAR100(5, seed=1234, shuffle=True) + input_size = 32 * 32 * 3 + elif benchmark_type == "tinyimagenet": + benchmark = SplitTinyImageNet() + input_size = 64 * 64 * 3 + else: + raise RuntimeError("Unknown benchmark") + + # Enabling FFCV in Avalanche is as simple as calling `enable_ffcv`. + # For additional info regarding on how this works, please refer + # to the `ffcv_enable.py` example. + # In this example, the focus is on the RGB encoder customization. + # + # `ffcv_parameters` is where we pass custom parameters for the RGB encoder + # These parameters are listed in the FFCV website: + # https://docs.ffcv.io/working_with_images.html + # As an example, here we set parameters like + # write_mode, compress_probability, and jpeg_quality + # + # Note: an alternative way to achieve this is to specify the encoder + # dictionary directly by passing the `encoder_def` parameter. + print("Enabling FFCV support...") + print("The may include writing the datasets in FFCV format. May take some time...") + enable_ffcv( + benchmark=benchmark, + write_dir=f"./ffcv_test_compress_{benchmark_type}", + device=device, + ffcv_parameters=dict( + num_workers=8, + write_mode="proportion", + compress_probability=0.25, + jpeg_quality=90, + ), + ) + print("FFCV enabled!") + + # -------------------- THAT'S IT!! ------------------------------ + # The rest of the script is an usual Avalanche code. + # + # In certain situations, you may want to pass some custom + # parameters to the FFCV Loader. This can be achieved + # when calling `train()` and `eval()` (see the main loop). + # --------------------------------------------------------------- + + # MODEL CREATION + model = SimpleMLP(input_size=input_size, num_classes=benchmark.n_classes) + + # METRICS + eval_plugin = EvaluationPlugin( + accuracy_metrics(stream=True, experience=True), + loggers=[TensorboardLogger(f"tb_data/{datetime.now()}"), InteractiveLogger()], + ) + + # CREATE THE STRATEGY INSTANCE + replay_plugin = ReplayPlugin(mem_size=100, batch_size=125, batch_size_mem=25) + cl_strategy = Naive( + model, + Adam(model.parameters()), + train_mb_size=128, + train_epochs=4, + eval_mb_size=128, + device=device, + plugins=[replay_plugin], + evaluator=eval_plugin, + ) + + # TRAINING LOOP + # For FFCV, you can pass the Loader parameters using ffcv_args + # Notice that some parameters like shuffle, num_workers, ..., + # which are also found in the PyTorch DataLoader, can be passed + # to train() and eval() as usual: they will be passed to the FFCV + # Loader as they would be passed to the PyTorch Dataloader. + # + # In addition to the FFCV Loader parameters, you can pass the + # print_ffcv_summary flag (which is managed by Avalanche), + # which allows for printing the pipeline and the status of + # internal checks made by Avalanche. That flag is very useful + # when setting up an FFCV+Avalanche experiment. Once you are sure + # that the code works as intended, it is better to remove it as + # the logging may be quite verbose... + start_time = time.time() + for i, experience in enumerate(benchmark.train_stream): + cl_strategy.train( + experience, + shuffle=False, + persistent_workers=True, + num_workers=num_workers, + ffcv_args={"print_ffcv_summary": True, "batches_ahead": 2}, + ) + + cl_strategy.eval( + benchmark.test_stream[: i + 1], + shuffle=False, + num_workers=num_workers, + ffcv_args={"print_ffcv_summary": True, "batches_ahead": 2}, + ) + end_time = time.time() + print("Overall time:", end_time - start_time, "seconds") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--cuda", + type=int, + default=0, + help="Select zero-indexed cuda device. -1 to use CPU.", + ) + args = parser.parse_args() + main(args.cuda) diff --git a/examples/ffcv/ffcv_io_manual_test.py b/examples/ffcv/ffcv_io_manual_test.py new file mode 100644 index 000000000..71536cd91 --- /dev/null +++ b/examples/ffcv/ffcv_io_manual_test.py @@ -0,0 +1,201 @@ +""" +Simple script used to (manually) check if the FFCV pipeline returns +the expected outputs. This can be used to inspect the output +of a decoding pipeline. + +It is recommended to start with the automatic translation pipeline, +which Avalanche tries to put toghether when `enable_ffcv` +has no `decoder_def` parameter. If you are not happy with the +automatic pipeline, you can start putting your custom pipeline together +by following the FFCV tutorials! +""" + +# %% +import random +import time +from matplotlib import pyplot as plt + +import torch +from avalanche.benchmarks.classic.ccifar100 import SplitCIFAR100 +from avalanche.benchmarks.classic.ctiny_imagenet import SplitTinyImageNet +from avalanche.benchmarks.utils.ffcv_support import enable_ffcv +from avalanche.benchmarks.utils.ffcv_support.ffcv_components import ( + HybridFfcvLoader, +) +from avalanche.training.determinism.rng_manager import RNGManager + +from torchvision.transforms.functional import to_pil_image +from torchvision import transforms +from torch.utils.data import DataLoader + +from torch.utils.data.sampler import ( + BatchSampler, + SequentialSampler, +) + + +# %% +def main(cuda: int): + # --- CONFIG + device = torch.device(f"cuda:{cuda}" if torch.cuda.is_available() else "cpu") + RNGManager.set_random_seeds(1234) + + # Define here the transformations to check + + # --- CIFAR-100 --- + cifar_train_transform = transforms.Compose( + [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(p=0.5), + transforms.ToTensor(), + ] + ) + cifar_eval_transform = transforms.Compose( + [ + transforms.ToTensor(), + ] + ) + benchmark = SplitCIFAR100( + 5, + seed=4321, + shuffle=True, + train_transform=cifar_train_transform, + eval_transform=cifar_eval_transform, + return_task_id=True, + ) + write_dir = "./ffcv_manual_test_cifar100" + + # --- TinyImagenet --- + # benchmark = SplitTinyImageNet() + # write_dir = "./ffcv_manual_test_tiny_imagenet" + + # It is recommended to start with `None`, so that Avalanche can try + # putting a pipeline together automatically by translating common + # torchvision transformations to FFCV. + # If you encounter issues or the output is not what you expect, then + # it is recommended to start from the pipeline printed by Avalanche + # and adapt it by following the guides in the FFCV website and repo. + custom_decoder_pipeline = None + + num_workers = 8 + + print("Preparing FFCV datasets...") + enable_ffcv( + benchmark=benchmark, + write_dir=write_dir, + device=device, + ffcv_parameters=dict(num_workers=num_workers), + decoder_def=custom_decoder_pipeline, + print_summary=True, # Leave to True to get important info! + ) + print("FFCV datasets ready") + + # Create the FFCV Loader + # Here we use the HybridFfcvLoader directly to load an AvalancheDataset + # The HybridFfcvLoader is an internal utility we here use to directly check + # if the decoder pipeline is working as intended. + # Note: this is not the way FFCV should be used in Avalanche + # Refer to the `ffcv_enable.py` example for the correct way + + start_time = time.time() + ffcv_data_loader = HybridFfcvLoader( + benchmark.train_stream[0].dataset, + batch_sampler=BatchSampler( + SequentialSampler(benchmark.train_stream[0].dataset), + batch_size=12, + drop_last=True, + ), + ffcv_loader_parameters=dict(num_workers=num_workers, drop_last=True), + device=device, + persistent_workers=False, + print_ffcv_summary=True, + start_immediately=False, + ) + end_time = time.time() + print("Loader creation took", end_time - start_time, "seconds") + + # Also load the same data using a PyTorch DataLoader + # Note: data will be different when using random augmentations! + pytorch_loader = DataLoader( + benchmark.train_stream[0].dataset, + batch_size=12, + drop_last=True, + ) + + start_time = time.time() + for i, (ffcv_batch, torch_batch) in enumerate( + zip(ffcv_data_loader, pytorch_loader) + ): + print(f"Batch {i} composition (FFCV vs PyTorch)") + for element in ffcv_batch: + print(element.shape, "vs", element.shape) + + n_to_show = 3 + for idx in range(n_to_show): + as_img_ffcv = to_pil_image(ffcv_batch[0][idx]) + as_img_torch = to_pil_image(torch_batch[0][idx]) + + f, axarr = plt.subplots(1, 2) + ffcv_label = ffcv_batch[1][idx].item() + torch_label = torch_batch[1][idx].item() + ffcv_task = ffcv_batch[2][idx].item() + torch_task = torch_batch[2][idx].item() + f.suptitle( + f"Label: {ffcv_label}/{torch_label}, " + f"Task label: {ffcv_task}/{torch_task}" + ) + + axarr[0].set_title("FFCV") + axarr[0].imshow(as_img_ffcv) + axarr[1].set_title("PyTorch") + axarr[1].imshow(as_img_torch) + + plt.show() + f.clear() + + # --------------------------------------------- + # Checks to verify that ffcv == pytorch + # Note: when using certain transformations such as Normalize, + # having `almost_same` True is usually sufficient even if + # `all_same` is False. + all_same = True + almost_same = True + correct_device = True + + for f, t in zip(ffcv_batch, torch_batch): + print(f.shape, t.shape) + correct_device = correct_device and f.device == device + f = f.cpu() + t = t.cpu() + + exactly_same = torch.equal(f, t) + all_same = all_same and exactly_same + + if f.dtype.is_floating_point: + almost_same = almost_same and ( + torch.sum(torch.abs(f - t) > 1e-6).item() == 0 + ) + else: + almost_same = almost_same and exactly_same + + print("all_same", all_same) + print("almost_same", almost_same) + print("correct_device", correct_device) + # --------------------------------------------- + + # Keep this break if it is sufficient to analyze only the first batch + break + + # Print batch separator + print("." * 40) + + end_time = time.time() + print("Loop time:", end_time - start_time, "seconds") + + +# When running on VSCode (with Python extension), you will notice additional +# controls such as "Run Cell", "Run Above", ... +# The recommended way to use this script +# is to first "Run Above" and then "Run Cell". +# %% +main(0) diff --git a/examples/ffcv/ffcv_try_speed.py b/examples/ffcv/ffcv_try_speed.py new file mode 100644 index 000000000..ed487d2e2 --- /dev/null +++ b/examples/ffcv/ffcv_try_speed.py @@ -0,0 +1,175 @@ +""" +This scripts can be used to measure the speed of the FFCV dataloader. + +Note: this is not the correct way to use FFCV in Avalanche. For a proper +example, please refer to `ffcv_enable.py`. This script should be used +to measure speed only. +""" + +import argparse +import time +from typing import Tuple + +import torch +import torch.optim.lr_scheduler +from avalanche.benchmarks import SplitMNIST +from avalanche.benchmarks.classic.ccifar100 import SplitCIFAR100 +from avalanche.benchmarks.classic.core50 import CORe50 +from avalanche.benchmarks.classic.ctiny_imagenet import SplitTinyImageNet +from avalanche.benchmarks.utils.data import AvalancheDataset +from avalanche.benchmarks.utils.ffcv_support import ( + HybridFfcvLoader, + enable_ffcv, +) +from avalanche.training.determinism.rng_manager import RNGManager + +from ffcv.transforms import ToTensor + +from torchvision.transforms import Compose, ToTensor, Normalize + +from torch.utils.data import DataLoader +from tqdm import tqdm + + +def main(cuda: int): + # --- CONFIG + device = torch.device(f"cuda:{cuda}" if torch.cuda.is_available() else "cpu") + RNGManager.set_random_seeds(1234) + + benchmark_type = "cifar100" + + # --- BENCHMARK CREATION + if benchmark_type == "mnist": + benchmark = SplitMNIST( + n_experiences=5, seed=42, class_ids_from_zero_from_first_exp=True + ) + elif benchmark_type == "core50": + benchmark = CORe50() + benchmark.n_classes = 50 + elif benchmark_type == "cifar100": + cifar100_train_transform = Compose( + [ + ToTensor(), + Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)), + ] + ) + + cifar100_eval_transform = Compose( + [ + ToTensor(), + Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)), + ] + ) + benchmark = SplitCIFAR100( + 5, + seed=1234, + shuffle=True, + train_transform=cifar100_train_transform, + eval_transform=cifar100_eval_transform, + ) + elif benchmark_type == "tinyimagenet": + benchmark = SplitTinyImageNet() + else: + raise RuntimeError("Unknown benchmark") + + # Note: when Numba uses TBB, then 20 is the limit number of workers + # However, this limit does not apply when using OpenMP + # (which may be faster...). If you want to test using OpenMP, then + # run this script with the following command: + # NUMBA_THREADING_LAYER=omp NUMBA_NUM_THREADS=32 python benchmark_ffcv.py + for num_workers in [8, 16, 32]: + print("num_workers =", num_workers) + print("device =", device) + benchmark_pytorch_speed( + benchmark, device=device, num_workers=num_workers, epochs=4 + ) + benchmark_ffcv_speed( + benchmark, + f"./ffcv_test_{benchmark_type}", + device=device, + num_workers=num_workers, + epochs=4, + ) + + +def benchmark_ffcv_speed( + benchmark, path, device, batch_size=128, num_workers=1, epochs=1 +): + print("Testing FFCV Loader speed") + + all_train_dataset = [x.dataset for x in benchmark.train_stream] + avl_set = AvalancheDataset(all_train_dataset) + avl_set = avl_set.train() + + start_time = time.time() + enable_ffcv( + benchmark, + path, + device, + dict(num_workers=num_workers), + print_summary=False, # Better keep this true on non-benchmarking code + ) + end_time = time.time() + print("FFCV preparation time:", end_time - start_time, "seconds") + + start_time = time.time() + ffcv_loader = HybridFfcvLoader( + avl_set, + None, + batch_size, + dict(num_workers=num_workers, drop_last=True), + device=device, + print_ffcv_summary=False, + ) + + for _ in tqdm(range(epochs)): + for batch in ffcv_loader: + # "Touch" tensors to make sure they already moved to GPU + batch[0][0] + batch[-1][0] + + end_time = time.time() + print("FFCV time:", end_time - start_time, "seconds") + + +def benchmark_pytorch_speed(benchmark, device, batch_size=128, num_workers=1, epochs=1): + print("Testing PyTorch Loader speed") + + all_train_dataset = [x.dataset for x in benchmark.train_stream] + avl_set = AvalancheDataset(all_train_dataset) + avl_set = avl_set.train() + + start_time = time.time() + torch_loader = DataLoader( + avl_set, + batch_size, + num_workers=num_workers, + pin_memory=True, + drop_last=True, + shuffle=False, + persistent_workers=True, + ) + + batch: Tuple[torch.Tensor] + for _ in tqdm(range(epochs)): + for batch in torch_loader: + batch = tuple(x.to(device, non_blocking=True) for x in batch) + + # "Touch" tensors to make sure they already moved to GPU + batch[0][0] + batch[-1][0] + + end_time = time.time() + print("PyTorch time:", end_time - start_time, "seconds") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--cuda", + type=int, + default=0, + help="Select zero-indexed cuda device. -1 to use CPU.", + ) + args = parser.parse_args() + main(args.cuda) diff --git a/tests/benchmarks/ffcv/__init__.py b/tests/benchmarks/ffcv/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/benchmarks/ffcv/test_ffcv_support.py b/tests/benchmarks/ffcv/test_ffcv_support.py new file mode 100644 index 000000000..bdf8b7ad6 --- /dev/null +++ b/tests/benchmarks/ffcv/test_ffcv_support.py @@ -0,0 +1,152 @@ +import os +import random +import tempfile +import unittest +import torch +from torch.utils.data.sampler import ( + BatchSampler, + SubsetRandomSampler, + SequentialSampler, +) +from torch.utils.data.dataloader import DataLoader + +from avalanche.benchmarks.classic.cmnist import SplitMNIST +from avalanche.benchmarks.utils.data_loader import MultiDatasetSampler +from avalanche.benchmarks.utils import AvalancheDataset, DataAttribute +from torchvision.transforms import Normalize + +try: + import ffcv + + skip = False +except ImportError: + skip = True + + +class FFCVSupportTests(unittest.TestCase): + @unittest.skipIf(skip, reason="Need ffcv to run these tests") + def test_simple_scenario(self): + from avalanche.benchmarks.utils.ffcv_support.ffcv_components import ( + enable_ffcv, + HybridFfcvLoader, + ) + + train_transform = Normalize((0.1307,), (0.3081,)) + + eval_transform = Normalize((0.1307,), (0.3081,)) + + use_gpu = str(os.environ["USE_GPU"]).lower() in ["true", "1"] + + if use_gpu: + device = torch.device("cuda:0") + else: + device = torch.device("cpu") + + benchmark = SplitMNIST( + 5, + seed=4321, + shuffle=True, + return_task_id=True, + train_transform=train_transform, + eval_transform=eval_transform, + ) + + with tempfile.TemporaryDirectory() as write_dir: + num_workers = 4 + + enable_ffcv( + benchmark=benchmark, + write_dir=write_dir, + device=device, + ffcv_parameters=dict(num_workers=num_workers), + print_summary=False, + ) + + dataset_0 = benchmark.train_stream[0].dataset + dataset_1 = benchmark.train_stream[1].dataset + + subset_indices = list(range(0, len(dataset_0), 5)) + random.shuffle(subset_indices) + + generator_0_a = torch.Generator() + generator_0_a.manual_seed(2147483647) + + generator_0_b = torch.Generator() + generator_0_b.manual_seed(2147483647) + + sampler_0_a = BatchSampler( + SubsetRandomSampler(subset_indices, generator_0_a), + batch_size=12, + drop_last=True, + ) + + sampler_0_b = BatchSampler( + SubsetRandomSampler(subset_indices, generator_0_b), + batch_size=12, + drop_last=True, + ) + + sampler_0_a_lst = list(sampler_0_a) + sampler_0_b_lst = list(sampler_0_b) + self.assertEqual(sampler_0_a_lst, sampler_0_b_lst) + + sampler_1 = BatchSampler( + SequentialSampler(dataset_1), batch_size=123, drop_last=False + ) + + batch_sampler_a = MultiDatasetSampler( + [dataset_0, dataset_1], + [sampler_0_a, sampler_1], + oversample_small_datasets=True, + ) + + batch_sampler_b = MultiDatasetSampler( + [dataset_0, dataset_1], + [sampler_0_b, sampler_1], + oversample_small_datasets=True, + ) + + batch_sampler_a_lst = list(batch_sampler_a) + batch_sampler_b_lst = list(batch_sampler_b) + self.assertEqual(batch_sampler_a_lst, batch_sampler_b_lst) + + sum_len = len(dataset_0) + len(dataset_1) + concat_dataset = AvalancheDataset( + [dataset_0, dataset_1], + data_attributes=[ + DataAttribute( + list(range(sum_len)), "custom_attr", use_in_getitem=True + ) + ], + ) + + ffcv_data_loader = HybridFfcvLoader( + concat_dataset, + batch_sampler=batch_sampler_a, + ffcv_loader_parameters=dict(num_workers=num_workers, drop_last=False), + device=device, + persistent_workers=False, + print_ffcv_summary=False, + start_immediately=False, + ) + + pytorch_loader = DataLoader(concat_dataset, batch_sampler=batch_sampler_b) + + self.assertEqual(len(ffcv_data_loader), len(pytorch_loader)) + + for i, (ffcv_batch, torch_batch) in enumerate( + zip(ffcv_data_loader, pytorch_loader) + ): + for f, t in zip(ffcv_batch, torch_batch): + self.assertEqual(f.device, device) + f = f.cpu() + t = t.cpu() + + if f.dtype.is_floating_point: + self.assertTrue(torch.sum(torch.abs(f - t) > 1e-6).item() == 0) + else: + self.assertTrue(torch.equal(f, t)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_transformations.py b/tests/test_transformations.py new file mode 100644 index 000000000..2d96dc559 --- /dev/null +++ b/tests/test_transformations.py @@ -0,0 +1,298 @@ +import copy +import unittest +from avalanche.benchmarks.datasets.dataset_utils import default_dataset_location +from avalanche.benchmarks.utils.data import AvalancheDataset +from avalanche.benchmarks.utils.dataset_traversal_utils import single_flat_dataset +from avalanche.benchmarks.utils.detection_dataset import DetectionDataset +from avalanche.benchmarks.classic.cmnist import SplitMNIST +from avalanche.benchmarks.utils.transform_groups import TransformGroups + +from avalanche.benchmarks.utils.transforms import ( + MultiParamCompose, + MultiParamTransformCallable, + TupleTransform, + flat_transforms_recursive, +) + +import torch +from PIL import ImageChops +from torch import Tensor +from torch.utils.data import DataLoader, ConcatDataset +from torchvision.datasets import MNIST +from torchvision.transforms import ( + ToTensor, + Compose, + CenterCrop, + Normalize, + Lambda, + RandomHorizontalFlip, +) +from torchvision.transforms.functional import to_tensor +from PIL.Image import Image + +from tests.unit_tests_utils import get_fast_detection_datasets + + +def pil_images_equal(img_a, img_b): + diff = ImageChops.difference(img_a, img_b) + + return not diff.getbbox() + + +def zero_if_label_2(img_tensor: Tensor, class_label): + if int(class_label) == 2: + torch.full(img_tensor.shape, 0.0, out=img_tensor) + + return img_tensor, class_label + + +def get_mbatch(data, batch_size=5): + dl = DataLoader( + data, shuffle=False, batch_size=batch_size, collate_fn=data.collate_fn + ) + return next(iter(dl)) + + +class TransformsTest(unittest.TestCase): + def test_multi_param_transform_callable(self): + dataset: DetectionDataset + dataset, _ = get_fast_detection_datasets() + + boxes = [] + i = 0 + while len(boxes) == 0: + x_orig, y_orig, t_orig = dataset[i] + boxes = y_orig["boxes"] + i += 1 + i -= 1 + + x_expect = to_tensor(copy.deepcopy(x_orig)) + x_expect[0][0] += 1 + + y_expect = copy.deepcopy(y_orig) + y_expect["boxes"][0][0] += 1 + + def do_something_xy(img, target): + img = to_tensor(img) + img[0][0] += 1 + target["boxes"][0][0] += 1 + return img, target + + uut = MultiParamTransformCallable(do_something_xy) + + # Test __eq__ + uut_eq = MultiParamTransformCallable(do_something_xy) + self.assertTrue(uut == uut_eq) + self.assertTrue(uut_eq == uut) + + x, y, t = uut(*dataset[i]) + + self.assertIsInstance(x, torch.Tensor) + self.assertIsInstance(y, dict) + self.assertIsInstance(t, int) + + self.assertTrue(torch.equal(x_expect, x)) + keys = set(y_expect.keys()) + self.assertSetEqual(keys, set(y.keys())) + + for k in keys: + self.assertTrue(torch.equal(y_expect[k], y[k]), msg=f"Wrong {k}") + + def test_multi_param_compose(self): + dataset: DetectionDataset + dataset, _ = get_fast_detection_datasets() + + assert_called = 0 + + def do_something_xy(img: Tensor, target): + nonlocal assert_called + assert_called += 1 + img = img.clone() + img[0][0] += 1 + target["boxes"][0][0] += 1 + return img, target + + t_x = lambda x, y: (to_tensor(x), y) + t_xy = do_something_xy + t_x_1_element = ToTensor() + + boxes = [] + i = 0 + while len(boxes) == 0: + x_orig, y_orig, t_orig = dataset[i] + boxes = y_orig["boxes"] + i += 1 + i -= 1 + + x_expect = to_tensor(copy.deepcopy(x_orig)) + x_expect[0][0] += 1 + + y_expect = copy.deepcopy(y_orig) + y_expect["boxes"][0][0] += 1 + + uut_2 = MultiParamCompose([t_x, t_xy]) + + # Test __eq__ + uut_2_eq = MultiParamCompose([t_x, t_xy]) + self.assertTrue(uut_2 == uut_2_eq) + self.assertTrue(uut_2_eq == uut_2) + + with self.assertWarns(Warning): + # Assert that the following warn is raised: + # "Transformations define a different number of parameters. ..." + uut_1 = MultiParamCompose([t_x_1_element, t_xy]) + + for uut, uut_type in zip((uut_1, uut_2), ("uut_1", "uut_2")): + with self.subTest(uut_type=uut_type): + initial_assert_called = assert_called + + x, y, t = uut(*dataset[i]) + + self.assertEqual(initial_assert_called + 1, assert_called) + + self.assertIsInstance(x, torch.Tensor) + self.assertIsInstance(y, dict) + self.assertIsInstance(t, int) + + self.assertTrue(torch.equal(x_expect, x)) + keys = set(y_expect.keys()) + self.assertSetEqual(keys, set(y.keys())) + + for k in keys: + self.assertTrue(torch.equal(y_expect[k], y[k]), msg=f"Wrong {k}") + + def test_tuple_transform(self): + dataset = MNIST(root=default_dataset_location("mnist"), download=True) + + t_x = ToTensor() + t_y = lambda element: element + 1 + t_bad = lambda element: element - 1 + + uut = TupleTransform([t_x, t_y]) + + uut_eq = TupleTransform( + (t_x, t_y) # Also test with a tuple instead of a list here + ) + + uut_not_x = TupleTransform([None, t_y]) + + uut_bad = TupleTransform((t_x, t_y, t_bad)) + + x_orig, y_orig = dataset[0] + + # Test with x transform + x, y = uut(*dataset[0]) + + self.assertIsInstance(x, torch.Tensor) + self.assertIsInstance(y, int) + + self.assertTrue(torch.equal(to_tensor(x_orig), x)) + self.assertEqual(y_orig + 1, y) + + # Test without x transform + x, y = uut_not_x(*dataset[0]) + + self.assertIsInstance(x, Image) + self.assertIsInstance(y, int) + + self.assertEqual(x_orig, x) + self.assertEqual(y_orig + 1, y) + + # Check __eq__ works + self.assertTrue(uut == uut_eq) + self.assertTrue(uut_eq == uut) + + self.assertFalse(uut == uut_not_x) + self.assertFalse(uut_not_x == uut) + + with self.assertRaises(Exception): + # uut_bad has 3 transforms, which is incorrect + uut_bad(*dataset[0]) + + def test_flat_transforms_recursive_only_torchvision(self): + x_transform = ToTensor() + x_transform_list = [CenterCrop(24), Normalize(0.5, 0.1)] + x_transform_composed = Compose(x_transform_list) + + expected_x = [x_transform] + x_transform_list + + # Single transforms checks + self.assertSequenceEqual( + [x_transform], flat_transforms_recursive([x_transform], 0) + ) + + self.assertSequenceEqual( + [x_transform], flat_transforms_recursive(x_transform, 0) + ) + + self.assertSequenceEqual( + x_transform_list, flat_transforms_recursive(x_transform_list, 0) + ) + + self.assertSequenceEqual( + x_transform_list, flat_transforms_recursive(x_transform_composed, 0) + ) + + # Hybrid list checks + self.assertSequenceEqual( + expected_x, + flat_transforms_recursive([x_transform, x_transform_composed], 0), + ) + + def test_flat_transforms_recursive_from_dataset(self): + x_transform = ToTensor() + x_transform_list = [CenterCrop(24), Normalize(0.5, 0.1)] + x_transform_additional = RandomHorizontalFlip(p=0.2) + x_transform_composed = Compose(x_transform_list) + + expected_x = [x_transform] + x_transform_list + [x_transform_additional] + + y_transform = Lambda(lambda x: max(0, x - 1)) + + dataset = MNIST( + root=default_dataset_location("mnist"), download=True, transform=x_transform + ) + + transform_group = TransformGroups( + transform_groups={ + "train": TupleTransform([x_transform_composed, y_transform]) + } + ) + + transform_group_additional_1a = TransformGroups( + transform_groups={"train": TupleTransform([x_transform_additional, None])} + ) + transform_group_additional_1b = TransformGroups( + transform_groups={"train": TupleTransform([x_transform_additional, None])} + ) + + avl_dataset = AvalancheDataset([dataset], transform_groups=transform_group) + + avl_subset_1 = avl_dataset.subset([1, 2, 3]) + avl_subset_2 = avl_dataset.subset([5, 6, 7]) + + avl_subset_1 = AvalancheDataset( + [avl_subset_1], transform_groups=transform_group_additional_1a + ) + avl_subset_2 = AvalancheDataset( + [avl_subset_2], transform_groups=transform_group_additional_1b + ) + + for concat_type, avl_concat in zip( + ["avalanche", "pytorch"], + [ + avl_subset_1.concat(avl_subset_2), + ConcatDataset([avl_subset_1, avl_subset_2]), + ], + ): + with self.subTest("Concatenation type", concat_type=concat_type): + _, _, transforms = single_flat_dataset(avl_concat) + x_flattened = flat_transforms_recursive(transforms, 0) + y_flattened = flat_transforms_recursive(transforms, 1) + + self.assertSequenceEqual(expected_x, x_flattened) + self.assertSequenceEqual([y_transform], y_flattened) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests_utils.py b/tests/unit_tests_utils.py index 2eb74b551..febf6bebe 100644 --- a/tests/unit_tests_utils.py +++ b/tests/unit_tests_utils.py @@ -1,10 +1,12 @@ +import copy +import itertools from os.path import expanduser import os import random import torch from PIL.Image import Image -from sklearn.datasets import make_classification +from sklearn.datasets import make_blobs, make_classification from sklearn.model_selection import train_test_split import numpy as np from torch.utils.data import TensorDataset, Dataset @@ -14,6 +16,9 @@ from torchvision.transforms import Compose, ToTensor from avalanche.benchmarks import nc_benchmark +from avalanche.benchmarks.utils.detection_dataset import ( + make_detection_dataset, +) # Environment variable used to skip some expensive tests that are very unlikely @@ -214,6 +219,149 @@ def set_deterministic_run(seed=0): torch.backends.cudnn.deterministic = True +class _DummyDetectionDataset: + """ + A dataset that makes a defensive copy of the + targets before returning them. + + Alas, many detection transformations, including the + ones in the torchvision repository, modify bounding boxes + (and other elements) in place. + Luckly, images seem to be never modified in place. + """ + + def __init__(self, images, targets): + self.images = images + self.targets = targets + + def __len__(self): + return len(self.images) + + def __getitem__(self, index): + return self.images[index], copy.deepcopy(self.targets[index]) + + +def get_fast_detection_datasets( + n_images=30, + max_elements_per_image=10, + n_samples_per_class=20, + n_classes=10, + seed=None, + image_size=64, + n_test_images=5, +): + if seed is not None: + np.random.seed(seed) + random.seed(seed) + + assert n_images * max_elements_per_image >= n_samples_per_class * n_classes + assert n_test_images < n_images + assert n_test_images > 0 + + base_n_per_images = (n_samples_per_class * n_classes) // n_images + additional_elements = (n_samples_per_class * n_classes) % n_images + to_allocate = np.full(n_images, base_n_per_images) + to_allocate[:additional_elements] += 1 + np.random.shuffle(to_allocate) + classes_elements = np.repeat(np.arange(n_classes), n_samples_per_class) + np.random.shuffle(classes_elements) + + import matplotlib.colors as mcolors + + forms = ["ellipse", "rectangle", "line", "arc"] + colors = list(mcolors.TABLEAU_COLORS.values()) + combs = list(itertools.product(forms, colors)) + random.shuffle(combs) + + generated_images = [] + generated_targets = [] + for img_idx in range(n_images): + n_to_allocate = to_allocate[img_idx] + base_alloc_idx = to_allocate[:img_idx].sum() + classes_to_instantiate = classes_elements[ + base_alloc_idx : base_alloc_idx + n_to_allocate + ] + + _, _, clusters = make_blobs( + n_to_allocate, + n_features=2, + centers=n_to_allocate, + center_box=(0, image_size - 1), + random_state=seed, + return_centers=True, + ) + + from PIL import Image as ImageApi + from PIL import ImageDraw + + im = ImageApi.new("RGB", (image_size, image_size)) + draw = ImageDraw.Draw(im) + + target = { + "boxes": torch.zeros((n_to_allocate, 4), dtype=torch.float32), + "labels": torch.zeros((n_to_allocate,), dtype=torch.long), + "image_id": torch.full((1,), img_idx, dtype=torch.long), + "area": torch.zeros((n_to_allocate,), dtype=torch.float32), + "iscrowd": torch.zeros((n_to_allocate,), dtype=torch.long), + } + + obj_sizes = np.random.uniform( + low=image_size * 0.1 * 0.95, + high=image_size * 0.1 * 1.05, + size=(n_to_allocate,), + ) + for center_idx, center in enumerate(clusters): + obj_size = float(obj_sizes[center_idx]) + class_to_gen = classes_to_instantiate[center_idx] + + class_form, class_color = combs[class_to_gen] + + left = center[0] - obj_size + top = center[1] - obj_size + right = center[0] + obj_size + bottom = center[1] + obj_size + ltrb = (left, top, right, bottom) + if class_form == "ellipse": + draw.ellipse(ltrb, fill=class_color) + elif class_form == "rectangle": + draw.rectangle(ltrb, fill=class_color) + elif class_form == "line": + draw.line(ltrb, fill=class_color, width=max(1, int(obj_size * 0.25))) + elif class_form == "arc": + draw.arc(ltrb, fill=class_color, start=45, end=200) + else: + raise RuntimeError("Unsupported form") + + target["boxes"][center_idx] = torch.as_tensor(ltrb) + target["labels"][center_idx] = class_to_gen + target["area"][center_idx] = obj_size**2 + + generated_images.append(np.array(im)) + generated_targets.append(target) + im.close() + + test_indices = set( + np.random.choice(n_images, n_test_images, replace=False).tolist() + ) + train_images = [x for i, x in enumerate(generated_images) if i not in test_indices] + test_images = [x for i, x in enumerate(generated_images) if i in test_indices] + + train_targets = [ + x for i, x in enumerate(generated_targets) if i not in test_indices + ] + test_targets = [x for i, x in enumerate(generated_targets) if i in test_indices] + + return make_detection_dataset( + _DummyDetectionDataset(train_images, train_targets), + targets=train_targets, + task_labels=0, + ), make_detection_dataset( + _DummyDetectionDataset(test_images, test_targets), + targets=test_targets, + task_labels=0, + ) + + __all__ = [ "common_setups", "load_benchmark", @@ -221,4 +369,5 @@ def set_deterministic_run(seed=0): "load_experience_train_eval", "get_device", "set_deterministic_run", + "get_fast_detection_datasets", ]