In [1]:
!pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu118

Looking in indexes: https://download.pytorch.org/whl/cu118


In [46]:
import torch
import triton
import triton.language as tl

# def getConfigs(names, ranges, num_stages_range, num_warps_range):
#     configs = [
#         triton.Config({f'{names[0]}': x0, f'{names[1]}': x1, f'{names[2]}': x2, f'{names[3]}': x3}, num_stages=s, num_warps=w) \
#         for x0 in ranges[0]\
#         for x1 in ranges[1]\
#         for x2 in ranges[2]\
#         for x3 in ranges[3]\
#         for s in num_stages_range\
#         for w in num_warps_range\
#     ]
#     return configs

# ranges = [[32, 64, 128, 256, 512, 1024], [32, 64, 128, 256, 512, 1024], [32, 64, 128, 256, 512, 1024], [2,4,8]]
# num_stages_range = [0, 1, 2, 3, 4, 5, 6, 7, 8]
# num_warps_range = [2, 4, 8, 16, 32]


@triton.autotune(
    configs=[
        triton.Config(
            {
                "BLOCK_SIZE_BSIZE": 32,
                "BLOCK_SIZE_K": 32,
                "BLOCK_SIZE_D2": 64,
                "GROUP_SIZE_BSIZE": 8,
            },
            num_stages=1,
            num_warps=4,
        ),
        # triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_D2': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_D2': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_D2': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_D2': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_D2': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_D2': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=5, num_warps=2),
        # triton.Config({'BLOCK_SIZE_BSIZE': 32, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_D2': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=5, num_warps=2),
        # triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_D2': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=3, num_warps=8),
        # triton.Config({'BLOCK_SIZE_BSIZE': 256, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_D2': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=3, num_warps=8),
        # triton.Config({'BLOCK_SIZE_BSIZE': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_D2': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_D2': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_D2': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_D2': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_D2': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_D2': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4)
    ],
    # configs=getConfigs(['BLOCK_SIZE_BSIZE', 'BLOCK_SIZE_K', 'BLOCK_SIZE_D2', 'GROUP_SIZE_BSIZE'], ranges, num_stages_range, num_warps_range),
    key=["BSIZE", "K", "d2", "L"],
)
@triton.jit
def first_pass_kernel(
    hin_ptr,
    S1s_ptr,
    U2s_ptr,
    out1_ptr,
    out2_ptr,
    BSIZE,
    K,
    d2,
    L,
    stride_hin_bsize,
    stride_hin_d2,
    stride_su_l,
    stride_su_d2,
    stride_su_k,
    stride_out_l,
    stride_out_bsize,
    stride_out_k,
    BLOCK_SIZE_BSIZE: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    BLOCK_SIZE_D2: tl.constexpr,
    GROUP_SIZE_BSIZE: tl.constexpr,
):
    pid = tl.program_id(axis=1)
    batch_id = tl.program_id(axis=0)

    num_pid_bsize = tl.cdiv(BSIZE, BLOCK_SIZE_BSIZE)
    num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
    num_pid_in_group = GROUP_SIZE_BSIZE * num_pid_k
    group_id = pid // num_pid_in_group
    first_pid_bsize = group_id * GROUP_SIZE_BSIZE
    group_size_bsize = min(num_pid_bsize - first_pid_bsize, GROUP_SIZE_BSIZE)
    pid_bsize = first_pid_bsize + ((pid % num_pid_in_group) % group_size_bsize)
    pid_k = (pid % num_pid_in_group) // group_size_bsize

    offs_bsize = pid_bsize * BLOCK_SIZE_BSIZE + tl.arange(0, BLOCK_SIZE_BSIZE)
    offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
    offs_d2 = tl.arange(0, BLOCK_SIZE_D2)

    offs_bsize = tl.max_contiguous(
        tl.multiple_of(offs_bsize, BLOCK_SIZE_BSIZE), BLOCK_SIZE_BSIZE
    )
    offs_k = tl.max_contiguous(tl.multiple_of(offs_k, BLOCK_SIZE_K), BLOCK_SIZE_K)
    offs_d2 = tl.max_contiguous(tl.multiple_of(offs_d2, BLOCK_SIZE_D2), BLOCK_SIZE_D2)

    hin_ptrs = hin_ptr + (
        offs_bsize[:, None] * stride_hin_bsize + offs_d2[None, :] * stride_hin_d2
    )

    su_tmp = batch_id * stride_su_l + (
        offs_d2[:, None] * stride_su_d2 + offs_k[None, :] * stride_su_k
    )
    S1s_ptrs = S1s_ptr + su_tmp
    U2s_ptrs = U2s_ptr + su_tmp

    accumulator1 = tl.full(
        shape=(BLOCK_SIZE_BSIZE, BLOCK_SIZE_K), value=0.0, dtype=tl.float32
    )
    accumulator2 = tl.full(
        shape=(BLOCK_SIZE_BSIZE, BLOCK_SIZE_K), value=0.0, dtype=tl.float32
    )

    for d2_i in range(0, tl.cdiv(d2, BLOCK_SIZE_D2)):
        hin_mask = (offs_bsize[:, None] < BSIZE) & (
            offs_d2[None, :] < d2 - d2_i * BLOCK_SIZE_D2
        )
        hin = tl.load(hin_ptrs, mask=hin_mask, other=0.0)

        su_mask = (offs_d2[:, None] < d2 - d2_i * BLOCK_SIZE_D2) & (offs_k[None, :] < K)
        S1s = tl.load(S1s_ptrs, mask=su_mask, other=0.0)
        U2s = tl.load(U2s_ptrs, mask=su_mask, other=0.0)

        accumulator1 += tl.dot(hin, S1s, input_precision="ieee")
        accumulator2 += tl.dot(hin, U2s, input_precision="ieee")

        hin_ptrs += BLOCK_SIZE_D2 * stride_hin_d2
        S1s_ptrs += BLOCK_SIZE_D2 * stride_su_d2
        U2s_ptrs += BLOCK_SIZE_D2 * stride_su_d2

    out_tmp = (
        batch_id * stride_out_l
        + stride_out_bsize * offs_bsize[:, None]
        + stride_out_k * offs_k[None, :]
    )
    out1_ptrs = out1_ptr + out_tmp
    out2_ptrs = out2_ptr + out_tmp

    out_mask = (offs_bsize[:, None] < BSIZE) & (offs_k[None, :] < K)

    tl.store(out1_ptrs, accumulator1, mask=out_mask)
    tl.store(out2_ptrs, accumulator2, mask=out_mask)


def first_pass(hin, S1s, U2s):
    device = "cuda"
    # assert hin.shape[1] == S1s.shape[1], "Incompatible dimensions"
    # assert hin.shape[1] == U2s.shape[1], "Incompatible dimensions"
    # assert hin.is_contiguous(), "Matrix A must be contiguous"
    # assert S1s.is_contiguous(), "Matrix A must be contiguous"
    # assert U2s.is_contiguous(), "Matrix A must be contiguous"
    # assert S1s.stride() == U2s.stride(), "Matrix A must be contiguous"

    BSIZE, d2 = hin.shape
    L, _, K = S1s.shape

    out1 = torch.empty((L, BSIZE, K), dtype=torch.float32, device=device)
    out2 = torch.empty((L, BSIZE, K), dtype=torch.float32, device=device)

    # stride_hin_bsize, stride_hin_d2 = hin.stride()
    # stride_su_l, stride_su_d2, stride_su_k = S1s.stride()
    # stride_out_l, stride_out_bsize, stride_out_k = out1.stride()
    stride_hin_bsize, stride_hin_d2 = hin.shape[1], 1
    stride_su_l, stride_su_d2, stride_su_k = (
        S1s.shape[1] * S1s.shape[2],
        S1s.shape[2],
        1,
    )
    stride_out_l, stride_out_bsize, stride_out_k = (
        out1.shape[1] * out1.shape[2],
        out1.shape[2],
        1,
    )

    # assert out1.stride() == out2.stride(), "Matrix A must be contiguous"

    grid = lambda META: (
        L,
        triton.cdiv(BSIZE, META["BLOCK_SIZE_BSIZE"])
        * triton.cdiv(K, META["BLOCK_SIZE_K"]),
    )

    first_pass_kernel[grid](
        hin,
        S1s,
        U2s,
        out1,
        out2,
        BSIZE,
        K,
        d2,
        L,
        stride_hin_bsize,
        stride_hin_d2,
        stride_su_l,
        stride_su_d2,
        stride_su_k,
        stride_out_l,
        stride_out_bsize,
        stride_out_k,
    )

    return out1, out2


