# SmoothQuant on OPT-13B

### Guangxuan Xiao\*, Ji Lin\*, Mickael Seznec, Julien Demouth, Song Han

In this notebook, we use OPT-13B model to demonstrate SmoothQuant can use 8-bit for both weights and activations to achieve the same accuracy as FP16 models. Unlike previous method [[Dettmers *et al.*, 2022]](https://arxiv.org/abs/2208.07339), SmoothQuant enables fully INT8 GEMMs for linear layers and does not require high precision numbers to represent outliers. 

This notebook demonstrates SmoothQuant on OPT-13B in consideration of the user's resouce constraints. We have tested SmoothQuant on up to 176 billion parameter models (OPT-175B, BLOOM-176B, GLM-130B). You can also adjust the model name to validate SmoothQuant on other models. `../act_scales/` provides the activation channel scales for OPT and BLOOM models.

In order to run this notebook, you need to install the following packages:

- smoothquant
- PyTorch
- Transformers
- Accelerate

In [1]:
%env HF_HOME="/state/partition1/user/zzhang1/cache/huggingface"

env: HF_HOME="/state/partition1/user/zzhang1/cache/huggingface"


In [2]:
opt_125m_model_path = "/state/partition1/user/zzhang1/cache/huggingface/hub/models--facebook--opt-125m/snapshots/27dcfa74d334bc871f3234de431e71c6eeba5dd6/"
opt_6_7b_model_path = "/state/partition1/user/zzhang1/cache/huggingface/hub/models--facebook--opt-6.7b/snapshots/a45aa65bbeb77c1558bc99bedc6779195462dab0/"
opt_13b_model_path = "/state/partition1/user/zzhang1/cache/huggingface/hub/models--facebook--opt-13b/snapshots/e515202d1e7750da62d245fbccb2723b9c1790f5/"

In [3]:
import importlib
# Force reimport of the module
importlib.reload(importlib.import_module("nanoquant.investigate"))
importlib.reload(importlib.import_module("smoothquant.fake_quant"))

from nanoquant.investigate import sweep, report_sweep, Investigation

repo_dir = "."
short_model_name = "opt-125m"
#sweep(short_model_name=short_model_name, repo_dir=repo_dir, save_dir=".")
#report_sweep(short_model_name=short_model_name, save_dir=".")


n_bits = 8
q_group_size = 0 # 0 means no grouping
q_protect = False # False means no protection
q_protection_scale = 0.0 # 0.0 means mixed-precision. >= 1.0 means actual scale up/down.
q_protection_ratio = 0.01 # 0.01 means 1% of the weights are protected.
q_smoothing_strength = 0.5

investigation = Investigation(
    short_model_name=short_model_name,
    local_model_path=opt_125m_model_path,
    local_files_only=True,
    repo_dir=repo_dir,
    n_bits=n_bits,
    q_group_size=q_group_size,
    q_protect=q_protect,
    q_protection_scale=q_protection_scale,
    q_protection_ratio=q_protection_ratio,
    q_smoothing_strength=q_smoothing_strength
)


Using the latest cached version of the dataset since lambada couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'plain_text' at /state/partition1/user/zzhang1/cache/huggingface/datasets/lambada/plain_text/0.0.0/5953bd97664b64b95754f299b2309ecfbfbe81b9 (last modified on Wed Dec 11 00:22:59 2024).
Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /state/partition1/user/zzhang1/cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Dec 11 00:23:01 2024).


In [4]:
from nanoquant.investigate import sweep, report_sweep, make_setups

setups = make_setups()
print(setups[0])
print(setups[1])


{'n_bits': 4, 'q_group_size': 128, 'q_protect': False, 'q_protection_scale': -1.0, 'q_protection_ratio': -1.0, 'q_smoothing_strength': 0.5}
{'n_bits': 4, 'q_group_size': 128, 'q_protect': True, 'q_protection_scale': 0.0, 'q_protection_ratio': 0.0, 'q_smoothing_strength': 0.5}


In [5]:
base_res = investigation.evaluate_base_model(perp=True)
print(f"Base Result: {base_res}")

Making base model...
Done making base model.


evaluating...:   0%|          | 0/40 [00:00<?, ?it/s]2024-12-11 14:05:43.514563: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-11 14:05:43.528357: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-12-11 14:05:43.544094: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-12-11 14:05:43.548933: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-11 14:05:4

Base Result: 27.56900978088379


In [6]:
quantized_res = investigation.evaluate_base_quantized_model(perp=True)
print(f"Quantized Result: {quantized_res}")

Making base model...
Done making base model.
Quantizing model...
Done quantizing model.


evaluating...: 100%|██████████| 40/40 [00:06<00:00,  6.44it/s]


Quantized Result: 27.811603546142578


In [7]:
smooth_quantized_res = investigation.evaluate_base_smooth_model(perp=True)
print(f"Smooth Quantized Result: {smooth_quantized_res}")

Making base model...
Done making base model.
Smoothing model...
Done smoothing model.
Quantizing model...
Done quantizing model.


evaluating...: 100%|██████████| 40/40 [00:02<00:00, 13.82it/s]

Smooth Quantized Result: 27.627756118774414





In [7]:
def setup_name(setup):
    n_bits = setup["n_bits"]
    base_name = f"W{n_bits}A{n_bits}"
    q_group_size = setup["q_group_size"]
    if q_group_size > 0:
        base_name += f" G{q_group_size}"
    q_protect = setup["q_protect"]
    if q_protect:
        q_protection_scale = setup["q_protection_scale"]
        q_protection_ratio = setup["q_protection_ratio"]
        with_act = "Act" if q_protection_ratio > 1e-5 else "NoAct"
        if q_protection_scale > 1e-5:
            base_name += f" AWQ-Scaled-{with_act}"
        else:
            base_name += f" AWQ-Mixed-{with_act}"
    return base_name

def make_baselines():
    return ["fp16", "awq", "smoothquant", "smoothquant-g", "w4a4", "smooth-w4a4", "w8a8"]

In [8]:
import gc
import os
import pickle as pkl
import torch

save_dir = "."
perp = True
os.makedirs(save_dir, exist_ok=True)
result_file = f"{save_dir}/results_{short_model_name}.pkl"
#if os.path.exists(result_file):
#    with open(result_file, "rb") as f:
#        results = pkl.load(f)
#else:
results = {}
for setup in setups:
        setup_key = str(setup)
        base_expt_name = setup_name(setup)
        if setup_key in results and base_expt_name != "W4A4 G128":
            print(f"Setup {base_expt_name} already run. Results={results[setup_key]['q_res']}, SmoothResults={results[setup_key]['q_smooth_res']}")
            continue
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        investigation = Investigation(
            short_model_name=short_model_name,
            repo_dir=repo_dir,
            local_model_path=opt_125m_model_path,
            local_files_only=True,
            **setup)
        simple_expt_name = f"{base_expt_name}"
        if simple_expt_name not in results:
            print(f"Running setup {base_expt_name}")
            q_res = investigation.evaluate_setup_model(perp=perp, apply_smooth=False)
            results[simple_expt_name] = q_res
            with open(result_file, "wb") as f:
                pkl.dump(results, f)
        else:
            q_res = results[simple_expt_name]
        print(f"{simple_expt_name}: {q_res}")
        # Smoothed model
        smooth_expt_name = f"Smooth {base_expt_name}"
        if smooth_expt_name not in results:
            print(f"Running setup {smooth_expt_name}")
            q_smooth_res = investigation.evaluate_setup_model(perp=perp, apply_smooth=True)
            results[smooth_expt_name] = q_smooth_res
            with open(result_file, "wb") as f:
                pkl.dump(results, f)
        else:
            q_smooth_res = results[smooth_expt_name]
        print(f"{smooth_expt_name}: {q_smooth_res}")
        res = {
            "setup": setup,
            "q_res": q_res,
            "q_smooth_res": q_smooth_res,
        }
        results[setup_key] = res
        # Checkpointing
        with open(result_file, "wb") as f:
            pkl.dump(results, f)

Using the latest cached version of the dataset since lambada couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'plain_text' at /state/partition1/user/zzhang1/cache/huggingface/datasets/lambada/plain_text/0.0.0/5953bd97664b64b95754f299b2309ecfbfbe81b9 (last modified on Wed Dec 11 00:22:59 2024).
Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /state/partition1/user/zzhang1/cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Dec 11 00:23:01 2024).


Running setup W4A4 G128
Making base model...
Done making base model.
Quantizing model...
Quantizing model... False
Done quantizing model.


evaluating...: 100%|██████████| 40/40 [00:15<00:00,  2.60it/s]


W4A4 G128: 36.21949005126953
Running setup Smooth W4A4 G128
Making base model...
Done making base model.
Smoothing model...
Done smoothing model.
Quantizing model...
Quantizing model... False
Done quantizing model.


evaluating...: 100%|██████████| 40/40 [00:04<00:00,  9.61it/s]


Smooth W4A4 G128: 32.42618179321289


Using the latest cached version of the dataset since lambada couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'plain_text' at /state/partition1/user/zzhang1/cache/huggingface/datasets/lambada/plain_text/0.0.0/5953bd97664b64b95754f299b2309ecfbfbe81b9 (last modified on Wed Dec 11 00:22:59 2024).
Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /state/partition1/user/zzhang1/cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Dec 11 00:23:01 2024).


Running setup W4A4 G128 AWQ-Mixed-NoAct
Making base model...
Done making base model.
Applying AWQ...
Done applying AWQ.
Quantizing model...
Quantizing model... True
Done quantizing model.


evaluating...: 100%|██████████| 40/40 [00:23<00:00,  1.69it/s]


W4A4 G128 AWQ-Mixed-NoAct: 29.848878860473633
Running setup Smooth W4A4 G128 AWQ-Mixed-NoAct
Making base model...
Done making base model.
Applying AWQ...
Done applying AWQ.
Smoothing model...
Done smoothing model.
Quantizing model...
Quantizing model... True
Done quantizing model.


evaluating...: 100%|██████████| 40/40 [00:18<00:00,  2.15it/s]

Smooth W4A4 G128 AWQ-Mixed-NoAct: 30.167171478271484





In [10]:
import torch
import tqdm
from torch import nn
from transformers.models.opt.modeling_opt import (
    OPTAttention,
    OPTDecoderLayer,
    OPTForCausalLM,
)
from transformers import GPT2Tokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
from smoothquant.smooth import smooth_lm
from smoothquant.fake_quant import WQAQLinear, quantize_opt

In this notebook, we simulate the 8-bit dynamic per-tensor weight and activation quantization with FP16, i.e., fake quantization. We have implemented the real 8-bit quantization with INT8 CUTLASS GEMM kernels for both PyTorch and FasterTransformer. Please stay tuned for the release.

The following is an evaluator to see the performance of the model. We use a toy dataset (the first 1000 examples in the validation set of the Lambada dataset) to evaluate the model. You can replace it with your own dataset. The conclusion should be the same.

**In this demo, we have simplified the evaluation by using the first 1,000 samples from the LAMBADA dataset's validation set. We employ the "Last Token Prediction Accuracy" as our evaluation metric. This approximate evaluation is intended for demonstration purposes, providing simple but meaningful comparisons of relative performance between methods. For a more strict assessment, we recommend using the [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness) to obtain the "Last Word Prediction Accuracy" for the LAMBADA dataset, which is the reported metric in our paper.**

In [11]:
class PerplexityEvaluator:
    def __init__(self, dataset, tokenizer, device, n_samples=40):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.device = device

        self.dataset = tokenizer(
            "\n\n".join(dataset["text"]), return_tensors="pt"
        ).input_ids.to(device)

        self.n_samples = n_samples

    @torch.no_grad()
    def evaluate(self, model):
        model.eval()
        testenc = self.dataset
        nsamples = self.n_samples
        model = model.eval()

        nlls = []
        for i in tqdm.tqdm(range(nsamples), desc="evaluating..."):
            batch = testenc[:, (i * 2048):((i + 1) * 2048)].to(model.device)
            with torch.no_grad():
                lm_logits = model(batch).logits
            shift_logits = lm_logits[:, :-1, :].contiguous().float()
            shift_labels = testenc[:, (i * 2048):((i + 1) * 2048)][:, 1:]
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            neg_log_likelihood = loss.float() * 2048
            nlls.append(neg_log_likelihood)

        return torch.exp(torch.stack(nlls).sum() / (nsamples * 2048))

class AccuracyEvaluator:
    def __init__(self, dataset, tokenizer, device):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.device = device

        # tokenize the dataset
        def tokenize_function(examples):
            example = self.tokenizer(examples["text"])
            return example

        self.dataset = self.dataset.map(tokenize_function, batched=True)
        self.dataset.set_format(type="torch", columns=["input_ids"])

    @torch.no_grad()
    def evaluate(self, model):
        model.eval()
        # The task is to predict the last word of the input.
        total, hit = 0, 0
        for batch in self.dataset:
            input_ids = batch["input_ids"].to(self.device).unsqueeze(0)
            label = input_ids[:, -1]
            outputs = model(input_ids)
            last_token_logits = outputs.logits[:, -2, :]
            pred = last_token_logits.argmax(dim=-1)
            total += label.size(0)
            hit += (pred == label).sum().item()
        acc = hit / total
        return acc

In [15]:
from datasets import load_dataset

#model_name = "facebook/opt-125m"
model_name
#model_path = "/state/partition1/user/zzhang1/cache/huggingface/hub/models--facebook--opt-125m/snapshots/27dcfa74d334bc871f3234de431e71c6eeba5dd6/"
#acc_tokenizer = GPT2Tokenizer.from_pretrained(model_name)
acc_tokenizer = GPT2Tokenizer.from_pretrained(model_path)
#perp_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
perp_tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
#acc_dataset = load_dataset("lambada", split="validation[:40]")
#perp_dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
cache_dir = "/state/partition1/user/zzhang1/cache/huggingface/datasets"
acc_dataset_name = f"{cache_dir}/lambada"
#acc_dataset = load_dataset(acc_dataset_name)
n_samples = 40
acc_dataset = load_dataset("lambada", split=f"validation[:{n_samples}]", cache_dir="/state/partition1/user/zzhang1/cache/huggingface/datasets/")
perp_dataset_name = f"{cache_dir}/wikitext"
#perp_dataset = load_dataset(perp_dataset_name)
perp_dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test', cache_dir="/state/partition1/user/zzhang1/cache/huggingface/datasets/")
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"
acc_evaluator = AccuracyEvaluator(acc_dataset, acc_tokenizer, device)
perp_evaluator = PerplexityEvaluator(perp_dataset, perp_tokenizer, device, n_samples=15)

Using the latest cached version of the dataset since lambada couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'plain_text' at /state/partition1/user/zzhang1/cache/huggingface/datasets/lambada/plain_text/0.0.0/5953bd97664b64b95754f299b2309ecfbfbe81b9 (last modified on Wed Dec 11 00:28:24 2024).
Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /state/partition1/user/zzhang1/cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Dec 11 00:23:01 2024).


Map:   0%|          | 0/40 [00:00<?, ? examples/s]

## FP16 Model Accuracy

Let's first check the performance of the original FP16 model.

In [17]:
model_fp16 = OPTForCausalLM.from_pretrained(
    model_path, torch_dtype=torch.float16, device_map="auto"
)
res_fp16 = perp_evaluator.evaluate(model_fp16)
print(f"Original model (fp16) result: {res_fp16}")

evaluating...: 100%|██████████| 15/15 [00:00<00:00, 23.37it/s]

Original model (fp16) result: 29.11785888671875





We then quantize the model to W8A8 and check the performance.

## Naive W8A8 Quantized Model Accuracy

In [19]:
model_fp16 = OPTForCausalLM.from_pretrained(
    model_path, torch_dtype=torch.float16, device_map="auto"
)

n_bits = 4
q_group_size = 128 # 0 means no group
q_protect = True # False means no protection
q_protection_ratio = 0.01 # 0.01 means 1% of the weights are protected.
q_protection_scale = 0.0 # 0.0 mixed-precision. >1.0 means scale up/down.
q_name = f"W{n_bits}A{n_bits}"
q_model = quantize_opt(
    model_fp16,
    n_bits=n_bits,
    q_group_size=q_group_size,
    q_protect=q_protect,
    q_protection_ratio=q_protection_ratio,
    q_protection_scale=q_protection_scale,
)
q_res = perp_evaluator.evaluate(q_model)
print(f"Naive {q_name} quantized model result: {q_res}")

evaluating...: 100%|██████████| 15/15 [00:02<00:00,  6.54it/s]

Naive W4A4 quantized model result: 35.55803680419922





In [20]:
print(q_model)

OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 768, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 768)
      (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-11): 12 x OPTDecoderLayer(
          (self_attn): OPTAttention(
            (k_proj): WALinear(768, 768, bias=True, weight_quant=protected_group_quant_128, act_quant=protected_group_quant_128, output_quant=protected_group_quant_128)
            (v_proj): WALinear(768, 768, bias=True, weight_quant=protected_group_quant_128, act_quant=protected_group_quant_128, output_quant=protected_group_quant_128)
            (q_proj): WALinear(768, 768, bias=True, weight_quant=protected_group_quant_128, act_quant=protected_group_quant_128, output_quant=protected_group_quant_128)
            (out_proj): WALinear(768, 768, bias=True, weight_quant=protected_group_quant_128, act_quant=protected_group_q

We can see there is a significant accuracy drop. This is consistent with LLM.int8()'s finding: when the model size increases larger than 6.7B, systematic outliers will emerge in activations, which makes fully INT8 quantization impossible.

## SmoothQuant W8A8 Quantized Model Accuracy

Let's smooth the model, quantize it, and check the performance! In `../act_scales`, we provide the activation scales for OPT and BLOOM models. You can also use this notebook to test quantizing those models.

In [22]:
model = OPTForCausalLM.from_pretrained(
    model_path, torch_dtype=torch.float16, device_map="auto"
)
scales_path = "smoothquant/act_scales/opt-125m.pt"
act_scales = torch.load(scales_path)
smooth_lm(model, act_scales, 0.5)
q_model_smooth = quantize_opt(
    model,
    n_bits=n_bits,
    q_group_size=q_group_size,
    q_protect=q_protect,
    q_protection_ratio=q_protection_ratio,
    q_protection_scale=q_protection_scale,
)
q_res_smooth = perp_evaluator.evaluate(q_model_smooth)
print(f"Smoothed {q_name} quantized model result: {q_res_smooth}")

evaluating...: 100%|██████████| 15/15 [00:01<00:00,  8.45it/s]

Smoothed W4A4 quantized model result: 33.575801849365234





We can see the smoothed model has the same accuracy as the FP16 model. This is because SmoothQuant smooths the outliers in activations and moves the quantization difficulty from activations to weights.