In [2]:
import math
import torch
import triton
import triton.language as tl

# -----------------------------
# Your provided kernels (as-is)
# -----------------------------
@triton.jit
def quant_pack_kernel(
    X_ptr, P_ptr, S_ptr, M_ptr,
    stride_x0: tl.constexpr, stride_x1: tl.constexpr,
    stride_p0: tl.constexpr, stride_p1: tl.constexpr,
    BITS: tl.constexpr,
    VPW: tl.constexpr,
    NWORDS: tl.constexpr,
    QMAX: tl.constexpr,
):
    pid = tl.program_id(0)

    x_block_ptr = tl.make_block_ptr(
        base=X_ptr + pid * stride_x0,
        shape=(NWORDS, VPW),
        strides=(stride_x1 * VPW, stride_x1),
        offsets=(0, 0),
        block_shape=(NWORDS, VPW),
        order=(0, 1),
    )

    x = tl.load(x_block_ptr).to(tl.float32)   # [NWORDS, VPW]

    xmin = tl.min(x, axis=1)
    xmin = tl.min(xmin, axis=0)

    xmax = tl.max(x, axis=1)
    xmax = tl.max(xmax, axis=0)

    rng = xmax - xmin
    scale = (rng / tl.full([], QMAX, tl.float32))
    scale = tl.where(rng > 0.0, scale, tl.full([], 1.0, tl.float32))

    tl.store(S_ptr + pid, scale)
    tl.store(M_ptr + pid, xmin)

    inv_scale = tl.full([], 1.0, tl.float32) / scale

    j = tl.arange(0, VPW)
    shifts = (j * BITS).to(tl.int32)
    eps  = tl.full([VPW], 1e-6, tl.float32)
    half = tl.full([VPW], 0.5,  tl.float32)

    qf = (x - xmin) * inv_scale + (half - eps)
    qi = qf.to(tl.int32)
    qi = tl.maximum(qi, 0)
    qi = tl.minimum(qi, QMAX)

    words = tl.sum(qi << shifts[None, :], axis=1)  # [NWORDS]

    p_block_ptr = tl.make_block_ptr(
        base = P_ptr + pid * stride_p0,
        shape=(NWORDS,),
        strides=(stride_p1,),
        offsets=(0,),
        block_shape=(NWORDS,),
        order=(0,)
    )
    tl.store(p_block_ptr, words.to(tl.int32))


@triton.jit
def dequant_unpack_kernel(
    P_ptr, S_ptr, M_ptr, Y_ptr,
    stride_p0: tl.constexpr, stride_p1: tl.constexpr,
    stride_y0: tl.constexpr, stride_y1: tl.constexpr,
    BITS: tl.constexpr,
    VPW: tl.constexpr,
    NWORDS: tl.constexpr
):
    pid = tl.program_id(0)

    p_block_ptr = tl.make_block_ptr(
        base=P_ptr + pid * stride_p0,
        shape=(NWORDS,),
        strides=(stride_p1,),
        offsets=(0,),
        block_shape=(NWORDS,),
        order=(0,)
    )

    word = tl.load(p_block_ptr)

    scale = tl.load(S_ptr + pid)
    scale_dtype = scale.dtype
    scale = scale.to(tl.float32)
    xmin  = tl.load(M_ptr + pid).to(tl.float32)

    mask = (1 << BITS) - 1
    j = tl.arange(0, VPW)
    shifts = (j * BITS).to(tl.int32)

    q = ((word[:, None] >> shifts[None, :]) & mask).to(tl.float32)
    q = q * scale + xmin
    q_flat = tl.reshape(q, (NWORDS * VPW,))

    y_block_ptr = tl.make_block_ptr(
        base=Y_ptr + pid * stride_y0,
        shape=(NWORDS * VPW,),
        strides=(stride_y1,),
        offsets=(0,),
        block_shape=(NWORDS * VPW,),
        order=(0,)
    )

    tl.store(y_block_ptr, q_flat.to(scale_dtype))


