In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from datasets import load_dataset

import flwr as fl
import torch
import math

  from .autonotebook import tqdm as notebook_tqdm
2023-12-03 16:57:00,812	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


## Data Load & Preprocessing

In [2]:
model_checkpoint = "distilgpt2"
datasets = load_dataset('wikitext', 'wikitext-103-raw-v1')

datasets["train"] = datasets["train"].select(range(len(datasets["train"]) // 2))

In [3]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)


def tokenize_function(examples):
    return tokenizer(examples["text"])


tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=["text"])

In [4]:
# block_size = tokenizer.model_max_length
block_size = 128


def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result


lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=1000,
    num_proc=4,
)

## Load Model

In [5]:
model = AutoModelForCausalLM.from_pretrained(model_checkpoint)

model_name = model_checkpoint.split("/")[-1]
training_args = TrainingArguments(
    f"{model_name}-finetuned-wikitext103_p1",
    evaluation_strategy = "epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    num_train_epochs=1,
    push_to_hub=False,
)

## Flwr Client

In [6]:
class LanguageModelClient(fl.client.NumPyClient):
	def __init__(self, model, trainer, train_dataset, eval_dataset):
		self.model = model
		self.trainer = trainer
		self.train_dataset = train_dataset
		self.eval_dataset = eval_dataset
		self.rounds = 0
	
	def get_parameters(self, config):
		# Convert PyTorch parameters to NumPy
		return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

	def set_parameters(self, parameters):
		# Convert NumPy parameters to PyTorch and set for the model
		state_dict = {k: torch.Tensor(v) for k, v in zip(self.model.state_dict().keys(), parameters)}
		self.model.load_state_dict(state_dict, strict=True)

	def fit(self, parameters, config):
		# Set the provided parameters and then train the model
		self.set_parameters(parameters)
		self.trainer.train()
		self.trainer.save_model(f"{model_name}-finetuned-wiki_p1-{self.rounds}")
		self.rounds += 1
		return self.get_parameters(config), len(self.train_dataset), {}

	def evaluate(self, parameters, config):
		# Set the provided parameters and then evaluate the model
		self.set_parameters(parameters)
		eval_result = self.trainer.evaluate(self.eval_dataset)
		return float(eval_result['eval_loss']), len(self.eval_dataset), {"perplexity": math.exp(eval_result['eval_loss'])}

In [7]:
client = LanguageModelClient(
    model,
    Trainer(
        model=model, args=training_args, train_dataset=lm_datasets["train"], eval_dataset=lm_datasets["validation"]
    ),
    lm_datasets["train"],
    lm_datasets["validation"],
)

In [8]:
fl.client.start_numpy_client(
    server_address="localhost:8080", 
    client=LanguageModelClient(
        model,
        Trainer(
            model=model, args=training_args, train_dataset=lm_datasets["train"], eval_dataset=lm_datasets["validation"]
        ),
        lm_datasets["train"],
        lm_datasets["validation"],
    ),
    grpc_max_message_length=1024 * 1024 * 1024,
)

INFO flwr 2023-12-03 16:57:06,852 | grpc.py:52 | Opened insecure gRPC connection (no certificates were passed)


DEBUG flwr 2023-12-03 16:57:06,855 | connection.py:42 | ChannelConnectivity.IDLE
DEBUG flwr 2023-12-03 16:57:06,857 | connection.py:42 | ChannelConnectivity.CONNECTING
DEBUG flwr 2023-12-03 16:57:06,860 | connection.py:42 | ChannelConnectivity.READY


Epoch,Training Loss,Validation Loss
1,3.5687,3.44846


Epoch,Training Loss,Validation Loss
1,3.4996,3.402641


Epoch,Training Loss,Validation Loss
1,3.4438,3.366887


DEBUG flwr 2023-12-03 19:27:04,779 | connection.py:141 | gRPC channel closed
INFO flwr 2023-12-03 19:27:04,780 | app.py:304 | Disconnect and shut down
INFO flwr 2023-12-03 19:27:04,780 | app.py:304 | Disconnect and shut down


--- 