# Distributing an existing FLEXible experiment

FLEXible is a framework primarily designed for simulating federated learning (FL) experiments locally. However, it also provides a convenient method for reusing existing code and deploying it in a genuine federated environment. In this notebook, we demonstrate how to create a distributed FLEXible environment by reusing code from our local experiments. Additionally, we highlight the key differences and important considerations that need attention.

## The local experiment

In the following experiment, we will train a Multi-Layer Perceptron (MLP) using the MNIST dataset. This example is explained in greater detail in the `Federated MNIST PT example with flexible decorators`. We recommend reviewing that example for a more comprehensive understanding.

In [None]:
import copy

from flex.data import Dataset
from flex.datasets import load
from flex.pool import init_server_model, deploy_server_model, collect_clients_weights, set_aggregated_weights, aggregate_weights
from flex.pool import FlexPool
from flex.model import FlexModel

import tensorly as tl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

flex_dataset, test_data = load("federated_emnist", return_test=True, split="digits")

# Assign test data to server_id
server_id = "server"
flex_dataset[server_id] = test_data


mnist_transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)


class SimpleNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


@init_server_model
def build_server_model():
    server_flex_model = FlexModel()

    server_flex_model["model"] = SimpleNet()
    # Required to store this for later stages of the FL training process
    server_flex_model["criterion"] = torch.nn.CrossEntropyLoss()
    server_flex_model["optimizer_func"] = torch.optim.Adam
    server_flex_model["optimizer_kwargs"] = {}
    return server_flex_model


flex_pool = FlexPool.client_server_pool(
    flex_dataset, server_id=server_id, init_func=build_server_model
)


clients = flex_pool.clients
servers = flex_pool.servers
aggregators = flex_pool.aggregators


# Select clients
clients_per_round = 20
selected_clients_pool = clients.select(clients_per_round)
selected_clients = selected_clients_pool.clients


@deploy_server_model
def copy_server_model_to_clients(server_flex_model: FlexModel):
    return copy.deepcopy(server_flex_model)


servers.map(copy_server_model_to_clients, selected_clients)


def train(client_flex_model: FlexModel, client_data: Dataset):
    train_dataset = client_data.to_torchvision_dataset(transform=mnist_transforms)
    client_dataloader = DataLoader(train_dataset, batch_size=20)
    model = client_flex_model["model"]
    optimizer = client_flex_model["optimizer_func"](
        model.parameters(), **client_flex_model["optimizer_kwargs"]
    )
    model = model.train()
    model = model.to(device)
    criterion = client_flex_model["criterion"]
    for _ in range(5):
        for imgs, labels in client_dataloader:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            pred = model(imgs)
            loss = criterion(pred, labels)
            loss.backward()
            optimizer.step()


selected_clients.map(train)


@collect_clients_weights
def get_clients_weights(client_flex_model: FlexModel):
    weight_dict = client_flex_model["model"].state_dict()
    return [weight_dict[name] for name in weight_dict]


aggregators.map(get_clients_weights, selected_clients)


tl.set_backend("pytorch")


@aggregate_weights
def aggregate_with_fedavg(list_of_weights: list):
    agg_weights = []
    for layer_index in range(len(list_of_weights[0])):
        weights_per_layer = [weights[layer_index] for weights in list_of_weights]
        weights_per_layer = tl.stack(weights_per_layer)
        agg_layer = tl.mean(weights_per_layer, axis=0)
        agg_weights.append(agg_layer)
    return agg_weights


# Aggregate weights
aggregators.map(aggregate_with_fedavg)

@set_aggregated_weights
def set_agreggated_weights_to_server(server_flex_model: FlexModel, aggregated_weights):
    with torch.no_grad():
        weight_dict = server_flex_model["model"].state_dict()
        for layer_key, new in zip(weight_dict, aggregated_weights):
            weight_dict[layer_key].copy_(new)


aggregators.map(set_agreggated_weights_to_server, servers)


def evaluate_global_model(server_flex_model: FlexModel, test_data: Dataset):
    model = server_flex_model["model"]
    model.eval()
    test_loss = 0
    test_acc = 0
    total_count = 0
    model = model.to(device)
    criterion = server_flex_model["criterion"]
    # get test data as a torchvision object
    test_dataset = test_data.to_torchvision_dataset(transform=mnist_transforms)
    test_dataloader = DataLoader(
        test_dataset, batch_size=256, shuffle=True, pin_memory=False
    )
    losses = []
    with torch.no_grad():
        for data, target in test_dataloader:
            total_count += target.size(0)
            data, target = data.to(device), target.to(device)
            output = model(data)
            losses.append(criterion(output, target).item())
            pred = output.data.max(1, keepdim=True)[1]
            test_acc += pred.eq(target.data.view_as(pred)).long().cpu().sum().item()

    test_loss = sum(losses) / len(losses)
    test_acc /= total_count
    return test_loss, test_acc


metrics = servers.map(evaluate_global_model)
print(metrics[0])

This is quite some code, but we can summarize it in the following steps:
- From a dataset, we are able to federate it and create a pool from it. In our case the loaded dataset is already federated.
- We create a `FlexModel` instance for the server which contains all necesary variables for the training and evaluation.
- The server's model is copied to clients.
- Clients train their model with their data.
- Weights are collected by the server.
- The server/aggregator aggregates all the weights.
- The weights are set to the server and evaluated.
- Restart from second point until convergence is reached.

Now let's explore how to distribute this code.

## Distributing the local experiment
FLEXible supports (at the moment of writing) a real distributed Client-Server architecture. We will start with the client first. A FLEXible client is an instance of a subclass of the abstract class `flex.distributed.Client`. This class already contains all necesary code for communication making the user only to implement methods showing how to train, collect and set weights and run evaluation. Still, if we decided to make a local experiment first, FLEXible provides you with a convenient `flex.distributed.ClientBuilder` which will allow us to reuse a lot of code without needing to change much. Let's see:

In [None]:
# Client.py
# We assume that we have imported all the previous functions definitions from the local experiment.

from flex.distributed import ClientBuilder

# Load a dataset, in a local experiment a FedDataset is nothing but a collection of datasets. Let's pick the first one from the collection.
dataset = flex_dataset[list(flex_dataset.keys())[0]]

# We create a client by passing all necessary information to the ClientBuilder.
# We see that we can reuse the functions that we defined in the local experiment.
client = ClientBuilder() \
    .build_model(build_server_model) \
    .dataset(dataset) \
    .collect_weights(get_clients_weights) \
    .train(train) \
    .eval(evaluate_global_model, dataset) \
    .set_weights(set_agreggated_weights_to_server) \
    .build()
    
# Finally, we can run the client.
client.run(address="you_address_here:port")

We see that we have been able to create a client in just a few lines of code. For sake of simplicity, we are using the same dataset for training (set by `.dataset`) and for evaluation (set by `.eval`). Now let's talk about the minimum changes that we should have into account.
## Changes necessary for client code
### Collect and set weights must be wrapped
When we are writing the functions for our local experiments, we are always working with pools. The decorators `@collect_clients_weights` and `@set_aggregated_weights` allow us to make use of the user-defined functions for the whole pool. Still, in a real distributed scenario, we have no pool object. The client builder exploits these decorators so it is able to recover the original function. This is an opinionated decision but since most of the users will define these functions with the provided decorators, it is also the most convenient.
### Training and evaluation functions returns dictionaries
In order to recover in the server side any information about training or evaluation such as accuracy or loss, these functions must return a dictionary with strings as keys and float as values. This is enforced by the communication protocol used. In case we do not return anything, then an empty dictionary will be always returned.

### Set weights recieves an array of numpy arrays
In a local experiment, since we never leave the computer, we tend to keep always weights on device such as GPU. Unfortunaly, network communication imposes the necessity of moving the weights to be sent to the CPU. When we are collecting weights, this conversion is transparent to the user, but when recieving weights through the network (that is, when `set_weights` is being called) we are unable to convert to any other kind of array since FLEXible is framework agnostic. This way, is responsability of the user to convert this weights to the desired tensor type, such as `torch.Tensor`.


## The server code
Moving on, we show how the server code works. It is different from the local code but still very simple to use.

In [None]:
# Server.py

from flex.distributed import Server

# We need to create the server and storing it in a variable.
server = Server()

# Start the server.
server.run(address="your_address_here", port=8080)

# Now we are able to use the server to communicate with clients.
# We are able to see the total number of clients connected to the server.
clients_connected = len(server)

# To retrieve their ids.
clients_ids = server.get_ids()

# Now we may select a given ammoount of clients to run the FL process.
number_of_clients = 10
selected_ids = clients_ids[:number_of_clients]

# We can now tell the clients to start training.
# By passing the selected ids, the server will only communicate with the selected clients.
metrics = server.train(node_ids=selected_ids)

# Let's see the metrics.
for node_id in metrics:
    print(f"Client with id {node_id} has sent the following metrics: {metrics[node_id]}")
    
# Running evaluation is equivalent.
metrics = server.eval(node_ids=selected_ids)

# Now, we can also get the weights from the clients.
weights = server.collect_weights(node_ids=selected_ids)

# This weights can be aggregated
aggregate_weights = aggregate_with_fedavg.__wrapped__(weights)

# And then send the aggregated weights to the clients.
server.send_weights(aggregate_weights, node_ids=selected_ids)

# Finally, we can stop the server.
# This is very important to avoid memory leaks.
server.stop()

As we can see, the Server API allow us to fetch information from clients, providing us with a simple interface to use. The only important thing to notice here is that we must call `server.stop()` at the end of the code. This is necessary since `server.run()` starts multiple threads that will hang until we call this method.


Also FLEX provides an alternative Server API based on `asyncio` and thus the `async/await` syntax. This API is not only more efficient but also allows to run operations with timeout. The API is mostly the same and can be found on `flex.distributed.aio.Server`. Here we can see the previous server code with this API.

In [None]:
# Server_aio.py
import asyncio

from flex.distributed.aio import Server

async def run_server():
    # We need to create the server and storing it in a variable.
    server = Server()

    # Start the server.
    await server.run(address="your_address_here", port=8080)

    # Now we are able to use the server to communicate with clients.
    # We are able to see the total number of clients connected to the server.
    clients_connected = len(server)

    # Also we can wait for getting some ammount of clients connected
    await server.wait_for_clients(10)

    # To retrieve their ids.
    clients_ids = server.get_ids()

    # Now we may select a given ammoount of clients to run the FL process.
    number_of_clients = 10
    selected_ids = clients_ids[:number_of_clients]

    # We can now tell the clients to start training.
    # By passing the selected ids, the server will only communicate with the selected clients.
    metrics = await server.train(node_ids=selected_ids)

    # Let's see the metrics.
    for node_id in metrics:
        print(f"Client with id {node_id} has sent the following metrics: {metrics[node_id]}")
        
    # Running evaluation is equivalent.
    metrics = await server.eval(node_ids=selected_ids)

    # Now, we can also get the weights from the clients.
    weights = await server.collect_weights(node_ids=selected_ids)

    # This weights can be aggregated
    aggregate_weights = aggregate_with_fedavg.__wrapped__(weights)

    # And then send the aggregated weights to the clients.
    await server.send_weights(aggregate_weights, node_ids=selected_ids)

    # Finally, we can stop the server.
    # This is very important to avoid memory leaks.
    await server.stop()

# Run with asyncio
asyncio.run(run_server())