# FLEXible tutorial: Text classification using PyTorch

FLEXible is a library to federate models. We offer the tools to load and federate data or to load federated data, and the tools to create a federated environment. The user must define the model and the *communication primitives* to train the model in a federated environment. This primitives can be expressed in the following steps:
- initialization: Initialize the model in the server.
- deplot model: Deploy the model to the clients.
- training: Define the train function.
- collect the weights: Collect the weights of the clients params to aggregate them later.
- aggregate the weights: Use an aggregation method to aggregte the collected weights.
- deploy model: Deploy the model with the updated weights to the clients.
- evaluate: Define the evaluate function.

In this notebook, we show how to implement this primitives and how to use FLEXible in orther to federate a model using TensorFlow. In this way, we will train a model using multiple clients, but without sharing any data between clients. We will follow this [tutorial](https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html#) from the PyTorch tutorials for text classification.

## Setup

In [None]:
import os
import time

from copy import deepcopy
import numpy as np

import torch
from torchtext.datasets import AG_NEWS
from datasets.load import load_dataset
from torch import nn
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader, TensorDataset, Dataset


In PyTorch we need to define the device where we are going to train and evaluate the model.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") # For MAC_OS GPU
print(device)

As usual in every experiment, the first step is to load the dataset we will use. In this case we will use the dataset **AG_News** for a supervised text classification model. In this tutorial we will have to tokenize the data in orther to build the vocab, as we will build a LSTM model from scratch.

We will use torchtext to load the tokenizer and to build the vocabulary from the dataset.
For downloading the dataset, we will use the [datasets package](https://huggingface.co/datasets)

In [None]:
tokenizer = get_tokenizer('basic_english') # Get the tokenizer

ag_news_dataset = load_dataset('ag_news', split='test') # Get the dataset from huggingface library
train_data = ag_news_dataset['train'] # Get the train data
test_data = ag_news_dataset['test'] # Get the test data

dataset_ag_news = AG_NEWS(split='train')

# Function to help us building the vocabulary
def yield_tokens_hf(data):
    for text in data['text']:
        yield tokenizer(text)

vocab_hf = build_vocab_from_iterator(yield_tokens_hf(train_data), specials=["<unk>"])
vocab_hf.set_default_index(vocab_hf["<unk>"])

train_examples, train_labels = np.array(train_data['text']), np.array(train_data['label'])
test_examples, test_labels = np.array(test_data['text']), np.array(test_data['label'])

print(f"Training entries: {len(train_examples)}, test entries: {len(test_examples)}")

# Auxiliar functions to help us with the collate function.
text_pipeline = lambda x: vocab_hf(tokenizer(x))
label_pipeline = lambda x: int(x)

# 1) From centralized data to federated data
Firstly and foremost, we need to encapsulare our centralized dataset as numpy arrays in a FlexDataObject, to split it for every federated client.
As we are using a centrilized dataset, we have to federate it. To federate the data we need to create a basic data object for FLEXible that is called **FlexDataObject**. To create a  **FlexDataObject** we need to have the data as *numpy.arrays*.

The dataset AG_News is available in both TorchText and HuggingFace datasets, so we can use both to load the data as a FlexDataObject.

In [None]:
from flex.data import FlexDataObject

flex_data = FlexDataObject.from_torchtext_dataset(dataset_ag_news)

flex_data = FlexDataObject.from_huggingface_dataset(hf_dataset=train_data, X_columns='text', label_column='label')

In order to federate our dataset, we need to specify how we want to split it among clients in a ``FlexDatasetConfig`` object. For this case we want to split it evenly between 2 clients, that is, an iid distribution. To apply our config to our centralized dataset, we use ``FlexDataDistribution.from_config``. A more complete description of the configuration options of ``FlexDatasetConfig`` to federate a dataset can be found in the documentation.

In [None]:
from flex.data import FlexDatasetConfig, FlexDataDistribution

config = FlexDatasetConfig(seed=0)
config.n_clients = 2
config.replacement = False # ensure that clients do not share any data
config.client_names = ['client1', 'client2'] # Optional
flex_dataset = FlexDataDistribution.from_config(cdata=flex_data, config=config)

However, there is a shortcut, if we want to split the dataset iid between the clients we can directly use ``FlexDataDistribution.iid_distribution`` with the number of clients and our centralized data stored in a ``FlexDataObject``. Note that in this case the name of the clients are generated automatically: client number ``i`` gets id: ``f"client_{i}"``.

In [None]:
from flex.data import FlexDataDistribution

flex_dataset = FlexDataDistribution.iid_distribution(flex_data, n_clients=2)

# 2) Federating a model with FLEXible

Once we've federated the dataset, we have to create the FlexPool. The FlexPool class simulates a real-time scenario for federated learning, so it is in charge of the communications across the actors. The class FlexPool will assign to each actor a role (client, aggregator, server), so they can communicate during the training phase.

Please, check the notebook about the actors (TODO: Hacer notebook actores y sus relaciones) to know more about the actors and their relationships in FLEXible.

To create a Pool of actors, we need to have a federated dataset, like we've just done, and the model to initialize in the server side, because the server will send the model to the clients so they can train the model. As we have the federated dataset (flex_dataset), we will now create the model.

In this case, we will use a model from the tensorflow hub, so we dont have to worry about coding it. We also consider a federated setup commonly know as client server architecture, where a server orchestrates the training of federated clients in every round.

In the following, we create a client server architecture and provide a function to initialize the server model.

In [None]:
class TextClassificationModel(nn.Module):

    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

In [None]:
def initialize_server_model(flex_model, *args, **kwargs):
    print("Initializing server model.")
    model = TextClassificationModel(vocab_size=kwargs['vocab_size'],
                                    embed_dim=kwargs['embed_dim'],
                                    num_class=kwargs['num_class'])
    flex_model['model'] = model
    flex_model['criterion'] = kwargs['criterion']
    flex_model['optimizer'] = kwargs['optimizer'](model.parameters(), lr=kwargs['learning_rate'])

In [None]:
vocab_size = len(vocab_hf) # Length of the vocab
embed_dim = 64 # We set the embed dim to 64 for creating a low model, just for the tutorial
num_class = len(set(train_labels)) # Number of classes on the dataset

In [None]:
from flex.pool import FlexPool

flex_pool = FlexPool.client_server_architecture(fed_dataset=flex_dataset, init_func=initialize_server_model,
                                                vocab_size=vocab_size, embed_dim=embed_dim,
                                                num_class=num_class,
                                                criterion=torch.nn.CrossEntropyLoss(),
                                                optimizer=torch.optim.SGD,
                                                learning_rate=5
                                                )
clients = flex_pool.clients
server = flex_pool.servers
print(f"Server node is indentified by {server.actor_ids}")
print(f"Client nodes are identified by {clients.actor_ids}")

We define the criterion and the optimizer now, even thought they will be used later. This is needed becaue we may want to use optimizers that saves a state, and that will be used in multiple rounds of the federated learning, and we don't want it to initialize in each round. The SGD optimizer does not keep and state, and it's not needed to be declared here but, this is the best practice so you don't forget it later in the trainin function. 

We have to create the function that will deploy the model to the clients. 

In [None]:
def deploy_model_to_clients(server_model, clients_model, *args, **kwargs):
    print("Initializing model at a client")
    for client_id in clients_model:
        clients_model[client_id] = deepcopy(server_model)

To work in an easier way, FlexPool let the use to have organized pools, such as clients, aggregators or servers. This helps to understand how we are connecting the actors.

In [None]:
clients = flex_pool.clients
server = flex_pool.servers

To apply all the primitives, such as the deploy step, we will use the **map** function from *FlexPool*. The map function works in the following way: the pool that calls the function map, is the one that will send a message to the destiny pool. If we don't specify it to any pool, no destiny pool, it will "send" the message to the same pool that it's calling the map function. This is needed if we want to tell the clients to train/evaluate the model.

In [None]:
server.map(deploy_model_to_clients, clients)

Once the model is deployed on the clients, is time to create the training function. Our PyTorch model needs a DataLoader to feed the data into the model in batches. So first, let's the requirements for the DataLoader, that in our case, is the collate function. This collate function will help us to apply the preprocessing for each batch.

In [None]:
class TextDataset(Dataset):
    def __init__(self, text, label):
        self.text = text
        self.label = label

    def __len__(self):
        return len(self.label)

    def __getitem__(self, idx):
        return self.text[idx], self.label[idx]

In [None]:
def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    for (_text, _label) in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))
    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)
    return label_list.to(device), text_list.to(device), offsets.to(device)

In PyTorch, we need to define the training function for our model. GIven that, we will create two different functions for training, one for the model, and one for the client. In TensorFlow we can call the *fit* method and the model will just train, but here we have to create this fit method.

First we create the fit method as follows:

In [None]:
def fit_client_model(model, dataloader, *args, **kwargs):
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 500
    start_time = time.time()
    optimizer = kwargs['optimizer']
    criterion = kwargs['criterion']
    epoch = kwargs['epoch']
    for idx, (label, text, offsets) in enumerate(dataloader):
        optimizer.zero_grad()
        predicted_label = model(text, offsets)
        loss = criterion(predicted_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches '
                    '| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),
                                                total_acc/total_count))
            total_acc, total_count = 0, 0
            start_time = time.time()

And now the train method for our federated environment:

In [None]:
def train(client_model, client_data, *args, **kwargs):
    print("Training model at client.")
    model = client_model['model']
    # optimizer = torch.optim.SGD(model.parameters(), lr=kwargs['learning_rate'])
    optimizer = client_model['optimizer']
    # criterion = torch.nn.CrossEntropyLoss()
    criterion = client_model['criterion']
    epochs = kwargs['epochs']
    client_dataset = TextDataset(text=client_data.X_data, label=client_data.y_data)
    client_dataloader = DataLoader(client_dataset, batch_size=kwargs['batch_size'], shuffle=False, collate_fn=collate_batch)
    for epoch in range(epochs):
        epoch_start_time = time.time()
        fit_client_model(model=model, dataloader=client_dataloader,
                        optimizer=optimizer, criterion=criterion, epoch=epoch)
        print('-' * 59)
        print('| end of epoch {:3d} | time: {:5.2f}s | '.format(epoch,
                                            time.time() - epoch_start_time))
        print('-' * 59)
        

Now we will train the model in the clients side. We will use the *map function* to tell the clients to train the model, and, to do so, we just need to use this function from the clients pool.

In [None]:
clients.map(train, batch_size=512, epochs=10)

Now that we have trained the model we have to aggregate the weights. To do so, clients will send the weights to the aggregator, and she will perform the aggregation step. For the tutorial, we will implement the FevAvg aggregation mechanism. That is, the aggreation step is split in two steps, 1) for collecting the weights from each client and 2) for averaging them.

First, we select the aggregator, which in this case is the same as the server, because in the client server architecture, the server is also an aggregator.

In [None]:
aggregator = flex_pool.aggregators
aggregator.actor_ids

In [None]:
def collect_weights(client_model, aggregator_model, **kwargs):
    print("Collecting weights.")
    if 'weights' not in aggregator_model["server"]:
        aggregator_model["server"]['weights'] = []
    client_weights = [param.cpu().data.numpy() for param in client_model['model'].parameters()]
    aggregator_model["server"]['weights'].append(client_weights)

In [None]:
clients.map(collect_weights, aggregator)

In [None]:
def aggregate_weights(agg_model, *args):
    print("Aggregating weights")
    averaged_weights = np.mean(np.array(agg_model['weights']), axis=0)
    with torch.no_grad():
        for old, new in zip(agg_model['model'].parameters(), averaged_weights):
            old.data = torch.from_numpy(new).float()
    agg_model["weights"] = []

