<a href="https://colab.research.google.com/github/adithyab100/smoothquant-mixedprecision/blob/main/examples/smoothquant_llama_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SmoothQuant on Llama 2 7B

In this notebook, we use Llama-2-7B model to demonstrate SmoothQuant can use 8-bit for both weights and activations to achieve the similar perplexity as FP16 models.

In order to run this notebook, you need to install the following packages:

- smoothquant
- PyTorch
- Transformers
- Accelerate

In [6]:
!git clone https://github.com/adithyab100/smoothquant-mixedprecision.git
%cd smoothquant-mixedprecision
!pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
!pip install transformers==4.36.0 accelerate datasets zstandard
!python setup.py install

Cloning into 'smoothquant-mixedprecision'...
remote: Enumerating objects: 443, done.[K
remote: Counting objects: 100% (247/247), done.[K
remote: Compressing objects: 100% (110/110)[Kcts: 100% (110/110), done.[K
remote: Total 443 (delta 179), reused 157 (delta 137), pack-reused 196 (from 1)[K
Receiving objects: 100% (443/443), 6.90 MiB | 5.62 MiB/s, done.
Resolving deltas: 100% (266/266), done.
/Users/ningshanma/Desktop/fa24/6.5940/smoothquant-mixedprecision/examples/smoothquant-mixedprecision/smoothquant-mixedprecision
Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu113
[31mERROR: Could not find a version that satisfies the requirement torch==1.12.1+cu113 (from versions: 2.2.0, 2.2.1, 2.2.2, 2.3.0, 2.3.1, 2.4.0, 2.4.1, 2.5.0, 2.5.1)[0m[31m
[0m[31mERROR: No matching distribution found for torch==1.12.1+cu113[0m[31m
running install
!!

        ********************************************************************************
        Please avoid 

In [7]:
%reload_ext autoreload
%autoreload 2

import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import (
    LlamaAttention,
    LlamaDecoderLayer,
    LlamaForCausalLM,
    LlamaMLP,
)
from transformers import LlamaTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
from smoothquant.smooth import smooth_lm
from smoothquant.fake_quant import quantize_llama_like, quantize_opt, quantize_model
import tqdm
import gc
from functools import partial

The following is an evaluator to see the performance of the model. We use a toy dataset (the first 40 examples in the test set of the Wikitext-2 dataset) to evaluate the model. You can replace it with your own dataset. The conclusion should be the same.

In [8]:
class Evaluator:
    def __init__(self, dataset, tokenizer, device, n_samples=40):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.device = device

        self.dataset = tokenizer(
            "\n\n".join(dataset["text"]), return_tensors="pt"
        ).input_ids.to(device)

        self.n_samples = n_samples

    @torch.no_grad()
    def evaluate(self, model):
        model.eval()
        nlls = []
        for i in tqdm.tqdm(range(self.n_samples), desc="Evaluating..."):
            batch = self.dataset[:, (i * 2048) : ((i + 1) * 2048)].to(model.device)
            with torch.no_grad():
                lm_logits = model(batch).logits
            shift_logits = lm_logits[:, :-1, :].contiguous().float()
            shift_labels = self.dataset[:, (i * 2048) : ((i + 1) * 2048)][:, 1:]
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )
            neg_log_likelihood = loss.float() * 2048
            nlls.append(neg_log_likelihood)

        return torch.exp(torch.stack(nlls).sum() / (self.n_samples * 2048))

def evaluate(model, tokenizer):
    testenc = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
    testenc = tokenizer("\n\n".join(testenc['text']), return_tensors='pt')

    testenc = testenc.input_ids.to(model.device)
    nsamples = 40
    model = model.eval()

    nlls = []
    for i in tqdm.tqdm(range(nsamples), desc="evaluating..."):
        batch = testenc[:, (i * 2048):((i + 1) * 2048)].to(model.device)
        with torch.no_grad():
            lm_logits = model(batch).logits
        shift_logits = lm_logits[:, :-1, :].contiguous().float()
        shift_labels = testenc[:, (i * 2048):((i + 1) * 2048)][:, 1:]
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        neg_log_likelihood = loss.float() * 2048
        nlls.append(neg_log_likelihood)

    return torch.exp(torch.stack(nlls).sum() / (nsamples * 2048))

def get_model_size(model: nn.Module, data_width=16, group_size=-1):

    if group_size != -1:
        data_width += (16 + 4) / group_size

    num_elements = 0
    for param in model.parameters():
        num_elements += param.numel()
    return num_elements * data_width

Byte = 8
KiB = 1024 * Byte
MiB = 1024 * KiB
GiB = 1024 * MiB

In [9]:
from datasets import load_dataset


model_path = "facebook/opt-1.3b"
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")

# Evaluate the model
model_perplexity = evaluate(model, tokenizer)
model_size = get_model_size(model, data_width=32, group_size=128)
print(f"\nOriginal model perplexity: {model_perplexity:.2f}")
print(f"model size: {model_size/MiB:.2f} MiB")



KeyboardInterrupt: 

In [5]:
def get_calib_dataset(tokenizer=None, n_samples=256, block_size=512):
    dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split="validation")
    dataset = dataset.shuffle(seed=42)
    samples = []
    n_run = 0
    for data in dataset:
        line = data["text"]
        line = line.strip()
        line_encoded = tokenizer.encode(line)
        if len(line_encoded) > block_size:
            continue
        sample = torch.tensor([line_encoded])
        if sample.numel() == 0:
            continue
        samples.append(sample)
        n_run += 1
        if n_run == n_samples:
            break

    # now concatenate all samples and split according to block size
    cat_samples = torch.cat(samples, dim=1)
    n_split = cat_samples.shape[1] // block_size
    print(f" * Split into {n_split} blocks")
    return [cat_samples[:, i*block_size:(i+1)*block_size] for i in range(n_split)]

@torch.no_grad()
def get_calib_feat(model, tokenizer):
    input_dict = dict()
    def stat_input_max_hook(m, x, y, name):
        if isinstance(x, tuple):
            x = x[0]
        x_max = x.view(-1, x.shape[-1]).abs().mean(dim=0).cpu().detach()
        if name not in input_dict:
            input_dict[name] = [x_max]
        else:
            input_dict[name] += [x_max]

    hooks = []
    for name, m in model.named_modules():
        if isinstance(m, nn.Linear):
            hooks.append(
                m.register_forward_hook(
                    partial(stat_input_max_hook, name=name)))

    print("Collecting activation scales...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    samples = get_calib_dataset(tokenizer)
    pbar = tqdm.tqdm(samples)
    for input_ids in pbar:
        input_ids = input_ids.to(device)
        model(input_ids)

    for hook in hooks:
        hook.remove()
    return input_dict
input_feat = get_calib_feat(model, tokenizer)

Collecting activation scales...
 * Split into 33 blocks


100%|██████████| 33/33 [00:14<00:00,  2.32it/s]


In [39]:
if 'model' in globals():
    del model
if 'model_w4a4' in globals():
    del model_w4a4
gc.collect()
torch.cuda.empty_cache()

model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
model_w4a4 = quantize_opt(model, weight_quant="per_group", act_quant="per_group", input_feat=input_feat, min_prop=0.03, max_prop=0.5, quant_bits=4)
print("Finished quantizing")

Quantizing OPT model: 270it [00:53,  5.02it/s]

Finished quantizing





In [None]:
ppl_w4a4 = evaluate(model_w4a4,tokenizer)
print(f"\nNaive W4A4 quantized model perplexity: {ppl_w4a4}")
model_size = get_model_size(model_w4a4, data_width=2, group_size=128)
print(f"model size: {model_size/MiB:.2f} MiB")
# if do nothing: perplexity will be 32997.125

evaluating...:  75%|███████▌  | 30/40 [01:13<00:24,  2.43s/it]

In [None]:
@torch.no_grad()
def pseudo_quantize_model_salient_weight_fp16(
    model, w_bit, q_group_size, input_feat
):
    for n, m in model.named_modules():
        if isinstance(m, nn.Linear):
            importance = sum(input_feat[n]).float()

            ############### YOUR CODE STARTS HERE ###############

            # Step 1: Find 1% of the salient weight channels according to importance (hint: use torch.topk())
            outlier_indices = torch.topk(importance, m.weight.data.shape[1]//100)[1]
            assert outlier_indices.dim() == 1

            ############### YOUR CODE ENDS HERE #################

            # Back up the values of the salient weight channels
            outlier = m.weight.data[:, outlier_indices].clone()

            m.weight.data = #pseudo_quantize_tensor(m.weight.data, n_bit=w_bit, q_group_size=q_group_size)

            ############### YOUR CODE STARTS HERE ###############

            # Step 2: Restore the 1% salient weight channels to their original FP16 values
            m.weight.data[:, outlier_indices] = outlier

            ############### YOUR CODE ENDS HERE #################

## FP16 Model Perplexity

Let's first check the performance of the original FP16 model.

In [None]:
model_fp16 = LlamaForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16, device_map="auto"
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
ppl_fp16 = evaluate(model_fp16)
print(f"Original model (fp16) perplexity: {ppl_fp16}")

Evaluating...: 100%|██████████| 40/40 [00:09<00:00,  4.04it/s]


Original model (fp16) perplexity: 5.822948932647705


We then quantize the model to W8A8 and check the performance.

## Naive W8A8 Quantized Model Perplexity

In [None]:
model_w8a8 = quantize_llama_like(model)
print(model_w8a8)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (k_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (v_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (o_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): W8A8Linear(4096, 11008, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (up_proj): W8A8Linear(4096, 11008, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=Non

In [None]:
ppl_w8a8 = evaluator.evaluate(model_w8a8)
print(f"Naive W8A8 quantized model perplexity: {ppl_w8a8}")

Evaluating...: 100%|██████████| 40/40 [00:07<00:00,  5.02it/s]

Naive W8A8 quantized model perplexity: 5.931240558624268





We can see there is a perplexity increase. We then use SmoothQuant to quantize the model and check the performance.

## SmoothQuant W8A8 Quantized Model Perplexity

In [None]:
model = LlamaForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16, device_map="auto"
)
act_scales = torch.load("../act_scales/llama-2-7b.pt")
smooth_lm(model, act_scales, 0.85)
model_smoothquant_w8a8 = quantize_llama_like(model)
print(model_smoothquant_w8a8)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (k_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (v_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (o_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): W8A8Linear(4096, 11008, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (up_proj): W8A8Linear(4096, 11008, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=Non

We can see the smoothed model has a lower perplexity which is close to the FP16 model's. This is because SmoothQuant smooths the outliers in activations and balances the quantization difficulty of activations and weights.

In [None]:
ppl_smoothquant_w8a8 = evaluator.evaluate(model_smoothquant_w8a8)
print(f"SmoothQuant W8A8 quantized model perplexity: {ppl_smoothquant_w8a8}")

Evaluating...: 100%|██████████| 40/40 [00:08<00:00,  4.49it/s]

SmoothQuant W8A8 quantized model perplexity: 5.85634183883667





In [5]:
def evaluate_group_sizes(model, tokenizer, group_sizes=[8, 16, 32, 64, 128], device="cuda"):
    """Evaluate model perplexity with different group sizes."""
    perplexities = []
    
    # Get calibration dataset
    calib_dataset = get_calib_dataset(tokenizer)
    
    for group_size in group_sizes:
        print(f"\nTesting group size: {group_size}")
        
        # Quantize model with current group size
        model_quantized = quantize_model(
            model, 
            calib_dataset=calib_dataset,
            group_size=group_size,
            device=device
        )
        
        # Evaluate perplexity
        ppl = evaluate(model_quantized, tokenizer)
        perplexities.append(ppl)
        print(f"Group size {group_size} perplexity: {ppl:.2f}")
    
    return group_sizes, perplexities

def plot_group_size_results(group_sizes, perplexities, baseline_ppl=None):
    """Plot perplexity vs group size."""
    plt.figure(figsize=(10, 6))
    plt.plot(group_sizes, perplexities, marker='o', label='Group Quantization')
    
    if baseline_ppl is not None:
        plt.axhline(y=baseline_ppl, color='r', linestyle='--', label='FP16 Baseline')
    
    plt.xscale('log', base=2)
    plt.yscale('log', base=2)
    plt.xlabel('Group Size')
    plt.ylabel('Perplexity')
    plt.title('Impact of Group Size on Model Perplexity')
    plt.grid(True)
    plt.legend()
    plt.show()

# Run group size analysis
print("Running group size analysis...")
group_sizes = [8, 16, 32, 64, 128]
baseline_ppl = evaluate(model, tokenizer)  # Get baseline perplexity
group_sizes, perplexities = evaluate_group_sizes(model, tokenizer, group_sizes=group_sizes)

# Plot results
plot_group_size_results(group_sizes, perplexities, baseline_ppl)

Running group size analysis...


NameError: name 'evaluate' is not defined

In [None]:
%reload_ext autoreload
%autoreload 2

import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import (
    LlamaAttention,
    LlamaDecoderLayer,
    LlamaForCausalLM,
    LlamaMLP,
)
from transformers import LlamaTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
from smoothquant.smooth import smooth_lm
from smoothquant.fake_quant import quantize_llama_like, quantize_opt
import tqdm
import gc
from functools import partial

ModuleNotFoundError: No module named 'smoothquant.smooth'

In [None]:
%reload_ext autoreload
%autoreload 2

import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import (
    LlamaAttention,
    LlamaDecoderLayer,
    LlamaForCausalLM,
    LlamaMLP,
)
from transformers import LlamaTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
from smoothquant.smooth import smooth_lm
from smoothquant.fake_quant import quantize_llama_like, quantize_opt
import tqdm
import gc
from functools import partial

ModuleNotFoundError: No module named 'smoothquant.smooth'

In [None]:
%reload_ext autoreload
%autoreload 2

import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import (
    LlamaAttention,
    LlamaDecoderLayer,
    LlamaForCausalLM,
    LlamaMLP,
)
from transformers import LlamaTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
from smoothquant.smooth import smooth_lm
from smoothquant.fake_quant import quantize_llama_like, quantize_opt
import tqdm
import gc
from functools import partial

ModuleNotFoundError: No module named 'smoothquant.smooth'