In [1]:
from aqlm import QuantizedLinear
from aqlm.utils import _dequantize_weight, unpack_int_data

In [2]:
!bash build.sh

-- Using python executable 'python'
-- Resolved buck2 as /Users/blacksamorez/reps/AQLM/inference_lib/src/aqlm/inference_kernels/executorch/buck2-bin/buck2-99773fe6f7963a72ae5f7b737c02836e.
-- Killing buck2 daemon
-- executorch: Generating source lists
-- executorch: Using source file list /Users/blacksamorez/reps/AQLM/inference_lib/src/aqlm/inference_kernels/executorch/executorch_srcs.cmake
-- executorch: Using sources file /Users/blacksamorez/reps/AQLM/inference_lib/src/aqlm/inference_kernels/executorch/executorch_srcs.cmake
-- Proceeding with version: 24.3.25.0
-- CMAKE_CXX_FLAGS: 
-- Generating operator lib:
--   LIB_NAME: portable_ops_lib
--   OPS_SCHEMA_YAML: /Users/blacksamorez/reps/executorch/kernels/portable/functions.yaml
--   ROOT_OPS: 
--   INCLUDE_ALL_OPS: 
[0mCommand - python;-m;codegen.tools.gen_oplist;--output_path=/Users/blacksamorez/reps/AQLM/inference_lib/src/aqlm/inference_kernels/executorch/kernels/portable/portable_ops_lib/selected_operators.yaml;--ops_schema_yaml

In [3]:
SIZE = 1024

layer = QuantizedLinear(
    in_features=SIZE,
    out_features=SIZE * 3,
    in_group_size=8,
    out_group_size=1,
    num_codebooks=2,
    nbits_per_codebook=8,
    bias=True,
)

In [4]:
reference_weight = _dequantize_weight(
    unpack_int_data(layer.codes, 8),
    layer.codebooks,
    layer.scales,
)

In [5]:
import torch

input = torch.rand((3, 2, SIZE)) * 2 - 1

In [6]:
import torch
torch.ops.load_library("./cmake-out/libaqlm_bindings.dylib")

reference = input @ reference_weight.T + layer.bias
test = torch.ops.aqlm.code2x8_lut_matmat(
    input,
    torch.permute(layer.codes, (1, 0, 2)).contiguous(),
    layer.codebooks,
    layer.scales,
    layer.bias,
)

torch.testing.assert_close(
    test,
    reference,
    atol=0.01,
    rtol=1e-3,
)

In [7]:
%%time

for i in range(10):
    input @ reference_weight.T

CPU times: user 21.6 ms, sys: 17 ms, total: 38.6 ms
Wall time: 10.4 ms


In [8]:
%%time

for i in range(10):
    torch.ops.aqlm.code2x8_lut_matmat(
        input,
        torch.permute(layer.codes, (1, 0, 2)).contiguous(),
        layer.codebooks,
        layer.scales,
        layer.bias,
    )

CPU times: user 857 ms, sys: 239 ms, total: 1.1 s
Wall time: 201 ms


In [17]:
import torch
from torch.export import export, ExportedProgram, Dim
from executorch.exir import EdgeCompileConfig, to_edge

_ = layer(input)

batch_size = Dim("batch_size", min=1)
seq_len = Dim("seq_len", min=1)
dynamic_shapes = {"input": {0: batch_size, 1: seq_len}}

with torch.no_grad():
    aten_dialect = export(layer, (input,), dynamic_shapes=dynamic_shapes)
    
edge_manager = to_edge(aten_dialect, compile_config=EdgeCompileConfig(_check_ir_validity=False))
et_program = edge_manager.to_executorch()

with open("aqlm.pte", "wb") as file:
    file.write(et_program.buffer)

