In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path
from typing import Any
from logging import INFO, DEBUG
import json
import os
import numpy as np
from torch import nn
from torch.utils.data import DataLoader
from scipy.signal import medfilt
from flwr.common import log, ndarrays_to_parameters
import matplotlib.pyplot as plt
import math

from src.common.client_utils import (
    load_femnist_dataset,
    get_network_generator_cnn as get_network_generator,
    get_device,
    get_model_parameters,
    aggregate_weighted_average,
)


from src.flwr_core import (
    set_all_seeds,
    get_paths,
    decompress_dataset,
    get_flower_client_generator,
    sample_random_clients,
    get_federated_evaluation_function,
    create_iid_partition,
)

from src.estimate import (
    compute_critical_batch,
)

from src.experiments_simulation import (
    run_simulation,
    centralized_experiment,
)

from src.utils import get_centralized_acc_from_hist

PathType = Path | str | None

  from .autonotebook import tqdm as notebook_tqdm
2025-03-13 21:01:06,879	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [3]:
set_all_seeds()

PATHS = get_paths()

HOME_DIR = PATHS["home_dir"]
DATASET_DIR = PATHS["dataset_dir"]
DATA_DIR = PATHS["data_dir"]
CENTRALIZED_PARTITION = PATHS["centralized_partition"]
CENTRALIZED_MAPPING = PATHS["centralized_mapping"]
FEDERATED_PARTITION = PATHS["federated_partition"]
FEDERATED_IID_PARTITION = PATHS["iid_partition"]

# extract dataset from tar.gz
decompress_dataset(PATHS)

In [4]:
create_iid_partition(PATHS, num_clients=3229)

In [5]:
NETWORK_GENERATOR = get_network_generator()
SEED_NET = NETWORK_GENERATOR()
SEED_MODEL_PARAMS = get_model_parameters(SEED_NET)
CID_CLIENT_GENERATOR = get_flower_client_generator(NETWORK_GENERATOR, FEDERATED_IID_PARTITION)

In [6]:
# Centralized experiments
centralized_experiment_batch_sizes = [32, 64, 128, 256, 512, 1024]

# Load the centralized dataset using the same function as in FL.
# The centralized mapping folder should be the one used in the FL centralized experiment.
centralized_train_dataset = load_femnist_dataset(data_dir=DATA_DIR,mapping=CENTRALIZED_MAPPING, name="train")
centralized_test_dataset = load_femnist_dataset(data_dir=DATA_DIR, mapping=CENTRALIZED_MAPPING, name="test")

centralized_train_config = {
    "epochs": 10,
    "batch_size": 32,
    "client_learning_rate": 0.01,
    "weight_decay": 0,
    "num_workers": 0,
    "max_batches": 100,
}

centralized_test_config = {
    "batch_size": 32,
    "num_workers": 0,
    "max_batches": 100,
    "target_accuracy": 0.70,
}

In [7]:
# FL experiments
experiment_batch_sizes = [16, 32, 64, 128, 256]
cohort_sizes = [5, 10, 20, 50, 75, 100]


# Federated configuration dictionary
federated_train_config = {
    "epochs": 10,
    "batch_size": 32,
    "client_learning_rate": 0.01,
    "weight_decay": 0,
    "num_workers": 0,
    "max_batches": 100,
}

federated_test_config: dict[str, Any] = {
    "batch_size": 32,
    "num_workers": 0,
    "max_batches": 100,
}

num_rounds = 10
num_total_clients = 100
num_evaluate_clients = 0
num_clients_per_round = 10

initial_parameters = ndarrays_to_parameters(SEED_MODEL_PARAMS)

federated_evaluation_function = get_federated_evaluation_function(
    batch_size=federated_test_config["batch_size"],
    num_workers=federated_test_config["num_workers"],
    model_generator=NETWORK_GENERATOR,
    criterion=nn.CrossEntropyLoss(),
    max_batches=None if "max_batches" not in federated_test_config else federated_test_config["max_batches"],
)

server_learning_rate = 1.0
server_momentum = 0.0
accept_failures = False


CID_CLIENT_GENERATOR = get_flower_client_generator(NETWORK_GENERATOR, FEDERATED_IID_PARTITION)

list_of_ids = sample_random_clients(
    num_total_clients, federated_train_config["batch_size"],
    CID_CLIENT_GENERATOR,
)

federated_client_generator = (
    get_flower_client_generator(
        NETWORK_GENERATOR, FEDERATED_IID_PARTITION, lambda seq_id: list_of_ids[seq_id]
    )
)

INFO flwr 2025-03-13 21:01:30,559 | flwr_core.py:107 | cid: 2530


INFO flwr 2025-03-13 21:01:30,669 | flwr_core.py:107 | cid: 2184
INFO flwr 2025-03-13 21:01:30,671 | flwr_core.py:107 | cid: 2907
INFO flwr 2025-03-13 21:01:30,673 | flwr_core.py:107 | cid: 1498
INFO flwr 2025-03-13 21:01:30,675 | flwr_core.py:107 | cid: 2338
INFO flwr 2025-03-13 21:01:30,677 | flwr_core.py:107 | cid: 2399
INFO flwr 2025-03-13 21:01:30,679 | flwr_core.py:107 | cid: 2997
INFO flwr 2025-03-13 21:01:30,681 | flwr_core.py:107 | cid: 678
INFO flwr 2025-03-13 21:01:30,683 | flwr_core.py:107 | cid: 3175
INFO flwr 2025-03-13 21:01:30,685 | flwr_core.py:107 | cid: 1363
INFO flwr 2025-03-13 21:01:30,687 | flwr_core.py:107 | cid: 1571
INFO flwr 2025-03-13 21:01:30,689 | flwr_core.py:107 | cid: 2600
INFO flwr 2025-03-13 21:01:30,691 | flwr_core.py:107 | cid: 1473
INFO flwr 2025-03-13 21:01:30,693 | flwr_core.py:107 | cid: 1260
INFO flwr 2025-03-13 21:01:30,695 | flwr_core.py:107 | cid: 1603
INFO flwr 2025-03-13 21:01:30,697 | flwr_core.py:107 | cid: 2855
INFO flwr 2025-03-13 21:01

