From 50a1565eefe3cbff38ac7b8cd77e053cd2524df1 Mon Sep 17 00:00:00 2001 From: Moshiko Raboh <86309179+mosheraboh@users.noreply.github.com> Date: Wed, 1 Feb 2023 14:34:25 +0200 Subject: [PATCH] optimized train time for a use case of small samples and large batch (#268) --- fuse/data/utils/collates.py | 8 ++- fuse/dl/lightning/pl_funcs.py | 6 +- fuse/eval/examples/examples_segmentation.py | 4 +- fuse/eval/metrics/metrics_common.py | 65 ++++++++++++++------- fuse/utils/data/collate.py | 62 +++++++++----------- fuse/utils/ndict.py | 44 +++++++++----- 6 files changed, 114 insertions(+), 75 deletions(-) diff --git a/fuse/data/utils/collates.py b/fuse/data/utils/collates.py index a9b14709e..3563bc950 100644 --- a/fuse/data/utils/collates.py +++ b/fuse/data/utils/collates.py @@ -46,7 +46,7 @@ def __init__( ): """ :param skip_keys: do not collect the listed keys - :param keep_keys: specifies a list of keys to collect. missing keep_keys are skipped. + :param keep_keys: specifies a list of keys to collect. See raise_error_key_missing argument (dealing with missing keys). :param special_handlers_keys: per key specify a callable which gets as an input list of values and convert it to a batch. The rest of the keys will be converted to batch using PyTorch default collate_fn() Example of such Callable can be seen in the CollateDefault.pad_all_tensors_to_same_size. @@ -68,9 +68,11 @@ def __call__(self, samples: List[Dict]) -> Dict: batch_dict = NDict() # collect all keys - keys = self._collect_all_keys(samples) if self._keep_keys: - keys = [k for k in keys if k in self._keep_keys] + keys = self._keep_keys + else: + keys = self._collect_all_keys(samples) + # collect values for key in keys: diff --git a/fuse/dl/lightning/pl_funcs.py b/fuse/dl/lightning/pl_funcs.py index f842c3d24..55ea96115 100644 --- a/fuse/dl/lightning/pl_funcs.py +++ b/fuse/dl/lightning/pl_funcs.py @@ -124,7 +124,11 @@ def convert_predictions_to_dataframe(predictions: List[NDict]) -> pd.DataFrame: predictions_per_sample = [] for elem in predictions: predictions_per_sample += uncollate(elem) - keys = predictions_per_sample[0].keypaths() + if isinstance(predictions_per_sample[0], NDict): + keys = predictions_per_sample[0].keypaths() + else: # dict + keys = predictions_per_sample[0].keys() + for key in keys: values[key] = [elem[key] for elem in predictions_per_sample] diff --git a/fuse/eval/examples/examples_segmentation.py b/fuse/eval/examples/examples_segmentation.py index 2643a91d1..6b6a52206 100644 --- a/fuse/eval/examples/examples_segmentation.py +++ b/fuse/eval/examples/examples_segmentation.py @@ -19,7 +19,7 @@ from pycocotools.coco import COCO import numpy as np import nibabel as nib -from fuse.utils import set_seed +from fuse.utils import set_seed, NDict from fuse.eval.metrics.segmentation.metrics_segmentation_common import ( MetricDice, @@ -209,7 +209,7 @@ def data_iter(): sample_dict["pred.array"] = np.array([(1.0, 0.0), (0.0, 1.0), (1.0, 0.0), (0.0, 1.0)]) sample_dict["label.array"] = np.array([(1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0)]) sample_dict["pixel_weight"] = {"1": np.array([(0.125, 0.125), (0.125, 0.125), (0.125, 0.125), (0.125, 0.125)])} - yield sample_dict + yield NDict(sample_dict) # list of metrics metrics = OrderedDict( diff --git a/fuse/eval/metrics/metrics_common.py b/fuse/eval/metrics/metrics_common.py index 6a02a7211..38e2a8da2 100644 --- a/fuse/eval/metrics/metrics_common.py +++ b/fuse/eval/metrics/metrics_common.py @@ -82,19 +82,23 @@ def __init__( self, pre_collect_process_func: Optional[Callable] = None, post_collect_process_func: Optional[Callable] = None, + batch_pre_collect_process_func: Optional[Callable] = None, **keys_to_collect: Dict[str, str], ): """ - :param pre_collect_process_func: Optional callable - the callable will get as an input a sample_dict and can preprocess it if required - :param post_collect_process_func: Optional callable - custom process func that convert the fields to be collected to the values that will actually be collected - the callable will get as an input the collected values of a single sample - and can return either a single value or dictionary. The returned values will be collected under the name "post_args" + :param pre_collect_process_func: Optional callable - the callable will get as an input a sample_dict and can preprocess it if required before collection. + Consider using batch_pre_collect_process_func instead to optimize the running time when working with large batch size. + :param post_collect_process_func: Optional callable - custom process func - used to evaluate in a sample level and keep only the result. + Typically used in methods such as segmentation to avoid from storing all the images until the end of the epoch. + Can return either a single value or dictionary. The returned values will be collected under the name "post_args" + :param batch_pre_collect_process_func: Optional callable - the callable will get as an input a batch_dict and can preprocess it if required before collection. :param keys_to_collect: specify the keys you want to collect from the source data """ super().__init__() # store input self._pre_collect_process_func = pre_collect_process_func self._post_collect_process_func = post_collect_process_func + self._batch_pre_collect_process_func = batch_pre_collect_process_func self._keys_to_collect = copy.copy(keys_to_collect) self._id_keys = MetricCollector.DEFAULT_ID_KEYS @@ -108,39 +112,54 @@ def collect(self, batch: Dict) -> None: if not isinstance(batch, NDict): batch = NDict(batch) - samples = uncollate(batch) - # If in distributed mode (multi gpu training) we shall gather the result from all the machine to evaluate with respect to the entire batch. if dist.is_initialized(): world_size = dist.get_world_size() # num of gpus samples_gather = [None for rank in range(world_size)] # samples_gather[i] will have the 'samples' value of the i's GPU - dist.all_gather_object(samples_gather, samples) + dist.all_gather_object(samples_gather, batch) # union all the GPU's samples into one samples list samples = [] for rank in range(world_size): samples += samples_gather[rank] - for sample in samples: - sample_to_collect = {} + if self._pre_collect_process_func is not None or self._post_collect_process_func is not None: + samples = uncollate(batch) + for sample in samples: + + if self._pre_collect_process_func is not None: + sample = self._pre_collect_process_func(sample) + + sample = NDict(sample) + + sample_to_collect = {} + for name, key in self._keys_to_collect.items(): + value = sample[key] + if isinstance(value, torch.Tensor): + value = value.detach().cpu().numpy() - if self._pre_collect_process_func is not None: - sample = NDict(self._pre_collect_process_func(sample)) + sample_to_collect[name] = value + if self._post_collect_process_func is not None: + sample_to_collect = {"post_args": self._post_collect_process_func(**sample_to_collect)} + + # store it - assumes batch dimension? What about single sample? + for name in sample_to_collect: + self._collected_data[name].append(sample_to_collect[name]) + else: # work in a batch level - optimized for large batch size + if self._batch_pre_collect_process_func is not None: + batch = self._batch_pre_collect_process_func(batch) + batch_to_collect = {} for name, key in self._keys_to_collect.items(): - value = sample[key] + value = batch[key] if isinstance(value, torch.Tensor): value = value.detach().cpu().numpy() - sample_to_collect[name] = value - - if self._post_collect_process_func is not None: - sample_to_collect = {"post_args": self._post_collect_process_func(**sample_to_collect)} + batch_to_collect[name] = value - # store it - assumes batch dimension? What about single sample? - for name in sample_to_collect: - self._collected_data[name].append(sample_to_collect[name]) + for name in batch_to_collect: + self._collected_data[name].extend(batch_to_collect[name]) # extract ids and store it in self._collected_ids ids = None @@ -257,6 +276,7 @@ def __init__( self, pre_collect_process_func: Optional[Callable] = None, post_collect_process_func: Optional[Callable] = None, + batch_pre_collect_process_func: Optional[Callable] = None, external_data_collector: Optional[MetricCollector] = None, extract_ids: bool = False, **kwargs, @@ -276,7 +296,12 @@ def __init__( } self._value_args = {n: k for n, k in kwargs.items() if k is not None and not isinstance(k, str)} self._collector = ( - MetricCollector(pre_collect_process_func, post_collect_process_func, **self._keys_to_collect) + MetricCollector( + pre_collect_process_func, + post_collect_process_func, + batch_pre_collect_process_func, + **self._keys_to_collect, + ) if external_data_collector is None else external_data_collector ) diff --git a/fuse/utils/data/collate.py b/fuse/utils/data/collate.py index 56a0fbd71..d2c5c77b5 100644 --- a/fuse/utils/data/collate.py +++ b/fuse/utils/data/collate.py @@ -24,8 +24,6 @@ import torch import numpy as np -import math - class CollateToBatchList(Callable): """ @@ -36,7 +34,6 @@ def __init__( self, skip_keys: Sequence[str] = tuple(), raise_error_key_missing: bool = True, - missing_values: Sequence[str] = (None, "N/A"), ): """ :param skip_keys: do not collect the listed keys @@ -44,7 +41,6 @@ def __init__( """ self._skip_keys = skip_keys self._raise_error_key_missing = raise_error_key_missing - self._missing_values = missing_values def __call__(self, samples: List[Dict]) -> Dict: """ @@ -98,22 +94,15 @@ def _collect_values_to_list(self, samples: List[str], key: str) -> Tuple[List, b has_missing_values = False collected_values = [] for index, sample in enumerate(samples): - sample = NDict(sample) - if key not in sample: + try: + value = sample[key] + except: has_error = True has_missing_values = True if self._raise_error_key_missing: raise Exception(f"Error: key {key} does not exist in sample {index}: {sample}") else: value = None - else: - value = sample[key] - if isinstance(value, float) and math.isnan(value): - has_missing_values = True - for missing_value in self._missing_values: - if type(missing_value) is type(value) and value == missing_value: - has_missing_values = True - break collected_values.append(value) return collected_values, has_error, has_missing_values @@ -124,49 +113,50 @@ def uncollate(batch: Dict) -> List[Dict]: Reverse collate method Gets a batch_dict and convert it back to list of samples """ - samples = [] - if not isinstance(batch, NDict): - batch = NDict(batch) - keys = batch.keypaths() - # empty batch - if not keys: - return samples + if not batch.keys(): + return [] - if "data.sample_id" in keys: + if isinstance(batch, NDict): + batch = batch.flatten() + + # infer batch size + if "data.sample_id" in batch: batch_size = len(batch["data.sample_id"]) else: batch_size = None - if batch_size is None: + keys = batch.keys() + for key in keys: if isinstance(batch[key], torch.Tensor): batch_size = len(batch[key]) break - if batch_size is None: - for key in keys: - if isinstance(batch[key], (np.ndarray, list)): - batch_size = len(batch[key]) - break + if batch_size is None: + for key in keys: + if isinstance(batch[key], (np.ndarray, list)): + batch_size = len(batch[key]) + break if batch_size is None: return batch # assuming batch dict with no samples - for sample_index in range(batch_size): - sample = NDict() - for key in keys: - if isinstance(batch[key], (np.ndarray, torch.Tensor, list)): + keys = batch.keys() + samples = [{} for _ in range(batch_size)] + for key in keys: + values = batch[key] + for sample_index in range(batch_size): + + if isinstance(values, (np.ndarray, torch.Tensor, list)): try: - sample[key] = batch[key][sample_index] + samples[sample_index][key] = values[sample_index] except IndexError: logging.error( f"Error - IndexError - key={key}, batch_size={batch_size}, type={type((batch[key]))}, len={len(batch[key])}" ) raise else: - sample[key] = batch[key] # broadcast single value for all batch - - samples.append(sample) + samples[sample_index][key] = values # broadcast single value for all batch return samples diff --git a/fuse/utils/ndict.py b/fuse/utils/ndict.py index a4621b5db..1372975e7 100644 --- a/fuse/utils/ndict.py +++ b/fuse/utils/ndict.py @@ -111,23 +111,41 @@ def flatten(self) -> dict: #you can use it to get a list of the flat keys: print(nx.flatten().keys()) """ + flat_dict = {} + NDict._flatten_static(self._stored, None, flat_dict) + return flat_dict - all_keys = {} - for key in self._stored: - if isinstance(self._stored[key], MutableMapping): - all_sub_keys = NDict(self[key]).flatten() - keys_to_add = {f"{key}.{sub_key}": all_sub_keys[sub_key] for sub_key in all_sub_keys} - all_keys.update(keys_to_add) - else: - all_keys[key] = self._stored[key] - - return all_keys + @staticmethod + def _flatten_static(item: Union[dict, Any], prefix: str, flat_dict: dict) -> None: + if isinstance(item, MutableMapping): + for key, value in item.items(): + if prefix is None: + cur_prefix = key + else: + cur_prefix = f"{prefix}.{key}" + NDict._flatten_static(value, cur_prefix, flat_dict) + else: + flat_dict[prefix] = item def keypaths(self) -> List[str]: """ returns a list of keypaths (i.e. "a.b.c.d") to all values in the nested dict """ - return list(self.flatten().keys()) + return NDict._keypaths_static(self._stored, None) + + @staticmethod + def _keypaths_static(item: Union[dict, Any], prefix: str) -> List[str]: + if isinstance(item, MutableMapping): + keys = [] + for key, value in item.items(): + if prefix is None: + cur_prefix = key + else: + cur_prefix = f"{prefix}.{key}" + keys += NDict._keypaths_static(value, cur_prefix) + return keys + else: + return [prefix] def keys(self) -> dict_keys: """ @@ -228,7 +246,7 @@ def pop(self, key: str) -> Any: del self[key] return res - def indices(self, indices: Union[torch.Tensor, numpy.ndarray]) -> dict: + def indices(self, indices: numpy.ndarray) -> dict: """ Extract the specified indices from each element in the dictionary (if possible) :param nested_dict: input dict @@ -240,7 +258,7 @@ def indices(self, indices: Union[torch.Tensor, numpy.ndarray]) -> dict: for key in all_keys: try: value = self[key] - if isinstance(value, numpy.ndarray) or isinstance(value, torch.Tensor): + if isinstance(value, (numpy.ndarray, torch.Tensor)): new_value = value[indices] elif isinstance(value, Sequence): new_value = [item for i, item in enumerate(value) if indices[i]]