# Running Federated Learning Locally and on GCP's Vertex AI 

  

## Overview
Federated Learning is a privacy preserving technique to train machine learning models.
When dealing with secure data (e.g. patient data) it is often restricted to merge datasets from different institutions (clients).
Federated Learning enables the training of a machine learning model without sharing the underlying data.
This is achieved by training local models at each client and only aggreagating the model weights. Instead of sending raw data to a central server, each client trains the model locally on its own data and only shares the model's updates (e.g., weights or gradients) with the central server. The server aggregates these updates (e.g., using algorithms like Federated Averaging, or FedAvg) to improve the global model.

This approach is particularly useful in scenarios where data privacy, security, or bandwidth constraints make it impractical to centralize data.

In this tutorial we will focus on centralized horizontal Federated Learning. Centralized FL has a coordinating server that controls the learning process and aggreagates the model weights [1].
Horizontal FL means that the same features are available on each client (e.g. images) [2].
The pendant to that would be vertical FL where different features are present, but for the same sample (e.g. patient).

This is the general overview of a federated learning training process.
The image was taken from the [NVIDIA blog](https://blogs.nvidia.com/blog/what-is-federated-learning/) and slightly modified.

<img src="../../images/federated_learning_animation_still_white.png" width="800">

The training process can be split up in [5 steps](https://flower.ai/docs/framework/tutorial-series-what-is-federated-learning.html):

0. Initialize the global model
1. Send global model to clients
2. Local training
3. Return model updates to coordinator
4. Aggregate model updates by averaging (FedAvg)
5. Repeat steps 1 to 4 until convergence

We have extended these steps to 7 to allow a more detailed approach to our workflow.

## Prerequisites

This notebooks was run using the machinetype n1-highcpu-8 (8 vCPUs, 7.199 GB RAM) on Pytorch. Ensure Vertex AI and Cloud Storage APIs are enabled. Visit the following tutorial to set up notebooks that utilize: [GPUs Spinning up a Vertex AI Notebook](https://github.com/STRIDES/NIHCloudLabGCP/blob/42ee2b7dbffce54e53a212d8c02ac16fd872c5be/docs/vertexai.md) for faster speeds if needed.

## Learning Objectives
* Understand Federated Learning 
* Learn to created a Centralized training and Federated Learning workflows locally
* Evaluate and visualize model performance Centralized training vs. Federated Learning
* Learn how to adapt the federated learning process to Google Clouds' Vertex AI.

## Get Started

Install the following packages. In this tutorial we are using the **Pytorch Kernel** in a Vertex AI Workbench Jupyter Notebook which has Pytorch preinstalled. If you are not using the same setting you can install the rest of the needed packages by running `pip install torch pandas scikit-learn matplotlib ordereddict`.

The kfp package will allow us to complie the functions that we are about to make into a pipeline which we will use in a later step.

In [None]:
! pip install google-cloud-aiplatform kfp

Import our packages and functions.

In [None]:
import os
import torch

import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
import matplotlib.pyplot as plt

from kfp import compiler
from torch.nn import Sequential
from collections import OrderedDict
from google.cloud import aiplatform
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from kfp.dsl import component, pipeline, Output, Dataset, Model, Input, Artifact, Metrics
from sklearn.preprocessing import StandardScaler

### Centralized Training

As a first step we demonstrate the training of a ML model through a traditional, centralized training. Although this is not a prerequisite for Federated Learning both trainings share many of the same steps and we will be comparing the accuracy of the two trainings (Centralized Training vs.Federated Learning). It will also help us determine if our model is trainable.

To start centralized training we first define a class called `BreastCancerDataset`. In this tutorial we are using the Breast Cancer Wisconsin (Diagnosic) datset [3]. It contains 30 features, computed from digitized breast cancer images. The task is to perform binary classification into the two classes "malignant" (=1) and "benign" (=0).

If you are done with the tutorial and understand the main principles of federated learning, you can create your own data classes here.

#### 1. Data Prep
We have already split our data into training and validation datasets which you can see in the `data` directory. 

The class below main function is to take a standardize the feature columns within a dataframe (excluding the first column, which is an ID, and the last column, which is the diagnosis label). By standardizing our features we avoid any outliers that may cause our model to become biased in training. 

Then it converts the standardized features into a PyTorch tensor (`self.X`). Extracts the diagnosis labels (malignant = 1, benign = 0) from the last column of the DataFrame and converts them into a PyTorch tensor (`self.y`).

We will use this class for federated learning as well.

In [None]:
class BreastCancerDataset(Dataset):
    def __init__(self, df):
        scaler = StandardScaler()
        self.X = torch.tensor(scaler.fit_transform(df.iloc[:,1:-1].values))   # first (ID) and last (diagnisis) columns are excluded
        self.y =  torch.tensor(df.iloc[:,-1].values)                          # load the diagnosis (malignant=1, benign=0)
    
    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

After defining the class for our dataset we load it.
The data is split in a train and validation subset.
The loaded instance is additionaly wrapped into a PyTorch `DataLoader()` object.
This makes the dataset accessible to the model during training.

In [None]:
train_df = pd.read_csv(os.path.join("data", "full_train_data.csv"), dtype=np.float32)
val_df = pd.read_csv(os.path.join("data", "full_val_data.csv"), dtype=np.float32)

train_data = BreastCancerDataset(train_df)
val_data = BreastCancerDataset(val_df)

train_dataloader = DataLoader(train_data, batch_size=50, shuffle=True)
val_dataloader = DataLoader(val_data, batch_size=50, shuffle=False)

#### 2. Define Client class for model training and validation 

Now we define a `Client` class that is used for the training.
The client receives the model and the train and validation data loaders.
The class will be used for the centralized and federated training.

The class contains two functions:
- `train()` runs the training of the model
- `validate()` runs the validation of the model on the given `val_loader`

The `train` function in the `Client` class trains the client's local model for one epoch using its assigned training data. It iterates through the training dataset in batches, computes predictions using the model, and calculates the loss with the specified loss function (`criterion`). The function performs backpropagation by calculating gradients and updating the model's weights using the optimizer. It also tracks the number of correct predictions to compute the training accuracy for the epoch. Finally, it records the epoch's loss and accuracy in the client's `metrics` dictionary for later evaluation and visualization.

The `validate` function in the `Client` class evaluates the client's local model using its validation dataset. It sets the model to evaluation mode (`model.eval()`) to disable dropout and other training-specific behaviors. The function iterates over the validation dataset (similar to the training dataset), computes predictions, and calculates the loss for each batch without updating the model's weights. It also tracks the number of correct predictions to compute the validation accuracy. Finally, it records the average loss and accuracy for the validation epoch in the client's `metrics` dictionary for later analysis.


In [None]:
class Client:
    def __init__(self, name, model, train_loader, val_loader, optimizer, criterion):
        self.name = name
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.metrics = dict({"train_acc": list(), "train_loss": list(), "val_acc": list(), "val_loss": list()})

        print(f"[INFO] Initialized client '{self.name}' with {len(train_loader.dataset)} train and {len(val_loader.dataset)} validation samples")
        
        
    def train(self):
        """
            Trains the model of the client for 1 epoch.
        """
        self.model.train()
        correct_predictions = 0
        running_loss = 0.0

        # iterate over training dataset
        for inputs, labels in self.train_loader:
            # make predictions
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            labels = torch.unsqueeze(labels, 1)

            # apply gradient
            loss = self.criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()
            running_loss += loss.item()

            # calculate number of correct predictions
            predicted = torch.round(outputs)
            correct_predictions += (predicted == labels).sum().item()

        # calculate overall loss and acc.
        epoch_loss = running_loss / len(self.train_loader)
        accuracy = correct_predictions / len(self.train_loader.dataset)

        # save metrics
        self.metrics["train_acc"].append(accuracy)
        self.metrics["train_loss"].append(epoch_loss)
    
    def validate(self):
        """
            Validates the model of the client based on the given validation data loader.
        """
        self.model.eval()
        total_loss = 0
        correct_predictions = 0

        # iterate over validation data loader and make predictions
        with torch.no_grad():
            for inputs, labels in self.val_loader:
                outputs = self.model(inputs)
                labels = torch.unsqueeze(labels, 1)
                loss = self.criterion(outputs, labels)

                total_loss += loss.item()
                predicted = torch.round(outputs)
                correct_predictions += (predicted == labels).sum().item()

        # calculate overall loss and acc.
        average_loss = total_loss / len(self.val_loader)
        accuracy = correct_predictions / len(self.val_loader.dataset)

        # save metrics
        self.metrics["val_acc"].append(accuracy)
        self.metrics["val_loss"].append(average_loss)

#### 3. Defining the model

Now that we have our data set up and the client defined we can define our model!

The `SimpleNN` class defines a simple and small feedforward neural network for binary classification tasks. It contains three linear layers, with only a few nodes each. The network takes an input of a specified size (`n_input`), processes it through the layers, and outputs a single value between 0 and 1, representing the probability of the positive class. It is designed to be lightweight and efficient, making it suitable for use in both centralized and federated learning scenarios. The `forward` method defines how the input data flows through the network during training and inference.

After finishing the tutorial feel free to come back to here and implement your own models.



In [None]:
class SimpleNN(nn.Module):
    def __init__(self, n_input):
        super(SimpleNN, self).__init__()
        self.NN = Sequential(
            nn.Linear(n_input, 32),
            nn.ReLU(),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16,1),
            nn.Sigmoid()
        )

    def forward(self, x):
        logits = self.NN(x)
        return logits

The `n_input` is set to 30 because the input layer of the `SimpleNN` neural network is designed to accept 30 features (or columns) as input. This matches the number of standardized feature columns in the dataset used for training and validation. In the context of the `BreastCancerDataset` class, the dataset contains 30 numerical features after excluding the ID column and the diagnosis label column. These 30 features are then fed into the neural network to make predictions.

In [None]:
model = SimpleNN(n_input=30)

#### 4. Initalizing the client

With the model available we can set up our client that is used for centralized training.

The `optimizer` is responsible for updating the model's parameters (weights) during training to minimize the loss function. It uses the gradients computed during backpropagation to adjust the weights in the direction that reduces the loss.

The `criterion` is the loss function used to measure how well the model's predictions match the true labels. It calculates the error between the predicted outputs and the actual targets, which the optimizer then tries to minimize.

All of these functions are inputed into `central_client` to start the centralized training. after running this cell you should see an output stating that the client has been initialized this mean that the client has been created!

For this example the client has been assigned 397 samples for training and 172 samples for validation. 

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.BCELoss()
central_client = Client("central", model, train_dataloader, val_dataloader, optimizer, criterion)

#### 5. Begin training the model

Now we can start the training. Using the `central_client` that we just initalized we run training and validation for 10 epochs, where in each epoch we train the model once on all training samples and adapt the model. Then we validate the updated model.

In [None]:
epochs = 10
for i in range(epochs):
    print(f"Epoch {i}")
    # run one training epoch
    central_client.train()
    
    # run validation of training epoch
    central_client.validate()

#### 6. Plotting metrics to confrim convergence

After training we plot the training and validation metrics to check for convergence of the model by monitoring the loss and accuracy over multiple epochs. Lets take a look at what some of the metrics mean.

- **Training loss:** Measures how well the model is fitting the training data. A decreasing training loss over epochs indicates that the model is learning from the training data.
- **Training accuracy:** Tracks the proportion of correct predictions on the training dataset. An increasing training accuracy suggests that the model is improving its ability to classify the training samples correctly.
- **Validation loss:** Measures how well the model generalizes to unseen data (validation dataset). A decreasing validation loss indicates better generalization, while an increasing loss may suggest overfitting.
- **Validation accuracy:** Tracks the proportion of correct predictions on the validation dataset. An increasing validation accuracy indicates that the model is improving its performance on unseen data.

The model is considered to be converging when the training and validation losses stabilize (stop decreasing significantly) and the validation accuracy reaches a plateau.
If the validation loss starts increasing while the training loss continues to decrease, it may indicate overfitting, meaning the model is memorizing the training data instead of generalizing.

In [None]:
def plot_metrics(client, op_save):
    plt.figure(figsize=(8, 4))
    for k, v in client.metrics.items():
        x_vals = range(len(v))
        plt.plot(x_vals, v, label=k)

    plt.ylim(bottom=0.0, top=1.0)
    plt.xlim(left=0)
    plt.xlabel("Epoch")
    plt.ylabel("Metric")
    plt.title(client.name)
    plt.legend()
    plt.show()
    if op_save is not None:
        plt.savefig(op_save.path)
    plt.close()

Run the code cell below to see a visual of our model metrics! 

**Note:** This function also lets you save the image as a file which we will do later during the Vetex AI Federated Learning portion of this tutorial. If you would like to do this now you can change the `None` value to a file path.

In [None]:
plot_metrics(central_client, None)

#### 7. Evaluating the model

Additionally, we evaluate the model on the validation dataset. The `run_predictions` function below will iterate through the validation dataset, computes predictions using the model, and rounds the outputs to classify them as either 0 or 1 ("malignant" (=1) and "benign" (=0).).

It compares the predicted labels with the true labels and counts the number of correct predictions.
The accuracy is computed as the ratio of correct predictions to the total number of samples in the validation dataset.

In [None]:
def run_prediction(model, test_data_path):
    model.eval()
    
    test_df = pd.read_csv(test_data_path, dtype=np.float32)
    test_data = BreastCancerDataset(test_df)
    test_dataloader = DataLoader(test_data, batch_size=1, shuffle=False)

    correct_predictions = 0

    # iterate over validation data loader and make predictions
    with torch.no_grad():
        for inputs, labels in test_dataloader:
            outputs = model(inputs)
            labels = torch.unsqueeze(labels, 1)
            predicted = torch.round(outputs)
            correct_predictions += (predicted == labels).sum().item()

    # calculate overall acc.
    accuracy = correct_predictions / len(test_dataloader.dataset)
    
    print(f"{accuracy:.2f}")

Lets run the cell below to see our models perdiction accuracy!

In [None]:
print("Accuracy of the centrally trained model:")
run_prediction(central_client.model, 'data/full_val_data.csv')

The train and validation accurracy increases upon the epochs, while the loss decreases.
This is a sign that our model converges and we can move on to implement federated learning.

### Implementing Federated Learning locally

Now that we have shown that our model is trainable with the given breast cancer dataset, we can implement federated learning!

In the `data` folder there there are two clients already prepared for this tutorial (`client_0`, `client_1`).
The data was presplit homogeneously accross the three clients, stratified by diagnosis.

Just like in centralized data we are going to prep our data then we initialize each client, using the centrally initiallized model (Steps 1-4).



In [None]:
fed_model = SimpleNN(n_input=30)
# initialize clients
clients = list()
for i in range(2):
    train_df = pd.read_csv(os.path.join("data", f"client_{i}", "train_data.csv"), dtype=np.float32)
    val_df = pd.read_csv(os.path.join("data", f"client_{i}", "val_data.csv"), dtype=np.float32)
    
    train_data = BreastCancerDataset(train_df)
    val_data = BreastCancerDataset(val_df)

    train_dataloader = DataLoader(train_data, batch_size=7, shuffle=True)
    val_dataloader = DataLoader(val_data, batch_size=7, shuffle=False)

    optimizer = torch.optim.SGD(fed_model.parameters(), lr=0.01, momentum=0.9)
    criterion = nn.BCELoss()
    
    clients.append(Client(f"client_{i}", fed_model, train_dataloader, val_dataloader, optimizer, criterion))

#### Define model aggregation

This step is new and unique to Federated Learning because we need to define a function that aggregates the model weights. In this tutorial we use the basiv FedAvg algorithm for that [4]. It calculates the weighted mean for each node in the neural network.

In [None]:
def fed_avg(global_state_dict, client_states, n_data_points):
    """
    Averages the weights of client models to update the global model by FedAvg.

    Args:
        global_state_dict: The state dict of the global PyTorch model.
        client_states: A list of PyTorch models state dicts representing client models.
        n_data_points: A list with the number of data points per client.

    Returns:
        The state dict of the updated global PyTorch model.
    """
    averaged_state_dict = OrderedDict()

    for key in global_state_dict.keys():
        for state, n in zip(client_states, n_data_points):
            averaged_state_dict[key] =+ state[key] * (n/ sum(n_data_points))
   
    return averaged_state_dict

#### Definition of a coordination server

To orchestrate the federated learning process we define a coordination server.
It has just one function that runs the federated learning training.
The function loops over the clients and trains one epoch on each client.
Then the updated models are aggregated by the FedAvg function.
The updated models are sent back to the clients before validation.

In [None]:
class FLServer:
    def __init__(self, model, clients):
        self.model = model
        self.clients = clients
        self.n_data_points = [len(client.train_loader.dataset) for client in self.clients]

    def run(self, epochs):
        for i in range(epochs):
            print(f"Epoch {i}")

            # Step 2 of figure at the beginning of the tutorial
            for client in self.clients:
                client.train()

            # aggregate the models using FedAvg (Step 3 & 4 of figure at the beginning of the tutorial)
            client_states = [client.model.state_dict() for client in self.clients]                 # Step 3
            aggregated_state = fed_avg(self.model.state_dict(), client_states, self.n_data_points) # Step 4
            self.model.load_state_dict(aggregated_state)
            
            # redistribute central model (Step 1 of figure at the beginning of the tutorial)
            for client in fl_server.clients:
                client.model.load_state_dict(aggregated_state)

            # run validation of aggregated model
            for client in self.clients:
                client.validate()

            # repeat for n epochs (Step 5 of figure at the beginning of the tutorial

#### Start Federated Learning Training

Now we can finally start our federated training by calling the `run()` function! This will create training and validation accuracy and loss metrics for each epoch which we will then visualize inthe next step.

In [None]:
fl_server = FLServer(fed_model, clients)
# distribute the central model to all clients (Step 1 of figure at the beginning of the tutorial)
for client in fl_server.clients:
    client.model.load_state_dict(fl_server.model.state_dict())

#run training with server
fl_server.run(epochs=10)

#### Plot training metrics per client

After training is completed we can again have a look at the convergence of the model.
In this case we get one plot for each client, containing accurracy and loss.

In [None]:
for client in fl_server.clients:
    plot_metrics(client, None)

#### Compare Central vs. Federated Learning accuracy

Now we can compare the final performance of the centrally trained model against the model trained with federated learning.
The accuracies will not match perfectly, but they are close.

In [None]:
# print("Centrally trained model accuracy:")
# run_prediction(central_client.model, 'data/full_val_data.csv')
# print()
print("Model trained with federated learning accuracy:")
run_prediction(fl_server.model, 'data/full_val_data.csv')

### FL Vertex AI Custom Pipelines


Google Cloud's **Vertex AI custom pipeline** is a user-defined machine learning workflow built using Kubeflow Pipelines (KFP) and executed on Google Cloud's Vertex AI Pipelines service. It allows you to orchestrate and automate ML tasks such as data preprocessing, model training, evaluation, and deployment. Custom pipelines are composed of modular components, each performing a specific task, and these components can exchange data through inputs and outputs. The pipeline is compiled into a JSON file and deployed to Vertex AI, where it runs in a managed environment with support for logging, monitoring, and metrics visualization. This approach enables scalable, reproducible, and efficient ML workflows integrated with Google Cloud's ecosystem.

**Kubeflow Pipelines** is an open-source platform for building, deploying, and managing machine learning (ML) workflows on Kubernetes. It provides a way to define and orchestrate ML workflows as a series of reusable components, where each component performs a specific task (e.g., data preprocessing, model training, evaluation). These workflows are defined using Python code and compiled into a format that can be executed on Kubernetes. they allow you to orchestrate piplines easliy creating by creating workflows making them reusable and scalable. You can easily integrate other tools and workflows and visualize metrics.

#### Create a bucket and add data

Before we create our pipeline we need to add our data to a bucket. Enter your bucket name and project id in the following cell. Make sure your bucket has a globally unique name otherwise it will result in an error.

**Note:** If you are not running this notebook in Vertex AI Workbench then you will need to authenticate your credentials, run `!gcloud auth login` before creating a bucket

In [None]:
#Create a Google cloud storage bucket.
BUCKET='YOUR_BUCKET_NAME'

!gsutil mb gs://$BUCKET

Copy your client datasets (test and validation) and full validation data set to the buckets.

In [None]:
!gsutil cp -r data/client_0 gs://$BUCKET/data/
!gsutil cp -r data/client_1 gs://$BUCKET/data/
!gsutil cp data/full_val_data.csv gs://$BUCKET/data/

#### Training component

In Kubeflow Pipelines (KFP), **components** are the building blocks of a pipeline. Each component is a self-contained piece of code that performs a specific task in the machine learning workflow, such as data preprocessing, model training, evaluation, or deployment. Components are reusable, modular, and can be combined to create end-to-end pipelines.

Below we are creating two components, training and evaluation. Each component acts like it own environment and may require functions to be repeatedly defined.
Each component is compramized with a `@component` and the function containg the steps of that component. The `@component` flag may contain the following:
- The base image you would like the component to run in
- Packages to be installed in environment

The first component we will create is the `training` component where we are adding steps 1-5 with the additional aggregation step unique to federated learning. The function lables inputs, outputs and their datatypes. Its important to include the output datatypes as this will become the input to our next component. For this component we are outputting the training metrics and the model state.
 

In [None]:
@component(base_image="gcr.io/deeplearning-platform-release/pytorch-cu124.py310:latest",  
           packages_to_install=[
               "google-cloud-aiplatform",
               "ordereddict"
               ])
def training(
    clients_data_dir: str, 
    epochs: int,
    feature_inputs: int,
    num_client: int,
    client_metrics: Output[Metrics],
    client_metrics_output: Output[Dataset],
    trained_model_output: Output[Model]):

    import os 
    import numpy as np
    import pandas as pd
    import torch
    import torch.nn as nn
    from torch.nn import Sequential
    from torch.utils.data import Dataset
    from torch.utils.data import DataLoader
    from sklearn.preprocessing import StandardScaler

    class BreastCancerDataset(Dataset):
        def __init__(self, df):
            scaler = StandardScaler()
            self.X = torch.tensor(scaler.fit_transform(df.iloc[:,1:-1].values))   # first (ID) and last (diagnisis) columns are excluded
            self.y =  torch.tensor(df.iloc[:,-1].values)                          # load the diagnosis (malignant=1, benign=0)
        
        def __len__(self):
            return len(self.X)

        def __getitem__(self, idx):
            return self.X[idx], self.y[idx]
    
    class SimpleNN(nn.Module):
        def __init__(self, n_input):
            super(SimpleNN, self).__init__()
            self.NN = Sequential(
                nn.Linear(n_input, 32),
                nn.ReLU(),
                nn.Linear(32, 16),
                nn.ReLU(),
                nn.Linear(16, 1),
                nn.Sigmoid()
            )
        def forward(self, x):
            return self.NN(x)
    
    class Client:
        def __init__(self, name, model, train_loader, val_loader, optimizer, criterion):
            self.name = name
            self.model = model
            self.optimizer = optimizer
            self.criterion = criterion
            self.train_loader = train_loader
            self.val_loader = val_loader
            self.metrics = dict({"train_acc": list(), "train_loss": list(), "val_acc": list(), "val_loss": list()})

            print(f"[INFO] Initialized client '{self.name}' with {len(train_loader.dataset)} train and {len(val_loader.dataset)} validation samples")
            
            
        def train(self):
            """
                Trains the model of the client for 1 epoch.
            """
            self.model.train()
            correct_predictions = 0
            running_loss = 0.0

            # iterate over training dataset
            for inputs, labels in self.train_loader:
                # make predictions
                self.optimizer.zero_grad()
                outputs = self.model(inputs)
                labels = torch.unsqueeze(labels, 1)

                # apply gradient
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
                running_loss += loss.item()

                # calculate number of correct predictions
                predicted = torch.round(outputs)
                correct_predictions += (predicted == labels).sum().item()

            # calculate overall loss and acc.
            epoch_loss = running_loss / len(self.train_loader)
            accuracy = correct_predictions / len(self.train_loader.dataset)

            # save metrics
            self.metrics["train_acc"].append(accuracy)
            self.metrics["train_loss"].append(epoch_loss)
        
        def validate(self):
            """
                Validates the model of the client based on the given validation data loader.
            """
            self.model.eval()
            total_loss = 0
            correct_predictions = 0

            # iterate over validation data loader and make predictions
            with torch.no_grad():
                for inputs, labels in self.val_loader:
                    outputs = self.model(inputs)
                    labels = torch.unsqueeze(labels, 1)
                    loss = self.criterion(outputs, labels)

                    total_loss += loss.item()
                    predicted = torch.round(outputs)
                    correct_predictions += (predicted == labels).sum().item()

            # calculate overall loss and acc.
            average_loss = total_loss / len(self.val_loader)
            accuracy = correct_predictions / len(self.val_loader.dataset)

            # save metrics
            self.metrics["val_acc"].append(accuracy)
            self.metrics["val_loss"].append(average_loss)
    
    def fed_avg(global_state_dict, client_states, n_data_points):
        """
        Averages the weights of client models to update the global model by FedAvg.

        Args:
            global_state_dict: The state dict of the global PyTorch model.
            client_states: A list of PyTorch models state dicts representing client models.
            n_data_points: A list with the number of data points per client.

        Returns:
            The state dict of the updated global PyTorch model.
        """
        averaged_state_dict = OrderedDict()

        for key in global_state_dict.keys():
            for state, n in zip(client_states, n_data_points):
                averaged_state_dict[key] =+ state[key] * (n/ sum(n_data_points))
    
        return averaged_state_dict       

    class FLServer:
        def __init__(self, model, clients):
            self.model = model
            self.clients = clients
            self.n_data_points = [len(client.train_loader.dataset) for client in self.clients]

        def run(self, epochs):
            for i in range(epochs):
                print(f"Epoch {i}")

                # Step 2 of figure at the beginning of the tutorial
                for client in self.clients:
                    client.train()

                # aggregate the models using FedAvg (Step 3 & 4 of figure at the beginning of the tutorial)
                client_states = [client.model.state_dict() for client in self.clients]                 # Step 3
                aggregated_state = fed_avg(self.model.state_dict(), client_states, self.n_data_points) # Step 4
                self.model.load_state_dict(aggregated_state)
                
                # redistribute central model (Step 1 of figure at the beginning of the tutorial)
                for client in fl_server.clients:
                    client.model.load_state_dict(aggregated_state)

                # run validation of aggregated model
                for client in self.clients:
                    client.validate()

                # repeat for n epochs (Step 5 of figure at the beginning of the tutorial

    fed_model = SimpleNN(n_input=feature_inputs)
    # initialize clients
    clients = list()
    for i in range(num_client):
        train_df = pd.read_csv(os.path.join(clients_data_dir, f"client_{i}", "train_data.csv"), dtype=np.float32)
        val_df = pd.read_csv(os.path.join(clients_data_dir, f"client_{i}", "val_data.csv"), dtype=np.float32)
        
        train_data = BreastCancerDataset(train_df)
        val_data = BreastCancerDataset(val_df)

        train_dataloader = DataLoader(train_data, batch_size=7, shuffle=True)
        val_dataloader = DataLoader(val_data, batch_size=7, shuffle=False)

        optimizer = torch.optim.SGD(fed_model.parameters(), lr=0.01, momentum=0.9)
        criterion = nn.BCELoss()
        
        clients.append(Client(f"client_{i}", fed_model, train_dataloader, val_dataloader, optimizer, criterion))
    
    fl_server = FLServer(fed_model, clients)
    # distribute the central model to all clients (Step 1 of figure at the beginning of the tutorial)
    for client in fl_server.clients:
        client.model.load_state_dict(fl_server.model.state_dict())

    #run training with server
    fl_server.run(epochs=epochs)

    #save model
    torch.save(fl_server.model.state_dict(), trained_model_output.path)
    
    #### Create a list to store the clients' data for visulaization ###
    clients_data = []

    # Iterate through each client in fl_server.clients
    for client in fl_server.clients:
        # Create a dictionary for the current client
        client_data = {
            "name": client.name,
            "metrics": {}
        }

        # Add each metric to the "metrics" dictionary
        for metric_name, values in client.metrics.items():
            client_metrics.log_metric(f"{client.name} - {metric_name}", values)
            client_data["metrics"][metric_name] = values

        # Append the client's data to the list
        clients_data.append(client_data)  

    import json
    with open(client_metrics_output.path, "w") as f:
        json.dump(clients_data, f, indent=2)     

Looking at the last steps of the training component you will notice that other than saving the model the function is also logging in the training and validation accuracy and losss metrics. This will result in the metrics displaying in the Google Cloud console (by going to Vertex AI > Pipelines > Metrics) but will expor it as a dataset to create  create a visual of our metric in the next component.

#### Visualization component

The visulization component (Step 6) will take the metric output from the pervious component and will create a plot for each client. The plot will save in a Google Cloud bucket. Both plots will be saved to one image.

In [None]:
@component(base_image="gcr.io/deeplearning-platform-release/pytorch-cu124.py310:latest",  
           packages_to_install=[
               "google-cloud-aiplatform",
               "matplotlib"
               ])
def visualization(metrics_input: Input[Dataset], plot_save: Output[Artifact]):
    import json
    import matplotlib.pyplot as plt
    import os

    with open(metrics_input.path, "r") as f:
            metrics = json.load(f)
            #print(metrics)

    # Create subplots: one for each client
    num_clients = len(metrics)
    fig, axes = plt.subplots(num_clients, 1, figsize=(8, 4 * num_clients), squeeze=False)
    
    # Iterate over clients and plot their metrics
    for i, client in enumerate(metrics):
        ax = axes[i, 0]  # Access the subplot for the current client
        for k, v in client["metrics"].items():
            x_vals = range(len(v))
            ax.plot(x_vals, v, label=k)

        # Customize the subplot
        ax.set_title(f"Metrics for {client['name']}")
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Value")
        ax.legend()
        ax.grid(True)

    # Adjust layout and save the combined plot
    plt.tight_layout()
    plt.savefig(plot_save.path)
    print(f"Combined plot saved to {plot_save.path}")
    plt.close()

#### Evaluation component

Next we will create the `evaluation` component which will load the model from the training component then use the `run_prediction` function to evaluate the accuracy of the model (Step 7). The accuracy will be logged in our metrics using the  `log_metrics` function and we will be able to visualize them on the console by going to Vertex AI > Pipelines, then click on the the metric icon on your pipeline workflow.

In [None]:
@component(base_image="gcr.io/deeplearning-platform-release/pytorch-cu124.py310:latest",
           packages_to_install=[
               "google-cloud-aiplatform"
               ])
def evaluation(
    model_input: Input[Model],
    test_data_path: str,
    feature_inputs: int,
    eval_metrics: Output[Metrics]
    ):
    import pandas as pd
    import numpy as np
    import torch
    import torch.nn as nn
    from torch.nn import Sequential
    from sklearn.preprocessing import StandardScaler
    from torch.utils.data import DataLoader
    from torch.utils.data import Dataset

    class BreastCancerDataset(Dataset):
        def __init__(self, df):
            scaler = StandardScaler()
            self.X = torch.tensor(scaler.fit_transform(df.iloc[:,1:-1].values))   # first (ID) and last (diagnisis) columns are excluded
            self.y =  torch.tensor(df.iloc[:,-1].values)                          # load the diagnosis (malignant=1, benign=0)
        
        def __len__(self):
            return len(self.X)

        def __getitem__(self, idx):
            return self.X[idx], self.y[idx]
    
    class SimpleNN(nn.Module):
        def __init__(self, n_input):
            super(SimpleNN, self).__init__()
            self.NN = Sequential(
                nn.Linear(n_input, 32),
                nn.ReLU(),
                nn.Linear(32, 16),
                nn.ReLU(),
                nn.Linear(16, 1),
                nn.Sigmoid()
            )
        def forward(self, x):
            return self.NN(x)

    def run_prediction(model, test_data_path):
        model.eval()
        
        test_df = pd.read_csv(test_data_path, dtype=np.float32)
        test_data = BreastCancerDataset(test_df)
        test_dataloader = DataLoader(test_data, batch_size=1, shuffle=False)

        correct_predictions = 0

        # iterate over validation data loader and make predictions
        with torch.no_grad():
            for inputs, labels in test_dataloader:
                outputs = model(inputs)
                labels = torch.unsqueeze(labels, 1)
                predicted = torch.round(outputs)
                correct_predictions += (predicted == labels).sum().item()

        # calculate overall acc.
        accuracy = correct_predictions / len(test_dataloader.dataset)
        #run prediction to check accuracy and save metrics in the console
        eval_metrics.log_metric("accuracy", accuracy)
        eval_metrics.log_metric("total_samples", len(test_dataloader.dataset))
        print(f"{accuracy:.2f}")

    fl_server = SimpleNN(n_input=feature_inputs)     
    fl_server.load_state_dict(torch.load(model_input.path))

    print("Model trained with federated learning accuracy:")
    run_prediction(fl_server, test_data_path)

#### Construct the pipeline

The `@dsl.pipeline` decorator in Kubeflow Pipelines defines a pipeline. It assigns a name to the pipeline (e.g., `"federated-learning-pipeline"`) for identification in the Kubeflow or Vertex AI Pipelines UI. The decorated function specifies the components (steps) and their dependencies, describing how data flows between them. This function does not execute the pipeline but prepares it for compilation into a JSON file that can be submitted for execution. The pipeline enables reproducible and scalable orchestration of machine learning workflows.

In [None]:
@pipeline(
    name="federated-learning-pipeline",
    description="A pipeline for federated learning with client initialization, training, and evaluation."
)
def federated_learning_pipeline(
    clients_data_dir: str, 
    num_client: int, 
    feature_inputs: int, 
    epochs: int,
    test_data_path: str,
    ):
    
    train_model_task = training(
        clients_data_dir=clients_data_dir,
        num_client=num_client,
        epochs=epochs,
        feature_inputs=feature_inputs)
    
    visualization(
        metrics_input=train_model_task.outputs["client_metrics_output"])
    
    evaluation(
        model_input=train_model_task.outputs["trained_model_output"],
        test_data_path=test_data_path,
        feature_inputs=feature_inputs)

#### Complie the pipeline

Now that we have all of our components we will use `compiler` converts our pipeline into a JSON-based Intermediate Representation (IR) that can be executed by the Vertex AI Pipeline/Kubeflow Pipelines engine. It ensures the pipeline's structure is valid, defines the dependencies between components, and serializes metadata like inputs, outputs, and parameters. The compiler does not execute the pipeline but prepares it for orchestration on Kubernetes or other infrastructure. This process enables reproducible, scalable, and shareable machine learning workflows. 

In [None]:
compiler.Compiler().compile(
    pipeline_func=federated_learning_pipeline,
    package_path="federated_learning_pipeline.json"
)

#### Run the pipeline

Now we can finally submit/run our pipeline in Vertex AI! Specify your project is and location. Enter in the name of your pipeline to be displayed on the console. Enter in the path of the pipeline JSON file (this was made in the pervious step). Specify a `pipeline_root` this will be a bucket that will hold any input and outputs made by our pipeline, in our case this bucket hold a model file, our plots, and metrics. Lastly enter in any parameters needed to start our pipeline.

In [None]:
from google.cloud import aiplatform
project_id = "ENTER_YOUR_PROJECT_ID"
location = "ENTER_YOUR_REGION (e.g. us-central1)"

aiplatform.init(project=project_id, location=location)

pipeline_job = aiplatform.PipelineJob(
    display_name="federated-learning-pipeline",
    template_path="federated_learning_pipeline.json",
    pipeline_root=f"gs://{BUCKET}/pipeline_root/federated_learning_pipeline",
    parameter_values={
        "clients_data_dir": f"gs://{BUCKET}/data",
        "num_client": 2,
        "feature_inputs": 30,
        "epochs": 10,
        "test_data_path": f"gs://{BUCKET}/data/full_val_data.csv"
    }
)

Run your pipeline! After your run the cell below you should see a link to your pipeline (under "View Pipeline Job:"). You can us this link to open the console and track your pipeline, view logs, and outputs. This pipeline will take ~6 mins to run.

In [None]:
pipeline_job.run()

Once you head to the console after clicking the link you can view and track the progress of your pipeline.

![image1](../../images/fl_console1.png)

You can click on the icons to view metrics.

![image1](../../images/fl_console3.png)

Or head to the location of where some of you outputs may be, like our graphs.
![image1](../../images/fl_console2.png)

## Conclusion

In this tutorial, we explored the implementation of a federated learning pipeline on Google Cloud Platform (GCP) using Vertex AI Pipelines. We demonstrated how to preprocess data, train models across multiple clients, and aggregate the results in a federated learning setup. Additionally, we visualized client-specific metrics and registered the trained model in the Vertex AI Model Registry for future deployment and serving.

## Clean up

Make sure you shutdown or delete any notebooks or buckets that have been created in this tutorial. You can also delete you pipeline by going to Vertex AI > Pipeline. Select your pipeline then click 'Delete'.

## Additional resources

If you are further interested in Federated Learning here are some useful resources to continue your FL journey.

### Other Tutorials:
- [TensorFlow](https://www.tensorflow.org/federated/tutorials/building_your_own_federated_learning_algorithm)
- [FLOWER](https://flower.ai/docs/framework/tutorial-series-get-started-with-flower-pytorch.html)

### Frameworks:
- [NVFLare](https://developer.nvidia.com/flare)
- [TensorFlow](https://www.tensorflow.org/federated)
- [FLOWER](https://flower.ai/docs/framework/index.html)
- [FeatureCloud](https://featurecloud.ai/)

### Literature:
[1] Rieke et al., (2020), "[The future of digital health with federated learning](https://www.nature.com/articles/s41746-020-00323-1)"

[2] Zhang et al., (2021), "[A survey on federated learning](https://www.sciencedirect.com/science/article/pii/S0950705121000381)"

[3] Wolberg et al., (1993), [Breast Cancer Wisconsin (Diagnostic)](https://doi.org/10.24432/C5DW2B)

[4] McMahan et al., (2017), "[Communication-Efficient Learning of Deep Networks from Decentralized Data](https://proceedings.mlr.press/v54/mcmahan17a.html)"

