Skip to content

Latest commit



133 lines (107 loc) · 5 KB

File metadata and controls

133 lines (107 loc) · 5 KB

Inference-Time Training: With Greater Text Comes Greater Necessity

Long-text generation, such as novel writing or discourse-level translation with extremely long context, poses significant challenges to language models. Existing methods primarily focus on extending the model’s context window through strategies such as length extrapolation. However, these methods require daunting hardware resources in both training and inference.

Our proposed method, Temp-Lora, offers an alternative idea. Rather than relying on the KV cache to store all context information, Temp-Lora embeds this information directly into the model’s parameters. During the long-text generation process, we employ a temporary Lora module, which is progressively trained using the previously-generated text. This method efficiently preserves contextual knowledge, and the module is subsequently discarded after generation to prevent a permanent impact on the model parameters.

Extensive experiments conducted on the PG19 language-modeling benchmark and the GuoFeng discourse-level translation benchmark demonstrate the efficacy of Temp-Lora. Our findings reveal that: 1) Temp-Lora significantly improves generation quality on long texts, evidenced by a 1.68 perplexity decrease and a 6.6 BLEU increase on GuoFeng, 2) TempLora is compatible and complementary to most existing long-text generation methods, and 3) Temp-Lora can significantly reduce computation cost


  1. configs/deepspeed_zero2.json文件的路径写入 accelerate_default_config.yamldeepspeed_config_file 字段,如下所示:
compute_environment: LOCAL_MACHINE
  deepspeed_config_file: <deepspeed_zeros.json 文件的绝对路径>
  zero3_init_flag: false
distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
num_machines: 1
num_processes: 1
use_cpu: false
  1. 填写训练任务对应的shell文件

若要在 PG-19 数据集上进行实验,请填写 scripts/ 文件,补全训练所用的配置文件、数据集等信息。

ACCELERATE_CONFIG="<configs/accelerate_default_config.yaml 文件地址>"
DS_CONFIG="<configs/deepspeed_zero2.json 文件地址>"


mkdir -p $SAVE_DIR
accelerate launch --config_file $ACCELERATE_CONFIG trainer/ --model_name $MODEL_NAME \
  --train_fp "<训练所用的文件,如 data/pg19_test/12204.txt>" \
  --eval_fp "<与 train_fp 相同即可>" \
  --lora_rank 64 \
  --lora_alpha 64 \
  --lora_dropout 0.05 \
  --learning_rate 5e-5 \
  --weight_decay 0 \
  --optim torch_adaw_fused \
  --lr_scheduler_type constant_with_warmup \
  --warmup_steps 2 \
  --output_dir $SAVE_DIR \
  --training_input_length 3072 \
  --stride_size 1024 \
  --eval_input_length 3072 \
  --num_train_epochs 2 \
  --gradient_checkpointing "false" \


ACCELERATE_CONFIG="<configs/accelerate_default_config.yaml 文件地址>"
DS_CONFIG="<configs/deepspeed_zero2.json 文件地址>"

mkdir -p $SAVE_DIR

accelerate launch --config_file $ACCELERATE_CONFIG trainer/ \
  --model_name "01-ai/Yi-6B-Chat" \
  --train_fp "<训练所用的文件,如 data/cache_guofeng/102-bgwzsl.train>" \
  --eval_fp "<测试所用的文件,如 data/cache_guofeng/102-bgwzsl.eval>" \
  --lora_rank 64 \
  --lora_alpha 64 \
  --lora_dropout 0.05 \
  --learning_rate 5e-5 \
  --weight_decay 0 \
  --optim torch_adaw_fused \
  --lr_scheduler_type constant_with_warmup \
  --warmup_steps 2 \
  --output_dir $SAVE_DIR \
  --num_train_epochs 2 \
  --gradient_checkpointing false \
  --training_input_length 1 \   // 1024
  --eval_input_length 1         // 1024
  1. 执行相关的sheel文件
bash scripts/ 
  1. 实验结果分析
python3 eval_results/ --help
python3 eval_results/ --help

Example and KV Cache Reuse


- 基于输入生成1个chunk
- 训练该chunk
- 继续生成

的示例,并实现了 KV Cache Reuse,该示例见trainer/



use docker: nvidia/cuda:11.8.0-devel-ubuntu22.04
python version: 3.10.12
torch 2.2.1+cu118
torchaudio  2.2.1+cu118
torchvision 0.17.1+cu118
transformers 4.38.2
transformers-stream-generator 0.0.4
flash-attn 2.5.6
xformers 0.0.24+cu118
peft 0.9.0
accelerate 0.27.2
deepspeed 0.13.4
deepspeed-kernels 0.0.1.dev1698255861
datasets 2.18.0