In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
from pathlib import Path
from typing import Any
from logging import INFO, DEBUG
import json
import os
import numpy as np
from torch import nn
from torch.utils.data import DataLoader
from scipy.signal import medfilt
from flwr.common import log, ndarrays_to_parameters
import matplotlib.pyplot as plt
import math

from src.common.client_utils import (
    load_femnist_dataset,
    get_network_generator_cnn as get_network_generator,
    get_device,
    get_model_parameters,
    aggregate_weighted_average,
)


from src.flwr_core import (
    set_all_seeds,
    get_paths,
    decompress_dataset,
    get_flower_client_generator,
    sample_random_clients,
    get_federated_evaluation_function,
    create_iid_partition,
)

from src.estimate import (
    compute_critical_batch,
)

from src.experiments_simulation import (
    run_simulation,
    centralized_experiment,
)

from src.utils import get_centralized_acc_from_hist

PathType = Path | str | None

In [2]:
set_all_seeds()

PATHS = get_paths()

HOME_DIR = PATHS["home_dir"]
DATASET_DIR = PATHS["dataset_dir"]
DATA_DIR = PATHS["data_dir"]
CENTRALIZED_PARTITION = PATHS["centralized_partition"]
CENTRALIZED_MAPPING = PATHS["centralized_mapping"]
FEDERATED_PARTITION = PATHS["federated_partition"]
FEDERATED_IID_PARTITION = PATHS["iid_partition"]

# extract dataset from tar.gz
decompress_dataset(PATHS)

In [None]:
max_clients = 1000
create_iid_partition(PATHS, num_clients=max_clients)

In [4]:
NETWORK_GENERATOR = get_network_generator()
SEED_NET = NETWORK_GENERATOR()
SEED_MODEL_PARAMS = get_model_parameters(SEED_NET)
CID_CLIENT_GENERATOR = get_flower_client_generator(NETWORK_GENERATOR, FEDERATED_IID_PARTITION)

In [None]:
# FL experiments
experiment_batch_sizes = [32, 64, 128, 256, 512]
cohort_sizes = [5, 10, 20, 50, 75, 100, 150]


# Federated configuration dictionary
federated_train_config = {
    "epochs": 10,
    "batch_size": 32,
    "client_learning_rate": 0.01,
    "weight_decay": 0,
    "num_workers": 0,
    "max_batches": 100,
    "return_params": False

}

federated_test_config: dict[str, Any] = {
    "batch_size": 32,
    "num_workers": 0,
    "max_batches": 100,
}

num_rounds = 10
num_total_clients = 100
num_evaluate_clients = 0
num_clients_per_round = 10

initial_parameters = ndarrays_to_parameters(SEED_MODEL_PARAMS)

federated_evaluation_function = get_federated_evaluation_function(
    batch_size=federated_test_config["batch_size"],
    num_workers=federated_test_config["num_workers"],
    model_generator=NETWORK_GENERATOR,
    criterion=nn.CrossEntropyLoss(),
    max_batches=None if "max_batches" not in federated_test_config else federated_test_config["max_batches"],
)

server_learning_rate = 1.0
server_momentum = 0.0
accept_failures = False


CID_CLIENT_GENERATOR = get_flower_client_generator(NETWORK_GENERATOR, FEDERATED_IID_PARTITION)

list_of_ids = sample_random_clients(
    num_total_clients, federated_train_config["batch_size"],
    CID_CLIENT_GENERATOR, max_clients=max_clients,
)

federated_client_generator = (
    get_flower_client_generator(
        NETWORK_GENERATOR, FEDERATED_IID_PARTITION, lambda seq_id: list_of_ids[seq_id]
    )
)

### Estimating critical batch size

In [None]:

metric_keys = ['training_time', 'samples_processed', 'noise_scale', 'train_loss', 'actual_batches']
import gc

B_simples = []
results = []
batch_sizes = [16, 32, 64, 128, 256]
for batch_size in batch_sizes:
    train_cfg = federated_train_config.copy()
    train_cfg["batch_size"] = batch_size
    ratio = np.sqrt(batch_size / 256)
    train_cfg["client_learning_rate"] = ratio * 0.01 # Same as centralized, but should be lower for FL
    train_cfg["return_params"] = True # needed to get g locals for estimation

    test_cfg = federated_test_config.copy()
    test_cfg["batch_size"] = batch_size

    
    parameters_for_each_round, hist = run_simulation(
        num_rounds = 10,
        num_total_clients = num_total_clients,
        num_clients_per_round = num_clients_per_round,
        num_evaluate_clients = num_evaluate_clients,
        min_available_clients = num_total_clients,
        min_fit_clients = num_clients_per_round,
        min_evaluate_clients = num_evaluate_clients,
        evaluate_fn = federated_evaluation_function,
        on_fit_config_fn = lambda _: train_cfg,
        on_evaluate_config_fn = lambda _: test_cfg,
        initial_parameters = initial_parameters,
        fit_metrics_aggregation_fn = aggregate_weighted_average,
        evaluate_metrics_aggregation_fn = aggregate_weighted_average,
        federated_client_generator = federated_client_generator,
        server_learning_rate=server_learning_rate,
        server_momentum=server_momentum,
        accept_failures=accept_failures,
        target_accuracy=0.60,
        use_target_accuracy=True
        )
    n_params = len(hist.metrics_distributed_fit.keys()) - 5
    param_keys = list(set(hist.metrics_distributed_fit.keys()) - set(metric_keys))
    hist_metrics = {key: hist.metrics_distributed_fit[key] for key in metric_keys}
    params = [hist.metrics_distributed_fit[key] for key in param_keys]
    del hist
    #gc.collect()

    res = (batch_size, parameters_for_each_round, hist_metrics, params)
    results.append(res)

In [None]:
import torch
import math

B_simples = [] # 244230
n_clients = 10

K = n_clients
alpha = 0.9

for k, res in enumerate(results):
    batch_size, parameters_for_each_round, hist_metrics, params = res
    B_small = batch_size
    B_big = num_clients_per_round * batch_size
    G_local = params
    n_rounds = len(params[0])
    params_filt = [params[i] for i in range(len(params)) if len(params[i]) == n_rounds]
    G_local_by_rounds = [[params_filt[i][j][1]['all'] for i in range(len(params_filt))] for j in range(n_rounds)]
    B_simples.append([0] * n_rounds)
    for round_idx, G_local in enumerate(G_local_by_rounds):
        G_local = [[el[1] for el in G_loc] for G_loc in G_local]
        G_local_filt = [G_local[i] for i in range(len(G_local)) if len(G_local[i]) == 10]
        G_local_filt = np.array(G_local_filt)
        G_local_filt = G_local_filt.reshape(K, -1)

        G_local_filt = [torch.tensor(G_local) for G_local in G_local_filt]
        
        local_norm_squared = torch.tensor([torch.norm(G_local)**2 for G_local in G_local_filt])

        GBsmall_squared = local_norm_squared.sum() / K

        G_big = sum(G_local_filt) / K

        GBbig_squared = torch.norm(G_big)**2 

        G2 = (1 / (B_big - B_small)) * (B_big * GBbig_squared - B_small * GBsmall_squared) 

        S = (B_small * B_big / (B_big - B_small)) * (GBbig_squared - GBsmall_squared)

        B_simple = S/G2

        B_simples[k][round_idx] = B_simple

B_simples = [[el.item() for el in B_simple] for B_simple in B_simples]
with open("B_simples.txt", "w") as f:
    f.write(str(B_simples))
print(B_simples)

# [[7331.360775759389, 10162.621287155156], [289988.8675930225, -76584.42877282941], [65100.4030868133, 15098.130444262237], [-52009.415823073214, -18557.47927118022, -11092165.511967614, -93741.82440415821], [-63007.125717336494, -74530.95902848525, -1084232.788576654, 15547940.564414758, 407759.4055155981, 272396.49553753383, -990024.0722842965, -127974.54799143146]]

In [None]:

metric_keys = ['training_time', 'samples_processed', 'noise_scale', 'train_loss', 'actual_batches']
import gc

B_simples = []
results = []
batch_sizes = [16, 32, 64, 128, 256]
batch_times = []
for batch_size in batch_sizes:
    train_cfg = federated_train_config.copy()
    train_cfg["batch_size"] = batch_size
    ratio = np.sqrt(batch_size / 256)
    train_cfg["client_learning_rate"] = ratio * 0.01 # Same as centralized, but should be lower for FL

    test_cfg = federated_test_config.copy()
    test_cfg["batch_size"] = batch_size

    local_list_of_ids = sample_random_clients(num_total_clients, train_cfg["batch_size"], CID_CLIENT_GENERATOR, max_clients=max_clients)
    local_federated_client_generator = get_flower_client_generator(NETWORK_GENERATOR, FEDERATED_IID_PARTITION, lambda seq_id: local_list_of_ids[seq_id])

    parameters_for_each_round, hist = run_simulation(
        num_rounds = 2,
        num_total_clients = num_total_clients,
        num_clients_per_round = num_clients_per_round,
        num_evaluate_clients = num_evaluate_clients,
        min_available_clients = num_total_clients,
        min_fit_clients = num_clients_per_round,
        min_evaluate_clients = num_evaluate_clients,
        evaluate_fn = federated_evaluation_function,
        on_fit_config_fn = lambda _: train_cfg,
        on_evaluate_config_fn = lambda _: test_cfg,
        initial_parameters = initial_parameters,
        fit_metrics_aggregation_fn = aggregate_weighted_average,
        evaluate_metrics_aggregation_fn = aggregate_weighted_average,
        federated_client_generator = local_federated_client_generator,
        server_learning_rate=server_learning_rate,
        server_momentum=server_momentum,
        accept_failures=accept_failures,
        )

    times = []
    for round_idx, round_metrics in hist.metrics_distributed_fit['training_time']:
        if round_idx > num_rounds:
            break
        round_times = [t for _, t in round_metrics['all']]
        times.append(np.mean(round_times))

    cumulative_time = np.sum(times)
    batch_times.append((batch_size, cumulative_time, times))
    #n_params = len(hist.metrics_distributed_fit.keys()) - 5
    #param_keys = list(set(hist.metrics_distributed_fit.keys()) - set(metric_keys))
    #hist_metrics = {key: hist.metrics_distributed_fit[key] for key in metric_keys}
    #params = [hist.metrics_distributed_fit[key] for key in param_keys]
    #del hist
    #gc.collect()

    #res = (batch_size, parameters_for_each_round, hist_metrics, params)
    #results.append(res)


In [None]:
B_simples = [[-18407.686925566362, -13762.505288528548, -29665.108319094634, -12927.35371816805, 4372.760959719874, 10805.953609141283, -21040.561981702886, 23238.6966912209, -20496.066521615783, 40973.22031090615], [-40754.83941407597, -94828.01133622219, -168606.6061223791, -13053.941180392021, -12196.393251686653, 139131.94393159807, -18836.993638551605, 14911.932928595123, 71887.69261004392, -429673.47555676906], [-50326.972540666844, -221078.50863027747, -12331.574093571728, -10559.19947656049, -533488.2914706912, 72097.91998244265, 2756775.10465039, -37545.04286075924, -215591.0786897387, -124943.4867790073], [-86654.8377399388, 70070.96612662714, 869667.1183916295, 62085.1567899368, 47205.20980176431, -64801.69123063742, 59619.394715874994, 43996.21136494777, 76026.20309461898, 370612.13195318024], [-37359.721294111616, -179548.41406089318, -266947.3063659719, -956503.1802564623, 272930.986020992, -264428.3445225671, 100874.68615423376, -94089.052502197, -138475.6462939103, -258445.3923542993]]

batch_sizes = [16, 32, 64, 128, 256]
B_simples = [[abs(el) for el in subl] for subl in B_simples]
# apply median filter for each sublist
B_simple_median = [medfilt(subl, 5) for subl in B_simples]
# plot nicely B_simple through the rounds
fig, ax = plt.subplots(figsize=(10, 5))
for bs, B_sim in zip(batch_sizes, B_simple_median):
    ax.plot(B_sim, label=f"Batch size: {bs}")
ax.set_xlabel('Round')
ax.set_ylabel('Batch Size')
ax.set_title('Critical Batch Size through Rounds')
ax.legend()
# log scale y
ax.set_yscale('log')
#ax.set_ylim(0, 5000)
plt.grid(True)
plt.show()



In [None]:
#for b, cumulative_round_times, total_times in batch_times:
#    print("Batch size: ", b)
#    print("Total time: ", cumulative_round_times)
#    print("Times per round: ", total_times)
"""

Batch size:  16
Total time:  22.73178415029806
Times per round:  [22.73178415029806]
Batch size:  32
Total time:  13.338960419098475
Times per round:  [13.338960419098475]
Batch size:  64
Total time:  5.612612966999677
Times per round:  [5.612612966999677]
Batch size:  128
Total time:  3.1369846047005012
Times per round:  [3.1369846047005012]
Batch size:  256
Total time:  1.3977464394006347
Times per round:  [1.3977464394006347]

"""



In [None]:
import pickle

