In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Option 1: Add project root to sys.path
import sys
from pathlib import Path

# Go up two levels from notebooks/centralized to reach project root
project_root = str(Path().absolute().parent.parent)
if project_root not in sys.path:
    sys.path.insert(0, project_root)


In [3]:
from pathlib import Path
from typing import Any
from logging import INFO, DEBUG
import json
import os
import numpy as np
from torch import nn
from torch.utils.data import DataLoader
from scipy.signal import medfilt
from flwr.common import log, ndarrays_to_parameters
import matplotlib.pyplot as plt
import math

from src.common.client_utils import (
    load_femnist_dataset,
    get_network_generator_cnn as get_network_generator,
    get_device,
    get_model_parameters,
    aggregate_weighted_average,
)


from src.flwr_core import (
    set_all_seeds,
    get_paths,
    decompress_dataset,
    get_flower_client_generator,
    sample_random_clients,
    get_federated_evaluation_function,
    create_iid_partition
)

from src.estimate import (
    compute_critical_batch,
)

from src.experiments_simulation import (
    run_simulation,
    centralized_experiment,
)

from src.utils import get_centralized_acc_from_hist

PathType = Path | str | None

  from .autonotebook import tqdm as notebook_tqdm
2025-03-18 21:44:44,798	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [4]:
set_all_seeds()

PATHS = get_paths()

HOME_DIR = PATHS["home_dir"]
DATASET_DIR = PATHS["dataset_dir"]
DATA_DIR = PATHS["data_dir"]
CENTRALIZED_PARTITION = PATHS["centralized_partition"]
CENTRALIZED_MAPPING = PATHS["centralized_mapping"]
FEDERATED_PARTITION = PATHS["federated_partition"]
FEDERATED_IID_PARTITION = PATHS["iid_partition"]

# extract dataset from tar.gz
decompress_dataset(PATHS)

In [5]:
max_clients = 1000
create_iid_partition(PATHS, num_clients=max_clients)

Current number of clients in iid partition: 1000, expected: 1000
IID partition already exists and has the correct number of clients


In [6]:
NETWORK_GENERATOR = get_network_generator()
SEED_NET = NETWORK_GENERATOR()
SEED_MODEL_PARAMS = get_model_parameters(SEED_NET)
CID_CLIENT_GENERATOR = get_flower_client_generator(NETWORK_GENERATOR, FEDERATED_IID_PARTITION)

In [7]:
# FL experiments
experiment_batch_sizes = [32, 64, 128, 256, 512]
cohort_sizes = [5, 10, 20, 50, 75, 100, 150]


# Federated configuration dictionary
federated_train_config = {
    "epochs": 10,
    "batch_size": 32,
    "client_learning_rate": 0.01,
    "weight_decay": 0,
    "num_workers": 0,
    "max_batches": 100,
}

federated_test_config: dict[str, Any] = {
    "batch_size": 32,
    "num_workers": 0,
    "max_batches": 100,
}

num_rounds = 10
num_total_clients = 100
num_evaluate_clients = 0
num_clients_per_round = 10

initial_parameters = ndarrays_to_parameters(SEED_MODEL_PARAMS)

federated_evaluation_function = get_federated_evaluation_function(
    batch_size=federated_test_config["batch_size"],
    num_workers=federated_test_config["num_workers"],
    model_generator=NETWORK_GENERATOR,
    criterion=nn.CrossEntropyLoss(),
    max_batches=None if "max_batches" not in federated_test_config else federated_test_config["max_batches"],
)

server_learning_rate = 1.0
server_momentum = 0.0
accept_failures = False


CID_CLIENT_GENERATOR = get_flower_client_generator(NETWORK_GENERATOR, FEDERATED_IID_PARTITION)

list_of_ids = sample_random_clients(
    num_total_clients, federated_train_config["batch_size"],
    CID_CLIENT_GENERATOR, max_clients=max_clients,
)

federated_client_generator = (
    get_flower_client_generator(
        NETWORK_GENERATOR, FEDERATED_IID_PARTITION, lambda seq_id: list_of_ids[seq_id]
    )
)

INFO flwr 2025-03-18 21:44:48,059 | flwr_core.py:107 | cid: 632


INFO flwr 2025-03-18 21:44:48,149 | flwr_core.py:107 | cid: 947
INFO flwr 2025-03-18 21:44:48,156 | flwr_core.py:107 | cid: 546
INFO flwr 2025-03-18 21:44:48,164 | flwr_core.py:107 | cid: 726
INFO flwr 2025-03-18 21:44:48,170 | flwr_core.py:107 | cid: 374
INFO flwr 2025-03-18 21:44:48,176 | flwr_core.py:107 | cid: 584
INFO flwr 2025-03-18 21:44:48,181 | flwr_core.py:107 | cid: 599
INFO flwr 2025-03-18 21:44:48,187 | flwr_core.py:107 | cid: 749
INFO flwr 2025-03-18 21:44:48,193 | flwr_core.py:107 | cid: 169
INFO flwr 2025-03-18 21:44:48,198 | flwr_core.py:107 | cid: 793
INFO flwr 2025-03-18 21:44:48,204 | flwr_core.py:107 | cid: 844
INFO flwr 2025-03-18 21:44:48,210 | flwr_core.py:107 | cid: 340
INFO flwr 2025-03-18 21:44:48,217 | flwr_core.py:107 | cid: 392
INFO flwr 2025-03-18 21:44:48,223 | flwr_core.py:107 | cid: 650
INFO flwr 2025-03-18 21:44:48,229 | flwr_core.py:107 | cid: 808
INFO flwr 2025-03-18 21:44:48,235 | flwr_core.py:107 | cid: 368
INFO flwr 2025-03-18 21:44:48,242 | flwr

