In [7]:
! pip install -q flwr[simulation] flwr-datasets[vision] torch torchvision matplotlib
! pip install -U ipywidgets
! pip install numpy==1.26.4
! pip install urllib3==1.26.6

zsh:1: no matches found: flwr[simulation]


In [1]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union, Callable

import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import copy
import torch.nn.functional as F
import torchvision.transforms as transforms
from datasets.utils.logging import disable_progress_bar
from torch.utils.data import DataLoader
from flwr.server.strategy import Strategy
import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Metrics, Context, Status, GetParametersRes, Parameters, GetParametersIns, MetricsAggregationFn,NDArrays,Scalar
from flwr.server import ServerApp, ServerConfig, ServerAppComponents 
from flwr.server.strategy import FedAvg, FedProx
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset
from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg

DEVICE = "cuda" 
if torch.cuda.is_available():
    DEVICE = "cuda" 
# elif torch.backends.mps.is_available():
#     DEVICE = "mps"
else:
    DEVICE = "cpu"
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")
disable_progress_bar()

Training on cpu
Flower 1.15.1 / PyTorch 2.6.0


In [2]:

BATCH_SIZE = 32

def load_datasets(partition_id, num_partitions: int):
    fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions})
    partition = fds.load_partition(partition_id)
    # Divide data on each node: 80% train, 20% test
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
    pytorch_transforms = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    def apply_transforms(batch):
        # Instead of passing transforms to CIFAR10(..., transform=transform)
        # we will use this function to dataset.with_transform(apply_transforms)
        # The transforms object is exactly the same
        batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
        return batch

    partition_train_test = partition_train_test.with_transform(apply_transforms)
    trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
    valloader = DataLoader(partition_train_test["test"], batch_size=32)
    testset = fds.load_split("test").with_transform(apply_transforms)
    testloader = DataLoader(testset, batch_size=32)
    return trainloader, valloader, testloader

In [3]:
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class MoonNet(nn.Module):
    """Returns both the representation (penultimate layer output) and classification"""
    def __init__(self) -> None:
        super(MoonNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        representation = x.clone()
        classification = self.fc3(x)
        return representation, classification

def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def train(net, trainloader, epochs: int):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for batch in trainloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")
        
def proxima_train(net, trainloader, epochs: int, proximal_mu:float, global_params:List[torch.Tensor]):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for batch in trainloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)

            proximal_term = 0.0
            for local_weights, global_weights in zip(net.parameters(), global_params):
                proximal_term += (local_weights - global_weights).norm(2)
            loss = criterion(net(images), labels) + (proximal_mu / 2) * proximal_term


            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def train_moon(net,train_loader, global_net,previous_net, epochs, mu, temperature, partition_id=None):
    """Training function for MOON."""
    print(f"Started training moon")
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())

    previous_net.eval()
    global_net.eval()

    cnt = 0
    cos = torch.nn.CosineSimilarity(dim=-1)

    for epoch in range(epochs):
        epoch_loss_collector = []
        epoch_loss1_collector = []
        epoch_loss2_collector = []
        for batch in train_loader:
            x, target = batch["img"], batch["label"]
            x, target = x.to(DEVICE), target.to(DEVICE)
            optimizer.zero_grad()

            # pro1 is the representation by the current model (Line 14 of Algorithm 1)
            pro1, out = net(x)
            # pro2 is the representation by the global model (Line 15 of Algorithm 1)
            # pro3 is the representation by the previous model (Line 16 of Algorithm 1)
            with torch.no_grad():
                pro2, _ = global_net(x)
                pro3, _ = previous_net(x)

            # posi is the positive pair
            posi = cos(pro1, pro2)
            logits = posi.reshape(-1, 1)

            # nega is the negative pair
            nega = cos(pro1, pro3)
            logits = torch.cat((logits, nega.reshape(-1, 1)), dim=1)
            if cnt % 500 == 0:
                print(f"Client {partition_id}: at epoch {epoch} batch {cnt}, similarity between current and global is {torch.mean(posi)}")
                print(f"Client {partition_id}: at epoch {epoch} batch {cnt}, similarity between current and previous is {torch.mean(nega)}")
                print(f"Client {partition_id}: has net {net.clientid} prev {previous_net.clientid} global {global_net.clientid}")


            previous_net.to("cpu")
            logits /= temperature
            labels = torch.zeros(x.size(0)).to(DEVICE).long()

            # compute the model-contrastive loss (Line 17 of Algorithm 1)
            loss2 = mu * criterion(logits, labels)

            # compute the cross-entropy loss (Line 13 of Algorithm 1)
            loss1 = criterion(out, target)

            # compute the loss (Line 18 of Algorithm 1)
            loss = loss1 + loss2

            loss.backward()
            optimizer.step()

            cnt += 1
            epoch_loss_collector.append(loss.item())
            epoch_loss1_collector.append(loss1.item())
            epoch_loss2_collector.append(loss2.item())

        epoch_loss = sum(epoch_loss_collector) / len(epoch_loss_collector)
        epoch_loss1 = sum(epoch_loss1_collector) / len(epoch_loss1_collector)
        epoch_loss2 = sum(epoch_loss2_collector) / len(epoch_loss2_collector)
        print(
            "Epoch: %d Loss: %f Loss1: %f Loss2: %f"
            % (epoch, epoch_loss, epoch_loss1, epoch_loss2)
        )


def test_moon(net, testloader):
    """
    Evaluate the network on the entire test set.
    Same as the regular test, but using the MoonNet 
    (where the output is a tuple of (representation, classification) )
    """
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            _, outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy




def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

def freeze_layers(model: torch.nn.Module, trainable_layers: int) -> None:
        """Freeze specified layers of the model."""
        for idx, (name, param) in enumerate(model.named_parameters()):
            if idx == trainable_layers or trainable_layers == -1:
                param.requires_grad = True
            else:
                param.requires_grad = False

In [4]:

NETWORK_LEN = len(Net().state_dict().keys())
EPOCHS = 8
NUM_PARTITIONS = 3
NUM_OF_CYCLES  = 1
NUM_OF_FULL_UPDATES_BETWEEN_CYCLES = 5
NUM_OF_ROUNDS = (NUM_OF_CYCLES * NUM_OF_FULL_UPDATES_BETWEEN_CYCLES) + (NUM_OF_CYCLES * NETWORK_LEN *2)
print(f"Number of rounds: {NUM_OF_ROUNDS}")
backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 0.0}}


Number of rounds: 25


# Normal FedAvg

In [9]:
class NormalFlowerClient(NumPyClient):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.partition_id}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.partition_id}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=EPOCHS)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.partition_id}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


def client_fn(context: Context) -> Client:
    net = Net().to(DEVICE)
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return NormalFlowerClient(partition_id, net, trainloader, valloader).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

In [None]:
def server_fn(context: Context) -> ServerAppComponents:
    # Configure the server for just 3 rounds of training
    config = ServerConfig(num_rounds=NUM_OF_ROUNDS)
    return ServerAppComponents(
        config=config,
        strategy=FedAvg(),
    )

server = ServerApp(server_fn=server_fn)

# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=25, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[91mERROR [0m:     An exception occurred !! local variable 'backend' referenced before assignment
[91mERROR [0m:     Traceback (most recent call last):
  File "/opt/homebrew/anaconda3/envs/fl_proj/lib/python3.9/site-packages/flwr/server/superlink/fleet/vce/vce_api.py", line 191, in run_api
    backend = backend_fn()
  File "/opt/homebrew/anaconda3/envs/fl_proj/lib/python3.9/site-packages/flwr/server/superlink/fleet/vce/vce_api.py", line 341, in backend_fn
    return backend_type(backend_config)
  File "/opt/homebrew/anaconda3/envs/fl_proj/lib/python3.9/site-packages/flwr/server/superlink/fleet/vce/backend/raybackend.py", line 52, in __init__
    self.init_ray(backend_config)
  File "/opt/homebrew/anaconda3/envs/fl_proj/lib/python3.9/site-packages/flwr/server/superlink/fleet

# FedAvgPart Experiments

In [5]:
class FedAvgPartFlowerClient(NumPyClient):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.partition_id}] get_parameters")
        return [get_parameters(self.net)[config["trainable_layers"]]]

    def fit(self, parameters, config):
        print(f"[Client {self.partition_id}] fit, config: {config}")
        set_parameters(self.net, parameters)
        freeze_layers(self.net, config["trainable_layers"])
        train(self.net, self.trainloader, epochs=EPOCHS)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.partition_id}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


def client_fn(context: Context) -> Client:
    net = Net().to(DEVICE)
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FedAvgPartFlowerClient(partition_id, net, trainloader, valloader).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