In [None]:
aggregator.map(aggregate_weights)

Now it's turn from the server to update the weights from the clients models and then evaluate the model.

In [None]:
def deploy_global_model_to_clients(server_model, clients_models, *args, **kwargs):
    print("Deploying the global model on the clients.")
    aggregated_weights = [param.cpu().data.numpy() for param in server_model['model'].parameters()]
    with torch.no_grad():
        for client_model in clients_models:
            for old, new in zip(clients_models[client_model]['model'].parameters(), aggregated_weights):
                old.data = torch.from_numpy(new).float()

In [None]:
server.map(deploy_global_model_to_clients, clients)

And now, we can evaluate the model with the test set that we prepared at the begining of the notebook.

In [None]:
def evaluate_model_torch(model, dataloader, criterion):
    model.eval()
    total_acc, total_count = 0, 0
    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            predicted_label = model(text, offsets)
            loss = criterion(predicted_label, label)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc/total_count

In [None]:
def evaluate_model(flex_model, data, *args, **kwargs):
    model = flex_model['model']
    criterion = flex_model['criterion']
    if data is not None:
        print("Evaluating model at client.")
        client_dataset = TextDataset(data.X_data, data.y_data)
        client_dataloader = DataLoader(client_dataset, batch_size=kwargs['batch_size'], shuffle=True, collate_fn=collate_batch)
        results_local = evaluate_model_torch(model, client_dataloader, criterion)
        print(f"Results at client on client's data: {results_local}")
    else:
        print("Evaluating model at server")
    test_dataset = TextDataset(kwargs['test_examples'], kwargs['test_labels'])
    test_dataloader = DataLoader(test_dataset, batch_size=kwargs['batch_size'], shuffle=True, collate_fn=collate_batch)
    results = evaluate_model_torch(model, test_dataloader, criterion)
    print(f"Results on test data: {results}")

In [None]:
server.map(evaluate_model, test_examples=test_examples, test_labels=test_labels, batch_size=8)

In [None]:
clients.map(evaluate_model, test_examples=test_examples, test_labels=test_labels, batch_size=8)

# Putting it all together

You just have trained a model for 1 round using FLEXible. Now, you could set up all together in a function and iterate for multiple rounds.

In [None]:
def train_n_rounds(n_rounds, batch_size, epochs):
    pool = FlexPool.client_server_architecture(fed_dataset=flex_dataset, init_func=initialize_server_model,
                                                vocab_size=vocab_size, embed_dim=embed_dim,
                                                num_class=num_class,
                                                criterion=torch.nn.CrossEntropyLoss(),
                                                optimizer=torch.optim.SGD,
                                                learning_rate=5
                                                )
    pool.servers.map(deploy_model_to_clients, pool.clients)
    for i in range(n_rounds):
        print(f"\nRunning round: {i}\n")
        pool.clients.map(train, batch_size=batch_size, epochs=epochs)
        pool.clients.map(evaluate_model, test_examples=test_examples, test_labels=test_labels, batch_size=8)
        pool.clients.map(collect_weights, pool.aggregators)
        pool.aggregators.map(aggregate_weights)
        pool.servers.map(deploy_global_model_to_clients, pool.clients)
        pool.servers.map(evaluate_model, test_examples=test_examples, test_labels=test_labels, batch_size=8)

In [None]:
train_n_rounds(n_rounds=4, batch_size=512, epochs=10)

### END
Congratulations, now you know how to train a model using FLEXible for multiples rounds. Remember that it's important to first deploy/initialize the model on the clients, so you can run the rounds without problem!