# Cifar-10 Example
In this notebook we are going to train a basic convolutional neural network in a federated blockchain environment. This example supose previous experience with pytorch and the flex framework for federated learning experiments.

In [None]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
device

First we are going to download the dataset and create the according `FedDataset` for this experiment. We are not loading a test dataset since we are not going to have a server per-se.

In [None]:
from flex.data import Dataset, FedDatasetConfig, FedDataDistribution
from torchvision import datasets, transforms

cifar_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

train_data = datasets.CIFAR10(
    root=".",
    train=True,
    download=True,
    transform=None,  # Note that we do not specify transforms here, we provide them later in the training process
)

config = FedDatasetConfig(seed=0)
config.replacement = False
config.n_nodes = 100

flex_dataset = FedDataDistribution.from_config(
    centralized_data=Dataset.from_torchvision_dataset(train_data), config=config
)

Here we are going to create our pool as usual. Note that this time we are using a `PoSBlockchainPool`, that is a pool that uses a Proof of Stake blockchain to coordinate the training process. In this pool every client is also an aggregator. We are opting for a PoS blockchain due to the fact of low computational power neccesary to run the blockchain.

In [None]:
import torch.nn as nn
import torch.nn.functional as F

from flexBlock.pool import PoSBlockchainPool
from flex.pool import init_server_model
from flex.model import FlexModel



def get_model(num_classes=10):
    # Model for cifar-10 32x32x3
    return nn.Sequential(
        nn.Conv2d(3, 6, 3),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Conv2d(6, 16, 3),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Flatten(),
        nn.Linear(46656, 120),
        nn.ReLU(),
        nn.Linear(120, 84),
        nn.ReLU(),
        nn.Linear(84, num_classes)
    )


@init_server_model
def build_server_model():
    server_flex_model = FlexModel()
    server_flex_model["model"] = get_model().to(device)
    # Required to store this for later stages of the FL training process
    server_flex_model["criterion"] = torch.nn.CrossEntropyLoss()
    server_flex_model["optimizer_func"] = torch.optim.Adam
    server_flex_model["optimizer_kwargs"] = {}

    return server_flex_model

flex_pool = PoSBlockchainPool(
    fed_dataset=flex_dataset, init_func=build_server_model
)

clients = flex_pool.clients
servers = flex_pool.servers
aggregators = flex_pool.aggregators

print(
    f"Number of nodes in the pool {len(flex_pool)}"
)

We can select a subset of clients in each round

In [None]:
# Select clients
clients_per_round = 2
selected_clients_pool = clients.select(clients_per_round)
selected_clients = selected_clients_pool.clients

print(f'Server node is indentified by key "{servers.actor_ids[0]}"')
print(
    f"Selected {len(selected_clients.actor_ids)} client nodes of a total of {len(clients.actor_ids)}"
)

In [None]:
from flex.pool import deploy_server_model_pt

servers.map(deploy_server_model_pt, selected_clients)

We define the train loop manually.

In [None]:
from flex.data import Dataset
from torch.utils.data import DataLoader
from copy import deepcopy
from tqdm import tqdm


def train(client_flex_model: FlexModel, client_data: Dataset):
    train_dataset = client_data.to_torchvision_dataset(transform=cifar_transforms)
    client_dataloader = DataLoader(train_dataset, batch_size=64)
    model = client_flex_model["model"]
    model = model.to(device)
    client_flex_model["previous_model"] = deepcopy(
        model
    )  # Required to use `collect_client_diff_weights_pt` primitive
    optimizer = client_flex_model["optimizer_func"](
        model.parameters(), **client_flex_model["optimizer_kwargs"]
    )
    model = model.train()
    criterion = client_flex_model["criterion"]
    epochs = 1
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        for imgs, labels in tqdm(client_dataloader):
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            pred = model(imgs)
            loss = criterion(pred, labels)
            loss.backward()
            optimizer.step()

Now let's train our clients as usual.

In [None]:
selected_clients.map(train)

One important difference with a normal flex example is that the weight collection functions are not compatible with flex. In case that we want to use a flex built-in collection primitive or a custom one that we built for flex, we can adapt it to the blockchain environment using the function `collect_to_send_wrapper`.

In [None]:
from flex.pool import collect_client_diff_weights_pt
from flexBlock.pool.primitives import collect_to_send_wrapper

# Make the collect_client_diff_weights_pt primitive blockchain ready
collect_client_diff = collect_to_send_wrapper(collect_client_diff_weights_pt)

aggregators.map(collect_client_diff, selected_clients)


Now the weights need to be shared between miners. In a real life scenario this would be done through a gossip mechanishm between miners, we can archieve this by calling the `flex_pool.gossip` method. Then, we are ready to mine a block, this is done as part of the aggregate step. The `aggregate` method will run the consensus mechanism for our architecture (in our case, proof of stake) and the winner miner will run the given aggregation function. In this case, we are using a simple average.

Finally, we can update the weights of the clients and repeat the process.

In [None]:
from flex.pool import fed_avg

flex_pool.gossip()
flex_pool.aggregate(fed_avg)

Now we can run the federated experiment for a few rounds

In [None]:
from flex.pool import set_aggregated_diff_weights_pt, deploy_server_model_pt

# Auxiliar function to clear unused gpu mem in clients
def clean_up_models(client_model: FlexModel, _):
    import gc

    client_model.clear()
    gc.collect()


def train_n_rounds(n_rounds, clients_per_round=20):
    pool = PoSBlockchainPool(
        fed_dataset=flex_dataset, init_func=build_server_model
    )
    for i in range(n_rounds):
        print(f"\nRunning round: {i+1} of {n_rounds}")
        selected_clients_pool = pool.clients.select(clients_per_round)
        selected_clients = selected_clients_pool.clients
        print(f"Selected clients for this round: {len(selected_clients)}")
        # Deploy the server model to the selected clients
        pool.servers.map(deploy_server_model_pt, selected_clients)
        # Each selected client trains her model
        selected_clients.map(train)
        # The aggregador collects weights from the selected clients and aggregates them
        pool.aggregators.map(collect_client_diff, selected_clients)
        pool.gossip()
        pool.aggregate(fed_avg)
        # The aggregator send its aggregated weights to the server
        pool.aggregators.map(set_aggregated_diff_weights_pt, pool.servers)
        # Optional: clean-up unused memory
        selected_clients.map(clean_up_models)

In [None]:
train_n_rounds(10, clients_per_round=2)