Skip to content

Commit

Permalink
Fix initializing weights from ptl ckpt with exclude (#4807)
Browse files Browse the repository at this point in the history
Signed-off-by: sam1373 <samuelkriman@gmail.com>

Signed-off-by: sam1373 <samuelkriman@gmail.com>
  • Loading branch information
sam1373 authored and ericharper committed Sep 9, 2022
1 parent db2d95a commit def5b9e
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ def get_test_dataloader_prefix(self, dataloader_idx: int = 0) -> str:
"""
return self._test_names[dataloader_idx]

def load_part_of_state_dict(self, state_dict, include, exclude, load_from_string):
def load_part_of_state_dict(self, state_dict, include, exclude, load_from_string=None):

excluded_param_names = []
# create dict
Expand All @@ -971,12 +971,18 @@ def load_part_of_state_dict(self, state_dict, include, exclude, load_from_string

# Restore checkpoint part into current model
self.load_state_dict(dict_to_load, strict=False)
logging.info(f'Model checkpoint partially restored from {load_from_string}')
if len(excluded_param_names) > 0:
logging.info(
f'The following parameters were excluded from loading from {load_from_string} : {excluded_param_names}'
)
logging.info(f'Make sure that this is what you wanted!')
if load_from_string is not None:
logging.info(f'Model checkpoint partially restored from {load_from_string}')
if len(excluded_param_names) > 0:
logging.info(
f'The following parameters were excluded when loading from {load_from_string} : {excluded_param_names}'
)
logging.info(f'Make sure that this is what you wanted!')
else:
if len(excluded_param_names) > 0:
logging.info(
f'The following parameters were excluded when loading checkpoint : {excluded_param_names}'
)

@rank_zero_only
def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: str = 'cpu'):
Expand Down Expand Up @@ -1149,7 +1155,7 @@ def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: st
exclude = model_load_cfg.pop('exclude', [])

self.load_part_of_state_dict(
ckpt['state_dict'], include, exclude, f'nemo file with path `{model_path}`'
ckpt['state_dict'], include, exclude, f'nemo file with path `{ckpt_path}`'
)

del ckpt
Expand Down

0 comments on commit def5b9e

Please sign in to comment.