In [1]:
%matplotlib inline
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    Invertd,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
import tarfile
import numpy as np
from torch.utils.data import Subset
import copy
set_determinism(seed=0)
import torch.nn.functional as F
import time
import pandas as pd


#print_config()

# Read data

In [2]:
# read data
def get_spleen_dataset(data_dir):
    train_images = sorted(glob.glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
    print(train_images)
    train_labels = sorted(glob.glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))
    data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)]
    train_files, val_files = data_dicts[:-9], data_dicts[-9:]
    
    train_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
            RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=(96, 96, 96),
                pos=1,
                neg=1,
                num_samples=4,
                image_key="image",
                image_threshold=0,
            ),
            # user can also add other random transforms
            # RandAffined(
            #     keys=['image', 'label'],
            #     mode=('bilinear', 'nearest'),
            #     prob=1.0, spatial_size=(96, 96, 96),
            #     rotate_range=(0, 0, np.pi/15),
            #     scale_range=(0.1, 0.1, 0.1)),
        ]
    )
    val_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        ]
    )
    
    train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4)
    val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4)
    
    return train_ds, val_ds

# FedCluster with knowledge distillation

In [3]:
class FedCluster:
    def __init__(self, model, loss_function, device, num_clients, client_fraction, local_epochs, batch_size, lr=1e-4, num_clusters=2, teacher_model=None, alpha=0.5, temperature=2.0):
        self.global_model = model
        self.loss_function = loss_function
        self.device = device
        self.num_clients = num_clients
        self.client_fraction = client_fraction
        self.local_epochs = local_epochs
        self.batch_size = batch_size
        self.lr = lr
        self.num_clusters = num_clusters
        self.clients = []  
        self.clusters = []
        self.teacher_model = teacher_model.to(device) if teacher_model else None
        self.alpha = alpha
        self.temperature = temperature  
        
        self.round_loss_values = []  
        self.dice_metric_values = []  

    def set_clients(self, train_datasets, val_datasets):
        # Create a list of dictionaries with train and val datasets for each client
        self.clients = [{"train": train, "val": val} for train, val in zip(train_datasets, val_datasets)]
        self.create_clusters()

        # Track communication costs
        self.communication_cost = 0  # For KD-based lightweight models
        self.communication_time = 0  # Total time spent in communication (in seconds)
        self.param_count = 0  # Total number of parameters exchanged

    def create_clusters(self):
        # Randomly group clients into clusters
        client_ids = np.arange(self.num_clients)
        np.random.shuffle(client_ids)
        cluster_size = self.num_clients // self.num_clusters
        self.clusters = [client_ids[i:i + cluster_size] for i in range(0, len(client_ids), cluster_size)]

    def calculate_model_size(self, model_state_dict):
        """
        Calculate the size of the model state in bytes.
        """
        return sum(param.numel() * param.element_size() for param in model_state_dict.values())
    
    def calculate_parameter_count(self, model_state_dict):
        """
        Count the number of parameters in the model state.
        """
        return sum(param.numel() for param in model_state_dict.values())

    def distillation_loss(self, student_logits, teacher_logits):
        """
        Compute the distillation loss using KL divergence between softened logits.
        Args:
            student_logits: Logits from the student model.
            teacher_logits: Logits from the teacher model.
        """
        student_soft = F.log_softmax(student_logits / self.temperature, dim=1)
        teacher_soft = F.softmax(teacher_logits / self.temperature, dim=1)
        return F.kl_div(student_soft, teacher_soft, reduction="batchmean") * (self.temperature ** 2)

    def client_update(self, client_data):
        """
        Perform local training on a client.
        """
        local_model = copy.deepcopy(self.global_model)
        local_model.train()
        local_model.to(self.device)
        local_train_start = time.time()

        # Freeze the teacher model if provided
        if self.teacher_model:
            self.teacher_model.eval()

        train_loader = DataLoader(client_data["train"], batch_size=self.batch_size, shuffle=True, num_workers=4)
        optimizer = torch.optim.Adam(local_model.parameters(), lr=self.lr)

        total_loss = 0
        num_batches = 0
        
        for _ in range(self.local_epochs):
            for batch_data in train_loader:
                inputs, labels = (
                    batch_data["image"].to(self.device),
                    batch_data["label"].to(self.device),
                )
                optimizer.zero_grad()
                outputs = local_model(inputs)

                # Compute Dice loss
                dice_loss = self.loss_function(outputs, labels)

                if self.teacher_model:
                    with torch.no_grad():
                        teacher_outputs = self.teacher_model(inputs)
                    
                    # Compute distillation loss
                    distill_loss = self.distillation_loss(outputs, teacher_outputs)
                    distill_loss = torch.sigmoid(distill_loss)

                    # Combine Dice loss and distillation loss
                    task_specific_loss = dice_loss.item()
                    loss = self.alpha * distill_loss + (1 - self.alpha) * dice_loss
                else:
                    loss = dice_loss

                loss.backward()
                optimizer.step()

                if self.teacher_model:
                    total_loss += task_specific_loss
                else:
                    total_loss += loss.item()
                num_batches += 1
        
        local_train_end = time.time()

        loss_per_client = total_loss / num_batches if num_batches > 0 else 0

        # Calculate model size for communication cost
        model_size = self.calculate_model_size(local_model.state_dict())

        # Calculate number of parameters passed for communication cost
        num_params = self.calculate_parameter_count(local_model.state_dict())

        local_train_time = local_train_end - local_train_start

        return local_model.state_dict(), loss_per_client, model_size, num_params, local_train_time

    def aggregate_updates(self, client_updates, client_sizes):
        """
        Aggregate client updates to update the global model.
        """
        global_weights = copy.deepcopy(self.global_model.state_dict())
        total_size = sum(client_sizes)

        for key in global_weights.keys():
            global_weights[key] = sum(
                client_updates[i][key] * client_sizes[i] / total_size for i in range(len(client_updates))
            )

        return global_weights

    def run(self, num_rounds, post_pred, post_label, dice_metric):
        """
        Run the FedCluster algorithm.
        """
        global_weights = self.global_model.state_dict()

        self.communication_time = 0  # Reset communication time
        self.param_count = 0  # Reset parameter count

        for round_num in range(num_rounds):
            print(f"--- Round {round_num + 1} ---")
            for cluster_idx, cluster in enumerate(self.clusters):
                print(f"Processing Cluster {cluster_idx + 1}")

                # Randomly select clients from the cluster
                num_selected_clients = max(int(self.client_fraction * len(cluster)), 1)
                selected_clients = np.random.choice(cluster, num_selected_clients, replace=False)

                client_updates = []
                client_sizes = []
                total_loss = 0

                round_start_time = time.time()

                # Each selected client performs local training
                for client_id in selected_clients:
                    client_data = self.clients[client_id]
                    client_size = len(client_data["train"])
                    client_sizes.append(client_size)

                    updated_weights, loss_per_client, model_size, num_params, local_train_time = self.client_update(client_data)
                    client_updates.append(updated_weights)
                    total_loss += loss_per_client

                    # Track communication cost
                    self.communication_time -= local_train_time
                    # print(local_train_time, self.communication_time)
                    self.communication_cost += model_size
                    self.param_count += num_params
                
                round_end_time = time.time()

                self.communication_time += (round_end_time - round_start_time)
                # print(round_end_time - round_start_time, self.communication_time)

                # Aggregate updates from the cluster
                global_weights = self.aggregate_updates(client_updates, client_sizes)
                self.global_model.load_state_dict(global_weights)

            # Evaluate the global model on all validation datasets
            avg_dice_score = self.evaluate_global_model(
                [client["val"] for client in self.clients],
                post_pred,
                post_label,
                dice_metric,
            )
            print(f"Average Dice Score after Round {round_num + 1}: {avg_dice_score:.4f}")

            self.dice_metric_values.append(avg_dice_score)
            avg_loss_per_round = total_loss / len(selected_clients) if len(selected_clients) > 0 else 0
            self.round_loss_values.append(avg_loss_per_round)

        # Print communication costs
        print(f"Total communication cost: {self.communication_cost / 1e6:.2f} MB")
        print(f"Total communication time: {self.communication_time:.2f} seconds")
        print(f"Total parameters exchanged: {self.param_count}")
            
    def evaluate_global_model(self, val_datasets, post_pred, post_label, dice_metric):
        """
        Evaluate the global model on all clients' validation datasets.
        """
        self.global_model.eval()
        dice_scores = []

        for val_data in val_datasets:
            val_loader = DataLoader(val_data, batch_size=1, num_workers=4)
            with torch.no_grad():
                for batch_data in val_loader:
                    val_inputs, val_labels = (
                        batch_data["image"].to(self.device),
                        batch_data["label"].to(self.device),
                    )
                    roi_size = (160, 160, 160)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, self.global_model)
                    val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                    val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                    dice_metric(y_pred=val_outputs, y=val_labels)

                dice_score = dice_metric.aggregate().item()
                dice_scores.append(dice_score)
                dice_metric.reset()

        return np.mean(dice_scores)


