In [1]:
%env CUDA_VISIBLE_DEVICES=3
%env TRANSFORMERS_CACHE=/mnt/LLM/hub
%env HF_HOME=/mnt/LLM/hub
%env OMP_NUM_THREADS=16
%env AQ_USE_JIT=0
%load_ext autoreload
%autoreload 2

import os
import sys
sys.path.insert(0, '..')

import time
import random
from tqdm.auto import trange
import ipynbname  # pip install ipynbname

import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers

from src.aq import QuantizedWeight


torch.set_num_threads(16)
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

input_loading_dir = '/extra_disk_1/vahe1994/BRRR/layer10.self_attn.q_proj.input_activation.pt'  # <-- stealing from Vahe
num_codebooks = 1
nbits_per_codebook = 8
out_group_size = 1
in_group_size = 8
batch_size = 16384
beam_size = 1
beam_search_epochs = 100
print_frequency = 10
scale_nbits = 0    # 0 means no scales, 16 means no compression;
codebook_values_nbits = 16  # less than 16 means we quantize codebooks as well
init_max_iter = 100
symmetric = True

env: CUDA_VISIBLE_DEVICES=3
env: TRANSFORMERS_CACHE=/mnt/LLM/hub
env: HF_HOME=/mnt/LLM/hub
env: OMP_NUM_THREADS=16
env: AQ_USE_JIT=0




In [2]:
import wandb

os.environ["WANDB_NOTEBOOK_NAME"] = os.path.join(os.getcwd(), ipynbname.name() + ".ipynb")

# start a new wandb run to track this script
run = wandb.init(
    # set the wandb project where this run will be logged
    dir=os.getcwd(),
    project="AddQuantization",
    entity = "rock-and-roll",
    save_code=True,
    name = f"{ipynbname.name()}_AQ_{num_codebooks=}_{out_group_size=}_{in_group_size=}_{nbits_per_codebook=}_{beam_search_epochs=}",
    settings=wandb.Settings(code_dir="."),
    # track hyperparameters and run metadata
    config={
    "num_codebooks" : num_codebooks,
    "out_group_size": out_group_size,
    "in_group_size": in_group_size,
    "group_size" : out_group_size * in_group_size,
    "batch_size" : batch_size,
    "beam_size" : beam_size,
    "nbits_per_codebook" : nbits_per_codebook,
    "codebook_values_nbits": codebook_values_nbits,
    "scale_nbits": scale_nbits,
    "beam_search_epochs": beam_search_epochs,
    "init_max_iter": init_max_iter,
    }
)

[34m[1mwandb[0m: Currently logged in as: [33mjustheuristic[0m ([33mrock-and-roll[0m). Use [1m`wandb login --relogin`[0m to force relogin




In [3]:
model = transformers.AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-70b-hf", torch_dtype='auto', low_cpu_mem_usage=True)

X = torch.load(input_loading_dir, map_location='cpu').float().flatten(0, -2)
reference_weight = model.model.layers[10].self_attn.q_proj.weight.detach().to(device).float()

XTX = torch.zeros(X.shape[-1], X.shape[-1], device=device, dtype=torch.float64)
for i in range(0, len(X), batch_size):
    x_batch = X[i: i + batch_size].cuda().double()
    XTX.addmm_(x_batch.T, x_batch, alpha=1/len(X))
    del x_batch
XTX = XTX.float()
del X

Downloading shards:   0%|          | 0/15 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]

In [4]:
quantized_weight = QuantizedWeight(
    XTX=XTX, reference_weight=reference_weight, num_codebooks=num_codebooks,
    nbits_per_codebook=nbits_per_codebook, scale_nbits=scale_nbits, 
    out_group_size=out_group_size, in_group_size=in_group_size,
    verbose=True, max_iter=init_max_iter, symmetric=symmetric,
)

initializing with kmeans:   0%|          | 0/1 [00:00<?, ?it/s]

  clusters[0]


In [None]:
run.log({"Avg_bits": quantized_weight.estimate_nbits_per_parameter()})
print("AVG bits:", quantized_weight.estimate_nbits_per_parameter())
opt = torch.optim.Adam(quantized_weight.parameters(), lr=1e-4, betas=(0.0, 0.95), amsgrad=True)


for epoch in range(1000):
    start = time.perf_counter()
    delta_weight = (quantized_weight() - reference_weight).double()
    loss = (delta_weight @ XTX.double()).flatten() @ delta_weight.flatten() / len(delta_weight)
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    run.log({'loss':loss.item()}, step=epoch)
    
    if epoch % print_frequency == 0:
        print(f"loss={loss.item():.10f}\t",
              f"time_on_epoch {epoch} = {time.perf_counter() - start}")
    if (epoch + 1) % beam_search_epochs == 0:
        quantized_weight.beam_search_update_codes_(
            XTX, reference_weight, beam_size=beam_size, code_penalties=None,
            dim_rng=random.Random(), verbose=True)
        


AVG bits: 2.00244140625
loss=0.0218783334	 time_on_epoch 0 = 0.1448188559152186
loss=0.0155558518	 time_on_epoch 10 = 0.13680543913505971
loss=0.0132976436	 time_on_epoch 20 = 0.13702892884612083
loss=0.0124233356	 time_on_epoch 30 = 0.1369958990253508
loss=0.0119796950	 time_on_epoch 40 = 0.1369604878127575
loss=0.0117023647	 time_on_epoch 50 = 0.1370939170010388
loss=0.0115072961	 time_on_epoch 60 = 0.1370458968449384




loss=0.0113597860	 time_on_epoch 70 = 0.13706981693394482
loss=0.0112423442	 time_on_epoch 80 = 0.13707614596933126
loss=0.0111450235	 time_on_epoch 90 = 0.13706385693512857


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0042712641	 time_on_epoch 100 = 0.1376357520930469
loss=0.0042563154	 time_on_epoch 110 = 0.13718988094478846
loss=0.0042491754	 time_on_epoch 120 = 0.1372395809739828
loss=0.0042444595	 time_on_epoch 130 = 0.1378013410139829
loss=0.0042409266	 time_on_epoch 140 = 0.13743733102455735
loss=0.0042380960	 time_on_epoch 150 = 0.13724016002379358
loss=0.0042357343	 time_on_epoch 160 = 0.13795034098438919
loss=0.0042337104	 time_on_epoch 170 = 0.13740576012060046
loss=0.0042319429	 time_on_epoch 180 = 0.13723252899944782
loss=0.0042303771	 time_on_epoch 190 = 0.1370161089580506


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0030785928	 time_on_epoch 200 = 0.13724402501247823
loss=0.0030707106	 time_on_epoch 210 = 0.137341083958745
loss=0.0030684524	 time_on_epoch 220 = 0.13695154408924282
loss=0.0030672045	 time_on_epoch 230 = 0.13692927290685475
loss=0.0030663771	 time_on_epoch 240 = 0.13695572409778833
loss=0.0030657751	 time_on_epoch 250 = 0.13697975385002792
loss=0.0030653117	 time_on_epoch 260 = 0.13693975191563368
loss=0.0030649412	 time_on_epoch 270 = 0.13702199282124639
loss=0.0030646363	 time_on_epoch 280 = 0.13690252206288278
