In [3]:
!pip install tensorflow flwr -U "flwr[simulation]" numpy


Collecting tensorflow
  Downloading tensorflow-2.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (601.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m601.3/601.3 MB[0m [31m589.4 kB/s[0m eta [36m0:00:00[0m
Collecting numpy
  Downloading numpy-2.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (19.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m19.5/19.5 MB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
Collecting h5py>=3.10.0 (from tensorflow)
  Downloading h5py-3.11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.3/5.3 MB[0m [31m905.8 kB/s[0m eta [36m0:00:00[0m
Collecting ml-dtypes<0.5.0,>=0.3.1 (from tensorflow)
  Downloading ml_dtypes-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m1.1 MB/s[0m eta [36m0:00:00

In [1]:
import tensorflow as tf
import flwr as fl
import numpy as np
from typing import List, Tuple, Dict
from collections import OrderedDict
import warnings
warnings.filterwarnings("ignore")

# Define the model
class Net(tf.keras.Model):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same')
        self.pool1 = tf.keras.layers.MaxPooling2D((2, 2))
        self.conv2 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')
        self.pool2 = tf.keras.layers.MaxPooling2D((2, 2))
        self.flatten = tf.keras.layers.Flatten()
        self.fc1 = tf.keras.layers.Dense(128, activation='relu')
        self.fc2 = tf.keras.layers.Dense(10)  # CIFAR-10 has 10 classes

    def call(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.fc1(x)
        return self.fc2(x)

def set_parameters(model, parameters: List[np.ndarray]):
    params_dict = zip(model.trainable_variables, parameters)
    for var, param in params_dict:
        var.assign(param)

def get_parameters(model) -> List[np.ndarray]:
    return [var.numpy() for var in model.trainable_variables]

class FlowerClient(fl.client.NumPyClient):
    def __init__(self, model, trainloader, valloader):
        self.model = model
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        return get_parameters(self.model)

    def fit(self, parameters, config):
        set_parameters(self.model, parameters)
        self.model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))
        self.model.fit(self.trainloader, epochs=1, verbose=0)
        return get_parameters(self.model), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        set_parameters(self.model, parameters)
        self.model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                           metrics=['accuracy'])
        loss, accuracy = self.model.evaluate(self.valloader, verbose=0)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

def client_fn(cid: str) -> FlowerClient:
    """Create a Flower client representing a single organization."""

    # Load model
    model = Net()

    # Load data (CIFAR-10)
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    x_train, x_test = x_train.astype(np.float32), x_test.astype(np.float32)
    y_train, y_test = y_train.astype(np.int32), y_test.astype(np.int32)

    # Split data for each client (this is a simplification for illustration)
    num_clients = 10
    client_data_size = len(x_train) // num_clients
    start_idx = int(cid) * client_data_size
    end_idx = (int(cid) + 1) * client_data_size

    train_data = tf.data.Dataset.from_tensor_slices((x_train[start_idx:end_idx], y_train[start_idx:end_idx]))
    trainloader = train_data.batch(32)

    val_data = tf.data.Dataset.from_tensor_slices((x_test, y_test))
    valloader = val_data.batch(32)

    return FlowerClient(model, trainloader, valloader)

def weighted_average(metrics: List[Tuple[int, Dict[str, float]]]) -> Dict[str, float]:
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]
    return {"accuracy": sum(accuracies) / sum(examples)}

# Create FedAvg strategy
strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,
    fraction_evaluate=0.5,
    min_fit_clients=10,
    min_evaluate_clients=5,
    min_available_clients=10,
    evaluate_metrics_aggregation_fn=weighted_average,
)

# Start simulation
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=10,
    config=fl.server.ServerConfig(num_rounds=5),
    strategy=strategy,
    client_resources={"num_cpus": 1, "num_gpus": 0.0}
)


[92mINFO [0m:      Starting Flower simulation, config: num_rounds=5, no round_timeout
2024-07-24 13:12:39,101	INFO worker.py:1752 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'object_store_memory': 4001238220.0, 'memory': 8002476443.0, 'CPU': 2.0, 'node:__internal_head__': 1.0, 'node:172.28.0.12': 1.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 1, 'num_gpus': 0.0}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 2 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[36m(pid=5715)[0m 2024-07-24 13:12:45.077102: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registere

[36m(ClientAppActor pid=5716)[0m Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


[36m(ClientAppActor pid=5716)[0m 
[36m(ClientAppActor pid=5716)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=5716)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=5716)[0m         
[36m(pid=5716)[0m 2024-07-24 13:12:45.125326: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=5716)[0m 2024-07-24 13:12:45.207096: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=5716)[0m 2024-07-24 13:12:45.228599: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


[36m(ClientAppActor pid=5716)[0m [1m        0/170498071[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 0s/step
[1m   909312/170498071[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m25s[0m 0us/step 
[1m  9330688/170498071[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m4s[0m 0us/step
[1m 18595840/170498071[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m2s[0m 0us/step
[1m 26075136/170498071[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m2s[0m 0us/step
[1m 32677888/170498071[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m2s[0m 0us/step
[1m 39632896/170498071[0m [32m━━━━[0m[37m━━━━━━━━━━━━━━━━[0m [1m2s[0m 0us/step
[1m 51355648/170498071[0m [32m━━━━━━[0m[37m━━━━━━━━━━━━━━[0m [1m1s[0m 0us/step
[1m 58720256/170498071[0m [32m━━━━━━[0m[37m━━━━━━━━━━━━━━[0m [1m1s[0m 0us/step
[1m 65978368/170498071[0m [32m━━━━━━━[0m[37m━━━━━━━━━━━━━[0m [1m1s[0m 0us/step
[1m 73768960/170498071[0m [32m━━━━━━━━[0m[37m━━━━━━━━━━━━[0m [1m1s[0m 0us/step
[1m 81141760/170498071

[92mINFO [0m:      Received initial parameters from one random client
[92mINFO [0m:      Evaluating initial global parameters
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 10 clients (out of 10)
[36m(ClientAppActor pid=5716)[0m 
[36m(ClientAppActor pid=5716)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=5716)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=5716)[0m         
[36m(ClientAppActor pid=5715)[0m         [32m [repeated 2x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)[0m
[36m(ClientAppActor pid=5715)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=5715)[0m             entirely in future versions of Flower.
[36m(ClientAppActor

History (loss, distributed):
	round 1: 2.3205811023712157
	round 2: 2.314773750305176
	round 3: 2.312864828109741
	round 4: 2.312894678115845
	round 5: 2.3236358165740967
History (metrics, distributed, evaluate):
{'accuracy': [(1, 0.10070000141859055),
              (2, 0.09345999956130982),
              (3, 0.10750000029802323),
              (4, 0.09937999993562699),
              (5, 0.09931999891996383)]}