In [None]:
"""
Improved visualization script for MONAI transformations applied to medical images.
Includes robust error handling for different image types.
"""

import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import monai
from monai.data import PILReader
from monai.transforms import (
    LoadImaged,
    EnsureChannelFirstd,
    ScaleIntensityd,
    SpatialPadd,
    RandSpatialCropd,
    RandAxisFlipd,
    RandRotate90d,
    RandGaussianNoised,
    RandAdjustContrastd,
    RandGaussianSmoothd,
    RandHistogramShiftd,
    RandZoomd,
    EnsureTyped,
    Compose,
)

In [None]:
def print_data_info(data_dict, stage=""):
    """Helper function to print information about the data at any stage"""
    print(f"\n--- Data Info {stage} ---")
    for key in data_dict.keys():
        if isinstance(data_dict[key], (np.ndarray, torch.Tensor)):
            shape = data_dict[key].shape if isinstance(data_dict[key], np.ndarray) else tuple(data_dict[key].shape)
            dtype = data_dict[key].dtype
            if isinstance(data_dict[key], np.ndarray):
                minmax = (data_dict[key].min(), data_dict[key].max())
            else:
                minmax = (data_dict[key].min().item(), data_dict[key].max().item()) if data_dict[key].numel() > 0 else ("N/A", "N/A")
            print(f"{key}: Shape={shape}, Type={dtype}, Min/Max={minmax}")
        else:
            print(f"{key}: Type={type(data_dict[key])}")

