Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions self_supervised_pretraining/ssl_finetuning_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,15 @@ def main():
vit_dict = torch.load(pretrained_path)
vit_weights = vit_dict['state_dict']

# Delete the following variable names conv3d_transpose.weight, conv3d_transpose.bias,
# conv3d_transpose_1.weight, conv3d_transpose_1.bias as they were used to match dimensions
# while pretraining with ViTAutoEnc and are not a part of ViT backbone (this is used in UNETR)
vit_weights.pop('conv3d_transpose_1.bias')
vit_weights.pop('conv3d_transpose_1.weight')
vit_weights.pop('conv3d_transpose.bias')
vit_weights.pop('conv3d_transpose.weight')

model.vit.load_state_dict(vit_weights)
# Remove items of vit_weights if they are not in the ViT backbone (this is used in UNETR).
# For example, some variables names like conv3d_transpose.weight, conv3d_transpose.bias,
# conv3d_transpose_1.weight and conv3d_transpose_1.bias are used to match dimensions
# while pretraining with ViTAutoEnc and are not a part of ViT backbone.
model_dict = model.vit.state_dict()
vit_weights = {k: v for k, v in vit_weights.items() if k in model_dict}
model_dict.update(vit_weights)
model.vit.load_state_dict(model_dict)
del model_dict, vit_weights, vit_dict
print('Pretrained Weights Succesfully Loaded !')

elif use_pretrained==0:
Expand Down