In [1]:
!pip install -U "flwr[simulation]" tensorflow numpy

Collecting numpy
  Using cached numpy-2.2.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)


In [2]:
import flwr as fl
import tensorflow as tf
import numpy as np
from typing import List, Tuple, Optional
import random
import time
import sys
import warnings
import os

In [3]:
warnings.filterwarnings('ignore')

In [4]:
# Suppress TensorFlow warnings and other unnecessary warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # Suppress TensorFlow logging (0 = all, 3 = none)
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)

In [5]:
# Initial configurations
NUM_CLIENTS = 10
NUM_ROUNDS = 10
FRACTION_FIT = 0.5  # Fraction of clients randomly selected per round
BATCH_SIZE = 32
EPOCHS = 1
DATASET_SIZE_PER_CLIENT = 6000  # Number of samples per client (approx. 60000 / 10)

In [6]:
# Global variables to track metrics
total_data_transferred = 0
start_time = time.time()

In [7]:
# Define the model
def create_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

In [8]:
# Calculate model size (to estimate transferred data volume)
def get_model_size(model):
    total_params = sum(np.prod(w.shape) for w in model.get_weights())
    return total_params * 4  # Assume each parameter is 4 bytes (float32)

In [9]:
# Load and split MNIST dataset among clients
def load_data():
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    x_train, y_train = x_train[:60000], y_train[:60000]  # Limit to 60000 samples
    x_test, y_test = x_test[:10000], y_test[:10000]

    # Normalize data
    x_train = x_train.astype('float32') / 255.0
    x_test = x_test.astype('float32') / 255.0

    # Split data among clients
    client_data = []
    samples_per_client = len(x_train) // NUM_CLIENTS
    for i in range(NUM_CLIENTS):
        start_idx = i * samples_per_client
        end_idx = (i + 1) * samples_per_client
        client_data.append((x_train[start_idx:end_idx], y_train[start_idx:end_idx]))
    return client_data, (x_test, y_test)

In [10]:
# Define the client class
class MnistClient(fl.client.NumPyClient):
    def __init__(self, cid: str, x_train, y_train, x_test, y_test):
        self.cid = cid
        self.x_train = x_train
        self.y_train = y_train
        self.x_test = x_test
        self.y_test = y_test
        self.model = create_model()

    def get_parameters(self, config):
        return self.model.get_weights()

    def fit(self, parameters, config):
        self.model.set_weights(parameters)
        self.model.fit(self.x_train, self.y_train, batch_size=BATCH_SIZE, epochs=EPOCHS, verbose=0)
        global total_data_transferred
        model_size = get_model_size(self.model)
        total_data_transferred += model_size * 2  # Account for sending and receiving parameters for this client
        return self.model.get_weights(), len(self.x_train), {}

    def evaluate(self, parameters, config):
        self.model.set_weights(parameters)
        loss, accuracy = self.model.evaluate(self.x_test, self.y_test, verbose=0)
        return loss, len(self.x_test), {"accuracy": float(accuracy)}  # Ensure float for serialization

In [11]:
# Define FedAvg strategy with random client selection
strategy = fl.server.strategy.FedAvg(
    fraction_fit=FRACTION_FIT,  # Fraction of clients selected per round
    fraction_evaluate=1.0,      # Evaluate on all clients
    min_fit_clients=int(NUM_CLIENTS * FRACTION_FIT),
    min_evaluate_clients=NUM_CLIENTS,
    min_available_clients=NUM_CLIENTS,
)

In [12]:
# Custom FedAvg strategy to aggregate accuracy and track data transfer
class CustomFedAvg(fl.server.strategy.FedAvg):
    def aggregate_evaluate(
        self,
        server_round: int,
        results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.EvaluateRes]],
        failures: List[BaseException],
    ) -> Tuple[Optional[float], dict]:
        """Aggregate evaluation loss and accuracy."""
        if not results:
            return None, {}

        # Aggregate loss using default FedAvg
        loss_aggregated = super().aggregate_evaluate(server_round, results, failures)[0]

        # Aggregate accuracy
        accuracies = [r.metrics["accuracy"] * r.num_examples for _, r in results]
        examples = [r.num_examples for _, r in results]
        accuracy_aggregated = sum(accuracies) / sum(examples) if examples else 0.0

        return loss_aggregated, {"accuracy": accuracy_aggregated}

    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
        failures: List[BaseException],
    ) -> Tuple[fl.common.Parameters, dict]:
        """Aggregate fit results and update data transfer."""
        global total_data_transferred
        if results:
            model_size = get_model_size(create_model())  # Use a sample model to estimate size
            num_clients = len(results)
            total_data_transferred += model_size * 2 * num_clients  # Account for all active clients
        return super().aggregate_fit(server_round, results, failures)

In [13]:
# Define strategy with custom FedAvg
strategy = CustomFedAvg(
    fraction_fit=FRACTION_FIT,  # Fraction of clients selected per round
    fraction_evaluate=1.0,      # Evaluate on all clients
    min_fit_clients=int(NUM_CLIENTS * FRACTION_FIT),
    min_evaluate_clients=NUM_CLIENTS,
    min_available_clients=NUM_CLIENTS,
)

In [14]:
# Function to calculate convergence speed
def calculate_convergence_speed(accuracies: List[float]) -> float:
    if len(accuracies) < 2:
        return 0.0
    improvements = [accuracies[i + 1] - accuracies[i] for i in range(len(accuracies) - 1)]
    avg_improvement = np.mean(improvements)
    return avg_improvement / NUM_ROUNDS  # Average accuracy improvement per round

In [15]:
# Main function to start server and clients
def main():
    # Load data
    client_data, (x_test, y_test) = load_data()

    # List to store accuracies
    accuracies = []

    # Function to create clients
    def client_fn(cid: str) -> fl.client.Client:
        return MnistClient(cid, client_data[int(cid)][0], client_data[int(cid)][1], x_test, y_test)

    # Start simulation
    history = fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=NUM_CLIENTS,
        config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
        strategy=strategy,
    )

    # Extract accuracies from history
    for round_num, accuracy in history.metrics_distributed.get('accuracy', []):
        accuracies.append(accuracy)  # Directly append the accuracy value (float)

    # Calculate metrics
    end_time = time.time()
    convergence_speed = calculate_convergence_speed(accuracies)
    total_data_mb = total_data_transferred / (1024 * 1024)  # Convert to megabytes

    # Print results
    print(f"Number of communication rounds: {NUM_ROUNDS}")
    print(f"Total data transferred: {total_data_mb:.2f} MB")
    print(f"Final model accuracy: {accuracies[-1]:.4f}" if accuracies else "No accuracy data")
    print(f"Convergence speed (avg accuracy improvement per round): {convergence_speed:.6f}")

if __name__ == "__main__":
    main()

	Instead, use the `flwr run` CLI command to start a local simulation in your Flower app, as shown for example below:

		$ flwr new  # Create a new Flower app from a template

		$ flwr run  # Run the Flower app in Simulation Mode

	Using `start_simulation()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
	Instead, use the `flwr run` CLI command to start a local simulation in your Flower app, as shown for example below:

		$ flwr new  # Create a new Flower app from a template

		$ flwr run  # Run the Flower app in Simulation Mode

	Using `start_simulation()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower simulation, config: num_rounds=10, no round_timeout
2025-05-16 10:20:24,537	INFO worker.py:1771 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initial

Number of communication rounds: 10
Total data transferred: 38.82 MB
Final model accuracy: 0.9651
Convergence speed (avg accuracy improvement per round): 0.000641
