# GBT-net Model Loading Notebook

This notebook demonstrates loading pretrained weights for segmentation and classification models for the GBT-net project.

In [None]:
import torch
from monai.networks.nets import swin_unetr
from monai.networks.nets.swin_unetr import SwinUNETR
from gbt_net import SwinTransformer_new, PatchMergingV3

## Utility Function
Define a helper to rename keys in the checkpoint state dictionary.

In [None]:
def change_weight_key(state_dict, source_key='module.', target_key=''):
    print("Tag '", source_key, "' found in state dict - fixing to ", target_key)
    for key in list(state_dict.keys()):
        if source_key in key:
            state_dict[key.replace(source_key, target_key)] = state_dict.pop(key)

## Load Checkpoint
Load the pretrained checkpoint file.

In [None]:
checkpoint = torch.load("/path/to/your/directory/gbt_net_checkpoint.pth")

## Initialize model and load state dict

In [None]:
# segmentation model
swin_unetr.MERGING_MODE = {"mergingv3": PatchMergingV3}
segmentation_model = SwinUNETR(
    img_size=(128, 128, 128),
    in_channels=1,
    out_channels=1,
    feature_size=48,
    downsample="mergingv3",
    use_v2=True
)

missing_keys, unexpected_keys = segmentation_model.load_state_dict(checkpoint, strict=False)
print(missing_keys)
print(unexpected_keys)

In [None]:
# classification model
classification_model = SwinTransformer_new(in_chans=1, embed_dim=48, window_size=(7, 7, 7), 
                            patch_size=(2, 2, 2), depths=(2, 2, 2, 2), num_heads=(3, 6, 12, 24), downsample="mergingv3", use_v2=True)

change_weight_key(checkpoint, 'swinViT.', '')


missing_keys, unexpected_keys = classification_model.load_state_dict(checkpoint, strict=False)
print(missing_keys)
print(unexpected_keys)