In [1]:
import os 

os.environ['CUDA_VISIBLE_DEVICES'] = '3'

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from datasets import load_dataset

import flwr as fl
import torch
import math

  from .autonotebook import tqdm as notebook_tqdm
2023-11-29 16:13:04,125	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [3]:
class LanguageModelClient(fl.client.NumPyClient):
	def __init__(self, model, trainer, train_dataset, eval_dataset):
		self.model = model
		self.trainer = trainer
		self.train_dataset = train_dataset
		self.eval_dataset = eval_dataset
	
	def get_parameters(self, config):
		# Convert PyTorch parameters to NumPy
		return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

	def set_parameters(self, parameters):
		# Convert NumPy parameters to PyTorch and set for the model
		state_dict = {k: torch.Tensor(v) for k, v in zip(self.model.state_dict().keys(), parameters)}
		self.model.load_state_dict(state_dict, strict=True)

	def fit(self, parameters, config):
		# Set the provided parameters and then train the model
		self.set_parameters(parameters)
		self.trainer.train()
		return self.get_parameters(), len(self.train_dataset), {}

	def evaluate(self, parameters, config):
		# Set the provided parameters and then evaluate the model
		self.set_parameters(parameters)
		eval_result = self.trainer.evaluate(self.eval_dataset)
		return float(eval_result['eval_loss']), len(self.eval_dataset)

In [4]:
datasets = load_dataset('wikitext', 'wikitext-103-raw-v1')

In [5]:
model_checkpoint = "distilgpt2"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)

In [6]:
def tokenize_function(examples):
    return tokenizer(examples["text"])

tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=["text"])

In [7]:
# block_size = tokenizer.model_max_length
block_size = 128

In [8]:
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=1000,
    num_proc=4,
)

In [9]:
model = AutoModelForCausalLM.from_pretrained(model_checkpoint)

In [10]:
model_name = model_checkpoint.split("/")[-1]
training_args = TrainingArguments(
    f"{model_name}-finetuned-wikitext103",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    push_to_hub=False,
)

In [11]:
client = LanguageModelClient(
    model,
    Trainer(
        model=model, args=training_args, train_dataset=lm_datasets["train"], eval_dataset=lm_datasets["validation"]
    ),
    lm_datasets["train"],
    lm_datasets["validation"],
)

In [None]:
fl.server.start_server(
    server_address="0.0.0.0:8080", 
    config=fl.server.ServerConfig(num_rounds=3),
    strategy=fl.server.strategy.FedAvg(),
)

In [12]:
fl.client.start_numpy_client(
    server_address="localhost:8080",
    client=client,
)

INFO flwr 2023-11-29 16:13:34,935 | grpc.py:52 | Opened insecure gRPC connection (no certificates were passed)
DEBUG flwr 2023-11-29 16:13:34,937 | connection.py:42 | ChannelConnectivity.IDLE
DEBUG flwr 2023-11-29 16:13:34,939 | connection.py:42 | ChannelConnectivity.CONNECTING


DEBUG flwr 2023-11-29 16:13:34,942 | connection.py:42 | ChannelConnectivity.READY


Epoch,Training Loss,Validation Loss


DEBUG flwr 2023-11-29 16:16:29,083 | connection.py:141 | gRPC channel closed


KeyboardInterrupt: 

--- 

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"],
    eval_dataset=lm_datasets["validation"],
)

In [None]:
eval_results_before = trainer.evaluate()
print(f"Perplexity before training: {math.exp(eval_results_before['eval_loss']):.2f}")

In [None]:
trainer.train()

eval_results_after = trainer.evaluate()
print(f"Perplexity after training: {math.exp(eval_results_after['eval_loss']):.2f}")