**MVA 24/25 - LLM Final Project - On LLM Quantization**

# Optimal Brain Quantization

Samson GOUREVITCH, Thomas ROBERT

$\texttt{\{samson.gourevitch,thomas.robert.x21\}@polytechnique.edu}$

TODO: Project ABSTRACT


### Installation

In [1]:
# ! pip install datasets, transformers

import torch
from torch import nn
from torch.nn import functional as F
from sklearn.metrics import accuracy_score

from transformers import AutoModelForSequenceClassification, AutoTokenizer

from torch.utils.data import DataLoader
from datasets import load_dataset

import copy
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

Load BERT model

In [2]:
model_name = 'huawei-noah/TinyBERT_General_4L_312D'  # Example, check actual availability
tokenizer = AutoTokenizer.from_pretrained(model_name)

unquantized_model = AutoModelForSequenceClassification.from_pretrained(model_name)
unquantized_model = unquantized_model.to(device)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at huawei-noah/TinyBERT_General_4L_312D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Prepare data

In [3]:
dataset = load_dataset("glue", "qnli")

# Tokenize the dataset
def tokenize_function(example):
    return tokenizer(example["sentence"], truncation=True, padding="max_length", max_length=128)

tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets.set_format("torch", columns=["input_ids", "attention_mask", "label"])

### Model fine-tuning on Downstream Task

In [4]:
num_epochs = 1
learning_rate = 5e-5
train_batch_size = 64
eval_batch_size = 64

optimizer = torch.optim.AdamW(unquantized_model.parameters(), lr=learning_rate)
# Create DataLoaders
train_dataloader = DataLoader(tokenized_datasets["train"], batch_size=train_batch_size, shuffle=True)
eval_dataloader = DataLoader(tokenized_datasets["validation"], batch_size=eval_batch_size)

for epoch in range(num_epochs):
    unquantized_model.train()
    total_loss = 0

    ### Training
    for batch in train_dataloader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        optimizer.zero_grad()
        outputs = unquantized_model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    avg_loss = total_loss / len(train_dataloader)

    ### Evaluation
    unquantized_model.eval()
    eval_loss = 0
    preds = []
    true_labels = []

    with torch.no_grad():
        for batch in eval_dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            outputs = unquantized_model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            eval_loss += loss.item()

            logits = outputs.logits
            preds.extend(torch.argmax(logits, dim=-1).cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
    
    avg_eval_loss = eval_loss / len(eval_dataloader)
    eval_accuracy = accuracy_score(true_labels, preds)
    print(f"Epoch {epoch+1}, Eval Loss: {avg_loss:.4f}, Eval Loss: {avg_eval_loss:.4f}, Eval Accuracy: {eval_accuracy:.4f}")

Epoch 1, Eval Loss: 0.6552, Eval Loss: 0.6429, Eval Accuracy: 0.6231


### Surrogate OBC Quantization

In [5]:
def quantize(weights, grid):
  """Quantize a weight to the nearest value in the quantization grid."""
  return torch.round(weights / grid) * grid

def naive_quantization(W, quantization_grid):
  return quantize(W, quantization_grid)

def compute_H(X):
  """Compute the average Hessian matrix H_F = 2 * (X_reshaped^T * X_reshaped) / N,
  where N is the total number of vectors from X.
  Assumes X is a tensor with shape [batch_size, seq_length, hidden_size].
  """
  # Reshape to merge batch and sequence dimensions.
  X_reshaped = X.view(-1, X.shape[-1])
  N = X_reshaped.size(0)
  H_avg = 2 * (X_reshaped.T @ X_reshaped) / N
  return H_avg

def compute_H_inv(H, damp=1e-3):
  """Compute the inverse of the Hessian matrix H_F.
  Adds a small damping factor to the diagonal for numerical stability.
  """
  I = torch.eye(H.shape[0], device=H.device, dtype=H.dtype)

  L = torch.linalg.cholesky(H + damp * I)
  L_inv = torch.inverse(L)
  H_inv = L_inv.T @ L_inv

  return H_inv

In [6]:
def quantize_OBC(W,X, quantization_grid):
  """Optimal Brain Quantization (OBQ) "
      W: weight matrix"
      X: input vector"
      quantization_grid: quantization precision"
  """
  H = compute_H(X)
  H_inv_save = compute_H_inv(H)

  Q = torch.zeros_like(W)

  for row in tqdm(range(W.shape[0])):
    w = W[row,:]
    H_inv = H_inv_save.clone()

    for col in range(W.shape[1]):
      q = quantize(w[col], quantization_grid)
      Q[row, col] = q

      error = w[col] - q
      d = torch.diag(H_inv)[col]

      Hinv_row = H_inv[col,]

      w = w - error  / d * (Hinv_row)
      H_inv = H_inv - torch.outer(Hinv_row, Hinv_row) / d

  return Q

### Quantizing the Model

Attach quantization hooks to perform quantization during the forward pass

In [7]:
model_OBC = copy.deepcopy(unquantized_model).to(device)
model_naive = copy.deepcopy(unquantized_model).to(device)

In [8]:
def save_hook(module, input, output):
    module.__dict__["unquantized_input"] = input
    module.__dict__["unquantized_output"] = output
    # module.__dict__["unquantized_weight"] = module.weight
    
    module._forward_hooks.popitem()

save_hooks = []

for layer_idx, layer in enumerate(model_OBC.bert.encoder.layer):
    save_hooks.append(layer.attention.self.query.register_forward_hook(save_hook))
    save_hooks.append(layer.attention.self.key.register_forward_hook(save_hook))
    save_hooks.append(layer.attention.self.value.register_forward_hook(save_hook))
    save_hooks.append(layer.attention.output.dense.register_forward_hook(save_hook))
    save_hooks.append(layer.intermediate.dense.register_forward_hook(save_hook))
    save_hooks.append(layer.output.dense.register_forward_hook(save_hook))

for layer_idx, layer in enumerate(model_naive.bert.encoder.layer):
    save_hooks.append(layer.attention.self.query.register_forward_hook(save_hook))
    save_hooks.append(layer.attention.self.key.register_forward_hook(save_hook))
    save_hooks.append(layer.attention.self.value.register_forward_hook(save_hook))
    save_hooks.append(layer.attention.output.dense.register_forward_hook(save_hook))
    save_hooks.append(layer.intermediate.dense.register_forward_hook(save_hook))
    save_hooks.append(layer.output.dense.register_forward_hook(save_hook))

In [9]:
# Load Calibration data
calibration_batch_size = 64
calibration_dataloader = DataLoader(tokenized_datasets["train"], batch_size=calibration_batch_size, shuffle=True)

calibration_batch = next(iter(calibration_dataloader))
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)

# Forward pass through the model to trigger the hooks with calibration data from
with torch.no_grad():
  print("Saving input/output for unquantized version to OBC model")
  outputs = model_OBC(input_ids=input_ids, attention_mask=attention_mask)

  print("Saving input/output for unquantized version to naive model")
  outputs = model_naive(input_ids=input_ids, attention_mask=attention_mask)

Saving input/output for unquantized version to OBC model
Saving input/output for unquantized version to naive model


In [10]:
quantization_grid = 0.01

weight_approximation_error_OBC = list()
reconstruction_error_OBC = list()
def OBC_hook(module, input, output, quantization_grid=quantization_grid):
    X_data = module.unquantized_input[0]
    W = module.weight.clone()
    quantized_weight = quantize_OBC(W, X_data, quantization_grid=quantization_grid)
    module.weight = nn.Parameter(quantized_weight)
    
    module._forward_hooks.popitem()
    new_output = module(*input)

    weight_approximation_error_OBC.append(F.mse_loss(W,module.weight))
    reconstruction_error_OBC.append(F.mse_loss(module.unquantized_output, new_output))

    del module.unquantized_input
    del module.unquantized_output



weight_approximation_error_naive = list()
reconstruction_error_naive = list()
def naive_hook(module, input, output, quantization_grid=quantization_grid):
    W = module.weight.clone()
    quantized_weight = naive_quantization(W, quantization_grid=quantization_grid)
    module.weight = nn.Parameter(quantized_weight)

    module._forward_hooks.popitem()
    new_output = module(*input)

    weight_approximation_error_naive.append(F.mse_loss(W,module.weight))
    reconstruction_error_naive.append(F.mse_loss(module.unquantized_output, new_output))

In [11]:
OBC_hooks = []
for layer_idx, layer in enumerate(model_OBC.bert.encoder.layer):
    OBC_hooks.append(layer.attention.self.query.register_forward_hook(OBC_hook))
    OBC_hooks.append(layer.attention.self.key.register_forward_hook(OBC_hook))
    OBC_hooks.append(layer.attention.self.value.register_forward_hook(OBC_hook))
    OBC_hooks.append(layer.attention.output.dense.register_forward_hook(OBC_hook))
    OBC_hooks.append(layer.intermediate.dense.register_forward_hook(OBC_hook))
    OBC_hooks.append(layer.output.dense.register_forward_hook(OBC_hook))

naive_hooks = []
for layer_idx, layer in enumerate(model_naive.bert.encoder.layer):
    naive_hooks.append(layer.attention.self.query.register_forward_hook(naive_hook))
    naive_hooks.append(layer.attention.self.key.register_forward_hook(naive_hook))
    naive_hooks.append(layer.attention.self.value.register_forward_hook(naive_hook))
    naive_hooks.append(layer.attention.output.dense.register_forward_hook(naive_hook))
    naive_hooks.append(layer.intermediate.dense.register_forward_hook(naive_hook))
    naive_hooks.append(layer.output.dense.register_forward_hook(naive_hook))

In [12]:
with torch.no_grad():
  print("Performing OBC quantization")
  outputs = model_OBC(input_ids=input_ids, attention_mask=attention_mask)

  print("Performing naive quantization")
  outputs = model_naive(input_ids=input_ids, attention_mask=attention_mask)

Performing OBC quantization


