In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("github_repos_wildcard")

In [None]:
import os
import shutil

GITHUB_TOKEN = secret_value_0
USER = "gaserSami"
CLONE_URL = f"https://{USER}:{GITHUB_TOKEN}@github.com/{USER}/panther.git"
get_ipython().system(f"git clone --branch torch_compile {CLONE_URL}")

# import sys
# sys.path.append("panther")

In [None]:
!pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu118

In [None]:
import torch
print(torch.__version__)
import triton
print(triton.__version__)

In [None]:
!export LC_ALL="en_US.UTF-8"
!export LD_LIBRARY_PATH="/usr/lib64-nvidia"
!export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
!ldconfig /usr/lib64-nvidia

In [None]:
%%writefile /kaggle/working/panther/panther/nn/linear_kernels/forward.py
import torch
import triton
import triton.language as tl

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_BSIZE': 32, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_D2': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=1, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_D2': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_D2': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_D2': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_D2': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_D2': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_D2': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=5, num_warps=2),
        triton.Config({'BLOCK_SIZE_BSIZE': 32, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_D2': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=5, num_warps=2),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_D2': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE_BSIZE': 256, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_D2': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE_BSIZE': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_D2': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_D2': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_D2': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_D2': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_D2': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_D2': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4)
    ],
    key=['BSIZE', 'K', 'd2', 'L'],
)
@triton.jit
def first_pass_kernel(
        hin_ptr, S1s_ptr, U2s_ptr, out1_ptr, out2_ptr,
        BSIZE, K, d2, L,
        stride_hin_bsize, stride_hin_d2,
        stride_su_l, stride_su_d2, stride_su_k,
        stride_out_l, stride_out_bsize, stride_out_k,
        BLOCK_SIZE_BSIZE: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_D2: tl.constexpr,
        GROUP_SIZE_BSIZE: tl.constexpr
):
    pid = tl.program_id(axis=1)
    batch_id = tl.program_id(axis=0)
    
    num_pid_bsize = tl.cdiv(BSIZE, BLOCK_SIZE_BSIZE)
    num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
    num_pid_in_group = GROUP_SIZE_BSIZE * num_pid_k
    group_id = pid // num_pid_in_group
    first_pid_bsize = group_id * GROUP_SIZE_BSIZE
    group_size_bsize = min(num_pid_bsize - first_pid_bsize, GROUP_SIZE_BSIZE)
    pid_bsize = first_pid_bsize + ((pid % num_pid_in_group) % group_size_bsize)
    pid_k = (pid % num_pid_in_group) // group_size_bsize

    offs_bsize = pid_bsize * BLOCK_SIZE_BSIZE + tl.arange(0, BLOCK_SIZE_BSIZE)
    offs_k = pid_k *  BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
    offs_d2 = tl.arange(0, BLOCK_SIZE_D2)

    offs_bsize = tl.max_contiguous(tl.multiple_of(offs_bsize, BLOCK_SIZE_BSIZE), BLOCK_SIZE_BSIZE)
    offs_k = tl.max_contiguous(tl.multiple_of(offs_k, BLOCK_SIZE_K), BLOCK_SIZE_K)
    offs_d2 = tl.max_contiguous(tl.multiple_of(offs_d2, BLOCK_SIZE_D2), BLOCK_SIZE_D2)
    
    hin_ptrs = hin_ptr + (offs_bsize[:, None] * stride_hin_bsize + offs_d2[None, :] * stride_hin_d2)

    su_tmp = batch_id * stride_su_l + (offs_d2[:, None] * stride_su_d2 + offs_k[None, :] * stride_su_k)
    S1s_ptrs = S1s_ptr + su_tmp
    U2s_ptrs = U2s_ptr + su_tmp

    accumulator1 = tl.full(shape=(BLOCK_SIZE_BSIZE, BLOCK_SIZE_K), value=0.0, dtype=tl.float32)
    accumulator2 = tl.full(shape=(BLOCK_SIZE_BSIZE, BLOCK_SIZE_K), value=0.0, dtype=tl.float32)
    
    for d2_i in range(0, tl.cdiv(d2, BLOCK_SIZE_D2)):
        hin_mask = (offs_bsize[:, None] < BSIZE) & (offs_d2[None, :] < d2 - d2_i * BLOCK_SIZE_D2)
        hin = tl.load(hin_ptrs, mask=hin_mask, other=0.0)
        
        su_mask = (offs_d2[:, None] < d2 - d2_i * BLOCK_SIZE_D2) & (offs_k[None, :] < K)
        S1s = tl.load(S1s_ptrs, mask=su_mask, other=0.0)
        U2s = tl.load(U2s_ptrs, mask=su_mask, other=0.0)
        
        accumulator1 += tl.dot(hin, S1s, input_precision="ieee")
        accumulator2 += tl.dot(hin, U2s, input_precision="ieee")
        
        hin_ptrs += BLOCK_SIZE_D2 * stride_hin_d2
        S1s_ptrs += BLOCK_SIZE_D2 * stride_su_d2
        U2s_ptrs += BLOCK_SIZE_D2 * stride_su_d2

    out_tmp = batch_id * stride_out_l + stride_out_bsize * offs_bsize[:, None] + stride_out_k * offs_k[None, :]
    out1_ptrs = out1_ptr + out_tmp
    out2_ptrs = out2_ptr + out_tmp
    
    out_mask = (offs_bsize[:, None] < BSIZE) & (offs_k[None, :] < K)
    
    tl.store(out1_ptrs, accumulator1, mask=out_mask)
    tl.store(out2_ptrs, accumulator2, mask=out_mask)
  
@triton.autotune(
    configs = [
        triton.Config({'BLOCK_SIZE_BSIZE': 32, 'BLOCK_SIZE_D1': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=1, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_D1': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_D1': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_D1': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_D1': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_D1': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_D1': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=5, num_warps=2),
        triton.Config({'BLOCK_SIZE_BSIZE': 32, 'BLOCK_SIZE_D1': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=5, num_warps=2),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_D1': 128, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_BSIZE': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE_BSIZE': 256, 'BLOCK_SIZE_D1': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE_BSIZE': 256, 'BLOCK_SIZE_D1': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_D1': 128, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_D1': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_D1': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_D1': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_D1': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4)
    ],
    key=['BSIZE', 'd1', 'K', 'L'],
)
@triton.jit
def second_pass_kernel(
        in1_ptr, in2_ptr, U1s_ptr, S2s_ptr, bias_ptr, out_ptr,
        BSIZE, d1, K, L,
        stride_in12_l, stride_in12_bsize, stride_in12_k,
        stride_us_l, stride_us_k, stride_us_d1,
        stride_bias_bsize, stride_bias_d1,
        stride_out_bsize, stride_out_d1,
        BLOCK_SIZE_BSIZE: tl.constexpr, BLOCK_SIZE_D1: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
        GROUP_SIZE_BSIZE: tl.constexpr
):
    pid = tl.program_id(axis=0)
    
    num_pid_bsize = tl.cdiv(BSIZE, BLOCK_SIZE_BSIZE)
    num_pid_d1 = tl.cdiv(d1, BLOCK_SIZE_D1)
    num_pid_in_group = GROUP_SIZE_BSIZE * num_pid_d1
    group_id = pid // num_pid_in_group
    first_pid_bsize = group_id * GROUP_SIZE_BSIZE
    GROUP_SIZE_BSIZE = min(num_pid_bsize - first_pid_bsize, GROUP_SIZE_BSIZE)
    pid_bsize = first_pid_bsize + ((pid % num_pid_in_group) % GROUP_SIZE_BSIZE)
    pid_d1 = (pid % num_pid_in_group) // GROUP_SIZE_BSIZE

    offs_bsize = pid_bsize * BLOCK_SIZE_BSIZE + tl.arange(0, BLOCK_SIZE_BSIZE)
    offs_d1 = pid_d1 *  BLOCK_SIZE_D1 + tl.arange(0, BLOCK_SIZE_D1)
    offs_k = tl.arange(0, BLOCK_SIZE_K)

    offs_bsize = tl.max_contiguous(tl.multiple_of(offs_bsize, BLOCK_SIZE_BSIZE), BLOCK_SIZE_BSIZE)
    offs_d1 = tl.max_contiguous(tl.multiple_of(offs_d1, BLOCK_SIZE_D1), BLOCK_SIZE_D1)
    offs_k = tl.max_contiguous(tl.multiple_of(offs_k, BLOCK_SIZE_K), BLOCK_SIZE_K)

    in_tmp = offs_bsize[:, None] * stride_in12_bsize + offs_k[None, :] * stride_in12_k
    us_tmp = offs_k[:, None] * stride_us_k + offs_d1[None, :] * stride_us_d1

    accumulator = tl.full(shape=(BLOCK_SIZE_BSIZE, BLOCK_SIZE_D1), value=0.0, dtype=tl.float32)
    
    for l in range(0, L):
        l_in_offset = l * stride_in12_l
        l_us_offset = l * stride_us_l
        
        in1_ptrs = in1_ptr + l_in_offset + in_tmp
        in2_ptrs = in2_ptr + l_in_offset + in_tmp
    
        U1s_ptrs = U1s_ptr + l_us_offset + us_tmp
        S2s_ptrs = S2s_ptr + l_us_offset + us_tmp
        
        for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
            in_mask = offs_k[None, :] < K - k * BLOCK_SIZE_K
            in1 = tl.load(in1_ptrs, mask=in_mask, other=0.0)
            in2 = tl.load(in2_ptrs, mask=in_mask, other=0.0)
            
            us_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
            U1s = tl.load(U1s_ptrs, mask=us_mask, other=0.0)
            S2s = tl.load(S2s_ptrs, mask=us_mask, other=0.0)
            
            accumulator += tl.dot(in1, U1s, input_precision="ieee")
            accumulator += tl.dot(in2, S2s, input_precision="ieee")

            in_inc = BLOCK_SIZE_K * stride_in12_k
            in1_ptrs += in_inc
            in2_ptrs += in_inc
            
            us_inc = BLOCK_SIZE_K * stride_us_k
            U1s_ptrs += us_inc
            S2s_ptrs += us_inc
    
    bias_ptrs = bias_ptr + offs_d1[None, :] * stride_bias_d1
    bias_mask = (offs_d1[None, :] < d1)
    bias = tl.load(bias_ptrs, mask=bias_mask, other=0.0)

    accumulator *= (1.0/ (2.0 * L))
    accumulator += bias

    out_ptrs = out_ptr + stride_out_bsize * offs_bsize[:, None] + stride_out_d1 * offs_d1[None, :]
    out_mask = (offs_bsize[:, None] < BSIZE) & (offs_d1[None, :] < d1)
    
    tl.store(out_ptrs, accumulator, mask=out_mask)

In [None]:
%%writefile /kaggle/working/panther/panther/nn/linear_kernels/backward.py
import torch
import triton
import triton.language as tl

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_BSIZE': 32, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_d1': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=1, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_d1': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_d1': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_d1': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_d1': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_d1': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_d1': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=5, num_warps=2),
        triton.Config({'BLOCK_SIZE_BSIZE': 32, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_d1': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=5, num_warps=2),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_d1': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE_BSIZE': 256, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_d1': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE_BSIZE': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_d1': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_d1': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_d1': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_d1': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_d1': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_d1': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4)
    ],
    key=['BSIZE', 'K', 'd1', 'L'],
)
@triton.jit
def first_pass_gU1s_g_S2s_kernel(
        g_ptr, U1s_ptr, S2s_ptr, g_U1s_ptr, g_S2s_ptr,
        BSIZE, K, d1, L,
        stride_g_bsize, stride_g_d1,
        stride_su_l, stride_su_d1, stride_su_k,
        stride_out_l, stride_out_bsize, stride_out_k,
        BLOCK_SIZE_BSIZE: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_d1: tl.constexpr,
        GROUP_SIZE_BSIZE: tl.constexpr
):
    pid = tl.program_id(axis=1)
    batch_id = tl.program_id(axis=0)
    
    num_pid_bsize = tl.cdiv(BSIZE, BLOCK_SIZE_BSIZE)
    num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
    num_pid_in_group = GROUP_SIZE_BSIZE * num_pid_k
    group_id = pid // num_pid_in_group
    first_pid_bsize = group_id * GROUP_SIZE_BSIZE
    group_size_bsize = min(num_pid_bsize - first_pid_bsize, GROUP_SIZE_BSIZE)
    pid_bsize = first_pid_bsize + ((pid % num_pid_in_group) % group_size_bsize)
    pid_k = (pid % num_pid_in_group) // group_size_bsize

    offs_bsize = pid_bsize * BLOCK_SIZE_BSIZE + tl.arange(0, BLOCK_SIZE_BSIZE)
    offs_k = pid_k *  BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
    offs_d1 = tl.arange(0, BLOCK_SIZE_d1)

    g_ptrs = g_ptr + (offs_bsize[:, None] * stride_g_bsize + offs_d1[None, :] * stride_g_d1)

    su_tmp = batch_id * stride_su_l + (offs_d1[:, None] * stride_su_d1 + offs_k[None, :] * stride_su_k)
    U1s_ptrs = U1s_ptr + su_tmp
    S2s_ptrs = S2s_ptr + su_tmp

    accumulator1 = tl.full(shape=(BLOCK_SIZE_BSIZE, BLOCK_SIZE_K), value=0.0, dtype=tl.float32)
    accumulator2 = tl.full(shape=(BLOCK_SIZE_BSIZE, BLOCK_SIZE_K), value=0.0, dtype=tl.float32)
    
    
    for d1_i in range(0, tl.cdiv(d1, BLOCK_SIZE_d1)):
        g = tl.load(g_ptrs, mask=(offs_d1[None, :] < d1 - d1_i * BLOCK_SIZE_d1), other=0.0)
        
        su_mask = (offs_d1[:, None] < d1 - d1_i * BLOCK_SIZE_d1)
        U1s = tl.load(U1s_ptrs, mask=su_mask, other=0.0)
        S2s = tl.load(S2s_ptrs, mask=su_mask, other=0.0)
        
        accumulator1 += tl.dot(g, U1s, input_precision="ieee")
        accumulator2 += tl.dot(g, S2s, input_precision="ieee")
        
        g_ptrs += BLOCK_SIZE_d1 * stride_g_d1
        U1s_ptrs += BLOCK_SIZE_d1 * stride_su_d1
        S2s_ptrs += BLOCK_SIZE_d1 * stride_su_d1

    out_tmp = batch_id * stride_out_l + stride_out_bsize * offs_bsize[:, None] + stride_out_k * offs_k[None, :]
    g_U1s_ptrs = g_U1s_ptr + out_tmp
    g_S2s_ptrs = g_S2s_ptr + out_tmp
    
    out_mask = (offs_bsize[:, None] < BSIZE) & (offs_k[None, :] < K)
    
    tl.store(g_U1s_ptrs, accumulator1, mask=out_mask)
    tl.store(g_S2s_ptrs, accumulator2, mask=out_mask)

  
@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_BSIZE': 32, 'BLOCK_SIZE_d2': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=1, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_d2': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_d2': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_d2': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_d2': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_d2': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_d2': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=5, num_warps=2),
        triton.Config({'BLOCK_SIZE_BSIZE': 32, 'BLOCK_SIZE_d2': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_BSIZE': 8}, num_stages=5, num_warps=2),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_d2': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE_BSIZE': 256, 'BLOCK_SIZE_d2': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE_BSIZE': 256, 'BLOCK_SIZE_d2': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_d2': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_d2': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_d2': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_d2': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_d2': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_BSIZE': 8}, num_stages=4, num_warps=4)
    ],
    key=['BSIZE', 'd2', 'K', 'L'],
)
@triton.jit
def second_pass_gUS11_22_kernel(
        g_U1s_ptr, g_S2s_ptr, S1s_ptr, U2s_ptr, out_ptr,
        BSIZE, d2, K, L,
        stride_g_U1s2_l, stride_g_U1s2_bsize, stride_g_U1s2_k,
        stride_us_l, stride_us_k, stride_us_d2,
        stride_out_bsize, stride_out_d2,
        BLOCK_SIZE_BSIZE: tl.constexpr, BLOCK_SIZE_d2: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
        GROUP_SIZE_BSIZE: tl.constexpr
):
    pid = tl.program_id(axis=0)
    
    num_pid_bsize = tl.cdiv(BSIZE, BLOCK_SIZE_BSIZE)
    num_pid_d2 = tl.cdiv(d2, BLOCK_SIZE_d2)
    num_pid_in_group = GROUP_SIZE_BSIZE * num_pid_d2
    group_id = pid // num_pid_in_group
    first_pid_bsize = group_id * GROUP_SIZE_BSIZE
    GROUP_SIZE_BSIZE = min(num_pid_bsize - first_pid_bsize, GROUP_SIZE_BSIZE)
    pid_bsize = first_pid_bsize + ((pid % num_pid_in_group) % GROUP_SIZE_BSIZE)
    pid_d2 = (pid % num_pid_in_group) // GROUP_SIZE_BSIZE

    offs_bsize = pid_bsize * BLOCK_SIZE_BSIZE + tl.arange(0, BLOCK_SIZE_BSIZE)
    offs_d2 = pid_d2 *  BLOCK_SIZE_d2 + tl.arange(0, BLOCK_SIZE_d2)
    offs_k = tl.arange(0, BLOCK_SIZE_K)

    in_tmp = offs_bsize[:, None] * stride_g_U1s2_bsize + offs_k[None, :] * stride_g_U1s2_k
    us_tmp = offs_k[:, None] * stride_us_k + offs_d2[None, :] * stride_us_d2

    accumulator = tl.full(shape=(BLOCK_SIZE_BSIZE, BLOCK_SIZE_d2), value=0.0, dtype=tl.float32)
    
    for l in range(0, L):
        g_l_offset = l * stride_g_U1s2_l  # Offset for g_U1s and g_S2s
        s_l_offset = l * stride_us_l      # Offset for S1s and U2s

        g_U1s_ptrs = g_U1s_ptr + g_l_offset + in_tmp
        g_S2s_ptrs = g_S2s_ptr + g_l_offset + in_tmp

        S1s_ptrs = S1s_ptr + s_l_offset + us_tmp
        U2s_ptrs = U2s_ptr + s_l_offset + us_tmp
        
        for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
            in_mask = offs_k[None, :] < K - k * BLOCK_SIZE_K
            g_U1s = tl.load(g_U1s_ptrs, mask=in_mask, other=0.0)
            g_S2s = tl.load(g_S2s_ptrs, mask=in_mask, other=0.0)
            
            us_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
            S1s = tl.load(S1s_ptrs, mask=us_mask, other=0.0)
            U2s = tl.load(U2s_ptrs, mask=us_mask, other=0.0)
            
            accumulator += tl.dot(g_U1s, S1s, input_precision="ieee")
            accumulator += tl.dot(g_S2s, U2s, input_precision="ieee")

            in_inc = BLOCK_SIZE_K * stride_g_U1s2_k
            g_U1s_ptrs += in_inc
            g_S2s_ptrs += in_inc
            
            us_inc = BLOCK_SIZE_K * stride_us_k
            S1s_ptrs += us_inc
            U2s_ptrs += us_inc
    
    # accumulator *= (1.0/ (2.0 * L))

    out_ptrs = out_ptr + stride_out_bsize * offs_bsize[:, None] + stride_out_d2 * offs_d2[None, :]
    out_mask = (offs_bsize[:, None] < BSIZE) & (offs_d2[None, :] < d2)
    
    tl.store(out_ptrs, accumulator, mask=out_mask)

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_d2': 32, 'BLOCK_SIZE_k': 32, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_d2': 8}, num_stages=1, num_warps=4),
        triton.Config({'BLOCK_SIZE_d2': 64, 'BLOCK_SIZE_k': 256, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_d2': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_d2': 128, 'BLOCK_SIZE_k': 128, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_d2': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_d2': 128, 'BLOCK_SIZE_k': 64, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_d2': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_d2': 64, 'BLOCK_SIZE_k': 128, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_d2': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_d2': 128, 'BLOCK_SIZE_k': 32, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_d2': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_d2': 64, 'BLOCK_SIZE_k': 32, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_d2': 8}, num_stages=5, num_warps=2),
        triton.Config({'BLOCK_SIZE_d2': 32, 'BLOCK_SIZE_k': 64, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_d2': 8}, num_stages=5, num_warps=2),
        triton.Config({'BLOCK_SIZE_d2': 128, 'BLOCK_SIZE_k': 256, 'BLOCK_SIZE_BSIZE': 128, 'GROUP_SIZE_d2': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE_d2': 256, 'BLOCK_SIZE_k': 128, 'BLOCK_SIZE_BSIZE': 128, 'GROUP_SIZE_d2': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE_d2': 256, 'BLOCK_SIZE_k': 64, 'BLOCK_SIZE_BSIZE': 128, 'GROUP_SIZE_d2': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_d2': 64, 'BLOCK_SIZE_k': 256, 'BLOCK_SIZE_BSIZE': 128, 'GROUP_SIZE_d2': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_d2': 128, 'BLOCK_SIZE_k': 128, 'BLOCK_SIZE_BSIZE': 128, 'GROUP_SIZE_d2': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_d2': 128, 'BLOCK_SIZE_k': 64, 'BLOCK_SIZE_BSIZE': 64, 'GROUP_SIZE_d2': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_d2': 64, 'BLOCK_SIZE_k': 128, 'BLOCK_SIZE_BSIZE': 64, 'GROUP_SIZE_d2': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_d2': 128, 'BLOCK_SIZE_k': 32, 'BLOCK_SIZE_BSIZE': 64, 'GROUP_SIZE_d2': 8}, num_stages=4, num_warps=4)
    ],
    key=['d2', 'K', 'BSIZE', 'L'],
)
@triton.jit
def calc_grad_S1s_kernel(
        hin_ptr, g_U1s_ptr, grad_g_S1s_ptr,
        d2, k, BSIZE, L,
        stride_hin_bsize, stride_hin_BSIZE,
        stride_su_l, stride_su_BSIZE, stride_su_k,
        stride_out_l, stride_out_bsize, stride_out_k,
        BLOCK_SIZE_d2: tl.constexpr, BLOCK_SIZE_k: tl.constexpr, BLOCK_SIZE_BSIZE: tl.constexpr,
        GROUP_SIZE_d2: tl.constexpr
):
    pid = tl.program_id(axis=1)
    batch_id = tl.program_id(axis=0)
    
    num_pid_bsize = tl.cdiv(d2, BLOCK_SIZE_d2)
    num_pid_k = tl.cdiv(k, BLOCK_SIZE_k)
    num_pid_in_group = GROUP_SIZE_d2 * num_pid_k
    group_id = pid // num_pid_in_group
    first_pid_bsize = group_id * GROUP_SIZE_d2
    group_size_bsize = min(num_pid_bsize - first_pid_bsize, GROUP_SIZE_d2)
    pid_bsize = first_pid_bsize + ((pid % num_pid_in_group) % group_size_bsize)
    pid_k = (pid % num_pid_in_group) // group_size_bsize

    offs_bsize = pid_bsize * BLOCK_SIZE_d2 + tl.arange(0, BLOCK_SIZE_d2)
    offs_k = pid_k *  BLOCK_SIZE_k + tl.arange(0, BLOCK_SIZE_k)
    offs_BSIZE = tl.arange(0, BLOCK_SIZE_BSIZE)

    offs_bsize = tl.max_contiguous(tl.multiple_of(offs_bsize, BLOCK_SIZE_d2), BLOCK_SIZE_d2)
    offs_k = tl.max_contiguous(tl.multiple_of(offs_k, BLOCK_SIZE_k), BLOCK_SIZE_k)
    offs_BSIZE = tl.max_contiguous(tl.multiple_of(offs_BSIZE, BLOCK_SIZE_BSIZE), BLOCK_SIZE_BSIZE)
    
    hin_ptrs = hin_ptr + (offs_bsize[:, None] * stride_hin_bsize + offs_BSIZE[None, :] * stride_hin_BSIZE)

    su_tmp = batch_id * stride_su_l + (offs_BSIZE[:, None] * stride_su_BSIZE + offs_k[None, :] * stride_su_k)
    g_U1s_ptrs = g_U1s_ptr + su_tmp

    accumulator1 = tl.full(shape=(BLOCK_SIZE_d2, BLOCK_SIZE_k), value=0.0, dtype=tl.float32)
    accumulator2 = tl.full(shape=(BLOCK_SIZE_d2, BLOCK_SIZE_k), value=0.0, dtype=tl.float32)
    
    for BSIZE_i in range(0, tl.cdiv(BSIZE, BLOCK_SIZE_BSIZE)):
        hin_mask = (offs_bsize[:, None] < d2) & (offs_BSIZE[None, :] < BSIZE - BSIZE_i * BLOCK_SIZE_BSIZE)
        hin = tl.load(hin_ptrs, mask=hin_mask, other=0.0)
        
        su_mask = (offs_BSIZE[:, None] < BSIZE - BSIZE_i * BLOCK_SIZE_BSIZE) & (offs_k[None, :] < k)
        g_U1s = tl.load(g_U1s_ptrs, mask=su_mask, other=0.0)
        
        accumulator1 += tl.dot(hin, g_U1s, input_precision="ieee")
        
        hin_ptrs += BLOCK_SIZE_BSIZE * stride_hin_BSIZE
        g_U1s_ptrs += BLOCK_SIZE_BSIZE * stride_su_BSIZE

    out_tmp = batch_id * stride_out_l + stride_out_bsize * offs_bsize[:, None] + stride_out_k * offs_k[None, :]
    grad_g_S1s_ptrs = grad_g_S1s_ptr + out_tmp
    
    out_mask = (offs_bsize[:, None] < d2) & (offs_k[None, :] < k)
    
    tl.store(grad_g_S1s_ptrs, accumulator1, mask=out_mask)

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_BSIZE': 32, 'BLOCK_SIZE_d2': 32, 'GROUP_SIZE_K': 8}, num_stages=1, num_warps=4),
        triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_BSIZE': 256, 'BLOCK_SIZE_d2': 32, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_d2': 32, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_d2': 32, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_d2': 32, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_BSIZE': 32, 'BLOCK_SIZE_d2': 32, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_BSIZE': 32, 'BLOCK_SIZE_d2': 32, 'GROUP_SIZE_K': 8}, num_stages=5, num_warps=2),
        triton.Config({'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_d2': 32, 'GROUP_SIZE_K': 8}, num_stages=5, num_warps=2),
        triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_BSIZE': 256, 'BLOCK_SIZE_d2': 128, 'GROUP_SIZE_K': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_d2': 128, 'GROUP_SIZE_K': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_d2': 128, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_BSIZE': 256, 'BLOCK_SIZE_d2': 128, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_d2': 128, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_BSIZE': 64, 'BLOCK_SIZE_d2': 64, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_BSIZE': 128, 'BLOCK_SIZE_d2': 64, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_BSIZE': 32, 'BLOCK_SIZE_d2': 64, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4)
    ],
    key=['K', 'd2', 'BSIZE', 'L'],
)
@triton.jit
def first_pass_U2s_hin_kernel(
        hin_ptr, U2s_ptr, U2s_h_in_ptr,
        K, d2, BSIZE, L,
        stride_hin_d2, stride_hin_BSIZE,
        stride_su_l, stride_su_K, stride_su_d2,
        stride_out_l, stride_out_K, stride_out_BSIZE,
        BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_BSIZE: tl.constexpr, BLOCK_SIZE_d2: tl.constexpr,
        GROUP_SIZE_K: tl.constexpr
):
    pid = tl.program_id(axis=1)
    batch_id = tl.program_id(axis=0)
    
    num_pid_K = tl.cdiv(K, BLOCK_SIZE_K)
    num_pid_BSIZE = tl.cdiv(BSIZE, BLOCK_SIZE_BSIZE)
    num_pid_in_group = GROUP_SIZE_K * num_pid_BSIZE
    group_id = pid // num_pid_in_group
    first_pid_K = group_id * GROUP_SIZE_K
    group_size_BSIZE = min(num_pid_K - first_pid_K, GROUP_SIZE_K)
    pid_K = first_pid_K + ((pid % num_pid_in_group) % group_size_BSIZE)
    pid_BSIZE = (pid % num_pid_in_group) // group_size_BSIZE

    offs_K = pid_K * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
    offs_BSIZE = pid_BSIZE *  BLOCK_SIZE_BSIZE + tl.arange(0, BLOCK_SIZE_BSIZE)
    offs_d2 = tl.arange(0, BLOCK_SIZE_d2)

    offs_K = tl.max_contiguous(tl.multiple_of(offs_K, BLOCK_SIZE_K), BLOCK_SIZE_K)
    offs_BSIZE = tl.max_contiguous(tl.multiple_of(offs_BSIZE, BLOCK_SIZE_BSIZE), BLOCK_SIZE_BSIZE)
    offs_d2 = tl.max_contiguous(tl.multiple_of(offs_d2, BLOCK_SIZE_d2), BLOCK_SIZE_d2)
    
    hin_ptrs = hin_ptr + (offs_d2[:, None] * stride_hin_d2 + offs_BSIZE[None, :] * stride_hin_BSIZE)

    su_tmp = batch_id * stride_su_l + (offs_K[:, None] * stride_su_K + offs_d2[None, :] * stride_su_d2)
    U2s_ptrs = U2s_ptr + su_tmp

    accumulator1 = tl.full(shape=(BLOCK_SIZE_K, BLOCK_SIZE_BSIZE), value=0.0, dtype=tl.float32)
    
    for d2_i in range(0, tl.cdiv(d2, BLOCK_SIZE_d2)):
        hin_mask = (offs_d2[:, None] < d2 - d2_i * BLOCK_SIZE_d2) & (offs_BSIZE[None, :] < BSIZE)
        hin = tl.load(hin_ptrs, mask=hin_mask, other=0.0)
        
        su_mask = (offs_K[:, None] < K) & (offs_d2[None, :] < d2 - d2_i * BLOCK_SIZE_d2)
        U2s = tl.load(U2s_ptrs, mask=su_mask, other=0.0)
        
        accumulator1 += tl.dot(U2s, hin, input_precision="ieee")
        
        hin_ptrs += BLOCK_SIZE_d2 * stride_hin_d2
        U2s_ptrs += BLOCK_SIZE_d2 * stride_su_d2

    out_tmp = batch_id * stride_out_l + stride_out_K * offs_K[:, None] + stride_out_BSIZE * offs_BSIZE[None, :]
    U2s_h_in_ptrs = U2s_h_in_ptr + out_tmp
    
    out_mask = (offs_K[:, None] < K) & (offs_BSIZE[None, :] < BSIZE)
    
    tl.store(U2s_h_in_ptrs, accumulator1, mask=out_mask)

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_d1': 32, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_K': 8}, num_stages=1, num_warps=4),
        triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_d1': 256, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_d1': 128, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_d1': 64, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_d1': 128, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_d1': 32, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_d1': 32, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_K': 8}, num_stages=5, num_warps=2),
        triton.Config({'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_d1': 64, 'BLOCK_SIZE_BSIZE': 32, 'GROUP_SIZE_K': 8}, num_stages=5, num_warps=2),
        triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_d1': 256, 'BLOCK_SIZE_BSIZE': 128, 'GROUP_SIZE_K': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_d1': 128, 'BLOCK_SIZE_BSIZE': 128, 'GROUP_SIZE_K': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_d1': 64, 'BLOCK_SIZE_BSIZE': 128, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_d1': 256, 'BLOCK_SIZE_BSIZE': 128, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_d1': 128, 'BLOCK_SIZE_BSIZE': 128, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_d1': 64, 'BLOCK_SIZE_BSIZE': 64, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_d1': 128, 'BLOCK_SIZE_BSIZE': 64, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_d1': 32, 'BLOCK_SIZE_BSIZE': 64, 'GROUP_SIZE_K': 8}, num_stages=4, num_warps=4)
    ],
    key=['K', 'BSIZE', 'd1', 'L'],
)
@triton.jit
def calc_grad_S2s_kernel(
        g_ptr, U2s_hin_ptr, grad_S2s_ptr,
        K, BSIZE, d1, L,
        stride_g_BSIZE, stride_g_d1,
        stride_su_l, stride_su_K, stride_su_BSIZE,
        stride_out_l, stride_out_K, stride_out_d1,
        BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_d1: tl.constexpr, BLOCK_SIZE_BSIZE: tl.constexpr,
        GROUP_SIZE_K: tl.constexpr
):
    pid = tl.program_id(axis=1)
    batch_id = tl.program_id(axis=0)
    
    num_pid_K = tl.cdiv(K, BLOCK_SIZE_K)
    num_pid_d1 = tl.cdiv(d1, BLOCK_SIZE_d1)
    num_pid_in_group = GROUP_SIZE_K * num_pid_d1
    group_id = pid // num_pid_in_group
    first_pid_K = group_id * GROUP_SIZE_K
    group_size_d1 = min(num_pid_K - first_pid_K, GROUP_SIZE_K)
    pid_K = first_pid_K + ((pid % num_pid_in_group) % group_size_d1)
    pid_d1 = (pid % num_pid_in_group) // group_size_d1

    offs_K = pid_K * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
    offs_d1 = pid_d1 *  BLOCK_SIZE_d1 + tl.arange(0, BLOCK_SIZE_d1)
    offs_BSIZE = tl.arange(0, BLOCK_SIZE_BSIZE)

    offs_K = tl.max_contiguous(tl.multiple_of(offs_K, BLOCK_SIZE_K), BLOCK_SIZE_K)
    offs_d1 = tl.max_contiguous(tl.multiple_of(offs_d1, BLOCK_SIZE_d1), BLOCK_SIZE_d1)
    offs_BSIZE = tl.max_contiguous(tl.multiple_of(offs_BSIZE, BLOCK_SIZE_BSIZE), BLOCK_SIZE_BSIZE)
    
    g_ptrs = g_ptr + (offs_BSIZE[:, None] * stride_g_BSIZE + offs_d1[None, :] * stride_g_d1)

    su_tmp = batch_id * stride_su_l + (offs_K[:, None] * stride_su_K + offs_BSIZE[None, :] * stride_su_BSIZE)
    U2s_hin_ptrs = U2s_hin_ptr + su_tmp

    accumulator1 = tl.full(shape=(BLOCK_SIZE_K, BLOCK_SIZE_d1), value=0.0, dtype=tl.float32)
    
    for BSIZE_i in range(0, tl.cdiv(BSIZE, BLOCK_SIZE_BSIZE)):
        g_mask = (offs_BSIZE[:, None] < BSIZE - BSIZE_i * BLOCK_SIZE_BSIZE) & (offs_d1[None, :] < d1)
        g = tl.load(g_ptrs, mask=g_mask, other=0.0)
        
        su_mask = (offs_K[:, None] < K) & (offs_BSIZE[None, :] < BSIZE - BSIZE_i * BLOCK_SIZE_BSIZE)
        U2s_hin = tl.load(U2s_hin_ptrs, mask=su_mask, other=0.0)
        
        accumulator1 += tl.dot(U2s_hin, g, input_precision="ieee")
        
        g_ptrs += BLOCK_SIZE_BSIZE * stride_g_BSIZE
        U2s_hin_ptrs += BLOCK_SIZE_BSIZE * stride_su_BSIZE

    out_tmp = batch_id * stride_out_l + stride_out_K * offs_K[:, None] + stride_out_d1 * offs_d1[None, :]
    grad_S2s_ptrs = grad_S2s_ptr + out_tmp
    
    out_mask = (offs_K[:, None] < K) & (offs_d1[None, :] < d1)
    
    tl.store(grad_S2s_ptrs, accumulator1, mask=out_mask)

In [None]:
%%writefile /kaggle/working/panther/panther/nn/linear_tr.py
import math
from typing import Any, Tuple, List, Optional

import torch
import torch.nn as nn
from torch.autograd import Function
from torch.nn import init
import triton
from panther.random import scaled_sign_sketch as gen_U
from torch.library import triton_op, wrap_triton
from .linear_kernels import (
    first_pass_kernel,
    second_pass_kernel,
    first_pass_gU1s_g_S2s_kernel,
    second_pass_gUS11_22_kernel,
    calc_grad_S1s_kernel,
    first_pass_U2s_hin_kernel,
    calc_grad_S2s_kernel
)

@triton_op("panther::forward_op", mutates_args={})
def forward_op(hin: torch.Tensor, S1s: torch.Tensor, S2s: torch.Tensor, U1s: torch.Tensor, U2s: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
    device = 'cuda'
    
    # first pass
    ############
    L = S2s.shape[0]
    BSIZE, d2 = hin.shape
    L, _, K = S1s.shape
    
    in1 = torch.empty((L, BSIZE, K), dtype=torch.float32, device=device)
    in2 = torch.empty((L, BSIZE, K), dtype=torch.float32, device=device)
    
    stride_hin_bsize, stride_hin_d2 = hin.stride(0), hin.stride(1)
    stride_su_l, stride_su_d2, stride_su_k = S1s.stride(0), S1s.stride(1), S1s.stride(2)
    stride_out_l, stride_out_bsize, stride_out_k = in1.stride(0), in1.stride(1), in1.stride(2)
    
    grid = lambda META: (L, triton.cdiv(BSIZE, META["BLOCK_SIZE_BSIZE"]) * triton.cdiv(K, META["BLOCK_SIZE_K"]), )
    
    wrap_triton(first_pass_kernel)[grid](
        hin, S1s, U2s, in1, in2,
        BSIZE, K, d2, L,
        stride_hin_bsize, stride_hin_d2,
        stride_su_l, stride_su_d2, stride_su_k,
        stride_out_l, stride_out_bsize, stride_out_k,
    )

    in1 = in1.contiguous()
    in2 = in2.contiguous()

    # second pass
    #############
    bias_unsqueezed = bias.unsqueeze(0)
    L, BSIZE, K = in1.shape
    _, _, d1 = U1s.shape
    
    out = torch.empty((BSIZE, d1), dtype=torch.float32, device=device)

    stride_in12_l, stride_in12_bsize, stride_in12_k = in1.stride(0), in1.stride(1), in1.stride(2)
    stride_us_l, stride_us_k, stride_us_d1 = U1s.stride(0), U1s.stride(1), U1s.stride(2)
    stride_bias_bsize, stride_bias_d1 = bias_unsqueezed.stride(0), bias_unsqueezed.stride(1)
    stride_out_bsize, stride_out_d1 = out.stride(0), out.stride(1)
    
    grid = lambda META: (triton.cdiv(BSIZE, META["BLOCK_SIZE_BSIZE"]) * triton.cdiv(d1, META["BLOCK_SIZE_D1"]), )
    
    wrap_triton(second_pass_kernel)[grid](
        in1, in2, U1s, S2s, bias_unsqueezed, out,
        BSIZE, d1, K, L,
        stride_in12_l, stride_in12_bsize, stride_in12_k,
        stride_us_l, stride_us_k, stride_us_d1,
        stride_bias_bsize, stride_bias_d1,
        stride_out_bsize, stride_out_d1,
    )
    
    return out.contiguous()

@forward_op.register_kernel("cpu")
def _(input, S1s, S2s, U1s, U2s, bias):
    num_terms = S2s.shape[0]
    # Efficiently perform the sum over all l terms
    input = input.unsqueeze(0).expand(num_terms, input.shape[0], input.shape[1])
    return (
        ((input.bmm(S1s)).bmm(U1s)).mean(0) / 2
        + ((input.bmm(U2s)).bmm(S2s)).mean(0) / 2
        + bias
    )
    
@triton_op("panther::backward_op", mutates_args={})
def backward_op(hin: torch.Tensor, S1s: torch.Tensor, S2s: torch.Tensor, U1s: torch.Tensor, U2s: torch.Tensor, g: torch.Tensor) -> List[torch.Tensor]:
    device = 'cuda'
    num_terms = S2s.shape[0]
        
    # Transpose and make contiguous
    hin = hin.transpose(0, 1).contiguous()
    U1s = U1s.transpose(1, 2).contiguous()
    S1s = S1s.transpose(1, 2).contiguous()
    U2s = U2s.transpose(1, 2).contiguous()
    S2s = S2s.transpose(1, 2).contiguous()

    g = g.contiguous() / (2 * num_terms)
    
    # first_pass_gU1s_g_S2s
    #######################
    BSIZE, d1 = g.shape
    L, _, K = U1s.shape

    # TO UNREMOVE:
    g_U1s = torch.empty((L, BSIZE, K), dtype=torch.float32, device='cuda')
    g_S2s = torch.empty((L, BSIZE, K), dtype=torch.float32, device='cuda')

    # TO REMOVE:
    # g_U1s = torch.rand((L, BSIZE, K), dtype=torch.float32, device='cuda')
    # g_S2s = torch.rand((L, BSIZE, K), dtype=torch.float32, device='cuda')

    stride_g_bsize, stride_g_d1 = g.stride(0), g.stride(1)
    stride_su_l, stride_su_d1, stride_su_k = U1s.stride(0), U1s.stride(1), U1s.stride(2)
    stride_out_l, stride_out_bsize, stride_out_k = g_U1s.stride(0), g_U1s.stride(1), g_U1s.stride(2)
    
    grid = lambda META: (L, triton.cdiv(BSIZE, META["BLOCK_SIZE_BSIZE"]) * triton.cdiv(K, META["BLOCK_SIZE_K"]), )
    
    wrap_triton(first_pass_gU1s_g_S2s_kernel)[grid](
        g, U1s, S2s, g_U1s, g_S2s,
        BSIZE, K, d1, L,
        stride_g_bsize, stride_g_d1,
        stride_su_l, stride_su_d1, stride_su_k,
        stride_out_l, stride_out_bsize, stride_out_k
    )

    g_U1s = g_U1s.contiguous()
    g_S2s = g_S2s.contiguous()

    # second_pass_gUS11_22
    #######################
    L, BSIZE, K = g_U1s.shape
    _, _, d2 = S1s.shape
    
    grad = torch.empty((BSIZE, d2), dtype=torch.float32, device='cuda')

    stride_g_U1s2_l, stride_g_U1s2_bsize, stride_g_U1s2_k = g_U1s.stride(0), g_U1s.stride(1), g_U1s.stride(2)
    stride_us_l, stride_us_k, stride_us_d2 = S1s.stride(0), S1s.stride(1), S1s.stride(2)
    stride_out_bsize, stride_out_d2 = grad.stride(0), grad.stride(1)
    
    grid = lambda META: (triton.cdiv(BSIZE, META["BLOCK_SIZE_BSIZE"]) * triton.cdiv(d2, META["BLOCK_SIZE_d2"]), )
    
    wrap_triton(second_pass_gUS11_22_kernel)[grid](
        g_U1s, g_S2s, S1s, U2s, grad,
        BSIZE, d2, K, L,
        stride_g_U1s2_l, stride_g_U1s2_bsize, stride_g_U1s2_k,
        stride_us_l, stride_us_k, stride_us_d2,
        stride_out_bsize, stride_out_d2,
    )

    grad = grad.contiguous()

    # calc_grad_S1s
    ################
    d2, BSIZE = hin.shape
    L, _, k = g_U1s.shape
    
    grad_S1s = torch.empty((L, d2, k), dtype=torch.float32, device=device)

    stride_hin_bsize, stride_hin_BSIZE = hin.stride(0), hin.stride(1)
    stride_su_l, stride_su_BSIZE, stride_su_k = g_U1s.stride(0), g_U1s.stride(1), g_U1s.stride(2)
    stride_out_l, stride_out_bsize, stride_out_k = grad_S1s.stride(0), grad_S1s.stride(1), grad_S1s.stride(2)
    
    grid = lambda META: (L, triton.cdiv(d2, META["BLOCK_SIZE_d2"]) * triton.cdiv(k, META["BLOCK_SIZE_k"]), )
    
    wrap_triton(calc_grad_S1s_kernel)[grid](
        hin, g_U1s, grad_S1s,
        d2, k, BSIZE, L,
        stride_hin_bsize, stride_hin_BSIZE,
        stride_su_l, stride_su_BSIZE, stride_su_k,
        stride_out_l, stride_out_bsize, stride_out_k
    )

    grad_S1s = grad_S1s.contiguous()
    
    # first_pass_U2s_hin
    ####################
    L, K, d2 = U2s.shape
    _, BSIZE = hin.shape

    # TO UNREMOVE:
    U2s_hin = torch.empty((L, K, BSIZE), dtype=torch.float32, device=device)
    # TO REMOVE:
    # U2s_hin = torch.randn((L, K, BSIZE), dtype=torch.float32, device=device)

    stride_hin_d2, stride_hin_BSIZE = hin.stride(0), hin.stride(1)
    stride_su_l, stride_su_K, stride_su_d2 = U2s.stride(0), U2s.stride(1), U2s.stride(2)
    stride_out_l, stride_out_K, stride_out_BSIZE = U2s_hin.stride(0), U2s_hin.stride(1), U2s_hin.stride(2)
    
    grid = lambda META: (L, triton.cdiv(K, META["BLOCK_SIZE_K"]) * triton.cdiv(BSIZE, META["BLOCK_SIZE_BSIZE"]), )
    
    wrap_triton(first_pass_U2s_hin_kernel)[grid](
        hin, U2s, U2s_hin,
        K, d2, BSIZE, L,
        stride_hin_d2, stride_hin_BSIZE,
        stride_su_l, stride_su_K, stride_su_d2,
        stride_out_l, stride_out_K, stride_out_BSIZE
    )

    U2s_hin = U2s_hin.contiguous()
    
    # calc_grad_S2s
    ###############
    L, K, BSIZE = U2s_hin.shape
    _, d1 = g.shape
    
    grad_S2s = torch.empty((L, K, d1), dtype=torch.float32, device=device)

    stride_g_BSIZE, stride_g_d1 = g.stride(0), g.stride(1)
    stride_su_l, stride_su_K, stride_su_BSIZE = U2s_hin.stride(0), U2s_hin.stride(1), U2s_hin.stride(2)
    stride_out_l, stride_out_K, stride_out_d1 = grad_S2s.stride(0), grad_S2s.stride(1), grad_S2s.stride(2)
    
    grid = lambda META: (L, triton.cdiv(K, META["BLOCK_SIZE_K"]) * triton.cdiv(d1, META["BLOCK_SIZE_d1"]), )
    
    wrap_triton(calc_grad_S2s_kernel)[grid](
        g, U2s_hin, grad_S2s,
        K, BSIZE, d1, L,
        stride_g_BSIZE, stride_g_d1,
        stride_su_l, stride_su_K, stride_su_BSIZE,
        stride_out_l, stride_out_K, stride_out_d1
    )

    grad_S2s = grad_S2s.contiguous()

    return [
        grad,
        grad_S1s,
        grad_S2s,
        g.sum(0)
    ]
    
@backward_op.register_kernel("cpu")
def _(input, S1s, S2s, U1s, U2s, grad_output):
    num_terms = S2s.shape[0]
    g = grad_output / (2 * num_terms)
    g = g.unsqueeze(0).expand(num_terms, g.shape[0], g.shape[1])
    input = (
        input.unsqueeze(0)
        .expand(num_terms, input.shape[0], input.shape[1])
        .transpose(1, 2)
    )
    U1s = U1s.transpose(1, 2)
    S1s = S1s.transpose(1, 2)
    U2s = U2s.transpose(1, 2)
    S2s = S2s.transpose(1, 2)
    t1 = g.bmm(U1s)
    grad = t1.bmm(S1s).sum(0) + g.bmm(S2s).bmm(U2s).sum(0)
    grad_S2s = (U2s.bmm(input)).bmm(g)
    grad_S1s = input.bmm(g.bmm(U1s))

    g = g[0]
    return [
        grad,
        grad_S1s,
        grad_S2s,
        # sum g on batch dimension input.shape[0]
        g.reshape(input.shape[2], -1).sum(0)
    ]
    
class SketchedLinearFunction_triton(Function):
    # Note that forward, setup_context, and backward are @staticmethods
    @staticmethod
    def forward(
        input: torch.Tensor,
        S1s: torch.Tensor,
        S2s: torch.Tensor,
        U1s: torch.Tensor,
        U2s: torch.Tensor,
        bias: torch.Tensor,
    ): 
        return forward_op(input, S1s, S2s, U1s, U2s, bias)

    @staticmethod
    # inputs is a Tuple of all of the inputs passed to forward.
    # output is the output of the forward().
    def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any):
        input, S1s, S2s, U1s, U2s, bias = inputs
        ctx.save_for_backward(input, S1s, S2s, U1s, U2s, bias)

    @staticmethod
    def backward(ctx: Any, *grad_output: Any) -> Any:
        # dl/dS2_i = U1_i g h_in^T / 2 * l
        # dl/dS1_i = g h_in^T U2_i^T / 2 * l
        # dl/dh_in = 1/(2*l) * (sum_{i=1}^{l} (S1_i^T U1_i g) + sum_{i=1}^{l} (U2_i^T S2_i g))
        # dl/db = g
        hin, S1s, S2s, U1s, U2s, _ = ctx.saved_tensors
        grad_input, grad_S1s, grad_S2s, grad_bias = backward_op(hin, S1s, S2s, U1s, U2s, grad_output[0])
        return grad_input, grad_S1s, grad_S2s, None, None, grad_bias
        


class SKLinear_triton(nn.Module):
    __constants__ = ["in_features", "out_features", "num_terms", "low_rank"]
    in_features: int
    out_features: int
    num_terms: int
    low_rank: int
    S1s: torch.Tensor
    S2s: torch.Tensor
    U1s: torch.Tensor
    U2s: torch.Tensor

    def __init__(
        self,
        in_features: int,
        out_features: int,
        num_terms: int,
        low_rank: int,
        W_init=None,
        bias: bool = True,
        dtype=None,
        device=None,
    ):
        factory_kwargs = {"dtype": dtype, "device": device}
        super(SKLinear_triton, self).__init__()

        # if (
        #     2 * num_terms * low_rank * (out_features + in_features)
        #     > out_features * in_features
        # ):
        #     raise ValueError(
        #         "The number of parameters in the sketching layer is larger "
        #         + "than the number of parameters in the fully connected layer."
        #     )

        self.num_terms = num_terms # l
        self.low_rank = low_rank # k
        self.out_features = out_features
        self.in_features = in_features

        # Register U1s and U2s as buffers since they are not learnable
        self.register_buffer(
            "U1s",
            torch.stack(
                [
                    gen_U(low_rank, out_features, **factory_kwargs)
                    for _ in range(num_terms)
                ]
            ),
        )  # k(low rank)xd1(out) stacked along the zeros dimension (l) -> l x k x d1
        self.register_buffer(
            "U2s",
            torch.stack(
                [
                    gen_U(in_features, low_rank, **factory_kwargs)
                    for _ in range(num_terms)
                ]
            ),
        )  # d2xk stacked along the zeros dimension (l) -> l x d2 x k

        # W is used to only initialize S
        if W_init is None:
            W = torch.empty(in_features, out_features, **factory_kwargs) # d2 * d1
            init.kaiming_uniform_(W, a=math.sqrt(5))
        else:
            W = W_init.T.detach().clone()

        # S1s and S2s are precomputed but not updated in the backward pass
        self.S1s = nn.Parameter(
            torch.stack([torch.matmul(W, self.U1s[i].T) for i in range(num_terms)])
        )  # d2xk stacked along the zeros dimension (l) -> l x d2 x k
        self.S2s = nn.Parameter(
            torch.stack([torch.matmul(self.U2s[i].T, W) for i in range(num_terms)])
        )  # kxd1 stacked along the zeros dimension (l) -> l x k x d1

        # Bias term initialized with a small standard deviation
        if bias:
            self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) #1 * d1
            fan_in, _ = init._calculate_fan_in_and_fan_out(W)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            init.uniform_(self.bias, -bound, bound)
        else:
            self.register_parameter("bias", None)

    def forward(self, h_in):
        # TODO: Make sure all the things are contiguos
        return SketchedLinearFunction_triton.apply(
            h_in.contiguous(), self.S1s.contiguous(), self.S2s.contiguous(), self.U1s.contiguous(), self.U2s.contiguous(), self.bias.contiguous()
        )

# checking 

In [None]:
# import torch
# import torch.nn.functional as F
# from panther.nn.linear import SKLinear
# from panther.nn.linear_tr import SKLinear_triton

# def compare_linear_implementations(batch_size=64, in_features=128, out_features=64, num_terms=4, low_rank=8, raise_error=True):
#     """
#     Compare the outputs of the forward pass and gradients between regular SKLinear and Triton-accelerated SKLinear_triton.
    
#     Args:
#         batch_size: Batch size for the input tensor
#         in_features: Number of input features
#         out_features: Number of output features
#         num_terms: Number of terms (l) in the sketched linear layer
#         low_rank: The low-rank dimension (k) in the sketched linear layer
#         raise_error: If True, raise AssertionError when comparisons fail
    
#     Returns:
#         bool: True if all comparisons pass, False otherwise
    
#     Raises:
#         AssertionError: If any comparison fails and raise_error is True
#     """
#     # Set seed for reproducibility
#     torch.manual_seed(42)
    
#     # Create identical layers
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
#     # Generate the same initialization for both layers
#     W_init = torch.randn(out_features, in_features, device=device)
    
#     # Create both layer implementations with identical parameters
#     sk_linear = SKLinear(in_features, out_features, num_terms, low_rank, W_init=W_init, device=device)
#     sk_linear_triton = SKLinear_triton(in_features, out_features, num_terms, low_rank, W_init=W_init, device=device)
    
#     # Ensure parameters match exactly by copying from sk_linear to sk_linear_triton
#     with torch.no_grad():
#         sk_linear_triton.S1s.copy_(sk_linear.S1s)
#         sk_linear_triton.S2s.copy_(sk_linear.S2s)
#         sk_linear_triton.U1s.copy_(sk_linear.U1s)
#         sk_linear_triton.U2s.copy_(sk_linear.U2s)
#         sk_linear_triton.bias.copy_(sk_linear.bias)
    
#     # Create identical input
#     input_data = torch.randn(batch_size, in_features, device=device, requires_grad=True)
#     input_data_triton = input_data.clone().detach().requires_grad_(True)
    
#     # Forward pass
#     output = sk_linear(input_data)
#     output_triton = sk_linear_triton(input_data_triton)
    
#     # Check if forward outputs match
#     forward_match = torch.allclose(output, output_triton, rtol=1e-4, atol=1e-5)
#     print(f"Forward outputs match: {forward_match}")
#     if not forward_match:
#         max_diff = torch.max(torch.abs(output - output_triton))
#         error_msg = f"Forward outputs don't match. Max difference: {max_diff.item()}"
#         print(error_msg)
#         if raise_error:
#             assert forward_match, error_msg
    
#     # Create identical gradients
#     grad_output = torch.randn_like(output)
#     grad_output_triton = grad_output.clone()
    
#     # Backward pass
#     output.backward(grad_output)
#     output_triton.backward(grad_output_triton)
    
#     # Check if input gradients match
#     input_grad_match = torch.allclose(input_data.grad, input_data_triton.grad, rtol=1e-4, atol=1e-5)
#     print(f"Input gradients match: {input_grad_match}")
#     if not input_grad_match:
#         max_diff = torch.max(torch.abs(input_data.grad - input_data_triton.grad))
#         error_msg = f"Input gradients don't match. Max difference: {max_diff.item()}"
#         print(error_msg)
#         if raise_error:
#             assert input_grad_match, error_msg
    
#     # Check if S1s gradients match
#     S1s_grad_match = torch.allclose(sk_linear.S1s.grad, sk_linear_triton.S1s.grad, rtol=1e-4, atol=1e-5)
#     print(f"S1s gradients match: {S1s_grad_match}")
#     if not S1s_grad_match:
#         max_diff = torch.max(torch.abs(sk_linear.S1s.grad - sk_linear_triton.S1s.grad))
#         error_msg = f"S1s gradients don't match. Max difference: {max_diff.item()}"
#         print(error_msg)
#         if raise_error:
#             assert S1s_grad_match, error_msg
    
#     # Check if S2s gradients match
#     S2s_grad_match = torch.allclose(sk_linear.S2s.grad, sk_linear_triton.S2s.grad, rtol=1e-4, atol=1e-5)
#     print(f"S2s gradients match: {S2s_grad_match}")
#     if not S2s_grad_match:
#         max_diff = torch.max(torch.abs(sk_linear.S2s.grad - sk_linear_triton.S2s.grad))
#         error_msg = f"S2s gradients don't match. Max difference: {max_diff.item()}"
#         print(error_msg)
#         if raise_error:
#             assert S2s_grad_match, error_msg
    
#     # U1s and U2s are buffers, not parameters, so they may not have gradients
#     # Only check if both gradients exist
#     if hasattr(sk_linear.U1s, 'grad') and hasattr(sk_linear_triton.U1s, 'grad') and sk_linear.U1s.grad is not None and sk_linear_triton.U1s.grad is not None:
#         U1s_grad_match = torch.allclose(sk_linear.U1s.grad, sk_linear_triton.U1s.grad, rtol=1e-4, atol=1e-5)
#         print(f"U1s gradients match: {U1s_grad_match}")
#         if not U1s_grad_match:
#             max_diff = torch.max(torch.abs(sk_linear.U1s.grad - sk_linear_triton.U1s.grad))
#             error_msg = f"U1s gradients don't match. Max difference: {max_diff.item()}"
#             print(error_msg)
#             if raise_error:
#                 assert U1s_grad_match, error_msg
#     else:
#         print("U1s gradients not available - skipping comparison")
#         U1s_grad_match = True  # Skip this comparison
    
#     # Check if U2s gradients match - only if both exist
#     if hasattr(sk_linear.U2s, 'grad') and hasattr(sk_linear_triton.U2s, 'grad') and sk_linear.U2s.grad is not None and sk_linear_triton.U2s.grad is not None:
#         U2s_grad_match = torch.allclose(sk_linear.U2s.grad, sk_linear_triton.U2s.grad, rtol=1e-4, atol=1e-5)
#         print(f"U2s gradients match: {U2s_grad_match}")
#         if not U2s_grad_match:
#             max_diff = torch.max(torch.abs(sk_linear.U2s.grad - sk_linear_triton.U2s.grad))
#             error_msg = f"U2s gradients don't match. Max difference: {max_diff.item()}"
#             print(error_msg)
#             if raise_error:
#                 assert U2s_grad_match, error_msg
#     else:
#         print("U2s gradients not available - skipping comparison")
#         U2s_grad_match = True  # Skip this comparison
    
#     # Check if bias gradients match
#     bias_grad_match = torch.allclose(sk_linear.bias.grad, sk_linear_triton.bias.grad, rtol=1e-4, atol=1e-5)
#     print(f"Bias gradients match: {bias_grad_match}")
#     if not bias_grad_match:
#         max_diff = torch.max(torch.abs(sk_linear.bias.grad - sk_linear_triton.bias.grad))
#         error_msg = f"Bias gradients don't match. Max difference: {max_diff.item()}"
#         print(error_msg)
#         if raise_error:
#             assert bias_grad_match, error_msg
    
#     # Return overall result
#     all_matches = all([forward_match, input_grad_match, S1s_grad_match, S2s_grad_match, U1s_grad_match, U2s_grad_match, bias_grad_match])
#     if raise_error:
#         assert all_matches, "At least one comparison failed. See detailed errors above."
#     return all_matches


# if __name__ == "__main__":
#     print("Comparing SKLinear and SKLinear_triton implementations...")
    
#     # Test with default parameters
#     try:
#         result = compare_linear_implementations()
#         print(f"All checks passed: {result}")
#     except AssertionError as e:
#         print(f"Default test failed: {e}")
    
#     # Test with different parameters
#     try:
#         result_small = compare_linear_implementations(batch_size=32, in_features=64, out_features=32, num_terms=2, low_rank=4)
#         print(f"Small model checks passed: {result_small}")
#     except AssertionError as e:
#         print(f"Small model test failed: {e}")

In [1]:
%%writefile /kaggle/working/panther/tests/run.py
import time
import numpy as np
import torch
import torch._dynamo
import torch._inductor.config as config
import itertools
import pandas as pd
from panther.nn import SKLinear

# Configure torch
config.max_autotune_gemm = False
torch._dynamo.config.cache_size_limit = 2**16
torch._dynamo.config.accumulated_cache_size_limit = 2**16

def is_valid_params(in_features, out_features, num_terms, low_rank):
    """
    Check if parameter combination is valid:
    A combination is invalid if 2 * num_terms * low_rank * (out_features + in_features) >= out_features * in_features
    """
    return 2 * num_terms * low_rank * (out_features + in_features) < out_features * in_features

class BenchmarkParams:
    def __init__(self,
                 in_features=256, 
                 out_features=256,
                 num_terms=3,
                 low_rank=8,
                 batch_size=64, 
                 num_runs=200,
                 warmup=15,
                 device='cuda',
                 dtype=torch.float32):
        self.in_features = in_features
        self.out_features = out_features
        self.num_terms = num_terms
        self.low_rank = low_rank
        self.batch_size = batch_size
        self.num_runs = num_runs
        self.warmup = warmup
        self.device = device
        self.dtype = dtype

def benchmark_model(model, x, model_name, params):
    """
    Generic benchmarking function for any PyTorch model.
    
    Args:
        model: The PyTorch model to benchmark
        x: Input tensor
        model_name: Name of the model for logging
        params: Benchmark parameters
    
    Returns:
        Dictionary with benchmark results
    """
    # Compile the model
    model_compiled = torch.compile(
        model,
        backend="inductor",
        fullgraph=True,
        dynamic=False
    )
    
    # Benchmark forward pass
    print(f"\n=== {model_name} FORWARD PASS BENCHMARK ===")
    
    # Warmup runs for forward pass
    # model_compiled.eval()
    # with torch.no_grad():
    #     for _ in range(params.warmup):
    #         _ = model_compiled(x)
    
    # torch.cuda.synchronize()
    
    # Actual timed runs for forward
    forward_times = []
    # with torch.no_grad():
    #     for _ in range(params.num_runs):
    #         torch.cuda.synchronize()
    #         start = time.perf_counter()
    #         _ = model_compiled(x)
    #         torch.cuda.synchronize()
    #         end = time.perf_counter()
            
    #         forward_times.append((end - start) * 1000)  # Convert to ms
    
    # mean_forward = np.mean(forward_times)
    # std_forward = np.std(forward_times)
    mean_forward = 0
    std_forward = 0
    print(f"{model_name} forward: {mean_forward:.3f} ± {std_forward:.3f} ms")
    
    # Benchmark backward pass
    # print(f"\n=== {model_name} BACKWARD PASS BENCHMARK ===")

    def infer():
        # x_to_use = x.clone().detach().to('cuda')
        # x_to_use.requires_grad_(True)
        x_to_use = x

        out = model_compiled(x_to_use)
        loss = out.sum()
        
        torch.cuda.synchronize()
        start = time.perf_counter()
        loss.backward()
        torch.cuda.synchronize()
        end = time.perf_counter()
        
        # model_compiled.zero_grad(set_to_none=True)
        # x_to_use.grad.zero_()
        # x_to_use = x_to_use.detach().requires_grad_(True)
        # torch.cuda.empty_cache()

        return ((end - start) * 1000)
    
    # Warmup runs for backward pass
    model_compiled.train()
    for _ in range(params.warmup):
       infer()
    
    torch.cuda.synchronize()
    
    # Actual timed runs for backward
    backward_times = []
    for _ in range(params.num_runs):
        backward_times.append(infer())
    
    mean_backward = np.mean(backward_times)
    std_backward = np.std(backward_times)
    print(f"{model_name} backward: {mean_backward:.3f} ± {std_backward:.3f} ms")
    
    return {
        "forward": {
            "mean": mean_forward,
            "std": std_forward,
            "times": forward_times
        },
        "backward": {
            "mean": mean_backward,
            "std": std_backward,
            "times": backward_times
        }
    }

def benchmark_model_factory(model_factory, model_name, params):
    """
    Benchmark a model using a factory function.
    
    Args:
        model_factory: Function that creates the model
        model_name: Name of the model for logging
        params: Benchmark parameters
    
    Returns:
        Dictionary with benchmark results
    """
    # Create the model
    torch.manual_seed(42)
    model = model_factory(params)
    
    # Create input tensor for benchmarking
    x = torch.randn(params.batch_size, params.in_features, 
                  dtype=params.dtype, device=params.device, requires_grad=True)
    
    return benchmark_model(model, x, model_name, params)

if __name__ == "__main__":
    import torch.nn as nn
    from panther.nn import SKLinear, SKLinear_triton
    
    # Parameter combinations to test
    ratios = [(1, 128), (128, 1), (1, 1), (2, 1), (1, 2)]
    base_sizes = [256, 512, 1024, 8192, 16384]
    num_terms_options = [1, 2, 3]
    # low_rank_options = [16, 32, 64, 128]
    low_rank_options = [32, 64, 128]
    
    # Define model factories
    def create_sklinear_triton(p):
        return SKLinear_triton(p.in_features, p.out_features, 
                             p.num_terms, p.low_rank, 
                             dtype=p.dtype, device=p.device)
    
    models_to_benchmark = [
        (create_sklinear_triton, "SKLinear_triton")
    ]
    
    # Prepare data structure to store all results
    results_data = []
    
    # Iterate through all parameter combinations
    total_combinations = len(ratios) * len(base_sizes) * len(num_terms_options) * len(low_rank_options)
    current_combo = 0
    
    for ratio, base_size in itertools.product(ratios, base_sizes):
        ratio_in, ratio_out = ratio
        
        # Calculate actual dimensions based on ratio and base size
        if ratio_in == 1:
            in_features = base_size
            out_features = base_size * ratio_out
        else:
            out_features = base_size
            in_features = base_size * ratio_in
        
        for num_terms, low_rank in itertools.product(num_terms_options, low_rank_options):
            current_combo += 1
            print(f"\n\n{'='*20} COMBINATION {current_combo}/{total_combinations} {'='*20}")
            print(f"In features: {in_features}, Out features: {out_features}, Ratio: {ratio_in}:{ratio_out}")
            print(f"Base size: {base_size}, Num terms: {num_terms}, Low rank: {low_rank}")
            
            # Check if parameters are valid
            is_valid = is_valid_params(in_features, out_features, num_terms, low_rank)
            
            if not is_valid:
                print(f"INVALID COMBINATION: 2 * {num_terms} * {low_rank} * ({out_features} + {in_features}) >= {out_features} * {in_features}")
                print("Skipping benchmarks for this invalid combination")
                
                # Add invalid entry to results data
                for model_name in [m[1] for m in models_to_benchmark]:
                    results_data.append({
                        'model': model_name,
                        'in_features': in_features,
                        'out_features': out_features,
                        'ratio': f"{ratio_in}:{ratio_out}",
                        'base_size': base_size,
                        'num_terms': num_terms,
                        'low_rank': low_rank,
                        'forward_mean_ms': float('nan'),
                        'forward_std_ms': float('nan'),
                        'backward_mean_ms': float('nan'),
                        'backward_std_ms': float('nan'),
                        'is_valid': False,
                        'error': "Invalid parameter combination"
                    })
                continue
            
            # Create parameter object for this combination
            params = BenchmarkParams(
                in_features=in_features,
                out_features=out_features,
                num_terms=num_terms,
                low_rank=low_rank
            )
            
            all_results = {}
            for model_factory, model_name in models_to_benchmark:
                print(f"\n{'='*20} Benchmarking {model_name} {'='*20}")
                try:
                    results = benchmark_model_factory(model_factory, model_name, params)
                    all_results[model_name] = results
                    
                    # Add result to our data collection
                    results_data.append({
                        'model': model_name,
                        'in_features': in_features,
                        'out_features': out_features,
                        'ratio': f"{ratio_in}:{ratio_out}",
                        'base_size': base_size,
                        'num_terms': num_terms,
                        'low_rank': low_rank,
                        'forward_mean_ms': results['forward']['mean'],
                        'forward_std_ms': results['forward']['std'],
                        'backward_mean_ms': results['backward']['mean'],
                        'backward_std_ms': results['backward']['std'],
                        'is_valid': True
                    })
                except Exception as e:
                    print(f"Error benchmarking {model_name}: {e}")
                    # Add error entry to data
                    results_data.append({
                        'model': model_name,
                        'in_features': in_features,
                        'out_features': out_features,
                        'ratio': f"{ratio_in}:{ratio_out}",
                        'base_size': base_size,
                        'num_terms': num_terms,
                        'low_rank': low_rank,
                        'forward_mean_ms': float('nan'),
                        'forward_std_ms': float('nan'),
                        'backward_mean_ms': float('nan'),
                        'backward_std_ms': float('nan'),
                        'is_valid': True,
                        'error': str(e)
                    })
            
            # Print comparative summary for this combination
            if all_results:
                print("\n" + "="*60)
                print(f"{'='*20} SUMMARY FOR CURRENT COMBINATION {'='*20}")
                print("="*60)
                print(f"{'Model':<20} {'Forward (ms)':<25} {'Backward (ms)':<25}")
                print("-"*60)
                
                for model_name, results in all_results.items():
                    fwd = f"{results['forward']['mean']:.3f} ± {results['forward']['std']:.3f}"
                    bwd = f"{results['backward']['mean']:.3f} ± {results['backward']['std']:.3f}"
                    print(f"{model_name:<20} {fwd:<25} {bwd:<25}")
    
    # Create a DataFrame with all results
    df = pd.DataFrame(results_data)
    
    # Save results to CSV
    results_file = "benchmark_results.csv"
    df.to_csv(results_file, index=False)
    print(f"\nAll benchmark results saved to {results_file}")

Overwriting /kaggle/working/panther/tests/run.py


In [None]:
import os
os.chdir("/kaggle/working/panther/")

In [None]:
!pwd

In [None]:
import os

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ['TORCH_USE_CUDA_DSA'] = "1"

In [2]:
!PYTHONPATH=/kaggle/working/panther python /kaggle/working/panther/tests/run.py



In features: 256, Out features: 32768, Ratio: 1:128
Base size: 256, Num terms: 1, Low rank: 32


=== SKLinear_triton FORWARD PASS BENCHMARK ===
SKLinear_triton forward: 0.000 ± 0.000 ms
SKLinear_triton backward: 2.745 ± 0.766 ms

Model                Forward (ms)              Backward (ms)            
------------------------------------------------------------
SKLinear_triton      0.000 ± 0.000             2.745 ± 0.766            


In features: 256, Out features: 32768, Ratio: 1:128
Base size: 256, Num terms: 1, Low rank: 64


=== SKLinear_triton FORWARD PASS BENCHMARK ===
SKLinear_triton forward: 0.000 ± 0.000 ms
SKLinear_triton backward: 2.369 ± 0.026 ms

Model                Forward (ms)              Backward (ms)            
------------------------------------------------------------
SKLinear_triton      0.000 ± 0.000             2.369 ± 0.026            


In features: 256, Out features: 32768, Ratio: 1:128
Base size: 256, Num terms: 1, Low rank: 128
INVALID COMBINATION: 2 *

In [None]:
# !CUDA_LAUNCH_BLOCKING=1 PYTHONPATH=/kaggle/working/panther python /kaggle/working/panther/tests/run.py