Skip to content

Latest commit

 

History

History
414 lines (372 loc) · 16.8 KB

CompressWeights.md

File metadata and controls

414 lines (372 loc) · 16.8 KB

Weights Compression

OpenVINO is the preferred backend to run Weights Compression with, and PyTorch is also supported.

The algorithm description

The Weights Compression algorithm is aimed at compressing the weights of the models and can be used to optimize the model footprint and performance of large models where the size of weights is relatively larger than the size of activations, for example, Large Language Models (LLM). The algorithm compresses weights for Linear, Convolution and Embedding layers.

Supported modes

By default, weights are compressed asymmetrically to 8-bit integer data type - "INT8_ASYM" mode. OpenVINO backend also supports 3 modes of mixed precision weight quantization with a 4-bit data type as a primary precision - INT4_SYM, INT4_ASYM and NF4. The primary precision in case of INT4_SYM mode is unsigned 4-bit integer and weights are quantized to it symmetrically with a fixed zero point equals to 8. In case of INT4_ASYM mode - also unsigned 4-bit integer, but weight are quantized to it asymmetrically with a typical non-fixed zero point. In case of NF4 mode - nf4 data type without zero point. All 4-bit modes have a grouped quantization support, when small group of weights (e.g. 128) in the channel dimension share quantization parameters (scale). All embeddings, convolutions and last linear layers are always compressed to 8-bit integer data type. To quantize embeddings and last linear layers to 4-bit, use all_layers=True. Percent of the rest layers compressed to 4-bit can be configured by "ratio" parameter. E.g. ratio=0.9 means 90% of layers compressed to the corresponding 4-bit data type and the rest to 8-bit asymmetric integer data type.

User guide

  • Compress weights asymmetrically to 8-bit integer data type.
from nncf import compress_weights
compressed_model = compress_weights(model) # model is openvino.Model object
  • Compress weights symmetrically to 8-bit integer data type.
from nncf import compress_weights, CompressWeightsMode
compressed_model = compress_weights(model, mode=CompressWeightsMode.INT8_SYM) # model is openvino.Model object
  • Compress weights symmetrically to 4-bit integer data type with group size = 128, except embeddings, convolutions and last linear layers - they are compressed asymmetrically to 8-bit integer data type.
from nncf import compress_weights, CompressWeightsMode
compressed_model = compress_weights(model, mode=CompressWeightsMode.INT4_SYM) # model is openvino.Model object
  • Generally, INT4_SYM mode is the fastest mixed-precision mode, but it may lead to a significant accuracy degradation or perplexity increase. Compressing weights asymmetrically (INT4_ASYM mode) is the way to increase accuracy, however in turns it slows down inference a bit. If the accuracy or perplexity is still not satisfying, there are 2 more hyper-parameters to tune: group_size and ratio. Please refer to the example how to automatically tune these parameters. Lower group size and less ratio of 4-bit layers usually improve accuracy at the sacrifice of inference speed. Below is the example how to compress weights of 90% of layers to 4-bit integer asymmetrically with the group size 64, and the rest of layers to 8-bit asymmetric integer data type. The same parametrization is applicable for INT4_SYM mode.
from nncf import compress_weights, CompressWeightsMode
compressed_model = compress_weights(model, mode=CompressWeightsMode.INT4_ASYM, group_size=64, ratio=0.9) # model is openvino.Model object
  • Accuracy of the 4-bit compressed models can be improved by using data-aware mixed-precision algorithm. It is capable to find outliers in the input activations and assign different quantization precision to minimize accuracy degradation. Below is the example how to compress 80% of layers to 4-bit integer with a default data-aware mixed precision algorithm. It requires just one extra parameter - a NNCF wrapper of the dataset. Refer to the full example of data-aware weight compression for more details. If dataset is not specified, data-free mixed precision algorithm works based on weights only. Refer to the second table below for evaluation of data-free and data-aware method on the wikitext dataset. On the average the data-aware mixed-precision weight compression takes more time than the data-free one (~30% slower on Intel(R) Xeon(R) Gold 6430L), since it infers model on calibration dataset to find outliers in the input activations.
from nncf import compress_weights, CompressWeightsMode, Dataset
nncf_dataset = nncf.Dataset(data_source, transform_fn)
compressed_model = compress_weights(model, mode=CompressWeightsMode.INT4_SYM, ratio=0.8, dataset=nncf_dataset) # model is openvino.Model object
  • Accuracy of the 4-bit compressed models also can be improved by using AWQ algorithm over data-based mixed-precision algorithm. It is capable to equalize some subset of weights to minimize difference between original precision and 4-bit. Below is the example how to compress 80% of layers to 4-bit integer with a default data-based mixed precision algorithm and AWQ. It requires to set awq to True additionally to data-based mixed-precision algorithm.
from datasets import load_dataset
from functools import partial
from nncf import compress_weights, CompressWeightsMode, Dataset
from optimum.intel.openvino import OVModelForCausalLM
from transformers import AutoTokenizer

def transform_func(item, tokenizer, input_shapes):
    text = item['text']
    tokens = tokenizer(text)

    res = {'input_ids': np.expand_dims(np.array(tokens['input_ids']), 0),
           'attention_mask': np.expand_dims(np.array(tokens['attention_mask']), 0)}

    if 'position_ids' in input_shapes:
        position_ids = np.cumsum(res['attention_mask'], axis=1) - 1
        position_ids[res['attention_mask'] == 0] = 1
        res['position_ids'] = position_ids

    for name, shape in input_shapes.items():
        if name in res:
            continue
        res[name] = np.zeros(shape)

    return res

