# Federated Learning with BraTS 2021 Dataset
This notebook implements a federated learning framework for brain tumor segmentation using the BraTS 2021 dataset.

In [1]:
import os
import numpy as np
import nibabel as nib
import cv2
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt

In [2]:
# Define dataset path and configurations
images_path = 'Data/BraTS2021'
IMG_SIZE = 128  # Resize images to 128x128
VOLUME_SLICES = 50  # Number of slices per patient
VOLUME_START_AT = 22  # Start slicing from this index
REDUCE_PATIENTS = False  # Option to reduce the number of patients per hospital
PATIENT_LIMIT = 10  # Number of patients per hospital if REDUCE_PATIENTS is True

# Load the dataset file paths
all_patients = [os.path.join(images_path, p) for p in os.listdir(images_path)]
np.random.shuffle(all_patients)

# Optionally reduce dataset size
if REDUCE_PATIENTS:
    all_patients = all_patients[:PATIENT_LIMIT * 3]  # Total patients across hospitals

In [3]:
# Split dataset among hospitals (nodes)
nodes = {
    "Hospital_1": all_patients[:len(all_patients)//3],
    "Hospital_2": all_patients[len(all_patients)//3:2*len(all_patients)//3],
    "Hospital_3": all_patients[2*len(all_patients)//3:]
}

# Print number of patients assigned to each hospital
for hospital, patients in nodes.items():
    print(f"{hospital} has {len(patients)} patients")

Hospital_1 has 417 patients
Hospital_2 has 417 patients
Hospital_3 has 417 patients


In [4]:
# Define Data Generator
class DataGenerator(keras.utils.Sequence):
    def __init__(self, list_IDs, dim=(IMG_SIZE, IMG_SIZE), batch_size=1, n_channels=2, shuffle=True):
        self.dim = dim
        self.batch_size = batch_size
        self.list_IDs = list_IDs
        self.n_channels = n_channels
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
        Batch_ids = [self.list_IDs[k] for k in indexes]
        X, y = self.__data_generation(Batch_ids)
        return X, y

    def on_epoch_end(self):
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def __data_generation(self, Batch_ids):
        X = np.zeros((self.batch_size * VOLUME_SLICES, *self.dim, self.n_channels))
        y = np.zeros((self.batch_size * VOLUME_SLICES, IMG_SIZE, IMG_SIZE, 4))

        for c, i in enumerate(Batch_ids):
            case_path = os.path.join(images_path, os.path.basename(i), os.path.basename(i))
            flair = nib.load(f'{case_path}_flair.nii').get_fdata()
            ce = nib.load(f'{case_path}_t1ce.nii').get_fdata()
            seg = nib.load(f'{case_path}_seg.nii').get_fdata()

            for j in range(VOLUME_SLICES):
                X[j + VOLUME_SLICES * c, :, :, 0] = cv2.resize(flair[:, :, j + VOLUME_START_AT], (IMG_SIZE, IMG_SIZE))
                X[j + VOLUME_SLICES * c, :, :, 1] = cv2.resize(ce[:, :, j + VOLUME_START_AT], (IMG_SIZE, IMG_SIZE))
                label = cv2.resize(seg[:, :, j + VOLUME_START_AT], (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST)
                y[j + VOLUME_SLICES * c] = tf.one_hot(label, depth=4)

        return X / np.max(X), y

In [1]:
# Distribute dataset into train, validation, and test sets
hospitals_data = {}
for hospital, files in nodes.items():
    np.random.shuffle(files)
    train_split = int(0.7 * len(files))
    val_split = int(0.85 * len(files))
    hospitals_data[hospital] = {
        "train": DataGenerator(files[:train_split]),
        "val": DataGenerator(files[train_split:val_split]),
        "test": DataGenerator(files[val_split:])
    }

    # Print dataset shapes for each hospital
    print(f"{hospital}: Train={len(hospitals_data[hospital]['train'])}, Val={len(hospitals_data[hospital]['val'])}, Test={len(hospitals_data[hospital]['test'])}")

NameError: name 'nodes' is not defined

In [None]:
# Visualize Some Slices from Each Hospital
for hospital, data in hospitals_data.items():
    X_sample, y_sample = data['train'].__getitem__(0)
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(X_sample[0, :, :, 0], cmap='gray')
    ax[0].set_title(f"{hospital} - FLAIR")
    ax[1].imshow(np.argmax(y_sample[0], axis=-1), cmap='jet')
    ax[1].set_title(f"{hospital} - Segmentation")
    plt.show()

In [None]:
# Define UNet Model
def create_unet_model(input_shape=(IMG_SIZE, IMG_SIZE, 2)):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same')(x)
    x = layers.Conv2D(4, (1, 1), activation='softmax', padding='same')(x)
    return keras.Model(inputs, x)