In [1]:
import copy
import matplotlib.pyplot as plt
import numpy as np
import random
import torch
import torch.nn.functional as F

from time import perf_counter
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from utils import generate

In [2]:
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

In [3]:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id

tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"

In [4]:
# modify the dtype of the model to float32
def get_float32_dtype(self):
    return torch.float32

GPT2Model.dtype = property(get_float32_dtype)

In [5]:
# model memory footprint
model.get_memory_footprint()

510342192

In [6]:
# quantization function
def quantize(t):
    # min and max values of the tensor
    min_val, max_val = t.min(), t.max()

    # scale: used to quantize the tensor between 0 and 255
    # zero point: used to subtract from the tensor during quantization
    scale = (max_val - min_val) / 255
    zero_point = min_val

    # quantize and clamp to ensure we're in [0, 255]
    t_quant = (t - zero_point) / scale  # f = (q - z) * s
    t_quant = torch.clamp(t_quant, min=0, max=255) # handling rounding errors

    # keep track of scale and zero point to reverse the quantization
    # we cannot use quantized tensors in computation
    state = (scale, zero_point)

    # cast to uint8
    t_quant = t_quant.type(torch.uint8)
    return t_quant, state

In [7]:
# testing on a random tensor from the model
# tensor: from the 1st attention block layer
t = model.transformer.h[0].attn.c_attn.weight.data
print(t, t.shape)

tensor([[-0.4738, -0.2614, -0.0978,  ...,  0.0513, -0.0584,  0.0250],
        [ 0.0874,  0.1473,  0.2387,  ..., -0.0525, -0.0113, -0.0156],
        [ 0.0039,  0.0695,  0.3668,  ...,  0.1143,  0.0363, -0.0318],
        ...,
        [-0.2592, -0.0164,  0.1991,  ...,  0.0095, -0.0516,  0.0319],
        [ 0.1517,  0.2170,  0.1043,  ...,  0.0293, -0.0429, -0.0475],
        [-0.4100, -0.1924, -0.2400,  ..., -0.0046,  0.0070,  0.0198]]) torch.Size([768, 2304])


In [8]:
t_q, state = quantize(t)
print(t_q, t_q.shape, t_q.min(), t_q.max())

tensor([[107, 116, 124,  ..., 130, 125, 129],
        [132, 135, 139,  ..., 126, 128, 127],
        [128, 131, 145,  ..., 133, 130, 127],
        ...,
        [116, 127, 137,  ..., 129, 126, 130],
        [135, 138, 133,  ..., 129, 126, 126],
        [110, 119, 117,  ..., 128, 128, 129]], dtype=torch.uint8) torch.Size([768, 2304]) tensor(0, dtype=torch.uint8) tensor(255, dtype=torch.uint8)


In [20]:
# dequantization function
def dequantize(t, state):
    scale, zero_point = state
    return t.to(torch.float32) * scale + zero_point

In [10]:
t_rev = dequantize(t_q, state)
print(t_rev, t_rev.shape)

tensor([[-0.4774, -0.2783, -0.1014,  ...,  0.0313, -0.0793,  0.0092],
        [ 0.0755,  0.1419,  0.2303,  ..., -0.0572, -0.0129, -0.0351],
        [-0.0129,  0.0534,  0.3630,  ...,  0.0976,  0.0313, -0.0351],
        ...,
        [-0.2783, -0.0351,  0.1861,  ...,  0.0092, -0.0572,  0.0313],
        [ 0.1419,  0.2082,  0.0976,  ...,  0.0092, -0.0572, -0.0572],
        [-0.4110, -0.2120, -0.2562,  ..., -0.0129, -0.0129,  0.0092]]) torch.Size([768, 2304])


In [11]:
# Difference between the original and dequantized tensor
torch.abs(t - t_rev)

tensor([[0.0035, 0.0170, 0.0036,  ..., 0.0200, 0.0209, 0.0158],
        [0.0119, 0.0055, 0.0084,  ..., 0.0046, 0.0017, 0.0195],
        [0.0168, 0.0161, 0.0038,  ..., 0.0167, 0.0050, 0.0032],
        ...,
        [0.0191, 0.0187, 0.0131,  ..., 0.0004, 0.0056, 0.0006],
        [0.0098, 0.0088, 0.0067,  ..., 0.0202, 0.0143, 0.0097],
        [0.0010, 0.0196, 0.0162,  ..., 0.0084, 0.0199, 0.0107]])

In [12]:
response_expected = generate(
    model,
    tokenizer,
    [("The quick brown fox jumped over the", 10)]
)[0]
response_expected

'The quick brown fox jumped over the fence and ran to the other side of the fence'

In [13]:
# quantize the model
def quantize_model(model):
    states = {}
    for name, param in model.named_parameters():
        param.requires_grad = False
        param.data, state = quantize(param.data)
        states[name] = state
    return model, states

In [14]:
quantized_model, states = quantize_model(model)
quantized_model.get_memory_footprint()

137022768

In [15]:
# calculating the size of the additional overhead from the state dictionary
def size_in_bytes(t):
    return t.numel() * t.element_size()

sum([
    size_in_bytes(v[0]) + size_in_bytes(v[1])
    for v in states.values()
])

1184

In [21]:
# dequantize the model
# dequantize each layer of the model during the forward pass / production

def dequantize_model(model, states):
    for name, param in model.named_parameters():
        state = states[name]
        param.data = dequantize(param.data, state)
    return model

In [22]:
dequantized_model = dequantize_model(quantized_model, states)
dequantized_model.get_memory_footprint()

510342192

In [24]:
response_expected = generate(
    dequantized_model,
    tokenizer,
    [("The quick brown fox jumped over the", 10)]
)[0]

response_expected

'The quick brown fox jumped over the fence.\n\nThe fox jumped over the fence'

- Hence this is a lossy de-quantization process, the output is incorrect.