In [193]:
import os
import numpy as np
import nibabel as nib
def load_nifti_file(filepath):
    """Load a NIfTI file and return its data as a numpy array."""
    scan = nib.load(filepath)
    return scan.get_fdata()

def normalize(volume):
    """Normalize the volume by clipping and standardizing."""
    min_val = -1000
    max_val = 400
    volume = np.clip(volume, min_val, max_val)
    mean = np.mean(volume)
    std = np.std(volume)
    volume = (volume - mean) / std
    return volume

path = {
    'Training': {
        'flair': 'flair',
        't1': 't1',
        't1ce': 't1ce',
        't2': 't2',
        'mask': 'seg'
    },
    # 'Validation': {
    #     'flair': 'flair',
    #     't1': 't1',
    #     't1ce': 't1ce',
    #     't2': 't2'
    # }
}
data_dir = './Dataset'
output_dir = './Dataset/Processed'
def preprocess_data(data_dir, output_dir):
    """Load, preprocess and split the dataset based on the provided folder structure."""
    for phase, data in path.items():
        if phase == 'Training':
            for folder,modality in data.items():
                npy_path=os.path.join(output_dir, phase, folder)
                os.makedirs(npy_path, exist_ok=True)
                for i in range(1,370):
                    id=f"{i:03d}"
                    file_name=f'BraTS20_{phase}_{id}_{modality}'
                    folder_path = os.path.join(data_dir, phase, folder)
                    img = load_nifti_file(os.path.join(folder_path, file_name+'.nii'))
                    img = normalize(img)
                    npy_file_path = os.path.join(npy_path, file_name + '.npy')
                    np.save(npy_file_path,img)
preprocess_data(data_dir, output_dir)

In [None]:
class CustomDataGen(tf.keras.utils.Sequence):
    def __init__(self, image_filenames, labels, batch_size, image_size, augment=False):
        self.image_filenames = image_filenames
        self.labels = labels
        self.batch_size = batch_size
        self.image_size = image_size
        self.augment = augment
        self.on_epoch_end()
        
        if self.augment:
            self.augmenter = ImageDataGenerator(
                rotation_range=10,
                width_shift_range=0.1,
                height_shift_range=0.1,
                zoom_range=0.1,
                horizontal_flip=True,
                fill_mode='nearest'
            )
def __len__(self):
    """Denotes the number of batches per epoch."""
    return int(np.ceil(len(self.image_filenames) / self.batch_size))

def on_epoch_end(self):
    """Updates indexes after each epoch."""
    self.indexes = np.arange(len(self.image_filenames))
    if self.augment:
        np.random.shuffle(self.indexes)

def __getitem__(self, index):
    """Generate one batch of data."""
    batch_indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
    batch_filenames = [self.image_filenames[i] for i in batch_indexes]
    batch_labels = [self.labels[i] for i in batch_indexes]

    X, y = self.__data_generation(batch_filenames, batch_labels)

    if self.augment:
        return next(self.augmenter.flow(X, y, batch_size=self.batch_size))
    else:
        return X, y

def __data_generation(self, batch_filenames, batch_labels):
    """Generates data containing batch_size samples."""
    X = np.empty((self.batch_size, *self.image_size, 1))
    y = np.empty((self.batch_size, *self.image_size, 1))

    for i, (filename, label) in enumerate(zip(batch_filenames, batch_labels)):
        # Load image and resize
        image = np.load(filename)
        image = np.expand_dims(image, axis=-1)
        label = np.load(label)
        label = np.expand_dims(label, axis=-1)
        
        X[i,] = image
        y[i,] = label

    return X, y


In [None]:
def get_file_paths(data_dir):
    """Get file paths for images and labels."""
    image_paths = []
    label_paths = []
    for root, _, files in os.walk(data_dir):
        for file in files:
            if 'flair' in file:
                image_paths.append(os.path.join(root, file))
            elif 'seg' in file:
                label_paths.append(os.path.join(root, file))
    return sorted(image_paths), sorted(label_paths)


In [None]:
data_dir = './Dataset/Processed'
image_size = (240, 240, 155)
batch_size = 2
epochs = 100

train_image_paths, train_label_paths = get_file_paths(os.path.join(data_dir, 'train'))
val_image_paths, val_label_paths = get_file_paths(os.path.join(data_dir, 'val'))

train_gen = CustomDataGen(train_image_paths, train_label_paths, batch_size, image_size, augment=True)
val_gen = CustomDataGen(val_image_paths, val_label_paths, batch_size, image_size, augment=False)


In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv3D, MaxPooling3D, UpSampling3D, Concatenate, Activation
from tensorflow.keras.models import Model

def conv_block(inputs, num_filters):
    """Convolutional block consisting of two Conv3D layers followed by an activation function."""
    x = Conv3D(num_filters, 3, padding='same')(inputs)
    x = Activation('relu')(x)
    x = Conv3D(num_filters, 3, padding='same')(x)
    x = Activation('relu')(x)
    return x

def encoder_block(inputs, num_filters):
    """Encoder block consisting of a conv block followed by a max pooling layer."""
    x = conv_block(inputs, num_filters)
    p = MaxPooling3D(pool_size=(2, 2, 2))(x)
    return x, p

def decoder_block(inputs, skip_features, num_filters):
    """Decoder block consisting of an upsampling layer followed by a conv block."""
    x = UpSampling3D(size=(2, 2, 2))(inputs)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x
def unet_model(input_shape):
    """Builds the 3D U-Net model."""
    inputs = Input(input_shape)

    # Encoder
    s1, p1 = encoder_block(inputs, 32)
    s2, p2 = encoder_block(p1, 64)
    s3, p3 = encoder_block(p2, 128)
    s4, p4 = encoder_block(p3, 256)

    # Bridge
    b1 = conv_block(p4, 512)

    # Decoder
    d1 = decoder_block(b1, s4, 256)
    d2 = decoder_block(d1, s3, 128)
    d3 = decoder_block(d2, s2, 64)
    d4 = decoder_block(d3, s1, 32)

    outputs = Conv3D(1, 1, padding='same', activation='sigmoid')(d4)

    model = Model(inputs, outputs, name='3d_unet')
    return model


In [None]:
model = unet_model(input_shape=(*image_size, 1))

In [None]:
model.compile(optimizer=Adam(learning_rate=1e-4), loss='binary_crossentropy', metrics=['accuracy'])

checkpoint = ModelCheckpoint('unet_brats2020.h5', save_best_only=True, monitor='val_loss', mode='min')
early_stopping = EarlyStopping(monitor='val_loss', patience=10, mode='min')

history = model.fit(train_gen, validation_data=val_gen, epochs=epochs, callbacks=[checkpoint, early_stopping])


In [None]:
model.save_weights('model_weights.h5')