In [2]:
%pip install -q flwr[simulation] torch torchvision


In [3]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import flwr as fl
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10

DEVICE = torch.device("cpu")  # Try "cuda" to train on GPU
print(f"Training on {DEVICE} using PyTorch {torch.__version__} and Flower {fl.__version__}")

Training on cpu using PyTorch 1.12.1+cu113 and Flower 1.1.0


In [4]:
NUM_CLIENTS = 10

def load_datasets(num_clients: int):
    # Download and transform CIFAR-10 (train and test)
    transform = transforms.Compose(
      [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    trainset = CIFAR10("./dataset", train=True, download=True, transform=transform)
    testset = CIFAR10("./dataset", train=False, download=True, transform=transform)

    # Split training set into `num_clients` partitions to simulate different local datasets
    partition_size = len(trainset) // num_clients
    lengths = [partition_size] * num_clients
    datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))

    # Split each partition into train/val and create DataLoader
    trainloaders = []
    valloaders = []
    for ds in datasets:
        len_val = len(ds) // 10  # 10 % validation set
        len_train = len(ds) - len_val
        lengths = [len_train, len_val]
        ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
        trainloaders.append(DataLoader(ds_train, batch_size=32, shuffle=True))
        valloaders.append(DataLoader(ds_val, batch_size=32))
    testloader = DataLoader(testset, batch_size=32)
    return trainloaders, valloaders, testloader

trainloaders, valloaders, testloader = load_datasets(NUM_CLIENTS)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def train(net, trainloader, epochs: int):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

In [6]:
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, cid, net, trainloader, valloader):
        self.cid = cid
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.cid}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.cid}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.cid}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}
    
    def get_properties(self, config):
        Client_RAM = {"RAM": 2} # Client will has there way to get RAM number, In this example we use 2GB
        return Client_RAM


def client_fn(cid) -> FlowerClient:
    net = Net().to(DEVICE)
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]
    return FlowerClient(cid, net, trainloader, valloader)

In [7]:
from typing import Callable, Union

from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    MetricsAggregationFn,
    NDArrays,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
    GetPropertiesIns
)
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy

class Client_selection_RAM_strategy(fl.server.strategy.FedAvg):
    def __repr__(self) -> str:
        return "Client_selection_RAM_strategy"

    def configure_fit(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, FitIns]]:
        """Configure the next round of training."""
        weights = parameters_to_ndarrays(parameters)
        self.pre_weights = weights
        parameters = ndarrays_to_parameters(weights)
        config = {}
        if self.on_fit_config_fn is not None:
            # Custom fit config function provided
            config = self.on_fit_config_fn(server_round)
        fit_ins = FitIns(parameters, config)

        # Sample clients
        sample_size, min_num_clients = self.num_fit_clients(
            client_manager.num_available()
        )
        
        all_clients = client_manager.all() # Dict[str, ClientProxy], Return all available clients
        selected_clients = [] # the clients list we will select

        for client in all_clients.values(): # look all clients  
            config_properties = GetPropertiesIns({"RAM": 0}) # config (Config) – Configuration parameters requested by the server. This can be used to tell the client which properties are needed along with some Scalar attributes.
            client_properties = client.get_properties(config_properties, timeout=2.0) # get each client properties, pass GetPropertiesIns and timeout parameter
            client_propertie = client_properties.properties
            if client_propertie["RAM"] > 1: # Choose the client which RAM > 1GB
                selected_clients.append(client)

        if len(selected_clients) == 0:
            print("No client has be selected")
        # Return client/config pairs
        return [(client, fit_ins) for client in selected_clients]      

In [8]:
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=2,
    config=fl.server.ServerConfig(num_rounds=2),
    strategy=Client_selection_RAM_strategy(),  # <-- pass the new strategy here
)

INFO flower 2022-11-20 13:21:53,846 | app.py:143 | Starting Flower simulation, config: ServerConfig(num_rounds=2, round_timeout=None)
INFO:flower:Starting Flower simulation, config: ServerConfig(num_rounds=2, round_timeout=None)
2022-11-20 13:21:55,551	INFO worker.py:1518 -- Started a local Ray instance.
INFO flower 2022-11-20 13:21:58,455 | app.py:177 | Flower VCE: Ray initialized with resources: {'memory': 7956017972.0, 'CPU': 2.0, 'node:172.28.0.2': 1.0, 'object_store_memory': 3978008985.0}
INFO:flower:Flower VCE: Ray initialized with resources: {'memory': 7956017972.0, 'CPU': 2.0, 'node:172.28.0.2': 1.0, 'object_store_memory': 3978008985.0}
INFO flower 2022-11-20 13:21:58,463 | server.py:86 | Initializing global parameters
INFO:flower:Initializing global parameters
INFO flower 2022-11-20 13:21:58,467 | server.py:270 | Requesting initial parameters from one random client
INFO:flower:Requesting initial parameters from one random client
INFO flower 2022-11-20 13:22:03,043 | server.py:

