In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [3]:
import awq

In [4]:
import tqdm
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from functools import partial
import gc
from awq.quantize.quantizer import (
    real_quantize_model_weight,
    pseudo_quantize_model_weight,
    pseudo_quantize_tensor,
)

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
def evaluate(model, tokenizer):
    testenc = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    testenc = tokenizer("\n\n".join(testenc["text"]), return_tensors="pt")

    testenc = testenc.input_ids.to(model.device)
    nsamples = 40
    model = model.eval()

    nlls = []
    for i in tqdm.tqdm(range(nsamples), desc="evaluating..."):
        batch = testenc[:, (i * 2048) : ((i + 1) * 2048)].to(model.device)
        with torch.no_grad():
            lm_logits = model(batch).logits
        shift_logits = lm_logits[:, :-1, :].contiguous().float()
        shift_labels = testenc[:, (i * 2048) : ((i + 1) * 2048)][:, 1:]
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(
            shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
        )
        neg_log_likelihood = loss.float() * 2048
        nlls.append(neg_log_likelihood)

    return torch.exp(torch.stack(nlls).sum() / (nsamples * 2048))

def get_model_size(model: nn.Module, data_width=16, group_size=-1):

    if group_size != -1:
        data_width += (16 + 4) / group_size

    num_elements = 0
    for param in model.parameters():
        num_elements += param.numel()
    return num_elements * data_width


Byte = 8
KiB = 1024 * Byte
MiB = 1024 * KiB
GiB = 1024 * MiB

In [6]:
@torch.no_grad()
def pseudo_quantize_model_salient_weight_fp16(
    model, w_bit, q_group_size, input_feat
):
    for n, m in model.named_modules():
        if isinstance(m, nn.Linear):
            importance = sum(input_feat[n]).float()

            """
            for i in range(len(input_feat[n])):
                print("input_feat[n][i].shape: {}".format(input_feat[n][i].shape))
            print("name = {}, importance.shape: {}".format(n, importance.shape))
            """

            ############### YOUR CODE STARTS HERE ###############

            # Step 1: Find 1% of the salient weight channels according to importance (hint: use torch.topk())
            outlier_indices = torch.topk(importance, k=int(importance.shape[0]/100), dim=0)[1]
            assert outlier_indices.dim() == 1

            ############### YOUR CODE ENDS HERE #################

            # Back up the values of the salient weight channels
            outlier = m.weight.data[:, outlier_indices].clone()

            m.weight.data = pseudo_quantize_tensor(m.weight.data, n_bit=w_bit, q_group_size=q_group_size)

            ############### YOUR CODE STARTS HERE ###############

            
            # Step 2: Restore the 1% salient weight channels to their original FP16 values
            m.weight.data[:, outlier_indices] = outlier

            ############### YOUR CODE ENDS HERE #################

def get_calib_dataset(tokenizer=None, n_samples=256, block_size=512):
    dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
    dataset = dataset.shuffle(seed=42)
    samples = []
    n_run = 0
    for data in dataset:
        line = data["text"]
        line = line.strip()
        line_encoded = tokenizer.encode(line)
        if len(line_encoded) > block_size:
            continue
        sample = torch.tensor([line_encoded])
        if sample.numel() == 0:
            continue
        samples.append(sample)
        n_run += 1
        if n_run == n_samples:
            break

    # now concatenate all samples and split according to block size
    cat_samples = torch.cat(samples, dim=1)
    n_split = cat_samples.shape[1] // block_size
    print(f" * Split into {n_split} blocks")
    return [cat_samples[:, i*block_size:(i+1)*block_size] for i in range(n_split)]

