In [1]:
import torch
import os

nnUNet_raw = os.environ.get('nnUNet_raw')
nnUNet_preprocessed = os.environ.get('nnUNet_preprocessed')
nnUNet_results = os.environ.get('nnUNet_results')

In [None]:

model_999_fold4 = torch.load("/Dataset999/nnUNetTrainer__nnUNetResEncUNetMPlans__3d_fullres/fold_4/checkpoint_final.pth", weights_only=False)


In [None]:
model_to_update = torch.load("./dir/fold_2/checkpoint_best_1_adapted.pth", weights_only=False)


In [None]:
def compare_model_shapes(state_dict_1, state_dict_2, name_1="Model 1", name_2="Model 2", filter_key=None):
    """
    Compare the shapes of all matching layers in two model state dicts.

    Parameters:
    - state_dict_1: dict, first model's state_dict (e.g., model_444["network_weights"])
    - state_dict_2: dict, second model's state_dict (e.g., model_555["network_weights"])
    - name_1: str, name for first model
    - name_2: str, name for second model
    - filter_key: str or None, if provided only layers containing this string will be checked

    Returns:
    - List of mismatched layers
    """
    mismatches = []

    for layer_name in state_dict_1:
        if filter_key and filter_key.lower() not in layer_name.lower():
            continue
        
        tensor_1 = state_dict_1[layer_name]
        tensor_2 = state_dict_2.get(layer_name, None)

        if tensor_2 is None:
            print(f"⚠️ Layer missing in {name_2}: {layer_name}")
            mismatches.append((layer_name, "missing"))
            continue

        shape_1 = tuple(tensor_1.shape)
        shape_2 = tuple(tensor_2.shape)

        if shape_1 != shape_2:
            print(f"❌ {layer_name}: {name_1} shape={shape_1}, {name_2} shape={shape_2}")
            mismatches.append((layer_name, shape_1, shape_2))
        else:
            print(f"✅ {layer_name}: shape match {shape_1}")

    return mismatches


In [None]:
def transfer_weights_from_smaller_to_larger(src_state_dict, dst_state_dict, verbose=True):
    """
    Transfers weights from a smaller model (e.g., T1-only) into a larger model (e.g., T1+FLAIR).
    Copies overlapping weight regions and leaves the rest of the target model unchanged.

    Parameters:
    - src_state_dict: dict, state_dict from smaller model
    - dst_state_dict: dict, state_dict of the larger model to be modified (in-place)
    - verbose: bool, print details

    Returns:
    - List of keys that were partially or fully updated
    """
    updated = []

    for key, src_tensor in src_state_dict.items():
        if key not in dst_state_dict:
            if verbose:
                print(f"⚠️ Skipping {key}: not in destination model")
            continue

        dst_tensor = dst_state_dict[key]
        if src_tensor.shape == dst_tensor.shape:
            dst_state_dict[key] = src_tensor.clone()
            updated.append(key)
            if verbose:
                print(f"✅ {key}: copied full tensor")
        else:
            # Only copy overlapping part
            new_tensor = dst_tensor.clone()
            slices = tuple(slice(0, min(s1, s2)) for s1, s2 in zip(src_tensor.shape, dst_tensor.shape))
            new_tensor[slices] = src_tensor[slices]
            dst_state_dict[key] = new_tensor
            updated.append(key)
            if verbose:
                print(f"🔁 {key}: copied partial tensor {src_tensor.shape} → {dst_tensor.shape}")

    return updated
updated = transfer_weights_from_smaller_to_larger(model_999_fold4['network_weights'], model_to_update['network_weights'])

In [None]:
def check_shared_weights_equal(src_state_dict, dst_state_dict, rtol=1e-5, atol=1e-8, verbose=True):
    """
    Check if the overlapping regions of src_state_dict match those in dst_state_dict.

    Parameters:
    - src_state_dict: dict, e.g., from T1-only model
    - dst_state_dict: dict, e.g., from T1+FLAIR model
    - rtol, atol: float, relative and absolute tolerance for comparison
    - verbose: bool, whether to print results

    Returns:
    - List of mismatched keys (with details)
    """
    mismatches = []

    for key, src_tensor in src_state_dict.items():
        if key not in dst_state_dict:
            if verbose:
                print(f"⚠️ Missing in destination: {key}")
            continue

        dst_tensor = dst_state_dict[key]
        slices = tuple(slice(0, min(s1, s2)) for s1, s2 in zip(src_tensor.shape, dst_tensor.shape))

        src_sub = src_tensor[slices]
        dst_sub = dst_tensor[slices]

        if not torch.allclose(src_sub, dst_sub, rtol=rtol, atol=atol):
            if verbose:
                print(f"❌ Mismatch in {key} (shared shape {src_sub.shape})")
            mismatches.append(key)
        elif verbose:
            print(f"✅ {key}: shared weights match")

    return mismatches


In [None]:
mismatches = check_shared_weights_equal(model_999_fold4["network_weights"], model_to_update["network_weights"])
print(f"Found {len(mismatches)} mismatches")

In [None]:
import torch

def transfer_weights_repeat_to_larger_T1_Flair(src_state_dict, dst_state_dict, verbose=True):
    """
    Transfers weights from a smaller model into a larger model by copying and
    repeating weights to fill unmatched dimensions, e.g., for input channels.

    Parameters:
    - src_state_dict: dict, state_dict from smaller model
    - dst_state_dict: dict, state_dict of the larger model to be modified (in-place)
    - verbose: bool, print details

    Returns:
    - List of keys that were partially or fully updated
    """
    updated = []

    for key, src_tensor in src_state_dict.items():
        if key not in dst_state_dict:
            if verbose:
                print(f"⚠️ Skipping {key}: not in destination model")
            continue

        dst_tensor = dst_state_dict[key]
        if src_tensor.shape == dst_tensor.shape:
            dst_state_dict[key] = src_tensor.clone()
            updated.append(key)
            if verbose:
                print(f"✅ {key}: copied full tensor")
        else:
            # Determine which dimensions need repeating
            new_shape = dst_tensor.shape
            src_shape = src_tensor.shape

            # Expand src_tensor shape to dst_tensor by repeating along extra dimensions
            expand_factors = [
                dst // src if src != dst else 1
                for src, dst in zip(src_shape, new_shape)
            ]
            # Repeat the src_tensor as needed
            expanded_tensor = src_tensor
            for dim, factor in enumerate(expand_factors):
                if factor > 1:
                    expanded_tensor = expanded_tensor.repeat(
                        *(factor if i == dim else 1 for i in range(expanded_tensor.ndimension()))
                    )
            # Crop in case repeat overshoots
            slices = tuple(slice(0, dst) for dst in new_shape)
            expanded_tensor = expanded_tensor[slices].clone()
            dst_state_dict[key] = expanded_tensor
            updated.append(key)
            if verbose:
                print(f"🔁 {key}: repeated tensor {src_shape} → {new_shape}")

    return updated

# Example usage
updated = transfer_weights_repeat_to_larger_T1_Flair(model_999_fold4['network_weights'], model_to_update['network_weights'])

torch.save(model_to_update, f"./updated_checkpoint.pth")




🔁 encoder.stem.convs.0.conv.weight: repeated tensor torch.Size([32, 1, 3, 3, 3]) → torch.Size([32, 2, 3, 3, 3])
✅ encoder.stem.convs.0.conv.bias: copied full tensor
✅ encoder.stem.convs.0.norm.weight: copied full tensor
✅ encoder.stem.convs.0.norm.bias: copied full tensor
🔁 encoder.stem.convs.0.all_modules.0.weight: repeated tensor torch.Size([32, 1, 3, 3, 3]) → torch.Size([32, 2, 3, 3, 3])
✅ encoder.stem.convs.0.all_modules.0.bias: copied full tensor
✅ encoder.stem.convs.0.all_modules.1.weight: copied full tensor
✅ encoder.stem.convs.0.all_modules.1.bias: copied full tensor
✅ encoder.stages.0.blocks.0.conv1.conv.weight: copied full tensor
✅ encoder.stages.0.blocks.0.conv1.conv.bias: copied full tensor
✅ encoder.stages.0.blocks.0.conv1.norm.weight: copied full tensor
✅ encoder.stages.0.blocks.0.conv1.norm.bias: copied full tensor
✅ encoder.stages.0.blocks.0.conv1.all_modules.0.weight: copied full tensor
✅ encoder.stages.0.blocks.0.conv1.all_modules.0.bias: copied full tensor
✅ encoder.

In [5]:
def check_repeated_weights_equal(src_state_dict, dst_state_dict, rtol=1e-5, atol=1e-8, verbose=True):
    """
    Check if dst_state_dict contains repeated copies of src_state_dict in extended dimensions.
    
    For each key, for each dim, if dst > src, checks that the dst tensor consists of repeated src blocks.

    Returns:
    - List of keys with mismatches
    """
    mismatches = []

    for key, src_tensor in src_state_dict.items():
        if key not in dst_state_dict:
            if verbose:
                print(f"⚠️ Missing in destination: {key}")
            continue

        dst_tensor = dst_state_dict[key]
        src_shape = src_tensor.shape
        dst_shape = dst_tensor.shape

        # If shapes are the same, just compare directly
        if src_shape == dst_shape:
            if not torch.allclose(src_tensor, dst_tensor, rtol=rtol, atol=atol):
                if verbose:
                    print(f"❌ Mismatch in {key} (identical shape {src_shape})")
                mismatches.append(key)
            elif verbose:
                print(f"✅ {key}: identical weights match")
            continue

        # For shapes with repetition
        expand_factors = [
            dst // src if src != dst else 1
            for src, dst in zip(src_shape, dst_shape)
        ]

        # Reconstruct expected tensor by repeating src_tensor
        expected_tensor = src_tensor
        for dim, factor in enumerate(expand_factors):
            if factor > 1:
                expected_tensor = expected_tensor.repeat(
                    *(factor if i == dim else 1 for i in range(expected_tensor.ndimension()))
                )
        # Crop (in case repeat overshoots)
        slices = tuple(slice(0, dst) for dst in dst_shape)
        expected_tensor = expected_tensor[slices]

        # Compare expected_tensor and actual dst_tensor
        if not torch.allclose(expected_tensor, dst_tensor, rtol=rtol, atol=atol):
            if verbose:
                print(f"❌ Mismatch in {key} (expected repeated shape {expected_tensor.shape}, got {dst_tensor.shape})")
            mismatches.append(key)
        elif verbose:
            print(f"✅ {key}: repeated weights match")

    return mismatches

mismatches = check_repeated_weights_equal(model_999_fold4['network_weights'], model_to_update['network_weights'])

✅ encoder.stem.convs.0.conv.weight: repeated weights match
✅ encoder.stem.convs.0.conv.bias: identical weights match
✅ encoder.stem.convs.0.norm.weight: identical weights match
✅ encoder.stem.convs.0.norm.bias: identical weights match
✅ encoder.stem.convs.0.all_modules.0.weight: repeated weights match
✅ encoder.stem.convs.0.all_modules.0.bias: identical weights match
✅ encoder.stem.convs.0.all_modules.1.weight: identical weights match
✅ encoder.stem.convs.0.all_modules.1.bias: identical weights match
✅ encoder.stages.0.blocks.0.conv1.conv.weight: identical weights match
✅ encoder.stages.0.blocks.0.conv1.conv.bias: identical weights match
✅ encoder.stages.0.blocks.0.conv1.norm.weight: identical weights match
✅ encoder.stages.0.blocks.0.conv1.norm.bias: identical weights match
✅ encoder.stages.0.blocks.0.conv1.all_modules.0.weight: identical weights match
✅ encoder.stages.0.blocks.0.conv1.all_modules.0.bias: identical weights match
✅ encoder.stages.0.blocks.0.conv1.all_modules.1.weight: 