In [None]:
%matplotlib inline
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    Invertd,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
import tarfile
import numpy as np
from torch.utils.data import Subset
import copy
set_determinism(seed=0)
np.random.seed(0)

#print_config()


## Read and split data

In [None]:
def get_spleen_dataset(data_dir):
    train_images = sorted(glob.glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
    print(train_images)
    train_labels = sorted(glob.glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))
    data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)]
    train_files, val_files = data_dicts[:-9], data_dicts[-9:]
    
    train_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
            RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=(96, 96, 96),
                pos=1,
                neg=1,
                num_samples=4,
                image_key="image",
                image_threshold=0,
            ),
        ]
    )
    val_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        ]
    )
    
    train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4)
    val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4)
    
    return train_ds, val_ds

In [None]:
def split_dataset_iid(train_dataset, val_dataset, num_clients):
    #make the dataset evenly divisible by removing the last elements if needed
    train_dataset = train_dataset[:len(train_dataset) - len(train_dataset) % num_clients]
    val_dataset = val_dataset[:len(val_dataset) - len(val_dataset) % num_clients]
    
    #split the dataset into num_clients subsets randomly 
    train_indices = np.random.permutation(len(train_dataset))
    val_indices = np.random.permutation(len(val_dataset))
    train_client_splits = np.array_split(train_indices, num_clients)
    val_client_splits = np.array_split(val_indices, num_clients)
    
    #create a list of subsets
    train_client_datasets = [Subset(train_dataset, split) for split in train_client_splits]
    val_client_datasets = [Subset(val_dataset, split) for split in val_client_splits]
    
    return train_client_datasets, val_client_datasets

In [None]:
def split_dataset_non_iid(train_dataset, val_dataset, num_clients):
    if len(train_dataset) % num_clients != 0 or len(val_dataset) % num_clients != 0:
         #split the dataset into num_clients subsets randomly 
        train_indices = np.random.permutation(len(train_dataset))
        val_indices = np.random.permutation(len(val_dataset))
        train_client_splits = np.array_split(train_indices, num_clients)
        val_client_splits = np.array_split(val_indices, num_clients)
        
        #create a list of subsets
        train_client_datasets = [Subset(train_dataset, split) for split in train_client_splits]
        val_client_datasets = [Subset(val_dataset, split) for split in val_client_splits]
        
        return train_client_datasets, val_client_datasets
    
    else: 
        while len(train_dataset) % num_clients == 0:
            train_dataset = train_dataset[:len(train_dataset) - 1]
        return split_dataset_non_iid(train_dataset, val_dataset, num_clients)

In [None]:
#set the data path
root_dir = os.path.dirname(os.path.dirname(os.path.abspath("federated_learning")))
data_dir = os.path.join(root_dir, "data/raw/Task09_Spleen/")

In [81]:
#load the entire dataset
train_ds, val_ds = get_spleen_dataset(data_dir)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)

#split the dataset into num_clients IID subsets
num_clients = 4
train_client_datasets, val_client_datasets = split_dataset_iid(train_ds, val_ds, num_clients)

#check the number of samples in each client
for i, client_data in enumerate(val_client_datasets):
    print(f"Client {i+1} has {len(client_data)} samples.")      

['/Users/danyu/Desktop/SJTU/Distributed_ML_Systems/federated_learning/data/raw/Task09_Spleen/imagesTr/spleen_10.nii.gz', '/Users/danyu/Desktop/SJTU/Distributed_ML_Systems/federated_learning/data/raw/Task09_Spleen/imagesTr/spleen_12.nii.gz', '/Users/danyu/Desktop/SJTU/Distributed_ML_Systems/federated_learning/data/raw/Task09_Spleen/imagesTr/spleen_13.nii.gz', '/Users/danyu/Desktop/SJTU/Distributed_ML_Systems/federated_learning/data/raw/Task09_Spleen/imagesTr/spleen_14.nii.gz', '/Users/danyu/Desktop/SJTU/Distributed_ML_Systems/federated_learning/data/raw/Task09_Spleen/imagesTr/spleen_16.nii.gz', '/Users/danyu/Desktop/SJTU/Distributed_ML_Systems/federated_learning/data/raw/Task09_Spleen/imagesTr/spleen_17.nii.gz', '/Users/danyu/Desktop/SJTU/Distributed_ML_Systems/federated_learning/data/raw/Task09_Spleen/imagesTr/spleen_18.nii.gz', '/Users/danyu/Desktop/SJTU/Distributed_ML_Systems/federated_learning/data/raw/Task09_Spleen/imagesTr/spleen_19.nii.gz', '/Users/danyu/Desktop/SJTU/Distributed_

Loading dataset: 100%|██████████| 32/32 [00:42<00:00,  1.33s/it]
Loading dataset: 100%|██████████| 9/9 [00:13<00:00,  1.53s/it]

Client 1 has 2 samples.
Client 2 has 2 samples.
Client 3 has 2 samples.
Client 4 has 2 samples.





# FedAvg

In [88]:
class FedAvg:
    def __init__(self, model, loss_function, device, num_clients, client_fraction, local_epochs, batch_size, lr=1e-4):
        self.global_model = model
        self.loss_function = loss_function
        self.device = device
        self.num_clients = num_clients
        self.client_fraction = client_fraction
        self.local_epochs = local_epochs
        self.batch_size = batch_size
        self.lr = lr
        self.clients = []  # Placeholder for client datasets
        
        self.round_loss_values = []  # To store round-wise training losses
        self.dice_metric_values = []  # To store round-wise Dice scores

    def set_clients(self, train_datasets, val_datasets):
        #create a list of dictionaries with train, val datasets for each client
        self.clients = [{"train": train, "val": val} for train, val in zip(train_datasets, val_datasets)]

    def client_update(self, client_data):
        """
        Perform local training on a client.
        Args:
            client_data: A dictionary with 'train' and 'val' datasets for the client.
        Returns:
            Updated weights (state_dict) of the local model.
        """
        #initialize a local model
        local_model = copy.deepcopy(self.global_model)
        local_model.train()
        local_model.to(self.device)

        #create data loader for the client's training dataset
        train_loader = DataLoader(client_data["train"], batch_size=2, shuffle=True, num_workers=4)
        # Optimizer for the local training
        optimizer = torch.optim.Adam(local_model.parameters(), lr=self.lr)
        
        total_loss = 0
        num_batches = 0
        
        # Perform local training
        for _ in range(self.local_epochs):
            for batch_data in train_loader:
                inputs, labels = (
                    batch_data["image"].to(self.device),
                    batch_data["label"].to(self.device),
                )
                optimizer.zero_grad()
                outputs = local_model(inputs)
                loss = self.loss_function(outputs, labels)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                num_batches += 1

        # Calculate the average loss for this client
        loss_per_client = total_loss / num_batches if num_batches > 0 else 0
        # Return the updated model weights
        return local_model.state_dict(), loss_per_client

    def aggregate_updates(self, client_updates, client_sizes):
        """
        Aggregate client updates to update the global model.
        Args:
            client_updates: List of state_dicts from clients.
            client_sizes: List of dataset sizes for each client.
        Returns:
            Aggregated weights for the global model.
        """
        # Initialize global weights
        global_weights = copy.deepcopy(self.global_model.state_dict())
        total_size = sum(client_sizes)

        # Perform weighted aggregation
        for key in global_weights.keys():
            global_weights[key] = sum(
                client_updates[i][key] * client_sizes[i] / total_size for i in range(len(client_updates))
            )

        return global_weights

    def evaluate_global_model(self, val_datasets, post_pred, post_label, dice_metric):
        """
        Evaluate the global model on all clients' validation datasets.
        Args:
            val_datasets: List of validation datasets for each client.
            post_pred: Post-processing function for predictions.
            post_label: Post-processing function for labels.
            dice_metric: Metric instance for Dice score.
        Returns:
            Average Dice score across all validation datasets.
        """
        self.global_model.eval()
        dice_scores = []

        for val_data in val_datasets:
            val_loader = DataLoader(val_data, batch_size=1, num_workers=4)
            with torch.no_grad():
                for batch_data in val_loader:
                    val_inputs, val_labels = (
                        batch_data["image"].to(self.device),
                        batch_data["label"].to(self.device),
                    )
                    roi_size = (160, 160, 160)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, self.global_model)
                    val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                    val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                    dice_metric(y_pred=val_outputs, y=val_labels)

                # Aggregate Dice metric for this client
                dice_score = dice_metric.aggregate().item()
                dice_scores.append(dice_score)
                dice_metric.reset()

        # Return the average Dice score
        return np.mean(dice_scores)

    def run(self, num_rounds, post_pred, post_label, dice_metric):
        """
        Run the FedAvg algorithm for the specified number of rounds.
        Args:
            num_rounds: Number of global training rounds.
            post_pred: Post-processing function for predictions.
            post_label: Post-processing function for labels.
            dice_metric: Metric instance for Dice score.
        """
        global_weights = self.global_model.state_dict()

        for round_num in range(num_rounds):
            print(f"--- Round {round_num + 1} ---")

            # Randomly select clients
            num_selected_clients = max(int(self.client_fraction * self.num_clients), 1)
            selected_clients = np.random.choice(self.num_clients, num_selected_clients, replace=False)

            client_updates = []
            client_sizes = []
            total_loss = 0

            # Each selected client performs local training
            for client_id in selected_clients:
                client_data = self.clients[client_id]
                client_size = len(client_data["train"])
                client_sizes.append(client_size)

                updated_weights, loss_per_client = self.client_update(client_data)
                client_updates.append(updated_weights)
                total_loss += loss_per_client

            # Aggregate updates to update global model weights
            global_weights = self.aggregate_updates(client_updates, client_sizes)
            self.global_model.load_state_dict(global_weights)

            # Evaluate the global model on all validation datasets
            avg_dice_score = self.evaluate_global_model(
                [client["val"] for client in self.clients],
                post_pred,
                post_label,
                dice_metric,
            )
            print(f"Average Dice Score after Round {round_num + 1}: {avg_dice_score:.4f}")

            self.dice_metric_values.append(avg_dice_score)
            avg_loss_per_round = total_loss / len(selected_clients) if len(selected_clients) > 0 else 0
            self.round_loss_values.append(avg_loss_per_round)

