# Federated Learning example

### This notebook is referenced by [Flower offical tutorial](https://flower.dev/docs/quickstart-pytorch.html).

In this notebook we will learn how to train a HANet using Flower and PyTorch.
We will using 3 client

# Server

### If current PC is to be server, run above cell to start server.

In [1]:
from typing import List, Tuple
from typing import Optional
import numpy as np
import torch
from collections import OrderedDict

import flwr as fl
from flwr.common import Metrics

from handover_grasping.model import HANet

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


### Define strategy

The strategy is to deal with model weight parameters delievered by clients, and the following example uses the averaging method and includes saving method.

In [2]:
net = HANet()
class SaveModelAndMetricsStrategy(fl.server.strategy.FedAvg):
    #
    def aggregate_fit(
        self,
        rnd: int,
        results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]], # FitRes is like EvaluateRes and has a metrics key 
        failures: List[BaseException],
    ) -> Optional[fl.common.Weights]:

        """Aggregate model weights using weighted average and store checkpoint"""
        aggregated_parameters_tuple = super().aggregate_fit(rnd, results, failures)
        aggregated_parameters, _ = aggregated_parameters_tuple
        # log_dict['aggregated_parameters']=aggregated_parameters
        
        if aggregated_parameters is not None:
            print(f"Saving round {rnd} aggregated_parameters...")
            # Convert `Parameters` to `List[np.ndarray]`
            aggregated_weights: List[np.ndarray] = fl.common.parameters_to_weights(aggregated_parameters)
            
            # Convert `List[np.ndarray]` to PyTorch`state_dict`
            params_dict = zip(net.state_dict().keys(), aggregated_weights)
            state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
            net.load_state_dict(state_dict, strict=True)
            torch.save(net.state_dict(), f"model_round_{rnd}.pth")
            
        return aggregated_parameters_tuple 

# Create strategy and run server
strategy = SaveModelAndMetricsStrategy(min_fit_clients=3, min_available_clients=3)

### Server's IP

In [3]:
SERVER_IP = 'your_own_IP'

### Start Flower server

In [None]:
ROUNDS = 10

fl.server.start_server(
    server_address=SERVER_IP+":8080",
    config={"num_rounds": ROUNDS},
    strategy=strategy,
)

INFO flower 2022-07-20 15:07:21,001 | app.py:109 | Flower server running (10 rounds)
SSL is disabled
INFO - 2022-07-20 15:07:21,001 - app - Flower server running (10 rounds)
SSL is disabled
INFO flower 2022-07-20 15:07:21,005 | server.py:128 | Initializing global parameters
INFO - 2022-07-20 15:07:21,005 - server - Initializing global parameters
INFO flower 2022-07-20 15:07:21,006 | server.py:327 | Requesting initial parameters from one random client
INFO - 2022-07-20 15:07:21,006 - server - Requesting initial parameters from one random client
INFO flower 2022-07-20 15:08:11,729 | server.py:330 | Received initial parameters from one random client
INFO - 2022-07-20 15:08:11,729 - server - Received initial parameters from one random client
INFO flower 2022-07-20 15:08:11,730 | server.py:130 | Evaluating initial parameters
INFO - 2022-07-20 15:08:11,730 - server - Evaluating initial parameters
INFO flower 2022-07-20 15:08:11,732 | server.py:143 | FL starting
INFO - 2022-07-20 15:08:11,732

# Client

### If current PC is to be client, run above cell to start client

In [None]:
from handover_grasping import HANet
import warnings
from collections import OrderedDict

import flwr as fl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

warnings.filterwarnings("ignore", category=UserWarning)

### Target Server's IP

In [None]:
TARGET_SERVER_IP = 'your_own_IP'

### Training, Testing and dataloader function

In [None]:
def train(net, trainloader, epochs):
    """Train the model on the training set."""
    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr = 1e-3)
    for _ in range(epochs):
        for i_batch, sampled_batched in enumerate(trainloader):
            optimizer.zero_grad()
            color = sampled_batched['color'].cuda()
            depth = sampled_batched['depth'].cuda()
            label = sampled_batched['label'].permute(0,2,3,1).cuda().float()
            criterion(net(color, depth), label).backward()
            optimizer.step()


def test(net, testloader):
    """Validate the model on the test set."""
    criterion = torch.nn.BCEWithLogitsLoss()
    correct, total, loss = 0, 0, 0.0
    with torch.no_grad():
        for i_batch, sampled_batched in enumerate(testloader):
            color = sampled_batched['color'].cuda()
            depth = sampled_batched['depth'].cuda()
            labels = sampled_batched['label'].permute(0,2,3,1).cuda().float()
            
            outputs = net(color, depth)
            loss += criterion(outputs, labels).item()
            
    return loss / len(testloader.dataset)


def load_data():
    DATA_PATH = '/home/arg/handover_grasping/data/HANet_training_datasets'

    dataset_train = handover_grasping_dataset(DATA_PATH, color_type='png', mode='train_split_1')
    dataset_test = handover_grasping_dataset(DATA_PATH, color_type='png', mode='fl_test')

    return DataLoader(dataset_train, batch_size = 8, shuffle = True, num_workers = 8), DataLoader(dataset_test, batch_size = 1, shuffle = False, num_workers = 8)

### Initial HANet and dataloader

In [None]:
net = HANet(4)
net = net.cuda()

trainloader, testloader = load_data()

### Define Flower client

#### get_parameters

1. return the model weight as a list of NumPy ndarrays

#### set_parameters

1. update the local model weights with the parameters received from the server

#### fit

1. set the local model weights
2. train the local model
3. receive the updated local model weights

#### evaluate
1. test the local model

In [None]:
class FlowerClient(fl.client.NumPyClient):
    def get_parameters(self):
        return [val.cpu().numpy() for _, val in net.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(net.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        net.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        train(net, trainloader, epochs=10)
        return self.get_parameters(), len(trainloader.dataset), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        loss = test(net, testloader)
        return loss, len(testloader.dataset), {"accuracy": 0.0}

### Start Flower client

In [None]:
fl.client.start_numpy_client(TARGET_SERVER_IP+":8080", client=FlowerClient())