@triton.autotune(
    configs=[
        triton.Config(
            {
                "BLOCK_SIZE_BSIZE": 32,
                "BLOCK_SIZE_D1": 32,
                "BLOCK_SIZE_K": 32,
                "GROUP_SIZE_BSIZE": 8,
            },
            num_stages=1,
            num_warps=4,
        ),
        # triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_D1': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_D1': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_D1': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_D1': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_D1': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_D1': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=5, num_warps=2),
        # triton.Config({'BLOCK_SIZE_BSIZE': 32, 'BLOCK_SIZE_D1': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=5, num_warps=2),
        # triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_D1': 128, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_BSIZE': 8}, num_stages=3, num_warps=8),
        # triton.Config({'BLOCK_SIZE_BSIZE': 256, 'BLOCK_SIZE_D1': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=3, num_warps=8),
        # triton.Config({'BLOCK_SIZE_BSIZE': 256, 'BLOCK_SIZE_D1': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_D1': 128, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_D1': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_D1': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_D1': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_D1': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4)
    ],
    key=["BSIZE", "d1", "K", "L"],
)
@triton.jit
def second_pass_kernel(
    in1_ptr,
    in2_ptr,
    U1s_ptr,
    S2s_ptr,
    bias_ptr,
    out_ptr,
    BSIZE,
    d1,
    K,
    L,
    stride_in12_l,
    stride_in12_bsize,
    stride_in12_k,
    stride_us_l,
    stride_us_k,
    stride_us_d1,
    stride_bias_bsize,
    stride_bias_d1,
    stride_out_bsize,
    stride_out_d1,
    BLOCK_SIZE_BSIZE: tl.constexpr,
    BLOCK_SIZE_D1: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_BSIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)

    num_pid_bsize = tl.cdiv(BSIZE, BLOCK_SIZE_BSIZE)
    num_pid_d1 = tl.cdiv(d1, BLOCK_SIZE_D1)
    num_pid_in_group = GROUP_SIZE_BSIZE * num_pid_d1
    group_id = pid // num_pid_in_group
    first_pid_bsize = group_id * GROUP_SIZE_BSIZE
    GROUP_SIZE_BSIZE = min(num_pid_bsize - first_pid_bsize, GROUP_SIZE_BSIZE)
    pid_bsize = first_pid_bsize + ((pid % num_pid_in_group) % GROUP_SIZE_BSIZE)
    pid_d1 = (pid % num_pid_in_group) // GROUP_SIZE_BSIZE

    offs_bsize = pid_bsize * BLOCK_SIZE_BSIZE + tl.arange(0, BLOCK_SIZE_BSIZE)
    offs_d1 = pid_d1 * BLOCK_SIZE_D1 + tl.arange(0, BLOCK_SIZE_D1)
    offs_k = tl.arange(0, BLOCK_SIZE_K)

    offs_bsize = tl.max_contiguous(
        tl.multiple_of(offs_bsize, BLOCK_SIZE_BSIZE), BLOCK_SIZE_BSIZE
    )
    offs_d1 = tl.max_contiguous(tl.multiple_of(offs_d1, BLOCK_SIZE_D1), BLOCK_SIZE_D1)
    offs_k = tl.max_contiguous(tl.multiple_of(offs_k, BLOCK_SIZE_K), BLOCK_SIZE_K)

    in_tmp = offs_bsize[:, None] * stride_in12_bsize + offs_k[None, :] * stride_in12_k
    us_tmp = offs_k[:, None] * stride_us_k + offs_d1[None, :] * stride_us_d1

    accumulator = tl.full(
        shape=(BLOCK_SIZE_BSIZE, BLOCK_SIZE_D1), value=0.0, dtype=tl.float32
    )

    for l in range(0, L):
        l_in_offset = l * stride_in12_l
        l_us_offset = l * stride_us_l

        in1_ptrs = in1_ptr + l_in_offset + in_tmp
        in2_ptrs = in2_ptr + l_in_offset + in_tmp

        U1s_ptrs = U1s_ptr + l_us_offset + us_tmp
        S2s_ptrs = S2s_ptr + l_us_offset + us_tmp

        for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
            in_mask = offs_k[None, :] < K - k * BLOCK_SIZE_K
            in1 = tl.load(in1_ptrs, mask=in_mask, other=0.0)
            in2 = tl.load(in2_ptrs, mask=in_mask, other=0.0)

            us_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
            U1s = tl.load(U1s_ptrs, mask=us_mask, other=0.0)
            S2s = tl.load(S2s_ptrs, mask=us_mask, other=0.0)

            accumulator += tl.dot(in1, U1s, input_precision="ieee")
            accumulator += tl.dot(in2, S2s, input_precision="ieee")

            in_inc = BLOCK_SIZE_K * stride_in12_k
            in1_ptrs += in_inc
            in2_ptrs += in_inc

            us_inc = BLOCK_SIZE_K * stride_us_k
            U1s_ptrs += us_inc
            S2s_ptrs += us_inc

    bias_ptrs = bias_ptr + offs_d1[None, :] * stride_bias_d1
    bias_mask = offs_d1[None, :] < d1
    bias = tl.load(bias_ptrs, mask=bias_mask, other=0.0)

    accumulator *= 1.0 / (2.0 * L)
    accumulator += bias

    out_ptrs = (
        out_ptr
        + stride_out_bsize * offs_bsize[:, None]
        + stride_out_d1 * offs_d1[None, :]
    )
    out_mask = (offs_bsize[:, None] < BSIZE) & (offs_d1[None, :] < d1)

    tl.store(out_ptrs, accumulator, mask=out_mask)


def second_pass(in1, in2, U1s, S2s, bias):
    # assert in1.shape[2] == U1s.shape[1], "Incompatible dimensions"
    # assert in2.shape[2] == S2s.shape[1], "Incompatible dimensions"
    # assert in1.is_contiguous(), "Matrix A must be contiguous"
    # assert in2.is_contiguous(), "Matrix A must be contiguous"
    # assert U1s.is_contiguous(), "Matrix A must be contiguous"
    # assert S2s.is_contiguous(), "Matrix A must be contiguous"
    # assert bias.is_contiguous(), "Matrix A must be contiguous"
    # assert U1s.stride() == S2s.stride(), "Matrix A must be contiguous"
    # assert in1.stride() == in2.stride(), "Matrix A must be contiguous"

    L, BSIZE, K = in1.shape
    _, _, d1 = U1s.shape

    out = torch.empty((BSIZE, d1), dtype=torch.float32, device=device)

    # stride_in12_l, stride_in12_bsize, stride_in12_k = in1.stride()
    # stride_us_l, stride_us_k, stride_us_d1 = U1s.stride()
    # stride_bias_bsize, stride_bias_d1 = bias.stride()
    # stride_out_bsize, stride_out_d1 = out.stride()
    stride_in12_l, stride_in12_bsize, stride_in12_k = (
        in1.shape[1] * in1.shape[2],
        in1.shape[2],
        1,
    )
    stride_us_l, stride_us_k, stride_us_d1 = (
        U1s.shape[1] * U1s.shape[2],
        U1s.shape[2],
        1,
    )
    stride_bias_bsize, stride_bias_d1 = bias.shape[1], 1
    stride_out_bsize, stride_out_d1 = out.shape[1], 1

    grid = lambda META: (
        triton.cdiv(BSIZE, META["BLOCK_SIZE_BSIZE"])
        * triton.cdiv(d1, META["BLOCK_SIZE_D1"]),
    )

    second_pass_kernel[grid](
        in1,
        in2,
        U1s,
        S2s,
        bias,
        out,
        BSIZE,
        d1,
        K,
        L,
        stride_in12_l,
        stride_in12_bsize,
        stride_in12_k,
        stride_us_l,
        stride_us_k,
        stride_us_d1,
        stride_bias_bsize,
        stride_bias_d1,
        stride_out_bsize,
        stride_out_d1,
    )

    return out

In [67]:
import torch
import triton
import triton.language as tl


@triton.autotune(
    configs=[
        triton.Config(
            {
                "BLOCK_SIZE_BSIZE": 32,
                "BLOCK_SIZE_K": 32,
                "BLOCK_SIZE_d1": 32,
                "GROUP_SIZE_BSIZE": 8,
            },
            num_stages=1,
            num_warps=4,
        ),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_d1': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4,
        #                num_warps=4),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_d1': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4,
        #                num_warps=4),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_d1': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4,
        #                num_warps=4),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_d1': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4,
        #                num_warps=4),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_d1': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4,
        #                num_warps=4),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_d1': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=5,
        #                num_warps=2),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 32, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_d1': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=5,
        #                num_warps=2),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_d1': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=3,
        #                num_warps=8),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 256, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_d1': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=3,
        #                num_warps=8),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_d1': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=4,
        #                num_warps=4),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_d1': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=4,
        #                num_warps=4),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_d1': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=4,
        #                num_warps=4),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_d1': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_d1': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_d1': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4)
    ],
    key=["BSIZE", "K", "d1", "L"],
)
@triton.jit
def first_pass_gU1s_g_S2s_kernel(
    g_ptr,
    U1s_ptr,
    S2s_ptr,
    g_U1s_ptr,
    g_S2s_ptr,
    BSIZE,
    K,
    d1,
    L,
    stride_g_bsize,
    stride_g_d1,
    stride_su_l,
    stride_su_d1,
    stride_su_k,
    stride_out_l,
    stride_out_bsize,
    stride_out_k,
    BLOCK_SIZE_BSIZE: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    BLOCK_SIZE_d1: tl.constexpr,
    GROUP_SIZE_BSIZE: tl.constexpr,
):
    pid = tl.program_id(axis=1)
    batch_id = tl.program_id(axis=0)

    num_pid_bsize = tl.cdiv(BSIZE, BLOCK_SIZE_BSIZE)
    num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
    num_pid_in_group = GROUP_SIZE_BSIZE * num_pid_k
    group_id = pid // num_pid_in_group
    first_pid_bsize = group_id * GROUP_SIZE_BSIZE
    group_size_bsize = min(num_pid_bsize - first_pid_bsize, GROUP_SIZE_BSIZE)
    pid_bsize = first_pid_bsize + ((pid % num_pid_in_group) % group_size_bsize)
    pid_k = (pid % num_pid_in_group) // group_size_bsize

    offs_bsize = pid_bsize * BLOCK_SIZE_BSIZE + tl.arange(0, BLOCK_SIZE_BSIZE)
    offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
    offs_d1 = tl.arange(0, BLOCK_SIZE_d1)

    g_ptrs = g_ptr + (
        offs_bsize[:, None] * stride_g_bsize + offs_d1[None, :] * stride_g_d1
    )

    su_tmp = batch_id * stride_su_l + (
        offs_d1[:, None] * stride_su_d1 + offs_k[None, :] * stride_su_k
    )
    U1s_ptrs = U1s_ptr + su_tmp
    S2s_ptrs = S2s_ptr + su_tmp

    accumulator1 = tl.full(
        shape=(BLOCK_SIZE_BSIZE, BLOCK_SIZE_K), value=0.0, dtype=tl.float32
    )
    accumulator2 = tl.full(
        shape=(BLOCK_SIZE_BSIZE, BLOCK_SIZE_K), value=0.0, dtype=tl.float32
    )

    for d1_i in range(0, tl.cdiv(d1, BLOCK_SIZE_d1)):
        g = tl.load(
            g_ptrs, mask=(offs_d1[None, :] < d1 - d1_i * BLOCK_SIZE_d1), other=0.0
        )

        su_mask = offs_d1[:, None] < d1 - d1_i * BLOCK_SIZE_d1
        U1s = tl.load(U1s_ptrs, mask=su_mask, other=0.0)
        S2s = tl.load(S2s_ptrs, mask=su_mask, other=0.0)

        accumulator1 += tl.dot(g, U1s, input_precision="ieee")
        accumulator2 += tl.dot(g, S2s, input_precision="ieee")

        g_ptrs += BLOCK_SIZE_d1 * stride_g_d1
        U1s_ptrs += BLOCK_SIZE_d1 * stride_su_d1
        S2s_ptrs += BLOCK_SIZE_d1 * stride_su_d1

    out_tmp = (
        batch_id * stride_out_l
        + stride_out_bsize * offs_bsize[:, None]
        + stride_out_k * offs_k[None, :]
    )
    g_U1s_ptrs = g_U1s_ptr + out_tmp
    g_S2s_ptrs = g_S2s_ptr + out_tmp

    out_mask = (offs_bsize[:, None] < BSIZE) & (offs_k[None, :] < K)

    tl.store(g_U1s_ptrs, accumulator1, mask=out_mask)
    tl.store(g_S2s_ptrs, accumulator2, mask=out_mask)


