In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from typing import Literal
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from functools import partial
import gc
import awq
from awq.quantize.quantizer import (
    real_quantize_model_weight,
    pseudo_quantize_model_weight,
    pseudo_quantize_tensor,
)
from awq.quantize.wnan_salient import *

In [3]:
def evaluate(model, tokenizer):
    testenc = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    testenc = tokenizer("\n\n".join(testenc["text"]), return_tensors="pt")

    testenc = testenc.input_ids.to(model.device)
    nsamples = 40
    model = model.eval()

    nlls = []
    for i in tqdm.tqdm(range(nsamples), desc="evaluating..."):
        batch = testenc[:, (i * 2048) : ((i + 1) * 2048)].to(model.device)
        with torch.no_grad():
            lm_logits = model(batch).logits
        shift_logits = lm_logits[:, :-1, :].contiguous().float()
        shift_labels = testenc[:, (i * 2048) : ((i + 1) * 2048)][:, 1:]
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(
            shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
        )
        neg_log_likelihood = loss.float() * 2048
        nlls.append(neg_log_likelihood)

    return torch.exp(torch.stack(nlls).sum() / (nsamples * 2048))

def get_model_size(model: nn.Module, data_width=16, group_size=-1):

    if group_size != -1:
        data_width += (16 + 4) / group_size

    num_elements = 0
    for param in model.parameters():
        num_elements += param.numel()
    return num_elements * data_width


Byte = 8
KiB = 1024 * Byte
MiB = 1024 * KiB
GiB = 1024 * MiB

@torch.no_grad()
def pseudo_quantize_model_salient_weight_fp16(
    model, w_bit, q_group_size, input_feat
):
    for n, m in model.named_modules():
        if isinstance(m, nn.Linear):
            importance = sum(input_feat[n]).float()

            """
            for i in range(len(input_feat[n])):
                print("input_feat[n][i].shape: {}".format(input_feat[n][i].shape))
            print("name = {}, importance.shape: {}".format(n, importance.shape))
            """

            ############### YOUR CODE STARTS HERE ###############

            # Step 1: Find 1% of the salient weight channels according to importance (hint: use torch.topk())
            outlier_indices = torch.topk(importance, k=int(importance.shape[0]/100), dim=0)[1]
            assert outlier_indices.dim() == 1

            ############### YOUR CODE ENDS HERE #################

            # Back up the values of the salient weight channels
            outlier = m.weight.data[:, outlier_indices].clone()

            m.weight.data = pseudo_quantize_tensor(m.weight.data, n_bit=w_bit, q_group_size=q_group_size)

            ############### YOUR CODE STARTS HERE ###############

            
            # Step 2: Restore the 1% salient weight channels to their original FP16 values
            m.weight.data[:, outlier_indices] = outlier

            ############### YOUR CODE ENDS HERE #################

def get_calib_dataset(tokenizer=None, n_samples=256, block_size=512):
    dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
    dataset = dataset.shuffle(seed=42)
    samples = []
    n_run = 0
    for data in dataset:
        line = data["text"]
        line = line.strip()
        line_encoded = tokenizer.encode(line)
        if len(line_encoded) > block_size:
            continue
        sample = torch.tensor([line_encoded])
        if sample.numel() == 0:
            continue
        samples.append(sample)
        n_run += 1
        if n_run == n_samples:
            break

    # now concatenate all samples and split according to block size
    cat_samples = torch.cat(samples, dim=1)
    n_split = cat_samples.shape[1] // block_size
    print(f" * Split into {n_split} blocks")
    return [cat_samples[:, i*block_size:(i+1)*block_size] for i in range(n_split)]

@torch.no_grad()
def get_calib_feat(model, tokenizer):
    input_dict = dict()
    def stat_input_max_hook(m, x, y, name):
        if isinstance(x, tuple):
            x = x[0]
        x_max = x.view(-1, x.shape[-1]).abs().mean(dim=0).cpu().detach()
        if name not in input_dict:
            input_dict[name] = [x_max]
        else:
            input_dict[name] += [x_max]

    hooks = []
    for name, m in model.named_modules():
        if isinstance(m, nn.Linear):
            hooks.append(
                m.register_forward_hook(
                    partial(stat_input_max_hook, name=name)))

    print("Collecting activation scales...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    samples = get_calib_dataset(tokenizer)
    pbar = tqdm.tqdm(samples)
    for input_ids in pbar:
        input_ids = input_ids.to(device)
        model(input_ids)

    for hook in hooks:
        hook.remove()
    return input_dict

@torch.no_grad()
def get_calib_act(model, tokenizer):
    input_dict = dict()
    def stat_input_max_hook(m, x, y, name):
        if isinstance(x, tuple):
            x = x[0]
        x_max = x.view(-1, x.shape[-1]).abs().cpu().detach()
        if name not in input_dict:
            input_dict[name] = [x_max]
        else:
            input_dict[name] += [x_max]

    hooks = []
    for name, m in model.named_modules():
        if isinstance(m, nn.Linear):
            hooks.append(
                m.register_forward_hook(
                    partial(stat_input_max_hook, name=name)))

    print("Collecting activation scales...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    samples = get_calib_dataset(tokenizer)
    pbar = tqdm.tqdm(samples)
    for input_ids in pbar:
        input_ids = input_ids.to(device)
        model(input_ids)

    for hook in hooks:
        hook.remove()
    return input_dict

# core quantization method (simulated quantization)
def pseudo_quantize_tensor_wo_group(w, n_bit=4, q_group_size=-1):
    org_w_shape = w.shape
    if q_group_size > 0:
        assert org_w_shape[-1] % q_group_size == 0
        w = w.reshape(-1, q_group_size)

    assert w.dim() == 2

    # Calculate the maximum (\alpha) and minimum values (\beta) in the tensor.
    max_val = w.amax(dim=1, keepdim=True)
    assert max_val.dim() == 2 and max_val.size(0) == w.size(0) and max_val.size(1) == 1
    min_val = w.amin(dim=1, keepdim=True)
    assert min_val.dim() == 2 and min_val.size(0) == w.size(0) and min_val.size(1) == 1

    # Calculate the scale factor and zero point.  (Formula 1 & 2)
    max_int = 2 ** n_bit - 1
    scales = (max_val - min_val).clamp(min=1e-5) / max_int
    assert scales.shape == max_val.shape
    zeros = (-torch.round(min_val / scales)).clamp_(0, max_int)
    assert scales.shape == min_val.shape

    assert torch.isnan(scales).sum() == 0
    assert torch.isnan(w).sum() == 0

    # Quantize W: Map values in the range [\beta, \alpha] to lie within [0, 2^b - 1] (Formula 3)
    w = torch.clamp(torch.round(w / scales) + zeros, 0, max_int)
    assert w.dim() == 2 and w.size(0) == scales.size(0) and w.size(1) == q_group_size

    # Dequantize W (pseudo quantization, the inverse transformation of Formula 3)
    w = (w - zeros) * scales
    assert w.dim() == 2 and w.size(0) == scales.size(0) and w.size(1) == q_group_size

    assert torch.isnan(w).sum() == 0

    w = w.reshape(org_w_shape)
    return w

@torch.no_grad()
def pseudo_quantize_model_weight(
    model, w_bit, q_group_size,
):
    for n, m in model.named_modules():
        if isinstance(m, nn.Linear):
            m.weight.data = pseudo_quantize_tensor_wo_group(m.weight.data, n_bit=w_bit, q_group_size=q_group_size)

In [None]:
# Results for Table 2 from final report

model_paths = ["facebook/opt-1.3b", "facebook/opt-2.7b", "facebook/opt-6.7b"]

for model_path in model_paths:
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
    input_feats = get_calib_feat(model, tokenizer)       # compute input features from calibration dataset
    del model
    gc.collect()
    torch.cuda.empty_cache()

    # 1) FP16
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
    pseudo_quantize_model_weight(model, w_bit=16, q_group_size=128)
    model_perplexity = evaluate(model, tokenizer)
    model_size = get_model_size(model, data_width=16, group_size=128)
    print(f"\nmodel perplexity (FP16, {model_path}): {model_perplexity:.5f}")
    print(f"model size: {model_size/MiB:.5f} MiB")
    del model
    gc.collect()
    torch.cuda.empty_cache()

    # 2) Naive W4A4 (per-channel act. quant)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
    model = quantize_opt(model, w_n_bits=4, a_n_bits=4, act_quant="per_channel")
    model.cuda()
    model_perplexity = evaluate(model, tokenizer)
    model_size = get_model_size(model, data_width=4, group_size=128)
    print(f"\nmodel perplexity (Naive W4A4, {model_path}): {model_perplexity:.5f}")
    print(f"model size: {model_size/MiB:.5f} MiB")
    del model
    gc.collect()
    torch.cuda.empty_cache()

    # 3) W4A4 protect 1% salient weights and activations (per-channel act. quant)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
    model = quantize_opt_salient_weight_act_fp16(model, input_feats, w_n_bits=4, a_n_bits=4, act_quant="per_channel")
    model.cuda()
    model_perplexity = evaluate(model, tokenizer)
    model_size = get_model_size(model, data_width=4, group_size=128)
    print(f"\nmodel perplexity (W4A4 with 1% protected salient weights AND activations, {model_path}): {model_perplexity:.5f}")
    print(f"model size: {model_size/MiB:.5f} MiB")
    del model
    gc.collect()
    torch.cuda.empty_cache()

    # 4) Naive W8A8 (per-channel act. quant)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
    model = quantize_opt(model, w_n_bits=8, a_n_bits=8, act_quant="per_channel")
    model.cuda()
    model_perplexity = evaluate(model, tokenizer)
    model_size = get_model_size(model, data_width=8, group_size=128)
    print(f"\nmodel perplexity (Naive W8A8, {model_path}): {model_perplexity:.5f}")
    print(f"model size: {model_size/MiB:.5f} MiB")
    del model
    gc.collect()
    torch.cuda.empty_cache()

    # 5) W8A8 protect 1% salient weights and activations (per-channel act. quant)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
    model = quantize_opt_salient_weight_act_fp16(model, input_feats, w_n_bits=8, a_n_bits=8, act_quant="per_channel")
    model.cuda()
    model_perplexity = evaluate(model, tokenizer)
    model_size = get_model_size(model, data_width=8, group_size=128)
    print(f"\nmodel perplexity (W8A8 with 1% protected salient weights AND activations, {model_path}): {model_perplexity:.5f}")
    print(f"model size: {model_size/MiB:.5f} MiB")
    del model
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
# Results for Table 5 from final report

model_paths = ["facebook/opt-1.3b", "facebook/opt-2.7b", "facebook/opt-6.7b"]

for model_path in model_paths:
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
    input_feats = get_calib_feat(model, tokenizer)       # compute input features from calibration dataset
    del model
    gc.collect()
    torch.cuda.empty_cache()

    # 1) FP16
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
    pseudo_quantize_model_weight(model, w_bit=16, q_group_size=128)
    model_perplexity = evaluate(model, tokenizer)
    model_size = get_model_size(model, data_width=16, group_size=128)
    print(f"\nmodel perplexity (FP16, {model_path}): {model_perplexity:.5f}")
    print(f"model size: {model_size/MiB:.5f} MiB")
    del model
    gc.collect()
    torch.cuda.empty_cache()

    # 2) Naive W4A4 (per-channel act. quant)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
    model = quantize_opt(model, w_n_bits=4, a_n_bits=4, act_quant="per_channel")
    model.cuda()
    model_perplexity = evaluate(model, tokenizer)
    model_size = get_model_size(model, data_width=4, group_size=128)
    print(f"\nmodel perplexity (Naive W4A4, {model_path}): {model_perplexity:.5f}")
    print(f"model size: {model_size/MiB:.5f} MiB")
    del model
    gc.collect()
    torch.cuda.empty_cache()

    # 3) W4A4 protect 1% salient weights (per-channel act. quant)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
    model = quantize_opt_salient_weight_fp16(model, input_feats, w_n_bits=4, a_n_bits=4, act_quant="per_channel")
    model.cuda()
    model_perplexity = evaluate(model, tokenizer)
    model_size = get_model_size(model, data_width=4, group_size=128)
    print(f"\nmodel perplexity (W4A4 with 1% protected salient weights, {model_path}): {model_perplexity:.5f}")
    print(f"model size: {model_size/MiB:.5f} MiB")
    del model
    gc.collect()
    torch.cuda.empty_cache()

    # 4) W4A4 protect 1% salient activations (per-channel act. quant)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
    model = quantize_opt_salient_act_fp16(model, input_feats, w_n_bits=4, a_n_bits=4, act_quant="per_channel")
    model.cuda()
    model_perplexity = evaluate(model, tokenizer)
    model_size = get_model_size(model, data_width=4, group_size=128)
    print(f"\nmodel perplexity (W4A4 with 1% protected salient activations, {model_path}): {model_perplexity:.5f}")
    print(f"model size: {model_size/MiB:.5f} MiB")
    del model
    gc.collect()
    torch.cuda.empty_cache()

    # 5) W4A4 protect 1% salient weights and activations (per-channel act. quant)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
    model = quantize_opt_salient_weight_act_fp16(model, input_feats, w_n_bits=4, a_n_bits=4, act_quant="per_channel")
    model.cuda()
    model_perplexity = evaluate(model, tokenizer)
    model_size = get_model_size(model, data_width=4, group_size=128)
    print(f"\nmodel perplexity (W4A4 with 1% protected salient weights AND activations, {model_path}): {model_perplexity:.5f}")
    print(f"model size: {model_size/MiB:.5f} MiB")
    del model
    gc.collect()
    torch.cuda.empty_cache()

    # 6) Naive W4A4 (per-token act. quant)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
    model = quantize_opt(model, w_n_bits=4, a_n_bits=4, act_quant="per_token")
    model.cuda()
    model_perplexity = evaluate(model, tokenizer)
    model_size = get_model_size(model, data_width=4, group_size=128)
    print(f"\nmodel perplexity (Naive W4A4, {model_path}): {model_perplexity:.5f}")
    print(f"model size: {model_size/MiB:.5f} MiB")
    del model
    gc.collect()
    torch.cuda.empty_cache()

    # 7) W4A4 protect 1% salient weights (per-token act. quant)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
    model = quantize_opt_salient_weight_fp16(model, input_feats, w_n_bits=4, a_n_bits=4, act_quant="per_token")
    model.cuda()
    model_perplexity = evaluate(model, tokenizer)
    model_size = get_model_size(model, data_width=4, group_size=128)
    print(f"\nmodel perplexity (W4A4 with 1% protected salient weights, {model_path}): {model_perplexity:.5f}")
    print(f"model size: {model_size/MiB:.5f} MiB")
    del model
    gc.collect()
    torch.cuda.empty_cache()

    # 8) W4A4 protect 1% salient activations (per-token act. quant)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
    model = quantize_opt_salient_act_fp16(model, input_feats, w_n_bits=4, a_n_bits=4, act_quant="per_token")
    model.cuda()
    model_perplexity = evaluate(model, tokenizer)
    model_size = get_model_size(model, data_width=4, group_size=128)
    print(f"\nmodel perplexity (W4A4 with 1% protected salient activations, {model_path}): {model_perplexity:.5f}")
    print(f"model size: {model_size/MiB:.5f} MiB")
    del model
    gc.collect()
    torch.cuda.empty_cache()

    # 9) W4A4 protect 1% salient weights and activations (per-token act. quant)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
    model = quantize_opt_salient_weight_act_fp16(model, input_feats, w_n_bits=4, a_n_bits=4, act_quant="per_token")
    model.cuda()
    model_perplexity = evaluate(model, tokenizer)
    model_size = get_model_size(model, data_width=4, group_size=128)
    print(f"\nmodel perplexity (W4A4 with 1% protected salient weights AND activations, {model_path}): {model_perplexity:.5f}")
    print(f"model size: {model_size/MiB:.5f} MiB")
    del model
    gc.collect()
    torch.cuda.empty_cache()

    # 10) Naive W8A8 (per-token act. quant)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
    model = quantize_opt(model, w_n_bits=8, a_n_bits=8, act_quant="per_token")
    model.cuda()
    model_perplexity = evaluate(model, tokenizer)
    model_size = get_model_size(model, data_width=8, group_size=128)
    print(f"\nmodel perplexity (Naive W8A8, {model_path}): {model_perplexity:.5f}")
    print(f"model size: {model_size/MiB:.5f} MiB")
    del model
    gc.collect()
    torch.cuda.empty_cache()

    # 11) W8A8 protect 1% salient weights (per-token act. quant)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
    model = quantize_opt_salient_weight_fp16(model, input_feats, w_n_bits=8, a_n_bits=8, act_quant="per_token")
    model.cuda()
    model_perplexity = evaluate(model, tokenizer)
    model_size = get_model_size(model, data_width=8, group_size=128)
    print(f"\nmodel perplexity (W8A8 with 1% protected salient weights, {model_path}): {model_perplexity:.5f}")
    print(f"model size: {model_size/MiB:.5f} MiB")
    del model
    gc.collect()
    torch.cuda.empty_cache()

    # 12) W8A8 protect 1% salient activations (per-token act. quant)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
    model = quantize_opt_salient_act_fp16(model, input_feats, w_n_bits=8, a_n_bits=8, act_quant="per_token")
    model.cuda()
    model_perplexity = evaluate(model, tokenizer)
    model_size = get_model_size(model, data_width=8, group_size=128)
    print(f"\nmodel perplexity (W8A8 with 1% protected salient activations, {model_path}): {model_perplexity:.5f}")
    print(f"model size: {model_size/MiB:.5f} MiB")
    del model
    gc.collect()
    torch.cuda.empty_cache()

    # 13) W8A8 protect 1% salient weights and activations (per-token act. quant)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
    model = quantize_opt_salient_weight_act_fp16(model, input_feats, w_n_bits=8, a_n_bits=8, act_quant="per_token")
    model.cuda()
    model_perplexity = evaluate(model, tokenizer)
    model_size = get_model_size(model, data_width=8, group_size=128)
    print(f"\nmodel perplexity (W8A8 with 1% protected salient weights AND activations, {model_path}): {model_perplexity:.5f}")
    print(f"model size: {model_size/MiB:.5f} MiB")
    del model
    gc.collect()
    torch.cuda.empty_cache()