In [1]:
%env CUDA_VISIBLE_DEVICES=4

import aqlm
import torch
from torch import nn
from torch.nn import functional as F

env: CUDA_VISIBLE_DEVICES=4


In [2]:
num_inputs = 1
out_features = 4096
in_features = 4096
out_group_size = 1
num_out_groups = out_features // out_group_size
in_group_size = 8
num_in_groups = in_features // in_group_size
num_codebooks = 1
codebook_size = 2**16


input = torch.rand((num_inputs, in_features), dtype=torch.float16, device='cuda')
codes = aqlm.utils.pack_int_data(
    torch.randint(0, 2**16, size=(num_out_groups, num_in_groups, num_codebooks), dtype=torch.int32, device='cuda'),
    16
)
codebooks = torch.rand((num_codebooks, codebook_size, out_group_size, in_group_size), dtype=torch.float16, device='cuda').normal_(1, 1)
scales = torch.rand((num_out_groups, 1, 1, 1), device='cuda').half()

weight = aqlm.utils._dequantize_weight(codes, codebooks, scales)
bias = None # torch.rand((out_features,), dtype=torch.float16, device='cuda')

In [3]:
triton_result = aqlm.inference_kernels.triton_kernel.triton_matmul(input, codes, codebooks, scales, bias)

In [4]:
cuda_result = aqlm.cuda.cuda_kernel.cuda_matmul(input, codes, codebooks, scales, bias)

In [5]:
torch.testing.assert_close(triton_result, cuda_result, rtol=0.005, atol=0.005)

In [6]:
for _ in range(100):
    _ = aqlm.inference_kernels.triton_kernel.triton_matmul(input, codes, codebooks, scales, bias)
torch.cuda.synchronize()

In [7]:
%%time
for i in range(1000):
    torch.cuda.synchronize()
    _ = aqlm.inference_kernels.triton_kernel.triton_matmul(input, codes, codebooks, scales, bias)
torch.cuda.synchronize()

CPU times: user 84.6 ms, sys: 45.2 ms, total: 130 ms
Wall time: 128 ms


In [8]:
for _ in range(100):
    _ = aqlm.cuda.cuda_kernel.cuda_matmul(input, codes, codebooks, scales, bias)
torch.cuda.synchronize()

In [9]:
%%time
for i in range(1000):
    torch.cuda.synchronize()
    _ = aqlm.cuda.cuda_kernel.cuda_matmul(input, codes, codebooks, scales, bias)
torch.cuda.synchronize()

CPU times: user 95.9 ms, sys: 6.64 ms, total: 103 ms
Wall time: 101 ms


In [10]:
for _ in range(100):
    _ = F.linear(input, weight, bias)
torch.cuda.synchronize()

In [11]:
%%time
for i in range(1000):
    torch.cuda.synchronize()
    _ = F.linear(input, weight, bias)
torch.cuda.synchronize()

CPU times: user 29.9 ms, sys: 21.8 ms, total: 51.6 ms
Wall time: 50.5 ms


In [12]:
import sys
sys.path.append("/home/blacksamorez/quip-sharp")
from model.llama import LlamaForCausalLM

llama = LlamaForCausalLM.from_pretrained("relaxml/Llama-2-7b-E8P-2Bit")
quip_linear = llama.model.layers[0].self_attn.o_proj.cuda()
print(quip_linear.in_features, quip_linear.out_features)

Using `is_flash_attn_available` is deprecated and will be removed in v4.38. Please use `is_flash_attn_2_available` instead.
I0130 00:40:47.090702 598054 utils.py:145] Note: detected 255 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
I0130 00:40:47.092173 598054 utils.py:148] Note: NumExpr detected 255 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
I0130 00:40:47.093024 598054 utils.py:160] NumExpr defaulting to 8 threads.
I0130 00:40:47.257316 598054 config.py:58] PyTorch version 2.1.2 available.


4096 4096


In [13]:
for _ in range(100):
    _ = quip_linear(input)
torch.cuda.synchronize()

In [14]:
%%time
for i in range(1000):
    torch.cuda.synchronize()
    _ = quip_linear(input)
torch.cuda.synchronize()

CPU times: user 321 ms, sys: 1 s, total: 1.32 s
Wall time: 1.37 s