def get_input_shapes(model, batch_size = 1):
    inputs = {}

    for val in model.model.inputs:
        name = val.any_name
        shape = list(val.partial_shape.get_min_shape())
        shape[0] = batch_size
        inputs[name] = shape

    return inputs

# load your model and tokenizer
model = OVModelForCausalLM.from_pretrained(...)
tokenizer = AutoTokenizer.from_pretrained(...)

# prepare dataset for compression
dataset = load_dataset('wikitext', 'wikitext-2-v1', split='train')
dataset = dataset.filter(lambda example: len(example["text"]) > 80)
input_shapes = get_input_shapes(model)
nncf_dataset = Dataset(dataset, partial(transform_func, tokenizer=tokenizer,
                                                        input_shapes=input_shapes))

model.model = compress_weights(model.model, mode=CompressWeightsMode.INT4_SYM, ratio=0.8, dataset=nncf_dataset, awq=True)

model.save_pretrained(...)
  • NF4 mode can be considered for improving accuracy, but currently models quantized to nf4 should not be faster models quantized to 8-bit asymmetric integer. Here's the example how to compress weights to nf4 data type with group size = 128. Different group_size and ratio are also supported.
from nncf import compress_weights, CompressWeightsMode
compressed_model = compress_weights(model, mode=CompressWeightsMode.NF4)

Evaluation results

Here is the perplexity and model size before and after weight compression for different language models on the Lambada OpenAI dataset. g32 refers to the group size equals to 32, r60 - to the ratio equals to 0.6.

Model Mode Perplexity (↓) Perplexity
Increase (↓)
Model Size
(Gb)
databricks/dolly-v2-3b fp32 5.01 0 10.3
databricks/dolly-v2-3b int8_asym 5.07 0.05 2.6
databricks/dolly-v2-3b int4_asym_g32_r50 5.28 0.26 2.2
databricks/dolly-v2-3b nf4_g128_r60 5.19 0.18 1.9
facebook/opt-6.7b fp32 4.25 0 24.8
facebook/opt-6.7b int8_asym 4.27 0.01 6.2
facebook/opt-6.7b int4_asym_g64_r80 4.32 0.07 4.1
facebook/opt-6.7b nf4_g64 4.35 0.1 3.6
meta-llama/Llama-2-7b-chat-hf fp32 3.28 0 25.1
meta-llama/Llama-2-7b-chat-hf int8_asym 3.29 0.01 6.3
meta-llama/Llama-2-7b-chat-hf int4_asym_g128_r80 3.41 0.14 4.0
meta-llama/Llama-2-7b-chat-hf nf4_g128 3.41 0.13 3.5
togethercomputer/RedPajama-INCITE-7B-Instruct fp32 4.15 0 25.6
togethercomputer/RedPajama-INCITE-7B-Instruct int8_asym 4.17 0.02 6.4
togethercomputer/RedPajama-INCITE-7B-Instruct nf4_ov_g32_r60 4.28 0.13 5.1
togethercomputer/RedPajama-INCITE-7B-Instruct int4_asym_g128 4.17 0.02 3.6
meta-llama/Llama-2-13b-chat-hf fp32 2.92 0 48.5
meta-llama/Llama-2-13b-chat-hf int8_asym 2.91 0 12.1
meta-llama/Llama-2-13b-chat-hf int4_sym_g64_r80 2.98 0.06 8.0
meta-llama/Llama-2-13b-chat-hf nf4_g128 2.95 0.04 6.6

Here is the word perplexity with data-free and data-aware mixed-precision INT4-INT8 weight compression for different language models on the wikitext dataset. data suffix refers to the data-aware mixed-precision. data_awq suffix refers to the data-aware mixed-precision with modified AWQ algorithm. This modification applies only for patterns MatMul-Multiply-MatMul (for example MLP block in LLama).

Model Mode Word Perplexity (↓)
meta-llama/llama-7b-chat-hf fp16 11.57
int4_sym_g128_r80_data 11.87
int4_sym_g128_r80 11.92
int4_sym_g128_r100_data_awq 12.34
int4_sym_g128_r100 12.35
stabilityai_stablelm-3b-4e1t fp16 10.16
int4_sym_g64_r80_data 10.67
int4_sym_g64_r80 10.83
int4_sym_g64_r100_data_awq 10.89
int4_sym_g64_r100 11.07
stable-zephyr-3b-dpo int4_sym_g64_r80_data_awq 21.62
int4_sym_g64_r80_data 21.74
int4_sym_g64_r80 23.10
int4_sym_g64_r100_data_awq 21.76
int4_sym_g64_r100 23.19
HuggingFaceH4/zephyr-7b-beta fp16 9.82
int4_sym_g128_r80_data 10.13
int4_sym_g128 10.22

Limitations

  • The algorithm is supported for OpenVINO and PyTorch models.
  • The compression applies in-place.
  • The compressed model is not trainable.
  • INT8_SYM, INT4_SYM, INT4_ASYM and NF4 modes, grouped quantization and mixed precision selection is available for OpenVINO backend only.
  • NF4 support is experimental - models quantized to nf4 should not be faster models quantized to 8-bit integer.

Additional resources

List of notebooks demonstrating OpenVINO conversion and inference together with NNCF weight compression for models from various domains: