This repository contains the core code for the paper GRAG4CM: A RAG-Aware Generative Model with Chain-of-Thought Reasoning for Chinese Medicine. It implements:
- 🚀 Stage 1 (Stage1/): LoRA fine-tuning of a local base model (default: Qwen) on structured TCM QA / judgment data.
- 🎯 Stage 2 (Stage2/): Optional Prompt / Prefix Tuning on top of the LoRA-merged base model, plus unified validation.
- 📊 Verification (Verification/): Standalone evaluation of a trained LoRA model with multiple QA-style metrics.
.
├── Stage1/ # Stage 1: LoRA fine-tuning (Qwen)
│ ├── config.py # Training configuration (model, data paths, LoRA and training hyper-params)
│ ├── data_processor.py # Data loading & normalization into JSON-style prompts and targets
│ ├── standard_evaluator.py # Unified evaluator (ROUGE / BLEU / GLEU / Distinct / BERTScore / BLEURT, etc.)
│ ├── trainer.py # LoRA training logic based on HF Trainer + custom evaluation
│ └── train.py # Entry point for Qwen LoRA fine-tuning
│
├── Stage2/ # Stage 2: Prompt / Prefix Tuning (optional / advanced)
│ └── train2.py # P-Tuning / Prefix-Tuning + multi-dataset validation
│
├── Verification/ # Standalone evaluation of a trained LoRA adapter
│ └── evaluate_trained_lora.py # Per-sample generation and metric accumulation
│
├── GTCMQA.json # Example or related dataset (used in the paper)
└── README.md
Stage 1 (Stage1/) and Verification (Verification/) together are sufficient to reproduce the main training and evaluation pipeline. Stage 2 (Stage2/) provides an additional P/Prefix-Tuning phase and is optional.
- Python 3.10+ is recommended.
- A GPU with sufficient memory (e.g., 24GB) is strongly recommended for training.
- Core Python libraries:
torch(CUDA build)transformerspeftdatasetsnumpy
Example installation (adapt to your environment):
pip install torch transformers peft datasets numpyBERTScore and BLEURT in this project are implemented using local models loaded via
transformers. No extra BLEURT toolkit is required.
Prepare a local base model in models/Qwen/. Any Qwen-like Causal LM that is compatible with transformers can be used.
Example (download from HuggingFace and save locally):
from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = "Qwen/xxx" # TODO: replace with the exact model you use
tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
save_dir = "models/Qwen"
tok.save_pretrained(save_dir)
model.save_pretrained(save_dir)Stage1/standard_evaluator.py and Verification/evaluate_trained_lora.py expect the following local models:
models/
├── bert-base-chinese/ # For BERTScore
└── bleurt-base-128/ # For BLEURT (regression scores, normalized to [0, 1])
You can download them from HuggingFace and save them in the above directories, e.g.:
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
# BERT-Base Chinese
bert_id = "bert-base-chinese"
bert_tok = AutoTokenizer.from_pretrained(bert_id)
bert_model = AutoModel.from_pretrained(bert_id)
bert_tok.save_pretrained("models/bert-base-chinese")
bert_model.save_pretrained("models/bert-base-chinese")
# BLEURT-base-128 (replace with the exact BLEURT model you use)
bleurt_id = "[bleurt-model-id]" # e.g., "lucadiliello/BLEURT-20" or your local export
bleurt_tok = AutoTokenizer.from_pretrained(bleurt_id)
bleurt_model = AutoModelForSequenceClassification.from_pretrained(bleurt_id)
bleurt_tok.save_pretrained("models/bleurt-base-128")
bleurt_model.save_pretrained("models/bleurt-base-128")Stage1/config.py (class NewTrainConfig) defines the default data paths:
dataBuilder/
└── output/
├── train_dataset.json # Training set
└── val_dataset.json # Validation set
The raw data formats are flexible. Stage1/data_processor.py supports:
-
Nested QA schema
{ "instruction": "...", "input": {"question": "...", "knowledge": ["..."]}, "output": {"thinking": "...", "answer": "..."} } -
Flat QA schema
{ "instruction": "...", "question": "...", "knowledge": ["..."], "answer": "..." } -
Legacy judgment schema
{ "input": {"Instruction": "...", "Knowledge": ["..."], "describe": "..."}, "output": {"answer": "Right" or "Error", "reason": "..."} }
These are normalized into a JSON-style prompt string of the form:
{
"Instruction": "...",
"Knowledge": ["..."],
"Question": "..."
}and a corresponding target output (either a JSON object or plain answer text, depending on output_mode).
When you adapt this project, please ensure your preprocessed datasets are placed at the paths expected by NewTrainConfig, or change the paths in Stage1/config.py.
Stage 1 performs LoRA fine-tuning of the base model on the structured TCM QA / judgment data.
All key hyper-parameters are defined in Stage1/config.py (NewTrainConfig):
model_name_or_path: base model directory (default:../models/Qwen)train_data_path,val_data_path: training / validation datamax_length: maximum sequence lengthoutput_mode:"full"(full structured output) or"answer"(answer-only)- LoRA:
lora_r,lora_alpha,lora_dropout,lora_target_modules - Training:
num_train_epochs,per_device_train_batch_size,gradient_accumulation_steps,learning_rate,weight_decay,warmup_ratio, etc. - System:
bf16,gradient_checkpointing, logging configuration, random seed.
Edit config.py to match your hardware and dataset size.
From the project root:
cd Stage1
python train.pyStage1/train.py will:
- Check GPU availability and basic environment conditions.
- Instantiate
NewTrainConfigfromconfig.py. - Load the base model and tokenizer (Qwen by default), disable the chat template, and attach LoRA modules.
- Use
NewJudgmentDataProcessor(data_processor.py) to:- load raw data;
- normalize it into JSON-style prompts and outputs;
- tokenize and create HF
Datasetobjects with properinput_ids,attention_mask, andlabels.
- Use
NewLoRATrainer(trainer.py) to run training with a customTrainersubclass, including- small-scale validation based on the unified metrics (ROUGE / BLEU / GLEU / Distinct / BERTScore / BLEURT);
- saving LoRA weights at each epoch and tracking the best epoch according to an aggregated metric.
The LoRA checkpoints and logs are written under the output_dir specified in NewTrainConfig (by default under ../ResModels/).
Stage2/train2.py provides an optional second stage where Prompt Tuning or Prefix Tuning is applied on top of a LoRA-merged base model, followed by multi-dataset validation.
Typical responsibilities of Stage2/train2.py include:
- Merging the base model and a chosen LoRA adapter into a new base.
- Applying Prompt Tuning or Prefix Tuning via
peft(e.g.,PromptTuningConfig,PrefixTuningConfig). - Preparing simple supervised datasets from generic QA-like JSON / JSONL files.
- Running validation on several datasets (e.g.,
val_dataset.json,ChatMed.jsonl,CMtMedQA.jsonl) using the same metric definitions asVerification/evaluate_trained_lora.py.
Usage (high-level):
cd Stage2
python train2.py --help # inspect all available optionsStage 2 is more engineering-heavy and is not strictly required to reproduce the primary LoRA fine-tuning and evaluation results reported in the paper.
The script Verification/evaluate_trained_lora.py evaluates a trained LoRA adapter in a standalone way. It supports both JSON and JSONL data files and is compatible with the same schemas used for training.
Supported evaluation items include:
-
New QA schema
{ "instruction": "...", "knowledge": ["...", "..."], "question": "...", "answer": "..." // reference answer (string) } -
Legacy schema
{ "input": {"Instruction": "...", "Knowledge": ["..."], "describe": "..."}, "output": {"answer": "...", "reason": "..."} }
The script normalizes all inputs into a JSON prompt with keys Instruction, Knowledge, and Question, and extracts a reference answer string for metric computation.
From the project root:
cd Verification
python evaluate_trained_lora.py \
--data_file path/to/eval_dataset.jsonl \
--base_model path/to/base_model_dir \
--lora_path path/to/lora_adapter_dir \
--output_dir path/to/verification_outputs \
--max_samples 0Main arguments:
--data_file: evaluation dataset (JSON or JSONL).--base_model: base model directory (must be consistent with Stage 1, e.g.,../models/Qwen).--lora_path: directory containing the trained LoRA adapter (e.g., onelora_epoch-*subfolder from Stage1 outputs).--output_dir: directory where evaluation JSON summaries are written.--max_samples: maximum number of samples to evaluate (0= all).
For each sample, the script:
- Prints the full JSON prompt sent to the model.
- Generates the model output using greedy decoding.
- Extracts the
answerfield (if present) from the output as the prediction. - Extracts the reference answer from the sample.
- Computes per-sample metrics using
StandardEvaluator:- BLEU-1/2/3/4, GLEU
- ROUGE-1-F, ROUGE-2-F, ROUGE-L (Recall / F1)
- QA F1, Distinct-1/2
- BERTScore, BLEURT
- Accumulates running averages and finally writes a JSON summary file with metric means and an overall
all_metrics_avg.
To reproduce the main training and evaluation pipeline from the paper:
-
Prepare models and data
- Place the base model under
models/Qwen/. - Place BERT and BLEURT models under
models/bert-base-chinese/andmodels/bleurt-base-128/. - Place the preprocessed training and validation datasets under
dataBuilder/output/.
- Place the base model under
-
Run Stage 1 (LoRA fine-tuning)
cd Stage1 python train.py -
Select a LoRA checkpoint
- Choose one of the
lora_epoch-*directories created under the Stage 1 output directory (see logs for the best epoch).
- Choose one of the
-
Run Verification
cd Verification python evaluate_trained_lora.py \ --data_file path/to/eval_dataset.jsonl \ --base_model ../models/Qwen \ --lora_path path/to/chosen_lora_epoch_dir \ --output_dir path/to/verification_outputs -
(Optional) Run Stage 2
- Use
Stage2/train2.pyfor Prompt / Prefix Tuning and additional validations if needed.
- Use
If you use this code or data in your research, please cite our paper:
@article{[your_key],
title = {[Project / Paper Title]},
author = {[Authors]},
journal = {[Venue]},
year = {[Year]},
}Replace the fields with your actual publication information before release.
This repository is released under the [Choose a License, e.g., MIT or Apache-2.0] license. See the LICENSE file for details.
If you have questions or suggestions, feel free to open an issue or contact:
- Email:
[20251513002@sspu.edu.cn] - GitHub Issues: please provide your environment details and a minimal reproduction when reporting bugs.