Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NDict optimization #271

Merged
merged 22 commits into from Feb 14, 2023
Merged

NDict optimization #271

merged 22 commits into from Feb 14, 2023

Conversation

SagiPolaczek
Copy link
Collaborator

@SagiPolaczek SagiPolaczek commented Feb 7, 2023

✅ Ready for review

Profiling

EHR Transformer (5 epochs)

Not Optimized

ncalls  tottime  percall  cumtime  percall filename:lineno(function)

2565    0.003    0.000    0.004    0.000 /dccstor/mm_hcls/usr/sagi/fuse_2/fuse/utils/ndict.py:150(keys)
2565    0.004    0.000    0.005    0.000 /dccstor/mm_hcls/usr/sagi/fuse_2/fuse/utils/ndict.py:159(items)
15390    0.043    0.000    0.117    0.000 /dccstor/mm_hcls/usr/sagi/fuse_2/fuse/utils/ndict.py:223(get_closest_key)
15390    0.021    0.000    0.138    0.000 /dccstor/mm_hcls/usr/sagi/fuse_2/fuse/utils/ndict.py:294(__contains__)
61560/2565    0.066    0.000    0.185    0.000 /dccstor/mm_hcls/usr/sagi/fuse_2/fuse/utils/ndict.py:118(_flatten_static)
2565    0.008    0.000    0.193    0.000 /dccstor/mm_hcls/usr/sagi/fuse_2/fuse/utils/ndict.py:100(flatten)
53905/2575    0.088    0.000    0.203    0.000 /dccstor/mm_hcls/usr/sagi/fuse_2/fuse/utils/ndict.py:136(_keypaths_static)
2575    0.018    0.000    0.221    0.000 /dccstor/mm_hcls/usr/sagi/fuse_2/fuse/utils/ndict.py:130(keypaths)
412420    0.497    0.000    1.916    0.000 /dccstor/mm_hcls/usr/sagi/fuse_2/fuse/utils/ndict.py:172(__getitem__)
2565    4.561    0.002    4.561    0.002 /dccstor/mm_hcls/usr/sagi/fuse_2/fuse/utils/ndict.py:264(<listcomp>)
2565    2.039    0.001    7.170    0.003 /dccstor/mm_hcls/usr/sagi/fuse_2/fuse/utils/ndict.py:249(indices)
6275340/6272775    4.483    0.000   16.573    0.000 /dccstor/mm_hcls/usr/sagi/fuse_2/fuse/utils/ndict.py:191(__setitem__)
330290    2.766    0.000   20.023    0.000 /dccstor/mm_hcls/usr/sagi/fuse_2/fuse/utils/ndict.py:63(__init__)
Optimized
ncalls  tottime  percall  cumtime  percall filename:lineno(function)

       10    0.000    0.000    0.000    0.000 /dccstor/mm_hcls/usr/sagi/fuse_3/fuse/utils/ndict.py:106(keys)
     2565    0.004    0.000    0.005    0.000 /dccstor/mm_hcls/usr/sagi/fuse_3/fuse/utils/ndict.py:115(items)
    12835    0.024    0.000    0.027    0.000 /dccstor/mm_hcls/usr/sagi/fuse_3/fuse/utils/ndict.py:100(keypaths)
     5130    0.026    0.000    0.049    0.000 /dccstor/mm_hcls/usr/sagi/fuse_3/fuse/utils/ndict.py:149(is_prefix)
     5130    0.019    0.000    0.051    0.000 /dccstor/mm_hcls/usr/sagi/fuse_3/fuse/utils/ndict.py:224(get_closest_key)
     5130    0.005    0.000    0.056    0.000 /dccstor/mm_hcls/usr/sagi/fuse_3/fuse/utils/ndict.py:303(__contains__)
     5140    0.022    0.000    0.075    0.000 /dccstor/mm_hcls/usr/sagi/fuse_3/fuse/utils/ndict.py:47(__init__)
     2565    0.026    0.000    0.076    0.000 /dccstor/mm_hcls/usr/sagi/fuse_3/fuse/utils/ndict.py:172(get_sub_dict)
153960/143700    0.080    0.000    0.174    0.000 /dccstor/mm_hcls/usr/sagi/fuse_3/fuse/utils/ndict.py:129(__getitem__)
    56450    0.047    0.000    0.275    0.000 /dccstor/mm_hcls/usr/sagi/fuse_3/fuse/utils/ndict.py:193(__setitem__)
     5130    8.242    0.002    8.242    0.002 /dccstor/mm_hcls/usr/sagi/fuse_3/fuse/utils/ndict.py:267(<listcomp>)
     5130    3.152    0.001   11.579    0.002 /dccstor/mm_hcls/usr/sagi/fuse_3/fuse/utils/ndict.py:251(indices)

ISIC (2 epochs)

Not Optimized
ncalls  tottime  percall  cumtime  percall filename:lineno(function)

75/5    0.000    0.000    0.000    0.000 /dccstor/mm_hcls/usr/sagi/fuse_2/fuse/utils/ndict.py:136(_keypaths_static)
5    0.000    0.000    0.000    0.000 /dccstor/mm_hcls/usr/sagi/fuse_2/fuse/utils/ndict.py:130(keypaths)
4184    0.006    0.000    0.007    0.000 /dccstor/mm_hcls/usr/sagi/fuse_2/fuse/utils/ndict.py:159(items)
25104    0.053    0.000    0.173    0.000 /dccstor/mm_hcls/usr/sagi/fuse_2/fuse/utils/ndict.py:223(get_closest_key)
25104    0.024    0.000    0.197    0.000 /dccstor/mm_hcls/usr/sagi/fuse_2/fuse/utils/ndict.py:294(__contains__)
65560    0.148    0.000    0.617    0.000 /dccstor/mm_hcls/usr/sagi/fuse_2/fuse/utils/ndict.py:172(__getitem__)
4189    0.035    0.000    0.709    0.000 /dccstor/mm_hcls/usr/sagi/fuse_2/fuse/utils/ndict.py:63(__init__)
179337/29303    0.350    0.000    0.960    0.000 /dccstor/mm_hcls/usr/sagi/fuse_2/fuse/utils/ndict.py:191(__setitem__)
Optimized
ncalls  tottime  percall  cumtime  percall filename:lineno(function)

5    0.000    0.000    0.000    0.000 /dccstor/mm_hcls/usr/sagi/fuse_3/fuse/utils/ndict.py:106(keys)
4184    0.006    0.000    0.007    0.000 /dccstor/mm_hcls/usr/sagi/fuse_3/fuse/utils/ndict.py:115(items)
20925    0.031    0.000    0.035    0.000 /dccstor/mm_hcls/usr/sagi/fuse_3/fuse/utils/ndict.py:100(keypaths)
4184    0.042    0.000    0.110    0.000 /dccstor/mm_hcls/usr/sagi/fuse_3/fuse/utils/ndict.py:172(get_sub_dict)
90664/82296    0.048    0.000    0.191    0.000 /dccstor/mm_hcls/usr/sagi/fuse_3/fuse/utils/ndict.py:129(__getitem__)
16736    0.105    0.000    0.198    0.000 /dccstor/mm_hcls/usr/sagi/fuse_3/fuse/utils/ndict.py:149(is_prefix)
25104    0.056    0.000    0.228    0.000 /dccstor/mm_hcls/usr/sagi/fuse_3/fuse/utils/ndict.py:224(get_closest_key)
25104    0.017    0.000    0.245    0.000 /dccstor/mm_hcls/usr/sagi/fuse_3/fuse/utils/ndict.py:303(__contains__)
8373    0.070    0.000    0.368    0.000 /dccstor/mm_hcls/usr/sagi/fuse_3/fuse/utils/ndict.py:47(__init__)
133313/133263    0.125    0.000    0.500    0.000 /dccstor/mm_hcls/usr/sagi/fuse_3/fuse/utils/ndict.py:193(__setitem__)

@SagiPolaczek SagiPolaczek added the enhancement New feature or request label Feb 7, 2023
@SagiPolaczek SagiPolaczek marked this pull request as draft February 7, 2023 19:01