# -------------------------------------------------------
# Fused kernel: bit-unpack + dequant + accumulate dw tile
# dw[m, n:n+256] += sum_k dy[k,m] * x[k,n:n+256]
# -------------------------------------------------------
@triton.jit
def fused_linear_dw_from_packed_kernel(
    P_ptr, S_ptr, M_ptr, DY_ptr, DW_ptr,
    K: tl.constexpr,
    COUT: tl.constexpr,
    CIN: tl.constexpr,
    G: tl.constexpr,
    stride_p0: tl.constexpr, stride_p1: tl.constexpr,
    stride_s0: tl.constexpr,
    stride_m0: tl.constexpr,
    stride_dy0: tl.constexpr, stride_dy1: tl.constexpr,
    stride_dw0: tl.constexpr, stride_dw1: tl.constexpr,
    BITS: tl.constexpr,
    VPW: tl.constexpr,
    NWORDS: tl.constexpr,
    GROUP: tl.constexpr,          # <-- add this
    BLOCK_M: tl.constexpr,
    BLOCK_K: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_g = tl.program_id(1)

    tl.static_assert(GROUP == NWORDS * VPW)

    m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    mask_m = m < COUT

    n = pid_g * GROUP + tl.arange(0, GROUP)   # <-- now GROUP is constexpr
    mask_n = n < CIN

    g = pid_g
    valid_g = g < G

    acc = tl.zeros((BLOCK_M, GROUP), dtype=tl.float32)

    bitmask = (1 << BITS) - 1
    j = tl.arange(0, VPW)
    shifts = (j * BITS).to(tl.int32)
    w = tl.arange(0, NWORDS)

    for k0 in tl.range(0, K, BLOCK_K):
        k = k0 + tl.arange(0, BLOCK_K)
        mask_k = k < K

        dy_ptrs = DY_ptr + k[:, None] * stride_dy0 + m[None, :] * stride_dy1
        dy = tl.load(dy_ptrs, mask=mask_k[:, None] & mask_m[None, :], other=0.0).to(tl.bfloat16)
        a = tl.trans(dy)  # [BM, BK]

        pids = k * G + g
        p_ptrs = P_ptr + pids[:, None] * stride_p0 + w[None, :] * stride_p1
        words = tl.load(p_ptrs, mask=mask_k[:, None] & valid_g, other=0).to(tl.int32)

        scale = tl.load(S_ptr + pids * stride_s0, mask=mask_k & valid_g, other=1.0).to(tl.float32)
        xmin  = tl.load(M_ptr + pids * stride_m0, mask=mask_k & valid_g, other=0.0).to(tl.float32)

        q = ((words[:, :, None] >> shifts[None, None, :]) & bitmask).to(tl.float32)
        x = q * scale[:, None, None] + xmin[:, None, None]
        x = tl.reshape(x, (BLOCK_K, GROUP)).to(tl.bfloat16)

        acc += tl.dot(a, x)

    dw_ptrs = DW_ptr + m[:, None] * stride_dw0 + n[None, :] * stride_dw1
    tl.store(dw_ptrs, acc.to(tl.bfloat16), mask=mask_m[:, None] & mask_n[None, :])






# -----------------------------
# Timing helpers
# -----------------------------
@torch.no_grad()
def time_ms(fn, iters=50, warmup=10):
    torch.cuda.synchronize()
    for _ in range(warmup):
        fn()
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end   = torch.cuda.Event(enable_timing=True)
    start.record()
    for _ in range(iters):
        fn()
    end.record()
    torch.cuda.synchronize()
    return start.elapsed_time(end) / iters


def cdiv(a, b): return (a + b - 1) // b


# -----------------------------
# Main benchmark
# -----------------------------
def main(
    K=8192,        # rows in x_flat / dy_flat
    CIN=768,       # input features
    COUT=768,      # output features
    BITS=2,        # 1/2/4/8 supported by your bitpack
    GROUP=256,     # your pack group size (must match kernels)
    iters=50,
):
    assert GROUP == 256, "This script assumes GROUP=256 like your kernels (NWORDS*VPW=256)."
    assert BITS in (1,2,4,8), "Test bits in {1,2,4,8} first."

    device = "cuda"
    torch.manual_seed(0)

    VPW = 32 // BITS
    NWORDS = GROUP // VPW
    QMAX = (1 << BITS) - 1

    # pad CIN to multiple of GROUP (for packing)
    CIN_PAD = cdiv(CIN, GROUP) * GROUP
    G = CIN_PAD // GROUP

    print(f"\nShapes: K={K}, CIN={CIN}, CIN_PAD={CIN_PAD}, COUT={COUT}, G={G}, BITS={BITS}, VPW={VPW}, NWORDS={NWORDS}")

    # inputs
    x = torch.randn((K, CIN), device=device, dtype=torch.bfloat16)
    if CIN_PAD != CIN:
        x = torch.nn.functional.pad(x, (0, CIN_PAD - CIN))
    x = x.contiguous()  # [K, CIN_PAD]

    dy = torch.randn((K, COUT), device=device, dtype=torch.bfloat16).contiguous()
    dy_T = dy.t().contiguous()  # [COUT, K]

    # view as groups and flatten to [K*G, GROUP]
    X_in = x.view(K, G, GROUP).reshape(K * G, GROUP).contiguous()

    # packed buffers (constant during benchmark)
    P = torch.empty((K * G, NWORDS), device=device, dtype=torch.int32)
    S = torch.empty((K * G,), device=device, dtype=torch.bfloat16)
    M = torch.empty((K * G,), device=device, dtype=torch.bfloat16)

    grid_q = (K * G,)
    quant_pack_kernel[grid_q](
        X_in, P, S, M,
        X_in.stride(0), X_in.stride(1),
        P.stride(0), P.stride(1),
        BITS=BITS, VPW=VPW, NWORDS=NWORDS, QMAX=QMAX,
        num_warps=4,
    )
    torch.cuda.synchronize()

    # unfused buffers (reused)
    Y = torch.empty((K * G, GROUP), device=device, dtype=torch.bfloat16)
    dw_unfused = torch.empty((COUT, CIN), device=device, dtype=torch.bfloat16)

    # fused output buffer (reused)
    dw_fused = torch.empty((COUT, CIN), device=device, dtype=torch.bfloat16)

    # --- unfused path: dequant -> reshape -> mm
    def run_unfused():
        # dequant into Y
        dequant_unpack_kernel[grid_q](
            P, S, M, Y,
            P.stride(0), P.stride(1),
            Y.stride(0), Y.stride(1),
            BITS=BITS, VPW=VPW, NWORDS=NWORDS,
            num_warps=4,
        )
        # rebuild x_deq: [K, CIN_PAD] then slice [:, :CIN]
        x_deq = Y.view(K, G, GROUP).reshape(K, CIN_PAD)[:, :CIN]
        # dw = dy_T @ x_deq  => [COUT, CIN]
        torch.mm(dy_T, x_deq, out=dw_unfused)

    # --- fused path: unpack+dequant inside GEMM-like kernel
    def run_fused():
        # zero dw_fused (kernel writes full tile; no atomic)
        dw_fused.zero_()
        grid = (triton.cdiv(COUT, 16), triton.cdiv(CIN, GROUP))
        fused_linear_dw_from_packed_kernel[grid](
            P, S, M, dy, dw_fused,
            K=K, COUT=COUT, CIN=CIN, G=G,
            stride_p0=P.stride(0), stride_p1=P.stride(1),
            stride_s0=S.stride(0),
            stride_m0=M.stride(0),
            stride_dy0=dy.stride(0), stride_dy1=dy.stride(1),
            stride_dw0=dw_fused.stride(0), stride_dw1=dw_fused.stride(1),
            BITS=BITS, VPW=VPW, NWORDS=NWORDS,
            GROUP=GROUP,                 # <-- add this
            BLOCK_M=16, BLOCK_K=16,
            num_warps=8, num_stages=3
        )


    # correctness check (allow small tolerance)
    run_unfused()
    run_fused()
    torch.cuda.synchronize()

    # These should be close (same quant format; different reduction order possible)
    torch.testing.assert_close(dw_fused, dw_unfused, rtol=1e-2, atol=1e-2)
    print("✅ correctness: fused ~ unfused (within tolerance)")

    # timing
    t_unfused = time_ms(run_unfused, iters=iters, warmup=10)
    t_fused   = time_ms(run_fused,   iters=iters, warmup=10)

    print(f"Unfused (dequant + mm): {t_unfused:.3f} ms/iter")
    print(f"Fused   (unpack+mm):    {t_fused:.3f} ms/iter")
    print(f"Speedup: {t_unfused / t_fused:.2f}x")

    # optional: isolate dequant cost alone
    def run_dequant_only():
        dequant_unpack_kernel[grid_q](
            P, S, M, Y,
            P.stride(0), P.stride(1),
            Y.stride(0), Y.stride(1),
            BITS=BITS, VPW=VPW, NWORDS=NWORDS,
            num_warps=4,
        )
    t_deq = time_ms(run_dequant_only, iters=iters, warmup=10)
    print(f"Dequant-only: {t_deq:.3f} ms/iter")


if __name__ == "__main__":
    assert torch.cuda.is_available()
    # adjust K/CIN/COUT to your layer; start moderate so compile+timing is quick
    main(K=8192, CIN=768, COUT=768, BITS=2, iters=50)



Shapes: K=8192, CIN=768, CIN_PAD=768, COUT=768, G=3, BITS=2, VPW=16, NWORDS=16
✅ correctness: fused ~ unfused (within tolerance)
Unfused (dequant + mm): 0.045 ms/iter
Fused   (unpack+mm):    0.694 ms/iter
Speedup: 0.06x
Dequant-only: 0.024 ms/iter


In [6]:
import triton
import triton.language as tl

@triton.jit
def fused_dw_kernel_splitk(
    P_ptr, S_ptr, M_ptr, DY_ptr, DW_ptr,   # DW is FP32 for atomic accumulation
    K: tl.constexpr,
    COUT: tl.constexpr,
    CIN: tl.constexpr,
    G: tl.constexpr,  # groups per row = CIN_PAD // GROUP
    stride_p0: tl.constexpr, stride_p1: tl.constexpr,   # P: [K*G, NWORDS]
    stride_s0: tl.constexpr,                            # S: [K*G]
    stride_m0: tl.constexpr,                            # M: [K*G]
    stride_dy0: tl.constexpr, stride_dy1: tl.constexpr, # dy: [K, COUT]
    stride_dw0: tl.constexpr, stride_dw1: tl.constexpr, # dw_fp32: [COUT, CIN]
    # quant format
    BITS: tl.constexpr,
    VPW: tl.constexpr,
    NWORDS: tl.constexpr,
    GROUP: tl.constexpr,          # 256
    # tiling
    BM: tl.constexpr,
    BN: tl.constexpr,
    BK: tl.constexpr,
    # split-K
    SPLIT_K: tl.constexpr,
    K_SLICE: tl.constexpr,
    # NEW: make these explicit constexpr so tl.arange accepts them
    WORDS_N: tl.constexpr,        # = BN // VPW
    TILES_PER_GROUP: tl.constexpr # = GROUP // BN
):
    pid_m  = tl.program_id(0)
    pid_n  = tl.program_id(1)
    pid_sk = tl.program_id(2)

    tl.static_assert(GROUP == 256)
    tl.static_assert((GROUP % BN) == 0)
    tl.static_assert((BN % VPW) == 0)
    tl.static_assert(WORDS_N * VPW == BN)
    tl.static_assert(TILES_PER_GROUP * BN == GROUP)

    # output indices
    m = pid_m * BM + tl.arange(0, BM)
    n = pid_n * BN + tl.arange(0, BN)
    mask_m = m < COUT
    mask_n = n < CIN

    # map BN-tile -> group g, and tile offset within that group
    g = pid_n // TILES_PER_GROUP
    tile_in_group = pid_n - g * TILES_PER_GROUP
    valid_g = g < G

    # which packed words inside this group does this BN-tile need?
    # word_base = (tile_in_group*BN) // VPW = tile_in_group * (BN//VPW) = tile_in_group * WORDS_N
    word_base = tile_in_group * WORDS_N

    acc = tl.zeros((BM, BN), dtype=tl.float32)

    bitmask = (1 << BITS) - 1
    shifts = (tl.arange(0, VPW) * BITS).to(tl.int32)  # [VPW]
    w = tl.arange(0, WORDS_N)                         # ✅ now WORDS_N is constexpr

    # split-K range
    k_start = pid_sk * K_SLICE

    for kk in tl.range(0, K_SLICE, BK):
        k = k_start + kk + tl.arange(0, BK)
        mask_k = k < K

        # dy: [BK, BM] -> A = [BM, BK]
        dy_ptrs = DY_ptr + k[:, None] * stride_dy0 + m[None, :] * stride_dy1
        dy = tl.load(dy_ptrs, mask=mask_k[:, None] & mask_m[None, :], other=0.0).to(tl.bfloat16)
        A = tl.trans(dy)

        # packed words: [BK, WORDS_N]
        pids = k * G + g
        p_ptrs = P_ptr + pids[:, None] * stride_p0 + (word_base + w)[None, :] * stride_p1
        words = tl.load(p_ptrs, mask=mask_k[:, None] & valid_g, other=0).to(tl.int32)

        scale = tl.load(S_ptr + pids * stride_s0, mask=mask_k & valid_g, other=1.0).to(tl.float16)
        xmin  = tl.load(M_ptr + pids * stride_m0, mask=mask_k & valid_g, other=0.0).to(tl.float16)

        # unpack only BN cols: [BK, WORDS_N, VPW] -> [BK, BN]
        q = ((words[:, :, None] >> shifts[None, None, :]) & bitmask).to(tl.float16)
        X = q * scale[:, None, None] + xmin[:, None, None]
        B = tl.reshape(X, (BK, BN)).to(tl.bfloat16)

        acc += tl.dot(A, B)  # MMA path

    # atomic add into dw_fp32
    dw_ptrs = DW_ptr + m[:, None] * stride_dw0 + n[None, :] * stride_dw1
    tl.atomic_add(dw_ptrs, acc, mask=mask_m[:, None] & mask_n[None, :] & valid_g)


In [7]:
import torch
import triton

# reuse your quant_pack_kernel and dequant_unpack_kernel here

def time_ms(fn, iters=50, warmup=10):
    torch.cuda.synchronize()
    for _ in range(warmup):
        fn()
    torch.cuda.synchronize()
    start = torch.cuda.Event(True); end = torch.cuda.Event(True)
    start.record()
    for _ in range(iters):
        fn()
    end.record()
    torch.cuda.synchronize()
    return start.elapsed_time(end) / iters

def cdiv(a,b): return (a+b-1)//b

@torch.no_grad()
def bench(K=100864, CIN=768, COUT=768, BITS=2, GROUP=256, iters=50,
          BM=64, BN=64, BK=32, SPLIT_K=8):
    assert GROUP == 256
    assert BN % (32//BITS) == 0
    assert GROUP % BN == 0
    assert BK % 16 == 0

    VPW = 32 // BITS
    NWORDS = GROUP // VPW
    QMAX = (1 << BITS) - 1

    CIN_PAD = cdiv(CIN, GROUP) * GROUP
    G = CIN_PAD // GROUP

    print(f"\nShapes: K={K}, CIN={CIN}, CIN_PAD={CIN_PAD}, COUT={COUT}, G={G}, "
          f"BITS={BITS}, VPW={VPW}, NWORDS={NWORDS}")
    print(f"Tiling: BM={BM}, BN={BN}, BK={BK}, SPLIT_K={SPLIT_K}")

    # ---- inputs
    x = torch.randn((K, CIN), device="cuda", dtype=torch.bfloat16)
    if CIN_PAD != CIN:
        x = torch.nn.functional.pad(x, (0, CIN_PAD - CIN))
    x = x.contiguous()

    dy = torch.randn((K, COUT), device="cuda", dtype=torch.bfloat16).contiguous()
    dy_T = dy.t().contiguous()

    # pack row-wise by groups along Cin
    X_in = x.view(K, G, GROUP).reshape(K * G, GROUP).contiguous()

    P = torch.empty((K * G, NWORDS), device="cuda", dtype=torch.int32)
    S = torch.empty((K * G,), device="cuda", dtype=torch.bfloat16)
    M = torch.empty((K * G,), device="cuda", dtype=torch.bfloat16)

    grid_q = (K * G,)
    quant_pack_kernel[grid_q](
        X_in, P, S, M,
        X_in.stride(0), X_in.stride(1),
        P.stride(0), P.stride(1),
        BITS=BITS, VPW=VPW, NWORDS=NWORDS, QMAX=QMAX,
        num_warps=4
    )
    torch.cuda.synchronize()

    # unfused buffers
    Y = torch.empty((K * G, GROUP), device="cuda", dtype=torch.bfloat16)
    dw_unfused = torch.empty((COUT, CIN), device="cuda", dtype=torch.bfloat16)

    # fused buffers (fp32 for split-k accumulation)
    dw_fp32 = torch.empty((COUT, CIN), device="cuda", dtype=torch.float32)
    dw_fused = torch.empty((COUT, CIN), device="cuda", dtype=torch.bfloat16)

    # ---- unfused
    def run_unfused():
        dequant_unpack_kernel[grid_q](
            P, S, M, Y,
            P.stride(0), P.stride(1),
            Y.stride(0), Y.stride(1),
            BITS=BITS, VPW=VPW, NWORDS=NWORDS,
            num_warps=4
        )
        x_deq = Y.view(K, G, GROUP).reshape(K, CIN_PAD)[:, :CIN]
        torch.mm(dy_T, x_deq, out=dw_unfused)

    # ---- fused
    K_SLICE = cdiv(K, SPLIT_K)

    def run_fused():
        VPW = 32 // BITS
        WORDS_N = BN // VPW
        TILES_PER_GROUP = GROUP // BN
        K_SLICE = (K + SPLIT_K - 1) // SPLIT_K
        # atomic accumulation needs zero
        dw_fp32.zero_()
        grid = (triton.cdiv(COUT, BM), triton.cdiv(CIN, BN), SPLIT_K)
        fused_dw_kernel_splitk[grid](
            P, S, M, dy, dw_fp32,
            K=K, COUT=COUT, CIN=CIN, G=G,
            stride_p0=P.stride(0), stride_p1=P.stride(1),
            stride_s0=S.stride(0),
            stride_m0=M.stride(0),
            stride_dy0=dy.stride(0), stride_dy1=dy.stride(1),
            stride_dw0=dw_fp32.stride(0), stride_dw1=dw_fp32.stride(1),
            BITS=BITS, VPW=VPW, NWORDS=NWORDS, GROUP=GROUP,
            BM=BM, BN=BN, BK=BK,
            SPLIT_K=SPLIT_K, K_SLICE=K_SLICE,
            WORDS_N=WORDS_N,
            TILES_PER_GROUP=TILES_PER_GROUP,
            num_warps=8,
            num_stages=3
        )
        # match baseline dtype
        dw_fused.copy_(dw_fp32.to(torch.bfloat16))

    # correctness (loose tol because quant + different reduction order)
    run_unfused()
    run_fused()
    torch.cuda.synchronize()
    torch.testing.assert_close(dw_fused, dw_unfused, rtol=2e-2, atol=2e-2)
    print("✅ correctness OK")

    t_unf = time_ms(run_unfused, iters=iters)
    t_fus = time_ms(run_fused, iters=iters)

    print(f"Unfused (dequant + cuBLAS mm): {t_unf:.3f} ms/iter")
    print(f"Fused   (unpack-in-GEMM):      {t_fus:.3f} ms/iter")
    print(f"Speedup: {t_unf/t_fus:.2f}x")

# Example runs:
# - your old K=8192 will still likely favor cuBLAS
# - ViT-like K (~100k) is where fusion starts to matter
if __name__ == "__main__":
    bench(K=8192,   CIN=768, COUT=768, BITS=2, iters=100, SPLIT_K=8)      # likely slower than cuBLAS
    bench(K=100864, CIN=768, COUT=768, BITS=2, iters=50,  SPLIT_K=16)     # more realistic ViT case



Shapes: K=8192, CIN=768, CIN_PAD=768, COUT=768, G=3, BITS=2, VPW=16, NWORDS=16
Tiling: BM=64, BN=64, BK=32, SPLIT_K=8
✅ correctness OK
Unfused (dequant + cuBLAS mm): 0.064 ms/iter
Fused   (unpack-in-GEMM):      0.422 ms/iter
Speedup: 0.15x

Shapes: K=100864, CIN=768, CIN_PAD=768, COUT=768, G=3, BITS=2, VPW=16, NWORDS=16
Tiling: BM=64, BN=64, BK=32, SPLIT_K=16
✅ correctness OK
Unfused (dequant + cuBLAS mm): 0.764 ms/iter
Fused   (unpack-in-GEMM):      4.597 ms/iter
Speedup: 0.17x


In [4]:
import triton
import triton.language as tl 

In [5]:
@triton.jit
def pack(
    X_ptr, P_ptr, 
    stride_x0: tl.constexpr, stride_x1: tl.constexpr,
    stride_p0: tl.constexpr, stride_p1: tl.constexpr,
    BITS: tl.constexpr,
    VPW: tl.constexpr,
    NWORDS: tl.constexpr
):
    pid = tl.program_id(0)

    x_block_ptr = tl.make_block_ptr(
        base = X_ptr + pid * stride_x0,
        shape=(NWORDS * VPW,),
        strides=(stride_x1,),
        offsets=(0,),
        block_shape=(NWORDS * VPW,),
        order=(0,)
    )

    x = tl.load(x_block_ptr)
    x = tl.reshape(x, (NWORDS, VPW))

    j = tl.arange(0, VPW)
    shifts = (j * BITS).to(tl.int32)

    q = tl.sum(x >> shifts[None, :], axis=1).to(tl.int8)

    q_block_ptr = tl.make_block_ptr(
        base=P_ptr + pid * stride_p0,
        shape=(NWORDS,),
        strides=(stride_p1,),
        offsets=(0,),
        block_shape=(NWORDS,),
        order=(0,)
    )

    tl.store(q_block_ptr, q)


@triton.jit
def unpack(
    P_ptr, Y_ptr,
    stride_p0: tl.constexpr, stride_p1: tl.constexpr,
    stride_y0: tl.constexpr, stride_y1: tl.constexpr,
    BITS: tl.constexpr,
    VPW: tl.constexpr,
    NWORDS: tl.constexpr
):
    pid = tl.program_id(0)

    p_block_ptr = tl.make_block_ptr(
        base=P_ptr + pid * stride_p0,
        shape=(NWORDS,),
        strides=(stride_p1,),
        offsets=(0,),
        block_shape=(NWORDS,),
        order=(0,)
    )

    p = tl.load(p_block_ptr)

    mask = (1 << BITS) - 1
    j = tl.arange(0, VPW)
    shifts = (j * BITS).to(tl.int32)

    y = ((p[:, None] >> shifts[None, :]) & mask).to(tl.int8)

    y_block_ptr = tl.make_block_ptr(
        base=Y_ptr + pid * stride_y0,
        shape=(NWORDS, VPW),
        strides=(VPW, stride_y1),
        offsets=(0, 0),
        block_shape=(NWORDS, VPW),
        order=(0, 1)
    )

    tl.store(y_block_ptr, y)

In [33]:
@triton.jit
def relu_fwd_fused_pack(
    X_ptr, P_ptr, Y_ptr,
    stride_x0: tl.constexpr, stride_x1: tl.constexpr,
    stride_p0: tl.constexpr, stride_p1: tl.constexpr,
    stride_y0: tl.constexpr, stride_y1: tl.constexpr,
    BITS: tl.constexpr,
    VPW: tl.constexpr,
    NWORDS: tl.constexpr
):
    pid = tl.program_id(0)

    x_block_ptr = tl.make_block_ptr(
        base=X_ptr + pid * stride_x0,
        shape=(NWORDS * VPW,),
        strides=(stride_x1, ),
        offsets=(0,),
        block_shape=(NWORDS * VPW,),
        order=(0,)
    )

    x = tl.load(x_block_ptr)

    relu_mask = x > 0
    x = tl.where(relu_mask, x, 0)

    relu_mask = tl.reshape(relu_mask, (NWORDS, VPW))
    j = tl.arange(0, VPW)
    shifts = (j * BITS).to(tl.int32)
    p_relu_mask = tl.sum(relu_mask >> shifts[None, :], axis=1).to(tl.int8)

    p_block_ptr = tl.make_block_ptr(
        base=P_ptr + pid * stride_p1,
        shape=(NWORDS,),
        strides=(stride_p1,),
        offsets=(0,),
        block_shape=(NWORDS,),
        order=(0,)
    )

    y_block_ptr = tl.make_block_ptr(
        base=Y_ptr + pid * stride_y0,
        shape=(NWORDS * VPW,),
        strides=(stride_y1,),
        offsets=(0,),
        block_shape=(NWORDS * VPW,),
        order=(0,)
    )

    tl.store(p_block_ptr, p_relu_mask)
    tl.store(y_block_ptr, x)



@triton.jit
def relu_bwd_fused_unpack(
    P_ptr, DY_ptr, DX_ptr,
    stride_p0: tl.constexpr, stride_p1: tl.constexpr,
    stride_dy0: tl.constexpr, stride_dy1: tl.constexpr,
    stride_dx0: tl.constexpr, stride_dx1: tl.constexpr,
    BITS: tl.constexpr,
    VPW: tl.constexpr,
    NWORDS: tl.constexpr
):
    pid = tl.program_id(0)

    p_block_ptr = tl.make_block_ptr(
        base=P_ptr + pid * stride_p0,
        shape=(NWORDS,),
        strides=(stride_dy1,),
        offsets=(0,),
        block_shape=(NWORDS,),
        order=(0,)
    )

    dy_block_ptr = tl.make_block_ptr(
        base=DY_ptr + pid * stride_dy0,
        shape=(NWORDS, VPW),
        strides=(VPW, stride_dy1),
        offsets=(0,0),
        block_shape=(NWORDS, VPW),
        order=(0,1)
    )

    packed = tl.load(p_block_ptr)

    mask = (1 << BITS) - 1
    j = tl.arange(0, VPW)
    shifts = (j * BITS).to(tl.int32)

    relu_mask = ((packed[:, None] >> shifts[None, :]) & mask).to(tl.float32) # [NWORDS, VPW]
    dy = tl.load(dy_block_ptr).to(tl.float32)

    dx = dy * relu_mask
    dx = tl.reshape(dx, (NWORDS * VPW,))

    dx_block_ptr = tl.make_block_ptr(
        base=DX_ptr + pid * stride_dx0,
        shape=(NWORDS * VPW,),
        strides=(stride_dx1,),
        offsets=(0,),
        block_shape=(NWORDS * VPW,),
        order=(0,)
    )

    tl.store(dx_block_ptr, dx.to(tl.bfloat16))

In [7]:
import torch
import torch.nn.functional as F
import torch.cuda as cuda

In [31]:
N, G = 32, 100

x = torch.randint(-100, 100, (N*G, 256), dtype=torch.int8, device="cuda")
y = torch.empty_like(x)
dy = torch.randn(*x.shape, dtype=torch.bfloat16, device='cuda')
x2 = torch.empty_like(dy, dtype=torch.bfloat16, device='cuda')

p = torch.empty((N * G, 16), dtype=torch.int8, device='cuda')


start_timer = cuda.Event(enable_timing=True)
end_timer = cuda.Event(enable_timing=True)

In [27]:
BITS = 2
VPW = 32 // 2
NWORDS = 256 // VPW

In [46]:
cuda.synchronize()

start_timer.record()

grid = (N * G,)
# pack[grid](
#     x, p, 
#     x.stride(0), x.stride(1),
#     p.stride(0), p.stride(1),
#     BITS=BITS,
#     VPW=VPW,
#     NWORDS=NWORDS
# )
# y2 = F.relu(x)
# relu_fwd_fused_pack[grid](
#     x, p, y,
#     x.stride(0), x.stride(1),
#     p.stride(0), p.stride(1),
#     y.stride(0), y.stride(1),
#     BITS=BITS,
#     VPW=VPW,
#     NWORDS=NWORDS
# )
relu_bwd_fused_unpack[grid](
    p, dy, x2,
    p.stride(0), p.stride(1),
    dy.stride(0), dy.stride(1),
    x2.stride(0), x2.stride(1),
    BITS=BITS,
    VPW=VPW,
    NWORDS=NWORDS
)

end_timer.record()
cuda.synchronize()
time = start_timer.elapsed_time(end_timer)

print(f"Time: {time / 1000} s")

Time: 0.00025513601303100584 s


In [102]:
Time: 0.014269120216369629 s

SyntaxError: invalid syntax (3298410441.py, line 1)