In [1]:
import os
import copy
import numpy as np
import torch
import torch.nn.functional as F
from monai.data import decollate_batch
import time
import glob

%matplotlib inline
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    Invertd,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
import tarfile
import numpy as np
from torch.utils.data import Subset
import copy
set_determinism(seed=0)

#print_config()


# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
2024-12-22 12:40:23.361801: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


# Read and split data

In [13]:
def get_spleen_dataset(data_dir):
    train_images = sorted(glob.glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
    print(train_images)
    train_labels = sorted(glob.glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))
    data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)]
    train_files, val_files = data_dicts[:-9], data_dicts[-9:]
    
    train_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
            RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=(96, 96, 96),
                pos=1,
                neg=1,
                num_samples=4,
                image_key="image",
                image_threshold=0,
            ),
            # user can also add other random transforms
            # RandAffined(
            #     keys=['image', 'label'],
            #     mode=('bilinear', 'nearest'),
            #     prob=1.0, spatial_size=(96, 96, 96),
            #     rotate_range=(0, 0, np.pi/15),
            #     scale_range=(0.1, 0.1, 0.1)),
        ]
    )
    val_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        ]
    )
    
    train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4)
    val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4)
    
    return train_ds, val_ds

In [4]:
def split_dataset_non_iid(train_dataset, val_dataset, num_clients):
    if len(train_dataset) % num_clients != 0 or len(val_dataset) % num_clients != 0:
         #split the dataset into num_clients subsets randomly 
        train_indices = np.random.permutation(len(train_dataset))
        val_indices = np.random.permutation(len(val_dataset))
        train_client_splits = np.array_split(train_indices, num_clients)
        val_client_splits = np.array_split(val_indices, num_clients)
        
        #create a list of subsets
        train_client_datasets = [Subset(train_dataset, split) for split in train_client_splits]
        val_client_datasets = [Subset(val_dataset, split) for split in val_client_splits]
        
        return train_client_datasets, val_client_datasets
    
    else: 
        while len(train_dataset) % num_clients == 0:
            train_dataset = train_dataset[:len(train_dataset) - 1]
        return split_dataset_non_iid(train_dataset, val_dataset, num_clients)

In [5]:
#set the data path
root_dir = os.path.dirname(os.path.dirname(os.path.abspath("federated_learning")))
data_dir = os.path.join(root_dir, "data/raw/Task09_Spleen/")
model_dir = os.path.join(root_dir, "models_gpu")

In [6]:
#load the entire dataset
train_ds, val_ds = get_spleen_dataset(data_dir)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)

#split the dataset into num_clients IID subsets
num_clients = 4
train_client_datasets, val_client_datasets = split_dataset_non_iid(train_ds, val_ds, num_clients)

#check the number of samples in each client
for i, client_data in enumerate(val_client_datasets):
    print(f"Client {i+1} has {len(client_data)} samples.")      



['/Users/danyu/Desktop/SJTU/Distributed_ML_Systems/federated_learning/data/raw/Task09_Spleen/imagesTr/spleen_10.nii.gz', '/Users/danyu/Desktop/SJTU/Distributed_ML_Systems/federated_learning/data/raw/Task09_Spleen/imagesTr/spleen_12.nii.gz', '/Users/danyu/Desktop/SJTU/Distributed_ML_Systems/federated_learning/data/raw/Task09_Spleen/imagesTr/spleen_13.nii.gz', '/Users/danyu/Desktop/SJTU/Distributed_ML_Systems/federated_learning/data/raw/Task09_Spleen/imagesTr/spleen_14.nii.gz', '/Users/danyu/Desktop/SJTU/Distributed_ML_Systems/federated_learning/data/raw/Task09_Spleen/imagesTr/spleen_16.nii.gz', '/Users/danyu/Desktop/SJTU/Distributed_ML_Systems/federated_learning/data/raw/Task09_Spleen/imagesTr/spleen_17.nii.gz', '/Users/danyu/Desktop/SJTU/Distributed_ML_Systems/federated_learning/data/raw/Task09_Spleen/imagesTr/spleen_18.nii.gz', '/Users/danyu/Desktop/SJTU/Distributed_ML_Systems/federated_learning/data/raw/Task09_Spleen/imagesTr/spleen_19.nii.gz', '/Users/danyu/Desktop/SJTU/Distributed_

Loading dataset: 100%|██████████| 32/32 [00:34<00:00,  1.08s/it]
Loading dataset: 100%|██████████| 9/9 [00:07<00:00,  1.19it/s]

Client 1 has 3 samples.
Client 2 has 2 samples.
Client 3 has 2 samples.
Client 4 has 2 samples.





# FedMixCluster

In [None]:
class FedMixCluster:
    def __init__(self, model, loss_function, device, num_clients, client_fraction, local_epochs, batch_size, lr, teacher_model=None, alpha=0.5, temperature=3.0, lambda_=10, beta_=1.5, threshold=0.9, num_clusters=2):
        self.global_model = model  # The student model
        self.loss_function = loss_function
        self.device = device
        self.num_clients = num_clients
        self.client_fraction = client_fraction
        self.local_epochs = local_epochs
        self.batch_size = batch_size
        self.lr = lr
        self.num_clusters = num_clusters
        self.clients = []
        self.clusters = []
        self.teacher_model = teacher_model.to(device) if teacher_model else None
        self.alpha = alpha
        self.temperature = temperature
        
        self.weights = []
        self.round_loss_values = []  
        self.dice_metric_values = [] 

        self.lambda_ = lambda_
        self.beta_ = beta_
        self.threshold = threshold
        self.client_supervision_types = []  # Types of client supervision: fully_labeled, weakly_labeled, unlabeled
        
    def set_clients(self, train_datasets, val_datasets):
        # Assign datasets to clients
        self.clients = [{"train": train, "val": val} for train, val in zip(train_datasets, val_datasets)]
        self.create_clusters()
        
        # Track communication costs
        self.communication_cost = 0  # For KD-based lightweight models
        self.communication_time = 0  # Total time spent in communication (in seconds)
        self.param_count = 0  # Total number of parameters exchanged

    def set_supervision_types(self, supervision_types):
        # Assign supervision type to each client
        if len(supervision_types) != self.num_clients:
            raise ValueError("Length of supervision_types must match the number of clients.")
        self.client_supervision_types = supervision_types

    def create_clusters(self):
        # Randomly group clients into clusters
        client_ids = np.arange(self.num_clients)
        np.random.shuffle(client_ids)
        cluster_size = self.num_clients // self.num_clusters
        self.clusters = [client_ids[i:i + cluster_size] for i in range(0, len(client_ids), cluster_size)]

    def calculate_model_size(self, model_state_dict):
        # Calculate the size of the model state in bytes
        return sum(param.numel() * param.element_size() for param in model_state_dict.values())

    def calculate_parameter_count(self, model_state_dict):
        """
        Count the number of parameters in the model state.
        """
        return sum(param.numel() for param in model_state_dict.values())
    
    def distillation_loss(self, student_logits, teacher_logits):
        """
        Compute the distillation loss using KL divergence between softened logits.
        Args:
            student_logits: Logits from the student model.
            teacher_logits: Logits from the teacher model.
        """
        student_soft = F.log_softmax(student_logits / self.temperature, dim=1)
        teacher_soft = F.softmax(teacher_logits / self.temperature, dim=1)
        return F.kl_div(student_soft, teacher_soft, reduction="batchmean") * (self.temperature ** 2)

    def pseudo_label_generation(self, model, data_loader, supervision_type):
        # Generate and refine pseudo-labels
        pseudo_data = []
        model.eval()
        with torch.no_grad():
            for batch_data in data_loader:
                inputs = batch_data["image"].to(self.device)  # Shape: [batch_size, 1, H, W, D]
                outputs = model(inputs)  # Shape: [batch_size, num_classes, H, W, D]

                if supervision_type == "unlabeled":
                    # Generate pseudo-labels for unlabeled data
                    probs = torch.softmax(outputs, dim=1)  # Shape: [batch_size, num_classes, H, W, D]
                    confidence, pseudo_labels = torch.max(probs, dim=1)  # Shape: [batch_size, H, W, D]
                    mask = confidence > self.threshold  # Shape: [batch_size, H, W, D]

                    if mask.sum() == 0:
                        # Skip this batch if no confident predictions
                        continue

                    # Align mask dimensions with inputs
                    mask = mask.unsqueeze(1)  # Add the channel dimension: [batch_size, 1, H, W, D]

                    # Filter inputs and pseudo-labels using the mask
                    masked_inputs = inputs[mask.expand_as(inputs)].contiguous()
                    masked_labels = pseudo_labels[mask.squeeze(1)].contiguous()

                    if masked_inputs.numel() > 0 and masked_labels.numel() > 0:
                        # Reshape to ensure valid tensor dimensions
                        masked_inputs = masked_inputs.view(-1, *inputs.shape[1:])
                        masked_labels = masked_labels.view(-1, *pseudo_labels.shape[1:])
                        pseudo_data.append((masked_inputs, masked_labels))

                elif supervision_type == "weakly_labeled":
                    # Add logic to refine pseudo-labels with bounding boxes if needed
                    pass  # Implement bounding box refinement logic if required

                else:
                    # Fully labeled data, no pseudo-labeling required
                    pseudo_data.append((inputs, batch_data["label"].to(self.device)))

        return pseudo_data


    def client_update(self, client_data, supervision_type):
        local_model = copy.deepcopy(self.global_model)
        local_model.train()
        local_model.to(self.device)
        local_train_start = time.time()

        if self.teacher_model:
            self.teacher_model.eval()  # Ensure teacher model is frozen during training

        train_loader = DataLoader(client_data["train"], batch_size=self.batch_size, shuffle=True, num_workers=4)
        optimizer = torch.optim.Adam(local_model.parameters(), lr=self.lr)
        
        total_loss = 0
        num_batches = 0
        
        # Debug: Verify train_loader output
        #for batch_data in train_loader:
        #    print("Raw batch_data in client_update:", batch_data)  # Debugging
        #    break

        for _ in range(self.local_epochs):
            for batch_data in train_loader:
                # Debug: Ensure batch_data has correct keys
                #print("Before processing batch_data:", batch_data)

                inputs = batch_data["image"].to(self.device)
                labels = batch_data["label"].to(self.device)

                optimizer.zero_grad()

                if supervision_type == "fully_labeled":
                    outputs = local_model(inputs)
                    dice_loss = self.loss_function(outputs, labels)

                    if self.teacher_model:
                        with torch.no_grad():
                            teacher_outputs = self.teacher_model(inputs)

                        distill_loss = self.distillation_loss(outputs, teacher_outputs)
                        distill_loss = torch.sigmoid(distill_loss)
                        # Combine Dice loss and distillation loss
                        task_specific_loss = dice_loss.item()
                        loss = self.alpha * distill_loss + (1 - self.alpha) * dice_loss
                    else:
                        loss = dice_loss

                elif supervision_type in ["weakly_labeled", "unlabeled"]:
                    pseudo_data = self.pseudo_label_generation(local_model, train_loader, supervision_type)
                    if len(pseudo_data) > 0:
                        pseudo_inputs, pseudo_labels = zip(*pseudo_data)
                        pseudo_inputs = torch.cat(pseudo_inputs, dim=0)
                        pseudo_labels = torch.cat(pseudo_labels, dim=0)

                        outputs = local_model(pseudo_inputs)
                        dice_loss = self.loss_function(outputs, pseudo_labels)

                        if self.teacher_model:
                            with torch.no_grad():
                                teacher_outputs = self.teacher_model(pseudo_inputs)

                            distill_loss = self.distillation_loss(outputs, teacher_outputs)
                            distill_loss = torch.sigmoid(distill_loss)
                            # Combine Dice loss and distillation loss
                            task_specific_loss = dice_loss.item()
                            loss = self.alpha * distill_loss + (1 - self.alpha) * dice_loss
                        else:
                            loss = dice_loss
                    else:
                        continue  # Skip this batch if no valid pseudo-labels are generated

                else:
                    raise ValueError(f"Unsupported supervision type: {supervision_type}")

                loss.backward()
                optimizer.step()

                if self.teacher_model:
                    total_loss += task_specific_loss
                else:
                    total_loss += loss.item()
                num_batches += 1

        local_train_end = time.time()
        
        loss_per_client = total_loss / num_batches if num_batches > 0 else 0

        # Calculate model size for communication cost
        model_size = self.calculate_model_size(local_model.state_dict())

        # Calculate number of parameters passed for communication cost
        num_params = self.calculate_parameter_count(local_model.state_dict())

        local_train_time = local_train_end - local_train_start

        return local_model.state_dict(), loss_per_client, model_size, num_params, local_train_time

    def dynamic_weighting(self, client_losses, client_sizes):
        """
        Calculate dynamic weights for each client based on its dataset size and task-specific loss.
        Args:
            client_losses: List of average losses for each client.
            client_sizes: List of the number of samples in each client.
        """
        # Score based on loss (penalize higher losses)
        scores = [loss ** self.beta_ for loss in client_losses]
        denominator = sum(scores)
        
        if denominator == 0:
            # If all losses are zero, assign equal weights
            weighted_scores = [1 / len(client_losses) for _ in client_losses]
        else:
            weighted_scores = [score / denominator for score in scores]

        # Weight based on data size
        size_weights = [size / sum(client_sizes) for size in client_sizes]

        # Combine size and loss-based weights
        self.weights = [w + self.lambda_ * s for w, s in zip(size_weights, weighted_scores)]

        # Normalize weights
        total_weight = sum(self.weights)
        if total_weight == 0:
            # Safeguard against division by zero
            self.weights = [1 / len(client_sizes) for _ in client_sizes]
        else:
            self.weights = [w / total_weight for w in self.weights]

    def aggregate_updates(self, client_updates, client_sizes, client_losses):
        """
        Aggregate client updates to compute global weights, ensuring all parameters are tensors.
        """
        self.dynamic_weighting(client_losses, client_sizes)
        global_weights = copy.deepcopy(self.global_model.state_dict())

        for key in global_weights.keys():
            # Ensure all values being aggregated are tensors
            aggregated_tensors = torch.stack([
                client_updates[i][key] * self.weights[i] for i in range(len(client_updates))
            ], dim=0).sum(dim=0)

            # Assign aggregated tensor to global weights
            global_weights[key] = aggregated_tensors

        return global_weights

    def cluster_aggregate(self, cluster_updates, cluster_sizes):
        """
        Aggregate updates across clusters to compute global weights, ensuring all parameters are tensors.
        """
        global_weights = copy.deepcopy(self.global_model.state_dict())
        total_size = sum(cluster_sizes)

        for key in global_weights.keys():
            # Ensure all values being aggregated are tensors
            aggregated_tensors = torch.stack([
                cluster_updates[i][key] * (cluster_sizes[i] / total_size) for i in range(len(cluster_updates))
            ], dim=0).sum(dim=0)

            # Assign aggregated tensor to global weights
            global_weights[key] = aggregated_tensors

        return global_weights


    def run(self, num_rounds, post_pred, post_label, dice_metric):
        # Run the FedMixCluster algorithm with clustering
        global_weights = self.global_model.state_dict()
        
        self.communication_time = 0  # Reset communication time
        self.param_count = 0  # Reset parameter count
        
        for round_num in range(num_rounds):
            print(f"--- Round {round_num + 1} ---")
            cluster_updates = []
            cluster_sizes = []

            for cluster_idx, cluster in enumerate(self.clusters):
                print(f"Processing Cluster {cluster_idx + 1}")
                client_updates = []
                client_sizes = []
                client_losses = []
                total_loss = 0
                
                round_start_time = time.time()

                # Randomly select clients within the cluster
                num_selected_clients = max(int(self.client_fraction * len(cluster)), 1)
                selected_clients = np.random.choice(cluster, num_selected_clients, replace=False)

                for client_id in selected_clients:
                    client_data = self.clients[client_id]
                    supervision_type = self.client_supervision_types[client_id]
                    updated_weights, loss_per_client, model_size, num_params, local_train_time = self.client_update(client_data, supervision_type)
                    client_updates.append(updated_weights)
                    client_sizes.append(len(client_data["train"]))
                    client_losses.append(loss_per_client)
                    total_loss += loss_per_client
                    
                    # Track communication cost
                    self.communication_time -= local_train_time
                    # print(local_train_time, self.communication_time)
                    self.communication_cost += model_size
                    self.param_count += num_params
                    
                round_end_time = time.time()
                self.communication_time += (round_end_time - round_start_time)
                    
                # Aggregate within the cluster
                cluster_weights = self.aggregate_updates(client_updates, client_sizes, client_losses)
                cluster_updates.append(cluster_weights)
                cluster_sizes.append(sum(client_sizes))

            global_weights = self.cluster_aggregate(cluster_updates, cluster_sizes)
            self.global_model.load_state_dict(global_weights)

            # Evaluate the global model
            avg_dice_score = self.evaluate_global_model([client["val"] for client in self.clients], post_pred, post_label, dice_metric)
            print(f"Average Dice Score after Round {round_num + 1}: {avg_dice_score:.4f}")
            self.dice_metric_values.append(avg_dice_score)
            avg_loss_per_round = total_loss / len(selected_clients) if len(selected_clients) > 0 else 0
            self.round_loss_values.append(avg_loss_per_round)
            
        # Print communication costs
        print(f"Total communication cost: {self.communication_cost / 1e6:.2f} MB")
        print(f"Total communication time: {self.communication_time:.2f} seconds")
        print(f"Total parameters exchanged: {self.param_count}")
            

    def evaluate_global_model(self, val_datasets, post_pred, post_label, dice_metric):
        # Evaluate the global model
        self.global_model.eval()
        dice_scores = []
        for val_data in val_datasets:
            val_loader = DataLoader(val_data, batch_size=1, num_workers=4)
            with torch.no_grad():
                for batch_data in val_loader:
                    inputs = batch_data["image"].to(self.device)
                    labels = batch_data["label"].to(self.device)
                    outputs = sliding_window_inference(inputs, (160, 160, 160), 4, self.global_model)
                    outputs = [post_pred(i) for i in decollate_batch(outputs)]
                    labels = [post_label(i) for i in decollate_batch(labels)]
                    dice_metric(y_pred=outputs, y=labels)
                dice_score = dice_metric.aggregate().item()
                dice_scores.append(dice_score)
                dice_metric.reset()
        return np.mean(dice_scores)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")

student_model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(8, 16, 32, 64),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)

teacher_state_dict = torch.load(os.path.join(model_dir, "fedc_r200_niid.pth"))

teacher_model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
teacher_model.load_state_dict(teacher_state_dict['global_model_state'])

num_clients = 4  # Total number of clients
client_fraction = 0.5  # Fraction of clients selected per round
local_epochs = 2  # Number of epochs each client trains locally
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([AsDiscrete(to_onehot=2)])

fedmix = FedMixCluster(
    model=student_model,
    loss_function=loss_function,
    device=device,
    num_clients=num_clients,
    client_fraction=client_fraction,
    local_epochs=local_epochs,
    batch_size=2,
    lr=1e-4,
    num_clusters=2,
    teacher_model=teacher_model,  # Pass the teacher model
    alpha=0.7,                    # Weight for distillation loss
    temperature=3.0               # Temperature for softening logits
)

# Set clients and supervision types
fedmix.set_clients(train_client_datasets, val_client_datasets)
supervision_types = ["fully_labeled", "weakly_labeled", "unlabeled", "fully_labeled"]
fedmix.set_supervision_types(supervision_types)

# Run FedMixCluster
num_rounds = 5
fedmix.run(num_rounds, post_pred, post_label, dice_metric)


--- Round 1 ---
Processing Cluster 1


  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site

Processing Cluster 2


  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site

