In [1]:
import copy
import matplotlib.pyplot as plt
import numpy as np
import random
import time
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from helpers import generate

In [2]:
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)


In [3]:
# Define PAD token = EOS Token = 50256
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id

# pad on the left so we can append new tokens on the right
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"

In [4]:
# fix dtype post quantization to "pretend" to be fp32
def get_float32_dtype(self):
    return torch.float32
GPT2Model.dtype = property(get_float32_dtype)


In [5]:
model.get_memory_footprint()

510342192

In [6]:
def quantize(t):
    # obtain range of values in the tensor to map between 0 and 255
    min_val, max_val = t.min(), t.max()

    # determine the "zero-point", or value in the tensor to map to 0
    scale = (max_val - min_val) / 255
    zero_point = min_val

    # quantize and clamp to ensure we're in [0, 255]
    t_quant = (t - zero_point) / scale
    t_quant = torch.clamp(t_quant, min=0, max=255)

    # keep track of scale and zero_point for reversing quantization
    state = (scale, zero_point)

    # cast to uint8 and return
    t_quant = t_quant.type(torch.uint8)
    return t_quant, state

In [7]:
t = model.transformer.h[0].attn.c_attn.weight.data
print(t, t.shape)

tensor([[-0.4738, -0.2614, -0.0978,  ...,  0.0513, -0.0584,  0.0250],
        [ 0.0874,  0.1473,  0.2387,  ..., -0.0525, -0.0113, -0.0156],
        [ 0.0039,  0.0695,  0.3668,  ...,  0.1143,  0.0363, -0.0318],
        ...,
        [-0.2592, -0.0164,  0.1991,  ...,  0.0095, -0.0516,  0.0319],
        [ 0.1517,  0.2170,  0.1043,  ...,  0.0293, -0.0429, -0.0475],
        [-0.4100, -0.1924, -0.2400,  ..., -0.0046,  0.0070,  0.0198]]) torch.Size([768, 2304])


In [8]:
t_q, state = quantize(t)
print(t_q, t_q.min(), t_q.max())

tensor([[107, 116, 124,  ..., 130, 125, 129],
        [132, 135, 139,  ..., 126, 128, 127],
        [128, 131, 145,  ..., 133, 130, 127],
        ...,
        [116, 127, 137,  ..., 129, 126, 130],
        [135, 138, 133,  ..., 129, 126, 126],
        [110, 119, 117,  ..., 128, 128, 129]], dtype=torch.uint8) tensor(0, dtype=torch.uint8) tensor(255, dtype=torch.uint8)


In [9]:
def dequantize(t, state):
    scale, zero_point = state
    return t.to(torch.float32) * scale + zero_point

In [10]:
t_rev = dequantize(t_q, state)
print(t_rev)

tensor([[-0.4774, -0.2783, -0.1014,  ...,  0.0313, -0.0793,  0.0092],
        [ 0.0755,  0.1419,  0.2303,  ..., -0.0572, -0.0129, -0.0351],
        [-0.0129,  0.0534,  0.3630,  ...,  0.0976,  0.0313, -0.0351],
        ...,
        [-0.2783, -0.0351,  0.1861,  ...,  0.0092, -0.0572,  0.0313],
        [ 0.1419,  0.2082,  0.0976,  ...,  0.0092, -0.0572, -0.0572],
        [-0.4110, -0.2120, -0.2562,  ..., -0.0129, -0.0129,  0.0092]])


In [11]:
torch.abs(t - t_rev)

tensor([[0.0035, 0.0170, 0.0036,  ..., 0.0200, 0.0209, 0.0158],
        [0.0119, 0.0055, 0.0084,  ..., 0.0046, 0.0017, 0.0195],
        [0.0168, 0.0161, 0.0038,  ..., 0.0167, 0.0050, 0.0032],
        ...,
        [0.0191, 0.0187, 0.0131,  ..., 0.0004, 0.0056, 0.0006],
        [0.0098, 0.0088, 0.0067,  ..., 0.0202, 0.0143, 0.0097],
        [0.0010, 0.0196, 0.0162,  ..., 0.0084, 0.0199, 0.0107]])

In [12]:
response_expected = generate(
    model,
    tokenizer,
    [("The quick brown fox jumped over the", 10)]
)[0]
response_expected

NameError: name 'generate' is not defined

In [13]:
def quantize_model(model):
    states = {}
    for name, param in model.named_parameters():
        param.requires_grad = False
        param.data, state = quantize(param.data)
        states[name] = state
    return model, states




In [14]:
quant_model, states = quantize_model(model)

In [15]:
quant_model.get_memory_footprint()

137022768

In [17]:
states

{'transformer.wte.weight': (tensor(0.0120), tensor(-1.2698)),
 'transformer.wpe.weight': (tensor(0.0337), tensor(-4.5381)),
 'transformer.h.0.ln_1.weight': (tensor(0.0008), tensor(0.0419)),
 'transformer.h.0.ln_1.bias': (tensor(0.0018), tensor(-0.2589)),
 'transformer.h.0.attn.c_attn.weight': (tensor(0.0221), tensor(-2.8436)),
 'transformer.h.0.attn.c_attn.bias': (tensor(0.0099), tensor(-1.3371)),
 'transformer.h.0.attn.c_proj.weight': (tensor(0.0250), tensor(-3.3171)),
 'transformer.h.0.attn.c_proj.bias': (tensor(0.0185), tensor(-2.6844)),
 'transformer.h.0.ln_2.weight': (tensor(0.0057), tensor(0.0453)),
 'transformer.h.0.ln_2.bias': (tensor(0.0055), tensor(-0.6648)),
 'transformer.h.0.mlp.c_fc.weight': (tensor(0.0271), tensor(-2.3131)),
 'transformer.h.0.mlp.c_fc.bias': (tensor(0.0042), tensor(-0.7462)),
 'transformer.h.0.mlp.c_proj.weight': (tensor(0.0479), tensor(-6.1433)),
 'transformer.h.0.mlp.c_proj.bias': (tensor(0.0098), tensor(-1.0288)),
 'transformer.h.1.ln_1.weight': (tenso

In [16]:
def size_in_bytes(t):
    return t.numel() * t.element_size()



In [18]:
sum([
    size_in_bytes(v[0]) + size_in_bytes(v[1])
    for v in states.values()
])

1184

In [19]:
def dequantize_model(model, states):
    for name, param in model.named_parameters():
        param.data = dequantize(param.data, states[name])
    return model

In [20]:
dequant_model = dequantize_model(quant_model, states)

In [22]:
dequant_model.get_memory_footprint()

510342192

In [None]:
response_expected = generate(
    dequant_model,
    tokenizer,
    [("The quick brown fox jumped over the", 10)]

)[0]
response_expected