From 3a7a35f33b192c8589e800de517e7ae8bbdeb41c Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Wed, 2 Mar 2022 07:52:00 +0800 Subject: [PATCH 1/2] fix load pretrain Signed-off-by: Yiheng Wang --- .../ssl_finetuning_train.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/self_supervised_pretraining/ssl_finetuning_train.py b/self_supervised_pretraining/ssl_finetuning_train.py index 51dfeed150..ceb2e81bd2 100644 --- a/self_supervised_pretraining/ssl_finetuning_train.py +++ b/self_supervised_pretraining/ssl_finetuning_train.py @@ -190,15 +190,15 @@ def main(): vit_dict = torch.load(pretrained_path) vit_weights = vit_dict['state_dict'] - # Delete the following variable names conv3d_transpose.weight, conv3d_transpose.bias, - # conv3d_transpose_1.weight, conv3d_transpose_1.bias as they were used to match dimensions - # while pretraining with ViTAutoEnc and are not a part of ViT backbone (this is used in UNETR) - vit_weights.pop('conv3d_transpose_1.bias') - vit_weights.pop('conv3d_transpose_1.weight') - vit_weights.pop('conv3d_transpose.bias') - vit_weights.pop('conv3d_transpose.weight') - - model.vit.load_state_dict(vit_weights) + # Remove items of vit_weights if they are not in the ViT backbone (this is used in UNETR). + # For example, some variables names like conv3d_transpose.weight, conv3d_transpose.bias, + # conv3d_transpose_1.weight and conv3d_transpose_1.bias are used to match dimensions + # while pretraining with ViTAutoEnc and are not a part of ViT backbone. + model_dict = model.vit.state_dict() + vit_weights = {k: v for k, v in vit_weights.items() if k in model_dict} + model_dict.update(vit_weights) + model.vit.load_state_dict(model_dict) + del vit_weights, vit_dict print('Pretrained Weights Succesfully Loaded !') elif use_pretrained==0: From 9b67f683a1dc205363a658afd38add523954c9e7 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Wed, 2 Mar 2022 08:02:28 +0800 Subject: [PATCH 2/2] del loaded weights Signed-off-by: Yiheng Wang --- self_supervised_pretraining/ssl_finetuning_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/self_supervised_pretraining/ssl_finetuning_train.py b/self_supervised_pretraining/ssl_finetuning_train.py index ceb2e81bd2..d4da2a1da9 100644 --- a/self_supervised_pretraining/ssl_finetuning_train.py +++ b/self_supervised_pretraining/ssl_finetuning_train.py @@ -198,7 +198,7 @@ def main(): vit_weights = {k: v for k, v in vit_weights.items() if k in model_dict} model_dict.update(vit_weights) model.vit.load_state_dict(model_dict) - del vit_weights, vit_dict + del model_dict, vit_weights, vit_dict print('Pretrained Weights Succesfully Loaded !') elif use_pretrained==0: