Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix initializing weights from ptl ckpt with exclude #4807

Merged
merged 1 commit into from
Aug 25, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,7 @@ def get_test_dataloader_prefix(self, dataloader_idx: int = 0) -> str:
"""
return self._test_names[dataloader_idx]

def load_part_of_state_dict(self, state_dict, include, exclude, load_from_string):
def load_part_of_state_dict(self, state_dict, include, exclude, load_from_string=None):

excluded_param_names = []
# create dict
Expand All @@ -961,12 +961,18 @@ def load_part_of_state_dict(self, state_dict, include, exclude, load_from_string

# Restore checkpoint part into current model
self.load_state_dict(dict_to_load, strict=False)
logging.info(f'Model checkpoint partially restored from {load_from_string}')
if len(excluded_param_names) > 0:
logging.info(
f'The following parameters were excluded from loading from {load_from_string} : {excluded_param_names}'
)
logging.info(f'Make sure that this is what you wanted!')
if load_from_string is not None:
logging.info(f'Model checkpoint partially restored from {load_from_string}')
if len(excluded_param_names) > 0:
logging.info(
f'The following parameters were excluded when loading from {load_from_string} : {excluded_param_names}'
)
logging.info(f'Make sure that this is what you wanted!')
else:
if len(excluded_param_names) > 0:
logging.info(
f'The following parameters were excluded when loading checkpoint : {excluded_param_names}'
)

@rank_zero_only
def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: str = 'cpu'):
Expand Down Expand Up @@ -1139,7 +1145,7 @@ def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: st
exclude = model_load_cfg.pop('exclude', [])

self.load_part_of_state_dict(
ckpt['state_dict'], include, exclude, f'nemo file with path `{model_path}`'
ckpt['state_dict'], include, exclude, f'nemo file with path `{ckpt_path}`'
)

del ckpt
Expand Down