def first_pass_gU1s_g_S2s(g, U1s, S2s):
    # assert g.shape[1] == U1s.shape[1], "Incompatible dimensions"
    # assert g.shape[1] == S2s.shape[1], "Incompatible dimensions"
    # assert g.is_contiguous(), "Matrix A must be contiguous"
    # assert U1s.is_contiguous(), "Matrix A must be contiguous"
    # assert S2s.is_contiguous(), "Matrix A must be contiguous"
    # assert U1s.stride() == S2s.stride(), "Matrix A must be contiguous"

    BSIZE, d1 = g.shape
    L, _, K = U1s.shape

    g_U1s = torch.empty((L, BSIZE, K), dtype=torch.float32, device="cuda")
    g_S2s = torch.empty((L, BSIZE, K), dtype=torch.float32, device="cuda")

    # stride_g_bsize, stride_g_d1 = g.stride()
    # stride_su_l, stride_su_d1, stride_su_k = U1s.stride()
    # stride_out_l, stride_out_bsize, stride_out_k = g_U1s.stride()
    stride_g_bsize, stride_g_d1 = g.shape[1], 1
    stride_su_l, stride_su_d1, stride_su_k = (
        U1s.shape[1] * U1s.shape[2],
        U1s.shape[2],
        1,
    )
    stride_out_l, stride_out_bsize, stride_out_k = (
        g_U1s.shape[1] * g_U1s.shape[2],
        g_U1s.shape[2],
        1,
    )

    # assert g_U1s.stride() == g_S2s.stride(), "Matrix A must be contiguous"

    grid = lambda META: (
        L,
        triton.cdiv(BSIZE, META["BLOCK_SIZE_BSIZE"])
        * triton.cdiv(K, META["BLOCK_SIZE_K"]),
    )

    first_pass_gU1s_g_S2s_kernel[grid](
        g,
        U1s,
        S2s,
        g_U1s,
        g_S2s,
        BSIZE,
        K,
        d1,
        L,
        stride_g_bsize,
        stride_g_d1,
        stride_su_l,
        stride_su_d1,
        stride_su_k,
        stride_out_l,
        stride_out_bsize,
        stride_out_k,
    )

    return g_U1s, g_S2s


@triton.autotune(
    configs=[
        triton.Config(
            {
                "BLOCK_SIZE_BSIZE": 32,
                "BLOCK_SIZE_d2": 32,
                "BLOCK_SIZE_K": 32,
                "GROUP_SIZE_BSIZE": 8,
            },
            num_stages=1,
            num_warps=4,
        ),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_d2': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4,
        #                num_warps=4),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_d2': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4,
        #                num_warps=4),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_d2': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4,
        #                num_warps=4),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_d2': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4,
        #                num_warps=4),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_d2': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4,
        #                num_warps=4),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_d2': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=5,
        #                num_warps=2),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 32, 'BLOCK_SIZE_d2': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=5,
        #                num_warps=2),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_d2': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=3,
        #                num_warps=8),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 256, 'BLOCK_SIZE_d2': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=3,
        #                num_warps=8),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 256, 'BLOCK_SIZE_d2': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=4,
        #                num_warps=4),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_d2': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=4,
        #                num_warps=4),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_d2': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=4,
        #                num_warps=4),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_d2': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=4,
        #                num_warps=4),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_d2': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=4,
        #                num_warps=4),
        #  triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_d2': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=4,
        #                num_warps=4)
    ],
    key=["BSIZE", "d2", "K", "L"],
)
@triton.jit
def second_pass_gUS11_22_kernel(
    g_U1s_ptr,
    g_S2s_ptr,
    S1s_ptr,
    U2s_ptr,
    out_ptr,
    BSIZE,
    d2,
    K,
    L,
    stride_g_U1s2_l,
    stride_g_U1s2_bsize,
    stride_g_U1s2_k,
    stride_us_l,
    stride_us_k,
    stride_us_d2,
    stride_out_bsize,
    stride_out_d2,
    BLOCK_SIZE_BSIZE: tl.constexpr,
    BLOCK_SIZE_d2: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_BSIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)

    num_pid_bsize = tl.cdiv(BSIZE, BLOCK_SIZE_BSIZE)
    num_pid_d2 = tl.cdiv(d2, BLOCK_SIZE_d2)
    num_pid_in_group = GROUP_SIZE_BSIZE * num_pid_d2
    group_id = pid // num_pid_in_group
    first_pid_bsize = group_id * GROUP_SIZE_BSIZE
    GROUP_SIZE_BSIZE = min(num_pid_bsize - first_pid_bsize, GROUP_SIZE_BSIZE)
    pid_bsize = first_pid_bsize + ((pid % num_pid_in_group) % GROUP_SIZE_BSIZE)
    pid_d2 = (pid % num_pid_in_group) // GROUP_SIZE_BSIZE

    offs_bsize = pid_bsize * BLOCK_SIZE_BSIZE + tl.arange(0, BLOCK_SIZE_BSIZE)
    offs_d2 = pid_d2 * BLOCK_SIZE_d2 + tl.arange(0, BLOCK_SIZE_d2)
    offs_k = tl.arange(0, BLOCK_SIZE_K)

    in_tmp = (
        offs_bsize[:, None] * stride_g_U1s2_bsize + offs_k[None, :] * stride_g_U1s2_k
    )
    us_tmp = offs_k[:, None] * stride_us_k + offs_d2[None, :] * stride_us_d2

    accumulator = tl.full(
        shape=(BLOCK_SIZE_BSIZE, BLOCK_SIZE_d2), value=0.0, dtype=tl.float32
    )

    for l in range(0, L):
        g_l_offset = l * stride_g_U1s2_l  # Offset for g_U1s and g_S2s
        s_l_offset = l * stride_us_l  # Offset for S1s and U2s

        g_U1s_ptrs = g_U1s_ptr + g_l_offset + in_tmp
        g_S2s_ptrs = g_S2s_ptr + g_l_offset + in_tmp

        S1s_ptrs = S1s_ptr + s_l_offset + us_tmp
        U2s_ptrs = U2s_ptr + s_l_offset + us_tmp

        for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
            in_mask = offs_k[None, :] < K - k * BLOCK_SIZE_K
            g_U1s = tl.load(g_U1s_ptrs, mask=in_mask, other=0.0)
            g_S2s = tl.load(g_S2s_ptrs, mask=in_mask, other=0.0)

            us_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
            S1s = tl.load(S1s_ptrs, mask=us_mask, other=0.0)
            U2s = tl.load(U2s_ptrs, mask=us_mask, other=0.0)

            accumulator += tl.dot(g_U1s, S1s, input_precision="ieee")
            accumulator += tl.dot(g_S2s, U2s, input_precision="ieee")

            in_inc = BLOCK_SIZE_K * stride_g_U1s2_k
            g_U1s_ptrs += in_inc
            g_S2s_ptrs += in_inc

            us_inc = BLOCK_SIZE_K * stride_us_k
            S1s_ptrs += us_inc
            U2s_ptrs += us_inc

    accumulator *= 1.0 / (2.0 * L)

    out_ptrs = (
        out_ptr
        + stride_out_bsize * offs_bsize[:, None]
        + stride_out_d2 * offs_d2[None, :]
    )
    out_mask = (offs_bsize[:, None] < BSIZE) & (offs_d2[None, :] < d2)

    tl.store(out_ptrs, accumulator, mask=out_mask)


def second_pass_gUS11_22(g_U1s, g_S2s, S1s, U2s):
    # assert g_U1s.shape[2] == S1s.shape[1], "Incompatible dimensions"
    # assert g_S2s.shape[2] == U2s.shape[1], "Incompatible dimensions"
    # assert g_U1s.is_contiguous(), "Matrix A must be contiguous"
    # assert g_S2s.is_contiguous(), "Matrix A must be contiguous"
    # assert S1s.is_contiguous(), "Matrix A must be contiguous"
    # assert U2s.is_contiguous(), "Matrix A must be contiguous"
    # assert S1s.stride() == U2s.stride(), "Matrix A must be contiguous"
    # assert g_U1s.stride() == g_S2s.stride(), "Matrix A must be contiguous"

    L, BSIZE, K = g_U1s.shape
    _, _, d2 = S1s.shape

    out = torch.empty((BSIZE, d2), dtype=torch.float32, device="cuda")

    # stride_g_U1s2_l, stride_g_U1s2_bsize, stride_g_U1s2_k = g_U1s.stride()
    # stride_us_l, stride_us_k, stride_us_d2 = S1s.stride()
    # stride_out_bsize, stride_out_d2 = out.stride()
    stride_g_U1s2_l, stride_g_U1s2_bsize, stride_g_U1s2_k = (
        g_U1s.shape[1] * g_U1s.shape[2],
        g_U1s.shape[2],
        1,
    )
    stride_us_l, stride_us_k, stride_us_d2 = (
        S1s.shape[1] * S1s.shape[2],
        S1s.shape[2],
        1,
    )
    stride_out_bsize, stride_out_d2 = out.shape[1], 1

    grid = lambda META: (
        triton.cdiv(BSIZE, META["BLOCK_SIZE_BSIZE"])
        * triton.cdiv(d2, META["BLOCK_SIZE_d2"]),
    )

    second_pass_gUS11_22_kernel[grid](
        g_U1s,
        g_S2s,
        S1s,
        U2s,
        out,
        BSIZE,
        d2,
        K,
        L,
        stride_g_U1s2_l,
        stride_g_U1s2_bsize,
        stride_g_U1s2_k,
        stride_us_l,
        stride_us_k,
        stride_us_d2,
        stride_out_bsize,
        stride_out_d2,
    )

    return out  # grad


