Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initializing weights from checkpoint with configs - init_from_ptl_ckpt #4805

Closed
stalevna opened this issue Aug 25, 2022 · 1 comment
Closed
Labels
bug Something isn't working

Comments

@stalevna
Copy link

Describe the bug
Hi! When I tried to initialize weights of the model from checkpoint using configs and excluding some parts of it I get an error that model_path is not defined
I looked a the code at https://github.com/NVIDIA/NeMo/blob/main/nemo/core/classes/modelPT.py and indeed there is an initialized variable ckpt_path, but model_path is initialized for init_from_nemo_model
image

Could you check if it was just accidental copy-paste and there should be ckpt_path instead of model_path
self.load_part_of_state_dict(
ckpt['state_dict'], include, exclude, f'nemo file with path **{model_path**}'
)

in this snippet of the code

    if 'init_from_ptl_ckpt' in cfg and cfg.init_from_ptl_ckpt is not None:
        with open_dict(cfg):
            if isinstance(cfg.init_from_ptl_ckpt, str):
                # Restore checkpoint
                ckpt_path = cfg.pop('init_from_ptl_ckpt')
                ckpt = torch.load(ckpt_path, map_location=map_location)

                # Restore checkpoint into current model
                self.load_state_dict(ckpt['state_dict'], strict=False)
                logging.info(
                    f'Model checkpoint restored from pytorch lightning chackpoint with path : `{ckpt_path}`'
                )

                del ckpt
            elif isinstance(cfg.init_from_ptl_ckpt, (DictConfig, dict)):
                model_load_dict = cfg.init_from_ptl_ckpt
                for model_load_cfg in model_load_dict.values():
                    ckpt_path = model_load_cfg.path
                    # Restore model
                    ckpt = torch.load(ckpt_path, map_location=map_location)

                    include = model_load_cfg.pop('include', [""])
                    exclude = model_load_cfg.pop('exclude', [])

                    self.load_part_of_state_dict(
                        ckpt['state_dict'], include, exclude, f'nemo file with path `{model_path}`'
                    )
@stalevna stalevna added the bug Something isn't working label Aug 25, 2022
@titu1994
Copy link
Collaborator

This seems to be a bug, @sam1373 could you send a PR to R1.11.0 branch please ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants