In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Option 1: Add project root to sys.path
import sys
from pathlib import Path

# Go up two levels from notebooks/centralized to reach project root
project_root = str(Path().absolute().parent.parent)
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Now you can import from src normally
from src.common.client_utils import load_femnist_dataset

  from .autonotebook import tqdm as notebook_tqdm
2025-03-18 18:59:37,478	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [3]:
from pathlib import Path
from typing import Any
from logging import INFO, DEBUG
import json
import os
import numpy as np
from torch import nn
from torch.utils.data import DataLoader
from scipy.signal import medfilt
from flwr.common import log, ndarrays_to_parameters
import matplotlib.pyplot as plt
import math

from src.common.client_utils import (
    load_femnist_dataset,
    get_network_generator_cnn as get_network_generator,
    get_device,
    get_model_parameters,
    aggregate_weighted_average,
)


from src.flwr_core import (
    set_all_seeds,
    get_paths,
    decompress_dataset,
    get_flower_client_generator,
    sample_random_clients,
    get_federated_evaluation_function,
)

from src.estimate import (
    compute_critical_batch,
)

from src.experiments_simulation import (
    run_simulation,
    centralized_experiment,
)

from src.utils import get_centralized_acc_from_hist

PathType = Path | str | None

In [4]:
set_all_seeds()

PATHS = get_paths()

HOME_DIR = PATHS["home_dir"]
DATASET_DIR = PATHS["dataset_dir"]
DATA_DIR = PATHS["data_dir"]
CENTRALIZED_PARTITION = PATHS["centralized_partition"]
CENTRALIZED_MAPPING = PATHS["centralized_mapping"]
FEDERATED_PARTITION = PATHS["federated_partition"]

# extract dataset from tar.gz
decompress_dataset(PATHS)

In [5]:
NETWORK_GENERATOR = get_network_generator()
SEED_NET = NETWORK_GENERATOR()
SEED_MODEL_PARAMS = get_model_parameters(SEED_NET)
CID_CLIENT_GENERATOR = get_flower_client_generator(NETWORK_GENERATOR, FEDERATED_PARTITION)

In [6]:
# FL experiments
experiment_batch_sizes = [16, 32, 64, 128, 256]
cohort_sizes = [5, 10, 20, 50, 75, 100]


# Federated configuration dictionary
federated_train_config = {
    "epochs": 10,
    "batch_size": 32,
    "client_learning_rate": 0.01,
    "weight_decay": 0,
    "num_workers": 0,
    "max_batches": 100,
}

federated_test_config: dict[str, Any] = {
    "batch_size": 32,
    "num_workers": 0,
    "max_batches": 100,
}

num_rounds = 10
num_total_clients = 100
num_evaluate_clients = 0
num_clients_per_round = 10

initial_parameters = ndarrays_to_parameters(SEED_MODEL_PARAMS)

federated_evaluation_function = get_federated_evaluation_function(
    batch_size=federated_test_config["batch_size"],
    num_workers=federated_test_config["num_workers"],
    model_generator=NETWORK_GENERATOR,
    criterion=nn.CrossEntropyLoss(),
    max_batches=None if "max_batches" not in federated_test_config else federated_test_config["max_batches"],
)

server_learning_rate = 1.0
server_momentum = 0.0
accept_failures = False


CID_CLIENT_GENERATOR = get_flower_client_generator(NETWORK_GENERATOR, FEDERATED_PARTITION)

list_of_ids = sample_random_clients(
    num_total_clients, federated_train_config["batch_size"],
    CID_CLIENT_GENERATOR,
)

federated_client_generator = (
    get_flower_client_generator(
        NETWORK_GENERATOR, FEDERATED_PARTITION, lambda seq_id: list_of_ids[seq_id]
    )
)

INFO flwr 2025-03-18 18:59:41,454 | flwr_core.py:107 | cid: 2530


INFO flwr 2025-03-18 18:59:41,574 | flwr_core.py:107 | cid: 2184
INFO flwr 2025-03-18 18:59:41,579 | flwr_core.py:107 | cid: 2907
INFO flwr 2025-03-18 18:59:41,582 | flwr_core.py:107 | cid: 1498
INFO flwr 2025-03-18 18:59:41,586 | flwr_core.py:107 | cid: 2338
INFO flwr 2025-03-18 18:59:41,590 | flwr_core.py:107 | cid: 2399
INFO flwr 2025-03-18 18:59:41,593 | flwr_core.py:107 | cid: 2997
INFO flwr 2025-03-18 18:59:41,597 | flwr_core.py:107 | cid: 678
INFO flwr 2025-03-18 18:59:41,601 | flwr_core.py:107 | cid: 3175
INFO flwr 2025-03-18 18:59:41,604 | flwr_core.py:107 | cid: 1363
INFO flwr 2025-03-18 18:59:41,608 | flwr_core.py:107 | cid: 1571
INFO flwr 2025-03-18 18:59:41,611 | flwr_core.py:107 | cid: 2600
INFO flwr 2025-03-18 18:59:41,614 | flwr_core.py:107 | cid: 1473
INFO flwr 2025-03-18 18:59:41,617 | flwr_core.py:107 | cid: 1260
INFO flwr 2025-03-18 18:59:41,622 | flwr_core.py:107 | cid: 1603
INFO flwr 2025-03-18 18:59:41,625 | flwr_core.py:107 | cid: 2855
INFO flwr 2025-03-18 18:59

In [7]:
import pickle

def save_experiment(save_file_name, batch_size, parameters_for_each_round, hist):
    results_dict = {
        'batch_size': batch_size,
        'parameters_for_each_round': parameters_for_each_round,
        'history': hist
    }
    
    with open(save_file_name, 'wb') as f:  # Note: 'wb' for binary write mode
        pickle.dump(results_dict, f)

def load_experiment(file_name):
    with open(file_name, 'rb') as f:  # Note: 'rb' for binary read mode
        results_dict = pickle.load(f)
    
    return (
        results_dict['batch_size'],
        results_dict['parameters_for_each_round'],
        results_dict['history'],
    )

In [8]:
total_cohort_results = []
cohort_sizes =  [100] # [5, 10, 20, 50, 75, 100]
for cohort_size in cohort_sizes:
    train_cfg = federated_train_config.copy()
    ratio = np.sqrt(cohort_size / 100)
    train_cfg["client_learning_rate"] = ratio * 0.01
    #train_cfg["max_batches"] = 1000

    test_cfg = federated_test_config.copy()

    parameters_for_each_round, hist = run_simulation(
        num_rounds = 10,
        num_total_clients = num_total_clients,
        num_clients_per_round = cohort_size,
        num_evaluate_clients = num_evaluate_clients,
        min_available_clients = num_total_clients,
        min_fit_clients = cohort_size,
        min_evaluate_clients = num_evaluate_clients,
        evaluate_fn = federated_evaluation_function,
        on_fit_config_fn = lambda _: train_cfg,
        on_evaluate_config_fn = lambda _: test_cfg,
        initial_parameters = initial_parameters,
        fit_metrics_aggregation_fn = aggregate_weighted_average,
        evaluate_metrics_aggregation_fn = aggregate_weighted_average,
        federated_client_generator = federated_client_generator,
        server_learning_rate=server_learning_rate,
        server_momentum=server_momentum,
        accept_failures=accept_failures,
        target_accuracy=0.70,
        use_target_accuracy=True,
        )

    total_cohort_results.append((cohort_size, parameters_for_each_round, hist))
    save_experiment(f"federated_cohort_results_{cohort_size}.pkl", cohort_size, parameters_for_each_round=parameters_for_each_round, hist=hist)

INFO flwr 2025-03-18 18:59:42,042 | experiments_simulation.py:232 | FL will execute for 10 rounds
INFO flwr 2025-03-18 18:59:42,049 | app.py:149 | Starting Flower simulation, config: ServerConfig(num_rounds=10, round_timeout=None)
INFO flwr 2025-03-18 18:59:42,051 | flwr_core.py:264 | Initializing global parameters
INFO flwr 2025-03-18 18:59:42,053 | server_returns_parameters.py:273 | Using initial parameters provided by strategy
INFO flwr 2025-03-18 18:59:42,055 | flwr_core.py:269 | Evaluating initial parameters


 11%|█         | 100/891 [00:03<00:25, 31.05it/s]
INFO flwr 2025-03-18 18:59:45,836 | flwr_core.py:272 | initial parameters (loss, other metrics): 413.6843070983887, {'accuracy': 0.0065625}
INFO flwr 2025-03-18 18:59:45,837 | flwr_core.py:280 | FL starting - Target accuracy: 0.7
DEBUG flwr 2025-03-18 18:59:45,842 | server_returns_parameters.py:223 | fit_round 1: strategy sampled 100 clients (out of 100)
INFO flwr 2025-03-18 18:59:45,844 | flwr_core.py:107 | cid: 937
INFO flwr 2025-03-18 18:59:45,847 | flwr_core.py:107 | cid: 1167
INFO flwr 2025-03-18 18:59:45,850 | flwr_core.py:107 | cid: 2167
INFO flwr 2025-03-18 18:59:45,851 | flwr_core.py:107 | cid: 2471
INFO flwr 2025-03-18 18:59:45,862 | flwr_core.py:107 | cid: 1287
INFO flwr 2025-03-18 18:59:45,870 | flwr_core.py:107 | cid: 51
INFO flwr 2025-03-18 18:59:45,872 | flwr_core.py:107 | cid: 757
INFO flwr 2025-03-18 18:59:45,875 | flwr_core.py:107 | cid: 1742
INFO flwr 2025-03-18 18:59:45,883 | flwr_core.py:107 | cid: 2782
INFO flwr 20