Skip to content

GuojcKnight/GRAG4CM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GRAG4CM

This repository contains the core code for the paper GRAG4CM: A RAG-Aware Generative Model with Chain-of-Thought Reasoning for Chinese Medicine. It implements:

  • 🚀 Stage 1 (Stage1/): LoRA fine-tuning of a local base model (default: Qwen) on structured TCM QA / judgment data.
  • 🎯 Stage 2 (Stage2/): Optional Prompt / Prefix Tuning on top of the LoRA-merged base model, plus unified validation.
  • 📊 Verification (Verification/): Standalone evaluation of a trained LoRA model with multiple QA-style metrics.

📁 1. Repository Layout

.
├── Stage1/                      # Stage 1: LoRA fine-tuning (Qwen)
│   ├── config.py                # Training configuration (model, data paths, LoRA and training hyper-params)
│   ├── data_processor.py        # Data loading & normalization into JSON-style prompts and targets
│   ├── standard_evaluator.py    # Unified evaluator (ROUGE / BLEU / GLEU / Distinct / BERTScore / BLEURT, etc.)
│   ├── trainer.py               # LoRA training logic based on HF Trainer + custom evaluation
│   └── train.py                 # Entry point for Qwen LoRA fine-tuning
│
├── Stage2/                      # Stage 2: Prompt / Prefix Tuning (optional / advanced)
│   └── train2.py                # P-Tuning / Prefix-Tuning + multi-dataset validation
│
├── Verification/                # Standalone evaluation of a trained LoRA adapter
│   └── evaluate_trained_lora.py # Per-sample generation and metric accumulation
│
├── GTCMQA.json                  # Example or related dataset (used in the paper)
└── README.md

Stage 1 (Stage1/) and Verification (Verification/) together are sufficient to reproduce the main training and evaluation pipeline. Stage 2 (Stage2/) provides an additional P/Prefix-Tuning phase and is optional.


⚙️ 2. Environment & Dependencies

  • Python 3.10+ is recommended.
  • A GPU with sufficient memory (e.g., 24GB) is strongly recommended for training.
  • Core Python libraries:
    • torch (CUDA build)
    • transformers
    • peft
    • datasets
    • numpy

Example installation (adapt to your environment):

pip install torch transformers peft datasets numpy

BERTScore and BLEURT in this project are implemented using local models loaded via transformers. No extra BLEURT toolkit is required.


📂 3. Models & Data Preparation

🧠 3.1 Base Model (Qwen)

Prepare a local base model in models/Qwen/. Any Qwen-like Causal LM that is compatible with transformers can be used.

Example (download from HuggingFace and save locally):

from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "Qwen/xxx"  # TODO: replace with the exact model you use

tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)

save_dir = "models/Qwen"
tok.save_pretrained(save_dir)
model.save_pretrained(save_dir)

🧪 3.2 Evaluation Models: BERT & BLEURT

Stage1/standard_evaluator.py and Verification/evaluate_trained_lora.py expect the following local models:

models/
├── bert-base-chinese/       # For BERTScore
└── bleurt-base-128/         # For BLEURT (regression scores, normalized to [0, 1])

You can download them from HuggingFace and save them in the above directories, e.g.:

from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification

# BERT-Base Chinese
bert_id = "bert-base-chinese"
bert_tok = AutoTokenizer.from_pretrained(bert_id)
bert_model = AutoModel.from_pretrained(bert_id)
bert_tok.save_pretrained("models/bert-base-chinese")
bert_model.save_pretrained("models/bert-base-chinese")

# BLEURT-base-128 (replace with the exact BLEURT model you use)
bleurt_id = "[bleurt-model-id]"  # e.g., "lucadiliello/BLEURT-20" or your local export
bleurt_tok = AutoTokenizer.from_pretrained(bleurt_id)
bleurt_model = AutoModelForSequenceClassification.from_pretrained(bleurt_id)
bleurt_tok.save_pretrained("models/bleurt-base-128")
bleurt_model.save_pretrained("models/bleurt-base-128")

