# FLEXible tutorial: CIFAR10 classification using Pytorch

FLEXible is a library to federate models. We offer the tools to load and federate data or to load federated data, and the tools to create a federated environment. The user can define the model and the *communication primitives* to train the model in a federated environment, but we already offer decorators so that an advancer user can implement its own federated workflow. We design python decorators to handle the following federated learning flows:
- initialization: Initialize the model in the server.
- deploy model: Deploy the model to the clients.
- training: Define the train function.
- collect the weights: Collect the weights of the clients params to aggregate them later.
- aggregate the weights: Use an aggregation method to aggregte the collected weights.
- deploy model: Deploy the model with the updated weights to the clients.
- evaluate: Define the evaluate function.

In this notebook, we show how to use the defined primitive functions, letting the user the implementation of some key functions:
- Define the model to train: It's necessary to tell server and clients which model will be trained.
- Aggregator method: In this notebook we will implement FedAvg as the aggregation function.

Note that the primitive functions that we offer are basics functions, as we assume how the federated learning training will be. If you want to do a more customizable training loop, please check the notebook "Federated MNIST PT example with flexible decorators", as we show there how to implement the primitive functions from scrach. 

In [None]:
import torch
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
device

In [None]:
from flex.data import Dataset, FedDatasetConfig, FedDataDistribution
from torchvision import datasets, transforms

cifar_transforms = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

train_data = datasets.CIFAR10(
        root=".",
        train=True,
        download=True,
        transform=None, # Note that we do not specify transforms here, we provide them later in the training process
)

test_data = datasets.CIFAR10(
        root=".",
        train=False,
        download=True,
        transform=cifar_transforms
)


config = FedDatasetConfig(seed=0)
config.replacement = False
config.n_clients = 100

flex_dataset = FedDataDistribution.from_config(
                        centralized_data=Dataset.from_torchvision_dataset(train_data), 
                        config=config
                )

`@init_server_model` is a decorator designed to perform the initialization of the server model in a client-server architecture. It has no requirements for specific arguments in the function that uses it.

In [None]:
import torch.nn as nn
import torch.nn.functional as F

from flex.pool import init_server_model
from flex.pool import FlexPool
from flex.model import FlexModel

from torchvision.models import resnet18

def get_model(num_classes=10):
    resnet_model = resnet18(weights='DEFAULT')
    for p in resnet_model.parameters():
        p.requires_grad = False
    resnet_model.fc = torch.nn.Linear(resnet_model.fc.in_features, num_classes)
    return resnet_model

@init_server_model
def build_server_model():
    server_flex_model = FlexModel()

    server_flex_model["model"] = get_model().to(device)
    # Required to store this for later stages of the FL training process
    server_flex_model["criterion"] = torch.nn.CrossEntropyLoss()
    server_flex_model["optimizer_func"] = torch.optim.Adam
    server_flex_model["optimizer_kwargs"] = {}

    return server_flex_model

flex_pool = FlexPool.client_server_architecture(flex_dataset, init_func=build_server_model)

clients = flex_pool.clients
servers = flex_pool.servers
aggregators = flex_pool.aggregators

print(f"Number of nodes in the pool {len(flex_pool)}: {len(servers)} server plus {len(clients)} clients. The server is also an aggregator")

We also implement the possibility of select a subsample of the clients in the training process.

In [None]:
#Select clients
clients_per_round=2
selected_clients_pool = clients.select(clients_per_round)
selected_clients = selected_clients_pool.clients

print(f"Server node is indentified by key \"{servers.actor_ids[0]}\"")
print(f"Selected {len(selected_clients.actor_ids)} client nodes of a total of {len(clients.actor_ids)}")

`@deploy_server_model` is a decorator designed to copy the model from the server to the clients at each federated learning round. The function that uses it, must have at least one argument, which is the FlexModel object that stores the model at the server.

In [None]:
from flex.pool import deploy_server_model_pt

servers.map(deploy_server_model_pt, selected_clients)

Suprisingly, there is no decorator for the training process as it can be imnplemented directly.

In [None]:
from flex.data import Dataset
from torch.utils.data import DataLoader
from copy import deepcopy
from tqdm import tqdm

def train(client_flex_model: FlexModel, client_data: Dataset):
    train_dataset = client_data.to_torchvision_dataset(transform=cifar_transforms)
    client_dataloader = DataLoader(train_dataset, batch_size=64)
    model = client_flex_model["model"]
    model = model.to(device)
    client_flex_model['previous_model'] = deepcopy(model) # Required to use `collect_client_diff_weights_pt` primitive
    optimizer = client_flex_model['optimizer_func'](model.parameters(), **client_flex_model["optimizer_kwargs"])
    model = model.train()
    criterion = client_flex_model["criterion"]
    epochs = 5
    for _ in tqdm(range(epochs)):
        for imgs, labels in client_dataloader:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            pred = model(imgs)
            loss = criterion(pred, labels)
            loss.backward()
            optimizer.step()

In [None]:
selected_clients.map(train)

`collect_client_diff_weights_pt` as it name says, it collects weights from a set of clients. Particularly, it collects the difference between the model before and after training, that is, what the model has learnt in its local training step. Also note that the weights of the model before training are assume to be stored using `previous_model` as key in the FlexModel of a client.

In [None]:
from flex.pool import collect_client_diff_weights_pt

aggregators.map(collect_client_diff_weights_pt, selected_clients)

`fed_avg` implements the aggregator Fedeverated Average, which computes the mean of the collected weights in previous steps.

In [None]:
from flex.pool import fed_avg

aggregators.map(fed_avg)

`set_aggregated_diff_weights_pt` adds the aggregated weights to the weights of the server, it assumes that the aggregated weights have been collected using a similar logic to `collect_client_diff_weights_pt`.

In [None]:
from flex.pool import set_aggregated_diff_weights_pt

aggregators.map(set_aggregated_diff_weights_pt, servers)

`@evaluate_server_model` is a decorator used to test the server model. The function that uses it must have at least one argument, the FlexModel at the server.

In [None]:
from flex.pool import evaluate_server_model

@evaluate_server_model
def evaluate_global_model(server_flex_model: FlexModel, test_data=None):
    model = server_flex_model["model"]
    model.eval()
    test_loss = 0
    test_acc = 0
    total_count = 0
    model = model.to(device)
    criterion=server_flex_model['criterion']
    # get test data as a torchvision object
    test_dataloader = DataLoader(test_data, batch_size=256, shuffle=True, pin_memory=False)
    losses = []
    with torch.no_grad():
        for data, target in tqdm(test_dataloader):
            total_count += target.size(0)
            data, target = data.to(device), target.to(device)
            output = model(data)
            losses.append(criterion(output, target).item())
            pred = output.data.max(1, keepdim=True)[1]
            test_acc += pred.eq(target.data.view_as(pred)).long().cpu().sum().item()

    test_loss = sum(losses) / len(losses)
    test_acc /= total_count
    return test_loss, test_acc

In [None]:

metrics = servers.map(evaluate_global_model, test_data=test_data)
print(metrics[0])

### Run the federated learning experiment for a few rounds

Now, we can summarize the steps provided above and run the federated experiment for multiple rounds:

In [None]:
# Auxiliar function to clear unused gpu mem in clients
def clean_up_models(client_model: FlexModel, _):
    import gc
    client_model.clear()
    gc.collect()

def train_n_rounds(n_rounds, clients_per_round=20):  
    pool = FlexPool.client_server_architecture(fed_dataset=flex_dataset, init_func=build_server_model)
    for i in range(n_rounds):
        print(f"\nRunning round: {i+1} of {n_rounds}")
        selected_clients_pool = pool.clients.select(clients_per_round)
        selected_clients = selected_clients_pool.clients
        print(f"Selected clients for this round: {len(selected_clients)}")
        # Deploy the server model to the selected clients
        pool.servers.map(deploy_server_model_pt, selected_clients)
        # Each selected client trains her model
        selected_clients.map(train)
        # The aggregador collects weights from the selected clients and aggregates them
        pool.aggregators.map(collect_client_diff_weights_pt, selected_clients)
        pool.aggregators.map(fed_avg)
        # The aggregator send its aggregated weights to the server
        pool.aggregators.map(set_aggregated_diff_weights_pt, pool.servers)
        # Optional: evaluate the server model
        metrics = pool.servers.map(evaluate_global_model, test_data=test_data)
        # Optional: clean-up unused memory
        selected_clients.map(clean_up_models)
        loss, acc = metrics[0]
        print(f"Server: Test acc: {acc:.4f}, test loss: {loss:.4f}")

In [None]:
train_n_rounds(20, clients_per_round=10)