A framework for training Large Language Models to use external tools (e.g., search engines) through in-context learning combined with reinforcement learning.
git clone https://github.com/applese233/ICRL.git
cd ICRL
# Environment with Python 3.9
conda env create -f environment.yml
conda activate icrl
pip install -e .
pip install flash-attn --no-build-isolation# 3-shot few-shot data used by train_grpo_fewshot.sh
python scripts/data_process/nq_search_fewshot.py \
--local_dir data/nq_search_fewshot \
--num_examples 3Few-shot data preparation supports choosing which example set to use from the example/ directory. The default is fewshot_examples.txt.
# Use a built-in example set under example/
python scripts/data_process/nq_search_fewshot.py \
--template_type fewshot \
--examples_name 7B_examples.txt \
--num_examples 3
# Or point to an explicit file path
python scripts/data_process/nq_search_fewshot.py \
--template_type fewshot \
--examples_file example/wrong_examples.txt \
--num_examples 2Available example files currently include fewshot_examples.txt, 7B_examples.txt, and wrong_examples.txt.
If you want to run the curriculum script, prepare all three stages first:
# Stage 1: 3-shot
python scripts/data_process/nq_search_fewshot.py \
--template_type fewshot \
--local_dir data/nq_search_fewshot \
--num_examples 3
# Stage 2: 2-shot
python scripts/data_process/nq_search_fewshot.py \
--template_type fewshot \
--local_dir data/nq_search_2shot \
--num_examples 2
# Stage 3: 0-shot
python scripts/data_process/nq_search_fewshot.py \
--template_type zeroshot \
--local_dir data/nq_search_0shotThe training requires a web search server. We use SerpAPI or Serper.dev as the search backend.
First, get your API key from:
- SerpAPI (default)
- Serper.dev
Then start the server:
# Using SerpAPI (default)
SERPAPI_KEY=your_api_key bash scripts/search/run_search_server.sh
# Or using Serper.dev
PROVIDER=serper SERPAPI_KEY=your_api_key bash scripts/search/run_search_server.shThe server will listen on http://127.0.0.1:8000/retrieve.
You can train the model with a single stage:
bash train_grpo_fewshot.shOr train the whole method:
bash train_curriculum.shThe training scripts default to 4 GPUs and the dataset paths above, but you can override them with environment variables:
CUDA_VISIBLE_DEVICES=0,1 N_GPUS=2 bash train_grpo_fewshot.sh
CUDA_VISIBLE_DEVICES=0,1 N_GPUS=2 \
STAGE1_DATA_DIR=data/nq_search_fewshot \
STAGE2_DATA_DIR=data/nq_search_2shot \
STAGE3_DATA_DIR=data/nq_search_0shot \
bash train_curriculum.shTest the model with a single question:
# Ask a question directly
python infer.py \
--model_path Qwen/Qwen2.5-7B-Instruct \
--question "Who won the 2022 FIFA World Cup?"
# Interactive mode (enter questions one by one)
python infer.py --model_path your_trained_model
# Use few-shot prompts
python infer.py \
--model_path your_model \
--use_fewshot \
--fewshot_path example/fewshot_examples.txt \
--question "Your question here"Evaluate on benchmark datasets:
# Convert evaluation datasets only
bash eval_batch_vllm.sh --datasets "triviaqa hotpotqa musique 2wikimultihopqa bamboogle" --convert_only
# Run evaluation with a checkpoint
bash eval_batch_vllm.sh \
--checkpoint_ref icrl-grpo-qwen2.5-7B/global_step_50 \
--datasets "triviaqa hotpotqa musique 2wikimultihopqa bamboogle" \
--use_vllm
# Equivalent form: model name + global step
bash eval_batch_vllm.sh \
--model_name icrl-grpo-qwen2.5-7B \
--global_step global_step_50 \
--datasets "triviaqa hotpotqa musique 2wikimultihopqa bamboogle" \
--use_vllmFor named checkpoints, the default lookup path is ./checkpoints/<model_name>/actor/<global_step>. You can override the root with --checkpoint_dir.
- TriviaQA
- HotpotQA
- 2WikiMultihopQA
- MuSiQue
- Bamboogle
This project builds upon:
Apache License 2.0