In [1]:
from aqlm import QuantizedLinear
from aqlm.utils import _dequantize_weight, unpack_int_data

In [2]:
# !cd /Users/blacksamorez/reps/executorch && bash build.sh

In [3]:
SIZE = 1024

layer = QuantizedLinear(
    in_features=SIZE,
    out_features=SIZE * 3,
    in_group_size=8,
    out_group_size=1,
    num_codebooks=2,
    nbits_per_codebook=8,
    bias=False,
)

In [4]:
reference_weight = _dequantize_weight(
    unpack_int_data(layer.codes, 8),
    layer.codebooks,
    layer.scales,
)

In [5]:
import torch

input = torch.rand((3, 2, SIZE)) * 2 - 1

In [6]:
import torch
torch.ops.load_library("./cmake-out/libaqlm_bindings.dylib")

reference = input @ reference_weight.T + (layer.bias if layer.bias is not None else 0)
test = torch.ops.aqlm.code2x8_lut_matmat(
    input,
    torch.permute(layer.codes, (1, 0, 2)).contiguous(),
    layer.codebooks,
    layer.scales,
    bias=layer.bias,
)

torch.testing.assert_close(
    test,
    reference,
    atol=0.01,
    rtol=1e-3,
)

In [7]:
%%time

for i in range(10):
    input @ reference_weight.T

CPU times: user 22.7 ms, sys: 11.9 ms, total: 34.6 ms
Wall time: 6.31 ms


In [8]:
%%time

for i in range(10):
    torch.ops.aqlm.code2x8_lut_matmat(
        input,
        torch.permute(layer.codes, (1, 0, 2)).contiguous(),
        layer.codebooks,
        layer.scales,
        bias=layer.bias,
    )

CPU times: user 996 ms, sys: 127 ms, total: 1.12 s
Wall time: 193 ms


In [9]:
import torch
from torch.export import export, ExportedProgram, Dim
from executorch.exir import EdgeCompileConfig, to_edge

_ = layer(input)

batch_size = Dim("batch_size", min=1)
seq_len = Dim("seq_len", min=1)
dynamic_shapes = {"input": {0: batch_size, 1: seq_len}}

with torch.no_grad():
    aten_dialect = export(layer, (input,), dynamic_shapes=dynamic_shapes)
    
edge_manager = to_edge(aten_dialect, compile_config=EdgeCompileConfig(_check_ir_validity=False))

et_program = edge_manager.to_executorch()

with open("aqlm.pte", "wb") as file:
    file.write(et_program.buffer)

  @impl_abstract("quantized_decomposed::embedding_byte.out")
  @impl_abstract("quantized_decomposed::embedding_byte.dtype_out")
  @impl_abstract("quantized_decomposed::embedding_4bit.out")
  @impl_abstract("quantized_decomposed::embedding_4bit.dtype_out")


In [1]:
import torch
from safetensors.torch import load_file

dict = load_file("/Users/blacksamorez/models/Llama-2-7b-AQLM-2Bit-2x8-hf/model.safetensors")

mapping = {
    "model.": "",
    
    "self_attn.q_proj": "attention.wq",
    "self_attn.k_proj": "attention.wk",
    "self_attn.v_proj": "attention.wv",
    "self_attn.o_proj": "attention.wo",
    
    "mlp.up_proj": "feed_forward.w3",
    "mlp.gate_proj": "feed_forward.w1",
    "mlp.down_proj": "feed_forward.w2",
    
    "input_layernorm": "attention_norm",
    "post_attention_layernorm": "ffn_norm",
    
    "lm_head": "output",
    "embed_tokens": "tok_embeddings",
}


new_dict = {}

for key, value in dict.items():
    for old, new in mapping.items():
        key = key.replace(old, new)
        
    if "attention.wq.codes" in key or "attention.wk.codes" in key:
        # [num_out_groups, num_in_groups, num_codebooks]
        print(f"Transposing codes {key} {value.shape=}")
        value = (value.reshape(32, 2, 128 // 2, -1, 2)
            .transpose(1, 2)
            .reshape(128 * 32, -1, 2))
        
    if "attention.wq.scales" in key or "attention.wk.scales" in key:
        # [num_out_groups, 1, 1, 1]
        print(f"Transposing scales {key} {value.shape=}")
        value = (value.reshape(32, 2, 128 // 2, 1)
            .transpose(1, 2)
            .reshape(128 * 32, 1, 1, 1))
    
    new_dict[key] = value
    
# del new_dict["output.weight"]
# del new_dict["tok_embeddings.weight"]

torch.save(new_dict, "/Users/blacksamorez/models/Llama-2-7b-AQLM-2Bit-2x8-hf/executorch.pth")

Transposing codes layers.0.attention.wk.codes value.shape=torch.Size([4096, 512, 2])
Transposing scales layers.0.attention.wk.scales value.shape=torch.Size([4096, 1, 1, 1])
Transposing codes layers.0.attention.wq.codes value.shape=torch.Size([4096, 512, 2])
Transposing scales layers.0.attention.wq.scales value.shape=torch.Size([4096, 1, 1, 1])
Transposing codes layers.1.attention.wk.codes value.shape=torch.Size([4096, 512, 2])
Transposing scales layers.1.attention.wk.scales value.shape=torch.Size([4096, 1, 1, 1])
Transposing codes layers.1.attention.wq.codes value.shape=torch.Size([4096, 512, 2])
Transposing scales layers.1.attention.wq.scales value.shape=torch.Size([4096, 1, 1, 1])
Transposing codes layers.10.attention.wk.codes value.shape=torch.Size([4096, 512, 2])
Transposing scales layers.10.attention.wk.scales value.shape=torch.Size([4096, 1, 1, 1])
Transposing codes layers.10.attention.wq.codes value.shape=torch.Size([4096, 512, 2])
Transposing scales layers.10.attention.wq.scale