In [5]:

import numpy as np
import nibabel as nib
from tensorflow.keras.utils import to_categorical

def load_brats_data(dataset_year, missing_modalities=False):
    num_samples = 500 if dataset_year == 2021 else 400
    modalities = 4  # T1, T1ce, T2, FLAIR
    if missing_modalities:
        modalities = 3  # Remove one modality (e.g., T1ce)
    
    x_data = np.random.rand(num_samples, 128, 128, modalities)
    y_data = np.random.randint(0, 2, (num_samples, 128, 128, 1))  # Binary segmentation mask
    return x_data, y_data

# Simulate dataset loading for 5 hospitals
datasets = {
    1: load_brats_data(2021),
    2: load_brats_data(2020),
    3: load_brats_data(2022),
    4: load_brats_data(2019),
    5: load_brats_data(2021, missing_modalities=True),
}
    

TypeError: Descriptors cannot be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.
If you cannot immediately regenerate your protos, some other possible workarounds are:
 1. Downgrade the protobuf package to 3.20.x or lower.
 2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use pure-Python parsing and will be much slower).

More information: https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates

In [None]:

from tensorflow.keras import layers, models

def create_unet_model(input_shape=(128, 128, 4)):
    inputs = layers.Input(input_shape)
    conv1 = layers.Conv2D(64, 3, activation="relu", padding="same")(inputs)
    conv1 = layers.Conv2D(64, 3, activation="relu", padding="same")(conv1)
    pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = layers.Conv2D(128, 3, activation="relu", padding="same")(pool1)
    conv2 = layers.Conv2D(128, 3, activation="relu", padding="same")(conv2)
    up1 = layers.UpSampling2D(size=(2, 2))(conv2)
    concat1 = layers.Concatenate()([conv1, up1])
    outputs = layers.Conv2D(1, 1, activation="sigmoid")(concat1)
    model = models.Model(inputs, outputs)
    model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
    return model
    

In [None]:

import flwr as fl

class HospitalClient(fl.client.NumPyClient):
    def __init__(self, model, train_data, test_data):
        self.model = model
        self.train_data = train_data
        self.test_data = test_data

    def get_parameters(self):
        return self.model.get_weights()

    def fit(self, parameters, config):
        self.model.set_weights(parameters)
        x_train, y_train = self.train_data
        self.model.fit(x_train, y_train, epochs=1, verbose=0)
        return self.model.get_weights(), len(x_train), {}

    def evaluate(self, parameters, config):
        self.model.set_weights(parameters)
        x_test, y_test = self.test_data
        loss, accuracy = self.model.evaluate(x_test, y_test, verbose=0)
        return loss, len(x_test), {"accuracy": accuracy}
    

In [None]:

from multiprocessing import Process

def start_server():
    fl.server.start_server(config=fl.server.ServerConfig(num_rounds=3))

def start_client(hospital_id):
    model = create_unet_model(input_shape=(128, 128, datasets[hospital_id][0].shape[-1]))
    train_data = (datasets[hospital_id][0][:400], datasets[hospital_id][1][:400])
    test_data = (datasets[hospital_id][0][400:], datasets[hospital_id][1][400:])
    client = HospitalClient(model, train_data, test_data)
    fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=client)

server_process = Process(target=start_server)
server_process.start()

client_processes = []
for hospital_id in range(1, 6):
    p = Process(target=start_client, args=(hospital_id,))
    client_processes.append(p)
    p.start()

for p in client_processes:
    p.join()

server_process.terminate()
    

In [None]:

def evaluate_global_model(global_model, datasets):
    results = {}
    for hospital_id, (x_test, y_test) in datasets.items():
        loss, accuracy = global_model.evaluate(x_test[400:], y_test[400:], verbose=0)
        results[hospital_id] = {"loss": loss, "accuracy": accuracy}
    return results

global_model = create_unet_model()
global_results = evaluate_global_model(global_model, datasets)
print(global_results)
    

In [None]:

import matplotlib.pyplot as plt

def visualize_segmentation(model, x_test, y_test, index=0):
    prediction = model.predict(x_test[index:index+1])[0, :, :, 0]
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.title("Input Image")
    plt.imshow(x_test[index, :, :, 0], cmap="gray")
    plt.subplot(1, 3, 2)
    plt.title("Ground Truth")
    plt.imshow(y_test[index, :, :, 0], cmap="gray")
    plt.subplot(1, 3, 3)
    plt.title("Predicted Segmentation")
    plt.imshow(prediction, cmap="gray")
    plt.show()

x_test, y_test = datasets[1]
visualize_segmentation(global_model, x_test, y_test)
    