100%|██████████| 312/312 [00:11<00:00, 27.02it/s]
100%|██████████| 312/312 [00:11<00:00, 26.97it/s]
100%|██████████| 312/312 [00:11<00:00, 27.03it/s]
100%|██████████| 312/312 [00:11<00:00, 26.98it/s]
100%|██████████| 1200/1200 [00:44<00:00, 26.77it/s]
100%|██████████| 312/312 [00:45<00:00,  6.90it/s]
100%|██████████| 312/312 [00:11<00:00, 26.85it/s]
100%|██████████| 312/312 [00:11<00:00, 26.98it/s]
100%|██████████| 312/312 [00:11<00:00, 26.53it/s]
100%|██████████| 312/312 [00:11<00:00, 27.01it/s]
100%|██████████| 1200/1200 [00:44<00:00, 26.84it/s]
100%|██████████| 312/312 [00:45<00:00,  6.92it/s]
100%|██████████| 312/312 [00:11<00:00, 26.97it/s]
100%|██████████| 312/312 [00:11<00:00, 26.98it/s]
100%|██████████| 312/312 [00:11<00:00, 27.11it/s]
100%|██████████| 312/312 [00:11<00:00, 26.93it/s]
100%|██████████| 1200/1200 [00:44<00:00, 26.88it/s]
100%|██████████| 312/312 [00:45<00:00,  6.91it/s]
100%|██████████| 312/312 [00:11<00:00, 27.01it/s]
100%|██████████| 312/312 [00:11<00:00, 26.92

Performing naive quantization





### Saving Results

In [13]:
result_types = ["weight_approximation_error_OBC", "reconstruction_error_OBC", "weight_approximation_error_naive", "reconstruction_error_naive"]
layer_idx = [0, 1, 2, 3]
layer_types = ["query", "key", "value", "proj", "intermediate", "output"]


results = dict()
for result_type in result_types:
    results[result_type] = dict()

for layer_idx in layer_idx:
    for result_type in result_types:
        results[result_type]['layer_'+str(layer_idx)] = dict()    
    for layer_type in layer_types:
        results['weight_approximation_error_OBC']['layer_'+str(layer_idx)][layer_type] = weight_approximation_error_OBC.pop(0).item()
        results['reconstruction_error_OBC']['layer_'+str(layer_idx)][layer_type] = reconstruction_error_OBC.pop(0).item()
        results['weight_approximation_error_naive']['layer_'+str(layer_idx)][layer_type] = weight_approximation_error_naive.pop(0).item()
        results['reconstruction_error_naive']['layer_'+str(layer_idx)][layer_type] = reconstruction_error_naive.pop(0).item()

In [14]:
models = {'unquantized': unquantized_model, 'OBC': model_OBC, 'naive': model_naive}

for k,model in models.items():
    model.eval()

    eval_loss = 0
    preds = []
    true_labels = []

    with torch.no_grad():
        for batch in eval_dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            eval_loss += loss.item()

            logits = outputs.logits
            preds.extend(torch.argmax(logits, dim=-1).cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

    avg_eval_loss = eval_loss / len(eval_dataloader)
    eval_accuracy = accuracy_score(true_labels, preds)

    results['loss_'+k] = avg_eval_loss
    results['accuracy_'+k]= eval_accuracy

    print(f"Model type {k}, Eval Loss: {avg_eval_loss:.4f}, Eval Accuracy: {eval_accuracy:.4f}")

Model type unquantized, Eval Loss: 0.6429, Eval Accuracy: 0.6231
Model type OBC, Eval Loss: 0.6429, Eval Accuracy: 0.6238
Model type naive, Eval Loss: 0.6428, Eval Accuracy: 0.6227


In [15]:
torch.save(unquantized_model.state_dict(), "model_unquantized_weights_"+str(quantization_grid)+".pth")
torch.save(model_OBC.state_dict(), "model_OBC_weights_"+str(quantization_grid)+".pth")
torch.save(model_naive.state_dict(), "model_naive_weights_"+str(quantization_grid)+".pth")

torch.save(results, "results_"+str(quantization_grid)+".pth")

In [16]:
# # Load the model weights
# quantization_grid = 0.01
# model_name = 'huawei-noah/TinyBERT_General_4L_312D'
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# unquantized_model = AutoModelForSequenceClassification.from_pretrained(model_name)
# model_OBC = AutoModelForSequenceClassification.from_pretrained(model_name)
# model_naive = AutoModelForSequenceClassification.from_pretrained(model_name)

# unquantized_model.load_state_dict(torch.load("model_unquantized_weights_"+str(quantization_grid)+".pth"))
# model_OBC.load_state_dict(torch.load("model_OBC_weights_"+str(quantization_grid)+".pth"))
# model_naive.load_state_dict(torch.load("model_naive_weights_"+str(quantization_grid)+".pth"))

# results = torch.load("results_"+str(quantization_grid)+".pth")

## References

Hassibi, B., Stork, D. G., & Wolff, G. J. (1993, March). **Optimal brain surgeon and general network pruning**. In IEEE international conference on neural networks (pp. 293-299). IEEE.

Frantar, E., & Alistarh, D. (2022). **Optimal brain compression: A framework for accurate post-training quantization and pruning**. Advances in Neural Information Processing Systems, 35, 4475-4488

Frantar, E., Ashkboos, S., Hoefler, T., & Alistarh, D. (2022). **Gptq: Accurate post-training quantization for generative pre-trained transformers**. arXiv preprint arXiv:2210.17323.

Jiao, X., Yin, Y., Shang, L., Jiang, X., Chen, X., Li, L., ... & Liu, Q. (2019). **Tinybert: Distilling bert for natural language understanding.** arXiv preprint arXiv:1909.10351.

Wang, A., Singh, A., Michael, J., Hill, F., Levy, O., & Bowman, S. R. (2018). **GLUE: A multi-task benchmark and analysis platform for natural language understanding**. arXiv preprint arXiv:1804.07461.