### Velocity demo
Current approach shares full optimizer state between rounds. Optimizer states are aggregated when received on server end and then shared out such that all clients begin the next round with the aggregated optimizer state.
Limitations: 
- We are not measuring communication from this, but if desired, it should be pretty easy to add
- Currently, we share the entire optimizer state, but when training only a single layer, we would only need to share the state related to that layer. To only share updated layer, modify _aggregate_optimizer_states.

In [None]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union, Callable


import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import copy
import sys
import base64
import pickle
import copy
import torch.nn.functional as F
import torchvision.transforms as transforms
from datasets.utils.logging import disable_progress_bar
from torch.utils.data import DataLoader
from flwr.server.strategy import Strategy
import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Metrics, Context, Status, GetParametersRes, Parameters, GetParametersIns, MetricsAggregationFn,NDArrays,Scalar
from flwr.server import ServerApp, ServerConfig, ServerAppComponents 
from flwr.server.strategy import FedAvg, FedProx
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset
from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")
disable_progress_bar()

In [None]:

BATCH_SIZE = 32

def load_datasets(partition_id, num_partitions: int):
    fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions})
    partition = fds.load_partition(partition_id)
    # Divide data on each node: 80% train, 20% test
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
    pytorch_transforms = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    def apply_transforms(batch):
        
        batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
        return batch

    partition_train_test = partition_train_test.with_transform(apply_transforms)
    trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
    valloader = DataLoader(partition_train_test["test"], batch_size=32)
    testset = fds.load_split("test").with_transform(apply_transforms)
    testloader = DataLoader(testset, batch_size=32)
    return trainloader, valloader, testloader

In [None]:
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def train(net, trainloader, epochs: int, optimizer=None):
    """Train the network on the training set."""
    print(f"training network...")
    criterion = torch.nn.CrossEntropyLoss()
    if optimizer is None:
        optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for batch in trainloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

def freeze_layers(model: torch.nn.Module, trainable_layers: int) -> None:
        """Freeze specified layers of the model."""
        for idx, (name, param) in enumerate(model.named_parameters()):
            if idx == trainable_layers or trainable_layers == -1:
                param.requires_grad = True
            else:
                param.requires_grad = False

def get_parameters_size(params: Parameters) -> int:
    size = sys.getsizeof(params)  # Base size of the dataclass instance
    size += sys.getsizeof(params.tensor_type)  # Size of the string
    size += sys.getsizeof(params.tensors)  # Size of the list container
    size += sum(sys.getsizeof(tensor) for tensor in params.tensors)  # Size of each bytes object
    return size


In [None]:

NETWORK_LEN = len(Net().state_dict().keys())
EPOCHS = 5
NUM_PARTITIONS = 3
NUM_OF_CYCLES  = 1
NUM_OF_FULL_UPDATES_BETWEEN_CYCLES = 5
NUM_OF_ROUNDS = (NUM_OF_CYCLES * NUM_OF_FULL_UPDATES_BETWEEN_CYCLES) + (NUM_OF_CYCLES * NETWORK_LEN *2)
print(f"Number of rounds: {NUM_OF_ROUNDS}")
backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 0.0}}


In [None]:
from flwr.common import NDArrays, Scalar

