Skip to content

intel/neural-compressor

Repository files navigation

Intel® Neural Compressor

An open-source Python library supporting popular model compression techniques on all mainstream deep learning frameworks (TensorFlow, PyTorch, ONNX Runtime, and MXNet)

python version license coverage Downloads

Architecture   |   Workflow   |   LLMs Recipes   |   Results   |   Documentations


Intel® Neural Compressor aims to provide popular model compression techniques such as quantization, pruning (sparsity), distillation, and neural architecture search on mainstream frameworks such as TensorFlow, PyTorch, ONNX Runtime, and MXNet, as well as Intel extensions such as Intel Extension for TensorFlow and Intel Extension for PyTorch. In particular, the tool provides the key features, typical examples, and open collaborations as below:

What's New

Installation

Install from pypi

pip install neural-compressor

Note: Further installation methods can be found under Installation Guide. check out our FAQ for more details.

Getting Started

Setting up the environment:

pip install "neural-compressor>=2.3" "transformers>=4.34.0" torch torchvision

After successfully installing these packages, try your first quantization program.

Weight-Only Quantization (LLMs)

Following example code demonstrates Weight-Only Quantization on LLMs, it supports Intel CPU, Intel Gauid2 AI Accelerator, Nvidia GPU, best device will be selected automatically.

To try on Intel Gaudi2, docker image with Gaudi Software Stack is recommended, please refer to following script for environment setup. More details can be found in Gaudi Guide.

docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.14.0/ubuntu22.04//habanalabs/pytorch-installer-2.1.1:latest

# Check the container ID
docker ps

# Login into container
docker exec -it <container_id> bash

# Install the optimum-habana
pip install --upgrade-strategy eager optimum[habana]

# Install INC/auto_round
pip install neural-compressor auto_round

Run the example:

from transformers import AutoModel, AutoTokenizer

from neural_compressor.config import PostTrainingQuantConfig
from neural_compressor.quantization import fit
from neural_compressor.adaptor.torch_utils.auto_round import get_dataloader

model_name = "EleutherAI/gpt-neo-125m"
float_model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
dataloader = get_dataloader(tokenizer, seqlen=2048)

woq_conf = PostTrainingQuantConfig(
    approach="weight_only",
    op_type_dict={
        ".*": {  # match all ops
            "weight": {
                "dtype": "int",
                "bits": 4,
                "algorithm": "AUTOROUND",
            },
        }
    },
)
quantized_model = fit(model=float_model, conf=woq_conf, calib_dataloader=dataloader)

Note:

To try INT4 model inference, please directly use Intel Extension for Transformers, which leverages Intel Neural Compressor for model quantization.

Static Quantization (Non-LLMs)

from torchvision import models

from neural_compressor.config import PostTrainingQuantConfig
from neural_compressor.data import DataLoader, Datasets
from neural_compressor.quantization import fit

float_model = models.resnet18()
dataset = Datasets("pytorch")["dummy"](shape=(1, 3, 224, 224))
calib_dataloader = DataLoader(framework="pytorch", dataset=dataset)
static_quant_conf = PostTrainingQuantConfig()
quantized_model = fit(model=float_model, conf=static_quant_conf, calib_dataloader=calib_dataloader)

Documentation

Overview
Architecture Workflow APIs LLMs Recipes Examples
Python-based APIs
Quantization Advanced Mixed Precision Pruning (Sparsity) Distillation
Orchestration Benchmarking Distributed Compression Model Export
Neural Coder (Zero-code Optimization)
Launcher JupyterLab Extension Visual Studio Code Extension Supported Matrix
Advanced Topics
Adaptor Strategy Distillation for Quantization SmoothQuant
Weight-Only Quantization (INT8/INT4/FP4/NF4) FP8 Quantization Layer-Wise Quantization
Innovations for Productivity
Neural Insights Neural Solution

Note: Further documentations can be found at User Guide.

Selected Publications/Events

Note: View Full Publication List.

Additional Content

Communication

  • GitHub Issues: mainly for bug reports, new feature requests, question asking, etc.
  • Email: welcome to raise any interesting research ideas on model compression techniques by email for collaborations.
  • Discord Channel: join the discord channel for more flexible technical discussion.
  • WeChat group: scan the QA code to join the technical discussion.