In [1]:
%matplotlib inline
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    Invertd,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
import tarfile
import numpy as np
from torch.utils.data import Subset
import copy
set_determinism(seed=0)
np.random.seed(0)
import torch.nn.functional as F
import time
import pandas as pd

#print_config()


## Read data

In [2]:
# read data
def get_spleen_dataset(data_dir):
    train_images = sorted(glob.glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
    print(train_images)
    train_labels = sorted(glob.glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))
    data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)]
    train_files, val_files = data_dicts[:-9], data_dicts[-9:]
    
    train_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
            RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=(96, 96, 96),
                pos=1,
                neg=1,
                num_samples=4,
                image_key="image",
                image_threshold=0,
            ),
        ]
    )
    val_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        ]
    )
    
    train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4)
    val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4)
    
    return train_ds, val_ds

In [5]:
# split the dataset into iid subsets
def split_dataset_iid(train_dataset, val_dataset, num_clients):
    #make the dataset evenly divisible by removing the last elements if needed
    train_dataset = train_dataset[:len(train_dataset) - len(train_dataset) % num_clients]
    val_dataset = val_dataset[:len(val_dataset) - len(val_dataset) % num_clients]
    
    #split the dataset into num_clients subsets randomly 
    train_indices = np.random.permutation(len(train_dataset))
    val_indices = np.random.permutation(len(val_dataset))
    train_client_splits = np.array_split(train_indices, num_clients)
    val_client_splits = np.array_split(val_indices, num_clients)
    
    #create a list of subsets
    train_client_datasets = [Subset(train_dataset, split) for split in train_client_splits]
    val_client_datasets = [Subset(val_dataset, split) for split in val_client_splits]
    
    return train_client_datasets, val_client_datasets

In [6]:
# split the dataset into non-iid subsets
def split_dataset_non_iid(train_dataset, val_dataset, num_clients):
    if len(train_dataset) % num_clients != 0 or len(val_dataset) % num_clients != 0:
         #split the dataset into num_clients subsets randomly 
        train_indices = np.random.permutation(len(train_dataset))
        val_indices = np.random.permutation(len(val_dataset))
        train_client_splits = np.array_split(train_indices, num_clients)
        val_client_splits = np.array_split(val_indices, num_clients)
        
        #create a list of subsets
        train_client_datasets = [Subset(train_dataset, split) for split in train_client_splits]
        val_client_datasets = [Subset(val_dataset, split) for split in val_client_splits]
        
        return train_client_datasets, val_client_datasets
    
    else: 
        while len(train_dataset) % num_clients == 0:
            train_dataset = train_dataset[:len(train_dataset) - 1]
        return split_dataset_non_iid(train_dataset, val_dataset, num_clients)

In [7]:
#set the data path
root_dir = os.path.dirname(os.path.dirname(os.path.abspath("federated_learning")))
data_dir = os.path.join(root_dir, "data/raw/Task09_Spleen/")
model_dir = os.path.join(root_dir, "models_gpu")

In [12]:
#load the entire dataset
train_ds, val_ds = get_spleen_dataset(data_dir)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)

#split the dataset into num_clients non IID subsets
num_clients = 4
train_client_datasets, val_client_datasets = split_dataset_non_iid(train_ds, val_ds, num_clients)

#check the number of samples in each client
for i, client_data in enumerate(val_client_datasets):
    print(f"Client {i+1} has {len(client_data)} samples.")     

 