def visualize_transform(data_item, transform, title):
    """Apply a transform and visualize before/after results with enhanced error handling"""
    try:
        # Create a copy to avoid modifying the original
        data_copy = {k: v.copy() if isinstance(v, np.ndarray) else v for k, v in data_item.items()}
        
        # Print information about input data
        print(f"\n=== Applying: {title} ===")
        print_data_info(data_copy, "before transform")
        
        # Apply transform
        result = transform(data_copy)
        
        # Print information about transformed data
        print_data_info(result, "after transform")
        
        # Create figure
        fig = plt.figure(figsize=(15, 8))
        gs = GridSpec(2, 3, figure=fig)
        
        # Original image
        ax1 = fig.add_subplot(gs[0, 0])
        if "img" in data_item:
            if isinstance(data_item["img"], str):
                ax1.text(0.5, 0.5, f"File path: {os.path.basename(data_item['img'])}", 
                         ha='center', va='center', wrap=True)
            elif isinstance(data_item["img"], np.ndarray):
                # Handle different dimensions
                if data_item["img"].ndim == 3 and data_item["img"].shape[0] == 3:  # CHW
                    ax1.imshow(np.transpose(data_item["img"], (1, 2, 0)))
                elif data_item["img"].ndim == 3 and data_item["img"].shape[2] == 3:  # HWC
                    ax1.imshow(data_item["img"])
                elif data_item["img"].ndim == 3 and data_item["img"].shape[0] == 1:  # CHW with 1 channel
                    ax1.imshow(data_item["img"][0], cmap='gray')
                else:
                    ax1.imshow(data_item["img"], cmap='gray')
            elif isinstance(data_item["img"], torch.Tensor):
                img_np = data_item["img"].detach().cpu().numpy()
                if img_np.ndim == 3 and img_np.shape[0] == 3:  # CHW
                    ax1.imshow(np.transpose(img_np, (1, 2, 0)))
                elif img_np.ndim == 3 and img_np.shape[0] == 1:  # CHW with 1 channel
                    ax1.imshow(img_np[0], cmap='gray')
                else:
                    ax1.imshow(img_np, cmap='gray')
        else:
            ax1.text(0.5, 0.5, "No image data", ha='center', va='center')
        ax1.set_title("Original Image")
        ax1.axis('off')
        
        # Original label
        ax2 = fig.add_subplot(gs[0, 1])
        if "label" in data_item and data_item["label"] is not None:
            if isinstance(data_item["label"], str):
                ax2.text(0.5, 0.5, f"File path: {os.path.basename(data_item['label'])}", 
                         ha='center', va='center', wrap=True)
            elif isinstance(data_item["label"], np.ndarray):
                if data_item["label"].ndim == 3 and data_item["label"].shape[0] == 1:  # CHW
                    ax2.imshow(data_item["label"][0], cmap='viridis')
                elif data_item["label"].ndim == 3 and data_item["label"].shape[0] > 1:
                    # Multi-channel label - show first channel
                    ax2.imshow(data_item["label"][0], cmap='viridis')
                    ax2.set_title("Original Label (Channel 0)")
                else:
                    ax2.imshow(data_item["label"], cmap='viridis')
            elif isinstance(data_item["label"], torch.Tensor):
                label_np = data_item["label"].detach().cpu().numpy()
                if label_np.ndim == 3 and label_np.shape[0] == 1:  # CHW
                    ax2.imshow(label_np[0], cmap='viridis')
                elif label_np.ndim == 3 and label_np.shape[0] > 1:
                    ax2.imshow(label_np[0], cmap='viridis')
                    ax2.set_title("Original Label (Channel 0)")
                else:
                    ax2.imshow(label_np, cmap='viridis')
        else:
            ax2.text(0.5, 0.5, "No label data", ha='center', va='center')
        if ax2.get_title() == "":  # Only set if not already set
            ax2.set_title("Original Label")
        ax2.axis('off')
        
        # Transformed image
        ax3 = fig.add_subplot(gs[1, 0])
        if "img" in result:
            # Convert to numpy array if it's a torch tensor
            if isinstance(result["img"], torch.Tensor):
                img_data = result["img"].detach().cpu().numpy()
            else:
                img_data = result["img"]
            
            # Handle different dimensions
            if img_data.ndim == 3 and img_data.shape[0] == 3:  # CHW
                ax3.imshow(np.transpose(img_data, (1, 2, 0)))
            elif img_data.ndim == 3 and img_data.shape[0] == 1:  # CHW with 1 channel
                ax3.imshow(img_data[0], cmap='gray')
            elif img_data.ndim == 4:  # Extra batch dimension
                if img_data.shape[1] == 3:  # BCHW
                    ax3.imshow(np.transpose(img_data[0], (1, 2, 0)))
                else:
                    ax3.imshow(img_data[0, 0], cmap='gray')
            else:
                ax3.imshow(img_data, cmap='gray')
        else:
            ax3.text(0.5, 0.5, "No transformed image", ha='center', va='center')
        ax3.set_title("Transformed Image")
        ax3.axis('off')
        
        # Transformed label
        ax4 = fig.add_subplot(gs[1, 1])
        if "label" in result and result["label"] is not None:
            # Convert to numpy array if it's a torch tensor
            if isinstance(result["label"], torch.Tensor):
                label_data = result["label"].detach().cpu().numpy()
            else:
                label_data = result["label"]
            
            # Handle different dimensions
            if label_data.ndim == 3 and label_data.shape[0] == 1:  # CHW
                ax4.imshow(label_data[0], cmap='viridis')
            elif label_data.ndim == 3 and label_data.shape[0] > 1:
                ax4.imshow(label_data[0], cmap='viridis')
                ax4.set_title("Transformed Label (Channel 0)")
            elif label_data.ndim == 4:  # Extra batch dimension
                ax4.imshow(label_data[0, 0], cmap='viridis')
            else:
                ax4.imshow(label_data, cmap='viridis')
        else:
            ax4.text(0.5, 0.5, "No transformed label", ha='center', va='center')
        if ax4.get_title() == "":  # Only set if not already set
            ax4.set_title("Transformed Label")
        ax4.axis('off')
        
        # Info panel
        ax5 = fig.add_subplot(gs[:, 2])
        ax5.axis('off')
        
        # Display image shape information
        info_text = f"Transform: {title}\n\n"
        
        if "img" in data_item:
            if isinstance(data_item["img"], np.ndarray):
                info_text += f"Original Image Shape: {data_item['img'].shape}\n"
                info_text += f"Original Image Type: {data_item['img'].dtype}\n"
            elif isinstance(data_item["img"], torch.Tensor):
                info_text += f"Original Image Shape: {tuple(data_item['img'].shape)}\n"
                info_text += f"Original Image Type: {data_item['img'].dtype}\n"
            elif isinstance(data_item["img"], str):
                info_text += f"Original Image: File path\n"
                
        if "label" in data_item and data_item["label"] is not None:
            if isinstance(data_item["label"], np.ndarray):
                info_text += f"Original Label Shape: {data_item['label'].shape}\n"
                info_text += f"Original Label Type: {data_item['label'].dtype}\n"
            elif isinstance(data_item["label"], torch.Tensor):
                info_text += f"Original Label Shape: {tuple(data_item['label'].shape)}\n"
                info_text += f"Original Label Type: {data_item['label'].dtype}\n"
            elif isinstance(data_item["label"], str):
                info_text += f"Original Label: File path\n"
        
        if "img" in result:
            if isinstance(result["img"], np.ndarray):
                img_data = result["img"]
                info_text += f"Transformed Image Shape: {img_data.shape}\n"
                info_text += f"Transformed Image Type: {img_data.dtype}\n"
                info_text += f"Image Value Range: [{img_data.min():.4f}, {img_data.max():.4f}]\n"
            elif isinstance(result["img"], torch.Tensor):
                img_data = result["img"].detach().cpu().numpy()
                info_text += f"Transformed Image Shape: {tuple(result['img'].shape)}\n"
                info_text += f"Transformed Image Type: {result['img'].dtype}\n"
                info_text += f"Image Value Range: [{img_data.min():.4f}, {img_data.max():.4f}]\n"
        
        if "label" in result and result["label"] is not None:
            if isinstance(result["label"], np.ndarray):
                label_data = result["label"]
                info_text += f"Transformed Label Shape: {label_data.shape}\n"
                info_text += f"Transformed Label Type: {label_data.dtype}\n"
                unique_values = np.unique(label_data)
                if len(unique_values) <= 10:  # Only show if there aren't too many values
                    info_text += f"Label Unique Values: {unique_values}\n"
                else:
                    info_text += f"Label Unique Values: {len(unique_values)} different values\n"
            elif isinstance(result["label"], torch.Tensor):
                label_data = result["label"].detach().cpu().numpy()
                info_text += f"Transformed Label Shape: {tuple(result['label'].shape)}\n"
                info_text += f"Transformed Label Type: {result['label'].dtype}\n"
                unique_values = np.unique(label_data)
                if len(unique_values) <= 10:  # Only show if there aren't too many values
                    info_text += f"Label Unique Values: {unique_values}\n"
                else:
                    info_text += f"Label Unique Values: {len(unique_values)} different values\n"
        
        # Add transform parameters
        info_text += "\nTransform Parameters:\n"
        if hasattr(transform, "transform_params"):
            for key, value in transform.transform_params.items():
                info_text += f"{key}: {value}\n"
        else:
            info_text += str(transform)[:200] + "...\n"  # Show abbreviated transform description
            
        ax5.text(0.1, 0.9, info_text, va='top', ha='left', wrap=True, fontsize=8)
        
        plt.tight_layout()
        plt.suptitle(title, fontsize=16)
        plt.subplots_adjust(top=0.9)
        plt.show()
        
        return result
    
    except Exception as e:
        print(f"\n!!! Error visualizing '{title}' transform !!!")
        print(f"Error type: {type(e).__name__}")
        print(f"Error message: {str(e)}")
        print("\nTraceback:")
        import traceback
        traceback.print_exc()
        
        # Try to print more debug info about the data
        print("\nData item keys:", list(data_item.keys()))
        for key in data_item:
            if isinstance(data_item[key], (np.ndarray, torch.Tensor)):
                print(f"{key} shape:", data_item[key].shape if isinstance(data_item[key], np.ndarray) 
                      else tuple(data_item[key].shape))
                
        # Return original data to continue pipeline
        return data_item