Average Dice Score after Round 1: 0.0157
--- Round 2 ---
Processing Cluster 1


  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site

Processing Cluster 2


  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site

Average Dice Score after Round 2: 0.0181
--- Round 3 ---
Processing Cluster 1


  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site

Processing Cluster 2


  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site

Average Dice Score after Round 3: 0.0190
--- Round 4 ---
Processing Cluster 1


  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site

Processing Cluster 2


  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site

Average Dice Score after Round 4: 0.0200
--- Round 5 ---
Processing Cluster 1


  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site

Processing Cluster 2


  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site

Average Dice Score after Round 5: 0.0213
Total communication cost: 11.95 MB
Total communication time: 0.70 seconds
Total parameters exchanged: 2988100


In [27]:
torch.save({
    'round_losses': fedmix.round_loss_values,
    'dice_scores': fedmix.dice_metric_values,
    'model size': fedmix.communication_cost,
    'communication time': fedmix.communication_time,
    'parameter count': fedmix.param_count,
    'global_model_state': fedmix.global_model.state_dict()
}, 'fedm_kd_r5_niid.pth')

# FedMix (without FedCluster and knowledge distillation)

In [33]:
class FedMixv:
    def __init__(self, model, loss_function, device, num_clients, client_fraction, local_epochs, batch_size, lr, lambda_=10, beta_=1.5, threshold=0.9):
        self.global_model = model  # The global model
        self.loss_function = loss_function
        self.device = device
        self.num_clients = num_clients
        self.client_fraction = client_fraction
        self.local_epochs = local_epochs
        self.batch_size = batch_size
        self.lr = lr

        self.weights = []  # Dynamic weights for aggregation
        self.round_loss_values = []  # Per round loss tracking
        self.dice_metric_values = []  # Per round Dice score tracking
        self.communication_cost = 0  # Total communication cost
        self.communication_time = 0  # Total time spent in communication (seconds)
        self.param_count = 0  # Total number of parameters exchanged

        self.lambda_ = lambda_  # Weighting parameter for data size
        self.beta_ = beta_  # Weighting parameter for task-specific loss
        self.threshold = threshold  # Confidence threshold for pseudo-labeling
        self.client_supervision_types = []  # Types of client supervision: fully_labeled, weakly_labeled, unlabeled
        self.clients = []

    def set_clients(self, train_datasets, val_datasets):
        # Assign datasets to clients
        self.clients = [{"train": train, "val": val} for train, val in zip(train_datasets, val_datasets)]

    def set_supervision_types(self, supervision_types):
        # Assign supervision type to each client
        if len(supervision_types) != self.num_clients:
            raise ValueError("Length of supervision_types must match the number of clients.")
        self.client_supervision_types = supervision_types

    def calculate_model_size(self, model_state_dict):
        # Calculate the size of the model state in bytes
        return sum(param.numel() * param.element_size() for param in model_state_dict.values())

    def calculate_parameter_count(self, model_state_dict):
        # Count the number of parameters in the model state
        return sum(param.numel() for param in model_state_dict.values())

    def pseudo_label_generation(self, model, data_loader, supervision_type):
        # Generate and refine pseudo-labels
        pseudo_data = []
        model.eval()
        with torch.no_grad():
            for batch_data in data_loader:
                inputs = batch_data["image"].to(self.device)  # Shape: [batch_size, 1, H, W, D]
                outputs = model(inputs)  # Shape: [batch_size, num_classes, H, W, D]

                if supervision_type == "unlabeled":
                    # Generate pseudo-labels for unlabeled data
                    probs = torch.softmax(outputs, dim=1)  # Shape: [batch_size, num_classes, H, W, D]
                    confidence, pseudo_labels = torch.max(probs, dim=1)  # Shape: [batch_size, H, W, D]
                    mask = confidence > self.threshold  # Shape: [batch_size, H, W, D]

                    # Align mask dimensions with inputs
                    mask = mask.unsqueeze(1)  # Add the channel dimension: [batch_size, 1, H, W, D]

                    # Ensure the mask is valid
                    if mask.sum() > 0:  # Check if the mask is not empty
                        # Gather valid inputs and labels
                        masked_inputs = inputs[mask.expand_as(inputs)]
                        masked_labels = pseudo_labels[mask.squeeze(1)]

                        # Reshape valid tensors to expected shape
                        masked_inputs = masked_inputs.view(-1, *inputs.shape[1:])
                        masked_labels = masked_labels.view(-1, *pseudo_labels.shape[1:])

                        pseudo_data.append((masked_inputs, masked_labels))

                elif supervision_type == "weakly_labeled":
                    # Add logic to refine pseudo-labels with bounding boxes if needed
                    pass  # Implement bounding box refinement logic if required

                else:
                    # Fully labeled data, no pseudo-labeling required
                    pseudo_data.append((inputs, batch_data["label"].to(self.device)))

        return pseudo_data


    def client_update(self, client_data, supervision_type):
        local_model = copy.deepcopy(self.global_model)
        local_model.train()
        local_model.to(self.device)
        local_train_start = time.time()

        train_loader = DataLoader(client_data["train"], batch_size=self.batch_size, shuffle=True, num_workers=4)
        optimizer = torch.optim.Adam(local_model.parameters(), lr=self.lr)
        
        total_loss = 0
        num_batches = 0

        for _ in range(self.local_epochs):
            for batch_data in train_loader:
                inputs = batch_data["image"].to(self.device)
                labels = batch_data["label"].to(self.device)

                optimizer.zero_grad()

                if supervision_type == "fully_labeled":
                    outputs = local_model(inputs)
                    loss = self.loss_function(outputs, labels)

                elif supervision_type in ["weakly_labeled", "unlabeled"]:
                    pseudo_data = self.pseudo_label_generation(local_model, train_loader, supervision_type)
                    if len(pseudo_data) > 0:
                        pseudo_inputs, pseudo_labels = zip(*pseudo_data)
                        pseudo_inputs = torch.cat(pseudo_inputs, dim=0)
                        pseudo_labels = torch.cat(pseudo_labels, dim=0)

                        outputs = local_model(pseudo_inputs)
                        loss = self.loss_function(outputs, pseudo_labels)
                    else:
                        continue  # Skip this batch if no valid pseudo-labels are generated

                else:
                    raise ValueError(f"Unsupported supervision type: {supervision_type}")

                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                num_batches += 1

        local_train_end = time.time()
        
        loss_per_client = total_loss / num_batches if num_batches > 0 else 0

        # Calculate model size for communication cost
        model_size = self.calculate_model_size(local_model.state_dict())

        # Calculate number of parameters passed for communication cost
        num_params = self.calculate_parameter_count(local_model.state_dict())

        local_train_time = local_train_end - local_train_start

        return local_model.state_dict(), loss_per_client, model_size, num_params, local_train_time

    def dynamic_weighting(self, client_losses, client_sizes):
        # Calculate dynamic weights for each client based on its dataset size and task-specific loss
        scores = [loss ** self.beta_ for loss in client_losses]
        denominator = sum(scores)
        
        if denominator == 0:
            weighted_scores = [1 / len(client_losses) for _ in client_losses]
        else:
            weighted_scores = [score / denominator for score in scores]

        size_weights = [size / sum(client_sizes) for size in client_sizes]
        self.weights = [w + self.lambda_ * s for w, s in zip(size_weights, weighted_scores)]

        total_weight = sum(self.weights)
        if total_weight == 0:
            self.weights = [1 / len(client_sizes) for _ in client_sizes]
        else:
            self.weights = [w / total_weight for w in self.weights]

    def aggregate_updates(self, client_updates, client_sizes, client_losses):
        # Aggregate client updates to compute global weights
        self.dynamic_weighting(client_losses, client_sizes)
        global_weights = copy.deepcopy(self.global_model.state_dict())

        for key in global_weights.keys():
            aggregated_tensors = torch.stack([
                client_updates[i][key] * self.weights[i] for i in range(len(client_updates))
            ], dim=0).sum(dim=0)

            global_weights[key] = aggregated_tensors

        return global_weights

    def run(self, num_rounds, post_pred, post_label, dice_metric):
        global_weights = self.global_model.state_dict()

        for round_num in range(num_rounds):
            print(f"--- Round {round_num + 1} ---")
            client_updates = []
            client_sizes = []
            client_losses = []
            total_loss = 0
            round_start_time = time.time()

            # Randomly select clients for this round
            num_selected_clients = max(int(self.client_fraction * self.num_clients), 1)
            selected_clients = np.random.choice(range(self.num_clients), num_selected_clients, replace=False)

            for client_id in selected_clients:
                client_data = self.clients[client_id]
                supervision_type = self.client_supervision_types[client_id]
                updated_weights, loss_per_client, model_size, num_params, local_train_time = self.client_update(client_data, supervision_type)
                client_updates.append(updated_weights)
                client_sizes.append(len(client_data["train"]))
                client_losses.append(loss_per_client)
                total_loss += loss_per_client

                # Track communication cost
                self.communication_time -= local_train_time
                self.communication_cost += model_size
                self.param_count += num_params

            round_end_time = time.time()
            self.communication_time += (round_end_time - round_start_time)

            # Aggregate updates to compute new global model
            global_weights = self.aggregate_updates(client_updates, client_sizes, client_losses)
            self.global_model.load_state_dict(global_weights)

            # Evaluate the global model
            avg_dice_score = self.evaluate_global_model([client["val"] for client in self.clients], post_pred, post_label, dice_metric)
            print(f"Average Dice Score after Round {round_num + 1}: {avg_dice_score:.4f}")
            self.dice_metric_values.append(avg_dice_score)
            avg_loss_per_round = total_loss / len(selected_clients) if len(selected_clients) > 0 else 0
            self.round_loss_values.append(avg_loss_per_round)

        print(f"Total communication cost: {self.communication_cost / 1e6:.2f} MB")
        print(f"Total communication time: {self.communication_time:.2f} seconds")
        print(f"Total parameters exchanged: {self.param_count}")

    def evaluate_global_model(self, val_datasets, post_pred, post_label, dice_metric):
        self.global_model.eval()
        dice_scores = []
        for val_data in val_datasets:
            val_loader = DataLoader(val_data, batch_size=1, num_workers=4)
            with torch.no_grad():
                for batch_data in val_loader:
                    inputs = batch_data["image"].to(self.device)
                    labels = batch_data["label"].to(self.device)
                    outputs = sliding_window_inference(inputs, (160, 160, 160), 4, self.global_model)
                    outputs = [post_pred(i) for i in decollate_batch(outputs)]
                    labels = [post_label(i) for i in decollate_batch(labels)]
                    dice_metric(y_pred=outputs, y=labels)
                dice_score = dice_metric.aggregate().item()
                dice_scores.append(dice_score)
                dice_metric.reset()
        return np.mean(dice_scores)


