### GPTQ quantization implemented from scratch

In [None]:
!pip install datasets
!pip install transformers



In [None]:
import numpy as np
import torch
import torch.nn as nn

from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

We start with classical model load from hugging face

In [None]:
checkpoint = "HuggingFaceTB/SmolLM2-135M-Instruct"

device = "cuda:0"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16)
model.eval()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 576, padding_idx=2)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
      )
    )
    (norm

We want to create data for evaluation and computation of hessian matrix for GPTQ. I took c4 as just popular generic dataset with cleaned data. I also wrote a small function for perplexity evaluation

In [None]:
dataset = load_dataset("c4", "en", split="validation", streaming=True)
input_texts = [s["text"][:1024] for s, _ in zip(dataset, range(5000 + 128)) if s["text"]!='']
calibration_texts = input_texts[-128:]
validation_texts = input_texts[:-128]



In [None]:
def compute_perplexity(model, tokenizer, input_texts):
    perplexities = []

    for text in tqdm(input_texts):
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
        inputs = {key: value.to(device) for key, value in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])
            loss = outputs.loss.item()

        perplexities.append(np.exp(loss))

    return np.mean(perplexities)

In [None]:
model.to(device)
initial_res = compute_perplexity(model, tokenizer, validation_texts)
print("Initial model perplexity: ", initial_res)

100%|██████████| 5000/5000 [04:30<00:00, 18.51it/s]

Initial model perplexity:  44.00592542669673





We see that our small model has quite a good perplexity for its size. Let's check how much memory it uses and what it consists of

In [None]:
model.get_memory_footprint() / 1e9

0.269033984

In [None]:
param_memory = 0

for p in model.parameters():
    param_memory += p.numel() * 2

param_memory, buffer_memory = param_memory / 1e9, (sum(p.numel() * 4 for p in model.buffers())) / 1e9

In [None]:
param_memory + buffer_memory

0.269033984

Here we have 2 bytes per models parameter, as they are in bf16 and 4 bytes for buffers

Below I implemented simple symmetrical linear quantization. It separately quantizes columns of weight matrix (columns because we transpose it before multiplication) and we will use the same technique futher for GPTQ.

In [None]:
class ColumnQuantizedLinear(nn.Module):
    def __init__(self, in_features, out_features, nbits=8):
        super().__init__()
        self.register_buffer("scales", torch.zeros(in_features, dtype=torch.bfloat16))
        self.register_buffer("low_limit", torch.tensor(- 2 ** (nbits - 1) + 1))
        self.register_buffer("up_limit", torch.tensor(2 ** (nbits - 1) - 1))

        self.weight = nn.Parameter(torch.zeros((out_features, in_features), dtype=torch.int8, requires_grad=False), requires_grad=False)

    def update_weight_column(self, target, idx):
        if target.abs().max() == 0:
            scale = 1
        else:
            scale = self.up_limit / target.abs().max()
        self.scales[idx] = scale

        self.weight[:, idx] = torch.clamp(torch.round(target * scale), min=self.low_limit, max=self.up_limit).to(torch.int8)

    def forward(self, x):
        weight_bf16 = self.weight.to(torch.bfloat16) / self.scales
        return x @ weight_bf16.t()

In [None]:
for name, layer in tqdm(model.named_modules()):
    if isinstance(layer, nn.Linear):
        out_features, in_features = layer.weight.data.shape
        quantized_layer = ColumnQuantizedLinear(in_features, out_features).to(device)
        for i in range(in_features):
            quantized_layer.update_weight_column(layer.weight.data[:, i], i)
        parent_module_name, attr_name = name.rsplit('.', 1) if '.' in name else (None, name)
        parent_module = model if parent_module_name is None else dict(model.named_modules())[parent_module_name]
        setattr(parent_module, attr_name, quantized_layer)

427it [00:37, 11.45it/s]


In [None]:
linquant_res = compute_perplexity(model, tokenizer, validation_texts)
print("Linear symmetrical quantization model perplexity: ", linquant_res)

100%|██████████| 5000/5000 [05:21<00:00, 15.55it/s]

Linear symmetrical quantization model perplexity:  44.55137823154533





Here we achieve nice performance, not far from the original perplexity. Let's see how much memory we need for this model. I recompute memory used for linear layers and use an uper bound for buffers' memory

In [None]:
param_memory = 0

for p in model.parameters():
    param_memory += p.numel() * 2

for module in model.modules():
    if isinstance(module, ColumnQuantizedLinear):
        param_memory -= module.weight.numel() * 2
        param_memory += module.weight.numel()

param_memory, buffer_memory = param_memory / 1e9, (sum(p.numel() * 4 for p in model.buffers())) / 1e9