In [4]:
def split_dataset_iid(train_dataset, val_dataset, num_clients):
    #make the dataset evenly divisible by removing the last elements if needed
    train_dataset = train_dataset[:len(train_dataset) - len(train_dataset) % num_clients]
    val_dataset = val_dataset[:len(val_dataset) - len(val_dataset) % num_clients]
    
    #split the dataset into num_clients subsets randomly 
    train_indices = np.random.permutation(len(train_dataset))
    val_indices = np.random.permutation(len(val_dataset))
    train_client_splits = np.array_split(train_indices, num_clients)
    val_client_splits = np.array_split(val_indices, num_clients)
    
    #create a list of subsets
    train_client_datasets = [Subset(train_dataset, split) for split in train_client_splits]
    val_client_datasets = [Subset(val_dataset, split) for split in val_client_splits]
    
    return train_client_datasets, val_client_datasets

In [5]:
def split_dataset_non_iid(train_dataset, val_dataset, num_clients):
    if len(train_dataset) % num_clients != 0 or len(val_dataset) % num_clients != 0:
         #split the dataset into num_clients subsets randomly 
        train_indices = np.random.permutation(len(train_dataset))
        val_indices = np.random.permutation(len(val_dataset))
        train_client_splits = np.array_split(train_indices, num_clients)
        val_client_splits = np.array_split(val_indices, num_clients)
        
        #create a list of subsets
        train_client_datasets = [Subset(train_dataset, split) for split in train_client_splits]
        val_client_datasets = [Subset(val_dataset, split) for split in val_client_splits]
        
        return train_client_datasets, val_client_datasets
    
    else: 
        while len(train_dataset) % num_clients == 0:
            train_dataset = train_dataset[:len(train_dataset) - 1]
        return split_dataset_non_iid(train_dataset, val_dataset, num_clients)

In [6]:
#set the data path
root_dir = os.path.dirname(os.path.dirname(os.path.abspath("federated_learning")))
data_dir = os.path.join(root_dir, "data/raw/Task09_Spleen/")
model_dir = os.path.join(root_dir, "models_gpu")

In [7]:
#load the entire dataset
train_ds, val_ds = get_spleen_dataset(data_dir)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)

#split the dataset into num_clients IID subsets
num_clients = 4
train_client_datasets, val_client_datasets = split_dataset_iid(train_ds, val_ds, num_clients)



['c:\\Users\\edmund\\Documents\\mlsys\\federated_learning\\data/raw/Task09_Spleen/imagesTr\\spleen_10.nii.gz', 'c:\\Users\\edmund\\Documents\\mlsys\\federated_learning\\data/raw/Task09_Spleen/imagesTr\\spleen_12.nii.gz', 'c:\\Users\\edmund\\Documents\\mlsys\\federated_learning\\data/raw/Task09_Spleen/imagesTr\\spleen_13.nii.gz', 'c:\\Users\\edmund\\Documents\\mlsys\\federated_learning\\data/raw/Task09_Spleen/imagesTr\\spleen_14.nii.gz', 'c:\\Users\\edmund\\Documents\\mlsys\\federated_learning\\data/raw/Task09_Spleen/imagesTr\\spleen_16.nii.gz', 'c:\\Users\\edmund\\Documents\\mlsys\\federated_learning\\data/raw/Task09_Spleen/imagesTr\\spleen_17.nii.gz', 'c:\\Users\\edmund\\Documents\\mlsys\\federated_learning\\data/raw/Task09_Spleen/imagesTr\\spleen_18.nii.gz', 'c:\\Users\\edmund\\Documents\\mlsys\\federated_learning\\data/raw/Task09_Spleen/imagesTr\\spleen_19.nii.gz', 'c:\\Users\\edmund\\Documents\\mlsys\\federated_learning\\data/raw/Task09_Spleen/imagesTr\\spleen_2.nii.gz', 'c:\\Users

Loading dataset: 100%|██████████| 32/32 [00:31<00:00,  1.01it/s]
Loading dataset: 100%|██████████| 9/9 [00:07<00:00,  1.25it/s]


# Model training

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

In [9]:
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")

student_model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(8, 16, 32, 64),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)

teacher_state_dict = torch.load(os.path.join(model_dir, "fedc_r200_niid.pth"))

teacher_model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
teacher_model.load_state_dict(teacher_state_dict['global_model_state'])

# FedAvg configuration
num_clients = 4  # Total number of clients
#num_rounds = 2  # Number of communication rounds
client_fraction = 0.5  # Fraction of clients selected per round
local_epochs = 2  # Number of epochs each client trains locally
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([AsDiscrete(to_onehot=2)])

fedcluster = FedCluster(
    model=student_model,
    loss_function=loss_function,
    device=device,
    num_clients=num_clients,
    client_fraction=client_fraction,
    local_epochs=local_epochs,
    batch_size=2,
    lr=1e-4,
    num_clusters=2,
    teacher_model=teacher_model,  # Pass the teacher model
    alpha=0.7,                    # Weight for distillation loss
    temperature=3.0               # Temperature for softening logits
)

fedcluster.set_clients(train_client_datasets, val_client_datasets)

num_rounds = 200
fedcluster.run(num_rounds, post_pred, post_label, dice_metric)

torch.save({
    'round_losses': fedcluster.round_loss_values,
    'dice_scores': fedcluster.dice_metric_values,
    'model size': fedcluster.communication_cost,
    'communication time': fedcluster.communication_time,
    'parameter count': fedcluster.param_count,
    'global_model_state': fedcluster.global_model.state_dict()
}, 'fedc_kd_r200_niid.pth')


  teacher_state_dict = torch.load(os.path.join(model_dir, "fedc_r200_niid.pth"))


--- Round 1 ---
Processing Cluster 1
Processing Cluster 2
Average Dice Score after Round 1: 0.0181
--- Round 2 ---
Processing Cluster 1
Processing Cluster 2
Average Dice Score after Round 2: 0.0334
--- Round 3 ---
Processing Cluster 1
Processing Cluster 2
Average Dice Score after Round 3: 0.0429
--- Round 4 ---
Processing Cluster 1
Processing Cluster 2
Average Dice Score after Round 4: 0.0592
--- Round 5 ---
Processing Cluster 1
Processing Cluster 2
Average Dice Score after Round 5: 0.0596
--- Round 6 ---
Processing Cluster 1
Processing Cluster 2
Average Dice Score after Round 6: 0.0661
--- Round 7 ---
Processing Cluster 1
Processing Cluster 2
Average Dice Score after Round 7: 0.0789
--- Round 8 ---
Processing Cluster 1
Processing Cluster 2
Average Dice Score after Round 8: 0.0877
--- Round 9 ---
Processing Cluster 1
Processing Cluster 2
Average Dice Score after Round 9: 0.0889
--- Round 10 ---
Processing Cluster 1
Processing Cluster 2
Average Dice Score after Round 10: 0.1066
--- Roun

In [10]:
fedcluster_nkd = FedCluster(
    model=model,
    loss_function=loss_function,
    device=device,
    num_clients=num_clients,
    client_fraction=client_fraction,
    local_epochs=local_epochs,
    batch_size=2,
    lr=1e-4,
    num_clusters=2,
    teacher_model=None,  # Pass the teacher model
    alpha=0.7,                    # Weight for distillation loss
    temperature=3.0               # Temperature for softening logits
)

fedcluster_nkd.set_clients(train_client_datasets, val_client_datasets)

num_rounds = 200     
fedcluster_nkd.run(num_rounds, post_pred, post_label, dice_metric)

torch.save({
    'round_losses': fedcluster_nkd.round_loss_values,
    'dice_scores': fedcluster_nkd.dice_metric_values,
    'model size': fedcluster_nkd.communication_cost,
    'communication time': fedcluster_nkd.communication_time,
    'parameter count': fedcluster_nkd.param_count,
    'global_model_state': fedcluster_nkd.global_model.state_dict()
}, 'fedc_nkd_r200_niid.pth')


--- Round 1 ---
Processing Cluster 1
Processing Cluster 2
Average Dice Score after Round 1: 0.0149
--- Round 2 ---
Processing Cluster 1
Processing Cluster 2
Average Dice Score after Round 2: 0.0322
--- Round 3 ---
Processing Cluster 1
Processing Cluster 2
Average Dice Score after Round 3: 0.0373
--- Round 4 ---
Processing Cluster 1
Processing Cluster 2
Average Dice Score after Round 4: 0.0416
--- Round 5 ---
Processing Cluster 1
Processing Cluster 2
Average Dice Score after Round 5: 0.0480
--- Round 6 ---
Processing Cluster 1
Processing Cluster 2
Average Dice Score after Round 6: 0.0486
--- Round 7 ---
Processing Cluster 1
Processing Cluster 2
Average Dice Score after Round 7: 0.0521
--- Round 8 ---
Processing Cluster 1
Processing Cluster 2
Average Dice Score after Round 8: 0.0420
--- Round 9 ---
Processing Cluster 1
Processing Cluster 2
Average Dice Score after Round 9: 0.0497
--- Round 10 ---
Processing Cluster 1
Processing Cluster 2
Average Dice Score after Round 10: 0.0548
--- Roun

In [11]:
data = []
for i in range(200):
    round_data = {
        'round': i+1,
        'round_losses': fedcluster.round_loss_values[i],
        'dice_scores': fedcluster.dice_metric_values[i] if i % 2 != 0 else '',  # 'NA' for even rounds
        'model_size': fedcluster.communication_cost if i == 0 else '',
        'communication_time': fedcluster.communication_time if i == 0 else '',
        'parameter_count': fedcluster.param_count if i == 0 else ''
    }
    data.append(round_data)

df = pd.DataFrame(data)
df.to_csv('fedc_kd_r200_niid_plot.csv', index=False)

print("CSV file saved as 'fedc_kd_r200_niid_plot.csv'")
df

CSV file saved as 'fedc_kd_r200_niid_plot.csv'


Unnamed: 0,round,round_losses,dice_scores,model_size,communication_time,parameter_count
0,1,0.605714,,478116800,14.470428,119524000
1,2,0.591351,0.033404,,,
2,3,0.581811,,,,
3,4,0.567760,0.059243,,,
4,5,0.555551,,,,
...,...,...,...,...,...,...
195,196,0.200636,0.923943,,,
196,197,0.167078,,,,
197,198,0.193291,0.927936,,,
198,199,0.136929,,,,


In [12]:
data = []
for i in range(200):
    round_data = {
        'round': i+1,
        'round_losses': fedcluster_nkd.round_loss_values[i],
        'dice_scores': fedcluster_nkd.dice_metric_values[i] if i % 2 != 0 else '',  # 'NA' for even rounds
        'model_size': fedcluster_nkd.communication_cost if i == 0 else '',
        'communication_time': fedcluster_nkd.communication_time if i == 0 else '',
        'parameter_count': fedcluster_nkd.param_count if i == 0 else ''
    }
    data.append(round_data)

df = pd.DataFrame(data)
df.to_csv('fedc_nkd_r200_niid_plot.csv', index=False)

print("CSV file saved as 'fedc_nkd_r200_niid_plot.csv'")

CSV file saved as 'fedc_nkd_r200_niid_plot.csv'


In [None]:
#IID KD
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")

student_model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(8, 16, 32, 64),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)

# FedAvg configuration
num_clients = 4  # Total number of clients
#num_rounds = 2  # Number of communication rounds
client_fraction = 0.5  # Fraction of clients selected per round
local_epochs = 2  # Number of epochs each client trains locally
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([AsDiscrete(to_onehot=2)])


teacher_state_dict = torch.load(os.path.join(model_dir, "fedc_r200_iid.pth"))

teacher_model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
teacher_model.load_state_dict(teacher_state_dict['global_model_state'])

fedcluster_kd_iid = FedCluster(
    model=student_model,
    loss_function=loss_function,
    device=device,
    num_clients=num_clients,
    client_fraction=client_fraction,
    local_epochs=local_epochs,
    batch_size=2,
    lr=1e-4,
    num_clusters=2,
    teacher_model=teacher_model,  # Pass the teacher model
    alpha=0.7,                    # Weight for distillation loss
    temperature=3.0               # Temperature for softening logits
)

fedcluster_kd_iid.set_clients(train_client_datasets, val_client_datasets)

num_rounds = 200
fedcluster_kd_iid.run(num_rounds, post_pred, post_label, dice_metric)

torch.save({
    'round_losses': fedcluster_kd_iid.round_loss_values,
    'dice_scores': fedcluster_kd_iid.dice_metric_values,
    'model size': fedcluster_kd_iid.communication_cost,
    'communication time': fedcluster_kd_iid.communication_time,
    'parameter count': fedcluster_kd_iid.param_count,
    'global_model_state': fedcluster_kd_iid.global_model.state_dict()
}, 'fedc_kd_r200_iid.pth')


In [None]:
#IID Non-KD
fedcluster_nkd_iid = FedCluster(
    model=model,
    loss_function=loss_function,
    device=device,
    num_clients=num_clients,
    client_fraction=client_fraction,
    local_epochs=local_epochs,
    batch_size=2,
    lr=1e-4,
    num_clusters=2,
    teacher_model=None,  # Pass the teacher model
    alpha=0.7,                    # Weight for distillation loss
    temperature=3.0               # Temperature for softening logits
)

fedcluster_nkd_iid.set_clients(train_client_datasets, val_client_datasets)

num_rounds = 200     
fedcluster_nkd_iid.run(num_rounds, post_pred, post_label, dice_metric)

torch.save({
    'round_losses': fedcluster_nkd_iid.round_loss_values,
    'dice_scores': fedcluster_nkd_iid.dice_metric_values,
    'model size': fedcluster_nkd_iid.communication_cost,
    'communication time': fedcluster_nkd_iid.communication_time,
    'parameter count': fedcluster_nkd_iid.param_count,
    'global_model_state': fedcluster_nkd_iid.global_model.state_dict()
}, 'fedc_nkd_r200_iid.pth')


In [11]:
data = []
for i in range(200):
    round_data = {
        'round': i+1,
        'round_losses': fedcluster_kd_iid.round_loss_values[i],
        'dice_scores': fedcluster_kd_iid.dice_metric_values[i] if i % 2 != 0 else '',  # 'NA' for even rounds
        'model_size': fedcluster_kd_iid.communication_cost if i == 0 else '',
        'communication_time': fedcluster_kd_iid.communication_time if i == 0 else '',
        'parameter_count': fedcluster_kd_iid.param_count if i == 0 else ''
    }
    data.append(round_data)

df = pd.DataFrame(data)
df.to_csv('fedc_kd_r200_iid_plot.csv', index=False)

print("CSV file saved as 'fedc_kd_r200_iid_plot.csv'")
df

CSV file saved as 'fedc_kd_r200_iid_plot.csv'


Unnamed: 0,round,round_losses,dice_scores,model_size,communication_time,parameter_count
0,1,0.605714,,478116800,12.765136,119524000
1,2,0.591350,0.033372,,,
2,3,0.581809,,,,
3,4,0.567759,0.059096,,,
4,5,0.555549,,,,
...,...,...,...,...,...,...
195,196,0.201577,0.926182,,,
196,197,0.166992,,,,
197,198,0.195321,0.928434,,,
198,199,0.137545,,,,


In [None]:
data = []
for i in range(200):
    round_data = {
        'round': i+1,
        'round_losses': fedcluster_nkd_iid.round_loss_values[i],
        'dice_scores': fedcluster_nkd_iid.dice_metric_values[i] if i % 2 != 0 else '',  # 'NA' for even rounds
        'model_size': fedcluster_nkd_iid.communication_cost if i == 0 else '',
        'communication_time': fedcluster_nkd_iid.communication_time if i == 0 else '',
        'parameter_count': fedcluster_nkd_iid.param_count if i == 0 else ''
    }
    data.append(round_data)

df = pd.DataFrame(data)
df.to_csv('fedc_nkd_r200_iid_plot.csv', index=False)

print("CSV file saved as 'fedc_nkd_r200_iid_plot.csv'")
df