In [21]:
# In a Jupyter Notebook cell

import torch
from torch.utils.data import ConcatDataset
from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider

# Import our custom project files
from local_flare_loader import LocalFlareTask3 # Needed for load_dataset to work
from dataset import NiftiDataset

# --- 1. CONFIGURE YOUR VISUALIZATION ---

# Path to your local dataset
LOCAL_DATASET_PATH = "/mnt/asgard6/data/FLARE-MedFM/FLARE-Task3-DomainAdaption"

# Choose which dataset split to view. Options:
# 'train_ct_gt', 'train_ct_pseudo_aladdin', 'train_ct_pseudo_blackbean',
# 'validation_mri', 'validation_pet', 'train_mri_unlabeled', 'train_pet_unlabeled'
DATASET_CONFIG_NAME = 'train_mri_unlabeled' 

# Choose which sample from that split to visualize (e.g., the 10th patient)
SAMPLE_INDEX = 2

# ----------------------------------------

In [22]:
# In a Jupyter Notebook cell

print(f"--- Loading configuration: '{DATASET_CONFIG_NAME}' ---")
print(f"--- Visualizing sample index: {SAMPLE_INDEX} ---")

# Load the Hugging Face dataset object (contains file paths)
hf_dataset = load_dataset(
    "./local_flare_loader.py", 
    name=DATASET_CONFIG_NAME, 
    data_dir=LOCAL_DATASET_PATH, 
    trust_remote_code=True
)["train"]

print(len(hf_dataset), "samples found in the dataset.")
# Create our PyTorch NiftiDataset
pytorch_dataset = NiftiDataset(hf_dataset)

# Get the specific sample
sample = pytorch_dataset[SAMPLE_INDEX]

# Move tensors to CPU and convert to NumPy for plotting
# The .squeeze(0) removes the channel dimension (1, D, H, W) -> (D, H, W)
image_3d = sample['image'].cpu().numpy().squeeze(0)
label_3d = sample['label'].cpu().numpy()

# Check if the sample has a real label or is unlabeled
has_label = sample['label'].min() != -1

print(f"\nImage shape: {image_3d.shape}")
print(f"Label shape: {label_3d.shape}")
print(f"Dataset has labels: {has_label}")

--- Loading configuration: 'train_mri_unlabeled' ---
--- Visualizing sample index: 2 ---


Generating train split: 0 examples [00:00, ? examples/s]

4817 samples found in the dataset.

Image shape: (72, 512, 512)
Label shape: ()
Dataset has labels: False


In [27]:
# In a Jupyter Notebook cell

def window_image(image, window_center, window_width):
    """
    Applies windowing to a CT scan. This is crucial for visualizing CT data.
    """
    img_min = window_center - window_width // 2
    img_max = window_center + window_width // 2
    windowed_img = image.copy()
    windowed_img[windowed_img < img_min] = img_min
    windowed_img[windowed_img > img_max] = img_max
    return windowed_img

def plot_slice(z):
    """
    This function is called every time the slider value changes.
    It plots the image, label, and an overlay of the two.
    """
    image_slice = image_3d[z, :, :]
    
    # Apply a standard abdominal window for CT scans for better contrast
    if 'ct' in DATASET_CONFIG_NAME.lower():
        image_slice = window_image(image_slice, window_center=40, window_width=400)

    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # --- Plot 1: The Raw Image ---
    axes[0].imshow(image_slice, cmap='gray')
    axes[0].set_title(f"Image Slice (Modality: {DATASET_CONFIG_NAME.split('_')[1].upper()})")
    axes[0].axis('off')

    # --- Plot 2: The Segmentation Label ---
    if has_label:
        label_slice = label_3d[z, :, :]
        axes[1].imshow(label_slice, cmap='nipy_spectral', interpolation='none')
        axes[1].set_title("Segmentation Mask")
    else:
        axes[1].text(0.5, 0.5, 'No Label Available', ha='center', va='center', fontsize=12)
        axes[1].set_title("Segmentation Mask")
    axes[1].axis('off')

    # --- Plot 3: The Overlay ---
    axes[2].imshow(image_slice, cmap='gray')
    if has_label:
        # Use a masked array to make the background (label 0) transparent
        masked_label = np.ma.masked_where(label_slice == 0, label_slice)
        axes[2].imshow(masked_label, cmap='nipy_spectral', alpha=0.5, interpolation='none')
    axes[2].set_title("Image + Mask Overlay")
    axes[2].axis('off')
    
    plt.suptitle(f"Viewing Slice Z = {z}", fontsize=16)
    plt.tight_layout()
    plt.show()

# Create the interactive slider
# The z-axis is the first dimension of the 3D volume
num_slices = image_3d.shape[0]
interact(plot_slice, z=IntSlider(min=0, max=num_slices - 1, step=1, value=num_slices // 2, description='Slice:'));

interactive(children=(IntSlider(value=50, description='Slice:'), Output()), _dom_classes=('widget-interact',))

In [3]:
def window_image(image, window_center, window_width):
    """
    Applies CT windowing for better visualization.
    """
    img_min = window_center - window_width // 2
    img_max = window_center + window_width // 2
    windowed_img = image.copy()
    windowed_img[windowed_img < img_min] = img_min
    windowed_img[windowed_img > img_max] = img_max
    return windowed_img

In [14]:
# check folder

import os

list_dirs1 = os.listdir(os.path.join(LOCAL_DATASET_PATH, "validation", "MRI_imagesVal")) 
# list_dirs2 = os.listdir(os.path.join(LOCAL_DATASET_PATH, "train_MRI_unlabeled", "LLD-MMRI-3984"))  
print(len(list_dirs1), len(list_dirs1))

110 110


In [15]:
hf_dataset[0]

{'image_path': '/mnt/asgard6/data/FLARE-MedFM/FLARE-Task3-DomainAdaption/validation/PET_labelsVal/fdg_1bb48bfb40_12-02-2000-NA-PET-CT_Ganzkoerper__primaer_mit_KM-90244.nii.gz',
 'label_path': 'N/A'}

In [4]:
load_dataset(
    "./local_flare_loader.py", 
    name="train_ct_gt", 
    data_dir=LOCAL_DATASET_PATH, 
    # trust_remote_code=True
)

DatasetDict({
    train: Dataset({
        features: ['image_path', 'label_path'],
        num_rows: 50
    })
})