In [None]:
# Instead of !pip install gguf
# Here, we must use local ik_llama.cpp
# must use python 3.10, cause we tried python 3.12 the -dequant file is wrong!
# conda create -yn mmlupro python=3.10
# also must pip installed after pip installed exllamav3
%cd ik_llama.cpp
!pip install .
%cd -

In [None]:
import sys
import os
# seems like NextN/MTP tensors for GLM4_MOE in tensor_mapping.py doesn't work, so we delete it currently
gguf_sys_path = f"{sys.prefix}/lib/python{sys.version_info.major}.{sys.version_info.minor}/site-packages/gguf"
!cp gguf-py/constants.py   {gguf_sys_path}
!cp gguf-py/gguf_reader.py {gguf_sys_path}
!cp gguf-py/quants.py      {gguf_sys_path}

In [None]:
import os
import torch
import numpy as np
import gguf
from gguf import GGUFReader
from gguf.constants import GGMLQuantizationType
import re
import subprocess
from safetensors import safe_open
from typing import Any
from pathlib import Path

torch.set_grad_enabled(False)

def data_get(
    data, offset: int, dtype: np.typing.DTypeLike, count: int = 1) -> np.typing.NDArray[Any]:
    count = int(count)
    itemsize = int(np.empty([], dtype = dtype).itemsize)
    end_offs = offset + itemsize * count
    return (
        data[offset:end_offs]
        .view(dtype = dtype)[:count]
    )

def data_read_string(data, offset):
    str_length = data_get(data, offset, np.int32)[0]
    offset += 4
    byte = data[offset : offset+str_length]
    value = byte.tobytes().decode('utf-8')
    return value, offset+str_length

# Dequant

<details>
<summary>permute and inverse_permute demo</summary>

https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py#L222

```python
def permute(w, n_heads, dim1, dim2):
    return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)

def inverse_permute(w, n_heads, dim1, dim2):
    w = w.view(n_heads, 2, dim1 // n_heads // 2, dim2)
    w = w.transpose(1, 2)
    w = w.reshape(dim1, dim2)
    return w

n_heads = 2
dim1 = 12
dim2 = 12
w= torch.arange(dim1 * dim2).view(dim1, dim2)
permuted_w = permute(w, n_heads, dim1, dim2)
print(w)
print(permuted_w)
w = inverse_permute(permuted_w, n_heads, dim1, dim2)
print(w)
```

```python
# for llama and mistral series
metadata = {}
for key, field in bf16_reader.fields.items():
    metadata[key] = field.parts[field.data[0]][0]

n_heads     = metadata['llama.attention.head_count']
n_kv_heads  = metadata['llama.attention.head_count_kv']

# inverse permute for sliced rotary, qwen3 needn't
def inverse_permute(w, name, num_heads, num_kv_heads):
    if 'attn_q' in name:
        dim3 = num_heads
    elif 'attn_k' in name:
        dim3 = num_kv_heads
    else:
        return w
        
    dim1, dim2 = w.shape
    return w.view(dim3, 2, dim1 // dim3 // 2, dim2).transpose(1, 2).reshape(dim1, dim2)
```

</details>

## EXL3

In [None]:
from exllamav3 import Config, Model
device = "cuda:0"

'''
    Attention: the name is already converted:
    if name == 'output.weight': name = 'lm_head'
'''
def load_exl3_tensor(reader, name):
    layer = reader.find_module(name)
    layer.load(device = device)
    if 'norm' in name:
        return layer.weight.cpu()
    else:
        return layer.inner.get_weight_tensor().cpu().T

## legacy

In [None]:
def pt_get_tensor(reader, prefix, name):
    if name == 'token_embd.weight':
        return reader.get_tensor('model.embed_tokens.weight')
    for k in ['attn_norm', 'ffn_norm']:
        if k in name:
            return reader.get_tensor(prefix+name_map[k]+'.weight')
    return None

def load_fakequant_tensor(reader, name):
    layer = name.split('.')[1] # f'blk.{layer}.xxx'
    prefix = f'model.layers.{layer}.'
    tensor = pt_get_tensor(reader, prefix, name)
    if tensor is not None:
        return tensor
    for k in name_map:
        if k in name:
            pt_name = prefix+name_map[k]+ '_proj.weight'
            break
    return inverse_permute(name, reader.get_tensor(pt_name).float())

### AWQ

https://github.com/mit-han-lab/llm-awq/blob/main/awq/quantize/quantizer.py

<details>
<summary>awq quant logic</summary>

[Question about the zero point](https://github.com/mit-han-lab/llm-awq/issues/116)

I noticed that only negative minimum values are preserved as zero points with the code.

```python
    if zero_point:
        max_val = w.amax(dim=1, keepdim=True)
        min_val = w.amin(dim=1, keepdim=True)
        max_int = 2**n_bit - 1
        min_int = 0
        scales = (max_val - min_val).clamp(min=1e-5) / max_int
        zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
```

Then, why not preserve all the minimum values including the positive values?
</details>

In [None]:
# https://medium.com/@crclq2018/awq-how-its-code-works-1ea92fb80bd2
def load_awq_tensor(reader, name):
    layer = name.split('.')[1] # f'blk.{layer}.xxx'
    prefix = f'model.layers.{layer}.'
    
    tensor = pt_get_tensor(reader, prefix, name)
    if tensor is not None:
        return tensor

    pt_name = None
    for k in name_map:
        if k in name:
            pt_name = prefix+name_map[k]+ '_proj.'
            break
    assert pt_name is not None

    qweight = reader.get_tensor(pt_name+'qweight')
    qzeros  = reader.get_tensor(pt_name+'qzeros')
    scales  = reader.get_tensor(pt_name+'scales')

    # dequantize
    group_size = 128
    wf = torch.tensor([x * 4 for x in [0, 4, 1, 5, 2, 6, 3, 7]], dtype=torch.int32).unsqueeze(0)
    zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 8), wf.unsqueeze(0)).to(torch.int8)
    zeros = torch.bitwise_and(zeros, 0xf)
    zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])
    
    weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 2).expand(-1, -1, 8), wf.unsqueeze(0)).to(torch.int8)
    weight = torch.bitwise_and(weight, 0xf)
    weight = weight.reshape(-1, group_size, weight.shape[1] * weight.shape[2])
    
    scales = scales.reshape(-1, 1, scales.shape[-1])
    weight = scales * (weight - zeros) # by intristic broadcast
    weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])

    return inverse_permute(name, weight.float().T)

### GPTQ

https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/triton/gptq.py

https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/triton/quant/quantizer.py

<details>
<summary>gptq quant logic</summary>


```python

# if actorder:
H = torch.tensor([[8,100,100,100], [100,7,100,100], [100,100,5,100], [100,100,100,9]])
g_idx = torch.tensor([i//2 for i in range(4)])
perm = torch.argsort(torch.diag(H), descending=True)
# [3, 0, 1, 2]
invperm = torch.argsort(perm)
# [1, 2, 3, 0]
g_idx = g_idx[invperm]
# [0, 1, 1, 0]
 
if self.maxq < 0:
    self.scale = xmax
    self.zero = xmin
else:
    self.scale = (xmax - xmin) / self.maxq
    if self.sym:
        self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
    else:
        self.zero = torch.round(-xmin / self.scale)
```
</details>

In [None]:
# https://github.com/AutoGPTQ/AutoGPTQ/blob/main/auto_gptq/nn_modules/qlinear/qlinear_cuda.py
def load_gptq_tensor(reader, name):
    layer = name.split('.')[1] # f'blk.{layer}.xxx'
    prefix = f'model.layers.{layer}.'

    tensor = pt_get_tensor(reader, prefix, name)
    if tensor is not None:
        return tensor

    pt_name = None
    for k in name_map:
        if k in name:
            pt_name = prefix+name_map[k]+ '_proj.'
            break
    assert pt_name is not None

    qzeros  = reader.get_tensor(pt_name+'qzeros')
    qweight = reader.get_tensor(pt_name+'qweight')
    g_idx   = reader.get_tensor(pt_name+'g_idx')
    scales  = reader.get_tensor(pt_name+'scales')

    # dequantize
    wf = torch.tensor(list(range(0, 32, 4)), dtype=torch.int32).unsqueeze(0)
    zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 8), wf.unsqueeze(0)).to(torch.int8)
    zeros = torch.bitwise_and(zeros, 0xf)
    zeros = zeros + 1 # ohhhhhhhhh
    zeros = zeros.reshape(scales.shape)

    weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 8, -1), wf.unsqueeze(-1)).to(torch.int8)
    weight = torch.bitwise_and(weight, 0xf)
    weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])

    weight = scales[g_idx.long()] * (weight - zeros[g_idx.long()])

    return inverse_permute(name, weight.float().T)

### AutoRound

In [None]:
def load_autoround_tensor(reader, name):
    layer = name.split('.')[1] # f'blk.{layer}.xxx'
    prefix = f'model.layers.{layer}.'

    tensor = pt_get_tensor(reader, prefix, name)
    if tensor is not None:
        return tensor

    pt_name = None
    for k in name_map:
        if k in name:
            pt_name = prefix+name_map[k]+ '_proj.'
            break
    assert pt_name is not None

    qzeros  = reader.get_tensor(pt_name+'qzeros')
    qweight = reader.get_tensor(pt_name+'qweight')
    scales  = reader.get_tensor(pt_name+'scales')

    # dequantize
    group_size = 128
    wf = torch.tensor(list(range(0, 32, 4)), dtype=torch.int32).unsqueeze(0)
    zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 8), wf.unsqueeze(0)).to(torch.int8)
    zeros = torch.bitwise_and(zeros, 0xf)
    zeros = zeros + 1 # ohhhhhhhhh
    zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])

    weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 8, -1), wf.unsqueeze(-1)).to(torch.int8)
    weight = torch.bitwise_and(weight, 0xf)
    weight = weight.reshape(-1, group_size, weight.shape[2])
    
    scales = scales.reshape(-1, 1, scales.shape[-1])
    weight = scales * (weight - zeros) # by intristic broadcast
    weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])

    return inverse_permute(name, weight.float().T)

### OmniQuant/MLC-LLM

https://github.com/mlc-ai/mlc-llm/blob/main/python/mlc_llm/op/moe_matmul.py#L130

https://github.com/mlc-ai/mlc-llm/blob/main/python/mlc_llm/quantization/group_quantization.py#L61
<details>
<summary>dequantization</summary>

```python
    def _dequantize(w, s, e, i, j):
    tir_bin_mask = tir.const((2**quantize_dtype_bits) - 1, storage_dtype)
    tir_max_int = tir.const((2 ** (quantize_dtype_bits - 1)) - 1, model_dtype)
    w = w[e, i, j // num_elem_per_storage]
    s = s[e, i, j // group_size]
    shift = (j % num_elem_per_storage * quantize_dtype_bits).astype(storage_dtype)
    w = tir.bitwise_and(tir.shift_right(w, shift), tir_bin_mask).astype(model_dtype)
    return (w - tir_max_int) * s
```
</details>

In [None]:
import json
class mlcbin_open:
    def __init__(self, local, remote):
        self.local = local
        file_path = local + "ndarray-cache.json"
        param_metadata = None
        if not os.path.exists(local):
            print("Model doesn't exist, start downloading...")
            subprocess.run(["mkdir", "-p", local])
            subprocess.run(["wget", "-O", file_path, remote + "ndarray-cache.json"])
            param_metadata = json.load(open(file_path, encoding="utf-8"))["records"]
            for record in param_metadata:
                record_name = record["dataPath"]
                subprocess.run(["wget", "-O", self.local + record_name, remote + record_name])
        else:
            param_metadata = json.load(open(file_path, encoding="utf-8"))["records"]

        self.metadata = {}
        for record in param_metadata:
            dataPath = record["dataPath"]
            for i in record['records']:
                self.metadata[i["name"]] = {
                    "dataPath": dataPath,
                    "shape": i["shape"],
                    "dtype": i["dtype"],
                    "format": i["format"],
                    "nbytes": i["nbytes"],
                    "byteOffset": i["byteOffset"]
                }
        q_weight = self.metadata["model.layers.0.mlp.down_proj.q_weight"]
        q_scale  = self.metadata["model.layers.0.mlp.down_proj.q_scale"]
        self.groups = 8 * q_weight["shape"][1] // q_scale["shape"][1]
            
    def get_tensor(self, name):
        data = np.memmap(self.local + self.metadata[name]["dataPath"], mode = "r")
        shape = self.metadata[name]["shape"]
        dtype = np.uint32 if self.metadata[name]["dtype"] == "uint32" else np.float16
        cnt = shape[0]*shape[1] if len(shape) == 2 else shape[0]
        w = data_get(data, self.metadata[name]["byteOffset"], dtype, cnt)
        w = torch.from_numpy(w.copy())
        if len(shape) == 2:
            w = w.view(shape[0], shape[1])
        return w

In [None]:
def load_mlc_tensor(reader, name):
    # only has q_weight and q_scale, and stored in mlc-llm format
    layer = name.split('.')[1] # f'blk.{layer}.xxx'
    prefix = f'model.layers.{layer}.'
    
    for k in ['attn_norm', 'ffn_norm']:
        if k in name:
            return reader.get_tensor(prefix+name_map[k]+'.weight')

    pt_name = None    
    for k in name_map:
        if k in name:
            if k in ['ffn_gate', 'ffn_up']:
                pt_name = prefix + "mlp.gate_up_proj."
                break
            elif k in ['attn_q', 'attn_k', 'attn_v']:
                pt_name = prefix + "self_attn.qkv_proj."
                break
            pt_name = prefix + name_map[k] + "_proj."
            break
            
    if name ==  "token_embd.weight":
        pt_name = "model.embed_tokens."
    assert pt_name is not None

    qweight = reader.get_tensor(pt_name + "q_weight")
    scales  = reader.get_tensor(pt_name + "q_scale")
    assert qweight.shape[0] == scales.shape[0]
    dim0 = qweight.shape[0]
    unit_dim = dim0 // (n_heads+2*n_kv_heads)
    split_dim = [dim0//2, unit_dim*n_heads, unit_dim*(n_heads+n_kv_heads)]
    slice_dict = {
        'ffn_gate': slice(0, split_dim[0]),
        'ffn_up': slice(split_dim[0], dim0),
        
        'attn_q': slice(0, split_dim[1]),
        'attn_k': slice(split_dim[1], split_dim[2]),
        'attn_v': slice(split_dim[2], dim0)
    }
    slice_range = slice(0, dim0)
    if k in slice_dict:
        slice_range = slice_dict[k]

    # use autoround dequantize stype
    qweight = qweight[slice_range].to(torch.int32).T
    scales  = scales[slice_range].T

    # dequantize
    wf = torch.tensor(list(range(0, 32, 4)), dtype=torch.int32).unsqueeze(0)

    weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 8, -1), wf.unsqueeze(-1)).to(torch.int8)
    weight = torch.bitwise_and(weight, 0xf)
    weight = weight.reshape(-1, reader.groups, weight.shape[2])
    
    scales = scales.reshape(-1, 1, scales.shape[-1])
    
    weight = scales * (weight - 7) # without zero point
    weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])

    # return weight.float().T
    return inverse_permute(name, weight.float().T)

### HQQ

In [None]:
class hqqpt_open:
    def __init__(self, model_path):
        self.metadata = {}
        model_data = torch.load(model_path)
        for key, tensor in model_data.items():
            self.metadata[key] = tensor
            
    def get_tensor(self, name):
        tensor = self.metadata[name]
        if 'shape' in tensor:
            qweight = torch.cat((tensor['W_q'] >> 4, tensor['W_q'] & 0xf), dim=0)
            weight = tensor['scale'] * (qweight - tensor['zero'])
            weight = weight.reshape(tensor['shape'][0], tensor['shape'][1])
            return weight
        else:
            return tensor['weight']

In [None]:
def load_hqq_tensor(reader, name):
    layer = name.split('.')[1] # f'blk.{layer}.xxx'
    prefix = f'model.layers.{layer}.'
    for k in ['attn_norm', 'ffn_norm']:
        if k in name:
            return reader.get_tensor(prefix+name_map[k])
    
    pt_name = None
    for k in name_map:
        if k in name:
            pt_name = prefix+name_map[k]+ '_proj'
            break
    assert pt_name is not None
    weight = reader.get_tensor(pt_name)
    return inverse_permute(name, weight.float())

## ik_llama.cpp (``gguf-py/gguf/quants.py``)

<details>
<summary>gguf-py/gguf/gguf_reader.py</summary>

```diff
diff --git a/gguf-py/gguf/gguf_reader.py b/gguf-py/gguf/gguf_reader.py
index e8e61abf..5b0717c0 100644
--- a/gguf-py/gguf/gguf_reader.py
+++ b/gguf-py/gguf/gguf_reader.py
@@ -277,6 +277,8 @@ class GGUFReader:
             np_dims = tuple(reversed(dims.tolist()))
             block_size, type_size = GGML_QUANT_SIZES[ggml_type]
             n_bytes = n_elems * type_size // block_size
+            if ggml_type == GGMLQuantizationType.IQ2_KS:
+                n_bytes += 2 * int(dims[1])
             data_offs = int(start_offs + offset_tensor[0])
             item_type: npt.DTypeLike
             if ggml_type == GGMLQuantizationType.F16:
```
</details>


<details>
<summary>gguf-py/gguf/quants.py</summary>

```diff
diff --git a/gguf-py/gguf/quants.py b/gguf-py/gguf/quants.py
index ff589b85..bff524b4 100644
--- a/gguf-py/gguf/quants.py
+++ b/gguf-py/gguf/quants.py
@@ -15,11 +15,15 @@ def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantization
     block_size, type_size = GGML_QUANT_SIZES[quant_type]
     if shape[-1] % block_size != 0:
         raise ValueError(f"Quantized tensor row size ({shape[-1]}) is not a multiple of {quant_type.name} block size ({block_size})")
+    if quant_type == GGMLQuantizationType.IQ2_KS:
+        return (*shape[:-1], 2 + shape[-1] // block_size * type_size)
     return (*shape[:-1], shape[-1] // block_size * type_size)
 
 
 def quant_shape_from_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]:
     block_size, type_size = GGML_QUANT_SIZES[quant_type]
+    if quant_type == GGMLQuantizationType.IQ2_KS:
+        return (*shape[:-1], (shape[-1] - 2) // type_size * block_size)
     if shape[-1] % type_size != 0:
         raise ValueError(f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of {quant_type.name} type size ({type_size})")
     return (*shape[:-1], shape[-1] // type_size * block_size)
@@ -148,12 +152,21 @@ class __Quant(ABC):
     def dequantize_rows(cls, rows: np.ndarray) -> np.ndarray:
         rows = rows.view(np.uint8)
         shape = rows.shape
+        d = np.ones((shape[0], 1), dtype=float)
+        if cls.qtype == GGMLQuantizationType.IQ2_KS:
+            d, rows = np.hsplit(rows, [2])
+            d = d.view(np.float16).astype(np.float32)
         n_blocks = rows.size // cls.type_size
         blocks = rows.reshape((n_blocks, cls.type_size))
         blocks = cls.dequantize_blocks(blocks)
         assert blocks.dtype == np.float32
         assert blocks.shape[-1] == cls.block_size
-        return blocks.reshape(cls.__shape_from_bytes(shape))
+        # print(cls.qtype)
+        # tmpa = d * blocks.reshape(cls.__shape_from_bytes(shape))
+        # tmpa = tmpa.reshape(1, -1, 16)
+        # print(tmpa[0][1])
+        # exit(0)
+        return d * blocks.reshape(cls.__shape_from_bytes(shape))
 
     @classmethod
     def __shape_to_bytes(cls, shape: Sequence[int]):
@@ -1186,3 +1199,177 @@ class IQ4_XS(__Quant, qtype=GGMLQuantizationType.IQ4_XS):
         qs = np.take_along_axis(kvalues, qs, axis=-1).astype(np.float32).reshape((n_blocks, -1, 32))
 
         return (dl * qs).reshape((n_blocks, -1))
+
```
</details>


In [None]:
class IQ2_K(__Quant, qtype=GGMLQuantizationType.IQ2_K):
    kvalues = (-31, -13, 1, 17)

    @classmethod
    def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
        # according to dequantize_row_iq2_k in ggml/src/iqk/iqk_quantize.cpp
        n_blocks = blocks.shape[0]

        d, rest = np.hsplit(blocks, [2])
        extra, rest = np.hsplit(rest, [2])
        scales, qs = np.hsplit(rest, [QK_K // 32])

        d = d.view(np.float16).astype(np.float32)
        extra = extra.view(np.uint16)

        extra = extra.reshape((n_blocks, 1, -1)) >> np.array(list(range(QK_K // 16)), dtype=np.uint16).reshape((1, -1, 1))
        scales = scales.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2))
        extra = extra.reshape((n_blocks, -1)).astype(np.uint8) & np.uint8(0x01)
        scales = scales.reshape((n_blocks, -1)) & np.uint8(0x0F)
        # astype(np.int8) is very very very important!!!
        scales = scales.astype(np.int8) - np.int8(8)

        dl = (d * scales.astype(np.float32)).reshape((n_blocks, -1, 1))

        # wow! key point!!!
        qs = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1))
        qs = qs.reshape((n_blocks, -1, 128, 1)) & np.uint8(0x03)

        kvalues = np.array(cls.kvalues, dtype=np.int8).reshape((1, 1, 1, -1))
        qs = np.take_along_axis(kvalues, qs, axis=-1).astype(np.float32).reshape((n_blocks, -1, 16))
        qs = qs + 5 * extra.reshape(n_blocks, -1, 1)

        return (dl * qs).reshape((n_blocks, -1))

In [None]:
class IQ2_KS(__Quant, qtype=GGMLQuantizationType.IQ2_KS):

    @classmethod
    def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
        # according to dequantize_row_iq2_ks in ggml/src/iqk/iqk_quantize.cpp
        n_blocks = blocks.shape[0]

        extra, rest = np.hsplit(blocks, [1])
        scales_h, rest = np.hsplit(rest, [1])
        scales_l, qs = np.hsplit(rest, [QK_K // 64])

        extra = extra.reshape((n_blocks, 1, -1)) >> np.array(list(range(QK_K // 32)), dtype=np.uint8).reshape((1, -1, 1))
        scales_h = scales_h.reshape((n_blocks, 1, -1)) >> np.array(list(range(QK_K // 32)), dtype=np.uint8).reshape((1, -1, 1))
        scales_l = scales_l.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2))
        extra = extra.reshape((n_blocks, -1)) & np.uint8(0x01)
        scales_h = scales_h.reshape((n_blocks, -1)) & np.uint8(0x01)
        scales_l = scales_l.reshape((n_blocks, -1)) & np.uint8(0x0F)
        scales = (scales_l | (scales_h << np.uint8(4))).astype(np.int8) - np.int8(16)
        dl = scales.astype(np.float32).reshape((n_blocks, -1, 1))

        # wow! key point!!!
        qs = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1))
        qs = qs.reshape((n_blocks, -1, 128, 1)) & np.uint8(0x03)

        kvalues = np.array(IQ2_K.kvalues, dtype=np.int8).reshape((1, 1, 1, -1))
        qs = np.take_along_axis(kvalues, qs, axis=-1).astype(np.float32).reshape((n_blocks, -1, 32))
        qs = qs + 5 * extra.reshape(n_blocks, -1, 1)

        return (dl * qs).reshape((n_blocks, -1))

In [None]:
class IQ2_KL(__Quant, qtype=GGMLQuantizationType.IQ2_KL):
    kvalues = (0xe9c1, 0x0dc1, 0xc1d8, 0xf6d8, 0x0dd8, 0x2fd8, 0xd8e9, 0xe9e9,
               0x01e9, 0x0de9, 0x1ce9, 0xc1f6, 0x01f6, 0x0df6, 0x2ff6, 0xe901,
               0xf601, 0x0101, 0x0d01, 0x1c01, 0xd80d, 0xe90d, 0xf60d, 0x010d,
               0x0d0d, 0xc11c, 0xe91c, 0x011c, 0x1c1c, 0x2f1c, 0xe92f, 0x0d2f)

    @classmethod
    def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
        # according to dequantize_row_iq2_kl in ggml/src/iqk/iqk_quantize.cpp
        n_blocks = blocks.shape[0]

        scales_h, rest = np.hsplit(blocks, [2])
        scales_l, rest = np.hsplit(rest, [QK_K // 64])
        qs, qh = np.hsplit(rest, [QK_K // 4])
        
        scales_h = scales_h.view(np.uint16)
        scales_h = scales_h.reshape((n_blocks, 1, -1)) >> np.array(list(range(0, QK_K//16, 2)), dtype=np.uint16).reshape((1, -1, 1))
        scales_h = scales_h.reshape((n_blocks, -1)) & np.uint16(0x03)
        scales_l = scales_l.reshape((n_blocks, -1, 1, 4)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
        scales_l = scales_l.reshape((n_blocks, -1)).astype(np.uint16) & np.uint16(0x0F)
        scales = (scales_l.astype(np.uint16) | scales_h << np.uint16(4)).astype(np.int16) - np.int16(32)
        dl = scales.astype(np.float32).reshape((n_blocks, -1, 1))
        
        qs = qs.reshape((n_blocks, -1, 1, 16)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
        qs = qs.reshape((n_blocks, -1, 128, 1)) & np.uint8(0x0F)
        qh = qh.reshape((n_blocks, -1, 1, 16)) >> np.array(list(range(QK_K//32)), dtype=np.uint8).reshape((1, 1, 8, 1))
        qh = qh.reshape((n_blocks, -1, 128, 1)) & np.uint8(0x01)
        qs = (qs | (qh << np.uint8(4))).astype(np.uint8)
        kvalues = np.array(cls.kvalues, dtype=np.uint16).reshape((1, 1, 1, -1))
        qs = np.take_along_axis(kvalues, qs, axis=-1).view(np.int8).astype(np.float32).reshape((n_blocks, -1, 32))
        return (dl * qs).reshape((n_blocks, -1))

In [None]:
class IQ3_K(__Quant, qtype=GGMLQuantizationType.IQ3_K):
    kvalues = (-63, -40, -23, -10, 1, 13, 28, 47)

    @classmethod
    def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
        # according to dequantize_row_iq3_k in ggml/src/iqk/iqk_quantize.cpp
        n_blocks = blocks.shape[0]

        d, rest = np.hsplit(blocks, [2])
        extra, rest = np.hsplit(rest, [2])
        scales_h, rest = np.hsplit(rest, [2])
        scales_l, rest = np.hsplit(rest, [QK_K // 32])
        qs, qh = np.hsplit(rest, [QK_K // 4])

        d = d.view(np.float16).astype(np.float32)
        extra = extra.view(np.uint16)
        scales_h = scales_h.view(np.uint16)

        extra = extra.reshape((n_blocks, 1, -1)) >> np.array(list(range(QK_K // 16)), dtype=np.uint16).reshape((1, -1, 1))
        scales_h = scales_h.reshape((n_blocks, 1, -1)) >> np.array(list(range(QK_K // 16)), dtype=np.uint16).reshape((1, -1, 1))
        scales_l = scales_l.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2))
        extra = extra.reshape((n_blocks, -1)).astype(np.uint8) & np.uint8(0x01)
        scales_h = scales_h.reshape((n_blocks, -1)).astype(np.uint8) & np.uint8(0x01)
        scales_l = scales_l.reshape((n_blocks, -1)) & np.uint8(0x0F)
        scales = (1 + 2 * scales_l.astype(np.float32)) * (1 - 2 * scales_h.astype(np.float32))
        dl = (d * scales).reshape((n_blocks, -1, 1))

        # wow! key point!!!
        qs = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1))
        qs = qs.reshape((n_blocks, -1, 256, 1)) & np.uint8(0x03)
        qh = qh.reshape((n_blocks, -1, 1, 32)) >> np.array(list(range(8)), dtype=np.uint8).reshape((1, 1, 8, 1))
        qh = qh.reshape((n_blocks, -1, 256, 1)) & np.uint8(0x01)
        qs = (qs | (qh << np.uint8(2))).astype(np.uint8)

        kvalues = np.array(cls.kvalues, dtype=np.int8).reshape((1, 1, 1, -1))
        qs = np.take_along_axis(kvalues, qs, axis=-1).astype(np.float32).reshape((n_blocks, -1, 16))
        qs = qs + 4 * extra.reshape(n_blocks, -1, 1)

        return (dl * qs).reshape((n_blocks, -1))

In [None]:
class IQ4_K(__Quant, qtype=GGMLQuantizationType.IQ4_K):

    @classmethod
    def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
        # according to dequantize_row_iq4_k in ggml/src/iqk/iqk_quantize.cpp
        n_blocks = blocks.shape[0]

        d, rest = np.hsplit(blocks, [2])
        extra, rest = np.hsplit(rest, [2])
        scales_h, rest = np.hsplit(rest, [QK_K // 64])
        scales_l, qs = np.hsplit(rest, [QK_K // 32])

        d = d.view(np.float16).astype(np.float32)
        extra = extra.view(np.uint16)

        extra = extra.reshape((n_blocks, 1, -1)) >> np.array(list(range(QK_K // 16)), dtype=np.uint16).reshape((1, -1, 1))
        scales_h = scales_h.reshape((n_blocks, -1, 1)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4))
        scales_l = scales_l.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2))
        extra = extra.reshape((n_blocks, -1)).astype(np.uint8) & np.uint8(0x01)
        scales_h = scales_h.reshape((n_blocks, -1)) & np.uint8(0x03)
        scales_l = scales_l.reshape((n_blocks, -1)) & np.uint8(0x0F)

        scales = (scales_l | (scales_h << np.uint8(4))).astype(np.int8) - np.int8(32)
        dl = (d * scales.astype(np.float32)).reshape((n_blocks, -1, 1))

        # wow! key point!!!
        qs = qs.reshape((n_blocks, -1, 1, 16)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
        qs = qs.reshape((n_blocks, -1, 32, 1)) & np.uint8(0x0F)

        kvalues = np.array(IQ4_NL.kvalues, dtype=np.int8).reshape((1, 1, 1, -1))
        qs = np.take_along_axis(kvalues, qs, axis=-1).astype(np.float32).reshape((n_blocks, -1, 16))
        qs = qs + 4 * extra.reshape(n_blocks, -1, 1)

        return (dl * qs).reshape((n_blocks, -1))

In [None]:
class IQ4_KS(__Quant, qtype=GGMLQuantizationType.IQ4_KS):

    @classmethod
    def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
        # according to dequantize_row_iq4_ks in ggml/src/iqk/iqk_quantize.cpp
        n_blocks = blocks.shape[0]
        scales, qs = np.hsplit(blocks, [QK_K // 32])
        scales = scales.reshape((n_blocks, -1))
        extra  = scales & np.uint8(0x1)
        scales = (scales & np.uint8(254)).astype(np.int8) - np.int8(127)
        dl = scales.astype(np.float32).reshape((n_blocks, -1, 1))

        # wow! key point!!!
        qs = qs.reshape((n_blocks, -1, 1, 16)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
        qs = qs.reshape((n_blocks, -1, 32, 1)) & np.uint8(0x0F)

        kvalues = np.array(IQ4_NL.kvalues, dtype=np.int8).reshape((1, 1, 1, -1))
        qs = np.take_along_axis(kvalues, qs, axis=-1).astype(np.float32).reshape((n_blocks, -1, 32))
        qs = qs + 4 * extra.reshape(n_blocks, -1, 1)

        return (dl * qs).reshape((n_blocks, -1))

In [None]:
class IQ5_K(__Quant, qtype=GGMLQuantizationType.IQ5_K):
    kvalues = (-126, -114, -103, -92, -83, -74, -65, -57, -50, -43, -36, -30, -24, -18, -12, -6, -1, 5, 11, 17, 23, 29, 36, 43, 51, 59, 68, 77, 87, 97, 109, 121)

    @classmethod
    def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
        # according to dequantize_row_iq5_k in ggml/src/iqk/iqk_quantize.cpp
        n_blocks = blocks.shape[0]

        d, rest = np.hsplit(blocks, [2])
        extra, rest = np.hsplit(rest, [2])
        scales_h, rest = np.hsplit(rest, [QK_K // 64])
        scales_l, rest = np.hsplit(rest, [QK_K // 32])
        qs, qh = np.hsplit(rest, [QK_K // 2])

        d = d.view(np.float16).astype(np.float32)
        extra = extra.view(np.uint16)

        extra = extra.reshape((n_blocks, 1, -1)) >> np.array(list(range(QK_K // 16)), dtype=np.uint16).reshape((1, -1, 1))
        scales_h = scales_h.reshape((n_blocks, -1, 1)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4))
        scales_l = scales_l.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2))
        extra = extra.reshape((n_blocks, -1)).astype(np.uint8) & np.uint8(0x01)
        scales_h = scales_h.reshape((n_blocks, -1)) & np.uint8(0x03)
        scales_l = scales_l.reshape((n_blocks, -1)) & np.uint8(0x0F)

        scales = (scales_l | (scales_h << np.uint8(4))).astype(np.int8) - np.int8(32)
        dl = (d * scales.astype(np.float32)).reshape((n_blocks, -1, 1))

        # wow! key point!!!
        qs = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
        qs = qs.reshape((n_blocks, -1, 256, 1)) & np.uint8(0x0F)
        qh = qh.reshape((n_blocks, -1, 1, 32)) >> np.array(list(range(8)), dtype=np.uint8).reshape((1, 1, 8, 1))
        qh = qh.reshape((n_blocks, -1, 256, 1)) & np.uint8(0x01)
        qs = (qs | (qh << np.uint8(4))).astype(np.uint8)

        kvalues = np.array(cls.kvalues, dtype=np.int8).reshape((1, 1, 1, -1))
        qs = np.take_along_axis(kvalues, qs, axis=-1).astype(np.float32).reshape((n_blocks, -1, 16))
        qs = qs + 2 * extra.reshape(n_blocks, -1, 1)

        return (dl * qs).reshape((n_blocks, -1))

# Quant

In [None]:
def find_best_neighbour(neighbours, grid, scale, xval, weight):
    assert np.size(neighbours) > 0
    d2 = [np.sum(weight * (scale * grid[n] - xval)**2) for n in neighbours]
    grid_index = neighbours[np.argmin(d2)]
    assert grid_index >= 0, "No valid grid index found"
    return grid_index, (grid[grid_index] - 1)//2

In [None]:
def iq1_find_best_neighbour(neighbours, grid, scale, xval, weight, xg):
    assert np.size(neighbours) > 0
    d2 = [np.sum(weight * (scale * xg[(grid[n] - 1)//2] - xval)**2) for n in neighbours]
    grid_index = neighbours[np.argmin(d2)]
    assert grid_index >=0, "No valid grid index found"
    return grid_index, (grid[grid_index] - 1)//2
    # if grid_index < 0:
    #     d2 = [np.sum(weight * (scale * xg[(g - 1)//2] - xval)**2) for g in grid]
    #     grid_index = np.argmin(d2)

In [None]:
# see Model file GGUF section
# from Llama-3.2-1B-Instruct-f16.gguf blk.0.attn_k.weight first block
sigma2 = 0.00406796578
xb = np.array([
    0.0581054688, 0.117675781, 0.0615234375, -0.0395507812, -0.0522460938, -0.10546875, -0.06640625, -0.0693359375,
    0.0246582031, -0.0502929688, 0.00503540039, -0.0181884766, 0.012878418, 0.0458984375, 0.0310058594, -0.044921875,
    0.00939941406, -0.000237464905, 0.145507812, -0.0366210938, -0.0554199219, 0.134765625, -0.0476074219, -0.0244140625,
    0.00415039062, 0.0250244141, -0.0712890625, 0.0522460938, 0.0339355469, 0.078125, 0.048828125, -0.0270996094
])
qw = np.array([
    0.0176409613, 0.0372294672, 0.108566001, 0.000205380886, 0.0152213955, 0.0277738478, 0.0150570357, 0.020145515,
    0.0244456567, 0.0230933968, 0.0182256084, 0.0103802737, 0.023744056, 0.0159365162, 0.0502867512, 0.0228893068,
    1.04835891e-07, 0.044989869, 0.501476467, 0.0156725124, 0.0488786846, 0.0643265024, 0.0239071101, 0.019364614,
    0.194378719, 0.0222598892, 0.0501582734, 0.0344416499, 0.030363027, 0.209668413, 0.362394035, 0.0164955128
])

## IQ1/IQ2

In [None]:
kgrid_1bit_2048 = [
            0,     2,     5,     8,    10,    17,    21,    32,    34,    40,    42,    69,    81,    84,    86,   101,
          128,   130,   136,   138,   149,   160,   162,   168,   170,   260,   261,   273,   276,   278,   281,   282,
          293,   321,   326,   329,   338,   341,   346,   353,   356,   358,   360,   389,   401,   404,   406,   421,
          512,   514,   520,   522,   533,   544,   546,   552,   554,   581,   593,   601,   612,   617,   640,   642,
          648,   650,   657,   661,   665,   672,   674,   680,   682,  1041,  1044,  1046,  1061,  1089,  1097,  1109,
         1114,  1124,  1125,  1169,  1177,  1189,  1281,  1284,  1285,  1286,  1301,  1304,  1306,  1321,  1344,  1349,
         1354,  1360,  1361,  1364,  1365,  1366,  1369,  1376,  1378,  1381,  1384,  1386,  1409,  1425,  1429,  1432,
         1434,  1441,  1444,  1445,  1446,  1449,  1556,  1561,  1601,  1604,  1616,  1618,  1621,  1624,  1632,  1633,
         1638,  1641,  1669,  1681,  1684,  1689,  2048,  2050,  2056,  2058,  2069,  2080,  2082,  2088,  2090,  2117,
         2129,  2134,  2149,  2176,  2178,  2184,  2186,  2197,  2208,  2210,  2216,  2218,  2309,  2321,  2324,  2329,
         2340,  2341,  2369,  2384,  2385,  2389,  2401,  2404,  2409,  2449,  2452,  2454,  2457,  2469,  2560,  2562,
         2568,  2570,  2581,  2592,  2594,  2600,  2602,  2629,  2641,  2649,  2657,  2661,  2688,  2690,  2693,  2696,
         2698,  2709,  2720,  2722,  2728,  2730,  4112,  4113,  4116,  4121,  4132,  4133,  4161,  4164,  4176,  4181,
         4184,  4193,  4196,  4197,  4201,  4241,  4244,  4246,  4257,  4261,  4353,  4356,  4358,  4361,  4368,  4370,
         4373,  4376,  4385,  4388,  4393,  4421,  4426,  4432,  4433,  4434,  4436,  4437,  4438,  4441,  4448,  4453,
         4484,  4498,  4501,  4513,  4516,  4625,  4628,  4630,  4645,  4672,  4678,  4681,  4690,  4693,  4696,  4698,
         4708,  4710,  4741,  4753,  4756,  4758,  4773,  5121,  5126,  5129,  5140,  5141,  5144,  5145,  5153,  5158,
         5185,  5189,  5190,  5192,  5194,  5201,  5204,  5205,  5206,  5209,  5218,  5221,  5224,  5252,  5257,  5264,
         5268,  5269,  5272,  5273,  5274,  5281,  5284,  5285,  5289,  5378,  5381,  5386,  5393,  5396,  5397,  5398,
         5401,  5408,  5410,  5413,  5416,  5418,  5441,  5444,  5445,  5446,  5457,  5458,  5460,  5461,  5462,  5465,
         5466,  5473,  5476,  5477,  5478,  5481,  5504,  5506,  5508,  5509,  5512,  5514,  5520,  5521,  5524,  5525,
         5526,  5529,  5530,  5536,  5538,  5541,  5633,  5636,  5637,  5638,  5653,  5654,  5656,  5658,  5665,  5670,
         5696,  5698,  5700,  5701,  5704,  5706,  5713,  5717,  5718,  5720,  5721,  5729,  5732,  5733,  5736,  5737,
         5738,  5766,  5770,  5778,  5781,  5796,  5801,  6161,  6166,  6181,  6209,  6212,  6214,  6217,  6224,  6229,
         6232,  6234,  6240,  6241,  6244,  6246,  6249,  6277,  6289,  6292,  6309,  6416,  6418,  6421,  6426,  6433,
         6437,  6466,  6468,  6469,  6472,  6481,  6484,  6485,  6486,  6489,  6490,  6496,  6501,  6506,  6537,  6545,
         6546,  6549,  6552,  6561,  6566,  6569,  6665,  6678,  6692,  6694,  6724,  6726,  6729,  6736,  6738,  6741,
         6744,  6753,  6758,  6761,  6789,  6801,  6806,  6810,  8192,  8194,  8200,  8202,  8213,  8224,  8226,  8229,
         8232,  8234,  8261,  8273,  8281,  8289,  8293,  8320,  8322,  8328,  8330,  8341,  8352,  8354,  8357,  8360,
         8362,  8453,  8465,  8468,  8473,  8485,  8514,  8516,  8521,  8533,  8536,  8538,  8545,  8548,  8549,  8550,
         8581,  8592,  8598,  8601,  8613,  8705,  8712,  8714,  8721,  8725,  8736,  8738,  8744,  8746,  8773,  8785,
         8790,  8793,  8805,  8833,  8840,  8842,  8849,  8853,  8864,  8866,  8872,  8874,  9221,  9236,  9238,  9241,
         9253,  9284,  9285,  9286,  9289,  9298,  9301,  9304,  9306,  9318,  9349,  9361,  9364,  9369,  9377,  9381,
         9481,  9493,  9505,  9513,  9536,  9541,  9544,  9553,  9556,  9557,  9561,  9570,  9573,  9576,  9609,  9616,
         9620,  9621,  9624,  9626,  9633,  9636,  9638,  9641,  9733,  9744,  9746,  9753,  9765,  9793,  9801,  9813,
         9824,  9825,  9833,  9860,  9862,  9872,  9882, 10240, 10242, 10248, 10250, 10261, 10272, 10274, 10280, 10282,
        10309, 10321, 10324, 10341, 10368, 10370, 10376, 10378, 10400, 10402, 10408, 10410, 10505, 10513, 10516, 10521,
        10533, 10566, 10569, 10578, 10581, 10593, 10596, 10598, 10601, 10629, 10640, 10646, 10649, 10660, 10661, 10752,
        10754, 10760, 10762, 10784, 10786, 10792, 10794, 10821, 10833, 10838, 10841, 10853, 10880, 10882, 10888, 10890,
        10901, 10912, 10914, 10920, 10922, 16389, 16401, 16406, 16421, 16457, 16466, 16469, 16472, 16474, 16481, 16484,
        16486, 16532, 16537, 16545, 16550, 16640, 16641, 16644, 16646, 16649, 16658, 16661, 16662, 16664, 16666, 16673,
        16678, 16681, 16709, 16712, 16714, 16721, 16724, 16725, 16726, 16729, 16730, 16741, 16744, 16746, 16769, 16772,
        16774, 16784, 16786, 16789, 16800, 16801, 16802, 16901, 16913, 16916, 16918, 16933, 16961, 16978, 16981, 16986,
        16996, 17001, 17033, 17044, 17061, 17409, 17429, 17433, 17449, 17477, 17480, 17482, 17489, 17492, 17493, 17494,
        17505, 17506, 17509, 17512, 17514, 17537, 17542, 17545, 17552, 17554, 17557, 17568, 17569, 17577, 17665, 17666,
        17669, 17674, 17681, 17684, 17685, 17686, 17689, 17696, 17701, 17706, 17729, 17732, 17733, 17734, 17737, 17744,
        17745, 17748, 17749, 17750, 17752, 17753, 17761, 17764, 17765, 17766, 17769, 17794, 17796, 17797, 17800, 17809,
        17812, 17813, 17814, 17817, 17818, 17829, 17832, 17834, 17921, 17925, 17929, 17940, 17941, 17944, 17946, 17953,
        17956, 17961, 17984, 17986, 17989, 17992, 18000, 18001, 18002, 18005, 18006, 18009, 18018, 18021, 18024, 18049,
        18053, 18058, 18068, 18069, 18081, 18084, 18086, 18437, 18449, 18453, 18458, 18469, 18498, 18505, 18512, 18517,
        18520, 18529, 18532, 18534, 18537, 18565, 18577, 18580, 18582, 18585, 18597, 18689, 18693, 18694, 18698, 18704,
        18708, 18709, 18712, 18721, 18724, 18726, 18752, 18757, 18762, 18769, 18770, 18772, 18773, 18774, 18777, 18784,
        18786, 18789, 18790, 18794, 18822, 18825, 18834, 18837, 18838, 18840, 18849, 18852, 18854, 18857, 18966, 19012,
        19014, 19017, 19029, 19032, 19034, 19044, 19049, 19092, 19109, 20481, 20484, 20485, 20486, 20489, 20498, 20501,
        20506, 20513, 20516, 20521, 20544, 20549, 20552, 20561, 20564, 20565, 20566, 20569, 20581, 20584, 20614, 20617,
        20629, 20632, 20640, 20641, 20646, 20649, 20741, 20744, 20745, 20746, 20753, 20756, 20757, 20758, 20760, 20761,
        20768, 20773, 20774, 20776, 20778, 20801, 20804, 20805, 20806, 20809, 20816, 20817, 20818, 20820, 20821, 20822,
        20824, 20825, 20826, 20833, 20836, 20837, 20838, 20841, 20866, 20869, 20881, 20884, 20885, 20886, 20889, 20896,
        20901, 20906, 20993, 20998, 21010, 21013, 21018, 21025, 21028, 21058, 21061, 21066, 21073, 21076, 21077, 21078,
        21081, 21090, 21093, 21125, 21136, 21138, 21141, 21145, 21146, 21156, 21508, 21509, 21521, 21524, 21525, 21526,
        21528, 21529, 21537, 21541, 21544, 21546, 21569, 21572, 21573, 21574, 21577, 21578, 21584, 21585, 21588, 21589,
        21590, 21592, 21593, 21594, 21601, 21602, 21604, 21605, 21606, 21609, 21632, 21640, 21642, 21649, 21652, 21653,
        21654, 21657, 21665, 21668, 21669, 21674, 21761, 21762, 21764, 21765, 21766, 21769, 21776, 21777, 21778, 21780,
        21781, 21782, 21785, 21786, 21793, 21796, 21797, 21798, 21801, 21824, 21825, 21826, 21828, 21829, 21830, 21832,
        21833, 21840, 21841, 21842, 21844, 21845, 21846, 21848, 21849, 21850, 21856, 21857, 21860, 21861, 21862, 21864,
        21865, 21866, 21889, 21892, 21893, 21897, 21898, 21904, 21905, 21908, 21909, 21910, 21912, 21913, 21921, 21924,
        21925, 21926, 21929, 22016, 22017, 22018, 22020, 22022, 22024, 22025, 22033, 22036, 22037, 22040, 22041, 22048,
        22049, 22050, 22052, 22053, 22054, 22056, 22057, 22081, 22085, 22086, 22088, 22089, 22090, 22096, 22097, 22098,
        22100, 22101, 22102, 22104, 22105, 22106, 22113, 22116, 22117, 22121, 22146, 22149, 22150, 22152, 22153, 22154,
        22161, 22165, 22170, 22178, 22181, 22182, 22184, 22185, 22532, 22533, 22534, 22537, 22544, 22549, 22552, 22561,
        22570, 22597, 22600, 22602, 22609, 22612, 22613, 22614, 22616, 22617, 22624, 22626, 22628, 22629, 22658, 22665,
        22672, 22674, 22677, 22680, 22689, 22697, 22785, 22786, 22789, 22794, 22801, 22804, 22805, 22806, 22809, 22821,
        22849, 22852, 22853, 22854, 22857, 22864, 22865, 22866, 22868, 22869, 22870, 22872, 22873, 22874, 22881, 22884,
        22885, 22886, 22889, 22913, 22917, 22921, 22929, 22932, 22933, 22934, 22936, 22937, 22949, 23044, 23048, 23061,
        23066, 23072, 23077, 23078, 23081, 23109, 23112, 23113, 23121, 23125, 23126, 23128, 23129, 23138, 23141, 23144,
        23146, 23169, 23178, 23186, 23189, 23190, 23192, 23194, 23201, 24581, 24596, 24598, 24601, 24613, 24644, 24656,
        24661, 24662, 24664, 24666, 24673, 24676, 24678, 24681, 24705, 24726, 24741, 24833, 24836, 24838, 24841, 24850,
        24853, 24865, 24866, 24870, 24873, 24901, 24905, 24913, 24917, 24918, 24921, 24933, 24934, 24938, 24964, 24970,
        24978, 24981, 24993, 24998, 25001, 25105, 25110, 25113, 25152, 25153, 25158, 25173, 25174, 25176, 25184, 25221,
        25233, 25238, 25253, 25617, 25618, 25621, 25622, 25626, 25633, 25638, 25641, 25664, 25666, 25669, 25672, 25674,
        25681, 25684, 25685, 25686, 25689, 25690, 25696, 25698, 25701, 25732, 25733, 25737, 25744, 25746, 25748, 25749,
        25750, 25752, 25754, 25761, 25764, 25769, 25861, 25864, 25866, 25873, 25877, 25878, 25881, 25924, 25925, 25926,
        25929, 25936, 25937, 25940, 25941, 25942, 25945, 25953, 25956, 25957, 25958, 25961, 25990, 25993, 25994, 26001,
        26005, 26006, 26009, 26010, 26018, 26021, 26022, 26024, 26114, 26121, 26133, 26144, 26150, 26152, 26153, 26176,
        26181, 26184, 26186, 26193, 26196, 26197, 26198, 26200, 26202, 26208, 26213, 26216, 26240, 26242, 26245, 26250,
        26260, 26262, 26264, 26265, 26272, 26276, 26278, 26282, 26646, 26649, 26661, 26689, 26706, 26709, 26714, 26721,
        26729, 26757, 26769, 26776, 26790, 26881, 26884, 26896, 26901, 26913, 26916, 26918, 26921, 26944, 26945, 26949,
        26950, 26952, 26961, 26964, 26965, 26966, 26969, 26976, 26981, 26986, 27010, 27012, 27018, 27029, 27041, 27044,
        27045, 27049, 27153, 27158, 27160, 27201, 27204, 27209, 27216, 27221, 27224, 27226, 27236, 27237, 27241, 27270,
        27284, 27288, 27290, 27302, 32768, 32770, 32776, 32778, 32800, 32802, 32808, 32810, 32837, 32848, 32849, 32852,
        32854, 32857, 32869, 32896, 32898, 32904, 32906, 32917, 32928, 32930, 32936, 32938, 33029, 33041, 33044, 33046,
        33049, 33061, 33089, 33092, 33097, 33104, 33106, 33109, 33110, 33112, 33113, 33124, 33126, 33129, 33157, 33161,
        33172, 33174, 33177, 33189, 33280, 33282, 33288, 33290, 33301, 33312, 33314, 33320, 33322, 33361, 33364, 33369,
        33381, 33408, 33410, 33416, 33418, 33429, 33440, 33442, 33448, 33450, 33812, 33817, 33857, 33860, 33873, 33877,
        33882, 33889, 33892, 33897, 33940, 33945, 34049, 34057, 34066, 34069, 34074, 34086, 34089, 34112, 34113, 34117,
        34120, 34129, 34132, 34133, 34134, 34137, 34138, 34149, 34150, 34152, 34154, 34177, 34180, 34182, 34185, 34192,
        34194, 34197, 34200, 34214, 34321, 34326, 34329, 34341, 34369, 34372, 34377, 34378, 34384, 34389, 34393, 34394,
        34401, 34406, 34410, 34437, 34449, 34458, 34468, 34816, 34818, 34824, 34826, 34837, 34848, 34850, 34856, 34858,
        34881, 34885, 34897, 34900, 34905, 34917, 34921, 34944, 34946, 34952, 34954, 34965, 34976, 34978, 34984, 34986,
        35077, 35078, 35089, 35092, 35094, 35109, 35137, 35140, 35142, 35145, 35152, 35154, 35157, 35162, 35169, 35172,
        35205, 35222, 35225, 35237, 35328, 35330, 35336, 35338, 35349, 35360, 35362, 35368, 35370, 35397, 35409, 35412,
        35414, 35456, 35458, 35464, 35466, 35477, 35488, 35490, 35496, 35498, 36869, 36881, 36886, 36888, 36889, 36901,
        36929, 36934, 36937, 36949, 36952, 36954, 36969, 36970, 36997, 37009, 37012, 37014, 37017, 37029, 37121, 37124,
        37126, 37129, 37136, 37141, 37144, 37146, 37153, 37156, 37158, 37161, 37184, 37189, 37200, 37201, 37204, 37205,
        37206, 37209, 37218, 37221, 37252, 37254, 37266, 37269, 37272, 37281, 37284, 37286, 37289, 37381, 37393, 37396,
        37401, 37413, 37444, 37446, 37449, 37456, 37458, 37461, 37464, 37478, 37481, 37509, 37524, 37526, 37545, 37889,
        37892, 37894, 37904, 37909, 37912, 37926, 37952, 37962, 37969, 37972, 37973, 37974, 37976, 37977, 37984, 37985,
        37986, 37989, 38020, 38022, 38034, 38036, 38037, 38040, 38049, 38057, 38144, 38149, 38152, 38154, 38160, 38161,
        38164, 38165, 38166, 38169, 38177, 38181, 38185, 38186, 38209, 38212, 38213, 38214, 38217, 38224, 38225, 38226,
        38228, 38229, 38230, 38232, 38233, 38234, 38241, 38244, 38245, 38246, 38249, 38273, 38277, 38280, 38289, 38290,
        38292, 38293, 38294, 38297, 38298, 38304, 38306, 38309, 38312, 38314, 38401, 38404, 38416, 38421, 38425, 38432,
        38438, 38441, 38469, 38472, 38473, 38481, 38482, 38485, 38486, 38489, 38501, 38504, 38530, 38532, 38537, 38538,
        38546, 38548, 38549, 38564, 38566, 38569, 38917, 38934, 38937, 38949, 38977, 38982, 38992, 38994, 38997, 38998,
        39002, 39012, 39013, 39045, 39057, 39062, 39065, 39077, 39172, 39174, 39177, 39184, 39186, 39189, 39192, 39194,
        39200, 39201, 39204, 39206, 39232, 39234, 39237, 39240, 39242, 39249, 39252, 39253, 39254, 39257, 39266, 39269,
        39270, 39274, 39297, 39300, 39312, 39314, 39317, 39322, 39329, 39334, 39429, 39445, 39461, 39492, 39494, 39497,
        39504, 39509, 39512, 39521, 39557, 39569, 39572, 39573, 39574, 40960, 40962, 40968, 40970, 40981, 40992, 40994,
        41000, 41002, 41029, 41041, 41044, 41046, 41049, 41088, 41090, 41096, 41098, 41109, 41120, 41122, 41128, 41130,
        41221, 41225, 41233, 41236, 41238, 41241, 41242, 41286, 41289, 41297, 41301, 41304, 41306, 41313, 41316, 41349,
        41360, 41362, 41366, 41369, 41474, 41480, 41482, 41488, 41497, 41506, 41512, 41514, 41541, 41553, 41558, 41561,
        41573, 41600, 41602, 41608, 41610, 41621, 41632, 41634, 41640, 41642, 42009, 42021, 42049, 42052, 42064, 42068,
        42069, 42072, 42074, 42081, 42085, 42086, 42088, 42089, 42117, 42246, 42249, 42256, 42258, 42261, 42264, 42278,
        42281, 42306, 42309, 42321, 42324, 42325, 42326, 42329, 42341, 42346, 42369, 42372, 42373, 42374, 42377, 42386,
        42389, 42392, 42501, 42513, 42518, 42522, 42529, 42533, 42564, 42566, 42570, 42578, 42581, 42582, 42584, 42592,
        42594, 42630, 42640, 42645, 42646, 42649, 42657, 42660, 42662, 43008, 43010, 43016, 43018, 43040, 43042, 43048,
        43050, 43089, 43092, 43094, 43097, 43136, 43138, 43144, 43146, 43157, 43168, 43170, 43176, 43178, 43269, 43284,
        43289, 43297, 43301, 43329, 43344, 43349, 43354, 43361, 43366, 43369, 43408, 43414, 43520, 43522, 43528, 43530,
        43552, 43554, 43560, 43562, 43601, 43604, 43606, 43648, 43650, 43656, 43658, 43669, 43680, 43682, 43688, 43690,
]
grid_size = 2048
nwant = 3

In [None]:
kgrid_2bit_256 = [
            0,     2,     5,     8,    10,    17,    20,    32,    34,    40,    42,    65,    68,    80,    88,    97,
          100,   128,   130,   138,   162,   257,   260,   272,   277,   320,   388,   408,   512,   514,   546,   642,
         1025,  1028,  1040,  1057,  1060,  1088,  1090,  1096,  1120,  1153,  1156,  1168,  1188,  1280,  1282,  1288,
         1312,  1350,  1385,  1408,  1425,  1545,  1552,  1600,  1668,  1700,  2048,  2053,  2056,  2068,  2088,  2113,
         2116,  2128,  2130,  2184,  2308,  2368,  2562,  2580,  4097,  4100,  4112,  4129,  4160,  4192,  4228,  4240,
         4245,  4352,  4360,  4384,  4432,  4442,  4480,  4644,  4677,  5120,  5128,  5152,  5157,  5193,  5248,  5400,
         5474,  5632,  5654,  6145,  6148,  6160,  6208,  6273,  6400,  6405,  6560,  6737,  8192,  8194,  8202,  8260,
         8289,  8320,  8322,  8489,  8520,  8704,  8706,  9217,  9220,  9232,  9280,  9302,  9472,  9537,  9572,  9872,
        10248, 10272, 10388, 10820, 16385, 16388, 16400, 16408, 16417, 16420, 16448, 16456, 16470, 16480, 16513, 16516,
        16528, 16640, 16672, 16737, 16768, 16773, 16897, 16912, 16968, 16982, 17000, 17408, 17416, 17440, 17536, 17561,
        17682, 17700, 17920, 18433, 18436, 18448, 18496, 18501, 18688, 18776, 18785, 18818, 19013, 19088, 20480, 20488,
        20497, 20505, 20512, 20608, 20616, 20740, 20802, 20900, 21137, 21648, 21650, 21770, 22017, 22100, 22528, 22545,
        22553, 22628, 22848, 23048, 24580, 24592, 24640, 24680, 24832, 24917, 25112, 25184, 25600, 25605, 25872, 25874,
        25988, 26690, 32768, 32770, 32778, 32833, 32898, 33028, 33048, 33088, 33297, 33793, 33796, 33808, 33813, 33856,
        33888, 34048, 34118, 34196, 34313, 34368, 34400, 34818, 35076, 35345, 36868, 36880, 36900, 36928, 37025, 37142,
        37248, 37445, 37888, 37922, 37956, 38225, 39041, 39200, 40962, 41040, 41093, 41225, 41472, 42008, 43088, 43268,
]
grid_size = 256
nwant = 2

In [None]:
kgrid_2bit_512 = [
            0,     2,     5,     8,    10,    17,    20,    22,    25,    32,    34,    37,    40,    65,    68,    70,
           73,    80,    82,    85,    88,    97,   100,   128,   130,   133,   136,   145,   148,   153,   160,   257,
          260,   262,   265,   272,   274,   277,   280,   282,   289,   292,   320,   322,   325,   328,   337,   340,
          352,   360,   385,   388,   400,   512,   514,   517,   520,   529,   532,   544,   577,   580,   592,   597,
          640,   650,  1025,  1028,  1030,  1033,  1040,  1042,  1045,  1048,  1057,  1060,  1088,  1090,  1093,  1096,
         1105,  1108,  1110,  1120,  1153,  1156,  1168,  1280,  1282,  1285,  1288,  1297,  1300,  1312,  1345,  1348,
         1360,  1377,  1408,  1537,  1540,  1552,  1574,  1600,  1602,  1668,  2048,  2050,  2053,  2056,  2058,  2065,
         2068,  2080,  2085,  2113,  2116,  2128,  2136,  2176,  2208,  2218,  2305,  2308,  2320,  2368,  2433,  2441,
         2560,  2592,  2600,  2710,  2720,  4097,  4100,  4102,  4105,  4112,  4114,  4117,  4120,  4129,  4132,  4160,
         4162,  4165,  4168,  4177,  4180,  4192,  4202,  4225,  4228,  4240,  4352,  4354,  4357,  4360,  4369,  4372,
         4384,  4417,  4420,  4432,  4480,  4500,  4502,  4609,  4612,  4614,  4624,  4672,  4704,  5120,  5122,  5125,
         5128,  5137,  5140,  5152,  5185,  5188,  5193,  5200,  5220,  5248,  5377,  5380,  5392,  5440,  5632,  5652,
         5705,  6145,  6148,  6160,  6162,  6208,  6228,  6278,  6400,  6405,  6502,  6737,  6825,  8192,  8194,  8197,
         8200,  8202,  8209,  8212,  8224,  8257,  8260,  8272,  8320,  8352,  8449,  8452,  8464,  8512,  8520,  8549,
         8704,  8738,  8832,  8872,  9217,  9220,  9232,  9257,  9280,  9472,  9537,  9554,  9625,  9729,  9754,  9894,
        10240, 10248, 10250, 10272, 10325, 10376, 10402, 10600, 10640, 10760, 10784, 10882, 10888, 10890, 16385, 16388,
        16390, 16393, 16400, 16402, 16405, 16408, 16417, 16420, 16448, 16450, 16453, 16456, 16458, 16465, 16468, 16480,
        16485, 16513, 16516, 16528, 16640, 16642, 16645, 16648, 16657, 16660, 16672, 16705, 16708, 16720, 16768, 16773,
        16802, 16897, 16900, 16912, 16914, 16937, 16960, 17408, 17410, 17413, 17416, 17425, 17428, 17433, 17440, 17473,
        17476, 17488, 17536, 17556, 17665, 17668, 17680, 17700, 17728, 17818, 17920, 17930, 17988, 18000, 18433, 18436,
        18448, 18496, 18501, 18516, 18530, 18688, 18705, 18756, 18768, 18793, 18948, 20480, 20482, 20485, 20488, 20497,
        20500, 20512, 20520, 20545, 20548, 20560, 20608, 20737, 20740, 20752, 20757, 20800, 20802, 20992, 21060, 21162,
        21505, 21508, 21520, 21537, 21568, 21600, 21633, 21665, 21760, 21768, 21888, 21896, 22049, 22120, 22177, 22528,
        22548, 22593, 22608, 22681, 22810, 22848, 22850, 23173, 24577, 24580, 24592, 24640, 24660, 24674, 24710, 24745,
        24832, 25124, 25162, 25234, 25600, 25622, 25872, 25920, 25925, 26020, 26625, 26730, 26917, 27142, 27220, 27234,
        32768, 32770, 32773, 32776, 32785, 32788, 32800, 32810, 32833, 32836, 32848, 32896, 32898, 32936, 32938, 33025,
        33028, 33030, 33040, 33088, 33105, 33113, 33280, 33312, 33408, 33410, 33440, 33448, 33793, 33796, 33808, 33810,
        33813, 33856, 33888, 33929, 34048, 34116, 34213, 34328, 34410, 34816, 34824, 34853, 34906, 34944, 34946, 34984,
        35078, 35362, 35456, 35464, 35478, 35496, 36865, 36868, 36880, 36928, 36950, 36996, 37120, 37154, 37220, 37462,
        37513, 37888, 37893, 37956, 37968, 37976, 38185, 38288, 38290, 38465, 38993, 39078, 39241, 39445, 39520, 40960,
        40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048,
        42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690,
]
grid_size = 512
nwant = 2

In [None]:
kgrid_2bit_1024 = [
            0,     2,     5,     8,    10,    17,    20,    22,    25,    32,    34,    37,    40,    65,    68,    70,
           73,    80,    82,    85,    88,    97,   100,   102,   105,   128,   130,   133,   136,   145,   148,   160,
          165,   170,   257,   260,   262,   265,   272,   274,   277,   280,   289,   292,   320,   322,   325,   328,
          337,   340,   342,   345,   352,   357,   360,   385,   388,   400,   402,   405,   417,   420,   512,   514,
          517,   520,   529,   532,   544,   554,   577,   580,   582,   585,   592,   597,   640,   645,   650,   660,
          674,  1025,  1028,  1030,  1033,  1040,  1042,  1045,  1048,  1057,  1060,  1062,  1065,  1088,  1090,  1093,
         1096,  1098,  1105,  1108,  1110,  1113,  1120,  1122,  1125,  1153,  1156,  1158,  1161,  1168,  1173,  1176,
         1185,  1188,  1280,  1282,  1285,  1288,  1290,  1297,  1300,  1302,  1305,  1312,  1317,  1320,  1345,  1348,
         1350,  1353,  1360,  1362,  1365,  1368,  1377,  1380,  1408,  1410,  1413,  1416,  1425,  1428,  1440,  1537,
         1540,  1542,  1545,  1552,  1557,  1600,  1605,  1608,  1617,  1620,  1632,  1665,  1668,  1680,  2048,  2050,
         2053,  2056,  2065,  2068,  2070,  2073,  2080,  2085,  2090,  2113,  2116,  2118,  2121,  2128,  2130,  2133,
         2136,  2145,  2148,  2176,  2181,  2196,  2218,  2305,  2308,  2320,  2322,  2325,  2328,  2337,  2368,  2373,
         2376,  2385,  2388,  2400,  2433,  2448,  2560,  2577,  2580,  2594,  2600,  2602,  2640,  2713,  4097,  4100,
         4102,  4105,  4112,  4114,  4117,  4120,  4129,  4132,  4134,  4160,  4162,  4165,  4168,  4177,  4180,  4182,
         4185,  4192,  4194,  4197,  4200,  4225,  4228,  4230,  4240,  4245,  4248,  4257,  4260,  4352,  4354,  4357,
         4360,  4362,  4369,  4372,  4374,  4377,  4384,  4386,  4389,  4392,  4417,  4420,  4422,  4425,  4432,  4434,
         4437,  4440,  4449,  4452,  4480,  4482,  4485,  4488,  4497,  4500,  4609,  4612,  4617,  4624,  4629,  4641,
         4644,  4672,  4677,  4689,  4692,  4737,  4740,  4752,  5120,  5122,  5125,  5128,  5137,  5140,  5142,  5145,
         5152,  5157,  5160,  5185,  5188,  5190,  5193,  5200,  5202,  5205,  5208,  5217,  5220,  5248,  5250,  5253,
         5256,  5265,  5268,  5280,  5377,  5380,  5382,  5385,  5392,  5394,  5397,  5400,  5409,  5412,  5440,  5442,
         5445,  5448,  5457,  5460,  5472,  5505,  5508,  5520,  5632,  5637,  5640,  5649,  5652,  5664,  5697,  5700,
         5712,  5760,  5802,  6145,  6148,  6150,  6153,  6160,  6165,  6168,  6177,  6208,  6210,  6213,  6216,  6225,
         6228,  6240,  6273,  6276,  6400,  6402,  6405,  6408,  6417,  6420,  6432,  6465,  6468,  6480,  6505,  6562,
         6660,  6672,  6720,  6742,  8192,  8194,  8197,  8200,  8209,  8212,  8214,  8217,  8224,  8229,  8234,  8257,
         8260,  8272,  8274,  8277,  8292,  8320,  8330,  8340,  8362,  8449,  8452,  8464,  8466,  8469,  8481,  8512,
         8514,  8517,  8529,  8532,  8544,  8577,  8580,  8592,  8704,  8714,  8738,  8744,  8746,  8772,  8784,  8840,
         8842,  8872,  9217,  9220,  9222,  9225,  9232,  9237,  9240,  9249,  9252,  9280,  9282,  9285,  9288,  9297,
         9300,  9312,  9345,  9348,  9360,  9472,  9477,  9480,  9489,  9492,  9504,  9537,  9540,  9552,  9574,  9600,
         9729,  9732,  9744,  9792,  9817, 10240, 10245, 10257, 10260, 10305, 10308, 10320, 10378, 10410, 10497, 10500,
        10512, 10645, 10762, 10786, 10852, 10888, 10890, 16385, 16388, 16390, 16393, 16400, 16402, 16405, 16408, 16410,
        16417, 16420, 16422, 16448, 16450, 16453, 16456, 16458, 16465, 16468, 16470, 16473, 16480, 16482, 16485, 16513,
        16516, 16528, 16533, 16536, 16545, 16548, 16640, 16642, 16645, 16648, 16657, 16660, 16662, 16665, 16672, 16674,
        16677, 16705, 16708, 16710, 16713, 16720, 16722, 16725, 16728, 16737, 16740, 16768, 16770, 16773, 16776, 16785,
        16788, 16800, 16897, 16900, 16912, 16914, 16917, 16920, 16932, 16960, 16965, 16968, 16977, 16980, 16992, 17025,
        17028, 17408, 17410, 17413, 17416, 17418, 17425, 17428, 17430, 17433, 17440, 17442, 17445, 17448, 17473, 17476,
        17478, 17481, 17488, 17490, 17493, 17496, 17505, 17508, 17536, 17538, 17541, 17544, 17553, 17556, 17568, 17665,
        17668, 17670, 17673, 17680, 17682, 17685, 17688, 17697, 17700, 17728, 17730, 17733, 17736, 17745, 17748, 17760,
        17770, 17793, 17796, 17808, 17920, 17922, 17925, 17928, 17937, 17940, 17952, 17985, 17988, 18000, 18048, 18085,
        18433, 18436, 18441, 18448, 18450, 18453, 18456, 18465, 18468, 18496, 18498, 18501, 18504, 18513, 18516, 18528,
        18564, 18576, 18688, 18690, 18693, 18696, 18705, 18708, 18720, 18753, 18756, 18768, 18816, 18838, 18945, 18948,
        18960, 19008, 20480, 20482, 20485, 20488, 20497, 20500, 20502, 20505, 20512, 20514, 20517, 20520, 20545, 20548,
        20550, 20553, 20560, 20562, 20565, 20568, 20577, 20580, 20608, 20610, 20613, 20616, 20625, 20628, 20737, 20740,
        20742, 20745, 20752, 20754, 20757, 20760, 20769, 20772, 20800, 20802, 20805, 20808, 20817, 20820, 20832, 20865,
        20868, 20880, 20992, 20997, 21000, 21009, 21012, 21024, 21057, 21060, 21072, 21097, 21120, 21505, 21508, 21510,
        21513, 21520, 21522, 21525, 21528, 21537, 21540, 21568, 21570, 21573, 21576, 21585, 21588, 21600, 21633, 21636,
        21648, 21760, 21762, 21765, 21768, 21777, 21780, 21792, 21825, 21828, 21840, 21888, 22017, 22020, 22032, 22054,
        22080, 22528, 22530, 22533, 22536, 22545, 22548, 22560, 22593, 22596, 22608, 22618, 22656, 22785, 22788, 22800,
        22848, 23040, 23065, 23173, 23208, 24577, 24580, 24582, 24592, 24594, 24597, 24600, 24609, 24612, 24640, 24645,
        24648, 24657, 24660, 24672, 24708, 24720, 24832, 24834, 24837, 24840, 24849, 24852, 24864, 24897, 24900, 24912,
        24960, 24985, 25092, 25104, 25152, 25174, 25249, 25600, 25605, 25608, 25617, 25620, 25632, 25665, 25668, 25680,
        25728, 25857, 25860, 25872, 25920, 25930, 25960, 26002, 26112, 26260, 26625, 26628, 26640, 26725, 26776, 26880,
        26922, 27202, 27297, 32768, 32770, 32773, 32776, 32785, 32788, 32793, 32800, 32805, 32833, 32836, 32848, 32850,
        32853, 32856, 32865, 32896, 32901, 32913, 32916, 33025, 33028, 33033, 33040, 33042, 33045, 33048, 33057, 33060,
        33088, 33090, 33093, 33096, 33105, 33108, 33153, 33156, 33168, 33193, 33280, 33285, 33290, 33297, 33300, 33345,
        33348, 33360, 33793, 33796, 33798, 33801, 33808, 33810, 33813, 33816, 33825, 33856, 33858, 33861, 33864, 33873,
        33876, 33888, 33921, 33924, 33936, 34048, 34050, 34053, 34056, 34065, 34068, 34080, 34113, 34116, 34128, 34176,
        34186, 34305, 34308, 34320, 34345, 34368, 34816, 34821, 34833, 34836, 34881, 34884, 34896, 34978, 35073, 35076,
        35136, 35173, 35362, 35416, 35418, 35458, 35490, 36865, 36868, 36873, 36880, 36882, 36885, 36888, 36900, 36928,
        36930, 36933, 36936, 36945, 36948, 36960, 36993, 36996, 37008, 37120, 37125, 37137, 37140, 37185, 37188, 37200,
        37210, 37377, 37380, 37392, 37440, 37542, 37888, 37890, 37893, 37896, 37905, 37908, 37920, 37953, 37956, 37968,
        38016, 38038, 38145, 38148, 38160, 38208, 38296, 38305, 38400, 38470, 38500, 38913, 38916, 38928, 38950, 38976,
        39081, 39168, 39241, 39250, 39568, 40960, 40965, 40970, 40980, 40994, 41002, 41025, 41028, 41040, 41122, 41130,
        41280, 41317, 41474, 41482, 41506, 41512, 41514, 41602, 41608, 41610, 41640, 41985, 41988, 42000, 42048, 42121,
        42148, 42240, 42265, 42577, 43018, 43048, 43170, 43348, 43398, 43528, 43530, 43552, 43554, 43560, 43656, 43690,
]
grid_size = 1024
nwant = 1

In [None]:
kMaxQ = 3
GROUP_MAX_EPS = 1e-15
GROUP_MAX_EPS_IQ1_S = 1e-12
GROUP_MAX_EPS_IQ1_M = 1e-7
GROUP_MAX_EPS_IQ2_S = 1e-8

kmap_size = 43692
match grid_size:
    case 256:
        kgrid = kgrid_2bit_256
    case 512:
        kgrid = kgrid_2bit_512
    case 1024:
        kgrid = kgrid_2bit_1024
    case _:
        kgrid = kgrid_1bit_2048

kgrid_q2xs = 2 * ((np.array(kgrid).reshape((-1,1)) >> (2 * np.arange(8))) & 0x3) + 1
kmap_q2xs = np.full(kmap_size, -1)
# 在原来c++的代码逻辑里还要用the_grid变量倒来倒去的，这里直接就省了
for i, k in enumerate(kgrid):
    kmap_q2xs[k] = i

num_neighbors = 0
for i, k in enumerate(kmap_q2xs):
    if k < 0:
        pos = 2 * ((i >> (2 * np.arange(8))) & 0x3) + 1
        dist2 = np.column_stack((np.sum((kgrid_q2xs - pos)**2, axis=-1), np.arange(grid_size)))
        sorted_indices = np.lexsort((dist2[:, 1], dist2[:, 0]))
        dist2 = dist2[sorted_indices]
        nhave = 1
        d2 = dist2[0][0]
        for j,_ in dist2:
            if j > d2:
                if nhave == nwant:
                    break
                d2 = j
                nhave += 1
            num_neighbors += 1

counter = 0
kneighbors_q2xs = np.empty(num_neighbors + kmap_size - grid_size, dtype=np.uint16)
for i, k in enumerate(kmap_q2xs):
    if k < 0:
        pos = 2 * ((i >> (2 * np.arange(8))) & 0x3) + 1
        dist2 = np.column_stack((np.sum((kgrid_q2xs - pos)**2, axis=-1), np.arange(grid_size)))
        sorted_indices = np.lexsort((dist2[:, 1], dist2[:, 0]))
        dist2 = dist2[sorted_indices]
        kmap_q2xs[i] = -(counter + 1)
        nhave = 1
        start = counter
        counter += 1
        d2 = dist2[0][0]
        for j,ii in dist2:
            if j > d2:
                if nhave == nwant:
                    break
                d2 = j
                nhave += 1
            kneighbors_q2xs[counter] = ii
            counter += 1
        kneighbors_q2xs[start] = counter - 1 - start

### IQ1_S (from IQ1_S file)

In [None]:
IQ1S_DELTA = 0.125
x_p = np.array([-1 + IQ1S_DELTA,  IQ1S_DELTA, 1 + IQ1S_DELTA])
x_m = np.array([-1 - IQ1S_DELTA, -IQ1S_DELTA, 1 - IQ1S_DELTA])

def iq1_s_quant(xb, qw, sigma2, Kmap, Kneighbors, Kgrid):
    if qw is not None:
        weight = qw * np.sqrt(sigma2 + xb**2)
    else:
        weight = xb**2

    l_values = np.ones(np.size(xb), dtype=np.int8)
    scale = np.max(np.abs(xb))
    if scale < GROUP_MAX_EPS_IQ1_S:
        return 0.0, l_values

    block_size = np.size(xb)
    indices = np.argsort(xb)
    sumx = np.zeros(np.size(xb)+1)
    sumw = np.zeros(np.size(xb)+1)
    for j in range(block_size):
        i = indices[j]
        sumx[j+1] = sumx[j] + weight[i]*xb[i]
        sumw[j+1] = sumw[j] + weight[i]

    best_score = -np.inf
    besti1 = -1; besti2 = -1; best_shift = 0
    for i1 in range(block_size + 1):
        for i2 in range(i1, block_size + 1):
            sumqx = (sumx[i1] - sumx[0]) * x_p[0] + \
                    (sumx[i2] - sumx[i1]) * x_p[1] + \
                    (sumx[block_size] - sumx[i2]) * x_p[2]
            sumq2 = (sumw[i1] - sumw[0]) * x_p[0]**2 + \
                    (sumw[i2] - sumw[i1]) * x_p[1]**2 + \
                    (sumw[block_size] - sumw[i2]) * x_p[2]**2
            if sumq2 > 0 and sumqx**2 > best_score * sumq2:
                scale = sumqx / sumq2
                best_score = scale * sumqx
                besti1 = i1; besti2 = i2; best_shift = 1
            sumqx = (sumx[i1] - sumx[0]) * x_m[0] + \
                    (sumx[i2] - sumx[i1]) * x_m[1] + \
                    (sumx[block_size] - sumx[i2]) * x_m[2]
            sumq2 = (sumw[i1] - sumw[0]) * x_m[0]**2 + \
                    (sumw[i2] - sumw[i1]) * x_m[1]**2 + \
                    (sumw[block_size] - sumw[i2]) * x_m[2]**2
            if sumq2 > 0 and sumqx**2 > best_score * sumq2:
                scale = sumqx / sumq2
                best_score = scale * sumqx
                besti1 = i1; besti2 = i2; best_shift = -1

    assert besti1 >= 0 and besti2 >= 0 and best_shift != 0

    for j in range(besti1):
        l_values[indices[j]] = 0
    for j in range(besti1, besti2):
        l_values[indices[j]] = 1
    for j in range(besti2, block_size):
        l_values[indices[j]] = 2

    if scale < 0:
        l_values = 2 - l_values
        scale = -scale
        best_shift = - best_shift

    any_miss_grid = False
    xx = x_p if best_shift == 1 else x_m
    index = [Kmap[np.sum(l_values[8*k:8*k+8] << (2 * np.arange(8)))] for k in range(block_size//8)]
    for k in range(block_size // 8):
        if index[k] < 0:
            any_miss_grid = True
            nb_index = -index[k]
            index[k], _ = iq1_find_best_neighbour(
                Kneighbors[nb_index : nb_index + Kneighbors[nb_index - 1]], Kgrid,
                scale, xb[8*k:8*k+8], weight[8*k:8*k+8], xx)
    if any_miss_grid:
        q = [xx[(Kgrid[index[k]] - 1)//2] for k in range(block_size//8)]
        sumqx = np.sum([np.sum(weight[8*k:8*k+8] * q[k] * xb[8*k:8*k+8]) for k in range(block_size//8)])
        sumq2 = np.sum([np.sum(weight[8*k:8*k+8] * q[k]**2) for k in range(block_size//8)])
        if sumqx > 0 and sumq2 > 0:
            scale = sumqx/sumq2
    
    return index, scale

In [None]:
# gdb llama-quantize
# r --imatrix Llama-3.2-1B-Instruct.imatrix Llama-3.2-1B-Instruct-f16.gguf Llama-3.2-1B-Instruct-IQ1_S.gguf IQ1_S 1
# b quantize_row_iq1_s_impl
# c
IQ1S_BLOCK_SIZE = 32
index, scale = iq1_s_quant(xb[:IQ1S_BLOCK_SIZE], qw[:IQ1S_BLOCK_SIZE], 2*sigma2, kmap_q2xs, kneighbors_q2xs, kgrid_q2xs)
print(index)
print(scale)

### IQ1_M (from IQ1_M file)

In [None]:
IQ1M_DELTA = 0.125
x_p = np.array([-1 + IQ1M_DELTA,  IQ1M_DELTA, 1 + IQ1M_DELTA])
x_m = np.array([-1 - IQ1M_DELTA, -IQ1M_DELTA, 1 - IQ1M_DELTA])

def get_inc(w, x, p, m):
    return np.column_stack([w * x * p, w * x * m, w * p**2, w * m**2])

def iq1_m_quant(xb, qw, sigma2, Kmap, Kneighbors, Kgrid):
    if qw is not None:
        weight = qw * np.sqrt(sigma2 + xb**2)
    else:
        weight = xb**2

    l_values = np.ones(np.size(xb), dtype=np.int8)
    scale = np.max(np.abs(xb))
    if scale < GROUP_MAX_EPS_IQ1_M:
        return 0.0, l_values

    block_size = np.size(xb)
    # Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
    # With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two
    # boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights
    # in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and
    # Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale
    # for each possible and score for each split.
    indices = np.argsort(xb)
    best_score = -np.inf
    besti1 = -1; besti2 = -1; best_k = -1

    inc_segments = [get_inc(weight, xb, x_p[i], x_m[i]) for i in range(3)]
    
    for i1 in range(block_size + 1):
        for i2 in range(i1, block_size + 1):
            # 0: +, +
            # 1: +, -
            # 2: -, +
            # 3: -, -
            sumqx = np.zeros(4)
            sumq2 = np.zeros(4)

            segments = [
                (0, i1, 0),
                (i1, i2, 1),
                (i2, block_size, 2)
            ]
            for start, end, seg_idx in segments:
                for j in range(start, end):
                    i = indices[j]
                    inc = inc_segments[seg_idx][i]
                    
                    sumqx[0] += inc[0]; sumqx[3] += inc[1]
                    sumq2[0] += inc[2]; sumq2[3] += inc[3]
                    if i < block_size // 2:
                        sumqx[1] += inc[0]; sumqx[2] += inc[1]
                        sumq2[1] += inc[2]; sumq2[2] += inc[3]
                    else:
                        sumqx[2] += inc[0]; sumqx[1] += inc[1]
                        sumq2[2] += inc[2]; sumq2[1] += inc[3]

            for k in range(4):
                if sumq2[k] > 0 and sumqx[k] * sumqx[k] > best_score * sumq2[k]:
                    scale = sumqx[k] / sumq2[k]
                    best_score = scale * sumqx[k]
                    besti1 = i1; besti2 = i2; best_k = k

    assert besti1 >= 0 and besti2 >= 0 and best_k >= 0

    for j in range(besti1):
        l_values[indices[j]] = 0
    for j in range(besti1, besti2):
        l_values[indices[j]] = 1
    for j in range(besti2, block_size):
        l_values[indices[j]] = 2

    if scale < 0:
        l_values = 2 - l_values
        scale = -scale
        best_k = 3 - best_k

    any_miss_grid = False
    index = [Kmap[np.sum(l_values[8*k:8*k+8] << (2 * np.arange(8)))] for k in range(block_size//8)]
    for k in range(block_size // 8):
        if index[k] < 0:
            any_miss_grid = True
            xx = x_p if (k == 0 and best_k < 2) or (k != 0 and best_k%2 == 0) else x_m
            nb_index = -index[k]
            index[k], _ = iq1_find_best_neighbour(
                Kneighbors[nb_index : nb_index + Kneighbors[nb_index - 1]], Kgrid,
                scale, xb[8*k:8*k+8], weight[8*k:8*k+8], xx)
    if any_miss_grid:
        q = []
        for k in range(block_size//8):
            xx = x_p if (k == 0 and best_k < 2) or (k != 0 and best_k%2 == 0) else x_m
            q.append(xx[(Kgrid[index[k]] - 1)//2])
        sumqx_f = np.sum([np.sum(weight[8*k:8*k+8] * q[k] * xb[8*k:8*k+8]) for k in range(block_size//8)])
        sumq2_f = np.sum([np.sum(weight[8*k:8*k+8] * q[k]**2) for k in range(block_size//8)])
        if sumqx_f > 0 and sumq2_f > 0:
            scale = sumqx_f/sumq2_f
    
    return index, scale

In [None]:
# gdb llama-quantize
# r --imatrix Llama-3.2-1B-Instruct.imatrix Llama-3.2-1B-Instruct-f16.gguf Llama-3.2-1B-Instruct-IQ1_M.gguf IQ1_M 1
# b quantize_row_iq1_m_impl
# c
IQ1M_BLOCK_SIZE = 16
index, scale = iq1_m_quant(xb[:IQ1M_BLOCK_SIZE], qw[:IQ1M_BLOCK_SIZE], 2*sigma2, kmap_q2xs, kneighbors_q2xs, kgrid_q2xs)
print(index)
print(scale)

### IQ2_XXS (from IQ2_XXS file)

In [None]:
def make_qp_quants(nmax, x, quant_weights):
    # 注意.astype(int)和.astype(np.uint8)的用法，不这么用就是和C++源代码输出结果不符的！！！
    max_val = np.max(x)
    if max_val == 0:
        return 0.0
    iscale = nmax / max_val    
    scale = 1 / iscale
    # TODO: -1 -> 255, 是故意的还是不小心的？！
    # L[i] = nearest_int(iscale * x[i]);
    l_values = np.round(iscale * x).astype(np.uint8)
    best_mse = np.sum(quant_weights * (x - scale * l_values) ** 2)
    for is_val in range(-4, 5):
        if is_val != 0:
            iscale_is = (0.1 * is_val + nmax) / max_val
            scale_is = 1 / iscale_is
            l_values = np.minimum(np.round(iscale_is * x), nmax).astype(int)
            mse = np.sum(quant_weights * (x - scale_is * l_values) ** 2)
            if mse < best_mse:
                best_mse = mse
                iscale = iscale_is

    l_values = np.minimum(np.round(iscale * x), nmax).astype(int)
    sumlx = np.sum(quant_weights * x * l_values)
    suml2 = np.sum(quant_weights * l_values**2)
    l_values = l_values.astype(np.uint8)
    
    for _ in range(5):
        n_changed = 0
        for i, xx in enumerate(x):
            w = quant_weights[i]
            slx = sumlx - w * xx * l_values[i]
            sl2 = suml2 - w * l_values[i]**2
            if slx > 0 and sl2 > 0:
                new_l = np.minimum(np.round(xx * sl2 / slx), nmax).astype(int)
                if new_l != l_values[i]:
                    slx += w * xx * new_l
                    sl2 += w * new_l * new_l
                    
                    if (slx**2 * suml2) > (sumlx**2 * sl2):
                        l_values[i] = new_l.astype(np.uint8)
                        sumlx = slx
                        suml2 = sl2
                        n_changed += 1
        if n_changed == 0:
            break
    
    return sumlx / suml2, l_values

In [None]:
def iq2_xxs_quant(xb, qw, sigma2, Kmap, Kneighbors, Kgrid):
    weight = qw * np.sqrt(sigma2 + xb**2)
    waux = np.sqrt(weight)
    xval = np.abs(xb)
    block_signs = [np.sum((xb[8*k:8*k+8] < 0) << np.arange(8)) for k in range(4)]
    nflip = [np.sum(xb[8*k:8*k+8] < 0) for k in range(4)]
    for k in range(4):
        if nflip[k] % 2:
            imin = np.argmin(weight[8*k:8*k+8] * xb[8*k:8*k+8]**2)
            xval[8*k+imin] = -xval[8*k+imin]
            block_signs[k] ^= (1 << imin)
        block_signs[k] = block_signs[k] & 127    
    
    max_xval = np.max(xval)
    if max_xval < GROUP_MAX_EPS:
        return 0.0
    best = 0
    scale, best_l_values = make_qp_quants(kMaxQ+1, xval, weight)
    eff_max = scale*kMaxQ
    for grid_step in range(-6, 7):
        id_val = (2 * kMaxQ - 1 + grid_step * 0.1) / eff_max
        this_scale = 1 / id_val
        l_values = np.clip(np.round(0.5 * (id_val * xval - 1)), 0, kMaxQ - 1).astype(int)
        grid_index = [Kmap[np.sum(l_values[8*k:8*k+8] << (2 * np.arange(8)))] for k in range(4)]
        for k in range(4):
            if grid_index[k] < 0:
                nb_index = -grid_index[k]
                _, l_values[8*k:8*k+8] = find_best_neighbour(
                    Kneighbors[nb_index : nb_index + Kneighbors[nb_index - 1]], Kgrid,
                    this_scale, xval[8*k:8*k+8], waux[8*k:8*k+8])

        sumqx = np.sum(weight * xval * (2 * l_values + 1))
        sumq2 = np.sum(weight * (2 * l_values + 1)**2)

        if sumq2 > 0 and sumqx**2 > best * sumq2:
            scale = sumqx / sumq2
            best  = scale * sumqx
            best_l_values   = l_values.copy()

    if scale > 0:
        id_val = 1 / scale
        l_values = np.clip(np.round(0.5 * (id_val * xval - 1)), 0, kMaxQ - 1).astype(int)
        grid_index = [Kmap[np.sum(l_values[8*k:8*k+8] << (2 * np.arange(8)))] for k in range(4)]
        for k in range(4):
            if grid_index[k] < 0:
                nb_index = -grid_index[k]
                grid_index[k], _ = find_best_neighbour(
                    Kneighbors[nb_index : nb_index + Kneighbors[nb_index - 1]], Kgrid,
                    scale, xval[8*k:8*k+8], waux[8*k:8*k+8])
        best_l_values = np.concatenate([(Kgrid[grid_index[k]] - 1)//2 for k in range(4)])

        sumqx = np.sum(weight * xval * (2 * best_l_values + 1))
        sumq2 = np.sum(weight * (2 * best_l_values + 1)**2)

        if sumq2 > 0:
            scale = sumqx / sumq2

    if scale < 0:
        scale = -scale
        block_signs = [(~b) & 127 for b in block_signs]

    grid_index = [Kmap[np.sum(best_l_values[8*k:8*k+8] << (2 * np.arange(8)))] for k in range(4)]
    assert np.all(np.array(grid_index) >= 0), "Oops: found point not on grid"

    return grid_index, block_signs, scale

In [None]:
# gdb llama-quantize
# r --imatrix Llama-3.2-1B-Instruct.imatrix Llama-3.2-1B-Instruct-f16.gguf Llama-3.2-1B-Instruct-IQ2_XXS.gguf IQ2_XXS 1
# b quantize_row_iq2_xxs_impl
# c
grid_indexs, block_signs, scale = iq2_xxs_quant(xb[:32], qw[:32], sigma2, kmap_q2xs, kneighbors_q2xs, kgrid_q2xs)
print(grid_indexs)
print(block_signs)
print(scale)

### IQ2_XS (from IQ2_XS/IQ2_S file)

In [None]:
def iq2_xs_quant(xb, qw, sigma2, Kmap, Kneighbors, Kgrid):
    weight = qw * np.sqrt(sigma2 + xb**2)
    waux = np.sqrt(weight)
    xval = np.abs(xb)
    block_signs = [np.sum((xb[8*k:8*k+8] < 0) << np.arange(8)) for k in range(2)]
    nflip = [np.sum(xb[8*k:8*k+8] < 0) for k in range(2)]
    for k in range(2):
        if nflip[k] % 2:
            imin = np.argmin(weight[8*k:8*k+8] * xb[8*k:8*k+8]**2)
            xval[8*k+imin] = -xval[8*k+imin]
            block_signs[k] ^= (1 << imin)
        block_signs[k] = block_signs[k] & 127

    max_xval = np.max(xval)
    if max_xval < GROUP_MAX_EPS:
        return 0.0

    best = 0
    best_grid_index = [None for _ in range(2)]
    best_l_values = np.zeros(np.size(xb), dtype=int)

    scale = max_xval / (2 * kMaxQ - 1)
    for grid_step in range(-9, 10):
        id_val = (2 * kMaxQ - 1 + grid_step * 0.1) / max_xval
        this_scale = 1 / id_val
        l_values = np.clip(np.round(0.5 * (id_val * xval - 1)), 0, kMaxQ - 1).astype(int)
        grid_index = [Kmap[np.sum(l_values[8*k:8*k+8] << (2 * np.arange(8)))] for k in range(2)]
        for k in range(2):
            if grid_index[k] < 0:
                nb_index = -grid_index[k]
                _, l_values[8*k:8*k+8] = find_best_neighbour(
                    Kneighbors[nb_index : nb_index + Kneighbors[nb_index - 1]], Kgrid,
                    this_scale, xval[8*k:8*k+8], waux[8*k:8*k+8])

        sumqx = np.sum(weight * xval * (2 * l_values + 1))
        sumq2 = np.sum(weight * (2 * l_values + 1)**2)

        if sumq2 > 0 and sumqx**2 > best * sumq2:
            scale = sumqx / sumq2
            best  = scale * sumqx
            best_grid_index = grid_index.copy()
            best_l_values   = l_values.copy()

    assert best_grid_index[0]
    if scale > 0 and np.any(np.array(best_grid_index) < 0):
        for k in range(2):
            if best_grid_index[k] < 0:
                l_values = np.clip(np.round(0.5 * (1 / scale * xval[8*k:8*k+8] - 1)), 0, kMaxQ - 1).astype(int)
                grid_index = Kmap[np.sum(l_values << (2 * np.arange(8)))]
                if grid_index < 0:
                    nb_index = -grid_index
                    _, best_l_values[8*k:8*k+8] = find_best_neighbour(
                        Kneighbors[nb_index : nb_index + Kneighbors[nb_index - 1]], Kgrid,
                        scale, xval[8*k:8*k+8], waux[8*k:8*k+8])
                else:
                    best_l_values[8*k:8*k+8] = l_values

        sumqx = np.sum(weight * xval * (2 * best_l_values + 1))
        sumq2 = np.sum(weight * (2 * best_l_values + 1)**2)

        if sumq2 > 0:
            scale = sumqx / sumq2

    if scale < 0:
        scale = -scale
        block_signs = [(~b) & 127 for b in block_signs]


    grid_index = [Kmap[np.sum(best_l_values[8*k:8*k+8] << (2 * np.arange(8)))] for k in range(2)]
    assert np.all(np.array(grid_index) >= 0), "Oops: found point not on grid"

    return grid_index, block_signs, scale

In [None]:
# gdb llama-quantize
# r --imatrix Llama-3.2-1B-Instruct.imatrix Llama-3.2-1B-Instruct-f16.gguf Llama-3.2-1B-Instruct-IQ2_XS.gguf IQ2_XS 1
# b quantize_row_iq2_xs_impl
# c
grid_indexs, block_signs, scale = iq2_xs_quant(xb[:16], qw[:16], sigma2, kmap_q2xs, kneighbors_q2xs, kgrid_q2xs)
print(grid_indexs)
print(block_signs)
print(scale)

### IQ2_S (from IQ2_M file)

In [None]:
def iq2_s_quant(xb, qw, sigma2, Kmap, Kneighbors, Kgrid):
    if qw is not None:
        weight = qw * np.sqrt(sigma2 + xb**2)
    else:
        weight = 0.25 * sigma2 + xb**2

    waux = np.sqrt(weight)
    xval = np.abs(xb)
    block_signs = [np.sum((xb[8*k:8*k+8] < 0) << np.arange(8)) for k in range(2)]
    max_xval = np.max(xval)
    if max_xval < GROUP_MAX_EPS_IQ2_S:
        return 0.0

    best = 0
    best_grid_index = [None for _ in range(2)]
    best_l_values = np.zeros(np.size(xb), dtype=int)

    scale = max_xval / (2 * kMaxQ - 1)
    for grid_step in range(-9, 10):
        id_val = (2 * kMaxQ - 1 + grid_step * 0.1) / max_xval
        this_scale = 1 / id_val
        l_values = np.clip(np.round(0.5 * (id_val * xval - 1)), 0, kMaxQ - 1).astype(int)
        grid_index = [Kmap[np.sum(l_values[8*k:8*k+8] << (2 * np.arange(8)))] for k in range(2)]
        for k in range(2):
            if grid_index[k] < 0:
                nb_index = -grid_index[k]
                _, l_values[8*k:8*k+8] = find_best_neighbour(
                    Kneighbors[nb_index : nb_index + Kneighbors[nb_index - 1]], Kgrid,
                    this_scale, xval[8*k:8*k+8], waux[8*k:8*k+8])

        sumqx = np.sum(weight * xval * (2 * l_values + 1))
        sumq2 = np.sum(weight * (2 * l_values + 1)**2)

        if sumq2 > 0 and sumqx**2 > best * sumq2:
            scale = sumqx / sumq2
            best  = scale * sumqx
            best_grid_index = grid_index.copy()
            best_l_values   = l_values.copy()

    assert best_grid_index[0]
    if scale > 0 and np.any(np.array(best_grid_index) < 0):
        for k in range(2):
            if best_grid_index[k] < 0:
                l_values = np.clip(np.round(0.5 * (1 / scale * xval[8*k:8*k+8] - 1)), 0, kMaxQ - 1).astype(int)
                grid_index = Kmap[np.sum(l_values << (2 * np.arange(8)))]
                if grid_index < 0:
                    nb_index = -grid_index
                    _, best_l_values[8*k:8*k+8] = find_best_neighbour(
                        Kneighbors[nb_index : nb_index + Kneighbors[nb_index - 1]], Kgrid,
                        scale, xval[8*k:8*k+8], waux[8*k:8*k+8])
                else:
                    best_l_values[8*k:8*k+8] = l_values

        sumqx = np.sum(weight * xval * (2 * best_l_values + 1))
        sumq2 = np.sum(weight * (2 * best_l_values + 1)**2)

        if sumq2 > 0:
            scale = sumqx / sumq2

    if scale < 0:
        scale = -scale
        block_signs = [~b for b in block_signs]

    grid_index = [Kmap[np.sum(best_l_values[8*k:8*k+8] << (2 * np.arange(8)))] for k in range(2)]
    assert np.all(np.array(grid_index) >= 0), "Oops: found point not on grid"

    return grid_index, block_signs, scale

In [None]:
# gdb llama-quantize
# r --imatrix Llama-3.2-1B-Instruct.imatrix Llama-3.2-1B-Instruct-f16.gguf Llama-3.2-1B-Instruct-IQ2_M.gguf IQ2_M 1
# b quantize_row_iq2_s_impl
# c
grid_indexs, block_signs, scale = iq2_s_quant(xb[:16], qw[:16], 2*sigma2, kmap_q2xs, kneighbors_q2xs, kgrid_q2xs)
print(grid_indexs)
print(block_signs)
print(scale)

## IQ3

In [None]:
kgrid_256 = [
            0,     2,     4,     9,    11,    15,    16,    18,    25,    34,    59,    61,    65,    67,    72,    74,
           81,    85,    88,    90,    97,   108,   120,   128,   130,   132,   137,   144,   146,   153,   155,   159,
          169,   175,   189,   193,   199,   200,   202,   213,   248,   267,   287,   292,   303,   315,   317,   321,
          327,   346,   362,   413,   436,   456,   460,   462,   483,   497,   513,   515,   520,   522,   529,   531,
          536,   538,   540,   551,   552,   576,   578,   585,   592,   594,   641,   643,   648,   650,   657,   664,
          698,   704,   706,   720,   729,   742,   758,   769,   773,   808,   848,   852,   870,   889,   901,   978,
          992,  1024,  1026,  1033,  1035,  1040,  1042,  1046,  1049,  1058,  1089,  1091,  1093,  1096,  1098,  1105,
         1112,  1139,  1143,  1144,  1152,  1154,  1161,  1167,  1168,  1170,  1183,  1184,  1197,  1217,  1224,  1228,
         1272,  1276,  1309,  1323,  1347,  1367,  1377,  1404,  1473,  1475,  1486,  1509,  1537,  1544,  1546,  1553,
         1555,  1576,  1589,  1594,  1600,  1602,  1616,  1625,  1636,  1638,  1665,  1667,  1672,  1685,  1706,  1722,
         1737,  1755,  1816,  1831,  1850,  1856,  1862,  1874,  1901,  1932,  1950,  1971,  2011,  2032,  2052,  2063,
         2077,  2079,  2091,  2095,  2172,  2192,  2207,  2208,  2224,  2230,  2247,  2277,  2308,  2345,  2356,  2389,
         2403,  2424,  2501,  2504,  2506,  2520,  2570,  2593,  2616,  2624,  2630,  2646,  2669,  2700,  2714,  2746,
         2754,  2795,  2824,  2835,  2839,  2874,  2882,  2905,  2984,  3028,  3042,  3092,  3108,  3110,  3124,  3153,
         3185,  3215,  3252,  3288,  3294,  3364,  3397,  3434,  3483,  3523,  3537,  3587,  3589,  3591,  3592,  3610,
         3626,  3670,  3680,  3722,  3749,  3754,  3776,  3789,  3803,  3824,  3857,  3873,  3904,  3906,  3924,  3992,
]
grid_size = 256
kmap_size = 4096
nwant = 2

In [None]:
kgrid_512 = [
            0,     1,     2,     5,     7,     8,     9,    10,    12,    14,    16,    17,    21,    27,    32,    34,
           37,    39,    41,    43,    48,    50,    57,    60,    63,    64,    65,    66,    68,    72,    73,    77,
           80,    83,    87,    89,    93,   100,   113,   117,   122,   128,   129,   133,   135,   136,   139,   142,
          145,   149,   152,   156,   162,   165,   167,   169,   171,   184,   187,   195,   201,   205,   208,   210,
          217,   219,   222,   228,   232,   234,   247,   249,   253,   256,   267,   271,   273,   276,   282,   288,
          291,   297,   312,   322,   324,   336,   338,   342,   347,   353,   357,   359,   374,   379,   390,   393,
          395,   409,   426,   441,   448,   450,   452,   464,   466,   470,   475,   488,   492,   512,   513,   514,
          516,   520,   521,   523,   525,   527,   528,   530,   537,   540,   542,   556,   558,   561,   570,   576,
          577,   579,   582,   584,   588,   593,   600,   603,   609,   616,   618,   632,   638,   640,   650,   653,
          655,   656,   660,   666,   672,   675,   685,   688,   698,   705,   708,   711,   712,   715,   721,   727,
          728,   732,   737,   754,   760,   771,   773,   778,   780,   793,   795,   802,   806,   808,   812,   833,
          840,   843,   849,   856,   858,   873,   912,   916,   919,   932,   934,   961,   963,   968,   970,   977,
          989,   993,  1010,  1016,  1024,  1025,  1027,  1029,  1031,  1032,  1034,  1036,  1038,  1041,  1043,  1047,
         1048,  1050,  1057,  1059,  1061,  1064,  1066,  1079,  1080,  1083,  1085,  1088,  1090,  1096,  1099,  1103,
         1106,  1109,  1113,  1116,  1122,  1129,  1153,  1156,  1159,  1169,  1171,  1176,  1183,  1185,  1195,  1199,
         1209,  1212,  1216,  1218,  1221,  1225,  1234,  1236,  1241,  1243,  1250,  1256,  1270,  1281,  1287,  1296,
         1299,  1306,  1309,  1313,  1338,  1341,  1348,  1353,  1362,  1375,  1376,  1387,  1400,  1408,  1410,  1415,
         1425,  1453,  1457,  1477,  1481,  1494,  1496,  1507,  1512,  1538,  1545,  1547,  1549,  1551,  1554,  1561,
         1563,  1565,  1570,  1572,  1575,  1577,  1587,  1593,  1601,  1603,  1605,  1612,  1617,  1619,  1632,  1648,
         1658,  1662,  1664,  1674,  1680,  1690,  1692,  1704,  1729,  1736,  1740,  1745,  1747,  1751,  1752,  1761,
         1763,  1767,  1773,  1787,  1795,  1801,  1806,  1810,  1817,  1834,  1840,  1844,  1857,  1864,  1866,  1877,
         1882,  1892,  1902,  1915,  1934,  1953,  1985,  1987,  2000,  2002,  2013,  2048,  2052,  2058,  2064,  2068,
         2071,  2074,  2081,  2088,  2104,  2114,  2119,  2121,  2123,  2130,  2136,  2141,  2147,  2153,  2157,  2177,
         2179,  2184,  2189,  2193,  2203,  2208,  2223,  2226,  2232,  2244,  2249,  2251,  2256,  2258,  2265,  2269,
         2304,  2306,  2324,  2335,  2336,  2361,  2373,  2375,  2385,  2418,  2443,  2460,  2480,  2504,  2509,  2520,
         2531,  2537,  2562,  2568,  2572,  2578,  2592,  2596,  2599,  2602,  2614,  2620,  2625,  2627,  2629,  2634,
         2641,  2650,  2682,  2688,  2697,  2707,  2712,  2718,  2731,  2754,  2759,  2760,  2775,  2788,  2793,  2805,
         2811,  2817,  2820,  2832,  2842,  2854,  2890,  2902,  2921,  2923,  2978,  3010,  3012,  3026,  3081,  3083,
         3085,  3097,  3099,  3120,  3136,  3152,  3159,  3188,  3210,  3228,  3234,  3245,  3250,  3256,  3264,  3276,
         3281,  3296,  3349,  3363,  3378,  3392,  3395,  3420,  3440,  3461,  3488,  3529,  3531,  3584,  3588,  3591,
         3600,  3602,  3614,  3616,  3628,  3634,  3650,  3657,  3668,  3683,  3685,  3713,  3716,  3720,  3726,  3729,
         3736,  3753,  3778,  3802,  3805,  3819,  3841,  3845,  3851,  3856,  3880,  3922,  3938,  3970,  3993,  4032,
]
grid_size = 512
kmap_size = 4096
nwant = 3

In [None]:
kMaxQ = 8
GROUP_MAX_EPS_IQ3_XXS = 1e-8

kgrid = kgrid_256 if grid_size == 256 else kgrid_512
kgrid_q3xs = 2 * ((np.array(kgrid).reshape((-1,1)) >> (3 * np.arange(4))) & 0x7) + 1
kmap_q3xs = np.full(kmap_size, -1)
# 在原来c++的代码逻辑里还要用the_grid变量倒来倒去的，这里直接就省了
for i, k in enumerate(kgrid):
    kmap_q3xs[k] = i

num_neighbors = 0
for i, k in enumerate(kmap_q3xs):
    if k < 0:
        pos = 2 * ((i >> (3 * np.arange(4))) & 0x7) + 1
        dist2 = np.column_stack((np.sum((kgrid_q3xs - pos)**2, axis=-1), np.arange(grid_size)))
        sorted_indices = np.lexsort((dist2[:, 1], dist2[:, 0]))
        dist2 = dist2[sorted_indices]
        nhave = 1
        d2 = dist2[0][0]
        for j,_ in dist2:
            if j > d2:
                if nhave == nwant:
                    break
                d2 = j
                nhave += 1
            num_neighbors += 1

counter = 0
kneighbors_q3xs = np.empty(num_neighbors + kmap_size - grid_size, dtype=np.uint16)
for i, k in enumerate(kmap_q3xs):
    if k < 0:
        pos = 2 * ((i >> (3 * np.arange(4))) & 0x7) + 1
        dist2 = np.column_stack((np.sum((kgrid_q3xs - pos)**2, axis=-1), np.arange(grid_size)))
        sorted_indices = np.lexsort((dist2[:, 1], dist2[:, 0]))
        dist2 = dist2[sorted_indices]
        kmap_q3xs[i] = -(counter + 1)
        nhave = 1
        start = counter
        counter += 1
        d2 = dist2[0][0]
        for j,ii in dist2:
            if j > d2:
                if nhave == nwant:
                    break
                d2 = j
                nhave += 1
            kneighbors_q3xs[counter] = ii
            counter += 1
        kneighbors_q3xs[start] = counter - 1 - start

### IQ3_XXS (from IQ3_XXS/IQ3_XS)

In [None]:
def iq3_xxs_quant(xb, qw, sigma2, Kmap, Kneighbors, Kgrid):
    if qw is not None:
        weight = qw * np.sqrt(sigma2 + xb**2)
    else:
        weight = xb**2

    waux = np.sqrt(weight)
    xval = np.abs(xb)
    block_signs = [np.sum((xb[8*k:8*k+8] < 0) << np.arange(8)) for k in range(4)]
    nflip = [np.sum(xb[8*k:8*k+8] < 0) for k in range(4)]
    for k in range(4):
        if nflip[k] % 2:
            imin = np.argmin(weight[8*k:8*k+8] * xb[8*k:8*k+8]**2)
            xval[8*k+imin] = -xval[8*k+imin]
            block_signs[k] ^= (1 << imin)
        block_signs[k] = block_signs[k] & 127
    
    max_xval = np.max(xval)
    if max_xval < GROUP_MAX_EPS_IQ3_XXS:
        return 0.0

    best = 0
    best_grid_index = [None for _ in range(8)]
    best_l_values = np.zeros(np.size(xb), dtype=int)

    scale = max_xval / (2 * kMaxQ - 1)
    for grid_step in range(-15, 16):
        id_val = (2 * kMaxQ - 1 + grid_step * 0.2) / max_xval
        this_scale = 1 / id_val
        l_values = np.clip(np.round(0.5 * (id_val * xval - 1)), 0, kMaxQ - 1).astype(int)
        grid_index = [Kmap[np.sum(l_values[4*k:4*k+4] << (3 * np.arange(4)))] for k in range(8)]
        for k in range(8):
            if grid_index[k] < 0:
                nb_index = -grid_index[k]
                _, l_values[4*k:4*k+4] = find_best_neighbour(
                    Kneighbors[nb_index : nb_index + Kneighbors[nb_index - 1]], Kgrid,
                    this_scale, xval[4*k:4*k+4], waux[4*k:4*k+4])

        sumqx = np.sum(weight * xval * (2 * l_values + 1))
        sumq2 = np.sum(weight * (2 * l_values + 1)**2)

        if sumq2 > 0 and sumqx**2 > best * sumq2:
            scale = sumqx / sumq2
            best  = scale * sumqx
            best_grid_index = grid_index.copy()
            best_l_values   = l_values.copy()

    assert best_grid_index[0]
    if scale > 0 and np.any(np.array(best_grid_index) < 0):
        id_val = 1 / scale
        for k in range(8):
            if best_grid_index[k] < 0:
                l_values = np.clip(np.round(0.5 * (1 / scale * xval[4*k:4*k+4] - 1)), 0, kMaxQ - 1).astype(int)
                grid_index = Kmap[np.sum(l_values << (3 * np.arange(4)))]
                if grid_index < 0:
                    nb_index = -grid_index
                    grid_index, _ = find_best_neighbour(
                        Kneighbors[nb_index : nb_index + Kneighbors[nb_index - 1]], Kgrid,
                        scale, xval[4*k:4*k+4], waux[4*k:4*k+4])
                best_l_values[4*k:4*k+4] = (Kgrid[grid_index] - 1)//2 # TODO

        sumqx = np.sum(weight * xval * (2 * best_l_values + 1))
        sumq2 = np.sum(weight * (2 * best_l_values + 1)**2)

        if sumq2 > 0:
            scale = sumqx / sumq2

    if scale < 0:
        scale = -scale
        block_signs = [(~b) & 127 for b in block_signs]

    grid_index = [Kmap[np.sum(best_l_values[4*k:4*k+4] << (3 * np.arange(4)))] for k in range(8)]
    assert np.all(np.array(grid_index)>= 0), "Oops: found point not on grid"

    return grid_index, block_signs, scale

In [None]:
# gdb llama-quantize
# r --imatrix Llama-3.2-1B-Instruct.imatrix Llama-3.2-1B-Instruct-f16.gguf Llama-3.2-1B-Instruct-IQ3_XS.gguf IQ3_XS 1
# b quantize_row_iq3_xxs_impl
# c
grid_indexs, block_signs, scale = iq3_xxs_quant(xb[:32], qw[:32], 2*sigma2, kmap_q3xs, kneighbors_q3xs, kgrid_q3xs)
print(grid_indexs)
print(block_signs)
print(scale)

### IQ3_S (from IQ3_S/IQ3_M)

In [None]:
def iq3_s_quant(xb, qw, sigma2, Kmap, Kneighbors, Kgrid):
    if qw is not None:
        weight = qw * np.sqrt(sigma2 + xb**2)
    else:
        weight = xb**2

    waux = np.sqrt(weight)
    xval = np.abs(xb)
    block_signs = [np.sum((xb[8*k:8*k+8] < 0) << np.arange(8)) for k in range(4)]
    max_xval = np.max(xval)
    if max_xval == 0.0:
        return 0.0

    best = 0
    best_grid_index = [None for _ in range(8)]
    best_l_values = np.zeros(np.size(xb), dtype=int)

    scale = max_xval / (2 * kMaxQ - 1)
    for grid_step in range(-9, 10):
        id_val = (2 * kMaxQ - 1 + grid_step * 0.2) / max_xval
        this_scale = 1 / id_val
        l_values = np.clip(np.round(0.5 * (id_val * xval - 1)), 0, kMaxQ - 1).astype(int)
        grid_index = [Kmap[np.sum(l_values[4*k:4*k+4] << (3 * np.arange(4)))] for k in range(8)]
        for k in range(8):
            if grid_index[k] < 0:
                nb_index = -grid_index[k]
                _, l_values[4*k:4*k+4] = find_best_neighbour(
                    Kneighbors[nb_index : nb_index + Kneighbors[nb_index - 1]], Kgrid,
                    this_scale, xval[4*k:4*k+4], waux[4*k:4*k+4])

        sumqx = np.sum(weight * xval * (2 * l_values + 1))
        sumq2 = np.sum(weight * (2 * l_values + 1)**2)

        if sumq2 > 0 and sumqx**2 > best * sumq2:
            scale = sumqx / sumq2
            best  = scale * sumqx
            best_grid_index = grid_index.copy()
            best_l_values   = l_values.copy()

    assert best_grid_index[0]
    if scale > 0 and np.any(np.array(best_grid_index) < 0):
        id_val = 1 / scale
        # //if (is_on_grid[k]) continue; IQ2 is used this code!!!
        best_l_values = np.clip(np.round(0.5 * (id_val * xval - 1)), 0, kMaxQ - 1).astype(int)
        grid_index = [Kmap[np.sum(best_l_values[4*k:4*k+4] << (3 * np.arange(4)))] for k in range(8)]
        for k in range(8):
            if grid_index[k] < 0:
                nb_index = -grid_index[k]
                grid_index[k], _ = find_best_neighbour(
                    Kneighbors[nb_index : nb_index + Kneighbors[nb_index - 1]], Kgrid,
                    scale, xval[4*k:4*k+4], waux[4*k:4*k+4])
        best_l_values = np.concatenate([(Kgrid[grid_index[k]] - 1)//2 for k in range(8)])

        sumqx = np.sum(weight * xval * (2 * best_l_values + 1))
        sumq2 = np.sum(weight * (2 * best_l_values + 1)**2)

        if sumq2 > 0:
            scale = sumqx / sumq2

    if scale < 0:
        scale = -scale
        block_signs = [~b for b in block_signs]

    grid_index = [Kmap[np.sum(best_l_values[4*k:4*k+4] << (3 * np.arange(4)))] for k in range(8)]
    assert np.all(np.array(grid_index)>= 0), "Oops: found point not on grid"

    return grid_index, block_signs, scale

In [None]:
# gdb llama-quantize
# r --imatrix Llama-3.2-1B-Instruct.imatrix Llama-3.2-1B-Instruct-f16.gguf Llama-3.2-1B-Instruct-IQ3_S.gguf IQ3_S 1
# b quantize_row_iq3_s_impl
# c
grid_indexs, block_signs, scale = iq3_s_quant(xb[:32], qw[:32], 2*sigma2, kmap_q3xs, kneighbors_q3xs, kgrid_q3xs)
print(grid_indexs)
print(block_signs)
print(scale)

## Q4_K/Q5_K

In [None]:
def make_qkx_quants(nmax: int, x: np.ndarray, weights: np.ndarray, rmin, rdelta, nstep, use_mad: bool = False):
    min_val = min(np.min(x), 0)
    max_val = np.max(x)
    if max_val == min_val:
        return 0.0, -min_val

    iscale = nmax / (max_val - min_val)
    scale = 1 / iscale
    if nstep < 2:
        return scale, -min_val
    
    l_values = np.clip(np.round(iscale * (x - min_val)), 0, nmax)
    diffs = scale * l_values + min_val - x
    diffs = np.abs(diffs) if use_mad else diffs ** 2
    best_mad = np.sum(weights * diffs)
    
    sum_w  = np.sum(weights)
    sum_wx = np.sum(weights * x)
    
    for step in range(nstep):
        iscale = (rmin + rdelta * step + nmax) / (max_val - min_val)
        l_values = np.clip(np.round(iscale * (x - min_val)), 0, nmax)
        sum_wl  = np.sum(weights * l_values)
        sum_wll = np.sum(weights * l_values **2)
        sum_wlx = np.sum(weights * l_values * x)
        
        D = sum_w * sum_wll - sum_wl**2
        if D > 0:
            this_scale = (sum_w * sum_wlx - sum_wx * sum_wl) / D
            this_min = (sum_wll * sum_wx - sum_wl * sum_wlx) / D
            if this_min > 0:
                this_min = 0
                this_scale = sum_wlx / sum_wll

            diffs = this_scale * l_values + this_min - x
            diffs = np.abs(diffs) if use_mad else diffs ** 2
            mad = np.sum(weights * diffs)
            
            if mad < best_mad:
                best_mad = mad
                scale = this_scale
                min_val = this_min
    
    return scale, -min_val

In [None]:
# gdb llama-quantize
# r --imatrix Llama-3.2-1B-Instruct.imatrix Llama-3.2-1B-Instruct-f16.gguf Llama-3.2-1B-Instruct-Q4_K_M.gguf Q4_K_M 1
# b quantize_row_q4_K_impl
# b make_qkx3_quants
# c
# c
# fin
weights = qw[:32] * np.sqrt(2*sigma2 + xb[:32]**2)
scale, min_val = make_qkx_quants(15, xb[:32], weights, -0.9, 0.05, 37)
print(scale, min_val)

In [None]:
# gdb llama-quantize
# r --imatrix Llama-3.2-1B-Instruct.imatrix Llama-3.2-1B-Instruct-f16.gguf Llama-3.2-1B-Instruct-Q5_K_M.gguf Q5_K_M 1
# b quantize_row_q5_K_impl
# b make_qkx3_quants
# c
# c
# fin
weights = qw[:32] * np.sqrt(2*sigma2 + xb[:32]**2)
scale, min_val = make_qkx_quants(31, xb[:32], weights, -0.9, 0.05, 37)
print(scale, min_val)

# Model file

In [None]:
# model_name = "Llama-3.2-1B-Instruct"
# model_name = "Llama-3.1-8B-Instruct"
# model_name = "Llama-3.3-70B-Instruct"
# model_name = "Llama-3_3-Nemotron-Super-49B-v1_5"
# model_name = "DeepSeek-R1-Distill-Llama-70B"
# model_name = "DeepSeek-R1-Distill-Qwen-14B"
# model_name = "DeepSeek-R1-Distill-Qwen-32B"
# model_name = "Qwen3-0.6B"
# model_name = "Qwen3-8B"
# model_name = "Qwen3-14B"
model_name = "Qwen3-32B"
# model_name = "gemma-3-12b-it"
# model_name = "gemma-3-27b-it"
# model_name = "phi-4"
# model_name = "Phi-4-reasoning-plus"
# model_name = "Mistral-Small-3.2-24B-Instruct-2506"
# model_name = "Magistral-Small-2509"

bf16_gguf_file = f"models/{model_name}-BF16.gguf"
assert os.path.exists(bf16_gguf_file)
bf16_reader = GGUFReader(bf16_gguf_file, 'r')

## SOTA | EXL3

In [None]:
def is_iq2ks(tensor): return False

# (ik_)llama.cpp
# acted also as base template model
quant_type = "IQ2_K"
test_gguf_file = f"models/{model_name}-{quant_type}.gguf"
gguf_reader = GGUFReader(test_gguf_file, 'r')

use_gguf = True

------

In [None]:
# exl3
bpw = 2.25
quant_type = f"exl3-{str(bpw)}bpw"
test_model = f"models/{model_name}-{quant_type}"
config = Config.from_directory(test_model)
reader = Model.from_config(config)
load_tensor = load_exl3_tensor

use_gguf = False

In [None]:
do_dequant, is_residual, add_residual = False, False, False
if use_gguf:
    output_gguf = f"models/{model_name}-{quant_type[3:]}.gguf"
else:
    test_path = Path(test_model)
    safetensors_files = sorted([str(file) for file in list(test_path.glob("*.safetensors"))])
    model_safetensors = {}
    for sf in safetensors_files:
        with safe_open(sf, framework="pt") as f:
            for key in f.keys():
                model_safetensors[key] = f.get_tensor(key)
    output_gguf = f"models/{model_name}-{quant_type}.gguf"

------

In [None]:
do_dequant, is_residual, add_residual = True, False, False
output_gguf = f"models/{model_name}-{quant_type}-dequant.gguf"

In [None]:
do_dequant, is_residual, add_residual = True, True, False
output_gguf = f"models/{model_name}-{quant_type}-residual.gguf"

In [None]:
do_dequant, is_residual, add_residual = True, False, True
output_gguf = f"models/{model_name}-{quant_type}-IQ2_KS.gguf"
residual_reader = GGUFReader(f"models/{model_name}-{quant_type}-residual-IQ2_KS.gguf", 'r')

## Transpose

In [None]:
!gcc -shared -o lib_transpose.so -fPIC transpose.c -Ofast

In [None]:
from ctypes import cdll, c_int, c_bool
from numpy.ctypeslib import ndpointer

lib = cdll.LoadLibrary('./lib_transpose.so')
lib.do_my_transpose.argtypes = [
    ndpointer(dtype=np.uint8, flags="C_CONTIGUOUS"), # dst
    ndpointer(dtype=np.uint8, flags="C_CONTIGUOUS"), # src
    c_int,  # rows
    c_int,  # cols
    c_int,  # stride
    c_int,  # avxlen
    c_bool, # full_trans
]

In [None]:
stride = 256
avxlen = 512
full_trans = True
# gguf_reader = GGUFReader(f"models/{model_name}-IQ2_KS.gguf", 'r')
gguf_reader = GGUFReader("tmp.bin", 'r')
output_gguf = f"models/{model_name}-IQ2_KS-FT-avx{avxlen}.gguf"
def is_iq2ks(tensor): return tensor.tensor_type == GGMLQuantizationType.IQ2_KS
use_gguf = True

In [None]:
for tensor in gguf_reader.tensors:
    if is_iq2ks(tensor) and tensor.name == "blk.0.attn_q.weight":
        nx, ny = tensor.data.shape
        data_t = np.empty((nx, ny), np.uint8)
        lib.do_my_transpose(data_t, tensor.data, nx, ny, 256, 256, True)
        print(tensor.data)
        print(tensor.data.dtype, data_t.dtype, nx, ny, 256, 256, True)
        data_t = data_t.reshape(1, nx*ny)
        print(data_t)
        break

## Write a new GGUF file

### Setting

In [None]:
from tqdm import tqdm
from typing import Any, Sequence, NamedTuple
from gguf import GGUFWriter
from gguf.constants import GGMLQuantizationType
from scripts.gguf_new_metadata import decode_field, get_byteorder, get_field_data, MetadataDetails

def tensor_is_2d(name):
    return 'blk.' in name[:4] and '_norm' not in name and '.bias' not in name

def tensor_is_misc(name):
    return '_norm' in name or '.bias' in name or 'blk.' not in name[:4]

In [None]:
layer_prefix = 'language_model.model.layers' if 'gemma-3-' in model_name or 'istral-' in model_name else 'model.layers'

if 'gemma-3' in model_name:
    name_map = {
        'attn_q'     : 'self_attn.q_proj',
        'attn_k'     : 'self_attn.k_proj',
        'attn_v'     : 'self_attn.v_proj',
        'attn_output': 'self_attn.o_proj',
        'ffn_down'   : 'mlp.down_proj',
        'ffn_gate'   : 'mlp.gate_proj',
        'ffn_up'     : 'mlp.up_proj',
        'attn_norm'          : 'input_layernorm',
        'attn_k_norm': 'self_attn.k_norm',
        'attn_q_norm': 'self_attn.q_norm',
        'ffn_norm'           : 'pre_feedforward_layernorm',
        'post_attention_norm': 'post_attention_layernorm',
        'post_ffw_norm'      : 'post_feedforward_layernorm',
    }
elif 'hi-' in model_name or 'GLM-' in model_name:
    name_map = {
        'attn_output': 'self_attn.o_proj',
        'ffn_up'     : ['mlp.gate_proj', 'mlp.up_proj'],
        'ffn_down'   : 'mlp.down_proj',
        'attn_norm'  : 'input_layernorm',
        'ffn_norm'   : 'post_attention_layernorm',
        # for GLM
        'attn_q'     : 'self_attn.q_proj',
        'attn_k'     : 'self_attn.k_proj',
        'attn_v'     : 'self_attn.v_proj',
        'post_ffw_norm'      : 'post_mlp_layernorm',
        'post_attention_norm': 'post_self_attn_layernorm',
        # for Phi
        'attn_qkv'   : ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj'],

    }
else: # llama/qwen/mistral
    name_map = {
        'attn_q'     : 'self_attn.q_proj',
        'attn_k'     : 'self_attn.k_proj',
        'attn_v'     : 'self_attn.v_proj',
        'attn_output': 'self_attn.o_proj',
        'ffn_down'   : 'mlp.down_proj',
        'ffn_gate'   : 'mlp.gate_proj',
        'ffn_up'     : 'mlp.up_proj',    
        'attn_norm'  : 'input_layernorm',
        'ffn_norm'   : 'post_attention_layernorm',
        # for Qwen3
        'attn_k_norm': 'self_attn.k_norm',
        'attn_q_norm': 'self_attn.q_norm',
    }

model_name

-----
### 

In [None]:
# first run setting below
arch = get_field_data(gguf_reader, gguf.Keys.General.ARCHITECTURE)
writer = gguf.GGUFWriter(output_gguf, arch=arch, endianess=get_byteorder(gguf_reader))
alignment = get_field_data(gguf_reader, gguf.Keys.General.ALIGNMENT)
assert alignment is None

for field in gguf_reader.fields.values():
    # Suppress virtual fields and fields written by GGUFWriter
    if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
        continue
    val = MetadataDetails(field.types[0], decode_field(field))
    if val.value is not None:
        writer.add_key_value(field.name, val.value, val.type)

total_bytes = 0

for tensor in gguf_reader.tensors:
    if is_iq2ks(tensor):
        total_bytes += tensor.n_bytes
        writer.add_tensor_info(tensor.name, tensor.data.shape, tensor.data.dtype, tensor.data.nbytes, GGMLQuantizationType.IQ2_KS_T)
    elif do_dequant and tensor_is_2d(tensor.name):
        dim1, dim2 = gguf.quant_shape_from_byte_shape(tensor.data.shape, tensor.tensor_type)
        nbytes = dim1 * dim2 * 2
        writer.add_tensor_info(tensor.name, (dim1, dim2), np.float16, nbytes, GGMLQuantizationType.F16)
        total_bytes += nbytes
    elif not is_residual:
        if do_dequant or 'blk.' in tensor.name[:4]:
            if use_gguf or tensor_is_misc(tensor.name):
                # if is gguf file or is norm or bias 1-d tensor
                writer.add_tensor_info(tensor.name, tensor.data.shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)
                total_bytes += tensor.n_bytes
            else:
                # if is exl3 file and is 2-d quant weight tensor
                name_split = tensor.name.split('.')
                layer_id = name_split[1]
                layer_type = name_split[2]
                mapped_name = name_map[layer_type]
                # here we don't consider Phi & GLM for easy
                assert not isinstance(mapped_name, list)

                layer_name = f'{layer_prefix}.{layer_id}.{mapped_name}'
                trellis = model_safetensors[f"{layer_name}.trellis"]
                suh     = model_safetensors[f"{layer_name}.suh"]
                svh     = model_safetensors[f"{layer_name}.svh"]
                assert suh.shape[0] == 16*trellis.shape[0] and svh.shape[0] == 16*trellis.shape[1]
                dim1, dim2 = gguf.quant_shape_from_byte_shape(tensor.data.shape, tensor.tensor_type)
                assert suh.shape[0] == dim2 and svh.shape[0] == dim1
                exl3_tensor_nbytes = trellis.nbytes + suh.nbytes + svh.nbytes
                writer.add_tensor_info(tensor.name, (trellis.shape[2], 16*trellis.shape[1], 16*trellis.shape[0]),
                                       np.int16, exl3_tensor_nbytes, 401)
                total_bytes += exl3_tensor_nbytes
        else:
            for bf16_tensor in bf16_reader.tensors:
                if bf16_tensor.name == tensor.name: break;
            total_bytes += bf16_tensor.n_bytes
            writer.add_tensor_info(bf16_tensor.name,
                                   bf16_tensor.data.shape, bf16_tensor.data.dtype, bf16_tensor.data.nbytes,
                                   bf16_tensor.tensor_type)

bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
writer.write_header_to_file()
writer.write_kv_data_to_file()
writer.write_ti_data_to_file()

for tensor in gguf_reader.tensors:
    if is_iq2ks(tensor):
        nx, ny = tensor.data.shape
        data = np.empty((nx, ny), np.uint8)
        lib.do_my_transpose(data, tensor.data, nx, ny, stride, avxlen, full_trans)
        writer.write_tensor_data(data)
        bar.update(tensor.n_bytes)
    elif do_dequant and tensor_is_2d(tensor.name):
        if use_gguf:
            data = torch.from_numpy(gguf.dequantize(tensor.data, tensor.tensor_type))                
        else:
            name_split = tensor.name.split('.')
            layer_id = name_split[1]
            layer_type = name_split[2]
            mapped_names = name_map[layer_type]
            if not isinstance(mapped_names, list):
                mapped_names = [mapped_names]
            data = torch.cat([load_tensor(reader, f'{layer_prefix}.{layer_id}.{i}') for i in mapped_names], dim=0)

        if add_residual:
            data = next(torch.from_numpy(gguf.dequantize(t.data, t.tensor_type))
                        for t in residual_reader.tensors if t.name == tensor.name) + data
        elif is_residual:
            data = next(torch.from_numpy(gguf.dequantize(t.data, t.tensor_type))
                        for t in bf16_reader.tensors if t.name == tensor.name) - data
                
        data = data.numpy().astype(np.float16)
        writer.write_tensor_data(data)
        bar.update(data.shape[0] * data.shape[1] * 2)
    elif not is_residual: # if is_residual, no need for other 1d tensor in blocks or tensor outside of blocks
        if do_dequant or 'blk.' in tensor.name[:4]:
            if use_gguf or tensor_is_misc(tensor.name):
                # if is gguf file or is norm or bias 1-d tensor
                data = tensor.data
                writer.write_tensor_data(data)
                bar.update(tensor.n_bytes)
            else:
                # if is exl3 file and is 2-d quant weight tensor
                name_split = tensor.name.split('.')
                layer_id = name_split[1]
                layer_type = name_split[2]
                mapped_name = name_map[layer_type]
                layer_name = f'{layer_prefix}.{layer_id}.{mapped_name}'
                trellis = model_safetensors[f"{layer_name}.trellis"].flatten().view(torch.uint8)
                suh     = model_safetensors[f"{layer_name}.suh"].view(torch.uint8)
                svh     = model_safetensors[f"{layer_name}.svh"].view(torch.uint8)
                data = torch.cat([trellis, suh, svh])
                writer.write_tensor_data(data.numpy())
                bar.update(data.nbytes)
        else:
            for bf16_tensor in bf16_reader.tensors:
                if bf16_tensor.name == tensor.name: break;
            writer.write_tensor_data(bf16_tensor.data)
            bar.update(bf16_tensor.n_bytes)

writer.close()

output_gguf

## legacy

### FakeQuant/AWQ/GPTQ

<details>
<summary>analysis</summary>

```python
tensor_names = awq_reader.keys()
for name in tensor_names:
    tensor = awq_reader.get_tensor(name)
    print(f"name: {name}")
    print(f"shape: {tensor.shape}")
    print(f"type: {tensor.dtype}")
    # print(f"tensor:\n{tensor}")
    print("-" * 50)

# 讲个鬼故事，gguf和awq/gptq读出来的tensor是互为转置的
# 而且gguf的tensor排列是和原版的meta发布的模型权重次序是一致的
# awq/gptq则和huggingface中的保持一致，因为awq/gptq也是作为huggingface生态的一部分
print(gguf_reader.get_tensor(tensor_idx[f'blk.{0}.ffn_down.weight']).shape)
print(awq_reader.get_tensor(f'model.layers.{0}.mlp.down_proj.qweight').shape)
print(gptq_reader.get_tensor(f'model.layers.{0}.mlp.down_proj.qweight').shape)

print(load_gguf_tensor(gguf_reader, 'output_norm.weight'))
print(awq_reader.get_tensor('model.norm.weight'))
print(gptq_reader.get_tensor('model.norm.weight'))

layer=0
print(load_gguf_tensor(gguf_reader, f'blk.{layer}.attn_k.weight').shape)
print(load_awq_tensor(awq_reader, f'blk.{layer}.attn_k.weight').shape)
print(load_gptq_tensor(gptq_reader, f'blk.{layer}.attn_k.weight').shape)
```
    
</details>

In [None]:
url = "llmc generate"
model_path = "models/Llama-3.2-1B-Instruct-llmc-awq.safetensors"
output_gguf = "Llama-3.2-1B-Instruct-llmc-awq.gguf"
load_tensor = load_fakequant_tensor

In [None]:
url = "llmc generate"
model_path = "models/Llama-3.2-1B-Instruct-llmc-awq-omniq.safetensors"
output_gguf = "Llama-3.2-1B-Instruct-llmc-awq-omniq.gguf"
load_tensor = load_fakequant_tensor

In [None]:
url = "llmc generate"
model_path = "models/Llama-3.2-1B-Instruct-llmc-hqq.safetensors"
output_gguf = "Llama-3.2-1B-Instruct-llmc-hqq.gguf"
load_tensor = load_fakequant_tensor

In [None]:
url = "https://huggingface.co/AMead10/Llama-3.2-1B-Instruct-AWQ/resolve/main/model.safetensors"
model_path = "models/Llama-3.2-1B-Instruct-AWQ.safetensors"
output_gguf = "Llama-3.2-1B-Instruct-AWQ.gguf"
load_tensor = load_awq_tensor

In [None]:
# intel的gptq/autoround系列反量化后的值是一样的，就是格式不同而已
url = "https://huggingface.co/fbaldassarri/meta-llama_Llama-3.2-1B-Instruct-auto_awq-int4-gs128-asym/resolve/main/model.safetensors"
model_path = "models/Llama-3.2-1B-Instruct-auto_awq-int4-gs128-asym.safetensors"
output_gguf = "Llama-3.2-1B-Instruct-auto_awq-int4-gs128-asym.gguf"
load_tensor = load_awq_tensor

In [None]:
url = "https://huggingface.co/fbaldassarri/meta-llama_Llama-3.2-1B-Instruct-auto_awq-int4-gs128-sym/resolve/main/model.safetensors"
model_path = "models/Llama-3.2-1B-Instruct-auto_awq-int4-gs128-sym.safetensors"
output_gguf = "Llama-3.2-1B-Instruct-auto_awq-int4-gs128-sym.gguf"
load_tensor = load_awq_tensor

In [None]:
url = "https://huggingface.co/ModelCloud/Llama-3.2-1B-Instruct-gptqmodel-4bit-vortex-v2.5/resolve/main/model.safetensors"
model_path = "models/Llama-3.2-1B-Instruct-GPTQ-g32.safetensors"
output_gguf = "Llama-3.2-1B-Instruct-GPTQ-g32.gguf"
load_tensor = load_gptq_tensor

In [None]:
url = "https://huggingface.co/shuyuej/Llama-3.2-1B-Instruct-GPTQ/resolve/main/model.safetensors"
model_path = "models/Llama-3.2-1B-Instruct-GPTQ-g128.safetensors"
output_gguf = "Llama-3.2-1B-Instruct-GPTQ-g128.gguf"
load_tensor = load_gptq_tensor

In [None]:
url = "https://huggingface.co/Almheiri/Llama-3.2-1B-Instruct-GPTQ-INT4/resolve/main/model.safetensors"
model_path = "models/Llama-3.2-1B-Instruct-GPTQ-INT4.safetensors"
output_gguf = "Llama-3.2-1B-Instruct-GPTQ-INT4.gguf"
load_tensor = load_gptq_tensor

In [None]:
if not os.path.exists(model_path):
    print("file not found, download from internet...")
    subprocess.run(["wget", "-O", model_path, url])

reader = safe_open(model_path, framework="pt")
model_path

### HQQ

In [None]:
url = "https://hf-mirror.com/brunopio/Llama-3.2-1B-Instruct-nbits4-GSNone-Axis0-HQQ/resolve/main/qmodel.pt"
model_path = "models/Llama-3.2-1B-Instruct-nbits4-GSNone-Axis0-HQQ.pt"
output_gguf = "Llama-3.2-1B-Instruct-nbits4-GSNone-Axis0-HQQ.gguf"
load_tensor = load_hqq_tensor

In [None]:
url = "https://hf-mirror.com/brunopio/Llama-3.2-1B-Instruct-nbits4-GS64-Axis1-HQQ/resolve/main/qmodel.pt"
model_path = "models/Llama-3.2-1B-Instruct-nbits4-GS64-Axis1-HQQ.pt"
output_gguf = "Llama-3.2-1B-Instruct-nbits4-GS64-Axis1-HQQ.gguf"
load_tensor = load_hqq_tensor

In [None]:
if not os.path.exists(model_path):
    print("file not found, download from internet...")
    subprocess.run(["wget", "-O", model_path, url])

reader = hqqpt_open(model_path)
model_path

### OmniQuant/MLC

In [None]:
url = "https://hf-mirror.com/numen-tech/Llama-3.2-1B-Instruct-w4a16g128asym/resolve/main/"
model_path = "models/Llama-3.2-1B-Instruct-w4a16g128asym/"
output_gguf = "Llama-3.2-1B-Instruct-w4a16g128asym.gguf"
load_tensor = load_mlc_tensor

In [None]:
url = "https://hf-mirror.com/mlc-ai/Llama-3.2-1B-Instruct-q4f16_1-MLC/resolve/main/"
model_path = "models/Llama-3.2-1B-Instruct-q4f16_1-MLC/"
output_gguf = "Llama-3.2-1B-Instruct-q4f16_1-MLC.gguf"
load_tensor = load_mlc_tensor

In [None]:
reader = mlcbin_open(model_path, url)
model_path

### Running demo

In [None]:
%env HF_ENDPOINT=https://hf-mirror.com
%env HF_HUB_ENABLE_HF_TRANSFER=1
# !pip install -U transformers peft accelerate optimum auto-gptq autoawq
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# question = "火影忍者的作者是谁？" # Q5_K_S/Q4_K_S/IQ4_XS/Q3_K_XL will be failed
# question = "Naruto的作者是谁？" # Q5/IQ4_XS/Q3_K_XL will be failed
question = "Who is the author of 'Chainsaw Man'?" # Q4_0/IQ3_M/Q3_K_XL will be failed

In [None]:
# model_path = "AMead10/Llama-3.2-1B-Instruct-AWQ"
# model_path = "Almheiri/Llama-3.2-1B-Instruct-GPTQ-INT4"
model_path = "ModelCloud/Llama-3.2-1B-Instruct-gptqmodel-4bit-vortex-v2.5"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", device_map="cuda")

In [None]:
from hqq.engine.hf import HQQModelForCausalLM, AutoTokenizer
model_path = 'brunopio/Llama-3.2-1B-Instruct-nbits4-GS64-Axis1-HQQ'
# model_path = 'brunopio/Llama-3.2-1B-Instruct-nbits4-GSNone-Axis0-HQQ' # rabbish

tokenizer = AutoTokenizer.from_pretrained(model_path)
model     = HQQModelForCausalLM.from_quantized(model_path)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", device_map="cuda")

prompt = [
    {"role": "system", "content": "\n\nYou are a helpful assistant"},
    {"role": "user", "content": question},
]

input_tensor = tokenizer.apply_chat_template(prompt, add_generation_prompt=True, return_tensors="pt")

outputs = model.generate(input_ids=input_tensor.to(model.device), max_new_tokens=512, do_sample=False)
# result = tokenizer.decode(outputs[0][input_tensor.shape[1]:], skip_special_tokens=True)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(result)

# Analysis

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
def display_heatmap_2d(tensor, xname, yname, name=None, vmin=None, vmax=None):
    plt.figure(figsize=(16, 4))
    sns.heatmap(tensor, cmap='coolwarm', center=0, vmin=vmin, vmax=vmax)
    # sns.heatmap(tensor, cmap='viridis', center=0, vmin=vmin, vmax=vmax)
    if name:
        plt.title(name)
    plt.xlabel(xname)
    plt.ylabel(yname)
    plt.show()

def sns_histplot(tensor, name=None):
    plt.figure(figsize=(8, 4))
    hist = sns.histplot(tensor, kde=True, bins=50)
    if name:
        plt.title(name)
    plt.xlabel("Value")
    plt.ylabel("Frequency")
    plt.grid()
    plt.show()

def print_channal_import_factor(acts, nchannals):
    top_k = nchannals//2
    sorted_a = np.sort(acts)
    
    fig, ((ax2, ax1), (ax4, ax3)) = plt.subplots(nrows=2, ncols=2, figsize=(10, 5))
    
    ax1.plot(sorted_a[-top_k//4:], color="blue", linestyle="-", linewidth=2)
    ax1.grid(True, linestyle="--", alpha=0.5)
    ax1.set_title(f"{nchannals-top_k//4}~{nchannals} (1/8)")
    print("top-1/16", sorted_a[-top_k//8], "| top-1/8", sorted_a[-top_k//4])
    
    ax2.plot(sorted_a[-top_k//2:-top_k//4], color="blue", linestyle="-", linewidth=2)
    ax2.grid(True, linestyle="--", alpha=0.5)
    ax2.set_title(f"{nchannals-top_k//2}~{nchannals-top_k//4} (1/8)")
    print("top-3/16", sorted_a[-3*top_k//8], "| top-1/4", sorted_a[-top_k//2])
    
    ax3.plot(sorted_a[-top_k:-top_k//2], color="blue", linestyle="-", linewidth=2)
    ax3.grid(True, linestyle="--", alpha=0.5)
    ax3.set_title(f"{nchannals-top_k}~{nchannals-top_k//2} (1/4)")
    print("top-3/8", sorted_a[-3*top_k//4], "| top-1/2", sorted_a[-top_k])
    
    ax4.plot(sorted_a[:-top_k], color="blue", linestyle="-", linewidth=2)
    ax4.grid(True, linestyle="--", alpha=0.5)
    ax4.set_title(f"0~{nchannals-top_k} (1/2)")
    
    print("1/16:", sorted_a[nchannals//16], "1/8:", sorted_a[nchannals//8], "1/4:", sorted_a[nchannals//4])
    
    plt.tight_layout()
    plt.show()

def compare_resdiual_factor(acts0, acts1, nchannals):
    sorted_a0 = np.sort(acts0)
    sorted_a1 = np.sort(acts1)
    plt.figure(figsize=(10, 4))
    plt.plot(sorted_a0[0:-nchannals//2], color="red", linestyle="-", linewidth=2)
    plt.plot(sorted_a1[nchannals//4:-nchannals//4], color="blue", linestyle="-", linewidth=2)
    plt.show()

<details>
<summary>save_imatrix</summary>

```python
def data_write_string(f, string: str):
    encoded = string.encode('utf-8')
    np.array(len(encoded), dtype=np.int32).tofile(f)
    f.write(encoded)

def save_imatrix(filename: str, entries: dict, m_last_call: int, prompt_file: str):
    with open(filename, 'wb') as f:
        # Write number of entries
        np.array(len(entries), dtype=np.int32).tofile(f)
        
        # Write each entry
        for name, values in entries.items():
            # Write string (name)
            data_write_string(f, name)
            
            # Write metadata (ncall=1, nval=len(values))
            np.array(1, dtype=np.int32).tofile(f)  # ncall
            np.array(len(values), dtype=np.int32).tofile(f)  # nval
            
            # Write float32 values (assume values are already normalized)
            np.asarray(values, dtype=np.float32).tofile(f)
        
        # Write final metadata
        np.array(m_last_call, dtype=np.int32).tofile(f)
        data_write_string(f, prompt_file)
```

</details>

In [None]:
def load_gguf_tensor(main_reader, name, residual_reader=None):
    # Get main tensor from primary reader
    main_t = next(torch.from_numpy(gguf.dequantize(t.data, t.tensor_type))
                  for t in main_reader.tensors if t.name == name)
    assert main_t is not None

    if residual_reader is not None and '_norm' not in tensor.name:
        residual_t = next(torch.from_numpy(gguf.dequantize(t.data, t.tensor_type))
                          for t in residual_reader.tensors if t.name == name)
        assert residual_t is not None
        main_t = main_t + residual_t

    return main_t

def load_gguf_imatrix(imatrix_name, has_acts=False):
    offs = 0
    data = np.memmap(imatrix_name, mode = 'r')
    n_entries = data_get(data, offs, np.int32)[0]
    offs += 4
    entries = {}
    act_entries = {}
    for _ in range(n_entries):
        name, offs = data_read_string(data, offs)
        ncall = data_get(data, offs,   np.int32)[0]
        nval  = data_get(data, offs+4, np.int32)[0]
        vals  = data_get(data, offs+8, np.float32, nval)
        offs  += 4*nval+8
        entries[name] = vals / ncall if ncall > 0 else vals
        if has_acts:
            nact  = data_get(data, offs, np.int32)[0]
            acts  = data_get(data, offs+4, np.float32, nact)
            offs  += 4*nact+4
            act_entries[name] = acts.reshape((-1, nval))
    m_last_call = data_get(data, offs, np.int32)[0]
    prompt_file, _ = data_read_string(data, offs+4)
    if has_acts:
        return entries, act_entries, m_last_call, prompt_file
    return entries, m_last_call, prompt_file

<details>
<summary>IQ2_KS quant log</summary>
    
```log
[   3/ 292]    blk.0.attn_norm.weight - [ 4096,     1], type =    f32, size =    0.016 MB
[   4/ 292]     blk.0.ffn_down.weight - [14336,  4096], type = bf16, converting to q2_k
[   5/ 292]     blk.0.ffn_gate.weight - [ 4096, 14336], type = bf16, converting to iq2_ks
[   6/ 292]       blk.0.ffn_up.weight - [ 4096, 14336], type = bf16, converting to iq2_ks
[   7/ 292]     blk.0.ffn_norm.weight - [ 4096,     1], type =    f32, size =    0.016 MB
[   8/ 292]       blk.0.attn_k.weight - [ 4096,  1024], type = bf16, converting to iq2_ks
[   9/ 292]  blk.0.attn_output.weight - [ 4096,  4096], type = bf16, converting to iq2_ks
[  10/ 292]       blk.0.attn_q.weight - [ 4096,  4096], type = bf16, converting to iq2_ks
[  11/ 292]       blk.0.attn_v.weight - [ 4096,  1024], type = bf16, converting to iq4_k
[ 273/ 292]   blk.30.attn_norm.weight - [ 4096,     1], type =    f32, size =    0.016 MB
[ 274/ 292]    blk.30.ffn_down.weight - [14336,  4096], type = bf16, converting to iq2_ks
[ 275/ 292]    blk.30.ffn_gate.weight - [ 4096, 14336], type = bf16, converting to iq2_ks
[ 276/ 292]      blk.30.ffn_up.weight - [ 4096, 14336], type = bf16, converting to iq2_ks
[ 277/ 292]    blk.30.ffn_norm.weight - [ 4096,     1], type =    f32, size =    0.016 MB
[ 278/ 292]      blk.30.attn_k.weight - [ 4096,  1024], type = bf16, converting to iq2_ks
[ 279/ 292] blk.30.attn_output.weight - [ 4096,  4096], type = bf16, converting to iq2_ks
[ 280/ 292]      blk.30.attn_q.weight - [ 4096,  4096], type = bf16, converting to iq2_ks
[ 281/ 292]      blk.30.attn_v.weight - [ 4096,  1024], type = bf16, converting to iq4_k
```
</details>

<details>
<summary>IQ2_K quant log</summary>

```log
[   3/ 292]    blk.0.attn_norm.weight - [ 4096,     1], type =    f32, size =    0.016 MB
[   4/ 292]     blk.0.ffn_down.weight - [14336,  4096], type = bf16, converting to iq2_k .. size = 112.00 MiB -> 16.62 MiB
[   5/ 292]     blk.0.ffn_gate.weight - [ 4096, 14336], type = bf16, converting to iq2_k .. size = 112.00 MiB -> 16.62 MiB
[   6/ 292]       blk.0.ffn_up.weight - [ 4096, 14336], type = bf16, converting to iq2_k .. size = 112.00 MiB -> 16.62 MiB
[   7/ 292]     blk.0.ffn_norm.weight - [ 4096,     1], type =    f32, size =    0.016 MB
[   8/ 292]       blk.0.attn_k.weight - [ 4096,  1024], type = bf16, converting to iq2_k .. size =   8.00 MiB ->  1.19 MiB
[   9/ 292]  blk.0.attn_output.weight - [ 4096,  4096], type = bf16, converting to iq3_k .. size =  32.00 MiB ->  6.88 MiB
[  10/ 292]       blk.0.attn_q.weight - [ 4096,  4096], type = bf16, converting to iq2_k .. size =  32.00 MiB ->  4.75 MiB
[  11/ 292]       blk.0.attn_v.weight - [ 4096,  1024], type = bf16, converting to iq4_k .. size =   8.00 MiB ->  2.25 MiB
[ 273/ 292]   blk.30.attn_norm.weight - [ 4096,     1], type =    f32, size =    0.016 MB
[ 274/ 292]    blk.30.ffn_down.weight - [14336,  4096], type = bf16, converting to iq2_k .. size = 112.00 MiB -> 16.62 MiB
[ 275/ 292]    blk.30.ffn_gate.weight - [ 4096, 14336], type = bf16, converting to iq2_k .. size = 112.00 MiB -> 16.62 MiB
[ 276/ 292]      blk.30.ffn_up.weight - [ 4096, 14336], type = bf16, converting to iq2_k .. size = 112.00 MiB -> 16.62 MiB
[ 277/ 292]    blk.30.ffn_norm.weight - [ 4096,     1], type =    f32, size =    0.016 MB
[ 278/ 292]      blk.30.attn_k.weight - [ 4096,  1024], type = bf16, converting to iq2_k .. size =   8.00 MiB ->  1.19 MiB
[ 279/ 292] blk.30.attn_output.weight - [ 4096,  4096], type = bf16, converting to iq3_k .. size =  32.00 MiB ->  6.88 MiB
[ 280/ 292]      blk.30.attn_q.weight - [ 4096,  4096], type = bf16, converting to iq2_k .. size =  32.00 MiB ->  4.75 MiB
[ 281/ 292]      blk.30.attn_v.weight - [ 4096,  1024], type = bf16, converting to iq4_k .. size =   8.00 MiB ->  2.25 MiB
```
</details>


In [None]:
# quant_type = "IQ2_KS"
quant_type = "IQ2_K"
# quant_type = "IQ2_KL"
# quant_type = "IQ2_M"
# quant_type = "IQ2_XXS"

test_gguf_file = f"models/{model_name}-{quant_type}.gguf"
reader = GGUFReader(test_gguf_file, 'r')
load_tensor = load_gguf_tensor

<details>
<summary>exl3-2.25bpw quant log</summary>

```log
 -- Loading unquantized module: model.layers.0
 -- Captured: model.layers.0
 -- Quantized: model.layers.0.self_attn.q_proj     bpw:  3.00  proxy_err: 0.000374  .  g_sc: 0.724476  [2.98 s]
 -- Quantized: model.layers.0.self_attn.k_proj     bpw:  4.00  proxy_err: 0.000127  .  g_sc: 0.744701  [0.95 s]
 -- Quantized: model.layers.0.self_attn.v_proj     bpw:  4.00  proxy_err: 0.000660  .  g_sc: 0.868692  [0.94 s]
 -- Quantized: model.layers.0.self_attn.o_proj     bpw:  3.00  proxy_err: 0.004458  o  g_sc: 0.945322  [3.00 s]
 -- Quantized: model.layers.0.mlp.up_proj          bpw:  2.00  proxy_err: 0.034301  o  g_sc: 0.998273  [11.61 s]
 -- Quantized: model.layers.0.mlp.gate_proj        bpw:  2.00  proxy_err: 0.028172  o  g_sc: 0.998273  [11.65 s]
 -- Quantized: model.layers.0.mlp.down_proj        bpw:  2.00  proxy_err: 0.023181  o  g_sc: 1.001727  [12.60 s]
 -- Quantized: model.layers.0                      bpw:  2.24  rfn: 0.156698  cos: 0.011249  sqnr: 16.935898  [51.25 s]
```

</details>

<details>
<summary>exl3-2.5bpw quant log</summary>

```log
 -- Loading unquantized module: model.layers.0
 -- Captured: model.layers.0
 -- Quantized: model.layers.0.self_attn.q_proj     bpw:  3.00  proxy_err: 0.000374  .  g_sc: 0.724476  [2.99 s]
 -- Quantized: model.layers.0.self_attn.k_proj     bpw:  4.00  proxy_err: 0.000127  .  g_sc: 0.744701  [0.95 s]
 -- Quantized: model.layers.0.self_attn.v_proj     bpw:  4.00  proxy_err: 0.000660  .  g_sc: 0.868692  [0.94 s]
 -- Quantized: model.layers.0.self_attn.o_proj     bpw:  3.00  proxy_err: 0.004458  o  g_sc: 0.945322  [3.00 s]
 -- Quantized: model.layers.0.mlp.up_proj          bpw:  2.00  proxy_err: 0.034301  o  g_sc: 0.998273  [11.61 s]
 -- Quantized: model.layers.0.mlp.gate_proj        bpw:  2.00  proxy_err: 0.028172  o  g_sc: 0.998273  [11.65 s]
 -- Quantized: model.layers.0.mlp.down_proj        bpw:  3.00  proxy_err: 0.005810  o  g_sc: 0.939732  [10.98 s]
 -- Quantized: model.layers.0                      bpw:  2.51  rfn: 0.124460  cos: 0.006881  sqnr: 19.005521  [49.64 s]
```

</details>



In [None]:
bpw = 2.5
quant_type = f"exl3-{str(bpw)}bpw"

test_model = f"models/{model_name}-{quant_type}"
config = Config.from_directory(test_model)
reader = Model.from_config(config)
load_tensor = load_exl3_tensor
test_model

<details>
<summary>slice the weight tensor</summary>

```python
part0 = 0; stride0 = 512; part1 = 0; stride1 = 512
idx0, idx1 = (slice(stride0*part0, stride0*(part0+1)), slice(stride1*part1, stride1*(part1+1)))
```
</details>

In [None]:
layer = 5
name = f'blk.{layer}.attn_q.weight'
entries, _, _ = load_gguf_imatrix(f'models/{model_name}.imatrix')

bf16_t = load_gguf_tensor(bf16_reader, name)
quant_t = load_tensor(reader, name)

diff_t = bf16_t-quant_t

# display_heatmap_2d(bf16_t[:256, -1024:], "inputC", "outputC", name)
# sns_histplot(torch.sum(torch.abs(bf16_t), dim=0))
bf16_t_collapse = torch.sum(torch.pow(bf16_t, 2), dim=0).flatten().numpy()
sns_histplot(bf16_t_collapse, name=f"{model_name}/{name}")

# display_heatmap_2d(diff_t[:256, :1024], "inputC", "outputC", quant_type+'/'+name)
# sns_histplot(torch.sum(torch.abs(diff_t), dim=0))
# sns_histplot(diff_t)
diff_t_collapse = torch.sum(torch.pow(diff_t, 2), dim=0).flatten().numpy()
sns_histplot(diff_t_collapse, name=f"{model_name}-{quant_type}-residual/{name}")

In [None]:
_, act_entries, _, _ = load_gguf_imatrix(f'models/{model_name}-{quant_type}-IQ2_KS.imatrix', has_acts=True)

In [None]:
acts = act_entries[name]
acts_sq = acts[252] ** 2

In [None]:
wbf16 = acts_sq * bf16_t_collapse
print_channal_import_factor(wbf16, acts_sq.size)

In [None]:
wdiff = acts_sq * diff_t_collapse
print_channal_import_factor(wdiff, acts_sq.size)

In [None]:
compare_resdiual_factor(wbf16, wdiff, acts_sq.size)

In [None]:
def get_top_k_index(a, top_k, cached_index=None):
    threshold = np.partition(a, -top_k)[-top_k]
    top_index = np.where(a >= threshold)[0]
    if cached_index is not None:
        print(f'{top_k - np.sum(np.isin(top_index, cached_index)):4d}', end=' ')
    return top_index

In [None]:
top_k = acts_sq.size//2
imatrix_index = get_top_k_index(entries[name], top_k)

batch_size = 4
start_idx = 252
cached_index = get_top_k_index(np.sum(acts[start_idx:start_idx+batch_size], axis=0), top_k)
assert cached_index.size == top_k and imatrix_index.size == top_k

wacts = acts[start_idx:start_idx+batch_size]*diff_t_collapse
for test_acts in wacts:
    print("miss channels ", end='')
    for i in range(8):
        step = (i+1)*top_k//8
        print(f'|{step}', end=':')
        get_top_k_index(test_acts, step, cached_index)
        get_top_k_index(test_acts, step, imatrix_index)
    print()

In [None]:
cached_index = get_top_k_index(np.sum(wacts, axis=0), top_k)
assert cached_index.size == top_k

for test_acts in wacts:
    print("miss channels ", end='')
    for i in range(8):
        step = (i+1)*top_k//8
        print(f'|{step}', end=':')
        get_top_k_index(test_acts, step, cached_index)
        get_top_k_index(test_acts, step, imatrix_index)
    print()

In [None]:
# the below scheme is obsolete
addition = top_k//8 if batch_size == 4 else 0

act_map = np.zeros(wq.size, dtype=np.uint8)
cached_index = None

for test_acts in acts[start:start+batch_size]:
    cached_index = get_top_k_index(test_acts, top_k//4+addition, cached_index) 
    act_map[cached_index] = 1
    cached_index = np.where(act_map == 1)[0]
    
print(cached_index.size)

for test_acts in acts[start:start+batch_size]:
    for i in range(4):
        step = (i+1)*top_k//4
        print(f'|{step}', end=':')
        get_top_k_index(test_acts, step, cached_index)
        get_top_k_index(test_acts, step, imatrix_index)
    print()

# Inference

[other implementation](https://github.com/nicolaswilde/llama.pytorch/tree/main)

In [None]:
tensor_names = [tensor.name for tensor in bf16_reader.tensors]

In [None]:
def rms_norm(tensor, norm_weights, norm_eps):
    return (tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + norm_eps)) * norm_weights

In [None]:
# llama.cpp style
def compute_rope(x, cos, sin):
    rotated = torch.zeros_like(x)
    rotated[..., 1::2] = x[..., ::2]
    rotated[..., ::2] = -x[..., 1::2]
    x_rotated = (x * cos) + (rotated * sin)
    return x_rotated.to(dtype=x.dtype)

theta_scale = 1.0 / rope_theta ** (2.0/n_dims)
freqs = theta_scale ** torch.tensor([i//2 for i in range(n_dims)])

if 'rope_freqs.weight' in tensor_names:
    print("rope_scaling has been found")
    freq_factors = load_gguf_tensor(gguf_reader, 'rope_freqs.weight')
    freq_factors = torch.repeat_interleave(freq_factors, repeats=2)
    freqs = freqs/freq_factors

In [None]:
# huggingface style
def compute_rope(x, cos, sin):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    rotated = torch.cat((-x2, x1), dim=-1)
    x_rotated = (x * cos) + (rotated * sin)
    return x_rotated.to(dtype=x.dtype)

# freqs = 1.0 / (rope_theta ** (2*torch.tensor(range(n_dims//2))/n_dims))
theta_scale = 1.0 / rope_theta ** (2.0/n_dims)
freqs = theta_scale ** torch.tensor(2*list(range(0, n_dims//2)))

if 'rope_freqs.weight' in tensor_names:
    print("rope_scaling has been found")
    freq_factors = load_gguf_tensor(gguf_reader, 'rope_freqs.weight')
    freq_factors = freq_factors.repeat(2)
    freqs = freqs/freq_factors

In [None]:
load_tensor(reader, name)

## Tokenization

In [None]:
prefill_tokens  = re.findall(r':(\d+)', open("prefill.in").read())
generate_tokens = re.findall(r':(\d+)', open("generate.out").read())

tokens = torch.tensor([int(token) for token in prefill_tokens)
# tokens = torch.tensor([int(token) for token in prefill_tokens+generate_tokens[:4]])
seq_len = len(tokens)

freqs_for_each_token = torch.outer(torch.arange(seq_len), freqs)
cos = torch.cos(freqs_for_each_token)
sin = torch.sin(freqs_for_each_token)

mask = torch.full((seq_len, seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
tokens

## Input

<details>
<summary>llama.cpp</summary>

```
b ggml_compute_forward
p *tensor # op = GGML_OP_GET_ROWS, type = GGML_TYPE_F32, name = "inp_embd"
tb ggml_compute_forward_get_rows
c
p *dst->src[0] # type = GGML_TYPE_F16, name = "token_embd.weight"
p *dst->src[1] # type = GGML_TYPE_I32, name = "inp_tokens"
p *(unsigned short *)((char *) dst->src[0]->data + 0*dst->src[0]->nb[3] + 0*dst->src[0]->nb[2] + 0*dst->src[0]->nb[1] + 0*dst->src[0]->nb[0])
p *(int *)((char *) dst->src[1]->data + 0*dst->src[1]->nb[3] + 0*dst->src[1]->nb[2] + 0*dst->src[1]->nb[1] + 0*dst->src[1]->nb[0])
fin
p *(float *)((char *) tensor->data + 0*tensor->nb[3] + 0*tensor->nb[2] + 0*tensor->nb[1] + 0*tensor->nb[0])
```
</details>

In [None]:
embedding_layer = torch.nn.Embedding(vocab_size, hidden_size)
# embedding_layer.weight.data.copy_(load_tensor(reader, 'token_embd.weight'))
embedding_layer.weight.data.copy_(load_gguf_tensor(gguf_reader, 'token_embd.weight'))
token_embeddings_unnormalized = embedding_layer(tokens)

In [None]:
final_embedding = token_embeddings_unnormalized.clone()

reader, load_tensor = gguf_reader, load_gguf_tensor
# reader, load_tensor = awq_reader, load_awq_tensor
# reader, load_tensor = gptq_reader, load_gptq_tensor

## Output

In [None]:
final_embedding = rms_norm(final_embedding, load_gguf_tensor(gguf_reader, 'output_norm.weight'), norm_eps)

if 'output.weight' in tensor_names:
    embed_w = load_gguf_tensor(gguf_reader, 'output.weight') # weight tying strategy
else:
    embed_w = load_gguf_tensor(gguf_reader, 'token_embd.weight') # weight tying strategy
logits = torch.matmul(final_embedding[-1], embed_w.T)
next_token = torch.argmax(logits, dim=-1)
next_token

## Transformer

In [None]:
n_layer = n_blocks
# n_layer = 10
for layer in range(n_layer):
    layer_embedding_norm = rms_norm(final_embedding, load_tensor(reader, f'blk.{layer}.attn_norm.weight'), norm_eps)

    q_layer  = load_tensor(reader, f'blk.{layer}.attn_q.weight')
    k_layer  = load_tensor(reader, f'blk.{layer}.attn_k.weight')
    v_layer  = load_tensor(reader, f'blk.{layer}.attn_v.weight')
    w_layer  = load_tensor(reader, f'blk.{layer}.attn_output.weight')
    ffn_down = load_tensor(reader, f'blk.{layer}.ffn_down.weight')
    ffn_gate = load_tensor(reader, f'blk.{layer}.ffn_gate.weight')
    ffn_up   = load_tensor(reader, f'blk.{layer}.ffn_up.weight')

    q_per_token = torch.chunk(torch.matmul(layer_embedding_norm, q_layer.T), chunks=n_heads, dim=-1)
    k_per_token = torch.chunk(torch.matmul(layer_embedding_norm, k_layer.T), chunks=n_kv_heads, dim=-1)
    v_per_token = torch.chunk(torch.matmul(layer_embedding_norm, v_layer.T), chunks=n_kv_heads, dim=-1)
    
    qkv_attention_list = []
    for head in range(n_heads):
        # compute query with location
        q_per_token_rotated = compute_rope(q_per_token[head], cos, sin)
        # compute key with location
        k_per_token_rotated = compute_rope(k_per_token[head*n_kv_heads//n_heads], cos, sin)
        # TODO: use kv-cache
        k_per_token_rotated_cached = k_per_token_rotated.to(torch.float16)
        v_per_token_cached = v_per_token[head*n_kv_heads//n_heads].to(torch.float16)
        # compute attention score
        qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated_cached.float().T) / (n_dims)**0.5
        qk_per_token_after_masking = qk_per_token + mask
        qk_per_token_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1)
        # use score to query the value
        qkv_attention = torch.matmul(qk_per_token_after_softmax, v_per_token_cached.float())
        qkv_attention_list.append(qkv_attention)

    stacked_qkv_attention = torch.cat(qkv_attention_list, dim=-1)

    embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T)
    embedding_after_edit = final_embedding + embedding_delta
    embedding_after_edit_normalized = rms_norm(embedding_after_edit, load_tensor(reader, f'blk.{layer}.ffn_norm.weight'), norm_eps)
    embedding_after_ffn = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, ffn_gate.T)) * torch.matmul(embedding_after_edit_normalized, ffn_up.T), ffn_down.T)
    final_embedding = embedding_after_edit + embedding_after_ffn

### Attention break down

#### rms_norm

<details>
<summary>llama.cpp</summary>

```
p *tensor # type = GGML_TYPE_F32, op = GGML_OP_RMS_NORM, name = "norm-0"
p *tensor->src[0] # type = GGML_TYPE_F32, op = GGML_OP_GET_ROWS, name = "inp_embd"
tb ggml_compute_forward_rms_norm
c
c
p *tensor # type = GGML_TYPE_F32, op = GGML_OP_MUL, name = "attn_norm-0"
p *tensor->src[0] # type = GGML_TYPE_F32, op = GGML_OP_RMS_NORM, name = "norm-0"
p *tensor->src[1] # type = GGML_TYPE_F32, name = "blk.0.attn_norm.weight"
tb ggml_compute_forward_mul
```
</details>

In [None]:
layer_embedding_norm = rms_norm(final_embedding,
    load_tensor(reader, f'blk.{n_layer}.attn_norm.weight'), norm_eps)

#### query

<details>
<summary>llama.cpp</summary>

```
p *tensor # type = GGML_TYPE_F32, op = GGML_OP_MUL_MAT, name = "Qcur-0"
p *tensor->src[0] # type = GGML_TYPE_F16, name = "blk.0.attn_q.weight"
p *tensor->src[1] # type = GGML_TYPE_F32, op = GGML_OP_MUL, name = "attn_norm-0"
tb ggml_compute_forward_mul_mat
c
fin
c
p *tensor # type = GGML_TYPE_F32, op = GGML_OP_RESHAPE, name = "Qcur-0 (reshaped)" # for head dimention

========
p *tensor # type = GGML_TYPE_F32, op = GGML_OP_ROPE, name = "Qcur-0"
p *tensor->src[0] # type = GGML_TYPE_F32, op = GGML_OP_RESHAPE, name = "Qcur-0 (reshaped)"
p *tensor->src[1] # type = GGML_TYPE_I32, name = "inp_pos"
tb ggml_compute_forward_rope
```
</details>

In [None]:
q_layer = load_tensor(reader, f'blk.{n_layer}.attn_q.weight')
q_per_token = torch.matmul(layer_embedding_norm, q_layer.T)

In [None]:
head = 0
q_per_token_chunk = torch.chunk(q_per_token, chunks=n_heads, dim=-1)
q_per_token_rotated = compute_rope(q_per_token_chunk[head], cos, sin)

#### key

<details>
<summary>llama.cpp</summary>

```
p *tensor # type = GGML_TYPE_F32, op = GGML_OP_MUL_MAT, name = "Kcur-0"
p *tensor # type = GGML_TYPE_F32, op = GGML_OP_RESHAPE, name = "Kcur-0 (reshaped)"
p *tensor # type = GGML_TYPE_F32, op = GGML_OP_ROPE, name = "Kcur-0"
tb ggml_compute_forward_rope
```
</details>

In [None]:
k_layer = load_tensor(reader, f'blk.{n_layer}.attn_k.weight')
k_per_token = torch.matmul(layer_embedding_norm, k_layer.T)

#### value

<details>
<summary>llama.cpp</summary>

```
p *tensor # type = GGML_TYPE_F32, op = GGML_OP_MUL_MAT, name = "Vcur-0"
p *tensor->src[0] # type = GGML_TYPE_F16, name = "blk.0.attn_v.weight"
p *tensor->src[1] # type = GGML_TYPE_F32, name = "attn_norm-0"
tb ggml_compute_forward_mul_mat
```
</details>

In [None]:
v_layer = load_tensor(reader, f'blk.{n_layer}.attn_v.weight')
v_per_token = torch.matmul(layer_embedding_norm, v_layer.T)

#### KV-cache

<details>
<summary>llama.cpp K cache</summary>

```
p *tensor # type = GGML_TYPE_F16, op = GGML_OP_VIEW, name = "k_cache_view-0"
p *tensor->src[0] # type = GGML_TYPE_F16, name = "cache_k_l0"
p *tensor # type = GGML_TYPE_F16, op = GGML_OP_CPY, name = "k_cache_view-0 (copy of Kcur-0)"
p *tensor->src[0] # type = GGML_TYPE_F32, op = GGML_OP_ROPE, name = "Kcur-0"
p *tensor->src[1] # type = GGML_TYPE_F16, op = GGML_OP_VIEW, name = "k_cache_view-0"
tb ggml_compute_forward_cpy
c # fp32->fp16
fin
```
</details>

<details>
<summary>llama.cpp V cache</summary>

```
p *tensor # type = GGML_TYPE_F32, op = GGML_OP_TRANSPOSE, name = "Vcur-0 (transposed)"
p *tensor->src[0] # type = GGML_TYPE_F32, op = GGML_OP_MUL_MAT, name = "Vcur-0"

p *tensor # type = GGML_TYPE_F16, op = GGML_OP_VIEW, name = "v_cache_view-0"
p *tensor->src[0] # type = GGML_TYPE_F16, name = "cache_v_l0"
p *tensor # type = GGML_TYPE_F16, op = GGML_OP_CPY, name = "v_cache_view-0 (copy of Vcur-0)"
p *tensor->src[0] # type = GGML_TYPE_F32, op = GGML_OP_ROPE, name = "Vcur-0"
p *tensor->src[1] # type = GGML_TYPE_F16, op = GGML_OP_VIEW, name = "v_cache_view-0"
```
</details>

<details>
<summary>llama.cpp prepare and reshape</summary>

```
p *tensor # type = GGML_TYPE_F16, op = GGML_OP_VIEW, name = "v-0"
#  ne = {[0] = 32, [1] = 16, [2] = 4, [3] = 1},
#  nb = {[0] = 2, [1] = 8192, [2] = 131072, [3] = 524288},
p *tensor->src[0] # type = GGML_TYPE_F16, name = "cache_v_l0"

p *tensor # type = GGML_TYPE_F16, op = GGML_OP_VIEW, name = "k-0"
#  ne = {[0] = 16, [1] = 32, [2] = 4, [3] = 1},
#  nb = {[0] = 2, [1] = 128, [2] = 32, [3] = 128},
p *tensor->src[0] # type = GGML_TYPE_F16, name = "cache_k_l0"

p *tensor # type = GGML_TYPE_F32, op = GGML_OP_PERMUTE, name = "q-0"
p *tensor->src[0] # type = GGML_TYPE_F32, op = GGML_OP_ROPE, name = "Qcur-0"
```
</details>

In [None]:
k_per_token_chunk = torch.chunk(k_per_token, chunks=n_kv_heads, dim=-1)
k_per_token_rotated = compute_rope(k_per_token_chunk[head], cos, sin)
k_per_token_rotated_cached = k_per_token_rotated.to(torch.float16)

v_per_token_chunk = torch.chunk(v_per_token, chunks=n_kv_heads, dim=-1)
v_per_token_cached = v_per_token_chunk[head].to(torch.float16)

#### QK attention score

<details>
<summary>llama.cpp</summary>

```
p *tensor # type = GGML_TYPE_F32, op = GGML_OP_MUL_MAT, name = "kq-0"
p *tensor->src[0] # type = GGML_TYPE_F16, op = GGML_OP_VIEW, name = "k-0"
p *tensor->src[1] # type = GGML_TYPE_F32, op = GGML_OP_PERMUTE, name = "q-0"
tb ggml_compute_forward_mul_mat

========

p *tensor # type = GGML_TYPE_F32, op = GGML_OP_SOFT_MAX, name = "kq_soft_max_ext-0"
p *tensor->src[0] # type = GGML_TYPE_F32, op = GGML_OP_MUL_MAT, name = "kq-0"
p *tensor->src[1] # type = GGML_TYPE_F32, name = "KQ_mask"
tb ggml_compute_forward_soft_max

========

p *tensor # type = GGML_TYPE_F32, op = GGML_OP_MUL_MAT, name = "kqv-0"
p *tensor->src[0] # type = GGML_TYPE_F16, op = GGML_OP_VIEW, name = "v-0"
p *tensor->src[1] # type = GGML_TYPE_F32, op = GGML_OP_SOFT_MAX, name = "kq_soft_max_ext-0"
tb ggml_compute_forward_mul_mat
```
</details>

In [None]:
qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated_cached.float().T) / (n_dims)**0.5
qk_per_token_after_masking = qk_per_token + mask
qk_per_token_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1)
qkv_attention = torch.matmul(qk_per_token_after_softmax, v_per_token_cached.float())

In [None]:
display_heatmap_2d(qk_per_token, 'token', 'token')
display_heatmap_2d(qk_per_token_after_masking, 'token', 'token')
display_heatmap_2d(qk_per_token_after_softmax, 'token', 'token')
display_heatmap_2d(qkv_attention, 'inputC', 'token')

### Attention

<details>
<summary>llama.cpp</summary>

```
p *tensor # type = GGML_TYPE_F32, op = GGML_OP_PERMUTE, name = "kqv_merged-0"
p *tensor->src[0] # type = GGML_TYPE_F32, op = GGML_OP_MUL_MAT, name = "kqv-0"
c
p *tensor # type = GGML_TYPE_F32, op = GGML_OP_CONT, name = "kqv_merged_cont-0"
p *tensor->src[0] # type = GGML_TYPE_F32, op = GGML_OP_PERMUTE, name = "kqv_merged-0"

========

p *tensor # type = GGML_TYPE_F32, op = GGML_OP_MUL_MAT, name = "kqv_out-0"
p *tensor->src[0] # type = GGML_TYPE_F16, op = GGML_OP_MUL_MAT, name = "blk.0.attn_output.weight"
p *tensor->src[1] # type = GGML_TYPE_F32, op = GGML_OP_CONT, name = "kqv_merged_cont-0"
tb ggml_compute_forward_mul_mat

========

p *tensor # type = GGML_TYPE_F32, op = GGML_OP_ADD, name = "ffn_inp-0"
p *tensor->src[0] # type = GGML_TYPE_F32, op = GGML_OP_MUL_MAT, name = "kqv_out-0"
p *tensor->src[1] # type = GGML_TYPE_F32, op = GGML_OP_GET_ROWS, name = "inp_embd"
tb ggml_compute_forward_add
```
</details>

In [None]:
# n_layer = 0
layer_embedding_norm = rms_norm(final_embedding, load_tensor(reader, f'blk.{n_layer}.attn_norm.weight'), norm_eps)

q_layer = load_tensor(reader, f'blk.{n_layer}.attn_q.weight')
k_layer = load_tensor(reader, f'blk.{n_layer}.attn_k.weight')
v_layer = load_tensor(reader, f'blk.{n_layer}.attn_v.weight')
w_layer = load_tensor(reader, f'blk.{n_layer}.attn_output.weight')

q_per_token = torch.matmul(layer_embedding_norm, q_layer.T)
k_per_token = torch.matmul(layer_embedding_norm, k_layer.T)
v_per_token = torch.matmul(layer_embedding_norm, v_layer.T)
q_per_token_chunk = torch.chunk(q_per_token, chunks=n_heads, dim=-1) 
k_per_token_chunk = torch.chunk(k_per_token, chunks=n_kv_heads, dim=-1) 
v_per_token_chunk = torch.chunk(v_per_token, chunks=n_kv_heads, dim=-1) 

qkv_attention_list = []
for head in range(n_heads):
    # compute query with location
    q_per_token_rotated = compute_rope(q_per_token_chunk[head], cos, sin)
    # compute key with location
    k_per_token_rotated = compute_rope(k_per_token_chunk[head*n_kv_heads//n_heads], cos, sin)
    # TODO: use kv-cache
    k_per_token_rotated_cached = k_per_token_rotated.to(torch.float16)
    v_per_token_cached = v_per_token_chunk[head*n_kv_heads//n_heads].to(torch.float16)
    # compute attention score
    qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated_cached.float().T) / (n_dims)**0.5
    qk_per_token_after_masking = qk_per_token + mask
    qk_per_token_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1)
    # use score to query the value
    qkv_attention = torch.matmul(qk_per_token_after_softmax, v_per_token_cached.float())
    qkv_attention_list.append(qkv_attention)

stacked_qkv_attention = torch.cat(qkv_attention_list, dim=-1)
################
embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T)
################
embedding_after_edit = final_embedding + embedding_delta

In [None]:
display_heatmap_2d(stacked_qkv_attention, 'inputC', 'token', f'{name}-imatrix')

In [None]:
display_heatmap_2d(layer_embedding_norm, "inputC", "token")
display_heatmap_2d(final_embedding, "inputC", "token")

### FFN

<details>
<summary>llama.cpp</summary>

```
p *tensor # type = GGML_TYPE_F32, op = GGML_OP_RMS_NORM, name = "norm-0"
p *tensor->src[0] # type = GGML_TYPE_F32, op = GGML_OP_ADD, name = "ffn_inp-0"
c
p *tensor # type = GGML_TYPE_F32, op = GGML_OP_MUL, name = "ffn_norm-0"
p *tensor->src[0] # type = GGML_TYPE_F32, op = GGML_OP_RMS_NORM, name = "norm-0"
p *tensor->src[1] # type = GGML_TYPE_F32, name = "blk.0.ffn_norm.weight"
tb ggml_compute_forward_mul

========

p *tensor # type = GGML_TYPE_F32, op = GGML_OP_MUL_MAT, name = "ffn_gate-0"
p *tensor->src[0] # type = GGML_TYPE_F16, name = "blk.0.ffn_gate.weight"
p *tensor->src[1] # type = GGML_TYPE_F32, op = GGML_OP_MUL, name = "ffn_norm-0"
tb ggml_compute_forward_mul_mat

========

p *tensor # type = GGML_TYPE_F32, op = GGML_OP_UNARY, name = "ffn_silu-0"
p *tensor->src[0] # type = GGML_TYPE_F32, op = GGML_OP_MUL_MAT, name = "ffn_gate-0"
tb ggml_compute_forward_unary

========

p *tensor # type = GGML_TYPE_F32, op = GGML_OP_MUL_MAT, name = "ffn_up-0"
p *tensor->src[0] # type = GGML_TYPE_F16, name = "blk.0.ffn_up.weight"
p *tensor->src[1] # type = GGML_TYPE_F32, op = GGML_OP_MUL, name = "ffn_norm-0"
tb ggml_compute_forward_mul_mat

========

p *tensor # type = GGML_TYPE_F32, op = GGML_OP_MUL, name = "ffn_gate_par-0"
p *tensor->src[0] # type = GGML_TYPE_F32, op = GGML_OP_UNARY, name = "ffn_silu-0"
p *tensor->src[1] # type = GGML_TYPE_F32, op = GGML_OP_MUL_MAT, name = "ffn_up-0"
tb ggml_compute_forward_mul

========

p *tensor # type = GGML_TYPE_F32, op = GGML_OP_MUL_MAT, name = "ffn_out-0"
p *tensor->src[0] # type = GGML_TYPE_F16, name = "blk.0.ffn_down.weight"
p *tensor->src[1] # type = GGML_TYPE_F32, op = GGML_OP_MUL, name = "ffn_gate_par-0"
tb ggml_compute_forward_mul_mat

========

p *tensor # type = GGML_TYPE_F32, op = GGML_OP_ADD, name = "l_out-0"
p *tensor->src[0] # type = GGML_TYPE_F32, op = GGML_OP_MUL_MAT, name = "ffn_out-0"
p *tensor->src[1] # type = GGML_TYPE_F32, op = GGML_OP_ADD, name = "ffn_inp-0"
tb ggml_compute_forward_add
```
</details>

In [None]:
embedding_after_edit_normalized = rms_norm(embedding_after_edit, load_tensor(reader, f'blk.{n_layer}.ffn_norm.weight'), norm_eps)
################
ffn_gate = load_tensor(reader, f'blk.{n_layer}.ffn_gate.weight')
x1 = torch.matmul(embedding_after_edit_normalized, ffn_gate.T)
################
x2 = torch.functional.F.silu(x1)
################
ffn_up   = load_tensor(reader, f'blk.{n_layer}.ffn_up.weight')
x3 = torch.matmul(embedding_after_edit_normalized, ffn_up.T)
################
x4 = x2 * x3
################
ffn_down = load_tensor(reader, f'blk.{n_layer}.ffn_down.weight')
output_after_feedforward = torch.matmul(x4 , ffn_down.T)
################
final_embedding = embedding_after_edit + output_after_feedforward

In [None]:
display_heatmap_2d(embedding_after_edit_normalized, "inputC", "token",  f'{name}-imatrix')