In [17]:
%env CUDA_VISIBLE_DEVICES=3
%env TRANSFORMERS_CACHE=/mnt/LLM/hub
%env OMP_NUM_THREADS=16

import os
import sys
sys.path.insert(0, '..')

import time
import random
from tqdm.auto import trange
import ipynbname  # pip install ipynbname

import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers

from src.aq import QuantizedWeight
import numpy as np
import torch
import faiss


torch.set_num_threads(min(16, 16))
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input_loading_dir = '/extra_disk_1/vahe1994/BRRR/layer10.self_attn.q_proj.input_activation.pt'
num_codebooks = 2
nbits_per_codebook = 8
out_group_size = 1
in_group_size = 8
batch_size = 16384
beam_size = 1
beam_search_epochs = 100
print_frequency = 10
scale_nbits = 0    # 0 means no scales, 16 means no compression;
codebook_values_nbits = 16  # less than 16 means we quantize codebooks as well
init_max_iter = 100

env: CUDA_VISIBLE_DEVICES=3
env: TRANSFORMERS_CACHE=/mnt/LLM/hub
env: OMP_NUM_THREADS=16


In [18]:
model = transformers.AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-70b-hf", torch_dtype='auto', low_cpu_mem_usage=True)

X = torch.load(input_loading_dir, map_location='cpu').float().flatten(0, -2)
reference_weight = model.model.layers[10].self_attn.q_proj.weight.detach().to(device).float()

XTX = torch.zeros(X.shape[-1], X.shape[-1], device=device, dtype=torch.float64)
for i in range(0, len(X), batch_size):
    x_batch = X[i: i + batch_size].cuda().double()
    XTX.addmm_(x_batch.T, x_batch, alpha=1/len(X))
    del x_batch
XTX = XTX.float()
del X

Downloading shards:   0%|          | 0/15 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]

In [19]:
quantized_weight = QuantizedWeight(
    reference_weight=reference_weight, num_codebooks=num_codebooks,
    nbits_per_codebook=nbits_per_codebook, scale_nbits=scale_nbits, 
    out_group_size=out_group_size, in_group_size=in_group_size,
    verbose=True, max_iter=init_max_iter,   # faster init, not tested
)


initializing with kmeans:   0%|          | 0/2 [00:00<?, ?it/s]

## Entropy penalty upper bounds (fixed!)

In [20]:
from typing import Optional
def _calculate_code_frequencies(codes: torch.LongTensor):
    code_counts = torch.zeros(num_codebooks, 2**nbits_per_codebook, dtype=torch.int64, device=codes.device)
    for codebook_index in range(num_codebooks):
        code_counts[codebook_index, :] = torch.bincount(
            codes[..., codebook_index].flatten(), minlength=2**nbits_per_codebook)
    return code_counts.float() / code_counts.sum(-1, keepdim=True)

def _calculate_code_entropy(codes: torch.LongTensor, eps: float = 1e-20):
    """Calculate per-codebook code entropy measured in bits (base-2)"""
    probs = _calculate_code_frequencies(codes)
    logprobs = torch.log2(probs.clamp_min(eps))
    return - torch.sum(probs * logprobs, dim=-1)

import huffman
def _get_huffman_penalties_upper_bound(codes: torch.LongTensor, regularizer: float):
    """Compute log-probability penalties that minimize a linearized upper bound on entropy """
    penalties = torch.empty(num_codebooks, 2 ** nbits_per_codebook, device=codes.device, dtype=torch.float32)
    probs = _calculate_code_frequencies(codes)
    
    for codebook_index in range(num_codebooks):
        num_codes = torch.as_tensor(codes[..., codebook_index].numel(), device=probs.device)
        missing_value_length = torch.log2(num_codes).item()
        
        huffman_codes = huffman.codebook([(i, probs[codebook_index, i].item()) for i in range(2 ** nbits_per_codebook)])
        code_lengths = torch.as_tensor([
            len(huffman_codes.get(i, missing_value_length)) for i in range(2 ** nbits_per_codebook)],
            device=probs.device, dtype=torch.float32)
        penalties[codebook_index] = (regularizer / probs.shape[-1]) * code_lengths
    return penalties

def _get_entropy_penalties_upper_bound(codes: torch.LongTensor, regularizer: float, eps: Optional[float] = None):
    """Compute log-probability penalties that minimize a linearized upper bound on entropy """
    probs = _calculate_code_frequencies(codes)
    num_codes = torch.as_tensor(codes[..., 0].numel(), device=probs.device)
    if eps is None:
        eps = 1. / num_codes
    logprobs = torch.log2(probs.clamp_min(eps))
    return (- regularizer / probs.shape[-1]) * logprobs 

In [21]:
_calculate_code_entropy(quantized_weight.get_codes())

tensor([7.9844, 7.9880], device='cuda:0')

In [22]:
penalties = _get_entropy_penalties_upper_bound(quantized_weight.get_codes(), regularizer=1.0)
penalties[0, quantized_weight.get_codes()[0]].sum()

tensor(64.0238, device='cuda:0')

In [23]:
penalties = _get_huffman_penalties_upper_bound(quantized_weight.get_codes(), regularizer=1.0)
penalties[0, quantized_weight.get_codes()[0]].sum()

tensor(64.0391, device='cuda:0')

### Copy of beam_search_l2 with entropy penalties introduced

In [24]:
"""Beam search that minimizes ||Wref - Wq||^2 w.r.t. Wq"""
import math
import random
import time
from typing import List, Optional

import torch
import torch.nn.functional as F

from src.utils import _dequantize_weight, maybe_script


@torch.inference_mode
def beam_search_optimal_codes(
    reference_weight: torch.Tensor,
    codebooks: torch.Tensor,
    prev_codes: torch.Tensor,
    scales: Optional[torch.Tensor],
    beam_size: int,
    stochastic_rounding_tau: float = 0.0,
    chunk_size_bytes: int = 2**32,
    penalties: Optional[torch.Tensor] = None,
    penalty_weights: Optional[torch.Tensor] = None,
    dim_rng: Optional[random.Random] = None,
    force_update: bool = False,
    max_update_fraction: float = 1.0,
    code_selection_temperature: float = 0,
    trust_ratio: Optional[float] = None,
) -> torch.Tensor:
    """
    Update codes using beam search to minimize L2 error in code values (regardless of activations)
    :param reference_weight: a target for L2 error, [out_features, in_features]
    :param codebooks: look-up tables of codes, shape: [num_codebooks, codebook_size, out_group_size, in_group_size]
    :param prev_codes: previous-best integer weight codes, shape: [num_output_groups, num_input_groups, num_codebooks]
    :param scales: weight will be multiplied by this factor, shape = [num_output_groups, num_input_groups or 1, 1, 1]
    :param dim_rng: a source of randomness to (optionally) shuffle the order in which the beam search runs
      None = update dimensions and codebooks in their natural order (0, 1, ..., n)
      random.Random(optional_seed) = shuffle dimensions at random, optionally using the specified seed

    :param beam_size: consider up to this many best encoding combinations
    :param stochastic_rounding_tau: if positive, each time the algorithm chooses a code, it will have a probability
        of replacing it with the second-best choice. If the two best codes increase the error by delta1 and delta2,
        then the probability of choosing each code is P_i = delta_i ^ -1/tau / (sum_j_in_choices delta_j ^ -1/tau).
        Note that if there is a code that has zero error, the algorithm will choose allways choose such a code
    :param chunk_size_bytes: process this many candidates at a time; reduce to save memory
    :param penalties: additional penalty to loss for picking a given code, [num_codebooks, codebook_size]
    :param penalty_weigths: multiplicative scale to penalties (same shape as codes) so that the final objective is 
        ||reference_weight - dequantize(codes)||^2 + sum_[i,j,c] penalties[codes[i, j, c]] * penalty_weights[i,j,c]
    
    :param force_update: if True, the algorithm will force codes to change even if code is optimal in terms
     of mean squared error. By default, the algorithm forces *all* weights to update this way, which may change weights
     too much. To limit the numer of updated weights, set max_code_change and trust_ratio.
    :param max_update_fraction: the maximum portion of discrete code groups that *can* be updated;
        By default, all codes can be updated. If < 1, only this portion of all code groups is allowed to update.
        The algorithm selects the codes for update based on the difference between de-quantized and reference_weight.
        If there are multiple codebooks, changing any one code responsible for the group counts as code group changed.
        Note that small max_code_change also speeds up computation since not all codes need beam search.
        If the number of weights do not divide evenly, the algoritm will round the number of updates up.
    :param code_selection_temperature: only used if max_code_change > 1; by default, prioritize updating the codes with
        the largest delta = ||(reference_weight - quantized_weight) * mask_only_weights_that_depend_on_this_code||_2 .
        If temperature > 0, the updated codes are instead *sampled* at random, proportionally to delta^(1/temperature) .
    :param trust_ratio: if not None, the algorithm only admits code changes as long as they do not change too much.
        Formally, ||new_quantized_weight - prev_quantized_weight|| / ||prev_quantized_weight|| <= trust_ratio
        If this is not true, the algorithm will reset some of the new quantized weights to their old values until the
        constraint becomes satisfied. The algorithm still prioritizes changes to weights with largest delta (see above).
        If code_change_temperature > 0, the algorithm instead samples which weights to change with the same probability.
        The algorithm will always allow changing exactly *one* code in excess of trust ratio to ensure that at least
        one weight is updated. If both this and max_code_change is set, both these constraints are enforced.
    :return: the best quantization codes found within constraints, same shape as prev_codes

    """
    assert 0 < max_update_fraction <= 1 and (trust_ratio is None or trust_ratio > 0)
    # reshape references, codes and codebooks so they are no longer group-wise
    num_output_groups, num_input_groups, num_codebooks = prev_codes.shape
    _num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
    
    if penalty_weights is not None:
        assert penalties is not None
        assert penalty_weights.shape == codes.shape  # (num_output_groups, num_input_groups, num_codebooks)
    if penalties is not None:
        assert penalties.shape == codebooks.shape[:2]  # (num_codebooks, codebook_size)
        if penalty_weights is None:
            penalty_weights = torch.ones_like(prev_codes, dtype=codebooks.dtype)
        penalty_weights = penalty_weights / scales.square().squeeze(-1)
        # why: in the inner beam search, we minimize un-scaled MSE; dividing penalty weights by scales^2 is ...
        # ... equivalent to optimizing scaled MSE; weights with larger scales are more important.
        
        penalty_weights = penalty_weights.flatten(0, 1)  # [num_groups, num_codebooks]

    flat_unscaled_reference = reference_weight.reshape(
        num_output_groups, out_group_size, num_input_groups, in_group_size
    ).permute(
        0, 2, 1, 3
    )  # [num_output_groups, num_input_groups, out_group_size, in_group_size]
    if scales is not None:
        flat_unscaled_reference = flat_unscaled_reference / scales
        # divide by scales; the resulting problem is equivalent to multiplying dequantized weight
    flat_unscaled_reference = flat_unscaled_reference.flatten(2, 3).flatten(0, 1) # [num_output_groups*num_input_groups,out_group_size*in_group_size]
    flat_prev_codes = prev_codes.flatten(0, -2) 
    flat_codebooks = codebooks.flatten(-2, -1).detach()
    dim_order = list(range(num_codebooks))
    if dim_rng is not None:
        dim_rng.shuffle(dim_order)

    def _update_flat_codes(_flat_reference, _flat_codes):
        """update _flat_codes [num_groups, num_codebooks] to approximate _flat_reference [num_groups, group_size]"""
        if num_codebooks == 1 and beam_size == 1 and stochastic_rounding_tau == 0 and not force_update:
            # a faster algorithm for a special case of one codebook
            return _greedy_find_best_codes(
                reference=_flat_reference,
                codebook=flat_codebooks[0],
                chunk_size_values=chunk_size_bytes // _flat_reference[0, 0].nbytes,
                code_dtype=prev_codes.dtype,
                penalties=penalties,
                penalty_weights=penalty_weights,
            )
        else:
            return _beam_search_update_codes_groupwise(
                reference=_flat_reference,
                codebooks=flat_codebooks,
                codes=_flat_codes,
                beam_size=beam_size,
                stochastic_rounding_tau=stochastic_rounding_tau,
                force_update=force_update,
                chunk_size_values=chunk_size_bytes // _flat_reference[0, 0].nbytes,
                dim_order=dim_order,
                penalties=penalties,
                penalty_weights=penalty_weights,
            )

    def _groupwise_squared_norms(delta: torch.Tensor):
        """
        Given a matrix delta [out_features, in_features], compute a tensor [num_output_groups, num_input_groups] that
        contains the squared sum of elements of delta from each tile of (out_group_size, in_group_size) values.
        """
        return (
            delta.view(delta.shape[0] // out_group_size, out_group_size, delta.shape[1] // in_group_size, in_group_size)
            .square()
            .sum(dim=(1, 3))
        )

    flat_indices_to_update = prev_dequantized_weight = None
    if max_update_fraction < 1 or trust_ratio is not None:
        # precompute ordered code indices to be used for constraints on the number of updates
        prev_dequantized_weight = _dequantize_weight(prev_codes, codebooks, scales)
        num_codes_to_update = int(math.ceil(max_update_fraction * num_output_groups * num_input_groups))
        difference_with_reference_squared_norms = _groupwise_squared_norms(reference_weight - prev_dequantized_weight)
        # ^-- [num_output_groups, num_input_groups]
        if code_selection_temperature > 0:
            flat_indices_to_update = torch.pow(
                difference_with_reference_squared_norms.flatten(),
                0.5 / code_selection_temperature,
                # note: temperature is multuplied by 0.5 because sampling is proportional to norms without square
            ).multinomial(num_samples=num_codes_to_update, replacement=False)
        else:
            flat_indices_to_update = torch.topk(
                difference_with_reference_squared_norms.flatten(), k=num_codes_to_update, largest=True, sorted=True
            ).indices

    if max_update_fraction == 1:
        print(f"{flat_unscaled_reference.shape=}",f"{flat_prev_codes.shape=}",)
        print(f"{penalties.shape=}",f"{penalty_weights.shape=}")
        flat_new_codes = _update_flat_codes(flat_unscaled_reference, flat_prev_codes)
    else:
        # may be penalty_weights -> penalty_weights[flat_indices_to_update]
        penalty_weights = penalty_weights[flat_indices_to_update]
        flat_new_codes = flat_prev_codes.index_put(  # note: this is an out-of-place op that does not modify prev codes
            (flat_indices_to_update[:, None], torch.arange(num_codebooks, device=codebooks.device)[None, :]),
            _update_flat_codes(
                flat_unscaled_reference[flat_indices_to_update], flat_prev_codes[flat_indices_to_update]
            ),
        )

    if trust_ratio is not None:
        assert isinstance(flat_indices_to_update, torch.Tensor) and isinstance(prev_dequantized_weight, torch.Tensor)
        new_dequantized_weight = _dequantize_weight(flat_new_codes.view_as(prev_codes), codebooks, scales)
        weight_change_squared_norms = _groupwise_squared_norms(new_dequantized_weight - prev_dequantized_weight)
        # ^-- shape: [num_output_groups, num_input_groups]

        flat_ordered_weight_change_squared_norms = weight_change_squared_norms.flatten()[flat_indices_to_update]
        flat_ordered_cumulative_norms = flat_ordered_weight_change_squared_norms.cumsum(0).sqrt()
        # [num_codes_to_update]

        num_codes_selected = 1 + torch.searchsorted(
            flat_ordered_cumulative_norms, trust_ratio * prev_dequantized_weight.norm(), side="left"
        )
        truncated_flat_indices_to_update = flat_indices_to_update[:num_codes_selected]  # sorted most to least important
        flat_new_codes = flat_prev_codes.index_put(  # <-- note: this is an out-of-place operation
            (truncated_flat_indices_to_update[:, None], torch.arange(num_codebooks, device=codebooks.device)[None, :]),
            flat_new_codes[truncated_flat_indices_to_update],
        )
    return flat_new_codes.view_as(prev_codes)


@maybe_script
def _beam_search_update_codes_groupwise(
    reference: torch.Tensor,
    codebooks: torch.Tensor,
    codes: torch.Tensor,
    *,
    beam_size: int,
    stochastic_rounding_tau: float,
    chunk_size_values: int,
    dim_order: Optional[List[int]],
    force_update: bool,
    penalties: Optional[torch.Tensor],
    penalty_weights: Optional[torch.Tensor],
) -> torch.Tensor:
    """
    :param reference: [num_groups, group_size]
    :param codes: [num_groups, num_codebooks]
    :param codebooks: [num_codebooks, codebook_size, group_size]
    :returns: [num_groups, num_codebooks]
    """
    assert (penalty_weights is None) == (penalties is None)
    if penalty_weights is not None:
        assert penalties is not None
        assert penalty_weights.shape == (len(reference), codes.shape[-1]), penalty_weights.shape
    if stochastic_rounding_tau > 0:
        assert beam_size >= 2, "with stochastic rounding, we need at least 2 hypotheses to choose from"

    prev_codes = codes
    device = reference.device
    num_groups, group_size = reference.shape
    num_codebooks, codebook_size, group_size = codebooks.shape
    codebook_offsets = torch.arange(0, num_codebooks * codebook_size, codebook_size, device=device)  # [num_codebooks]
    original_dequantized_vectors = F.embedding_bag(
        codes + codebook_offsets, codebooks.flatten(0, 1), mode="sum"
    )  # [num_groups, group_size]
    if dim_order is None:
        dim_order = list(range(num_codebooks))

    code_norms_sq = codebooks.square().sum(-1)  # [num_codebooks, codebook_size]
    beam_codes = codes.clone().unsqueeze(1)  # [num_groups, current_beam_size, num_codebooks]
    residue = (reference - original_dequantized_vectors).view(num_groups, 1, group_size)
    # shape: [num_groups, current_beam_size, group_size]
    direction = residue.clone().view(num_groups, group_size) if force_update else torch.empty(0)

    for i, codebook_index in enumerate(dim_order):
        current_beam_size = residue.shape[1]
        is_last_step = i == len(dim_order) - 1
        # ^-- [num_groups, current_beam_size, group_size]
        residue = residue + F.embedding(beam_codes[..., codebook_index], codebooks[codebook_index, ...])
        if beam_size > 1 or stochastic_rounding_tau > 0:
            residue_norms_sq = residue.square().sum(-1).unsqueeze(-1)  # [num_groups, current beam size, 1]
        else:
            residue_norms_sq = torch.empty(0, device=device)  # when doing greedy search, these are const

        if not is_last_step:
            target_num_candidates = beam_size + int(stochastic_rounding_tau > 0)
        else:
            target_num_candidates = 2 if stochastic_rounding_tau > 0 or force_update else 1

        flat_best_indices = torch.empty(num_groups, target_num_candidates, device=device, dtype=codes.dtype)
        chunk_size_rows = chunk_size_values // (codebook_size * current_beam_size) // 32
        for chunk_start in range(0, num_groups, chunk_size_rows):
            chunk_end = min(chunk_start + chunk_size_rows, num_groups)
            scores = torch.matmul(residue[chunk_start:chunk_end], codebooks[codebook_index].T)
            if beam_size > 1 or stochastic_rounding_tau > 0:
                scores = residue_norms_sq[chunk_start:chunk_end] - 2 * scores + code_norms_sq[codebook_index]
            else:
                scores = -2 * scores + code_norms_sq[codebook_index]  # residue norms are const(j)
            # ^-- [num_groups_chunk, beam_size, codebook_size]
            if penalty_weights is not None:
                assert isinstance(penalties, torch.Tensor)
                assert isinstance(penalty_weights, torch.Tensor)
                scores = scores.add_(
                    penalties[codebook_index, None, None, :] * 
                    penalty_weights[chunk_start: chunk_start + chunk_size_rows, codebook_index, None, None]
                )
        

            flat_best_losses_chunk, flat_best_indices_chunk = torch.topk(
                scores.flatten(1, 2),
                k=target_num_candidates,
                largest=False,
                sorted=is_last_step or beam_size > 1 or stochastic_rounding_tau > 0,
            )  # [num_groups_chunk, target_num_candidates]

            if stochastic_rounding_tau > 0:
                errors = flat_best_losses_chunk.relu().sqrt()  # non-squared errors
                scores = torch.pow(errors / errors.sum(-1, keepdim=True), -1 / stochastic_rounding_tau)
                # ^-- [num_groups_chunk, beam_size + 1]
                keep_prob = scores[:, :-1] / (scores[:, :-1] + scores[:, 1:])  # [num_groups, k_best]
                keep_prob = torch.where(torch.isinf(scores[:, :-1]), 1.0, keep_prob)
                keep = torch.less_equal(torch.rand_like(keep_prob), keep_prob)
                flat_best_indices_chunk = torch.where(
                    keep, flat_best_indices_chunk[:, :-1], flat_best_indices_chunk[:, 1:]
                )

            flat_best_indices[chunk_start:chunk_end] = flat_best_indices_chunk

        arange_num_groups = torch.arange(num_groups, device=device)
        best_hypo_source_ids = flat_best_indices // codebook_size
        best_hypo_codes = flat_best_indices % codebook_size
        beam_codes = beam_codes[arange_num_groups[:, None], best_hypo_source_ids, :]
        beam_codes[:, :, codebook_index] = best_hypo_codes.to(beam_codes.dtype)
        # ^-- [num_groups, beam_size, num_codebooks]

        if not is_last_step:
            residue = residue - F.embedding(beam_codes[..., codebook_index], codebooks[codebook_index, ...])

    if force_update:
        assert beam_codes.shape[1] == 2
        best_codes = beam_codes[:, 0, :]
        second_best_codes = beam_codes[:, 1, :]
        best_code_changed = torch.ne(best_codes, prev_codes).any(dim=-1)
        return torch.where(best_code_changed.unsqueeze(-1), best_codes, second_best_codes)
    else:
        return beam_codes[:, 0, :]


@maybe_script
def _greedy_find_best_codes(
    reference: torch.Tensor, codebook: torch.Tensor,
    penalties: Optional[torch.Tensor],
    penalty_weights: Optional[torch.Tensor],
    chunk_size_values: int, code_dtype: torch.dtype,
) -> torch.Tensor:
    """
    :param reference: [num_groups, group_size]
    :param codebook: [codebook_size, group_size]
    :param penalties: [codebook_size, 1]
    :param penalty_weights: [num_groups, 1]
    :param chunk_size_values: how many values can be materialized in memory simultaneously
    :parma code_dtype the dtype of optimal codes returned by this function
    :returns: codes [num_groups, 1]
    """
    assert (penalty_weights is None) == (penalties is None)
    if penalty_weights is not None:
        assert penalties is not None
        assert penalty_weights.shape == (len(reference), 1)
        penalties = penalties.flatten() # [codebook_size]
        penalty_weights = penalty_weights.flatten() # [num_groups]
    codebook_t = codebook.T.contiguous()
    chunk_size = chunk_size_values // len(codebook)
    codebook_norms_sq = codebook.square().sum(dim=-1)
    new_codes = torch.empty((len(reference),), dtype=code_dtype, device=reference.device)
    for chunk_start in range(0, len(reference), chunk_size):
        scores = torch.addmm(
            codebook_norms_sq[None], reference[chunk_start : chunk_start + chunk_size], codebook_t, alpha=-2
        )  # ||quantized^2|| - 2 * <reference * codebook> + omitted_as_const[ || reference ||^2 ]
        if penalty_weights is not None:
            assert isinstance(penalties, torch.Tensor)
            assert isinstance(penalty_weights, torch.Tensor)
            scores = scores.add_(penalties[None, :] * penalty_weights[chunk_start: chunk_start + chunk_size, None])
        new_codes[chunk_start : chunk_start + chunk_size] = scores.argmin(-1)
    return new_codes.unsqueeze(-1)


def _find_optimal_codebooks(
    reference: torch.Tensor,
    codebooks: torch.Tensor,
    codes: torch.Tensor,
) -> torch.Tensor:
    num_samples = len(reference)
    num_codebooks, codebook_size, group_size = codebooks.shape

    # compute optimal codebooks via linsolve
    codebook_offsets = torch.arange(num_codebooks, device=codes.device) * codebook_size
    code_indicators = torch.sparse_coo_tensor(
        indices=torch.stack(
            [
                torch.arange(num_samples * num_codebooks, device=codes.device) // num_codebooks,
                (codes + codebook_offsets).flatten(),
            ],
            0,
        ),
        values=torch.ones(num_samples * num_codebooks, device=codes.device),
        size=(num_samples, num_codebooks * codebook_size),
    )
    cooc = (code_indicators.T @ code_indicators).coalesce()
    rhs = code_indicators.T @ reference

    try:
        cooc = cooc.to_dense()
        cooc[torch.arange(len(cooc)), torch.arange(len(cooc))].clamp_min_(1.0)
        optimal_codebooks = (torch.linalg.lstsq(cooc, rhs)).solution.reshape(num_codebooks, codebook_size, group_size)
    except Exception as e:
        print(f"Linsolve failed with {e}")
        optimal_codebooks = codebooks
    return optimal_codebooks


## Demo: optimize entropy

* (not tested for >1 codebooks!)
* (not tested with user-defined penalty weights!)

In [34]:
quantized_weight = QuantizedWeight(
    reference_weight=reference_weight, num_codebooks=num_codebooks,
    nbits_per_codebook=nbits_per_codebook, scale_nbits=scale_nbits, 
    out_group_size=out_group_size, in_group_size=in_group_size,
    verbose=True, max_iter=init_max_iter,   # faster init, not tested
)

initializing with kmeans:   0%|          | 0/2 [00:00<?, ?it/s]

In [35]:


print('entropy:', _calculate_code_entropy(quantized_weight.get_codes()))

for i in range(10):
#     quantized_weight.get_codes()[...] =  beam_search_optimal_codes(
#         reference_weight, quantized_weight.get_codebooks(), quantized_weight.get_codes(), quantized_weight.get_scales(),
#         penalties=_get_entropy_penalties_upper_bound(quantized_weight.get_codes(), regularizer=0.1),
#         beam_size=1,
#     )
    prev_codes = quantized_weight.get_codes().clone()
    new_codes = quantized_weight.beam_search_update_codes_(
                        reference_weight=reference_weight,
                        beam_size=1,
                        stochastic_rounding_tau=0.0,
                        max_update_fraction=0.2,
                        force_update=False,
                        code_selection_temperature=0,
                        trust_ratio=None,
                        dim_rng=random.Random(None),
                        penalties=_get_entropy_penalties_upper_bound(prev_codes,  regularizer=0.1)
                    )

    print('entropy:', _calculate_code_entropy(quantized_weight.get_codes()))

entropy: tensor([7.9842, 7.9918], device='cuda:0')
entropy: tensor([7.9827, 7.9753], device='cuda:0')
entropy: tensor([7.9825, 7.7727], device='cuda:0')
entropy: tensor([7.9787, 7.1288], device='cuda:0')
entropy: tensor([7.9732, 7.0851], device='cuda:0')
entropy: tensor([7.9714, 7.0846], device='cuda:0')
entropy: tensor([7.9709, 7.0846], device='cuda:0')
entropy: tensor([7.9707, 7.0846], device='cuda:0')
entropy: tensor([7.9706, 7.0846], device='cuda:0')
entropy: tensor([7.9706, 7.0846], device='cuda:0')
entropy: tensor([7.9706, 7.0846], device='cuda:0')


In [15]:
_get_entropy_penalties_upper_bound(quantized_weight.get_codes(), regularizer=0.1)

torch.Size([1, 1024])

In [17]:
quantized_weight.get_codes().shape

torch.Size([8192, 1024, 1])

In [19]:
torch.ones_like(quantized_weight.get_codes()).shape

torch.Size([8192, 1024, 1])

In [21]:
2**10

1024