# Federated NLP with BERT Model

## Introduction 
In this example, we show how to use [NVIDIA FLARE](https://nvidia.github.io/NVFlare) for a Natural Language Processing (NLP) task using [BERT](https://github.com/google-research/bert) model from [Hugging Face](https://huggingface.co/). We select [BERT-base-uncased](https://huggingface.co/bert-base-uncased) as our base model. 

## Setup
Install required packages for training

In [None]:
%pip install -r code/requirements.txt

## Download Data 
The raw data can be accessed from [official page](https://www.ncbi.nlm.nih.gov/CBBresearch/Dogan/DISEASE/). 
In this example, we use the preprocessed csv-files from the reference repo above, which can be downloaded [here](https://drive.google.com/drive/folders/13wROtEAnMgWpLMIGHB5CY1BQ1Xe2XqhG). 

In the following, we download three files `train.csv`, `dev.csv`, and `test.csv` and save them to `/tmp/nvflare/dataset/nlp_ner`

In [None]:
%%sh
mkdir -p /tmp/nvflare/dataset/nlp_ner
wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1YWGBElsqj5ENsW0PtYwMlk_ShBt8MsLD' -O /tmp/nvflare/dataset/nlp_ner/dev.csv
wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=12kXGQPW-do-F7T-TLGycl0DCw6eQIaZc' -O /tmp/nvflare/dataset/nlp_ner/test.csv
wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1fjsf0jFKWu_-bbx236oB6e7DqOqGmw3y' -O /tmp/nvflare/dataset/nlp_ner/train.csv

## Data Preprocessing 
We then use the preprocessed data to generate random splits for both 4-client and 2-client experiments. 
Please modify the `DATASET_ROOT` below to point to folder containing the four downloaded csv-files.

In [None]:
! code/prepare_data.sh /tmp/nvflare/dataset/nlp_ner

The expected output is
```
4-client
(7594, 5) (2531, 5)
(5063, 5) (2531, 5)
(2532, 5) (2531, 5)
(2532, 5) (2532, 5)
(950, 5) (316, 5)
(634, 5) (316, 5)
(318, 5) (316, 5)
(318, 5) (318, 5)
```
The task here is to categorize each word in the text into three classes specified by the label. For example, the sentence 
`Recent progress has resulted in part of the gene mutated in Duchenne and the milder Becker muscular dystrophies being cloned and has suggested that the gene itself extends over 1 , 000 to 2 , 000 kilobases ( kb ) .` into label vector `O O O O O O O O O O O B I I I I I I O O O O O O O O O O O O O O O O O O O O O O O`. `B` marks the beginning of an entity, `I` marks each entity word, and `O` represents other words.
Let's take a closer look at the word-label correspondence:
![data sample](./figs/sample.png)
As shown above, the task is to capture the keywords related to medical findings.

## Run automated experiments
We run the federated training on 4 clients for BERT model using NVFlare Simulator via [JobAPI](https://nvflare.readthedocs.io/en/main/programming_guide/fed_job_api.html). To save time, we only run 5 rounds of fedrated training.

In [None]:
%cd code
! python nlp_fl_job.py --model_name Bert
%cd ..

## Results
### Validation curve on each site
In this example, each client computes their validation scores using their own
validation set. We recorded the loss, F1 score, precision, and recall. 
The curves can be viewed with TensorBoard, each training for 50 epochs (50 FL rounds, 1 local epoch per round).

For BERT model, the TensorBoard curves can be visualized:

In [None]:
%load_ext tensorboard
%tensorboard --logdir /tmp/nvflare/workspace/works/Bert/

### Testing score
The testing score is computed for the global model over the testing set.
We provide a script for performing validation on testing data. 

In [None]:
%cd code
! sh test_global_model.sh /tmp/nvflare/dataset/nlp_ner
%cd ..

The test results are:
```
BERT
              precision    recall  f1-score   support

           _       0.83      0.92      0.87      1255

   micro avg       0.83      0.92      0.87      1255
   macro avg       0.83      0.92      0.87      1255
weighted avg       0.83      0.92      0.87      1255
```
Note that training is not deterministic so the numbers can have some variations.

In this section, we showed how to train a BERT model with standard Pytorch training loop. Now let's move on to the next section [LLM Supervised Fine-Tuning (SFT)](../08.2_llm_sft/LLM_SFT.ipynb) where we will see how to utilize existing Trainer scripts via HuggingFace APIs