In [25]:
def animate_image_and_mask_compare(
    img_orig, mask_orig, img_aug, mask_aug,
    title="Image + Mask Animation (Original vs Predicted)",
    alpha=0.4, interval=100
):
    """
    Animate a 3D image with a red mask overlay, comparing original and augmented side by side.

    Args:
        img_orig (np.ndarray): Original 3D image array (Z last, e.g. [H, W, Z]).
        mask_orig (np.ndarray): Original 3D mask array (same shape as img, binary or float).
        img_aug (np.ndarray): Augmented 3D image array (Z last, e.g. [H, W, Z]).
        mask_aug (np.ndarray): Augmented 3D mask array (same shape as img, binary or float).
        title (str): Title prefix for the animation.
        alpha (float): Alpha value for mask overlay.
        interval (int): Delay between frames in ms.

    Returns:
        HTML: Animation for Jupyter display.
    """
    import matplotlib.pyplot as plt
    from matplotlib import animation
    from IPython.display import HTML
    import numpy as np

    # Optionally normalize images for display
    # img_orig_norm = (img_orig - img_orig.min()) / (np.ptp(img_orig) + 1e-8)
    # img_aug_norm = (img_aug - img_aug.min()) / (np.ptp(img_aug) + 1e-8)
    img_orig_norm = img_orig
    img_aug_norm = img_aug

    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
    slice_idx = img_orig.shape[2] // 2

    # Initial slices
    img_slice_orig = img_orig_norm[:, :, slice_idx]
    mask_slice_orig = mask_orig[:, :, slice_idx]
    img_slice_aug = img_aug_norm[:, :, slice_idx]
    mask_slice_aug = mask_aug[:, :, slice_idx]

    # Original
    axes[0].set_title(f"Original (slice {slice_idx})")
    im_orig = axes[0].imshow(img_slice_orig, cmap='gray')
    red_mask_orig = np.zeros((*mask_slice_orig.shape, 4), dtype=np.float32)
    red_mask_orig[..., 0] = 1.0
    red_mask_orig[..., 3] = (mask_slice_orig > 0) * alpha
    mask_im_orig = axes[0].imshow(red_mask_orig)
    axes[0].axis('off')

    # Augmented
    axes[1].set_title(f"Predicted (slice {slice_idx})")
    im_aug = axes[1].imshow(img_slice_aug, cmap='gray')
    red_mask_aug = np.zeros((*mask_slice_aug.shape, 4), dtype=np.float32)
    red_mask_aug[..., 0] = 1.0
    red_mask_aug[..., 3] = (mask_slice_aug > 0) * alpha
    mask_im_aug = axes[1].imshow(red_mask_aug)
    axes[1].axis('off')

    fig.suptitle(title)

    def update(i):
        # Original
        img_slice_orig = img_orig_norm[:, :, i]
        mask_slice_orig = mask_orig[:, :, i]
        im_orig.set_data(img_slice_orig)
        red_mask_orig = np.zeros((*mask_slice_orig.shape, 4), dtype=np.float32)
        red_mask_orig[..., 0] = 1.0
        red_mask_orig[..., 3] = (mask_slice_orig > 0) * alpha
        mask_im_orig.set_data(red_mask_orig)
        axes[0].set_title(f"Original (slice {i})")

        # Augmented
        img_slice_aug = img_aug_norm[:, :, i]
        mask_slice_aug = mask_aug[:, :, i]
        im_aug.set_data(img_slice_aug)
        red_mask_aug = np.zeros((*mask_slice_aug.shape, 4), dtype=np.float32)
        red_mask_aug[..., 0] = 1.0
        red_mask_aug[..., 3] = (mask_slice_aug > 0) * alpha
        mask_im_aug.set_data(red_mask_aug)
        axes[1].set_title(f"Predicted (slice {i})")

        return [im_orig, mask_im_orig, im_aug, mask_im_aug]

    ani = animation.FuncAnimation(
        fig, update, frames=img_orig.shape[2], interval=interval, blit=True
    )
    plt.close(fig)
    return HTML(ani.to_jshtml())

In [38]:
import torch
from utils.constants import *
from model_definitions.UNet import UNet3D

from utils.MRIDataset import MRIDataset
from utils.transformations import MRIAugmentationPipeline
from torch.utils.data import DataLoader

from itertools import islice

# --------------------------
# Load test dataset
# --------------------------
TRAIN_DATASET = MRIDataset(f"../{TRAIN_IMG_DIR}", f"../{TRAIN_LABEL_FILE}", is_train=True, augmentations=None)
# TEST_DATASET = MRIDataset(VAL_IMG_DIR, VAL_LABEL_FILE, is_train=False, augmentations=None)
train_loader = DataLoader(TRAIN_DATASET, batch_size=1, shuffle=False)

models = ['trial_32', 'trial_46', 'trial_27', 'trial_43', 'trial_42']   # same 5 top models
state_dicts = {m: torch.load(f"../{EXPERIMENT_DIR}/UNet3D/{m}/model.pt", map_location=DEVICE) for m in models}

for model_name in models:
    print(f"\nEvaluating model: {model_name}")

    model = UNet3D(depth=5, base_filters=16, clf_threshold=[0.5, 0.5, 0.5]).to(DEVICE, DTYPE)
    model.load_state_dict(state_dicts[model_name])
    model.eval()

    img, mask, label = next(islice(train_loader, 10, 12))

    # model.train()
    # assert(model.training)
    # _, mask_pred = model(img)

    # feats, _ = model.encoder(img)
    # gap = torch.nn.AdaptiveAvgPool3d(1)
    # flat = torch.nn.Flatten()
    # feats = flat(gap(feats))
    # print(feats.mean(), feats.std())

    all_weights = torch.cat([p.flatten() for p in model.classifier.parameters()])
    print(all_weights.shape)
    print(all_weights.mean(), all_weights.std(), all_weights.min(), all_weights.max())


    # img = img.cpu().squeeze(0).squeeze(0).permute(2,1,0).numpy()
    # mask = mask.cpu().squeeze(0).squeeze(0).permute(2,1,0).numpy()
    # mask_pred = mask_pred.detach().cpu().squeeze(0).squeeze(0).permute(2,1,0).numpy()


    # display(animate_image_and_mask_compare(img, mask, img, mask_pred))




Evaluating model: trial_32
torch.Size([25155])
tensor(0.0007, device='cuda:0', grad_fn=<MeanBackward0>) tensor(0.0511, device='cuda:0', grad_fn=<StdBackward0>) tensor(-0.1290, device='cuda:0', grad_fn=<MinBackward1>) tensor(0.1337, device='cuda:0', grad_fn=<MaxBackward1>)

Evaluating model: trial_46
torch.Size([25155])
tensor(0.0007, device='cuda:0', grad_fn=<MeanBackward0>) tensor(0.0509, device='cuda:0', grad_fn=<StdBackward0>) tensor(-0.1315, device='cuda:0', grad_fn=<MinBackward1>) tensor(0.1320, device='cuda:0', grad_fn=<MaxBackward1>)

Evaluating model: trial_27
torch.Size([25155])
tensor(0.0008, device='cuda:0', grad_fn=<MeanBackward0>) tensor(0.0512, device='cuda:0', grad_fn=<StdBackward0>) tensor(-0.1290, device='cuda:0', grad_fn=<MinBackward1>) tensor(0.1377, device='cuda:0', grad_fn=<MaxBackward1>)

Evaluating model: trial_43
torch.Size([25155])
tensor(0.0005, device='cuda:0', grad_fn=<MeanBackward0>) tensor(0.0514, device='cuda:0', grad_fn=<StdBackward0>) tensor(-0.1378, d