In [8]:
import pickle

def save_experiment(save_file_name, batch_size, parameters_for_each_round, hist):
    results_dict = {
        'batch_size': batch_size,
        'parameters_for_each_round': parameters_for_each_round,
        'history': hist
    }
    
    with open(save_file_name, 'wb') as f:  # Note: 'wb' for binary write mode
        pickle.dump(results_dict, f)

def load_experiment(file_name):
    with open(file_name, 'rb') as f:  # Note: 'rb' for binary read mode
        results_dict = pickle.load(f)
    
    return (
        results_dict['batch_size'],
        results_dict['parameters_for_each_round'],
        results_dict['history'],
    )

In [9]:


total_global_batch_results = []
cs_bs_pairs = [(5, 20), (20, 50), (50, 200), (100, 250), (100, 1000), (100, 2000), (100, 4000), (100, 12000)]
cs_bs_pairs = [(5, 16), (10, 32), (20, 64), (50, 128), (75, 256), (100, 512)]

for cohort_size, batch_size in cs_bs_pairs:
    global_batch_size = batch_size * cohort_size
    train_cfg = federated_train_config.copy()
    ratio = np.sqrt(cohort_size * batch_size / 1e6)
    # ratio = 100 / 100 = 1 * 0.01
    # if i multiply by batch size, i want to divide
    train_cfg["client_learning_rate"] = ratio * 0.01

    test_cfg = federated_test_config.copy()

    parameters_for_each_round, hist = run_simulation(
        num_rounds = 10,
        num_total_clients = num_total_clients,
        num_clients_per_round = cohort_size,
        num_evaluate_clients = num_evaluate_clients,
        min_available_clients = num_total_clients,
        min_fit_clients = cohort_size,
        min_evaluate_clients = num_evaluate_clients,
        evaluate_fn = federated_evaluation_function,
        on_fit_config_fn = lambda _: train_cfg,
        on_evaluate_config_fn = lambda _: test_cfg,
        initial_parameters = initial_parameters,
        fit_metrics_aggregation_fn = aggregate_weighted_average,
        evaluate_metrics_aggregation_fn = aggregate_weighted_average,
        federated_client_generator = federated_client_generator,
        server_learning_rate=server_learning_rate,
        server_momentum=server_momentum,
        accept_failures=accept_failures,
        target_accuracy=0.70,
        use_target_accuracy=True,
        )

    total_global_batch_results.append((global_batch_size, parameters_for_each_round, hist))
    save_experiment(f"federated_iid_global_batch_results_{global_batch_size}.pkl", global_batch_size, parameters_for_each_round=parameters_for_each_round, hist=hist)

INFO flwr 2025-03-18 21:44:48,828 | experiments_simulation.py:232 | FL will execute for 10 rounds
INFO flwr 2025-03-18 21:44:48,831 | app.py:149 | Starting Flower simulation, config: ServerConfig(num_rounds=10, round_timeout=None)
INFO flwr 2025-03-18 21:44:48,832 | flwr_core.py:264 | Initializing global parameters
INFO flwr 2025-03-18 21:44:48,833 | server_returns_parameters.py:273 | Using initial parameters provided by strategy
INFO flwr 2025-03-18 21:44:48,834 | flwr_core.py:269 | Evaluating initial parameters


 11%|█         | 100/891 [00:02<00:19, 39.90it/s]
INFO flwr 2025-03-18 21:44:51,774 | flwr_core.py:272 | initial parameters (loss, other metrics): 413.6843070983887, {'accuracy': 0.0065625}
INFO flwr 2025-03-18 21:44:51,774 | flwr_core.py:280 | FL starting - Target accuracy: 0.7
DEBUG flwr 2025-03-18 21:44:51,775 | server_returns_parameters.py:223 | fit_round 1: strategy sampled 5 clients (out of 100)
INFO flwr 2025-03-18 21:44:51,776 | flwr_core.py:107 | cid: 985
INFO flwr 2025-03-18 21:44:51,779 | flwr_core.py:107 | cid: 68
INFO flwr 2025-03-18 21:44:51,782 | flwr_core.py:107 | cid: 291
INFO flwr 2025-03-18 21:44:51,783 | flwr_core.py:107 | cid: 852
INFO flwr 2025-03-18 21:44:51,785 | flwr_core.py:107 | cid: 993
DEBUG flwr 2025-03-18 21:45:04,947 | server_returns_parameters.py:237 | fit_round 1 received 5 results and 0 failures
 11%|█         | 100/891 [00:02<00:16, 47.02it/s]
INFO flwr 2025-03-18 21:45:07,284 | flwr_core.py:303 | fit progress: (round 1, accuracy 0.079375, loss 356.2