# FedCluster

In [None]:
#implement FedCluster***: https://arxiv.org/pdf/2009.10748
class FedCluster:
    def __init__(self, model, loss_function, device, num_clients, client_fraction, local_epochs, batch_size, lr=1e-4, num_clusters=2):
        self.global_model = model
        self.loss_function = loss_function
        self.device = device
        self.num_clients = num_clients
        self.client_fraction = client_fraction
        self.local_epochs = local_epochs
        self.batch_size = batch_size
        self.lr = lr
        self.num_clusters = num_clusters
        self.clients = []  
        self.clusters = []  
        
        self.round_loss_values = []  
        self.dice_metric_values = []  

    def set_clients(self, train_datasets, val_datasets):
        # Create a list of dictionaries with train and val datasets for each client
        self.clients = [{"train": train, "val": val} for train, val in zip(train_datasets, val_datasets)]
        self.create_clusters()

    def create_clusters(self):
        # Randomly group clients into clusters
        client_ids = np.arange(self.num_clients)
        np.random.shuffle(client_ids)
        cluster_size = self.num_clients // self.num_clusters
        self.clusters = [client_ids[i:i + cluster_size] for i in range(0, len(client_ids), cluster_size)]

    def client_update(self, client_data):
        """
        Perform local training on a client.
        """
        local_model = copy.deepcopy(self.global_model)
        local_model.train()
        local_model.to(self.device)

        train_loader = DataLoader(client_data["train"], batch_size=self.batch_size, shuffle=True, num_workers=4)
        optimizer = torch.optim.Adam(local_model.parameters(), lr=self.lr)

        total_loss = 0
        num_batches = 0
        
        for _ in range(self.local_epochs):
            for batch_data in train_loader:
                inputs, labels = (
                    batch_data["image"].to(self.device),
                    batch_data["label"].to(self.device),
                )
                optimizer.zero_grad()
                outputs = local_model(inputs)
                loss = self.loss_function(outputs, labels)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                num_batches += 1
        
        loss_per_client = total_loss / num_batches if num_batches > 0 else 0
        return local_model.state_dict(), loss_per_client

    def aggregate_updates(self, client_updates, client_sizes):
        """
        Aggregate client updates to update the global model.
        """
        global_weights = copy.deepcopy(self.global_model.state_dict())
        total_size = sum(client_sizes)

        for key in global_weights.keys():
            global_weights[key] = sum(
                client_updates[i][key] * client_sizes[i] / total_size for i in range(len(client_updates))
            )

        return global_weights

    def run(self, num_rounds, post_pred, post_label, dice_metric):
        """
        Run the FedCluster algorithm.
        """
        global_weights = self.global_model.state_dict()

        for round_num in range(num_rounds):
            print(f"--- Round {round_num + 1} ---")
            for cluster_idx, cluster in enumerate(self.clusters):
                print(f"Processing Cluster {cluster_idx + 1}")

                # Randomly select clients from the cluster
                num_selected_clients = max(int(self.client_fraction * len(cluster)), 1)
                selected_clients = np.random.choice(cluster, num_selected_clients, replace=False)

                client_updates = []
                client_sizes = []
                total_loss = 0

                # Each selected client performs local training
                for client_id in selected_clients:
                    client_data = self.clients[client_id]
                    client_size = len(client_data["train"])
                    client_sizes.append(client_size)

                    updated_weights, loss_per_client = self.client_update(client_data)
                    client_updates.append(updated_weights)
                    total_loss += loss_per_client
                    
                # Aggregate updates from the cluster
                global_weights = self.aggregate_updates(client_updates, client_sizes)
                self.global_model.load_state_dict(global_weights)

            # Evaluate the global model on all validation datasets
            avg_dice_score = self.evaluate_global_model(
                [client["val"] for client in self.clients],
                post_pred,
                post_label,
                dice_metric,
            )
            print(f"Average Dice Score after Round {round_num + 1}: {avg_dice_score:.4f}")

            self.dice_metric_values.append(avg_dice_score)
            avg_loss_per_round = total_loss / len(selected_clients) if len(selected_clients) > 0 else 0
            self.round_loss_values.append(avg_loss_per_round)
            
    def evaluate_global_model(self, val_datasets, post_pred, post_label, dice_metric):
        """
        Evaluate the global model on all clients' validation datasets.
        """
        self.global_model.eval()
        dice_scores = []

        for val_data in val_datasets:
            val_loader = DataLoader(val_data, batch_size=1, num_workers=4)
            with torch.no_grad():
                for batch_data in val_loader:
                    val_inputs, val_labels = (
                        batch_data["image"].to(self.device),
                        batch_data["label"].to(self.device),
                    )
                    roi_size = (160, 160, 160)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, self.global_model)
                    val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                    val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                    dice_metric(y_pred=val_outputs, y=val_labels)

                dice_score = dice_metric.aggregate().item()
                dice_scores.append(dice_score)
                dice_metric.reset()

        return np.mean(dice_scores)


