In [1]:
import os
import glob
import torch
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider, Dropdown
import warnings

# --- Suppress Matplotlib UserWarning about figure layout ---
warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")

# =====================================================================================
# Configuration
# =====================================================================================

# IMPORTANT: Set this to the directory where you saved your preprocessed .pt files.
# This should match the `save_dir` from your spoof.py script.
PROCESSED_DATA_DIR = r"G:\My Drive\InputScans_Final"

# =====================================================================================
# Helper Functions
# =====================================================================================

def load_patient_data(patient_id, data_dir):
    """Loads the preprocessed image and label tensors for a given patient."""
    file_path = os.path.join(data_dir, f"{patient_id}.pt")
    if not os.path.exists(file_path):
        print(f"Error: Processed file not found for patient '{patient_id}' at {file_path}")
        return None, None
    
    # Load the data dictionary from the .pt file
    try:
        data = torch.load(file_path, map_location=torch.device('cpu'))
        # Convert image from float16 (half) to float32 for compatibility with matplotlib
        image_tensor = data['image'].float()
        label_tensor = data['label']
        print(f"Successfully loaded Patient '{patient_id}'")
        print(f"  Image tensor shape: {image_tensor.shape} | (Channels, Depth, Height, Width)")
        print(f"  Label tensor shape: {label_tensor.shape} | (Depth, Height, Width)")
        return image_tensor, label_tensor
    except Exception as e:
        print(f"An error occurred while loading {file_path}: {e}")
        return None, None

def plot_slice(image_slice, label_slice, title=""):
    """Plots a single 2D slice with its corresponding label overlay."""
    # Mask the label so that the background (value 0) is not shown
    masked_label = np.ma.masked_where(label_slice == 0, label_slice)
    
    fig, ax = plt.subplots(1, 1, figsize=(7, 7), dpi=100)
    ax.imshow(image_slice, cmap='bone') # Display the scan slice in grayscale
    ax.imshow(masked_label, cmap='autumn', alpha=0.6) # Overlay the segmentation in color
    ax.set_title(title, fontsize=14)
    ax.axis('off') # Hide axes
    plt.tight_layout()
    plt.show()

# =====================================================================================
# Main Interactive Visualizer
# =====================================================================================

def visualize_patient(patient_id, data_dir=PROCESSED_DATA_DIR):
    """Creates an interactive visualizer for a specific patient's 3D scan."""
    
    image, label = load_patient_data(patient_id, data_dir)
    
    if image is None or label is None:
        return

    # Channel mapping based on your preprocessing script
    channel_map = {
        '0: t1c': 0,
        '1: t1n': 1,
        '2: t2f': 2,
        '3: t2w': 3,
        '4: Gradient Map': 4
    }

    # --- Interactive function that is called by the widgets ---
    def viewer(view_plane, channel_name, slice_idx):
        # Select the chosen channel
        channel_idx = channel_map[channel_name]
        img_channel = image[channel_idx].numpy()
        lbl = label.numpy()

        if view_plane == 'Axial':
            image_slice = img_channel[slice_idx, :, :]
            label_slice = lbl[slice_idx, :, :]
            title = f"Patient: {patient_id} | Axial Slice: {slice_idx} | Channel: {channel_name}"
        elif view_plane == 'Sagittal':
            image_slice = img_channel[:, :, slice_idx].T # Transpose for correct orientation
            label_slice = lbl[:, :, slice_idx].T
            title = f"Patient: {patient_id} | Sagittal Slice: {slice_idx} | Channel: {channel_name}"
        else: # Coronal
            image_slice = img_channel[:, slice_idx, :].T
            label_slice = lbl[:, slice_idx, :].T
            title = f"Patient: {patient_id} | Coronal Slice: {slice_idx} | Channel: {channel_name}"
        
        plot_slice(image_slice, label_slice, title)

    # --- Create the interactive widgets ---
    # Define slice ranges for the sliders based on the view
    depth, height, width = image.shape[1:]
    slice_sliders = {
        'Axial': IntSlider(min=0, max=depth - 1, step=1, value=depth // 2, description='Slice:', continuous_update=False),
        'Sagittal': IntSlider(min=0, max=width - 1, step=1, value=width // 2, description='Slice:', continuous_update=False),
        'Coronal': IntSlider(min=0, max=height - 1, step=1, value=height // 2, description='Slice:', continuous_update=False)
    }

    # The main widget that ties everything together
    @interact
    def main_interaction(
        view_plane=Dropdown(options=['Axial', 'Sagittal', 'Coronal'], value='Axial', description='View:'),
        channel_name=Dropdown(options=list(channel_map.keys()), value='0: t1c', description='Channel:')
    ):
        # Display a second slider that is specific to the selected view plane
        interact(viewer, view_plane=view_plane, channel_name=channel_name, slice_idx=slice_sliders[view_plane])


# =====================================================================================
# Execution Block
# =====================================================================================

if __name__ == "__main__" and 'ipykernel' in str(get_ipython()):
    if not os.path.exists(PROCESSED_DATA_DIR):
        print(f"Error: The directory '{PROCESSED_DATA_DIR}' does not exist.")
        print("Please update the 'PROCESSED_DATA_DIR' variable to the correct path.")
    else:
        # Find available patient IDs from the filenames
        patient_files = glob.glob(os.path.join(PROCESSED_DATA_DIR, "*.pt"))
        if not patient_files:
            print(f"No preprocessed '.pt' files found in '{PROCESSED_DATA_DIR}'.")
        else:
            # Extract patient IDs from filenames like 'BraTS-PED-00000.pt'
            available_patients = sorted([os.path.basename(f).replace('.pt', '') for f in patient_files])
            
            # --- START VISUALIZATION ---
            # Change this to the patient you want to view
            patient_to_visualize = available_patients[0] 
            
            print("="*50)
            print("STARTING INTERACTIVE VISUALIZER")
            print(f"Available patients found: {len(available_patients)}")
            print(f"To view a different patient, change the 'patient_to_visualize' variable.")
            print("="*50)
            
            visualize_patient(patient_to_visualize)
else:
    print("This script is designed to be run in a Jupyter Notebook environment.")

STARTING INTERACTIVE VISUALIZER
Available patients found: 4
To view a different patient, change the 'patient_to_visualize' variable.
Successfully loaded Patient 'BraTS-PED-00001-000'
  Image tensor shape: torch.Size([5, 240, 240, 155]) | (Channels, Depth, Height, Width)
  Label tensor shape: torch.Size([240, 240, 155]) | (Depth, Height, Width)


interactive(children=(Dropdown(description='View:', options=('Axial', 'Sagittal', 'Coronal'), value='Axial'), …