Skip to content

Fix: Paroquant impl accuracy#2601

Merged
Qubitium merged 5 commits intomainfrom
paroquant-sync
Mar 24, 2026
Merged

Fix: Paroquant impl accuracy#2601
Qubitium merged 5 commits intomainfrom
paroquant-sync

Conversation

@Qubitium
Copy link
Copy Markdown
Collaborator

  GPU: NVIDIA GeForce RTX 4090 (PCI-ordered)
  +----------------+-----------------------+-----------------------+------------------+-------------------+-------------------+------------------+
  | Shape          | Ours(gptqmodel) s/loss| Ours(reference) s/loss| Official s/loss  | Off/OursG speed   | Off/OursR speed   | Off/OursG loss   |
  +----------------+-----------------------+-----------------------+------------------+-------------------+-------------------+------------------+
  | 192x2048->512  | 2.817 / 0.005365      | 7.120 / 0.005243      | 5.655 / 0.004038 | 2.007 (official slower) | 0.794 (official faster) | 0.753 |
  | 192x2048->2048 | 2.534 / 0.005494      | 7.084 / 0.005447      | 5.666 / 0.004086 | 2.236 (official slower) | 0.800 (official faster) | 0.744 |
  | 192x8192->2048 |10.132 / 0.022573      |28.321 / 0.022481      |22.080 / 0.016616 | 2.179 (official slower) | 0.780 (official faster) | 0.736 |
  +----------------+-----------------------+-----------------------+------------------+-------------------+-------------------+------------------+
  | AVG            | 5.161 / 0.011144      |14.175 / 0.011057      |11.134 / 0.008246 | 2.141             | 0.791             | 0.744            |
  +----------------+-----------------------+-----------------------+------------------+-------------------+-------------------+------------------+

  GPU: NVIDIA PG506-230 (PCI-ordered, non-A100)
  +----------------+-----------------------+-----------------------+------------------+-------------------+-------------------+------------------+
  | Shape          | Ours(gptqmodel) s/loss| Ours(reference) s/loss| Official s/loss  | Off/OursG speed   | Off/OursR speed   | Off/OursG loss   |
  +----------------+-----------------------+-----------------------+------------------+-------------------+-------------------+------------------+
  | 192x2048->512  | 3.023 / 0.005317      | 7.403 / 0.005368      | 5.646 / 0.004039 | 1.868 (official slower) | 0.763 (official faster) | 0.760 |
  | 192x2048->2048 | 2.646 / 0.005515      | 7.247 / 0.005540      | 5.566 / 0.004099 | 2.103 (official slower) | 0.768 (official faster) | 0.743 |
  | 192x8192->2048 |10.326 / 0.022637      |28.758 / 0.022575      |21.803 / 0.016303 | 2.111 (official slower) | 0.758 (official faster) | 0.720 |
  +----------------+-----------------------+-----------------------+------------------+-------------------+-------------------+------------------+
  | AVG            | 5.332 / 0.011157      |14.470 / 0.011161      |11.005 / 0.008147 | 2.028             | 0.763             | 0.741            |
  +----------------+-----------------------+-----------------------+------------------+-------------------+-------------------+------------------+

@Qubitium
Copy link
Copy Markdown
Collaborator Author

We still have accuracy drift at the module quantization level vs reference. There appears to be a bug in porting the code where we are treaing sym=True as no-zero-point optimizations but paroquant optmization stage actually can still optimize for zeropoint even in symmetric mode.

@Qubitium
Copy link
Copy Markdown
Collaborator Author

Qubitium commented Mar 24, 2026

@liang2kl @zhijian-liu @HsChen-sys @SubSir Hi guys! I think some of you already know I have been working to implement a fast but accurate port/implementation of ParoQuant to GPT-QModel so paro can enjoy the lifecycle and expanded models that alraedy support and will support.

I have benchmarked ParoQuant even with the current no-so-great optimizer (vs reference) and the results are very good to excellent. It may get even better as this code gets closer to parity with reference.

Right now I am hunting down per-module quantization accuracy drifts betwen the current code and z-lab's reference implmentation.

Once ready I would love to have feed back from your team. Also, GPT-QModel integration with kernel selection is already used by HF Transformers for both gptq and awq, we could also have GPT-QModel handle the loading and module level (kernel) loading of paroquant as well (since the config is so similar to gptq/awq) which we already have bindings for inside Transformers.

Let me know your thoughts.

@Qubitium
Copy link
Copy Markdown
Collaborator Author

Qubitium commented Mar 24, 2026

reference _impl accuracy drift vs upstream fixed.

Current status:

fast is 2x faster than official but has 25% higher error_loss
reference is within 1% of official but 25% slower than official.

  +--------------------+---------------------------+---------------------------------------------+---------------------------------------------------+----------------------+
  | Attribute          | What it controls          | fast                                        | reference                                         | Closest to upstream  |
  +--------------------+---------------------------+---------------------------------------------+---------------------------------------------------+----------------------+
  | opt_stage_impl     | optimizer/train loop      | our simpler AdamW loop in fp32              | AMP + GradScaler + official-style cosine schedule | reference            |
  | opt_pair_impl      | rotation pair schedule    | our direct packed-buffer builder            | official-style pair selection + kernel padding    | reference            |
  | opt_quantizer_impl | quantizer math + export   | symmetric quantizer                         | affine quantizer with learned zero point          | reference            |
  +--------------------+---------------------------+---------------------------------------------+---------------------------------------------------+----------------------+

@Qubitium Qubitium changed the title Paroquant sync Fix: Paroquant impl accuracy Mar 24, 2026
@Qubitium
Copy link
Copy Markdown
Collaborator Author

Qubitium commented Mar 24, 2026

I did a sweep of all the stage_imp, pair_imp, and quant_imp options (fast, reference) on a100 and here are the resutt. Summary: quant_imp using reference is curcial for accuracy. pair_imp fast is a huge speed up vs reference.

rfr and ffr are the best. quant_impl should always use reference. pair_impl can use the much faster fast impl without lowering accuracy.

  AVG ranking across all 3 shapes
  +-------+-----------+-----------+-----------+----------+----------+--------+--------+--------+
  | Combo | Stage     | Pair      | Quant     | Avg s    | Avg loss | Speed# | Loss#  | Pareto |
  +-------+-----------+-----------+-----------+----------+----------+--------+--------+--------+
  | fff   | fast      | fast      | fast      | 5.247    | 0.011157 | 4      | 5      | no     |
  | ffr   | fast      | fast      | reference | 5.123    | 0.008257 | 2      | 3      | yes    | <--
  | frf   | fast      | reference | fast      | 14.282   | 0.011157 | 6      | 6      | no     |
  | frr   | fast      | reference | reference | 14.445   | 0.008257 | 8      | 4      | no     |
  | rff   | reference | fast      | fast      | 5.112    | 0.011161 | 1      | 7      | yes    |
  | rfr   | reference | fast      | reference | 5.127    | 0.008124 | 3      | 1      | yes    | <--
  | rrf   | reference | reference | fast      | 14.334   | 0.011161 | 7      | 8      | no     |
  | rrr   | reference | reference | reference | 14.182   | 0.008124 | 5      | 2      | no     |
  +-------+-----------+-----------+-----------+----------+----------+--------+--------+--------+

  Per-shape time/loss
  +-------+--------------------+--------------------+--------------------+
  | Combo | 2048->512          | 2048->2048         | 8192->2048         |
  +-------+--------------------+--------------------+--------------------+
  | fff   | 2.656 / 0.005317   | 2.658 / 0.005515   | 10.426 / 0.022637  |
  | ffr   | 2.611 / 0.004051   | 2.618 / 0.004066   | 10.141 / 0.016656  |
  | frf   | 7.172 / 0.005317   | 7.082 / 0.005515   | 28.592 / 0.022637  |
  | frr   | 7.250 / 0.004051   | 7.567 / 0.004066   | 28.517 / 0.016656  |
  | rff   | 2.613 / 0.005368   | 2.607 / 0.005540   | 10.115 / 0.022575  |
  | rfr   | 2.653 / 0.004018   | 2.652 / 0.004047   | 10.075 / 0.016308  |
  | rrf   | 7.215 / 0.005368   | 7.130 / 0.005540   | 28.656 / 0.022575  |
  | rrr   | 7.258 / 0.004018   | 7.095 / 0.004047   | 28.192 / 0.016308  |
  +-------+--------------------+--------------------+--------------------+

@Qubitium
Copy link
Copy Markdown
Collaborator Author

Qubitium commented Mar 24, 2026

Re-ran the sweep and double the iterations. optimizer schedule unchanged at 1+1

  +-------+-----------+-----------+-----------+-----------+-----------+------------+------------+--------+--------+--------+
  | Combo | Stage     | Pair      | Quant     | Avg s     | Std s     | Avg loss   | Std loss   | Speed# | Loss#  | Pareto |
  +-------+-----------+-----------+-----------+-----------+-----------+------------+------------+--------+--------+--------+
  | fff   | fast      | fast      | fast      | 5.229     | 0.012     | 0.011157   | 0.000000   | 3      | 5      | no     |
  | ffr   | fast      | fast      | reference | 5.278     | 0.010     | 0.008257   | 0.000000   | 4      | 3      | no     |
  | frf   | fast      | reference | fast      | 14.234    | 0.050     | 0.011157   | 0.000000   | 6      | 6      | no     |
  | frr   | fast      | reference | reference | 14.303    | 0.075     | 0.008257   | 0.000000   | 7      | 4      | no     |
  | rff   | reference | fast      | fast      | 5.207     | 0.019     | 0.011161   | 0.000000   | 2      | 7      | no     |
  | rfr   | reference | fast      | reference | 5.170     | 0.008     | 0.008124   | 0.000000   | 1      | 1      | yes    |  <--
  | rrf   | reference | reference | fast      | 14.087    | 0.070     | 0.011161   | 0.000000   | 5      | 8      | no     |
  | rrr   | reference | reference | reference | 14.459    | 0.180     | 0.008124   | 0.000000   | 8      | 2      | no     |
  +-------+-----------+-----------+-----------+-----------+-----------+------------+------------+--------+--------+--------+

rfr is now the clear winner in this simulated single module quantization test.

@Qubitium Qubitium merged commit b8595e8 into main Mar 24, 2026
6 checks passed
@Qubitium Qubitium deleted the paroquant-sync branch March 24, 2026 17:38
@zhijian-liu
Copy link
Copy Markdown

Happy to help in any way we can! Thanks for bringing in the support!

@liang2kl
Copy link
Copy Markdown
Collaborator

liang2kl commented Mar 25, 2026

@Qubitium Thanks for the excellent support! The implementation looks very solid.

I'd like to discuss a few currently WIP changes in our codebase:

Feel free to test these out and see if they are actually better!

@Qubitium
Copy link
Copy Markdown
Collaborator Author

Qubitium commented Mar 25, 2026

@zhijian-liu @liang2kl I just cleaned up some code around paroquant and test and ran a fast/simple test and everything seems to work. The next step is to full validations (kernel and quantization) accuracy validation vs official implementation from z-labs. Still need to wire up more unit tests to verify regression and accuracy.

To assist with model testing, I have invited you two to ModelCloud/Evalution (currently private but will be public soon repo) which does llm benchmarks that we all know and love. It's a reimplentation of what I think an llm-benchmark should be structured. It currently uses Transformers/GPTQmodel for inference using paged attention + fa2 + continous batching so the tests are super fast. vllm/sglang and other engines` and pending dev/testing.

Install evalution (main) + gptqmodel (main) + flash-attn + transformers latest (5.3.0 I think) and run the following full quant test + eval (gsm8k, arc, mmlu.stem) micro test. The test only quantize the 2 layers of llama 3.2 1b instruct to save time and the eval datasets are row limited with max_rows to reduce time as well. This is my ci test so I need to run as fast possible but you guys can change/tweak to run as a full test.

CUDA_VISIBLE_DEVICES=0 pytest -v tests/models/teset_llama3_2_paroquant.py

if you crash in benchmark. run pip uninstall kernels (upstream bug fix, pending).

I recommend a 3.13t or 3.14t (nogil) env and run with PYTHON_GIL=0 as gptqmodel is designed to be thread-safe and there are lots of speed ups that we can use with true-threading both in inference and quant.

GPTQModel tree baed model deinition structure allows fast model prototyping (adding new models) so zlabs can concentrate on paroquant correctness/efficiency and less about how to wire it up to new models.

GPTQModel also supports multi-gpu accelerated processing for subsets like qkv, gate,up, etc. and with paroquant being much more intensive, the multi-gpu accelerateion (auto) will be a huge bost. I have yet to test this multi-gpu (data-parallel) acceleration with paroquant yet but will do it later. multi-gpu acceleration requires you to install 3.13t, 3.14t (nogil) version of Python and launch with PYTHON_GIL=0.

INFO  Evaluation comparison:
| Metric                                       |   PAROQUANT |
|----------------------------------------------|-------------|
| arc_challenge :: accuracy,loglikelihood      |      0.3281 |
| arc_challenge :: accuracy,loglikelihood_norm |      0.3359 |
| gsm8k_platinum_cot :: acc,num                |      0.3984 |
| mmlu_stem :: acc,ll                          |      0.3125 |
| mmlu_stem :: acc,ll_avg                      |      0.3125 |
INFO  Reusing post-quant validation model for backend `PAROQUANT`                                                                                                                                                                       
INFO  gc.collect() reclaimed 0 objects in 0.979s                                                                                                                                                                                        
modules in model: {<class 'gptqmodel.nn_modules.qlinear.paroquant.ParoQuantQuantLinear'>}
INFO  Reusing evaluation results for backend `PAROQUANT`; skipping duplicate lm_eval run                                                                                                                                                
INFO  gsm8k_platinum_cot:acc,num: `0.3984375` vs `0.460938` diff 86.44%                                                                                                                                                                 
INFO  arc_challenge:acc: `0.328125` vs `0.3216723549488055` diff 102.01%                                                                                                                                                                
INFO  arc_challenge:acc_norm: `0.3359375` vs `0.3515358361774744` diff 95.56%                                                                                                                                                           
INFO  mmlu_stem:acc: `0.3125` vs `0.40120520139549637` diff 77.89%             
INFO  +---------+-------+---------------------------+---------------+--------------+--------------+---------+---------+-------+----------+--------------+---------+                                                                     
INFO  | process | layer | module                    | feat: in, out | dtype: size  | loss         | samples | damp    | time  | fwd_time | (v)ram       | dynamic |                                                                     
INFO  +---------+-------+---------------------------+---------------+--------------+--------------+---------+---------+-------+----------+--------------+---------+                                                                     
INFO  | paroquant | 0     | self_attn.q_proj          | 2048, 2048    | f16: 16.2MB  | 0.0004742453 | 2048    |         | 3.697 | 1.357    | cuda 1.14G   |         |                                                                   
INFO  +-----------+-------+---------------------------+---------------+--------------+--------------+---------+---------+-------+----------+--------------+---------+                                                                   
INFO  | paroquant | 0     | self_attn.k_proj          | 2048, 512     | f16: 4.1MB   | 0.0008439685 | 2048    |         | 2.618 | 1.357    | cuda 1.14G   |         |                                                                   
INFO  +-----------+-------+---------------------------+---------------+--------------+--------------+---------+---------+-------+----------+--------------+---------+                                                                   
INFO  | paroquant | 0     | self_attn.v_proj          | 2048, 512     | f16: 4.1MB   | 0.0000160968 | 2048    |         | 2.543 | 1.357    | cuda 1.14G   |         |                                                                   
INFO  +-----------+-------+---------------------------+---------------+--------------+--------------+---------+---------+-------+----------+--------------+---------+                                                                   
INFO  | paroquant | 0     | self_attn.o_proj          | 2048, 2048    | f16: 16.2MB  | 0.0000002911 | 2048    |         | 2.726 | 1.357    | cuda 1.16G   |         |                                                                   
INFO  +-----------+-------+---------------------------+---------------+--------------+--------------+---------+---------+-------+----------+--------------+---------+                                                                   
INFO  | paroquant | 0     | mlp.gate_proj             | 2048, 8192    | f16: 64.8MB  | 0.0000885336 | 2048    |         | 3.306 | 1.357    | cuda 4.93G   |         |                                                                   
INFO  +-----------+-------+---------------------------+---------------+--------------+--------------+---------+---------+-------+----------+--------------+---------+                                                                   
INFO  | paroquant | 0     | mlp.up_proj               | 2048, 8192    | f16: 64.8MB  | 0.0000676560 | 2048    |         | 3.255 | 1.357    | cuda 4.99G   |         |                                                                   
INFO  +-----------+-------+---------------------------+---------------+--------------+--------------+---------+---------+-------+----------+--------------+---------+                                                                   
INFO  | paroquant | 0     | mlp.down_proj             | 8192, 2048    | f16: 65.0MB  | 0.0000016590 | 2048    |         | 9.795 | 1.357    | cuda 4.99G   |         |                                                                   
INFO  +-----------+-------+---------------------------+---------------+--------------+--------------+---------+---------+-------+----------+--------------+---------+                                                                   
INFO  | Process quant  | 14    | 28.244 | 2.017 | 28.244  | 65.4%  | model.layers.0.mlp.down_proj   |                                                                                                                                   
INFO  +----------------+-------+--------+-------+---------+--------+--------------------------------+                                                                                                                                   
INFO  | Capture inputs | 1     | 7.102  | 7.102 | 7.102   | 16.5%  | cache_inputs:LlamaDecoderLayer |                                                                                                                                   
INFO  +----------------+-------+--------+-------+---------+--------+--------------------------------+                                                                                                                                   
INFO  | Pre-quant forward | 8     | 1.357  | 0.502 | 4.013   | 9.3%   | model.layers.0:subset4/4       |                                                                                                                                
INFO  +-------------------+-------+--------+-------+---------+--------+--------------------------------+                                                                                                                                
INFO  | Forward hook      | 644   | 0.011  | 0.004 | 2.874   | 6.7%   | model.layers.0.mlp.down_proj   |                                                                                                                                
INFO  +-------------------+-------+--------+-------+---------+--------+--------------------------------+                                                                                                                                
INFO  | Post-quant replay | 1     | 0.664  | 0.664 | 0.664   | 1.5%   | model.layers.0:subset4/4       |                                                                                                                                
INFO  +-------------------+-------+--------+-------+---------+--------+--------------------------------+                                                                                                                                
INFO  | Turtle reload     | 1     | 0.257  | 0.257 | 0.257   | 0.6%   | auto:Embedding                 |                                                                                                                                
INFO  +-------------------+-------+--------+-------+---------+--------+--------------------------------+                                                                                                                                
INFO  | process   | layer | module                    | feat: in, out | dtype: size  | loss         | samples | damp    | time  | fwd_time | (v)ram       | dynamic |                                                                   
INFO  +-----------+-------+---------------------------+---------------+--------------+--------------+---------+---------+-------+----------+--------------+---------+                                                                   
INFO  | paroquant | 1     | self_attn.q_proj          | 2048, 2048    | f16: 16.2MB  | 0.0006264690 | 2048    |         | 4.039 | 1.335    | cuda 5.G     |         |                                                                   
INFO  +-----------+-------+---------------------------+---------------+--------------+--------------+---------+---------+-------+----------+--------------+---------+                                                                   
INFO  | paroquant | 1     | self_attn.k_proj          | 2048, 512     | f16: 4.1MB   | 0.0015308936 | 2048    |         | 2.722 | 1.335    | cuda 5.G     |         |                                                                   
INFO  +-----------+-------+---------------------------+---------------+--------------+--------------+---------+---------+-------+----------+--------------+---------+                                                                   
INFO  | paroquant | 1     | self_attn.v_proj          | 2048, 512     | f16: 4.1MB   | 0.0000656697 | 2048    |         | 2.630 | 1.335    | cuda 5.G     |         |                                                                   
INFO  +-----------+-------+---------------------------+---------------+--------------+--------------+---------+---------+-------+----------+--------------+---------+                                                                   
INFO  | paroquant | 1     | self_attn.o_proj          | 2048, 2048    | f16: 16.2MB  | 0.0000004194 | 2048    |         | 2.701 | 1.335    | cuda 5.G     |         |                                                                   
INFO  +-----------+-------+---------------------------+---------------+--------------+--------------+---------+---------+-------+----------+--------------+---------+                                                                   
INFO  | paroquant | 1     | mlp.gate_proj             | 2048, 8192    | f16: 64.8MB  | 0.0001432050 | 2048    |         | 3.213 | 1.335    | cuda 6.25G   |         |                                                                   
INFO  +-----------+-------+---------------------------+---------------+--------------+--------------+---------+---------+-------+----------+--------------+---------+                                                                   
INFO  | paroquant | 1     | mlp.up_proj               | 2048, 8192    | f16: 64.8MB  | 0.0001007438 | 2048    |         | 3.230 | 1.335    | cuda 6.31G   |         |                                                                   
INFO  +-----------+-------+---------------------------+---------------+--------------+--------------+---------+---------+-------+----------+--------------+---------+                                                                   
INFO  | paroquant | 1     | mlp.down_proj             | 8192, 2048    | f16: 65.0MB  | 0.0000033229 | 2048    |         | 9.602 | 1.335    | cuda 6.31G   |         |                                                                   
INFO  +-----------+-------+---------------------------+---------------+--------------+--------------+---------+---------+-------+----------+--------------+---------+                                                                   
INFO  | Process quant     | 28    | 28.499 | 2.027 | 56.744  | 54.5%  | model.layers.1.mlp.down_proj   |                                                                                                                                
INFO  +-------------------+-------+--------+-------+---------+--------+--------------------------------+                                                                                                                                
INFO  | Submodule finalize | 14    | 6.342  | 1.471 | 20.595  | 19.8%  | model.layers.0.mlp.down_proj   |                                                                                                                               
INFO  +--------------------+-------+--------+-------+---------+--------+--------------------------------+                                                                                                                               
INFO  | Pre-quant forward  | 16    | 1.335  | 0.485 | 7.759   | 7.5%   | model.layers.1:subset4/4       |                                                                                                                               
INFO  +--------------------+-------+--------+-------+---------+--------+--------------------------------+                                                                                                                               
INFO  | Capture inputs     | 1     | 7.102  | 7.102 | 7.102   | 6.8%   | cache_inputs:LlamaDecoderLayer |                                                                                                                               
INFO  +--------------------+-------+--------+-------+---------+--------+--------------------------------+                                                                                                                               
INFO  | Forward hook       | 1288  | 0.011  | 0.004 | 5.791   | 5.6%   | model.layers.1.mlp.down_proj   |                                                                                                                               
INFO  +--------------------+-------+--------+-------+---------+--------+--------------------------------+                                                                                                                               
INFO  | Finalize offload   | 7     | 0.023  | 0.703 | 4.921   | 4.7%   | model.layers.0.mlp.down_proj   |                                                                                                                               
INFO  +--------------------+-------+--------+-------+---------+--------+--------------------------------+                                                                                                                               
INFO  | Post-quant replay  | 2     | 0.202  | 0.433 | 0.867   | 0.8%   | model.layers.1:subset4/4       |                                                                                                                               
INFO  +--------------------+-------+--------+-------+---------+--------+--------------------------------+                                                                                                                               
INFO  | Turtle reload      | 1     | 0.257  | 0.257 | 0.257   | 0.2%   | auto:Embedding                 |                                                                                                                               
INFO  +--------------------+-------+--------+-------+---------+--------+--------------------------------+   

btw, gptqmodel also has full moe routing control so you can force all moe modules to receive all tokens or a fraction via QuantizeConfig. This is something that paroquant needs since many moe has extreme biased routing so some modules get ~0.0001 percent of tokens, not enough for activation sampling. mutil-gpu acceleration for moe also works with final fallbackmode where gptqmodel auto goes into rtn based quantization if an module that requires activation like gptq/awq does not get enough samples/activations. I have not yet tested the rtn fallback interation with paroquant yet, but it's there.

If you guys have free time, please help me test to see where I am doing right and where I am doing wrong. Feel free to PR and fix anything you see amiss as well.

@liang2kl
Copy link
Copy Markdown
Collaborator

@Qubitium Thanks for the information! I'll have a look and test it out. Thank you again for the strong support.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants