## Secure Aggregation in Federated Learning on MNIST

In this notebook, we demonstrate how to use secure aggregation in a federated learning setup using the MNIST dataset. We will configure the server and client agents to enable secure aggregation and run a simple federated learning experiment.

In [1]:
num_clients = 10

In [2]:
from omegaconf import OmegaConf

# server_config_file = "../../examples/resources/configs/mnist/server_fedavg.yaml"
server_config_file = "../../examples/resources/configs/mnist/server_fedavg_sec_agg.yaml"
server_config = OmegaConf.load(server_config_file)
print(OmegaConf.to_yaml(server_config))

client_configs:
  train_configs:
    trainer: VanillaTrainer
    mode: step
    num_local_steps: 100
    optim: Adam
    optim_args:
      lr: 0.001
    loss_fn_path: ./resources/loss/celoss.py
    loss_fn_name: CELoss
    do_validation: true
    do_pre_validation: true
    metric_path: ./resources/metric/acc.py
    metric_name: accuracy
    use_dp: false
    epsilon: 1
    clip_grad: false
    clip_value: 1
    clip_norm: 1
    train_batch_size: 64
    val_batch_size: 64
    train_data_shuffle: true
    val_data_shuffle: false
  model_configs:
    model_path: ./resources/model/cnn.py
    model_name: CNN
    model_kwargs:
      num_channel: 1
      num_classes: 10
      num_pixel: 28
  comm_configs:
    compressor_configs:
      enable_compression: false
      lossy_compressor: SZ2Compressor
      lossless_compressor: blosc
      error_bounding_mode: REL
      error_bound: 0.001
      param_cutoff: 1024
server_configs:
  num_clients: 2
  scheduler: SyncScheduler
  scheduler_kwargs:
   

In [3]:
server_config.client_configs.train_configs.loss_fn_path = (
    "../../examples/resources/loss/celoss.py"
)
server_config.client_configs.train_configs.metric_path = (
    "../../examples/resources/metric/acc.py"
)
server_config.client_configs.model_configs.model_path = (
    "../../examples/resources/model/cnn.py"
)

server_config.server_configs.num_global_epochs = 10
server_config.server_configs.num_clients = num_clients

In [4]:
client_config_file = "../../examples/resources/configs/mnist/client_1.yaml"
client_config = OmegaConf.load(client_config_file)
print(OmegaConf.to_yaml(client_config))

client_id: Client1
train_configs:
  device: cpu
  logging_output_dirname: ./output
  logging_output_filename: result
data_configs:
  dataset_path: ./resources/dataset/mnist_dataset.py
  dataset_name: get_mnist
  dataset_kwargs:
    num_clients: 2
    client_id: 0
    partition_strategy: class_noniid
    visualization: true
    output_dirname: ./output
    output_filename: visualization.pdf
comm_configs:
  grpc_configs:
    server_uri: localhost:50051
    max_message_size: 1048576
    use_ssl: false



In [5]:
import copy

client_configs = [copy.deepcopy(client_config) for _ in range(num_clients)]
for i in range(num_clients):
    client_configs[i].client_id = f"Client{i + 1}"
    client_configs[
        i
    ].data_configs.dataset_path = "../../examples/resources/dataset/mnist_dataset.py"
    client_configs[i].data_configs.dataset_kwargs.num_clients = num_clients
    client_configs[i].data_configs.dataset_kwargs.client_id = i
    client_configs[i].data_configs.dataset_kwargs.visualization = (
        True if i == 0 else False
    )

In [6]:
from appfl.agent import ServerAgent, ClientAgent

server_agent = ServerAgent(server_agent_config=server_config)
client_agents = [
    ClientAgent(client_agent_config=client_configs[i]) for i in range(num_clients)
]

[34m[1mappfl: ✅[0m[2025-11-11 15:37:06,280 server]: Logging to ./output/result_Server_2025-11-11-15-37-06.txt
[34m[1mappfl: ✅[0m[2025-11-11 15:37:06,288 Client1]: Logging to ./output/result_Client1_2025-11-11-15-37-06.txt
[34m[1mappfl: ✅[0m[2025-11-11 15:37:13,064 Client2]: Logging to ./output/result_Client2_2025-11-11-15-37-13.txt
[34m[1mappfl: ✅[0m[2025-11-11 15:37:19,680 Client3]: Logging to ./output/result_Client3_2025-11-11-15-37-19.txt
[34m[1mappfl: ✅[0m[2025-11-11 15:37:26,320 Client4]: Logging to ./output/result_Client4_2025-11-11-15-37-26.txt
[34m[1mappfl: ✅[0m[2025-11-11 15:37:32,940 Client5]: Logging to ./output/result_Client5_2025-11-11-15-37-32.txt
[34m[1mappfl: ✅[0m[2025-11-11 15:37:39,616 Client6]: Logging to ./output/result_Client6_2025-11-11-15-37-39.txt
[34m[1mappfl: ✅[0m[2025-11-11 15:37:46,192 Client7]: Logging to ./output/result_Client7_2025-11-11-15-37-46.txt
[34m[1mappfl: ✅[0m[2025-11-11 15:37:52,956 Client8]: Logging to ./output/result

In [7]:
# Get additional client configurations from the server
client_config_from_server = server_agent.get_client_configs()
for client_agent in client_agents:
    client_agent.load_config(client_config_from_server)

In [8]:
# Load initial global model from the server
init_global_model = server_agent.get_parameters(serial_run=True)
for client_agent in client_agents:
    client_agent.load_parameters(init_global_model)

In [9]:
# [Optional] Set number of local data to the server
for i in range(num_clients):
    sample_size = client_agents[i].get_sample_size()
    server_agent.set_sample_size(
        client_id=client_agents[i].get_id(), sample_size=sample_size
    )

In [10]:
for round_id in range(server_config.server_configs.num_global_epochs):
    all_client_ids = [str(c.client_id) for c in client_agents]
    secure_agg_secret = b"APPFL_SECURE_AGG_v1"  # choose per-job secret

    # compute global_num_examples_sum if you want sample_size weighting
    global_num_examples_sum = sum(len(c.train_dataset) for c in client_agents)
    new_global_models = []
    for client_agent in client_agents:
        # Client local training

        # set runtime context on client/trainer
        client_agent.runtime_context = {
            "all_client_ids": all_client_ids,
            "round_id": round_id,
            "secure_agg_secret": secure_agg_secret,
            "global_num_examples_sum": global_num_examples_sum,
            "local_num_examples": len(client_agent.train_dataset),
        }

        client_agent.train()
        local_model = client_agent.get_parameters()
        if isinstance(local_model, tuple):
            local_model, metadata = local_model[0], local_model[1]
        else:
            metadata = {}
        # "Send" local model to server and get a Future object for the new global model
        # The Future object will be resolved when the server receives local models from all clients
        new_global_model_future = server_agent.global_update(
            client_id=client_agent.get_id(),
            local_model=local_model,
            blocking=False,
            **metadata,
        )
        new_global_models.append(new_global_model_future)
    # Load the new global model from the server
    for client_agent, new_global_model_future in zip(client_agents, new_global_models):
        client_agent.load_parameters(new_global_model_future.result())

[34m[1mappfl: ✅[0m[2025-11-11 15:38:13,914 Client1]:      Round   Pre Val?       Time Train Loss Train Accuracy   Val Loss Val Accuracy
[34m[1mappfl: ✅[0m[2025-11-11 15:38:14,837 Client1]:          0          Y                                          2.3006      15.9300
[34m[1mappfl: ✅[0m[2025-11-11 15:38:17,103 Client1]:          0          N     2.2656     0.4323        90.8109    15.4933      30.3500
[34m[1mappfl: ✅[0m[2025-11-11 15:38:20,275 Client2]:      Round   Pre Val?       Time Train Loss Train Accuracy   Val Loss Val Accuracy
[34m[1mappfl: ✅[0m[2025-11-11 15:38:21,196 Client2]:          0          Y                                          2.3006      15.9300
[34m[1mappfl: ✅[0m[2025-11-11 15:38:23,464 Client2]:          0          N     2.2670     0.3130        87.7500     9.8549      48.6000
[34m[1mappfl: ✅[0m[2025-11-11 15:38:26,618 Client3]:      Round   Pre Val?       Time Train Loss Train Accuracy   Val Loss Val Accuracy
[34m[1mappfl: ✅[0m[2025-