['c:\\Users\\edmund\\Documents\\mlsys\\federated_learning\\data/raw/Task09_Spleen/imagesTr\\spleen_10.nii.gz', 'c:\\Users\\edmund\\Documents\\mlsys\\federated_learning\\data/raw/Task09_Spleen/imagesTr\\spleen_12.nii.gz', 'c:\\Users\\edmund\\Documents\\mlsys\\federated_learning\\data/raw/Task09_Spleen/imagesTr\\spleen_13.nii.gz', 'c:\\Users\\edmund\\Documents\\mlsys\\federated_learning\\data/raw/Task09_Spleen/imagesTr\\spleen_14.nii.gz', 'c:\\Users\\edmund\\Documents\\mlsys\\federated_learning\\data/raw/Task09_Spleen/imagesTr\\spleen_16.nii.gz', 'c:\\Users\\edmund\\Documents\\mlsys\\federated_learning\\data/raw/Task09_Spleen/imagesTr\\spleen_17.nii.gz', 'c:\\Users\\edmund\\Documents\\mlsys\\federated_learning\\data/raw/Task09_Spleen/imagesTr\\spleen_18.nii.gz', 'c:\\Users\\edmund\\Documents\\mlsys\\federated_learning\\data/raw/Task09_Spleen/imagesTr\\spleen_19.nii.gz', 'c:\\Users\\edmund\\Documents\\mlsys\\federated_learning\\data/raw/Task09_Spleen/imagesTr\\spleen_2.nii.gz', 'c:\\Users

Loading dataset: 100%|██████████| 32/32 [00:28<00:00,  1.12it/s]
Loading dataset: 100%|██████████| 9/9 [00:06<00:00,  1.37it/s]

Client 1 has 3 samples.
Client 2 has 2 samples.
Client 3 has 2 samples.
Client 4 has 2 samples.





# FedAvg with knowledge distillation

In [9]:
class FedAvg:
    def __init__(self, model, loss_function, device, num_clients, client_fraction, local_epochs, batch_size, lr=1e-4,
                 teacher_model=None, alpha=0.5, temperature=2.0):
        """
        Args:
            teacher_model: Pretrained teacher model for knowledge distillation.
            alpha: Weighting factor for distillation loss.
            temperature: Temperature for softening logits in distillation.
        """
        self.global_model = model
        self.loss_function = loss_function
        self.device = device
        self.num_clients = num_clients
        self.client_fraction = client_fraction
        self.local_epochs = local_epochs
        self.batch_size = batch_size
        self.lr = lr
        self.teacher_model = teacher_model.to(device) if teacher_model else None
        self.alpha = alpha
        self.temperature = temperature
        self.clients = []

        self.round_loss_values = []
        self.dice_metric_values = []

    def set_clients(self, train_datasets, val_datasets):
        self.clients = [{"train": train, "val": val} for train, val in zip(train_datasets, val_datasets)]

        # Track communication costs
        self.communication_cost = 0  # For KD-based lightweight models
        self.communication_time = 0  # Total time spent in communication (in seconds)
        self.param_count = 0  # Total number of parameters exchanged

    def calculate_model_size(self, model_state_dict):
        """
        Calculate the size of the model state in bytes.
        """
        return sum(param.numel() * param.element_size() for param in model_state_dict.values())
    
    def calculate_parameter_count(self, model_state_dict):
        """
        Count the number of parameters in the model state.
        """
        return sum(param.numel() for param in model_state_dict.values())

    def distillation_loss(self, student_logits, teacher_logits):
        """
        Compute the distillation loss using KL divergence between softened logits.
        Args:
            student_logits: Logits from the student model.
            teacher_logits: Logits from the teacher model.
        """
        student_soft = F.log_softmax(student_logits / self.temperature, dim=1)
        teacher_soft = F.softmax(teacher_logits / self.temperature, dim=1)
        return F.kl_div(student_soft, teacher_soft, reduction="batchmean") * (self.temperature ** 2)
    
    def client_update(self, client_data):
        """
        Perform local training on a client with optional knowledge distillation.
        Args:
            client_data: A dictionary with 'train' and 'val' datasets for the client.
        Returns:
            Updated weights (state_dict) of the local model, and its size.
        """
        # Initialize a local model
        local_model = copy.deepcopy(self.global_model)
        local_model.train()
        local_model.to(self.device)
        local_train_start = time.time()

        # Freeze the teacher model if provided
        if self.teacher_model:
            self.teacher_model.eval()

        train_loader = DataLoader(client_data["train"], batch_size=self.batch_size, shuffle=True, num_workers=4)
        optimizer = torch.optim.Adam(local_model.parameters(), lr=self.lr)

        total_loss = 0
        num_batches = 0

        # Perform local training    
        for _ in range(self.local_epochs):
            for batch_data in train_loader:
                inputs, labels = (
                    batch_data["image"].to(self.device),
                    batch_data["label"].to(self.device),
                )
                optimizer.zero_grad()
                outputs = local_model(inputs)

                # Compute Dice loss
                dice_loss = self.loss_function(outputs, labels)

                if self.teacher_model:
                    with torch.no_grad():
                        teacher_outputs = self.teacher_model(inputs)

                    # Compute distillation loss
                    distill_loss = self.distillation_loss(outputs, teacher_outputs)
                    distill_loss = torch.sigmoid(distill_loss)

                    # Combine Dice loss and distillation loss
                    task_specific_loss = dice_loss.item()
                    loss = self.alpha * distill_loss + (1 - self.alpha) * dice_loss
                    
                else:
                    loss = dice_loss

                loss.backward()
                optimizer.step()
                
                if self.teacher_model:
                    total_loss += task_specific_loss
                else:
                    total_loss += loss.item()
                
                num_batches += 1

        local_train_end = time.time()

        # Calculate the average loss for this client
        loss_per_client = total_loss / num_batches if num_batches > 0 else 0

        # Calculate model size for communication cost
        model_size = self.calculate_model_size(local_model.state_dict())

        # Calculate number of parameters passed for communication cost
        num_params = self.calculate_parameter_count(local_model.state_dict())

        local_train_time = local_train_end - local_train_start

        return local_model.state_dict(), loss_per_client, model_size, num_params, local_train_time
    
    def aggregate_updates(self, client_updates, client_sizes):
        """
        Aggregate client updates to update the global model.
        Args:
            client_updates: List of state_dicts from clients.
            client_sizes: List of dataset sizes for each client.
        Returns:
            Aggregated weights for the global model.
        """
        # Initialize global weights
        global_weights = copy.deepcopy(self.global_model.state_dict())
        total_size = sum(client_sizes)

        # Perform weighted aggregation
        for key in global_weights.keys():
            global_weights[key] = sum(
                client_updates[i][key] * client_sizes[i] / total_size for i in range(len(client_updates))
            )

        return global_weights

    def evaluate_global_model(self, val_datasets, post_pred, post_label, dice_metric):
        """
        Evaluate the global model on all clients' validation datasets.
        Args:
            val_datasets: List of validation datasets for each client.
            post_pred: Post-processing function for predictions.
            post_label: Post-processing function for labels.
            dice_metric: Metric instance for Dice score.
        Returns:
            Average Dice score across all validation datasets.
        """
        self.global_model.eval()
        dice_scores = []

        for val_data in val_datasets:
            val_loader = DataLoader(val_data, batch_size=1, num_workers=4)
            with torch.no_grad():
                for batch_data in val_loader:
                    val_inputs, val_labels = (
                        batch_data["image"].to(self.device),
                        batch_data["label"].to(self.device),
                    )
                    roi_size = (160, 160, 160)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, self.global_model)
                    val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                    val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                    dice_metric(y_pred=val_outputs, y=val_labels)

                # Aggregate Dice metric for this client
                dice_score = dice_metric.aggregate().item()
                dice_scores.append(dice_score)
                dice_metric.reset()

        return np.mean(dice_scores)

    def run(self, num_rounds, post_pred, post_label, dice_metric):
        """
        Run the FedAvg algorithm for the specified number of rounds.
        Args:
            num_rounds: Number of global training rounds.
            post_pred: Post-processing function for predictions.
            post_label: Post-processing function for labels.
            dice_metric: Metric instance for Dice score.
        """
        global_weights = self.global_model.state_dict()

        self.communication_time = 0  # Reset communication time
        self.param_count = 0  # Reset parameter count

        for round_num in range(num_rounds):
            print(f"--- Round {round_num + 1} ---")

            # Randomly select clients
            num_selected_clients = max(int(self.client_fraction * self.num_clients), 1)
            selected_clients = np.random.choice(self.num_clients, num_selected_clients, replace=False)

            client_updates = []
            client_sizes = []
            total_loss = 0

            round_start_time = time.time()
            
            # Each selected client performs local training
            for client_id in selected_clients:
                client_data = self.clients[client_id]
                client_size = len(client_data["train"])
                client_sizes.append(client_size)

                updated_weights, loss_per_client, model_size, num_params, local_train_time = self.client_update(client_data)
                client_updates.append(updated_weights)
                total_loss += loss_per_client

                # Track communication cost
                self.communication_time -= local_train_time
                # print(local_train_time, self.communication_time)
                self.communication_cost += model_size
                self.param_count += num_params

            round_end_time = time.time()

            self.communication_time += (round_end_time - round_start_time)
            # print(round_end_time - round_start_time, self.communication_time)

            # Aggregate updates to update global model weights
            global_weights = self.aggregate_updates(client_updates, client_sizes)
            self.global_model.load_state_dict(global_weights)

            # Evaluate the global model on all validation datasets
            avg_dice_score = self.evaluate_global_model(
                [client["val"] for client in self.clients],
                post_pred,
                post_label,
                dice_metric,
            )
            print(f"Average Dice Score after Round {round_num + 1}: {avg_dice_score:.4f}")

            self.dice_metric_values.append(avg_dice_score)
            avg_loss_per_round = total_loss / len(selected_clients) if len(selected_clients) > 0 else 0
            self.round_loss_values.append(avg_loss_per_round)

        # Print communication costs
        print(f"Total communication cost: {self.communication_cost / 1e6:.2f} MB")
        print(f"Total communication time: {self.communication_time:.2f} seconds")
        print(f"Total parameters exchanged: {self.param_count}")

# Model training

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [16]:
## KD non iid

# initialize the models
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")

student_model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(8, 16, 32, 64),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)

teacher_state_dict = torch.load(os.path.join(model_dir, "fedavg_r200_niid.pth"))

teacher_model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
teacher_model.load_state_dict(teacher_state_dict['global_model_state'])

num_clients = 4  # Total number of clients
client_fraction = 0.5  # Fraction of clients selected per round
local_epochs = 2  # Number of epochs each client trains locally
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([AsDiscrete(to_onehot=2)])

fedavg = FedAvg(
    model=student_model,
    loss_function=loss_function,
    device=device,
    num_clients=num_clients,
    client_fraction=0.5,
    local_epochs=2,
    batch_size=2,
    lr=1e-4,
    teacher_model=teacher_model,  # Pass the teacher model
    alpha=0.7,                    # Weight for distillation loss
    temperature=3.0               # Temperature for softening logits
)

fedavg.set_clients(train_client_datasets, val_client_datasets)

fedavg.run(
    num_rounds=200,           
    post_pred=post_pred,     
    post_label=post_label,   
    dice_metric=dice_metric  
)

torch.save({
    'round_losses': fedavg.round_loss_values,
    'dice_scores': fedavg.dice_metric_values,
    'model size': fedavg.communication_cost, ##bandwidth or data transfer
    'communication time': fedavg.communication_time, ## training and convergence time
    'parameter count': fedavg.param_count,
    'global_model_state': fedavg.global_model.state_dict()
}, 'fedavg_kd_r200_niid.pth')

#alt: cosine similarity or comparing mse between teacher and student or KL efficiency or comparing loss between teacher and student

  teacher_state_dict = torch.load(os.path.join(model_dir, "fedavg_r200_niid.pth"))


--- Round 1 ---
Average Dice Score after Round 1: 0.0132
--- Round 2 ---
Average Dice Score after Round 2: 0.0192
--- Round 3 ---
Average Dice Score after Round 3: 0.0288
--- Round 4 ---
Average Dice Score after Round 4: 0.0348
--- Round 5 ---
Average Dice Score after Round 5: 0.0382
--- Round 6 ---
Average Dice Score after Round 6: 0.0435
--- Round 7 ---
Average Dice Score after Round 7: 0.0515
--- Round 8 ---
Average Dice Score after Round 8: 0.0511
--- Round 9 ---
Average Dice Score after Round 9: 0.0521
--- Round 10 ---
Average Dice Score after Round 10: 0.0589
--- Round 11 ---
Average Dice Score after Round 11: 0.0602
--- Round 12 ---
Average Dice Score after Round 12: 0.0735
--- Round 13 ---
Average Dice Score after Round 13: 0.0630
--- Round 14 ---
Average Dice Score after Round 14: 0.0747
--- Round 15 ---
Average Dice Score after Round 15: 0.0792
--- Round 16 ---
Average Dice Score after Round 16: 0.1028
--- Round 17 ---
Average Dice Score after Round 17: 0.0751
--- Round 18 --

In [17]:
## No KD non iid

fedavg_nkd = FedAvg(
    model=model,
    loss_function=loss_function,
    device=device,
    num_clients=num_clients,
    client_fraction=0.5,
    local_epochs=2,
    batch_size=2,
    lr=1e-4,
    teacher_model=None,  # Pass the teacher model
    alpha=0.7,                    # Weight for distillation loss
    temperature=3.0               # Temperature for softening logits
)

fedavg_nkd.set_clients(train_client_datasets, val_client_datasets)

fedavg_nkd.run(
    num_rounds=200,           
    post_pred=post_pred,     
    post_label=post_label,   
    dice_metric=dice_metric  
)

torch.save({
    'round_losses': fedavg_nkd.round_loss_values,
    'dice_scores': fedavg_nkd.dice_metric_values,
    'model size': fedavg_nkd.communication_cost,
    'communication time': fedavg_nkd.communication_time,
    'parameter count': fedavg_nkd.param_count,
    'global_model_state': fedavg_nkd.global_model.state_dict()
}, 'fedavg_nkd_r200_niid.pth')

--- Round 1 ---
Average Dice Score after Round 1: 0.0144
--- Round 2 ---
Average Dice Score after Round 2: 0.0163
--- Round 3 ---
Average Dice Score after Round 3: 0.0246
--- Round 4 ---
Average Dice Score after Round 4: 0.0328
--- Round 5 ---
Average Dice Score after Round 5: 0.0402
--- Round 6 ---
Average Dice Score after Round 6: 0.0437
--- Round 7 ---
Average Dice Score after Round 7: 0.0446
--- Round 8 ---
Average Dice Score after Round 8: 0.0463
--- Round 9 ---
Average Dice Score after Round 9: 0.0483
--- Round 10 ---
Average Dice Score after Round 10: 0.0472
--- Round 11 ---
Average Dice Score after Round 11: 0.0476
--- Round 12 ---
Average Dice Score after Round 12: 0.0521
--- Round 13 ---
Average Dice Score after Round 13: 0.0539
--- Round 14 ---
Average Dice Score after Round 14: 0.0591
--- Round 15 ---
Average Dice Score after Round 15: 0.0596
--- Round 16 ---
Average Dice Score after Round 16: 0.0683
--- Round 17 ---
Average Dice Score after Round 17: 0.0564
--- Round 18 --

In [13]:
num_clients = 4
train_client_datasets, val_client_datasets = split_dataset_iid(train_ds, val_ds, num_clients)

#check the number of samples in each client
for i, client_data in enumerate(val_client_datasets):
    print(f"Client {i+1} has {len(client_data)} samples.")     

Client 1 has 2 samples.
Client 2 has 2 samples.
Client 3 has 2 samples.
Client 4 has 2 samples.


In [14]:
## KD iid

model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")

student_model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(8, 16, 32, 64),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)

teacher_state_dict = torch.load(os.path.join(model_dir, "fedavg_r200_iid.pth"))

teacher_model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
teacher_model.load_state_dict(teacher_state_dict['global_model_state'])

