In [6]:
import os 

os.environ['CUDA_VISIBLE_DEVICES'] = '3'

In [7]:
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from datasets import load_dataset

import logger
import flwr as fl
import torch
import math

In [None]:
class FedAvgWithModel(fl.server.strategy.FedAvg):
	def __init__(self, model, save_path, *args, **kwargs):
		super().__init__(*args, **kwargs)
		self.model = model
		self.save_path = save_path

	def aggregate_evaluate(
		self,
		server_round: int,
		results: List[Tuple[ClientProxy, EvaluateRes]],
		failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
	) -> Tuple[Optional[float], Dict[str, Scalar]]:
		"""Aggregate evaluation accuracy using weighted average."""

		if not results:
			return None, {}

		# Call aggregate_evaluate from base class (FedAvg) to aggregate loss and metrics
		aggregated_loss, aggregated_metrics = super().aggregate_evaluate(server_round, results, failures)

		# Weigh accuracy of each client by number of examples used
		accuracies = [r.metrics["accuracy"] * r.num_examples for _, r in results]
		examples = [r.num_examples for _, r in results]

		# Aggregate and print custom metric
		aggregated_accuracy = sum(accuracies) / sum(examples)
		print(f"Round {server_round} accuracy aggregated from client results: {aggregated_accuracy}")

		# Return aggregated loss and metrics (i.e., aggregated accuracy)
		return aggregated_loss, {"accuracy": aggregated_accuracy}

	def aggregate_fit(
		self,
		server_round: int,
		results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
		failures: List[Union[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes], BaseException]],
	) -> Tuple[Optional[fl.common.Parameters], Dict[str, fl.common.Scalar]]:
		"""Aggregate model weights using weighted average and store checkpoint"""

		# Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
		aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures)

		if aggregated_parameters is not None:
			print(f"Saving round {server_round} aggregated_parameters...")

			# Convert `Parameters` to `List[np.ndarray]`
			aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters)

			# Convert `List[np.ndarray]` to PyTorch`state_dict`
			params_dict = zip(self.model.state_dict().keys(), aggregated_ndarrays)
			state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
			self.model.load_state_dict(state_dict, strict=True)

			# Save the model in SafeTensors format
			self.model.save_pretrained(f"{self.save_path}/round_{server_round}", safe=True)
		return aggregated_parameters, aggregated_metrics

In [8]:
logger.configure("1234", filename=f"./server.log")

fl.server.start_server(
    server_address="0.0.0.0:8080",
    config=fl.server.ServerConfig(num_rounds=3),
    strategy=fl.server.strategy.FedAvg(),
    grpc_max_message_length=1024 * 1024 * 1024,
)

INFO flwr 2023-12-03 16:56:28,460 | app.py:163 | Starting Flower server, config: ServerConfig(num_rounds=3, round_timeout=None)
INFO flwr 2023-12-03 16:56:28,460 | app.py:163 | Starting Flower server, config: ServerConfig(num_rounds=3, round_timeout=None)
INFO flwr 2023-12-03 16:56:28,496 | app.py:176 | Flower ECE: gRPC server running (3 rounds), SSL is disabled
INFO flwr 2023-12-03 16:56:28,496 | app.py:176 | Flower ECE: gRPC server running (3 rounds), SSL is disabled
INFO flwr 2023-12-03 16:56:28,497 | server.py:89 | Initializing global parameters
INFO flwr 2023-12-03 16:56:28,497 | server.py:89 | Initializing global parameters
INFO flwr 2023-12-03 16:56:28,498 | server.py:276 | Requesting initial parameters from one random client
INFO flwr 2023-12-03 16:56:28,498 | server.py:276 | Requesting initial parameters from one random client
INFO flwr 2023-12-03 16:57:05,321 | server.py:280 | Received initial parameters from one random client
INFO flwr 2023-12-03 16:57:05,321 | server.py:280

History (loss, distributed):
	round 1: 4.365115014493489
	round 2: 4.478074505413008
	round 3: 4.46090363198044

--- 