New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
optimized train time for a use case of small samples and large batch #268
Conversation
has_error = True | ||
has_missing_values = True | ||
if self._raise_error_key_missing: | ||
raise Exception(f"Error: key {key} does not exist in sample {index}: {sample}") | ||
else: | ||
value = None | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removing to optimize the running time.
Detecting NaNs and more could move to an optional op (in the data pipeline)
fuse/utils/ndict.py
Outdated
@@ -228,7 +246,7 @@ def pop(self, key: str) -> Any: | |||
del self[key] | |||
return res | |||
|
|||
def indices(self, indices: Union[torch.Tensor, numpy.ndarray]) -> dict: | |||
def indices(self, indices: Optional[numpy.ndarray]) -> dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self review: will remove the optional here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!!
I added few comment inline to consider :)
@@ -68,9 +68,11 @@ def __call__(self, samples: List[Dict]) -> Dict: | |||
batch_dict = NDict() | |||
|
|||
# collect all keys | |||
keys = self._collect_all_keys(samples) | |||
if self._keep_keys: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the description we say: "missing keep_keys are skipped." , I think that now we won't do that.
Could it be an issue? If a user specifies to keep a key that doesn't exist?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch Now we will throw an error in such a case. I will update the comment.
batch_size = len(batch["data.sample_id"]) | ||
else: | ||
batch_size = None | ||
|
||
if batch_size is None: | ||
keys = batch.keys() | ||
|
||
for key in keys: | ||
if isinstance(batch[key], torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why there are two different loops? one for each case - torch.Tensor
, (np.ndarray, list)
If I'm not missing something we can check for the two cases in the same loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because, first I want to look for tensor (trust it more), and if I can't find one then my second choice is (np.ndarray, list)
|
||
return all_keys | ||
@staticmethod | ||
def _flatten_static(item: Union[dict, Any], prefix: str, flat_dict: dict) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
COOL!
|
||
def keypaths(self) -> List[str]: | ||
""" | ||
returns a list of keypaths (i.e. "a.b.c.d") to all values in the nested dict | ||
""" | ||
return list(self.flatten().keys()) | ||
return NDict._keypaths_static(self._stored, None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not using the same paradigm as before? Just calling flatten()
?
The two static functions has a lot in common
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly cause I don't want the overhead of creating a dictionary and extracting the keys,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again!
I added comments in a different review 😄
(I used the vscode interface and because of the changes in the middle it didn't allow me to approve)
Optimize nicely the train running time in such a use case (small samples and large batch size)
Leaving ndict optimization for the future