In [29]:
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim import Adam, AdamW

import matplotlib.pyplot as plt

from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset

import sys
sys.path.append("../../") 
from qlib.quantizers.lut_quantizer import QuantizerLUT
from qlib.quantizers.lut_quantizer_reparametrized import QuantizerLUT_reparametrized

from qlib.quantizers.lsq_quantizer_autograd import QuantizerLSQwithAutograd
from qlib.quantizers.lsq_quantizer import QuantizerLSQ

from qlib.initializers.criterias import MomentCriteria
from qlib.initializers.greedy_step_offset_initializer import GreedyInitializer


DEVICE = 'cuda'

In [2]:
def load_llama(path_to_pretrained):
	tokenizer = AutoTokenizer.from_pretrained(path_to_pretrained)
	model = AutoModelForCausalLM.from_pretrained(path_to_pretrained)
	return tokenizer, model

#path_to_pretrained = '/home/msst/repo/pretrained_models/AMD-Llama-135m'
#path_to_pretrained = '/home/msst/repo/pretrained_models/TinyLlama_v1.1'
path_to_pretrained = '/home/msst/repo/pretrained_models/Llama2-7b-hf'

tokenizer, model = load_llama(path_to_pretrained)
w = model.get_decoder().layers[0].self_attn.q_proj.weight

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

torch.Size([4096, 4096])

In [51]:
GROUP_SIZE = 64
BIT_WIDTH = 2

### LUT ###

initialization_params = {
		'optim' : 'Adam',
		'lr' : 3e-3,
		'steps' : 500,
		'grad_norm': False,
		'criteria' : MomentCriteria(p=4)
	}

lut_quantizer = QuantizerLUT(
	group_size=GROUP_SIZE,
	bit_width=BIT_WIDTH,
	initialization_params=initialization_params
)

lut_quantizer_r = QuantizerLUT_reparametrized(
	group_size=GROUP_SIZE,
	bit_width=BIT_WIDTH,
	initialization_params=initialization_params
)

### LSQ ###

greed_initializer = GreedyInitializer(
                             criteria=MomentCriteria(p=4, sum_along_axis=-1),
                             n_grid_steps=25,
                             n_grid_zooms=2)

lsq_quantizer = QuantizerLSQ(
	group_size=GROUP_SIZE,
	bit_width=BIT_WIDTH,
	use_offset=False,
	initializer=greed_initializer
)

lsq_quantizer_ag = QuantizerLSQwithAutograd(
	group_size=GROUP_SIZE,
	bit_width=BIT_WIDTH,
	use_offset=False,
	initializer=greed_initializer
)


w = w.to(DEVICE)

lut_quantizer = lut_quantizer.to(DEVICE)
lut_quantizer_r = lut_quantizer_r.to(DEVICE)

lsq_quantizer = lsq_quantizer.to(DEVICE)
lsq_quantizer_ag = lsq_quantizer_ag.to(DEVICE)

In [31]:
def moment_loss(x, x_q, p=4):
	return torch.sum(torch.pow(torch.abs(x-x_q), p))

print("LUT:", moment_loss(w, lut_quantizer(w)).item())
print("LUT reparametrized:", moment_loss(w, lut_quantizer_r(w)).item())

print("LSQ:", moment_loss(w, lsq_quantizer(w)).item())
print("LSQ_AG:", moment_loss(w, lsq_quantizer_ag(w)).item())

LUT: 0.0832228735089302
LUT reparametrized: 0.0754116028547287
LSQ: 0.48644912242889404
LSQ_AG: 0.48644912242889404


In [52]:
print("LSQ:", moment_loss(w, lsq_quantizer(w)).item())
print("LSQ_AG:", moment_loss(w, lsq_quantizer_ag(w)).item())

LSQ: 0.7248697280883789
LSQ_AG: 0.7248697280883789


In [53]:
lr = 1e-4
n_steps = 1000

In [54]:
optim_lsq_ag = Adam(lsq_quantizer_ag.parameters(), lr=lr)
scheduler_lsq_ag = CosineAnnealingLR(optim_lsq_ag, T_max=n_steps)

for i in range(n_steps):
	optim_lsq_ag.zero_grad()
	loss = moment_loss(w, lsq_quantizer_ag(w))
	loss.backward()
	optim_lsq_ag.step()
	scheduler_lsq_ag.step()
print(loss)

tensor(0.5224, device='cuda:0', grad_fn=<SumBackward0>)


In [55]:
optim_lsq = Adam(lsq_quantizer.parameters(), lr=lr)
scheduler_lsq = CosineAnnealingLR(optim_lsq, T_max=n_steps)

for i in range(n_steps):
	optim_lsq.zero_grad()
	loss = moment_loss(w, lsq_quantizer(w))
	loss.backward()
	optim_lsq.step()
	scheduler_lsq.step()
print(loss)

tensor(0.5588, device='cuda:0', grad_fn=<SumBackward0>)


### Plot

In [7]:
# borders = lut_quantizer_.borders[0].detach().cpu()
# levels = lut_quantizer_.levels.data[0].detach().cpu()
# x = w.detach().cpu()

# plt.hist(x)
# plt.vlines(levels, 0, 1, colors='r', linestyles='solid', label='')
# plt.vlines(borders, -0.5, 0.5, colors='y', linestyles='solid', label='')