# Exploratory Data Analysis: FeTA 2.4 Fetal Brain MRI

This notebook explores the FeTA 2.4 dataset for fetal brain MRI segmentation.

**Contents:**
1. Load a sample MRI volume and its segmentation label
2. Print shape and intensity statistics
3. Visualize middle slices in all orientations
4. Overlay segmentation on MRI
5. Analyze label distribution

In [None]:
# Import required libraries
import sys
sys.path.append('..')  # Add parent directory to path

import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib
from pathlib import Path

# Import our modules
from src.data import load_nifti_volume, discover_dataset
from src.preprocessing import zscore_normalize, intensity_clipping

# Configure matplotlib
plt.style.use('default')
%matplotlib inline

## 1. Configuration

Set the path to your FeTA dataset directory.

In [None]:
# CONFIGURE THIS: Path to your FeTA dataset
DATA_DIR = Path("../data/feta_2.4/train")  # Update this path

# Alternatively, you can set sample paths directly
# SAMPLE_IMAGE = Path("../data/feta_2.4/train/sub-001/sub-001_T2w.nii.gz")
# SAMPLE_LABEL = Path("../data/feta_2.4/train/sub-001/sub-001_dseg.nii.gz")

# Class names for FeTA dataset
CLASS_NAMES = {
    0: 'Background',
    1: 'External CSF',
    2: 'Gray Matter',
    3: 'White Matter',
    4: 'Ventricles',
    5: 'Cerebellum',
    6: 'Deep Gray Matter',
    7: 'Brainstem'
}

# Colors for visualization (RGBA)
CLASS_COLORS = {
    0: [0, 0, 0, 0],         # Background - transparent
    1: [255, 0, 0, 128],     # External CSF - red
    2: [0, 255, 0, 128],     # Gray Matter - green
    3: [0, 0, 255, 128],     # White Matter - blue
    4: [255, 255, 0, 128],   # Ventricles - yellow
    5: [255, 0, 255, 128],   # Cerebellum - magenta
    6: [0, 255, 255, 128],   # Deep Gray Matter - cyan
    7: [255, 128, 0, 128],   # Brainstem - orange
}

## 2. Dataset Discovery

Scan the data directory to find all image-label pairs.

In [None]:
# Discover dataset
if DATA_DIR.exists():
    samples = discover_dataset(
        DATA_DIR,
        image_pattern="*_T2w.nii.gz",
        label_pattern="*_dseg.nii.gz"
    )
    print(f"Found {len(samples)} samples in {DATA_DIR}")
    
    if samples:
        print("\nFirst 5 samples:")
        for s in samples[:5]:
            print(f"  {s['subject_id']}: {s['image'].name}")
else:
    print(f"Data directory not found: {DATA_DIR}")
    print("Please update DATA_DIR with the correct path to your FeTA dataset.")
    samples = []

## 3. Load Sample MRI and Label

Load one MRI volume and its corresponding segmentation label.

In [None]:
# Load a sample (or use placeholder data for demo)
if samples:
    sample = samples[0]
    print(f"Loading: {sample['subject_id']}")
    
    # Load MRI volume
    mri_volume, mri_affine = load_nifti_volume(sample['image'], return_affine=True)
    
    # Load segmentation label
    label_volume = load_nifti_volume(sample['label'])
    
    print("\nVolumes loaded successfully!")
else:
    # Create placeholder data for demonstration
    print("Creating placeholder data for demonstration...")
    print("(Replace with real data by updating DATA_DIR)")
    
    # Simulate a fetal brain MRI
    np.random.seed(42)
    shape = (256, 256, 256)
    
    # Create dummy MRI (Gaussian blob with noise)
    x, y, z = np.ogrid[:shape[0], :shape[1], :shape[2]]
    center = np.array(shape) // 2
    r = np.sqrt((x - center[0])**2 + (y - center[1])**2 + (z - center[2])**2)
    mri_volume = np.exp(-r**2 / (2 * 50**2)) * 1000 + np.random.randn(*shape) * 50
    mri_volume = mri_volume.astype(np.float32)
    mri_affine = np.eye(4)
    
    # Create dummy labels
    label_volume = np.zeros(shape, dtype=np.int64)
    label_volume[r < 80] = 1  # External CSF
    label_volume[r < 60] = 3  # White Matter
    label_volume[r < 40] = 2  # Gray Matter
    label_volume[r < 20] = 4  # Ventricles

## 4. Volume Statistics

Examine the shape, data type, and intensity statistics of the volumes.

In [None]:
# Print MRI statistics
print("=" * 50)
print("MRI Volume Statistics")
print("=" * 50)
print(f"Shape:       {mri_volume.shape}")
print(f"Data type:   {mri_volume.dtype}")
print(f"Min value:   {mri_volume.min():.2f}")
print(f"Max value:   {mri_volume.max():.2f}")
print(f"Mean:        {mri_volume.mean():.2f}")
print(f"Std:         {mri_volume.std():.2f}")
print(f"Median:      {np.median(mri_volume):.2f}")

# Non-zero statistics (brain region)
non_zero = mri_volume[mri_volume > 0]
if len(non_zero) > 0:
    print(f"\nNon-zero region:")
    print(f"  Voxels:    {len(non_zero):,}")
    print(f"  Mean:      {non_zero.mean():.2f}")
    print(f"  Std:       {non_zero.std():.2f}")

In [None]:
# Print Label statistics
print("\n" + "=" * 50)
print("Segmentation Label Statistics")
print("=" * 50)
print(f"Shape:       {label_volume.shape}")
print(f"Data type:   {label_volume.dtype}")

# Unique labels
unique_labels = np.unique(label_volume)
print(f"\nUnique label values: {unique_labels.tolist()}")

# Label distribution
print(f"\nLabel Distribution:")
print("-" * 40)
total_voxels = label_volume.size

for label in unique_labels:
    count = np.sum(label_volume == label)
    percentage = 100 * count / total_voxels
    name = CLASS_NAMES.get(int(label), f'Unknown ({label})')
    print(f"  {label}: {name:<20} {count:>10,} ({percentage:>5.2f}%)")

## 5. Intensity Distribution

Visualize the histogram of MRI intensities.

In [None]:
# Plot intensity histogram
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Full histogram
ax = axes[0]
ax.hist(mri_volume.flatten(), bins=100, color='steelblue', alpha=0.7, edgecolor='black')
ax.set_xlabel('Intensity', fontsize=12)
ax.set_ylabel('Frequency', fontsize=12)
ax.set_title('MRI Intensity Distribution (All Voxels)', fontsize=14)
ax.axvline(mri_volume.mean(), color='red', linestyle='--', label=f'Mean: {mri_volume.mean():.1f}')
ax.legend()

# Non-zero histogram (brain region)
ax = axes[1]
non_zero_values = mri_volume[mri_volume > 0]
ax.hist(non_zero_values, bins=100, color='coral', alpha=0.7, edgecolor='black')
ax.set_xlabel('Intensity', fontsize=12)
ax.set_ylabel('Frequency', fontsize=12)
ax.set_title('MRI Intensity Distribution (Brain Region)', fontsize=14)
ax.axvline(non_zero_values.mean(), color='red', linestyle='--', label=f'Mean: {non_zero_values.mean():.1f}')
ax.legend()

plt.tight_layout()
plt.show()

## 6. Visualize Middle Slices

Display the middle slice in axial, sagittal, and coronal views.

In [None]:
def show_middle_slices(volume, title="MRI Volume"):
    """Display middle slices in three orientations."""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Get middle indices
    mid_z = volume.shape[0] // 2
    mid_y = volume.shape[1] // 2
    mid_x = volume.shape[2] // 2
    
    # Axial (Z plane)
    axes[0].imshow(volume[mid_z, :, :], cmap='gray', origin='lower')
    axes[0].set_title(f'Axial (Z={mid_z})', fontsize=12)
    axes[0].axis('off')
    
    # Coronal (Y plane)
    axes[1].imshow(volume[:, mid_y, :], cmap='gray', origin='lower')
    axes[1].set_title(f'Coronal (Y={mid_y})', fontsize=12)
    axes[1].axis('off')
    
    # Sagittal (X plane)
    axes[2].imshow(volume[:, :, mid_x], cmap='gray', origin='lower')
    axes[2].set_title(f'Sagittal (X={mid_x})', fontsize=12)
    axes[2].axis('off')
    
    plt.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Show MRI slices
show_middle_slices(mri_volume, "MRI Volume - Middle Slices")

## 7. Visualize Segmentation Labels

Display the segmentation labels with a custom colormap.