def save_experiment(save_file_name, batch_size, parameters_for_each_round, hist):
    """Save experiment results using pickle.
    
    Args:
        save_file_name (str): Path to save the results
        batch_size (int): Batch size used in experiment
        parameters_for_each_round (list): List of model parameters for each round
        hist (History): Flower History object containing metrics
    """
    
    results_dict = {
        'batch_size': batch_size,
        'parameters_for_each_round': parameters_for_each_round,
        'history': hist
    }
    
    with open(save_file_name, 'wb') as f:  # Note: 'wb' for binary write mode
        pickle.dump(results_dict, f)

def load_experiment(file_name):
    """Load experiment results from a pickle file.
    
    Args:
        file_name (str): Path to the results file
        
    Returns:
        tuple: (batch_size, parameters_for_each_round, hist)
    """
    with open(file_name, 'rb') as f:  # Note: 'rb' for binary read mode
        results_dict = pickle.load(f)
    
    return (
        results_dict['batch_size'],
        results_dict['parameters_for_each_round'],
        results_dict['history'],
    )

total_batch_results = []

experiment_batch_sizes = [256, 128, 64, 32, 16]
for batch_size in experiment_batch_sizes:
    print(f"------------------------------------------------- BATCH SIZE: {batch_size} ---------------------------------------------------------------")
    train_cfg = federated_train_config.copy()
    train_cfg["batch_size"] = batch_size
    ratio = np.sqrt(batch_size / 256) # non-iid ratio, 32 not working, 64 not working, 128 not working, 
    learning_rate = ratio * 0.01 # Same as centralized, but should be lower for FL
    train_cfg["client_learning_rate"] = learning_rate

    test_cfg = federated_test_config.copy()
    test_cfg["batch_size"] = batch_size

    local_list_of_ids = sample_random_clients(num_total_clients, train_cfg["batch_size"], CID_CLIENT_GENERATOR, max_clients=max_clients)
    local_federated_client_generator = get_flower_client_generator(NETWORK_GENERATOR, FEDERATED_IID_PARTITION, lambda seq_id: local_list_of_ids[seq_id])

    parameters_for_each_round, hist = run_simulation(
        num_rounds = num_rounds,
        num_total_clients = num_total_clients,
        num_clients_per_round = num_clients_per_round,
        num_evaluate_clients = num_evaluate_clients,
        min_available_clients = num_total_clients,
        min_fit_clients = num_clients_per_round,
        min_evaluate_clients = num_evaluate_clients,
        evaluate_fn = federated_evaluation_function,
        on_fit_config_fn = lambda _: train_cfg,
        on_evaluate_config_fn = lambda _: test_cfg,
        initial_parameters = initial_parameters,
        fit_metrics_aggregation_fn = aggregate_weighted_average,
        evaluate_metrics_aggregation_fn = aggregate_weighted_average,
        federated_client_generator = local_federated_client_generator,
        server_learning_rate=server_learning_rate,
        server_momentum=server_momentum,
        accept_failures=accept_failures,
        target_accuracy=0.60,
        use_target_accuracy=True,
        )

    total_batch_results.append((batch_size, parameters_for_each_round, hist))
    save_experiment(fr"results/IID_federated_local_batch_results_{batch_size}.pkl", batch_size, parameters_for_each_round, hist)
    # open a file, and append to it batch size, learnign rate, number of rounds taken to reach target accuracy
    with open("bruh_file.txt", "a") as f:
        f.write(f"Batch size: {batch_size}, Learning rate: {learning_rate}, Number of rounds taken to reach target accuracy: {len(hist.metrics_centralized['accuracy'])}\n")

In [None]:
import pickle

def save_experiment(save_file_name, batch_size, parameters_for_each_round, hist):
    """Save experiment results using pickle.
    
    Args:
        save_file_name (str): Path to save the results
        batch_size (int): Batch size used in experiment
        parameters_for_each_round (list): List of model parameters for each round
        hist (History): Flower History object containing metrics
    """
    
    results_dict = {
        'batch_size': batch_size,
        'parameters_for_each_round': parameters_for_each_round,
        'history': hist
    }
    
    with open(save_file_name, 'wb') as f:  # Note: 'wb' for binary write mode
        pickle.dump(results_dict, f)

def load_experiment(file_name):
    """Load experiment results from a pickle file.
    
    Args:
        file_name (str): Path to the results file
        
    Returns:
        tuple: (batch_size, parameters_for_each_round, hist)
    """
    with open(file_name, 'rb') as f:  # Note: 'rb' for binary read mode
        results_dict = pickle.load(f)
    
    return (
        results_dict['batch_size'],
        results_dict['parameters_for_each_round'],
        results_dict['history'],
    )

total_cohort_results = []
cohort_sizes =  [5, 10, 20, 50, 75, 100]
for cohort_size in cohort_sizes:
    train_cfg = federated_train_config.copy()
    ratio = np.sqrt(cohort_size / 100)
    train_cfg["client_learning_rate"] = ratio * 0.01

    test_cfg = federated_test_config.copy()

    parameters_for_each_round, hist = run_simulation(
        num_rounds = 10,
        num_total_clients = num_total_clients,
        num_clients_per_round = cohort_size,
        num_evaluate_clients = num_evaluate_clients,
        min_available_clients = num_total_clients,
        min_fit_clients = cohort_size,
        min_evaluate_clients = num_evaluate_clients,
        evaluate_fn = federated_evaluation_function,
        on_fit_config_fn = lambda _: train_cfg,
        on_evaluate_config_fn = lambda _: test_cfg,
        initial_parameters = initial_parameters,
        fit_metrics_aggregation_fn = aggregate_weighted_average,
        evaluate_metrics_aggregation_fn = aggregate_weighted_average,
        federated_client_generator = federated_client_generator,
        server_learning_rate=server_learning_rate,
        server_momentum=server_momentum,
        accept_failures=accept_failures,
        target_accuracy=0.60,
        use_target_accuracy=True,
        )

    total_cohort_results.append((cohort_size, parameters_for_each_round, hist))
    save_experiment(f"results/IID_federated_cohort_results_{cohort_size}.pkl", cohort_size, parameters_for_each_round=parameters_for_each_round, hist=hist)

In [None]:
import pickle

def save_experiment(save_file_name, batch_size, parameters_for_each_round, hist):
    """Save experiment results using pickle.
    
    Args:
        save_file_name (str): Path to save the results
        batch_size (int): Batch size used in experiment
        parameters_for_each_round (list): List of model parameters for each round
        hist (History): Flower History object containing metrics
    """
    
    results_dict = {
        'batch_size': batch_size,
        'parameters_for_each_round': parameters_for_each_round,
        'history': hist
    }
    
    with open(save_file_name, 'wb') as f:  # Note: 'wb' for binary write mode
        pickle.dump(results_dict, f)

def load_experiment(file_name):
    """Load experiment results from a pickle file.
    
    Args:
        file_name (str): Path to the results file
        
    Returns:
        tuple: (batch_size, parameters_for_each_round, hist)
    """
    with open(file_name, 'rb') as f:  # Note: 'rb' for binary read mode
        results_dict = pickle.load(f)
    
    return (
        results_dict['batch_size'],
        results_dict['parameters_for_each_round'],
        results_dict['history'],
    )

total_global_batch_results = []
cs_bs_pairs = [(100, 4000), (100, 12000)]#[(5, 20), (20, 50), (50, 200), (100, 250), (100, 1000), (100, 2000), (100, 4000), (100, 12000)]
for cohort_size, batch_size in cs_bs_pairs:
    global_batch_size = batch_size * cohort_size
    train_cfg = federated_train_config.copy()
    ratio = np.sqrt(cohort_size * batch_size / 1e6)
    # if i multiply by batch size, i want to divide
    train_cfg["client_learning_rate"] = ratio * 0.01
    #train_cfg["max_batches"] = 1000

    test_cfg = federated_test_config.copy()

    parameters_for_each_round, hist = run_simulation(
        num_rounds = 10,
        num_total_clients = num_total_clients,
        num_clients_per_round = cohort_size,
        num_evaluate_clients = num_evaluate_clients,
        min_available_clients = num_total_clients,
        min_fit_clients = cohort_size,
        min_evaluate_clients = num_evaluate_clients,
        evaluate_fn = federated_evaluation_function,
        on_fit_config_fn = lambda _: train_cfg,
        on_evaluate_config_fn = lambda _: test_cfg,
        initial_parameters = initial_parameters,
        fit_metrics_aggregation_fn = aggregate_weighted_average,
        evaluate_metrics_aggregation_fn = aggregate_weighted_average,
        federated_client_generator = federated_client_generator,
        server_learning_rate=server_learning_rate,
        server_momentum=server_momentum,
        accept_failures=accept_failures,
        target_accuracy=0.60,
        use_target_accuracy=True,
        )

    total_global_batch_results.append((global_batch_size, parameters_for_each_round, hist))
    save_experiment(f"results/IID_federated_global_batch_results_{global_batch_size}.pkl", global_batch_size, parameters_for_each_round=parameters_for_each_round, hist=hist)