In [None]:
! pip install -e ..

In [None]:
from fl_g13.fl_pytorch.client_app import get_client_app
from fl_g13.fl_pytorch.server_app import get_server_app
from fl_g13.fl_pytorch.model import get_experiment_setting
from flwr.simulation import run_simulation
from fl_g13.fl_pytorch.constants import NUM_CLIENTS, DEFAULT_FRACTION_FIT, DEFAULT_NUM_ROUNDS, DEFAULT_LOCAL_EPOCHS

In [None]:
LOCAL = True

number_of_rounds = DEFAULT_NUM_ROUNDS
fraction_fit = DEFAULT_FRACTION_FIT
number_of_clients = NUM_CLIENTS
show_distribution = False
local_epochs = DEFAULT_LOCAL_EPOCHS

if LOCAL:
    number_of_rounds = 3
    fraction_fit = 1
    number_of_clients = 5
    show_distribution = True
    local_epochs = 1

In [None]:
checkpoint_dir = "./../models/"

In [None]:
from flwr_datasets import FederatedDataset, partitioner
from fl_g13.fl_pytorch.datasets import show_partition_distribution
    
if show_distribution:
    fds = FederatedDataset(
            dataset="cifar10",
            partitioners={"train": partitioner.IidPartitioner(num_partitions=number_of_clients)}
        )
    p = fds.partitioners["train"]
    show_partition_distribution(p)

In [None]:
model, optimizer, criterion, device, scheduler = get_experiment_setting()
client_app = get_client_app(
    model, 
    optimizer, 
    criterion, 
    device, 
    partition="iid", 
    local_epochs=local_epochs
)
server_app = get_server_app(
    model=model, 
    optimizer=optimizer, 
    criterion=criterion, 
    device=device, 
    num_rounds=number_of_rounds, 
    min_available_clients=number_of_clients,
    min_fit_clients=number_of_clients,
    checkpoint_dir=checkpoint_dir,
    fraction_fit=fraction_fit,
)

In [None]:
run_simulation(
    client_app=client_app,
    server_app=server_app,
    num_supernodes=number_of_clients
)

In [None]:
from fl_g13.fl_pytorch.datasets import plot_results

strategy = server_app._strategy
results = strategy.results
if results:
    print("Contenuto di results.json:", results)

plot_results(results)