In [21]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import flwr as fl
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10

# FedProx algorithm using package
from torch.optim.lr_scheduler import StepLR
import copy

DEVICE = torch.device("cpu")  # Try "cuda" to train on GPU
print(f"Training on {DEVICE} using PyTorch {torch.__version__} and Flower {fl.__version__}")

Training on cpu using PyTorch 1.13.1+cu117 and Flower 1.1.0


In [22]:
# Data preprocessing, divide data into 10 equal parts, to simulate at most 10 clients
NUM_CLIENTS = 10

def load_datasets(num_clients: int): 
    # Download and transform CIFAR-10 (train and test)
    transform = transforms.Compose(
      [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    trainset = CIFAR10("./dataset", train=True, download=True, transform=transform)
    testset = CIFAR10("./dataset", train=False, download=True, transform=transform)

    # Split training set into `num_clients` partitions to simulate different local datasets
    partition_size = len(trainset) // num_clients
    lengths = [partition_size] * num_clients
    datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))

    # Split each partition into train/val and create DataLoader
    # Default batch_size = 32
    trainloaders = []
    valloaders = []
    for ds in datasets:
        len_val = len(ds) // 10  # 10 % validation set
        len_train = len(ds) - len_val
        lengths = [len_train, len_val]
        ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
        trainloaders.append(DataLoader(ds_train, batch_size=32, shuffle=True)) # shuffle to random set dataset
        valloaders.append(DataLoader(ds_val, batch_size=32))
    testloader = DataLoader(testset, batch_size=32) 
    return trainloaders, valloaders, testloader

trainloaders, valloaders, testloader = load_datasets(NUM_CLIENTS)

Files already downloaded and verified
Files already downloaded and verified


In [23]:
# Define Pytorch modle for all client
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor: # 預測function
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Get pytorch model parameter as ndarray due to flower mechanism
def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]

def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)

# Client train model by Client's dataset
# net: client model
# trainloader: client's dataset
# epochs: the number of total epochs, decide by server
# config: the config pass by server, include some FedProx hyperparameter
# globol_model_ndarry: global model parameters with List[np.ndarray] type
def train(net, trainloader, epochs: int, config, globol_model_ndarry):
    # FedProx hyperparameter, server pass
    gamma = config["gamma"]
    mu = config["mu"]
    learning_rate = config["learning_rate"]
    
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss() # Client define loss function
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate) # Client define optimizer
    net.train()
    
    globol_model_ndarry_copy = copy.deepcopy(globol_model_ndarry) # Copy server module parameters
    stepLR = StepLR(optimizer, step_size=10, gamma=gamma) # Learning rate decrease(FedProx), the step_size is default to 10
    
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            
            # compute FedProx proximal_term
            proximal_term = 0.0
            for w, w_t in zip(net.parameters(), globol_model_ndarry_copy):
                w_t_tensor = torch.from_numpy(w_t) # globol model parameters type, from ndarrau to tensor
                proximal_term += (w - w_t_tensor).norm(2)
            
            loss = criterion(net(images), labels) + (mu / 2) * proximal_term
            loss.backward()
            optimizer.step()
            
            # Evaluate Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        
        # Evaulate
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")
        
        stepLR.step() # Learning rate decreasse
    
def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

In [24]:
from flwr.common import Code, EvaluateIns, EvaluateRes, FitIns, FitRes, GetParametersIns, GetParametersRes, Status
from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays


class FlowerClient(fl.client.Client):
    def __init__(self, cid, net, trainloader, valloader):
        self.cid = cid
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
        print(f"[Client {self.cid}] get_parameters")

        # Get parameters as a list of NumPy ndarray's
        ndarrays: List[np.ndarray] = get_parameters(self.net)

        # Serialize ndarray's into a Parameters object
        parameters = ndarrays_to_parameters(ndarrays)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return GetParametersRes(
            status=status,
            parameters=parameters,
        )

    def fit(self, ins: FitIns) -> FitRes:
        print(f"[Client {self.cid}] fit, config: {ins.config}")
        
        # Globol model parameters
        # Deserialize parameters to NumPy ndarray's
        parameters_original = ins.parameters
        ndarrays_original = parameters_to_ndarrays(parameters_original)

        # Update local model, train, get updated parameters
        set_parameters(self.net, ndarrays_original)
        train(self.net, self.trainloader, epochs=ins.config["locol_epochs"], config=ins.config, globol_model_ndarry=ndarrays_original)
        ndarrays_updated = get_parameters(self.net)

        # Serialize ndarray's into a Parameters object
        parameters_updated = ndarrays_to_parameters(ndarrays_updated)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return FitRes(
            status=status,
            parameters=parameters_updated,
            num_examples=len(self.trainloader),
            metrics={},
        )

    def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
        print(f"[Client {self.cid}] evaluate, config: {ins.config}")

        # Deserialize parameters to NumPy ndarray's
        parameters_original = ins.parameters
        ndarrays_original = parameters_to_ndarrays(parameters_original)

        set_parameters(self.net, ndarrays_original)
        loss, accuracy = test(self.net, self.valloader)
        # return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return EvaluateRes(
            status=status,
            loss=float(loss),
            num_examples=len(self.valloader),
            metrics={"accuracy": float(accuracy)},
        )

def client_fn(cid) -> FlowerClient:
    net = Net().to(DEVICE)
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]
    return FlowerClient(cid, net, trainloader, valloader)

In [25]:
def fit_config(server_round: int) -> Dict[str, str]:
    config = {
        "gamma": 0.1,
        "mu": 0.01,
        "learning_rate": 0.01,
        "locol_epochs": 1,
    }
    return config


strategy = fl.server.strategy.FedAvg(
        fraction_fit=1.0,  # Sample 100% of available clients for training
        fraction_evaluate=0.5,  # Sample 50% of available clients for evaluation
        on_fit_config_fn = fit_config,
)

fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=2, # This example only simulate two clients
    config=fl.server.ServerConfig(num_rounds=3),
    strategy = strategy,
)

INFO flower 2022-12-21 21:45:10,968 | app.py:140 | Starting Flower simulation, config: ServerConfig(num_rounds=3, round_timeout=None)
2022-12-21 21:45:14,178	INFO worker.py:1518 -- Started a local Ray instance.
INFO flower 2022-12-21 21:45:15,384 | app.py:174 | Flower VCE: Ray initialized with resources: {'object_store_memory': 1446241075.0, 'CPU': 4.0, 'node:192.168.0.4': 1.0, 'memory': 2892482151.0}
INFO flower 2022-12-21 21:45:15,385 | server.py:86 | Initializing global parameters
INFO flower 2022-12-21 21:45:15,386 | server.py:270 | Requesting initial parameters from one random client
INFO flower 2022-12-21 21:45:18,247 | server.py:274 | Received initial parameters from one random client
INFO flower 2022-12-21 21:45:18,248 | server.py:88 | Evaluating initial parameters
INFO flower 2022-12-21 21:45:18,249 | server.py:101 | FL starting
DEBUG flower 2022-12-21 21:45:18,250 | server.py:215 | fit_round 1: strategy sampled 2 clients (out of 2)


[2m[36m(launch_and_get_parameters pid=40286)[0m [Client 1] get_parameters
[2m[36m(launch_and_fit pid=40286)[0m [Client 0] fit, config: {'gamma': 0.1, 'mu': 0.01, 'learning_rate': 0.01, 'locol_epochs': 1}
[2m[36m(launch_and_fit pid=40287)[0m [Client 1] fit, config: {'gamma': 0.1, 'mu': 0.01, 'learning_rate': 0.01, 'locol_epochs': 1}
[2m[36m(launch_and_fit pid=40286)[0m Epoch 1: train loss 0.06942853331565857, accuracy 0.21666666666666667


DEBUG flower 2022-12-21 21:45:23,323 | server.py:229 | fit_round 1 received 2 results and 0 failures
DEBUG flower 2022-12-21 21:45:23,330 | server.py:165 | evaluate_round 1: strategy sampled 2 clients (out of 2)


[2m[36m(launch_and_fit pid=40287)[0m Epoch 1: train loss 0.06889630109071732, accuracy 0.20755555555555555


DEBUG flower 2022-12-21 21:45:25,522 | server.py:179 | evaluate_round 1 received 2 results and 0 failures
DEBUG flower 2022-12-21 21:45:25,523 | server.py:215 | fit_round 2: strategy sampled 2 clients (out of 2)


[2m[36m(launch_and_evaluate pid=40286)[0m [Client 0] evaluate, config: {}
[2m[36m(launch_and_evaluate pid=40287)[0m [Client 1] evaluate, config: {}
[2m[36m(launch_and_fit pid=40286)[0m [Client 1] fit, config: {'gamma': 0.1, 'mu': 0.01, 'learning_rate': 0.01, 'locol_epochs': 1}
[2m[36m(launch_and_fit pid=40287)[0m [Client 0] fit, config: {'gamma': 0.1, 'mu': 0.01, 'learning_rate': 0.01, 'locol_epochs': 1}


DEBUG flower 2022-12-21 21:45:30,084 | server.py:229 | fit_round 2 received 2 results and 0 failures
DEBUG flower 2022-12-21 21:45:30,090 | server.py:165 | evaluate_round 2: strategy sampled 2 clients (out of 2)


[2m[36m(launch_and_fit pid=40287)[0m Epoch 1: train loss 0.06463585793972015, accuracy 0.24466666666666667
[2m[36m(launch_and_fit pid=40286)[0m Epoch 1: train loss 0.06322526931762695, accuracy 0.27244444444444443


DEBUG flower 2022-12-21 21:45:32,737 | server.py:179 | evaluate_round 2 received 2 results and 0 failures
DEBUG flower 2022-12-21 21:45:32,741 | server.py:215 | fit_round 3: strategy sampled 2 clients (out of 2)


[2m[36m(launch_and_evaluate pid=40286)[0m [Client 1] evaluate, config: {}
[2m[36m(launch_and_evaluate pid=40287)[0m [Client 0] evaluate, config: {}
[2m[36m(launch_and_fit pid=40286)[0m [Client 1] fit, config: {'gamma': 0.1, 'mu': 0.01, 'learning_rate': 0.01, 'locol_epochs': 1}
[2m[36m(launch_and_fit pid=40287)[0m [Client 0] fit, config: {'gamma': 0.1, 'mu': 0.01, 'learning_rate': 0.01, 'locol_epochs': 1}


DEBUG flower 2022-12-21 21:45:37,124 | server.py:229 | fit_round 3 received 2 results and 0 failures
DEBUG flower 2022-12-21 21:45:37,129 | server.py:165 | evaluate_round 3: strategy sampled 2 clients (out of 2)


[2m[36m(launch_and_fit pid=40286)[0m Epoch 1: train loss 0.0607190765440464, accuracy 0.30622222222222223
[2m[36m(launch_and_fit pid=40287)[0m Epoch 1: train loss 0.06096450611948967, accuracy 0.31577777777777777
[2m[36m(launch_and_evaluate pid=40287)[0m [Client 1] evaluate, config: {}


DEBUG flower 2022-12-21 21:45:56,285 | server.py:179 | evaluate_round 3 received 2 results and 0 failures
INFO flower 2022-12-21 21:45:56,286 | server.py:144 | FL finished in 38.036099082000874
INFO flower 2022-12-21 21:45:56,287 | app.py:192 | app_fit: losses_distributed [(1, 0.07017705440521241), (2, 0.06140286910533905), (3, 0.058093232035636905)]
INFO flower 2022-12-21 21:45:56,288 | app.py:193 | app_fit: metrics_distributed {}
INFO flower 2022-12-21 21:45:56,289 | app.py:194 | app_fit: losses_centralized []
INFO flower 2022-12-21 21:45:56,289 | app.py:195 | app_fit: metrics_centralized {}


[2m[36m(launch_and_evaluate pid=40289)[0m [Client 0] evaluate, config: {}


History (loss, distributed):
	round 1: 0.07017705440521241
	round 2: 0.06140286910533905
	round 3: 0.058093232035636905