[2m[36m(launch_and_get_parameters pid=4740)[0m [Client 1] get_parameters


DEBUG flower 2022-11-20 13:22:08,940 | server.py:220 | fit_round 1: strategy sampled 2 clients (out of 2)
DEBUG:flower:fit_round 1: strategy sampled 2 clients (out of 2)


[2m[36m(launch_and_fit pid=4740)[0m [Client 1] fit, config: {}
[2m[36m(launch_and_fit pid=4741)[0m [Client 0] fit, config: {}
[2m[36m(launch_and_fit pid=4740)[0m Epoch 1: train loss 0.06344978511333466, accuracy 0.24022222222222223


DEBUG flower 2022-11-20 13:22:18,425 | server.py:234 | fit_round 1 received 2 results and 0 failures
DEBUG:flower:fit_round 1 received 2 results and 0 failures
DEBUG flower 2022-11-20 13:22:18,461 | server.py:170 | evaluate_round 1: strategy sampled 2 clients (out of 2)
DEBUG:flower:evaluate_round 1: strategy sampled 2 clients (out of 2)


[2m[36m(launch_and_fit pid=4741)[0m Epoch 1: train loss 0.06556787341833115, accuracy 0.2262222222222222
[2m[36m(launch_and_evaluate pid=4740)[0m [Client 0] evaluate, config: {}
[2m[36m(launch_and_evaluate pid=4741)[0m [Client 1] evaluate, config: {}


DEBUG flower 2022-11-20 13:22:21,765 | server.py:184 | evaluate_round 1 received 2 results and 0 failures
DEBUG:flower:evaluate_round 1 received 2 results and 0 failures
DEBUG flower 2022-11-20 13:22:24,597 | server.py:220 | fit_round 2: strategy sampled 2 clients (out of 2)
DEBUG:flower:fit_round 2: strategy sampled 2 clients (out of 2)


[2m[36m(launch_and_fit pid=4741)[0m [Client 1] fit, config: {}
[2m[36m(launch_and_fit pid=4740)[0m [Client 0] fit, config: {}
[2m[36m(launch_and_fit pid=4741)[0m Epoch 1: train loss 0.05627148225903511, accuracy 0.3388888888888889


DEBUG flower 2022-11-20 13:22:32,780 | server.py:234 | fit_round 2 received 2 results and 0 failures
DEBUG:flower:fit_round 2 received 2 results and 0 failures
DEBUG flower 2022-11-20 13:22:32,793 | server.py:170 | evaluate_round 2: strategy sampled 2 clients (out of 2)
DEBUG:flower:evaluate_round 2: strategy sampled 2 clients (out of 2)


[2m[36m(launch_and_fit pid=4740)[0m Epoch 1: train loss 0.05621957778930664, accuracy 0.3388888888888889
[2m[36m(launch_and_evaluate pid=4740)[0m [Client 0] evaluate, config: {}
[2m[36m(launch_and_evaluate pid=4741)[0m [Client 1] evaluate, config: {}


DEBUG flower 2022-11-20 13:22:36,067 | server.py:184 | evaluate_round 2 received 2 results and 0 failures
DEBUG:flower:evaluate_round 2 received 2 results and 0 failures
INFO flower 2022-11-20 13:22:36,075 | server.py:144 | FL finished in 32.97151034700073
INFO:flower:FL finished in 32.97151034700073
INFO flower 2022-11-20 13:22:36,080 | app.py:192 | app_fit: losses_distributed [(1, 0.06254182982444763), (2, 0.055628442406654356)]
INFO:flower:app_fit: losses_distributed [(1, 0.06254182982444763), (2, 0.055628442406654356)]
INFO flower 2022-11-20 13:22:36,088 | app.py:193 | app_fit: metrics_distributed {}
INFO:flower:app_fit: metrics_distributed {}
INFO flower 2022-11-20 13:22:36,095 | app.py:194 | app_fit: losses_centralized []
INFO:flower:app_fit: losses_centralized []
INFO flower 2022-11-20 13:22:36,101 | app.py:195 | app_fit: metrics_centralized {}
INFO:flower:app_fit: metrics_centralized {}


History (loss, distributed):
	round 1: 0.06254182982444763
	round 2: 0.055628442406654356