# convert continuous measurements to categorical ones based
# on defined bins mapping static clinical characteristics
# (Age, Gender, ICU type, Height, etc)
for k in sample_dict["StaticDetails"]:
sample_dict["StaticDetails"][k] = k + "_" + str(np.digitize(sample_dict["StaticDetails"][k], bins[k]))
sample_dict[f"StaticDetails.{k}"] = k + "_" + str(np.digitize(sample_dict["StaticDetails"][k], bins[k]))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A small fix to match the new NDict impl.

Since we now don't return a "real" nested dict, we can't change a returned sub-dict and expect the changes to be reflected in the original dictionary.

This kind of fix might be needed to be applied on other projects that are not covered by the CI tests.

train_metrics["gender_auc"] = Filter(
MetricAUCROC(pred="model.output.gender", target="Gender"),
"filter",
pre_collect_process_func=filter_gender_label_unknown_for_metric,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self-reminder:

open an issue on the pre_collect_process_func that cause to disproportionate amount of NDict's __init__ calls.

for key in keys:
if isinstance(batch[key], torch.Tensor):
for key in batch.keys():
if isinstance(batch[key], (torch.Tensor, np.ndarray, list)):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mosheraboh

I didn't understand why we first check for Tensors and ndarrays separately so I changed it..

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was safer, but let's change and see if someone encounter an issue

return self._stored[key]

# the key is a prefix for other value(s)
elif self.is_prefix(key): # TODO can be more optimized. we pass here once and in the "get_sub_dict" once again
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see TODO comment.

In my opinion is it enough to leave it like that for now (readability & code reuse VS optimization)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove is_prefix.
Instead return None is get_subdict

def __delitem__(self, key: str) -> None:
"""
:param key:
TODO should we delete both value and prefix ?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mosheraboh

What do you think?

Currently we delete the value (and not the subdict!)
if the value doesn't exists but the subdict does, we delete the subdict.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ans (talked offline):
delete both

@SagiPolaczek
Copy link
Collaborator Author

Added a self-CR.

@SagiPolaczek SagiPolaczek marked this pull request as ready for review February 14, 2023 08:51
for key in keys:
if isinstance(batch[key], torch.Tensor):
for key in batch.keys():
if isinstance(batch[key], (torch.Tensor, np.ndarray, list)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was safer, but let's change and see if someone encounter an issue

in deep copy, all values are copied recursively
:param deepcopy: if true, does deep copy, otherwise does shalow copy

:param deepcopy: if true, does deep copy, otherwise does a shallow copy
"""
if not deepcopy:
return NDict(copy.copy(self._stored))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

already_flat=True

NDict._flatten_static(value, cur_prefix, flat_dict)
else:
flat_dict[prefix] = item
return self._stored
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return self instead

@@ -163,83 +123,139 @@ def merge(self, other: dict) -> NDict:
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to get NDict input

self[k] = v

return
return self
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove - and change signature

return self._stored[key]

# the key is a prefix for other value(s)
elif self.is_prefix(key): # TODO can be more optimized. we pass here once and in the "get_sub_dict" once again
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove is_prefix.
Instead return None is get_subdict

suffix_key = None
for kk in self.keypaths():
if kk.startswith(prefix_key):
suffix_key = kk.replace(prefix_key, "", 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kk[len(prefix_key):]

res[suffix_key] = self[kk]

if suffix_key is None and key not in self:
raise NestedKeyError(key, self)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return None

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for key not in self

# set the value
element[nested_key[-1]] = value
# delete entire branch
elif self.is_prefix(key):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redundant

if key in self._stored:
return key

key_parts = key.split(".")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use similarity between strings

mosheraboh
mosheraboh previously approved these changes Feb 14, 2023
Copy link
Collaborator

@mosheraboh mosheraboh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! 🚀


def keypaths(self) -> List[str]:
def keypaths(self) -> dict_keys:
"""
:return: a list of keypaths (i.e. "a.b.c.d") to all values in the nested dict
"""
return list(self._stored.keys())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the list

"""
return self.keypaths()

def top_level_keys(self) -> dict_keys:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> List[str]

@SagiPolaczek SagiPolaczek merged commit b01394f into master Feb 14, 2023
@SagiPolaczek SagiPolaczek deleted the sagi/ndict_opt branch March 8, 2023 17:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants