# Federated Fine-tuning of an AMPLIFY Model

This example demonstrates how to use the AMPLIFY protein language model from [chandar-lab/AMPLIFY](https://github.com/chandar-lab/AMPLIFY) for fine-tuning on multiple downstream tasks. AMPLIFY is a powerful protein language model that can be adapted for various protein-related tasks. In this example, we'll show how to fine-tune AMPLIFY for the prediction of several protein properties using antibody sequence data. For more details, please refer this [paper](https://www.biorxiv.org/content/10.1101/2024.09.23.614603v1).

Note, this script assumes a regular Python environment and doesn't rely on running Docker as in the previous example. For running AMPLIFY within the BioNeMo Framework, please see [here](https://docs.nvidia.com/bionemo-framework/latest/models/amplify/).

### Prerequisits
First download the data and install the required dependencies.

### Dataset

Before running the data preparation script, you need to clone the FLAb repository to obtain the required data:

In [None]:
!git clone https://github.com/Graylab/FLAb.git

The FLAb repository contains experimental data for six properties of therapeutic antibodies: Expression, thermostability, immunogenicity, aggregation, polyreactivity, and binding affinity.

First, we clone the AMPLIFY code and install it as a local pip package following the instructions [here](https://github.com/chandar-lab/AMPLIFY?tab=readme-ov-file#installation-as-a-local-pip-package). 

Note, we recommend creating a new virtual enviornment to run this JupyterLab Python kernel before installing the dependencies.

In [None]:
!git clone https://github.com/chandar-lab/AMPLIFY
!pip install --upgrade pip
!pip install --editable AMPLIFY[dev]

Furthermore, we install the required dependencies for this example:

In [None]:
!pip install -r requirements.txt

## Federated Multi-task Fine-tuning

In this scenario, each client trains a different downstream task from the [FLAb](https://github.com/Graylab/FLAb.git) antibody fitness datasets using a custom regression head. At the same time, they jointly fine-tune the AMPLIFY pretrained model trunk to benefit from each other using **Federated Learning (FL)**.

<div style="display: flex; justify-content: center; margin: 20px 0;">
<img src="./figs/amplify_multi_task.svg" alt="AMPLIFY model for multi-task fine-tuning" style="width: 400px;"/>
</div>

The process involves:
1. Obtaining antibody sequence data from [FLAb](https://github.com/Graylab/FLAb.git)
2. Preparing the data for fine-tuning combining "light" and "heavy" antibody sequences with a "|" separator and splitting the data into clients.
3. Fine-tuning the AMPLIFY model for binding affinity prediction in two scenarios:
    - Local training: Each data owner/client trains only on their local data.
    - Federated learning: We use the federated averaging algorithm to jointly train a global model on all the clients' data.

To allow clients to keep their regressor model local, we simply add a NVFlare [filter](https://nvflare.readthedocs.io/en/main/programming_guide/filters.html#filters) that removes the local regression layers before returning the updated AMPLIFY trunk to the server for aggregation. See the [run_fl_multitask.py](run_fl_multitask.py) where we add the [ExcludeParamsFilter](src/filters.py) filter.

### Data Preparation

The [combine_data.py](src/combine_data.py) script is used to prepare data for sequence classification. It processes CSV files containing 'heavy' and 'light' feature columns, combines them, and splits the data into training and test sets for each task.

**Combine the CSV Datasets**

In [None]:
for task in ["aggregation", "binding", "expression", "immunogenicity", "polyreactivity", "tm"]:
    print("Combing $task CSV data")
    !python src/combine_data.py --input_dir ./FLAb/data/$task --output_dir ./FLAb/data_fl/$task


This will:
1. Read all CSV files from the `data` directory for each of the six antibody properties (aggregation, binding, expression, immunogenicity, polyreactivity, and thermostability)
2. Combine the 'heavy' and 'light' columns with a '|' separator into a 'combined' column
3. Split the data into training (80%) and test (20%) sets
5. Save the processed data to the specified output directory

### Experiments
The following experiments use the [120M AMPLIFY](https://huggingface.co/chandar-lab/AMPLIFY_120M) pretrained model from HuggingFace. It was tested using three NVIDIA A100 GPUs with 80 GB memory each.
With the 120M AMPLIFY model, we can run two clients on each GPU as specified by the ``--sim_gpus`` argument to `run_fl_*.py`.


#### Local Training
First we run the local training. Here, each data owner/client trains only on their local data. As we only run 1 round, the clients will never get the benefit of the updated global model and can only learn from their own data.

This command will:
1. Run federated learning with 6 clients (one for each task)
2. Perform one round of training with NVFlare
3. Each client will train for 10 local epochs per round
4. Use the 120M parameter AMPLIFY model by default
5. Configure the regression MLP with layer sizes [128, 64, 32]

Note, you can monitor the training progress with TensorBoard by running `tensorboard --logdir /tmp/nvflare/AMPLIFY` in a separate terminal.


<div class="alert alert-block alert-info"> <b>NOTE:</b> To speed up the results, we only run for a few local epochs. However, can see the resulting plots below when running for `local_epochs=600` and `num_rounds=600` in the local and federated, respectively.</div>

In [None]:
!python run_fl_multitask.py \
    --num_rounds 1 \
    --local_epochs 10 \
    --pretrained_model "chandar-lab/AMPLIFY_120M" \
    --layer_sizes "128,64,32" \
    --exp_name "local_singletask" \
    --sim_gpus "0,1,2,0,1,2"

### Federated Learning
Next, we run the same data setting but using the federated averaging ([FedAvg](https://arxiv.org/abs/1602.05629)) algorithm.

This command will:
1. Run federated learning with 6 clients (one for each task)
2. Perform 10 rounds of federated averaging
3. Each client will train for 1 local epoch per round
4. Use the 120M parameter AMPLIFY model by default
5. Configure the regression MLP with layer sizes [128, 64, 32]

In [None]:
!python run_fl_multitask.py \
    --num_rounds 10 \
    --local_epochs 1 \
    --pretrained_model "chandar-lab/AMPLIFY_120M" \
    --layer_sizes "128,64,32" \
    --exp_name "fedavg_multitask" \
    --sim_gpus "0,1,2,0,1,2"

### 1.3 Visualize the results

Apart from monitoring the progress with TensorBoard, you can also use the plotting code in [figs/plot_training_curves.py](./figs/plot_training_curves.py) to load the generated TensorBoard event files and compare the performance "local" vs. "fedavg" experiments for each task. Here's an example of how to use it:

In [None]:
# Plot RMSE metrics for all tasks
!python figs/plot_training_curves.py \
    --log_dir /tmp/nvflare/AMPLIFY/multitask \
    --output_dir ./figs/tb_figs_rmse \
    --tag "RMSE/local_test" \
    --out_metric "RMSE"

In [None]:
# Plot training Pearson coefficients for all tasks
!python figs/plot_training_curves.py \
    --log_dir /tmp/nvflare/AMPLIFY/multitask \
    --output_dir ./figs/tb_figs_pearson \
    --tag "Pearson/local_test" \
    --out_metric "Pearson"

This will generate plots for each task comparing the local and federated training performance, saving them as both PNG and SVG files in the specified output directory. The plots will show the progression of the specified metric (RMSE or Pearson coefficients) over training steps for both local and federated training approaches.

**120M AMPLIFY Multi-task Fine-tuning Results**

We plot the RMSE and Pearson Coefficients for different downstream tasks (lower is better): "aggregation", "binding", "expression", "immunogenicity", "polyreactivity", and "Thermostability (tm)". As can be observed, the models trained using FedAvg can achieve lower RMSE values for several downstream tasks compared to the locally only trained counterparts on the test set. 

Pearson Coefficients closer to 1.0 would indicate a direct positive correlation between the ground truth and predicted values. It can be observed that several downstream tasks are challenging for the 120M and only achieve low correlation scores. See the [FLAb paper](https://www.biorxiv.org/content/10.1101/2024.01.13.575504v1) for comparison. However, the FedAvg experiment shows benefits for several downstream tasks.

> Note, by default, we smooth the training curves with a smoothing window of 30 (controlled by the `smoothing_window` argument).

### Root Mean Squared Error
<div style="display: flex; justify-content: center; gap: 20px; flex-wrap: nowrap;">
<img src="./figs/tb_figs_rmse/aggregation.svg" alt="Aggregation" style="width: 300px; flex-shrink: 0;"/>
<img src="./figs/tb_figs_rmse/binding.svg" alt="Binding" style="width: 300px; flex-shrink: 0;"/>
<img src="./figs/tb_figs_rmse/expression.svg" alt="Expression" style="width: 300px; flex-shrink: 0;"/>
</div>

<div style="display: flex; justify-content: center; gap: 20px; flex-wrap: nowrap;">
<img src="./figs/tb_figs_rmse/immunogenicity.svg" alt="Immunogenicity"  style="width: 300px; flex-shrink: 0;">
<img src="./figs/tb_figs_rmse/polyreactivity.svg" alt="Polyreactivity"  style="width: 300px; flex-shrink: 0;">
<img src="./figs/tb_figs_rmse/tm.svg" alt="Thermostability"  style="width: 300px; flex-shrink: 0;">
</div>


### Pearson Coefficient
<div style="display: flex; justify-content: center; gap: 20px; flex-wrap: nowrap;">
<img src="./figs/tb_figs_pearson/aggregation.svg" alt="Aggregation" style="width: 300px; flex-shrink: 0;"/>
<img src="./figs/tb_figs_pearson/binding.svg" alt="Binding" style="width: 300px; flex-shrink: 0;"/>
<img src="./figs/tb_figs_pearson/expression.svg" alt="Expression" style="width: 300px; flex-shrink: 0;"/>
</div>

<div style="display: flex; justify-content: center; gap: 20px; flex-wrap: nowrap;">
<img src="./figs/tb_figs_pearson/immunogenicity.svg" alt="Immunogenicity"  style="width: 300px; flex-shrink: 0;">
<img src="./figs/tb_figs_pearson/polyreactivity.svg" alt="Polyreactivity"  style="width: 300px; flex-shrink: 0;">
<img src="./figs/tb_figs_pearson/tm.svg" alt="Thermostability"  style="width: 300px; flex-shrink: 0;">
</div>

## Summary

This notebook demonstrates federated fine-tuning of the AMPLIFY protein language model for drug discovery applications. Here are the key components and steps covered:

1. **Setup and Dependencies**
   - Installation of AMPLIFY and required dependencies
   - Setup of the FLAb repository containing experimental data for therapeutic antibodies

2. **Data Preparation**
   - Processing of six antibody properties: aggregation, binding, expression, immunogenicity, polyreactivity, and thermostability
   - Data splitting into training (80%) and test (20%) sets
   - Combination of heavy and light chain sequences

3. **Model Architecture**
   - Based on the 120M AMPLIFY pretrained model
   - Transformer-based architecture with 24 encoder blocks
   - Custom regression head with layer sizes [128, 64, 32]
   - Total parameters: ~118M

4. **Training Process**
   - Federated learning setup with multiple clients
   - Learning rates: Trunk (0.0001) and Regressor (0.01)
   - Training metrics tracked: MSE loss, RMSE loss, and Pearson correlation
   - Model evaluation on test sets for each property

5. **Results**
   - Performance metrics tracked across different antibody properties
   - Visualization of training progress and model predictions
   - Comparison of federated vs. centralized training approaches

This example showcases how federated learning can be applied to drug discovery tasks while maintaining data privacy across different research institutions.

Let's recap, what we learned in this [chapter](../../11.3_recap/recap.ipynb).