# NanoQuant: Investigating W4A4 Quantization on OPT-6.7B

### Amadou Ngom, Sylvia Zhang, Bowen Zhu, Qihang Chen
#### 6.5940 TinyML final project

In this notebook, we use OPT-6.7B model to demonstrate the prospects and limitations of W4A4 quantization via combining SmoothQuant and AWQ.

In order to run this notebook, you need to install the following packages:

- smoothquant
- PyTorch
- Transformers

smoothquant and awq should be installed from submodules

In [1]:
%env HF_HOME="/state/partition1/user/zzhang1/cache/huggingface"

env: HF_HOME="/state/partition1/user/zzhang1/cache/huggingface"


In [2]:
opt_125m_model_path = "/state/partition1/user/zzhang1/cache/huggingface/hub/models--facebook--opt-125m/snapshots/27dcfa74d334bc871f3234de431e71c6eeba5dd6/"
opt_6_7b_model_path = "/state/partition1/user/zzhang1/cache/huggingface/hub/models--facebook--opt-6.7b/snapshots/a45aa65bbeb77c1558bc99bedc6779195462dab0/"
opt_13b_model_path = "/state/partition1/user/zzhang1/cache/huggingface/hub/models--facebook--opt-13b/snapshots/e515202d1e7750da62d245fbccb2723b9c1790f5/"

In [18]:
import importlib
# Force reimport of the module
importlib.reload(importlib.import_module("nanoquant.investigate"))
importlib.reload(importlib.import_module("smoothquant.fake_quant"))

from nanoquant.investigate import sweep, report_sweep, Investigation
repo_dir = "."
short_model_name = "opt-125m"
# report_sweep(short_model_name=short_model_name, save_dir=".")

KeyError: 'fp16'

In [7]:
import importlib
# Force reimport of the module
importlib.reload(importlib.import_module("nanoquant.investigate"))
importlib.reload(importlib.import_module("smoothquant.fake_quant"))

from nanoquant.investigate import sweep, report_sweep, Investigation

repo_dir = "."
short_model_name = "opt-125m"
#sweep(short_model_name=short_model_name, repo_dir=repo_dir, save_dir=".")
#report_sweep(short_model_name=short_model_name, save_dir=".")


n_bits = 8
q_group_size = 0 # 0 means no grouping
q_protect = False # False means no protection
q_protection_scale = 0.0 # 0.0 means mixed-precision. >= 1.0 means actual scale up/down.
q_protection_ratio = 0.01 # 0.01 means 1% of the weights are protected.
q_smoothing_strength = 0.5

investigation = Investigation(
    short_model_name=short_model_name,
    local_model_path=opt_125m_model_path,
    local_files_only=True,
    repo_dir=repo_dir,
    n_bits=n_bits,
    q_group_size=q_group_size,
    q_protect=q_protect,
    q_protection_scale=q_protection_scale,
    q_protection_ratio=q_protection_ratio,
    q_smoothing_strength=q_smoothing_strength
)


Using the latest cached version of the dataset since lambada couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'plain_text' at /state/partition1/user/zzhang1/cache/huggingface/datasets/lambada/plain_text/0.0.0/5953bd97664b64b95754f299b2309ecfbfbe81b9 (last modified on Thu Dec 12 11:00:51 2024).
Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /state/partition1/user/zzhang1/cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Thu Dec 12 11:00:56 2024).


Map:   0%|          | 0/40 [00:00<?, ? examples/s]

In [4]:
# AWQ Investigate
short_model_name = "opt-125m"
repo_dir = "llm-awq"
awq_zoo = "mit-han-lab/awq-model-zoo"
awq_pt_name = f"llm-awq/awq_cache/{short_model_name}-w4-g128.pt"

from awq.quantize.pre_quant import apply_awq
import torch
awq_results = torch.load(awq_pt_name, map_location="cpu")


In [5]:
from nanoquant.investigate import sweep, report_sweep, make_setups

setups = make_setups()
print(setups[0])
print(setups[1])


{'n_bits': 4, 'q_group_size': 128, 'q_protect': False, 'q_protection_scale': -1.0, 'q_protection_ratio': -1.0, 'q_smoothing_strength': 0.5}
{'n_bits': 4, 'q_group_size': 128, 'q_protect': True, 'q_protection_scale': 0.0, 'q_protection_ratio': 0.0, 'q_smoothing_strength': 0.5}


#### Base fp16 Model Evaluation
First, let us evaluate the base fp16 model

In [14]:
base_res = investigation.evaluate_base_model(perp=True)
print(f"Base Result: {base_res}")

