In [None]:
!pip install -q datasets
!pip install -q transformers
!pip install -q evaluate
!pip install -q protobuf==4.25.3
!pip install -q flwr["simulation"]


In [None]:
import warnings
from transformers import logging
import os

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
logging.set_verbosity(logging.ERROR)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
warnings.simplefilter('ignore')


In [None]:
import torch

DEVICE = torch.device("cpu")
CHECKPOINT = "google-bert/bert-base-uncased"  # transformer model checkpoint
NUM_CLIENTS = 1
NUM_ROUNDS = 3

Data Handler

In [None]:
import random
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding
from torch.utils.data import DataLoader

def load_data():
    """Load IMDB data (training and eval)"""
    raw_datasets = load_dataset("imdb")
    raw_datasets = raw_datasets.shuffle(seed=2)

    # remove unnecessary data split
    del raw_datasets["unsupervised"]

    tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
    def tokenize(examples):
        return tokenizer(examples["text"], truncation=True)

    # Select 20 random samples to reduce the computation cost
    train_population = random.sample(range(len(raw_datasets["train"])), 20)
    test_population = random.sample(range(len(raw_datasets["test"])), 20)

    tokenized_datasets = raw_datasets.map(tokenize, batched=True)

    tokenized_datasets["train"] = tokenized_datasets["train"].select(train_population)
    tokenized_datasets["test"] = tokenized_datasets["test"].select(test_population)

    tokenized_datasets = tokenized_datasets.remove_columns("text")
    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    train_loader = DataLoader(tokenized_datasets["train"],
                              shuffle=True,
                              batch_size=32,
                              collate_fn=data_collator)

    test_loader = DataLoader(tokenized_datasets["test"],
                             shuffle=False,
                             batch_size=32,
                             collate_fn=data_collator)

    return train_loader, test_loader


### Training and testing the model

In [None]:
from transformers import AdamW

def train(net, train_loader, epochs):
    optimizer = AdamW(net.parameters(), lr=5e-5)
    net.train()

    for _ in range(epochs):
        for batch in train_loader:
            batch = {k: v.to(DEVICE) for k, v in batch.items()}
            outputs = net(**batch)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

In [None]:
from evaluate import load as load_metric

def test(net, test_loader):
    metric = load_metric("accuracy")
    loss = 0
    net.eval()

    for batch in test_loader:
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        with torch.no_grad():
            outputs = net(**batch)

        logits = outputs.logits
        loss  += outputs.loss.items()
        predictions = torch.argmax(logits, dim=-1)
        metric.add_batch(predictions=predictions, references=batch["labels"])

    loss /= len(test_loader.dataset)
    accuracy = metric.compute()["accuracy"]

    return loss, accuracy

### Creating the model itself

In [None]:
from transformers import AutoModelForSequenceClassification

net = AutoModelForSequenceClassification.from_pretrained(CHECKPOINT, num_labels=2).to(DEVICE)

### Creating the IMDBClient

In [None]:
from collections import OrderedDict
import flwr as fl

class IMDBClient(fl.client.NumPyClient):
    def __init__(self, net, train_loader, test_loader):
        self.net = net
        self.train_loader = train_loader
        self.test_loader = test_loader

    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in self.net.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(self.net.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        self.net.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters)

        print("Trainging Started...")
        train(self.net, self.train_loader, epochs=1)

        print("Training Finished")
        return self.get_parameters(config={}), len(self.train_loader), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        loss, accuracy = test(self.net, self.test_lodaer)
        return float(loss), len(self.test_lodaer), {"accuracy": float(accuracy), "loss": float(loss)}

  and should_run_async(code)


### Generating the clients

In [None]:
train_loader, test_loader = load_data()

In [None]:
def client_fn(cid):
    return IMDBClient(net, train_loader, test_loader)

### Starting the simulation

In [None]:
def weighted_average(metrics):
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    losses = [num_examples * m["loss"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    return {"accuracy": sum(accuracies) / sum(examples),
            "loss": sum(losses) / sum(examples)}

In [None]:
strategy = fl.server.strategy.FedAvg(fraction_fit=1.0,
                                     fraction_evaluate=1.0,
                                     evaluate_metrics_aggregation_fn=weighted_average)

In [None]:
fl.simulation.start_simulation(client_fn=client_fn,
                               num_clients=NUM_CLIENTS,
                               config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
                               strategy=strategy,
                               client_resources={"num_cpus": 1, "num_gpus": 0},
                               ray_init_args={"log_to_driver": False,
                                              "num_cpus": 1,
                                              "num_gpus": 0})

INFO flwr 2024-03-10 10:31:01,692 | app.py:178 | Starting Flower simulation, config: ServerConfig(num_rounds=3, round_timeout=None)
INFO:flwr:Starting Flower simulation, config: ServerConfig(num_rounds=3, round_timeout=None)
2024-03-10 10:31:08,379	INFO worker.py:1621 -- Started a local Ray instance.
INFO flwr 2024-03-10 10:31:11,335 | app.py:213 | Flower VCE: Ray initialized with resources: {'CPU': 1.0, 'memory': 7901215950.0, 'object_store_memory': 3950607974.0, 'node:172.28.0.12': 1.0, 'node:__internal_head__': 1.0}
INFO:flwr:Flower VCE: Ray initialized with resources: {'CPU': 1.0, 'memory': 7901215950.0, 'object_store_memory': 3950607974.0, 'node:172.28.0.12': 1.0, 'node:__internal_head__': 1.0}
INFO flwr 2024-03-10 10:31:11,345 | app.py:219 | Optimize your simulation with Flower VCE: https://flower.dev/docs/framework/how-to-run-simulations.html
INFO:flwr:Optimize your simulation with Flower VCE: https://flower.dev/docs/framework/how-to-run-simulations.html
INFO flwr 2024-03-10 10: