# Filtering out unnessary state_dict keys

In [6]:
import torch
from dynamic_network_architectures.architectures.unet import PlainConvUNetWithClassification

# Load the checkpoint
checkpoint_path = "C:/Users/Admin/nnUNet/nnUNet_results/Dataset001_Pancreas/nnUNetMultiTrainer_v2__nnUNetPlans__3d_fullres/fold_0/checkpoint_final.pth"
checkpoint = torch.load(checkpoint_path, map_location="cpu")

# Extract the state_dict
state_dict = checkpoint["network_weights"]

# Get the model's expected keys
model = PlainConvUNetWithClassification(    
    input_channels=1,
    n_stages=6,
    features_per_stage=(32, 64, 128, 256, 320, 320),
    conv_op=torch.nn.Conv3d,
    kernel_sizes=(3, 3, 3, 3, 3, 3),
    strides=(1, 2, 2, 2, 2, 2),
    n_conv_per_stage=(2, 2, 2, 2, 2, 2),
    num_classes=2,  # Adjust based on your segmentation classes
    n_conv_per_stage_decoder=(2, 2, 2, 2, 2),
    conv_bias=True,
    norm_op=torch.nn.InstanceNorm3d,
    norm_op_kwargs={"eps": 1e-5, "affine": True},
    dropout_op=None,
    dropout_op_kwargs=None,
    nonlin=torch.nn.LeakyReLU,
    nonlin_kwargs={"inplace": True},
    deep_supervision=False,
    classification_head={
        "n_classes": 3,  # Adjust based on your classification task
        "input_features": 320,
    },)  # Initialize the model
expected_keys = set(model.state_dict().keys())
print(f"Expected keys: {expected_keys}")
# Filter the state_dict
filtered_state_dict = {k: v for k, v in state_dict.items() if k in expected_keys}


  checkpoint = torch.load(checkpoint_path, map_location="cpu")


Expected keys: {'encoder.stages.5.0.convs.1.all_modules.1.bias', 'decoder.transpconvs.4.bias', 'encoder.stages.3.0.convs.0.conv.weight', 'encoder.stages.0.0.convs.0.all_modules.0.weight', 'encoder.stages.0.0.convs.1.conv.bias', 'encoder.stages.5.0.convs.1.norm.weight', 'decoder.stages.4.convs.1.all_modules.1.weight', 'decoder.stages.3.convs.1.conv.weight', 'decoder.encoder.stages.2.0.convs.1.all_modules.1.bias', 'decoder.encoder.stages.1.0.convs.1.conv.bias', 'decoder.stages.4.convs.1.all_modules.1.bias', 'decoder.encoder.stages.0.0.convs.0.all_modules.0.bias', 'decoder.encoder.stages.3.0.convs.0.norm.bias', 'decoder.seg_layers.1.weight', 'decoder.encoder.stages.4.0.convs.0.all_modules.1.weight', 'decoder.stages.2.convs.0.conv.bias', 'encoder.stages.3.0.convs.0.all_modules.1.weight', 'decoder.encoder.stages.1.0.convs.1.all_modules.1.weight', 'decoder.stages.4.convs.0.conv.weight', 'encoder.stages.1.0.convs.1.norm.weight', 'decoder.encoder.stages.0.0.convs.0.norm.bias', 'decoder.encoder

In [7]:
new_checkpoint_path = "C:/Users/Admin/nnUNet/nnUNet_results/Dataset001_Pancreas/nnUNetMultiTrainer_v2__nnUNetPlans__3d_fullres/fold_0/checkpoint_final_filtered.pth"
torch.save({
    "network_weights": filtered_state_dict,
    "optimizer_state": checkpoint["optimizer_state"],
    "grad_scaler_state": checkpoint["grad_scaler_state"],
    "logging": checkpoint["logging"],
    "_best_ema": checkpoint["_best_ema"],
    "current_epoch": checkpoint["current_epoch"],
    "init_args": checkpoint["init_args"],
    "trainer_name": checkpoint["trainer_name"],
    "inference_allowed_mirroring_axes": checkpoint["inference_allowed_mirroring_axes"],
}, new_checkpoint_path)

print(f"Filtered checkpoint saved at: {new_checkpoint_path}")


Filtered checkpoint saved at: C:/Users/Admin/nnUNet/nnUNet_results/Dataset001_Pancreas/nnUNetMultiTrainer_v2__nnUNetPlans__3d_fullres/fold_0/checkpoint_final_filtered.pth


In [5]:
print(filtered_state_dict.keys())

dict_keys(['encoder.stages.0.0.convs.0.conv.weight', 'encoder.stages.0.0.convs.0.conv.bias', 'encoder.stages.0.0.convs.0.norm.weight', 'encoder.stages.0.0.convs.0.norm.bias', 'encoder.stages.0.0.convs.0.all_modules.0.weight', 'encoder.stages.0.0.convs.0.all_modules.0.bias', 'encoder.stages.0.0.convs.0.all_modules.1.weight', 'encoder.stages.0.0.convs.0.all_modules.1.bias', 'encoder.stages.0.0.convs.1.conv.weight', 'encoder.stages.0.0.convs.1.conv.bias', 'encoder.stages.0.0.convs.1.norm.weight', 'encoder.stages.0.0.convs.1.norm.bias', 'encoder.stages.0.0.convs.1.all_modules.0.weight', 'encoder.stages.0.0.convs.1.all_modules.0.bias', 'encoder.stages.0.0.convs.1.all_modules.1.weight', 'encoder.stages.0.0.convs.1.all_modules.1.bias', 'encoder.stages.1.0.convs.0.conv.weight', 'encoder.stages.1.0.convs.0.conv.bias', 'encoder.stages.1.0.convs.0.norm.weight', 'encoder.stages.1.0.convs.0.norm.bias', 'encoder.stages.1.0.convs.0.all_modules.0.weight', 'encoder.stages.1.0.convs.0.all_modules.0.bias