In [1]:
%env OMP_NUM_THREADS=1
%env MKL_NUM_THREADS=1
import torch
torch.set_num_threads(1)
import torch
import numpy as np

from src.aq import _dequantize_weight

torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

env: OMP_NUM_THREADS=1
env: MKL_NUM_THREADS=1


In [2]:
out_features = 4096
in_features = 4096
out_group_size = 1
in_group_size = 32
out_group_size = 1
num_codebooks = 4
nbits_per_codebook = 8
dtype = torch.float32

num_input_groups = in_features // in_group_size
codebooks = torch.randn(num_codebooks, 2**nbits_per_codebook, out_group_size, in_group_size, dtype=dtype)
codes = torch.randint(0, 2**nbits_per_codebook, 
                  size=(out_features, num_input_groups, num_codebooks), dtype=torch.uint8)
scales = torch.rand(out_features, 1, 1, 1, dtype=dtype)
x = torch.randn(1, in_features, dtype=dtype)/ 100

w = _dequantize_weight(codes, codebooks, scales)


In [3]:
%%time
for i in range(100):
    y = x @ w.T

CPU times: user 446 ms, sys: 129 µs, total: 447 ms
Wall time: 444 ms


In [4]:
x_numpy = x.numpy()
w_numpy = w.numpy()

In [5]:
%%time
for i in range(100):
    x_numpy @ w_numpy.T

CPU times: user 437 ms, sys: 13 µs, total: 437 ms
Wall time: 437 ms


In [6]:
codes_alt = torch.permute(codes, (1, 0, 2)).contiguous()  #  [num_in_groups, num_out_groups, num_codebooks]
# x, codebooks, codes_alt, scales, w = x.numpy(), codebooks.numpy(), codes_alt.numpy(), scales.numpy(), w.numpy()

In [30]:
import numba
@numba.njit(nopython=True, parallel=False)
def aqlm_gemv_lut(x, codebooks, codes_alt, scales, in_group_size: int):
    lut = x.reshape(-1, in_group_size) @ codebooks.reshape(-1, in_group_size).T
    lut = lut.reshape(-1, num_codebooks, 2**nbits_per_codebook)
    
    output_vec = np.zeros(out_features, dtype=x.dtype)
    for j in range(num_input_groups):
        for i in range(out_features):
            for c in range(num_codebooks):
                output_vec[i] += lut[j, c, codes_alt[j, i, c]]
    output_vec *= scales.flatten()
    return output_vec

# _ = aqlm_gemv_lut(x, codebooks, codes_alt, scales)
_ = aqlm_gemv_lut(x.numpy(), codebooks.numpy(), torch.permute(codes, (1, 0, 2)).contiguous().numpy(), scales.numpy(), in_group_size)



In [31]:
%%time
for i in range(100):
    # aqlm_gemv_lut(x, codebooks, codes_alt, scales)
    aqlm_gemv_lut(x.numpy(), codebooks.numpy(), codes.numpy(), scales.numpy(), in_group_size)

CPU times: user 110 ms, sys: 0 ns, total: 110 ms
Wall time: 109 ms


In [None]:
aqlm_gemv_lut(x.numpy(), codebooks.numpy(), torch.permute(codes, (1, 0, 2)).contiguous().numpy(), scales.numpy())

array([-2.81641752e-01,  1.00420463e+00, -8.17836404e-01,  2.20113322e-01,
        3.20841908e-01,  7.90404499e-01,  1.88450608e-02,  3.29190046e-02,
        2.43881062e-01,  4.57227618e-01,  2.08502576e-01,  1.17066443e+00,
       -1.50245607e-01,  5.11013091e-01,  1.17427751e-01,  1.93431258e-01,
        1.66756463e+00, -1.48235768e-01, -8.21565628e-01, -2.23852754e+00,
        8.29873443e-01,  3.76370281e-01,  7.36643136e-01, -3.63634154e-02,
        5.99597931e-01,  1.42398670e-01, -5.82772315e-01, -3.35145175e-01,
       -1.25073993e+00, -3.20138186e-01,  2.45220947e+00, -2.49513894e-01,
        8.57899427e-01,  2.57896245e-01, -5.26075475e-02, -1.07554531e+00,
       -1.08862269e+00, -1.40455818e+00, -1.82120653e-03, -1.09655166e+00,
        6.54247105e-01,  6.17258549e-02, -1.26817403e-02, -1.65564388e-01,
        1.38416708e+00, -7.27667734e-02,  1.82184052e+00, -3.75634968e-01,
        4.39066291e-02,  3.37525159e-01, -1.12479830e+00,  1.00341213e+00,
        3.44569683e-02, -

In [None]:
x @ w.T

tensor([[-0.2816,  1.0042, -0.8178,  ..., -0.2378, -1.2057,  0.1868]])