In [34]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")

student_model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(8, 16, 32, 64),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)

teacher_state_dict = torch.load(os.path.join(model_dir, "fedc_r200_niid.pth"))

teacher_model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
teacher_model.load_state_dict(teacher_state_dict['global_model_state'])

# FedAvg configuration
num_clients = 4  # Total number of clients
#num_rounds = 2  # Number of communication rounds
client_fraction = 0.5  # Fraction of clients selected per round
local_epochs = 2  # Number of epochs each client trains locally
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([AsDiscrete(to_onehot=2)])


fedmixv = FedMixv(
    model=student_model,
    loss_function=loss_function,
    device=device,
    num_clients=num_clients,
    client_fraction=client_fraction,
    local_epochs=local_epochs,
    batch_size=2,
    lr=1e-4,
    lambda_=10,
    beta_=1.5,
    threshold=0.9
)

# Set clients and supervision types
fedmixv.set_clients(train_client_datasets, val_client_datasets)
supervision_types = ["fully_labeled", "weakly_labeled", "unlabeled", "fully_labeled"]
fedmixv.set_supervision_types(supervision_types)

# Run FedMixCluster
num_rounds = 5
fedmixv.run(num_rounds, post_pred, post_label, dice_metric)


--- Round 1 ---


  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site

Average Dice Score after Round 1: 0.0313
--- Round 2 ---


  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site

Average Dice Score after Round 2: 0.0477
--- Round 3 ---


  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site

Average Dice Score after Round 3: 0.0591
--- Round 4 ---


  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site

Average Dice Score after Round 4: 0.0609
--- Round 5 ---


  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so'
  warn(
  Referenced from: '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site

Average Dice Score after Round 5: 0.0670
Total communication cost: 11.95 MB
Total communication time: 0.77 seconds
Total parameters exchanged: 2988100


In [35]:
torch.save({
    'round_losses': fedmixv.round_loss_values,
    'dice_scores': fedmixv.dice_metric_values,
    'model size': fedmixv.communication_cost,
    'communication time': fedmixv.communication_time,
    'parameter count': fedmixv.param_count,
    'global_model_state': fedmixv.global_model.state_dict()
}, 'fedm_r5_niid.pth')