def get_evaluate_fn(
    testloader: DataLoader,
    net: torch.nn.Module,
) -> Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]:
    """Return an evaluation function for server-side evaluation."""

    def evaluate(
        server_round: int, parameters: NDArrays, config: Dict[str, Scalar]
    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        """Use the entire test set for evaluation."""
        
        # Copy model parameters to avoid modifying the original
        net_copy = copy.deepcopy(net)
        
        # Update model with the latest parameters
        params_dict = zip(net_copy.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        net_copy.load_state_dict(state_dict, strict=True)
        
        net_copy.to(DEVICE)
        net_copy.eval()

        # Test the model
        loss, accuracy = test(net_copy, testloader)
        
        # Return loss and metrics
        return loss, {"accuracy": accuracy}

    return evaluate

# FedAvg with velocity from global model

In [None]:
from typing import Union

from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg

fed_part_avg_result = {}
fed_part_avg_model_results = {}

class FedPartAvgWithVelocity(Strategy):
    def __init__(
        self,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        evaluate_fn: Optional[
            Callable[
                [int, NDArrays, dict[str, Scalar]],
                Optional[tuple[float, dict[str, Scalar]]],
            ]
        ] = None,
        on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
        on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
        accept_failures: bool = True,
        initial_parameters: Optional[Parameters] = None,
        fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        inplace: bool = True,
        layer_update_strategy: str = "sequential",
        
    ) -> None:
        super().__init__()
        self.fraction_fit = fraction_fit
        self.fraction_evaluate = fraction_evaluate
        self.min_fit_clients = min_fit_clients
        self.min_evaluate_clients = min_evaluate_clients
        self.min_available_clients = min_available_clients
        self.evaluate_fn = evaluate_fn
        self.on_fit_config_fn = on_fit_config_fn
        self.on_evaluate_config_fn = on_evaluate_config_fn
        self.accept_failures = accept_failures
        self.initial_parameters = initial_parameters
        self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn
        self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn
        self.inplace = inplace

        self.layer_update_strategy = layer_update_strategy  # 'sequential' or 'cyclic'
        self.current_layer = 0  # Track which layer to update
        self.number_of_layers = None
        self.layer_training_sequence = []
        self.training_sequence_index = 0
        self.latest_parameters = initial_parameters
        self.global_optim_state = None


    def __repr__(self) -> str:
        return "FedPartAvgWithVelocity"
    

    def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]:
        """Return sample size and required number of clients."""
        num_clients = int(num_available_clients * self.fraction_fit)
        return max(num_clients, self.min_fit_clients), self.min_available_clients

    def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]:
        """Use a fraction of available clients for evaluation."""
        num_clients = int(num_available_clients * self.fraction_evaluate)
        return max(num_clients, self.min_evaluate_clients), self.min_available_clients
    
    def generate_layer_training_sequence(self) -> List[int]:
        """Generate a sequence of layers to train."""
        layer_training_sequence = []
        for _ in range(NUM_OF_CYCLES):
            for _ in range(NUM_OF_FULL_UPDATES_BETWEEN_CYCLES):
                    layer_training_sequence.append(-1)
            for layer in range(NETWORK_LEN):
                    layer_training_sequence.append(layer)
                    layer_training_sequence.append(layer)

        return layer_training_sequence

    def initialize_parameters(
        self, client_manager: ClientManager
    ) -> Optional[Parameters]:
        """Initialize global model parameters."""
        net = Net()
        ndarrays = get_parameters(net)
        self.layer_training_sequence = self.generate_layer_training_sequence()
        self.number_of_layers = len(ndarrays)

        return ndarrays_to_parameters(ndarrays)
    


    def evaluate(
        self, server_round: int, parameters: Parameters
    ) -> Optional[tuple[float, dict[str, Scalar]]]:
        """Evaluate model parameters using an evaluation function."""
        if self.evaluate_fn is None:
            # No evaluation function provided
            return None
        parameters_ndarrays = parameters_to_ndarrays(parameters)
        eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})
        if eval_res is None:
            return None
        
        if server_round in fed_part_avg_model_results:  
            expand_fed_part_avg_model_results= {**fed_part_avg_model_results[server_round], "global_loss": eval_res[0], "global_metrics": eval_res[1]}
        else:
            expand_fed_part_avg_model_results= {"global_loss": eval_res[0], "global_metrics": eval_res[1]}
        
        fed_part_avg_model_results[server_round] = expand_fed_part_avg_model_results
        
        loss, metrics = eval_res
        return loss, metrics

    def configure_fit(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, FitIns]]:
        """Configure the next round of training."""
        
        config = {"trainable_layers": self.layer_training_sequence[self.training_sequence_index]}

        if self.global_optim_state is not None:
            config["optimizer_state"] = self.global_optim_state
        
        sample_size, min_num_clients = self.num_fit_clients(
            client_manager.num_available()
        )
        clients = client_manager.sample(
            num_clients=sample_size, min_num_clients=min_num_clients
        )
        
        print(f"Training on layer {self.layer_training_sequence}")
        fit_configurations = []
        for idx, client in enumerate(clients):
            fit_configurations.append((client, FitIns(parameters, config)))

        self.training_sequence_index = self.training_sequence_index + 1
        
        return fit_configurations
    

    def configure_evaluate(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, EvaluateIns]]:
        """Configure the next round of evaluation."""
        if self.fraction_evaluate == 0.0:
            return []
        config = {}
        evaluate_ins = EvaluateIns(parameters, config)

        # Sample clients
        sample_size, min_num_clients = self.num_evaluation_clients(
            client_manager.num_available()
        )
        clients = client_manager.sample(
            num_clients=sample_size, min_num_clients=min_num_clients
        )

        # Return client/config pairs
        return [(client, evaluate_ins) for client in clients]

    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        global global_all_optimizer_states
        """Aggregate fit results using weighted average."""
        
        total_size = 0
        for client, fit_res in results:
            total_size += get_parameters_size(fit_res.parameters)
        print(f"total size: {total_size}")
        
        weights_results = [
            (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
            for _, fit_res in results
        ]

        aggregated_weights = aggregate(weights_results)
        # parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results))

        if self.layer_training_sequence[self.training_sequence_index -1] == -1:
            self.latest_parameters = ndarrays_to_parameters(aggregated_weights)
        else:
            current_model = parameters_to_ndarrays(self.latest_parameters)
            current_model[self.layer_training_sequence[self.training_sequence_index -1]] = aggregated_weights[0]
            self.latest_parameters = ndarrays_to_parameters(current_model)

        metrics = {}
        # include optimizer state in metrics

        optimizer_states_serialized = [res.metrics.get("optimizer_state", None) for _, res in results]
        optimizer_states = [pickle.loads(base64.b64decode(state)) for state in optimizer_states_serialized if state is not None]

        aggregated_optimizer_state = self._aggregate_optimizer_states(optimizer_states)
        aggregated_optimizer_state_serialized = base64.b64encode(pickle.dumps(aggregated_optimizer_state)).decode('ascii')
        metrics["optimizer_state"] = aggregated_optimizer_state_serialized
        self.global_optim_state = aggregated_optimizer_state_serialized

        # Add sizes to total_size:
        sizes = sum([sys.getsizeof(data) for data in optimizer_states_serialized if data is not None])
        total_size += sizes

        if fed_part_avg_result.get(server_round):
            fed_part_avg_result[server_round]["total_size"] = total_size
        else:
            fed_part_avg_result[server_round] = {"total_size": total_size}

        return self.latest_parameters, metrics

    def _aggregate_optimizer_states(self, optimizer_states):
        # aggregates by comuting (unweighted) mean of 1st and 2nd momentum, and taking the maximum step value. (alternatively, could be changed to reset steps to 0)
        aggregated_state_dict = copy.deepcopy(optimizer_states[0])

        for layer in range(len(optimizer_states[0]['state'])):
            fst_momentums = []
            snd_momentums = []
            steps = []
            for s in range(len(optimizer_states)):
                fst_momentums.append(optimizer_states[s]['state'][layer]['exp_avg'])
                snd_momentums.append(optimizer_states[s]['state'][layer]['exp_avg_sq'])
                steps.append(optimizer_states[s]['state'][layer]['step'])
            
            mean_fst_momentum = torch.mean(torch.stack(fst_momentums), dim=0)
            mean_snd_momentum = torch.mean(torch.stack(snd_momentums), dim=0)
            max_step = torch.max(torch.stack(steps))

            aggregated_state_dict['state'][layer]['exp_avg'][:] = mean_fst_momentum
            aggregated_state_dict['state'][layer]['exp_avg_sq'][:] = mean_snd_momentum
            aggregated_state_dict['state'][layer]['step'] = torch.tensor([max_step])
        
        return aggregated_state_dict


    def aggregate_evaluate(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, EvaluateRes]],
        failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
    ) -> Tuple[Optional[float], Dict[str, Scalar]]:
        """Aggregate evaluation losses using weighted average."""

        if not results:
            return None, {}
        
        total_loss = 0
        for _, evaluate_res in results:
            total_loss += evaluate_res.loss

        if fed_part_avg_result.get(server_round):
            fed_part_avg_result[server_round]["total_loss"] = total_loss
        else:
            fed_part_avg_result[server_round] = {"total_loss": total_loss}

        loss_aggregated = weighted_loss_avg(
            [
                (evaluate_res.num_examples, evaluate_res.loss)
                for _, evaluate_res in results
            ]
        )
        metrics_aggregated = {}
        return loss_aggregated, metrics_aggregated

    

In [None]:
class FedAvgPartWithVelocityFlowerClient(NumPyClient):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.partition_id}] get_parameters")
        parameters = get_parameters(self.net)
        if config["trainable_layers"] == -1:
            trained_layers = parameters
        else:
            trained_layers = [parameters[config["trainable_layers"]]]
        
        return trained_layers

    def fit(self, parameters, config):
        print(f"[Client {self.partition_id}] fit, config: {config}")
        set_parameters(self.net, parameters)
        freeze_layers(self.net, config["trainable_layers"])
        self.optimizer = torch.optim.Adam(self.net.parameters())

        if "optimizer_state" in config:
            print(f"Got optimizer state in config, setting..")
            serialized_optimizer_state = config["optimizer_state"]
            optim_state = pickle.loads(base64.b64decode(serialized_optimizer_state))
            self._set_optimizer_state(optim_state)

        train(self.net, self.trainloader, epochs=EPOCHS, optimizer=self.optimizer)
        print('finished train')

        # handle optimizer state (serialize before passing)
        optim_state = self._get_optimizer_state()
        serialized_optimizer_state = base64.b64encode(pickle.dumps(optim_state)).decode('ascii')
        print(f"After training, got optim state {serialized_optimizer_state[:50]}")

        return self.get_parameters(config), len(self.trainloader), {"optimizer_state": serialized_optimizer_state}

    def evaluate(self, parameters, config):
        print(f"[Client {self.partition_id}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

    
    def _get_optimizer_state(self):
        state_dict = self.optimizer.state_dict()
        return state_dict
    
    def _set_optimizer_state(self, state_dict):
        self.optimizer.load_state_dict(state_dict)


def client_fn(context: Context) -> Client:
    net = Net().to(DEVICE)
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FedAvgPartWithVelocityFlowerClient(partition_id, net, trainloader, valloader).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

In [None]:
net = Net().to(DEVICE)

_, _, testloader = load_datasets(0, NUM_PARTITIONS)

evaluate_fn = get_evaluate_fn(testloader, net)

def server_fn(context: Context) -> ServerAppComponents:
    # Configure the server for just 3 rounds of training
    config = ServerConfig(num_rounds=NUM_OF_ROUNDS)
    return ServerAppComponents(
        config=config,
        strategy=FedPartAvgWithVelocity(
            evaluate_fn=evaluate_fn,
        ),
    )

server = ServerApp(server_fn=server_fn)

# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

In [None]:
with open(f'../data/fed_part_{EPOCHS}.p', 'wb') as file:
    pickle.dump(fed_part_avg_result, file)

In [None]:
load_run = 2
with open(f'../data/fed_part_{load_run}.p', 'rb') as file:
    data = pickle.load(file)

In [None]:
print(data)

In [None]:
import matplotlib.pyplot as plt
import numpy as np



# Plot the total size of parameters for each round
fed_part_avg_rounds = list(fed_part_avg_result.keys())
fed_part_avg_sizes = [fed_part_avg_result[round]["total_size"] for round in fed_part_avg_rounds]

plt.figure(figsize=(10, 5))
plt.plot(fed_part_avg_rounds, fed_part_avg_sizes, marker='o', linestyle='-', color='b', label='FedPartAvg')
# plt.plot(fed_avg_rounds, fed_avg_sizes, marker='o', linestyle='-', color='r', label='FedAvg')
plt.xlabel('Round')
plt.ylabel('Total Size of Parameters (bytes)')
plt.title('Total Size of Parameters for Each Round')
plt.legend()
plt.grid(True)

fed_part_avg_losses = [fed_part_avg_result[round]["total_loss"] for round in fed_part_avg_rounds]

plt.figure(figsize=(10, 5))
plt.plot(fed_part_avg_rounds, fed_part_avg_losses, marker='o', linestyle='-', color='b', label='FedPartAvg')
# plt.plot(fed_avg_rounds, fed_avg_losses, marker='o', linestyle='-', color='r', label='FedAvg')
plt.xlabel('Round')
plt.ylabel('Total Loss')
plt.title('Total Loss for Each Round')
plt.legend()
plt.grid(True)


fed_part_avg_model_rounds = list(fed_part_avg_model_results.keys())
fed_part_avg_accuracies = [fed_part_avg_model_results[round]["global_metrics"]["accuracy"] for round in fed_part_avg_model_rounds]

plt.figure(figsize=(10, 5))
plt.plot(fed_part_avg_model_rounds, fed_part_avg_accuracies, marker='o', linestyle='-', color='b', label='FedPartAvg')
# plt.plot(fed_avg_model_rounds, fed_avg_accuracies, marker='o', linestyle='-', color='r', label='FedAvg')
plt.xlabel('Round')
plt.ylabel('Accuracy')
plt.title('Accuracy for Each Round')
plt.legend()
plt.grid(True)

fed_part_avg_global_losses = [fed_part_avg_model_results[round]["global_loss"] for round in fed_part_avg_model_rounds]

plt.figure(figsize=(10, 5))
plt.plot(fed_part_avg_model_rounds, fed_part_avg_global_losses, marker='o', linestyle='-', color='b', label='FedPartAvg')
# plt.plot(fed_avg_model_rounds, fed_avg_global_losses, marker='o', linestyle='-', color='r', label='FedAvg')
plt.xlabel('Round')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss for Each Round')


# Using momentum similarity for weighing updates
Ideas:
- Weigh updates by cosine similarity between client momentum and global momentum from previous round.
- Save cosine similarities. Use these to compute the distributions of these similarities and identify 
- Use cosine similarities to identify outlier clients. 


This is likely of greater interest in heterogeneous data settings.

In [None]:
from typing import Union
from functools import partial, reduce
from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg

fed_momentum_weighting_results = {}
fed_momentum_weighting_model_results = {}

global client_opt
global prev_opt

class FedPartAvgWithVelocityweighting(Strategy):
    def __init__(
        self,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        evaluate_fn: Optional[
            Callable[
                [int, NDArrays, dict[str, Scalar]],
                Optional[tuple[float, dict[str, Scalar]]],
            ]
        ] = None,
        on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
        on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
        accept_failures: bool = True,
        initial_parameters: Optional[Parameters] = None,
        fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        inplace: bool = True,
        layer_update_strategy: str = "sequential",
        
    ) -> None:
        super().__init__()
        self.fraction_fit = fraction_fit
        self.fraction_evaluate = fraction_evaluate
        self.min_fit_clients = min_fit_clients
        self.min_evaluate_clients = min_evaluate_clients
        self.min_available_clients = min_available_clients
        self.evaluate_fn = evaluate_fn
        self.on_fit_config_fn = on_fit_config_fn
        self.on_evaluate_config_fn = on_evaluate_config_fn
        self.accept_failures = accept_failures
        self.initial_parameters = initial_parameters
        self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn
        self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn
        self.inplace = inplace

        self.layer_update_strategy = layer_update_strategy  # 'sequential' or 'cyclic'
        self.current_layer = 0  # Track which layer to update
        self.number_of_layers = None
        self.layer_training_sequence = []
        self.training_sequence_index = 0
        self.latest_parameters = initial_parameters
        self.global_optim_state = None


    def __repr__(self) -> str:
        return "FedPartAvgWithVelocity"
    

    def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]:
        """Return sample size and required number of clients."""
        num_clients = int(num_available_clients * self.fraction_fit)
        return max(num_clients, self.min_fit_clients), self.min_available_clients

    def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]:
        """Use a fraction of available clients for evaluation."""
        num_clients = int(num_available_clients * self.fraction_evaluate)
        return max(num_clients, self.min_evaluate_clients), self.min_available_clients
    
    def generate_layer_training_sequence(self) -> List[int]:
        """Generate a sequence of layers to train."""
        layer_training_sequence = []
        for _ in range(NUM_OF_CYCLES):
            for _ in range(NUM_OF_FULL_UPDATES_BETWEEN_CYCLES):
                    layer_training_sequence.append(-1)
            for layer in range(NETWORK_LEN):
                    layer_training_sequence.append(layer)
                    layer_training_sequence.append(layer)

        return layer_training_sequence

    def initialize_parameters(
        self, client_manager: ClientManager
    ) -> Optional[Parameters]:
        """Initialize global model parameters."""
        net = Net()
        ndarrays = get_parameters(net)
        self.layer_training_sequence = self.generate_layer_training_sequence()
        self.number_of_layers = len(ndarrays)

        return ndarrays_to_parameters(ndarrays)
    


    def evaluate(
        self, server_round: int, parameters: Parameters
    ) -> Optional[tuple[float, dict[str, Scalar]]]:
        """Evaluate model parameters using an evaluation function."""
        if self.evaluate_fn is None:
            # No evaluation function provided
            return None
        parameters_ndarrays = parameters_to_ndarrays(parameters)
        eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})
        if eval_res is None:
            return None
        
        if server_round in fed_momentum_weighting_model_results:  
            expand_fed_part_avg_model_results= {**fed_momentum_weighting_model_results[server_round], "global_loss": eval_res[0], "global_metrics": eval_res[1]}
        else:
            expand_fed_part_avg_model_results= {"global_loss": eval_res[0], "global_metrics": eval_res[1]}
        
        fed_momentum_weighting_model_results[server_round] = expand_fed_part_avg_model_results
        
        loss, metrics = eval_res
        return loss, metrics

    def configure_fit(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, FitIns]]:
        """Configure the next round of training."""
        
        config = {"trainable_layers": self.layer_training_sequence[self.training_sequence_index]}

        if self.global_optim_state is not None:
            config["optimizer_state"] = self.global_optim_state
        
        sample_size, min_num_clients = self.num_fit_clients(
            client_manager.num_available()
        )
        clients = client_manager.sample(
            num_clients=sample_size, min_num_clients=min_num_clients
        )
        
        print(f"Training on layer {self.layer_training_sequence}")
        fit_configurations = []
        for idx, client in enumerate(clients):
            fit_configurations.append((client, FitIns(parameters, config)))

        self.training_sequence_index = self.training_sequence_index + 1
        
        return fit_configurations
    

    def configure_evaluate(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, EvaluateIns]]:
        """Configure the next round of evaluation."""
        if self.fraction_evaluate == 0.0:
            return []
        config = {}
        evaluate_ins = EvaluateIns(parameters, config)

        # Sample clients
        sample_size, min_num_clients = self.num_evaluation_clients(
            client_manager.num_available()
        )
        clients = client_manager.sample(
            num_clients=sample_size, min_num_clients=min_num_clients
        )

        # Return client/config pairs
        return [(client, evaluate_ins) for client in clients]

    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        global global_all_optimizer_states
        """Aggregate fit results using weighted average."""
        
        total_size = 0
        for client, fit_res in results:
            total_size += get_parameters_size(fit_res.parameters)
        print(f"total size: {total_size}")
        
        weights_results = [
            (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
            for _, fit_res in results
        ]

        metrics = {}
        optimizer_states_serialized = [res.metrics.get("optimizer_state", None) for _, res in results]
        optimizer_states = [pickle.loads(base64.b64decode(state)) for state in optimizer_states_serialized if state is not None]

        # Get cosine similarities
        if server_round == 1:
            # First round doesn't have a previous state. Use average momentum instead
            equal_weights = [1] * len(optimizer_states)
            prev_global_optimizer_state = self._aggregate_optimizer_states_weighted(optimizer_states, equal_weights)
        else:
            prev_global_optimizer_state = pickle.loads(base64.b64decode(self.global_optim_state))
        cos_similarities = self._cos_similarity_from_optimizer_states(optimizer_states, prev_global_optimizer_state)

        # Aggregate nn params using weighting from momentum
        aggregated_weights = self._aggregate_params_weighted(weights_results, cos_similarities)

        if self.layer_training_sequence[self.training_sequence_index -1] == -1:
            self.latest_parameters = ndarrays_to_parameters(aggregated_weights)
        else:
            current_model = parameters_to_ndarrays(self.latest_parameters)
            current_model[self.layer_training_sequence[self.training_sequence_index -1]] = aggregated_weights[0]
            self.latest_parameters = ndarrays_to_parameters(current_model)

        # aggregate optimizer state using weighting from momentum
        aggregated_optimizer_state = self._aggregate_optimizer_states_weighted(optimizer_states, cos_similarities)
        aggregated_optimizer_state_serialized = base64.b64encode(pickle.dumps(aggregated_optimizer_state)).decode('ascii')
        metrics["optimizer_state"] = aggregated_optimizer_state_serialized
        self.global_optim_state = aggregated_optimizer_state_serialized

        # Add sizes to total_size:
        sizes = sum([sys.getsizeof(data) for data in optimizer_states_serialized if data is not None])
        total_size += sizes

        if fed_momentum_weighting_results.get(server_round):
            fed_momentum_weighting_results[server_round]["total_size"] = total_size
        else:
            fed_momentum_weighting_results[server_round] = {"total_size": total_size}

        return self.latest_parameters, metrics


    def _momentum_based_aggregate(results: list[tuple[NDArrays, int]], cosine_similarities) -> NDArrays:
        """Compute weighted average of params."""
        # Calculate the total number of examples used during training
        num_examples_total = sum(num_examples for (_, num_examples) in results)

        # Create a list of weights, each multiplied by the related number of examples
        weighted_weights = [
            [layer * num_examples for layer in weights] for weights, num_examples in results
        ]

        # Compute average weights of each layer
        weights_prime: NDArrays = [
            reduce(np.add, layer_updates) / num_examples_total
            for layer_updates in zip(*weighted_weights)
        ]
        return weights_prime

    def _aggregate_optimizer_states_weighted(self, optimizer_states, similarities):
        # aggregates by comuting mean of 1st and 2nd momentum, and taking the maximum step value. (alternatively, could be changed to reset steps to 0)
        # uses direct normalization (weight of similarity divided by sum of similarities)
        aggregated_state_dict = copy.deepcopy(optimizer_states[0])
        sum_similarities = sum(similarities)

        for layer in range(len(optimizer_states[0]['state'])):
            weighted_fst_momentum = torch.zeros_like(optimizer_states[0]['state'][layer]['exp_avg'])
            weighted_snd_momentum = torch.zeros_like(optimizer_states[0]['state'][layer]['exp_avg_sq'])
            
            # Accumulate weighted momentums from each client
            for i in range(len(optimizer_states)):
                weight = similarities[i] / sum_similarities
                weighted_fst_momentum += optimizer_states[i]['state'][layer]['exp_avg'] * weight
                weighted_snd_momentum += optimizer_states[i]['state'][layer]['exp_avg_sq'] * weight

            aggregated_state_dict['state'][layer]['exp_avg'] = weighted_fst_momentum
            aggregated_state_dict['state'][layer]['exp_avg_sq'] = weighted_snd_momentum
        max_steps = max([optimizer_states[i]['state'][layer]['step'] for i in range(len(optimizer_states))])
        aggregated_state_dict['state'][layer]['step'] = max_steps
        
        return aggregated_state_dict

    def _aggregate_params_weighted(self, results: list[tuple[NDArrays, int]], cos_similarities: list[float]) -> NDArrays:
        """Compute cos similarity weighted average."""
        # Calculate the total number of examples used during training
        assert len(results) == len(cos_similarities)
        sum_sim = sum(cos_similarities)

        # Create a list of weights, each multiplied by the cos similiarity
        weighted_weights = [
            [layer * cos_similarities[i] for layer in results[i][0]] for i in range(len(results))
        ]

        # Compute average weights of each layer
        weights_prime: NDArrays = [
            reduce(np.add, layer_updates) / sum_sim
            for layer_updates in zip(*weighted_weights)
        ]
        return weights_prime

    def _cos_similarity_from_optimizer_states(self, optimizer_states, prev_global_optimizer_state):
        # Returns: cosine similarity from each optimizer state to the prev global state
        cos = torch.nn.CosineSimilarity(dim=-1)
        similarities = []
        layers = range(len(optimizer_states[0]['state']))

        # get global opt params as vector
        fst_momentum_glob = [prev_global_optimizer_state['state'][l]['exp_avg'].view(-1) for l in layers]
        snd_momentum_glob = [prev_global_optimizer_state['state'][l]['exp_avg_sq'].view(-1) for l in layers]
        glob_state_vec = torch.cat(fst_momentum_glob + snd_momentum_glob)

        for client in range(len(optimizer_states)):
            # get client opt params as vector
            fst_momentum_client = [optimizer_states[client]['state'][l]['exp_avg'].view(-1) for l in layers]
            snd_momentum_client = [optimizer_states[client]['state'][l]['exp_avg_sq'].view(-1) for l in layers]
            client_state_vec = torch.cat(fst_momentum_client + snd_momentum_client)

            sim = cos(client_state_vec, glob_state_vec)
            similarities.append(sim.item())

        return similarities

    def aggregate_evaluate(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, EvaluateRes]],
        failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
    ) -> Tuple[Optional[float], Dict[str, Scalar]]:
        """Aggregate evaluation losses using weighted average."""

        if not results:
            return None, {}
        
        total_loss = 0
        for _, evaluate_res in results:
            total_loss += evaluate_res.loss

        if fed_momentum_weighting_results.get(server_round):
            fed_momentum_weighting_results[server_round]["total_loss"] = total_loss
        else:
            fed_momentum_weighting_results[server_round] = {"total_loss": total_loss}

        loss_aggregated = weighted_loss_avg(
            [
                (evaluate_res.num_examples, evaluate_res.loss)
                for _, evaluate_res in results
            ]
        )
        metrics_aggregated = {}
        return loss_aggregated, metrics_aggregated

    