In [6]:
from typing import Union

from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg


class FedPartAvg(Strategy):
    def __init__(
        self,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        evaluate_fn: Optional[
            Callable[
                [int, NDArrays, dict[str, Scalar]],
                Optional[tuple[float, dict[str, Scalar]]],
            ]
        ] = None,
        on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
        on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
        accept_failures: bool = True,
        initial_parameters: Optional[Parameters] = None,
        fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        inplace: bool = True,
        layer_update_strategy: str = "sequential",
        
    ) -> None:
        super().__init__()
        self.fraction_fit = fraction_fit
        self.fraction_evaluate = fraction_evaluate
        self.min_fit_clients = min_fit_clients
        self.min_evaluate_clients = min_evaluate_clients
        self.min_available_clients = min_available_clients
        self.evaluate_fn = evaluate_fn
        self.on_fit_config_fn = on_fit_config_fn
        self.on_evaluate_config_fn = on_evaluate_config_fn
        self.accept_failures = accept_failures
        self.initial_parameters = initial_parameters
        self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn
        self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn
        self.inplace = inplace

        self.layer_update_strategy = layer_update_strategy  # 'sequential' or 'cyclic'
        self.current_layer = 0  # Track which layer to update
        self.number_of_layers = None
        self.layer_training_sequence = []
        self.training_sequence_index = 0


    def __repr__(self) -> str:
        return "FedPartAvg"
    

    def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]:
        """Return sample size and required number of clients."""
        num_clients = int(num_available_clients * self.fraction_fit)
        return max(num_clients, self.min_fit_clients), self.min_available_clients

    def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]:
        """Use a fraction of available clients for evaluation."""
        num_clients = int(num_available_clients * self.fraction_evaluate)
        return max(num_clients, self.min_evaluate_clients), self.min_available_clients
    
    def generate_layer_training_sequence(self) -> List[int]:
        """Generate a sequence of layers to train."""
        layer_training_sequence = []
        for _ in range(NUM_OF_CYCLES):
            for _ in range(NUM_OF_FULL_UPDATES_BETWEEN_CYCLES):
                    layer_training_sequence.append(-1)
            for layer in range(NETWORK_LEN):
                    layer_training_sequence.append(layer)
                    layer_training_sequence.append(layer)

        return layer_training_sequence

    def initialize_parameters(
        self, client_manager: ClientManager
    ) -> Optional[Parameters]:
        """Initialize global model parameters."""
        net = Net()
        ndarrays = get_parameters(net)
        self.layer_training_sequence = self.generate_layer_training_sequence()
        self.number_of_layers = len(ndarrays)
        return ndarrays_to_parameters(ndarrays)
    


    def evaluate(
        self, server_round: int, parameters: Parameters
    ) -> Optional[tuple[float, dict[str, Scalar]]]:
        """Evaluate model parameters using an evaluation function."""
        if self.evaluate_fn is None:
            # No evaluation function provided
            return None
        parameters_ndarrays = parameters_to_ndarrays(parameters)
        eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})
        if eval_res is None:
            return None
        loss, metrics = eval_res
        return loss, metrics

    def configure_fit(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, FitIns]]:
        """Configure the next round of training."""
        
        config = {"trainable_layers": self.layer_training_sequence[self.training_sequence_index]}
        
        sample_size, min_num_clients = self.num_fit_clients(
            client_manager.num_available()
        )
        clients = client_manager.sample(
            num_clients=sample_size, min_num_clients=min_num_clients
        )
        
        print(f"Training on layer {self.layer_training_sequence}")
        fit_configurations = []
        for idx, client in enumerate(clients):
            fit_configurations.append((client, FitIns(parameters, config)))

        self.training_sequence_index = self.training_sequence_index + 1
        
        return fit_configurations
    

    def configure_evaluate(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, EvaluateIns]]:
        """Configure the next round of evaluation."""
        if self.fraction_evaluate == 0.0:
            return []
        config = {}
        evaluate_ins = EvaluateIns(parameters, config)

        # Sample clients
        sample_size, min_num_clients = self.num_evaluation_clients(
            client_manager.num_available()
        )
        clients = client_manager.sample(
            num_clients=sample_size, min_num_clients=min_num_clients
        )

        # Return client/config pairs
        return [(client, evaluate_ins) for client in clients]

    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate fit results using weighted average."""
        
        for client, fit_res in results:
            print(f"Client {client} has trained on layer {fit_res.parameters}")

        weights_results = [
            (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
            for _, fit_res in results
        ]
        parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results))
        metrics_aggregated = {}
        return parameters_aggregated, metrics_aggregated

    

    def aggregate_evaluate(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, EvaluateRes]],
        failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
    ) -> Tuple[Optional[float], Dict[str, Scalar]]:
        """Aggregate evaluation losses using weighted average."""

        if not results:
            return None, {}

        loss_aggregated = weighted_loss_avg(
            [
                (evaluate_res.num_examples, evaluate_res.loss)
                for _, evaluate_res in results
            ]
        )
        metrics_aggregated = {}
        return loss_aggregated, metrics_aggregated

    

In [None]:


def server_fn(context: Context) -> ServerAppComponents:
    # Configure the server for just 3 rounds of training
    config = ServerConfig(num_rounds=NUM_OF_ROUNDS)
    return ServerAppComponents(
        config=config,
        strategy=FedPartAvg(),
    )

server = ServerApp(server_fn=server_fn)

# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

# FedProxPart Experiments

In [7]:
class FedProxPartFlowerClient(NumPyClient):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.partition_id}] get_parameters")
        return [get_parameters(self.net)[config["trainable_layers"]]]

    def fit(self, parameters, config):
        print(f"[Client {self.partition_id}] fit, config: {config}")
        set_parameters(self.net, parameters)
        global_params = copy.deepcopy(self.net).parameters()
        freeze_layers(self.net, config["trainable_layers"])
        proxima_train(self.net, self.trainloader, EPOCHS, config["proximal_mu"], global_params)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.partition_id}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


def client_fn(context: Context) -> Client:
    net = Net().to(DEVICE)
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FedProxPartFlowerClient(partition_id, net, trainloader, valloader).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

In [8]:
class FedPartProx(FedPartAvg):

    def __init__(
        self,
        *,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        evaluate_fn: Optional[
            Callable[
                [int, NDArrays, dict[str, Scalar]],
                Optional[tuple[float, dict[str, Scalar]]],
            ]
        ] = None,
        on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
        on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
        accept_failures: bool = True,
        initial_parameters: Optional[Parameters] = None,
        fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        proximal_mu: float,
    ) -> None:
        super().__init__(
            fraction_fit=fraction_fit,
            fraction_evaluate=fraction_evaluate,
            min_fit_clients=min_fit_clients,
            min_evaluate_clients=min_evaluate_clients,
            min_available_clients=min_available_clients,
            evaluate_fn=evaluate_fn,
            on_fit_config_fn=on_fit_config_fn,
            on_evaluate_config_fn=on_evaluate_config_fn,
            accept_failures=accept_failures,
            initial_parameters=initial_parameters,
            fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
            evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
        )
        self.proximal_mu = proximal_mu


    def __repr__(self) -> str:
        return "FedPartAvg"
    

    def configure_fit(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> list[tuple[ClientProxy, FitIns]]:
        """Configure the next round of training.

        Sends the proximal factor mu to the clients
        """
        # Get the standard client/config pairs from the FedAvg super-class
        client_config_pairs = super().configure_fit(
            server_round, parameters, client_manager
        )

        # Return client/config pairs with the proximal factor mu added
        return [
            (
                client,
                FitIns(
                    fit_ins.parameters,
                    {**fit_ins.config, "proximal_mu": self.proximal_mu},
                ),
            )
            for client, fit_ins in client_config_pairs
        ]


NameError: name 'FedPartAvg' is not defined

In [None]:
def server_fn(context: Context) -> ServerAppComponents:
    # Configure the server for just 3 rounds of training
    config = ServerConfig(num_rounds=NUM_OF_ROUNDS)
    return ServerAppComponents(
        config=config,
        strategy=FedPartProx(proximal_mu=0.1),
    )

server = ServerApp(server_fn=server_fn)

# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

# FedMoon Experiments:

In [None]:
class FedMoonPartFlowerClient(NumPyClient):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader
        self.model_dir = "models"

    def get_parameters(self, config):
        print(f"[Client {self.partition_id}] get_parameters")
        # return [get_parameters(self.net)[config["trainable_layers"]]]
        parameters = get_parameters(self.net)
        # print(f"Parameters are: {len(parameters)} {parameters}")
        return parameters

    def fit(self, parameters, config):
        print(f"[Client {self.partition_id}] fit, config: {config}")

        # load previous model
        if not os.path.exists(os.path.join(self.model_dir, str(self.partition_id))):
            prev_model = copy.deepcopy(self.net)
        else:
            # initialise and load params from model_dir
            prev_model = type(self.net)() 
            prev_model.load_state_dict(
                torch.load(
                    os.path.join(self.model_dir, str(self.partition_id), "prev_net.pt")
                )
            )

        # update params for current model (loading global params)
        set_parameters(self.net, parameters)

        # create global model (same params that were just loaded)
        global_model = type(self.net)()
        global_model.load_state_dict(self.net.state_dict())
        global_model.to(DEVICE)
        
        self.net.clientid = self.partition_id
        global_model.clientid = self.partition_id
        prev_model.clientid = self.partition_id

        # freeze_layers(self.net, config["trainable_layers"])
        train_moon(self.net, self.trainloader, global_model, prev_model, EPOCHS, 5, 0.5, self.partition_id )

        # save current model 
        if not os.path.exists(os.path.join(self.model_dir, str(self.partition_id))):
            os.makedirs(os.path.join(self.model_dir, str(self.partition_id)))
        torch.save(
            self.net.state_dict(),
            os.path.join(self.model_dir, str(self.partition_id), "prev_net.pt"),
        )

        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.partition_id}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test_moon(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


def client_fn(context: Context) -> Client:
    net = MoonNet().to(DEVICE)
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FedMoonPartFlowerClient(partition_id, net, trainloader, valloader).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

In [None]:
# Train FedMOON
def server_fn(context: Context) -> ServerAppComponents:
    # Configure the server for just 3 rounds of training
    config = ServerConfig(num_rounds=NUM_OF_ROUNDS)
    return ServerAppComponents(
        config=config
    )

server = ServerApp(server_fn=server_fn)

# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=25, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[92mINFO [0m:      Received initial parameters from one random client
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=57379)[0m [Client 0] get_parameters
[36m(ClientAppActor pid=57379)[0m [Client 0] fit, config: {}
[36m(ClientAppActor pid=57379)[0m Started training moon
[36m(ClientAppActor pid=57379)[0m Client 0: at epoch 0 batch 0, similarity between current and global is 1.0
[36m(ClientAppActor pid=57379)[0m Client 0: at epoch 0 batch 0, similarity between current and previous is 0.1600523740053177
[36m(ClientAppActor pid=57379)[0m Client 0: has net 0 prev 0 global 0
[36m(ClientAppActor pid=57379)[0m Epoch: 0 Loss: 2.741122 Loss1: 1.946624 Loss2: 0.794498
[36m(ClientAppActor pid=57378)[0m [Client 1] fit, config: {}[32m [repeated 2x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)[0m
[36m(ClientAppActor pid=57377)[0m Started training moon[32m [repeated 2x across cluster][0m
[36m(ClientAppActor p

[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=57378)[0m [Client 1] evaluate, config: {}
[36m(ClientAppActor pid=57377)[0m Epoch: 7 Loss: 1.897076 Loss1: 1.102979 Loss2: 0.794097[32m [repeated 2x across cluster][0m


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=57379)[0m [Client 0] fit, config: {}
[36m(ClientAppActor pid=57379)[0m Started training moon
[36m(ClientAppActor pid=57379)[0m Client 0: at epoch 0 batch 0, similarity between current and global is 1.0
[36m(ClientAppActor pid=57379)[0m Client 0: at epoch 0 batch 0, similarity between current and previous is 0.9699110984802246
[36m(ClientAppActor pid=57379)[0m Client 0: has net 0 prev 0 global 0
[36m(ClientAppActor pid=57377)[0m [Client 0] evaluate, config: {}[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=57379)[0m Epoch: 0 Loss: 4.481057 Loss1: 1.353308 Loss2: 3.127749
[36m(ClientAppActor pid=57377)[0m [Client 1] fit, config: {}[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=57377)[0m Started training moon[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=57377)[0m Client 1: at epoch 0 batch 0, similarity between current and global is 1.0[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid

[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=57377)[0m [Client 0] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=57379)[0m [Client 0] fit, config: {}
[36m(ClientAppActor pid=57379)[0m Started training moon
[36m(ClientAppActor pid=57379)[0m Client 0: at epoch 0 batch 0, similarity between current and global is 1.0
[36m(ClientAppActor pid=57379)[0m Client 0: at epoch 0 batch 0, similarity between current and previous is 0.7924039363861084
[36m(ClientAppActor pid=57379)[0m Client 0: has net 0 prev 0 global 0
[36m(ClientAppActor pid=57377)[0m Epoch: 7 Loss: 3.858482 Loss1: 0.810878 Loss2: 3.047604[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=57377)[0m Client 1: at epoch 0 batch 0, similarity between current and global is 1.0
[36m(ClientAppActor pid=57377)[0m Client 1: at epoch 0 batch 0, similarity between current and previous is 0.7729507088661194
[36m(ClientAppActor pid=57377)[0m Client 1: has net 1 prev 1 global 1
[36m(ClientAppActor pid=57379)[0m [Client 2] evaluate, config: {}[32m [repeated 2x across cluster][0m
[36m(ClientAppActor 

[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=57379)[0m [Client 2] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=57378)[0m [Client 1] fit, config: {}
[36m(ClientAppActor pid=57377)[0m Epoch: 7 Loss: 2.477286 Loss1: 0.590669 Loss2: 1.886617[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=57379)[0m Started training moon
[36m(ClientAppActor pid=57377)[0m Client 0: at epoch 0 batch 0, similarity between current and global is 1.0
[36m(ClientAppActor pid=57377)[0m Client 0: at epoch 0 batch 0, similarity between current and previous is 0.6458202600479126
[36m(ClientAppActor pid=57377)[0m Client 0: has net 0 prev 0 global 0
[36m(ClientAppActor pid=57377)[0m [Client 0] evaluate, config: {}[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=57377)[0m [Client 0] fit, config: {}[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=57379)[0m Epoch: 0 Loss: 2.645252 Loss1: 0.950089 Loss2: 1.695163
[36m(ClientAppActor pid=57378)[0m Started training moon[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=57379)[0m Client 

[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=57377)[0m [Client 0] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=57377)[0m [Client 1] fit, config: {}
[36m(ClientAppActor pid=57377)[0m Started training moon
[36m(ClientAppActor pid=57377)[0m Client 1: at epoch 0 batch 0, similarity between current and global is 1.0
[36m(ClientAppActor pid=57377)[0m Client 1: at epoch 0 batch 0, similarity between current and previous is 0.6457317471504211
[36m(ClientAppActor pid=57377)[0m Client 1: has net 1 prev 1 global 1
[36m(ClientAppActor pid=57377)[0m Epoch: 7 Loss: 1.812464 Loss1: 0.422319 Loss2: 1.390144[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=57378)[0m Client 2: at epoch 0 batch 0, similarity between current and global is 1.0
[36m(ClientAppActor pid=57378)[0m Client 2: at epoch 0 batch 0, similarity between current and previous is 0.709255039691925
[36m(ClientAppActor pid=57378)[0m Client 2: has net 2 prev 2 global 2
[36m(ClientAppActor pid=57379)[0m [Client 2] evaluate, config: {}[32m [repeated 2x across cluster][0m
[36m(ClientAppActor p

[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=57379)[0m [Client 2] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 6]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=57378)[0m [Client 0] fit, config: {}
[36m(ClientAppActor pid=57378)[0m Started training moon
[36m(ClientAppActor pid=57378)[0m Epoch: 7 Loss: 1.924207 Loss1: 0.316086 Loss2: 1.608120[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=57378)[0m Client 0: at epoch 0 batch 0, similarity between current and global is 1.0
[36m(ClientAppActor pid=57378)[0m Client 0: at epoch 0 batch 0, similarity between current and previous is 0.6460323929786682
[36m(ClientAppActor pid=57378)[0m Client 0: has net 0 prev 0 global 0
[36m(ClientAppActor pid=57377)[0m Client 1: at epoch 0 batch 0, similarity between current and global is 1.0
[36m(ClientAppActor pid=57377)[0m Client 1: at epoch 0 batch 0, similarity between current and previous is 0.8255131840705872
[36m(ClientAppActor pid=57377)[0m Client 1: has net 1 prev 1 global 1
[36m(ClientAppActor pid=57378)[0m [Client 0] evaluate, config: {}[32m [repeated 2x across cluster][0m
[36m(ClientAppActor 

[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=57379)[0m [Client 2] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 7]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=57377)[0m [Client 1] fit, config: {}
[36m(ClientAppActor pid=57377)[0m Epoch: 7 Loss: 2.211791 Loss1: 0.292984 Loss2: 1.918807[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=57379)[0m Started training moon
[36m(ClientAppActor pid=57379)[0m Client 0: at epoch 0 batch 0, similarity between current and global is 1.0
[36m(ClientAppActor pid=57379)[0m Client 0: at epoch 0 batch 0, similarity between current and previous is 0.795071542263031
[36m(ClientAppActor pid=57379)[0m Client 0: has net 0 prev 0 global 0
[36m(ClientAppActor pid=57378)[0m [Client 1] evaluate, config: {}[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=57378)[0m [Client 2] fit, config: {}[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=57379)[0m Epoch: 0 Loss: 2.548912 Loss1: 0.772506 Loss2: 1.776406
[36m(ClientAppActor pid=57378)[0m Started training moon[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=57378)[0m Client 2

[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=57379)[0m [Client 2] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 8]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=57377)[0m [Client 1] fit, config: {}
[36m(ClientAppActor pid=57377)[0m Started training moon
[36m(ClientAppActor pid=57378)[0m Epoch: 7 Loss: 1.923394 Loss1: 0.240490 Loss2: 1.682904[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=57379)[0m Client 2: at epoch 0 batch 0, similarity between current and global is 1.0
[36m(ClientAppActor pid=57379)[0m Client 2: at epoch 0 batch 0, similarity between current and previous is 0.7044244408607483
[36m(ClientAppActor pid=57379)[0m Client 2: has net 2 prev 2 global 2
[36m(ClientAppActor pid=57378)[0m [Client 0] evaluate, config: {}[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=57378)[0m [Client 0] fit, config: {}[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=57378)[0m Started training moon[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=57379)[0m Epoch: 0 Loss: 2.618172 Loss1: 0.710835 Loss2: 1.907337
[36m(ClientAppActor pid=57378)[0m Client 

[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=57377)[0m [Client 0] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 9]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=57377)[0m [Client 2] fit, config: {}
[36m(ClientAppActor pid=57377)[0m Started training moon
[36m(ClientAppActor pid=57377)[0m Client 2: at epoch 0 batch 0, similarity between current and global is 1.0
[36m(ClientAppActor pid=57377)[0m Client 2: at epoch 0 batch 0, similarity between current and previous is 0.666132390499115
[36m(ClientAppActor pid=57377)[0m Client 2: has net 2 prev 2 global 2
[36m(ClientAppActor pid=57377)[0m Epoch: 7 Loss: 2.227530 Loss1: 0.223230 Loss2: 2.004300[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=57378)[0m Client 1: at epoch 0 batch 0, similarity between current and global is 1.0
[36m(ClientAppActor pid=57378)[0m Client 1: at epoch 0 batch 0, similarity between current and previous is 0.6805444955825806
[36m(ClientAppActor pid=57378)[0m Client 1: has net 1 prev 1 global 1
[36m(ClientAppActor pid=57379)[0m [Client 2] evaluate, config: {}[32m [repeated 2x across cluster][0m
[36m(ClientAppActor p

[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=57377)[0m [Client 0] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 10]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=57379)[0m [Client 0] fit, config: {}
[36m(ClientAppActor pid=57379)[0m Started training moon
[36m(ClientAppActor pid=57377)[0m Epoch: 7 Loss: 2.078840 Loss1: 0.219332 Loss2: 1.859508[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=57379)[0m Client 0: at epoch 0 batch 0, similarity between current and global is 1.0
[36m(ClientAppActor pid=57379)[0m Client 0: at epoch 0 batch 0, similarity between current and previous is 0.6702529191970825
[36m(ClientAppActor pid=57379)[0m Client 0: has net 0 prev 0 global 0
[36m(ClientAppActor pid=57377)[0m Client 1: at epoch 0 batch 0, similarity between current and global is 1.0
[36m(ClientAppActor pid=57377)[0m Client 1: at epoch 0 batch 0, similarity between current and previous is 0.7199082374572754
[36m(ClientAppActor pid=57377)[0m Client 1: has net 1 prev 1 global 1
[36m(ClientAppActor pid=57379)[0m [Client 2] evaluate, config: {}[32m [repeated 2x across cluster][0m
[36m(ClientAppActor 