Skip to content

Commit

Permalink
optimized train time for a use case of small samples and large batch (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mosheraboh committed Feb 1, 2023
1 parent 0e56bb9 commit 50a1565
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 75 deletions.
8 changes: 5 additions & 3 deletions fuse/data/utils/collates.py
Expand Up @@ -46,7 +46,7 @@ def __init__(
):
"""
:param skip_keys: do not collect the listed keys
:param keep_keys: specifies a list of keys to collect. missing keep_keys are skipped.
:param keep_keys: specifies a list of keys to collect. See raise_error_key_missing argument (dealing with missing keys).
:param special_handlers_keys: per key specify a callable which gets as an input list of values and convert it to a batch.
The rest of the keys will be converted to batch using PyTorch default collate_fn()
Example of such Callable can be seen in the CollateDefault.pad_all_tensors_to_same_size.
Expand All @@ -68,9 +68,11 @@ def __call__(self, samples: List[Dict]) -> Dict:
batch_dict = NDict()

# collect all keys
keys = self._collect_all_keys(samples)
if self._keep_keys:
keys = [k for k in keys if k in self._keep_keys]
keys = self._keep_keys
else:
keys = self._collect_all_keys(samples)

# collect values
for key in keys:

Expand Down
6 changes: 5 additions & 1 deletion fuse/dl/lightning/pl_funcs.py
Expand Up @@ -124,7 +124,11 @@ def convert_predictions_to_dataframe(predictions: List[NDict]) -> pd.DataFrame:
predictions_per_sample = []
for elem in predictions:
predictions_per_sample += uncollate(elem)
keys = predictions_per_sample[0].keypaths()
if isinstance(predictions_per_sample[0], NDict):
keys = predictions_per_sample[0].keypaths()
else: # dict
keys = predictions_per_sample[0].keys()

for key in keys:
values[key] = [elem[key] for elem in predictions_per_sample]

Expand Down
4 changes: 2 additions & 2 deletions fuse/eval/examples/examples_segmentation.py
Expand Up @@ -19,7 +19,7 @@
from pycocotools.coco import COCO
import numpy as np
import nibabel as nib
from fuse.utils import set_seed
from fuse.utils import set_seed, NDict

from fuse.eval.metrics.segmentation.metrics_segmentation_common import (
MetricDice,
Expand Down Expand Up @@ -209,7 +209,7 @@ def data_iter():
sample_dict["pred.array"] = np.array([(1.0, 0.0), (0.0, 1.0), (1.0, 0.0), (0.0, 1.0)])
sample_dict["label.array"] = np.array([(1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0)])
sample_dict["pixel_weight"] = {"1": np.array([(0.125, 0.125), (0.125, 0.125), (0.125, 0.125), (0.125, 0.125)])}
yield sample_dict
yield NDict(sample_dict)

# list of metrics
metrics = OrderedDict(
Expand Down
65 changes: 45 additions & 20 deletions fuse/eval/metrics/metrics_common.py
Expand Up @@ -82,19 +82,23 @@ def __init__(
self,
pre_collect_process_func: Optional[Callable] = None,
post_collect_process_func: Optional[Callable] = None,
batch_pre_collect_process_func: Optional[Callable] = None,
**keys_to_collect: Dict[str, str],
):
"""
:param pre_collect_process_func: Optional callable - the callable will get as an input a sample_dict and can preprocess it if required
:param post_collect_process_func: Optional callable - custom process func that convert the fields to be collected to the values that will actually be collected
the callable will get as an input the collected values of a single sample
and can return either a single value or dictionary. The returned values will be collected under the name "post_args"
:param pre_collect_process_func: Optional callable - the callable will get as an input a sample_dict and can preprocess it if required before collection.
Consider using batch_pre_collect_process_func instead to optimize the running time when working with large batch size.
:param post_collect_process_func: Optional callable - custom process func - used to evaluate in a sample level and keep only the result.
Typically used in methods such as segmentation to avoid from storing all the images until the end of the epoch.
Can return either a single value or dictionary. The returned values will be collected under the name "post_args"
:param batch_pre_collect_process_func: Optional callable - the callable will get as an input a batch_dict and can preprocess it if required before collection.
:param keys_to_collect: specify the keys you want to collect from the source data
"""
super().__init__()
# store input
self._pre_collect_process_func = pre_collect_process_func
self._post_collect_process_func = post_collect_process_func
self._batch_pre_collect_process_func = batch_pre_collect_process_func
self._keys_to_collect = copy.copy(keys_to_collect)
self._id_keys = MetricCollector.DEFAULT_ID_KEYS

Expand All @@ -108,39 +112,54 @@ def collect(self, batch: Dict) -> None:
if not isinstance(batch, NDict):
batch = NDict(batch)

samples = uncollate(batch)

# If in distributed mode (multi gpu training) we shall gather the result from all the machine to evaluate with respect to the entire batch.
if dist.is_initialized():
world_size = dist.get_world_size() # num of gpus
samples_gather = [None for rank in range(world_size)]
# samples_gather[i] will have the 'samples' value of the i's GPU
dist.all_gather_object(samples_gather, samples)
dist.all_gather_object(samples_gather, batch)

# union all the GPU's samples into one samples list
samples = []
for rank in range(world_size):
samples += samples_gather[rank]

for sample in samples:
sample_to_collect = {}
if self._pre_collect_process_func is not None or self._post_collect_process_func is not None:
samples = uncollate(batch)
for sample in samples:

if self._pre_collect_process_func is not None:
sample = self._pre_collect_process_func(sample)

sample = NDict(sample)

sample_to_collect = {}
for name, key in self._keys_to_collect.items():
value = sample[key]
if isinstance(value, torch.Tensor):
value = value.detach().cpu().numpy()

if self._pre_collect_process_func is not None:
sample = NDict(self._pre_collect_process_func(sample))
sample_to_collect[name] = value

if self._post_collect_process_func is not None:
sample_to_collect = {"post_args": self._post_collect_process_func(**sample_to_collect)}

# store it - assumes batch dimension? What about single sample?
for name in sample_to_collect:
self._collected_data[name].append(sample_to_collect[name])
else: # work in a batch level - optimized for large batch size
if self._batch_pre_collect_process_func is not None:
batch = self._batch_pre_collect_process_func(batch)
batch_to_collect = {}
for name, key in self._keys_to_collect.items():
value = sample[key]
value = batch[key]
if isinstance(value, torch.Tensor):
value = value.detach().cpu().numpy()

sample_to_collect[name] = value

if self._post_collect_process_func is not None:
sample_to_collect = {"post_args": self._post_collect_process_func(**sample_to_collect)}
batch_to_collect[name] = value

# store it - assumes batch dimension? What about single sample?
for name in sample_to_collect:
self._collected_data[name].append(sample_to_collect[name])
for name in batch_to_collect:
self._collected_data[name].extend(batch_to_collect[name])

# extract ids and store it in self._collected_ids
ids = None
Expand Down Expand Up @@ -257,6 +276,7 @@ def __init__(
self,
pre_collect_process_func: Optional[Callable] = None,
post_collect_process_func: Optional[Callable] = None,
batch_pre_collect_process_func: Optional[Callable] = None,
external_data_collector: Optional[MetricCollector] = None,
extract_ids: bool = False,
**kwargs,
Expand All @@ -276,7 +296,12 @@ def __init__(
}
self._value_args = {n: k for n, k in kwargs.items() if k is not None and not isinstance(k, str)}
self._collector = (
MetricCollector(pre_collect_process_func, post_collect_process_func, **self._keys_to_collect)
MetricCollector(
pre_collect_process_func,
post_collect_process_func,
batch_pre_collect_process_func,
**self._keys_to_collect,
)
if external_data_collector is None
else external_data_collector
)
Expand Down
62 changes: 26 additions & 36 deletions fuse/utils/data/collate.py
Expand Up @@ -24,8 +24,6 @@
import torch
import numpy as np

import math


class CollateToBatchList(Callable):
"""
Expand All @@ -36,15 +34,13 @@ def __init__(
self,
skip_keys: Sequence[str] = tuple(),
raise_error_key_missing: bool = True,
missing_values: Sequence[str] = (None, "N/A"),
):
"""
:param skip_keys: do not collect the listed keys
:param raise_error_key_missing: if False, will not raise an error if there are keys that do not exist in some of the samples. Instead will set those values to None.
"""
self._skip_keys = skip_keys
self._raise_error_key_missing = raise_error_key_missing
self._missing_values = missing_values

def __call__(self, samples: List[Dict]) -> Dict:
"""
Expand Down Expand Up @@ -98,22 +94,15 @@ def _collect_values_to_list(self, samples: List[str], key: str) -> Tuple[List, b
has_missing_values = False
collected_values = []
for index, sample in enumerate(samples):
sample = NDict(sample)
if key not in sample:
try:
value = sample[key]
except:
has_error = True
has_missing_values = True
if self._raise_error_key_missing:
raise Exception(f"Error: key {key} does not exist in sample {index}: {sample}")
else:
value = None
else:
value = sample[key]
if isinstance(value, float) and math.isnan(value):
has_missing_values = True
for missing_value in self._missing_values:
if type(missing_value) is type(value) and value == missing_value:
has_missing_values = True
break

collected_values.append(value)
return collected_values, has_error, has_missing_values
Expand All @@ -124,49 +113,50 @@ def uncollate(batch: Dict) -> List[Dict]:
Reverse collate method
Gets a batch_dict and convert it back to list of samples
"""
samples = []
if not isinstance(batch, NDict):
batch = NDict(batch)
keys = batch.keypaths()

# empty batch
if not keys:
return samples
if not batch.keys():
return []

if "data.sample_id" in keys:
if isinstance(batch, NDict):
batch = batch.flatten()

# infer batch size
if "data.sample_id" in batch:
batch_size = len(batch["data.sample_id"])
else:
batch_size = None

if batch_size is None:
keys = batch.keys()

for key in keys:
if isinstance(batch[key], torch.Tensor):
batch_size = len(batch[key])
break

if batch_size is None:
for key in keys:
if isinstance(batch[key], (np.ndarray, list)):
batch_size = len(batch[key])
break
if batch_size is None:
for key in keys:
if isinstance(batch[key], (np.ndarray, list)):
batch_size = len(batch[key])
break

if batch_size is None:
return batch # assuming batch dict with no samples

for sample_index in range(batch_size):
sample = NDict()
for key in keys:
if isinstance(batch[key], (np.ndarray, torch.Tensor, list)):
keys = batch.keys()
samples = [{} for _ in range(batch_size)]
for key in keys:
values = batch[key]
for sample_index in range(batch_size):

if isinstance(values, (np.ndarray, torch.Tensor, list)):
try:
sample[key] = batch[key][sample_index]
samples[sample_index][key] = values[sample_index]
except IndexError:
logging.error(
f"Error - IndexError - key={key}, batch_size={batch_size}, type={type((batch[key]))}, len={len(batch[key])}"
)
raise
else:
sample[key] = batch[key] # broadcast single value for all batch

samples.append(sample)
samples[sample_index][key] = values # broadcast single value for all batch

return samples
44 changes: 31 additions & 13 deletions fuse/utils/ndict.py
Expand Up @@ -111,23 +111,41 @@ def flatten(self) -> dict:
#you can use it to get a list of the flat keys:
print(nx.flatten().keys())
"""
flat_dict = {}
NDict._flatten_static(self._stored, None, flat_dict)
return flat_dict

all_keys = {}
for key in self._stored:
if isinstance(self._stored[key], MutableMapping):
all_sub_keys = NDict(self[key]).flatten()
keys_to_add = {f"{key}.{sub_key}": all_sub_keys[sub_key] for sub_key in all_sub_keys}
all_keys.update(keys_to_add)
else:
all_keys[key] = self._stored[key]

return all_keys
@staticmethod
def _flatten_static(item: Union[dict, Any], prefix: str, flat_dict: dict) -> None:
if isinstance(item, MutableMapping):
for key, value in item.items():
if prefix is None:
cur_prefix = key
else:
cur_prefix = f"{prefix}.{key}"
NDict._flatten_static(value, cur_prefix, flat_dict)
else:
flat_dict[prefix] = item

def keypaths(self) -> List[str]:
"""
returns a list of keypaths (i.e. "a.b.c.d") to all values in the nested dict
"""
return list(self.flatten().keys())
return NDict._keypaths_static(self._stored, None)

@staticmethod
def _keypaths_static(item: Union[dict, Any], prefix: str) -> List[str]:
if isinstance(item, MutableMapping):
keys = []
for key, value in item.items():
if prefix is None:
cur_prefix = key
else:
cur_prefix = f"{prefix}.{key}"
keys += NDict._keypaths_static(value, cur_prefix)
return keys
else:
return [prefix]

def keys(self) -> dict_keys:
"""
Expand Down Expand Up @@ -228,7 +246,7 @@ def pop(self, key: str) -> Any:
del self[key]
return res

def indices(self, indices: Union[torch.Tensor, numpy.ndarray]) -> dict:
def indices(self, indices: numpy.ndarray) -> dict:
"""
Extract the specified indices from each element in the dictionary (if possible)
:param nested_dict: input dict
Expand All @@ -240,7 +258,7 @@ def indices(self, indices: Union[torch.Tensor, numpy.ndarray]) -> dict:
for key in all_keys:
try:
value = self[key]
if isinstance(value, numpy.ndarray) or isinstance(value, torch.Tensor):
if isinstance(value, (numpy.ndarray, torch.Tensor)):
new_value = value[indices]
elif isinstance(value, Sequence):
new_value = [item for i, item in enumerate(value) if indices[i]]
Expand Down

0 comments on commit 50a1565

Please sign in to comment.