In [1]:
%env CUDA_VISIBLE_DEVICES=0
%env TRANSFORMERS_CACHE=/mnt/LLM/hub
%env HF_HOME=/mnt/LLM/hub
%env OMP_NUM_THREADS=16

import os
import sys
sys.path.insert(0, '..')

import time
import random
from tqdm.auto import trange
import ipynbname  # pip install ipynbname

import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers

from src.aq import QuantizedWeight


torch.set_num_threads(16)
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

input_loading_dir = '/extra_disk_1/vahe1994/BRRR/layer10.self_attn.q_proj.input_activation.pt'  # <-- stealing from Vahe
num_codebooks = 4
nbits_per_codebook = 8
out_group_size = 1
in_group_size = 16
batch_size = 16384
beam_size = 1
beam_search_epochs = 100
print_frequency = 10
scale_nbits = 0    # 0 means no scales, 16 means no compression;
codebook_values_nbits = 16  # less than 16 means we quantize codebooks as well
init_max_iter = 100
entropy_regularizer = 1e-3
entropy_warmup_epochs = 500

env: CUDA_VISIBLE_DEVICES=0
env: TRANSFORMERS_CACHE=/mnt/LLM/hub
env: HF_HOME=/mnt/LLM/hub
env: OMP_NUM_THREADS=16




In [2]:
import wandb

os.environ["WANDB_NOTEBOOK_NAME"] = os.path.join(os.getcwd(), ipynbname.name() + ".ipynb")

# start a new wandb run to track this script
run = wandb.init(
    # set the wandb project where this run will be logged
    dir=os.getcwd(),
    project="AddQuantization",
    entity = "rock-and-roll",
    save_code=True,
    name = f"{ipynbname.name()}_AQ_{num_codebooks=}_{out_group_size=}_{in_group_size=}_{nbits_per_codebook=}_{beam_search_epochs=}",
    settings=wandb.Settings(code_dir="."),
    # track hyperparameters and run metadata
    config={
    "num_codebooks" : num_codebooks,
    "out_group_size": out_group_size,
    "in_group_size": in_group_size,
    "group_size" : out_group_size * in_group_size,
    "batch_size" : batch_size,
    "beam_size" : beam_size,
    "nbits_per_codebook" : nbits_per_codebook,
    "codebook_values_nbits": codebook_values_nbits,
    "scale_nbits": scale_nbits,
    "beam_search_epochs": beam_search_epochs,
    "init_max_iter": init_max_iter,
    "entropy_regularizer": entropy_regularizer,
    }
)

[34m[1mwandb[0m: Currently logged in as: [33mjustheuristic[0m ([33mrock-and-roll[0m). Use [1m`wandb login --relogin`[0m to force relogin




In [3]:
model = transformers.AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-70b-hf", torch_dtype='auto', low_cpu_mem_usage=True)

X = torch.load(input_loading_dir, map_location='cpu').float().flatten(0, -2)
reference_weight = model.model.layers[10].self_attn.q_proj.weight.detach().to(device).float()

XTX = torch.zeros(X.shape[-1], X.shape[-1], device=device, dtype=torch.float64)
for i in range(0, len(X), batch_size):
    x_batch = X[i: i + batch_size].cuda().double()
    XTX.addmm_(x_batch.T, x_batch, alpha=1/len(X))
    del x_batch
XTX = XTX.float()
del X

Downloading shards:   0%|          | 0/15 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]

In [4]:
def _calculate_code_frequencies(codes: torch.LongTensor):
    code_counts = torch.zeros(num_codebooks, 2**nbits_per_codebook, dtype=torch.int64, device=codes.device)
    for codebook_index in range(num_codebooks):
        code_counts[codebook_index, :] = torch.bincount(
            codes[..., codebook_index].flatten(), minlength=2**nbits_per_codebook)
    return code_counts.float() / code_counts.sum(-1, keepdim=True)

def _calculate_code_entropy(codes: torch.LongTensor):
    """Calculate per-codebook code entropy measured in bits (base-2)"""
    probs = _calculate_code_frequencies(codes)
    logprobs = torch.log2(probs.clamp_min(1e-12))
    return - torch.sum(probs * logprobs, dim=-1)

def _get_entropy_penalties_upper_bound(codes: torch.LongTensor, regularizer: float):
    """Compute log-probability penalties that minimize a linearized upper bound on entropy """
    probs = _calculate_code_frequencies(codes)
    logprobs = torch.log2(probs.clamp_min(1e-12))
    return (- regularizer / logprobs.shape[-1]) * logprobs


# debugging code: run regularizer; compare entropy before/after

In [5]:
quantized_weight = QuantizedWeight(
    XTX=XTX, reference_weight=reference_weight, num_codebooks=num_codebooks,
    nbits_per_codebook=nbits_per_codebook, scale_nbits=scale_nbits, 
    out_group_size=out_group_size, in_group_size=in_group_size,
    verbose=True, max_iter=init_max_iter,   # faster init, not tested
)

initializing with kmeans:   0%|          | 0/4 [00:00<?, ?it/s]

  codebook_i, _, _ = fit_kmeans(


In [6]:
print("Entropy before:", _calculate_code_entropy(quantized_weight.codes))

Entropy before: tensor([7.9989, 7.9997, 7.9943, 7.9942], device='cuda:0')


In [7]:
code_penalties = _get_entropy_penalties_upper_bound(quantized_weight.codes, regularizer=0.01)
quantized_weight.beam_search_update_codes_(
    XTX, reference_weight, beam_size=beam_size, code_penalties=code_penalties,
    dim_rng=random.Random(), verbose=True)
print("Entropy after:", _calculate_code_entropy(quantized_weight.codes))

  0%|          | 0/2048 [00:00<?, ?it/s]

Entropy after: tensor([7.9732, 7.8986, 7.7041, 7.2989], device='cuda:0')


In [8]:
code_penalties = _get_entropy_penalties_upper_bound(quantized_weight.codes, regularizer=0.01)
quantized_weight.beam_search_update_codes_(
    XTX, reference_weight, beam_size=beam_size, code_penalties=code_penalties,
    dim_rng=random.Random(), verbose=True)
print("Entropy after:", _calculate_code_entropy(quantized_weight.codes))

  0%|          | 0/2048 [00:00<?, ?it/s]

Entropy after: tensor([5.2803, 2.5830, 2.0586, 0.4563], device='cuda:0')


In [9]:
# MSE + sum_c  num_codes_equal_to(c) * code_penalty[c]

# let code_penalty[j]  = - log code_probs[j]

# MSE - sum_c  num_codes_equal_to(c) * ( log code_probs[c])

# MSE + const * entropy

# Main calibration code

In [None]:
quantized_weight = QuantizedWeight(
    XTX=XTX, reference_weight=reference_weight, num_codebooks=num_codebooks,
    nbits_per_codebook=nbits_per_codebook, scale_nbits=scale_nbits, 
    out_group_size=out_group_size, in_group_size=in_group_size,
    verbose=True, max_iter=init_max_iter,   # faster init, not tested
)
run.log({"Avg_bits": quantized_weight.estimate_nbits_per_parameter()})
print("AVG bits:", quantized_weight.estimate_nbits_per_parameter())
opt = torch.optim.Adam(quantized_weight.parameters(), lr=1e-4, betas=(0.0, 0.95), amsgrad=True)


for epoch in range(1000):
    start = time.perf_counter()
    delta_weight = (quantized_weight() - reference_weight).double()
    loss = (delta_weight @ XTX.double()).flatten() @ delta_weight.flatten() / len(delta_weight)
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    run.log({'loss':loss.item()}, step=epoch)
    
    if epoch % print_frequency == 0:
        print(f"loss={loss.item():.10f}\t",
              f"time_on_epoch {epoch} = {time.perf_counter() - start}")
    if (epoch + 1) % beam_search_epochs == 0:
        print("Entropy before beam search:", _calculate_code_entropy(quantized_weight.codes))
        code_penalties = _get_entropy_penalties_upper_bound(quantized_weight.codes, regularizer=entropy_regularizer)
        if epoch < entropy_warmup_epochs:   # hypothesis: this could help last codebooks become more meaningful before we prune them via entropy regularizer
            code_penalties *= 0
            print("Not regularizing for epoch", epoch)
        quantized_weight.beam_search_update_codes_(
            XTX, reference_weight, beam_size=beam_size, code_penalties=code_penalties,
            dim_rng=random.Random(), verbose=True)
        print("Entropy after beam search:", _calculate_code_entropy(quantized_weight.codes))
        
#         if code_penalties is not None:
#             mean_code_nbits = sum(get_mean_nbits_by_codebook(quantized_weight.codes)) / num_codebooks
#             print(f"mean_code_nbits {mean_code_nbits:.5f}")
#             run.log({'Mean codebook length nbits': mean_code_nbits}, step=epoch)
#             if in_group_size > 1 and out_group_size > 1:
#                 curr_avg_bits  = calc_avg_bits(num_codebooks, 1, mean_code_nbits,
#                                      nbits_per_codebook, in_features, out_features, scale_nbits)
#                 run.log({"Avg_bits": curr_avg_bits}, step=epoch)

initializing with kmeans:   0%|          | 0/4 [00:00<?, ?it/s]

AVG bits: 2.005859375
loss=0.0321929083	 time_on_epoch 0 = 0.30653247493319213
loss=0.0159933244	 time_on_epoch 10 = 0.13437715300824493
loss=0.0127318873	 time_on_epoch 20 = 0.13457738293800503
loss=0.0116782057	 time_on_epoch 30 = 0.134450733079575
loss=0.0111255542	 time_on_epoch 40 = 0.13483770296443254
loss=0.0107775138	 time_on_epoch 50 = 0.1345619229832664
loss=0.0105360060	 time_on_epoch 60 = 0.13469453295692801
loss=0.0103587484	 time_on_epoch 70 = 0.1347333239391446
loss=0.0102224418	 time_on_epoch 80 = 0.13468431401997805
loss=0.0101137650	 time_on_epoch 90 = 0.13480222295038402
Entropy before beam search: tensor([7.9987, 7.9996, 7.9997, 7.9941], device='cuda:0')


  0%|          | 0/2048 [00:00<?, ?it/s]

Entropy after beam search: tensor([7.9985, 7.9988, 7.9960, 7.9831], device='cuda:0')
loss=0.0049323175	 time_on_epoch 100 = 0.134709634934552
loss=0.0049000231	 time_on_epoch 110 = 0.13448562601115555
loss=0.0048899715	 time_on_epoch 120 = 0.1345701760146767
loss=0.0048844423	 time_on_epoch 130 = 0.13458515598904341
loss=0.0048811730	 time_on_epoch 140 = 0.1345395160606131
loss=0.0048786117	 time_on_epoch 150 = 0.13452253595460206
loss=0.0048758113	 time_on_epoch 160 = 0.13480562600307167
loss=0.0048737549	 time_on_epoch 170 = 0.13468538597226143
loss=0.0048721806	 time_on_epoch 180 = 0.13461041601840407
loss=0.0048709052	 time_on_epoch 190 = 0.13450881594326347
Entropy before beam search: tensor([7.9985, 7.9988, 7.9960, 7.9831], device='cuda:0')


  0%|          | 0/2048 [00:00<?, ?it/s]

Entropy after beam search: tensor([7.9982, 7.9963, 7.9506, 7.5419], device='cuda:0')
loss=0.0038284234	 time_on_epoch 200 = 0.13733786495868117
loss=0.0038126490	 time_on_epoch 210 = 0.13716789591126144
loss=0.0038079356	 time_on_epoch 220 = 0.13720458594616503
loss=0.0038053051	 time_on_epoch 230 = 0.1372052360093221
loss=0.0038036157	 time_on_epoch 240 = 0.1372614960419014
loss=0.0038024541	 time_on_epoch 250 = 0.13720554602332413
loss=0.0038016361	 time_on_epoch 260 = 0.13716848602052778
loss=0.0038010355	 time_on_epoch 270 = 0.1372344959527254
loss=0.0038005370	 time_on_epoch 280 = 0.13724346598610282
loss=0.0038000721	 time_on_epoch 290 = 0.13733416702598333
Entropy before beam search: tensor([7.9982, 7.9963, 7.9506, 7.5419], device='cuda:0')


  0%|          | 0/2048 [00:00<?, ?it/s]

Entropy after beam search: tensor([7.9961, 7.9541, 6.9827, 3.4024], device='cuda:0')
loss=0.0051695032	 time_on_epoch 300 = 0.16384631302207708
loss=0.0050173219	 time_on_epoch 310 = 0.16333806398324668
loss=0.0050003570	 time_on_epoch 320 = 0.16364737402182072
loss=0.0049963343	 time_on_epoch 330 = 0.1634368139784783
loss=0.0049941654	 time_on_epoch 340 = 0.16428753302898258
loss=0.0049926795	 time_on_epoch 350 = 0.16437282401602715
loss=0.0049915527	 time_on_epoch 360 = 0.16335051401983947
loss=0.0049906475	 time_on_epoch 370 = 0.16335509507916868
loss=0.0049898923	 time_on_epoch 380 = 0.16339253494516015
loss=0.0049892443	 time_on_epoch 390 = 0.16330697503872216
Entropy before beam search: tensor([7.9961, 7.9541, 6.9827, 3.4024], device='cuda:0')


  0%|          | 0/2048 [00:00<?, ?it/s]

Entropy after beam search: tensor([7.9856, 7.4455, 3.5256, 1.0036], device='cuda:0')
loss=0.0076876331	 time_on_epoch 400 = 0.17860544798895717
loss=0.0075240223	 time_on_epoch 410 = 0.1784091090084985
loss=0.0075106094	 time_on_epoch 420 = 0.17836942803114653
loss=0.0075060319	 time_on_epoch 430 = 0.17836837901268154
loss=0.0075032820	 time_on_epoch 440 = 0.17839083902072161
loss=0.0075013630	 time_on_epoch 450 = 0.17844732908997685
loss=0.0074999108	 time_on_epoch 460 = 0.17835793900303543
loss=0.0074987534	 time_on_epoch 470 = 0.17834684904664755
loss=0.0074977965	 time_on_epoch 480 = 0.17841326992493123
loss=0.0074969833	 time_on_epoch 490 = 0.17838564002886415
Entropy before beam search: tensor([7.9856, 7.4455, 3.5256, 1.0036], device='cuda:0')


  0%|          | 0/2048 [00:00<?, ?it/s]

Entropy after beam search: tensor([7.9439, 5.7458, 1.9318, 0.6045], device='cuda:0')
loss=0.0094117466	 time_on_epoch 500 = 0.18043602001853287
loss=0.0092709140	 time_on_epoch 510 = 0.18020800000522286
loss=0.0092241957	 time_on_epoch 520 = 0.18038407107815146
loss=0.0092177592	 time_on_epoch 530 = 0.18025261105503887
loss=0.0092142588	 time_on_epoch 540 = 0.18031079100910574
loss=0.0092117680	 time_on_epoch 550 = 0.18030279001686722
loss=0.0092098558	 time_on_epoch 560 = 0.18029104091692716
loss=0.0092083175	 time_on_epoch 570 = 0.18037664098665118
loss=0.0092070379	 time_on_epoch 580 = 0.1802686209557578
loss=0.0092059457	 time_on_epoch 590 = 0.18062070093583316
Entropy before beam search: tensor([7.9439, 5.7458, 1.9318, 0.6045], device='cuda:0')


  0%|          | 0/2048 [00:00<?, ?it/s]

Entropy after beam search: tensor([7.8416, 4.2590, 1.3671, 0.5135], device='cuda:0')
loss=0.0105941560	 time_on_epoch 600 = 0.18127939198166132
loss=0.0104853044	 time_on_epoch 610 = 0.18106419208925217
loss=0.0104659060	 time_on_epoch 620 = 0.18145907192956656
loss=0.0104495522	 time_on_epoch 630 = 0.1815927519928664
loss=0.0104453061	 time_on_epoch 640 = 0.18115997198037803
loss=0.0104424215	 time_on_epoch 650 = 0.18114822206553072
loss=0.0104402021	 time_on_epoch 660 = 0.181067563011311
loss=0.0104384160	 time_on_epoch 670 = 0.18107824202161282
loss=0.0104369319	 time_on_epoch 680 = 0.18131924199406058
