In [None]:
# ANALYZE: Check Z dimension distribution across all volumes
def analyze_dataset_dimensions(volume_dir='volume'):
    """
    Scan all volumes to determine shape and spacing statistics.
    Critical for deciding optimal patch size!
    """
    
    print("Analyzing dataset dimensions...")
    print("This may take 1-2 minutes...\n")
    
    total_volumes = get_num_volumes(volume_dir)
    
    shapes = []
    z_dims = []
    spacings = []
    
    # Scan all volumes (only load metadata, not full data)
    for i in range(total_volumes):
        volume_files = sorted([f for f in os.listdir(volume_dir) if f.endswith('.nii')])
        filename = volume_files[i]
        
        # Load just to get header info (fast)
        img = nib.load(f'{volume_dir}/{filename}')
        shape = img.header.get_data_shape()
        spacing = img.header.get_zooms()[:3]
        
        shapes.append(shape)
        z_dims.append(shape[2])
        spacings.append(spacing)
        
        if (i + 1) % 20 == 0:
            print(f"  Scanned {i + 1}/{total_volumes}...")
    
    z_dims = np.array(z_dims)
    spacings = np.array(spacings)
    
    print("\n" + "=" * 70)
    print("Z DIMENSION STATISTICS")
    print("=" * 70)
    print(f"Total volumes: {len(z_dims)}")
    print(f"\nZ Dimension (number of slices):")
    print(f"  Minimum:         {z_dims.min()} slices")
    print(f"  5th percentile:  {np.percentile(z_dims, 5):.0f} slices")
    print(f"  25th percentile: {np.percentile(z_dims, 25):.0f} slices")
    print(f"  Median:          {np.median(z_dims):.0f} slices")
    print(f"  75th percentile: {np.percentile(z_dims, 75):.0f} slices")
    print(f"  95th percentile: {np.percentile(z_dims, 95):.0f} slices")
    print(f"  Maximum:         {z_dims.max()} slices")
    print(f"  Mean:            {z_dims.mean():.1f} slices")
    print(f"  Std Dev:         {z_dims.std():.1f} slices")
    
    print(f"\nDistribution:")
    bins = [(0, 64), (64, 96), (96, 128), (128, 160), (160, 200), (200, 300)]
    for low, high in bins:
        count = np.sum((z_dims >= low) & (z_dims < high))
        pct = 100 * count / len(z_dims)
        print(f"  {low:3d}-{high:3d} slices: {count:3d} volumes ({pct:5.1f}%)")
    
    print(f"\n" + "=" * 70)
    print("VOXEL SPACING STATISTICS (mm)")
    print("=" * 70)
    print(f"X spacing: min={spacings[:, 0].min():.3f}, max={spacings[:, 0].max():.3f}, median={np.median(spacings[:, 0]):.3f}")
    print(f"Y spacing: min={spacings[:, 1].min():.3f}, max={spacings[:, 1].max():.3f}, median={np.median(spacings[:, 1]):.3f}")
    print(f"Z spacing: min={spacings[:, 2].min():.3f}, max={spacings[:, 2].max():.3f}, median={np.median(spacings[:, 2]):.3f}")
    
    print(f"\n" + "=" * 70)
    print("RECOMMENDATIONS FOR PATCH SIZE")
    print("=" * 70)
    
    # Calculate recommendations
    min_z = int(z_dims.min())
    p5_z = int(np.percentile(z_dims, 5))
    p25_z = int(np.percentile(z_dims, 25))
    median_z = int(np.median(z_dims))
    
    volumes_below_64 = np.sum(z_dims < 64)
    volumes_below_96 = np.sum(z_dims < 96)
    volumes_below_128 = np.sum(z_dims < 128)
    
    print(f"\nOption 1: Conservative (works for ALL volumes)")
    print(f"  Patch size: (128, 128, {min_z})")
    print(f"  Covers: 100% of volumes")
    print(f"  Trade-off: {'Small Z context' if min_z < 80 else 'Reasonable Z context'}")
    
    print(f"\nOption 2: 95th percentile (skip 5% thinnest)")
    print(f"  Patch size: (128, 128, {p5_z})")
    print(f"  Covers: 95% of volumes")
    print(f"  Skip: {np.sum(z_dims < p5_z)} volumes")
    
    print(f"\nOption 3: Use 128×128×128 (standard)")
    print(f"  Patch size: (128, 128, 128)")
    print(f"  Works without padding: {np.sum(z_dims >= 128)} volumes ({100*np.mean(z_dims >= 128):.1f}%)")
    print(f"  Need padding/skip: {volumes_below_128} volumes ({100*volumes_below_128/len(z_dims):.1f}%)")
    
    print(f"\nOption 4: Use 128×128×96")
    print(f"  Patch size: (128, 128, 96)")
    print(f"  Works without padding: {np.sum(z_dims >= 96)} volumes ({100*np.mean(z_dims >= 96):.1f}%)")
    print(f"  Need padding/skip: {volumes_below_96} volumes ({100*volumes_below_96/len(z_dims):.1f}%)")
    
    # Final recommendation
    print(f"\n" + "=" * 70)
    if volumes_below_128 <= 0.05 * len(z_dims):  # 5% or less
        print("✓ RECOMMENDATION: Use (128, 128, 128)")
        print(f"  Only {volumes_below_128} volumes need special handling")
        print(f"  Option: Skip thin volumes or pad them")
    elif volumes_below_96 <= 0.05 * len(z_dims):
        print("✓ RECOMMENDATION: Use (128, 128, 96)")
        print(f"  Good balance: covers {100*np.mean(z_dims >= 96):.1f}% without modification")
    else:
        print(f"✓ RECOMMENDATION: Use (128, 128, {p5_z})")
        print(f"  Covers 95% of volumes without padding")
    print("=" * 70)
    
    return {
        'z_dims': z_dims,
        'spacings': spacings,
        'shapes': shapes,
        'min_z': min_z,
        'p5_z': p5_z,
        'median_z': median_z
    }

# Run the analysis
print("Starting dimension analysis...\n")
stats = analyze_dataset_dimensions()