In [1]:
import torch

import os
# os.environ['TRITON_INTERPRET'] = '1'
import triton
import triton.language as tl
# import math

DEVICE = triton.runtime.driver.active.get_active_torch_device()

if torch.cuda.is_available():
    device = torch.cuda.current_device()
    props = torch.cuda.get_device_properties(device)
    
    print(f"Device: {props.name}")
    print(f"Compute Capability: {props.major}.{props.minor}")
    print(f"Total Memory: {props.total_memory / 1024**3:.2f} GB")
    # print(f"CUDA Cores: {props.multi_processor_count}")

Device: NVIDIA GeForce RTX 5070 Ti
Compute Capability: 12.0
Total Memory: 15.45 GB


In [2]:
def add_speedup_columns(df):
    df_result = df.copy()
    base_column = 'torch_fp16'
    
    new_column_order = []
    
    for column in df.columns:
        new_column_order.append(column)
        if column not in ['K', 'M', 'N', base_column]:
            speedup_column = f"{column}_speedup"
            df_result[speedup_column] = df[column] / df[base_column]
            new_column_order.append(speedup_column)
    df_result = df_result[new_column_order]
    return df_result

In [3]:
def get_cuda_autotune_config_small_bs():
    configs = []
    for num_warps, num_stages in [
        (4, 2),
        (4, 3),
        (4, 4),
    ]:
        for BLOCK_SIZE_M in [16]:
            for BLOCK_SIZE_N in [16, 32, 64, 128]:
                for BLOCK_SIZE_K in [128, 256, 512]:
                    configs.append(
                        triton.Config(
                            {
                                "BLOCK_SIZE_M" : BLOCK_SIZE_M,
                                "BLOCK_SIZE_N" : BLOCK_SIZE_N,
                                "BLOCK_SIZE_K" : BLOCK_SIZE_K,
                            },
                            num_stages=num_stages, 
                            num_warps=num_warps
                        ),
                    )
    return configs


@triton.jit()
def get_pid_point_grouped(
        pid,
        M, N,
        BLOCK_SIZE_M: tl.constexpr,
        BLOCK_SIZE_N: tl.constexpr,
        GROUP_SIZE: tl.constexpr
    ):
    
    grid_m = tl.cdiv(M, BLOCK_SIZE_M)
    grid_n = tl.cdiv(N, BLOCK_SIZE_N)

    width = GROUP_SIZE * grid_n
    group_id = pid // width
    group_size = tl.minimum(grid_m - group_id * group_m, group_m)

    pid_m = group_id * group_m + (pid % group_size)
    pid_n = (pid % width) // group_size

    return pid_m, pid_n


@triton.jit()
def get_pid_point_base(
        pid,
        M, N,
        BLOCK_SIZE_M: tl.constexpr,
        BLOCK_SIZE_N: tl.constexpr,
    ):
    grid_n = tl.cdiv(N, BLOCK_SIZE_N)
    pid_m = 0
    pid_n = pid % grid_n
    return pid_m, pid_n


@triton.jit()
def get_pid_point_swizzle(
        pid,
        M, N,
        BLOCK_SIZE_M: tl.constexpr,
        BLOCK_SIZE_N: tl.constexpr,
    ):
    pid = tl.program_id(axis=0)
    pid_m = 0
    block_id = pid // 64
    in_block = pid % 64
    s_block = in_block % 16
    s_pos = in_block // 16
    pid_n = 64 * block_id + (64 // 16) * s_block + s_pos
    
    return pid_m, pid_n



In [4]:
@triton.autotune(
    configs=get_cuda_autotune_config_small_bs(),
    key=['M', 'N', 'K'],
)
@triton.jit
def kernel_w4t_a16_matmul_small(
        a_ptr, b_ptr, c_ptr, scale_ptr,
        M, N, K,
        stride_am, stride_ak,
        stride_bk, stride_bn,
        stride_cm, stride_cn,
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,  #
):
    pid = tl.program_id(axis=0)
    pid_m, pid_n = get_pid_point_base(
        pid,
        M, N,
        BLOCK_SIZE_M,
        BLOCK_SIZE_N
    )

    tl.assume(pid_m >= 0)
    tl.assume(pid_n >= 0)
    tl.assume(stride_am > 0)
    tl.assume(stride_ak > 0)
    tl.assume(stride_bn > 0)
    tl.assume(stride_bk > 0)
    tl.assume(stride_cm > 0)
    tl.assume(stride_cn > 0)

    scale = tl.load(scale_ptr)

    offs_k = tl.arange(0, BLOCK_SIZE_K)
    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    
    offs_bk = tl.arange(0, BLOCK_SIZE_K // 8)
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    b_ptrs = b_ptr + ((offs_bk[:, None]) * stride_bk + offs_bn[None, :] * stride_bn)

    shifter = tl.arange(0, 8) * 4

    accumulator_dtype = tl.float32
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype)

    a_mask0 = (offs_am[:, None] < M) & (offs_k[None, :] < K)
    b_mask0 = (offs_bn[None, :] < N) & ((offs_bk[:, None]) < K // 8)
    a = tl.load(a_ptrs, mask=a_mask0, other=0.0, eviction_policy="evict_last")
    b_bits = tl.load(b_ptrs, mask=b_mask0, other=0)
    for k_idx in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        next_k_offset = (k_idx + 1) * BLOCK_SIZE_K
        
        if k_idx + 1 < tl.cdiv(K, BLOCK_SIZE_K):
            a_next_mask = (offs_am[:, None] < M) & (offs_k[None, :] + next_k_offset < K)
            b_next_mask = (offs_bn[None, :] < N) & ((offs_bk[:, None] + (next_k_offset // 8)) < K // 8)
            a_next = tl.load(a_ptrs + next_k_offset * stride_ak, mask=a_next_mask, other=0.0, eviction_policy="evict_last")
            b_bits_next = tl.load(b_ptrs + (next_k_offset // 8) * stride_bk, mask=b_next_mask, other=0)
        else:
            a_next = tl.zeros_like(a)
            b_bits_next = tl.zeros_like(b_bits)

        b = (b_bits[:, None, :] >> shifter[None, :, None]) & 0xF
        b = tl.reshape(b, (BLOCK_SIZE_K, BLOCK_SIZE_N))
        b = (b.to(tl.float16) - 7.5)
        
        accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)

        a = a_next
        b_bits = b_bits_next

    accumulator = accumulator * scale

    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, accumulator.to(tl.float16), mask=c_mask)


def triton_matmul_w4a16_tensor(a, b, scale):
    # Check constraints.
    assert a.shape[1] == b.shape[0] * 8, "Incompatible dimensions"
    assert a.is_contiguous(), "Matrix A must be contiguous"

    assert a.dtype == torch.float16
    assert b.dtype == torch.int32
    
    M, K = a.shape
    N = b.shape[1]
    
    assert scale.dtype == torch.float16
    assert len(scale.shape) == 1
    assert scale.shape[0] == 1
    
    # Allocates output.
    c = torch.empty((M, N), device=a.device, dtype=torch.float16)
    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    kernel_w4t_a16_matmul_small[grid](
        a, b, c, scale, #
        M, N, K,  #
        a.stride(0), a.stride(1),  #
        b.stride(0), b.stride(1),  #
        c.stride(0), c.stride(1),  #
    )
    return c


In [5]:
@triton.autotune(
    configs=get_cuda_autotune_config_small_bs(),
    key=['M', 'N', 'K'],
)
@triton.jit
def kernel_w4t_a16_matmul_small_test2(
        a_ptr, b_ptr, c_ptr, scale_ptr,
        M, N, K,
        stride_am, stride_ak,
        stride_bk, stride_bn,
        stride_cm, stride_cn,
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,  #
):
    pid = tl.program_id(axis=0)
    pid_m, pid_n = get_pid_point_base(
        pid,
        M, N,
        BLOCK_SIZE_M,
        BLOCK_SIZE_N
    )

    tl.assume(pid_m >= 0)
    tl.assume(pid_n >= 0)
    tl.assume(stride_am > 0)
    tl.assume(stride_ak > 0)
    tl.assume(stride_bn > 0)
    tl.assume(stride_bk > 0)
    tl.assume(stride_cm > 0)
    tl.assume(stride_cn > 0)

    scale = tl.load(scale_ptr)

    offs_k = tl.arange(0, BLOCK_SIZE_K)
    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    
    offs_bk = tl.arange(0, BLOCK_SIZE_K)
    offs_bn = (pid_n * BLOCK_SIZE_N // 8 + tl.arange(0, BLOCK_SIZE_N // 8)) % (N // 8)
    b_ptrs = b_ptr + ((offs_bk[:, None]) * stride_bk + offs_bn[None, :] * stride_bn)

    shifter = tl.arange(0, 8) * 4

    accumulator_dtype = tl.float32
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype)

    a_mask0 = (offs_am[:, None] < M) & (offs_k[None, :] < K)
    b_mask0 = (offs_bn[None, :] < (N // 8)) & ((offs_bk[:, None]) < K)
    a = tl.load(a_ptrs, mask=a_mask0, other=0.0, eviction_policy="evict_last")
    b_bits = tl.load(b_ptrs, mask=b_mask0, other=0)
    for k_idx in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        next_k_offset = (k_idx + 1) * BLOCK_SIZE_K
        
        if k_idx + 1 < tl.cdiv(K, BLOCK_SIZE_K):
            a_next_mask = (offs_am[:, None] < M) & (offs_k[None, :] + next_k_offset < K)
            b_next_mask = (offs_bn[None, :] < N // 8) & ((offs_bk[:, None] + (next_k_offset)) < K)
            a_next = tl.load(a_ptrs + next_k_offset * stride_ak, mask=a_next_mask, other=0.0, eviction_policy="evict_last")
            b_bits_next = tl.load(b_ptrs + next_k_offset * stride_bk, mask=b_next_mask, other=0)
        else:
            a_next = tl.zeros_like(a)
            b_bits_next = tl.zeros_like(b_bits)

        b = (b_bits[:, :, None] >> shifter[None, None, :]) & 0xF
        b = tl.reshape(b, (BLOCK_SIZE_K, BLOCK_SIZE_N))
        b = (b.to(tl.float16) - 7.5)
        
        accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)

        a = a_next
        b_bits = b_bits_next

    accumulator = accumulator * scale

    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, accumulator.to(tl.float16), mask=c_mask)


def triton_matmul_w4a16_tensor_test2(a, b, scale):
    # Check constraints.
    assert a.shape[1] == b.shape[0], "Incompatible dimensions"
    assert a.is_contiguous(), "Matrix A must be contiguous"

    assert a.dtype == torch.float16
    assert b.dtype == torch.int32
    
    M, K = a.shape
    N = b.shape[1] * 8
    
    assert scale.dtype == torch.float16
    assert len(scale.shape) == 1
    assert scale.shape[0] == 1
    
    # Allocates output.
    c = torch.empty((M, N), device=a.device, dtype=torch.float16)
    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    kernel_w4t_a16_matmul_small_test2[grid](
        a, b, c, scale, #
        M, N, K,  #
        a.stride(0), a.stride(1),  #
        b.stride(0), b.stride(1),  #
        c.stride(0), c.stride(1),  #
    )
    return c


In [6]:
@triton.autotune(
    configs=get_cuda_autotune_config_small_bs(),
    key=['M', 'N', 'K'],
)
@triton.jit
def kernel_w4t_a16_matmul_small_test3(
        a_ptr, b_ptr, c_ptr, scale_ptr,
        M, N, K,
        stride_am, stride_ak,
        stride_bn, stride_bk,
        stride_cm, stride_cn,
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,  #
):
    pid = tl.program_id(axis=0)
    pid_m, pid_n = get_pid_point_base(
        pid,
        M, N,
        BLOCK_SIZE_M,
        BLOCK_SIZE_N
    )

    tl.assume(pid_m >= 0)
    tl.assume(pid_n >= 0)
    tl.assume(stride_am > 0)
    tl.assume(stride_ak > 0)
    tl.assume(stride_bn > 0)
    tl.assume(stride_bk > 0)
    tl.assume(stride_cm > 0)
    tl.assume(stride_cn > 0)

    scale = tl.load(scale_ptr)

    offs_k = tl.arange(0, BLOCK_SIZE_K)
    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    
    offs_bk = tl.arange(0, BLOCK_SIZE_K // 8)
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    # b_ptrs = b_ptr + ((offs_bk[:, None]) * stride_bk + offs_bn[None, :] * stride_bn)
    b_ptrs = b_ptr + (offs_bn[:, None] * stride_bn + (offs_bk[None, :]) * stride_bk)

    shifter = tl.arange(0, 8) * 4

    accumulator_dtype = tl.float32
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype)

    a_mask0 = (offs_am[:, None] < M) & (offs_k[None, :] < K)
    b_mask0 = (offs_bn[:, None] < N) & ((offs_bk[None, :]) < K // 8)
    a = tl.load(a_ptrs, mask=a_mask0, other=0.0, eviction_policy="evict_last")
    b_bits = tl.load(b_ptrs, mask=b_mask0, other=0)
    for k_idx in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        next_k_offset = (k_idx + 1) * BLOCK_SIZE_K
        
        if k_idx + 1 < tl.cdiv(K, BLOCK_SIZE_K):
            a_next_mask = (offs_am[:, None] < M) & (offs_k[None, :] + next_k_offset < K)
            b_next_mask = (offs_bn[:, None] < N) & ((offs_bk[None, :] + (next_k_offset // 8)) < K // 8)
            a_next = tl.load(a_ptrs + next_k_offset * stride_ak, mask=a_next_mask, other=0.0, eviction_policy="evict_last")
            b_bits_next = tl.load(b_ptrs + (next_k_offset // 8) * stride_bk, mask=b_next_mask, other=0)
        else:
            a_next = tl.zeros_like(a)
            b_bits_next = tl.zeros_like(b_bits)

        b = (b_bits[:, None, :] >> shifter[None, :, None]) & 0xF
        b = tl.reshape(b, (BLOCK_SIZE_N, BLOCK_SIZE_K)).trans(1, 0)
        b = (b.to(tl.float16) - 7.5)
        
        accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)

        a = a_next
        b_bits = b_bits_next

    accumulator = accumulator * scale

    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, accumulator.to(tl.float16), mask=c_mask)


def triton_matmul_w4a16_tensor_test3(a, b, scale):
    # Check constraints.
    assert a.shape[1] == b.shape[1] * 8, "Incompatible dimensions"
    assert a.is_contiguous(), "Matrix A must be contiguous"

    assert a.dtype == torch.float16
    assert b.dtype == torch.int32
    
    M, K = a.shape
    N = b.shape[0]
    
    assert scale.dtype == torch.float16
    assert len(scale.shape) == 1
    assert scale.shape[0] == 1
    
    # Allocates output.
    c = torch.empty((M, N), device=a.device, dtype=torch.float16)
    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    kernel_w4t_a16_matmul_small_test3[grid](
        a, b, c, scale, #
        M, N, K,  #
        a.stride(0), a.stride(1),  #
        b.stride(0), b.stride(1),  #
        c.stride(0), c.stride(1),  #
    )
    return c


In [7]:
@torch.no_grad()
def deepseek_dist(x, y):
    x, y = x.double(), y.double()
    denom = (x * x + y * y).sum()
    sim = 2 * (x * y).sum() / denom
    return 1 - sim


def run_w4a16_tensor_test(M, N, K, print_dist=False):
    y_fp16 = torch.randn(M, K, dtype=torch.float16, device="cuda") / (M * K)
    x_compressed = torch.randint(-2**31, 2**31, (K // 8, N), dtype=torch.int32, device="cuda")
    tensor_scale = torch.randn(1, dtype=torch.float16, device="cuda").abs()

    # triton
    triton_out = triton_matmul_w4a16_tensor(y_fp16, x_compressed, tensor_scale)

    # torch
    shifter = (torch.arange(0, 8) * 4).cuda()
    x_decompressed = (x_compressed[:, None, :] >> shifter[None, :, None]) & 0xF
    x_decompressed = x_decompressed.reshape(K, N).to(torch.float16) - 7.5
    torch_out = (y_fp16 @ x_decompressed) * tensor_scale

    # results
    dist = deepseek_dist(triton_out, torch_out)
    
    if print_dist:
        print(f"[M x K x N]: [{M} x {K} x {N}], dist = {dist}")
        
    assert dist < 0.001


def run_w4a16_tensor_test2_test(M, N, K, print_dist=False):
    y_fp16 = torch.randn(M, K, dtype=torch.float16, device="cuda") / (M * K)
    x_compressed = torch.randint(-2**31, 2**31, (K, N // 8), dtype=torch.int32, device="cuda")
    tensor_scale = torch.randn(1, dtype=torch.float16, device="cuda").abs()

    # triton
    triton_out = triton_matmul_w4a16_tensor_test2(y_fp16, x_compressed, tensor_scale)

    # torch
    shifter = (torch.arange(0, 8) * 4).cuda()
    x_decompressed = (x_compressed[:, :, None] >> shifter[None, None, :]) & 0xF
    x_decompressed = x_decompressed.reshape(K, N).to(torch.float16) - 7.5
    torch_out = (y_fp16 @ x_decompressed) * tensor_scale

    # results
    dist = deepseek_dist(triton_out, torch_out)
    
    if print_dist:
        print(f"[M x K x N]: [{M} x {K} x {N}], dist = {dist}")
        
    assert dist < 0.001


def run_w4a16_tensor_test3_test(M, N, K, print_dist=False):
    y_fp16 = torch.randn(M, K, dtype=torch.float16, device="cuda") / (M * K)
    x_compressed = torch.randint(-2**31, 2**31, (N, K // 8), dtype=torch.int32, device="cuda")
    tensor_scale = torch.randn(1, dtype=torch.float16, device="cuda").abs()

    # triton
    triton_out = triton_matmul_w4a16_tensor_test3(y_fp16, x_compressed, tensor_scale)

    # torch
    shifter = (torch.arange(0, 8) * 4).cuda()
    x_decompressed = (x_compressed[:, :, None] >> shifter[None, None, :]) & 0xF
    x_decompressed = x_decompressed.reshape(K, N).to(torch.float16) - 7.5
    torch_out = (y_fp16 @ x_decompressed) * tensor_scale

    # results
    dist = deepseek_dist(triton_out, torch_out)
    
    if print_dist:
        print(f"[M x K x N]: [{M} x {K} x {N}], dist = {dist}")
        
    assert dist < 0.001

In [8]:
BS = 1
sizes = [2**11, 2**12, 2**13, 2**14]
# sizes = [2**11,] #, 2**12, 2**13, 2**14]

BSs = [1, 16, 32, 128, 1024]
size = 4096

llama_sizes = [
    (8, 11008, 4096),
    (8, 4096, 11008),
    (8, 4096, 4096),
]

experiments = [
            "torch_fp16",
            # "triton_w4a16_tensor",
            # "triton_w4a16_tensor_test2",
            "triton_w4a16_tensor_test3",
            ]

configs = []
configs.append(
    triton.testing.Benchmark(
        # x_names=["M", "K", "N"],
        # x_vals=llama_sizes,

        x_names=["K", "M", "N"],  # Argument names to use as an x-axis for the plot
        x_vals=[(size, BS, size) for size in sizes],

        # x_names=["M", "K", "N"],  # Argument names to use as an x-axis for the plot
        # x_vals=[(BS, size, size) for BS in BSs],
        # x_log=True,

        line_arg="provider",  # Argument name whose value corresponds to a different line in the plot
        line_vals=experiments,
        line_names=experiments,
        ylabel="TFLOPS",  # Label name for the y-axis
        xlabel="Matrix size",
        plot_name="matmul-performance",  # Name for the plot, used also as a file name for saving the plot.
        args={},
    ))


@triton.testing.perf_report(configs)
def benchmark(M, K, N, provider):
    y_fp16 = torch.randn(M, K, dtype=torch.float16, device="cuda") / (M * K)
    x_fp16 = torch.randn(K, N, dtype=torch.float16, device="cuda") / (M * K)
    
    # x_compressed_4bit = torch.randint(-2**31, 2**31, (K // 8, N), dtype=torch.int32, device="cuda")
    x_compressed_4bit = torch.randint(-2**31, 2**31, (N, K // 8), dtype=torch.int32, device="cuda").T
    
    # x_compressed_4bit_test2 = torch.randint(-2**31, 2**31, (K, N // 8), dtype=torch.int32, device="cuda")
    x_compressed_4bit_test2 = torch.randint(-2**31, 2**31, (N // 8, K), dtype=torch.int32, device="cuda").T
    
    # x_compressed_4bit_test3 = torch.randint(-2**31, 2**31, (N, K // 8), dtype=torch.int32, device="cuda")
    x_compressed_4bit_test3 = torch.randint(-2**31, 2**31, (K // 8, N), dtype=torch.int32, device="cuda").T
    
    
    # x_compressed_4bit_test4 = torch.randint(-2**31, 2**31, (N // 8, K), dtype=torch.int32, device="cuda")


    
    tensor_scale = torch.randn(1, dtype=torch.float16, device="cuda").abs()
    # channel_scale = torch.randn(N, dtype=torch.float16, device="cuda")
    # group256_scale = torch.randn((K // 256, N), dtype=torch.float16, device="cuda")
    
    quantiles = [0.5, 0.2, 0.8]
    if provider == "torch_fp16":
        print(f"\n[M x K x N]: [{M} x {K} x {N}]")
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(y_fp16, x_fp16), quantiles=quantiles)
    
    if provider == "triton_w4a16_tensor":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_matmul_w4a16_tensor(y_fp16, x_compressed_4bit, tensor_scale), quantiles=quantiles)
        print("matmul_kernel_w4a16_tensor:", kernel_w4t_a16_matmul_small.best_config)

    if provider == "triton_w4a16_tensor_test2":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_matmul_w4a16_tensor_test2(y_fp16, x_compressed_4bit_test2, tensor_scale), quantiles=quantiles)
        print("matmul_kernel_w4a16_tensor_test2:", kernel_w4t_a16_matmul_small_test2.best_config)

    if provider == "triton_w4a16_tensor_test3":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_matmul_w4a16_tensor_test3(y_fp16, x_compressed_4bit_test3, tensor_scale), quantiles=quantiles)
        print("matmul_kernel_w4a16_tensor_test3:", kernel_w4t_a16_matmul_small_test3.best_config)

    if provider == "triton_w4a16_channel":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_matmul_w4a16_channel(y_fp16, x_compressed_4bit, channel_scale), quantiles=quantiles)
        print("matmul_kernel_w4a16_channel:", matmul_kernel_w4a16_channel.best_config)

    if provider == "triton_w4a16_gs256":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_matmul_w4a16_gs256(y_fp16, x_compressed_4bit, group256_scale), quantiles=quantiles)
        print("matmul_kernel_w4a16_gs256:", matmul_kernel_w4a16_gs256.best_config)

    perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
    return perf(ms), perf(max_ms), perf(min_ms)

# bench_data = benchmark.run(show_plots=False, print_data=True, return_df=True)[0]

In [9]:
for size in sizes:
    M, N, K = BS, size, size
    # run_w4a16_tensor_test(M, N, K, print_dist=True)
    # run_w4a16_tensor_test2_test(M, N, K, print_dist=True)
    run_w4a16_tensor_test3_test(M, N, K, print_dist=True)

[M x K x N]: [1 x 2048 x 2048], dist = 1.0446353062541676


AssertionError: 

In [None]:
add_speedup_columns(bench_data)

Unnamed: 0,K,M,N,torch_fp16,triton_w4a16_tensor_test3,triton_w4a16_tensor_test3_speedup
0,2048.0,1.0,2048.0,0.229347,0.842907,3.675241
1,4096.0,1.0,4096.0,0.399077,1.6384,4.105469
2,8192.0,1.0,8192.0,0.606815,2.114065,3.483871
3,16384.0,1.0,16384.0,0.802008,2.561016,3.193253


In [None]:
# def get_cuda_autotune_config():
#     configs = []
#     for num_warps, num_stages in [
#         (4, 2),
#         (4, 3),
#         (4, 4),
#         # (4, 5),
#         # (8, 2),
#         # (8, 4),
#     ]:
#         # configs.append(
#         #     triton.Config({"GROUP_SIZE_M" : 8, "BLOCK_SIZE_M" : 64, "BLOCK_SIZE_N" : 128, "BLOCK_SIZE_K" : 64}, num_stages=num_stages, num_warps=num_warps),
#         # )
#         for GROUP_SIZE_M in [1]:
#             for BLOCK_SIZE_M in [16]:
#                 for BLOCK_SIZE_N in [16, 32, 64, 128]:
#                     for BLOCK_SIZE_K in [128, 256, 512]:
#                         configs.append(
#                             triton.Config(
#                                 {
#                                     "GROUP_SIZE_M" : GROUP_SIZE_M,
#                                     "BLOCK_SIZE_M" : BLOCK_SIZE_M,
#                                     "BLOCK_SIZE_N" : BLOCK_SIZE_N,
#                                     "BLOCK_SIZE_K" : BLOCK_SIZE_K,
#                                 }, 
#                                 num_stages=num_stages, 
#                                 num_warps=num_warps
#                             ),
#                         )                        
#     return configs
#     return [triton.Config(
#                                 {
#                                     "GROUP_SIZE_M" : 1,
#                                     "BLOCK_SIZE_M" : 16,
#                                     "BLOCK_SIZE_N" : 32,
#                                     "BLOCK_SIZE_K" : 128,
#                                 },
#                                 num_stages=4,
#                                 num_warps=4
#                             )]
