New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
optimized train time for a use case of small samples and large batch #268
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,8 +24,6 @@ | |
import torch | ||
import numpy as np | ||
|
||
import math | ||
|
||
|
||
class CollateToBatchList(Callable): | ||
""" | ||
|
@@ -36,15 +34,13 @@ def __init__( | |
self, | ||
skip_keys: Sequence[str] = tuple(), | ||
raise_error_key_missing: bool = True, | ||
missing_values: Sequence[str] = (None, "N/A"), | ||
): | ||
""" | ||
:param skip_keys: do not collect the listed keys | ||
:param raise_error_key_missing: if False, will not raise an error if there are keys that do not exist in some of the samples. Instead will set those values to None. | ||
""" | ||
self._skip_keys = skip_keys | ||
self._raise_error_key_missing = raise_error_key_missing | ||
self._missing_values = missing_values | ||
|
||
def __call__(self, samples: List[Dict]) -> Dict: | ||
""" | ||
|
@@ -98,22 +94,15 @@ def _collect_values_to_list(self, samples: List[str], key: str) -> Tuple[List, b | |
has_missing_values = False | ||
collected_values = [] | ||
for index, sample in enumerate(samples): | ||
sample = NDict(sample) | ||
if key not in sample: | ||
try: | ||
value = sample[key] | ||
except: | ||
has_error = True | ||
has_missing_values = True | ||
if self._raise_error_key_missing: | ||
raise Exception(f"Error: key {key} does not exist in sample {index}: {sample}") | ||
else: | ||
value = None | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removing to optimize the running time. |
||
value = sample[key] | ||
if isinstance(value, float) and math.isnan(value): | ||
has_missing_values = True | ||
for missing_value in self._missing_values: | ||
if type(missing_value) is type(value) and value == missing_value: | ||
has_missing_values = True | ||
break | ||
|
||
collected_values.append(value) | ||
return collected_values, has_error, has_missing_values | ||
|
@@ -124,49 +113,50 @@ def uncollate(batch: Dict) -> List[Dict]: | |
Reverse collate method | ||
Gets a batch_dict and convert it back to list of samples | ||
""" | ||
samples = [] | ||
if not isinstance(batch, NDict): | ||
batch = NDict(batch) | ||
keys = batch.keypaths() | ||
|
||
# empty batch | ||
if not keys: | ||
return samples | ||
if not batch.keys(): | ||
return [] | ||
|
||
if "data.sample_id" in keys: | ||
if isinstance(batch, NDict): | ||
batch = batch.flatten() | ||
|
||
# infer batch size | ||
if "data.sample_id" in batch: | ||
batch_size = len(batch["data.sample_id"]) | ||
else: | ||
batch_size = None | ||
|
||
if batch_size is None: | ||
keys = batch.keys() | ||
|
||
for key in keys: | ||
if isinstance(batch[key], torch.Tensor): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why there are two different loops? one for each case - There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because, first I want to look for tensor (trust it more), and if I can't find one then my second choice is (np.ndarray, list) |
||
batch_size = len(batch[key]) | ||
break | ||
|
||
if batch_size is None: | ||
for key in keys: | ||
if isinstance(batch[key], (np.ndarray, list)): | ||
batch_size = len(batch[key]) | ||
break | ||
if batch_size is None: | ||
for key in keys: | ||
if isinstance(batch[key], (np.ndarray, list)): | ||
batch_size = len(batch[key]) | ||
break | ||
|
||
if batch_size is None: | ||
return batch # assuming batch dict with no samples | ||
|
||
for sample_index in range(batch_size): | ||
sample = NDict() | ||
for key in keys: | ||
if isinstance(batch[key], (np.ndarray, torch.Tensor, list)): | ||
keys = batch.keys() | ||
samples = [{} for _ in range(batch_size)] | ||
for key in keys: | ||
values = batch[key] | ||
for sample_index in range(batch_size): | ||
|
||
if isinstance(values, (np.ndarray, torch.Tensor, list)): | ||
try: | ||
sample[key] = batch[key][sample_index] | ||
samples[sample_index][key] = values[sample_index] | ||
except IndexError: | ||
logging.error( | ||
f"Error - IndexError - key={key}, batch_size={batch_size}, type={type((batch[key]))}, len={len(batch[key])}" | ||
) | ||
raise | ||
else: | ||
sample[key] = batch[key] # broadcast single value for all batch | ||
|
||
samples.append(sample) | ||
samples[sample_index][key] = values # broadcast single value for all batch | ||
|
||
return samples |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -111,23 +111,41 @@ def flatten(self) -> dict: | |
#you can use it to get a list of the flat keys: | ||
print(nx.flatten().keys()) | ||
""" | ||
flat_dict = {} | ||
NDict._flatten_static(self._stored, None, flat_dict) | ||
return flat_dict | ||
|
||
all_keys = {} | ||
for key in self._stored: | ||
if isinstance(self._stored[key], MutableMapping): | ||
all_sub_keys = NDict(self[key]).flatten() | ||
keys_to_add = {f"{key}.{sub_key}": all_sub_keys[sub_key] for sub_key in all_sub_keys} | ||
all_keys.update(keys_to_add) | ||
else: | ||
all_keys[key] = self._stored[key] | ||
|
||
return all_keys | ||
@staticmethod | ||
def _flatten_static(item: Union[dict, Any], prefix: str, flat_dict: dict) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. COOL! |
||
if isinstance(item, MutableMapping): | ||
for key, value in item.items(): | ||
if prefix is None: | ||
cur_prefix = key | ||
else: | ||
cur_prefix = f"{prefix}.{key}" | ||
NDict._flatten_static(value, cur_prefix, flat_dict) | ||
else: | ||
flat_dict[prefix] = item | ||
|
||
def keypaths(self) -> List[str]: | ||
""" | ||
returns a list of keypaths (i.e. "a.b.c.d") to all values in the nested dict | ||
""" | ||
return list(self.flatten().keys()) | ||
return NDict._keypaths_static(self._stored, None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not using the same paradigm as before? Just calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mostly cause I don't want the overhead of creating a dictionary and extracting the keys, |
||
|
||
@staticmethod | ||
def _keypaths_static(item: Union[dict, Any], prefix: str) -> List[str]: | ||
if isinstance(item, dict): | ||
keys = [] | ||
for key, value in item.items(): | ||
if prefix is None: | ||
cur_prefix = key | ||
else: | ||
cur_prefix = f"{prefix}.{key}" | ||
keys += NDict._keypaths_static(value, cur_prefix) | ||
return keys | ||
else: | ||
return [prefix] | ||
|
||
def keys(self) -> dict_keys: | ||
""" | ||
|
@@ -228,7 +246,7 @@ def pop(self, key: str) -> Any: | |
del self[key] | ||
return res | ||
|
||
def indices(self, indices: Union[torch.Tensor, numpy.ndarray]) -> dict: | ||
def indices(self, indices: Optional[numpy.ndarray]) -> dict: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. self review: will remove the optional here |
||
""" | ||
Extract the specified indices from each element in the dictionary (if possible) | ||
:param nested_dict: input dict | ||
|
@@ -240,7 +258,7 @@ def indices(self, indices: Union[torch.Tensor, numpy.ndarray]) -> dict: | |
for key in all_keys: | ||
try: | ||
value = self[key] | ||
if isinstance(value, numpy.ndarray) or isinstance(value, torch.Tensor): | ||
if isinstance(value, (numpy.ndarray, torch.Tensor)): | ||
new_value = value[indices] | ||
elif isinstance(value, Sequence): | ||
new_value = [item for i, item in enumerate(value) if indices[i]] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the description we say: "missing keep_keys are skipped." , I think that now we won't do that.
Could it be an issue? If a user specifies to keep a key that doesn't exist?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch Now we will throw an error in such a case. I will update the comment.