## Training FedAvg

In [90]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")

num_clients = 4  # Total number of clients
client_fraction = 0.5  # Fraction of clients selected per round
local_epochs = 2  # Number of epochs each client trains locally
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([AsDiscrete(to_onehot=2)])

fedavg = FedAvg(
    model=model,
    loss_function=loss_function,
    device=device,
    num_clients=num_clients,
    client_fraction=0.5,  # Fraction of clients participating in each round
    local_epochs=2,       # Number of local epochs per client
    batch_size=2,         # Batch size for local training
    lr=1e-4               # Learning rate
)

fedavg.set_clients(train_client_datasets, val_client_datasets)

fedavg.run(
    num_rounds=200,           # Number of global rounds
    post_pred=post_pred,     # Post-processing for predictions
    post_label=post_label,   # Post-processing for labels
    dice_metric=dice_metric  # Dice metric instance
)

torch.save({
    'round_losses': fedavg.round_loss_values,
    'dice_scores': fedavg.dice_metric_values,
    'global_model_state': fedavg.global_model.state_dict()
}, 'fedavg_r200_iid.pth')

# Training FedCluster

In [None]:
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")


num_clients = 4  # Total number of clients
client_fraction = 0.5  # Fraction of clients selected per round
local_epochs = 2  # Number of epochs each client trains locally
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([AsDiscrete(to_onehot=2)])

fedcluster = FedCluster(
    model=model,
    loss_function=loss_function,
    device=device,
    num_clients=num_clients,
    client_fraction=client_fraction,
    local_epochs=local_epochs,
    batch_size=2,
    lr=1e-4,
    num_clusters=2,
)

fedcluster.set_clients(train_client_datasets, val_client_datasets)

num_rounds = 200
fedcluster.run(num_rounds, post_pred, post_label, dice_metric)

torch.save({
    'round_losses': fedcluster.round_loss_values,
    'dice_scores': fedcluster.dice_metric_values,
    'global_model_state': fedcluster.global_model.state_dict()
}, 'fedc_r200_iid.pth')
