# Federated Protein Embeddings and Task Model Fitting with BioNeMo

<div class="alert alert-block alert-info"> <b>NOTE</b> This notebook was tested on a single A1000 GPU and is compatible with BioNeMo Framework v1.8. To leverage additional or higher-performance GPUs, you can modify the configuration files and simulation script to accommodate multiple devices and increase thread utilization respectively. </div>

This example notebook shows how to obtain protein learned representations in the form of embeddings using the ESM-1nv pre-trained model in a federated learning (FL) setting. The model is trained with NVIDIA's BioNeMo framework for Large Language Model training and inference. For more details, please visit NVIDIA BioNeMo Service at https://www.nvidia.com/en-us/gpu-cloud/bionemo.

This example is based on NVIDIA BioNeMo Service [example](https://github.com/NVIDIA/BioNeMo/blob/main/examples/service/notebooks/task-fitting-predictor.ipynb) 
but runs inference locally (on the FL clients) instead of using BioNeMo's cloud API.

This notebook will walk you through the task fitting workflow in the following sections:

* Dataset sourcing & Data splitting
* Federated embedding extraction
* Training a MLP to predict subcellular location

## Setup

Ensure that you have read through the Getting Started section, can run the BioNeMo Framework docker container, and have configured the NGC Command Line Interface (CLI) within the container. It is assumed that this notebook is being executed from within the container.

<div class="alert alert-block alert-info"> <b>NOTE</b> Some of the cells below generate long text output.  We're using <pre>%%capture --no-display --no-stderr cell_output</pre> to suppress this output.  Comment or delete this line in the cells below to restore full output.</div>

### Import and install all required packages

In [1]:
#%%capture --no-display --no-stderr cell_output
! pip install nvflare~=2.5.0
! pip install biopython
! pip install scikit-learn
! pip install matplotlib
! pip install protobuf==3.20

import io
import numpy as np
import os
import pickle
import re
import requests
import split_data

from Bio import SeqIO
from importlib import reload
from nvflare import SimulatorRunner  
from split_data import split

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting nvflare~=2.5.0
  Downloading nvflare-2.5.0-py3-none-any.whl.metadata (11 kB)
Collecting cryptography>=36.0.0 (from nvflare~=2.5.0)
  Downloading cryptography-43.0.1-cp39-abi3-manylinux_2_28_x86_64.whl.metadata (5.4 kB)
Collecting Flask==3.0.2 (from nvflare~=2.5.0)
  Downloading flask-3.0.2-py3-none-any.whl.metadata (3.6 kB)
Collecting Flask-JWT-Extended==4.6.0 (from nvflare~=2.5.0)
  Downloading Flask_JWT_Extended-4.6.0-py2.py3-none-any.whl.metadata (3.9 kB)
Collecting Flask-SQLAlchemy==3.1.1 (from nvflare~=2.5.0)
  Downloading flask_sqlalchemy-3.1.1-py3-none-any.whl.metadata (3.4 kB)
Collecting SQLAlchemy==2.0.16 (from nvflare~=2.5.0)
  Downloading SQLAlchemy-2.0.16-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.4 kB)
Collecting grpcio>=1.62.1 (from nvflare~=2.5.0)
  Downloading grpcio-1.66.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.9 kB)
Collecting

### Obtaining the protein embeddings using the BioNeMo ESM-1nv model
Using BioNeMo, each FL client can obtain numerical vector representations of protein sequences called embeddings. Protein embeddings can then be used for visualization or making downstream predictions.

Here we are interested in training a neural network to predict subcellular location from an embedding.

The data we will be using comes from the paper [Light attention predicts protein location from the language of life](https://academic.oup.com/bioinformaticsadvances/article/1/1/vbab035/6432029) by Stärk et al. In this paper, the authors developed a machine learning algorithm to predict the subcellular location of proteins from sequence through protein langage models that are similar to those hosted by BioNeMo. Protein subcellular location refers to where the protein localizes in the cell, for example a protein my be expressed in the Nucleus or in the Cytoplasm. Knowing where proteins localize can provide insights into the underlying mechanisms of cellular processes and help identify potential targets for drug development. The following image includes a few examples of subcellular locations in an animal cell:


(Image freely available at https://pixabay.com/images/id-48542)

### Dataset sourcing
For our target input sequences, we will point to FASTA sequences in a benchmark dataset called Fitness Landscape Inference for Proteins (FLIP). FLIP encompasses experimental data across adeno-associated virus stability for gene therapy, protein domain B1 stability and immunoglobulin binding, and thermostability from multiple protein families.

In [2]:
# Example protein dataset location
fasta_url= "http://data.bioembeddings.com/public/FLIP/fasta/scl/mixed_soft.fasta"

First, we define the source of example protein dataset with the FASTA sequences. This data follows the [biotrainer](https://github.com/sacdallago/biotrainer/blob/main/docs/data_standardization.md) standard, so it includes information about the class in the FASTA header, and the protein sequence. Here are two example sequences in this file:

```
>Sequence1 TARGET=Cell_membrane SET=train VALIDATION=False
MMKTLSSGNCTLNVPAKNSYRMVVLGASRVGKSSIVSRFLNGRFEDQYTPTIEDFHRKVYNIHGDMYQLDILDTSGNHPFPAM
RRLSILTGDVFILVFSLDSRESFDEVKRLQKQILEVKSCLKNKTKEAAELPMVICGNKNDHSELCRQVPAMEAELLVSGDENC
AYFEVSAKKNTNVNEMFYVLFSMAKLPHEMSPALHHKISVQYGDAFHPRPFCMRRTKVAGAYGMVSPFARRPSVNSDLKYIKA
KVLREGQARERDKCSIQ
>Sequence4833 TARGET=Nucleus SET=train VALIDATION=False
MARTKQTARKSTGGKAPRKQLATKAARKSAPATGGVKKPHRFRPGTVALREIRKYQKSTELLIRKLPFQRLVREIAQDFKTDL
RFQSSAVAALQEAAEAYLVGLFEDTNLCAIHAKRVTIMPKDIQLARRIRGERA
Note the following attributes in the FASTA header:
```

* `TARGET` attribute holds the subcellular location classification for the sequence, for instance Cell_membrane and Nucleus. This dataset includes a total of ten subcellelular location classes -- more on that below.
* `SET` attribute defines whether the sequence should be used for training (train) or testing (test)
* `VALIDATION` attribute defines whether the sequence should be used for validation (all sequences where this is True are also in set=train)

### Downloading the protein sequences and subcellular location annotations
In this step we download the FASTA file defined above and parse the sequences into a list of BioPython SeqRecord objects.



In [3]:
# Download the FASTA file from FLIP: https://github.com/J-SNACKKB/FLIP/tree/main/splits/scl
fasta_content = requests.get(fasta_url, headers={
    'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; Win64; x86)'
}).content.decode('utf-8')
fasta_stream = io.StringIO(fasta_content)

# Obtain a list of SeqRecords/proteins which contain sequence and attributes
# from the FASTA header
proteins = list(SeqIO.parse(fasta_stream, "fasta"))
print(f"Downloaded {len(proteins)} sequences")

Downloaded 13949 sequences


In [4]:
bionemo_home = "/workspace/bionemo"
os.environ['BIONEMO_HOME'] = bionemo_home

### Download Model Checkpoints

In order to download pretrained models from the NGC registry, **please ensure that you have installed and configured the NGC CLI**, check the [Quickstart Guide](https://docs.nvidia.com/bionemo-framework/latest/quickstart-fw.html) for more info. The following code will download the pretrained model `esm2nv_650M_converted.nemo` from the NGC registry.

In [5]:
# Define the NGC CLI API KEY and ORG for the model download
# If these variables are not already set in the container, uncomment below
# to define and set with your API KEY and ORG
api_key = "OHIzbnYxcGlmY2t1ZmM3ZjhybnEzcG9lMTE6YmZjOWE0MTMtMWViYy00MjY5LWFhNDQtYWQ4Y2VlYTg4YTkx"
ngc_cli_org = "dlmed"
# # Update the environment variable
os.environ['NGC_CLI_API_KEY'] = api_key
os.environ['NGC_CLI_ORG'] = ngc_cli_org

model_name = "esm1nv"
actual_checkpoint_name = "esm1nv.nemo"
model_path = os.path.join(bionemo_home, 'models')
checkpoint_path = os.path.join(model_path, actual_checkpoint_name)
os.environ['MODEL_PATH'] = model_path

In [6]:
%%capture --no-display --no-stderr cell_output
if not os.path.exists(checkpoint_path):
    !cd /workspace/bionemo && \
    python download_artifacts.py --model_dir models --models {model_name}
else:
    print(f"Model {model_name} already exists at {model_path}.")

### Data splitting
Next, we prepare the data for simulating federated learning using `n_clients`.

In [7]:
n_clients = 3
# limiting to the proteins with sequence length<512 for embedding queries
MAX_SEQUENCE_LEN = 512
seed=0
out_dir = "/tmp/data/mixed_soft"
split_alpha = 100.0  # moderate label heterogeneity of alpha=1.0

reload(split_data)

np.random.seed(seed)

# Extract meta data and split
data = []
for i, x in enumerate(proteins):
        if len(str(x.seq)) > MAX_SEQUENCE_LEN:
            continue
            
        entry = {key: value for key, value in re.findall(r"([A-Z_]+)=(-?[A-z0-9]+[.0-9]*)", x.description)}
        entry["sequence"] = str(x.seq)
        entry["id"] = str(i)
       
        data.append(entry)
print(f"Read {len(data)} valid sequences.")
               
# Split the data and save for each client
# Note, test_data is kept the same on each client and is not split
# `concat=False` is used for SCL experiments (see ../downstream/scl)
split(proteins=data, num_sites=n_clients, split_dir=out_dir, alpha=split_alpha, concat=False)  
# `concat=True` is used for separate inference + MLP classifier in this notebook
split(proteins=data, num_sites=n_clients, split_dir=out_dir, alpha=split_alpha, concat=True)  

Read 8619 valid sequences.
Partition protein dataset with 10 classes into 3 sites with Dirichlet sampling under alpha 100.0
{'site-1': {'Cell_membrane': 178,
            'Cytoplasm': 402,
            'Endoplasmic_reticulum': 160,
            'Extracellular': 461,
            'Golgi_apparatus': 56,
            'Lysosome': 51,
            'Mitochondrion': 324,
            'Nucleus': 587,
            'Peroxisome': 30,
            'Plastid': 142},
 'site-2': {'Cell_membrane': 156,
            'Cytoplasm': 354,
            'Endoplasmic_reticulum': 140,
            'Extracellular': 404,
            'Golgi_apparatus': 50,
            'Lysosome': 45,
            'Mitochondrion': 285,
            'Nucleus': 515,
            'Peroxisome': 27,
            'Plastid': 125},
 'site-3': {'Cell_membrane': 180,
            'Cytoplasm': 407,
            'Endoplasmic_reticulum': 162,
            'Extracellular': 467,
            'Golgi_apparatus': 58,
            'Lysosome': 53,
            'Mitochondrio

### Federated embedding extraction
Running inference of the ESM-1nv model to extract embeddings requires a GPU with at least 12 GB memory. Here we run inference on each client sequentially using one thread to preserve GPU memory.

First, copy the model into the job folder

In [8]:
!cp /workspace/bionemo/models/esm1nv.nemo jobs/embeddings/app/models/.

In [9]:
simulator = SimulatorRunner(
    job_folder="jobs/embeddings",
    workspace="/tmp/nvflare/bionemo/embeddings",
    n_clients=n_clients,
    threads=1  # due to memory constraints, we run the client execution sequentially in one thread
)
run_status = simulator.run()
print("Simulator finished with run_status", run_status)

2024-10-04 20:49:50,145 - SimulatorRunner - INFO - Create the Simulator Server.
2024-10-04 20:49:50,150 - CoreCell - INFO - server: creating listener on tcp://0:50447
2024-10-04 20:49:50,174 - CoreCell - INFO - server: created backbone external listener for tcp://0:50447
2024-10-04 20:49:50,175 - ConnectorManager - INFO - 773: Try start_listener Listener resources: {'secure': False, 'host': 'localhost'}
2024-10-04 20:49:50,178 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connector [CH00002 PASSIVE tcp://0:55960] is starting
2024-10-04 20:49:50,681 - CoreCell - INFO - server: created backbone internal listener for tcp://localhost:55960
2024-10-04 20:49:50,684 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connector [CH00001 PASSIVE tcp://0:50447] is starting
2024-10-04 20:49:50,767 - nvflare.fuel.hci.server.hci - INFO - Starting Admin Server localhost on Port 42679
2024-10-04 20:49:50,767 - SimulatorRunner - INFO - Deploy the Apps.
2024-10-04 20:49:51,258 - SimulatorRunner - INFO - Create t

[NeMo W 2024-10-04 20:50:10 save_restore_connector:394] src path does not exist or it is not a path in nemo file. src value I got was: /tokenizers/vocab/protein_sequence_sentencepiece.vocab. Absolute: /tokenizers/vocab/protein_sequence_sentencepiece.vocab
[NeMo W 2024-10-04 20:50:10 megatron_base_model:1109] The model: ESM1nvModel() does not have field.name: context_parallel_size in its cfg. Add this key to cfg or config_mapping to make to make it configurable.
[NeMo W 2024-10-04 20:50:10 megatron_base_model:1109] The model: ESM1nvModel() does not have field.name: virtual_pipeline_model_parallel_size in its cfg. Add this key to cfg or config_mapping to make to make it configurable.
[NeMo W 2024-10-04 20:50:10 megatron_base_model:1109] The model: ESM1nvModel() does not have field.name: sequence_parallel in its cfg. Add this key to cfg or config_mapping to make to make it configurable.
[NeMo W 2024-10-04 20:50:10 megatron_base_model:1109] The model: ESM1nvModel() does not have field.name

[NeMo I 2024-10-04 20:50:10 megatron_init:251] Rank 0 has data parallel group : [0]
[NeMo I 2024-10-04 20:50:10 megatron_init:257] Rank 0 has combined group of data parallel and context parallel : [0]
[NeMo I 2024-10-04 20:50:10 megatron_init:262] All data parallel group ranks with context parallel combined: [[0]]
[NeMo I 2024-10-04 20:50:10 megatron_init:265] Ranks 0 has data parallel rank: 0
[NeMo I 2024-10-04 20:50:10 megatron_init:282] Rank 0 has context parallel group: [0]
[NeMo I 2024-10-04 20:50:10 megatron_init:285] All context parallel group ranks: [[0]]
[NeMo I 2024-10-04 20:50:10 megatron_init:286] Ranks 0 has context parallel rank: 0
[NeMo I 2024-10-04 20:50:10 megatron_init:297] Rank 0 has model parallel group: [0]
[NeMo I 2024-10-04 20:50:10 megatron_init:298] All model parallel group ranks: [[0]]
[NeMo I 2024-10-04 20:50:10 megatron_init:308] Rank 0 has tensor model parallel group: [0]
[NeMo I 2024-10-04 20:50:10 megatron_init:312] All tensor model parallel group ranks: 

[NeMo W 2024-10-04 20:50:50 save_restore_connector:394] src path does not exist or it is not a path in nemo file. src value I got was: /tokenizers/vocab/protein_sequence_sentencepiece.vocab. Absolute: /tokenizers/vocab/protein_sequence_sentencepiece.vocab
[NeMo W 2024-10-04 20:50:50 megatron_base_model:1109] The model: ESM1nvModel() does not have field.name: context_parallel_size in its cfg. Add this key to cfg or config_mapping to make to make it configurable.
[NeMo W 2024-10-04 20:50:50 megatron_base_model:1109] The model: ESM1nvModel() does not have field.name: virtual_pipeline_model_parallel_size in its cfg. Add this key to cfg or config_mapping to make to make it configurable.
[NeMo W 2024-10-04 20:50:50 megatron_base_model:1109] The model: ESM1nvModel() does not have field.name: sequence_parallel in its cfg. Add this key to cfg or config_mapping to make to make it configurable.
[NeMo W 2024-10-04 20:50:50 megatron_base_model:1109] The model: ESM1nvModel() does not have field.name

[NeMo I 2024-10-04 20:50:50 megatron_init:251] Rank 0 has data parallel group : [0]
[NeMo I 2024-10-04 20:50:50 megatron_init:257] Rank 0 has combined group of data parallel and context parallel : [0]
[NeMo I 2024-10-04 20:50:50 megatron_init:262] All data parallel group ranks with context parallel combined: [[0]]
[NeMo I 2024-10-04 20:50:50 megatron_init:265] Ranks 0 has data parallel rank: 0
[NeMo I 2024-10-04 20:50:50 megatron_init:282] Rank 0 has context parallel group: [0]
[NeMo I 2024-10-04 20:50:50 megatron_init:285] All context parallel group ranks: [[0]]
[NeMo I 2024-10-04 20:50:50 megatron_init:286] Ranks 0 has context parallel rank: 0
[NeMo I 2024-10-04 20:50:50 megatron_init:297] Rank 0 has model parallel group: [0]
[NeMo I 2024-10-04 20:50:50 megatron_init:298] All model parallel group ranks: [[0]]
[NeMo I 2024-10-04 20:50:50 megatron_init:308] Rank 0 has tensor model parallel group: [0]
[NeMo I 2024-10-04 20:50:50 megatron_init:312] All tensor model parallel group ranks: 

[NeMo W 2024-10-04 20:51:29 save_restore_connector:394] src path does not exist or it is not a path in nemo file. src value I got was: /tokenizers/vocab/protein_sequence_sentencepiece.vocab. Absolute: /tokenizers/vocab/protein_sequence_sentencepiece.vocab
[NeMo W 2024-10-04 20:51:29 megatron_base_model:1109] The model: ESM1nvModel() does not have field.name: context_parallel_size in its cfg. Add this key to cfg or config_mapping to make to make it configurable.
[NeMo W 2024-10-04 20:51:29 megatron_base_model:1109] The model: ESM1nvModel() does not have field.name: virtual_pipeline_model_parallel_size in its cfg. Add this key to cfg or config_mapping to make to make it configurable.
[NeMo W 2024-10-04 20:51:29 megatron_base_model:1109] The model: ESM1nvModel() does not have field.name: sequence_parallel in its cfg. Add this key to cfg or config_mapping to make to make it configurable.
[NeMo W 2024-10-04 20:51:29 megatron_base_model:1109] The model: ESM1nvModel() does not have field.name

[NeMo I 2024-10-04 20:51:29 megatron_init:251] Rank 0 has data parallel group : [0]
[NeMo I 2024-10-04 20:51:29 megatron_init:257] Rank 0 has combined group of data parallel and context parallel : [0]
[NeMo I 2024-10-04 20:51:29 megatron_init:262] All data parallel group ranks with context parallel combined: [[0]]
[NeMo I 2024-10-04 20:51:29 megatron_init:265] Ranks 0 has data parallel rank: 0
[NeMo I 2024-10-04 20:51:29 megatron_init:282] Rank 0 has context parallel group: [0]
[NeMo I 2024-10-04 20:51:29 megatron_init:285] All context parallel group ranks: [[0]]
[NeMo I 2024-10-04 20:51:29 megatron_init:286] Ranks 0 has context parallel rank: 0
[NeMo I 2024-10-04 20:51:29 megatron_init:297] Rank 0 has model parallel group: [0]
[NeMo I 2024-10-04 20:51:29 megatron_init:298] All model parallel group ranks: [[0]]
[NeMo I 2024-10-04 20:51:29 megatron_init:308] Rank 0 has tensor model parallel group: [0]
[NeMo I 2024-10-04 20:51:29 megatron_init:312] All tensor model parallel group ranks: 

### Inspecting the embeddings and labels
Embeddings returned from the BioNeMo model are vectors of fixed size for each input sequence. In other words, if we input 10 sequences, we will obtain a matrix `10xD`, where `D` is the size of the embedding (in the case of ESM-1nv, `D=768`). At a glance, these real-valued vector embeddings don't show any obvious features (see the printout in the next cell). But these vectors do contain information that can be used in downstream models to reveal properties of the protein, for example the subcellular location as we'll explore below.

In [10]:
# load embeddings from site-1
protein_embeddings = pickle.load(open(os.path.join(out_dir, "data_site-1.pkl"), "rb"))
print(f"Loaded {len(protein_embeddings)} embeddings from site-1.")

for i in range(4):
    protein_embedding = protein_embeddings[i]
    print(f"Inference result contains {list(protein_embedding.keys())}")
    x = protein_embedding["embeddings"]
    print(f"{protein_embedding['id']}: range {np.min(x)}-{np.max(x)}, mean={np.mean(x)}, shape={x.shape}")

Loaded 4072 embeddings from site-1.
Inference result contains ['embeddings', 'hiddens', 'sequence', 'id']
1918: range -0.81640625-1.8037109375, mean=-0.0017579937120899558, shape=(768,)
Inference result contains ['embeddings', 'hiddens', 'sequence', 'id']
1325: range -1.302734375-1.5400390625, mean=-0.003067302517592907, shape=(768,)
Inference result contains ['embeddings', 'hiddens', 'sequence', 'id']
4345: range -0.79443359375-1.3369140625, mean=-0.0012353993952274323, shape=(768,)
Inference result contains ['embeddings', 'hiddens', 'sequence', 'id']
8030: range -0.84375-1.158203125, mean=0.0007279149140231311, shape=(768,)


Let's enumerate the labels corresponding to potential subcellular locations.

In [11]:
# Let's also print all the labels

labels = set([entry['TARGET'] for entry in data])

for i, label in enumerate(labels):
    print(f"{i+1}. {label.replace('_', ' ')}")

1. Cytoplasm
2. Plastid
3. Cell membrane
4. Endoplasmic reticulum
5. Extracellular
6. Nucleus
7. Golgi apparatus
8. Lysosome
9. Peroxisome
10. Mitochondrion


### Training a MLP to predict subcellular location
To be able to classify proteins for their subcellular location, we train a simple scikit-learn Multi-layer Perceptron (MPL) classifier using Federated Averaging ([FedAvg](https://arxiv.org/abs/1602.05629)). The MLP model uses a network of hidden layers to fit the input embedding vectors to the model classes (the cellular locations above). In the simulation below, we define the MLP to use the Adam optimizer with a network of (512, 256, 128) hidden layers, defining a random state (or seed) for reproducibility, and trained for 30 rounds of FedAvg (see [config_fed_server.json](./jobs/fedavg/app/config/config_fed_server.json)). 

We can use the same configuration also to simulate local training where each client is only training with their own data by setting `os.environ["SIM_LOCAL"] = "True"`. Our [BioNeMoMLPLearner](./jobs/fedavg/app/custom/bionemo_mlp_learner.py) will then ignore the global weights coming from the server.

### Local training

In [12]:
os.environ["SIM_LOCAL"] = "True"

simulator = SimulatorRunner(
    job_folder="jobs/fedavg",
    workspace=f"/tmp/nvflare/bionemo/local_alpha{split_alpha}",
    n_clients=n_clients,
    threads=n_clients
)
run_status = simulator.run()
print("Simulator finished with run_status", run_status)

2024-10-04 20:52:15,207 - SimulatorRunner - INFO - Create the Simulator Server.
2024-10-04 20:52:15,211 - CoreCell - INFO - server: creating listener on tcp://0:37559
2024-10-04 20:52:15,239 - CoreCell - INFO - server: created backbone external listener for tcp://0:37559
2024-10-04 20:52:15,240 - ConnectorManager - INFO - 2801: Try start_listener Listener resources: {'secure': False, 'host': 'localhost'}
2024-10-04 20:52:15,242 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connector [CH00002 PASSIVE tcp://0:43680] is starting
2024-10-04 20:52:15,745 - CoreCell - INFO - server: created backbone internal listener for tcp://localhost:43680
2024-10-04 20:52:15,748 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connector [CH00001 PASSIVE tcp://0:37559] is starting
2024-10-04 20:52:15,825 - nvflare.fuel.hci.server.hci - INFO - Starting Admin Server localhost on Port 59779
2024-10-04 20:52:15,826 - SimulatorRunner - INFO - Deploy the Apps.
2024-10-04 20:52:15,867 - SimulatorRunner - INFO - Create 



2024-10-04 20:52:17,123 - BioNeMoMLPModelPersistor - INFO - [identity=simulator_server, run=simulate_job]: MLPClassifier coefficients [(768, 512), (512, 256), (256, 128), (128, 10)], intercepts [(512,), (256,), (128,), (10,)]
2024-10-04 20:52:17,125 - AuxRunner - INFO - registered aux handler for topic fed.event
2024-10-04 20:52:17,128 - ServerRunner - INFO - [identity=simulator_server, run=simulate_job]: starting workflow scatter_gather_ctl (<class 'nvflare.app_common.workflows.scatter_and_gather.ScatterAndGather'>) ...
2024-10-04 20:52:17,132 - ScatterAndGather - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_gather_ctl]: Initializing ScatterAndGather workflow.
2024-10-04 20:52:17,133 - ServerRunner - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_gather_ctl]: Workflow scatter_gather_ctl (<class 'nvflare.app_common.workflows.scatter_and_gather.ScatterAndGather'>) started
2024-10-04 20:52:17,134 - ScatterAndGather - INFO - [identity=simulator_server, 

### Federated learning

In [13]:
os.environ["SIM_LOCAL"] = "False"

simulator = SimulatorRunner(
    job_folder="jobs/fedavg",
    workspace=f"/tmp/nvflare/bionemo/fedavg_alpha{split_alpha}",
    n_clients=n_clients,
    threads=n_clients
)
run_status = simulator.run()
print("Simulator finished with run_status", run_status)

2024-10-04 20:58:43,915 - SimulatorRunner - INFO - Create the Simulator Server.
2024-10-04 20:58:43,919 - CoreCell - INFO - server: creating listener on tcp://0:41263
2024-10-04 20:58:43,942 - CoreCell - INFO - server: created backbone external listener for tcp://0:41263
2024-10-04 20:58:43,943 - ConnectorManager - INFO - 6087: Try start_listener Listener resources: {'secure': False, 'host': 'localhost'}
2024-10-04 20:58:43,945 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connector [CH00002 PASSIVE tcp://0:42708] is starting
2024-10-04 20:58:44,447 - CoreCell - INFO - server: created backbone internal listener for tcp://localhost:42708
2024-10-04 20:58:44,450 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connector [CH00001 PASSIVE tcp://0:41263] is starting
2024-10-04 20:58:44,529 - nvflare.fuel.hci.server.hci - INFO - Starting Admin Server localhost on Port 52439
2024-10-04 20:58:44,530 - SimulatorRunner - INFO - Deploy the Apps.
2024-10-04 20:58:44,536 - SimulatorRunner - INFO - Create 



2024-10-04 20:58:45,637 - BioNeMoMLPModelPersistor - INFO - [identity=simulator_server, run=simulate_job]: MLPClassifier coefficients [(768, 512), (512, 256), (256, 128), (128, 10)], intercepts [(512,), (256,), (128,), (10,)]
2024-10-04 20:58:45,639 - AuxRunner - INFO - registered aux handler for topic fed.event
2024-10-04 20:58:45,640 - ServerRunner - INFO - [identity=simulator_server, run=simulate_job]: starting workflow scatter_gather_ctl (<class 'nvflare.app_common.workflows.scatter_and_gather.ScatterAndGather'>) ...
2024-10-04 20:58:45,642 - ScatterAndGather - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_gather_ctl]: Initializing ScatterAndGather workflow.
2024-10-04 20:58:45,643 - ServerRunner - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_gather_ctl]: Workflow scatter_gather_ctl (<class 'nvflare.app_common.workflows.scatter_and_gather.ScatterAndGather'>) started
2024-10-04 20:58:45,645 - ScatterAndGather - INFO - [identity=simulator_server, 

### TensorBoard Visualization
You can visualize the training progress using TensorBoard
```
tensorboard --logdir /tmp/nvflare/bionemo
```

An example of local (red) vs federated (blue) training is shown below.

![TensorBoard training curves](tb_curve.png)