# Imports

In [None]:
import os, sys, logging
os.environ["RAY_DEDUP_LOGS"] = "0"

import pandas as pd

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../")))
from configuration import Configuration
from os.path import join as path
from warnings import simplefilter
simplefilter(action="ignore", category=pd.errors.PerformanceWarning)
simplefilter(action="ignore", category=RuntimeWarning)

from flwr.client import ClientApp
from flwr.common import Context, ndarrays_to_parameters
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAdagrad, FedAvg, FedProx, FedYogi, FedAdam
from flwr.simulation import run_simulation
from torch.utils.data import DataLoader

from federated_metrics import (
    MetricsTracker,
    evaluate_metrics_aggregation_fn, 
    fit_metrics_aggregation_fn, 
    on_fit_config_fn, 
    on_evaluate_config_fn,
    evaluate_fn
)

from federated_clients import (
    get_parameters,
    FederatedClient
)

from model import Model

# Configuration
c = Configuration()

# logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
)
logger = logging.getLogger(__name__)

# Controll Panel

Variables and such

In [None]:
SCENARIO = "realworld-A"
DATASET = "thesis-simulation"
AGGREGATE_ALGOS = [
    ("FedAvg", FedAvg),
    ("FedProx", FedProx),
    ("FedAdam", FedAdam),
    ("FedYogi", FedYogi),
    ("FedAdagrad", FedAdagrad)
]

In [None]:
for algo, algo_class in AGGREGATE_ALGOS:
    path_case_home = path(c.path_results, "scenarios", SCENARIO, algo)
    os.makedirs(path_case_home, exist_ok=True)

    features = c.features
    columns = features + c.appl

    in_features = len(features)
    out_features = len(c.classes)

    metrics_tracker = MetricsTracker(SCENARIO, algo)

    model_test = Model(in_features, out_features)
    
    def client_fn(context: Context):
        client_id = context.node_config["partition-id"]+1
        model = Model(in_features, out_features)
        if algo == "FedProx":
            return FederatedClient(model, client_id, DATASET, features, c.fedprox_mu).to_client()
        else:
            return FederatedClient(model, client_id, DATASET, features).to_client()

    # Create the ClientApp
    client = ClientApp(client_fn=client_fn)
    
    def server_fn(context: Context):
        config = ServerConfig(num_rounds=c.rounds)
        if algo == "FedProx":
            strategy = algo_class(
                fraction_fit=1.0,  # Sample 100% of available clients for training
                fraction_evaluate=1.0,  # Sample 100% of available clients for evaluation
                min_fit_clients=c.clients,  # Never sample less than X clients for training
                min_evaluate_clients=c.clients,  # Never sample less than X clients for evaluation
                min_available_clients=c.clients,  # Wait until X clients are available
                evaluate_fn=evaluate_fn(model_test, metrics_tracker),
                on_fit_config_fn=on_fit_config_fn,
                on_evaluate_config_fn=on_evaluate_config_fn,
                accept_failures=True,
                initial_parameters=ndarrays_to_parameters(get_parameters(Model(in_features, out_features))),
                fit_metrics_aggregation_fn=fit_metrics_aggregation_fn(metrics_tracker),
                evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn(metrics_tracker),
                proximal_mu=c.fedprox_mu
            )
        else:
            strategy = algo_class(
                fraction_fit=1.0,  # Sample 100% of available clients for training
                fraction_evaluate=1.0,  # Sample 100% of available clients for evaluation
                min_fit_clients=c.clients,  # Never sample less than X clients for training
                min_evaluate_clients=c.clients,  # Never sample less than X clients for evaluation
                min_available_clients=c.clients,  # Wait until X clients are available
                evaluate_fn=evaluate_fn(model_test, metrics_tracker),
                on_fit_config_fn=on_fit_config_fn,
                on_evaluate_config_fn=on_evaluate_config_fn,
                accept_failures=True,
                initial_parameters=ndarrays_to_parameters(get_parameters(Model(in_features, out_features))),
                fit_metrics_aggregation_fn=fit_metrics_aggregation_fn(metrics_tracker),
                evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn(metrics_tracker)
            )
        
        return ServerAppComponents(strategy=strategy, config=config)


    # Create the ServerApp
    server = ServerApp(server_fn=server_fn)
    
    backend_config = {"client_resources": {"num_cpus": 1.0, "num_gpus": 0.0}}
    run_simulation(
        server_app=server,
        client_app=client,
        num_supernodes=c.clients,
        backend_config=backend_config,
        verbose_logging=False
    )