Making base model...
Done making base model.


evaluating...: 100%|██████████| 40/40 [00:02<00:00, 19.78it/s]

Base Result: 27.56900978088379





#### 8-bit Quantization
Let us evaluate previous approaches for 8-bit quantization

In [6]:
quantized_res = investigation.evaluate_base_quantized_model(perp=True)
print(f"Quantized Result: {quantized_res}")

Making base model...
Done making base model.
Quantizing model...
Done quantizing model.


evaluating...: 100%|██████████| 40/40 [00:06<00:00,  6.44it/s]


Quantized Result: 27.811603546142578


In [10]:
smooth_quantized_res = investigation.evaluate_base_smooth_model(perp=True)
print(f"Smooth Quantized Result: {smooth_quantized_res}")

Making base model...
Done making base model.
Smoothing model...
Done smoothing model.
Quantizing model...
Done quantizing model.


evaluating...: 100%|██████████| 40/40 [00:02<00:00, 17.71it/s]

Smooth Quantized Result: 27.627756118774414





#### W4A4 Quantization Variants
The following setup-sweep code evaluates different approaches of W4A4 quantization that we implemented.

In [11]:
def setup_name(setup):
    n_bits = setup["n_bits"]
    base_name = f"W{n_bits}A{n_bits}"
    q_group_size = setup["q_group_size"]
    if q_group_size > 0:
        base_name += f" G{q_group_size}"
    q_protect = setup["q_protect"]
    if q_protect:
        q_protection_scale = setup["q_protection_scale"]
        q_protection_ratio = setup["q_protection_ratio"]
        with_act = "Act" if q_protection_ratio > 1e-5 else "NoAct"
        if q_protection_scale > 1e-5:
            base_name += f" AWQ-Scaled-{with_act}"
        else:
            base_name += f" AWQ-Mixed-{with_act}"
    return base_name

def make_baselines():
    return ["fp16", "awq", "smoothquant", "smoothquant-g", "w4a4", "smooth-w4a4", "w8a8"]

In [12]:
import gc
import os
import pickle as pkl
import torch

save_dir = "."
perp = True
os.makedirs(save_dir, exist_ok=True)
result_file = f"{save_dir}/results_{short_model_name}.pkl"
#if os.path.exists(result_file):
#    with open(result_file, "rb") as f:
#        results = pkl.load(f)
#else:
results = {}
for setup in setups:
        setup_key = str(setup)
        base_expt_name = setup_name(setup)
        if setup_key in results and base_expt_name != "W4A4 G128":
            print(f"Setup {base_expt_name} already run. Results={results[setup_key]['q_res']}, SmoothResults={results[setup_key]['q_smooth_res']}")
            continue
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        investigation = Investigation(
            short_model_name=short_model_name,
            repo_dir=repo_dir,
            local_model_path=opt_125m_model_path,
            local_files_only=True,
            **setup)
        simple_expt_name = f"{base_expt_name}"
        if simple_expt_name not in results:
            print(f"Running setup {base_expt_name}")
            q_res = investigation.evaluate_setup_model(perp=perp, apply_smooth=False)
            results[simple_expt_name] = q_res
            with open(result_file, "wb") as f:
                pkl.dump(results, f)
        else:
            q_res = results[simple_expt_name]
        print(f"{simple_expt_name}: {q_res}")
        # Smoothed model
        smooth_expt_name = f"Smooth {base_expt_name}"
        if smooth_expt_name not in results:
            print(f"Running setup {smooth_expt_name}")
            q_smooth_res = investigation.evaluate_setup_model(perp=perp, apply_smooth=True)
            results[smooth_expt_name] = q_smooth_res
            with open(result_file, "wb") as f:
                pkl.dump(results, f)
        else:
            q_smooth_res = results[smooth_expt_name]
        print(f"{smooth_expt_name}: {q_smooth_res}")
        res = {
            "setup": setup,
            "q_res": q_res,
            "q_smooth_res": q_smooth_res,
        }
        results[setup_key] = res
        # Checkpointing
        with open(result_file, "wb") as f:
            pkl.dump(results, f)

Using the latest cached version of the dataset since lambada couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'plain_text' at /state/partition1/user/zzhang1/cache/huggingface/datasets/lambada/plain_text/0.0.0/5953bd97664b64b95754f299b2309ecfbfbe81b9 (last modified on Thu Dec 12 23:37:06 2024).
Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /state/partition1/user/zzhang1/cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Thu Dec 12 11:00:56 2024).


Running setup W4A4 G128
Making base model...
Done making base model.
Quantizing model...
Quantizing model... False
Done quantizing model.


evaluating...: 100%|██████████| 40/40 [00:02<00:00, 15.19it/s]


W4A4 G128: 36.21949005126953
Running setup Smooth W4A4 G128
Making base model...
Done making base model.
Smoothing model...
Done smoothing model.
Quantizing model...
Quantizing model... False
Done quantizing model.


evaluating...: 100%|██████████| 40/40 [00:02<00:00, 15.21it/s]


Smooth W4A4 G128: 32.42618179321289


Using the latest cached version of the dataset since lambada couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'plain_text' at /state/partition1/user/zzhang1/cache/huggingface/datasets/lambada/plain_text/0.0.0/5953bd97664b64b95754f299b2309ecfbfbe81b9 (last modified on Thu Dec 12 23:37:06 2024).
Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /state/partition1/user/zzhang1/cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Thu Dec 12 11:00:56 2024).


Running setup W4A4 G128 AWQ-Mixed-NoAct
Making base model...
Done making base model.
Applying AWQ...
Done applying AWQ.
Quantizing model...
Quantizing model... True
Done quantizing model.


evaluating...: 100%|██████████| 40/40 [00:02<00:00, 15.19it/s]


W4A4 G128 AWQ-Mixed-NoAct: 29.848878860473633
Running setup Smooth W4A4 G128 AWQ-Mixed-NoAct
Making base model...
Done making base model.
Applying AWQ...
Done applying AWQ.
Computing scales after AWQ...
Smoothing model...
Done smoothing model.
Quantizing model...
Quantizing model... True
Done quantizing model.


evaluating...: 100%|██████████| 40/40 [00:02<00:00, 15.19it/s]

Smooth W4A4 G128 AWQ-Mixed-NoAct: 30.360246658325195





In [10]:
import torch
import tqdm
from torch import nn
from transformers.models.opt.modeling_opt import (
    OPTAttention,
    OPTDecoderLayer,
    OPTForCausalLM,
)
from transformers import GPT2Tokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
from smoothquant.smooth import smooth_lm
from smoothquant.fake_quant import WQAQLinear, quantize_opt

In this notebook, we simulate the 8-bit dynamic per-tensor weight and activation quantization with FP16, i.e., fake quantization. We have implemented the real 8-bit quantization with INT8 CUTLASS GEMM kernels for both PyTorch and FasterTransformer. Please stay tuned for the release.

The following is an evaluator to see the performance of the model. We use a toy dataset (the first 1000 examples in the validation set of the Lambada dataset) to evaluate the model. You can replace it with your own dataset. The conclusion should be the same.

In [15]:
from datasets import load_dataset

#model_name = "facebook/opt-125m"
model_name
#model_path = "/state/partition1/user/zzhang1/cache/huggingface/hub/models--facebook--opt-125m/snapshots/27dcfa74d334bc871f3234de431e71c6eeba5dd6/"
#acc_tokenizer = GPT2Tokenizer.from_pretrained(model_name)
acc_tokenizer = GPT2Tokenizer.from_pretrained(model_path)
#perp_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
perp_tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
#acc_dataset = load_dataset("lambada", split="validation[:40]")
#perp_dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
cache_dir = "/state/partition1/user/zzhang1/cache/huggingface/datasets"
acc_dataset_name = f"{cache_dir}/lambada"
#acc_dataset = load_dataset(acc_dataset_name)
n_samples = 40
acc_dataset = load_dataset("lambada", split=f"validation[:{n_samples}]", cache_dir="/state/partition1/user/zzhang1/cache/huggingface/datasets/")
perp_dataset_name = f"{cache_dir}/wikitext"
#perp_dataset = load_dataset(perp_dataset_name)
perp_dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test', cache_dir="/state/partition1/user/zzhang1/cache/huggingface/datasets/")
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"
acc_evaluator = AccuracyEvaluator(acc_dataset, acc_tokenizer, device)
perp_evaluator = PerplexityEvaluator(perp_dataset, perp_tokenizer, device, n_samples=15)

Using the latest cached version of the dataset since lambada couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'plain_text' at /state/partition1/user/zzhang1/cache/huggingface/datasets/lambada/plain_text/0.0.0/5953bd97664b64b95754f299b2309ecfbfbe81b9 (last modified on Wed Dec 11 00:28:24 2024).
Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /state/partition1/user/zzhang1/cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Dec 11 00:23:01 2024).


Map:   0%|          | 0/40 [00:00<?, ? examples/s]