Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimized train time for a use case of small samples and large batch #268

Merged
merged 6 commits into from Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 4 additions & 2 deletions fuse/data/utils/collates.py
Expand Up @@ -68,9 +68,11 @@ def __call__(self, samples: List[Dict]) -> Dict:
batch_dict = NDict()

# collect all keys
keys = self._collect_all_keys(samples)
if self._keep_keys:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the description we say: "missing keep_keys are skipped." , I think that now we won't do that.
Could it be an issue? If a user specifies to keep a key that doesn't exist?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch Now we will throw an error in such a case. I will update the comment.

keys = [k for k in keys if k in self._keep_keys]
keys = self._keep_keys
else:
keys = self._collect_all_keys(samples)

# collect values
for key in keys:

Expand Down
6 changes: 5 additions & 1 deletion fuse/dl/lightning/pl_funcs.py
Expand Up @@ -124,7 +124,11 @@ def convert_predictions_to_dataframe(predictions: List[NDict]) -> pd.DataFrame:
predictions_per_sample = []
for elem in predictions:
predictions_per_sample += uncollate(elem)
keys = predictions_per_sample[0].keypaths()
if isinstance(predictions_per_sample[0], NDict):
keys = predictions_per_sample[0].keypaths()
else: # dict
keys = predictions_per_sample[0].keys()

for key in keys:
values[key] = [elem[key] for elem in predictions_per_sample]

Expand Down
63 changes: 43 additions & 20 deletions fuse/eval/metrics/metrics_common.py
Expand Up @@ -82,19 +82,23 @@ def __init__(
self,
pre_collect_process_func: Optional[Callable] = None,
post_collect_process_func: Optional[Callable] = None,
batch_pre_collect_process_func: Optional[Callable] = None,
**keys_to_collect: Dict[str, str],
):
"""
:param pre_collect_process_func: Optional callable - the callable will get as an input a sample_dict and can preprocess it if required
:param post_collect_process_func: Optional callable - custom process func that convert the fields to be collected to the values that will actually be collected
the callable will get as an input the collected values of a single sample
and can return either a single value or dictionary. The returned values will be collected under the name "post_args"
:param pre_collect_process_func: Optional callable - the callable will get as an input a sample_dict and can preprocess it if required before collection.
Consider using batch_pre_collect_process_func instead to optimize the running time when working with large batch size.
:param post_collect_process_func: Optional callable - custom process func - used to evaluate in a sample level and keep only the result.
Typically used in methods such as segmentation to avoid from storing all the images until the end of the epoch.
Can return either a single value or dictionary. The returned values will be collected under the name "post_args"
:param batch_pre_collect_process_func: Optional callable - the callable will get as an input a batch_dict and can preprocess it if required before collection.
:param keys_to_collect: specify the keys you want to collect from the source data
"""
super().__init__()
# store input
self._pre_collect_process_func = pre_collect_process_func
self._post_collect_process_func = post_collect_process_func
self._batch_pre_collect_process_func = batch_pre_collect_process_func
self._keys_to_collect = copy.copy(keys_to_collect)
self._id_keys = MetricCollector.DEFAULT_ID_KEYS

Expand All @@ -108,39 +112,52 @@ def collect(self, batch: Dict) -> None:
if not isinstance(batch, NDict):
batch = NDict(batch)

samples = uncollate(batch)

# If in distributed mode (multi gpu training) we shall gather the result from all the machine to evaluate with respect to the entire batch.
if dist.is_initialized():
world_size = dist.get_world_size() # num of gpus
samples_gather = [None for rank in range(world_size)]
# samples_gather[i] will have the 'samples' value of the i's GPU
dist.all_gather_object(samples_gather, samples)
dist.all_gather_object(samples_gather, batch)

# union all the GPU's samples into one samples list
samples = []
for rank in range(world_size):
samples += samples_gather[rank]

for sample in samples:
sample_to_collect = {}
if self._pre_collect_process_func is not None or self._post_collect_process_func is not None:
samples = uncollate(batch)
for sample in samples:

if self._pre_collect_process_func is not None:
sample = NDict(self._pre_collect_process_func(sample))

sample_to_collect = {}
for name, key in self._keys_to_collect.items():
value = sample[key]
if isinstance(value, torch.Tensor):
value = value.detach().cpu().numpy()

if self._pre_collect_process_func is not None:
sample = NDict(self._pre_collect_process_func(sample))
sample_to_collect[name] = value

if self._post_collect_process_func is not None:
sample_to_collect = {"post_args": self._post_collect_process_func(**sample_to_collect)}

# store it - assumes batch dimension? What about single sample?
for name in sample_to_collect:
self._collected_data[name].append(sample_to_collect[name])
else: # work in a batch level - optimized for large batch size
if self._batch_pre_collect_process_func is not None:
batch = self._batch_pre_collect_process_func(batch)
batch_to_collect = {}
for name, key in self._keys_to_collect.items():
value = sample[key]
value = batch[key]
if isinstance(value, torch.Tensor):
value = value.detach().cpu().numpy()

sample_to_collect[name] = value

if self._post_collect_process_func is not None:
sample_to_collect = {"post_args": self._post_collect_process_func(**sample_to_collect)}
batch_to_collect[name] = value

# store it - assumes batch dimension? What about single sample?
for name in sample_to_collect:
self._collected_data[name].append(sample_to_collect[name])
for name in batch_to_collect:
self._collected_data[name].extend(batch_to_collect[name])

# extract ids and store it in self._collected_ids
ids = None
Expand Down Expand Up @@ -257,6 +274,7 @@ def __init__(
self,
pre_collect_process_func: Optional[Callable] = None,
post_collect_process_func: Optional[Callable] = None,
batch_pre_collect_process_func: Optional[Callable] = None,
external_data_collector: Optional[MetricCollector] = None,
extract_ids: bool = False,
**kwargs,
Expand All @@ -276,7 +294,12 @@ def __init__(
}
self._value_args = {n: k for n, k in kwargs.items() if k is not None and not isinstance(k, str)}
self._collector = (
MetricCollector(pre_collect_process_func, post_collect_process_func, **self._keys_to_collect)
MetricCollector(
pre_collect_process_func,
post_collect_process_func,
batch_pre_collect_process_func,
**self._keys_to_collect,
)
if external_data_collector is None
else external_data_collector
)
Expand Down
62 changes: 26 additions & 36 deletions fuse/utils/data/collate.py
Expand Up @@ -24,8 +24,6 @@
import torch
import numpy as np

import math


class CollateToBatchList(Callable):
"""
Expand All @@ -36,15 +34,13 @@ def __init__(
self,
skip_keys: Sequence[str] = tuple(),
raise_error_key_missing: bool = True,
missing_values: Sequence[str] = (None, "N/A"),
):
"""
:param skip_keys: do not collect the listed keys
:param raise_error_key_missing: if False, will not raise an error if there are keys that do not exist in some of the samples. Instead will set those values to None.
"""
self._skip_keys = skip_keys
self._raise_error_key_missing = raise_error_key_missing
self._missing_values = missing_values

def __call__(self, samples: List[Dict]) -> Dict:
"""
Expand Down Expand Up @@ -98,22 +94,15 @@ def _collect_values_to_list(self, samples: List[str], key: str) -> Tuple[List, b
has_missing_values = False
collected_values = []
for index, sample in enumerate(samples):
sample = NDict(sample)
if key not in sample:
try:
value = sample[key]
except:
has_error = True
has_missing_values = True
if self._raise_error_key_missing:
raise Exception(f"Error: key {key} does not exist in sample {index}: {sample}")
else:
value = None
else:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removing to optimize the running time.
Detecting NaNs and more could move to an optional op (in the data pipeline)

value = sample[key]
if isinstance(value, float) and math.isnan(value):
has_missing_values = True
for missing_value in self._missing_values:
if type(missing_value) is type(value) and value == missing_value:
has_missing_values = True
break

collected_values.append(value)
return collected_values, has_error, has_missing_values
Expand All @@ -124,49 +113,50 @@ def uncollate(batch: Dict) -> List[Dict]:
Reverse collate method
Gets a batch_dict and convert it back to list of samples
"""
samples = []
if not isinstance(batch, NDict):
batch = NDict(batch)
keys = batch.keypaths()

# empty batch
if not keys:
return samples
if not batch.keys():
return []

if "data.sample_id" in keys:
if isinstance(batch, NDict):
batch = batch.flatten()

# infer batch size
if "data.sample_id" in batch:
batch_size = len(batch["data.sample_id"])
else:
batch_size = None

if batch_size is None:
keys = batch.keys()

for key in keys:
if isinstance(batch[key], torch.Tensor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why there are two different loops? one for each case - torch.Tensor, (np.ndarray, list)
If I'm not missing something we can check for the two cases in the same loop.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because, first I want to look for tensor (trust it more), and if I can't find one then my second choice is (np.ndarray, list)

batch_size = len(batch[key])
break

if batch_size is None:
for key in keys:
if isinstance(batch[key], (np.ndarray, list)):
batch_size = len(batch[key])
break
if batch_size is None:
for key in keys:
if isinstance(batch[key], (np.ndarray, list)):
batch_size = len(batch[key])
break

if batch_size is None:
return batch # assuming batch dict with no samples

for sample_index in range(batch_size):
sample = NDict()
for key in keys:
if isinstance(batch[key], (np.ndarray, torch.Tensor, list)):
keys = batch.keys()
samples = [{} for _ in range(batch_size)]
for key in keys:
values = batch[key]
for sample_index in range(batch_size):

if isinstance(values, (np.ndarray, torch.Tensor, list)):
try:
sample[key] = batch[key][sample_index]
samples[sample_index][key] = values[sample_index]
except IndexError:
logging.error(
f"Error - IndexError - key={key}, batch_size={batch_size}, type={type((batch[key]))}, len={len(batch[key])}"
)
raise
else:
sample[key] = batch[key] # broadcast single value for all batch

samples.append(sample)
samples[sample_index][key] = values # broadcast single value for all batch

return samples
44 changes: 31 additions & 13 deletions fuse/utils/ndict.py
Expand Up @@ -111,23 +111,41 @@ def flatten(self) -> dict:
#you can use it to get a list of the flat keys:
print(nx.flatten().keys())
"""
flat_dict = {}
NDict._flatten_static(self._stored, None, flat_dict)
return flat_dict

all_keys = {}
for key in self._stored:
if isinstance(self._stored[key], MutableMapping):
all_sub_keys = NDict(self[key]).flatten()
keys_to_add = {f"{key}.{sub_key}": all_sub_keys[sub_key] for sub_key in all_sub_keys}
all_keys.update(keys_to_add)
else:
all_keys[key] = self._stored[key]

return all_keys
@staticmethod
def _flatten_static(item: Union[dict, Any], prefix: str, flat_dict: dict) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

COOL!

if isinstance(item, MutableMapping):
for key, value in item.items():
if prefix is None:
cur_prefix = key
else:
cur_prefix = f"{prefix}.{key}"
NDict._flatten_static(value, cur_prefix, flat_dict)
else:
flat_dict[prefix] = item

def keypaths(self) -> List[str]:
"""
returns a list of keypaths (i.e. "a.b.c.d") to all values in the nested dict
"""
return list(self.flatten().keys())
return NDict._keypaths_static(self._stored, None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using the same paradigm as before? Just calling flatten()?
The two static functions has a lot in common

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly cause I don't want the overhead of creating a dictionary and extracting the keys,


@staticmethod
def _keypaths_static(item: Union[dict, Any], prefix: str) -> List[str]:
if isinstance(item, dict):
keys = []
for key, value in item.items():
if prefix is None:
cur_prefix = key
else:
cur_prefix = f"{prefix}.{key}"
keys += NDict._keypaths_static(value, cur_prefix)
return keys
else:
return [prefix]

def keys(self) -> dict_keys:
"""
Expand Down Expand Up @@ -228,7 +246,7 @@ def pop(self, key: str) -> Any:
del self[key]
return res

def indices(self, indices: Union[torch.Tensor, numpy.ndarray]) -> dict:
def indices(self, indices: Optional[numpy.ndarray]) -> dict:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self review: will remove the optional here

"""
Extract the specified indices from each element in the dictionary (if possible)
:param nested_dict: input dict
Expand All @@ -240,7 +258,7 @@ def indices(self, indices: Union[torch.Tensor, numpy.ndarray]) -> dict:
for key in all_keys:
try:
value = self[key]
if isinstance(value, numpy.ndarray) or isinstance(value, torch.Tensor):
if isinstance(value, (numpy.ndarray, torch.Tensor)):
new_value = value[indices]
elif isinstance(value, Sequence):
new_value = [item for i, item in enumerate(value) if indices[i]]
Expand Down