Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

6268 enhance hovernet load pretrained function #6269

Merged

Conversation

yiheng-wang-nv
Copy link
Contributor

Fixes #6268 .

Description

This PR enhances Hovernet's load pretrained function.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

Signed-off-by: Yiheng Wang <vennw@nvidia.com>
Copy link
Contributor

@KumoLiu KumoLiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for update!

@Nic-Ma
Copy link
Contributor

Nic-Ma commented Apr 3, 2023

Hi @yiheng-wang-nv @KumoLiu ,

I don't quite understand the use case of this PR, could you please help share some examples here?
I mainly want to ensure the bundle user experience is not affected, especially configs.

Thanks in advance.

@yiheng-wang-nv
Copy link
Contributor Author

Hi @yiheng-wang-nv @KumoLiu ,

I don't quite understand the use case of this PR, could you please help share some examples here? I mainly want to ensure the bundle user experience is not affected, especially configs.

Thanks in advance.

Hi Nic, some users are used to save not only weights, but also other information such as learning rate, epoch ... into the whole weight file, if print the keys of the file we will see something like:

dict_keys(['epoch', 'arch', 'state_dict', 'optimizer', 'version', 'args', 'amp_scaler', 'metric'])

Where, only the state_dict is the place we need.

Currently, Hovernet is designed to load standard pytorch resnet50 pretrained weights which only contain the weights information, and print the keys of the file we will see:

odict_keys(['conv1.weight', 'bn1.running_mean', 'bn1.running_var', 'bn1.weight', 'bn1.bias', 'layer1.0.conv1.weight', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var', 'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.conv2.weight', 'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.0.bn2.weight', 'layer1.0.bn2.bias', 'layer1.0.conv3.weight', 'layer1.0.bn3.running_mean', 'layer1.0.bn3.running_var', 'layer1.0.bn3.weight', 'layer1.0.bn3.bias', 'layer1.0.downsample.0.weight' ...

@Nic-Ma
Copy link
Contributor

Nic-Ma commented Apr 3, 2023

I mean an example to show how to use this arg in a python program and a bundle config.

@yiheng-wang-nv
Copy link
Contributor Author

I see, in a bundle config we can do:

    "network_def": {
        "_target_": "HoVerNet",
        "mode": "@hovernet_mode",
        "in_channels": 3,
        "out_classes": 5,
        "adapt_standard_resnet": true,
        "pretrained_url": "$None",
        "freeze_encoder": true,
        "pretrained_state_dict_key": "state_dict"
    },

Only add this arg: "pretrained_state_dict_key"

@Nic-Ma
Copy link
Contributor

Nic-Ma commented Apr 3, 2023

@yiheng-wang-nv OK, looks good to me, please fix the CI errors.

Thanks.

@wyli wyli enabled auto-merge April 3, 2023 18:59
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
auto-merge was automatically disabled April 3, 2023 19:57

Merge queue setting changed

@wyli wyli enabled auto-merge (squash) April 3, 2023 19:58
@wyli wyli merged commit bb4df37 into Project-MONAI:dev Apr 3, 2023
25 of 30 checks passed
@yiheng-wang-nv yiheng-wang-nv deleted the 6268-enhance-hovernet-load-pretrain branch April 4, 2023 02:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Enhance Hovernet load pretrained weights functionality
4 participants