## Problem Description

## Federated learning using APPFL


<img src="https://github.com/APPFL/APPFL/blob/main/docs/_static/logo/logo_small.png?raw=true" width="40%" alt="APPFL Logo">


In this tutorial, we will leverage the Advanced Privacy-Preserving Federated Learning ([APPFL](https://github.com/APPFL/APPFL)) framework to launch a federated learning client for running a federated learning experiment with two clients and one central server. The server launching code is available in your workspace as **APPFL_Server.ipynb**. Please make sure you first launch the server before launching this client.

### Dataset

In this example, we will be working on the gridfm graphkit dataset. **This notebook represents the Client 1.**

### Training Settings

We use a gridfm graphkit model.

In [None]:
# [WARNING]: Please only run this cell ONCE at the beginning of your script.
# First: Change the working directory to the root of the repository and ignore warnings
import os
import warnings

os.chdir("../..")
warnings.filterwarnings("ignore")

In [None]:
# Set seed for reproducibility
import torch
import random
import numpy as np

seed_value = 1

random.seed(seed_value)
np.random.seed(seed_value)
torch.manual_seed(seed_value)
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

### 1. Create federated learning client agent from configurations

We need to update the `server_uri` to the URL obtained from the **server** notebook.

#### Steps:

- ##### Obtain the `server` URL:
  - For example, the server URL from the server notebook is `172.31.79.131:50051`.

- ##### Update the code:
  - Replace the placeholder in the following line with the actual `server` URL obtained.

  - Update the code as follows:
`client_agent_config.comm_configs.grpc_configs["server_uri"] = "172.31.79.131:50051"`

In the client configurations, it has four main parts:

- `client_id`: A unique identifier for the client
- `train_configs`: Client-specific training related configurations, such as the device and logging directories
- `data_configs`: Information about the dataloader file that can create a PyTorch dataset for the IXI data
- `comm_configs`: Information needed to connect to the server notebook

In [None]:
NUM_CLIENTS = 
CLIENT_ID = 
SERVER_URI = 

In [None]:
from omegaconf import OmegaConf
from appfl.agent import ClientAgent

client_agent_config = OmegaConf.load(
    "./resources/configs/grid/client_1.yaml"
)
client_agent_config.comm_configs.grpc_configs["server_uri"] = (
    str(SERVER_URI)  # Reminder: Replace this with the URI you got from the server notebook!
)
client_agent_config.client_id = f"Client{CLIENT_ID}"
client_agent_config.data_configs.dataset_kwargs["num_clients"] = NUM_CLIENTS
client_agent_config.data_configs.dataset_kwargs["client_id"] = CLIENT_ID - 1

print("==========Client Configuration==========")
print(OmegaConf.to_yaml(client_agent_config))
print("========================================")
client_agent = ClientAgent(client_agent_config=client_agent_config)

### 2. Create Client Communicator

Now, we create a grpc client communicator for sending various requests to the server.

In [None]:
from appfl.comm.grpc import GRPCClientCommunicator

client_communicator = GRPCClientCommunicator(
    client_id=client_agent.get_id(),
    **client_agent_config.comm_configs.grpc_configs,
)

### 3. Start training loop by sending requests to server.

In this main training loop, it has four main types of request to send to the server:

(1) `get_configuration()`: Get general client configurations for local training

(2) `get_global_model(init_model=True)`: Get the initial global model for training

(3) `update_global_model()`: Send the trained local model to update the global model, and get the updated model back for further local training

(4) `invoke_custom_action(action="close_connection")`: Close the connection with the server

In [None]:
# Get general client configurations
client_config = client_communicator.get_configuration()
client_agent.load_config(client_config)

# Get initial global model parameters
init_global_model = client_communicator.get_global_model(init_model=True)
client_agent.load_parameters(init_global_model)

# Start local training loop
while True:
    client_agent.train()
    local_model, metadata = client_agent.get_parameters()
    new_global_model, metadata = client_communicator.update_global_model(
        local_model, **metadata
    )
    if metadata["status"] == "DONE":
        break
    client_agent.load_parameters(new_global_model)

# Close connection
client_communicator.invoke_custom_action(action="close_connection")