In [None]:
def main():
    # Set parameters
    input_size = 256
    
    # Ask user for path to a sample image and label
    print("\n=== MONAI Transform Visualization Tool ===")
    default_img_path = "C:\\Users\\Samir\\Documents\\GitHub\\IFT3710-Advanced-Project-in-ML-AI\\notebooks\\preprocessing_outputs\\images"
    default_gt_path = "C:\\Users\\Samir\\Documents\\GitHub\\IFT3710-Advanced-Project-in-ML-AI\\notebooks\\preprocessing_outputs\\labels"
    
    img_path = input(f"Enter the image directory path [default: {default_img_path}]: ") or default_img_path
    gt_path = input(f"Enter the label directory path [default: {default_gt_path}]: ") or default_gt_path
    
    # Verify paths exist
    if not os.path.exists(img_path):
        print(f"Error: Image directory '{img_path}' does not exist.")
        return
    
    if not os.path.exists(gt_path):
        print(f"Warning: Label directory '{gt_path}' does not exist. Will proceed without labels.")
        has_labels = False
    else:
        has_labels = True
    
    img_names = sorted([f for f in os.listdir(img_path) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.nii', '.nii.gz'))])
    
    if not img_names:
        print("No images found in the specified directory.")
        return
    
    print("\nAvailable images:")
    for i, name in enumerate(img_names):
        print(f"{i+1}. {name}")
    
    selection = int(input("\nSelect image number to visualize (1-" + str(len(img_names)) + "): ")) - 1
    if selection < 0 or selection >= len(img_names):
        print("Invalid selection.")
        return
    
    selected_img = img_names[selection]
    selected_gt = selected_img.split(".")[0] + "_label.png"  # Assuming this naming convention
    
    # Check if the label file exists
    label_path = os.path.join(gt_path, selected_gt)
    label_exists = os.path.exists(label_path) and has_labels
    
    if not label_exists:
        print(f"Warning: Label file not found: {label_path}")
        proceed = input("Do you want to continue without a label? (y/n): ")
        if proceed.lower() != 'y':
            return
    
    # Create a data item
    data_item = {
        "img": os.path.join(img_path, selected_img),
    }
    
    if label_exists:
        data_item["label"] = os.path.join(gt_path, selected_gt)
    else:
        data_item["label"] = None
    
    # Define individual transforms with proper parameters and better error handling
    transforms = [
        # Basic loading and preprocessing
        ("Load Image", LoadImaged(keys=["img", "label"], reader=PILReader, dtype=np.uint8, allow_missing_keys=True)),
        
        # Handle channel dimensions
        ("Ensure Channel First", EnsureChannelFirstd(keys=["img", "label"], channel_dim="auto", allow_missing_keys=True)),
        
        # Normalize and scale image intensities
        ("Scale Intensity", ScaleIntensityd(keys=["img"], allow_missing_keys=True)),
        
        # Ensure consistent spatial dimensions
        ("Spatial Padding", SpatialPadd(keys=["img", "label"], spatial_size=input_size, allow_missing_keys=True)),
        ("Random Spatial Crop", RandSpatialCropd(keys=["img", "label"], roi_size=input_size, random_size=True, allow_missing_keys=True)),
        
        # Augmentations - with fixes for the errors
        ("Random Axis Flip", RandAxisFlipd(keys=["img", "label"], prob=1.0, allow_missing_keys=True)),
        ("Random Rotate 90", RandRotate90d(keys=["img", "label"], prob=1.0, spatial_axes=[0, 1], allow_missing_keys=True)),
        ("Random Gaussian Noise", RandGaussianNoised(keys=["img"], prob=1.0, mean=0, std=0.1, allow_missing_keys=True)),
        ("Random Adjust Contrast", RandAdjustContrastd(keys=["img"], prob=1.0, gamma=1.5, allow_missing_keys=True)),
        ("Random Gaussian Smooth", RandGaussianSmoothd(keys=["img"], prob=1.0, sigma_x=(0.5, 1.0), allow_missing_keys=True)),
        ("Random Histogram Shift", RandHistogramShiftd(keys=["img"], prob=1.0, num_control_points=3, allow_missing_keys=True)),
        
        # Fixed Random Zoom with proper mode setting
        ("Random Zoom", RandZoomd(
            keys=["img", "label"], 
            prob=1.0, 
            min_zoom=0.8, 
            max_zoom=1.2, 
            mode=("bilinear", "nearest"),  # Specify interpolation mode for each key
            keep_size=True,
            allow_missing_keys=True
        )),
        
        # Ensure proper tensor types
        ("Ensure Type", EnsureTyped(keys=["img", "label"], allow_missing_keys=True))
    ]
    
    # Visualize each transform step by step
    print("\n=== Starting transform visualization ===")
    current_data = data_item
    for name, transform in transforms:
        current_data = visualize_transform(current_data, transform, name)
    
    # Optional: Show the full transformation pipeline at once
    try:
        print("\n=== Applying full transformation pipeline... ===")
        # Create a full pipeline with allow_missing_keys=True for all transforms
        full_transforms = Compose([t for _, t in transforms])
        visualize_transform(data_item, full_transforms, "Complete Pipeline")
    except Exception as e:
        print(f"Error applying full pipeline: {str(e)}")
        traceback.print_exc()
    
    print("\n=== Transform visualization complete ===")

In [None]:
main()