# Taking the pre-trained model from STU-Net

### Load the model using the nnUNetv1

In [1]:
import torch
import torch.nn as nn
import nnunet
from batchgenerators.utilities.file_and_folder_operations import load_pickle, join
import pkgutil
import importlib
import torch.nn.functional as F
import SimpleITK as sitk

# --- Helper to find the class ---
def recursive_find_python_class(folder, trainer_name, current_module):
    tr = None
    for importer, modname, ispkg in pkgutil.iter_modules(folder):
        if not ispkg:
            m = importlib.import_module(current_module + "." + modname)
            if hasattr(m, trainer_name):
                tr = getattr(m, trainer_name)
                break
    if tr is None:
        for importer, modname, ispkg in pkgutil.iter_modules(folder):
            if ispkg:
                next_current_module = current_module + "." + modname
                tr = recursive_find_python_class([join(folder[0], modname)], trainer_name, current_module=next_current_module)
            if tr is not None:
                break
    return tr




Please cite the following paper when using nnUNet:

Isensee, F., Jaeger, P.F., Kohl, S.A.A. et al. "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation." Nat Methods (2020). https://doi.org/10.1038/s41592-020-01008-z


If you have questions or suggestions, feel free to open an issue at https://github.com/MIC-DKFZ/nnUNet



### Test with one inference -> Loading pre-trained STU-Net Large and doing inference.

In [2]:
def restore_model_reconstruction(pkl_file, checkpoint=None, train=False, fp16=None):
    # Load Configuration
    info = load_pickle(pkl_file)
    init = info['init']
    name = info['name']
    print(info)
    # Instantiate Trainer
    search_in = join(nnunet.__path__[0], "training", "network_training")
    tr = recursive_find_python_class([search_in], name, current_module="nnunet.training.network_training")
    
    # Fallback for meddec
    if tr is None:
        try:
            import meddec
            search_in = join(meddec.__path__[0], "model_training")
            tr = recursive_find_python_class([search_in], name, current_module="meddec.model_training")
        except ImportError:
            pass

    if tr is None: raise RuntimeError(f"Could not find trainer: {name}")

    trainer = tr(*init)
    if fp16 is not None: trainer.fp16 = fp16

    # Initialize Network (This creates the standard 5-head, 2-channel model)
    trainer.process_plans(info['plans'])
    trainer.initialize_network()
    
    # --- SURGERY STEP 1: MODIFY ARCHITECTURE ---
    print("\n--- Starting Architecture Surgery ---")
    
    # The network is now built. We need to find the High-Res output head.
    # In STUNet/nnU-Net, seg_outputs is a list. 
    # Index 0 = Deepest (Lowest Res). Index -1 (or 4) = Highest Res.
    
    # A. Disable Deep Supervision flag in the network module
    # This prevents the forward pass from trying to return multiple outputs
    if hasattr(trainer.network, 'deep_supervision'):
        trainer.network.deep_supervision = False
    if hasattr(trainer.network, 'do_ds'):
        trainer.network.do_ds = False
    
    # B. Replace the Final Output Layer with 1-Channel Conv3d
    # We grab the last layer (highest resolution)
    old_final_layer = trainer.network.seg_outputs[-1] 
    
    print(f"Replacing final layer: {old_final_layer}")
    print(f"Old config: In={old_final_layer.in_channels}, Out={old_final_layer.out_channels}")
    
    # Create new 1-channel layer
    new_final_layer = nn.Conv3d(
        in_channels=old_final_layer.in_channels,
        out_channels=105,
        kernel_size=old_final_layer.kernel_size,
        stride=old_final_layer.stride,
        padding=old_final_layer.padding,
        bias=(old_final_layer.bias is not None)
    )
    
    # Replace it in the module list. 
    # We perform a hard replacement so only this layer remains or is valid.
    # To be safe against list indexing errors in forward(), we replace ALL with Identity, 
    # and put the real one at the end.
    
    new_seg_outputs = nn.ModuleList()
    
    # Fill 0 to N-1 with Dummy Identity (to keep indices valid if code relies on them)
    for i in range(len(trainer.network.seg_outputs) - 1):
        new_seg_outputs.append(nn.Identity())
        
    # Append our new 1-channel real layer at the end
    new_seg_outputs.append(new_final_layer)
    
    # Assign back to network
    trainer.network.seg_outputs = new_seg_outputs
    print(f"Architecture Modified. Output heads: {len(trainer.network.seg_outputs)} (Last one is active 1-channel).")


    # --- SURGERY STEP 2: LOAD & PATCH WEIGHTS ---
    if checkpoint is not None:
        print(f"\n--- Loading and Patching Weights from {checkpoint} ---")
        
        # Load with weights_only=False to allow numpy scalars
        saved_state_dict = torch.load(checkpoint, map_location=torch.device('cpu'), weights_only=False)['state_dict']
        network_state_dict = trainer.network.state_dict()
        
        final_state_dict = {}
        
        # Find index of the last layer in the NEW network
        last_idx = len(trainer.network.seg_outputs) - 1
        new_layer_prefix = f"seg_outputs.{last_idx}"
        
        for key_new, param_new in network_state_dict.items():
            
            # 1. Handle the Output Head
            if new_layer_prefix in key_new:
                # We assume the old checkpoint had the high-res head at index '4'
                # (You verified this in logs: seg_outputs.4 had 64 input features)
                key_old = key_new.replace(f"seg_outputs.{last_idx}", "seg_outputs.4")
                
                if key_old in saved_state_dict:
                    param_old = saved_state_dict[key_old]
                    
                    if "weight" in key_new:
                        patched_param = param_old[:, ...] # Slice first dim
                    else: # bias
                        patched_param = param_old[:]
                        
                    final_state_dict[key_new] = patched_param
                    print(f"  PATCHED {key_new}: Sliced {key_old} {param_old.shape} -> {patched_param.shape}")
                else:
                    print(f"  WARNING: Could not find {key_old} in checkpoint!")

            # 2. Handle Dummy/Identity Layers (Skip loading)
            elif "seg_outputs" in key_new:
                # These are the Identity layers we added. They have no weights.
                # If they appear in state_dict (rare for Identity), ignore or init default.
                pass

            # 3. Handle Standard Layers (Encoder/Decoder)
            elif key_new in saved_state_dict:
                # Direct Copy
                if saved_state_dict[key_new].shape == param_new.shape:
                    final_state_dict[key_new] = saved_state_dict[key_new]
                else:
                    print(f"  Skipping {key_new}: Shape mismatch {saved_state_dict[key_new].shape} vs {param_new.shape}")
            else:
                pass # Missing key

        # Load weights (Strict=False is ESSENTIAL because we messed with the architecture)
        load_result = trainer.network.load_state_dict(final_state_dict, strict=False)

    print("\n" + "="*50)
    print("WEIGHT LOADING REPORT")
    print("="*50)
    
    # 1. Missing Keys (Layers initialized randomly because no weights were found)
    # We expect the "Identity" layers to be here (since they have no weights, this list might be empty or contain irrelevant names depending on implementation)
    # But mostly we care if 'seg_outputs.4' is MISSING (bad) or present.
    if len(load_result.missing_keys) > 0:
        print(f"‚ö†Ô∏è  MISSING KEYS ({len(load_result.missing_keys)}):")
        for k in load_result.missing_keys:
            print(f"   - {k}")
    else:
        print("‚úÖ No missing keys (All target layers received weights).")

    # 2. Unexpected Keys (Weights in the checkpoint that we threw away)
    # We EXPECT to see seg_outputs.0, .1, .2, .3 here because we deleted those layers from the architecture.
    # We also expect to see the unused channels of seg_outputs.4 here (though PyTorch won't list unused channels, just unused full keys).
    
    # To see what we skipped from the FILE, we compare the file's keys to the loaded dict
    skipped_keys = [k for k in saved_state_dict.keys() if k not in final_state_dict]
    
    if len(skipped_keys) > 0:
        print(f"\nüóëÔ∏è  SKIPPED LAYERS ({len(skipped_keys)}):")
        print("   (These were present in the checkpoint but removed/ignored in the new model)")
        # Print first 10 just to verify
        for k in skipped_keys[:10]:
            print(f"   - {k}")
        if len(skipped_keys) > 10:
            print(f"   ... and {len(skipped_keys)-10} more.")

    return trainer


In [3]:
# Saving the model to load it latter
PLANS_PATH = "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/checkpoints/pre-trained/Independent/large_ep4k.model.pkl"
MODEL_PATH = "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/checkpoints/pre-trained/Independent/large_ep4k.model"

trainer = restore_model_reconstruction(PLANS_PATH, MODEL_PATH, train=False)
model = trainer.network

# Verify
print("\nFinal Check:")
print(f"Output Head: {model.seg_outputs[-1]}")
# Should print: Conv3d(64, 1, kernel_size=(1, 1, 1), ...)
# Save the state dictionary of your modified model
torch.save(model.state_dict(), "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/checkpoints/pre-trained/tmp/stunet_reconstruction_weights.pth")
print("‚úÖ Weights saved.")


OrderedDict([('init', ('', 0, '', '', False, 0, True, False, True)), ('name', 'STUNetTrainer_large'), ('plans', {'num_stages': 1, 'num_modalities': 1, 'modalities': {0: 'nonCT'}, 'normalization_schemes': OrderedDict([(0, 'nonCT')]), 'num_classes': 104, 'all_classes': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104], 'base_num_features': 32, 'use_mask_for_norm': OrderedDict([(0, False)]), 'keep_only_largest_region': None, 'min_region_size_per_class': None, 'min_size_per_class': None, 'transpose_forward': [np.int64(0), 1, 2], 'transpose_backward': [np.int64(0), np.int64(1), np.int64(2)], 'data_identifier': 'nnUNetData_plans_v2.1',

In [4]:
from stunet_model import STUNetReconstruction
import torch

# Load the model with our Network
# Initialize the independent model
model = STUNetReconstruction()

# Load the weights you saved
#    'strict=True' should work perfectly now because the class matches the patched architecture exactly.
state_dict = torch.load("/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/checkpoints/pre-trained/tmp/stunet_reconstruction_weights.pth", map_location='cpu')
model.load_state_dict(state_dict, strict=True)

# Ready for training
model.eval()
model.cuda()
print("Model loaded successfully without nnU-Net dependencies!")

Model loaded successfully without nnU-Net dependencies!


In [8]:
import torch.nn.functional as F
def pad_and_center_crop_3d(tensor, crop_size=(128,128,128)):
    _, _, D, H, W = tensor.shape
    d_crop, h_crop, w_crop = crop_size
    
    # Compute padding
    pad_d = max(d_crop - D, 0)
    pad_h = max(h_crop - H, 0)
    pad_w = max(w_crop - W, 0)
    
    # Pad (pad=(w_left, w_right, h_left, h_right, d_left, d_right))
    tensor = F.pad(tensor, (
        pad_w//2, pad_w - pad_w//2,
        pad_h//2, pad_h - pad_h//2,
        pad_d//2, pad_d - pad_d//2
    ), mode='constant', value=-1000)
    
    # Now do center crop
    _, _, D, H, W = tensor.shape
    d_start = (D - d_crop) // 2
    h_start = (H - h_crop) // 2
    w_start = (W - w_crop) // 2
    
    cropped = tensor[:, :, d_start:d_start+d_crop,
                           h_start:h_start+h_crop,
                           w_start:w_start+w_crop]
    return cropped

def z_score_normalize(tensor, eps=1e-8):
    """
    Apply z-score normalization to a 3D tensor.
    Args:
        tensor: torch.Tensor, shape [1, 1, D, H, W]
        eps: small constant to avoid division by zero
    Returns:
        normalized tensor
    """
    mean = tensor.mean()
    std = tensor.std()
    normalized = (tensor - mean) / (std + eps)
    return normalized

def save_tensor_as_nii(tensor, output_path, reference_image=None):
    """
    Save a 3D PyTorch tensor as a .nii.gz file.
    
    tensor: torch.Tensor, shape [1,1,D,H,W] or [D,H,W]
    reference_image: sitk.Image to copy spacing, origin, direction
    """
    # Handle tuple
    if isinstance(tensor, tuple):
        tensor = tensor[0]
    print(f"tensor: {tensor.shape}")
    # Move to CPU
    if tensor.is_cuda:
        tensor = tensor.cpu()
    tensor = tensor.detach().numpy()
    
    # Remove batch & channel dimensions
    while tensor.ndim > 3:
        tensor = tensor[0]
    
    # Convert to SimpleITK image
    sitk_image = sitk.GetImageFromArray(tensor)
    
    # Copy metadata if provided
    if reference_image is not None:
        # Only copy 3D metadata
        sitk_image.SetSpacing(reference_image.GetSpacing())
        sitk_image.SetOrigin(reference_image.GetOrigin())
        sitk_image.SetDirection(reference_image.GetDirection())
    
    # Save
    sitk.WriteImage(sitk_image, output_path)
    print(f"Saved output to {output_path}")


In [9]:
# Do one inferenc (check if the results are good, the scan should be a CT scan)
import nibabel as nib
import numpy as np
import torch

nii = nib.load("/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/trash/ct.nii")
image_array = nii.get_fdata().astype(np.float32)
image_array = np.transpose(image_array, (2, 1, 0))

image_tensor = torch.from_numpy(image_array).unsqueeze(0).unsqueeze(0).cuda()

print(f"image_tensor: {image_tensor.shape}")


# cropped: shape [1, 1, 128, 128, 128]

cropped = pad_and_center_crop_3d(image_tensor, crop_size=(128,128,128))
cropped = z_score_normalize(cropped)
print(cropped.shape)


image_tensor: torch.Size([1, 1, 133, 160, 103])
torch.Size([1, 1, 128, 128, 128])


In [10]:
model = model.cuda()
output = model(cropped)
print(f"output: {output.shape}")
save_tensor_as_nii(tensor=cropped[0][0], output_path="/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/trash/cropped.nii.gz", reference_image=None)
print(torch.argmax(output, dim=1).shape)
save_tensor_as_nii(torch.argmax(output, dim=1)[0].to(torch.uint8) , "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/trash/output.nii.gz", reference_image=None)

output: torch.Size([1, 105, 128, 128, 128])
tensor: torch.Size([128, 128, 128])
Saved output to /mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/trash/cropped.nii.gz
torch.Size([1, 128, 128, 128])
tensor: torch.Size([128, 128, 128])
Saved output to /mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/trash/output.nii.gz


In [1]:
# VERIFYING IF THE MODEL LOADED BY THE NNUNET AND LOAD BY OUR NETWORK IS SIMILAR. MAX DIFF SHOULD BE 0.0
import torch
import numpy as np

# ==========================================
# 1. SETUP: LOAD BOTH MODELS
# ==========================================

# A. Load Reference Model (The one that works via nnunet logic)
#    Use the restore function that uses the original nnunet code
PLANS = "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/checkpoints/pre-trained/Independent/large_ep4k.model.pkl"
CKPT = "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/checkpoints/pre-trained/Independent/large_ep4k.model"

print("Loading Reference Model...")
# Ensure you have your 'restore_model_reconstruction' function available here
ref_trainer = restore_model_reconstruction(PLANS, CKPT, train=False, fp16=True)
model_ref = ref_trainer.network
model_ref.eval().cuda()

# B. Load Your Standalone Model
print("Loading Standalone Model...")
from stunet_model import STUNetReconstruction
model_new = STUNetReconstruction()
model_new.load_state_dict(torch.load("/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/checkpoints/pre-trained/tmp/stunet_reconstruction_weights.pth"), strict=True)
model_new.eval().cuda()

# ==========================================
# 2. DEBUGGING TOOL: HOOKS
# ==========================================
# We attach hooks to capture the output of every block

activations_ref = {}
activations_new = {}

def get_hook(name, storage_dict):
    def hook(model, input, output):
        # Detach and move to CPU to save memory
        storage_dict[name] = output.detach().cpu()
    return hook

# List of layers to check (Encoder blocks + Decoder blocks)
layer_names = [
    # Encoder
    "conv_blocks_context.0", 
    "conv_blocks_context.1", 
    "conv_blocks_context.2", 
    "conv_blocks_context.3", 
    "conv_blocks_context.4", 
    "conv_blocks_context.5", # Bottleneck
    
    # Decoder (Upsampling)
    "upsample_layers.0",
    "upsample_layers.1",
    
    # Decoder (Localization)
    "conv_blocks_localization.0",
    "conv_blocks_localization.1",
]

# Attach hooks
for name in layer_names:
    # Use recursive getattr to find the layer
    # Ref Model
    layer_ref = dict(model_ref.named_modules())[name]
    layer_ref.register_forward_hook(get_hook(name, activations_ref))
    
    # New Model
    layer_new = dict(model_new.named_modules())[name]
    layer_new.register_forward_hook(get_hook(name, activations_new))

# ==========================================
# 3. RUN COMPARISON
# ==========================================
print("\n--- Running Comparison ---")
with torch.no_grad():
    # Create a random input
    dummy_input = torch.randn(1, 1, 128, 128, 128).cuda()
    
    # Forward passes
    _ = model_ref(dummy_input)
    _ = model_new(dummy_input)

# Check errors
for name in layer_names:
    act_ref = activations_ref[name]
    act_new = activations_new[name]
    
    # 1. Check Shape
    if act_ref.shape != act_new.shape:
        print(f"‚ùå SHAPE MISMATCH at {name}!")
        print(f"   Ref: {act_ref.shape}")
        print(f"   New: {act_new.shape}")
        print("   -> Check your strides or kernel sizes in this block.")
        break
    
    # 2. Check Values
    diff = torch.abs(act_ref - act_new).max().item()
    print(f"Layer {name:<30} | Max Diff: {diff:.8f}")
    
    if diff > 1e-4: # Tolerance threshold
        print(f"‚ùå VALUE MISMATCH at {name}!")
        print("   -> The weights match, but the calculation is different.")
        print("   -> Possible causes: InstanceNorm settings (eps/momentum), LeakyReLU slope, or interpolation method.")
        break

print("--- Done ---")

Loading Reference Model...


NameError: name 'restore_model_reconstruction' is not defined

### All seems good. Save the model independently to load alone with PyTorch only -> Ignore nnUNet v1
* For this, we:
    * Save the model with replaced output layer (outputs 1 channel)
    * Defined the model network without the nnunet framework, in the stunet_model.py

In [2]:
def restore_model_reconstruction(pkl_file, checkpoint=None, train=False, fp16=None):
    # 1. Load Configuration
    info = load_pickle(pkl_file)
    init = info['init']
    name = info['name']
    
    # 2. Force Plans to 1 Class (This gets us close, usually 2 channels)
    info['plans']['num_classes'] = 1
    info['plans']['all_classes'] = [1]

    # 3. Instantiate Trainer
    search_in = join(nnunet.__path__[0], "training", "network_training")
    tr = recursive_find_python_class([search_in], name, current_module="nnunet.training.network_training")
    
    # Fallback for meddec
    if tr is None:
        try:
            import meddec
            search_in = join(meddec.__path__[0], "model_training")
            tr = recursive_find_python_class([search_in], name, current_module="meddec.model_training")
        except ImportError:
            pass

    if tr is None: raise RuntimeError(f"Could not find trainer: {name}")

    trainer = tr(*init)
    if fp16 is not None: trainer.fp16 = fp16

    # 4. Initialize Network (This creates the standard 5-head, 2-channel model)
    trainer.process_plans(info['plans'])
    trainer.initialize_network()
    
    # --- SURGERY STEP 1: MODIFY ARCHITECTURE ---
    print("\n--- Starting Architecture Surgery ---")
    
    # The network is now built. We need to find the High-Res output head.
    # In STUNet/nnU-Net, seg_outputs is a list. 
    # Index 0 = Deepest (Lowest Res). Index -1 (or 4) = Highest Res.
    
    # A. Disable Deep Supervision flag in the network module
    # This prevents the forward pass from trying to return multiple outputs
    if hasattr(trainer.network, 'deep_supervision'):
        trainer.network.deep_supervision = False
    if hasattr(trainer.network, 'do_ds'):
        trainer.network.do_ds = False
    
    # B. Replace the Final Output Layer with 1-Channel Conv3d
    # We grab the last layer (highest resolution)
    old_final_layer = trainer.network.seg_outputs[-1] 
    
    print(f"Replacing final layer: {old_final_layer}")
    print(f"Old config: In={old_final_layer.in_channels}, Out={old_final_layer.out_channels}")
    
    # Create new 1-channel layer
    new_final_layer = nn.Conv3d(
        in_channels=old_final_layer.in_channels,
        out_channels=1, # <--- FORCING 1 CHANNEL HERE
        kernel_size=old_final_layer.kernel_size,
        stride=old_final_layer.stride,
        padding=old_final_layer.padding,
        bias=(old_final_layer.bias is not None)
    )
    
    # Replace it in the module list. 
    # We perform a hard replacement so only this layer remains or is valid.
    # To be safe against list indexing errors in forward(), we replace ALL with Identity, 
    # and put the real one at the end.
    new_seg_outputs = nn.ModuleList()
    
    # Fill 0 to N-1 with Dummy Identity (to keep indices valid if code relies on them)
    for i in range(len(trainer.network.seg_outputs) - 1):
        new_seg_outputs.append(nn.Identity())
        
    # Append our new 1-channel real layer at the end
    new_seg_outputs.append(new_final_layer)
    
    # Assign back to network
    trainer.network.seg_outputs = new_seg_outputs
    print(f"Architecture Modified. Output heads: {len(trainer.network.seg_outputs)} (Last one is active 1-channel).")


    # --- SURGERY STEP 2: LOAD & PATCH WEIGHTS ---
    if checkpoint is not None:
        print(f"\n--- Loading and Patching Weights from {checkpoint} ---")
        
        # Load with weights_only=False to allow numpy scalars
        saved_state_dict = torch.load(checkpoint, map_location=torch.device('cpu'), weights_only=False)['state_dict']
        network_state_dict = trainer.network.state_dict()
        
        final_state_dict = {}
        
        # We need to map the Old High-Res Head (index 4) to our New High-Res Head (index 4)
        # And slice channels from 104 -> 1
        
        # Find index of the last layer in the NEW network
        last_idx = len(trainer.network.seg_outputs) - 1
        new_layer_prefix = f"seg_outputs.{last_idx}"
        
        for key_new, param_new in network_state_dict.items():
            
            # 1. Handle the Output Head
            if new_layer_prefix in key_new:
                # We assume the old checkpoint had the high-res head at index '4'
                # (You verified this in logs: seg_outputs.4 had 64 input features)
                key_old = key_new.replace(f"seg_outputs.{last_idx}", "seg_outputs.4")
                
                if key_old in saved_state_dict:
                    param_old = saved_state_dict[key_old]
                    
                    if "weight" in key_new:
                        patched_param = param_old[:1, ...] # Slice first dim
                    else: # bias
                        patched_param = param_old[:1]
                        
                    final_state_dict[key_new] = patched_param
                    print(f"  PATCHED {key_new}: Sliced {key_old} {param_old.shape} -> {patched_param.shape}")
                else:
                    print(f"  WARNING: Could not find {key_old} in checkpoint!")

            # 2. Handle Dummy/Identity Layers (Skip loading)
            elif "seg_outputs" in key_new:
                # These are the Identity layers we added. They have no weights.
                # If they appear in state_dict (rare for Identity), ignore or init default.
                pass

            # 3. Handle Standard Layers (Encoder/Decoder)
            elif key_new in saved_state_dict:
                # Direct Copy
                if saved_state_dict[key_new].shape == param_new.shape:
                    final_state_dict[key_new] = saved_state_dict[key_new]
                else:
                    print(f"  Skipping {key_new}: Shape mismatch {saved_state_dict[key_new].shape} vs {param_new.shape}")
            else:
                pass # Missing key

        # Load weights (Strict=False is ESSENTIAL because we messed with the architecture)
        load_result = trainer.network.load_state_dict(final_state_dict, strict=False)

    print("\n" + "="*50)
    print("WEIGHT LOADING REPORT")
    print("="*50)
    
    # 1. Missing Keys (Layers initialized randomly because no weights were found)
    # We expect the "Identity" layers to be here (since they have no weights, this list might be empty or contain irrelevant names depending on implementation)
    # But mostly we care if 'seg_outputs.4' is MISSING (bad) or present.
    if len(load_result.missing_keys) > 0:
        print(f"‚ö†Ô∏è  MISSING KEYS ({len(load_result.missing_keys)}):")
        for k in load_result.missing_keys:
            print(f"   - {k}")
    else:
        print("‚úÖ No missing keys (All target layers received weights).")

    # 2. Unexpected Keys (Weights in the checkpoint that we threw away)
    # We EXPECT to see seg_outputs.0, .1, .2, .3 here because we deleted those layers from the architecture.
    # We also expect to see the unused channels of seg_outputs.4 here (though PyTorch won't list unused channels, just unused full keys).
    
    # To see what we skipped from the FILE, we compare the file's keys to the loaded dict
    skipped_keys = [k for k in saved_state_dict.keys() if k not in final_state_dict]
    
    if len(skipped_keys) > 0:
        print(f"\nüóëÔ∏è  SKIPPED LAYERS ({len(skipped_keys)}):")
        print("   (These were present in the checkpoint but removed/ignored in the new model)")
        # Print first 10 just to verify
        for k in skipped_keys[:10]:
            print(f"   - {k}")
        if len(skipped_keys) > 10:
            print(f"   ... and {len(skipped_keys)-10} more.")

    return trainer

# --- EXECUTE ---
PLANS_PATH = "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/checkpoints/pre-trained/Independent/large_ep4k.model.pkl"
MODEL_PATH = "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/checkpoints/pre-trained/Independent/large_ep4k.model"

trainer = restore_model_reconstruction(PLANS_PATH, MODEL_PATH, train=False)
model = trainer.network

# save model and network
from stunet_model import STUNetReconstruction
import torch

# Verify
print("\nFinal Check:")
print(f"Output Head: {model.seg_outputs[-1]}")
# Should print: Conv3d(64, 1, kernel_size=(1, 1, 1), ...)
# Save the state dictionary of your modified model
torch.save(model.state_dict(), "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/checkpoints/pre-trained/Independent/binary_large_ep4k.pth")
print("‚úÖ Weights saved.")

# Initialize the independent model
model = STUNetReconstruction()

# Load the weights you saved
#    'strict=True' should work perfectly now because the class matches the patched architecture exactly.
state_dict = torch.load("/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/checkpoints/pre-trained/Independent/binary_large_ep4k.pth", map_location='cpu')
model.load_state_dict(state_dict, strict=True)

# Ready for training
model.train()
model.cuda()
print("Model loaded successfully without nnU-Net dependencies!")

nnUNet_raw_data_base is not defined and nnU-Net can only be used on data for which preprocessed files are already present on your system. nnU-Net cannot be used for experiment planning and preprocessing like this. If this is not intended, please read documentation/setting_up_paths.md for information on how to set this up properly.
nnUNet_preprocessed is not defined and nnU-Net can not be used for preprocessing or training. If this is not intended, please read documentation/setting_up_paths.md for information on how to set this up.
RESULTS_FOLDER is not defined and nnU-Net cannot be used for training or inference. If this is not intended behavior, please read documentation/setting_up_paths.md for information on how to set this up.

--- Starting Architecture Surgery ---
Replacing final layer: Conv3d(64, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
Old config: In=64, Out=2
Architecture Modified. Output heads: 5 (Last one is active 1-channel).

--- Loading and Patching Weights from /mounts/