# Parameter-Efficient Fine-Tuning (PEFT) with NeMo

In this example, we utilize NeMo's [PEFT](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/nlp/nemo_megatron/prompt_learning.html)
methods to showcase how to adapt a large language model (LLM) to 
a downstream task, such as financial sentiment predictions. 

With one line configuration change, you can try different PEFT techniques such as [p-tuning](https://arxiv.org/abs/2103.10385), [adapters](https://proceedings.mlr.press/v97/houlsby19a.html), or [LoRA](https://arxiv.org/abs/2106.09685), which add a small number of trainable parameters to the LLM
that condition the model to produce the desired output for the downstream task.

For more details, see the [PEFT script](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/tuning/megatron_gpt_peft_tuning.py) in NeMo, which we adapt using NVFlare's Lightning client API to run in a federated scenario.

## Dependencies
We assume you followed the instructions [here](../../README.md#requirements) 
to install the NeMo framework and the NeMo-NVFlare package. 

## Download the pre-trained LLM
In this example, we use a `MegatronGPTModel`, a transformer-based language model based on the GPT architecture.

In [None]:
# Check what GPT .nemo models we have available on NGC
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
MegatronGPTModel.list_available_models()

In [None]:
# Download the model from NGC
import os
model_file = "megatron_gpt_345m.nemo"
if not os.path.isfile(model_file):
    !wget "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/megatron_gpt_345m/versions/1/files/$model_file"
else:
    print(f"{model_file} already downloaded.")

## Data preprocessing
As our downstream task, we will use the [Financial PhraseBank dataset](https://huggingface.co/datasets/financial_phrasebank) for sentiment analysis.

The Financial PhraseBank dataset contains the sentiments for financial news headlines from a retail investor's perspective. Further details about the dataset can be found in Malo et al.'s ["Good Debt or Bad Debt: Detecting Semantic Orientations in Economic Texts"](https://arxiv.org/abs/1307.5336).


#### 1. Download the preprocessing scripts
We use the preprocessing scripts provided by NeMo which can be downloaded from GitHub.

In [None]:
script_name = "prompt_learning_financial_phrase_bank_preprocessing.py"
if not os.path.isfile(script_name):
    !wget -N "https://raw.githubusercontent.com/NVIDIA/NeMo/main/scripts/dataset_processing/nlp/financial_phrase_bank/$script_name"
else:
    print(f"{script_name} already downloaded.")

#### 2. Download the Financial PhraseBank Dataset

Download the `FinancialPhraseBank-v1.0.zip` dataset from [here](https://www.researchgate.net/profile/Pekka_Malo/publication/251231364_FinancialPhraseBank-v1.0/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v1.0.zip).

Then extract it under `./data`.

#### 3. Preprocess the dataset

In [None]:
!python3 prompt_learning_financial_phrase_bank_preprocessing.py

#### 4. Split the dataset to simulate clients
Next, we use three clients to simulate federated learning for p-tuning with NeMo.

In [None]:
!python3 data/split_financial_phrase_data.py --data_path data/FinancialPhraseBank-v1.0/financial_phrase_bank_train.jsonl --num_clients 3 --out_dir data/FinancialPhraseBank-v1.0_split

## Federated learning simulations
Next, we are using NVFlare's [simulator](https://nvflare.readthedocs.io/en/latest/user_guide/fl_simulator.html) to simulate each client training on their own dataset locally and all three clients training together using the [FedAvg](https://arxiv.org/abs/1602.05629) algorithm implemented in NVFlare.

With this setting, we require a GPU with at least 16GB of memory to run all clients in parallel on the same GPU. 
If you have multiple GPUs in your system, you can use the `gpu` argument to assign one GPU for each client, e.g., `gpu="0,1"`.

We will use NVFlare's job command for each setting to create the configurations needed to train the models based on the [sag_nemo](https://github.com/NVIDIA/NVFlare/blob/main/job_templates/sag_pt_deploy_map/info.md) job template. This template allows the definition of different configurations for each client, which we will use to assign their local training data file to each of them.

#### 1. Local P-Tuning
First, we create the job files and modify them to include the data paths for each client and the pre-trained LLM using the `-f` option.
Note, the `app_config` options are specific to the app script (`megatron_gpt_peft_tuning.py`) and modify variables in the NeMo config file (`megatron_gpt_peft_tuning_config.yaml`) directly on execution.

At this point, we also modify the local number of clients, local epochs and FL rounds to simulate local training.

The PEFT method is "ptuning".

In [None]:
%env NVFLARE_HOME=/home/hroth/Code2/nvflare/nemo_peft_example
#!python3 -m pip install -e /home/hroth/Code2/nvflare/nemo_peft_example

import os
peft_scheme="model.peft.peft_scheme\=ptuning" # can be either ptuning, adapter, lora, or ia3
app_script="megatron_gpt_peft_tuning.py"
restore_from_path=f"model.restore_from_path\={os. getcwd()}/megatron_gpt_345m.nemo"
trainer_config="trainer.max_steps\=2000 trainer.val_check_interval\=100"
val_files=f"model.data.validation_ds.file_names\=\[{os. getcwd()}/data/FinancialPhraseBank-v1.0/financial_phrase_bank_val.jsonl\]"
train_files_prefix=f"model.data.train_ds.file_names\=\[{os. getcwd()}/data/FinancialPhraseBank-v1.0_split/site"

!nvflare job create -force -j "./jobs/peft_p-tuning_local_345M" -w "sag_nemo" -sd "code" \
   -f app_1/config_fed_client.conf app_script={app_script} app_config="{peft_scheme} {restore_from_path} {trainer_config} {val_files} {train_files_prefix}-1.jsonl\]" \
   -f app_2/config_fed_client.conf app_script={app_script} app_config="{peft_scheme} {restore_from_path} {trainer_config} {val_files} {train_files_prefix}-2.jsonl\]" \
   -f app_3/config_fed_client.conf app_script={app_script} app_config="{peft_scheme} {restore_from_path} {trainer_config} {val_files} {train_files_prefix}-3.jsonl\]" \
   -f app_server/config_fed_server.conf num_rounds=1

Next, simulate each client p-tuning on their local dataset using the FL simulator. To do this, we only run 1 round of FL, with each client running 50 p-tuning epochs on their local dataset.

In [None]:
from nvflare import SimulatorRunner    

simulator = SimulatorRunner(
    job_folder="jobs/peft_p-tuning_local_345M",
    workspace="/tmp/nvflare/nemo/peft_p-tuning_local_345M",
    n_clients=3,
    threads=3
)
run_status = simulator.run()
print("Simulator finished with run_status", run_status)

#### 2. Federated P-Tuning
Next, we use the [FedAvg](https://arxiv.org/abs/1602.05629) algorithm to p-tune the model in a federated scenario. First, create and modify the configuration files again. 
This time, we increase the number of FL rounds and decrease the number of local epochs per round to match the federated scenario.

In [None]:
#!python3 create_configs.py --job_folder "jobs/peft_p-tuning_fedavg_345M" --num_clients 3 --max_steps 200 --num_rounds 50
import os
peft_scheme="model.peft.peft_scheme\=ptuning" # can be either ptuning, adapter, lora, or ia3
app_script="megatron_gpt_peft_tuning.py"
restore_from_path=f"model.restore_from_path\={os. getcwd()}/megatron_gpt_345m.nemo"
trainer_config="trainer.max_steps\=200 trainer.val_check_interval\=100"
val_files=f"model.data.validation_ds.file_names\=\[{os. getcwd()}/data/FinancialPhraseBank-v1.0/financial_phrase_bank_val.jsonl\]"
train_files_prefix=f"model.data.train_ds.file_names\=\[{os. getcwd()}/data/FinancialPhraseBank-v1.0_split/site"

!nvflare job create -force -j "./jobs/peft_p-tuning_fedavg_345M" -w "sag_nemo" -sd "code" \
   -f app_1/config_fed_client.conf app_script={app_script} app_config="{peft_scheme} {restore_from_path} {trainer_config} {val_files} {train_files_prefix}-1.jsonl\]" \
   -f app_2/config_fed_client.conf app_script={app_script} app_config="{peft_scheme} {restore_from_path} {trainer_config} {val_files} {train_files_prefix}-2.jsonl\]" \
   -f app_3/config_fed_client.conf app_script={app_script} app_config="{peft_scheme} {restore_from_path} {trainer_config} {val_files} {train_files_prefix}-3.jsonl\]" \
   -f app_server/config_fed_server.conf num_rounds=10

Next, simulate the federated p-tuning using FedAvg. Here, each client p-tunes for one local epoch before sending their local model updates to the server for aggregation. This is repeated for 50 FL rounds.

In [28]:
from nvflare import SimulatorRunner    

simulator = SimulatorRunner(
    job_folder="jobs/peft_p-tuning_fedavg_345M",
    workspace="/tmp/nvflare/nemo/peft_p-tuning_fedavg_345M",
    n_clients=3,
    threads=3
)
run_status = simulator.run()
print("Simulator finished with run_status", run_status)

2023-10-25 18:59:15,954 - SimulatorRunner - INFO - Create the Simulator Server.
2023-10-25 18:59:15,959 - CoreCell - INFO - server: creating listener on tcp://0:50463
2023-10-25 18:59:15,990 - CoreCell - INFO - server: created backbone external listener for tcp://0:50463
2023-10-25 18:59:15,991 - ConnectorManager - INFO - 27322: Try start_listener Listener resources: {'secure': False, 'host': 'localhost'}
2023-10-25 18:59:15,993 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connector [CH00002 PASSIVE tcp://0:56096] is starting
2023-10-25 18:59:16,495 - CoreCell - INFO - server: created backbone internal listener for tcp://localhost:56096
2023-10-25 18:59:16,499 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connector [CH00001 PASSIVE tcp://0:50463] is starting
2023-10-25 18:59:16,581 - nvflare.fuel.hci.server.hci - INFO - Starting Admin Server localhost on Port 57567
2023-10-25 18:59:16,582 - SimulatorRunner - INFO - Deploy the Apps.
2023-10-25 18:59:16,591 - SimulatorRunner - INFO - Create

  from pandas.core.computation.check import NUMEXPR_INSTALLED


NEMO version 1.21.0rc0
NEMO version 1.21.0rc0
2023-10-25 18:59:27,347 - IntimeModelSelector - INFO - model selection weights control: None
2023-10-25 18:59:27,349 - ServerRunner - INFO - [identity=simulator_server, run=simulate_job]: Server runner starting ...
[NeMo I 2023-10-25 18:59:27 megatron_trainer_builder:49] Detected interactive environment, using NLPDDPStrategyNotebook


      rank_zero_warn(
    


2023-10-25 18:59:27,630 - SimulatorClientRunner - INFO - Start the clients run simulation.
2023-10-25 18:59:27,657 - pytorch_lightning.utilities.rank_zero - INFO - GPU available: True (cuda), used: True
2023-10-25 18:59:27,659 - pytorch_lightning.utilities.rank_zero - INFO - TPU available: False, using: 0 TPU cores
2023-10-25 18:59:27,660 - pytorch_lightning.utilities.rank_zero - INFO - IPU available: False, using: 0 IPUs
2023-10-25 18:59:27,661 - pytorch_lightning.utilities.rank_zero - INFO - HPU available: False, using: 0 HPUs
2023-10-25 18:59:28,652 - SimulatorClientRunner - INFO - Simulate Run client: site-1 on GPU group: None
2023-10-25 18:59:28,655 - SimulatorClientRunner - INFO - Simulate Run client: site-2 on GPU group: None
2023-10-25 18:59:28,675 - SimulatorClientRunner - INFO - Simulate Run client: site-3 on GPU group: None
2023-10-25 18:59:29,789 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connection [CN00007 127.0.0.1:50463 <= 127.0.0.1:34430] is created: PID: 27322
2023-1

[NeMo W 2023-10-25 18:59:30 megatron_base_model:810] The model: MegatronGPTSFTModel() does not have field.name: virtual_pipeline_model_parallel_size in its cfg. Add this key to cfg or config_mapping to make to make it configurable.
[NeMo W 2023-10-25 18:59:30 megatron_base_model:810] The model: MegatronGPTSFTModel() does not have field.name: gradient_accumulation_fusion in its cfg. Add this key to cfg or config_mapping to make to make it configurable.
[NeMo W 2023-10-25 18:59:30 megatron_base_model:810] The model: MegatronGPTSFTModel() does not have field.name: overlap_p2p_comm in its cfg. Add this key to cfg or config_mapping to make to make it configurable.
[NeMo W 2023-10-25 18:59:30 megatron_base_model:810] The model: MegatronGPTSFTModel() does not have field.name: batch_p2p_comm in its cfg. Add this key to cfg or config_mapping to make to make it configurable.


[NeMo I 2023-10-25 18:59:30 megatron_init:234] Rank 0 has data parallel group: [0]
[NeMo I 2023-10-25 18:59:30 megatron_init:237] All data parallel group ranks: [[0]]
[NeMo I 2023-10-25 18:59:30 megatron_init:238] Ranks 0 has data parallel rank: 0
[NeMo I 2023-10-25 18:59:30 megatron_init:246] Rank 0 has model parallel group: [0]
[NeMo I 2023-10-25 18:59:30 megatron_init:247] All model parallel group ranks: [[0]]
[NeMo I 2023-10-25 18:59:30 megatron_init:257] Rank 0 has tensor model parallel group: [0]
[NeMo I 2023-10-25 18:59:30 megatron_init:261] All tensor model parallel group ranks: [[0]]
[NeMo I 2023-10-25 18:59:30 megatron_init:262] Rank 0 has tensor model parallel rank: 0
[NeMo I 2023-10-25 18:59:30 megatron_init:276] Rank 0 has pipeline model parallel group: [0]
[NeMo I 2023-10-25 18:59:30 megatron_init:288] Rank 0 has embedding group: [0]
[NeMo I 2023-10-25 18:59:30 megatron_init:294] All pipeline model parallel group ranks: [[0]]
[NeMo I 2023-10-25 18:59:30 megatron_init:295]

23-10-25 18:59:30 - PID:27322 - rank:(0, 0, 0, 0) - microbatches.py:39 - INFO - setting number of micro-batches to constant 32
[NeMo W 2023-10-25 18:59:30 megatron_base_model:810] The model: MegatronGPTSFTModel() does not have field.name: virtual_pipeline_model_parallel_size in its cfg. Add this key to cfg or config_mapping to make to make it configurable.
[NeMo W 2023-10-25 18:59:30 megatron_base_model:810] The model: MegatronGPTSFTModel() does not have field.name: gradient_accumulation_fusion in its cfg. Add this key to cfg or config_mapping to make to make it configurable.
[NeMo W 2023-10-25 18:59:30 megatron_base_model:810] The model: MegatronGPTSFTModel() does not have field.name: overlap_p2p_comm in its cfg. Add this key to cfg or config_mapping to make to make it configurable.
[NeMo W 2023-10-25 18:59:30 megatron_base_model:810] The model: MegatronGPTSFTModel() does not have field.name: batch_p2p_comm in its cfg. Add this key to cfg or config_mapping to make to make it configura

[NeMo I 2023-10-25 18:59:30 tokenizer_utils:204] Getting Megatron tokenizer for pretrained model name: megatron-gpt-345m, custom vocab file: /tmp/tmpqnulxgln/bfcdca5e44814366bdb5dcd651325152_gpt2-vocab.json, and merges file: /tmp/tmpqnulxgln/315a11fd68be49d6abdb34363e8c4997_gpt2-merge.txt
[NeMo I 2023-10-25 18:59:30 tokenizer_utils:130] Getting HuggingFace AutoTokenizer with pretrained_model_name: gpt2, vocab_file: /tmp/tmpqnulxgln/bfcdca5e44814366bdb5dcd651325152_gpt2-vocab.json, merges_files: /tmp/tmpqnulxgln/315a11fd68be49d6abdb34363e8c4997_gpt2-merge.txt, special_tokens_dict: {}, and use_fast: False


Using sep_token, but it is not set yet.
Using cls_token, but it is not set yet.
Using pad_token, but it is not set yet.
Using mask_token, but it is not set yet.


[NeMo I 2023-10-25 18:59:31 megatron_base_model:312] Padded vocab_size: 50304, original vocab_size: 50257, dummy tokens: 47.


[NeMo W 2023-10-25 18:59:31 megatron_base_model:810] The model: MegatronGPTSFTModel() does not have field.name: virtual_pipeline_model_parallel_size in its cfg. Add this key to cfg or config_mapping to make to make it configurable.
[NeMo W 2023-10-25 18:59:31 megatron_base_model:810] The model: MegatronGPTSFTModel() does not have field.name: gradient_accumulation_fusion in its cfg. Add this key to cfg or config_mapping to make to make it configurable.
[NeMo W 2023-10-25 18:59:31 megatron_base_model:810] The model: MegatronGPTSFTModel() does not have field.name: overlap_p2p_comm in its cfg. Add this key to cfg or config_mapping to make to make it configurable.
[NeMo W 2023-10-25 18:59:31 megatron_base_model:810] The model: MegatronGPTSFTModel() does not have field.name: batch_p2p_comm in its cfg. Add this key to cfg or config_mapping to make to make it configurable.
[NeMo W 2023-10-25 18:59:31 megatron_gpt_model:1598] The model: MegatronGPTSFTModel() does not have field.name: num_query_

[NeMo I 2023-10-25 18:59:32 nlp_overrides:686] Model MegatronGPTSFTModel was successfully restored from /home/hroth/Code2/nvflare/nemo_peft_example/integration/nemo/examples/peft/megatron_gpt_345m.nemo.
2023-10-25 18:59:32,510 - root - INFO - Adding adapter weights to the model for PEFT
[NeMo I 2023-10-25 18:59:32 nlp_adapter_mixins:182] Before adding PEFT params:
      | Name        | Type       | Params
    -------------------------------------------
    0 | model       | GPTModel   | 354 M 
    1 | val_metric  | ModuleList | 0     
    2 | test_metric | ModuleList | 0     
    -------------------------------------------
    0         Trainable params
    354 M     Non-trainable params
    354 M     Total params
    1,419.485 Total estimated model params size (MB)
[NeMo I 2023-10-25 18:59:32 nlp_adapter_mixins:195] After adding PEFT params:
      | Name        | Type       | Params
    -------------------------------------------
    0 | model       | GPTModel   | 356 M 
    1 | val_m

  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core.computation.check import NUMEXPR_INSTALLED


2023-10-25 18:59:33,015 - numexpr.utils - INFO - Note: NumExpr detected 36 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2023-10-25 18:59:33,015 - numexpr.utils - INFO - NumExpr defaulting to 8 threads.


  from pandas.core.computation.check import NUMEXPR_INSTALLED


2023-10-25 18:59:33,820 - Cell - INFO - Register blob CB for channel='aux_communication', topic='*'
2023-10-25 18:59:33,821 - Cell - INFO - Register blob CB for channel='aux_communication', topic='*'
2023-10-25 18:59:33,965 - Cell - INFO - Register blob CB for channel='aux_communication', topic='*'
2023-10-25 18:59:34,340 - ServerRunner - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather, peer=site-1, peer_run=simulate_job, task_name=train, task_id=2b6acff2-0f6a-4567-94a7-570c11cc8fd7]: assigned task to client site-1: name=train, id=2b6acff2-0f6a-4567-94a7-570c11cc8fd7
2023-10-25 18:59:34,342 - ServerRunner - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather, peer=site-1, peer_run=simulate_job, task_name=train, task_id=2b6acff2-0f6a-4567-94a7-570c11cc8fd7]: sent task assignment to client. client_name:site-1 task_id:2b6acff2-0f6a-4567-94a7-570c11cc8fd7
2023-10-25 18:59:34,343 - GetTaskCommand - INFO - return task to client.  client_name:

Process Process-46:


Simulator finished with run_status -9
[rank: 0] Received SIGTERM: 15
[rank: 0] Received SIGTERM: 15
Epoch 0: :  38%|███▊      | 77/201 [02:29<04:00, v_num=9-45, reduced_train_loss=0.765, global_step=76.00, consumed_samples=9856.0, train_step_timing in s=1.960]
Epoch 0: :  38%|███▊      | 77/201 [02:29<04:01, v_num=9-46, reduced_train_loss=0.651, global_step=76.00, consumed_samples=9856.0, train_step_timing in s=1.960]
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 801, in <module>
    main()
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed

You can visualize the training process using TensorBoard

In [None]:
!tensorboard --logdir /tmp/nvflare/nemo

## Results
In this scenario, all clients utilize the same validation set, allowing for a direct comparison between the locally p-tuned and federated global models. As anticipated, the FedAvg-trained global model exhibits lower validation loss than the models trained solely on their local datasets. This is because the global model has access to all client datasets and can, consequently, generalize better.

![validation loss](./figs/val_loss.svg)

## Inference

We can use `model.generate()` to run inference after p-tuning the model. 
Let's define some test examples to feed to the p-tuned model to see its predictions.

In [None]:
test_examples = [
    {"taskname": "sentiment", "sentence": "The products have a low salt and fat content ."},
    {"taskname": "sentiment", "sentence": "The agreement is valid for four years ."},
    {"taskname": "sentiment", "sentence": "Diluted EPS rose to EUR3 .68 from EUR0 .50 ."},
    {"taskname": "sentiment", "sentence": "The company is well positioned in Brazil and Uruguay ."},
    {"taskname": "sentiment", "sentence": "Profit before taxes decreased by 9 % to EUR 187.8 mn in the first nine months of 2008 , compared to EUR 207.1 mn a year earlier ."},
]

Next, we will load the global model.

In [None]:
import os
import torch
import pytorch_lightning as pl
from nemo_nvflare.fed_megatron_gpt_prompt_learning_model import FedMegatronGPTPromptLearningModel
from nemo_nvflare.utils import load_weights
from omegaconf import OmegaConf
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy
from pytorch_lightning.plugins.environments import TorchElasticEnvironment

# Load model configuration used by one of the clients
config = OmegaConf.load("jobs/gpt_p-tuning_fedavg_345M/server/config/megatron_gpt_prompt_learning_config.yaml")

# Set GPT model path
config.model.language_model_path = "megatron_gpt_345m.nemo"

# Load task templates
config.model.task_templates = OmegaConf.load("jobs/gpt_p-tuning_fedavg_345M/server/config/task_templates.json")

# Set task that were learned
config.model.new_tasks = ["sentiment"]

# Setup cluster environment parameters
# use torch elastic cluster environment so `create_process_externally` is True
# the launcher is set to None. It will not try to spawn new processes.
# It won't create the misconfiguration error because of the `interactive session`
os.environ["LOCAL_RANK"] = '0'
os.environ["RANK"] = '0'
os.environ["WORLD_SIZE"] = '1'
strategy = NLPDDPStrategy(find_unused_parameters=False, no_ddp_communication_hook=True)
plugins = [TorchElasticEnvironment()]

# Set up the trainer and load the model that was used for p-tuning
trainer = pl.Trainer(plugins=plugins, strategy=strategy, **config.trainer)
model = FedMegatronGPTPromptLearningModel(cfg=config.model, trainer=trainer)
model.init_prompt_encoder()

print("Model initialized", type(model))

Overwrite the prompt encoder with the best global model

In [None]:
ckpt = torch.load("/tmp/nvflare/nemo/gpt_p-tuning_fedavg_345M/simulate_job/app_server/best_FL_global_model.pt")
global_weights = ckpt["model"]

n_loaded = load_weights(model, global_weights, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
print(f"Loaded {n_loaded} of {len(global_weights)} weights")

Run the model

In [None]:
response = model.generate(inputs=test_examples, length_params=None)

print('The prediction results of some sample queries with the trained model:')
for result in response['sentences']:
    print(result)
    print("-" * 30)

The expected output predictions look something like this

>      The products have a low salt and fat content . sentiment: neutral
>      ------------------------------
>      The agreement is valid for four years . sentiment: neutral
>      ------------------------------
>      Diluted EPS rose to EUR3 .68 from EUR0 .50 . sentiment: positive
>      ------------------------------
>      The company is well positioned in Brazil and Uruguay . sentiment: positive
>      ------------------------------
>      Profit before taxes decreased by 9 % to EUR 187.8 mn in the first nine months of 2008 , compared to EUR 207.1 mn a year earlier . sentiment: negative
>      ------------------------------