In [None]:
from copy import deepcopy
import numpy as np

from datasets.load import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset
import torchtext

from flexnlp.utils.collators import ClassificationCollator

In [None]:

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(device)

Cargamos el dataset

In [None]:
# imdb_dataset = load_dataset('imdb', split=['train', 'test']) # Get the dataset from huggingface library
train_dataset, test_dataset = torchtext.datasets.AG_NEWS() # Get the dataset from torchtext library
unique_classes = set([label for (label, text) in train_dataset])
num_classes = len(unique_classes)

Preparativos como los embeddings, el vocabulario, etc

In [None]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import GloVe, FastText, vocab

In [None]:
embeddings_dim = 50 # Dimension of the embeddings
glove = GloVe(name='6B', dim=embeddings_dim) # Load GloVe embeddings with 100 dimensions.
# fasttext = FastText(language='en') # To use FastText instead of GloVe
vocabulary = vocab(glove.stoi)
# vocabulary_fasttext = vocab(fasttext.stoi) # To use FastText instead of GloVe
vocab_size = len(vocabulary) # Get the vocabulary size
print(f"Total vocabulary size: {vocab_size}")
print(f"Shape of embeddings: {glove.vectors.shape}")

In [None]:
example = "This is an example sentence to test the tokenizer."
tokenizer = get_tokenizer("basic_english")
spacy_tokenizer = get_tokenizer("spacy", language="en_core_web_sm")
example_tokens = tokenizer(example)
example_tokens_spacy = spacy_tokenizer(example)

In [None]:
vocabulary.get_itos()[:10] # Get the first 10 words of the vocabulary

In [None]:
print(f"Padding token idx, pad: {vocabulary.get_itos()[0]}") # Get the index of the word '<pad>' for padding
print(f"Padding token idx, pad: {vocabulary.get_itos()[0:10]}") # Get the index of the word '<pad>' for padding

In [None]:
pad_token = "<pad>"
pad_index = 0
vocabulary.insert_token(pad_token, pad_index)
vocabulary.set_default_index(pad_index)
# glove.vectors = torch.cat(1, (torch.zeros(1, embeddings_dim), glove.vectors))
pretrained_embeddings = glove.vectors
print(f"Len pretrained embeddings: {len(pretrained_embeddings)}")
pretrained_embeddings = torch.cat((torch.zeros(1,pretrained_embeddings.shape[1]),pretrained_embeddings))
print(f"Len pretrained embeddings: {len(pretrained_embeddings)}")

In [None]:
print(f"Padding token idx, pad: {vocabulary.get_itos()[0:10]}") # Get the index of the word '<pad>' for padding

We can use the basic english tokenizer from PyTorch, or the SpaCy tokenizer if we have spacy downloaded. Here we probe both tokenizer with the same example sentence.

In [None]:
print(f"Basic English Tokenizer: {example_tokens}")
print(f"Spacy Tokenizer: {example_tokens_spacy}")

Client's will probably want to delete the stopwords, optional, as the embeddings may have vectors for most of the stopwords. Here we show multiple options show the user must decide what he prefers to use. In this notebook we're going to use the first case, as it will have most information. In other case, we would use the last one, so at least we keep the most information we can. 

Later we will have to tokenize the clients data, and then we will add the padding to the sequences, and will convert the token to the index of the embedding matrix (ids).

In [None]:
# Remove stopwords
from nltk.corpus import stopwords
stop_words = set(stopwords.words('english'))

print(f"Example tokens tokenized: {[word.lower() for word in example_tokens_spacy]}")

print(f"Example tokens without stopwords: {[word.lower() for word in example_tokens_spacy if word not in stop_words]}")

print(f"Example tokens without stopwords and word in vocabulary: {[word.lower() for word in example_tokens_spacy if word not in stop_words and word.lower() in vocabulary]}")

print(f"Example tokens without quitting stopwords and word in vocabulary: {[word.lower() for word in example_tokens_spacy if word.lower() in vocabulary]}")

# From centralized data to federated data

First we're going to federate the dataset using the FedDataDristibution class, that has functions to load multiple datasets from deep learning libraries such as PyTorch or TensorFlow. In this notebook we are using PyTorch, so we need to use the functions from the PyTorch ecosystem, and for the text datasets, we need to use the function `from_config_with_torchtext_dataset`.

In [None]:
from flex.data import FedDatasetConfig, FedDataDistribution

config = FedDatasetConfig(seed=0)
config.n_clients = 2
config.replacement = False # ensure that clients do not share any data
config.client_names = ['client1', 'client2'] # Optional
flex_dataset = FedDataDistribution.from_config_with_torchtext_dataset(data=train_dataset, config=config)

We may also want to use the FLEXible dataset for the test data, so we just use da function `from_torchtext_dataset` in the Dataset class.

In [None]:
from flex.data import Dataset

test_dataset = Dataset.from_torchtext_dataset(test_dataset)

# 2) Federate a model with FLEXible.

Once we've federated the dataset, it's time to create the FlexPool. The FlexPool class is the one that simulates the real-time scenario for federated learning, so it is in charge of the communications across actors. 

In [None]:
from flex.model import FlexModel
from flex.pool import FlexPool

from flex.pool.decorators import init_server_model
from flex.pool.decorators import deploy_server_model

In this notebook we are going to simulate a client-server architecture, which we can easily build using the FlexPool class, using the function `client_server_architecture`. This function needs a FlexDataset, which we already have prepared, and a function to initialize the server model, which we have to create.

The model we are going to use is a simple LSTM, which will have the embeddings, the LSTM, a Linear layer and the output layer.

In [None]:

class GruNet(nn.Module):
    def __init__(self, embeddings, hidden_size, num_classes):
        super().__init__()
        # Initialize the Embedding Layer with the GloVe embeddings.
        self.emb = nn.Embedding.from_pretrained(embeddings,
                                                freeze=True,
                                                padding_idx=0
                                                )
        # Take the embeddings size from the embeddings vector.
        self.embedding_size = embeddings.shape[1]
        #Create the GRU layer with just one layer.
        self.gru = nn.GRU(self.embedding_size,
                        hidden_size,
                        batch_first=True,
                        num_layers=1
                    )
        # Create the prediction layer.
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # x.shape = [batch_size, len]
        x = self.emb(x)
        # x.shape = [batch_size, len, emb_dim]
        _, x = self.gru(x)
        # x.shape = [1, batch_size, hid_dim]
        x = self.fc(x)
        return x


@init_server_model
def build_server_model():
    server_flex_model = FlexModel()

    server_flex_model['model'] = GruNet(embeddings=pretrained_embeddings, hidden_size=128,
                                        num_classes=num_classes)
    # Required to store this for later stages of the FL training process
    server_flex_model["criterion"] = torch.nn.CrossEntropyLoss()
    server_flex_model["optimizer_func"] = torch.optim.SGD
    server_flex_model["optimizer_kwargs"] = {}

    return server_flex_model

Once we've defined the function to initialize the server model, we can create the FlexPool using the function `client_server_architecture`.

In [None]:
flex_pool = FlexPool.client_server_pool(fed_dataset=flex_dataset, init_func=build_server_model)

clients = flex_pool.clients
servers = flex_pool.servers
aggregators = flex_pool.aggregators

print(f"Number of nodes in the pool {len(flex_pool)}: {len(servers)} server plus {len(clients)} clients. The server is also an aggregator")

We can use the decorator `deploy_server_model` to create a custom function that deploys our server model, or we can use the primitive `deploy_server_model_pt` to deploy the server model to the clients.

In [None]:
from flex.pool import deploy_server_model, deploy_server_model_pt

@deploy_server_model
def copy_server_model_to_clients(server_flex_model: FlexModel):
    return deepcopy(server_flex_model)

In [None]:
servers.map(copy_server_model_to_clients, clients) # Using the function created with the decorator
# servers.map(deploy_server_model_pt, clients) # Using the primitive function

As text needs to be preprocessed and batched on the clients, we can do it on the train function.

In [None]:
import re
import random

from tqdm import tqdm

from torch.nn.utils.rnn import pad_sequence

BATCH_SIZE = 256
NUM_EPOCHS = 10

def clean_str(string):
    """
    Tokenization/string cleaning.
    Original from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
    """
    string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
    string = re.sub(r"\'s", " \'s", string)
    string = re.sub(r"\'ve", " \'ve", string)
    string = re.sub(r"n\'t", " n\'t", string)
    string = re.sub(r"\'re", " \'re", string)
    string = re.sub(r"\'d", " \'d", string)
    string = re.sub(r"\'ll", " \'ll", string)
    string = re.sub(r",", " , ", string)
    string = re.sub(r"!", " ! ", string)
    string = re.sub(r"\(", " \( ", string)
    string = re.sub(r"\)", " \) ", string)
    string = re.sub(r"\?", " \? ", string)
    string = re.sub(r"\s{2,}", " ", string)

    return string.strip().lower()

def collate_batch(batch):
    def preprocess_text(text):
        text_transform = lambda x: [vocabulary["<pad>"]]+[vocabulary[token] for token in spacy_tokenizer(x)]+[vocabulary["<pad>"]]
        return list(text_transform(clean_str(text)))
    label_list, text_list = [], []
    for (_text, _label) in batch:
        label_transform = lambda x: int(x) - 1
        label_list.append(label_transform(_label))
        processed_text = torch.tensor(preprocess_text(_text))
        text_list.append(processed_text)
    label_list = torch.tensor(label_list, dtype=torch.int64)
    return pad_sequence(text_list, padding_value=pad_index, batch_first=True), label_list

def batch_sampler_v2(batch_size, indices):
    random.shuffle(indices)
    pooled_indices = []
    # create pool of indices with similar lengths 
    for i in range(0, len(indices), batch_size * 100):
        pooled_indices.extend(sorted(indices[i:i + batch_size * 100], key=lambda x: x[1]))

    pooled_indices = [x[0] for x in pooled_indices]

    # yield indices for current batch
    for i in range(0, len(pooled_indices), batch_size):
        yield pooled_indices[i:i + batch_size]

def train(client_flex_model: FlexModel, client_data: Dataset):
    X_data, y_data = client_data.to_list()
    if 'train_indices' not in client_flex_model:
        train_indices = [(i, len(tokenizer(s[0]))) for i, s in enumerate(X_data)]
        client_flex_model['train_indices'] = train_indices
    else:
        train_indices = client_flex_model['train_indices']
    # batch_size=BATCH_SIZE, shuffle=True, # No es necesario usarlo porque usamos el batch_sampler
    client_dataloader = DataLoader(client_data, collate_fn=collate_batch, batch_size=BATCH_SIZE,
                                    shuffle=True)
    #                             batch_sampler=batch_sampler_v2(BATCH_SIZE, train_indices))
    model = client_flex_model["model"]
    # lr = 0.001
    optimizer = client_flex_model['optimizer_func'](model.parameters(), lr=0.1, **client_flex_model["optimizer_kwargs"])
    model = model.train()
    model = model.to(device)
    criterion = client_flex_model["criterion"]
    # Al usar batch_sampler, hay que recargar el DataLoader en cada epoch.
    for _ in tqdm(range(NUM_EPOCHS)):
        # client_dataloader = DataLoader(client_data, collate_fn=collate_batch,
        #                             batch_sampler=batch_sampler_v2(BATCH_SIZE, train_indices))
        losses = []
        total_acc, total_count = 0, 0
        for texts, labels in client_dataloader:
            optimizer.zero_grad()
            texts, labels = texts.to(device), labels.to(device)
            predicted_labels = model(texts).squeeze(dim=0)
            # pred = pred.squeeze(dim=0)
            loss = criterion(predicted_labels, labels)
            if predicted_labels.isnan().any():
                print(f"Text in batch: {texts}")
                print(f"Predicted labels in batch: {predicted_labels}")
                print(f"Labels in batch: {labels}")
                print(f"Loss in batch: {loss}")
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
            optimizer.step()
            losses.append(loss.item())
            total_acc += (predicted_labels.argmax(1) == labels).sum().item()
            total_count += labels.shape[0]
        total_loss = sum(losses)/len(losses)
        total_acc = total_acc/total_count
        print(f"Accuracy after epoch: {total_acc}\t|\tLoss after epoch: {total_loss}")

In [None]:
clients.map(train)

After training the model, we have to aggregate the weights from the clients model in order to update the global model. To to so, we are going to use the primitive `collect_clients_weights_pt`.

In [None]:
from flex.pool import collect_clients_weights_pt

aggregators.map(collect_clients_weights_pt, clients)

Once the weights are aggregated, we aggregate them. In this notebook we use the FedAvg method that is already implemented in FLEXible.

In [None]:
from flex.pool import fed_avg

aggregators.map(fed_avg)

The function `set_aggregated_weights_pt` sed the aggregated weights to the server model to update it.

In [None]:
from flex.pool import set_aggregated_weights_pt

aggregators.map(set_aggregated_weights_pt, servers)

Now it's turn to evaluate the global model. To do so, we have to create a function using the decoratod `evaluate_server_model`.

In [None]:

from flex.pool import evaluate_server_model

@evaluate_server_model
def evaluate_global_model(server_flex_model: FlexModel, test_data=None):
    model = server_flex_model["model"]
    model.eval()
    test_loss = 0
    test_acc = 0
    total_count = 0
    model = model.to(device)
    criterion=server_flex_model['criterion']
    # get test data as a torchvision object
    test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=True, pin_memory=False, collate_fn=collate_batch)
    X_data, _ = test_dataset.to_list()
    test_indices = [(i, len(tokenizer(s[0]))) for i, s in enumerate(X_data)]
    test_dataloader = DataLoader(test_dataset, collate_fn=collate_batch,
                                    batch_sampler=batch_sampler_v2(BATCH_SIZE, test_indices))
    losses = []
    with torch.no_grad():
        for data, target in test_dataloader:
            total_count += target.size(0)
            data, target = data.to(device), target.to(device)
            output = model(data).squeeze(dim=0)
            loss = criterion(output, target)
            losses.append(loss.item())
            test_acc += (output.argmax(1) == target).sum().item()
            total_count += target.shape[0]
            # print(f"Prediciones: {pred.squeeze(dim=1)}")
            # print(f"Output: {output.data.max(1, keepdim=True)}")
            # print(f"Target: {target}")
            # print(pred.eq(target.data.view_as(pred)).long().cpu().sum().item())
            # test_acc += pred.eq(target.data.view_as(pred)).long().cpu().sum().item()
            # print(f"Test accuracy: {test_acc}")

    test_loss = sum(losses) / len(losses)
    test_acc /= total_count
    print(f"test loss: {test_loss}")
    print(f"test acc: {test_acc}")
    return test_loss, test_acc

In [None]:
metrics = servers.map(evaluate_global_model, test_data=test_dataset)

In [None]:
metrics[0]

### Run the federated learning experiment for a few rounds

Now, we can summarize the steps provided above and run the federated experiment for multiple rounds:

In [None]:
def train_n_rounds(n_rounds, clients_per_round=2):  
    pool = FlexPool.client_server_pool(fed_dataset=flex_dataset, init_func=build_server_model)
    for i in range(n_rounds):
        print(f"\nRunning round: {i+1} of {n_rounds}")
        selected_clients_pool = pool.clients.select(clients_per_round)
        selected_clients = selected_clients_pool.clients
        print(f"Selected clients for this round: {len(selected_clients)}")
        # Deploy the server model to the selected clients
        pool.servers.map(deploy_server_model_pt, selected_clients)
        # Each selected client trains her model
        selected_clients.map(train)
        # The aggregador collects weights from the selected clients and aggregates them
        pool.aggregators.map(collect_clients_weights_pt, selected_clients)
        pool.aggregators.map(fed_avg)
        # The aggregator send its aggregated weights to the server
        pool.aggregators.map(set_aggregated_weights_pt, pool.servers)
        metrics = pool.servers.map(evaluate_global_model, test_data=test_imdb_dataset)
        loss, acc = metrics[0]
        print(f"Server: Test acc: {acc:.4f}, test loss: {loss:.4f}")

In [None]:
# train_n_rounds(5)