### Velocity demo
Current approach shares full optimizer state between rounds. Optimizer states are aggregated when received on server end and then shared out such that all clients begin the next round with the aggregated optimizer state.
Limitations: 
- We are not measuring communication from this, but if desired, it should be pretty easy to add
- Currently, we share the entire optimizer state, but when training only a single layer, we would only need to share the state related to that layer. To only share updated layer, modify _aggregate_optimizer_states.

In [28]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union, Callable


import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import copy
import sys
import base64
import pickle
import copy
import torch.nn.functional as F
import torchvision.transforms as transforms
from datasets.utils.logging import disable_progress_bar
from torch.utils.data import DataLoader
from flwr.server.strategy import Strategy
import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Metrics, Context, Status, GetParametersRes, Parameters, GetParametersIns, MetricsAggregationFn,NDArrays,Scalar
from flwr.server import ServerApp, ServerConfig, ServerAppComponents 
from flwr.server.strategy import FedAvg, FedProx
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset
from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    ParametersRecord,
    parameters_to_ndarrays,
    array_from_numpy
)
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = "mps"
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")
disable_progress_bar()

Training on mps
Flower 1.15.1 / PyTorch 2.6.0


In [29]:

BATCH_SIZE = 32

def load_datasets(partition_id, num_partitions: int):
    fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions})
    partition = fds.load_partition(partition_id)
    # Divide data on each node: 80% train, 20% test
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
    pytorch_transforms = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    def apply_transforms(batch):
        
        batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
        return batch

    partition_train_test = partition_train_test.with_transform(apply_transforms)
    trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
    valloader = DataLoader(partition_train_test["test"], batch_size=32)
    testset = fds.load_split("test").with_transform(apply_transforms)
    testloader = DataLoader(testset, batch_size=32)
    return trainloader, valloader, testloader

In [30]:
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters, trainable_layers=-1):
    """Set model parameters from a list of NumPy arrays."""
    current_state = OrderedDict(net.state_dict())
    
    if trainable_layers == -1:
        # Update all parameters
        params_dict = zip(current_state.keys(), parameters)
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        net.load_state_dict(state_dict, strict=True)
    else:
        # Only update the specified layer's parameters
        # Convert current state to numpy arrays
        numpy_state = [param.cpu().numpy() for param in current_state.values()]
        
        # Update the specific indices with new parameters
        numpy_state[trainable_layers*2] = parameters[0]
        numpy_state[trainable_layers*2 + 1] = parameters[1]
        
        # Convert back to torch and update state dict
        for idx, key in enumerate(current_state.keys()):
            current_state[key] = torch.from_numpy(numpy_state[idx])
        
        net.load_state_dict(current_state, strict=True)


def train(net, trainloader, epochs: int, optimizer=None):
    """Train the network on the training set."""
    print(f"training network...")
    criterion = torch.nn.CrossEntropyLoss()
    if optimizer is None:
        optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for batch in trainloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

def freeze_layers(model: torch.nn.Module, trainable_layers: int) -> None:
        """Freeze specified layers of the model."""
        trainable_layers_set = []
        if trainable_layers == -1:
            trainable_layers_set = [-1]
        else:
            trainable_layers_set = [trainable_layers *2, trainable_layers *2 +1]

        for idx, (name, param) in enumerate(model.named_parameters()):
            
            if idx in trainable_layers_set or trainable_layers_set[0] == -1:
                param.requires_grad = True
                print(f"layer index is {idx} and name{name} is trainabe")
            else:
                param.requires_grad = False
                print(f"layer index is {idx} and name{name} is frozen")


def get_parameters_size(params: Parameters) -> int:
    size = sys.getsizeof(params)  # Base size of the dataclass instance
    size += sys.getsizeof(params.tensor_type)  # Size of the string
    size += sys.getsizeof(params.tensors)  # Size of the list container
    size += sum(sys.getsizeof(tensor) for tensor in params.tensors)  # Size of each bytes object
    return size

def serialize_optimizer_state(state_dict):
    """Serialize optimizer state with reduced memory footprint"""
    lightweight_state = {'state': {}, 'param_groups': state_dict['param_groups']}
    
    for k, v in state_dict['state'].items():
        lightweight_state['state'][k] = {
            'exp_avg': v['exp_avg'].clone().detach().cpu().half().numpy(),  # Use half precision
            'exp_avg_sq': v['exp_avg_sq'].clone().detach().cpu().half().numpy(),
            'step': v['step'].item()  # Store as scalar instead of tensor
        }
    
    return base64.b64encode(pickle.dumps(lightweight_state, protocol=pickle.HIGHEST_PROTOCOL)).decode('ascii')

def deserialize_optimizer_state(serialized_state):
    """Deserialize and reconstruct optimizer state"""
    state_dict = pickle.loads(base64.b64decode(serialized_state))
    reconstructed_state = {'state': {}, 'param_groups': state_dict['param_groups']}
    
    for k, v in state_dict['state'].items():
        reconstructed_state['state'][k] = {
            'exp_avg': torch.tensor(v['exp_avg'], dtype=torch.float32),
            'exp_avg_sq': torch.tensor(v['exp_avg_sq'], dtype=torch.float32),
            'step': torch.tensor(v['step'])
        }
    
    return reconstructed_state


In [31]:

NETWORK_LEN = len(Net().state_dict().keys()) // 2
EPOCHS = 4
NUM_PARTITIONS = 6
NUM_OF_CYCLES  = 1
NUM_OF_FULL_UPDATES_BETWEEN_CYCLES = 1
NUM_OF_ROUNDS = (NUM_OF_CYCLES * NUM_OF_FULL_UPDATES_BETWEEN_CYCLES) + (NUM_OF_CYCLES * NETWORK_LEN *2)
print(f"Number of rounds: {NUM_OF_ROUNDS}")
backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 0.0}}


Number of rounds: 11


In [32]:
from flwr.common import NDArrays, Scalar


# More robust evaluate function:
def get_evaluate_fn(
    testloader: DataLoader,
    net: torch.nn.Module,
) -> Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]:
    """Return an evaluation function for server-side evaluation."""
    
    # used to check if they're changing
    previous_params = None
    
    def evaluate(
        server_round: int, parameters: NDArrays, config: Dict[str, Scalar]
    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        """Use the entire test set for evaluation."""
        nonlocal previous_params
        
        print(f"\n==== Server-side evaluation for round {server_round} ====")
        
        # Check if parameters changed from previous round
        if previous_params is not None:
            param_change = False
            for i, (prev, curr) in enumerate(zip(previous_params, parameters)):
                diff = np.abs(prev - curr).mean()
                if diff > 1e-6:
                    param_change = True
                    print(f"  Parameter {i}: Changed by {diff:.6f}")
            
            if not param_change:
                print("  WARNING: Parameters haven't changed from previous round!")
        
        previous_params = [p.copy() for p in parameters]
        net_copy = copy.deepcopy(net)

        # Update model with the latest parameters
        params_dict = zip(net_copy.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v, device=DEVICE) for k, v in params_dict})
        
        # Check if state dict keys match model keys
        model_keys = set(net_copy.state_dict().keys())
        params_keys = set(state_dict.keys())
        if model_keys != params_keys:
            print(f"  WARNING: Key mismatch between model and parameters!")
            print(f"  Missing in params: {model_keys - params_keys}")
            print(f"  Extra in params: {params_keys - model_keys}")
        
        net_copy.load_state_dict(state_dict, strict=True)
        net_copy.to(DEVICE)
        net_copy.eval()
        
        # Test the model
        loss, accuracy = test(net_copy, testloader)
        print(f"  Evaluation results - Loss: {loss:.4f}, Accuracy: {accuracy:.4f}")
        
        # Return loss and metrics
        return loss, {"accuracy": accuracy}
    
    return evaluate

# def get_evaluate_fn(
#     testloader: DataLoader,
#     net: torch.nn.Module,
# ) -> Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]:
#     """Return an evaluation function for server-side evaluation."""

#     def evaluate(
#         server_round: int, parameters: NDArrays, config: Dict[str, Scalar]
#     ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
#         """Use the entire test set for evaluation."""
        
#         # Copy model parameters to avoid modifying the original
#         net_copy = copy.deepcopy(net)
        
#         # Update model with the latest parameters
#         params_dict = zip(net_copy.state_dict().keys(), parameters)
#         state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
#         net_copy.load_state_dict(state_dict, strict=True)
        
#         net_copy.to(DEVICE)
#         net_copy.eval()

#         # Test the model
#         loss, accuracy = test(net_copy, testloader)
        
#         # Return loss and metrics
#         return loss, {"accuracy": accuracy}

#     return evaluate

# FedAvg with velocity from global model

In [33]:
from typing import Union

from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg

fed_part_avg_result = {}
fed_part_avg_model_results = {}

class FedPartAvgWithVelocity(Strategy):
    def __init__(
        self,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        evaluate_fn: Optional[
            Callable[
                [int, NDArrays, dict[str, Scalar]],
                Optional[tuple[float, dict[str, Scalar]]],
            ]
        ] = None,
        on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
        on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
        accept_failures: bool = True,
        initial_parameters: Optional[Parameters] = None,
        fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        inplace: bool = True,
        layer_update_strategy: str = "sequential",
        
    ) -> None:
        super().__init__()
        self.fraction_fit = fraction_fit
        self.fraction_evaluate = fraction_evaluate
        self.min_fit_clients = min_fit_clients
        self.min_evaluate_clients = min_evaluate_clients
        self.min_available_clients = min_available_clients
        self.evaluate_fn = evaluate_fn
        self.on_fit_config_fn = on_fit_config_fn
        self.on_evaluate_config_fn = on_evaluate_config_fn
        self.accept_failures = accept_failures
        self.initial_parameters = initial_parameters
        self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn
        self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn
        self.inplace = inplace

        self.layer_update_strategy = layer_update_strategy  # 'sequential' or 'cyclic'
        self.current_layer = 0  # Track which layer to update
        self.number_of_layers = None
        self.layer_training_sequence = []
        self.training_sequence_index = 0
        self.latest_parameters = initial_parameters
        self.updated_layers = -1
        self.global_optim_state = None


    def __repr__(self) -> str:
        return "FedPartAvgWithVelocity"
    

    def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]:
        """Return sample size and required number of clients."""
        num_clients = int(num_available_clients * self.fraction_fit)
        return max(num_clients, self.min_fit_clients), self.min_available_clients

    def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]:
        """Use a fraction of available clients for evaluation."""
        num_clients = int(num_available_clients * self.fraction_evaluate)
        return max(num_clients, self.min_evaluate_clients), self.min_available_clients
    
    def generate_layer_training_sequence(self) -> List[int]:
        """Generate a sequence of layers to train."""
        layer_training_sequence = []
        for _ in range(NUM_OF_CYCLES):
            for _ in range(NUM_OF_FULL_UPDATES_BETWEEN_CYCLES):
                    layer_training_sequence.append(-1)
            for layer in range(NETWORK_LEN):
                    layer_training_sequence.append(layer)
                    layer_training_sequence.append(layer)

        return layer_training_sequence

    def initialize_parameters(
        self, client_manager: ClientManager
    ) -> Optional[Parameters]:
        """Initialize global model parameters."""
        net = Net()
        ndarrays = get_parameters(net)
        self.layer_training_sequence = self.generate_layer_training_sequence()
        self.number_of_layers = len(ndarrays)
        self.latest_parameters = ndarrays_to_parameters(ndarrays)

        return ndarrays_to_parameters(ndarrays)
    


    def evaluate(
        self, server_round: int, parameters: Parameters
    ) -> Optional[tuple[float, dict[str, Scalar]]]:
        """Evaluate model parameters using an evaluation function."""
        if self.evaluate_fn is None:
            # No evaluation function provided
            return None
        parameters_ndarrays = parameters_to_ndarrays(parameters)
        eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})
        if eval_res is None:
            return None
        
        if server_round in fed_part_avg_model_results:  
            expand_fed_part_avg_model_results= {**fed_part_avg_model_results[server_round], "global_loss": eval_res[0], "global_metrics": eval_res[1]}
        else:
            expand_fed_part_avg_model_results= {"global_loss": eval_res[0], "global_metrics": eval_res[1]}
        
        fed_part_avg_model_results[server_round] = expand_fed_part_avg_model_results
        
        loss, metrics = eval_res
        return loss, metrics

    def configure_fit(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, FitIns]]:
        """Configure the next round of training."""
        
        config = {"trainable_layers": self.layer_training_sequence[self.training_sequence_index], "updated_layers": self.updated_layers}

        if self.global_optim_state is not None:
            config["optimizer_state"] = self.global_optim_state
        
        sample_size, min_num_clients = self.num_fit_clients(
            client_manager.num_available()
        )
        clients = client_manager.sample(
            num_clients=sample_size, min_num_clients=min_num_clients
        )
        
        print(f"Training on layer {self.layer_training_sequence}")
        fit_configurations = []

        params_array = parameters_to_ndarrays(parameters)
        
        # If doing full model update, send all parameters
        if self.layer_training_sequence[self.training_sequence_index] == -1 or self.updated_layers == -1:
            selected_params = parameters
        else:
            layer_idx = self.updated_layers
            selected_params = ndarrays_to_parameters([
                    params_array[layer_idx * 2],     # Weight
                    params_array[layer_idx * 2 + 1]  # Bias
                ])

        for idx, client in enumerate(clients):
            fit_configurations.append((client, FitIns(selected_params, config)))

        self.updated_layers = self.layer_training_sequence[self.training_sequence_index]
        self.training_sequence_index = self.training_sequence_index + 1
        
        return fit_configurations
    

    def configure_evaluate(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, EvaluateIns]]:
        """Configure the next round of evaluation."""
        if self.fraction_evaluate == 0.0:
            return []
        config = {}
        evaluate_ins = EvaluateIns(parameters, config)

        # Sample clients
        sample_size, min_num_clients = self.num_evaluation_clients(
            client_manager.num_available()
        )
        clients = client_manager.sample(
            num_clients=sample_size, min_num_clients=min_num_clients
        )

        # Return client/config pairs
        return [(client, evaluate_ins) for client in clients]

    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        global global_all_optimizer_states
        """Aggregate fit results using weighted average."""
        
        total_size = 0
        for client, fit_res in results:
            total_size += get_parameters_size(fit_res.parameters)
            total_size += fit_res.metrics["recieved_parameter_size"]

        print(f"total size: {total_size}")
        
        if fed_part_avg_result.get(server_round):
            fed_part_avg_result[server_round]["total_size"] = total_size
        else:
            fed_part_avg_result[server_round] = {"total_size": total_size}
        


        weights_results = [
            (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
            for _, fit_res in results
        ]

        aggregated_weights = aggregate(weights_results)
        trained_layer = results[0][1].metrics["trained_layer"]
        print(f"aggregated weight size {len(aggregated_weights)} ")

        if trained_layer == -1:
            self.latest_parameters = ndarrays_to_parameters(aggregated_weights)
        else:
            current_model = parameters_to_ndarrays(self.latest_parameters)
            print(f"updateing layers {trained_layer* 2}  and {trained_layer* 2 + 1} ")
            current_model[trained_layer* 2] = aggregated_weights[0]
            current_model[trained_layer* 2 +1] = aggregated_weights[1]
            self.latest_parameters = ndarrays_to_parameters(current_model)

        # include optimizer state in metrics
        metrics = {}
        optimizer_states_serialized = [res.metrics.get("optimizer_state", None) for _, res in results]
        optimizer_states = [deserialize_optimizer_state(state) for state in optimizer_states_serialized if state is not None]
        aggregated_optimizer_state = self._aggregate_optimizer_states(optimizer_states)
        aggregated_optimizer_state_serialized = serialize_optimizer_state(aggregated_optimizer_state)
        metrics["optimizer_state"] = aggregated_optimizer_state_serialized
        self.global_optim_state = aggregated_optimizer_state_serialized

        return self.latest_parameters, metrics


    def _aggregate_optimizer_states(self, optimizer_states):
        # aggregates by comuting (unweighted) mean of 1st and 2nd momentum, and taking the maximum step value. (alternatively, could be changed to reset steps to 0)
        """Aggregate optimizer states with reduced memory usage"""
        if not optimizer_states:
            return None
            
        # Start with an empty state dict structure
        aggregated_state_dict = {'state': {}, 'param_groups': optimizer_states[0]['param_groups']}
        
        # Get all layer keys
        all_layers = list(optimizer_states[0]['state'].keys())
        
        for layer in all_layers:
            # Initialize with zeros instead of stacking
            if layer not in aggregated_state_dict['state']:
                aggregated_state_dict['state'][layer] = {}
                
            # Get shape from first optimizer
            exp_avg_shape = optimizer_states[0]['state'][layer]['exp_avg'].shape
            exp_avg_sq_shape = optimizer_states[0]['state'][layer]['exp_avg_sq'].shape
            
            # Pre-allocate tensors
            sum_exp_avg = torch.zeros(exp_avg_shape, device='cpu')
            sum_exp_avg_sq = torch.zeros(exp_avg_sq_shape, device='cpu')
            max_step = 0
            
            # Sum without stacking
            for state in optimizer_states:
                if layer in state['state']:
                    sum_exp_avg += state['state'][layer]['exp_avg'].cpu()
                    sum_exp_avg_sq += state['state'][layer]['exp_avg_sq'].cpu()
                    max_step = max(max_step, state['state'][layer]['step'].item())
            
            # Average
            n_states = len(optimizer_states)
            aggregated_state_dict['state'][layer] = {
                'exp_avg': sum_exp_avg / n_states,
                'exp_avg_sq': sum_exp_avg_sq / n_states,
                'step': torch.tensor(max_step)
            }
        
        return aggregated_state_dict


    def aggregate_evaluate(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, EvaluateRes]],
        failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
    ) -> Tuple[Optional[float], Dict[str, Scalar]]:
        """Aggregate evaluation losses using weighted average."""

        if not results:
            return None, {}
        
        total_loss = 0
        for _, evaluate_res in results:
            total_loss += evaluate_res.loss

        if fed_part_avg_result.get(server_round):
            fed_part_avg_result[server_round]["total_loss"] = total_loss
        else:
            fed_part_avg_result[server_round] = {"total_loss": total_loss}

        loss_aggregated = weighted_loss_avg(
            [
                (evaluate_res.num_examples, evaluate_res.loss)
                for _, evaluate_res in results
            ]
        )
        metrics_aggregated = {}
        return loss_aggregated, metrics_aggregated

    

In [34]:
class FedAvgPartWithVelocityFlowerClient(NumPyClient):
    def __init__(self, partition_id, net, trainloader, valloader, context: Context):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader
        self.client_state = context.state
    
        # Initialize parameters record if it doesn't exist
        if "net_parameters" not in self.client_state.parameters_records:
            self.client_state.parameters_records["net_parameters"] = ParametersRecord()
            # Save initial model state
            self._save_model_state()


    def get_parameters(self, config):
        print(f"[Client {self.partition_id}] get_parameters")
        parameters = get_parameters(self.net)
        if config["trainable_layers"] == -1:
            trained_layers = parameters
        else:
            trained_layers = [
                parameters[config["trainable_layers"] * 2],
                parameters[config["trainable_layers"] * 2 + 1]
            ]
        
        return trained_layers

    def _save_model_state(self):
        """Save current model parameters to context"""
        p_record = ParametersRecord()
        parameters = get_parameters(self.net)
        
        for i, param in enumerate(parameters):
            print(f"Saving layer {i} giving shape {param.shape}")
            p_record[f"layer_{i}"] = array_from_numpy(param)
        
        self.client_state.parameters_records["net_parameters"] = p_record

    def _load_model_state(self):
        """Load model parameters from context"""
        p_record = self.client_state.parameters_records["net_parameters"]
        parameters = []
        
        for i in range(len(p_record)):
            print(f"Loading layer {i} with shape {p_record[f'layer_{i}'].numpy().shape}")
            parameters.append(p_record[f"layer_{i}"].numpy())
        
        print(f"Loading model. Set parameters to net with {len(parameters)} params in total.")
        for i, p in enumerate(parameters):
            print(f"Loading param {i} with shape {p.shape}")
        set_parameters(self.net, parameters)

    def fit(self, parameters, config):
        print(f"[Client {self.partition_id}] fit")

        self._load_model_state()
        recieved_parameter_size = get_parameters_size(ndarrays_to_parameters(parameters))  
        set_parameters(self.net, parameters, config["updated_layers"])
        freeze_layers(self.net, config["trainable_layers"])

        self.optimizer = torch.optim.Adam(self.net.parameters())
        if "optimizer_state" in config:
            print(f"Got optimizer state in config, setting..")
            serialized_optimizer_state = config["optimizer_state"]
            optim_state = deserialize_optimizer_state(serialized_optimizer_state)
            self._set_optimizer_state(optim_state)

        train(self.net, self.trainloader, epochs=EPOCHS, optimizer=self.optimizer)
        self._save_model_state()
        print('finished train')

        # handle optimizer state (serialize before passing)
        optim_state = self._get_optimizer_state()
        serialized_optimizer_state = serialize_optimizer_state(optim_state)
        print(f"After training, got optim state {serialized_optimizer_state[:50]}")

        new_config = {
            "trained_layer":config["trainable_layers"], 
            "recieved_parameter_size": recieved_parameter_size,
            "optimizer_state": serialized_optimizer_state
        }

        return self.get_parameters(config), len(self.trainloader), new_config

    def evaluate(self, parameters, config):
        print(f"[Client {self.partition_id}] evaluate, config: {config}")
        self._load_model_state()
        # This part looks a bit sus. Why are we loading and then setting params to be the same?
        current_state = get_parameters(self.net)
        set_parameters(self.net, current_state)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

    def _get_optimizer_state(self):
        """Extract only the necessary parts of optimizer state"""
        state_dict = self.optimizer.state_dict()
        # Extract only essential information
        pruned_state = {'state': {}, 'param_groups': state_dict['param_groups']}
        
        for k, v in state_dict['state'].items():
            pruned_state['state'][k] = {
                'exp_avg': v['exp_avg'].clone().detach().cpu(),
                'exp_avg_sq': v['exp_avg_sq'].clone().detach().cpu(),
                'step': v['step']
            }
        
        return pruned_state

    def _set_optimizer_state(self, state_dict):
        """Load optimizer state efficiently"""
        # Move tensors to the appropriate device first
        for k in state_dict['state']:
            for key in ['exp_avg', 'exp_avg_sq']:
                if key in state_dict['state'][k]:
                    state_dict['state'][k][key] = state_dict['state'][k][key].to(DEVICE)
        
        self.optimizer.load_state_dict(state_dict)
    

def client_fn(context: Context) -> Client:
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]

    # Initialize network if not in context
    if not hasattr(context, 'net'):
        context.net = Net().to(DEVICE)
    
    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)

    return FedAvgPartWithVelocityFlowerClient(partition_id, context.net, trainloader, valloader, context).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

In [35]:
net = Net().to(DEVICE)

_, _, testloader = load_datasets(0, NUM_PARTITIONS)

evaluate_fn = get_evaluate_fn(testloader, net)

def server_fn(context: Context) -> ServerAppComponents:
    # Configure the server for just 3 rounds of training
    config = ServerConfig(num_rounds=NUM_OF_ROUNDS)
    return ServerAppComponents(
        config=config,
        strategy=FedPartAvgWithVelocity(
            evaluate_fn=evaluate_fn,
        ),
    )

server = ServerApp(server_fn=server_fn)

# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=11, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters



==== Server-side evaluation for round 0 ====


[92mINFO [0m:      initial parameters (loss, other metrics): 0.07213468878269196, {'accuracy': 0.0917}
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 6 clients (out of 6)


  Evaluation results - Loss: 0.0721, Accuracy: 0.0917
Training on layer [-1, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
[36m(ClientAppActor pid=30196)[0m Saving layer 0 giving shape (6, 3, 5, 5)
[36m(ClientAppActor pid=30196)[0m Saving layer 1 giving shape (6,)
[36m(ClientAppActor pid=30196)[0m Saving layer 2 giving shape (16, 6, 5, 5)
[36m(ClientAppActor pid=30196)[0m Saving layer 3 giving shape (16,)
[36m(ClientAppActor pid=30196)[0m Saving layer 4 giving shape (120, 400)
[36m(ClientAppActor pid=30196)[0m Saving layer 5 giving shape (120,)
[36m(ClientAppActor pid=30196)[0m Saving layer 6 giving shape (84, 120)
[36m(ClientAppActor pid=30196)[0m Saving layer 7 giving shape (84,)
[36m(ClientAppActor pid=30196)[0m Saving layer 8 giving shape (10, 84)
[36m(ClientAppActor pid=30196)[0m Saving layer 9 giving shape (10,)
[36m(ClientAppActor pid=30196)[0m [Client 5] fit
[36m(ClientAppActor pid=30196)[0m Loading layer 0 with shape (6, 3, 5, 5)
[36m(ClientAppActor pid=30196)[0m Loa

[92mINFO [0m:      aggregate_fit: received 6 results and 0 failures


total size: 2999136
aggregated weight size 10 

==== Server-side evaluation for round 1 ====
  Parameter 0: Changed by 0.033680
  Parameter 1: Changed by 0.045348
  Parameter 2: Changed by 0.027236
  Parameter 3: Changed by 0.060290
  Parameter 4: Changed by 0.015780
  Parameter 5: Changed by 0.034241
  Parameter 6: Changed by 0.011974
  Parameter 7: Changed by 0.036609
  Parameter 8: Changed by 0.022635
  Parameter 9: Changed by 0.044100


[92mINFO [0m:      fit progress: (1, 0.053068926656246185, {'accuracy': 0.4059}, 27.912993625000126)
[92mINFO [0m:      configure_evaluate: strategy sampled 6 clients (out of 6)


  Evaluation results - Loss: 0.0531, Accuracy: 0.4059
[36m(ClientAppActor pid=30201)[0m [Client 3] evaluate, config: {}
[36m(ClientAppActor pid=30201)[0m Loading layer 0 with shape (6, 3, 5, 5)
[36m(ClientAppActor pid=30201)[0m Loading layer 1 with shape (6,)
[36m(ClientAppActor pid=30201)[0m Loading layer 2 with shape (16, 6, 5, 5)
[36m(ClientAppActor pid=30201)[0m Loading layer 3 with shape (16,)
[36m(ClientAppActor pid=30201)[0m Loading layer 4 with shape (120, 400)
[36m(ClientAppActor pid=30201)[0m Loading layer 5 with shape (120,)
[36m(ClientAppActor pid=30201)[0m Loading layer 6 with shape (84, 120)
[36m(ClientAppActor pid=30201)[0m Loading layer 7 with shape (84,)
[36m(ClientAppActor pid=30201)[0m Loading layer 8 with shape (10, 84)
[36m(ClientAppActor pid=30201)[0m Loading layer 9 with shape (10,)
[36m(ClientAppActor pid=30201)[0m Loading model. Set parameters to net with 10 params in total.
[36m(ClientAppActor pid=30201)[0m Loading param 0 with shape 

[92mINFO [0m:      aggregate_evaluate: received 6 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 6 clients (out of 6)


Training on layer [-1, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
[36m(ClientAppActor pid=30197)[0m [Client 2] fit
[36m(ClientAppActor pid=30197)[0m layer index is 0 and nameconv1.weight is trainabe
[36m(ClientAppActor pid=30197)[0m layer index is 1 and nameconv1.bias is trainabe
[36m(ClientAppActor pid=30197)[0m layer index is 2 and nameconv2.weight is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 3 and nameconv2.bias is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 4 and namefc1.weight is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 5 and namefc1.bias is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 6 and namefc2.weight is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 7 and namefc2.bias is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 8 and namefc3.weight is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 9 and namefc3.bias is frozen
[36m(ClientAppActor pid=30197)[0m Got optimizer state in config, setting

[92mINFO [0m:      aggregate_fit: received 6 results and 0 failures


total size: 1513632
aggregated weight size 2 
updateing layers 0  and 1 

==== Server-side evaluation for round 2 ====
  Parameter 0: Changed by 0.034049
  Parameter 1: Changed by 0.048975


[92mINFO [0m:      fit progress: (2, 0.04913188210725784, {'accuracy': 0.4211}, 48.16628341600017)
[92mINFO [0m:      configure_evaluate: strategy sampled 6 clients (out of 6)


  Evaluation results - Loss: 0.0491, Accuracy: 0.4211
[36m(ClientAppActor pid=30200)[0m [Client 5] evaluate, config: {}
[36m(ClientAppActor pid=30200)[0m Loading layer 0 with shape (6, 3, 5, 5)
[36m(ClientAppActor pid=30200)[0m Loading layer 1 with shape (6,)
[36m(ClientAppActor pid=30200)[0m Loading layer 2 with shape (16, 6, 5, 5)
[36m(ClientAppActor pid=30200)[0m Loading layer 3 with shape (16,)
[36m(ClientAppActor pid=30200)[0m Loading layer 4 with shape (120, 400)
[36m(ClientAppActor pid=30200)[0m Loading layer 5 with shape (120,)
[36m(ClientAppActor pid=30200)[0m Loading layer 6 with shape (84, 120)
[36m(ClientAppActor pid=30200)[0m Loading layer 7 with shape (84,)
[36m(ClientAppActor pid=30200)[0m Loading layer 8 with shape (10, 84)
[36m(ClientAppActor pid=30200)[0m Loading layer 9 with shape (10,)
[36m(ClientAppActor pid=30200)[0m Loading model. Set parameters to net with 10 params in total.
[36m(ClientAppActor pid=30200)[0m Loading param 0 with shape 

[36m(ClientAppActor pid=30199)[0m Using the latest cached version of the dataset since cifar10 couldn't be found on the Hugging Face Hub
[36m(ClientAppActor pid=30199)[0m Found the latest cached dataset configuration 'plain_text' at /Users/asbjornlorenzen/.cache/huggingface/datasets/cifar10/plain_text/0.0.0/0b2714987fa478483af9968de7c934580d0bb9a2 (last modified on Tue Mar 11 10:39:48 2025).


[36m(ClientAppActor pid=30196)[0m [Client 0] evaluate, config: {}[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=30196)[0m Loading layer 9 with shape (10,)[32m [repeated 20x across cluster][0m
[36m(ClientAppActor pid=30196)[0m Loading model. Set parameters to net with 10 params in total.[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=30196)[0m Loading param 9 with shape (10,)[32m [repeated 20x across cluster][0m


[92mINFO [0m:      aggregate_evaluate: received 6 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 6 clients (out of 6)


Training on layer [-1, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
[36m(ClientAppActor pid=30197)[0m [Client 5] fit
[36m(ClientAppActor pid=30197)[0m layer index is 0 and nameconv1.weight is trainabe
[36m(ClientAppActor pid=30197)[0m layer index is 1 and nameconv1.bias is trainabe
[36m(ClientAppActor pid=30197)[0m layer index is 2 and nameconv2.weight is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 3 and nameconv2.bias is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 4 and namefc1.weight is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 5 and namefc1.bias is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 6 and namefc2.weight is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 7 and namefc2.bias is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 8 and namefc3.weight is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 9 and namefc3.bias is frozen
[36m(ClientAppActor pid=30197)[0m Got optimizer state in config, setting

[92mINFO [0m:      aggregate_fit: received 6 results and 0 failures


[36m(ClientAppActor pid=30200)[0m Saving layer 9 giving shape (10,)[32m [repeated 50x across cluster][0m
[36m(ClientAppActor pid=30200)[0m finished train[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=30200)[0m After training, got optim state gAWVoi4AAAAAAAB9lCiMBXN0YXRllH2UKEsAfZQojAdleHBfYX[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=30200)[0m [Client 4] get_parameters[32m [repeated 5x across cluster][0m
total size: 28128
aggregated weight size 2 
updateing layers 0  and 1 

==== Server-side evaluation for round 3 ====
  Parameter 0: Changed by 0.017936
  Parameter 1: Changed by 0.033001


[92mINFO [0m:      fit progress: (3, 0.04871329727172852, {'accuracy': 0.4238}, 89.31518787499999)
[92mINFO [0m:      configure_evaluate: strategy sampled 6 clients (out of 6)


  Evaluation results - Loss: 0.0487, Accuracy: 0.4238
[36m(ClientAppActor pid=30197)[0m [Client 5] evaluate, config: {}
[36m(ClientAppActor pid=30197)[0m Loading layer 0 with shape (6, 3, 5, 5)
[36m(ClientAppActor pid=30197)[0m Loading layer 1 with shape (6,)
[36m(ClientAppActor pid=30197)[0m Loading layer 2 with shape (16, 6, 5, 5)
[36m(ClientAppActor pid=30197)[0m Loading layer 3 with shape (16,)
[36m(ClientAppActor pid=30197)[0m Loading layer 4 with shape (120, 400)
[36m(ClientAppActor pid=30197)[0m Loading layer 5 with shape (120,)
[36m(ClientAppActor pid=30197)[0m Loading layer 6 with shape (84, 120)
[36m(ClientAppActor pid=30197)[0m Loading layer 7 with shape (84,)
[36m(ClientAppActor pid=30197)[0m Loading layer 8 with shape (10, 84)
[36m(ClientAppActor pid=30197)[0m Loading layer 9 with shape (10,)
[36m(ClientAppActor pid=30197)[0m Loading model. Set parameters to net with 10 params in total.
[36m(ClientAppActor pid=30197)[0m Loading param 0 with shape 

[92mINFO [0m:      aggregate_evaluate: received 6 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 6 clients (out of 6)


Training on layer [-1, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
[36m(ClientAppActor pid=30196)[0m [Client 2] fit
[36m(ClientAppActor pid=30196)[0m layer index is 0 and nameconv1.weight is frozen
[36m(ClientAppActor pid=30196)[0m layer index is 1 and nameconv1.bias is frozen
[36m(ClientAppActor pid=30196)[0m layer index is 2 and nameconv2.weight is trainabe
[36m(ClientAppActor pid=30196)[0m layer index is 3 and nameconv2.bias is trainabe
[36m(ClientAppActor pid=30196)[0m layer index is 4 and namefc1.weight is frozen
[36m(ClientAppActor pid=30196)[0m layer index is 5 and namefc1.bias is frozen
[36m(ClientAppActor pid=30196)[0m layer index is 6 and namefc2.weight is frozen
[36m(ClientAppActor pid=30196)[0m layer index is 7 and namefc2.bias is frozen
[36m(ClientAppActor pid=30196)[0m layer index is 8 and namefc3.weight is frozen
[36m(ClientAppActor pid=30196)[0m layer index is 9 and namefc3.bias is frozen
[36m(ClientAppActor pid=30196)[0m Got optimizer state in config, setting

[92mINFO [0m:      aggregate_fit: received 6 results and 0 failures


total size: 75168
aggregated weight size 2 
updateing layers 2  and 3 

==== Server-side evaluation for round 4 ====
  Parameter 2: Changed by 0.021095
  Parameter 3: Changed by 0.032073


[92mINFO [0m:      fit progress: (4, 0.04646151665449143, {'accuracy': 0.4528}, 113.24806099999978)
[92mINFO [0m:      configure_evaluate: strategy sampled 6 clients (out of 6)


  Evaluation results - Loss: 0.0465, Accuracy: 0.4528
[36m(ClientAppActor pid=30196)[0m [Client 3] evaluate, config: {}
[36m(ClientAppActor pid=30196)[0m Loading layer 0 with shape (6, 3, 5, 5)
[36m(ClientAppActor pid=30196)[0m Loading layer 1 with shape (6,)
[36m(ClientAppActor pid=30196)[0m Loading layer 2 with shape (16, 6, 5, 5)
[36m(ClientAppActor pid=30196)[0m Loading layer 3 with shape (16,)
[36m(ClientAppActor pid=30196)[0m Loading layer 4 with shape (120, 400)
[36m(ClientAppActor pid=30196)[0m Loading layer 5 with shape (120,)
[36m(ClientAppActor pid=30196)[0m Loading layer 6 with shape (84, 120)
[36m(ClientAppActor pid=30196)[0m Loading layer 7 with shape (84,)
[36m(ClientAppActor pid=30196)[0m Loading layer 8 with shape (10, 84)
[36m(ClientAppActor pid=30196)[0m Loading layer 9 with shape (10,)
[36m(ClientAppActor pid=30196)[0m Loading model. Set parameters to net with 10 params in total.
[36m(ClientAppActor pid=30196)[0m Loading param 0 with shape 

[92mINFO [0m:      aggregate_evaluate: received 6 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 6 clients (out of 6)


Training on layer [-1, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
[36m(ClientAppActor pid=30199)[0m [Client 0] fit
[36m(ClientAppActor pid=30199)[0m layer index is 0 and nameconv1.weight is frozen
[36m(ClientAppActor pid=30199)[0m layer index is 1 and nameconv1.bias is frozen
[36m(ClientAppActor pid=30199)[0m layer index is 2 and nameconv2.weight is trainabe
[36m(ClientAppActor pid=30199)[0m layer index is 3 and nameconv2.bias is trainabe
[36m(ClientAppActor pid=30199)[0m layer index is 4 and namefc1.weight is frozen
[36m(ClientAppActor pid=30199)[0m layer index is 5 and namefc1.bias is frozen
[36m(ClientAppActor pid=30199)[0m layer index is 6 and namefc2.weight is frozen
[36m(ClientAppActor pid=30199)[0m layer index is 7 and namefc2.bias is frozen
[36m(ClientAppActor pid=30199)[0m layer index is 8 and namefc3.weight is frozen
[36m(ClientAppActor pid=30199)[0m layer index is 9 and namefc3.bias is frozen
[36m(ClientAppActor pid=30199)[0m Got optimizer state in config, setting

[92mINFO [0m:      aggregate_fit: received 6 results and 0 failures


total size: 122208
aggregated weight size 2 
updateing layers 2  and 3 

==== Server-side evaluation for round 5 ====
  Parameter 2: Changed by 0.013361
  Parameter 3: Changed by 0.033342


[92mINFO [0m:      fit progress: (5, 0.04581130701303482, {'accuracy': 0.4577}, 144.9998774579999)
[92mINFO [0m:      configure_evaluate: strategy sampled 6 clients (out of 6)


  Evaluation results - Loss: 0.0458, Accuracy: 0.4577
[36m(ClientAppActor pid=30197)[0m [Client 2] evaluate, config: {}
[36m(ClientAppActor pid=30197)[0m Loading layer 0 with shape (6, 3, 5, 5)
[36m(ClientAppActor pid=30197)[0m Loading layer 1 with shape (6,)
[36m(ClientAppActor pid=30197)[0m Loading layer 2 with shape (16, 6, 5, 5)
[36m(ClientAppActor pid=30197)[0m Loading layer 3 with shape (16,)
[36m(ClientAppActor pid=30197)[0m Loading layer 4 with shape (120, 400)
[36m(ClientAppActor pid=30197)[0m Loading layer 5 with shape (120,)
[36m(ClientAppActor pid=30197)[0m Loading layer 6 with shape (84, 120)
[36m(ClientAppActor pid=30197)[0m Loading layer 7 with shape (84,)
[36m(ClientAppActor pid=30197)[0m Loading layer 8 with shape (10, 84)
[36m(ClientAppActor pid=30197)[0m Loading layer 9 with shape (10,)
[36m(ClientAppActor pid=30197)[0m Loading model. Set parameters to net with 10 params in total.
[36m(ClientAppActor pid=30197)[0m Loading param 0 with shape 

[92mINFO [0m:      aggregate_evaluate: received 6 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 6]
[92mINFO [0m:      configure_fit: strategy sampled 6 clients (out of 6)


Training on layer [-1, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
[36m(ClientAppActor pid=30199)[0m [Client 2] fit
[36m(ClientAppActor pid=30199)[0m layer index is 0 and nameconv1.weight is frozen
[36m(ClientAppActor pid=30199)[0m layer index is 1 and nameconv1.bias is frozen
[36m(ClientAppActor pid=30199)[0m layer index is 2 and nameconv2.weight is frozen
[36m(ClientAppActor pid=30199)[0m layer index is 3 and nameconv2.bias is frozen
[36m(ClientAppActor pid=30199)[0m layer index is 4 and namefc1.weight is trainabe
[36m(ClientAppActor pid=30199)[0m layer index is 5 and namefc1.bias is trainabe
[36m(ClientAppActor pid=30199)[0m layer index is 6 and namefc2.weight is frozen
[36m(ClientAppActor pid=30199)[0m layer index is 7 and namefc2.bias is frozen
[36m(ClientAppActor pid=30199)[0m layer index is 8 and namefc3.weight is frozen
[36m(ClientAppActor pid=30199)[0m layer index is 9 and namefc3.bias is frozen
[36m(ClientAppActor pid=30199)[0m Got optimizer state in config, setting

[92mINFO [0m:      aggregate_fit: received 6 results and 0 failures


total size: 1219104
aggregated weight size 2 
updateing layers 4  and 5 

==== Server-side evaluation for round 6 ====
  Parameter 4: Changed by 0.046805
  Parameter 5: Changed by 0.020523


[92mINFO [0m:      fit progress: (6, 0.04234300102591514, {'accuracy': 0.5049}, 166.54411583299998)
[92mINFO [0m:      configure_evaluate: strategy sampled 6 clients (out of 6)


  Evaluation results - Loss: 0.0423, Accuracy: 0.5049
[36m(ClientAppActor pid=30201)[0m [Client 4] evaluate, config: {}
[36m(ClientAppActor pid=30201)[0m Loading layer 0 with shape (6, 3, 5, 5)
[36m(ClientAppActor pid=30201)[0m Loading layer 1 with shape (6,)
[36m(ClientAppActor pid=30201)[0m Loading layer 2 with shape (16, 6, 5, 5)
[36m(ClientAppActor pid=30201)[0m Loading layer 3 with shape (16,)
[36m(ClientAppActor pid=30201)[0m Loading layer 4 with shape (120, 400)
[36m(ClientAppActor pid=30201)[0m Loading layer 5 with shape (120,)
[36m(ClientAppActor pid=30201)[0m Loading layer 6 with shape (84, 120)
[36m(ClientAppActor pid=30201)[0m Loading layer 7 with shape (84,)
[36m(ClientAppActor pid=30201)[0m Loading layer 8 with shape (10, 84)
[36m(ClientAppActor pid=30201)[0m Loading layer 9 with shape (10,)
[36m(ClientAppActor pid=30201)[0m Loading model. Set parameters to net with 10 params in total.
[36m(ClientAppActor pid=30201)[0m Loading param 0 with shape 

[92mINFO [0m:      aggregate_evaluate: received 6 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 7]
[92mINFO [0m:      configure_fit: strategy sampled 6 clients (out of 6)


Training on layer [-1, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
[36m(ClientAppActor pid=30199)[0m [Client 1] fit
[36m(ClientAppActor pid=30199)[0m layer index is 0 and nameconv1.weight is frozen
[36m(ClientAppActor pid=30199)[0m layer index is 1 and nameconv1.bias is frozen
[36m(ClientAppActor pid=30199)[0m layer index is 2 and nameconv2.weight is frozen
[36m(ClientAppActor pid=30199)[0m layer index is 3 and nameconv2.bias is frozen
[36m(ClientAppActor pid=30199)[0m layer index is 4 and namefc1.weight is trainabe
[36m(ClientAppActor pid=30199)[0m layer index is 5 and namefc1.bias is trainabe
[36m(ClientAppActor pid=30199)[0m layer index is 6 and namefc2.weight is frozen
[36m(ClientAppActor pid=30199)[0m layer index is 7 and namefc2.bias is frozen
[36m(ClientAppActor pid=30199)[0m layer index is 8 and namefc3.weight is frozen
[36m(ClientAppActor pid=30199)[0m layer index is 9 and namefc3.bias is frozen
[36m(ClientAppActor pid=30199)[0m Got optimizer state in config, setting

[92mINFO [0m:      aggregate_fit: received 6 results and 0 failures


[36m(ClientAppActor pid=30197)[0m Saving layer 0 giving shape (6, 3, 5, 5)
[36m(ClientAppActor pid=30197)[0m Saving layer 1 giving shape (6,)
[36m(ClientAppActor pid=30197)[0m Saving layer 2 giving shape (16, 6, 5, 5)
[36m(ClientAppActor pid=30197)[0m Saving layer 3 giving shape (16,)
[36m(ClientAppActor pid=30197)[0m Saving layer 4 giving shape (120, 400)
[36m(ClientAppActor pid=30197)[0m Saving layer 5 giving shape (120,)
[36m(ClientAppActor pid=30197)[0m Saving layer 6 giving shape (84, 120)
[36m(ClientAppActor pid=30197)[0m Saving layer 7 giving shape (84,)
[36m(ClientAppActor pid=30197)[0m Saving layer 8 giving shape (10, 84)
[36m(ClientAppActor pid=30197)[0m Saving layer 9 giving shape (10,)
[36m(ClientAppActor pid=30197)[0m finished train
[36m(ClientAppActor pid=30197)[0m After training, got optim state gAWVoi4AAAAAAAB9lCiMBXN0YXRllH2UKEsAfZQojAdleHBfYX
[36m(ClientAppActor pid=30197)[0m [Client 3] get_parameters
total size: 2316000
aggregated weight siz

[92mINFO [0m:      fit progress: (7, 0.04120952035784722, {'accuracy': 0.5213}, 201.1755780829999)
[92mINFO [0m:      configure_evaluate: strategy sampled 6 clients (out of 6)


  Evaluation results - Loss: 0.0412, Accuracy: 0.5213
[36m(ClientAppActor pid=30196)[0m [Client 0] evaluate, config: {}
[36m(ClientAppActor pid=30196)[0m Loading layer 0 with shape (6, 3, 5, 5)
[36m(ClientAppActor pid=30196)[0m Loading layer 1 with shape (6,)
[36m(ClientAppActor pid=30196)[0m Loading layer 2 with shape (16, 6, 5, 5)
[36m(ClientAppActor pid=30196)[0m Loading layer 3 with shape (16,)
[36m(ClientAppActor pid=30196)[0m Loading layer 4 with shape (120, 400)
[36m(ClientAppActor pid=30196)[0m Loading layer 5 with shape (120,)
[36m(ClientAppActor pid=30196)[0m Loading layer 6 with shape (84, 120)
[36m(ClientAppActor pid=30196)[0m Loading layer 7 with shape (84,)
[36m(ClientAppActor pid=30196)[0m Loading layer 8 with shape (10, 84)
[36m(ClientAppActor pid=30196)[0m Loading layer 9 with shape (10,)
[36m(ClientAppActor pid=30196)[0m Loading model. Set parameters to net with 10 params in total.
[36m(ClientAppActor pid=30196)[0m Loading param 0 with shape 

[92mINFO [0m:      aggregate_evaluate: received 6 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 8]
[92mINFO [0m:      configure_fit: strategy sampled 6 clients (out of 6)


Training on layer [-1, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
[36m(ClientAppActor pid=30197)[0m [Client 4] fit
[36m(ClientAppActor pid=30197)[0m layer index is 0 and nameconv1.weight is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 1 and nameconv1.bias is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 2 and nameconv2.weight is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 3 and nameconv2.bias is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 4 and namefc1.weight is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 5 and namefc1.bias is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 6 and namefc2.weight is trainabe
[36m(ClientAppActor pid=30197)[0m layer index is 7 and namefc2.bias is trainabe
[36m(ClientAppActor pid=30197)[0m layer index is 8 and namefc3.weight is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 9 and namefc3.bias is frozen
[36m(ClientAppActor pid=30197)[0m Got optimizer state in config, setting

[92mINFO [0m:      aggregate_fit: received 6 results and 0 failures


total size: 1405056
aggregated weight size 2 
updateing layers 6  and 7 

==== Server-side evaluation for round 8 ====
  Parameter 6: Changed by 0.082426
  Parameter 7: Changed by 0.018952


[92mINFO [0m:      fit progress: (8, 0.040402452236413956, {'accuracy': 0.5329}, 221.13410449999992)
[92mINFO [0m:      configure_evaluate: strategy sampled 6 clients (out of 6)


  Evaluation results - Loss: 0.0404, Accuracy: 0.5329
[36m(ClientAppActor pid=30199)[0m [Client 1] evaluate, config: {}
[36m(ClientAppActor pid=30199)[0m Loading layer 0 with shape (6, 3, 5, 5)
[36m(ClientAppActor pid=30199)[0m Loading layer 1 with shape (6,)
[36m(ClientAppActor pid=30199)[0m Loading layer 2 with shape (16, 6, 5, 5)
[36m(ClientAppActor pid=30199)[0m Loading layer 3 with shape (16,)
[36m(ClientAppActor pid=30199)[0m Loading layer 4 with shape (120, 400)
[36m(ClientAppActor pid=30199)[0m Loading layer 5 with shape (120,)
[36m(ClientAppActor pid=30199)[0m Loading layer 6 with shape (84, 120)
[36m(ClientAppActor pid=30199)[0m Loading layer 7 with shape (84,)
[36m(ClientAppActor pid=30199)[0m Loading layer 8 with shape (10, 84)
[36m(ClientAppActor pid=30199)[0m Loading layer 9 with shape (10,)
[36m(ClientAppActor pid=30199)[0m Loading model. Set parameters to net with 10 params in total.
[36m(ClientAppActor pid=30199)[0m Loading param 0 with shape 

[92mINFO [0m:      aggregate_evaluate: received 6 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 9]
[92mINFO [0m:      configure_fit: strategy sampled 6 clients (out of 6)


Training on layer [-1, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
[36m(ClientAppActor pid=30196)[0m [Client 4] fit
[36m(ClientAppActor pid=30196)[0m layer index is 0 and nameconv1.weight is frozen
[36m(ClientAppActor pid=30196)[0m layer index is 1 and nameconv1.bias is frozen
[36m(ClientAppActor pid=30196)[0m layer index is 2 and nameconv2.weight is frozen
[36m(ClientAppActor pid=30196)[0m layer index is 3 and nameconv2.bias is frozen
[36m(ClientAppActor pid=30196)[0m layer index is 4 and namefc1.weight is frozen
[36m(ClientAppActor pid=30196)[0m layer index is 5 and namefc1.bias is frozen
[36m(ClientAppActor pid=30196)[0m layer index is 6 and namefc2.weight is trainabe
[36m(ClientAppActor pid=30196)[0m layer index is 7 and namefc2.bias is trainabe
[36m(ClientAppActor pid=30196)[0m layer index is 8 and namefc3.weight is frozen
[36m(ClientAppActor pid=30196)[0m layer index is 9 and namefc3.bias is frozen
[36m(ClientAppActor pid=30196)[0m Got optimizer state in config, setting

[92mINFO [0m:      aggregate_fit: received 6 results and 0 failures


total size: 494112
aggregated weight size 2 
updateing layers 6  and 7 

==== Server-side evaluation for round 9 ====
  Parameter 6: Changed by 0.011882
  Parameter 7: Changed by 0.018287


[92mINFO [0m:      fit progress: (9, 0.04015083453059196, {'accuracy': 0.5364}, 253.50491112500004)
[92mINFO [0m:      configure_evaluate: strategy sampled 6 clients (out of 6)


  Evaluation results - Loss: 0.0402, Accuracy: 0.5364
[36m(ClientAppActor pid=30196)[0m [Client 5] evaluate, config: {}
[36m(ClientAppActor pid=30196)[0m Loading layer 0 with shape (6, 3, 5, 5)
[36m(ClientAppActor pid=30196)[0m Loading layer 1 with shape (6,)
[36m(ClientAppActor pid=30196)[0m Loading layer 2 with shape (16, 6, 5, 5)
[36m(ClientAppActor pid=30196)[0m Loading layer 3 with shape (16,)
[36m(ClientAppActor pid=30196)[0m Loading layer 4 with shape (120, 400)
[36m(ClientAppActor pid=30196)[0m Loading layer 5 with shape (120,)
[36m(ClientAppActor pid=30196)[0m Loading layer 6 with shape (84, 120)
[36m(ClientAppActor pid=30196)[0m Loading layer 7 with shape (84,)
[36m(ClientAppActor pid=30196)[0m Loading layer 8 with shape (10, 84)
[36m(ClientAppActor pid=30196)[0m Loading layer 9 with shape (10,)
[36m(ClientAppActor pid=30196)[0m Loading model. Set parameters to net with 10 params in total.
[36m(ClientAppActor pid=30196)[0m Loading param 0 with shape 

[92mINFO [0m:      aggregate_evaluate: received 6 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 10]
[92mINFO [0m:      configure_fit: strategy sampled 6 clients (out of 6)


Training on layer [-1, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
[36m(ClientAppActor pid=30201)[0m [Client 4] fit
[36m(ClientAppActor pid=30201)[0m layer index is 0 and nameconv1.weight is frozen
[36m(ClientAppActor pid=30201)[0m layer index is 1 and nameconv1.bias is frozen
[36m(ClientAppActor pid=30201)[0m layer index is 2 and nameconv2.weight is frozen
[36m(ClientAppActor pid=30201)[0m layer index is 3 and nameconv2.bias is frozen
[36m(ClientAppActor pid=30201)[0m layer index is 4 and namefc1.weight is frozen
[36m(ClientAppActor pid=30201)[0m layer index is 5 and namefc1.bias is frozen
[36m(ClientAppActor pid=30201)[0m layer index is 6 and namefc2.weight is frozen
[36m(ClientAppActor pid=30201)[0m layer index is 7 and namefc2.bias is frozen
[36m(ClientAppActor pid=30201)[0m layer index is 8 and namefc3.weight is trainabe
[36m(ClientAppActor pid=30201)[0m layer index is 9 and namefc3.bias is trainabe
[36m(ClientAppActor pid=30201)[0m Got optimizer state in config, setting

[92mINFO [0m:      aggregate_fit: received 6 results and 0 failures


[36m(ClientAppActor pid=30201)[0m Saving layer 0 giving shape (6, 3, 5, 5)
[36m(ClientAppActor pid=30201)[0m Saving layer 1 giving shape (6,)
[36m(ClientAppActor pid=30201)[0m Saving layer 2 giving shape (16, 6, 5, 5)
[36m(ClientAppActor pid=30201)[0m Saving layer 3 giving shape (16,)
[36m(ClientAppActor pid=30201)[0m Saving layer 4 giving shape (120, 400)
[36m(ClientAppActor pid=30201)[0m Saving layer 5 giving shape (120,)
[36m(ClientAppActor pid=30201)[0m Saving layer 6 giving shape (84, 120)
[36m(ClientAppActor pid=30201)[0m Saving layer 7 giving shape (84,)
[36m(ClientAppActor pid=30201)[0m Saving layer 8 giving shape (10, 84)
[36m(ClientAppActor pid=30201)[0m Saving layer 9 giving shape (10,)
[36m(ClientAppActor pid=30201)[0m finished train
[36m(ClientAppActor pid=30201)[0m After training, got optim state gAWVoi4AAAAAAAB9lCiMBXN0YXRllH2UKEsAfZQojAdleHBfYX
[36m(ClientAppActor pid=30201)[0m [Client 4] get_parameters
total size: 270576
aggregated weight size

[92mINFO [0m:      fit progress: (10, 0.0400597141802311, {'accuracy': 0.5396}, 276.90542562500013)
[92mINFO [0m:      configure_evaluate: strategy sampled 6 clients (out of 6)


  Evaluation results - Loss: 0.0401, Accuracy: 0.5396
[36m(ClientAppActor pid=30200)[0m [Client 1] evaluate, config: {}
[36m(ClientAppActor pid=30200)[0m Loading layer 0 with shape (6, 3, 5, 5)
[36m(ClientAppActor pid=30200)[0m Loading layer 1 with shape (6,)
[36m(ClientAppActor pid=30200)[0m Loading layer 2 with shape (16, 6, 5, 5)
[36m(ClientAppActor pid=30200)[0m Loading layer 3 with shape (16,)
[36m(ClientAppActor pid=30200)[0m Loading layer 4 with shape (120, 400)
[36m(ClientAppActor pid=30200)[0m Loading layer 5 with shape (120,)
[36m(ClientAppActor pid=30200)[0m Loading layer 6 with shape (84, 120)
[36m(ClientAppActor pid=30200)[0m Loading layer 7 with shape (84,)
[36m(ClientAppActor pid=30200)[0m Loading layer 8 with shape (10, 84)
[36m(ClientAppActor pid=30200)[0m Loading layer 9 with shape (10,)
[36m(ClientAppActor pid=30200)[0m Loading model. Set parameters to net with 10 params in total.
[36m(ClientAppActor pid=30200)[0m Loading param 0 with shape 

[92mINFO [0m:      aggregate_evaluate: received 6 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 11]
[92mINFO [0m:      configure_fit: strategy sampled 6 clients (out of 6)


Training on layer [-1, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
[36m(ClientAppActor pid=30197)[0m [Client 4] fit
[36m(ClientAppActor pid=30197)[0m layer index is 0 and nameconv1.weight is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 1 and nameconv1.bias is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 2 and nameconv2.weight is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 3 and nameconv2.bias is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 4 and namefc1.weight is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 5 and namefc1.bias is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 6 and namefc2.weight is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 7 and namefc2.bias is frozen
[36m(ClientAppActor pid=30197)[0m layer index is 8 and namefc3.weight is trainabe
[36m(ClientAppActor pid=30197)[0m layer index is 9 and namefc3.bias is trainabe
[36m(ClientAppActor pid=30197)[0m Got optimizer state in config, setting

[92mINFO [0m:      aggregate_fit: received 6 results and 0 failures


[36m(ClientAppActor pid=30196)[0m Saving layer 9 giving shape (10,)[32m [repeated 50x across cluster][0m
[36m(ClientAppActor pid=30196)[0m finished train[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=30199)[0m After training, got optim state gAWVoi4AAAAAAAB9lCiMBXN0YXRllH2UKEsAfZQojAdleHBfYX[32m [repeated 4x across cluster][0m
[36m(ClientAppActor pid=30199)[0m [Client 1] get_parameters[32m [repeated 4x across cluster][0m
total size: 47040
aggregated weight size 2 
updateing layers 8  and 9 

==== Server-side evaluation for round 11 ====
  Parameter 8: Changed by 0.006865
  Parameter 9: Changed by 0.020202


[92mINFO [0m:      fit progress: (11, 0.04002404733896255, {'accuracy': 0.5385}, 315.4028872079998)
[92mINFO [0m:      configure_evaluate: strategy sampled 6 clients (out of 6)


  Evaluation results - Loss: 0.0400, Accuracy: 0.5385
[36m(ClientAppActor pid=30198)[0m [Client 4] evaluate, config: {}
[36m(ClientAppActor pid=30198)[0m Loading layer 0 with shape (6, 3, 5, 5)
[36m(ClientAppActor pid=30198)[0m Loading layer 1 with shape (6,)
[36m(ClientAppActor pid=30198)[0m Loading layer 2 with shape (16, 6, 5, 5)
[36m(ClientAppActor pid=30196)[0m Epoch 4: train loss 0.038429856300354004, accuracy 0.5534053405340534
[36m(ClientAppActor pid=30196)[0m After training, got optim state gAWVoi4AAAAAAAB9lCiMBXN0YXRllH2UKEsAfZQojAdleHBfYX
[36m(ClientAppActor pid=30196)[0m [Client 5] get_parameters
[36m(ClientAppActor pid=30198)[0m 
[36m(ClientAppActor pid=30198)[0m Loading layer 3 with shape (16,)
[36m(ClientAppActor pid=30198)[0m Loading layer 4 with shape (120, 400)
[36m(ClientAppActor pid=30198)[0m Loading layer 5 with shape (120,)
[36m(ClientAppActor pid=30198)[0m Loading layer 6 with shape (84, 120)
[36m(ClientAppActor pid=30198)[0m Loading lay

[92mINFO [0m:      aggregate_evaluate: received 6 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 11 round(s) in 319.07s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.04827920325253873
[92mINFO [0m:      		round 2: 0.05045354747695938
[92mINFO [0m:      		round 3: 0.049991784227344904
[92mINFO [0m:      		round 4: 0.04814031727312565
[92mINFO [0m:      		round 5: 0.04728299412124754
[92mINFO [0m:      		round 6: 0.04506685741065479
[92mINFO [0m:      		round 7: 0.043642411462499295
[92mINFO [0m:      		round 8: 0.042221576943799896
[92mINFO [0m:      		round 9: 0.042044280448858086
[92mINFO [0m:      		round 10: 0.041684713707020746
[92mINFO [0m:      		round 11: 0.04164603814092357
[92mINFO [0m:      	History (loss, centralized):
[92mINFO [0m:      		round 0: 0.07213468878269196
[92mINFO [0m:      		round 1: 0.053068926656246185
[92mINFO [0m:      		round 

[36m(ClientAppActor pid=30197)[0m [Client 5] evaluate, config: {}[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=30197)[0m Loading layer 9 with shape (10,)[32m [repeated 50x across cluster][0m
[36m(ClientAppActor pid=30197)[0m Loading model. Set parameters to net with 10 params in total.[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=30197)[0m Loading param 9 with shape (10,)[32m [repeated 50x across cluster][0m


In [None]:
with open(f'results/fed_part_momentum_avg_result.p', 'wb') as file:
    pickle.dump(fed_part_avg_result, file)

with open(f'results/fed_part_avg_momentum_model_results.p', 'wb') as file:
    pickle.dump(fed_part_avg_model_results, file)

In [None]:
import matplotlib.pyplot as plt
import numpy as np



# # Plot the total size of parameters for each round
# fed_part_avg_rounds = list(fed_part_avg_result.keys())
# fed_part_avg_sizes = [fed_part_avg_result[round]["total_size"] for round in fed_part_avg_rounds]

# plt.figure(figsize=(10, 5))
# plt.plot(fed_part_avg_rounds, fed_part_avg_sizes, marker='o', linestyle='-', color='b', label='FedPartAvgMomentum')
# # plt.plot(fed_avg_rounds, fed_avg_sizes, marker='o', linestyle='-', color='r', label='FedAvg')
# plt.xlabel('Round')
# plt.ylabel('Total Size of Parameters (bytes)')
# plt.title('Total Size of Parameters for Each Round')
# plt.legend()
# plt.grid(True)

# fed_part_avg_losses = [fed_part_avg_result[round]["total_loss"] for round in fed_part_avg_rounds]

# plt.figure(figsize=(10, 5))
# plt.plot(fed_part_avg_rounds, fed_part_avg_losses, marker='o', linestyle='-', color='b', label='FedPartAvgMomentum')
# # plt.plot(fed_avg_rounds, fed_avg_losses, marker='o', linestyle='-', color='r', label='FedAvg')
# plt.xlabel('Round')
# plt.ylabel('Total Loss')
# plt.title('Total Loss for Each Round')
# plt.legend()
# plt.grid(True)


# fed_part_avg_model_rounds = list(fed_part_avg_model_results.keys())
# fed_part_avg_accuracies = [fed_part_avg_model_results[round]["global_metrics"]["accuracy"] for round in fed_part_avg_model_rounds]

# plt.figure(figsize=(10, 5))
# plt.plot(fed_part_avg_model_rounds, fed_part_avg_accuracies, marker='o', linestyle='-', color='b', label='FedPartAvgMomentum')
# # plt.plot(fed_avg_model_rounds, fed_avg_accuracies, marker='o', linestyle='-', color='r', label='FedAvg')
# plt.xlabel('Round')
# plt.ylabel('Accuracy')
# plt.title('Accuracy for Each Round')
# plt.legend()
# plt.grid(True)

# fed_part_avg_global_losses = [fed_part_avg_model_results[round]["global_loss"] for round in fed_part_avg_model_rounds]

# plt.figure(figsize=(10, 5))
# plt.plot(fed_part_avg_model_rounds, fed_part_avg_global_losses, marker='o', linestyle='-', color='b', label='FedPartAvgMomentum')
# # plt.plot(fed_avg_model_rounds, fed_avg_global_losses, marker='o', linestyle='-', color='r', label='FedAvg')
# plt.xlabel('Round')
# plt.ylabel('Loss')
# plt.legend()
# plt.title('Loss for Each Round')


# Using momentum similarity for weighing updates
Weigh updates by cosine similarity between client momentum and global momentum from previous round. 

Future work ideas:
- Save cosine similarities. Use these to compute the distributions of these similarities and identify
- Use cosine similarities to identify outlier clients.
This is likely of greater interest in heterogeneous data settings.

In [None]:
from typing import Union
from functools import partial, reduce
from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg

fed_momentum_weighting_results = {}
fed_momentum_weighting_model_results = {}

global client_opt
global prev_opt

class FedPartAvgWithVelocityweighting(Strategy):
    def __init__(
        self,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        evaluate_fn: Optional[
            Callable[
                [int, NDArrays, dict[str, Scalar]],
                Optional[tuple[float, dict[str, Scalar]]],
            ]
        ] = None,
        on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
        on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
        accept_failures: bool = True,
        initial_parameters: Optional[Parameters] = None,
        fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        inplace: bool = True,
        layer_update_strategy: str = "sequential",
        
    ) -> None:
        super().__init__()
        self.fraction_fit = fraction_fit
        self.fraction_evaluate = fraction_evaluate
        self.min_fit_clients = min_fit_clients
        self.min_evaluate_clients = min_evaluate_clients
        self.min_available_clients = min_available_clients
        self.evaluate_fn = evaluate_fn
        self.on_fit_config_fn = on_fit_config_fn
        self.on_evaluate_config_fn = on_evaluate_config_fn
        self.accept_failures = accept_failures
        self.initial_parameters = initial_parameters
        self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn
        self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn
        self.inplace = inplace

        self.layer_update_strategy = layer_update_strategy  # 'sequential' or 'cyclic'
        self.current_layer = 0  # Track which layer to update
        self.number_of_layers = None
        self.layer_training_sequence = []
        self.training_sequence_index = 0
        self.latest_parameters = initial_parameters
        self.updated_layers = -1
        self.global_optim_state = None


    def __repr__(self) -> str:
        return "FedPartAvgWithVelocityWeighting"
    

    def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]:
        """Return sample size and required number of clients."""
        num_clients = int(num_available_clients * self.fraction_fit)
        return max(num_clients, self.min_fit_clients), self.min_available_clients

    def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]:
        """Use a fraction of available clients for evaluation."""
        num_clients = int(num_available_clients * self.fraction_evaluate)
        return max(num_clients, self.min_evaluate_clients), self.min_available_clients
    
    def generate_layer_training_sequence(self) -> List[int]:
        """Generate a sequence of layers to train."""
        layer_training_sequence = []
        for _ in range(NUM_OF_CYCLES):
            for _ in range(NUM_OF_FULL_UPDATES_BETWEEN_CYCLES):
                    layer_training_sequence.append(-1)
            for layer in range(NETWORK_LEN):
                    layer_training_sequence.append(layer)
                    layer_training_sequence.append(layer)

        return layer_training_sequence

    def initialize_parameters(
        self, client_manager: ClientManager
    ) -> Optional[Parameters]:
        """Initialize global model parameters."""
        net = Net()
        ndarrays = get_parameters(net)
        self.layer_training_sequence = self.generate_layer_training_sequence()
        self.number_of_layers = len(ndarrays)
        self.latest_parameters = ndarrays_to_parameters(ndarrays)

        return ndarrays_to_parameters(ndarrays)
    


    def evaluate(
        self, server_round: int, parameters: Parameters
    ) -> Optional[tuple[float, dict[str, Scalar]]]:
        """Evaluate model parameters using an evaluation function."""
        if self.evaluate_fn is None:
            # No evaluation function provided
            return None
        parameters_ndarrays = parameters_to_ndarrays(parameters)
        eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})
        if eval_res is None:
            return None
        
        if server_round in fed_momentum_weighting_model_results:  
            expand_fed_part_avg_model_results= {**fed_momentum_weighting_model_results[server_round], "global_loss": eval_res[0], "global_metrics": eval_res[1]}
        else:
            expand_fed_part_avg_model_results= {"global_loss": eval_res[0], "global_metrics": eval_res[1]}
        
        fed_momentum_weighting_model_results[server_round] = expand_fed_part_avg_model_results
        
        loss, metrics = eval_res
        return loss, metrics

    def configure_fit(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, FitIns]]:
        """Configure the next round of training."""
        
        config = {"trainable_layers": self.layer_training_sequence[self.training_sequence_index], "updated_layers": self.updated_layers}

        if self.global_optim_state is not None:
            config["optimizer_state"] = self.global_optim_state
        
        sample_size, min_num_clients = self.num_fit_clients(
            client_manager.num_available()
        )
        clients = client_manager.sample(
            num_clients=sample_size, min_num_clients=min_num_clients
        )
        
        print(f"Training on layer {self.layer_training_sequence}")
        fit_configurations = []

        params_array = parameters_to_ndarrays(parameters)
        
        # If doing full model update, send all parameters
        if self.layer_training_sequence[self.training_sequence_index] == -1 or self.updated_layers == -1:
            selected_params = parameters
        else:
            layer_idx = self.updated_layers
            selected_params = ndarrays_to_parameters([
                    params_array[layer_idx * 2],     # Weight
                    params_array[layer_idx * 2 + 1]  # Bias
                ])

        for idx, client in enumerate(clients):
            fit_configurations.append((client, FitIns(selected_params, config)))

        self.updated_layers = self.layer_training_sequence[self.training_sequence_index]
        self.training_sequence_index = self.training_sequence_index + 1
        
        return fit_configurations
    

    def configure_evaluate(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, EvaluateIns]]:
        """Configure the next round of evaluation."""
        if self.fraction_evaluate == 0.0:
            return []
        config = {}
        evaluate_ins = EvaluateIns(parameters, config)

        # Sample clients
        sample_size, min_num_clients = self.num_evaluation_clients(
            client_manager.num_available()
        )
        clients = client_manager.sample(
            num_clients=sample_size, min_num_clients=min_num_clients
        )

        # Return client/config pairs
        return [(client, evaluate_ins) for client in clients]

    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        global global_all_optimizer_states
        """Aggregate fit results using weighted average."""
        
        total_size = 0
        for client, fit_res in results:
            total_size += get_parameters_size(fit_res.parameters)
            total_size += fit_res.metrics["recieved_parameter_size"]
            
        print(f"total size: {total_size}")
        
        weights_results = [
            (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
            for _, fit_res in results
        ]

        metrics = {}
        optimizer_states_serialized = [res.metrics.get("optimizer_state", None) for _, res in results]
        optimizer_states = [deserialize_optimizer_state(state) for state in optimizer_states_serialized if state is not None]

        # Get cosine similarities
        if server_round == 1:
            # First round doesn't have a previous state. Use average momentum instead
            equal_weights = [1] * len(optimizer_states)
            prev_global_optimizer_state = self._aggregate_optimizer_states_weighted(optimizer_states, equal_weights)
        else:
            prev_global_optimizer_state = deserialize_optimizer_state(self.global_optim_state)
        cos_similarities = self._cos_similarity_from_optimizer_states(optimizer_states, prev_global_optimizer_state)

        # Aggregate nn params using weighting from momentum
        aggregated_weights = self._aggregate_params_weighted(weights_results, cos_similarities)
        trained_layer = results[0][1].metrics["trained_layer"] 

        if trained_layer == -1:
            self.latest_parameters = ndarrays_to_parameters(aggregated_weights)
        else:
            current_model = parameters_to_ndarrays(self.latest_parameters)
            current_model[trained_layer* 2] = aggregated_weights[0]
            current_model[trained_layer* 2 +1] = aggregated_weights[1]
            self.latest_parameters = ndarrays_to_parameters(current_model)

        # aggregate optimizer state using weighting from momentum
        aggregated_optimizer_state = self._aggregate_optimizer_states_weighted(optimizer_states, cos_similarities)
        aggregated_optimizer_state_serialized = serialize_optimizer_state(aggregated_optimizer_state)
        metrics["optimizer_state"] = aggregated_optimizer_state_serialized
        self.global_optim_state = aggregated_optimizer_state_serialized

        # Add sizes to total_size:
        # Note that we currently add the whole optimizer state for ease of implementation, but this is redundant for all but the unfrozen layer as the frozen layers opt state is unchanged.
        sizes = sum([sys.getsizeof(data) for data in optimizer_states_serialized if data is not None])
        print(f"Adding optimizer state adds size {sizes}")
        total_size += sizes

        if fed_momentum_weighting_results.get(server_round):
            fed_momentum_weighting_results[server_round]["total_size"] = total_size
        else:
            fed_momentum_weighting_results[server_round] = {"total_size": total_size}

        return self.latest_parameters, metrics


    def _momentum_based_aggregate(results: list[tuple[NDArrays, int]]) -> NDArrays:
        """Compute weighted average of params."""
        num_examples_total = sum(num_examples for (_, num_examples) in results)

        # list of weights, each multiplied by the related number of examples
        weighted_weights = [
            [layer * num_examples for layer in weights] for weights, num_examples in results
        ]

        # Compute average weights of each layer
        weights_prime: NDArrays = [
            reduce(np.add, layer_updates) / num_examples_total
            for layer_updates in zip(*weighted_weights)
        ]
        return weights_prime

    def _aggregate_optimizer_states_weighted(self, optimizer_states, similarities):
        # aggregates by comuting mean of 1st and 2nd momentum, and taking the maximum step value. (alternatively, could be changed to reset steps to 0)
        # uses direct normalization (weight of similarity divided by sum of similarities)
        aggregated_state_dict = copy.deepcopy(optimizer_states[0])
        sum_similarities = sum(similarities)

        for layer in range(len(optimizer_states[0]['state'])):
            weighted_fst_momentum = torch.zeros_like(optimizer_states[0]['state'][layer]['exp_avg'])
            weighted_snd_momentum = torch.zeros_like(optimizer_states[0]['state'][layer]['exp_avg_sq'])
            
            # Accumulate weighted momentums from each client
            for i in range(len(optimizer_states)):
                weight = similarities[i] / sum_similarities
                weighted_fst_momentum += optimizer_states[i]['state'][layer]['exp_avg'] * weight
                weighted_snd_momentum += optimizer_states[i]['state'][layer]['exp_avg_sq'] * weight

            aggregated_state_dict['state'][layer]['exp_avg'] = weighted_fst_momentum
            aggregated_state_dict['state'][layer]['exp_avg_sq'] = weighted_snd_momentum
        max_steps = max([optimizer_states[i]['state'][layer]['step'] for i in range(len(optimizer_states))])
        aggregated_state_dict['state'][layer]['step'] = max_steps
        
        return aggregated_state_dict

    def _aggregate_params_weighted(self, results: list[tuple[NDArrays, int]], cos_similarities: list[float]) -> NDArrays:
        """Compute cos similarity weighted average."""
        # Calculate the total number of examples used during training
        assert len(results) == len(cos_similarities)
        sum_sim = sum(cos_similarities)

        # Create a list of weights, each multiplied by the cos similiarity
        weighted_weights = [
            [layer * cos_similarities[i] for layer in results[i][0]] for i in range(len(results))
        ]

        # Compute average weights of each layer
        weights_prime: NDArrays = [
            reduce(np.add, layer_updates) / sum_sim
            for layer_updates in zip(*weighted_weights)
        ]
        return weights_prime

    def _cos_similarity_from_optimizer_states(self, optimizer_states, prev_global_optimizer_state):
        # Returns: cosine similarity from each optimizer state to the prev global state
        cos = torch.nn.CosineSimilarity(dim=-1)
        similarities = []
        layers = range(len(optimizer_states[0]['state']))

        # get global opt params as vector
        fst_momentum_glob = [prev_global_optimizer_state['state'][l]['exp_avg'].view(-1) for l in layers]
        snd_momentum_glob = [prev_global_optimizer_state['state'][l]['exp_avg_sq'].view(-1) for l in layers]
        glob_state_vec = torch.cat(fst_momentum_glob + snd_momentum_glob)

        for client in range(len(optimizer_states)):
            # get client opt params as vector
            fst_momentum_client = [optimizer_states[client]['state'][l]['exp_avg'].view(-1) for l in layers]
            snd_momentum_client = [optimizer_states[client]['state'][l]['exp_avg_sq'].view(-1) for l in layers]
            client_state_vec = torch.cat(fst_momentum_client + snd_momentum_client)

            sim = cos(client_state_vec, glob_state_vec)
            similarities.append(sim.item())

        return similarities

    def aggregate_evaluate(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, EvaluateRes]],
        failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
    ) -> Tuple[Optional[float], Dict[str, Scalar]]:
        """Aggregate evaluation losses using weighted average."""

        if not results:
            return None, {}
        
        total_loss = 0
        for _, evaluate_res in results:
            total_loss += evaluate_res.loss

        if fed_momentum_weighting_results.get(server_round):
            fed_momentum_weighting_results[server_round]["total_loss"] = total_loss
        else:
            fed_momentum_weighting_results[server_round] = {"total_loss": total_loss}

        loss_aggregated = weighted_loss_avg(
            [
                (evaluate_res.num_examples, evaluate_res.loss)
                for _, evaluate_res in results
            ]
        )
        metrics_aggregated = {}
        return loss_aggregated, metrics_aggregated

In [None]:
class FedAvgPartWithVelocityWeightingFlowerClient(NumPyClient):
    def __init__(self, partition_id, net, trainloader, valloader, context: Context):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader
        self.client_state = context.state

        # Initialize parameters record if it doesn't exist
        if "net_parameters" not in self.client_state.parameters_records:
            self.client_state.parameters_records["net_parameters"] = ParametersRecord()
            # Save initial model state
            self._save_model_state()

    def _save_model_state(self):
        """Save current model parameters to context"""
        p_record = ParametersRecord()
        parameters = get_parameters(self.net)
        
        for i, param in enumerate(parameters):
            p_record[f"layer_{i}"] = array_from_numpy(param)
        
        self.client_state.parameters_records["net_parameters"] = p_record

    def _load_model_state(self):
        """Load model parameters from context"""
        p_record = self.client_state.parameters_records["net_parameters"]
        parameters = []
        
        for i in range(len(p_record)):
            parameters.append(p_record[f"layer_{i}"].numpy())
        
        set_parameters(self.net, parameters)

    def get_parameters(self, config):
        print(f"[Client {self.partition_id}] get_parameters")
        parameters = get_parameters(self.net)
        trainable_layer = config["trainable_layers"]
        self._save_model_state()
        
        if trainable_layer == -1:
            return parameters
        
        trained_layer = [parameters[trainable_layer*2], parameters[trainable_layer*2 +1]]
        return trained_layer

    def fit(self, parameters, config):
        print(f"[Client {self.partition_id}] fit")

        self._load_model_state()
        received_parameter_size = get_parameters_size(ndarrays_to_parameters(parameters))
        set_parameters(self.net, parameters, config["updated_layers"])
        freeze_layers(self.net, config["trainable_layers"])

        self.optimizer = torch.optim.Adam(self.net.parameters())
        if "optimizer_state" in config:
            print(f"Got optimizer state in config, setting..")
            serialized_optimizer_state = config["optimizer_state"]
            optim_state = deserialize_optimizer_state(serialized_optimizer_state)
            self._set_optimizer_state(optim_state)

        train(self.net, self.trainloader, epochs=EPOCHS, optimizer=self.optimizer)
        print('finished train')

        # handle optimizer state (serialize before passing)
        optim_state = self._get_optimizer_state()
        serialized_optimizer_state = serialize_optimizer_state(optim_state)
        print(f"After training, got optim state {serialized_optimizer_state[:50]}")

        new_config =  {
            "trained_layer":config["trainable_layers"], 
            "recieved_parameter_size": received_parameter_size,
            "optimizer_state": serialized_optimizer_state
            }

        self._save_model_state()

        return self.get_parameters(config), len(self.trainloader), new_config

    def evaluate(self, parameters, config):
        print(f"[Client {self.partition_id}] evaluate, config: {config}")
        self._load_model_state()
        current_state = get_parameters(self.net)
        set_parameters(self.net, current_state)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

    
    def _get_optimizer_state(self):
        state_dict = self.optimizer.state_dict()
        return state_dict
    
    def _set_optimizer_state(self, state_dict):
        self.optimizer.load_state_dict(state_dict)


def client_fn(context: Context) -> Client:
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]

    # Initialize network if not in context
    if not hasattr(context, 'net'):
        context.net = Net().to(DEVICE)

    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FedAvgPartWithVelocityWeightingFlowerClient(partition_id, context.net, trainloader, valloader, context).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

In [19]:
net = Net().to(DEVICE)

_, _, testloader = load_datasets(0, NUM_PARTITIONS)

evaluate_fn = get_evaluate_fn(testloader, net)

def server_fn(context: Context) -> ServerAppComponents:
    # Configure the server for just 3 rounds of training
    config = ServerConfig(num_rounds=NUM_OF_ROUNDS)
    return ServerAppComponents(
        config=config,
        strategy=FedPartAvgWithVelocityweighting(
            evaluate_fn=evaluate_fn,
        ),
    )

server = ServerApp(server_fn=server_fn)

# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

  obj.co_lnotab,  # for < python 3.10 [not counted in args]
[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=41, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      initial parameters (loss, other metrics): 0.07207868139743805, {'accuracy': 0.1}
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 6 clients (out of 6)


Training on layer [-1, -1, -1, -1, -1, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17]


[36m(ClientAppActor pid=90287)[0m   obj.co_lnotab,  # for < python 3.10 [not counted in args]


[36m(ClientAppActor pid=90287)[0m [Client 1] fit
[36m(ClientAppActor pid=90287)[0m layer index is 0 and nameconv1.weight is trainabe
[36m(ClientAppActor pid=90287)[0m layer index is 1 and nameconv1.bias is trainabe
[36m(ClientAppActor pid=90287)[0m layer index is 2 and nameconv2.weight is trainabe
[36m(ClientAppActor pid=90287)[0m layer index is 3 and nameconv2.bias is trainabe
[36m(ClientAppActor pid=90287)[0m layer index is 4 and nameconv3.weight is trainabe
[36m(ClientAppActor pid=90287)[0m layer index is 5 and nameconv3.bias is trainabe
[36m(ClientAppActor pid=90287)[0m layer index is 6 and nameconv4.weight is trainabe
[36m(ClientAppActor pid=90287)[0m layer index is 7 and nameconv4.bias is trainabe
[36m(ClientAppActor pid=90287)[0m layer index is 8 and nameconv5.weight is trainabe
[36m(ClientAppActor pid=90287)[0m layer index is 9 and nameconv5.bias is trainabe
[36m(ClientAppActor pid=90287)[0m layer index is 10 and nameconv6.weight is trainabe
[36m(Client

[36m(ClientAppActor pid=90283)[0m   obj.co_lnotab,  # for < python 3.10 [not counted in args][32m [repeated 3x across cluster][0m


[36m(ClientAppActor pid=90286)[0m Epoch 1: train loss 0.06495826691389084, accuracy 0.1932193219321932
[36m(ClientAppActor pid=90284)[0m [Client 2] fit[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=90284)[0m layer index is 17 and namefc3.bias is trainabe[32m [repeated 36x across cluster][0m
[36m(ClientAppActor pid=90284)[0m training network...[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=90283)[0m Epoch 1: train loss 0.06467293202877045, accuracy 0.19801980198019803[32m [repeated 3x across cluster][0m
[36m(ClientAppActor pid=90287)[0m Epoch 2: train loss 0.05860227718949318, accuracy 0.27583620818959054[32m [repeated 3x across cluster][0m
[36m(ClientAppActor pid=90283)[0m Epoch 2: train loss 0.05582219362258911, accuracy 0.31218121812181215[32m [repeated 3x across cluster][0m
[36m(ClientAppActor pid=90286)[0m Epoch 3: train loss 0.05033428966999054, accuracy 0.3835883588358836[32m [repeated 3x across cluster][0m
[36m(ClientApp

[36m(ClientAppActor pid=90284)[0m   obj.co_lnotab,  # for < python 3.10 [not counted in args][32m [repeated 2x across cluster][0m


: 

In [None]:
with open(f'results/fed_part_avg_momentum_weighting_results.p', 'wb') as file:
    pickle.dump(fed_momentum_weighting_results, file)

with open(f'results/fed_part_avg_momentum_weighting_model_results.p', 'wb') as file:
    pickle.dump(fed_momentum_weighting_model_results, file)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Plot the total size of parameters for each round
fed_momentum_weighting_rounds = list(fed_momentum_weighting_results.keys())
fed_momentum_weighting_sizes = [fed_momentum_weighting_results[round]["total_size"] for round in fed_momentum_weighting_rounds]

plt.figure(figsize=(10, 5))
plt.plot(fed_momentum_weighting_rounds, fed_momentum_weighting_sizes, marker='o', linestyle='-', color='b', label='FedPartWithVelocityWeighting')
# plt.plot(fed_avg_rounds, fed_avg_sizes, marker='o', linestyle='-', color='r', label='FedAvg')
plt.xlabel('Round')
plt.ylabel('Total Size of Parameters (bytes)')
plt.title('Total Size of Parameters for Each Round')
plt.legend()
plt.grid(True)

fed_momentum_weighting_losses = [fed_momentum_weighting_results[round]["total_loss"] for round in fed_momentum_weighting_rounds]

plt.figure(figsize=(10, 5))
plt.plot(fed_momentum_weighting_rounds, fed_momentum_weighting_losses, marker='o', linestyle='-', color='b', label='FedPartWithVelocityWeighting')
# plt.plot(fed_avg_rounds, fed_avg_losses, marker='o', linestyle='-', color='r', label='FedAvg')
plt.xlabel('Round')
plt.ylabel('Total Loss')
plt.title('Total Loss for Each Round')
plt.legend()
plt.grid(True)


fed_momentum_weighting_model_rounds = list(fed_momentum_weighting_model_results.keys())
fed_momentum_weighting_accuracies = [fed_momentum_weighting_model_results[round]["global_metrics"]["accuracy"] for round in fed_momentum_weighting_model_rounds]

plt.figure(figsize=(10, 5))
plt.plot(fed_momentum_weighting_model_rounds, fed_momentum_weighting_accuracies, marker='o', linestyle='-', color='b', label='FedPartWithVelocityWeighting')
# plt.plot(fed_avg_model_rounds, fed_avg_accuracies, marker='o', linestyle='-', color='r', label='FedAvg')
plt.xlabel('Round')
plt.ylabel('Accuracy')
plt.title('Accuracy for Each Round')
plt.legend()
plt.grid(True)

fed_momentum_weighting_global_losses = [fed_momentum_weighting_model_results[round]["global_loss"] for round in fed_momentum_weighting_model_rounds]

plt.figure(figsize=(10, 5))
plt.plot(fed_momentum_weighting_model_rounds, fed_momentum_weighting_global_losses, marker='o', linestyle='-', color='b', label='FedPartWithVelocityWeighting')
# plt.plot(fed_avg_model_rounds, fed_avg_global_losses, marker='o', linestyle='-', color='r', label='FedAvg')
plt.xlabel('Round')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss for Each Round')