# Federated Learning with BraTS 2021 Dataset
This notebook implements a federated learning framework for brain tumor segmentation using the BraTS 2021 dataset.

In [1]:
import os
import numpy as np
import nibabel as nib
import cv2
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt

In [2]:
# Define dataset path and configurations
images_path = 'Data/BraTS2021'
IMG_SIZE = 128  # Resize images to 128x128
VOLUME_SLICES = 50  # Number of slices per patient
VOLUME_START_AT = 22  # Start slicing from this index

# Load the dataset file paths
all_patients = [os.path.join(images_path, p) for p in os.listdir(images_path)]
np.random.shuffle(all_patients)

In [3]:
# Split dataset among hospitals (nodes)
nodes = {
    "Hospital_1": all_patients[:len(all_patients)//3],
    "Hospital_2": all_patients[len(all_patients)//3:2*len(all_patients)//3],
    "Hospital_3": all_patients[2*len(all_patients)//3:]
}

In [4]:
# Define Data Generator
class DataGenerator(keras.utils.Sequence):
    def __init__(self, list_IDs, dim=(IMG_SIZE, IMG_SIZE), batch_size=1, n_channels=2, shuffle=True):
        self.dim = dim
        self.batch_size = batch_size
        self.list_IDs = list_IDs
        self.n_channels = n_channels
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
        Batch_ids = [self.list_IDs[k] for k in indexes]
        X, y = self.__data_generation(Batch_ids)
        return X, y

    def on_epoch_end(self):
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def __data_generation(self, Batch_ids):
        X = np.zeros((self.batch_size * VOLUME_SLICES, *self.dim, self.n_channels))
        y = np.zeros((self.batch_size * VOLUME_SLICES, 240, 240))

        for c, i in enumerate(Batch_ids):
            case_path = os.path.join(images_path, os.path.basename(i), os.path.basename(i))
            flair = nib.load(f'{case_path}_flair.nii').get_fdata()
            ce = nib.load(f'{case_path}_t1ce.nii').get_fdata()
            seg = nib.load(f'{case_path}_seg.nii').get_fdata()

            for j in range(VOLUME_SLICES):
                X[j + VOLUME_SLICES * c, :, :, 0] = cv2.resize(flair[:, :, j + VOLUME_START_AT], (IMG_SIZE, IMG_SIZE))
                X[j + VOLUME_SLICES * c, :, :, 1] = cv2.resize(ce[:, :, j + VOLUME_START_AT], (IMG_SIZE, IMG_SIZE))
                y[j + VOLUME_SLICES * c] = seg[:, :, j + VOLUME_START_AT]

        y[y == 4] = 3  # Adjust class values
        mask = tf.one_hot(y, 4)
        Y = tf.image.resize(mask, (IMG_SIZE, IMG_SIZE), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

        return X / np.max(X), Y

In [5]:
# Distribute dataset into train, validation, and test sets
hospitals_data = {}
for hospital, files in nodes.items():
    np.random.shuffle(files)
    train_split = int(0.7 * len(files))
    val_split = int(0.85 * len(files))
    hospitals_data[hospital] = {
        "train": DataGenerator(files[:train_split]),
        "val": DataGenerator(files[train_split:val_split]),
        "test": DataGenerator(files[val_split:])
    }

In [6]:
# Define UNet model for segmentation
def create_unet_model(input_shape=(IMG_SIZE, IMG_SIZE, 2)):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same')(x)
    x = layers.Conv2D(4, (1, 1), activation='softmax')(x)
    return keras.Model(inputs, x)

In [7]:
# Initialize models for each hospital
local_models = {}
for hospital in hospitals_data.keys():
    local_models[hospital] = create_unet_model()

In [8]:
# Federated Training Process
num_rounds = 1
global_model = create_unet_model()

def federated_averaging(weight_list):
    avg_weights = []
    for weights in zip(*weight_list):
        avg_weights.append(np.mean(weights, axis=0))
    return avg_weights

for round_num in range(num_rounds):
    local_weights = []
    
    for hospital, model in local_models.items():
        model.set_weights(global_model.get_weights())
        model.compile(optimizer='adam', loss=tf.keras.losses.CategoricalCrossentropy(), metrics=['accuracy'])
        model.fit(hospitals_data[hospital]['train'], epochs=1, validation_data=hospitals_data[hospital]['val'])
        local_weights.append(model.get_weights())
    
    new_global_weights = federated_averaging(local_weights)
    global_model.set_weights(new_global_weights)

: 

In [None]:
# Evaluation on test sets
metrics = {}
for hospital, model in local_models.items():
    loss, acc = model.evaluate(hospitals_data[hospital]['test'])
    metrics[hospital] = {'Loss': loss, 'Accuracy': acc}

In [None]:
# Plot Results
hospitals = list(metrics.keys())
accuracy = [metrics[h]['Accuracy'] for h in hospitals]
plt.bar(hospitals, accuracy)
plt.ylabel('Accuracy')
plt.title('Accuracy Comparison: Local vs. Global Models')
plt.show()

In [None]:
# Print evaluation results
for hospital, metric in metrics.items():
    print(f"{hospital} - Loss: {metric['Loss']:.4f}, Accuracy: {metric['Accuracy']:.4f}")