In [1]:
import os
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from multiprocessing import Pool, cpu_count
from PIL import Image

In [2]:
def load_nii(file_path):
    """Load NIfTI file and return 3D volume data."""
    try:
        img = nib.load(file_path)
        data = img.get_fdata()
        print(f"Loaded volume shape: {data.shape}")
        return data
    except Exception as e:
        raise ValueError(f"Error loading {file_path}: {str(e)}")

In [4]:
def normalize_slice(slice_data):
    """Normalize slice to 0-255 for image saving."""
    slice_data = (slice_data - np.min(slice_data)) / (np.max(slice_data) - np.min(slice_data)) * 255
    return slice_data.astype(np.uint8)

def save_slice(slice_data, output_dir, slice_idx, plane="axial"):
    """Save a single slice as PNG."""
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, f"{plane}_slice_{slice_idx:04d}.png")
    Image.fromarray(slice_data).save(output_path)

In [5]:
def process_volume(volume, output_dir, plane="axial"):
    """Extract and save all slices from a 3D volume along a given plane."""
    if plane == "axial":
        slices = [volume[:, :, i] for i in range(volume.shape[2])]
    elif plane == "coronal":
        slices = [volume[:, i, :] for i in range(volume.shape[1])]
    elif plane == "sagittal":
        slices = [volume[i, :, :] for i in range(volume.shape[0])]
    else:
        raise ValueError("Invalid plane. Choose: 'axial', 'coronal', or 'sagittal'")

    # Parallelize saving slices (faster for large volumes)
    with Pool(cpu_count()) as pool:
        args = [(normalize_slice(slice.T), output_dir, i, plane) for i, slice in enumerate(slices)]
        pool.starmap(save_slice, tqdm(args, desc=f"Saving {plane} slices"))

In [None]:
# User inputs
nii_file = "dataset/subject01.nii"  # Replace with your .nii path
output_dir = "output_slices"         # Output folder
plane = "axial"                     # Slice plane: "axial", "coronal", or "sagittal"

# Load and process
volume = load_nii(nii_file)
process_volume(volume, output_dir, plane)
print(f"✅ All slices saved to: {output_dir}")

Loaded volume shape: (176, 256, 256)


  slice_data = (slice_data - np.min(slice_data)) / (np.max(slice_data) - np.min(slice_data)) * 255
  return slice_data.astype(np.uint8)
Saving axial slices: 100%|██████████| 256/256 [00:00<00:00, 379.46it/s]
