# AWQ on Vicuna

In this notebook, we use Vicuna model to demonstrate the performance of AWQ on instruction-tuned models. We implement AWQ real-INT4 inference kernels, which are wrapped as Pytorch modules and can be easily used by existing models. We also provide a simple example to show how to use AWQ to quantize a model and save/load the quantized model checkpoint.

In order to run this notebook, you need to install the following packages:
- [AWQ](https://github.com/mit-han-lab/llm-awq)
- [Pytorch](https://pytorch.org/)
- [Accelerate](https://github.com/huggingface/accelerate)
- [Transformers](https://github.com/huggingface/transformers)

In [1]:
import sys
import os

# 当前 notebook 路径: autodl-tmp/llm-awq/examples
project_root = os.path.abspath("../")
if project_root not in sys.path:
    sys.path.insert(0, project_root)
import awq

In [14]:
import torch
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from awq.quantize.quantizer import real_quantize_model_weight
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from tinychat.demo import gen_params, stream_output
from tinychat.stream_generators import StreamGenerator
from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp
from tinychat.utils.prompt_templates import get_prompter
import os
# This demo only support single GPU for now
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

Please get the Vicuna model from [FastChat](https://github.com/lm-sys/FastChat) and run the following command to generate a quantized model checkpoint first.

```bash
mkdir quant_cache
python -m awq.entry --model_path [vicuna-7b_model_path] \
    --w_bit 4 --q_group_size 128 \
    --load_awq awq_cache/vicuna-7b-w4-g128.pt \
    --q_backend real --dump_quant quant_cache/vicuna-7b-w4-g128-awq.pt
```

In [3]:
# model_path = "" # the path of vicuna-7b model
# load_quant_path = "quant_cache/vicuna-7b-w4-g128-awq.pt"
model_path = "/root/autodl-tmp/models--meta-llama--Llama-2-7b-hf"
load_quant_path = "/root/autodl-tmp/quant_cache/llama2-7b-w4-g128-awq-v2.pt"

We first load a empty model and replace all the linear layers with WQLinear layers. Then we load the quantized weights from the checkpoint. 

In [4]:
config = AutoConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
with init_empty_weights():
    model = AutoModelForCausalLM.from_pretrained(model_path, config=config,
                                                    torch_dtype=torch.float16)
q_config = {"zero_point": True, "q_group_size": 128}
real_quantize_model_weight(
    model, w_bit=4, q_config=q_config, init_only=True)

model = load_checkpoint_and_dispatch(
    model, load_quant_path,
    device_map="auto",
    no_split_module_classes=["LlamaDecoderLayer"]
)

  return self.fget.__get__(instance, owner)()
Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00,  4.74s/it]
real weight quantization...(init only): 100%|██████████| 32/32 [00:00<00:00, 1112.06it/s]


In [5]:
make_quant_attn(model, "cuda:0")
make_quant_norm(model)
make_fused_mlp(model)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): QuantLlamaAttention(
          (qkv_proj): WQLinear(in_features=4096, out_features=12288, bias=False, w_bit=4, group_size=128)
          (o_proj): WQLinear(in_features=4096, out_features=4096, bias=False, w_bit=4, group_size=128)
          (rotary_emb): QuantLlamaRotaryEmbedding()
        )
        (mlp): QuantLlamaMLP(
          (down_proj): WQLinear(in_features=11008, out_features=4096, bias=False, w_bit=4, group_size=128)
        )
        (input_layernorm): FTLlamaRMSNorm()
        (post_attention_layernorm): FTLlamaRMSNorm()
      )
    )
    (norm): FTLlamaRMSNorm()
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)

In [17]:
from tinychat.demo import gen_params, stream_output
print(gen_params)

{'seed': -1, 'n_threads': 1, 'n_predict': 512, 'n_parts': -1, 'n_ctx': 512, 'n_batch': 512, 'n_keep': 0, 'n_vocab': 50272, 'logit_bias': {}, 'top_k': 40, 'top_p': 0.95, 'tfs_z': 1.0, 'typical_p': 1.0, 'temp': 0.7, 'repeat_penalty': 1.1, 'repeat_last_n': 64, 'frequency_penalty': 0.0, 'presence_penalty': 0.0, 'mirostat': 0, 'mirostat_tau': 5.0, 'mirostat_eta': 0.1}


In [18]:

model_prompter = get_prompter("llama", model_path)
stream_generator = StreamGenerator
count = 0
while True:
    # Get input from the user
    input_prompt = input("USER: ")
    if input_prompt == "":
        print("EXIT...")
        break
    model_prompter.insert_prompt(input_prompt)
    
    # output_stream = stream_generator(model, tokenizer, model_prompter.model_input, gen_params, device="cuda:0")
    output_stream = stream_generator(model, tokenizer, model_prompter.model_input, 0, gen_params, device="cuda:0")
    outputs = stream_output(output_stream)
    model_prompter.update_template(outputs)
    count += 1

USER:  hi


ASSISTANT: 

TypeError: forward() got an unexpected keyword argument 'cache_position'

In [12]:
from transformers import TextStreamer
import torch

model_prompter = get_prompter("llama", model_path)
count = 0

streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

while True:
    input_prompt = input("USER: ")
    if input_prompt == "":
        print("EXIT...")
        break

    # 构造 prompt
    model_prompter.insert_prompt(input_prompt)
    prompt = model_prompter.model_input

    # tokenizer 编码
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

    # 直接用 HuggingFace 生成
    outputs = model.generate(
        **inputs,
        max_new_tokens=gen_params.n_predict,
        temperature=gen_params.temp,
        top_p=gen_params.top_p,
        streamer=streamer,        # ← 实现流式输出
    )

    # 获取最终完整输出
    final_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # 更新模板
    model_prompter.update_template(final_text)
    count += 1


USER:  In the context of model compression, what is quantization and what is its purpose?


TypeError: forward() got an unexpected keyword argument 'cache_position'