In [6]:
pip install flwr-datasets

Collecting flwr-datasets
  Downloading flwr_datasets-0.5.0-py3-none-any.whl.metadata (6.9 kB)
Collecting datasets<=3.1.0,>=2.14.6 (from flwr-datasets)
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting pyarrow>=15.0.0 (from datasets<=3.1.0,>=2.14.6->flwr-datasets)
  Downloading pyarrow-18.1.0-cp312-cp312-win_amd64.whl.metadata (3.4 kB)
Collecting xxhash (from datasets<=3.1.0,>=2.14.6->flwr-datasets)
  Downloading xxhash-3.5.0-cp312-cp312-win_amd64.whl.metadata (13 kB)
Collecting multiprocess<0.70.17 (from datasets<=3.1.0,>=2.14.6->flwr-datasets)
  Downloading multiprocess-0.70.16-py312-none-any.whl.metadata (7.2 kB)
Collecting huggingface-hub>=0.23.0 (from datasets<=3.1.0,>=2.14.6->flwr-datasets)
  Downloading huggingface_hub-0.27.1-py3-none-any.whl.metadata (13 kB)
Downloading flwr_datasets-0.5.0-py3-none-any.whl (87 kB)
   ---------------------------------------- 0.0/87.0 kB ? eta -:--:--
   ---- ----------------------------------- 10.2/87.0 kB ? eta -:--:--
  

  You can safely remove it manually.


In [12]:
pip install pyarrow[parquet]

Note: you may need to restart the kernel to use updated packages.




In [1]:
from flwr.client import Client, ClientApp, NumPyClient
from flwr.server import ServerApp, ServerConfig
from flwr.server.strategy import FedAvg
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset

from utils3 import *

In [4]:
# Prepare data using Flower datasets
def load_data(partition_id):
    fds = FederatedDataset(dataset="mnist", partitioners={"train": 5})
    partition = fds.load_partition(partition_id)

    traintest = partition.train_test_split(test_size=0.2, seed=42)
    traintest = traintest.with_transform(normalize)
    trainset, testset = traintest["train"], traintest["test"]

    trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
    testloader = DataLoader(testset, batch_size=64)
    return trainloader, testloader

In [6]:
# Flower can send configuration values to clients
def fit_config(server_round: int):
    config_dict = {
        "local_epochs": 2 if server_round < 3 else 5,
    }
    return config_dict

In [8]:
net = SimpleModel()
params = ndarrays_to_parameters(get_weights(net))

def server_fn(context: Context):
    strategy = FedAvg(
        min_fit_clients=5,
        fraction_evaluate=0.0,
        initial_parameters=params,
        on_fit_config_fn=fit_config,  # <- NEW
    )
    config=ServerConfig(num_rounds=3)
    return ServerAppComponents(
        strategy=strategy,
        config=config,
    )

In [10]:
# Define Server App
server = ServerApp(server_fn=server_fn)

In [12]:
# Define Flower Client
class FlowerClient(NumPyClient):
    def __init__(self, net, trainloader, testloader):
        self.net = net
        self.trainloader = trainloader
        self.testloader = testloader

    def fit(self, parameters, config):
        set_weights(self.net, parameters)

        epochs = config["local_epochs"]
        log(INFO, f"client trains for {epochs} epochs")
        train_model(self.net, self.trainloader, epochs)

        return get_weights(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        set_weights(self.net, parameters)
        loss, accuracy = evaluate_model(self.net, self.testloader)
        return loss, len(self.testloader), {"accuracy": accuracy}

In [14]:
# Create the Client Function and the Client App
def client_fn(context: Context) -> Client:
    net = SimpleModel()
    partition_id = int(context.node_config["partition-id"])
    trainloader, testloader = load_data(partition_id=partition_id)
    return FlowerClient(net, trainloader, testloader).to_client()


client = ClientApp(client_fn)

In [None]:
# Run Client and Server App
run_simulation(server_app=server,
               client_app=client,
               num_supernodes=5,
               backend_config=backend_setup
               )

[92mINFO [0m: Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m: 
[92mINFO [0m: [INIT]
[92mINFO [0m: Using initial global parameters provided by strategy
[92mINFO [0m: Starting evaluation of initial global parameters
[92mINFO [0m: Evaluation returned no results (`None`)
[92mINFO [0m: 
[92mINFO [0m: [ROUND 1]
[92mINFO [0m: configure_fit: strategy sampled 5 clients (out of 5)
