In [1]:
%env CUDA_VISIBLE_DEVICES=4

import aqlm
import torch
from torch import nn
from torch.nn import functional as F

env: CUDA_VISIBLE_DEVICES=4


In [2]:
num_inputs = 1
out_features = 8192
in_features = 8192
out_group_size = 1
num_out_groups = out_features // out_group_size
in_group_size = 8
num_in_groups = in_features // in_group_size
num_codebooks = 1
codebook_size = 2**16


input = torch.rand((num_inputs, in_features), dtype=torch.float16, device='cuda')
codes = aqlm.utils.pack_int_data(
    torch.randint(0, 2**16, size=(num_out_groups, num_in_groups, num_codebooks), dtype=torch.int32, device='cuda'),
    16
)
codebooks = torch.rand((num_codebooks, codebook_size, out_group_size, in_group_size), dtype=torch.float16, device='cuda').normal_(1, 1)
scales = torch.rand((num_out_groups, 1, 1, 1), device='cuda').half()

weight = aqlm.utils._dequantize_weight(codes, codebooks, scales)
bias = torch.rand((out_features,), dtype=torch.float16, device='cuda')

In [3]:
triton_result = aqlm.inference_kernels.triton_kernel.triton_matmul(input, codes, codebooks, scales, bias)

In [4]:
cuda_result = aqlm.cuda.cuda_kernel.cuda_matmul(input, codes, codebooks, scales, bias)

In [5]:
torch.testing.assert_close(triton_result, cuda_result, rtol=0.005, atol=0.005)

In [6]:
for _ in range(100):
    _ = aqlm.inference_kernels.triton_kernel.triton_matmul(input, codes, codebooks, scales, bias)
torch.cuda.synchronize()

In [7]:
%%time
for i in range(1000):
    torch.cuda.synchronize()
    _ = aqlm.inference_kernels.triton_kernel.triton_matmul(input, codes, codebooks, scales, bias)
torch.cuda.synchronize()

CPU times: user 142 ms, sys: 97.3 ms, total: 240 ms
Wall time: 238 ms


In [8]:
for _ in range(100):
    _ = aqlm.cuda.cuda_kernel.cuda_matmul(input, codes, codebooks, scales, bias)
torch.cuda.synchronize()

In [9]:
%%time
for i in range(1000):
    torch.cuda.synchronize()
    _ = aqlm.cuda.cuda_kernel.cuda_matmul(input, codes, codebooks, scales, bias)
torch.cuda.synchronize()

CPU times: user 92.1 ms, sys: 114 ms, total: 206 ms
Wall time: 205 ms


In [10]:
for _ in range(100):
    _ = F.linear(input, weight, bias)
torch.cuda.synchronize()

In [11]:
%%time
for i in range(1000):
    torch.cuda.synchronize()
    _ = F.linear(input, weight, bias)
torch.cuda.synchronize()

CPU times: user 46.2 ms, sys: 72.7 ms, total: 119 ms
Wall time: 116 ms
