# Federated Model Training

In the previous notebook, **Centralized_Model_Training**, we explored the limitations of centralized training, where models can struggle to generalize when certain classes or patterns are excluded from the training data. We observed that models trained under these conditions might produce incorrect predictions, highlighting the need for a more robust approach to training.

### In the 'Centralized_Training' notebook, we have seen that three models are trained on three different datasets. When we tested the models on data that they had not seen during training, the models did not perform well, and some even had an accuracy of zero. 

Data volume and diversity are critical for training good models, but the data is often distributed. Traditional training approaches, like the one used in **Centralized_Training**, assume centralized data, making it difficult or even impossible to centralize the data due to privacy concerns, regulations, and the sheer volume of data.

### Federated Learning (FL) is the solution to this problem. It operates on distributed data, allowing models to be trained across various devices and organizations while keeping the data localized. FL can be applied across different industries and organizational silos, providing a way to enhance model training without compromising data privacy.

### Objectives

1. Understand the fundamentals of federated learning and its advantages over traditional centralized training.
2. Implement federated learning using the Flower framework.
3. Train models on data distributed across different clients while maintaining data privacy.
4. Evaluate the performance of the federated model and compare it with the centralized approach.

Let's begin by setting up the necessary libraries and configurations for our federated learning experiment.


In [1]:
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import ndarrays_to_parameters, Context
from flwr.server import ServerApp, ServerConfig
from flwr.server import ServerAppComponents
from flwr.server.strategy import FedAvg
from flwr.simulation import run_simulation

from Utils2 import *

### Downloading and Preparing the MNIST Dataset

After importing the necessary libraries and the `Utils2.py` file, we will download the MNIST dataset. Then, we will split the dataset into three parts and exclude specific digits from each part to create distinct training sets.

In [2]:
trainset = datasets.MNIST(
    "./MNIST_data/", download=True, train=True, transform=transform
)


# Calculate the total length and split the dataset into three parts
total_length = len(trainset)
split_size = total_length // 3
torch.manual_seed(42)
part1, part2, part3 = random_split(trainset, [split_size] * 3)


# Exclude specific digits from each part
part1 = exclude_digits(part1, excluded_digits=[1, 3, 7])
part2 = exclude_digits(part2, excluded_digits=[2, 5, 8])
part3 = exclude_digits(part3, excluded_digits=[4, 6, 9])

# Store the training sets in a list
train_sets = [part1, part2, part3]

### Downloading the Test Dataset

Next, we will download the MNIST test dataset. This dataset will be used to evaluate the performance of our federated learning model. Additionally, we will include the digits that were excluded from the training sets to assess the model's ability to generalize to unseen data.

In [3]:
# Download the MNIST test dataset
testset = datasets.MNIST(
    "./MNIST_data/", download=True, train=False, transform=transform
)
print("Number of examples in `testset`:", len(testset))

# Include specific digits that were excluded during training
testset_137 = include_digits(testset, [1, 3, 7])
testset_258 = include_digits(testset, [2, 5, 8])
testset_469 = include_digits(testset, [4, 6, 9])

Number of examples in `testset`: 10000


#### Model Parameter Exchange in Federated Learning

In federated learning, exchanging model parameters between the server and clients is essential. When a client receives model parameters from the server, it updates its local model with those new parameters. After completing the training, the client sends the updated model parameters back to the server. To facilitate this exchange, we need two functions: `set_weights()` and `get_weights()`. These functions are used for the client-server exchange of training information.
The functions are defined as follows:


In [4]:
# Sets the parameters of the model
def set_weights(net, parameters):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict(
        {k: torch.tensor(v) for k, v in params_dict}
    )
    net.load_state_dict(state_dict, strict=True)

# Retrieves the parameters from the model
def get_weights(net):
    ndarrays = [
        val.cpu().numpy() for _, val in net.state_dict().items()
    ]
    return ndarrays

#### FlowerClient Class in Federated Learning

To connect our training and evaluation pipeline in federated learning, we define a `FlowerClient` class, which is a subclass of the `NumPyClient` class. The `FlowerClient` class typically includes two key methods: the `fit` method and the `evaluate` method.

- **`fit` Method**: This method is responsible for training the neural network using the provided parameters and the local training dataset.
- **`evaluate` Method**: This method evaluates the performance of the neural network using the provided parameters and the local test dataset.

The `FlowerClient` class is defined as follows:


In [5]:
class FlowerClient(NumPyClient):
    def __init__(self, net, trainset, testset):
        self.net = net
        self.trainset = trainset
        self.testset = testset

    # Train the model
    def fit(self, parameters, config):
        set_weights(self.net, parameters)  # Set the model weights
        train_model(self.net, self.trainset)  # Train the model on local dataset
        return get_weights(self.net), len(self.trainset), {}   # Return updated weights, number of examples, and additional information

    # Test the model
    def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar]):
        set_weights(self.net, parameters)  # Set the model weights
        loss, accuracy = evaluate_model(self.net, self.testset)    # Evaluate the model on the local test dataset
        return loss, len(self.testset), {"accuracy": accuracy}     # Return loss, number of examples, and accuracy

#### Client Function for Federated Learning

To enable the Flower framework to create client instances as needed, we utilize a function called `client_fn`. This function is responsible for creating Flower client instances on demand, which is essential for resource utilization and optimization.

By using this function, federated training can easily span hundreds of clients and can be efficiently simulated on a single machine. The Flower framework calls the `client_fn` function whenever it requires an instance of a specific client to invoke the `fit` or `evaluate` methods of the `FlowerClient` class.

The `client_fn` function is defined as follows:

In [6]:
# Client function
def client_fn(context: Context) -> Client:
    net = SimpleModel() # Create an instance of the SimpleModel
    partition_id = int(context.node_config["partition-id"])   # Retrieve the partition ID from context
    client_train = train_sets[int(partition_id)] # Get the corresponding training set for the client
    client_test = testset  # Use the shared test set for evaluation
    return FlowerClient(net, client_train, client_test).to_client()   # Return the FlowerClient instance

### Creating a ClientApp Instance

Next, we create an instance of `ClientApp` by passing the previously defined `client_fn`. The `ClientApp` serves as the entry point for everything happening on the client side, facilitating communication between the client and the server in the federated learning setup.

The following code snippet demonstrates how to create the `ClientApp` instance:

In [7]:
# Create an instance of ClientApp by passing the previously defined client_fn
client = ClientApp(client_fn=client_fn)

### Server-Side Evaluation Function

Now we need to establish the server-side counterpart that aggregates the models received from the clients and evaluates the performance of the global model. For this purpose, we need to define the following evaluation function called `evaluate`:

In [8]:

def evaluate(server_round, parameters, config):
    net = SimpleModel()
    set_weights(net, parameters)


    # assess the accuracy on full MNINST dataseta and alsot on diffrent subsets.
    _, accuracy = evaluate_model(net, testset)
    _, accuracy137 = evaluate_model(net, testset_137)
    _, accuracy258 = evaluate_model(net, testset_258)
    _, accuracy469 = evaluate_model(net, testset_469)

    log(INFO, "test accuracy on all digits: %.4f", accuracy)
    log(INFO, "test accuracy on [1,3,7]: %.4f", accuracy137)
    log(INFO, "test accuracy on [2,5,8]: %.4f", accuracy258)
    log(INFO, "test accuracy on [4,6,9]: %.4f", accuracy469)

    if server_round == 3:
        cm = compute_confusion_matrix(net, testset)
        plot_confusion_matrix(cm, "Final Global Model")

### Server Application and Strategy Definition

To create a server application, we first need to determine which strategy we want to implement. The strategy serves as an abstraction that implements the server-side federated learning algorithm. In this case, we will be using the FedAvg strategy (Federated Averaging), which is commonly employed in federated learning scenarios.

Here’s how we define the server function:

In [9]:
net = SimpleModel()
params = ndarrays_to_parameters(get_weights(net))  #  initialized a SimpleModel instance and retrieve its initial parameters using the get_weights function.
# These parameters will be used as the starting point for the federated learning process.

def server_fn(context: Context):
    strategy = FedAvg(
        fraction_fit=1.0, # Sample 100% of available clients for training
        fraction_evaluate=0.0, 
        initial_parameters=params,
        evaluate_fn=evaluate, # the function used for server side evaluation 
    )
    config=ServerConfig(num_rounds=3)
    return ServerAppComponents(
        strategy=strategy,
        config=config,
    )

### Creating the Server Application

Now that we have defined our server function and strategy, we can create an instance of `ServerApp`. This app will manage the federated learning process, coordinating the training and evaluation between clients and the server.

Here's how we can instantiate the `ServerApp`:

In [10]:
# Create an instance of the ServerApp
server = ServerApp(server_fn=server_fn)

### Running the Simulation

A real-life federated learning system is distributed, typically consisting of one server and several distributed devices (clients). For simplicity in this example, we will simulate such a system by running everything (both server and clients) on a single machine. 

To achieve this, we can use the Flower function `run_simulation`, as shown below:

In [None]:
# Initiate the simulation passing the server and client apps
# Specify the number of super nodes that will be selected on every round
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=3,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Evaluating initial global parameters
[92mINFO [0m:      test accuracy on all digits: 0.1267
[92mINFO [0m:      test accuracy on [1,3,7]: 0.2275
[92mINFO [0m:      test accuracy on [2,5,8]: 0.1201
[92mINFO [0m:      test accuracy on [4,6,9]: 0.0380
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=22196)[0m ++++++++++ client intilized ++++++++


### Final Remarks

In the above results, the **[INIT]** phase indicates that the system is using the initial global parameters provided by the strategy. It then evaluates these global model parameters using the evaluation function we defined earlier. This evaluation includes assessing the test accuracy on all digits as well as on the three different subsets we established previously.

The process continues through **[ROUND 1]**, **[ROUND 2]**, and **[ROUND 3]** respectively, showcasing how the model evolves over successive training rounds. Each round involves aggregating updates from clients, enhancing the model's performance by leveraging the diversity of data from different clients.

This federated learning approach demonstrates the potential for improved model accuracy and robustness, especially when dealing with non-IID (independent and identically distributed) data across multiple sources.
