In [1]:
import flwr as fl
from pathlib import Path
import os
from sources.flwr_parameters.default_strategy_configs import \
    femnist_fit_config, femnist_eval_config

from sources.datasets.femnist.femnist_client_dataset_factory import FemnistClientDatasetFactory
from sources.flwr_parameters.default_parameters import DEFAULT_SEED
from sources.models.femnist.femnist_model_template import FemnistModelTemplate
from sources.flwr_strategies.model_logging_strategy_decorator import ModelLoggingStrategyDecorator
from sources.flwr_strategies.evaluation_metrics_logging_strategy_decorator import EvaluationMetricsLoggingStrategyDecorator

In [2]:
from sources.simulation_framework.multiprocessing_simulator import MultiprocessingBasedSimulator
base_dir = Path(os.getcwd()).parent.parent
checkpoint_dir = base_dir / "checkpoints"
data_dir = base_dir / "data"

experiment_name = "initial_experiment"
model_saving_dir = str(checkpoint_dir / experiment_name / "models")
metrics_saving_dir = str(checkpoint_dir / experiment_name / "metrics")

In [3]:

simulation_parameters = {"num_rounds": 2, "num_clients": 2}
strategy = fl.server.strategy.FedAvg(on_fit_config_fn=femnist_fit_config, 
                                     on_evaluate_config_fn=femnist_eval_config)


strategy = EvaluationMetricsLoggingStrategyDecorator(
    strategy=strategy,
    metrics_logging_folder=metrics_saving_dir,
    experiment_identifier=experiment_name
)
strategy = ModelLoggingStrategyDecorator(
    strategy=strategy,
    model_saving_folder=model_saving_dir,
    experiment_identifier=experiment_name
)

model_template = FemnistModelTemplate(DEFAULT_SEED)
dataset_factory = FemnistClientDatasetFactory(data_dir)

simulator = MultiprocessingBasedSimulator(simulation_parameters, strategy, model_template, dataset_factory)

In [12]:
simulator.start_simulation()