📄 3.3 Training & Validation Data

Stage1/config.py (class NewTrainConfig) defines the default data paths:

dataBuilder/
└── output/
    ├── train_dataset.json   # Training set
    └── val_dataset.json     # Validation set

The raw data formats are flexible. Stage1/data_processor.py supports:

  1. Nested QA schema

    {
      "instruction": "...",
      "input": {"question": "...", "knowledge": ["..."]},
      "output": {"thinking": "...", "answer": "..."}
    }
  2. Flat QA schema

    {
      "instruction": "...",
      "question": "...",
      "knowledge": ["..."],
      "answer": "..."
    }
  3. Legacy judgment schema

    {
      "input": {"Instruction": "...", "Knowledge": ["..."], "describe": "..."},
      "output": {"answer": "Right" or "Error", "reason": "..."}
    }

These are normalized into a JSON-style prompt string of the form:

{
  "Instruction": "...",
  "Knowledge": ["..."],
  "Question": "..."
}

and a corresponding target output (either a JSON object or plain answer text, depending on output_mode).

When you adapt this project, please ensure your preprocessed datasets are placed at the paths expected by NewTrainConfig, or change the paths in Stage1/config.py.


🚀 4. Stage 1: LoRA Fine-tuning (Stage1/)

Stage 1 performs LoRA fine-tuning of the base model on the structured TCM QA / judgment data.

4.1 Configuration

All key hyper-parameters are defined in Stage1/config.py (NewTrainConfig):

  • model_name_or_path: base model directory (default: ../models/Qwen)
  • train_data_path, val_data_path: training / validation data
  • max_length: maximum sequence length
  • output_mode: "full" (full structured output) or "answer" (answer-only)
  • LoRA: lora_r, lora_alpha, lora_dropout, lora_target_modules
  • Training: num_train_epochs, per_device_train_batch_size, gradient_accumulation_steps, learning_rate, weight_decay, warmup_ratio, etc.
  • System: bf16, gradient_checkpointing, logging configuration, random seed.

Edit config.py to match your hardware and dataset size.

4.2 Start Training

From the project root:

cd Stage1
python train.py

Stage1/train.py will:

  1. Check GPU availability and basic environment conditions.
  2. Instantiate NewTrainConfig from config.py.
  3. Load the base model and tokenizer (Qwen by default), disable the chat template, and attach LoRA modules.
  4. Use NewJudgmentDataProcessor (data_processor.py) to:
    • load raw data;
    • normalize it into JSON-style prompts and outputs;
    • tokenize and create HF Dataset objects with proper input_ids, attention_mask, and labels.
  5. Use NewLoRATrainer (trainer.py) to run training with a custom Trainer subclass, including
    • small-scale validation based on the unified metrics (ROUGE / BLEU / GLEU / Distinct / BERTScore / BLEURT);
    • saving LoRA weights at each epoch and tracking the best epoch according to an aggregated metric.

The LoRA checkpoints and logs are written under the output_dir specified in NewTrainConfig (by default under ../ResModels/).


🎯 5. Stage 2: Prompt / Prefix Tuning (Stage2/, Optional)

Stage2/train2.py provides an optional second stage where Prompt Tuning or Prefix Tuning is applied on top of a LoRA-merged base model, followed by multi-dataset validation.

Typical responsibilities of Stage2/train2.py include:

  • Merging the base model and a chosen LoRA adapter into a new base.
  • Applying Prompt Tuning or Prefix Tuning via peft (e.g., PromptTuningConfig, PrefixTuningConfig).
  • Preparing simple supervised datasets from generic QA-like JSON / JSONL files.
  • Running validation on several datasets (e.g., val_dataset.json, ChatMed.jsonl, CMtMedQA.jsonl) using the same metric definitions as Verification/evaluate_trained_lora.py.

Usage (high-level):

cd Stage2
python train2.py --help   # inspect all available options

Stage 2 is more engineering-heavy and is not strictly required to reproduce the primary LoRA fine-tuning and evaluation results reported in the paper.


📊 6. Verification: Evaluating a Trained LoRA Adapter (Verification/)

The script Verification/evaluate_trained_lora.py evaluates a trained LoRA adapter in a standalone way. It supports both JSON and JSONL data files and is compatible with the same schemas used for training.

📑 6.1 Data Schemas

Supported evaluation items include:

  • New QA schema

    {
      "instruction": "...",
      "knowledge": ["...", "..."],
      "question": "...",
      "answer": "..."    // reference answer (string)
    }
  • Legacy schema

    {
      "input": {"Instruction": "...", "Knowledge": ["..."], "describe": "..."},
      "output": {"answer": "...", "reason": "..."}
    }

The script normalizes all inputs into a JSON prompt with keys Instruction, Knowledge, and Question, and extracts a reference answer string for metric computation.

🧪 6.2 Usage Example

From the project root:

cd Verification

python evaluate_trained_lora.py \
  --data_file path/to/eval_dataset.jsonl \
  --base_model path/to/base_model_dir \
  --lora_path path/to/lora_adapter_dir \
  --output_dir path/to/verification_outputs \
  --max_samples 0

Main arguments:

  • --data_file: evaluation dataset (JSON or JSONL).
  • --base_model: base model directory (must be consistent with Stage 1, e.g., ../models/Qwen).
  • --lora_path: directory containing the trained LoRA adapter (e.g., one lora_epoch-* subfolder from Stage1 outputs).
  • --output_dir: directory where evaluation JSON summaries are written.
  • --max_samples: maximum number of samples to evaluate (0 = all).

For each sample, the script:

  1. Prints the full JSON prompt sent to the model.
  2. Generates the model output using greedy decoding.
  3. Extracts the answer field (if present) from the output as the prediction.
  4. Extracts the reference answer from the sample.
  5. Computes per-sample metrics using StandardEvaluator:
    • BLEU-1/2/3/4, GLEU
    • ROUGE-1-F, ROUGE-2-F, ROUGE-L (Recall / F1)
    • QA F1, Distinct-1/2
    • BERTScore, BLEURT
  6. Accumulates running averages and finally writes a JSON summary file with metric means and an overall all_metrics_avg.

🔁 7. Reproducing the Main Results

To reproduce the main training and evaluation pipeline from the paper:

  1. Prepare models and data

    • Place the base model under models/Qwen/.
    • Place BERT and BLEURT models under models/bert-base-chinese/ and models/bleurt-base-128/.
    • Place the preprocessed training and validation datasets under dataBuilder/output/.
  2. Run Stage 1 (LoRA fine-tuning)

    cd Stage1
    python train.py
  3. Select a LoRA checkpoint

    • Choose one of the lora_epoch-* directories created under the Stage 1 output directory (see logs for the best epoch).
  4. Run Verification

    cd Verification
    python evaluate_trained_lora.py \
      --data_file path/to/eval_dataset.jsonl \
      --base_model ../models/Qwen \
      --lora_path path/to/chosen_lora_epoch_dir \
      --output_dir path/to/verification_outputs
  5. (Optional) Run Stage 2

    • Use Stage2/train2.py for Prompt / Prefix Tuning and additional validations if needed.

📝 8. Citation

If you use this code or data in your research, please cite our paper:

@article{[your_key],
  title   = {[Project / Paper Title]},
  author  = {[Authors]},
  journal = {[Venue]},
  year    = {[Year]},
}

Replace the fields with your actual publication information before release.


⚖️ 9. License

This repository is released under the [Choose a License, e.g., MIT or Apache-2.0] license. See the LICENSE file for details.


✉️ 10. Contact

If you have questions or suggestions, feel free to open an issue or contact:

  • Email: [20251513002@sspu.edu.cn]
  • GitHub Issues: please provide your environment details and a minimal reproduction when reporting bugs.

About

Official implementation of GRAG4CM: A RAG-Aware Generative Model with Chain-of-Thought Reasoning for Chinese Medicine".

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages