# FLEXLP tutotial:  Training a Question Answering model using Pytorch and Huggingface

FLEXNLP is an extension of the FLEXible library, developed to add specify features for Natural Language Processing (NLP). We offer tools to adapt your code easily into a federated environment. If you are not familiar with FLEXible, we recommend first to look at the tutorials, in order to understand how to convert your centralized code into a federated one.

In this notebook, we show how to federate a HuggingFace model for Question Answering. We use some primitives from FLEXible, but you can create your own ones.

In [None]:
from copy import deepcopy
import torch
import torch.nn as nn
from datasets import load_dataset
from datasets import Dataset as HFDataset
from transformers import AutoTokenizer
from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer
import collections
import numpy as np
import evaluate

In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(device)

### Load the dataset

First we load the dataset. As there isn't federated datasets for this task, it is needed to load a centralized dataset and federate it. In this tutorial we are using the ´squad´ dataset from **Huggigface Datasets**. This dataset is usually used as a benchmark for question answering models, and it is compatible with FLEXIble, as we show below. For this tutorial we are using 1% of the data, to just show how to load the data and use the model. We split the data into train/test instead of using the train/test split from the dataset.

In [None]:
# Load a percentage of squal
squad = load_dataset("squad", split="train[:1%]")
# Split 80% train, 20% test
squad  = squad.train_test_split(test_size=0.2)
print(squad)

### Preprocess

In order to use the dataset, we need to preprocess it to adapt the data into the expected input. We have created to different functions to preprocess the data, one for the training examples and another for the test/validation examples.

In [None]:
model_checkpoint = "distilbert-base-uncased"
#model_checkpoint = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

max_length = 384
stride = 128


def preprocess_training_examples_as_lists(examples, answers_examples):
    """
    Function that preprocess the data that comes as a list 
    instead as a Dataset type.
    Args:
        examples (list[list]): List of lists containg the examples to
        preprocess. ['id', 'title', 'context', 'question']
        answers (list[str]): List containing the answers
    """
    questions = [q[3].strip() for q in examples]
    contexts = [c[2] for c in examples]
    inputs = tokenizer(
        questions,
        # examples["context"],
        contexts,
        max_length=max_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = inputs.pop("offset_mapping")
    sample_map = inputs.pop("overflow_to_sample_mapping")
    # answers = examples["answers"]
    answers = [answers_examples[1][i] for i in range(len(answers_examples[1]))]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i]
        answer = answers[sample_idx]
        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label is (0, 0)
        if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return HFDataset.from_dict(inputs)

def preprocess_validation_examples(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_map = inputs.pop("overflow_to_sample_mapping")
    example_ids = []

    for i in range(len(inputs["input_ids"])):
        sample_idx = sample_map[i]
        example_ids.append(examples["id"][sample_idx])

        sequence_ids = inputs.sequence_ids(i)
        offset = inputs["offset_mapping"][i]
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    inputs["example_id"] = example_ids
    return inputs

train_dataset = squad["train"]

test_dataset = squad["test"].map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=squad["test"].column_names,
)


# From centralized data to federated data

First we're going to federate the dataset using the FedDataDristibution class, that has functions to load multiple datasets from deep learning libraries such as PyTorch or TensorFlow. In this notebook we are using HuggingFace with PyTorch, so we need to use the primitives functions from the PyTorch ecosystem. The data is available in the *datasets* library, from HuggingFace, that's why here we use the function `from_config_with_huggingface_dataset`.

In [None]:
from flex.data import FedDatasetConfig, FedDataDistribution

config = FedDatasetConfig(seed=0)
config.n_clients = 2
config.replacement = False # ensure that clients do not share any data
config.client_names = ['client1', 'client2'] # Optional
flex_dataset = FedDataDistribution.from_config_with_huggingface_dataset(data=train_dataset, config=config,
                                                                        X_columns=['id', 'title', 'context', 'question'],
                                                                        label_columns=['answers']
                                                                        )

We may also want to use the FLEXible dataset for the test data, so we just use da function `from_huggingface_dataset` in the Dataset class.

In [None]:
from flex.data import Dataset

test_dataset = Dataset.from_huggingface_dataset(test_dataset,
                                                X_columns=['input_ids', 'attention_mask'])

# Federate a model with FLEXible.

Once we've federated the dataset, it's time to create the FlexPool. The FlexPool class is the one that simulates the real-time scenario for federated learning, so it is in charge of the communications across actors. 

In [None]:
from flex.model import FlexModel
from flex.pool import FlexPool

from flex.pool.decorators import init_server_model
from flex.pool.decorators import deploy_server_model

In this notebook we are going to simulate a client-server architecture, which we can easily build using the FlexPool class, using the function `client_server_architecture`. This function needs a FlexDataset, which we already have prepared, and a function to initialize the server model, which we have to create.

The model we are going to use is `distilbert-base-uncased` for Question Answering, and we load it as follows.

In [None]:
@init_server_model
def build_server_model():
    server_flex_model = FlexModel()

    server_flex_model['model'] = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-uncased")
    # Required to store this for later stages of the FL training process
    server_flex_model['training_args'] = TrainingArguments(
        output_dir="my_awesome_qa_model",
        evaluation_strategy="epoch",
        learning_rate=2e-5,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=3,
        weight_decay=0.01,
    )

    return server_flex_model

Once we've defined the function to initialize the server model, we can create the FlexPool using the function `client_server_architecture`.

In [None]:
flex_pool = FlexPool.client_server_pool(fed_dataset=flex_dataset, init_func=build_server_model)

clients = flex_pool.clients
servers = flex_pool.servers
aggregators = flex_pool.aggregators

print(f"Number of nodes in the pool {len(flex_pool)}: {len(servers)} server plus {len(clients)} clients. The server is also an aggregator")

We can use the decorator `deploy_server_model` to create a custom function that deploys our server model, or we can use the primitive `deploy_server_model_pt` to deploy the server model to the clients.

In [None]:
from flex.pool import deploy_server_model, deploy_server_model_pt

@deploy_server_model
def copy_server_model_to_clients(server_flex_model: FlexModel):
    return deepcopy(server_flex_model)

In [None]:
servers.map(copy_server_model_to_clients, clients) # Using the function created with the decorator
# servers.map(deploy_server_model_pt, clients) # Using the primitive function

As text needs to be preprocessed and batched on the clients, we can do it on the train function.

As we have preprocesed the text before federating the data, and we are using the `Trainer` class from the Transformers library, we can train the client's models using the `train` function from the `Trainer` class

In [None]:
# Train each client's model
def train(client_flex_model: FlexModel, client_data: Dataset):
    print("Training client")
    model = client_flex_model['model']
    training_args = client_flex_model['training_args']
    X_data = client_data.X_data.tolist()
    y_data = client_data.to_list()
    client_train_dataset = preprocess_training_examples_as_lists(examples=X_data, answers_examples=y_data)
    trainer = Trainer(
        model = model,
        args=training_args,
        train_dataset=client_train_dataset,
        tokenizer=tokenizer,
    )
    trainer.train()

clients.map(train)

After training the model, we have to aggregate the weights from the clients model in order to update the global model. To to so, we are going to use the primitive `collect_clients_weights_pt`.

In [None]:
from flex.pool import collect_clients_weights_pt

aggregators.map(collect_clients_weights_pt, clients)

Once the weights are aggregated, we aggregate them. In this notebook we use the FedAvg method that is already implemented in FLEXible.

In [None]:
from flex.pool import fed_avg

aggregators.map(fed_avg)

The function `set_aggregated_weights_pt` sed the aggregated weights to the server model to update it.

In [None]:
from flex.pool import set_aggregated_weights_pt

aggregators.map(set_aggregated_weights_pt, servers)

### Evaluate the model

Now it's turn to evaluate the global model. To do so, we have to create a function using the decorator `evaluate_server_model`. 

In question answering we have to postprocess the predictions obtained, so we have created the function `compute_metrics` that will give us the performance of the model. Here we use the trainer function too. To do so, we creater a trainer instance in the server's FlexModel.

In [None]:
from tqdm import tqdm

n_best = 20
max_answer_length = 30
predicted_answers = []
metric = evaluate.load("squad")

def compute_metrics(start_logits, end_logits, features, examples):
    example_to_features = collections.defaultdict(list)
    for idx, feature in enumerate(features):
        example_to_features[feature["example_id"]].append(idx)

    predicted_answers = []
    for example in tqdm(examples):
        example_id = example["id"]
        context = example["context"]
        answers = []

        # Loop through all features associated with that example
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = features[feature_index]["offset_mapping"]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Skip answers that are not fully in the context
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0 or > max_answer_length
                    if (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue

                    answer = {
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                    answers.append(answer)

        # Select the answer with the best score
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append(
                {"id": example_id, "prediction_text": best_answer["text"]}
            )
        else:
            predicted_answers.append({"id": example_id, "prediction_text": ""})

    theoretical_answers = [{"id": ex["id"], "answers": ex["answers"]} for ex in examples]
    return metric.compute(predictions=predicted_answers, references=theoretical_answers)

In [None]:
from flex.pool import evaluate_server_model


@evaluate_server_model
def evaluate_global_model(server_flex_model: FlexModel, test_data=None):
    model = server_flex_model["model"]
    training_args = server_flex_model["training_args"]
    trainer = Trainer(
        model = model,
        args=training_args,
        train_dataset=test_data,
        tokenizer=tokenizer,
    )
    predictions, _, _ = trainer.predict(test_data)
    start_logits, end_logits = predictions
    print("Server metrics:", compute_metrics(start_logits, end_logits, test_data, squad["test"]))

In [None]:
servers.map(evaluate_global_model, test_data=test_dataset)

### Run the federated learning experiment for a few rounds

Now, we can summarize the steps provided above and run the federated experiment for multiple rounds:

In [None]:
def train_n_rounds(n_rounds, clients_per_round=2):  
    pool = FlexPool.client_server_pool(fed_dataset=flex_dataset, init_func=build_server_model)
    for i in range(n_rounds):
        print(f"\nRunning round: {i+1} of {n_rounds}")
        selected_clients_pool = pool.clients.select(clients_per_round)
        selected_clients = selected_clients_pool.clients
        print(f"Selected clients for this round: {len(selected_clients)}")
        # Deploy the server model to the selected clients
        pool.servers.map(deploy_server_model_pt, selected_clients)
        # Each selected client trains her model
        selected_clients.map(train)
        # The aggregador collects weights from the selected clients and aggregates them
        pool.aggregators.map(collect_clients_weights_pt, selected_clients)
        pool.aggregators.map(fed_avg)
        # The aggregator send its aggregated weights to the server
        pool.aggregators.map(set_aggregated_weights_pt, pool.servers)
        pool.servers.map(evaluate_global_model, test_data=test_dataset)

In [None]:
train_n_rounds(5)

# End

Congratulations, you have just trained a Question Answering model using the flexnlp library from the FLEXible environment.