In [129]:
# Dataset
import os
import numpy as np
import nibabel as nib

def load_nifti_file(filepath):
    #Load a NIfTI file and return its data as a numpy array
    scan = nib.load(filepath)
    return scan.get_fdata()

def normalize(volume):
    #Normalize the volume by clipping and standardizing
    min_val = -1000
    max_val = 400
    volume = np.clip(volume, min_val, max_val)
    mean = np.mean(volume)
    std = np.std(volume)
    volume = (volume - mean) / std
    return volume

input_dir = './Dataset'
output_dir = './Dataset/Processed'
total_size={'Training':369,'Validation':125}

input={'Training':['flair', 't1', 't1ce', 't2','seg'],
 'Validation':['flair', 't1', 't1ce', 't2']}
def Datagen(phase,chunk_index,chunk_size,input_dir, output_dir):
        chunk_size=chunk_size
        X = np.empty([chunk_size, 240, 240, 155,4])
        y = np.empty([chunk_size,240, 240, 155,1], dtype=int)
        start_id = chunk_index * chunk_size + 1
        end_id = min((chunk_index + 1) * chunk_size, total_size[phase])
        for i in range(start_id, end_id + 1):
            h=0
            id = f"{i:03d}"
            npy_folder_path = os.path.join(output_dir, phase, f'chunk_{chunk_index}',id)
            os.makedirs(npy_folder_path, exist_ok=True)
            for j,modality in enumerate( input[phase]):
                if modality != 'seg':
                    #Modality data
                    channel_name = f'BraTS20_{phase}_{id}_{modality}'
                    channel_path = os.path.join(input_dir, phase, modality)
                    channel = load_nifti_file(os.path.join(channel_path, channel_name+ '.nii'))
                    # Normalize 
                    channel = normalize(channel)
                    # Save numpy file
                    npy_channel_file_path=os.path.join(npy_folder_path,channel_name+ '.npy')
                    np.save(npy_channel_file_path, channel)
                    X[h, :, :, :, j] = np.load(os.path.join(npy_folder_path,channel_name+ '.npy'))
                else:
                    #Mask data
                    mask_path= os.path.join(input_dir, phase, 'mask')
                    mask_name= f'BraTS20_{phase}_{id}_{modality}'
                    mask = load_nifti_file(os.path.join(mask_path, mask_name + '.nii'))
                    npy_mask_file_path = os.path.join(npy_folder_path, mask_name + '.npy')
                    np.save(npy_mask_file_path, mask)
                    npy_mask=np.load(os.path.join(npy_folder_path, mask_name + '.npy'))
                    y[h,] = npy_mask.reshape((240, 240, 155, 1))
            h=+1
        if phase=='Training':
            return X, y
        else:
            return X

In [130]:
X_train, y_train = Datagen(phase='Training', chunk_index=0, chunk_size=10, input_dir=input_dir, output_dir=output_dir)

In [131]:
X_val = Datagen(phase='Validation', chunk_index=0, chunk_size=10, input_dir=input_dir, output_dir=output_dir)

In [None]:
# Model Artichecture


In [None]:
# Train 

In [None]:
# Evaluate

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv3D, MaxPooling3D, UpSampling3D, Concatenate, Activation
from tensorflow.keras.models import Model

def conv_block(inputs, num_filters):
    """Convolutional block consisting of two Conv3D layers followed by an activation function."""
    x = Conv3D(num_filters, 3, padding='same')(inputs)
    x = Activation('relu')(x)
    x = Conv3D(num_filters, 3, padding='same')(x)
    x = Activation('relu')(x)
    return x

def encoder_block(inputs, num_filters):
    """Encoder block consisting of a conv block followed by a max pooling layer."""
    x = conv_block(inputs, num_filters)
    p = MaxPooling3D(pool_size=(2, 2, 2))(x)
    return x, p

def decoder_block(inputs, skip_features, num_filters):
    """Decoder block consisting of an upsampling layer followed by a conv block."""
    x = UpSampling3D(size=(2, 2, 2))(inputs)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x
def unet_model(input_shape):
    """Builds the 3D U-Net model."""
    inputs = Input(input_shape)

    # Encoder
    s1, p1 = encoder_block(inputs, 32)
    s2, p2 = encoder_block(p1, 64)
    s3, p3 = encoder_block(p2, 128)
    s4, p4 = encoder_block(p3, 256)

    # Bridge
    b1 = conv_block(p4, 512)

    # Decoder
    d1 = decoder_block(b1, s4, 256)
    d2 = decoder_block(d1, s3, 128)
    d3 = decoder_block(d2, s2, 64)
    d4 = decoder_block(d3, s1, 32)

    outputs = Conv3D(1, 1, padding='same', activation='sigmoid')(d4)

    model = Model(inputs, outputs, name='3d_unet')
    return model
model = unet_model(input_shape=(*image_size, 1))
model.compile(optimizer=Adam(learning_rate=1e-4), loss='binary_crossentropy', metrics=['accuracy'])

checkpoint = ModelCheckpoint('unet_brats2020.h5', save_best_only=True, monitor='val_loss', mode='min')
early_stopping = EarlyStopping(monitor='val_loss', patience=10, mode='min')

history = model.fit(train_gen, validation_data=val_gen, epochs=epochs, callbacks=[checkpoint, early_stopping])
model.save_weights('model_weights.h5')

In [None]:
# Iterate over each chunk
for start in range(1, total_files + 1, chunk_size):
    Datagen(phase='Training', chunk_index=start, chunk_size=chunk_size, input_dir=input_dir, output_dir=output_dir)
    end = min(start_id + chunk_size - 1, total_files)