Skip to content

8. Quantization

sgsdxzy edited this page May 19, 2024 · 8 revisions

Aphrodite supports a variety of quantization methods (most of them, in fact). All the supported model classes can be run quantized, provided the originating quant library supports them. The ones we currently support are:

We also support KV Cache quantization, to allow using higher context lengths if memory constraints are present:

Let's go through how you can run each of them.

Weight Quantization

GPTQ

The GPTQ quantization in Aphrodite uses the ExllamaV2 kernels for boosting throughput. The bit sizes supported are: 2, 3, 4, and 8. Most models usually have a GPTQ converted version on Hugging Face, but you can manually convert using the transformers library:

from transformers import AutoModelForCausalLM, GPTQConfig

model_id = "/path/to/model"  # can also be a HF model
tokenizer = AutoTokenizer.from_pretrained(model_id)
gptq_config = GPTQConfig(
    bits=4,
    dataset="wikitext2",
    group_size=128,
    desc_act=True,
    use_cuda_fp16=True,
    tokenizer=tokenizer
)

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=gptq_config, attn_implementation="sdpa")
model.config.quantization_config.dataset = None
model.save_pretrained(f"{model_id}-GPTQ")

Replace the model_id with either the local path to your model on disk, or a HuggingFace model ID.

To load your GPTQ model into Aphrodite, simply pass the path or ID to the --model flag, it'll figure everything out.

AWQ

Similar to GPTQ, you can either find conversions on HF, or conver it yourself. The conversion process is very similar:

from transformers import AutoModelForCausalLM, AwqConfig

model_id = "/path/to/model"  # can also be a HF model
tokenizer = AutoTokenizer.from_pretrained(model_id)
awq_config = AwqConfig(
    bits=4,
    dataset="wikitext2",
    group_size=128,
    desc_act=True,
    use_cuda_fp16=True,
    tokenizer=tokenizer
)

model = AutoModeForCausalLM.from_pretrained(model_id, quantization_config=awq_config, attn_implementation="sdpa")
model.config.quantization_config.dataset = None
model.save_pretrained(f"{model_id}-AWQ")

Loading process is the same.

Exllama V2

There aren't as many Exllama V2 models on HF, so you may need to convert some of them. Please refer to this page for instructions: https://github.com/turboderp/exllamav2/blob/master/doc/convert.md

Similar to all quant methods, just point --model to the directory where your exl2 model is.

GGUF

Most models should have a GGUF variant uploaded to HF. Unlike other models, GGUF is contained within a single file, so you cannot pass a HuggingFace ID to the --model flag. Download the GGUF file, and point the --model to it. Starting from v0.5.3 Aphrodite extends support for GGUF to all available model architectures besides LLAMA, and sharded (multiple-file) GGUF. To use GGUF in Aphrodite, there are two ways:

Pre-convert to pytorch state_dict (recommanded)

Pre-conversion has the advantage of faster loading time after conversion, saves a lot of system RAM when using multiple GPUs, and support model architectures besides LLAMA. To convert, run examples/gguf_to_torch.py. Important args:

  • --input: Path to the single .gguf file, or the directory containing sharded .gguf files.
  • --output: The path to output directory.
  • --unquantized-path: The path to the unquantized model to copy config and tokenizer. For llama 1&2 models this can be skipped to try extracting the config and tokenizer from the GGUF file, but it is recommended to always supply this because the tokenizer inside GGUF can sometimes be broken.
  • --no-tokenizer: Do not try to copy or extract the tokenizer. Useful if you plan to supply another tokenizer via --tokenizer to Aphrodite later.

Run directly

Start Aphrodite with --model pointing to the single .gguf file, or the directory containing sharded .gguf files. If it is a directory, Aphrodite will try to use the config.json, other json configs and the tokenizer in transformers format inside the directory. The model must be of LlamaForCausalLM architecture to be loaded directly form GGUF, otherwise the original config.json and other json configs must be present in the directory. The tokenizer must be of LlamaTokenizer architecture to be loaded directly form GGUF, otherwise the original tokenizer must be present in the directory, or optionally use --tokenizer to choose another tokenizer.

To convert an FP16 model to GGUF in the first place, see here.

Marlin

Marlin needs a GPTQ model that satisfies a few conditions:

  • bits=4
  • group_size= -1 or 128
  • desc_act=False

You can follow the GPTQ conversion guide above but with these specific parameters. Then see here for how to convert to Marlin. Loading process same as others.

SmoothQuant+

No conversion is needed here. Simply load an FP16 model with either of these two flags:

  • --load-in-4bit

This is the fastest Quant method currently available, beats both GPTQ and Exllamav2. The start time is a bit slow as it needs to convert the model to 4bit. The quality, however, is very good. Reportedly as good or better than AWQ. You can also load AWQ models with this flag for faster speeds!

  • --load-in-smooth

This flag loads your FP16 model in 8bit using SmoothQuant+. It's slower than FP16, but is almost lossless at half the memory usage.

Bitsandbytes

Similar to SmoothQuant+, Bitsandbytes quantization is implicit and done automatically. Only 8bit is currently supported. Simply load your FP16 model with --load-in-8bit. This option is very slow, so please use --load-in-smooth if you need implicit 8-bit conversion.

QuIP#

Please use this library to quantize your model to QuIP#. These are very rare, so you likely won't find an HF upload.

SqueezeLLM

This 4-bit quantization method is also rare on HF. The speeds aren't good, so they're not very popular. Please refer to this repo for conversion.

AQLM

AQLM, similar to what QuIP# offers, is a state-of-the-art 2-bit quantization method. It takes a long time to quantize a model (12 days on 8xA100s for a 70B model), but inference is quite fast. See here for a list of quantized AQLM models, and here for how to quantize a model yourself.

KV Cache Quantization

FP8 E5M2

We support automatic quantization of the KV cache using NVIDIA's FP8 Intrinsics. This feature is only available on CUDA 11.8 and above. To use this, load your model with --kv-cache-dtype fp8_e5m2. This leads to a performance boost of ~20% and negligible accuracy loss.

INT8

INT8 quantization of the KV cache needs a bit more work. It's much better in quality compared to FP8, since we perform calibration.

You will need to prepare scales and zero points for the specific model you're using.

  1. Install extra dependencies:
pip install fire
  1. Generate state cache:
python aphrodite/kv_quant/calibrate.py --model meta-llama/Llama-2-7b-hf --calib_dataset wikitext2 --calib_samples 128 --calib_seqlen 4096 --work_dir kv_cache_states/llama-2-13b
  1. Export scales and zero points:
python aphrodite/kv_quant/export_kv_params.py --work_dir kv_cache_states/llama-2-13b --kv_params_dir quant_params/llama-2-13b

Then you can load your model with these two extra args:

--kv-cache-dtype int8 --kv-quant-params-path quant_params/llama-2-7b

This leads to no performance boost, but should save memory on context.