In [None]:
def create_label_colormap(num_classes=8):
    """Create a colormap for segmentation labels."""
    from matplotlib.colors import ListedColormap
    
    colors = [
        [0, 0, 0],           # 0: Background - black
        [1, 0, 0],           # 1: External CSF - red
        [0, 1, 0],           # 2: Gray Matter - green
        [0, 0, 1],           # 3: White Matter - blue
        [1, 1, 0],           # 4: Ventricles - yellow
        [1, 0, 1],           # 5: Cerebellum - magenta
        [0, 1, 1],           # 6: Deep Gray Matter - cyan
        [1, 0.5, 0],         # 7: Brainstem - orange
    ]
    return ListedColormap(colors[:num_classes])

def show_label_slices(label, title="Segmentation Labels"):
    """Display middle slices of segmentation labels."""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    cmap = create_label_colormap()
    
    mid_z = label.shape[0] // 2
    mid_y = label.shape[1] // 2
    mid_x = label.shape[2] // 2
    
    # Axial
    im = axes[0].imshow(label[mid_z, :, :], cmap=cmap, vmin=0, vmax=7, origin='lower')
    axes[0].set_title(f'Axial (Z={mid_z})', fontsize=12)
    axes[0].axis('off')
    
    # Coronal
    axes[1].imshow(label[:, mid_y, :], cmap=cmap, vmin=0, vmax=7, origin='lower')
    axes[1].set_title(f'Coronal (Y={mid_y})', fontsize=12)
    axes[1].axis('off')
    
    # Sagittal
    axes[2].imshow(label[:, :, mid_x], cmap=cmap, vmin=0, vmax=7, origin='lower')
    axes[2].set_title(f'Sagittal (X={mid_x})', fontsize=12)
    axes[2].axis('off')
    
    plt.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    # Show legend
    fig_legend, ax_legend = plt.subplots(figsize=(10, 1))
    for i, name in CLASS_NAMES.items():
        ax_legend.bar(i, 1, color=cmap(i), label=f'{i}: {name}')
    ax_legend.set_xlim(-0.5, len(CLASS_NAMES) - 0.5)
    ax_legend.set_xticks(list(CLASS_NAMES.keys()))
    ax_legend.set_xticklabels([f'{i}' for i in CLASS_NAMES.keys()])
    ax_legend.set_yticks([])
    ax_legend.legend(loc='upper center', bbox_to_anchor=(0.5, -0.5), ncol=4)
    plt.title('Label Legend', fontsize=12)
    plt.tight_layout()
    plt.show()

# Show label slices
show_label_slices(label_volume.astype(int), "Segmentation Labels - Middle Slices")

## 8. Overlay Labels on MRI

Visualize the segmentation overlaid on the MRI.

In [None]:
def show_overlay(mri, label, slice_idx=None, alpha=0.4):
    """Overlay segmentation labels on MRI."""
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    cmap = create_label_colormap()
    
    if slice_idx is None:
        mid_z = mri.shape[0] // 2
        mid_y = mri.shape[1] // 2
        mid_x = mri.shape[2] // 2
    else:
        mid_z, mid_y, mid_x = slice_idx
    
    # Normalize MRI for display
    mri_display = (mri - mri.min()) / (mri.max() - mri.min() + 1e-8)
    
    # Row 1: MRI only
    axes[0, 0].imshow(mri_display[mid_z, :, :], cmap='gray', origin='lower')
    axes[0, 0].set_title('MRI - Axial', fontsize=12)
    axes[0, 0].axis('off')
    
    axes[0, 1].imshow(mri_display[:, mid_y, :], cmap='gray', origin='lower')
    axes[0, 1].set_title('MRI - Coronal', fontsize=12)
    axes[0, 1].axis('off')
    
    axes[0, 2].imshow(mri_display[:, :, mid_x], cmap='gray', origin='lower')
    axes[0, 2].set_title('MRI - Sagittal', fontsize=12)
    axes[0, 2].axis('off')
    
    # Row 2: MRI with overlay
    # Axial
    axes[1, 0].imshow(mri_display[mid_z, :, :], cmap='gray', origin='lower')
    mask = label[mid_z, :, :] > 0
    overlay = np.ma.masked_where(~mask, label[mid_z, :, :])
    axes[1, 0].imshow(overlay, cmap=cmap, alpha=alpha, vmin=0, vmax=7, origin='lower')
    axes[1, 0].set_title('Overlay - Axial', fontsize=12)
    axes[1, 0].axis('off')
    
    # Coronal
    axes[1, 1].imshow(mri_display[:, mid_y, :], cmap='gray', origin='lower')
    mask = label[:, mid_y, :] > 0
    overlay = np.ma.masked_where(~mask, label[:, mid_y, :])
    axes[1, 1].imshow(overlay, cmap=cmap, alpha=alpha, vmin=0, vmax=7, origin='lower')
    axes[1, 1].set_title('Overlay - Coronal', fontsize=12)
    axes[1, 1].axis('off')
    
    # Sagittal
    axes[1, 2].imshow(mri_display[:, :, mid_x], cmap='gray', origin='lower')
    mask = label[:, :, mid_x] > 0
    overlay = np.ma.masked_where(~mask, label[:, :, mid_x])
    axes[1, 2].imshow(overlay, cmap=cmap, alpha=alpha, vmin=0, vmax=7, origin='lower')
    axes[1, 2].set_title('Overlay - Sagittal', fontsize=12)
    axes[1, 2].axis('off')
    
    plt.suptitle('MRI with Segmentation Overlay', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Show overlay
show_overlay(mri_volume, label_volume.astype(int), alpha=0.5)

## 9. Per-Class 3D Visualization

Visualize each tissue class separately.

In [None]:
def show_per_class_slices(mri, label, slice_idx=None):
    """Show MRI with each class highlighted."""
    if slice_idx is None:
        slice_idx = mri.shape[0] // 2
    
    unique_labels = np.unique(label)
    n_classes = len([l for l in unique_labels if l > 0])  # Exclude background
    
    cols = min(4, n_classes)
    rows = (n_classes + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
    if rows == 1:
        axes = axes.reshape(1, -1)
    axes = axes.flatten()
    
    # Normalize MRI
    mri_display = (mri - mri.min()) / (mri.max() - mri.min() + 1e-8)
    mri_slice = mri_display[slice_idx, :, :]
    label_slice = label[slice_idx, :, :]
    
    colors = ['red', 'green', 'blue', 'yellow', 'magenta', 'cyan', 'orange']
    
    plot_idx = 0
    for class_id in unique_labels:
        if class_id == 0:  # Skip background
            continue
        
        ax = axes[plot_idx]
        ax.imshow(mri_slice, cmap='gray', origin='lower')
        
        # Create mask for this class
        class_mask = (label_slice == class_id)
        
        if class_mask.any():
            # Show contour
            ax.contour(class_mask, colors=[colors[int(class_id) % len(colors)]], linewidths=2)
            # Show filled region
            masked = np.ma.masked_where(~class_mask, np.ones_like(mri_slice))
            ax.imshow(masked, cmap='Reds', alpha=0.3, origin='lower')
        
        ax.set_title(f'{int(class_id)}: {CLASS_NAMES.get(int(class_id), "Unknown")}', fontsize=11)
        ax.axis('off')
        plot_idx += 1
    
    # Hide unused axes
    for i in range(plot_idx, len(axes)):
        axes[i].axis('off')
    
    plt.suptitle(f'Per-Class Segmentation (Z={slice_idx})', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Show per-class visualization
show_per_class_slices(mri_volume, label_volume.astype(int))

## 10. Label Volume Analysis

Analyze the volume of each tissue class.

In [None]:
def analyze_label_volumes(label, voxel_size=(1.0, 1.0, 1.0)):
    """Analyze volumes of each tissue class."""
    voxel_volume = np.prod(voxel_size)  # in mm^3
    
    print("\n" + "=" * 60)
    print("Tissue Volume Analysis")
    print("=" * 60)
    print(f"Voxel size: {voxel_size} mm")
    print(f"Voxel volume: {voxel_volume:.4f} mm³\n")
    
    results = []
    unique_labels = np.unique(label)
    
    for class_id in unique_labels:
        count = np.sum(label == class_id)
        volume_mm3 = count * voxel_volume
        volume_ml = volume_mm3 / 1000  # Convert to ml
        name = CLASS_NAMES.get(int(class_id), f'Class {class_id}')
        
        results.append({
            'class_id': int(class_id),
            'name': name,
            'voxels': count,
            'volume_mm3': volume_mm3,
            'volume_ml': volume_ml
        })
        
        if class_id > 0:  # Skip background in printout
            print(f"{class_id}: {name:<20} {count:>10,} voxels = {volume_mm3:>12,.1f} mm³ = {volume_ml:>8.2f} ml")
    
    return results

# Analyze volumes
# Use actual voxel size if available from header
voxel_size = (1.0, 1.0, 1.0)  # Default, update with actual spacing
volume_results = analyze_label_volumes(label_volume, voxel_size)

In [None]:
# Visualize tissue volumes
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Filter out background
tissue_data = [r for r in volume_results if r['class_id'] > 0]
names = [r['name'] for r in tissue_data]
volumes = [r['volume_ml'] for r in tissue_data]
voxels = [r['voxels'] for r in tissue_data]

# Bar chart of volumes
colors = plt.cm.Set3(np.linspace(0, 1, len(names)))
axes[0].barh(names, volumes, color=colors)
axes[0].set_xlabel('Volume (ml)', fontsize=12)
axes[0].set_title('Tissue Volumes', fontsize=14)

# Pie chart of voxel distribution
axes[1].pie(voxels, labels=names, autopct='%1.1f%%', colors=colors, startangle=90)
axes[1].set_title('Tissue Distribution', fontsize=14)

plt.tight_layout()
plt.show()

## 11. Preprocessing Preview

Preview the effect of preprocessing transforms.

In [None]:
# Apply preprocessing
print("Applying preprocessing...")

# Original
original = mri_volume.copy()

# Intensity clipping
clipped = intensity_clipping(mri_volume, percentile_low=1, percentile_high=99)

# Z-score normalization
normalized = zscore_normalize(clipped)

# Compare statistics
print(f"\nOriginal: min={original.min():.2f}, max={original.max():.2f}, mean={original.mean():.2f}")
print(f"Clipped:  min={clipped.min():.2f}, max={clipped.max():.2f}, mean={clipped.mean():.2f}")
print(f"Normalized: min={normalized.min():.2f}, max={normalized.max():.2f}, mean={normalized.mean():.4f}, std={normalized.std():.4f}")

In [None]:
# Visualize preprocessing effects
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

mid_z = mri_volume.shape[0] // 2

# Row 1: Slices
axes[0, 0].imshow(original[mid_z, :, :], cmap='gray', origin='lower')
axes[0, 0].set_title('Original', fontsize=12)
axes[0, 0].axis('off')

axes[0, 1].imshow(clipped[mid_z, :, :], cmap='gray', origin='lower')
axes[0, 1].set_title('After Clipping', fontsize=12)
axes[0, 1].axis('off')

axes[0, 2].imshow(normalized[mid_z, :, :], cmap='gray', origin='lower')
axes[0, 2].set_title('After Z-score Normalization', fontsize=12)
axes[0, 2].axis('off')

# Row 2: Histograms
axes[1, 0].hist(original.flatten(), bins=100, color='steelblue', alpha=0.7)
axes[1, 0].set_title('Original Histogram', fontsize=12)
axes[1, 0].set_xlabel('Intensity')

axes[1, 1].hist(clipped.flatten(), bins=100, color='coral', alpha=0.7)
axes[1, 1].set_title('Clipped Histogram', fontsize=12)
axes[1, 1].set_xlabel('Intensity')

axes[1, 2].hist(normalized.flatten(), bins=100, color='green', alpha=0.7)
axes[1, 2].set_title('Normalized Histogram', fontsize=12)
axes[1, 2].set_xlabel('Intensity')

plt.suptitle('Preprocessing Pipeline Visualization', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 12. Summary

Key findings from the EDA:

In [None]:
print("\n" + "=" * 60)
print(" EDA Summary")
print("=" * 60)
print(f"\nVolume shape: {mri_volume.shape}")
print(f"Number of tissue classes: {len(np.unique(label_volume))}")
print(f"Unique labels: {np.unique(label_volume).tolist()}")
print(f"\nIntensity range: [{mri_volume.min():.1f}, {mri_volume.max():.1f}]")
print(f"\nPreprocessing applied:")
print(f"  1. Intensity clipping (1st-99th percentile)")
print(f"  2. Z-score normalization")
print(f"\nReady for training with the following tissue classes:")
for idx, name in CLASS_NAMES.items():
    if idx in np.unique(label_volume):
        count = np.sum(label_volume == idx)
        print(f"  {idx}: {name} ({count:,} voxels)")
print("\n" + "=" * 60)

## Next Steps

1. **Preprocessing**: Apply the preprocessing pipeline to all samples
2. **Dataset**: Create PyTorch Dataset with augmentation
3. **Model**: Train the 3D U-Net
4. **Evaluation**: Evaluate using multi-class Dice scores