diff --git a/self_supervised_pretraining/ssl_finetuning_train.py b/self_supervised_pretraining/ssl_finetuning_train.py index 51dfeed150..d4da2a1da9 100644 --- a/self_supervised_pretraining/ssl_finetuning_train.py +++ b/self_supervised_pretraining/ssl_finetuning_train.py @@ -190,15 +190,15 @@ def main(): vit_dict = torch.load(pretrained_path) vit_weights = vit_dict['state_dict'] - # Delete the following variable names conv3d_transpose.weight, conv3d_transpose.bias, - # conv3d_transpose_1.weight, conv3d_transpose_1.bias as they were used to match dimensions - # while pretraining with ViTAutoEnc and are not a part of ViT backbone (this is used in UNETR) - vit_weights.pop('conv3d_transpose_1.bias') - vit_weights.pop('conv3d_transpose_1.weight') - vit_weights.pop('conv3d_transpose.bias') - vit_weights.pop('conv3d_transpose.weight') - - model.vit.load_state_dict(vit_weights) + # Remove items of vit_weights if they are not in the ViT backbone (this is used in UNETR). + # For example, some variables names like conv3d_transpose.weight, conv3d_transpose.bias, + # conv3d_transpose_1.weight and conv3d_transpose_1.bias are used to match dimensions + # while pretraining with ViTAutoEnc and are not a part of ViT backbone. + model_dict = model.vit.state_dict() + vit_weights = {k: v for k, v in vit_weights.items() if k in model_dict} + model_dict.update(vit_weights) + model.vit.load_state_dict(model_dict) + del model_dict, vit_weights, vit_dict print('Pretrained Weights Succesfully Loaded !') elif use_pretrained==0: