# Federated Protein Downstream Fine-tuning

<div class="alert alert-block alert-info"> <b>NOTE</b> This notebook was tested on a DGX with 8 A100 GPUs with 80 GB memory each and is compatible with BioNeMo Framework v2.5. To leverage additional or higher-performance GPUs, you can modify the configuration files and simulation script to accommodate multiple devices and increase thread utilization respectively. To run with less memory consumption, you can reduce the micro-batch sizes in the `run_*.py` scripts.</div>

The example datasets used here are made available by [Therapeutics Data Commons](https://tdcommons.ai/) through PyTDC.

This example shows three different downstream tasks for fine-tuning a BioNeMo ESM-style model on different datasets.
We separate the scripts and job configurations into three folders based on the dataset names:


1. `tap`: therapeutic antibody profiling"
2. `sabdab`: SAbDab: the structural antibody database"
3. `scl`: "subcellular location prediction"

## Setup

Ensure that you have read through the Getting Started section, can run the BioNeMo Framework docker container, and have configured the NGC Command Line Interface (CLI) within the container. It is assumed that this notebook is being executed from within the container.

<div class="alert alert-block alert-info"> <b>NOTE</b> Some of the cells below generate long text output.  We're using <pre>%%capture --no-display --no-stderr cell_output</pre> to suppress this output.  Comment or delete this line in the cells below to restore full output.</div>

### Import and install all required packages

In [None]:
%%capture --no-display --no-stderr cell_output
! pip install fuzzywuzzy PyTDC --no-dependencies  # install tdc without dependencies to avoid version conflicts in the BioNeMo container
! pip install nvflare~=2.6rc

import os
import warnings

warnings.filterwarnings("ignore")
warnings.simplefilter("ignore")

---
### Task 1: Cross-endpoint multi-task fitting

#### Data: Five computational developability guidelines for therapeutic antibody profiling
See https://tdcommons.ai/single_pred_tasks/develop/#tap
- 241 Antibodies (both chains)

#### Task Description: *Regression*. 
Given the antibody's heavy chain and light chain sequence, predict its developability. The input X is a list of two sequences where the first is the heavy chain and the second light chain.

Includes five metrics measuring developability of an antibody: 
 - Complementarity-determining regions (CDR) length - Trivial (excluded)
 - patches of surface hydrophobicity (PSH) - Run on site-1
 - patches of positive charge (PPC) - Run on site-2
 - patches of negative charge (PNC) - Run on site-3
 - structural Fv charge symmetry parameter (SFvCSP) - Run on site-4

As indicated, we run each endpoint regression task on a different client. This simulates the multi-task fitting scenario with multiple endpoints where all client jointly train a shared ESM encoder trunk but keep their private regression heads for different endpoints (see the `BioNeMoExcludeParamsFilter` in [run_sum_tap.py](tap/run_sum_tap.py).

<img src="./tap/figs/esm_multi_task.svg" alt="ESM Cross-endpoint multi-task fitting" width="400"/>

In the data preparation script, one can choose between uniform sampling of the data among clients and
heterogeneous data splits using a Dirichlet sampling strategy. 
Here, different values of alpha control the level of heterogeneity. Below, we show a Dirichlet sampling of `alpha=1`.

In [None]:
! cd /bionemo_nvflare_examples/downstream/tap && python prepare_tap_data.py

|                                Uniform sampling                                 |                                    Dirichlet sampling                                     |
|:-------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------:|
| <img src="./tap/figs/tap_uniform.svg" alt="Uniform data sampling" width="300"/> | <img src="./tap/figs/tap_alpha1.0.svg" alt="Dirichlet sampling (alpha=1.0)" width="300"/> |

**Run training (central, local, & FL)**

You can change the FL job that's going to be simulated by changing the arguments of `run_sim_tap.py` script. You can choose which ESM2 model to download (8M or 650M parameters). The ESM2 finetuning arguments such as learning rate and others can be modified inside the script itself.

First, let's check its arguments.

In [None]:
! cd /bionemo_nvflare_examples/downstream/tap && python run_sim_tap.py --help

**1. Central training**

To simulate central training, we use four clients, running one round of training for several steps on a different regression task using the full dataset. Note that if the `--exp_name` argument contains `"central"`, the combined training dataset is used.

In [None]:
! cd /bionemo_nvflare_examples/downstream/tap && python run_sim_tap.py --num_clients=4 --num_rounds=1 --local_steps=1000 --exp_name central

**2. Local training**

To simulate local training, we use four clients, each running one round of training for several steps using the split datasets.

In [None]:
! cd /bionemo_nvflare_examples/downstream/tap && python run_sim_tap.py --num_clients=4 --num_rounds=1 --local_steps=1000 --exp_name local

**3. Federated training with FedAvg**

To simulate federated training, we use four clients, running several rounds with FedAvg, each with a smaller number of local steps. The number of rounds and local steps matches the setting of the local training scenario.

In [None]:
! cd /bionemo_nvflare_examples/downstream/tap && python run_sim_tap.py --num_clients=4 --num_rounds=10 --local_steps=100 --exp_name fedavg

You can visualize the results in TensorBoard using `tensorboard --logdir /tmp/nvflare/bionemo/tap`. Note, that for the FedAvg, you can sort the x-axis by wall-time as each FL round is creating a new TensorBoard output folder.

<div class="alert alert-block alert-info"> <b>NOTE</b> This public dataset is very small, and therefore, we only use it to illustrate the code example. The regression results are likely not reliable in practice. Hence, we skip the visualization here.</div>

---
### Task 2: Cross-compound task fitting

#### Data: Predicting Antibody Developability from Sequence using Machine Learning
See https://tdcommons.ai/single_pred_tasks/develop/#sabdab-chen-et-al
- 2,409 Antibodies (both chains)

#### Task Description: *Binary classification*. 
Given the antibody's heavy chain and light chain sequence, predict its developability. The input X is a list of two sequences where the first is the heavy chain and the second light chain.

In [None]:
# you may need to fix these paths to your own scripts
! cd /bionemo_nvflare_examples/downstream/sabdab && python prepare_sabdab_data.py

Again, we are using the Dirichlet sampling strategy to generate heterogeneous data distributions among clients.
Lower values of `alpha` generate higher levels of heterogeneity.

|                                            Alpha 10.0                                             |                                            Alpha 1.0                                            |
|:-------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------:|
| <img src="./sabdab/figs/sabdab_alpha10.0.svg" alt="Dirichlet sampling (alpha=10.0)" width="150"/> | <img src="./sabdab/figs/sabdab_alpha1.0.svg" alt="Dirichlet sampling (alpha=1.0)" width="150"/> |


**Run training (central, local, & FL)**

You can change the FL job that's going to be simulated by changing the arguments of `run_sim_sabdab.py` script. You can choose which ESM2 model to download (8M or 650M parameters). The ESM2 finetuning arguments such as learning rate and others can be modified inside the script itself.

First, let's check its arguments.

In [None]:
! cd /bionemo_nvflare_examples/downstream/sabdab && python run_sim_sabdab.py --help

**1. Central training**

To simulate central training, we use one client, running one round of training for several steps. Note that if the `--exp_name` argument contains `"central"`, the combined training dataset is used.

In [None]:
! cd /bionemo_nvflare_examples/downstream/sabdab && python run_sim_sabdab.py --num_clients=1 --num_rounds=1 --local_steps=3000 --exp_name central

**2. Local training**

To simulate central training, we use six clients, each running one round of training for several steps.

In [None]:
! cd /bionemo_nvflare_examples/downstream/sabdab && python run_sim_sabdab.py --num_clients=6 --num_rounds=1 --local_steps=3000 --exp_name local

**3. Federated training with FedAvg**

To simulate federated training, we use six clients, running several rounds with FedAvg, each with a smaller number of local steps.

In [None]:
! cd /bionemo_nvflare_examples/downstream/sabdab && python run_sim_sabdab.py --num_clients=6 --num_rounds=10 --local_steps=300 --exp_name fedavg

You can visualize the results in TensorBoard using `tensorboard --logdir /tmp/nvflare/bionemo/sabdab`. Note that for the FedAvg, you can display a continuous training curve streamed to the server by selecting a `server` subfolder.

#### Results with heterogeneous data sampling (alpha=1.0)
| Setting  | Accuracy  |
|:--------:|:---------:|
|  Central |   *0.8504*   |
|  Local   |   0.8099   |
|   FedAvg | **0.8341** |


|                                Central & Local                                 |                                    FedAvg                                     |
|:-------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------:|
| <img src="./sabdab/figs/tb_curve_sabdab_central_local.png" alt="sabdab central and local training" width="600"/> | <img src="./sabdab/figs/tb_curve_sabdab_fedavg.png" alt="sabdab FedAvg training" width="600"/> |

---
### Task 3. Subcellular location prediction with ESM2nv 650M
In this example, we use the `--encoder-frozen` option inside the `run_sim_scl.py` script. You can specify different base ESM2 models using the `--model` option.
Follow the data download and preparation in [task_fitting.ipynb](../task_fitting/task_fitting.ipynb).

Here, we use a heterogeneous sampling with `alpha=1.0`.

<img src="./scl/figs/scl_alpha1.0.svg" alt="Dirichlet sampling (alpha=10.0)" width="300"/>

**1. Local training**

In [None]:
# for this to work run the task_fitting notebook first in ../nvflare_with_bionemo/task_fitting/task_fitting.ipynb in order to download the SCL dataset
!cd /bionemo_nvflare_examples/downstream/scl && python run_sim_scl.py --num_clients=3 --num_rounds=1 --local_steps=5000 --exp_name "local" --model "650m" --sim_gpus="0,1,2"

**2. Federated training with FedAvg**

In [None]:
!cd /bionemo_nvflare_examples/downstream/scl && python run_sim_scl.py --num_clients=3 --num_rounds=10 --local_steps=500 --exp_name "fedavg" --model "650m" --sim_gpus="0,1,2"

You can visualize the results in TensorBoard using `tensorboard --logdir /tmp/nvflare/bionemo/scl`. Note that for the FedAvg, you can display a continuous training curve streamed to the server by selecting a `server` subfolder.

#### Results with heterogeneous data sampling (alpha=1.0)
|  Client   | Site-1  | Site-2 | Site-3 | Average    |
|:---------:|:-------:|:------:|:------:|:----------:|
| # Samples |  1844   | 2921   | 2151   | Accuracy   |
| Local     |  0.7819 |	0.7885 | 0.7921 | 0.7875     |
| FedAvg    |  0.8179 |	0.8131 | 0.8209 | **0.8173** |

<img src="./scl/figs/tb_curve_scl.png" alt="SCL Training curve with Dirichlet sampling (alpha=1.0)" width="400"/>