Skip to content

Semi-Offline Reinforcement Learning for Optimized Text Generation

Notifications You must be signed in to change notification settings

ChangyuChen347/semi-offline-RL

Repository files navigation

Semi-Offline Reinforcement Learning for Optimized Text Generation

This repository contains the code, data, and checkpoints for our paper published in ICML2023:
Semi-Offline Reinforcement Learning for Optimized Text Generation
Changyu Chen, Xiting Wang, Yiqiao Jin, Victor Ye Dong, Li Dong, Jie Cao, Yi Liu, Rui Yan
Paper: http://arxiv.org/abs/2306.09712

    @article{chen2023semi,
      title={Semi-Offline Reinforcement Learning for Optimized Text Generation},
      author={Chen, Changyu and Wang, Xiting and Jin, Yiqiao and Dong, Victor Ye and Dong, Li and Cao, Jie and Liu, Yi and Yan, Rui},
      journal={arXiv preprint arXiv:2306.09712},
      year={2023}
    }

1. Overview

Our semi-offline method is illustrated in (c2): we use static data as the starting point and do exploration by one Forward Propagation (FP).

d

2. Install

git clone https://github.com/ChangyuChen347/semi-offline-RL.git
cd semi-offline-RL
conda create -n semi-offline-rl python=3.8
conda activate semi-offline-rl
pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt
python -c "import nltk; nltk.download('punkt'); nltk.download('stopwords')"

3. Train

Before training, you should download the checkpoint of the base model and the data. Then, place the checkpoint in the ./model_output directory and the data in ./static_data directory.

You can simply run the bash scripts under the train directory.

bash train/run_cnndm.sh

or

CUDA_VISIBLE_DEVICES=0 python main.py \
    --do_train \
    --scene bart_cnndm_generation \
    --use_logit True \
    --report_to tensorboard \
    --seed 2022 \
    --smooth 0.1 \
    --trainer rl \
    --learning_rate 0.000001 \
    --num_train_epochs 60 \
    --max_grad_norm 1 \
    --print_every 1000 \
    --save_every 4000 \
    --eval_steps 2000 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 8 \
    --per_device_eval_batch_size 16 \
    --length_normalize_4_rl True \
    --training_length_penalty 1 \
    --train_dir static_data/cnndm/cnndm_train.tsv \
    --eval_dir static_data/cnndm/cnndm_valid.tsv \
    --cand_pos_remove_sp_tk True \
    --recover model_output/cnndm_base_model \
    --exp_name demo_cnndm \
    --rewards 'rouge' \
    --rouge_type 12l \
    --rl_weight 20 \
    --sample_num 63 \
    --mask_rate 0.4 \
    --kd_inputs_worst True \
    --eval_metrics rouges,rouge \
    --seq_decode_model bart 

Training Parameters:

Basic setting:

  • --learning_rate: Sets the learning rate for training.
  • --num_train_epochs: Specifies the number of training epochs.
  • --max_grad_norm: Sets the maximum gradient norm for gradient clipping.
  • --print_every: Prints training progress every specified number of steps.
  • --save_every: Saves the model every specified number of steps.
  • --eval_steps: Evaluates the model every specified number of steps.
  • --per_device_train_batch_size: Sets the training batch size per GPU.
  • --gradient_accumulation_steps: Accumulates gradients over the specified number of steps.
  • --per_device_eval_batch_size: Sets the evaluation batch size per GPU.
  • --length_normalize_4_rl: Applies length normalization for reinforcement learning.
  • --cand_pos_remove_sp_tk: Removes special tokens (pad/eos) from the candidate positions.
  • --exp_name: Specifies the experiment name.
  • --eval_metrics: Specifies the evaluation metric.

Model and Task:

  • --scene: Specifies the scene or task for the model. The config file is in ./config/SceneConfigs/
  • --train_dir: Specifies the training dataset directory.
  • --eval_dir: Specifies the evaluation dataset directory.
  • --recover: The path of the checkpoint.

RL setting:

  • --rewards: Specifies the reward metric.
  • --rouge_type: Sets the Rouge metric type (12l for rouge-1, rouge-2, and rouge-L).
  • --rl_weight: Sets the weight for reinforcement learning loss.
  • --sample_num: Sets the number of samples for RL.
  • --mask_rate: Sets the masking rate for both sft and RL.
  • --kd_inputs_worst: Uses worst case inputs for knowledge distillation.

4. Evaluation

Before Evaluation: Downloading Required Packages:

  1. The evaluation (word tokenization and metric computation) of CNN/DM and XSum is following BRIO: The predictions are first lowercased and tokenized using the PTB tokenizer provided by Standford (download here), and then the ROUGE score is computed using the standard ROUGE Perl package from (download here).

After downloading the two files, you can set the environment variables using the following commands:

export _ROUGE_PATH=./ROUGE-RELEASE-1.5.5
export CLASSPATH=./stanford-corenlp-3.8.0.jar

To utilize the ROUGE Perl package, you may need to install XML::DOM and XML::Parser. Alternatively, you can use the "-p" flag to obtain Python results for a quick start. Please note that the Python results may have slight differences compared to the Perl results.

  1. The evaluation of SQuAD is following LMQG

To compute the Meteor score for SQuAD, you need to download the paraphrase-en.gz and place it in the ./lmqg/automatic_evaluation_tool/meteor/data/ directory.

Example: The script for evaluating CNN/DM is as follows:

You can simply run the bash scripts under the evaluation directory.

bash evaluation/eval_cnn.sh

or

exp_name=rl
output_dir_path=eval_output
model_dir_path=model_output
dataset=cnndm
export _ROUGE_PATH=./ROUGE-RELEASE-1.5.5
export CLASSPATH=./stanford-corenlp-3.8.0.jar
bash evaluation/run_test_cnndm.sh ${exp_name} ${output_dir_path} ${model_dir_path}
python extract_prediction.py --dataset ${dataset} --exp_name ${exp_name} --output_dir_path ${output_dir_path}
cat ${output_dir_path}/${dataset}/${exp_name}/pred.txt | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > ${output_dir_path}/${dataset}/${exp_name}/pred.txt.token
cat ${output_dir_path}/${dataset}/${exp_name}/ref.txt | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > ${output_dir_path}/${dataset}/${exp_name}/ref.txt.token
python cal_rouge.py --ref  ${output_dir_path}/${dataset}/${exp_name}/ref.txt.token --hyp ${output_dir_path}/${dataset}/${exp_name}/pred.txt.token

5. Checkpoints and static datasets

The BASE models are supervised fine tuning (SFT) model trained with [mask] token. The RL models are our trained RL checkpoints.

BASE (M-FT) RL
CNN/DM cnndm_bart_base_model cnndm_bart_rl_model
SAMSum samsum_bart_base_model samsum_bart_rl_model
SQuAD t5_squad_base_model t5_squad_rl_model
XSum xsum_pegasus_base_model xsum_pegasus_rl_model

The training datasets (*_train.tsv) contains the source, ground truth, and ordered candidates.

Train validation Test
CNN/DM cnn_train.tsv cnn_valid.tsv cnn_test.tsv
SAMSum samsum_train.tsv samsum_valid.tsv samsum_test.tsv
SQuAD squad_train.tsv squad_valid.tsv squad_test.tsv
XSum xsum_train.tsv xsum_valid.tsv xsum_test.tsv

About

Semi-Offline Reinforcement Learning for Optimized Text Generation

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages