In [1]:
%env CUDA_VISIBLE_DEVICES=7
%env OMP_NUM_THREADS=1
%env MKL_NUM_THREADS=1
import torch
torch.set_num_threads(16)
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
import numpy as np
sys.path.append('..')

torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

import triton
assert triton.__version__.startswith('2.1.0'), f"found triton {triton.__version__}, want 2.1.0*"
import torch
import triton.language as tl
from triton.ops.matmul import get_configs_io_bound, early_config_prune, estimate_matmul_time
from src.aq import _dequantize_weight

env: CUDA_VISIBLE_DEVICES=7
env: OMP_NUM_THREADS=1
env: MKL_NUM_THREADS=1


In [2]:
out_features = 8192
in_features = 8192
in_group_size = 16
out_group_size = 1
num_codebooks = 4
nbits_per_codebook = 8
dtype = torch.float16
device = torch.device('cuda')
scales = torch.exp(torch.randn(out_features // out_group_size, 1, 1, 1)).to(device=device, dtype=dtype)
codes = torch.randint(0, 2 ** nbits_per_codebook,
                      size=(out_features // out_group_size, in_features // in_group_size, num_codebooks),
                      dtype=torch.int32)
codes = codes.to(dtype=torch.uint8, device=device)

codebooks = torch.randn(
    num_codebooks, 2 ** nbits_per_codebook, out_group_size, in_group_size, dtype=dtype, device=device) / 100

input_vec = torch.randn(1, in_features, device='cuda', dtype=dtype)

In [3]:
import torch
dtype = torch.float32
num_input_groups = in_features // in_group_size
C = torch.randn(num_codebooks, 2**nbits_per_codebook, in_group_size)
b = torch.randint(0, 2**nbits_per_codebook, 
                  size=(out_features, num_input_groups, num_codebooks), dtype=torch.uint8)
x = torch.randn(1, in_features)/ 100
w = torch.randn(out_features, in_features, dtype=dtype)


In [4]:
%%time
for i in range(100):
    y = F.linear(x, w)

CPU times: user 2.43 s, sys: 5.23 ms, total: 2.43 s
Wall time: 2.43 s


In [5]:
%%time
for i in range(100):
    y = x @ w.T

CPU times: user 2.4 s, sys: 2.77 ms, total: 2.4 s
Wall time: 2.4 s


In [6]:
x_numpy = x.numpy()
w_numpy = w.numpy()

In [7]:
%%time
for i in range(1000):
    y = x_numpy @ w_numpy.T

CPU times: user 24.3 s, sys: 7.57 ms, total: 24.3 s
Wall time: 24.3 s


In [8]:
b_alt = torch.permute(b, (1, 2, 0)).contiguous()
b_alt.shape

torch.Size([512, 4, 8192])

In [9]:
import numba
@numba.njit(nopython=True, parallel=False)
def aqlm_gemv_lut(x, C, b_alt):
    lut = x.reshape(-1, in_group_size) @ C.reshape(-1, in_group_size).T
    lut = lut.reshape(-1, num_codebooks, 2**nbits_per_codebook)
    
    output_vec = np.zeros(in_features, dtype=x.dtype)
    for j in range(num_input_groups):
        for c in range(num_codebooks):
            for i in range(out_features):
                output_vec[i] += lut[j, c, b_alt[j, c, i]]
    return output_vec



In [10]:
x, C, b = x.numpy(), C.numpy(), b.numpy()
b_alt = np.copy(b_alt)
aqlm_gemv_lut(x, C, b_alt);

In [11]:
%%time
for i in range(1000):
    aqlm_gemv_lut(x, C, b_alt)

CPU times: user 3min 19s, sys: 68 ms, total: 3min 19s
Wall time: 13 s


In [12]:
import cpp_kernel

RECOMPILING


In [13]:
@numba.njit(nopython=True, parallel=False)
def get_lut(x, C):
    lut = x.reshape(-1, in_group_size) @ C.reshape(-1, in_group_size).T
    lut = lut.reshape(-1, num_codebooks, 2**nbits_per_codebook)
    return lut

def aqlm_gemv_lut_cpp(x, C, b_alt):
    lut = get_lut(x, C)
    return cpp_kernel.triple_for(lut, b_alt, 1)



In [14]:
%%time
for i in range(1000):
    aqlm_gemv_lut_cpp(x, C, b_alt)

CPU times: user 3min 41s, sys: 314 ms, total: 3min 42s
Wall time: 18.7 s


In [15]:
b_alt.shape

(512, 4, 8192)