# FLEXLP tutotial:  Training a Semantic Similarity/Semantic Search with Sentence Transformers using Sentence Transformers.

FLEXNLP is an extension of the FLEXible library, developed to add specify features for Natural Language Processing (NLP). We offer tools to adapt your code easily into a federated environment. If you are not familiar with FLEXible, we recommend first to look at the tutorials, in order to understand how to convert your centralized code into a federated one.

In this notebook, we show how to federate a Sentence Transformers model. We use some primitives from FLEXible, but you can create your own ones.

In [None]:
from copy import deepcopy
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, models
from sentence_transformers import losses
from sentence_transformers.evaluation import TripletEvaluator
import torch

In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(device)

### Load the dataset

First we load the dataset. As there isn't federated datasets for this task, it is needed to load a centralized dataset and federate it. In this tutorial we are using the ´embedding-data/QQP_triplets´ dataset from **Huggigface Datasets**. We split the data into train/test, so we can evaluate in too on server's side.

In [None]:
# Load the dataset
dataset_id = "embedding-data/QQP_triplets"
# dataset_id = "embedding-data/sentence-compression"

data = load_dataset(dataset_id, split=['train[:1%]'])[0].train_test_split(test_size=0.1)
dataset, test_dataset = data['train'], data['test']
print(f"- The {dataset_id} dataset has {dataset.num_rows} examples.")
print(f"- Each example is a {type(dataset[0])} with a {type(dataset[0]['set'])} as value.")
print(f"- Examples look like this: {dataset[0]}")

# From centralized data to federated data

First we're going to federate the dataset using the FedDataDristibution class, that has functions to load multiple datasets from deep learning libraries such as PyTorch, TensorFlow or HuggingFace. In this notebook we are using PyTorch, so we need to use the functions from the PyTorch ecosystem with Huggingface, and for the text datasets, we need to use the function `from_config_with_huggingface_dataset`.

In [None]:
from flex.data import FedDatasetConfig, FedDataDistribution

config = FedDatasetConfig(seed=0)
config.n_clients = 2
config.replacement = False # ensure that clients do not share any data
config.client_names = ['client1', 'client2'] # Optional
flex_dataset = FedDataDistribution.from_config_with_huggingface_dataset(data=dataset, config=config,
                                                                        X_columns=['set'],
                                                                        label_columns=['set']
                                                                        )

# 2) Federate a model with FLEXible.

Once we've federated the dataset, it's time to create the FlexPool. The FlexPool class is the one that simulates the real-time scenario for federated learning, so it is in charge of the communications across actors. 

In [None]:
from flex.model import FlexModel
from flex.pool import FlexPool

from flex.pool.decorators import init_server_model

In this notebook we are going to simulate a client-server architecture, which we can easily build using the FlexPool class, using the function `client_server_architecture`. This function needs a FlexDataset, which we already have prepared, and a function to initialize the server model, which we have to create.

The model we are going to use is a `distilroberta-base` from the SentenceTransformers library.

In [None]:
@init_server_model
def build_server_model():
    server_flex_model = FlexModel()
    word_embedding_model = models.Transformer('distilroberta-base')
    pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
    server_flex_model['model'] = SentenceTransformer(modules=[word_embedding_model, pooling_model])

    return server_flex_model

Once we've defined the function to initialize the server model, we can create the FlexPool using the function `client_server_architecture`.

In [None]:
flex_pool = FlexPool.client_server_pool(fed_dataset=flex_dataset, init_func=build_server_model)

clients = flex_pool.clients
servers = flex_pool.servers
aggregators = flex_pool.aggregators

print(f"Number of nodes in the pool {len(flex_pool)}: {len(servers)} server plus {len(clients)} clients. The server is also an aggregator")

We can use the decorator `deploy_server_model` to create a custom function that deploys our server model, or we can use the primitive `deploy_server_model_pt` to deploy the server model to the clients.

In [None]:
from flex.pool import deploy_server_model, deploy_server_model_pt

@deploy_server_model
def copy_server_model_to_clients(server_flex_model: FlexModel):
    return deepcopy(server_flex_model)

In [None]:
servers.map(copy_server_model_to_clients, clients) # Using the function created with the decorator
# servers.map(deploy_server_model_pt, clients) # Using the primitive function

We need to prepare the data for the model. We have created an adapter for a triplet dataset, so once you load the data into FLEXible, we can just use the `ss_triplet_input_adapter` function. We have commented the evaluator of the model in the clients, but we keep it on the server side. 

In [None]:
from flexnlp.utils.adapters import ss_triplet_input_adapter

def train(client_flex_model: FlexModel, client_data):
    print("Training client")
    model = client_flex_model['model']
    sentences = ['This is an example sentence', 'Each sentence is converted']
    encodings = model.encode(sentences)
    print(f"Old encodings: {encodings}")
    X_data = client_data.X_data.tolist()
    tam_train = int(len(X_data) * 0.75)
    X_data, X_test = X_data[:tam_train], X_data[tam_train:]
    train_dataloader, _ = ss_triplet_input_adapter(X_data, X_test)
    train_loss = losses.TripletLoss(model=model)
    # evaluator = TripletEvaluator.from_input_examples(dev_examples)
    warmup_steps = int(len(train_dataloader) * 1 * 0.1) #10% of train data
    model.fit(train_objectives=[(train_dataloader, train_loss)],
        epochs=1,
        warmup_steps=warmup_steps,
        # evaluator=evaluator,
        evaluation_steps=1000,
    )
    # model.evaluate(evaluator, 'model_evaluation')
    sentences = ['This is an example sentence', 'Each sentence is converted']
    encodings = model.encode(sentences)
    print(f"New encodings: {encodings}")

In [None]:
clients.map(train)

After training the model, we have to aggregate the weights from the clients model in order to update the global model. To to so, we are going to use the primitive `collect_clients_weights_pt`.

In [None]:
from flex.pool import collect_clients_weights_pt

aggregators.map(collect_clients_weights_pt, clients)

Once the weights are aggregated, we aggregate them. In this notebook we use the Weighted FedAvg method that is already implemented in FLEXible. In this notebook we create the weights randomly, but feel free to follow the desired strategy for asigning the weights to the clients.

In [None]:
import numpy as np
from flex.pool import weighted_fed_avg

weights_by_clients = np.random.dirichlet(np.ones(config.n_clients), size=1)[0]

aggregators.map(weighted_fed_avg, ponderation=weights_by_clients)

The function `set_aggregated_weights_pt` sed the aggregated weights to the server model to update it.

In [None]:
from flex.pool import set_aggregated_weights_pt

aggregators.map(set_aggregated_weights_pt, servers)

Now it's turn to evaluate the global model. To do so, we have to create a function using the decoratod `evaluate_server_model`. We use the adapter function too, and here we show how to evaluate the model. The **SentenceTransformers** library has evaluators depeding on the dataset, and using the function `ss_triplet_input_adapter` let us use it to evaluate the model within the seleted data. The results are saved into a csv file on a folder, by default *model_evaluation*.

In [None]:
from flex.pool import evaluate_server_model
from flex.data import Dataset

test_dataset = Dataset.from_huggingface_dataset(test_dataset, X_columns=['set'])

@evaluate_server_model
def evaluate_global_model(server_flex_model: FlexModel, test_data=None):
    _, X_test = ss_triplet_input_adapter(X_test_as_list=test_dataset.X_data.tolist(), train=False)
    model = server_flex_model["model"]
    evaluator = TripletEvaluator.from_input_examples(X_test)
    model.evaluate(evaluator, 'server_evaluation')
    print("Model evaluation saved to file.")

In [None]:
servers.map(evaluate_global_model, test_data=test_dataset)

### Run the federated learning experiment for a few rounds

Now, we can summarize the steps provided above and run the federated experiment for multiple rounds:

In [None]:
def train_n_rounds(n_rounds, clients_per_round=2):  
    pool = FlexPool.client_server_pool(fed_dataset=flex_dataset, init_func=build_server_model)
    weights_by_clients = np.random.dirichlet(np.ones(config.n_clients), size=1)[0]
    for i in range(n_rounds):
        print(f"\nRunning round: {i+1} of {n_rounds}")
        selected_clients_pool = pool.clients.select(clients_per_round)
        selected_clients = selected_clients_pool.clients
        print(f"Selected clients for this round: {len(selected_clients)}")
        # Deploy the server model to the selected clients
        pool.servers.map(deploy_server_model_pt, selected_clients)
        # Each selected client trains her model
        selected_clients.map(train)
        # The aggregador collects weights from the selected clients and aggregates them
        pool.aggregators.map(collect_clients_weights_pt, selected_clients)
        pool.aggregators.map(weighted_fed_avg, weights_by_clients)
        # The aggregator send its aggregated weights to the server
        pool.aggregators.map(set_aggregated_weights_pt, pool.servers)
        pool.servers.map(evaluate_global_model, test_data=test_dataset)

In [None]:
# Train the model for n_rounds
train_n_rounds(5)

# End

Congratulations, you've just trained a **SentenceTransformers** model using FLEXible!