In [1]:
import torch
import sys
sys.path.append("/home/msst/repo/Quantization")
import qlib
DEVICE = 'cuda:0'
from qlib.utils.incoherence_preprocessing.incoherence_process_functions import incoherence_process, incoherence_preprocess

In [2]:
m = 2048
n = 2048

w = torch.randn(m, n, device=DEVICE)

quantizer = qlib.trellis_quantizer(
	L=16,
	K=3,
	V=2,
	T=256,
	decode_mode="LowBitSym", 
	tlut_bits=10
).to(DEVICE)

reco, states = quantizer.quantize(w)
packed = quantizer.pack_trellis(states)

reco_ref = quantizer.reconstruct_weight(packed, w.shape)
err_ref  = torch.mean((reco_ref  - w)**2, dim=-1)
print(f"error (unpack): {err_ref.mean():.3f} ± {err_ref.std():.3f}")


reco_fast = quantizer.reconstruct_weight_fast(packed, w.shape)
err_fast = torch.mean((reco_fast - w)**2, dim=-1)
print(f"error (fast)  : {err_fast.mean():.3f} ± {err_fast.std():.3f}")


100%|██████████| 64/64 [00:05<00:00, 12.49it/s]
100%|██████████| 64/64 [00:04<00:00, 14.58it/s]


error (unpack): 0.021 ± 0.001
error (fast)  : 0.021 ± 0.001


In [3]:
raise

RuntimeError: No active exception to reraise

### Without Scaling per group

In [3]:
fp_model = qlib.load_model('Llama2-7b-hf', torch_dtype=torch.float16)
W = fp_model.get_submodule('model.layers.0.self_attn.q_proj').weight.data.to(DEVICE)
Wr, SU, SV = incoherence_preprocess(W)

Wr = Wr.reshape(-1, 256)
scales = Wr.std()
Wr_scaled = Wr / scales
Wr_scaled = Wr_scaled.reshape_as(W)


quantizer = qlib.trellis_quantizer(
	L=16,
	K=2,
	V=2,
	T=256,
	decode_mode="LowBitSym", 
	tlut_bits=10
).to(DEVICE)

Wr_scaled_q, states = quantizer.quantize(Wr_scaled)

Wr_scaled_q = Wr_scaled_q.reshape(-1, 256)
Wr_q = Wr_scaled_q * scales
Wr_q = Wr_q.reshape_as(W)

W_q = incoherence_process(Wr_q, SU, SV)

err  = torch.mean(((W_q - W)**2).reshape(-1, 256), dim=-1)
print(f"error: {err.mean():.3e} ± {err.std():.3e}")

#error: 1.512e-05 ± 2.830e-06

#error: 5.038e-05 ± 5.205e-06

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 256/256 [00:19<00:00, 12.94it/s]
100%|██████████| 256/256 [00:20<00:00, 12.79it/s]


error: 1.377e-05 ± 2.306e-06


### With Scaling per group

In [4]:
fp_model = qlib.load_model('Llama2-7b-hf', torch_dtype=torch.float16)
W = fp_model.get_submodule('model.layers.0.self_attn.q_proj').weight.data.to(DEVICE)
Wr, SU, SV = incoherence_preprocess(W)

Wr = Wr.reshape(-1, 256)
scales = Wr.std(dim=-1, keepdim=True)
Wr_scaled = Wr / scales
Wr_scaled = Wr_scaled.reshape_as(W)

quantizer = qlib.trellis_quantizer(
	L=16,
	K=2,
	V=2,
	T=256,
	decode_mode="LowBitSym", 
	tlut_bits=10
).to(DEVICE)

Wr_scaled_q, states = quantizer.quantize(Wr_scaled)
Wr_scaled_q = Wr_scaled_q.reshape(-1, 256)
Wr_q = Wr_scaled_q * scales
Wr_q = Wr_q.reshape_as(W)
W_q = incoherence_process(Wr_q, SU, SV)

err  = torch.mean(((W_q - W)**2).reshape(-1, 256), dim=-1)
print(f"error: {err.mean():.3e} ± {err.std():.3e}")

# error: 5.013e-05 ± 5.153e-06

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 256/256 [00:19<00:00, 13.41it/s]
100%|██████████| 256/256 [00:20<00:00, 12.79it/s]


error: 1.311e-05 ± 1.842e-06


In [5]:
# %%timeit
# reco = quantizer.reconstruct_weight(packed, w.shape)

In [6]:
# %%timeit
# reco_fast = quantizer.reconstruct_weight_fast(packed, w.shape)