# Federated Protein Property Prediction with BioNeMo

<div class="alert alert-block alert-info"> <b>NOTE</b> This notebook was tested on a DGX with one A100 GPU with 80 GB memory and is compatible with BioNeMo Framework v2.5. To leverage additional or higher-performance GPUs, you can modify the configuration files and simulation script to accommodate multiple devices and increase thread utilization, respectively. To run with less memory consumption, you can reduce the micro-batch sizes in the `run_*.py` scripts.</div>

This example shows how to fine-tune an ESM-2 finetuned model on a "subcellular location prediction" task.

## Prerequisites

<div class="alert alert-block alert-info"> <b>NOTE:</b> This notebook is designed to run inside the BioNeMo Framework Docker container. Follow these [instructions](https://docs.nvidia.com/ai-enterprise/deployment/vmware/latest/docker.html) to set up your Docker environment and execute the following bash script before opening this notebook.</div>

To set up your environment, simply run (outside this notebook):

```bash
./start_bionemo.sh
```

This script will automatically pull the [BioNeMo Docker container](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/clara/containers/bionemo-framework) (tested with version nvcr.io/nvidia/clara/bionemo-framework:2.5) and launch Jupyter Lab at http://hostname:8888. Open that URL in your browser and access this notebook.

For detailed setup guidance, refer to the [BioNeMo User Guide](https://docs.nvidia.com/bionemo-framework/latest/user-guide/).

Once you open this notebook, continue executing the cells below.

<div class="alert alert-block alert-info"> <b>NOTE:</b> Some cells below produce long outputs. To suppress them, we use:<br><br> <pre>%%capture --no-display --no-stderr cell_output</pre><br> Comment or remove this line to restore full output.</div>

### Import and install all required packages

In [None]:
%%capture --no-display --no-stderr cell_output
! pip install nvflare>=2.6
! pip install biopython --no-dependencies

import io
import os
import warnings
import requests
from Bio import SeqIO

warnings.filterwarnings("ignore")
warnings.simplefilter("ignore")

## Subcellular location prediction with ESM2nv
Here, we are interested in training a neural network to predict subcellular location from an embedding.

The data we will be using comes from the paper ["Light attention predicts protein location from the language of life"](https://academic.oup.com/bioinformaticsadvances/article/1/1/vbab035/6432029) by Stärk et al. In this paper, the authors developed a machine learning algorithm to predict the subcellular location of proteins from sequence through protein language models that are similar to those available in BioNeMo. Protein subcellular location refers to where the protein localizes in the cell; for example, a protein may be expressed in the Nucleus or in the Cytoplasm. Knowing where proteins localize can provide insights into the underlying mechanisms of cellular processes and help identify potential targets for drug development. The following image includes a few examples of subcellular locations in an animal cell:

<img src="https://cdn.pixabay.com/photo/2012/05/07/14/58/cell-48542_1280.png" alt="Subcellular locations" width="500"/>
(Image freely available at https://pixabay.com/images/id-48542)

**Data Splitting**

Here, we use a heterogeneous sampling with `alpha=1.0`. To speed up the runtime and reduce computational resources, we use the [ESM-2nv 8M](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/clara/models/esm2nv8m) parameter model pretrained from BioNeMo.

<img src="./figs/scl_alpha1.0.svg" alt="Dirichlet sampling (alpha=10.0)" width="500"/>

### Data prep
For our target input sequences, we will point to FASTA sequences in a benchmark dataset called Fitness Landscape Inference for Proteins (FLIP). FLIP encompasses experimental data across adeno-associated virus stability for gene therapy, protein domain B1 stability and immunoglobulin binding, and thermostability from multiple protein families.

In [None]:
# Example protein dataset location
fasta_url= "http://data.bioembeddings.com/public/FLIP/fasta/scl/mixed_soft.fasta"

First, we define the source of example protein dataset with the FASTA sequences. This data follows the [biotrainer](https://github.com/sacdallago/biotrainer/blob/main/docs/data_standardization.md) standard, so it includes information about the class in the FASTA header, and the protein sequence. Here are two example sequences in this file:

```
>Sequence1 TARGET=Cell_membrane SET=train VALIDATION=False
MMKTLSSGNCTLNVPAKNSYRMVVLGASRVGKSSIVSRFLNGRFEDQYTPTIEDFHRKVYNIHGDMYQLDILDTSGNHPFPAM
RRLSILTGDVFILVFSLDSRESFDEVKRLQKQILEVKSCLKNKTKEAAELPMVICGNKNDHSELCRQVPAMEAELLVSGDENC
AYFEVSAKKNTNVNEMFYVLFSMAKLPHEMSPALHHKISVQYGDAFHPRPFCMRRTKVAGAYGMVSPFARRPSVNSDLKYIKA
KVLREGQARERDKCSIQ
>Sequence4833 TARGET=Nucleus SET=train VALIDATION=False
MARTKQTARKSTGGKAPRKQLATKAARKSAPATGGVKKPHRFRPGTVALREIRKYQKSTELLIRKLPFQRLVREIAQDFKTDL
RFQSSAVAALQEAAEAYLVGLFEDTNLCAIHAKRVTIMPKDIQLARRIRGERA
Note the following attributes in the FASTA header:
```

* `TARGET` attribute holds the subcellular location classification for the sequence, for instance Cell_membrane and Nucleus. This dataset includes a total of ten subcellelular location classes -- more on that below.
* `SET` attribute defines whether the sequence should be used for training (train) or testing (test)
* `VALIDATION` attribute defines whether the sequence should be used for validation (all sequences where this is True are also in set=train)

### Downloading the protein sequences and subcellular location annotations
In this step we download the FASTA file defined above and parse the sequences into a list of BioPython SeqRecord objects.

In [None]:
# Download the FASTA file from FLIP: https://github.com/J-SNACKKB/FLIP/tree/main/splits/scl
fasta_content = requests.get(fasta_url, headers={
    'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; Win64; x86)'
}).content.decode('utf-8')
fasta_stream = io.StringIO(fasta_content)

# Obtain a list of SeqRecords/proteins which contain sequence and attributes
# from the FASTA header
proteins = list(SeqIO.parse(fasta_stream, "fasta"))
print(f"Downloaded {len(proteins)} sequences")

### Data splitting
Next, we prepare the data for simulating federated learning using `n_clients`. Note that a copy of the same test set is shared between the clients in this example.

In [None]:
import numpy as np
import re
from split_data import split

n_clients = 3
# limiting to the proteins with sequence length<512 for embedding queries
MAX_SEQUENCE_LEN = 512
seed=42
data_root = "/tmp/data/mixed_soft"
split_alpha = 1.0  # moderate label heterogeneity of alpha=1.0

np.random.seed(seed)

# Extract meta data and split
data = []
for i, x in enumerate(proteins):
        if len(str(x.seq)) > MAX_SEQUENCE_LEN:
            continue
            
        entry = {key: value for key, value in re.findall(r"([A-Z_]+)=(-?[A-z0-9]+[.0-9]*)", x.description)}
        entry["sequences"] = str(x.seq)
        entry["id"] = str(i)
        entry["labels"] = entry["TARGET"]
       
        data.append(entry)
print(f"Read {len(data)} valid sequences.")
               
# Split the data and save for each client
# Note, test_data is kept the same on each client and is not split
# `concat=False` is used for SCL experiments (see ../downstream/scl)
split(proteins=data, num_sites=n_clients, split_dir=data_root, alpha=split_alpha, concat=False)  

**Run training (local, & FL)**

As usual, [run_sim_scl.py](./run_sim_scl.py) uses the Job API to configure our job.
You can change the FL job that's going to be simulated by changing the arguments of the run script. You can choose which ESM2 model to download (8M or 650M parameters). The ESM2 finetuning arguments such as learning rate and others can be modified inside the script itself.

First, let's check its arguments.

In [None]:
!python run_sim_scl.py --help

In this example, we use the `--encoder-frozen` option inside the `run_sim_scl.py` script. You can specify different base ESM2 models using the `--model` option.

**1. Local training**

To simulate local training, we use three clients, each running one round of training for several steps using the split datasets.

In [None]:
# for this to work run the task_fitting notebook first in ../nvflare_with_bionemo/task_fitting/task_fitting.ipynb in order to download the SCL dataset, each client will run on the same GPU.
!python run_sim_scl.py --num_clients=3 --num_rounds=1 --local_steps=5000 --exp_name "local" --model "8m" --sim_gpus="0"

**2. Federated training with FedAvg**

To simulate federated training, we use four clients, running several rounds with FedAvg, each with a smaller number of local steps. The number of rounds and local steps matches the setting of the local training scenario.

In [None]:
!python run_sim_scl.py --num_clients=3 --num_rounds=10 --local_steps=500 --exp_name "fedavg" --model "8m" --sim_gpus="0"

You can visualize the results in TensorBoard using `tensorboard --logdir /tmp/nvflare/bionemo/scl`. Note that for the FedAvg, you can display a continuous training curve streamed to the server by selecting a `server` subfolder.

#### Results with heterogeneous data sampling (alpha=1.0)
|  Client   | Site-1  | Site-2 | Site-3 | Average    |
|:---------:|:-------:|:------:|:------:|:----------:|
| # Samples |  1844   | 2921   | 2151   | Accuracy   |
| Local     |  0.7819 |	0.7885 | 0.7921 | 0.7875     |
| FedAvg    |  0.8179 |	0.8131 | 0.8209 | **0.8173** |

<img src="./figs/tb_curve_scl.png" alt="SCL Training curve with Dirichlet sampling (alpha=1.0)" width="600"/>

## Summary

In this section, we explored the application of federated learning to protein subcellular location prediction using NVIDIA's BioNeMo Framework. Here are the key takeaways:

* We tackled the challenge of predicting protein subcellular locations from sequence data. In a similar faction, other crucial task for biopharma and drug development applications could be developed.

* We compared both local and federated (FedAvg) training approaches with the ESM-2 8M parameter model from BioNeMo.

**Key Learnings**:
   - Federated learning can effectively improve protein property prediction
   - Collaborative learning benefits all participating sites
   - BioNeMo Framework provides powerful tools for biological sequence analysis

This example demonstrates how federated learning can be applied to healthcare and life sciences applications, enabling collaborative model development while maintaining data privacy.

In the next [section](../11.2.2_drug_discovery_amplify/finetuning_amplify.ipynb), we'll learn how to fine-tune AMPLIFY protein language model on multiple downstream tasks.