In [4]:
import torch

# Path to your pretrained autoencoder checkpoint
checkpoint_path = "/home/ulixes/segmentation_cv/unet/AE-pretrained/models/ae_pet_reconstruction/best_model.pth"

# Load the checkpoint to inspect its keys
checkpoint = torch.load(checkpoint_path, map_location="cpu")
print("Checkpoint keys:", checkpoint.keys())



Checkpoint keys: dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'scheduler_state_dict', 'best_loss', 'config'])


In [1]:
from unet import UNet
import torch.nn as nn
model = UNet(
    in_channels=3,
    num_classes=3,
    n_stages=6,
    features_per_stage=[32, 64, 128, 256, 512, 512],
    kernel_sizes=[[3, 3]] * 6,
    strides=[[1, 1], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2]],
    n_conv_per_stage=[2] * 6,
    n_conv_per_stage_decoder=[2] * 5,
    conv_bias=True,
    norm_op=nn.InstanceNorm2d,
    norm_op_kwargs={"eps": 1e-5, "affine": True},
    dropout_op=None,
    nonlin=nn.LeakyReLU,
    nonlin_kwargs={"inplace": True},
    encoder_dropout_rates=[0.0, 0.0, 0.1, 0.2, 0.3, 0.3],
    decoder_dropout_rates=[0.3, 0.2, 0.2, 0.1, 0.0],
    pretrained_encoder_path=""
)
print("Model's children:")
for name, child in model.named_children():
    print(f"  {name}: {child.__class__.__name__}")


Model's children:
  encoder_stages: ModuleList
  decoder_stages: ModuleList
  segmentation_output: Conv2d


In [2]:
import torch

checkpoint_path = "/home/ulixes/segmentation_cv/unet/AE-pretrained/models/ae_pet_reconstruction/best_model.pth"
checkpoint = torch.load(checkpoint_path, map_location="cpu")

print("Top-level checkpoint keys:")
for key in checkpoint.keys():
    print(" ", key)

# If the checkpoint contains a 'model_state_dict', print its keys too:
if "model_state_dict" in checkpoint:
    print("\nKeys in model_state_dict:")
    for key in checkpoint["model_state_dict"].keys():
        print(" ", key)


Top-level checkpoint keys:
  epoch
  model_state_dict
  optimizer_state_dict
  scheduler_state_dict
  best_loss
  config

Keys in model_state_dict:
  encoder_stages.0.block.0.weight
  encoder_stages.0.block.0.bias
  encoder_stages.0.block.1.weight
  encoder_stages.0.block.1.bias
  encoder_stages.0.block.3.weight
  encoder_stages.0.block.3.bias
  encoder_stages.0.block.4.weight
  encoder_stages.0.block.4.bias
  encoder_stages.1.block.0.weight
  encoder_stages.1.block.0.bias
  encoder_stages.1.block.1.weight
  encoder_stages.1.block.1.bias
  encoder_stages.1.block.3.weight
  encoder_stages.1.block.3.bias
  encoder_stages.1.block.4.weight
  encoder_stages.1.block.4.bias
  encoder_stages.2.block.0.weight
  encoder_stages.2.block.0.bias
  encoder_stages.2.block.1.weight
  encoder_stages.2.block.1.bias
  encoder_stages.2.block.4.weight
  encoder_stages.2.block.4.bias
  encoder_stages.2.block.5.weight
  encoder_stages.2.block.5.bias
  encoder_stages.3.block.0.weight
  encoder_stages.3.block.0

In [3]:
import torch

checkpoint_path = "/home/ulixes/segmentation_cv/unet/AE-pretrained/models/ae_pet_reconstruction/best_model.pth"
checkpoint = torch.load(checkpoint_path, map_location="cpu")
print("Top-level checkpoint keys:", list(checkpoint.keys()))
print("Encoder keys from model_state_dict:")
encoder_keys = [k for k in checkpoint["model_state_dict"].keys() if k.startswith("encoder_stages.")]
for key in encoder_keys:
    print(" ", key)


Top-level checkpoint keys: ['epoch', 'model_state_dict', 'optimizer_state_dict', 'scheduler_state_dict', 'best_loss', 'config']
Encoder keys from model_state_dict:
  encoder_stages.0.block.0.weight
  encoder_stages.0.block.0.bias
  encoder_stages.0.block.1.weight
  encoder_stages.0.block.1.bias
  encoder_stages.0.block.3.weight
  encoder_stages.0.block.3.bias
  encoder_stages.0.block.4.weight
  encoder_stages.0.block.4.bias
  encoder_stages.1.block.0.weight
  encoder_stages.1.block.0.bias
  encoder_stages.1.block.1.weight
  encoder_stages.1.block.1.bias
  encoder_stages.1.block.3.weight
  encoder_stages.1.block.3.bias
  encoder_stages.1.block.4.weight
  encoder_stages.1.block.4.bias
  encoder_stages.2.block.0.weight
  encoder_stages.2.block.0.bias
  encoder_stages.2.block.1.weight
  encoder_stages.2.block.1.bias
  encoder_stages.2.block.4.weight
  encoder_stages.2.block.4.bias
  encoder_stages.2.block.5.weight
  encoder_stages.2.block.5.bias
  encoder_stages.3.block.0.weight
  encoder_

In [6]:
from unet import UNet
import torch.nn as nn

model = UNet(
    in_channels=3,
    num_classes=3,
    n_stages=6,
    features_per_stage=[32, 64, 128, 256, 512, 512],
    kernel_sizes=[[3, 3]] * 6,
    strides=[[1, 1], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2]],
    n_conv_per_stage=[2] * 6,
    n_conv_per_stage_decoder=[2] * 5,
    conv_bias=True,
    norm_op=nn.InstanceNorm2d,
    norm_op_kwargs={"eps": 1e-5, "affine": True},
    dropout_op=None,
    nonlin=nn.LeakyReLU,
    nonlin_kwargs={"inplace": True},
    encoder_dropout_rates=[0.0, 0.0, 0.1, 0.2, 0.3, 0.3],
    decoder_dropout_rates=[0.3, 0.2, 0.2, 0.1, 0.0],
    pretrained_encoder_path=""  # We'll load manually below
)

# Attempt to load the pretrained encoder
model.load_pretrained_encoder(checkpoint_path)


Successfully loaded pre-trained encoder from /home/ulixes/segmentation_cv/unet/AE-pretrained/models/ae_pet_reconstruction/best_model.pth
Encoder stages have been frozen.


In [7]:
encoder_params = list(model.encoder_stages.parameters())
frozen = all(not p.requires_grad for p in encoder_params)
print("Are all encoder parameters frozen?", frozen)


Are all encoder parameters frozen? True


In [8]:
import torch
dummy_input = torch.randn(1, 3, 512, 512)  # Example input
output = model(dummy_input)
print("Output shape:", output.shape)


Output shape: torch.Size([1, 3, 512, 512])
