This project implements a Chain of Thought (CoT) decoding method (Paper) for transformer models using PyTorch and Hugging Face's Transformers library. The CoT approach enhances the reasoning capabilities of models by allowing them to generate intermediate steps in their thought process.
- Load various transformer models and tokenizers from Hugging Face Hub.
- Calculate CoT scores to assess the quality of generated answers.
- Supports both sequential and parallel decoding (slightly less VRAM usage for sequential mode at the cost of being much slower).
- Optimized for both CPU and GPU (CUDA and MPS) environments.
- Configurable model selection, question input, and decoding parameters.
- Compatible with Llama, Phi models (Gemma2 is not working at the moment).
To set up the environment, you can create a new virtual environment and install the required packages using the provided requirements.txt
.
# Create a new environment (optional)
conda create -n cot-decoding python=3.10
conda activate cot-decoding
Then, install the requirements:
pip install -r requirements.txt
Run the script with the following command:
python main.py --model_name <model_name> --question "<your_question>" --k <number_of_branches> --aggregation <max|sum> --device <cuda|cpu|mps>
python main.py --model_name "meta-llama/Llama-3.2-1B-Instruct" --question "Sally has two brothers, Sam and Joe. Sam has one sister. How many sisters does Joe have? Think step by step. You MUST end your reply with Answer:, FOLLOWED BY A SINGLE NUMBER." --k 10 --aggregation max --device cuda
--model_name
: Model checkpoint name (default:meta-llama/Llama-3.2-1B-Instruct
).--question
: Question to ask the model (default: "Sally has two brothers...").--k
: Number of decoding branches (default: 10).--aggregation
: Method for aggregating CoT scores (max
orsum
, default:max
).--device
: Device to run the model on (cuda
,cpu
, ormps
, default:cuda
).--use_sequential
: Use sequential processing for low RAM situations (optional).--system_prompt
: Use a custom system prompt. If not given, defaults to the chat template included in the tokenizer (optional).
Contributions are welcome! Please feel free to submit a pull request or open an issue.
This project is licensed under the Apache 2.0 License - see the LICENSE file for details.