@triton.autotune(
    configs=[
        triton.Config(
            {
                "BLOCK_SIZE_d2": 32,
                "BLOCK_SIZE_k": 32,
                "BLOCK_SIZE_BSIZE": 32,
                "GROUP_SIZE_d2": 8,
            },
            num_stages=1,
            num_warps=4,
        ),
        # triton.Config({'BLOCK_SIZE_d2': 64, 'BLOCK_SIZE_k': 256, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_d2': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_d2': 128, 'BLOCK_SIZE_k': 128, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_d2': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_d2': 128, 'BLOCK_SIZE_k': 64, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_d2': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_d2': 64, 'BLOCK_SIZE_k': 128, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_d2': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_d2': 128, 'BLOCK_SIZE_k': 32, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_d2': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_d2': 64, 'BLOCK_SIZE_k': 32, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_d2': 8}, num_stages=5, num_warps=2),
        # triton.Config({'BLOCK_SIZE_d2': 32, 'BLOCK_SIZE_k': 64, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_d2': 8}, num_stages=5, num_warps=2),
        # triton.Config({'BLOCK_SIZE_d2': 128, 'BLOCK_SIZE_k': 256, 'BLOCK_SIZE_BSIZE': 128, 'GROUP_SIZE_d2': 8}, num_stages=3, num_warps=8),
        # triton.Config({'BLOCK_SIZE_d2': 256, 'BLOCK_SIZE_k': 128, 'BLOCK_SIZE_BSIZE': 128, 'GROUP_SIZE_d2': 8}, num_stages=3, num_warps=8),
        # triton.Config({'BLOCK_SIZE_d2': 256, 'BLOCK_SIZE_k': 64, 'BLOCK_SIZE_BSIZE': 128, 'GROUP_SIZE_d2': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_d2': 64, 'BLOCK_SIZE_k': 256, 'BLOCK_SIZE_BSIZE': 128, 'GROUP_SIZE_d2': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_d2': 128, 'BLOCK_SIZE_k': 128, 'BLOCK_SIZE_BSIZE': 128, 'GROUP_SIZE_d2': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_d2': 128, 'BLOCK_SIZE_k': 64, 'BLOCK_SIZE_BSIZE': 64, 'GROUP_SIZE_d2': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_d2': 64, 'BLOCK_SIZE_k': 128, 'BLOCK_SIZE_BSIZE': 64, 'GROUP_SIZE_d2': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_d2': 128, 'BLOCK_SIZE_k': 32, 'BLOCK_SIZE_BSIZE': 64, 'GROUP_SIZE_d2': 8}, num_stages=4, num_warps=4)
    ],
    key=["d2", "k", "BSIZE", "L"],
)
@triton.jit
def calc_grad_S1s_kernel(
    hin_ptr,
    g_U1s_ptr,
    grad_g_S1s_ptr,
    d2,
    k,
    BSIZE,
    L,
    stride_hin_bsize,
    stride_hin_BSIZE,
    stride_su_l,
    stride_su_BSIZE,
    stride_su_k,
    stride_out_l,
    stride_out_bsize,
    stride_out_k,
    BLOCK_SIZE_d2: tl.constexpr,
    BLOCK_SIZE_k: tl.constexpr,
    BLOCK_SIZE_BSIZE: tl.constexpr,
    GROUP_SIZE_d2: tl.constexpr,
):
    pid = tl.program_id(axis=1)
    batch_id = tl.program_id(axis=0)

    num_pid_bsize = tl.cdiv(d2, BLOCK_SIZE_d2)
    num_pid_k = tl.cdiv(k, BLOCK_SIZE_k)
    num_pid_in_group = GROUP_SIZE_d2 * num_pid_k
    group_id = pid // num_pid_in_group
    first_pid_bsize = group_id * GROUP_SIZE_d2
    group_size_bsize = min(num_pid_bsize - first_pid_bsize, GROUP_SIZE_d2)
    pid_bsize = first_pid_bsize + ((pid % num_pid_in_group) % group_size_bsize)
    pid_k = (pid % num_pid_in_group) // group_size_bsize

    offs_bsize = pid_bsize * BLOCK_SIZE_d2 + tl.arange(0, BLOCK_SIZE_d2)
    offs_k = pid_k * BLOCK_SIZE_k + tl.arange(0, BLOCK_SIZE_k)
    offs_BSIZE = tl.arange(0, BLOCK_SIZE_BSIZE)

    offs_bsize = tl.max_contiguous(
        tl.multiple_of(offs_bsize, BLOCK_SIZE_d2), BLOCK_SIZE_d2
    )
    offs_k = tl.max_contiguous(tl.multiple_of(offs_k, BLOCK_SIZE_k), BLOCK_SIZE_k)
    offs_BSIZE = tl.max_contiguous(
        tl.multiple_of(offs_BSIZE, BLOCK_SIZE_BSIZE), BLOCK_SIZE_BSIZE
    )

    hin_ptrs = hin_ptr + (
        offs_bsize[:, None] * stride_hin_bsize + offs_BSIZE[None, :] * stride_hin_BSIZE
    )

    su_tmp = batch_id * stride_su_l + (
        offs_BSIZE[:, None] * stride_su_BSIZE + offs_k[None, :] * stride_su_k
    )
    g_U1s_ptrs = g_U1s_ptr + su_tmp

    accumulator1 = tl.full(
        shape=(BLOCK_SIZE_d2, BLOCK_SIZE_k), value=0.0, dtype=tl.float32
    )
    accumulator2 = tl.full(
        shape=(BLOCK_SIZE_d2, BLOCK_SIZE_k), value=0.0, dtype=tl.float32
    )

    for BSIZE_i in range(0, tl.cdiv(BSIZE, BLOCK_SIZE_BSIZE)):
        hin_mask = (offs_bsize[:, None] < d2) & (
            offs_BSIZE[None, :] < BSIZE - BSIZE_i * BLOCK_SIZE_BSIZE
        )
        hin = tl.load(hin_ptrs, mask=hin_mask, other=0.0)

        su_mask = (offs_BSIZE[:, None] < BSIZE - BSIZE_i * BLOCK_SIZE_BSIZE) & (
            offs_k[None, :] < k
        )
        g_U1s = tl.load(g_U1s_ptrs, mask=su_mask, other=0.0)

        accumulator1 += tl.dot(hin, g_U1s, input_precision="ieee")

        hin_ptrs += BLOCK_SIZE_BSIZE * stride_hin_BSIZE
        g_U1s_ptrs += BLOCK_SIZE_BSIZE * stride_su_BSIZE

    out_tmp = (
        batch_id * stride_out_l
        + stride_out_bsize * offs_bsize[:, None]
        + stride_out_k * offs_k[None, :]
    )
    grad_g_S1s_ptrs = grad_g_S1s_ptr + out_tmp

    out_mask = (offs_bsize[:, None] < d2) & (offs_k[None, :] < k)

    tl.store(grad_g_S1s_ptrs, accumulator1, mask=out_mask)


def calc_grad_S1s(hin, g_U1s):
    device = "cuda"
    # assert hin.shape[1] == g_U1s.shape[1], "Incompatible dimensions"
    # assert hin.is_contiguous(), "Matrix A must be contiguous"
    # assert g_U1s.is_contiguous(), "Matrix A must be contiguous"

    d2, BSIZE = hin.shape
    L, _, k = g_U1s.shape

    grad_g_S1s = torch.empty((L, d2, k), dtype=torch.float32, device=device)

    # stride_hin_bsize, stride_hin_BSIZE = hin.stride()
    # stride_su_l, stride_su_BSIZE, stride_su_k = g_U1s.stride()
    # stride_out_l, stride_out_bsize, stride_out_k = grad_g_S1s.stride()
    stride_hin_bsize, stride_hin_BSIZE = hin.shape[1], 1
    stride_su_l, stride_su_BSIZE, stride_su_k = (
        g_U1s.shape[1] * g_U1s.shape[2],
        g_U1s.shape[2],
        1,
    )
    stride_out_l, stride_out_bsize, stride_out_k = (
        grad_g_S1s.shape[1] * grad_g_S1s.shape[2],
        grad_g_S1s.shape[2],
        1,
    )

    grid = lambda META: (
        L,
        triton.cdiv(d2, META["BLOCK_SIZE_d2"]) * triton.cdiv(k, META["BLOCK_SIZE_k"]),
    )

    calc_grad_S1s_kernel[grid](
        hin,
        g_U1s,
        grad_g_S1s,
        d2,
        k,
        BSIZE,
        L,
        stride_hin_bsize,
        stride_hin_BSIZE,
        stride_su_l,
        stride_su_BSIZE,
        stride_su_k,
        stride_out_l,
        stride_out_bsize,
        stride_out_k,
    )

    return grad_g_S1s


@triton.autotune(
    configs=[
        triton.Config(
            {
                "BLOCK_SIZE_K": 32,
                "BLOCK_SIZE_BSIZE": 32,
                "BLOCK_SIZE_d2": 32,
                "GROUP_SIZE_K": 8,
            },
            num_stages=1,
            num_warps=4,
        ),
        # triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_BSIZE': 256, 'BLOCK_SIZE_d2': 32, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_d2': 32, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_d2': 32, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_d2': 32, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_BSIZE': 32, 'BLOCK_SIZE_d2': 32, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_BSIZE': 32, 'BLOCK_SIZE_d2': 32, 'GROUP_SIZE_K': 8}, num_stages=5, num_warps=2),
        # triton.Config({'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_d2': 32, 'GROUP_SIZE_K': 8}, num_stages=5, num_warps=2),
        # triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_BSIZE': 256, 'BLOCK_SIZE_d2': 128, 'GROUP_SIZE_K': 8}, num_stages=3, num_warps=8),
        # triton.Config({'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_d2': 128, 'GROUP_SIZE_K': 8}, num_stages=3, num_warps=8),
        # triton.Config({'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_d2': 128, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_BSIZE': 256, 'BLOCK_SIZE_d2': 128, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_d2': 128, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_d2': 64, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_d2': 64, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_BSIZE': 32, 'BLOCK_SIZE_d2': 64, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4)
    ],
    key=["K", "d2", "BSIZE", "L"],
)
@triton.jit
def first_pass_U2s_hin_kernel(
    hin_ptr,
    U2s_ptr,
    U2s_h_in_ptr,
    K,
    d2,
    BSIZE,
    L,
    stride_hin_d2,
    stride_hin_BSIZE,
    stride_su_l,
    stride_su_K,
    stride_su_d2,
    stride_out_l,
    stride_out_K,
    stride_out_BSIZE,
    BLOCK_SIZE_K: tl.constexpr,
    BLOCK_SIZE_BSIZE: tl.constexpr,
    BLOCK_SIZE_d2: tl.constexpr,
    GROUP_SIZE_K: tl.constexpr,
):
    pid = tl.program_id(axis=1)
    batch_id = tl.program_id(axis=0)

    num_pid_K = tl.cdiv(K, BLOCK_SIZE_K)
    num_pid_BSIZE = tl.cdiv(BSIZE, BLOCK_SIZE_BSIZE)
    num_pid_in_group = GROUP_SIZE_K * num_pid_BSIZE
    group_id = pid // num_pid_in_group
    first_pid_K = group_id * GROUP_SIZE_K
    group_size_BSIZE = min(num_pid_K - first_pid_K, GROUP_SIZE_K)
    pid_K = first_pid_K + ((pid % num_pid_in_group) % group_size_BSIZE)
    pid_BSIZE = (pid % num_pid_in_group) // group_size_BSIZE

    offs_K = pid_K * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
    offs_BSIZE = pid_BSIZE * BLOCK_SIZE_BSIZE + tl.arange(0, BLOCK_SIZE_BSIZE)
    offs_d2 = tl.arange(0, BLOCK_SIZE_d2)

    offs_K = tl.max_contiguous(tl.multiple_of(offs_K, BLOCK_SIZE_K), BLOCK_SIZE_K)
    offs_BSIZE = tl.max_contiguous(
        tl.multiple_of(offs_BSIZE, BLOCK_SIZE_BSIZE), BLOCK_SIZE_BSIZE
    )
    offs_d2 = tl.max_contiguous(tl.multiple_of(offs_d2, BLOCK_SIZE_d2), BLOCK_SIZE_d2)

    hin_ptrs = hin_ptr + (
        offs_d2[:, None] * stride_hin_d2 + offs_BSIZE[None, :] * stride_hin_BSIZE
    )

    su_tmp = batch_id * stride_su_l + (
        offs_K[:, None] * stride_su_K + offs_d2[None, :] * stride_su_d2
    )
    U2s_ptrs = U2s_ptr + su_tmp

    accumulator1 = tl.full(
        shape=(BLOCK_SIZE_K, BLOCK_SIZE_BSIZE), value=0.0, dtype=tl.float32
    )

    for d2_i in range(0, tl.cdiv(d2, BLOCK_SIZE_d2)):
        hin_mask = (offs_d2[:, None] < d2 - d2_i * BLOCK_SIZE_d2) & (
            offs_BSIZE[None, :] < BSIZE
        )
        hin = tl.load(hin_ptrs, mask=hin_mask, other=0.0)

        su_mask = (offs_K[:, None] < K) & (offs_d2[None, :] < d2 - d2_i * BLOCK_SIZE_d2)
        U2s = tl.load(U2s_ptrs, mask=su_mask, other=0.0)

        accumulator1 += tl.dot(U2s, hin, input_precision="ieee")

        hin_ptrs += BLOCK_SIZE_d2 * stride_hin_d2
        U2s_ptrs += BLOCK_SIZE_d2 * stride_su_d2

    out_tmp = (
        batch_id * stride_out_l
        + stride_out_K * offs_K[:, None]
        + stride_out_BSIZE * offs_BSIZE[None, :]
    )
    U2s_h_in_ptrs = U2s_h_in_ptr + out_tmp

    out_mask = (offs_K[:, None] < K) & (offs_BSIZE[None, :] < BSIZE)

    tl.store(U2s_h_in_ptrs, accumulator1, mask=out_mask)


def first_pass_U2s_hin(U2s, hin):
    device = "cuda"
    # assert U2s.shape[2] == hin.shape[0], "Incompatible dimensions"
    # assert hin.is_contiguous(), "Matrix A must be contiguous"
    # assert U2s.is_contiguous(), "Matrix A must be contiguous"

    L, K, d2 = U2s.shape
    _, BSIZE = hin.shape

    U2s_h_in = torch.empty((L, K, BSIZE), dtype=torch.float32, device=device)

    # stride_hin_d2, stride_hin_BSIZE = hin.stride()
    # stride_su_l, stride_su_K, stride_su_d2 = U2s.stride()
    # stride_out_l, stride_out_K, stride_out_BSIZE = U2s_h_in.stride()
    stride_hin_d2, stride_hin_BSIZE = hin.shape[1], 1
    stride_su_l, stride_su_K, stride_su_d2 = (
        U2s.shape[1] * U2s.shape[2],
        U2s.shape[2],
        1,
    )
    stride_out_l, stride_out_K, stride_out_BSIZE = (
        U2s_h_in.shape[1] * U2s_h_in.shape[2],
        U2s_h_in.shape[2],
        1,
    )

    BLOCK_SIZE_K, BLOCK_SIZE_BSIZE, BLOCK_SIZE_d2 = 128, 256, 64
    GROUP_SIZE_K = 8

    grid = lambda META: (
        L,
        triton.cdiv(K, META["BLOCK_SIZE_K"])
        * triton.cdiv(BSIZE, META["BLOCK_SIZE_BSIZE"]),
    )

    first_pass_U2s_hin_kernel[grid](
        hin,
        U2s,
        U2s_h_in,
        K,
        d2,
        BSIZE,
        L,
        stride_hin_d2,
        stride_hin_BSIZE,
        stride_su_l,
        stride_su_K,
        stride_su_d2,
        stride_out_l,
        stride_out_K,
        stride_out_BSIZE,
    )

    return U2s_h_in


@triton.autotune(
    configs=[
        triton.Config(
            {
                "BLOCK_SIZE_K": 32,
                "BLOCK_SIZE_d1": 32,
                "BLOCK_SIZE_BSIZE": 32,
                "GROUP_SIZE_K": 8,
            },
            num_stages=1,
            num_warps=4,
        ),
        # triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_d1': 256, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_d1': 128, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_d1': 64, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_d1': 128, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_d1': 32, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_d1': 32, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_K': 8}, num_stages=5, num_warps=2),
        # triton.Config({'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_d1': 64, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_K': 8}, num_stages=5, num_warps=2),
        # triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_d1': 256, 'BLOCK_SIZE_BSIZE': 128, 'GROUP_SIZE_K': 8}, num_stages=3, num_warps=8),
        # triton.Config({'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_d1': 128, 'BLOCK_SIZE_BSIZE': 128, 'GROUP_SIZE_K': 8}, num_stages=3, num_warps=8),
        # triton.Config({'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_d1': 64, 'BLOCK_SIZE_BSIZE': 128, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_d1': 256, 'BLOCK_SIZE_BSIZE': 128, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_d1': 128, 'BLOCK_SIZE_BSIZE': 128, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_d1': 64, 'BLOCK_SIZE_BSIZE': 64, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_d1': 128, 'BLOCK_SIZE_BSIZE': 64, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_d1': 32, 'BLOCK_SIZE_BSIZE': 64, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4)
    ],
    key=["K", "BSIZE", "d1", "L"],
)
@triton.jit
def calc_grad_S2s_kernel(
    g_ptr,
    U2s_hin_ptr,
    grad_S2s_ptr,
    K,
    BSIZE,
    d1,
    L,
    stride_g_BSIZE,
    stride_g_d1,
    stride_su_l,
    stride_su_K,
    stride_su_BSIZE,
    stride_out_l,
    stride_out_K,
    stride_out_d1,
    BLOCK_SIZE_K: tl.constexpr,
    BLOCK_SIZE_d1: tl.constexpr,
    BLOCK_SIZE_BSIZE: tl.constexpr,
    GROUP_SIZE_K: tl.constexpr,
):
    pid = tl.program_id(axis=1)
    batch_id = tl.program_id(axis=0)

    num_pid_K = tl.cdiv(K, BLOCK_SIZE_K)
    num_pid_d1 = tl.cdiv(d1, BLOCK_SIZE_d1)
    num_pid_in_group = GROUP_SIZE_K * num_pid_d1
    group_id = pid // num_pid_in_group
    first_pid_K = group_id * GROUP_SIZE_K
    group_size_d1 = min(num_pid_K - first_pid_K, GROUP_SIZE_K)
    pid_K = first_pid_K + ((pid % num_pid_in_group) % group_size_d1)
    pid_d1 = (pid % num_pid_in_group) // group_size_d1

    offs_K = pid_K * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
    offs_d1 = pid_d1 * BLOCK_SIZE_d1 + tl.arange(0, BLOCK_SIZE_d1)
    offs_BSIZE = tl.arange(0, BLOCK_SIZE_BSIZE)

    offs_K = tl.max_contiguous(tl.multiple_of(offs_K, BLOCK_SIZE_K), BLOCK_SIZE_K)
    offs_d1 = tl.max_contiguous(tl.multiple_of(offs_d1, BLOCK_SIZE_d1), BLOCK_SIZE_d1)
    offs_BSIZE = tl.max_contiguous(
        tl.multiple_of(offs_BSIZE, BLOCK_SIZE_BSIZE), BLOCK_SIZE_BSIZE
    )

    g_ptrs = g_ptr + (
        offs_BSIZE[:, None] * stride_g_BSIZE + offs_d1[None, :] * stride_g_d1
    )

    su_tmp = batch_id * stride_su_l + (
        offs_K[:, None] * stride_su_K + offs_BSIZE[None, :] * stride_su_BSIZE
    )
    U2s_hin_ptrs = U2s_hin_ptr + su_tmp

    accumulator1 = tl.full(
        shape=(BLOCK_SIZE_K, BLOCK_SIZE_d1), value=0.0, dtype=tl.float32
    )

    for BSIZE_i in range(0, tl.cdiv(BSIZE, BLOCK_SIZE_BSIZE)):
        g_mask = (offs_BSIZE[:, None] < BSIZE - BSIZE_i * BLOCK_SIZE_BSIZE) & (
            offs_d1[None, :] < d1
        )
        g = tl.load(g_ptrs, mask=g_mask, other=0.0)

        su_mask = (offs_K[:, None] < K) & (
            offs_BSIZE[None, :] < BSIZE - BSIZE_i * BLOCK_SIZE_BSIZE
        )
        U2s_hin = tl.load(U2s_hin_ptrs, mask=su_mask, other=0.0)

        accumulator1 += tl.dot(U2s_hin, g, input_precision="ieee")

        g_ptrs += BLOCK_SIZE_BSIZE * stride_g_BSIZE
        U2s_hin_ptrs += BLOCK_SIZE_BSIZE * stride_su_BSIZE

    out_tmp = (
        batch_id * stride_out_l
        + stride_out_K * offs_K[:, None]
        + stride_out_d1 * offs_d1[None, :]
    )
    grad_S2s_ptrs = grad_S2s_ptr + out_tmp

    out_mask = (offs_K[:, None] < K) & (offs_d1[None, :] < d1)

    tl.store(grad_S2s_ptrs, accumulator1, mask=out_mask)


def calc_grad_S2s(U2s_hin, g):
    device = "cuda"
    # assert U2s_hin.shape[2] == g.shape[0], "Incompatible dimensions"
    # assert g.is_contiguous(), "Matrix A must be contiguous"
    # assert U2s_hin.is_contiguous(), "Matrix A must be contiguous"

    L, K, BSIZE = U2s_hin.shape
    _, d1 = g.shape

    grad_S2s = torch.empty((L, K, d1), dtype=torch.float32, device=device)

    # stride_g_BSIZE, stride_g_d1 = g.stride()
    # stride_su_l, stride_su_K, stride_su_BSIZE = U2s_hin.stride()
    # stride_out_l, stride_out_K, stride_out_d1 = grad_S2s.stride()
    stride_g_BSIZE, stride_g_d1 = g.shape[1], 1
    stride_su_l, stride_su_K, stride_su_BSIZE = (
        U2s_hin.shape[1] * U2s_hin.shape[2],
        U2s_hin.shape[2],
        1,
    )
    stride_out_l, stride_out_K, stride_out_d1 = (
        grad_S2s.shape[1] * grad_S2s.shape[2],
        grad_S2s.shape[2],
        1,
    )

    BLOCK_SIZE_K, BLOCK_SIZE_d1, BLOCK_SIZE_BSIZE = 128, 256, 64
    GROUP_SIZE_K = 8

    grid = lambda META: (
        L,
        triton.cdiv(K, META["BLOCK_SIZE_K"]) * triton.cdiv(d1, META["BLOCK_SIZE_d1"]),
    )

    calc_grad_S2s_kernel[grid](
        g,
        U2s_hin,
        grad_S2s,
        K,
        BSIZE,
        d1,
        L,
        stride_g_BSIZE,
        stride_g_d1,
        stride_su_l,
        stride_su_K,
        stride_su_BSIZE,
        stride_out_l,
        stride_out_K,
        stride_out_d1,
    )

    return grad_S2s

In [31]:
import torch

# DISCLAIMER: THIS FILE NEEDS TO BE CHECKED FOR CORRECTNESS


def uniform_dense_sketch(m, n, device=None, dtype=None):
    factory_kwargs = {"device": device, "dtype": dtype}
    return torch.empty(m, n, **factory_kwargs).uniform_(-1, 1)


def gaussian_dense_sketch(m, n, device=None, dtype=None):
    factory_kwargs = {"device": device, "dtype": dtype}
    return torch.randn(m, n, **factory_kwargs)


def hadamard_sketch(m, device=None, dtype=None):
    factory_kwargs = {"device": device, "dtype": dtype}
    if m & (m - 1) != 0:
        raise ValueError("m must be a power of 2")

    H = torch.tensor([[1.0]])
    while H.shape[0] < m:
        H = torch.cat((torch.cat((H, H), dim=1), torch.cat((H, -H), dim=1)), dim=0)

    return H / torch.sqrt(torch.tensor(m, **factory_kwargs))


def gaussian_orthonormal_sketch(m, n, device=None, dtype=None):
    factory_kwargs = {"device": device, "dtype": dtype}
    return torch.qr(torch.randn(m, n, **factory_kwargs))[0]


def gen_U(m, n, device=None, dtype=None):
    factory_kwargs = {"device": device, "dtype": dtype}
    return (torch.randint(0, 2, (m, n), **factory_kwargs) * 2 - 1) / torch.sqrt(
        torch.tensor(m, **factory_kwargs)
    )


def clarkson_woodruff_sketch(m, n, device=None, dtype=None):
    factory_kwargs = {"device": device, "dtype": dtype}
    indices = torch.randint(0, m, (n,), **factory_kwargs)
    signs = torch.randint(0, 2, (n,), **factory_kwargs) * 2 - 1
    sketch = torch.zeros(m, n, **factory_kwargs)
    sketch[indices, torch.arange(n)] = signs
    return sketch


def sparse_sign_embeddings_sketch(m, n, sparsity=0.1):
    mask = torch.rand(m, n) < sparsity
    signs = torch.randint(0, 2, (m, n)) * 2 - 1
    return mask.float() * signs.float()

In [57]:
import math
from typing import Any, Tuple, List, Optional

import torch
import torch.nn as nn
from torch.autograd import Function
from torch.nn import init
import triton
from torch.library import triton_op, wrap_triton


@triton_op("panther::forward_op", mutates_args={})
def forward_op(
    hin: torch.Tensor,
    S1s: torch.Tensor,
    S2s: torch.Tensor,
    U1s: torch.Tensor,
    U2s: torch.Tensor,
    bias: torch.Tensor,
) -> torch.Tensor:
    device = "cuda"

    # first pass
    ############
    L = S2s.shape[0]
    BSIZE, d2 = hin.shape
    L, _, K = S1s.shape

    in1 = torch.empty((L, BSIZE, K), dtype=torch.float32, device=device)
    in2 = torch.empty((L, BSIZE, K), dtype=torch.float32, device=device)

    stride_hin_bsize, stride_hin_d2 = hin.stride(0), hin.stride(1)
    stride_su_l, stride_su_d2, stride_su_k = S1s.stride(0), S1s.stride(1), S1s.stride(2)
    stride_out_l, stride_out_bsize, stride_out_k = (
        in1.stride(0),
        in1.stride(1),
        in1.stride(2),
    )

    grid = lambda META: (
        L,
        triton.cdiv(BSIZE, META["BLOCK_SIZE_BSIZE"])
        * triton.cdiv(K, META["BLOCK_SIZE_K"]),
    )

    wrap_triton(first_pass_kernel)[grid](
        hin,
        S1s,
        U2s,
        in1,
        in2,
        BSIZE,
        K,
        d2,
        L,
        stride_hin_bsize,
        stride_hin_d2,
        stride_su_l,
        stride_su_d2,
        stride_su_k,
        stride_out_l,
        stride_out_bsize,
        stride_out_k,
    )

    # torch equivlant
    num_terms = S2s.shape[0]
    input_torch = hin.unsqueeze(0).expand(num_terms, hin.shape[0], hin.shape[1])
    in1_torch = input_torch.bmm(S1s)
    in2_torch = input_torch.bmm(U2s)

    # compare using torch allclose
    outputs_match = torch.allclose(in1_torch, in1)
    print(f"forward in1 and in1_torch Outputs match: {outputs_match}")
    if not outputs_match:
        max_diff = torch.max(torch.abs(in1_torch - in1))
        print(f"Max difference: {max_diff.item()}")
        print(f"out_triton: {in1_torch}")
        print(f"out_normal: {in1}")

    outputs_match = torch.allclose(in2_torch, in2)
    print(f"forward in2 and in2_torch Outputs match: {outputs_match}")
    if not outputs_match:
        max_diff = torch.max(torch.abs(in2_torch - in2))
        print(f"Max difference: {max_diff.item()}")
        print(f"out_triton: {in2_torch}")
        print(f"out_normal: {in2}")

    # second pass
    #############
    bias_unsqueezed = bias.unsqueeze(0)
    L, BSIZE, K = in1.shape
    _, _, d1 = U1s.shape

    out = torch.empty((BSIZE, d1), dtype=torch.float32, device=device)

    stride_in12_l, stride_in12_bsize, stride_in12_k = (
        in1.stride(0),
        in1.stride(1),
        in1.stride(2),
    )
    stride_us_l, stride_us_k, stride_us_d1 = U1s.stride(0), U1s.stride(1), U1s.stride(2)
    stride_bias_bsize, stride_bias_d1 = (
        bias_unsqueezed.stride(0),
        bias_unsqueezed.stride(1),
    )
    stride_out_bsize, stride_out_d1 = out.stride(0), out.stride(1)

    grid = lambda META: (
        triton.cdiv(BSIZE, META["BLOCK_SIZE_BSIZE"])
        * triton.cdiv(d1, META["BLOCK_SIZE_D1"]),
    )

    wrap_triton(second_pass_kernel)[grid](
        in1,
        in2,
        U1s,
        S2s,
        bias_unsqueezed,
        out,
        BSIZE,
        d1,
        K,
        L,
        stride_in12_l,
        stride_in12_bsize,
        stride_in12_k,
        stride_us_l,
        stride_us_k,
        stride_us_d1,
        stride_bias_bsize,
        stride_bias_d1,
        stride_out_bsize,
        stride_out_d1,
    )

    torch_out = (
        ((input_torch.bmm(S1s)).bmm(U1s)).mean(0) / 2
        + ((input_torch.bmm(U2s)).bmm(S2s)).mean(0) / 2
        + bias
    )
    # compare using torch allclose
    outputs_match = torch.allclose(torch_out, out)
    print(f"forward out and torch_out Outputs match: {outputs_match}")
    if not outputs_match:
        max_diff = torch.max(torch.abs(torch_out - out))
        print(f"Max difference: {max_diff.item()}")
        print(f"torch_out: {torch_out}")
        print(f"out: {out}")

    return out


@forward_op.register_kernel("cpu")
def _(input, S1s, S2s, U1s, U2s, bias):
    num_terms = S2s.shape[0]
    # Efficiently perform the sum over all l terms
    input = input.unsqueeze(0).expand(num_terms, input.shape[0], input.shape[1])
    return (
        ((input.bmm(S1s)).bmm(U1s)).mean(0) / 2
        + ((input.bmm(U2s)).bmm(S2s)).mean(0) / 2
        + bias
    )


@triton_op("panther::backward_op", mutates_args={})
def backward_op(
    hin: torch.Tensor,
    S1s: torch.Tensor,
    S2s: torch.Tensor,
    U1s: torch.Tensor,
    U2s: torch.Tensor,
    g: torch.Tensor,
) -> List[torch.Tensor]:
    device = "cuda"
    num_terms = S2s.shape[0]

    hin = hin.transpose(0, 1)
    U1s = U1s.transpose(1, 2)
    S1s = S1s.transpose(1, 2)
    U2s = U2s.transpose(1, 2)
    S2s = S2s.transpose(1, 2)

    # first_pass_gU1s_g_S2s
    #######################
    BSIZE, d1 = g.shape
    L, _, K = U1s.shape

    g_U1s = torch.empty((L, BSIZE, K), dtype=torch.float32, device="cuda")
    g_S2s = torch.empty((L, BSIZE, K), dtype=torch.float32, device="cuda")

    stride_g_bsize, stride_g_d1 = g.stride(0), g.stride(1)
    stride_su_l, stride_su_d1, stride_su_k = U1s.stride(0), U1s.stride(1), U1s.stride(2)
    stride_out_l, stride_out_bsize, stride_out_k = (
        g_U1s.stride(0),
        g_U1s.stride(1),
        g_U1s.stride(2),
    )

    grid = lambda META: (
        L,
        triton.cdiv(BSIZE, META["BLOCK_SIZE_BSIZE"])
        * triton.cdiv(K, META["BLOCK_SIZE_K"]),
    )

    wrap_triton(first_pass_gU1s_g_S2s_kernel)[grid](
        g,
        U1s,
        S2s,
        g_U1s,
        g_S2s,
        BSIZE,
        K,
        d1,
        L,
        stride_g_bsize,
        stride_g_d1,
        stride_su_l,
        stride_su_d1,
        stride_su_k,
        stride_out_l,
        stride_out_bsize,
        stride_out_k,
    )

    num_terms = S2s.shape[0]
    g_torch_unsqueezed = g.unsqueeze(0).expand(num_terms, g.shape[0], g.shape[1])
    g_U1s_torch = g_torch_unsqueezed.bmm(U1s)
    g_S2s_torch = g_torch_unsqueezed.bmm(S2s)

    # check if the outputs match
    outputs_match = torch.allclose(g_U1s_torch, g_U1s)
    print(f"backward g_U1s and g_U1s_torch Outputs match: {outputs_match}")
    if not outputs_match:
        max_diff = torch.max(torch.abs(g_U1s_torch - g_U1s))
        print(f"Max difference: {max_diff.item()}")
        print(f"g_U1s_torch: {g_U1s_torch}")
        print(f"g_U1s: {g_U1s}")

    outputs_match = torch.allclose(g_S2s_torch, g_S2s)
    print(f"backward g_S2s and g_S2s_torch Outputs match: {outputs_match}")
    if not outputs_match:
        max_diff = torch.max(torch.abs(g_S2s_torch - g_S2s))
        print(f"Max difference: {max_diff.item()}")
        print(f"g_S2s_torch: {g_S2s_torch}")
        print(f"g_S2s: {g_S2s}")

    # second_pass_gUS11_22
    #######################
    L, BSIZE, K = g_U1s.shape
    _, _, d2 = S1s.shape

    grad = torch.empty((BSIZE, d2), dtype=torch.float32, device="cuda")

    stride_g_U1s2_l, stride_g_U1s2_bsize, stride_g_U1s2_k = (
        g_U1s.stride(0),
        g_U1s.stride(1),
        g_U1s.stride(2),
    )
    stride_us_l, stride_us_k, stride_us_d2 = S1s.stride(0), S1s.stride(1), S1s.stride(2)
    stride_out_bsize, stride_out_d2 = grad.stride(0), grad.stride(1)

    grid = lambda META: (
        triton.cdiv(BSIZE, META["BLOCK_SIZE_BSIZE"])
        * triton.cdiv(d2, META["BLOCK_SIZE_d2"]),
    )

    wrap_triton(second_pass_gUS11_22_kernel)[grid](
        g_U1s,
        g_S2s,
        S1s,
        U2s,
        grad,
        BSIZE,
        d2,
        K,
        L,
        stride_g_U1s2_l,
        stride_g_U1s2_bsize,
        stride_g_U1s2_k,
        stride_us_l,
        stride_us_k,
        stride_us_d2,
        stride_out_bsize,
        stride_out_d2,
    )

    grad_torch = (
        g_torch_unsqueezed.bmm(U1s).bmm(S1s).sum(0)
        + g_torch_unsqueezed.bmm(S2s).bmm(U2s).sum(0)
    ) / (2 * num_terms)

    outputs_match = torch.allclose(grad_torch, grad)
    print(f"backward grad and grad_torch Outputs match: {outputs_match}")
    if not outputs_match:
        max_diff = torch.max(torch.abs(grad_torch - grad))
        print(f"Max difference: {max_diff.item()}")
        print(f"grad_torch: {grad_torch}")
        print(f"grad: {grad}")

    # calc_grad_S1s
    ################
    d2, BSIZE = hin.shape
    L, _, k = g_U1s.shape

    grad_S1s = torch.empty((L, d2, k), dtype=torch.float32, device=device)

    stride_hin_bsize, stride_hin_BSIZE = hin.stride(0), hin.stride(1)
    stride_su_l, stride_su_BSIZE, stride_su_k = (
        g_U1s.stride(0),
        g_U1s.stride(1),
        g_U1s.stride(2),
    )
    stride_out_l, stride_out_bsize, stride_out_k = (
        grad_S1s.stride(0),
        grad_S1s.stride(1),
        grad_S1s.stride(2),
    )

    grid = lambda META: (
        L,
        triton.cdiv(d2, META["BLOCK_SIZE_d2"]) * triton.cdiv(k, META["BLOCK_SIZE_k"]),
    )

    wrap_triton(calc_grad_S1s_kernel)[grid](
        hin,
        g_U1s,
        grad_S1s,
        d2,
        k,
        BSIZE,
        L,
        stride_hin_bsize,
        stride_hin_BSIZE,
        stride_su_l,
        stride_su_BSIZE,
        stride_su_k,
        stride_out_l,
        stride_out_bsize,
        stride_out_k,
    )

    input_torch = hin.unsqueeze(0).expand(num_terms, hin.shape[0], hin.shape[1])
    grad_S1s_torch = input_torch.bmm(g_torch_unsqueezed.bmm(U1s))

    # check if the outputs match
    outputs_match = torch.allclose(grad_S1s_torch, grad_S1s)
    print(f"backward grad_S1s and grad_S1s_torch Outputs match: {outputs_match}")

    if not outputs_match:
        max_diff = torch.max(torch.abs(grad_S1s_torch - grad_S1s))
        print(f"Max difference: {max_diff.item()}")
        print(f"grad_S1s_torch: {grad_S1s_torch}")
        print(f"grad_S1s: {grad_S1s}")

    # first_pass_U2s_hin
    ####################
    L, K, d2 = U2s.shape
    _, BSIZE = hin.shape

    U2s_hin = torch.empty((L, K, BSIZE), dtype=torch.float32, device=device)

    stride_hin_d2, stride_hin_BSIZE = hin.stride(0), hin.stride(1)
    stride_su_l, stride_su_K, stride_su_d2 = U2s.stride(0), U2s.stride(1), U2s.stride(2)
    stride_out_l, stride_out_K, stride_out_BSIZE = (
        U2s_hin.stride(0),
        U2s_hin.stride(1),
        U2s_hin.stride(2),
    )

    grid = lambda META: (
        L,
        triton.cdiv(K, META["BLOCK_SIZE_K"])
        * triton.cdiv(BSIZE, META["BLOCK_SIZE_BSIZE"]),
    )

    wrap_triton(first_pass_U2s_hin_kernel)[grid](
        hin,
        U2s,
        U2s_hin,
        K,
        d2,
        BSIZE,
        L,
        stride_hin_d2,
        stride_hin_BSIZE,
        stride_su_l,
        stride_su_K,
        stride_su_d2,
        stride_out_l,
        stride_out_K,
        stride_out_BSIZE,
    )

    U2s_hin_torch = U2s.bmm(input_torch)

    # check if the outputs match
    outputs_match = torch.allclose(U2s_hin_torch, U2s_hin)
    print(f"backward U2s_hin and U2s_hin_torch Outputs match: {outputs_match}")
    if not outputs_match:
        max_diff = torch.max(torch.abs(U2s_hin_torch - U2s_hin))
        print(f"Max difference: {max_diff.item()}")
        print(f"U2s_hin_torch: {U2s_hin_torch}")
        print(f"U2s_hin: {U2s_hin}")

    # calc_grad_S2s
    ###############
    L, K, BSIZE = U2s_hin.shape
    _, d1 = g.shape

    grad_S2s = torch.empty((L, K, d1), dtype=torch.float32, device=device)

    stride_g_BSIZE, stride_g_d1 = g.stride(0), g.stride(1)
    stride_su_l, stride_su_K, stride_su_BSIZE = (
        U2s_hin.stride(0),
        U2s_hin.stride(1),
        U2s_hin.stride(2),
    )
    stride_out_l, stride_out_K, stride_out_d1 = (
        grad_S2s.stride(0),
        grad_S2s.stride(1),
        grad_S2s.stride(2),
    )

    grid = lambda META: (
        L,
        triton.cdiv(K, META["BLOCK_SIZE_K"]) * triton.cdiv(d1, META["BLOCK_SIZE_d1"]),
    )

    wrap_triton(calc_grad_S2s_kernel)[grid](
        g,
        U2s_hin,
        grad_S2s,
        K,
        BSIZE,
        d1,
        L,
        stride_g_BSIZE,
        stride_g_d1,
        stride_su_l,
        stride_su_K,
        stride_su_BSIZE,
        stride_out_l,
        stride_out_K,
        stride_out_d1,
    )

    grad_S2s_torch = (U2s.bmm(input_torch)).bmm(g_torch_unsqueezed)

    outputs_match = torch.allclose(grad_S2s_torch, grad_S2s)
    print(f"backward grad_S2s and grad_S2s_torch Outputs match: {outputs_match}")

    if not outputs_match:
        max_diff = torch.max(torch.abs(grad_S2s_torch - grad_S2s))
        print(f"Max difference: {max_diff.item()}")
        print(f"grad_S2s_torch: {grad_S2s_torch}")
        print(f"grad_S2s: {grad_S2s}")

    return [grad, grad_S1s, grad_S2s, g.sum(0) / (2 * num_terms)]