# Federated varyig local batch size

In [8]:
metric_keys = ['training_time', 'samples_processed', 'noise_scale', 'train_loss', 'actual_batches']

In [9]:
B_simples = []
results = []
batch_sizes = [16, 32, 64, 128, 256]
for batch_size in batch_sizes:
    train_cfg = federated_train_config.copy()
    train_cfg["batch_size"] = batch_size
    ratio = np.sqrt(batch_size / 256)
    train_cfg["client_learning_rate"] = ratio * 0.01 # Same as centralized, but should be lower for FL

    test_cfg = federated_test_config.copy()
    test_cfg["batch_size"] = batch_size

    local_list_of_ids = sample_random_clients(num_total_clients, train_cfg["batch_size"], CID_CLIENT_GENERATOR)
    local_federated_client_generator = get_flower_client_generator(NETWORK_GENERATOR, FEDERATED_PARTITION, lambda seq_id: local_list_of_ids[seq_id])

    parameters_for_each_round, hist = run_simulation(
        num_rounds = 10,
        num_total_clients = num_total_clients,
        num_clients_per_round = num_clients_per_round,
        num_evaluate_clients = num_evaluate_clients,
        min_available_clients = num_total_clients,
        min_fit_clients = num_clients_per_round,
        min_evaluate_clients = num_evaluate_clients,
        evaluate_fn = federated_evaluation_function,
        on_fit_config_fn = lambda _: train_cfg,
        on_evaluate_config_fn = lambda _: test_cfg,
        initial_parameters = initial_parameters,
        fit_metrics_aggregation_fn = aggregate_weighted_average,
        evaluate_metrics_aggregation_fn = aggregate_weighted_average,
        federated_client_generator = local_federated_client_generator,
        server_learning_rate=server_learning_rate,
        server_momentum=server_momentum,
        accept_failures=accept_failures,
        target_accuracy=0.60,
        use_target_accuracy=True
        )
    n_params = len(hist.metrics_distributed_fit.keys()) - 5
    param_keys = list(set(hist.metrics_distributed_fit.keys()) - set(metric_keys))
    hist_metrics = {key: hist.metrics_distributed_fit[key] for key in metric_keys}
    params = [hist.metrics_distributed_fit[key] for key in param_keys]
    del hist
    #gc.collect()

    res = (batch_size, parameters_for_each_round, hist_metrics, params)
    results.append(res)

INFO flwr 2025-03-13 21:01:30,938 | flwr_core.py:107 | cid: 2530
INFO flwr 2025-03-13 21:01:30,941 | flwr_core.py:107 | cid: 2184
INFO flwr 2025-03-13 21:01:30,944 | flwr_core.py:107 | cid: 2907
INFO flwr 2025-03-13 21:01:30,946 | flwr_core.py:107 | cid: 1498
INFO flwr 2025-03-13 21:01:30,948 | flwr_core.py:107 | cid: 2338
INFO flwr 2025-03-13 21:01:30,950 | flwr_core.py:107 | cid: 2399
INFO flwr 2025-03-13 21:01:30,953 | flwr_core.py:107 | cid: 2997
INFO flwr 2025-03-13 21:01:30,955 | flwr_core.py:107 | cid: 678


INFO flwr 2025-03-13 21:01:30,959 | flwr_core.py:107 | cid: 3175
INFO flwr 2025-03-13 21:01:30,961 | flwr_core.py:107 | cid: 1363
INFO flwr 2025-03-13 21:01:30,964 | flwr_core.py:107 | cid: 1571
INFO flwr 2025-03-13 21:01:30,966 | flwr_core.py:107 | cid: 2600
INFO flwr 2025-03-13 21:01:30,968 | flwr_core.py:107 | cid: 1473
INFO flwr 2025-03-13 21:01:30,970 | flwr_core.py:107 | cid: 1260
INFO flwr 2025-03-13 21:01:30,972 | flwr_core.py:107 | cid: 1603
INFO flwr 2025-03-13 21:01:30,975 | flwr_core.py:107 | cid: 2855
INFO flwr 2025-03-13 21:01:30,977 | flwr_core.py:107 | cid: 839
INFO flwr 2025-03-13 21:01:30,979 | flwr_core.py:107 | cid: 3119
INFO flwr 2025-03-13 21:01:30,981 | flwr_core.py:107 | cid: 2688
INFO flwr 2025-03-13 21:01:30,984 | flwr_core.py:107 | cid: 1494
INFO flwr 2025-03-13 21:01:30,986 | flwr_core.py:107 | cid: 447
INFO flwr 2025-03-13 21:01:30,988 | flwr_core.py:107 | cid: 1742
INFO flwr 2025-03-13 21:01:30,990 | flwr_core.py:107 | cid: 2601
INFO flwr 2025-03-13 21:01:

RuntimeError: Simulation crashed.