In [1]:
import flwr as fl
import sys
import numpy as np
from tensorflow import keras
import os
import pickle

global_model = keras.Sequential([
            keras.layers.Flatten(input_shape=(28,28)),
            keras.layers.Dense(128, activation='relu'),
            keras.layers.Dense(256, activation='relu'),
            keras.layers.Dense(10, activation='softmax')
        ])

In [2]:
# Define a custom strategy that saves the aggregated weights after each round.
class SaveModelStrategy(fl.server.strategy.FedAvg):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.save_dir = './'
    def aggregate_fit(self, rnd, results, failures):
        """Aggregates the weights from the clients and saves the aggregated weights
        to a file."""
        results_dict = {}
        for client_id, client_weights in results:
            results_dict[client_id] = client_weights
        print(results_dict.items())

        # Save the weights of each client to a separate file.
        for client_id, client_weights in results_dict.items():
            client_weights_bytes = pickle.dumps(client_weights)
            client_weights_ndarrays = fl.common.bytes_to_ndarray(client_weights_bytes)
            np.savez(f"{self.save_dir}/round-{rnd}-client-{client_id}-weights.npz", *client_weights_ndarrays)


        # Aggregate the weights from the clients.
        aggregated_weights, agg_metrics = super().aggregate_fit(rnd, results, failures)

        # Save the aggregated weights to a file.
        if aggregated_weights is not None:
            print(f"Saving round {rnd} aggregated_weights to npz file...")
            aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_weights)
            np.savez(f"round-{rnd}-weights.npz", *aggregated_ndarrays)

            # Set the global model weights to the aggregated weights.
            global_model.set_weights(aggregated_ndarrays)

        return aggregated_weights, agg_metrics


In [3]:
global_model.get_weights()

[array([[-0.0186247 ,  0.02316877,  0.02240704, ...,  0.06275254,
          0.03031649, -0.06502932],
        [-0.08054168, -0.01716843,  0.01335085, ...,  0.03897269,
         -0.01196608, -0.00784412],
        [ 0.0426623 ,  0.01697833, -0.01824363, ...,  0.06636366,
         -0.06258225, -0.04735654],
        ...,
        [-0.02132846,  0.01508735,  0.03979226, ...,  0.04802003,
         -0.05110517, -0.03852886],
        [-0.04995808, -0.03515657,  0.03035498, ...,  0.07391796,
         -0.02808165, -0.04586395],
        [ 0.07336161, -0.04918544, -0.02184445, ...,  0.00994907,
         -0.02106065,  0.00291602]], dtype=float32),
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.

In [4]:

# Start the server.
strategy = SaveModelStrategy()
PORT=5010
fl.server.start_server(
    server_address='localhost:'+str(PORT),
    config=fl.server.ServerConfig(num_rounds=3),
    grpc_max_message_length=1024*1024*1024,
    strategy=strategy
)

INFO flwr 2023-11-03 17:39:48,520 | app.py:162 | Starting Flower server, config: ServerConfig(num_rounds=3, round_timeout=None)
INFO flwr 2023-11-03 17:39:48,589 | app.py:175 | Flower ECE: gRPC server running (3 rounds), SSL is disabled
INFO flwr 2023-11-03 17:39:48,590 | server.py:89 | Initializing global parameters
INFO flwr 2023-11-03 17:39:48,590 | server.py:276 | Requesting initial parameters from one random client
INFO flwr 2023-11-03 17:39:58,365 | server.py:280 | Received initial parameters from one random client
INFO flwr 2023-11-03 17:39:58,367 | server.py:91 | Evaluating initial parameters
INFO flwr 2023-11-03 17:39:58,368 | server.py:104 | FL starting
DEBUG flwr 2023-11-03 17:40:00,836 | server.py:222 | fit_round 1: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-11-03 17:40:03,016 | server.py:236 | fit_round 1 received 2 results and 0 failures
IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing i

ValueError: Cannot load file containing pickled data when allow_pickle=False

In [None]:
global_model.summary()

In [None]:
global_model.get_weights()

In [None]:
import tensorflow as tf
global_model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
(_,_), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_test = x_test/255.0
x_test = x_test.reshape(x_test.shape[0],28,28,1)
acc = global_model.evaluate(x_test, y_test)
print(acc)

# Getting Hyperparametres

In [None]:
npzfile = np.load("./round-3-weights.npz")
updated_model_weights = [npzfile[key] for key in npzfile.files]

In [None]:
# Check the shapes and types of the global_model weights
global_weights = global_model.get_weights()

for i, w in enumerate(global_weights):
    print(f"Weight {i} - Shape: {w.shape}, Type: {w.dtype}")


In [None]:
for i, (gw, w) in enumerate(zip(global_weights, updated_model_weights)):
    if gw.shape != w.shape:
        print(f"Weight {i} shape mismatch - Global: {gw.shape}, Updated: {w.shape}")
    if gw.dtype != w.dtype:
        print(f"Weight {i} type mismatch - Global: {gw.dtype}, Updated: {w.dtype}")
    print("All is well.")

In [None]:
updated_model_weights