@torch.no_grad()
def get_calib_feat(model, tokenizer):
    input_dict = dict()
    def stat_input_max_hook(m, x, y, name):
        if isinstance(x, tuple):
            x = x[0]
        x_max = x.view(-1, x.shape[-1]).abs().mean(dim=0).cpu().detach()
        if name not in input_dict:
            input_dict[name] = [x_max]
        else:
            input_dict[name] += [x_max]

    hooks = []
    for name, m in model.named_modules():
        if isinstance(m, nn.Linear):
            hooks.append(
                m.register_forward_hook(
                    partial(stat_input_max_hook, name=name)))

    print("Collecting activation scales...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    samples = get_calib_dataset(tokenizer)
    pbar = tqdm.tqdm(samples)
    for input_ids in pbar:
        input_ids = input_ids.to(device)
        model(input_ids)

    for hook in hooks:
        hook.remove()
    return input_dict

# core quantization method (simulated quantization)
def pseudo_quantize_tensor(w, n_bit=4, q_group_size=-1):
    org_w_shape = w.shape
    if q_group_size > 0:
        assert org_w_shape[-1] % q_group_size == 0
        w = w.reshape(-1, q_group_size)

    assert w.dim() == 2

    # Calculate the maximum (\alpha) and minimum values (\beta) in the tensor.
    max_val = w.amax(dim=1, keepdim=True)
    assert max_val.dim() == 2 and max_val.size(0) == w.size(0) and max_val.size(1) == 1
    min_val = w.amin(dim=1, keepdim=True)
    assert min_val.dim() == 2 and min_val.size(0) == w.size(0) and min_val.size(1) == 1

    # Calculate the scale factor and zero point.  (Formula 1 & 2)
    max_int = 2 ** n_bit - 1
    scales = (max_val - min_val).clamp(min=1e-5) / max_int
    assert scales.shape == max_val.shape
    zeros = (-torch.round(min_val / scales)).clamp_(0, max_int)
    assert scales.shape == min_val.shape

    assert torch.isnan(scales).sum() == 0
    assert torch.isnan(w).sum() == 0

    # Quantize W: Map values in the range [\beta, \alpha] to lie within [0, 2^b - 1] (Formula 3)
    w = torch.clamp(torch.round(w / scales) + zeros, 0, max_int)
    assert w.dim() == 2 and w.size(0) == scales.size(0) and w.size(1) == q_group_size

    # Dequantize W (pseudo quantization, the inverse transformation of Formula 3)
    w = (w - zeros) * scales
    assert w.dim() == 2 and w.size(0) == scales.size(0) and w.size(1) == q_group_size

    assert torch.isnan(w).sum() == 0

    w = w.reshape(org_w_shape)
    return w

@torch.no_grad()
def pseudo_quantize_model_weight(
    model, w_bit, q_group_size,
):
    for n, m in model.named_modules():
        if isinstance(m, nn.Linear):
            m.weight.data = pseudo_quantize_tensor(m.weight.data, n_bit=w_bit, q_group_size=q_group_size)

In [7]:
# 1) FP16 model

#model_path = "facebook/opt-2.7b"
model_path = "facebook/opt-1.3b"
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
pseudo_quantize_model_weight(model, w_bit=16, q_group_size=128)

# Evaluate the model
model_perplexity = evaluate(model, tokenizer)
model_size = get_model_size(model, data_width=16, group_size=128)
print(f"\nmodel perplexity: {model_perplexity:.5f}")
print(f"model size: {model_size/MiB:.5f} MiB")

evaluating...: 100%|██████████| 40/40 [00:20<00:00,  1.91it/s]


model perplexity: 14.46904
model size: 2534.11728 MiB





In [8]:
# 2) W4A16 but with salient weight protection

# a) compute input features from calibration dataset
del model
gc.collect()
torch.cuda.empty_cache()
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
input_feat = get_calib_feat(model, tokenizer)

# b) quantize the model weights to 4 bits while protecting the salient channels of those weights
del model
gc.collect()
torch.cuda.empty_cache()
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
pseudo_quantize_model_salient_weight_fp16(model, w_bit=4, q_group_size=128, input_feat=input_feat)

# Evaluate the model
model_perplexity = evaluate(model, tokenizer)
model_size = get_model_size(model, data_width=4, group_size=128)
print(f"\nmodel perplexity: {model_perplexity:.5f}")
print(f"model size: {model_size/MiB:.5f} MiB")

Repo card metadata block was not found. Setting CardData to empty.


Collecting activation scales...
 * Split into 127 blocks


100%|██████████| 127/127 [00:13<00:00,  9.52it/s]
evaluating...: 100%|██████████| 40/40 [00:20<00:00,  1.94it/s]


model perplexity: 14.78181
model size: 651.91025 MiB





In [9]:
import torch.nn as nn
import torch.nn.functional as F
from awq.quantize.quantizer import pseudo_quantize_tensor
from typing import Literal


@torch.no_grad()
def quantize_activation_per_token_absmax(t, n_bits=8):
    # t.shape = (input seq_len, hidden_size)

    t_shape = t.shape
    t.view(-1, t_shape[-1])
    scales = t.abs().max(dim=-1, keepdim=True)[0]           # scales.shape = (input seq_len, 1) max along the channel dimension
    q_max = 2 ** (n_bits - 1) - 1
    scales.clamp_(min=1e-5).div_(q_max)
    t.div_(scales).round_().mul_(scales)
    return t

@torch.no_grad()
def quantize_activation_per_tensor_absmax(t, n_bits=8):
    print("t.shape: {}".format(t.shape))
    t_shape = t.shape
    t.view(-1, t_shape[-1])
    scales = t.abs().max()                                  # scales.shape = (1) max along the entire tensor    
    q_max = 2 ** (n_bits - 1) - 1
    scales.clamp_(min=1e-5).div_(q_max)
    t.div_(scales).round_().mul_(scales)
    return t

@torch.no_grad()
def quantize_activation_per_token_absmax_salient(t, outlier_indices, n_bits=8):
    # t.shape = (input seq_len, hidden_size)
    # input_feats = list of tensors of shape (hidden_size,) <- my guess is that len(input_feats) = input seq_len

    assert outlier_indices.dim() == 1           # shape = (1% of hidden_size,)
    t_copy = t.clone()

    t_shape = t.shape
    t.view(-1, t_shape[-1])
    scales = t.abs().max(dim=-1, keepdim=True)[0]           # scales.shape = (input seq_len, 1) max along the channel dimension
    q_max = 2 ** (n_bits - 1) - 1
    scales.clamp_(min=1e-5).div_(q_max)
    t.div_(scales).round_().mul_(scales)

    t[:, outlier_indices] = t_copy[:, outlier_indices]

    return t

@torch.no_grad()
def quantize_activation_per_tensor_absmax_salient(t, outlier_indices, n_bits=8):
    # t.shape = (input seq_len, hidden_size)
    assert outlier_indices.dim() == 1           # shape = (1% of hidden_size,)
    t_copy = t.clone()

    t_shape = t.shape
    t.view(-1, t_shape[-1])
    scales = t.abs().max()                                  # scales.shape = (1) max along the entire tensor    
    q_max = 2 ** (n_bits - 1) - 1
    scales.clamp_(min=1e-5).div_(q_max)
    t.div_(scales).round_().mul_(scales)

    t[:, outlier_indices] = t_copy[:, outlier_indices]

    return t

class QuantizedLinear(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        w_n_bits=4,
        a_n_bits=16,
        act_quant: Literal["per_token", "per_tensor", "none"] = "per_token",
        quantize_output: bool = False,
        outlier_indices: torch.Tensor = None,
    ):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features

        self.register_buffer(
            "weight",
            torch.randn(
                self.out_features,
                self.in_features,
                dtype=torch.float16,
                requires_grad=False,
            ),
        )

        if bias:
            self.register_buffer(
                "bias",
                torch.zeros(
                    (1, self.out_features), dtype=torch.float16, requires_grad=False
                ),
            )
        else:
            self.register_buffer("bias", None)

        if act_quant == "per_token":
            self.act_quant_name = "per_token"
            self.act_quant = partial(
                quantize_activation_per_token_absmax, n_bits=a_n_bits
            )
        elif act_quant == "per_tensor":
            self.act_quant_name = "per_tensor"
            self.act_quant = partial(
                quantize_activation_per_tensor_absmax, n_bits=a_n_bits
            )
        elif act_quant == "per_token_salient":
            self.act_quant_name = "per_token_salient"
            self.act_quant = partial(
                quantize_activation_per_token_absmax_salient, 
                outlier_indices=outlier_indices,
                n_bits=a_n_bits,
            )
        elif act_quant == "per_tensor_salient":
            self.act_quant_name = "per_tensor_salient"
            self.act_quant = partial(
                quantize_activation_per_tensor_absmax_salient, 
                outlier_indices=outlier_indices,
                n_bits=a_n_bits,
            )
        else:
            self.act_quant_name = "None"
            self.act_quant = lambda x: x

        if quantize_output:
            self.output_quant_name = self.act_quant_name
            self.output_quant = self.act_quant
        else:
            self.output_quant_name = "None"
            self.output_quant = lambda x: x

    def to(self, *args, **kwargs):
        super(QuantizedLinear, self).to(*args, **kwargs)
        self.weight = self.weight.to(*args, **kwargs)
        if self.bias is not None:
            self.bias = self.bias.to(*args, **kwargs)
        return self

    @torch.no_grad()
    def forward(self, x):
        q_x = self.act_quant(x)
        y = F.linear(q_x, self.weight, self.bias)
        q_y = self.output_quant(y)
        return q_y

    @classmethod
    def from_linear(
        cls,
        linear: nn.Linear,
        w_n_bits: int = 4,
        a_n_bits: int = 4,
        zero_point: bool = True,
        group_size: int = 128,
        act_quant: Literal["per_token", "per_tensor", "none"] = "per_token",
        quantize_output: bool = False,
    ):

        # this is a linear layer that will eventually enhouse the quantized weights
        awq_linear = cls(
            linear.in_features,
            linear.out_features,
            bias=linear.bias is not None,
            w_n_bits=w_n_bits,
            a_n_bits=a_n_bits,  
            act_quant=act_quant,
            quantize_output=quantize_output,
        )

        awq_linear.weight.data = pseudo_quantize_tensor(
            w=linear.weight.data,
            n_bit=w_n_bits,
            zero_point=zero_point,
            q_group_size=group_size,
        )

        if linear.bias is not None:
            awq_linear.bias.data = linear.bias.data

        return awq_linear

    @classmethod
    def from_linear_salient_weight(
        cls,
        linear: nn.Linear,
        input_feats: torch.Tensor,
        w_n_bits: int = 4,
        a_n_bits: int = 4,
        zero_point: bool = True,
        group_size: int = 128,
        act_quant: Literal["per_token", "per_tensor", "none"] = "per_token",
        quantize_output: bool = False,
    ):

        # this is a linear layer that will eventually enhouse the quantized weights
        awq_linear = cls(
            linear.in_features,
            linear.out_features,
            bias=linear.bias is not None,
            w_n_bits=w_n_bits,
            a_n_bits=a_n_bits,
            act_quant=act_quant,
            quantize_output=quantize_output,
        )

        awq_linear.weight.data = pseudo_quantize_tensor(
            w=linear.weight.data,
            n_bit=w_n_bits,
            zero_point=zero_point,
            q_group_size=group_size,
        ) 

        # Step 1: Find 1% of the salient weight channels according to importance (hint: use torch.topk())
        importance = sum(input_feats).float()
        outlier_indices = torch.topk(importance, k=int(importance.shape[0]/100), dim=0)[1]
        assert outlier_indices.dim() == 1
            
        # Step 2: Restore the 1% salient weight channels to their original FP16 values
        outlier = linear.weight.data[:, outlier_indices].clone()
        awq_linear.weight.data[:, outlier_indices] = outlier

        if linear.bias is not None:
            awq_linear.bias.data = linear.bias.data

        return awq_linear

    @classmethod
    def from_linear_salient_weight_act(
        cls,
        linear: nn.Linear,
        input_feats: torch.Tensor,
        w_n_bits: int = 4,
        a_n_bits: int = 4,
        zero_point: bool = True,
        group_size: int = 128,
        act_quant: Literal["per_token", "per_tensor", "none"] = "per_token",
        quantize_output: bool = False,
    ):

        # Step 1: Find 1% of the salient weight channels according to importance (hint: use torch.topk())
        importance = sum(input_feats).float()
        outlier_indices = torch.topk(importance, k=int(importance.shape[0]/100), dim=0)[1]
        assert outlier_indices.dim() == 1

        # this is a linear layer that will eventually enhouse the quantized weights
        awq_linear = cls(
            linear.in_features,
            linear.out_features,
            bias=linear.bias is not None,
            w_n_bits=w_n_bits,
            a_n_bits=a_n_bits,
            act_quant=act_quant + "_salient",
            quantize_output = quantize_output,
            outlier_indices = outlier_indices,
        )

        awq_linear.weight.data = pseudo_quantize_tensor(
            w=linear.weight.data,
            n_bit=w_n_bits,
            zero_point=zero_point,
            q_group_size=group_size,
        )
            
        # Step 2: Restore the 1% salient weight channels to their original FP16 values
        outlier = linear.weight.data[:, outlier_indices].clone()
        awq_linear.weight.data[:, outlier_indices] = outlier

        if linear.bias is not None:
            awq_linear.bias.data = linear.bias.data

        return awq_linear

def quantize_opt(
    model,
    w_n_bits: int = 4,
    a_n_bits: int = 4,
    zero_point: bool = True,
    group_size: int = 128,
    act_quant: Literal["per_token", "per_tensor", "none"] = "per_token",
    quantize_bmm_input: bool = True,
):
    from transformers.models.opt.modeling_opt import (
        OPTAttention,
        OPTDecoderLayer,
    )

    for name, m in model.model.named_modules():
        if isinstance(m, OPTDecoderLayer):
            m.fc1 = QuantizedLinear.from_linear(
                m.fc1,
                w_n_bits=w_n_bits,
                a_n_bits=a_n_bits,
                zero_point=zero_point,
                group_size=group_size,
                act_quant=act_quant,
            )
            m.fc2 = QuantizedLinear.from_linear(
                m.fc2,
                w_n_bits=w_n_bits,          # new input param
                a_n_bits=a_n_bits,          # new input param
                zero_point=zero_point,      # new input param
                group_size=group_size,      # new input param
                act_quant=act_quant,        # new input param
            )
        elif isinstance(m, OPTAttention):
            # Her we simulate quantizing BMM inputs by quantizing the output of q_proj, k_proj, v_proj
            m.q_proj = QuantizedLinear.from_linear(
                m.q_proj,
                w_n_bits=w_n_bits,
                a_n_bits=a_n_bits,
                zero_point=zero_point,
                group_size=group_size,
                act_quant=act_quant,
                quantize_output=quantize_bmm_input,
            )
            m.k_proj = QuantizedLinear.from_linear(
                m.k_proj,
                w_n_bits=w_n_bits,
                a_n_bits=a_n_bits,
                zero_point=zero_point,
                group_size=group_size,
                act_quant=act_quant,
                quantize_output=quantize_bmm_input,
            )
            m.v_proj = QuantizedLinear.from_linear(
                m.v_proj,
                w_n_bits=w_n_bits,
                a_n_bits=a_n_bits,
                zero_point=zero_point,
                group_size=group_size,
                act_quant=act_quant,
                quantize_output=quantize_bmm_input,
            )
            m.out_proj = QuantizedLinear.from_linear(
                m.out_proj,
                w_n_bits=w_n_bits,          # new input param
                a_n_bits=a_n_bits,          # new input param
                zero_point=zero_point,      # new input param
                group_size=group_size,      # new input param
                act_quant=act_quant,        # new input param
            )

    return model

def quantize_opt_salient_weight_fp16(
    model,
    input_feats,
    w_n_bits: int = 4,
    a_n_bits: int = 4,
    zero_point: bool = True,
    group_size: int = 128,
    act_quant: Literal["per_token", "per_tensor", "none"] = "per_token",
    quantize_bmm_input: bool = True,
):
    from transformers.models.opt.modeling_opt import (
        OPTAttention,
        OPTDecoderLayer,
    )

    for name, m in model.model.named_modules():
        if isinstance(m, OPTDecoderLayer):
            m.fc1 = QuantizedLinear.from_linear_salient_weight(
                m.fc1,
                input_feats["model." + name + ".fc1"],
                w_n_bits=w_n_bits,
                a_n_bits=a_n_bits,
                zero_point=zero_point,
                group_size=group_size,
                act_quant=act_quant,
            )
            m.fc2 = QuantizedLinear.from_linear_salient_weight(
                m.fc2,
                input_feats["model." + name + ".fc2"],
                w_n_bits=w_n_bits,
                a_n_bits=a_n_bits,
                zero_point=zero_point,
                group_size=group_size,
                act_quant=act_quant,
            )
        elif isinstance(m, OPTAttention):
            # Her we simulate quantizing BMM inputs by quantizing the output of q_proj, k_proj, v_proj
            m.q_proj = QuantizedLinear.from_linear_salient_weight(
                m.q_proj,
                input_feats["model." + name + ".q_proj"],
                w_n_bits=w_n_bits,
                a_n_bits=a_n_bits,
                zero_point=zero_point,
                group_size=group_size,
                act_quant=act_quant,
                quantize_output=quantize_bmm_input,
            )
            m.k_proj = QuantizedLinear.from_linear_salient_weight(
                m.k_proj,
                input_feats["model." + name + ".k_proj"],
                w_n_bits=w_n_bits,
                a_n_bits=a_n_bits,
                zero_point=zero_point,
                group_size=group_size,
                act_quant=act_quant,
                quantize_output=quantize_bmm_input,
            )
            m.v_proj = QuantizedLinear.from_linear_salient_weight(
                m.v_proj,
                input_feats["model." + name + ".v_proj"],
                w_n_bits=w_n_bits,
                a_n_bits=a_n_bits,
                zero_point=zero_point,
                group_size=group_size,
                act_quant=act_quant,
                quantize_output=quantize_bmm_input,
            )
            m.out_proj = QuantizedLinear.from_linear_salient_weight(
                m.out_proj,
                input_feats["model." + name + ".out_proj"],
                w_n_bits=w_n_bits,          # new input param
                a_n_bits=a_n_bits,          # new input param
                zero_point=zero_point,      # new input param
                group_size=group_size,      # new input param
                act_quant=act_quant,        # new input param
            )

    return model

def quantize_opt_salient_weight_act_fp16(
    model,
    input_feats,
    w_n_bits: int = 4,
    a_n_bits: int = 4,
    zero_point: bool = True,
    group_size: int = 128,
    act_quant: Literal["per_token", "per_tensor", "none"] = "per_token",
    quantize_bmm_input: bool = True,
):
    from transformers.models.opt.modeling_opt import (
        OPTAttention,
        OPTDecoderLayer,
    )

    for name, m in model.model.named_modules():
        if isinstance(m, OPTDecoderLayer):
            m.fc1 = QuantizedLinear.from_linear_salient_weight_act(
                m.fc1,
                input_feats["model." + name + ".fc1"],
                w_n_bits=w_n_bits,
                a_n_bits=a_n_bits,
                zero_point=zero_point,
                group_size=group_size,
                act_quant=act_quant,
            )
            m.fc2 = QuantizedLinear.from_linear_salient_weight_act(
                m.fc2,
                input_feats["model." + name + ".fc2"],
                w_n_bits=w_n_bits,
                a_n_bits=a_n_bits,
                zero_point=zero_point,
                group_size=group_size,
                act_quant=act_quant,
            )
        elif isinstance(m, OPTAttention):
            # Her we simulate quantizing BMM inputs by quantizing the output of q_proj, k_proj, v_proj
            m.q_proj = QuantizedLinear.from_linear_salient_weight_act(
                m.q_proj,
                input_feats["model." + name + ".q_proj"],
                w_n_bits=w_n_bits,
                a_n_bits=a_n_bits,
                zero_point=zero_point,
                group_size=group_size,
                act_quant=act_quant,
                quantize_output=quantize_bmm_input,
            )
            m.k_proj = QuantizedLinear.from_linear_salient_weight_act(
                m.k_proj,
                input_feats["model." + name + ".k_proj"],
                w_n_bits=w_n_bits,
                a_n_bits=a_n_bits,
                zero_point=zero_point,
                group_size=group_size,
                act_quant=act_quant,
                quantize_output=quantize_bmm_input,
            )
            m.v_proj = QuantizedLinear.from_linear_salient_weight_act(
                m.v_proj,
                input_feats["model." + name + ".v_proj"],
                w_n_bits=w_n_bits,
                a_n_bits=a_n_bits,
                zero_point=zero_point,
                group_size=group_size,
                act_quant=act_quant,
                quantize_output=quantize_bmm_input,
            )
            m.out_proj = QuantizedLinear.from_linear_salient_weight_act(
                m.out_proj,
                input_feats["model." + name + ".out_proj"],
                w_n_bits=w_n_bits,          # new input param
                a_n_bits=a_n_bits,          # new input param
                zero_point=zero_point,      # new input param
                group_size=group_size,      # new input param
                act_quant=act_quant,        # new input param
            )

    return model

In [10]:
# naive W8A8 quantization (FP 16 model = 14.46904 perplexity for opt1.3b and 12.35642 perplexity for opt2.7b)

w_n_bits = 8
a_n_bits = 8

del model
gc.collect()
torch.cuda.empty_cache()
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
model = quantize_opt(model, w_n_bits=w_n_bits, a_n_bits=a_n_bits, act_quant="per_token")              # perplexity = 14.55xxx
model.cuda()

# Evaluate the model
model_perplexity = evaluate(model, tokenizer)
model_size = get_model_size(model, data_width=w_n_bits, group_size=128)
print(f"\nmodel perplexity: {model_perplexity:.5f}")
print(f"model size: {model_size/MiB:.5f} MiB")

evaluating...: 100%|██████████| 40/40 [00:23<00:00,  1.73it/s]


model perplexity: 14.93691
model size: 104.38248 MiB





In [11]:
# W8A8 quantization with salient weight protection (FP 16 model = 14.46904 perplexity for opt1.3b and 12.35642 perplexity for opt2.7b)

w_n_bits = 8
a_n_bits = 8

# a) compute input features from calibration dataset
del model
gc.collect()
torch.cuda.empty_cache()
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
input_feats = get_calib_feat(model, tokenizer)

# print out all keys in input_feats
print(input_feats.keys())

del model
gc.collect()
torch.cuda.empty_cache()
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
model = quantize_opt_salient_weight_fp16(model, input_feats, w_n_bits=w_n_bits, a_n_bits=a_n_bits, act_quant="per_token")
#model = quantize_opt_salient_weight_fp16(model, input_feats, w_n_bits=w_n_bits, a_n_bits=a_n_bits, act_quant="per_channel")

model.cuda()

# Evaluate the model
model_perplexity = evaluate(model, tokenizer)
model_size = get_model_size(model, data_width=w_n_bits, group_size=128)
print(f"\nmodel perplexity: {model_perplexity:.5f}")
print(f"model size: {model_size/MiB:.5f} MiB")

Collecting activation scales...


Repo card metadata block was not found. Setting CardData to empty.


 * Split into 127 blocks


100%|██████████| 127/127 [00:13<00:00,  9.22it/s]


dict_keys(['model.decoder.layers.0.self_attn.q_proj', 'model.decoder.layers.0.self_attn.k_proj', 'model.decoder.layers.0.self_attn.v_proj', 'model.decoder.layers.0.self_attn.out_proj', 'model.decoder.layers.0.fc1', 'model.decoder.layers.0.fc2', 'model.decoder.layers.1.self_attn.q_proj', 'model.decoder.layers.1.self_attn.k_proj', 'model.decoder.layers.1.self_attn.v_proj', 'model.decoder.layers.1.self_attn.out_proj', 'model.decoder.layers.1.fc1', 'model.decoder.layers.1.fc2', 'model.decoder.layers.2.self_attn.q_proj', 'model.decoder.layers.2.self_attn.k_proj', 'model.decoder.layers.2.self_attn.v_proj', 'model.decoder.layers.2.self_attn.out_proj', 'model.decoder.layers.2.fc1', 'model.decoder.layers.2.fc2', 'model.decoder.layers.3.self_attn.q_proj', 'model.decoder.layers.3.self_attn.k_proj', 'model.decoder.layers.3.self_attn.v_proj', 'model.decoder.layers.3.self_attn.out_proj', 'model.decoder.layers.3.fc1', 'model.decoder.layers.3.fc2', 'model.decoder.layers.4.self_attn.q_proj', 'model.dec

evaluating...: 100%|██████████| 40/40 [00:23<00:00,  1.72it/s]


model perplexity: 14.91101
model size: 104.38248 MiB





In [12]:
# W4A16 quantization with salient weight protection (reference implementation from homework = 14.78181)
w_n_bits = 4
a_n_bits = 16

# a) compute input features from calibration dataset
del model
gc.collect()
torch.cuda.empty_cache()
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
input_feats = get_calib_feat(model, tokenizer)

# print out all keys in input_feats
print(input_feats.keys())

del model
gc.collect()
torch.cuda.empty_cache()
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
model = quantize_opt_salient_weight_fp16(model, input_feats, w_n_bits=w_n_bits, a_n_bits=a_n_bits, act_quant="per_token")
#model = quantize_opt_salient_weight_fp16(model, input_feats, w_n_bits=w_n_bits, a_n_bits=a_n_bits, act_quant="per_channel")

model.cuda()

# Evaluate the model
model_perplexity = evaluate(model, tokenizer)
model_size = get_model_size(model, data_width=w_n_bits, group_size=128)
print(f"\nmodel perplexity: {model_perplexity:.5f}")
print(f"model size: {model_size/MiB:.5f} MiB")

Repo card metadata block was not found. Setting CardData to empty.


Collecting activation scales...
 * Split into 127 blocks


100%|██████████| 127/127 [00:13<00:00,  9.40it/s]


dict_keys(['model.decoder.layers.0.self_attn.q_proj', 'model.decoder.layers.0.self_attn.k_proj', 'model.decoder.layers.0.self_attn.v_proj', 'model.decoder.layers.0.self_attn.out_proj', 'model.decoder.layers.0.fc1', 'model.decoder.layers.0.fc2', 'model.decoder.layers.1.self_attn.q_proj', 'model.decoder.layers.1.self_attn.k_proj', 'model.decoder.layers.1.self_attn.v_proj', 'model.decoder.layers.1.self_attn.out_proj', 'model.decoder.layers.1.fc1', 'model.decoder.layers.1.fc2', 'model.decoder.layers.2.self_attn.q_proj', 'model.decoder.layers.2.self_attn.k_proj', 'model.decoder.layers.2.self_attn.v_proj', 'model.decoder.layers.2.self_attn.out_proj', 'model.decoder.layers.2.fc1', 'model.decoder.layers.2.fc2', 'model.decoder.layers.3.self_attn.q_proj', 'model.decoder.layers.3.self_attn.k_proj', 'model.decoder.layers.3.self_attn.v_proj', 'model.decoder.layers.3.self_attn.out_proj', 'model.decoder.layers.3.fc1', 'model.decoder.layers.3.fc2', 'model.decoder.layers.4.self_attn.q_proj', 'model.dec

evaluating...: 100%|██████████| 40/40 [00:23<00:00,  1.72it/s]


model perplexity: 14.68109
model size: 53.19107 MiB





In [13]:
# W8A8 quantization with salient weight AND activation protection
w_n_bits = 8
a_n_bits = 8

# a) compute input features from calibration dataset
del model
gc.collect()
torch.cuda.empty_cache()
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
input_feats = get_calib_feat(model, tokenizer)

# print out all keys in input_feats
print(input_feats.keys())

del model
gc.collect()
torch.cuda.empty_cache()
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
model = quantize_opt_salient_weight_act_fp16(model, input_feats, w_n_bits=w_n_bits, a_n_bits=a_n_bits, act_quant="per_token")

model.cuda()

# Evaluate the model
model_perplexity = evaluate(model, tokenizer)
model_size = get_model_size(model, data_width=w_n_bits, group_size=128)
print(f"\nmodel perplexity: {model_perplexity:.5f}")
print(f"model size: {model_size/MiB:.5f} MiB")

Repo card metadata block was not found. Setting CardData to empty.


Collecting activation scales...
 * Split into 127 blocks


100%|██████████| 127/127 [00:13<00:00,  9.27it/s]


dict_keys(['model.decoder.layers.0.self_attn.q_proj', 'model.decoder.layers.0.self_attn.k_proj', 'model.decoder.layers.0.self_attn.v_proj', 'model.decoder.layers.0.self_attn.out_proj', 'model.decoder.layers.0.fc1', 'model.decoder.layers.0.fc2', 'model.decoder.layers.1.self_attn.q_proj', 'model.decoder.layers.1.self_attn.k_proj', 'model.decoder.layers.1.self_attn.v_proj', 'model.decoder.layers.1.self_attn.out_proj', 'model.decoder.layers.1.fc1', 'model.decoder.layers.1.fc2', 'model.decoder.layers.2.self_attn.q_proj', 'model.decoder.layers.2.self_attn.k_proj', 'model.decoder.layers.2.self_attn.v_proj', 'model.decoder.layers.2.self_attn.out_proj', 'model.decoder.layers.2.fc1', 'model.decoder.layers.2.fc2', 'model.decoder.layers.3.self_attn.q_proj', 'model.decoder.layers.3.self_attn.k_proj', 'model.decoder.layers.3.self_attn.v_proj', 'model.decoder.layers.3.self_attn.out_proj', 'model.decoder.layers.3.fc1', 'model.decoder.layers.3.fc2', 'model.decoder.layers.4.self_attn.q_proj', 'model.dec

evaluating...: 100%|██████████| 40/40 [00:24<00:00,  1.64it/s]


model perplexity: 14.84560
model size: 104.38248 MiB



