Skip to content

Latest commit

 

History

History
133 lines (107 loc) · 5 KB

readme_zh.md

File metadata and controls

133 lines (107 loc) · 5 KB

Inference-Time Training: With Greater Text Comes Greater Necessity

Long-text generation, such as novel writing or discourse-level translation with extremely long context, poses significant challenges to language models. Existing methods primarily focus on extending the model’s context window through strategies such as length extrapolation. However, these methods require daunting hardware resources in both training and inference.

Our proposed method, Temp-Lora, offers an alternative idea. Rather than relying on the KV cache to store all context information, Temp-Lora embeds this information directly into the model’s parameters. During the long-text generation process, we employ a temporary Lora module, which is progressively trained using the previously-generated text. This method efficiently preserves contextual knowledge, and the module is subsequently discarded after generation to prevent a permanent impact on the model parameters.

Extensive experiments conducted on the PG19 language-modeling benchmark and the GuoFeng discourse-level translation benchmark demonstrate the efficacy of Temp-Lora. Our findings reveal that: 1) Temp-Lora significantly improves generation quality on long texts, evidenced by a 1.68 perplexity decrease and a 6.6 BLEU increase on GuoFeng, 2) TempLora is compatible and complementary to most existing long-text generation methods, and 3) Temp-Lora can significantly reduce computation cost

代码运行

  1. configs/deepspeed_zero2.json文件的路径写入 accelerate_default_config.yamldeepspeed_config_file 字段,如下所示:
compute_environment: LOCAL_MACHINE
deepspeed_config:
  deepspeed_config_file: <deepspeed_zeros.json 文件的绝对路径>
  zero3_init_flag: false
distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
num_machines: 1
num_processes: 1
use_cpu: false
  1. 填写训练任务对应的shell文件

若要在 PG-19 数据集上进行实验,请填写 scripts/llama2.sh 文件,补全训练所用的配置文件、数据集等信息。

ACCELERATE_CONFIG="<configs/accelerate_default_config.yaml 文件地址>"
DS_CONFIG="<configs/deepspeed_zero2.json 文件地址>"
SAVE_DIR="<模型结果保存路径>"

MODEL_NAME="togethercomputer/LLaMA-2-7B-32K"

mkdir -p $SAVE_DIR
accelerate launch --config_file $ACCELERATE_CONFIG trainer/acc_pg19_trainer.py --model_name $MODEL_NAME \
  --train_fp "<训练所用的文件,如 data/pg19_test/12204.txt>" \
  --eval_fp "<与 train_fp 相同即可>" \
  --lora_rank 64 \
  --lora_alpha 64 \
  --lora_dropout 0.05 \
  --learning_rate 5e-5 \
  --weight_decay 0 \
  --optim torch_adaw_fused \
  --lr_scheduler_type constant_with_warmup \
  --warmup_steps 2 \
  --output_dir $SAVE_DIR \
  --training_input_length 3072 \
  --stride_size 1024 \
  --eval_input_length 3072 \
  --num_train_epochs 2 \
  --gradient_checkpointing "false" \
  --use_flash_attention_2

对应GuoFeng数据集则是01_6B_chat.sh

ACCELERATE_CONFIG="<configs/accelerate_default_config.yaml 文件地址>"
DS_CONFIG="<configs/deepspeed_zero2.json 文件地址>"
SAVE_DIR="<模型结果保存路径>"

mkdir -p $SAVE_DIR

accelerate launch --config_file $ACCELERATE_CONFIG trainer/acc_guo_feng_trainer.py \
  --model_name "01-ai/Yi-6B-Chat" \
  --train_fp "<训练所用的文件,如 data/cache_guofeng/102-bgwzsl.train>" \
  --eval_fp "<测试所用的文件,如 data/cache_guofeng/102-bgwzsl.eval>" \
  --lora_rank 64 \
  --lora_alpha 64 \
  --lora_dropout 0.05 \
  --learning_rate 5e-5 \
  --weight_decay 0 \
  --optim torch_adaw_fused \
  --lr_scheduler_type constant_with_warmup \
  --warmup_steps 2 \
  --output_dir $SAVE_DIR \
  --num_train_epochs 2 \
  --gradient_checkpointing false \
  --training_input_length 1 \   // 1024
  --eval_input_length 1         // 1024
  1. 执行相关的sheel文件
bash scripts/llama2.sh 
  1. 实验结果分析
python3 eval_results/pg19.py --help
python3 eval_results/guo_feng.py --help

Example and KV Cache Reuse

我们实现了一个

- 基于输入生成1个chunk
- 训练该chunk
- 继续生成

的示例,并实现了 KV Cache Reuse,该示例见trainer/acc_complete_example_trainer.py

运行环境

我们基于nvidia/cuda:11.8.0-devel-ubuntu22.04构建运行环境,所需的python库详见requirements.txt文件,几个比较重要的库如下所示:

use docker: nvidia/cuda:11.8.0-devel-ubuntu22.04
python version: 3.10.12
torch 2.2.1+cu118
torchaudio  2.2.1+cu118
torchvision 0.17.1+cu118
transformers 4.38.2
transformers-stream-generator 0.0.4
flash-attn 2.5.6
xformers 0.0.24+cu118
peft 0.9.0
accelerate 0.27.2
deepspeed 0.13.4
deepspeed-kernels 0.0.1.dev1698255861
datasets 2.18.0