# Federated Protein Downstream Fine-tuning

<div class="alert alert-block alert-info"> <b>NOTE</b> This notebook was tested on a single A1000 GPU and is compatible with BioNeMo Framework v1.8. To leverage additional or higher-performance GPUs, you can modify the configuration files and simulation script to accommodate multiple devices and increase thread utilization respectively.</div>

The example datasets used here are made available by [Therapeutics Data Commons](https://tdcommons.ai/) through PyTDC.

This example shows three different downstream tasks for fine-tuning a BioNeMo ESM-style model on different datasets.
We separate the scripts and job configurations into three folders based on the dataset names:


1. `tap`: therapeutic antibody profiling"
2. `sabdab`: SAbDab: the structural antibody database"
3. `scl`: "subcellular location prediction"

## Setup

Ensure that you have read through the Getting Started section, can run the BioNeMo Framework docker container, and have configured the NGC Command Line Interface (CLI) within the container. It is assumed that this notebook is being executed from within the container.

<div class="alert alert-block alert-info"> <b>NOTE</b> Some of the cells below generate long text output.  We're using <pre>%%capture --no-display --no-stderr cell_output</pre> to suppress this output.  Comment or delete this line in the cells below to restore full output.</div>

### Import and install all required packages

In [1]:
%%capture --no-display --no-stderr cell_output
! pip install PyTDC
! pip install nvflare~=2.5.0
! pip install biopython
! pip install scikit-learn
! pip install matplotlib
! pip install protobuf==3.20
! pip install huggingface-hub==0.22.0

import os
import warnings


warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

#### Home Directory

In [2]:
bionemo_home = "/workspace/bionemo"
os.environ['BIONEMO_HOME'] = bionemo_home

### Download Model Checkpoints

In order to download pretrained models from the NGC registry, **please ensure that you have installed and configured the NGC CLI**, check the [Quickstart Guide](https://docs.nvidia.com/bionemo-framework/latest/quickstart-fw.html) for more info. The following code will download the pretrained model `esm2nv_650M_converted.nemo` from the NGC registry.

In [3]:
# Define the NGC CLI API KEY and ORG for the model download
# If these variables are not already set in the container, uncomment below
# to define and set with your API KEY and ORG
api_key = "OHIzbnYxcGlmY2t1ZmM3ZjhybnEzcG9lMTE6YmZjOWE0MTMtMWViYy00MjY5LWFhNDQtYWQ4Y2VlYTg4YTkx"
ngc_cli_org = "dlmed"
# # Update the environment variable
os.environ['NGC_CLI_API_KEY'] = api_key
os.environ['NGC_CLI_ORG'] = ngc_cli_org

# Set variables and paths for model and checkpoint
model_name = "esm2nv_650m" # "esm1nv"  
actual_checkpoint_name = "esm2nv_650M_converted.nemo" # "esm1nv.nemo"
model_path = os.path.join(bionemo_home, 'models')
checkpoint_path = os.path.join(model_path, actual_checkpoint_name)
os.environ['MODEL_PATH'] = model_path

In [4]:
%%capture --no-display --no-stderr cell_output
if not os.path.exists(checkpoint_path):
    !cd /workspace/bionemo && \
    python download_artifacts.py --model_dir models --models {model_name}
else:
    print(f"Model {model_name} already exists at {model_path}.")

Again for esm1nv: 

In [5]:
model_name = "esm1nv"
actual_checkpoint_name = "esm1nv.nemo"
model_path = os.path.join(bionemo_home, 'models')
checkpoint_path = os.path.join(model_path, actual_checkpoint_name)
os.environ['MODEL_PATH'] = model_path

In [6]:
%%capture --no-display --no-stderr cell_output
if not os.path.exists(checkpoint_path):
    !cd /workspace/bionemo && \
    python download_artifacts.py --model_dir models --models {model_name}
else:
    print(f"Model {model_name} already exists at {model_path}.")

### Task 1: Cross-endpoint multi-task fitting

#### Data: Five computational developability guidelines for therapeutic antibody profiling
See https://tdcommons.ai/single_pred_tasks/develop/#tap
- 241 Antibodies (both chains)

#### Task Description: *Regression*. 
Given the antibody's heavy chain and light chain sequence, predict its developability. The input X is a list of two sequences where the first is the heavy chain and the second light chain.

Includes five metrics measuring developability of an antibody: 
 - Complementarity-determining regions (CDR) length - Trivial (excluded)
 - patches of surface hydrophobicity (PSH)
 - patches of positive charge (PPC)
 - patches of negative charge (PNC)
 - structural Fv charge symmetry parameter (SFvCSP)

In the data preparation script, one can choose between uniform sampling of the data among clients and
heterogeneous data splits using a Dirichlet sampling strategy. 
Here, different values of alpha control the level of heterogeneity. Below, we show a Dirichlet sampling of `alpha=1`.

In [7]:
! cd /bionemo_nvflare_examples/downstream/tap && python prepare_tap_data.py

Found local copy...
Loading...
Done!
Found local copy...
Loading...
Done!
Found local copy...
Loading...
Done!
Found local copy...
Loading...
Done!
Found local copy...
Loading...
Done!
Sampling with alpha=1.0
Save 12 training proteins for site-1 (frac=0.064)
Save 57 training proteins for site-2 (frac=0.295)
Save 34 training proteins for site-3 (frac=0.174)
Save 90 training proteins for site-4 (frac=0.466)
Saved 193 training and 48 testing proteins.
[[       nan 0.07017544 0.08823529 0.04444444]
 [       nan        nan 0.41176471 0.27777778]
 [       nan        nan        nan 0.2       ]
 [       nan        nan        nan        nan]]
Avg. overlap: 18.21%


|                                Uniform sampling                                 |                                    Dirichlet sampling                                     |
|:-------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------:|
| <img src="./tap/figs/tap_uniform.svg" alt="Uniform data sampling" width="150"/> | <img src="./tap/figs/tap_alpha1.0.svg" alt="Dirichlet sampling (alpha=1.0)" width="150"/> |


**Run training (central, local, & FL)**

You can change the FL job that's going to be simulated inside the `run_sim_tap.py` script.

In [None]:
! cd /bionemo_nvflare_examples/downstream/tap && python run_sim_tap.py

2024-10-04 21:33:29,665 - SimulatorRunner - INFO - Create the Simulator Server.
2024-10-04 21:33:29,666 - CoreCell - INFO - server: creating listener on tcp://0:55505
2024-10-04 21:33:29,686 - CoreCell - INFO - server: created backbone external listener for tcp://0:55505
2024-10-04 21:33:29,686 - ConnectorManager - INFO - 18052: Try start_listener Listener resources: {'secure': False, 'host': 'localhost'}
2024-10-04 21:33:29,686 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connector [CH00002 PASSIVE tcp://0:5937] is starting
2024-10-04 21:33:30,188 - CoreCell - INFO - server: created backbone internal listener for tcp://localhost:5937
2024-10-04 21:33:30,188 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connector [CH00001 PASSIVE tcp://0:55505] is starting
2024-10-04 21:33:30,272 - nvflare.fuel.hci.server.hci - INFO - Starting Admin Server localhost on Port 34843
2024-10-04 21:33:30,272 - SimulatorRunner - INFO - Deploy the Apps.
2024-10-04 21:33:30,278 - SimulatorRunner - INFO - Create t

### Task 2: Cross-compound task fitting

#### Data: Predicting Antibody Developability from Sequence using Machine Learning
See https://tdcommons.ai/single_pred_tasks/develop/#sabdab-chen-et-al
- 2,409 Antibodies (both chains)

#### Task Description: *Binary classification*. 
Given the antibody's heavy chain and light chain sequence, predict its developability. The input X is a list of two sequences where the first is the heavy chain and the second light chain.

In [9]:
 # you may need to fix these paths to your own scripts
! cd /bionemo_nvflare_examples/downstream/sabdab && python prepare_sabdab_data.py

Downloading...
100%|████████████████████████████████████████| 601k/601k [00:00<00:00, 797kiB/s]
Loading...
Done!
Sampling with alpha=1.0
Save 80 training proteins for site-1 (frac=0.041)
Save 365 training proteins for site-2 (frac=0.190)
Save 216 training proteins for site-3 (frac=0.112)
Save 578 training proteins for site-4 (frac=0.300)
Save 568 training proteins for site-5 (frac=0.295)
Save 119 training proteins for site-6 (frac=0.062)
Saved 1927 training and 482 testing proteins.
  TRAIN Pos/Neg ratio: neg=366, pos=1561: 4.265
  TRAIN Trivial accuracy: 0.810
  TEST Pos/Neg ratio: neg=116, pos=366: 3.155
  TEST Trivial accuracy: 0.759
[[       nan 0.04657534 0.02314815 0.0449827  0.04929577 0.03361345]
 [       nan        nan 0.18055556 0.17128028 0.19542254 0.18487395]
 [       nan        nan        nan 0.11591696 0.10211268 0.08403361]
 [       nan        nan        nan        nan 0.28521127 0.32773109]
 [       nan        nan        nan        nan        nan 0.28571429]
 [       n

Again, we are using the Dirichlet sampling strategy to generate heterogeneous data distributions among clients.
Lower values of `alpha` generate higher levels of heterogeneity.

|                                            Alpha 10.0                                             |                                            Alpha 1.0                                            |
|:-------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------:|
| <img src="./sabdab/figs/sabdab_alpha10.0.svg" alt="Dirichlet sampling (alpha=10.0)" width="150"/> | <img src="./sabdab/figs/sabdab_alpha1.0.svg" alt="Dirichlet sampling (alpha=1.0)" width="150"/> |


**Run training (central, local, & FL)**

You can change the FL job that's going to be simulated inside the `run_sim_sabdab.py` script.

In [10]:
! cd /bionemo_nvflare_examples/downstream/sabdab && python run_sim_sabdab.py

Traceback (most recent call last):
  File "/bionemo_nvflare_examples/downstream/sabdab/run_sim_sabdab.py", line 15, in <module>
    from nvflare import SimulatorRunner
ModuleNotFoundError: No module named 'nvflare'


#### Results with heterogeneous data sampling (alpha=10.0)
| Setting | Accuracy  |
|:-------:|:---------:|
|  Local  |   0.821   |
|   FL    | **0.833** |

#### Results with heterogeneous data sampling (alpha=1.0)
| Setting | Accuracy  |
|:-------:|:---------:|
|  Local  |   0.813   |
|   FL    | **0.835** |

### Task 3. Subcellular location prediction with ESM2nv 650M
Follow the data download and preparation in [task_fitting.ipynb](../task_fitting/task_fitting.ipynb).

Here, we use a heterogeneous sampling with `alpha=1.0`.

<img src="./scl/figs/scl_alpha1.0.svg" alt="Dirichlet sampling (alpha=10.0)" width="300"/>


In [11]:
# for this to work run the task_fitting notebook first in ../nvflare_with_bionemo/task_fitting/task_fitting.ipynb
! cd /bionemo_nvflare_examples/downstream/scl && python run_sim_scl.py

Traceback (most recent call last):
  File "/bionemo_nvflare_examples/downstream/scl/run_sim_scl.py", line 15, in <module>
    from nvflare import SimulatorRunner
ModuleNotFoundError: No module named 'nvflare'


Note, you can switch between local and FL jobs by modifying the `run_sim_scl.py` script.

#### Results with heterogeneous data sampling (alpha=10.0)
| Setting | Accuracy  |
|:-------:|:---------:|
|  Local  |   0.773   |
|   FL    | **0.776** |


<img src="./scl/figs/scl_results.svg" alt="Dirichlet sampling (alpha=1.0)" width="300"/>