Reasoning SFT → GRPO training for Gemma 3 family with built-in benchmarks (GSM8K, MBPP, HumanEval) and optional quantum reward shaping.
Recommended order:
- Create and activate a Python 3.10–3.12 virtualenv.
- Install a matching PyTorch build for your CUDA/CPU:
# Example: CUDA 12.1
pip install --index-url https://download.pytorch.org/whl/cu121 torch torchvision torchaudio
# Or CPU-only
pip install --index-url https://download.pytorch.org/whl/cpu torch torchvision torchaudio
- Install the rest of the requirements:
pip install -r requirements.txt
Optional (GPU only): FlashAttention. This is not required to run the training script and is intentionally not pinned in requirements.txt
because it depends on your CUDA and compiler toolchain. If you want it:
# Try prebuilt wheels first
pip install flash-attn --no-build-isolation
# If that fails, consult the project docs for a source build compatible
# with your CUDA + PyTorch versions.
- Math SFT: GSM8K + Hendrycks MATH (selected configs). Always included.
- Code SFT (optional):
- MBPP: tries multiple variants; included by default when code solutions are available.
- HumanEval: optional; includes canonical solutions if present.
- Important: Using HumanEval for training contaminates the benchmark. If you plan to report HumanEval, keep it disabled during SFT.
- RL (GRPO):
- Math: GSM8K + MATH transformed to prompts with numeric exact-match rewards.
- Code: MBPP and HumanEval tasks with unit-test rewards (on-device Python sandbox). Both enabled by default; disable via flags below.
train_quantum_reasoner.py
— end-to-end SFT and GRPO with dataset toggles and fast batched eval.
Train with math only for SFT; include MBPP and HumanEval in RL:
python train_quantum_reasoner.py \
--base_model google/gemma-3-1b-it \
--output_dir ./runs/gemma-3-1b-qrl \
--sft_steps 300 --rl_steps 300 \
--no_sft_humaneval --sft_include_mbpp \
--rl_include_mbpp --rl_include_humaneval
Avoid HumanEval entirely (no contamination in either phase):
python train_quantum_reasoner.py \
--no_sft_humaneval --no_rl_humaneval
Limit dataset sizes to speed up experiments:
python train_quantum_reasoner.py \
--sft_mbpp_limit 100 \
--rl_mbpp_limit 200 --rl_humaneval_limit 30
- HumanEval in SFT is disabled by default; enable with
--sft_include_humaneval
only if you do not intend to report HumanEval scores as unbiased. - The Python sandbox used for code rewards is a local subprocess and not secure. Use only in controlled environments.