In [1]:
import torch
import awq_inference_engine


device     = torch.device("cuda")
dtype      = torch.float16
BS, SL, IN_CH, OUT_CH = 4, 4096, 4096, 11008
bit_width  = 4
group_size = 128

qmax = 2**(bit_width-1) - 1
qmin = -2**(bit_width-1)

W_fp16 = torch.randn(OUT_CH, IN_CH, dtype=dtype, device=device)
X_fp16 = torch.randn(BS, SL, IN_CH, dtype=dtype, device=device)

# # --- 2) MANUAL INT4 QUANTIZATION OF W_fp16 ------------------------------
# W_int = W_fp16.reshape(-1, group_size)
# w_scale = torch.maximum(
#     W_int.amin(dim=0) / qmin, 
#     W_int.amax(dim=0) / qmax, 
# )
# W_int /= w_scale
# W_int = torch.clamp(W_int, qmin, qmax).round_()
# W_int = W_int.reshape_as(W_fp16)
# W_int


In [2]:
def pack_intweight(unpacked_qweight, interleave, kstride):
    # unpacked_qweight: [N, K]
    N = unpacked_qweight.shape[0]
    K = unpacked_qweight.shape[1]

    Packed_Kernel = unpacked_qweight.cpu().numpy().reshape(N, K // 32, 32)
    # np.arange(32).reshape(4, 4, 2).transpose(1, 0, 2) => [0, 1, 8, 9, 16, 17, 24, 25, ...]
    Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 3, 2, 4)
    Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 32)

    # reorder each 8 weights for fast dequantization
    # [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7]
    Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 8)
    Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 2, 4, 3)
    Packed_Kernel = Packed_Kernel.reshape(N, K)

    # interleaving every four rows
    Packed_Kernel = Packed_Kernel.reshape(
        N // interleave, interleave, K // kstride, kstride
    )
    # N // 4, K // 64, 4, 64
    Packed_Kernel = Packed_Kernel.transpose(0, 2, 1, 3)
    Packed_Kernel = Packed_Kernel.reshape(
        N // interleave, K // kstride, kstride, interleave
    )
    # Packing -> (N // 4, K // 64, 64)
    Packed_Kernel = (
        Packed_Kernel[..., 0]
        | (Packed_Kernel[..., 1] << 4)
        | (Packed_Kernel[..., 2] << 8)
        | (Packed_Kernel[..., 3] << 12)
    )
    # reshape to (N // 4, K), FP16 format
    Packed_Kernel = Packed_Kernel.reshape(N // interleave, K)
    qweight = (
        torch.tensor(Packed_Kernel.astype("int16"))
        .to(unpacked_qweight.device)
        .contiguous()
    )
    return qweight


def pseudo_quantize_tensor(
    w, n_bit=8, zero_point=True, q_group_size=-1, inplace=False, get_scale_zp=False
):
    org_w_shape = w.shape
    if q_group_size > 0:
        assert org_w_shape[-1] % q_group_size == 0
        w = w.reshape(-1, q_group_size)
    assert w.dim() == 2
    if zero_point:
        max_val = w.amax(dim=1, keepdim=True)
        min_val = w.amin(dim=1, keepdim=True)
        max_int = 2**n_bit - 1
        min_int = 0
        scales = (max_val - min_val).clamp(min=1e-5) / max_int
        zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
    else:  # we actually never used this
        assert min_val is None
        max_val = w.abs().amax(dim=1, keepdim=True)
        max_val = max_val.clamp(min=1e-5)
        max_int = 2 ** (n_bit - 1) - 1
        min_int = -(2 ** (n_bit - 1))
        scales = max_val / max_int
        zeros = 0

    assert torch.isnan(scales).sum() == 0
    assert torch.isnan(w).sum() == 0

    if inplace:
        (
            (w.div_(scales).round_().add_(zeros)).clamp_(min_int, max_int).sub_(zeros)
        ).mul_(scales)
    else:
        w = (
            torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
        ) * scales
    assert torch.isnan(w).sum() == 0

    w = w.reshape(org_w_shape)

    if get_scale_zp:
        return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)
    else:
        return w
    

def make_divisible(c, divisor):
    return (c + divisor - 1) // divisor

def calculate_zeros_width(in_features, group_size=128, pack_num=8):
    if group_size >= 128:
        size_multiplier = 1
    elif group_size == 64:
        size_multiplier = 2
    elif group_size == 32:
        size_multiplier = 4
    else:
        raise NotImplementedError

    base_width = make_divisible(in_features // group_size, pack_num)
    base_width = make_divisible(base_width, size_multiplier) * size_multiplier
    return base_width

In [3]:
Wq_fp16, scales, zeros = pseudo_quantize_tensor(
	W_fp16, n_bit=bit_width, zero_point=True, get_scale_zp=True, q_group_size=group_size
)

In [16]:
scale_zeros = zeros * scales
dtype = scales.dtype

pack_num = 32 // bit_width

qscales = torch.zeros(
            (
                scales.shape[0],
                calculate_zeros_width(IN_CH, group_size) * pack_num,
            ),
            dtype=dtype,
            device=scales.device,
        )
qscales[:, : scales.shape[1]] = scales
prepared_scales = qscales.transpose(1, 0).contiguous()

intweight = []
for idx in range(IN_CH):
	intweight.append(
		torch.round(
			(Wq_fp16[:, idx] + scale_zeros[:, idx // group_size])
			/ qscales[:, idx // group_size]
		).to(torch.int)[:, None]
	)
intweight = torch.cat(intweight, dim=1)

qweight = pack_intweight(
            intweight.contiguous(), interleave=4, kstride=64
        )

zeros = zeros.to(dtype=torch.int32)
scaled_zeros = torch.zeros_like(qscales)
# scaled_zeros[:, :scales.shape[1]] = -(qscales[:, :scales.shape[1]] * (zeros.to(torch.float32) - 8.0)).to(torch.float16)
scaled_zeros[:, : scales.shape[1]] = -(
    qscales[:, : scales.shape[1]] * (zeros.to(torch.float32))
).to(dtype)
scaled_zeros = scaled_zeros.transpose(1, 0).contiguous()

In [24]:
%%timeit -n 100 -r 7
out = awq_inference_engine.gemm_forward_cuda_new(
                X_fp16, qweight, prepared_scales, scaled_zeros
            )

54.8 ms ± 19.4 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [23]:
%%timeit -n 100 -r 7
out = torch.nn.functional.linear(input=X_fp16, weight=W_fp16)

The slowest run took 3612.18 times longer than the fastest. This could mean that an intermediate result is being cached.
39.5 ms ± 24.3 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
