# Federated Learning example

### This notebook is referenced by [Flower offical tutorial](https://flower.dev/docs/quickstart-pytorch.html).

In this tutorial we will learn how to train a Convolutional Neural Network on CIFAR10 using Flower and PyTorch.

Our example consists of one server and two clients all having the same model.

You can use one PC to be the server and other multiple PC to be clients.

# Server

### If current PC is to be server, run above cell to start server.

In [1]:
from typing import List, Tuple

import flwr as fl
from flwr.common import Metrics

### Define strategy

The strategy is to deal with model weight parameters delievered by clients, and the following example uses the averaging method.

In [2]:
strategy = fl.server.strategy.FedAvg()

### Server's IP

In [3]:
SERVER_IP = 'your_own_IP'

### Start Flower server

In [4]:
fl.server.start_server(
    server_address=SERVER_IP+":8080",
    config={"num_rounds": 3},
    strategy=strategy,
)

INFO flower 2022-07-12 17:13:01,169 | app.py:109 | Flower server running (3 rounds)
SSL is disabled
INFO flower 2022-07-12 17:13:01,171 | server.py:128 | Initializing global parameters
INFO flower 2022-07-12 17:13:01,172 | server.py:327 | Requesting initial parameters from one random client
INFO flower 2022-07-12 17:14:20,528 | server.py:330 | Received initial parameters from one random client
INFO flower 2022-07-12 17:14:20,531 | server.py:130 | Evaluating initial parameters
INFO flower 2022-07-12 17:14:20,534 | server.py:143 | FL starting
DEBUG flower 2022-07-12 17:14:22,650 | server.py:269 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-07-12 17:14:30,809 | server.py:281 | fit_round received 2 results and 0 failures
DEBUG flower 2022-07-12 17:14:30,840 | server.py:215 | evaluate_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-07-12 17:14:36,786 | server.py:227 | evaluate_round received 2 results and 0 failures
DEBUG flower 2022-07-12 17:14:36,789 

History (loss, distributed):
	round 1: 2.069756031036377
	round 2: 1.668516755104065
	round 3: 1.5030343532562256

# Client

### If current PC is to be client, run above cell to start client

In [1]:
import warnings
from collections import OrderedDict

import flwr as fl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor
from tqdm import tqdm

warnings.filterwarnings("ignore", category=UserWarning)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### Target Server's IP

In [2]:
TARGET_SERVER_IP = 'your_own_IP'

## Regular PyTorch pipeline: nn.Module, train, test, and DataLoader

### Sample Network

This is a simple image classification CNN.

In [3]:
class Net(nn.Module):
    """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""

    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

### Training, Testing and dataloader function

In [4]:
def train(net, trainloader, epochs):
    """Train the model on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    for _ in range(epochs):
        for images, labels in tqdm(trainloader):
            optimizer.zero_grad()
            criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward()
            optimizer.step()


def test(net, testloader):
    """Validate the model on the test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    with torch.no_grad():
        for images, labels in tqdm(testloader):
            outputs = net(images.to(DEVICE))
            labels = labels.to(DEVICE)
            loss += criterion(outputs, labels).item()
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
    return loss / len(testloader.dataset), correct / total


def load_data():
    """Load CIFAR-10 (training and test set)."""
    trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    trainset = CIFAR10("./data", train=True, download=True, transform=trf)
    testset = CIFAR10("./data", train=False, download=True, transform=trf)
    return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset)

## Federation of the pipeline with Flower

In [5]:
net = Net().to(DEVICE)
trainloader, testloader = load_data()

Files already downloaded and verified
Files already downloaded and verified


### Define Flower client

#### get_parameters

1. return the model weight as a list of NumPy ndarrays

#### set_parameters

1. update the local model weights with the parameters received from the server

#### fit

1. set the local model weights
2. train the local model
3. receive the updated local model weights

#### evaluate
1. test the local model

In [6]:
class FlowerClient(fl.client.NumPyClient):
    def get_parameters(self):
        return [val.cpu().numpy() for _, val in net.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(net.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        net.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        train(net, trainloader, epochs=1)
        return self.get_parameters(), len(trainloader.dataset), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        loss, accuracy = test(net, testloader)
        return loss, len(testloader.dataset), {"accuracy": accuracy}

### Start Flower client

In [7]:
fl.client.start_numpy_client(TARGET_SERVER_IP+":8080", client=FlowerClient())

INFO flower 2022-07-12 17:18:12,639 | connection.py:102 | Opened insecure gRPC connection (no certificates were passed)
DEBUG flower 2022-07-12 17:18:12,643 | connection.py:39 | ChannelConnectivity.IDLE
DEBUG flower 2022-07-12 17:18:12,646 | connection.py:39 | ChannelConnectivity.READY
100%|██████████| 1563/1563 [00:10<00:00, 154.38it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1365.73it/s]
100%|██████████| 1563/1563 [00:09<00:00, 157.43it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1266.81it/s]
100%|██████████| 1563/1563 [00:09<00:00, 157.21it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1329.18it/s]
DEBUG flower 2022-07-12 17:19:10,270 | connection.py:121 | gRPC channel closed
INFO flower 2022-07-12 17:19:10,271 | app.py:101 | Disconnect and shut down
