In [1]:
from pathlib import Path
import re
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider

# --- PATH TO YOUR DATA ---
DATA_ROOT = Path("/Users/michaelbanks/Desktop/AI TXG/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData")

# --- BRATS NAMING + MODALITIES ---
MODALITIES = ["flair", "t1", "t1ce", "t2", "seg"]

# Example filename: BraTS20_Training_001_flair.nii -> ID = 001
ID_RE = re.compile(r".*_(\d+)_", re.IGNORECASE)


def list_nii_files(root: Path):
    """Recursively list all .nii files under DATA_ROOT."""
    return sorted(root.rglob("*.nii"))


def index_brats_volumes(nii_files):
    """
    Build:
    {
        1: {"flair": Path(...), "t1": Path(...), "t1ce": Path(...), "t2": Path(...), "seg": Path(...)},
        2: {...},
        ...
    }
    """
    volumes = {}

    for fp in nii_files:
        name = fp.stem.lower()  # filename without extension

        # patient ID
        m = ID_RE.match(name)
        if not m:
            continue
        vid = int(m.group(1))

        # modality
        modality = None
        for mod in MODALITIES:
            if name.endswith(mod):
                modality = mod
                break
        if modality is None:
            continue

        volumes.setdefault(vid, {})[modality] = fp

    # sort by patient ID
    return dict(sorted(volumes.items()))


nii_files = list_nii_files(DATA_ROOT)
vol_index = index_brats_volumes(nii_files)

print("Total .nii files:", len(nii_files))
print("Total volumes:", len(vol_index))

# quick sanity check: how many volumes actually have a tumor mask (seg)?
n_with_seg = sum(1 for v in vol_index.values() if "seg" in v)
print("Volumes with seg mask:", n_with_seg)


# --- LOADING ONE PATIENT VOLUME ---

def load_volume(vid):
    """Load all four MRI modalities and seg mask for a given patient ID."""
    entry = vol_index[vid]

    # load modalities (order: T1, T1ce, T2, FLAIR to match your earlier code)
    imgs = []
    for mod in ["t1", "t1ce", "t2", "flair"]:
        data = nib.load(str(entry[mod])).get_fdata()
        imgs.append(data)

    # stack to (Z, H, W, 4)
    X = np.stack(imgs, axis=-1).astype(np.float32)

    # segmentation mask: (Z, H, W)
    Y = nib.load(str(entry["seg"])).get_fdata().astype(np.int16)

    return X, Y


# --- INTERACTIVE VISUALIZATION ---

# freeze list of patient IDs so we can index them by integer
VOL_IDS = list(vol_index.keys())

# load one volume to set slider ranges
X0, Y0 = load_volume(VOL_IDS[0])
num_slices = X0.shape[0]

def show_slice(vol_idx=0, z=0, modality=3):
    """
    vol_idx: index into VOL_IDS
    z: slice index
    modality: 0=T1, 1=T1ce, 2=T2, 3=FLAIR
    """
    vid = VOL_IDS[vol_idx]
    X, Y = load_volume(vid)

    img = X[z, :, :, modality]
    mask = Y[z, :, :]  # tumor labels >0

    fig, axes = plt.subplots(1, 2, figsize=(10, 5))

    # Left: raw MRI slice
    axes[0].imshow(img, cmap='gray')
    axes[0].set_title(f"ID {vid} | Modality {['T1','T1ce','T2','FLAIR'][modality]}\nSlice {z} (no mask)")
    axes[0].axis('off')

    # Right: MRI + tumor overlay
    axes[1].imshow(img, cmap='gray')
    axes[1].imshow(np.ma.masked_where(mask == 0, mask),
                   alpha=0.7)
    axes[1].set_title(f"ID {vid} | With Tumor Mask\nSlice {z}")
    axes[1].axis('off')

    plt.tight_layout()
    plt.show()


interact(
    show_slice,
    vol_idx=IntSlider(min=0, max=len(VOL_IDS) - 1, step=1, value=0, description="Patient"),
    z=IntSlider(min=0, max=num_slices - 1, step=1, value=num_slices // 2, description="Slice"),
    modality=IntSlider(min=0, max=3, step=1, value=3, description="Modality"),
)


Total .nii files: 1845
Total volumes: 369
Volumes with seg mask: 369


interactive(children=(IntSlider(value=0, description='Patient', max=368), IntSlider(value=120, description='Sl…

<function __main__.show_slice(vol_idx=0, z=0, modality=3)>

In [None]:

# --- IMPORT SMART DATASET ---
from old_brats_utils import BraTSSmartDataset

# Initialize the smart dataset (limited to first 50 patients for speed)
smart_dataset = BraTSSmartDataset(DATA_ROOT, limit_patients=50)

print(f"Smart Dataset initialized with {len(smart_dataset)} valid slices")


TypeError: BraTSSmartDataset.__init__() got an unexpected keyword argument 'limit_patients'

In [None]:

# --- INTERACTIVE SMART DATASET VISUALIZATION ---

def visualize_smart_dataset(slice_idx=0):
    """
    Visualize a slice from the smart dataset.
    
    slice_idx: Index into the valid_slices list (0 to len(dataset)-1)
    """
    if slice_idx >= len(smart_dataset):
        print(f"Invalid slice index. Dataset has {len(smart_dataset)} slices.")
        return
    
    # Get the slice from dataset
    img_tensor, mask_tensor = smart_dataset[slice_idx]
    
    # Convert tensors to numpy
    img = img_tensor.squeeze().numpy()
    mask = mask_tensor.squeeze().numpy()
    
    # Get the underlying patient/slice info for reference
    patient_idx, slice_num = smart_dataset.valid_slices[slice_idx]
    patient_path = smart_dataset.patient_folders[patient_idx]
    patient_id = patient_path.name
    
    # Create visualization
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Left: Raw image (z-score normalized)
    axes[0].imshow(img, cmap='gray')
    axes[0].set_title(f"Patient {patient_id}\nSlice {slice_num} (Z-Score Normalized)")
    axes[0].axis('off')
    
    # Middle: Tumor mask
    axes[1].imshow(mask, cmap='gray')
    axes[1].set_title(f"Tumor Mask\n(Binary)")
    axes[1].axis('off')
    
    # Right: Overlay
    axes[2].imshow(img, cmap='gray')
    axes[2].imshow(np.ma.masked_where(mask == 0, mask), cmap='jet', alpha=0.6)
    axes[2].set_title(f"Image + Tumor Overlay\nSlice {slice_idx}/{len(smart_dataset)-1}")
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()


# Create interactive slider
interact(
    visualize_smart_dataset,
    slice_idx=IntSlider(min=0, max=len(smart_dataset)-1, step=1, value=0, description="Slice Index")
)


interactive(children=(IntSlider(value=0, description='Slice Index', max=3146), Output()), _dom_classes=('widge…

<function __main__.visualize_smart_dataset(slice_idx=0)>