In [None]:
class FedAvgPartWithVelocityWeightingFlowerClient(NumPyClient):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.partition_id}] get_parameters")
        parameters = get_parameters(self.net)
        if config["trainable_layers"] == -1:
            trained_layers = parameters
        else:
            trained_layers = [parameters[config["trainable_layers"]]]
        
        return trained_layers

    def fit(self, parameters, config):
        print(f"[Client {self.partition_id}] fit, config: {config}")
        set_parameters(self.net, parameters)
        freeze_layers(self.net, config["trainable_layers"])
        self.optimizer = torch.optim.Adam(self.net.parameters())

        if "optimizer_state" in config:
            print(f"Got optimizer state in config, setting..")
            serialized_optimizer_state = config["optimizer_state"]
            optim_state = pickle.loads(base64.b64decode(serialized_optimizer_state))
            self._set_optimizer_state(optim_state)

        train(self.net, self.trainloader, epochs=EPOCHS, optimizer=self.optimizer)
        print('finished train')

        # handle optimizer state (serialize before passing)
        optim_state = self._get_optimizer_state()
        serialized_optimizer_state = base64.b64encode(pickle.dumps(optim_state)).decode('ascii')
        print(f"After training, got optim state {serialized_optimizer_state[:50]}")

        return self.get_parameters(config), len(self.trainloader), {"optimizer_state": serialized_optimizer_state}

    def evaluate(self, parameters, config):
        print(f"[Client {self.partition_id}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

    
    def _get_optimizer_state(self):
        state_dict = self.optimizer.state_dict()
        return state_dict
    
    def _set_optimizer_state(self, state_dict):
        self.optimizer.load_state_dict(state_dict)


def client_fn(context: Context) -> Client:
    net = Net().to(DEVICE)
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FedAvgPartWithVelocityWeightingFlowerClient(partition_id, net, trainloader, valloader).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

In [None]:
net = Net().to(DEVICE)

_, _, testloader = load_datasets(0, NUM_PARTITIONS)

evaluate_fn = get_evaluate_fn(testloader, net)

def server_fn(context: Context) -> ServerAppComponents:
    # Configure the server for just 3 rounds of training
    config = ServerConfig(num_rounds=NUM_OF_ROUNDS)
    return ServerAppComponents(
        config=config,
        strategy=FedPartAvgWithVelocityweighting(
            evaluate_fn=evaluate_fn,
        ),
    )

server = ServerApp(server_fn=server_fn)

# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Plot the total size of parameters for each round
fed_momentum_weighting_rounds = list(fed_momentum_weighting_results.keys())
fed_momentum_weighting_sizes = [fed_momentum_weighting_results[round]["total_size"] for round in fed_momentum_weighting_rounds]

plt.figure(figsize=(10, 5))
plt.plot(fed_momentum_weighting_rounds, fed_momentum_weighting_sizes, marker='o', linestyle='-', color='b', label='FedPartWithVelocityWeighting')
# plt.plot(fed_avg_rounds, fed_avg_sizes, marker='o', linestyle='-', color='r', label='FedAvg')
plt.xlabel('Round')
plt.ylabel('Total Size of Parameters (bytes)')
plt.title('Total Size of Parameters for Each Round')
plt.legend()
plt.grid(True)

fed_momentum_weighting_losses = [fed_momentum_weighting_results[round]["total_loss"] for round in fed_momentum_weighting_rounds]

plt.figure(figsize=(10, 5))
plt.plot(fed_momentum_weighting_rounds, fed_momentum_weighting_losses, marker='o', linestyle='-', color='b', label='FedPartWithVelocityWeighting')
# plt.plot(fed_avg_rounds, fed_avg_losses, marker='o', linestyle='-', color='r', label='FedAvg')
plt.xlabel('Round')
plt.ylabel('Total Loss')
plt.title('Total Loss for Each Round')
plt.legend()
plt.grid(True)


fed_momentum_weighting_model_rounds = list(fed_momentum_weighting_model_results.keys())
fed_momentum_weighting_accuracies = [fed_momentum_weighting_model_results[round]["global_metrics"]["accuracy"] for round in fed_momentum_weighting_model_rounds]

plt.figure(figsize=(10, 5))
plt.plot(fed_momentum_weighting_model_rounds, fed_momentum_weighting_accuracies, marker='o', linestyle='-', color='b', label='FedPartWithVelocityWeighting')
# plt.plot(fed_avg_model_rounds, fed_avg_accuracies, marker='o', linestyle='-', color='r', label='FedAvg')
plt.xlabel('Round')
plt.ylabel('Accuracy')
plt.title('Accuracy for Each Round')
plt.legend()
plt.grid(True)

fed_momentum_weighting_global_losses = [fed_momentum_weighting_model_results[round]["global_loss"] for round in fed_momentum_weighting_model_rounds]

plt.figure(figsize=(10, 5))
plt.plot(fed_momentum_weighting_model_rounds, fed_momentum_weighting_global_losses, marker='o', linestyle='-', color='b', label='FedPartWithVelocityWeighting')
# plt.plot(fed_avg_model_rounds, fed_avg_global_losses, marker='o', linestyle='-', color='r', label='FedAvg')
plt.xlabel('Round')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss for Each Round')


# Momentum similarity weighted updates with data heterogeneity

In [None]:

BATCH_SIZE = 32
from flwr_datasets.partitioner import DirichletPartitioner

def load_heterogeneous_datasets(partition_id, num_partitions: int):
    drichlet_partitioner = DirichletPartitioner(num_partitions=num_partitions, alpha=0.1, partition_by="label")
    fds = FederatedDataset(dataset="cifar10", partitioners={"train": drichlet_partitioner})
    partition = fds.load_partition(partition_id)
    # Divide data on each node: 80% train, 20% test
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
    pytorch_transforms = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    def apply_transforms(batch):
        # Instead of passing transforms to CIFAR10(..., transform=transform)
        # we will use this function to dataset.with_transform(apply_transforms)
        # The transforms object is exactly the same
        batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
        return batch

    partition_train_test = partition_train_test.with_transform(apply_transforms)
    trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
    valloader = DataLoader(partition_train_test["test"], batch_size=32)
    testset = fds.load_split("test").with_transform(apply_transforms)
    testloader = DataLoader(testset, batch_size=32)
    return trainloader, valloader, testloader


def heterogeneous_client_fn(context: Context) -> Client:
    net = Net().to(DEVICE)
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    trainloader, valloader, _ = load_heterogeneous_datasets(partition_id, num_partitions)
    return FedAvgPartWithVelocityWeightingFlowerClient(partition_id, net, trainloader, valloader).to_client()


# Create the ClientApp
client = ClientApp(client_fn=heterogeneous_client_fn)

In [None]:
net = Net().to(DEVICE)

_, _, testloader = load_datasets(0, NUM_PARTITIONS)

evaluate_fn = get_evaluate_fn(testloader, net)

def server_fn(context: Context) -> ServerAppComponents:
    # Configure the server for just 3 rounds of training
    config = ServerConfig(num_rounds=NUM_OF_ROUNDS)
    return ServerAppComponents(
        config=config,
        strategy=FedPartAvgWithVelocityweighting(
            evaluate_fn=evaluate_fn,
        ),
    )

server = ServerApp(server_fn=server_fn)

# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Plot the total size of parameters for each round
fed_momentum_weighting_rounds = list(fed_momentum_weighting_results.keys())
fed_momentum_weighting_sizes = [fed_momentum_weighting_results[round]["total_size"] for round in fed_momentum_weighting_rounds]

plt.figure(figsize=(10, 5))
plt.plot(fed_momentum_weighting_rounds, fed_momentum_weighting_sizes, marker='o', linestyle='-', color='b', label='FedPartWithVelocityWeighting')
# plt.plot(fed_avg_rounds, fed_avg_sizes, marker='o', linestyle='-', color='r', label='FedAvg')
plt.xlabel('Round')
plt.ylabel('Total Size of Parameters (bytes)')
plt.title('Total Size of Parameters for Each Round (With Data Heterogeneity)')
plt.legend()
plt.grid(True)

fed_momentum_weighting_losses = [fed_momentum_weighting_results[round]["total_loss"] for round in fed_momentum_weighting_rounds]

plt.figure(figsize=(10, 5))
plt.plot(fed_momentum_weighting_rounds, fed_momentum_weighting_losses, marker='o', linestyle='-', color='b', label='FedPartWithVelocityWeighting')
# plt.plot(fed_avg_rounds, fed_avg_losses, marker='o', linestyle='-', color='r', label='FedAvg')
plt.xlabel('Round')
plt.ylabel('Total Loss')
plt.title('Total Loss for Each Round  (With Data Heterogeneity)')
plt.legend()
plt.grid(True)


fed_momentum_weighting_model_rounds = list(fed_momentum_weighting_model_results.keys())
fed_momentum_weighting_accuracies = [fed_momentum_weighting_model_results[round]["global_metrics"]["accuracy"] for round in fed_momentum_weighting_model_rounds]

plt.figure(figsize=(10, 5))
plt.plot(fed_momentum_weighting_model_rounds, fed_momentum_weighting_accuracies, marker='o', linestyle='-', color='b', label='FedPartWithVelocityWeighting')
# plt.plot(fed_avg_model_rounds, fed_avg_accuracies, marker='o', linestyle='-', color='r', label='FedAvg')
plt.xlabel('Round')
plt.ylabel('Accuracy')
plt.title('Accuracy for Each Round  (With Data Heterogeneity)')
plt.legend()
plt.grid(True)

fed_momentum_weighting_global_losses = [fed_momentum_weighting_model_results[round]["global_loss"] for round in fed_momentum_weighting_model_rounds]

plt.figure(figsize=(10, 5))
plt.plot(fed_momentum_weighting_model_rounds, fed_momentum_weighting_global_losses, marker='o', linestyle='-', color='b', label='FedPartWithVelocityWeighting')
# plt.plot(fed_avg_model_rounds, fed_avg_global_losses, marker='o', linestyle='-', color='r', label='FedAvg')
plt.xlabel('Round')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss for Each Round  (With Data Heterogeneity)')
