In [1]:
import os
import nibabel as nib
import numpy as np
import tensorflow as tf

data_dir = r'C:\Users\DELL\Graduation Project\Datasets\Testing Dataloader\BraTS2021_00000'
batch_size = 32
img_height = 256
img_width = 256

# Define the ImageDataGenerator with appropriate preprocessing parameters
datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)

# Create a generator for loading the Nifti volumes
def nifti_generator(data_dir, batch_size):
    # List the Nifti files in the data directory
    nifti_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.nii.gz') or f.endswith('.nii')]
    
    # Load the Nifti volumes and yield batches of images
    while True:
        np.random.shuffle(nifti_files)
        for i in range(0, len(nifti_files), batch_size):
            batch_files = nifti_files[i:i+batch_size]
            batch_images = []
            for file in batch_files:
                nifti = nib.load(file)
                # Get the image data from the Nifti volume
                image = nifti.get_fdata()
                # Resize the image to the desired size
                image = tf.image.resize(image, [img_height, img_width])
                # Add the image to the batch
                batch_images.append(image)
            # Apply the data augmentation and yield the batch of images
            batch_images = np.stack(batch_images, axis=0)
            yield datagen.flow(batch_images, batch_size=batch_size).next()

# Create an image dataset from the Nifti generator
nifti_dataset = tf.data.Dataset.from_generator(
    nifti_generator, args=[data_dir, batch_size], 
    output_types=(tf.float32, tf.float32), 
    output_shapes=((batch_size, img_height, img_width, 1), (batch_size,)))


In [5]:
print(nifti_dataset)

<FlatMapDataset shapes: ((32, 256, 256, 1), (32,)), types: (tf.float32, tf.float32)>
