In [None]:
input_layer = Input((128, 128, 4))
model = build_unet(input_layer, 'he_normal', 0.2)
model.compile(
    loss="categorical_crossentropy", 
    optimizer=keras.optimizers.Adam(learning_rate=0.001), 
    metrics=[
        'accuracy',
        tf.keras.metrics.MeanIoU(num_classes=4),
        dice_coef,
        precision,
        sensitivity,
        specificity,
        dice_coef_necrotic,
        dice_coef_edema,
        dice_coef_enhancing
    ]
)
plot_model(model, 
           show_shapes = True,
           show_dtype=False,
           show_layer_names = True, 
           rankdir = 'TB', 
           expand_nested = False, 
           dpi = 70)


In [207]:
# Dataset
import os
import numpy as np
import nibabel as nib
import tensorflow as tf
import cv2
from sklearn.model_selection import train_test_split


def load_nifti_file(filepath):
    #Load a NIfTI file and return its data as a numpy array
    scan = nib.load(filepath)
    return scan.get_fdata()

def normalize(volume):
    #Normalize the volume by clipping and standardizing
    min_val = -1000
    max_val = 400
    volume = np.clip(volume, min_val, max_val)
    mean = np.mean(volume)
    std = np.std(volume)
    volume = (volume - mean) / std
    return volume



input_dir = './Dataset'
output_dir = './Dataset/Processed'
total_size={'Training':369,'Validation':125}

input={'Training':['flair', 't1', 't1ce', 't2','seg'],
 'Validation':['flair', 't1', 't1ce', 't2']}

SEGMENT_CLASSES = {
    0 : 'NOT tumor',
    1 : 'NECROTIC/CORE', # or NON-ENHANCING tumor CORE
    2 : 'EDEMA',
    3 : 'ENHANCING' # original 4 -> converted into 3 later
}
def Datagen(phase,chunk_index,chunk_size,input_dir, output_dir):
        start_id = chunk_index * chunk_size + 1
        end_id = min((chunk_index + 1) * chunk_size, total_size[phase])
        for i in range(start_id,end_id):
            id = f"{i:03d}"
            npy_folder_path = os.path.join(output_dir, phase, f'chunk_{chunk_index}',id)
            os.makedirs(npy_folder_path, exist_ok=True)
            for modality in input[phase]:
                    if modality != 'seg':
                        #Modality data
                        channel_name = f'BraTS20_{phase}_{id}_{modality}'
                        channel_path = os.path.join(input_dir, phase, modality)
                        channel = load_nifti_file(os.path.join(channel_path, channel_name+ '.nii'))
                        # Save numpy file
                        npy_channel_file_path=os.path.join(npy_folder_path,channel_name+ '.npy')
                        np.save(npy_channel_file_path, channel)
                    else:
                        #Mask data
                        mask_path= os.path.join(input_dir, phase, 'mask')
                        mask_name= f'BraTS20_{phase}_{id}_{modality}'
                        mask = load_nifti_file(os.path.join(mask_path, mask_name + '.nii'))
                        # Save numpy file
                        npy_mask_file_path = os.path.join(npy_folder_path, mask_name + '.npy')
                        np.save(npy_mask_file_path, mask)
        
        # Dataset generation          
        X = np.empty([chunk_size*155,128, 128,4])
        y = np.empty([chunk_size*155,240, 240], dtype=int)
        Y = np.empty([chunk_size*155,128,128, 4])
        
        if phase=='Training':
            for i in range(start_id,end_id):
                id = f"{i:03d}"
                folder_path = os.path.join(output_dir, phase, f'chunk_{chunk_index}',id)
                # Train datagen
                for k,modality in enumerate(input[phase]):
                    if modality!='seg':
                        nmpy_channel= np.load(os.path.join(folder_path,f'BraTS20_{phase}_{id}_{modality}'+ '.npy'))
                        for h in range(155):
                            X[h , :, :, k] = cv2.resize(nmpy_channel[:, :, h], (128, 128))
                    else:
                        npy_mask=np.load(os.path.join(folder_path, f'BraTS20_{phase}_{id}_{modality}' + '.npy'))
                        for h in range(155):
                            y[h] = npy_mask[:, :, h]
                            
            y[y==4] = 3
            mask = tf.one_hot(y, 4)
            Y = tf.image.resize(mask, (128, 128))
            return X/np.max(X), Y
        
            # validation datagen                
        if phase=='Validation':
            for i in range(start_id,end_id):
                id = f"{i:03d}"
                folder_path = os.path.join(output_dir, phase, f'chunk_{chunk_index}',id)
                for k,modality in enumerate(input[phase]):
                    nmpy_channel= np.load(os.path.join(folder_path,f'BraTS20_{phase}_{id}_{modality}'+ '.npy'))
                    for h in range(155):
                        X[h , :, :, k] = cv2.resize(nmpy_channel[:, :, h], (128, 128))
            return X/np.max(X)
       