In [None]:
param_memory + buffer_memory

0.191780248

Let's reload initial model and see how it would work with only 4bits:

In [None]:
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16)
model.eval()
model.to(device)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 576, padding_idx=2)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
      )
    )
    (norm

In [None]:
for name, layer in tqdm(model.named_modules()):
    if isinstance(layer, nn.Linear):
        out_features, in_features = layer.weight.data.shape
        quantized_layer = ColumnQuantizedLinear(in_features, out_features, nbits=4).to(device)
        for i in range(in_features):
            quantized_layer.update_weight_column(layer.weight.data[:, i], i)
        parent_module_name, attr_name = name.rsplit('.', 1) if '.' in name else (None, name)
        parent_module = model if parent_module_name is None else dict(model.named_modules())[parent_module_name]
        setattr(parent_module, attr_name, quantized_layer)

427it [00:37, 11.40it/s]


In [None]:
linquant_res = compute_perplexity(model, tokenizer, validation_texts)
print("Linear symmetrical 4bit quantization model perplexity: ", linquant_res)

100%|██████████| 5000/5000 [05:12<00:00, 15.98it/s]

Linear symmetrical 4bit quantization model perplexity:  1140.3871696837778





Here we see much worse performance, it is seen that with this perplexity it doesn't make sense to quantize model. Let's see how much memory we win

In [None]:
param_memory = 0

for p in model.parameters():
    param_memory += p.numel() * 2

for module in model.modules():
    if isinstance(module, ColumnQuantizedLinear):
        param_memory -= module.weight.numel() * 2
        param_memory += module.weight.numel() / 2

param_memory, buffer_memory = param_memory / 1e9, (sum(p.numel() * 4 for p in model.buffers())) / 1e9

In [None]:
param_memory + buffer_memory

0.124540312

Now let's implement GPTQ model. Futher I do not compute memory savings, as they are the same, appart from the hessian matrix. But we don't need it after we do quantization, so we can omit it.

In [None]:
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16)
model.eval()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 576, padding_idx=2)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
      )
    )
    (norm

Here for the GPTQ quantization we compute matrix that estimates the activation that passed to every layer. Specialy for proposes like this pytorch has `register_forward_hook` functionality

In [None]:
calibration_inputs = []
for text in calibration_texts:
    inp = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
    inp = {key: value.to(device) for key, value in inp.items()}
    calibration_inputs.append(inp)

model.to(device)

hooks = []

def update_hessian_hook(module, inp, out):
    input_tensor = inp[0]
    assert input_tensor.shape[0] == 1
    input_tensor = input_tensor.squeeze(0)
    if not hasattr(module, "hessian"):
        module.hessian = torch.zeros((module.weight.data.shape[1], module.weight.data.shape[1]), device=device, requires_grad=False)

    module.hessian += 2 * (input_tensor.t() @ input_tensor)

for name, layer in tqdm(model.named_modules()):
    if isinstance(layer, nn.Linear):
        hooks.append(layer.register_forward_hook(update_hessian_hook))

with torch.no_grad():
    for inp in tqdm(calibration_inputs):
        model(**inp, labels=inp["input_ids"])

for name, layer in tqdm(model.named_modules()):
    if isinstance(layer, nn.Linear):
        layer.hessian /= len(calibration_inputs)

for hook in hooks:
    hook.remove()

427it [00:00, 183896.48it/s]
100%|██████████| 128/128 [00:09<00:00, 14.20it/s]
427it [00:00, 67492.00it/s]


Here is the implementation of GPTQ quantization that updates the remaining weights to minimize the second norm of the difference in activations. To be more precise, it solves the following problem by performing quantization column by column:

$$
argmin_{\hat{W}} ||WX - \hat{W} X||_2^2
$$

To solve this problem, we compute hessian matrix, which is

$$H =  2 X X^T$$

After quantizing a column, we update the remaining matrix using the following formula:


$$\delta = -(w_q - quant(w_q))(H^{-1}_{qq})^{-1}(H^{-1})_{:, q}$$

Since we don't want to update the hessian matrix every time using a naive inverse algorithm, the paper proposes an approximate inverse that can be computed using the formula:

$$H^{-1} = H^{-1} - \frac{H^{-1}_{:,q} H^{-1}_{q,:}}{H^{-1}_{q,q}}$$

For simplicity in implementation, we do not pack 4-bit integers and save them only as 8 bits. In a real-world scenario, for memory savings, we should do so.

In [None]:
class GPTQQuantizedLinear(nn.Module):
    def __init__(self, in_features, out_features, nbits=8):
        super().__init__()
        self.register_buffer("scales", torch.zeros(in_features, dtype=torch.bfloat16))
        self.register_buffer("low_limit", torch.tensor(- 2 ** (nbits - 1) + 1))
        self.register_buffer("up_limit", torch.tensor(2 ** (nbits - 1) - 1))

        self.weight = nn.Parameter(torch.zeros((out_features, in_features), dtype=torch.int8, requires_grad=False), requires_grad=False)

    def init_weights(self, target_weight, hessian):
        assert target_weight.shape == self.weight.shape

        diag_index = torch.arange(hessian.shape[0])
        hessian[diag_index, diag_index] += 0.1 * torch.mean(torch.diag(hessian))

        invH = torch.inverse(hessian)

        for idx in range(self.weight.shape[1]):
            target = target_weight[:, idx]
            if target.abs().max() == 0:
                scale = 1
            else:
                scale = self.up_limit / target.abs().max()
            self.scales[idx] = scale

            quantized = torch.clamp(torch.round(target * scale), min=self.low_limit, max=self.up_limit).to(torch.int8)
            self.weight[:, idx] = quantized
            dequant = quantized.to(torch.bfloat16) / scale

            delta = -(target - dequant).unsqueeze(1) / invH[idx, idx] * invH[idx, :]
            target_weight[:, idx + 1:] += delta[:, idx + 1:]
            invH -= (invH[:, idx].unsqueeze(1) @ invH[idx, :].unsqueeze(0)) / invH[idx, idx]


    def forward(self, x):
        weight_bf16 = self.weight.to(torch.bfloat16) / self.scales
        return x @ weight_bf16.t()

In [None]:
for name, layer in tqdm(model.named_modules()):
    if isinstance(layer, nn.Linear):
        out_features, in_features = layer.weight.data.shape
        quantized_layer = GPTQQuantizedLinear(in_features, out_features).to(device)
        quantized_layer.init_weights(layer.weight.data, layer.hessian)
        parent_module_name, attr_name = name.rsplit('.', 1) if '.' in name else (None, name)
        parent_module = model if parent_module_name is None else dict(model.named_modules())[parent_module_name]
        setattr(parent_module, attr_name, quantized_layer)

427it [01:25,  4.97it/s]


In [None]:
GPTQ_res = compute_perplexity(model, tokenizer, validation_texts)
print("GPTQ model perplexity: ", GPTQ_res)

100%|██████████| 5000/5000 [05:07<00:00, 16.26it/s]

GPTQ model perplexity:  44.36175512440009





In comparison to our linear quantization, we see slight improvement. Let's see how our model works in 4bit settings.

In [None]:
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16)
model.eval()
model.to(device)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 576, padding_idx=2)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
      )
    )
    (norm

In [None]:
for name, layer in tqdm(model.named_modules()):
    if isinstance(layer, nn.Linear):
        hooks.append(layer.register_forward_hook(update_hessian_hook))

with torch.no_grad():
    for inp in tqdm(calibration_inputs):
        model(**inp, labels=inp["input_ids"])

for name, layer in tqdm(model.named_modules()):
    if isinstance(layer, nn.Linear):
        layer.hessian /= len(calibration_inputs)

for hook in hooks:
    hook.remove()

427it [00:00, 280145.13it/s]
100%|██████████| 128/128 [00:09<00:00, 13.52it/s]
427it [00:00, 106018.34it/s]


In [None]:
for name, layer in tqdm(model.named_modules()):
    if isinstance(layer, nn.Linear):
        out_features, in_features = layer.weight.data.shape
        quantized_layer = GPTQQuantizedLinear(in_features, out_features, nbits=4).to(device)
        quantized_layer.init_weights(layer.weight.data, layer.hessian)
        parent_module_name, attr_name = name.rsplit('.', 1) if '.' in name else (None, name)
        parent_module = model if parent_module_name is None else dict(model.named_modules())[parent_module_name]
        setattr(parent_module, attr_name, quantized_layer)

427it [01:28,  4.80it/s]


In [None]:
GPTQ_res = compute_perplexity(model, tokenizer, validation_texts)
print("GPTQ 4bit model perplexity: ", GPTQ_res)

100%|██████████| 5000/5000 [05:09<00:00, 16.16it/s]

GPTQ 4bit model perplexity:  143.13055899315862





Here, it is seen that the performance is quite far from ideal, but at the same time, it is much better than the naive approach.

Possible further steps:

- 1 Investigate how 4-bit quantization can be improved. Conduct testing and ensure that the method is implemented 100% correctly, as the perplexity for 4-bit quantization looks suspicious.
- 2 For real-world scenarios, we need to control data types for buffers and possibly compute GPTQ layer by layer to avoid extra memory allocations.