@backward_op.register_kernel("cpu")
def _(input, S1s, S2s, U1s, U2s, grad_output):
    num_terms = S2s.shape[0]
    g = grad_output / (2 * num_terms)
    g = g.unsqueeze(0).expand(num_terms, g.shape[0], g.shape[1])
    input = (
        input.unsqueeze(0)
        .expand(num_terms, input.shape[0], input.shape[1])
        .transpose(1, 2)
    )
    U1s = U1s.transpose(1, 2)
    S1s = S1s.transpose(1, 2)
    U2s = U2s.transpose(1, 2)
    S2s = S2s.transpose(1, 2)
    t1 = g.bmm(U1s)
    grad = t1.bmm(S1s).sum(0) + g.bmm(S2s).bmm(U2s).sum(0)
    grad_S2s = (U2s.bmm(input)).bmm(g)
    grad_S1s = input.bmm(g.bmm(U1s))

    g = g[0]
    return [
        grad,
        grad_S1s,
        grad_S2s,
        # sum g on batch dimension input.shape[0]
        g.reshape(input.shape[2], -1).sum(0),
    ]


class SketchedLinearFunction_triton(Function):
    # Note that forward, setup_context, and backward are @staticmethods
    @staticmethod
    def forward(
        input: torch.Tensor,
        S1s: torch.Tensor,
        S2s: torch.Tensor,
        U1s: torch.Tensor,
        U2s: torch.Tensor,
        bias: torch.Tensor,
    ):
        return forward_op(input, S1s, S2s, U1s, U2s, bias)

    @staticmethod
    # inputs is a Tuple of all of the inputs passed to forward.
    # output is the output of the forward().
    def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any):
        input, S1s, S2s, U1s, U2s, bias = inputs
        ctx.save_for_backward(input, S1s, S2s, U1s, U2s, bias)

    @staticmethod
    def backward(ctx: Any, *grad_output: Any) -> Any:
        # dl/dS2_i = U1_i g h_in^T / 2 * l
        # dl/dS1_i = g h_in^T U2_i^T / 2 * l
        # dl/dh_in = 1/(2*l) * (sum_{i=1}^{l} (S1_i^T U1_i g) + sum_{i=1}^{l} (U2_i^T S2_i g))
        # dl/db = g
        hin, S1s, S2s, U1s, U2s, _ = ctx.saved_tensors
        grad_input, grad_S1s, grad_S2s, grad_bias = backward_op(
            hin, S1s, S2s, U1s, U2s, grad_output[0]
        )
        return grad_input, grad_S1s, grad_S2s, None, None, grad_bias


class SKLinear_triton(nn.Module):
    __constants__ = ["in_features", "out_features", "num_terms", "low_rank"]
    in_features: int
    out_features: int
    num_terms: int
    low_rank: int
    S1s: torch.Tensor
    S2s: torch.Tensor
    U1s: torch.Tensor
    U2s: torch.Tensor

    def __init__(
        self,
        in_features: int,
        out_features: int,
        num_terms: int,
        low_rank: int,
        W_init=None,
        bias: bool = True,
        dtype=None,
        device=None,
    ):
        factory_kwargs = {"dtype": dtype, "device": device}
        super(SKLinear_triton, self).__init__()

        # if (
        #     2 * num_terms * low_rank * (out_features + in_features)
        #     > out_features * in_features
        # ):
        #     raise ValueError(
        #         "The number of parameters in the sketching layer is larger "
        #         + "than the number of parameters in the fully connected layer."
        #     )

        self.num_terms = num_terms  # l
        self.low_rank = low_rank  # k
        self.out_features = out_features
        self.in_features = in_features

        # Register U1s and U2s as buffers since they are not learnable
        self.register_buffer(
            "U1s",
            torch.stack(
                [
                    gen_U(low_rank, out_features, **factory_kwargs)
                    for _ in range(num_terms)
                ]
            ),
        )  # k(low rank)xd1(out) stacked along the zeros dimension (l) -> l x k x d1
        self.register_buffer(
            "U2s",
            torch.stack(
                [
                    gen_U(in_features, low_rank, **factory_kwargs)
                    for _ in range(num_terms)
                ]
            ),
        )  # d2xk stacked along the zeros dimension (l) -> l x d2 x k

        # W is used to only initialize S
        if W_init is None:
            W = torch.empty(in_features, out_features, **factory_kwargs)  # d2 * d1
            init.kaiming_uniform_(W, a=math.sqrt(5))
        else:
            W = W_init.T.detach().clone()

        # S1s and S2s are precomputed but not updated in the backward pass
        self.S1s = nn.Parameter(
            torch.stack([torch.matmul(W, self.U1s[i].T) for i in range(num_terms)])
        )  # d2xk stacked along the zeros dimension (l) -> l x d2 x k
        self.S2s = nn.Parameter(
            torch.stack([torch.matmul(self.U2s[i].T, W) for i in range(num_terms)])
        )  # kxd1 stacked along the zeros dimension (l) -> l x k x d1

        # Bias term initialized with a small standard deviation
        if bias:
            self.bias = nn.Parameter(
                torch.empty(out_features, **factory_kwargs)
            )  # 1 * d1
            fan_in, _ = init._calculate_fan_in_and_fan_out(W)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            init.uniform_(self.bias, -bound, bound)
        else:
            self.register_parameter("bias", None)

    def forward(self, h_in):
        # TODO: Make sure all the things are contiguos
        return SketchedLinearFunction_triton.apply(
            h_in, self.S1s, self.S2s, self.U1s, self.U2s, self.bias
        )

In [68]:
# Test function for SKLinear_triton with float32
torch.manual_seed(42)

# Parameters
batch_size = 32
in_features = 256
out_features = 256
num_terms = 3
low_rank = 8

# Create input tensor with float32 dtype
x = torch.randn(batch_size, in_features, dtype=torch.float32, device="cuda")

# Create model with float32 dtype
model = SKLinear_triton(
    in_features=in_features,
    out_features=out_features,
    num_terms=num_terms,
    low_rank=low_rank,
    dtype=torch.float32,
    device="cuda",
)

# Forward pass
output = model(x)

torch.cuda.synchronize()
print("backward----------------------\n\n\n\n\n\n")

# # Test backward pass
loss = output.sum()
loss.backward()

forward in1 and in1_torch Outputs match: True
forward in2 and in2_torch Outputs match: True
forward out and torch_out Outputs match: False
Max difference: 1.430511474609375e-06
torch_out: tensor([[ 0.0104,  1.0086,  0.6721,  ..., -0.3649,  0.1341,  0.5287],
        [-0.8548, -0.6688,  0.2248,  ...,  0.2155,  1.3069,  2.5197],
        [ 0.4949,  0.1686,  0.1352,  ..., -1.8021, -1.7875, -0.9252],
        ...,
        [ 0.4069, -0.9734, -0.0315,  ..., -0.0174,  1.1333,  0.9707],
        [-0.7300, -0.3609, -1.6043,  ..., -0.4308, -0.9627, -0.7490],
        [-0.6174, -0.3168,  0.8720,  ...,  0.3366,  0.1846,  0.9859]],
       device='cuda:0')
out: tensor([[ 0.0104,  1.0086,  0.6721,  ..., -0.3649,  0.1341,  0.5287],
        [-0.8548, -0.6688,  0.2248,  ...,  0.2155,  1.3069,  2.5197],
        [ 0.4949,  0.1686,  0.1352,  ..., -1.8021, -1.7875, -0.9252],
        ...,
        [ 0.4069, -0.9734, -0.0315,  ..., -0.0174,  1.1333,  0.9707],
        [-0.7300, -0.3609, -1.6043,  ..., -0.4308, -0.96