In [208]:
X_train, y_train = Datagen(phase='Training', chunk_index=0, chunk_size=10, input_dir=input_dir, output_dir=output_dir)

In [213]:
Y=y_train.numpy()
X_train, X_test, y_train, y_test = train_test_split(X_train, Y, test_size=0.2, random_state=42)

In [191]:
X_val = Datagen(phase='Validation', chunk_index=0, chunk_size=10, input_dir=input_dir, output_dir=output_dir)

In [None]:
# Model Artichecture


In [None]:
# Train 

In [None]:
# Evaluate

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv3D, MaxPooling3D, UpSampling3D, Concatenate, Activation
from tensorflow.keras.models import Model

def conv_block(inputs, num_filters):
    """Convolutional block consisting of two Conv3D layers followed by an activation function."""
    x = Conv3D(num_filters, 3, padding='same')(inputs)
    x = Activation('relu')(x)
    x = Conv3D(num_filters, 3, padding='same')(x)
    x = Activation('relu')(x)
    return x

def encoder_block(inputs, num_filters):
    """Encoder block consisting of a conv block followed by a max pooling layer."""
    x = conv_block(inputs, num_filters)
    p = MaxPooling3D(pool_size=(2, 2, 2))(x)
    return x, p

def decoder_block(inputs, skip_features, num_filters):
    """Decoder block consisting of an upsampling layer followed by a conv block."""
    x = UpSampling3D(size=(2, 2, 2))(inputs)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x
def unet_model(input_shape):
    """Builds the 3D U-Net model."""
    inputs = Input(input_shape)

    # Encoder
    s1, p1 = encoder_block(inputs, 32)
    s2, p2 = encoder_block(p1, 64)
    s3, p3 = encoder_block(p2, 128)
    s4, p4 = encoder_block(p3, 256)

    # Bridge
    b1 = conv_block(p4, 512)

    # Decoder
    d1 = decoder_block(b1, s4, 256)
    d2 = decoder_block(d1, s3, 128)
    d3 = decoder_block(d2, s2, 64)
    d4 = decoder_block(d3, s1, 32)

    outputs = Conv3D(1, 1, padding='same', activation='sigmoid')(d4)

    model = Model(inputs, outputs, name='3d_unet')
    return model
model = unet_model(input_shape=(*image_size, 1))
model.compile(optimizer=Adam(learning_rate=1e-4), loss='binary_crossentropy', metrics=['accuracy'])

checkpoint = ModelCheckpoint('unet_brats2020.h5', save_best_only=True, monitor='val_loss', mode='min')
early_stopping = EarlyStopping(monitor='val_loss', patience=10, mode='min')

history = model.fit(train_gen, validation_data=val_gen, epochs=epochs, callbacks=[checkpoint, early_stopping])
model.save_weights('model_weights.h5')

In [None]:
# Iterate over each chunk
for start in range(1, total_files + 1, chunk_size):
    Datagen(phase='Training', chunk_index=start, chunk_size=chunk_size, input_dir=input_dir, output_dir=output_dir)
    end = min(start_id + chunk_size - 1, total_files)