# FedAvg configuration
num_clients = 4  # Total number of clients
client_fraction = 0.5  # Fraction of clients selected per round
local_epochs = 2  # Number of epochs each client trains locally
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([AsDiscrete(to_onehot=2)])

fedavg_iid = FedAvg(
    model=student_model,
    loss_function=loss_function,
    device=device,
    num_clients=num_clients,
    client_fraction=0.5,
    local_epochs=2,
    batch_size=2,
    lr=1e-4,
    teacher_model=teacher_model,  # Pass the teacher model
    alpha=0.7,                    # Weight for distillation loss
    temperature=3.0               # Temperature for softening logits
)

fedavg_iid.set_clients(train_client_datasets, val_client_datasets)

fedavg_iid.run(
    num_rounds=200,           
    post_pred=post_pred,     
    post_label=post_label,   
    dice_metric=dice_metric  
)

torch.save({
    'round_losses': fedavg_iid.round_loss_values,
    'dice_scores': fedavg_iid.dice_metric_values,
    'model size': fedavg_iid.communication_cost,
    'communication time': fedavg_iid.communication_time,
    'parameter count': fedavg_iid.param_count,
    'global_model_state': fedavg_iid.global_model.state_dict()
}, 'fedavg_kd_r200_iid.pth')

  teacher_state_dict = torch.load(os.path.join(model_dir, "fedavg_r200_iid.pth"))


--- Round 1 ---
Average Dice Score after Round 1: 0.0130
--- Round 2 ---
Average Dice Score after Round 2: 0.0188
--- Round 3 ---
Average Dice Score after Round 3: 0.0273
--- Round 4 ---
Average Dice Score after Round 4: 0.0340
--- Round 5 ---
Average Dice Score after Round 5: 0.0390
--- Round 6 ---
Average Dice Score after Round 6: 0.0418
--- Round 7 ---
Average Dice Score after Round 7: 0.0479
--- Round 8 ---
Average Dice Score after Round 8: 0.0500
--- Round 9 ---
Average Dice Score after Round 9: 0.0526
--- Round 10 ---
Average Dice Score after Round 10: 0.0623
--- Round 11 ---
Average Dice Score after Round 11: 0.0646
--- Round 12 ---
Average Dice Score after Round 12: 0.0744
--- Round 13 ---
Average Dice Score after Round 13: 0.0745
--- Round 14 ---
Average Dice Score after Round 14: 0.0593
--- Round 15 ---
Average Dice Score after Round 15: 0.0643
--- Round 16 ---
Average Dice Score after Round 16: 0.0916
--- Round 17 ---
Average Dice Score after Round 17: 0.0753
--- Round 18 --

In [None]:
## No KD iid

fedavg_nkd_iid = FedAvg(
    model=model,
    loss_function=loss_function,
    device=device,
    num_clients=num_clients,
    client_fraction=0.5,
    local_epochs=2,
    batch_size=2,
    lr=1e-4,
    teacher_model=None,  # Pass the teacher model
    alpha=0.7,                    # Weight for distillation loss
    temperature=3.0               # Temperature for softening logits
)

fedavg_nkd_iid.set_clients(train_client_datasets, val_client_datasets)

fedavg_nkd_iid.run(
    num_rounds=200,           
    post_pred=post_pred,     
    post_label=post_label,   
    dice_metric=dice_metric  
)

torch.save({
    'round_losses': fedavg_nkd_iid.round_loss_values,
    'dice_scores': fedavg_nkd_iid.dice_metric_values,
    'model size': fedavg_nkd_iid.communication_cost,
    'communication time': fedavg_nkd_iid.communication_time,
    'parameter count': fedavg_nkd_iid.param_count,
    'global_model_state': fedavg_nkd_iid.global_model.state_dict()
}, 'fedavg_nkd_r200_iid.pth')

In [None]:
data = []
for i in range(200):
    round_data = {
        'round': i+1,
        'round_losses': fedavg.round_loss_values[i],
        'dice_scores': fedavg.dice_metric_values[i] if i % 2 != 0 else '',  # 'NA' for even rounds
        'model_size': fedavg.communication_cost if i == 0 else '',
        'communication_time': fedavg.communication_time if i == 0 else '',
        'parameter_count': fedavg.param_count if i == 0 else ''
    }
    data.append(round_data)

df = pd.DataFrame(data)
df.to_csv('fedavg_kd_r200_niid_plot.csv', index=False)

print("CSV file saved as 'fedavg_kd_r200_niid_plot.csv'")
df

In [21]:
data = []
for i in range(200):
    round_data = {
        'round': i+1,
        'round_losses': fedavg_nkd.round_loss_values[i],
        'dice_scores': fedavg_nkd.dice_metric_values[i] if i % 2 != 0 else '',  # 'NA' for even rounds
        'model_size': fedavg_nkd.communication_cost if i == 0 else '',
        'communication_time': fedavg_nkd.communication_time if i == 0 else '',
        'parameter_count': fedavg_nkd.param_count if i == 0 else ''
    }
    data.append(round_data)

df = pd.DataFrame(data)
df.to_csv('fedavg_nkd_r200_niid_plot.csv', index=False)

print("CSV file saved as 'fedavg_nkd_r200_niid_plot.csv'")
df

CSV file saved as 'fedavg_nkd_r200_niid_plot.csv'


Unnamed: 0,round,round_losses,dice_scores,model_size,communication_time,parameter_count
0,1,0.679961,,7698219200,15.656538,1924548000
1,2,0.664540,0.016262,,,
2,3,0.642148,,,,
3,4,0.630550,0.032777,,,
4,5,0.622092,,,,
...,...,...,...,...,...,...
195,196,0.250318,0.16247,,,
196,197,0.226813,,,,
197,198,0.237786,0.063828,,,
198,199,0.233401,,,,


In [17]:
data = []
for i in range(200):
    round_data = {
        'round': i+1,
        'round_losses': fedavg_iid.round_loss_values[i],
        'dice_scores': fedavg_iid.dice_metric_values[i] if i % 2 != 0 else '',  # 'NA' for even rounds
        'model_size': fedavg_iid.communication_cost if i == 0 else '',
        'communication_time': fedavg_iid.communication_time if i == 0 else '',
        'parameter_count': fedavg_iid.param_count if i == 0 else ''
    }
    data.append(round_data)

df = pd.DataFrame(data)
df.to_csv('fedavg_kd_r200_iid_plot.csv', index=False)

print("CSV file saved as 'fedavg_kd_r200_iid_plot.csv'")
df

CSV file saved as 'fedavg_kd_r200_iid_plot.csv'


Unnamed: 0,round,round_losses,dice_scores,model_size,communication_time,parameter_count
0,1,0.623015,,478116800,13.011005,119524000
1,2,0.605093,0.018751,,,
2,3,0.603996,,,,
3,4,0.591854,0.03398,,,
4,5,0.584480,,,,
...,...,...,...,...,...,...
195,196,0.218153,0.889146,,,
196,197,0.218031,,,,
197,198,0.203979,0.814147,,,
198,199,0.232753,,,,


In [None]:
data = []
for i in range(200):
    round_data = {
        'round': i+1,
        'round_losses': fedavg_nkd_iid.round_loss_values[i],
        'dice_scores': fedavg_nkd_iid.dice_metric_values[i] if i % 2 != 0 else '',  # 'NA' for even rounds
        'model_size': fedavg_nkd_iid.communication_cost if i == 0 else '',
        'communication_time': fedavg_nkd_iid.communication_time if i == 0 else '',
        'parameter_count': fedavg_nkd_iid.param_count if i == 0 else ''
    }
    data.append(round_data)

df = pd.DataFrame(data)
df.to_csv('fedavg_nkd_r200_iid_plot.csv', index=False)

print("CSV file saved as 'fedavg_nkd_r200_iid_plot.csv'")
df