## Privacy Preserving Federated Learning on CIFAR-10 using Opacus

In [None]:
num_clients = 10

In [None]:
from omegaconf import OmegaConf

server_config_file = (
    "../../examples/resources/configs/cifar10/server_fedavg_opacus.yaml"
)
server_config = OmegaConf.load(server_config_file)
print(OmegaConf.to_yaml(server_config))

In [None]:
server_config.client_configs.train_configs.loss_fn_path = (
    "../../examples/resources/loss/celoss.py"
)
server_config.client_configs.train_configs.metric_path = (
    "../../examples/resources/metric/acc.py"
)
server_config.client_configs.model_configs.model_path = ""

if server_config.client_configs.train_configs["use_dp"] and (
    server_config.client_configs.train_configs["dp_mechanism"] == "opacus"
):
    server_config.client_configs.model_configs.model_path = (
        "../../examples/resources/model/resnet_opacus.py"
    )
else:
    server_config.client_configs.model_configs.model_path = (
        "../../examples/resources/model/resnet.py"
    )

server_config.server_configs.num_global_epochs = 3
server_config.server_configs.num_clients = num_clients

In [None]:
client_config_file = "../../examples/resources/configs/cifar10/client_1.yaml"
client_config = OmegaConf.load(client_config_file)
print(OmegaConf.to_yaml(client_config))

In [None]:
import copy

client_configs = [copy.deepcopy(client_config) for _ in range(num_clients)]
for i in range(num_clients):
    client_configs[i].client_id = f"Client{i + 1}"
    client_configs[
        i
    ].data_configs.dataset_path = "../../examples/resources/dataset/cifar10_dataset.py"
    client_configs[i].data_configs.dataset_kwargs.num_clients = num_clients
    client_configs[i].data_configs.dataset_kwargs.client_id = i
    client_configs[i].data_configs.dataset_kwargs.visualization = (
        True if i == 0 else False
    )

In [None]:
from appfl.agent import ServerAgent, ClientAgent

server_agent = ServerAgent(server_agent_config=server_config)
client_agents = [
    ClientAgent(client_agent_config=client_configs[i]) for i in range(num_clients)
]

In [None]:
# Get additional client configurations from the server
client_config_from_server = server_agent.get_client_configs()
for client_agent in client_agents:
    client_agent.load_config(client_config_from_server)

In [None]:
# Load initial global model from the server
init_global_model = server_agent.get_parameters(serial_run=True)
for client_agent in client_agents:
    client_agent.load_parameters(init_global_model)

In [None]:
# [Optional] Set number of local data to the server
for i in range(num_clients):
    sample_size = client_agents[i].get_sample_size()
    server_agent.set_sample_size(
        client_id=client_agents[i].get_id(), sample_size=sample_size
    )

In [None]:
while not server_agent.training_finished():
    new_global_models = []
    for client_agent in client_agents:
        # Client local training
        client_agent.train()
        local_model = client_agent.get_parameters()
        if isinstance(local_model, tuple):
            local_model, metadata = local_model[0], local_model[1]
        else:
            metadata = {}
        # "Send" local model to server and get a Future object for the new global model
        # The Future object will be resolved when the server receives local models from all clients
        new_global_model_future = server_agent.global_update(
            client_id=client_agent.get_id(),
            local_model=local_model,
            blocking=False,
            **metadata,
        )
        new_global_models.append(new_global_model_future)
    # Load the new global model from the server
    for client_agent, new_global_model_future in zip(client_agents, new_global_models):
        client_agent.load_parameters(new_global_model_future.result())