# Federated Learning with BraTS 2021 Dataset
This notebook implements a federated learning framework for brain tumor segmentation using the BraTS 2021 dataset.

In [20]:
import os
import numpy as np
import nibabel as nib
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, confusion_matrix
import cv2

In [21]:
images_path = 'Data/BraTS2021'
IMG_SIZE = 128  # Resize images to 128x128
VOLUME_SLICES = 50  # Number of slices per patient
VOLUME_START_AT = 22  # Start slicing from this index
BATCH_SIZE = 8  # Batch size for training
NUM_EPOCHS = 10  # Number of epochs for training
NUM_ROUNDS = 5  # Number of federated learning rounds

In [22]:
# Load the dataset file paths
all_patients = [os.path.join(images_path, p) for p in os.listdir(images_path)]
np.random.shuffle(all_patients)

# Split dataset among hospitals (nodes)
nodes = {
    "Hospital_1": all_patients[:len(all_patients)//3],
    "Hospital_2": all_patients[len(all_patients)//3:2*len(all_patients)//3],
    "Hospital_3": all_patients[2*len(all_patients)//3:]
}

In [23]:
def create_dataset(file_paths, batch_size=BATCH_SIZE):
    def parse_function(file_path):
        # Load and preprocess data
        flair = nib.load(file_path + '_flair.nii').get_fdata()
        ce = nib.load(file_path + '_t1ce.nii').get_fdata()
        seg = nib.load(file_path + '_seg.nii').get_fdata()

        # Resize and normalize
        flair = cv2.resize(flair[:, :, VOLUME_START_AT:VOLUME_START_AT + VOLUME_SLICES], (IMG_SIZE, IMG_SIZE))
        ce = cv2.resize(ce[:, :, VOLUME_START_AT:VOLUME_START_AT + VOLUME_SLICES], (IMG_SIZE, IMG_SIZE))
        seg = cv2.resize(seg[:, :, VOLUME_START_AT:VOLUME_START_AT + VOLUME_SLICES], (IMG_SIZE, IMG_SIZE))

        # Adjust class values
        seg[seg == 4] = 3
        seg = tf.one_hot(tf.cast(seg, tf.int32), depth=4)

        return (flair, ce), seg

    dataset = tf.data.Dataset.from_tensor_slices(file_paths)
    dataset = dataset.map(parse_function, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset

In [24]:
def create_unet_model(input_shape=(IMG_SIZE, IMG_SIZE, 2)):
    inputs = layers.Input(shape=input_shape)
    
    # Downsample
    x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(x)
    
    # Upsample
    x = layers.Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same')(x)
    x = layers.Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same')(x)
    outputs = layers.Conv2D(4, (1, 1), activation='softmax')(x)
    
    return keras.Model(inputs, outputs)

In [25]:
# Initialize models for each hospital
local_models = {hospital: create_unet_model() for hospital in nodes.keys()}
global_model = create_unet_model()

# Federated Averaging
def federated_averaging(weight_list):
    avg_weights = []
    for weights in zip(*weight_list):
        avg_weights.append(np.mean(weights, axis=0))
    return avg_weights

In [10]:
## Cell 5: Define UNet model with proper upsampling
def create_unet_model(input_shape=(IMG_SIZE, IMG_SIZE, 2)):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same')(x)
    x = layers.Conv2DTranspose(4, (4, 4), strides=(4, 4), activation='softmax', padding='same')(x)  # Proper upsampling
    return keras.Model(inputs, x)

In [11]:
# Initialize models for each hospital
local_models = {}
for hospital in hospitals_data.keys():
    local_models[hospital] = create_unet_model()

In [12]:
num_rounds = 1
global_model = create_unet_model()

def federated_averaging(weight_list):
    avg_weights = []
    for weights in zip(*weight_list):
        avg_weights.append(np.mean(weights, axis=0))
    return avg_weights


In [26]:
for round_num in range(NUM_ROUNDS):
    print(f"Round {round_num + 1}/{NUM_ROUNDS}")
    local_weights = []
    
    for hospital, files in nodes.items():
        print(f"Training {hospital}...")
        train_dataset = create_dataset(files[:150])  # Use first 150 patients for training
        val_dataset = create_dataset(files[150:180])  # Use next 30 for validation

        # Create a fresh model to avoid memory leak
        local_model = create_unet_model()
        local_model.set_weights(global_model.get_weights())
        
        # Compile and train
        local_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
        history = local_model.fit(train_dataset, epochs=NUM_EPOCHS, validation_data=val_dataset)
        
        # Save trained weights
        local_weights.append(local_model.get_weights())

    # Update global model
    new_global_weights = federated_averaging(local_weights)
    global_model.set_weights(new_global_weights)

Round 1/5
Training Hospital_1...


TypeError: in user code:

    File "C:\Users\binwa\AppData\Local\Temp\ipykernel_49572\2802195307.py", line 4, in parse_function  *
        flair = nib.load(file_path + '_flair.nii').get_fdata()
    File "c:\Users\binwa\anaconda3\envs\mlenv\lib\site-packages\nibabel\loadsave.py", line 96, in load  *
        filename = _stringify_path(filename)
    File "c:\Users\binwa\anaconda3\envs\mlenv\lib\site-packages\nibabel\filename_parser.py", line 41, in _stringify_path  *
        return pathlib.Path(filepath_or_buffer).expanduser().as_posix()
    File "c:\Users\binwa\anaconda3\envs\mlenv\lib\pathlib.py", line 1082, in __new__  **
        self = cls._from_parts(args, init=False)
    File "c:\Users\binwa\anaconda3\envs\mlenv\lib\pathlib.py", line 707, in _from_parts
        drv, root, parts = self._parse_args(args)
    File "c:\Users\binwa\anaconda3\envs\mlenv\lib\pathlib.py", line 691, in _parse_args
        a = os.fspath(a)

    TypeError: expected str, bytes or os.PathLike object, not SymbolicTensor
