## Using a Proof of Federated Learning for the MNIST dataset

In this notebook we will use a Proof of Federated Learning (PoFL) blockchain architecture to train a model on the MNIST dataset. PoFL is a consensus mechanism designed for federated learning, contrary to traditional consensus mechanism such as Proof of Work (PoW) or Proof of Stake (PoS) which are designed for a generic blockchain. This notebook will be straightforward since we will not dive into the implementation of PoFL in `flexBlock`. 

### How does PoFL work?
In a PoFL blockchain, we must differentiate between two types of nodes: the miners and the clients. The miners are responsible for creating new blocks, running the consensus mechanism and validating the blocks. The clients are the data owners, they are the ones who train the model. Every client is allocated to a miner, and each miner with its clients will act as an independent pool. The miner with the best model (i.e. the model with the highest accuracy) will propagate its model to the other miners. The other miners will validate the model and if it is better than their own, they will accept it.


Let's start by writting some boilerplate code. We will use torch for the model.

In [None]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device}")

We are going to use the mnist dataset. In order to load the dataset, we can use the `load` function from flex which gives us access to federated datasets.

In [None]:
from flex.datasets import load
from torchvision import transforms

flex_dataset, test_data = load("federated_emnist", return_test=True, split="digits")

mnist_transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

Now we will define the model and the function for init the servers models. This is standard `flex` code.

In [None]:
import torch.nn as nn
import torch.nn.functional as F

from flex.pool import init_server_model
from flex.model import FlexModel



class SimpleNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


@init_server_model
def build_server_model():
    server_flex_model = FlexModel()

    server_flex_model["model"] = SimpleNet()
    # Required to store this for later stages of the FL training process
    server_flex_model["criterion"] = torch.nn.CrossEntropyLoss()
    server_flex_model["optimizer_func"] = torch.optim.Adam
    server_flex_model["optimizer_kwargs"] = {}
    return server_flex_model

Now it is time to create our pool. `flexBlock` has a ready-to-use implementation of PoFL. We just need to import it from the `pool` module and create a new instance of it. It requires a dataset, a function for initializing the models and the ammount of miners that we want in our blockchain. The clients will be split evenly between the miners.

In [None]:

from flexBlock.pool import PoFLBlockchainPool

pool = PoFLBlockchainPool(fed_dataset=flex_dataset, init_func=build_server_model, number_of_miners=10)

clients = pool.clients
servers = pool.servers
aggregators = pool.aggregators

print(
    f"Number of nodes in the pool {len(pool)}: {len(servers)} miners plus {len(clients)} clients. The server is also an aggregator"
)

We can then define a function for deploying the miners (server) model to the clients. Since we are on `flexBlock` we will use the `deploy_miner_model` function from the `pool.decorators` module. If you have a function for doing this in `flex` remember that you can reuse it by wrapping it with the `deploy_server_to_miner_wrapper` function from the `pool.primitives` module.

In [None]:
import copy

from flexBlock.pool.decorators import deploy_miner_model

@deploy_miner_model
def copy_server_model_to_clients(server_flex_model: FlexModel):
    return copy.deepcopy(server_flex_model)


servers.map(copy_server_model_to_clients, clients)

Let's define an standard training function.

In [None]:
from flex.data import Dataset
from torch.utils.data import DataLoader


def train(client_flex_model: FlexModel, client_data: Dataset):
    train_dataset = client_data.to_torchvision_dataset(transform=mnist_transforms)
    client_dataloader = DataLoader(train_dataset, batch_size=20)
    model = client_flex_model["model"]
    optimizer = client_flex_model["optimizer_func"](
        model.parameters(), **client_flex_model["optimizer_kwargs"]
    )
    model = model.train()
    model = model.to(device)
    criterion = client_flex_model["criterion"]
    for _ in range(1):
        for imgs, labels in client_dataloader:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            pred = model(imgs)
            loss = criterion(pred, labels)
            loss.backward()
            optimizer.step()

A function for collecting our weights.

In [None]:
from flexBlock.pool import send_weights_to_miner

@send_weights_to_miner
def get_clients_weights(client_flex_model: FlexModel):
    weight_dict = client_flex_model["model"].state_dict()
    return [weight_dict[name] for name in weight_dict]

And a function for updating the model.

In [None]:
from flex.pool import set_aggregated_weights


@set_aggregated_weights
def set_agreggated_weights_to_server(server_flex_model: FlexModel, aggregated_weights):
    with torch.no_grad():
        weight_dict = server_flex_model["model"].state_dict()
        for layer_key, new in zip(weight_dict, aggregated_weights):
            weight_dict[layer_key].copy_(new)

Now we will need a function for evaluating the global model of each miner. This function will take the global model of each miner and compute the accuracy in the test set.

In [None]:
def evaluate_global_model(server_flex_model: FlexModel, test_data: Dataset):
    model = server_flex_model["model"]
    model.eval()
    test_acc = 0
    total_count = 0
    model = model.to(device)
    # get test data as a torchvision object
    test_dataset = test_data.to_torchvision_dataset(transform=mnist_transforms)
    test_dataloader = DataLoader(
        test_dataset, batch_size=256, shuffle=True, pin_memory=False
    )
    with torch.no_grad():
        for data, target in test_dataloader:
            total_count += target.size(0)
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.data.max(1, keepdim=True)[1]
            test_acc += pred.eq(target.data.view_as(pred)).long().cpu().sum().item()

    test_acc /= total_count
    return test_acc

Finally, we define a training loop. When we call the `aggregate` method we see that we pass the `eval_function` for computing accuracy, the `eval_dataset` where we want to compute the accuracy and a threshold for the accuracy. We see that the `aggregate` method will return a boolean value indicating if we reached an aggregation, which means that at least one model has surpassed the threshold and by so communicated to the other miners, or not.

In [None]:
from flex.pool import fed_avg

def train_until_acc(acc: float):
    aggregated = False
    i = 0
    while not aggregated:
        i = i + 1
        print(f"\nRunning round: {i}")
        # Deploy the server model to the selected clients
        pool.servers.map(copy_server_model_to_clients, clients)
        selected_clients = pool.clients.select(20)
        # Each selected client trains her model
        selected_clients.map(train)
        # The aggregador collects weights from the selected clients and aggregates them
        pool.aggregators.map(get_clients_weights, selected_clients)
        aggregated = pool.aggregate(fed_avg, set_agreggated_weights_to_server, eval_function=evaluate_global_model, eval_dataset=test_data, accuracy=acc)

Finally, we can run our experiment.

In [None]:
train_until_acc(acc=0.6)