In [6]:
#!pip install --upgrade cupy-cuda112==8.5.0

In [7]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import cupy as cp
import math
from time import time

In [8]:
#@title CustomKernel
import cupy as cp
import torch

@cp.util.memoize(for_each_device=True)
def cunnex(func_name, func_body):
  return cp.cuda.compile_with_cache(func_body).get_function(func_name)

class Stream:
  def __init__(self, ptr):
    self.ptr = ptr
  
class CustomKernel:
  def __init__(self):
    self._use_torch_in_cupy_malloc()
    self.stream = Stream(torch.cuda.current_stream().cuda_stream)

  @staticmethod
  def _torch_alloc(size):
    device = cp.cuda.Device().id
    tensor = torch.empty(size, dtype=torch.uint8, device=device)
    return cp.cuda.MemoryPointer(
        cp.cuda.UnownedMemory(tensor.data_ptr(), size, tensor), 0)

  def _use_torch_in_cupy_malloc(self):
    cp.cuda.set_allocator(self._torch_alloc)

  def _compile_kernel_str(
      self,
      kernel,
      name,
      options=(),
      backend="nvrtc",
      max_dynamic_smem=None
    ):
    fn = cp.RawKernel(
      kernel,
      name,
      options=options,
      backend=backend,
    )
    if max_dynamic_smem:
      fn.max_dynamic_shared_size_bytes = max_dynamic_smem
    return fn

In [9]:
#@title BitonicSort Kernel

kernel = """
typedef long long ll_t;
#define isnan(x) ( x != x )

#if (__CUDA_ARCH__ < 700)
__device__ void __nanosleep(unsigned int ns){
  clock_t start_clock = clock();
  clock_t clock_offset = 0;
  while (clock_offset < ns)
  {
    clock_offset = clock() - start_clock;
  }
}
#endif 

/*
mutex lock code from:
https://stackoverflow.com/questions/18963293/cuda-atomics-change-flag/18968893#18968893
*/

__device__ void mutex_lock_v2(
  unsigned int *mutex
) {
  unsigned int ns = 8;
  __syncthreads();
  if (threadIdx.x == 0){
    while (atomicCAS(mutex, 0, 1) == 1) {
      __nanosleep(ns);
      if (ns < 256) {
        ns *= 2;
      }
    }
  }
  __syncthreads();
}

__device__ void mutex_lock(
  unsigned int *mutex,
  unsigned int blockMutex[1]
) {
  unsigned int ns = 8;
  float old_value;
  if (threadIdx.x == 0){
    old_value = atomicCAS(mutex, 0, 1);
    blockMutex[0] = old_value;
  }
  __syncthreads();
  old_value = blockMutex[0];
  while (old_value == 1) {
    __nanosleep(ns);
    if (ns < 256) {
      ns *= 2;
    }

    if (threadIdx.x == 0){
      old_value = atomicCAS(mutex, 0, 1);
      blockMutex[0] = old_value;
    }
    __syncthreads();
    old_value = blockMutex[0];
    __syncthreads();
  }
}

__device__ void mutex_unlock_v2(unsigned int *mutex) {
  __threadfence();
  __syncthreads();
  if (threadIdx.x == 0){
    atomicExch(mutex, 0);
    __threadfence();
  }
  __syncthreads();
}

__device__ void mutex_unlock(unsigned int *mutex) {
  atomicExch(mutex, 0);
}

__device__ __forceinline__ unsigned int bfe(
  unsigned int source,
  unsigned int bitIndex
) {
  unsigned int bit;
  asm volatile("bfe.u32 %0, %1, %2, %3;" : "=r"(bit) : "r"((unsigned int) source), "r"(bitIndex), "r"(1));
  return bit;
}

__device__ __forceinline__ void warpComparator(
  float &value,
  float &index,
  const int stride,
  const int direction
){
  const float other_value = __shfl_xor_sync(0xFFFFFFFF, value, stride);
  const float other_index = __shfl_xor_sync(0xFFFFFFFF, index, stride);
  bool condition = value < other_value == direction;
  index = condition ? other_index : index;
  value = condition ? other_value : value;
}

__device__ __forceinline__ void blockComparator(
  float &value,
  float &index,
  const int stride,
  const int direction,
  const int laneID,
  float valSM[128],
  float idxSM[128]
){
  valSM[laneID] = value;
  idxSM[laneID] = index;
  __syncthreads();

  float other_value = valSM[laneID ^ stride];
  float other_index = idxSM[laneID ^ stride];
  __syncthreads();

  bool condition = value < other_value == direction;
  index = condition ? other_index : index;
  value = condition ? other_value : value;
}

__device__ void bitonicSort256(
  float &value,
  float &index,
  float* values,
  ll_t* indices,
  float valSM[128],
  float idxSM[128],
  int gStartx, int Q
){
  float other_value = values[threadIdx.x];
  float other_index = indices[threadIdx.x] - gStartx;
  
  bool condition = value > other_value == 0;
  if (condition){
    float temp_value = value;
    float temp_index = index;
    value = other_value;
    index = other_index;
    other_value = temp_value;
    other_index = temp_index;
  }

  int laneID = threadIdx.x % 128;
  int i = 7;
  for (int j = 6; j >= 0; j--){
    unsigned int direction = bfe(laneID, 8) ^ bfe(laneID, j);
    int stride = pow(2, j);
    if (stride < 32){
      warpComparator(value, index, stride, !direction);
    } else {
      blockComparator(value, index, stride, !direction, laneID, valSM, idxSM);
    }
  }

  if (threadIdx.x < Q){
    values[threadIdx.x] = value;
    indices[threadIdx.x] = index + gStartx;
  }
}

__device__ void bitonicSort(
  float &value,
  float &index,
  float valSM[128],
  float idxSM[128]
) {
  unsigned int laneID = threadIdx.x % 128;
  for (int i=0; i < 7; i++){
    for (int j=i; j >= 0; j--){
      unsigned int direction = bfe(laneID, i + 1) ^ bfe(laneID, j);
      int stride = pow(2, j);
      // if (i==6 && j==0) break;
      if (stride < 32){
        warpComparator(value, index, stride, direction);
      } else {
        blockComparator(value, index, stride, direction, laneID, valSM, idxSM);
      }
    }
  }
}

extern "C"
__global__ void bitonic_sort(
   const float* __restrict__ arr,
   float* values,
   ll_t* indices,
   unsigned int* mutex,
   int L, int Q
){
  int gStartx = blockIdx.x * 128;
  int tid = threadIdx.x;
  __shared__ float valSM[128];
  __shared__ float idxSM[128];
  
  float value;
  float index;
  int iL = gStartx + tid;
  if (iL < L){
    value = arr[iL];
    index = tid;
  } else {
    value = -INFINITY;
  }
  
  bitonicSort(value, index, valSM, idxSM);

  __shared__ unsigned int blockMutex[1];
  mutex_lock_v2(mutex);

  bitonicSort256(
    value, index, values, indices,
    valSM, idxSM, gStartx, Q
  );
  
  mutex_unlock_v2(mutex);
}
"""

with open("BitonicSort.cu", "w") as f:
  f.write(kernel)

In [10]:
#@title BitonicSort
import torch
import cupy as cp
import numpy as np
import math

class BitonicSort(CustomKernel): 
  def __init__(self):
    super(BitonicSort, self).__init__()
    
    with open("BitonicSort.cu",'r') as f: ###
      self.kernel = f.read()
    
    self._fn = cp.RawKernel(
      code=self.kernel,
      name="bitonic_sort",
      backend='nvcc',
      options=(
        '--maxrregcount=128',
        '--use_fast_math',
        #'-Xptxas',
        #'-dlcm=cg',
      )
    )

  def __call__(self, arr):
    l = arr.shape[0]
    q = 128
    threads_per_block = (128,)
    blocks_per_grid = ( math.ceil(l/128), )
    values = torch.empty(128, device="cuda:0", dtype=torch.float)
    values.fill_(float("-inf"))
    indices = torch.empty(128, device="cuda:0", dtype=torch.long)
    mutex = torch.zeros(1, device="cuda:0", dtype=torch.int)

    self._fn(
      grid = blocks_per_grid,
      block = threads_per_block,
      args = [
        arr.data_ptr(),
        values.data_ptr(),
        indices.data_ptr(),
        mutex.data_ptr(),
        l, q
      ],
      stream=self.stream
    )
    print(mutex)
    return values, indices

# x = torch.randn(128*1024, device="cuda:0")
# bitonic_sort = BitonicSort()

# v1, i1 = torch.topk(x, k=128)
# v2, i2 = bitonic_sort(x)
# print(i1)
# print(i2)

# # plt.plot(x.cpu())
# # plt.show()

# # print(v2)
# plt.plot(v1.cpu())
# plt.show()

# plt.plot(v2.cpu())
# plt.show()

# # x2 = x.sort(dim=0, descending=True)[0]
# # plt.plot(x2.cpu())
# # plt.show()


# val_dif = (v1 - v2).abs()
# idx_dif = (i1 != i2)

# print("val error", val_dif.sum())
# print("idx error", idx_dif.sum())

In [11]:
#@title TopkBMM Kernel
kernel = """
#define _VOLATILE_  

#define likely(x)      __builtin_expect(!!(x), 1)
#define unlikely(x)    __builtin_expect(!!(x), 0)
#define load(x)        __ldcg(x)
#define store(x, value) __stcs(x, value)
#define isnan(x) ( x != x )
#if (__CUDA_ARCH__ < 700)
__device__ void __nanosleep(unsigned int ns){
  clock_t start_clock = clock();
  clock_t clock_offset = 0;
  while (clock_offset < ns)
  {
    clock_offset = clock() - start_clock;
  }
}
#endif 

typedef long long ll_t;
typedef unsigned long long ull_t;

typedef struct __builtin_align__(32) {
  float s0, s1, s2, s3, s4, s5, s6, s7;
} _float8;

typedef union {
  _float8 f8;
  float val[8];
} float8;

__device__ void mutex_lock(
  unsigned int *mutex
) {
  unsigned int ns = 8;
  __syncthreads();
  if (threadIdx.x == 0 ){
    while (atomicCAS(mutex, 0, 1) == 1) {
      __nanosleep(ns);
      if (ns < 256) {
        ns *= 2;
      }
    }
  }
  __syncthreads();
}

__device__ void mutex_lock_noop(
) {
  __syncthreads();
}

__device__ void mutex_unlock(
  unsigned int *mutex
) {
  __threadfence();
  __syncthreads();
  if (threadIdx.x == 0){
    atomicExch(mutex, 0);
    __threadfence();
  }
  __syncthreads();
}

__device__ void mutex_unlock_noop(){
  __syncthreads();
  __syncthreads();
}

__device__ __forceinline__ unsigned int bfe(
  unsigned int source,
  unsigned int bitIndex
) {
  unsigned int bit;
  asm volatile("bfe.u32 %0, %1, %2, %3;" : "=r"(bit) : "r"((unsigned int) source), "r"(bitIndex), "r"(1));
  return bit;
}

__device__ __forceinline__ void warpComparator(
  float &value,
  float &index,
  const int stride,
  const int direction
){
  const float other_value = __shfl_xor_sync(0xFFFFFFFF, value, stride);
  const float other_index = __shfl_xor_sync(0xFFFFFFFF, index, stride);
  bool condition = value < other_value == direction;
  index = condition ? other_index : index;
  value = condition ? other_value : value;
}

__device__ __forceinline__ void blockComparator(
  float &value,
  float &index,
  const int stride,
  const int direction,
  const int laneID,
  _VOLATILE_ float valSM[128+4],
  _VOLATILE_ float idxSM[128+4]
){
  valSM[laneID] = value;
  idxSM[laneID] = index;
  __syncthreads();

  __syncthreads();
  float other_value = valSM[laneID ^ stride];
  float other_index = idxSM[laneID ^ stride];
  __syncthreads();

  bool condition = value < other_value == direction;
  index = condition ? other_index : index;
  value = condition ? other_value : value;
}

__device__ void bitonicSort128(
  float &value,
  float &index,
  _VOLATILE_ float valSM[128+4],
  _VOLATILE_ float idxSM[128+4]
) {
  unsigned int laneID = threadIdx.x % 128;
  warpComparator(value, index, 1, bfe(laneID, 1) ^ bfe(laneID, 0));

  warpComparator(value, index, 2, bfe(laneID, 2) ^ bfe(laneID, 1));
  warpComparator(value, index, 1, bfe(laneID, 2) ^ bfe(laneID, 0));

  warpComparator(value, index, 4, bfe(laneID, 3) ^ bfe(laneID, 2));
  warpComparator(value, index, 2, bfe(laneID, 3) ^ bfe(laneID, 1));
  warpComparator(value, index, 1, bfe(laneID, 3) ^ bfe(laneID, 0));

  warpComparator(value, index, 8, bfe(laneID, 4) ^ bfe(laneID, 3));
  warpComparator(value, index, 4, bfe(laneID, 4) ^ bfe(laneID, 2));
  warpComparator(value, index, 2, bfe(laneID, 4) ^ bfe(laneID, 1));
  warpComparator(value, index, 1, bfe(laneID, 4) ^ bfe(laneID, 0));

  warpComparator(value, index, 16, bfe(laneID, 5) ^ bfe(laneID, 4));
  warpComparator(value, index, 8, bfe(laneID, 5) ^ bfe(laneID, 3));
  warpComparator(value, index, 4, bfe(laneID, 5) ^ bfe(laneID, 2));
  warpComparator(value, index, 2, bfe(laneID, 5) ^ bfe(laneID, 1));
  warpComparator(value, index, 1, bfe(laneID, 5) ^ bfe(laneID, 0));

  blockComparator(value, index, 32, bfe(laneID, 6) ^ bfe(laneID, 5), laneID, valSM, idxSM);
  warpComparator(value, index, 16, bfe(laneID, 6) ^ bfe(laneID, 4));
  warpComparator(value, index, 8, bfe(laneID, 6) ^ bfe(laneID, 3));
  warpComparator(value, index, 4, bfe(laneID, 6) ^ bfe(laneID, 2));
  warpComparator(value, index, 2, bfe(laneID, 6) ^ bfe(laneID, 1));
  warpComparator(value, index, 1, bfe(laneID, 6) ^ bfe(laneID, 0));

  blockComparator(value, index, 64, bfe(laneID, 6), laneID, valSM, idxSM);
  blockComparator(value, index, 32, bfe(laneID, 5), laneID, valSM, idxSM);
  warpComparator(value, index, 16, bfe(laneID, 4));
  warpComparator(value, index, 8, bfe(laneID, 3));
  warpComparator(value, index, 4, bfe(laneID, 2));
  warpComparator(value, index, 2, bfe(laneID, 1));
  warpComparator(value, index, 1, bfe(laneID, 0));
}

__device__ void bitonicSort256(
  float &value,
  float &index,
  float* gValue,
  ll_t* gIndex,
  float valSM[128+4],
  float idxSM[128+4],
  int Q
){
  int laneID = threadIdx.x % 128;
  float other_value = gValue[0];
  float other_index = gIndex[0];
  
  bool condition = value > other_value == 0;
  if (condition){
    //float temp_value = value;
    //float temp_index = index;
    value = value + other_value;
    index = index + other_index;
    other_value = value - other_value;
    other_index = index - other_index;
    value = value - other_value;
    index = index - other_index;
  }

  blockComparator(value, index, 64, !bfe(laneID, 6), laneID, valSM, idxSM);
  blockComparator(value, index, 32, !bfe(laneID, 5), laneID, valSM, idxSM);
  warpComparator(value, index, 16, !bfe(laneID, 4));
  warpComparator(value, index, 8, !bfe(laneID, 3));
  warpComparator(value, index, 4, !bfe(laneID, 2));
  warpComparator(value, index, 2, !bfe(laneID, 1));
  warpComparator(value, index, 1, !bfe(laneID, 0));

  if ( laneID < Q){
    gValue[0] = value;
    gIndex[0] = index;
  }
}

__device__ void topk_dim_1(
  float8 cCache[8],
  _VOLATILE_ float valSM[16][128+4],
  _VOLATILE_ float idxSM[16][128+4],
  float* values,
  ll_t* indices,
  unsigned int* mutex,
  int gStartx, int gStarty, int bid,
  int M, int N, int Q
){
  int tid = threadIdx.x;
  int vx = tid % 16;
  int vy = tid / 16;
  int hx = tid % 128;
  int hy = tid / 128;
  #pragma unroll
  for (int ni=0; ni<8; ni++){
    if (gStartx + vx*8 + ni >= N)
      break;

    // Store cCache to cSM
    #pragma unroll
    for (int mi=0; mi<8; mi++){
      int iM = gStarty + vy*8 + mi;
      if (likely(iM < M)){
        valSM[vx][vy*8 + mi] = cCache[mi].val[ni];
        idxSM[vx][vy*8 + mi] = iM;
      } else {
        valSM[vx][vy*8 + mi] = -INFINITY;
        idxSM[vx][vy*8 + mi] = iM;
      }
    }
    __syncthreads();
    // Load from cSM to cCache
    #pragma unroll
    for (int i=0; i<8; i++){
      float value = valSM[hy*8 + i][hx];
      float index = idxSM[hy*8 + i][hx];
      bitonicSort128(
        value, index,
        valSM[hy*8 + i], idxSM[hy*8 + i]
      );
      int iN = gStartx + (hy*8 + i)*8 + ni;
      mutex_lock( &mutex[(bid)*N + iN] );
      bitonicSort256(
        value, index, 
        &values[(bid)*N*Q + iN*Q + hx],
        &indices[(bid)*N*Q + iN*Q + hx], 
        valSM[hy*8+i], idxSM[hy*8+i],
        Q
      );
      mutex_unlock( &mutex[(bid)*N + iN] );
    }
  }
}

__device__ void topk_dim_2(
  float8 cCache[8],
  _VOLATILE_ float valSM[16][128+4],
  _VOLATILE_ float idxSM[16][128+4],
  float* values,
  ll_t* indices,
  unsigned int* mutex,
  int gStartx, int gStarty, int bid,
  int M, int N, int Q
){
  int tid = threadIdx.x;
  int vx = tid % 16;
  int vy = tid / 16;
  int hx = tid % 128;
  int hy = tid / 128;
  #pragma unroll
  for (int mi=0; mi<8; mi++){
    if (gStarty + vy*8 + mi >= M)
      break;

    // Store cCache to cSM
    #pragma unroll
    for (int ni=0; ni<8; ni++){
      int iN = gStartx + vx*8 + ni;
      if (likely(iN < N)){
        valSM[vy][vx*8 + ni] = cCache[mi].val[ni];
        idxSM[vy][vx*8 + ni] = iN;
      } else {
        valSM[vy][vx*8 + ni] = -INFINITY;
      }
    }
    __syncthreads();
    // Load from cSM to cCache
    #pragma unroll
    for (int i=0; i<8; i++){
      float value = valSM[hy*8 + i][hx];
      float index = idxSM[hy*8 + i][hx];
      bitonicSort128(
        value, index,
        valSM[hy*8 + i], idxSM[hy*8 + i]
      );
      int iM = gStarty + (hy*8 + i)*8 + mi;
      mutex_lock( &mutex[(bid)*M + iM] );
      bitonicSort256(
        value, index, 
        &values[(bid)*M*Q + iM*Q + hx],
        &indices[(bid)*M*Q + iM*Q + hx], 
        valSM[hy*8+i], idxSM[hy*8+i],
        Q
      );
      mutex_unlock( &mutex[(bid)*M + iM] );
    }
  }
}

__device__ void init_cCache(
  float8 cCache[8]
) {
  #pragma unroll
  for (int i=0; i<8; i++){
    #pragma unroll
    for (int j=0; j<8; j++){
      cCache[i].val[j] = 0.f;
    }
  }
}

__device__ void thread_matmul_v4(
  _VOLATILE_ float aSM[8][128+4],
  _VOLATILE_ float bSM[8][128+4],
  float8 cCache[8],
  int vx, int vy
) {
  float aCache1[8];
  float aCache2[8];
  #pragma unroll
  for (int mi=0; mi<8; mi++){
    aCache1[mi] = aSM[0][8*vy + mi];
  }

  #pragma unroll
  for (int ki=0; ki<8; ki++){
    int is_odd = ki & 1;
    if (is_odd == 0){
      if (likely(ki < 7)){
        #pragma unroll
        for (int mi=0; mi<8; mi++){
          aCache2[mi] = aSM[ki+1][8*vy + mi];
        }
      }
      #pragma unroll
      for (int ni=0; ni<8; ni++){
        float b = bSM[ki][vx/4 + 8*vx + ni];
        #pragma unroll
        for (int mi=0; mi<8; mi++){
          float a = aCache1[mi];
          cCache[mi].val[ni] = fmaf(a, b, cCache[mi].val[ni]);
        }
      }
    } else {
      if (likely(ki < 7)){
        #pragma unroll
        for (int mi=0; mi<8; mi++){
          aCache1[mi] = aSM[ki+1][8*vy + mi];
        }
      }
      #pragma unroll
      for (int ni=0; ni<8; ni++){
        float b = bSM[ki][vx/4 + 8*vx + ni];
        #pragma unroll
        for (int mi=0; mi<8; mi++){
          float a = aCache2[mi];
          cCache[mi].val[ni] = fmaf(a, b, cCache[mi].val[ni]);
        }
      }
    }
  }
}

__device__ void thread_matmul_v3(
  _VOLATILE_ float aSM[16][128+4],
  _VOLATILE_ float bSM[16][128+4],
  float8 cCache[8],
  int vx, int vy
) {
  float aCache[8];

  #pragma unroll
  for (int ki=0; ki<16; ki++){
    #pragma unroll
    for (int mi=0; mi<8; mi++){
      aCache[mi] = aSM[ki][8*vy + mi];
    }
    #pragma unroll
    for (int ni=0; ni<8; ni++){
      float b = bSM[ki][vx/4 + 8*vx + ni];
      #pragma unroll
      for (int mi=0; mi<8; mi++){
        float a = aCache[mi];
        cCache[mi].val[ni] = fmaf(a, b, cCache[mi].val[ni]);
      }
    }
  }
}

// Unsafe
__device__ void write_c(
  float8 cCache[8],
  float* C,
  int gStartx, int gStarty,
  int vx, int vy, int bid,
  int M, int N
) {
  #pragma unroll
  for (int i=0; i<8; i++){
    int iM = gStarty + vy*8 + i;
    if (likely(iM < M)){
      int iN_start = gStartx + vx*8;
      reinterpret_cast<float8*>(C + (bid)*M*N + (iM)*N + (iN_start))[0] = cCache[i];
      /*
      if (likely(iN_start + 7 < N)){
        reinterpret_cast<float8*>(C + (bid)*M*N + (iM)*N + (iN_start))[0] = cCache[i];
      } else {
        #pragma unroll
        for (int j=0; j<8; j++){
          int iN = iN_start + j;
          if (iN < N){
            C[(bid)*M*N + (iM)*N + (iN)] = cCache[i].val[j];
          }
        }
      }
      */
    }
  }
}

__device__ void write_c_v3(
  float8 cCache[8],
  float* C,
  int gStartx, int gStarty,
  int vx, int vy, int bid,
  int M, int N
) {
  __shared__ volatile float cSM[16][128];
  #pragma unroll
  for (int mi=0; mi<8; mi++){
    int iM = gStarty + vy*8 + mi;
    // Store 1 row from cCache to cSM
    if (iM < M){
      #pragma unroll
      for (int ni=0; ni<8; ni++){
        cSM[vy][vx*8 + ni] = cCache[mi].val[ni];
      }
      // Store to C
      #pragma unroll
      for (int ni=0; ni<8; ni++){
        int iN = gStartx + 16*ni + vx;
        if (iN < N){
          float cVal = cSM[vy][16*ni + vx];
          store(C+(bid)*M*N + (iM)*N + (iN), cVal);
        }
      }
    }
  } 
}

extern "C"
__global__ void topk_bmm_tn(
  const float* __restrict__ A,
  const float* __restrict__ B,
  float* __restrict__ values,
  ll_t* __restrict__ indices,
  int M, int N, int K, int DIM
){
}

extern "C"
__global__ void topk_bmm_nt(
  const float* __restrict__ A,
  const float* __restrict__ B,
  float* __restrict__ values,
  ll_t* __restrict__ indices,
  int M, int N, int K, int DIM
){
}

extern "C"
__global__ void topk_bmm_nn(
  const float* __restrict__ A,
  const float* __restrict__ B,
  float* values,
  ll_t* indices,
  unsigned int* mutex,
  int M, int N, int K, int DIM, int Q
){
  int tid = threadIdx.x;     // thread idx
  int bid = blockIdx.z;      // batch idx

  // Neighboring blocks are grouped into PN x PM block groups in order to increase
  // L1 cache hit rate
  // There are ceil(M/PM) x ceil(N/PN) block groups in total.
  // Blocks within block groups are indexed with blockIdx.x % PN and blockIdx.x / PN
  int px = blockIdx.x % _PN_;
  int py = blockIdx.x / _PN_;
  int bDimX = (N + (128*_PN_) - 1) / (128*_PN_); 
  int bDimY = (M + (128*_PM_) - 1) / (128*_PM_); 
  int bIdxX = (blockIdx.y % bDimX) * _PN_ + px;
  int bIdxY = (blockIdx.y / bDimX) * _PM_ + py;
  int gStartx = bIdxX * 128;   // starting index of block on N axis
  int gStarty = bIdxY * 128;   // starting index of block on M axis
  if (gStartx > N || gStarty > M){
    return;
  }
  // These are used to re-arrange threads into different shapes
  // for example: (256) -> (16, 16) -> (8, 32) -> (32, 8)
  int vx = tid % 16;
  int vy = tid / 16;
  int wx = tid % 32; // thread idx in warp
  int wy = tid / 32; // warp id
  int dx = tid % 8;
  int dy = tid / 8;

  __shared__ _VOLATILE_ float aSM[16][128+4];
  __shared__ _VOLATILE_ float bSM[16][128+4];

  float aBuffer1[4];
  float bBuffer1[4];
  float aBuffer2[4];
  float bBuffer2[4];

  float8 cCache[8];
  init_cCache(cCache);

  // Load initial 16 x 128 tile of A and B to buffer1 and buffer2
  #pragma unroll
  for (int i=0; i<4; i++){
    int iM = gStarty + dy + i*32;
    int iN = gStartx + wx + i*32;
    if (likely(iM < _M_)){
      if (likely(dx < _K_)){
        aBuffer1[i] = load(A + (bid)*_M_*_K_ + (iM)*_K_ + (dx));
      } else {
        aBuffer1[i] = 0.f;
      }
      if (likely(dx+8 < _K_)){
        aBuffer2[i] = load(A + (bid)*_M_*_K_ + (iM)*_K_ + (dx+8));
      } else {
        aBuffer2[i] = 0.f;
      }
    }
    if (likely(iN < N)){
      if (likely(wy < _K_)){
        bBuffer1[i] = load(B + (bid)*_N_*_K_ + (wy)*_N_ + (iN));
      } else {
        bBuffer1[i] = 0.f;
      }
      if (likely(wy+8 < _K_)){
        bBuffer2[i] = load(B + (bid)*_N_*_K_ + (wy+8)*_N_ + (iN));
      } else {
        bBuffer2[i] = 0.f;
      }
    }
  }

  // Number of main loop iterations is ceil(k/16)
  int nIt = (_K_ + 16 - 1) / 16;
  #pragma unroll
  for (int itr=0; itr<nIt; itr++){
    int gStartk = itr * 16;

    // Index on K axis of A and B
    int iKA = gStartk + 16 + dx;
    int iKB = gStartk + 16 + wy;

    #pragma unroll
    for (int i=0; i<4; i++){
      // Store buffered tiles into shared memory
      aSM[dx][dy+i*32] = aBuffer1[i];
      bSM[wy][wx+i*32+i] = bBuffer1[i];
      aSM[8 + dx][dy+i*32] = aBuffer2[i];
      bSM[8 + wy][wx+i*32+i] = bBuffer2[i];

      // Start loading next 16*128 tile of A and B to buffer1 and buffer2.
      // Don't load anything on the last iteration.
      // Loading from global memory will not block thread_matmul
      if (likely(itr < nIt - 1)){
        int iM = gStarty + i*32 + dy;
        int iN = gStartx + i*32 + wx;
        
        if (likely(iM < _M_)){
          if (likely(iKA < _K_)){
            aBuffer1[i] = load(A + (bid)*_M_*_K_ + (iM)*_K_ + (iKA));
          } else {
            aBuffer1[i] = 0.f;
          }
          if (likely(iKA+8 < _K_)){
            aBuffer2[i] = load(A + (bid)*_M_*_K_ + (iM)*_K_ + (iKA+8));
          } else {
            aBuffer2[i] = 0.f;
          }
        }

        if (likely(iN < _N_)){
          if (likely(iKB < _K_)){
            bBuffer1[i] = load(B + (bid)*_N_*_K_ + (iKB)*_N_ + (iN));
          } else {
            bBuffer1[i] = 0.f;
          }
          if (likely(iKB+8 < _K_)){
            bBuffer2[i] = load(B + (bid)*_N_*_K_ + (iKB+8)*_N_ + (iN));
          } else {
            bBuffer2[i] = 0.f;
          }
        }
      }
    }
    // synchroznie threads in order make sure tiles of A and B are fully
    // loaded to shared memory.
    __syncthreads();

    thread_matmul_v3(aSM, bSM, cCache, vx, vy);

    // synchronize threads to signal that shared memory is consumed.
    __syncthreads();
  }

  // TopK sort along DIM
  if (DIM == 1){
    topk_dim_1(
      cCache, aSM, bSM,
      values, indices, mutex,
      gStartx, gStarty, bid, M, N, Q);
  } else if (DIM == 2){
    topk_dim_2(
      cCache, aSM, bSM,
      values, indices, mutex,
      gStartx, gStarty, bid, M, N, Q);
  }
}

extern "C"
__global__ void topk_bmm_tt(
  const float* __restrict__ A,
  const float* __restrict__ B,
  float* __restrict__ values,
  ll_t* __restrict__ indices,
  int M, int N, int K, int DIM
){
}
"""
with open("TopkBMMCUDA.cu", "w") as f:
  f.write(kernel)

In [12]:
#@title TopkBMM
import torch
import cupy as cp
import numpy as np
import math

class TopkBMMCUDA(CustomKernel): 
  def __init__(self, m=None, n=None, k=None, patch_m=4, patch_n=4):
    super(TopkBMMCUDA, self).__init__()
    self.m = m
    self.n = n
    self.k = k
    self.patch_m = patch_m
    self.patch_n = patch_n
    
    with open("TopkBMMCUDA.cu",'r') as f: ###
      self.kernel = f.read()
      
    self.kernel = (self.kernel
      .replace("_M_", str(m) if m else "M")
      .replace("_N_", str(n) if n else "N")
      .replace("_K_", str(k) if k else "K")
      .replace("_PM_", str(self.patch_m))
      .replace("_PN_", str(self.patch_n))
    )
    
    self._fn_tt = cp.RawKernel(
      code=self.kernel,
      name="topk_bmm_tt",
      backend='nvcc',
      options=('--maxrregcount=128', '--use_fast_math')
    )
    self._fn_nn = cp.RawKernel(
      code=self.kernel,
      name="topk_bmm_nn",
      backend='nvcc',
      options=(
        '--maxrregcount=128',
        '--use_fast_math',
        #'-Xptxas',
        #'-dlcm=cg',
      )
    )
    # print(self._fn_nn.attributes)
    self._fn_tn = cp.RawKernel(
      code=self.kernel,
      name="topk_bmm_tn",
      backend='nvcc',
      options=('--maxrregcount=128', '--use_fast_math')
    )
    self._fn_nt = cp.RawKernel(
      code=self.kernel,
      name="topk_bmm_nt",
      backend='nvcc',
      options=('--maxrregcount=128', '--use_fast_math')
    )

  def _call_nn(self, A, B, n_candidates, dim):
    """
      Performs C = A @ B
      A: shape = [l, m, k]
      B: shape = [l, k, n]
      returns C: shape = [l, m, n]
    """
    assert A.shape[0] == B.shape[0]
    assert A.shape[2] == B.shape[1]
    assert A.device.type == "cuda"
    assert B.device.type == "cuda"
    assert A.dtype in (torch.float, torch.half)
    assert B.dtype in (torch.float, torch.half)
    assert dim in [1, 2]
    assert 0 < n_candidates <= 128
    
    l, m, k = A.shape
    l, k, n = B.shape

    if self.m is not None: assert m == self.m
    if self.n is not None: assert n == self.n
    if self.k is not None: assert k == self.k

    if dim == 1:
      values = torch.empty([l, n, n_candidates], device="cuda:0", dtype=A.dtype)
      indices = torch.empty([l, n, n_candidates], device="cuda:0", dtype=torch.int64)
      mutex = torch.zeros([l, n], device="cuda:0", dtype=torch.int32)
    elif dim == 2:
      values = torch.empty([l, m, n_candidates], device="cuda:0", dtype=A.dtype)
      indices = torch.empty([l, m, n_candidates], device="cuda:0", dtype=torch.int64)
      mutex = torch.zeros([l, m], device="cuda:0", dtype=torch.int32)
    values.fill_(float("-inf"))

    threads_per_block = (256,)
    #blocks_per_grid = (math.ceil(n/128), math.ceil(m/128), l)
    
    n_ = math.ceil(n/(128*self.patch_n))
    m_ = math.ceil(m/(128*self.patch_m))
    blocks_per_grid = (self.patch_n*self.patch_m, n_ * m_, l)
    # print(blocks_per_grid, m_, n_)

    self._fn_nn(
      grid=blocks_per_grid,
      block=threads_per_block,
      args=[
        A.data_ptr(),
        B.data_ptr(),
        values.data_ptr(),
        indices.data_ptr(),
        mutex.data_ptr(),
        m, n, k, dim, n_candidates
      ],
      stream=self.stream
    )
    return values, indices

  def _call_tt(self, A, B, n_candidates, dim):
    raise NotImplementedError

  def _call_tn(self, A, B, n_candidates, dim):
    raise NotImplementedError

  def _call_nt(self, A, B, n_candidates, dim):
    raise NotImplementedError

  def __call__(self, A, B, k=128, dim=1, mode="nn"):
    """
      Performs C = min(f(A) @ g(B)), argmin(f(A) @ g(B))
      A: torch.Tensor, shape : [l, m, k] or [l, k, m]
      B: torch.Tensor, shape : [l, n, k] or [l, k, n]
      returns C: torch.Tensor, shape : [l, m, n]
      mode: str, default: "nn"
      Notes:
        f() and g() are determined by mode
        "nn" --> A @ B
        "tt" --> A.T @ B.T
        "nt" --> A @ B.T
        "tn" --> A.T @ B
    """
    assert len(A.shape) == len(B.shape)
    A = A.contiguous()
    B = B.contiguous()
    if len(A.shape) == 2 and len(B.shape) == 2:
      A2 = A[None]
      B2 = B[None]
    elif len(A.shape) == 3 and len(B.shape) == 3:
      A2 = A
      B2 = B
    else:
      raise ValueError("shape of A and B need to be 2d or 3d")

    if mode == "nn":
      values, indices = self._call_nn(A2, B2, k, dim)
    elif mode == "tt":
      values, indices = self._call_tt(A2, B2, k, dim)
    elif mode == "tn":
      values, indices = self._call_tn(A2, B2, k, dim)
    elif mode == "nt":
      values, indices = self._call_nt(A2, B2, k, dim)

    if len(A.shape) == 2 and len(B.shape) == 2:
      indices = indices[0]
      values = values[0]

    return values, indices

In [13]:
#@title BMMv2.5 Kernel
kernel = """
#define _VOLATILE_  

#define likely(x)      __builtin_expect(!!(x), 1)
#define unlikely(x)    __builtin_expect(!!(x), 0)
#define load(x)        __ldcg(x)
#define store(x, value) __stcs(x, value)

typedef long long ll_t;
typedef unsigned long long ull_t;

typedef struct __builtin_align__(32) {
  float s0, s1, s2, s3, s4, s5, s6, s7;
} _float8;

typedef union {
  _float8 f8;
  float val[8];
} float8;

__device__ void init_cCache(
  float8 cCache[8]
) {
  #pragma unroll
  for (int i=0; i<8; i++){
    #pragma unroll
    for (int j=0; j<8; j++){
      cCache[i].val[j] = 0.f;
    }
  }
}

__device__ void thread_matmul_v4(
  _VOLATILE_ float aSM[8][128+4],
  _VOLATILE_ float bSM[8][128+4],
  float8 cCache[8],
  int vx, int vy
) {
  float aCache1[8];
  float aCache2[8];
  #pragma unroll
  for (int mi=0; mi<8; mi++){
    aCache1[mi] = aSM[0][8*vy + mi];
  }

  #pragma unroll
  for (int ki=0; ki<8; ki++){
    int is_odd = ki & 1;
    if (is_odd == 0){
      if (likely(ki < 7)){
        #pragma unroll
        for (int mi=0; mi<8; mi++){
          aCache2[mi] = aSM[ki+1][8*vy + mi];
        }
      }
      #pragma unroll
      for (int ni=0; ni<8; ni++){
        float b = bSM[ki][vx/4 + 8*vx + ni];
        #pragma unroll
        for (int mi=0; mi<8; mi++){
          float a = aCache1[mi];
          cCache[mi].val[ni] = fmaf(a, b, cCache[mi].val[ni]);
        }
      }
    } else {
      if (likely(ki < 7)){
        #pragma unroll
        for (int mi=0; mi<8; mi++){
          aCache1[mi] = aSM[ki+1][8*vy + mi];
        }
      }
      #pragma unroll
      for (int ni=0; ni<8; ni++){
        float b = bSM[ki][vx/4 + 8*vx + ni];
        #pragma unroll
        for (int mi=0; mi<8; mi++){
          float a = aCache2[mi];
          cCache[mi].val[ni] = fmaf(a, b, cCache[mi].val[ni]);
        }
      }
    }
  }
}

__device__ void thread_matmul_v3(
  _VOLATILE_ float aSM[8][128+4],
  _VOLATILE_ float bSM[8][128+4],
  float8 cCache[8],
  int vx, int vy
) {
  float aCache[8];

  #pragma unroll
  for (int ki=0; ki<8; ki++){
    #pragma unroll
    for (int mi=0; mi<8; mi++){
      aCache[mi] = aSM[ki][8*vy + mi];
    }
    #pragma unroll
    for (int ni=0; ni<8; ni++){
      float b = bSM[ki][vx/4 + 8*vx + ni];
      #pragma unroll
      for (int mi=0; mi<8; mi++){
        float a = aCache[mi];
        cCache[mi].val[ni] = fmaf(a, b, cCache[mi].val[ni]);
      }
    }
  }
}

// Unsafe
__device__ void write_c(
  float8 cCache[8],
  float* C,
  int gStartx, int gStarty,
  int vx, int vy, int bid,
  int M, int N
) {
  #pragma unroll
  for (int i=0; i<8; i++){
    int iM = gStarty + vy*8 + i;
    if (likely(iM < M)){
      int iN_start = gStartx + vx*8;
      reinterpret_cast<float8*>(C + (bid)*M*N + (iM)*N + (iN_start))[0] = cCache[i];
      /*
      if (likely(iN_start + 7 < N)){
        reinterpret_cast<float8*>(C + (bid)*M*N + (iM)*N + (iN_start))[0] = cCache[i];
      } else {
        #pragma unroll
        for (int j=0; j<8; j++){
          int iN = iN_start + j;
          if (iN < N){
            C[(bid)*M*N + (iM)*N + (iN)] = cCache[i].val[j];
          }
        }
      }
      */
    }
  }
}

__device__ void write_c_v3(
  float8 cCache[8],
  float* C,
  int gStartx, int gStarty,
  int vx, int vy, int bid,
  int M, int N
) {
  __shared__ volatile float cSM[16][128];
  #pragma unroll
  for (int mi=0; mi<8; mi++){
    int iM = gStarty + vy*8 + mi;
    // Store 1 row from cCache to cSM
    if (iM < M){
      #pragma unroll
      for (int ni=0; ni<8; ni++){
        cSM[vy][vx*8 + ni] = cCache[mi].val[ni];
      }
      // Store to C
      #pragma unroll
      for (int ni=0; ni<8; ni++){
        int iN = gStartx + 16*ni + vx;
        if (iN < N){
          float cVal = cSM[vy][16*ni + vx];
          store(C+(bid)*M*N + (iM)*N + (iN), cVal);
        }
      }
    }
  } 
}

extern "C"
__global__ void bmm_tn(
  const float* __restrict__ A,
  const float* __restrict__ B,
  float* __restrict__ C,
  int M, int N, int K
){
}

extern "C"
__global__ void bmm_nt(
  const float* __restrict__ A,
  const float* __restrict__ B,
  float* __restrict__ C,
  int M, int N, int K
){
}

extern "C"
__global__ void bmm_nn(
  const float* __restrict__ A,
  const float* __restrict__ B,
  float* __restrict__ C,
  int M, int N, int K
){
  int tid = threadIdx.x;     // thread idx
  int bid = blockIdx.z;      // batch idx

  // Neighboring blocks are grouped into PN x PM block groups in order to increase
  // L1 cache hit rate
  // There are ceil(M/PM) x ceil(N/PN) block groups in total.
  // Blocks within block groups are indexed with blockIdx.x % PN and blockIdx.x / PN
  int px = blockIdx.x % _PN_;
  int py = blockIdx.x / _PN_;
  int bDimX = (N + (128*_PN_) - 1) / (128*_PN_); 
  int bDimY = (M + (128*_PM_) - 1) / (128*_PM_); 
  int bIdxX = (blockIdx.y % bDimX) * _PN_ + px;
  int bIdxY = (blockIdx.y / bDimX) * _PM_ + py;
  int gStartx = bIdxX * 128;   // starting index of block on N axis
  int gStarty = bIdxY * 128;   // starting index of block on M axis
  if (gStartx > N || gStarty > M){
    return;
  }
  // These are used to re-arrange threads into different shapes
  // for example: (256) -> (16, 16) -> (8, 32) -> (32, 8)
  int vx = tid % 16;
  int vy = tid / 16;
  int wx = tid % 32; // thread idx in warp
  int wy = tid / 32; // warp id
  int dx = tid % 8;
  int dy = tid / 8;

  __shared__ _VOLATILE_ float aSM1[8][128+4];
  __shared__ _VOLATILE_ float bSM1[8][128+4];
  __shared__ _VOLATILE_ float aSM2[8][128+4];
  __shared__ _VOLATILE_ float bSM2[8][128+4];
  float aBuffer1[4];
  float bBuffer1[4];
  float aBuffer2[4];
  float bBuffer2[4];

  float8 cCache[8];
  init_cCache(cCache);

  // Load initial 16 x 128 tile of A and B to buffer1 and buffer2
  #pragma unroll
  for (int i=0; i<4; i++){
    int iM = gStarty + dy + i*32;
    int iN = gStartx + wx + i*32;
    if (likely(iM < _M_)){
      if (likely(dx < _K_)){
        aBuffer1[i] = load(A + (bid)*_M_*_K_ + (iM)*_K_ + (dx));
      } else {
        aBuffer1[i] = 0.f;
      }
      if (likely(dx+8 < _K_)){
        aBuffer2[i] = load(A + (bid)*_M_*_K_ + (iM)*_K_ + (dx+8));
      } else {
        aBuffer2[i] = 0.f;
      }
    }
    if (likely(iN < N)){
      if (likely(wy < _K_)){
        bBuffer1[i] = load(B + (bid)*_N_*_K_ + (wy)*_N_ + (iN));
      } else {
        bBuffer1[i] = 0.f;
      }
      if (likely(wy+8 < _K_)){
        bBuffer2[i] = load(B + (bid)*_N_*_K_ + (wy+8)*_N_ + (iN));
      } else {
        bBuffer2[i] = 0.f;
      }
    }
  }

  // Number of main loop iterations is ceil(k/16)
  int nIt = (_K_ + 16 - 1) / 16;
  #pragma unroll
  for (int itr=0; itr<nIt; itr++){
    int gStartk = itr * 16;

    // Index on K axis of A and B
    int iKA = gStartk + 16 + dx;
    int iKB = gStartk + 16 + wy;

    #pragma unroll
    for (int i=0; i<4; i++){
      // Store buffered tiles into shared memory
      aSM1[dx][dy+i*32] = aBuffer1[i];
      bSM1[wy][wx+i*32+i] = bBuffer1[i];
      aSM2[dx][dy+i*32] = aBuffer2[i];
      bSM2[wy][wx+i*32+i] = bBuffer2[i];

      // Start loading next 16*128 tile of A and B to buffer1 and buffer2.
      // Don't load anything on the last iteration.
      // Loading from global memory will not block thread_matmul
      if (likely(itr < nIt - 1)){
        int iM = gStarty + i*32 + dy;
        int iN = gStartx + i*32 + wx;
        
        if (likely(iM < _M_)){
          if (likely(iKA < _K_)){
            aBuffer1[i] = load(A + (bid)*_M_*_K_ + (iM)*_K_ + (iKA));
          } else {
            aBuffer1[i] = 0.f;
          }
          if (likely(iKA+8 < _K_)){
            aBuffer2[i] = load(A + (bid)*_M_*_K_ + (iM)*_K_ + (iKA+8));
          } else {
            aBuffer2[i] = 0.f;
          }
        }

        if (likely(iN < _N_)){
          if (likely(iKB < _K_)){
            bBuffer1[i] = load(B + (bid)*_N_*_K_ + (iKB)*_N_ + (iN));
          } else {
            bBuffer1[i] = 0.f;
          }
          if (likely(iKB+8 < _K_)){
            bBuffer2[i] = load(B + (bid)*_N_*_K_ + (iKB+8)*_N_ + (iN));
          } else {
            bBuffer2[i] = 0.f;
          }
        }
      }
    }
    // synchroznie threads in order make sure tiles of A and B are fully
    // loaded to shared memory.
    __syncthreads();

    // Each thread computes 8 x 8 matrix multiplication
    // Accumulate intermediate results in cCache
    // aSM1, bSM1, aSM2, bSM2 are consumed
    thread_matmul_v3(aSM1, bSM1, cCache, vx, vy);
    thread_matmul_v3(aSM2, bSM2, cCache, vx, vy);

    // synchronize threads to signal that shared memory is consumed.
    __syncthreads();
  }
  
  // At the end of main loop, store cCache to C
  //write_c(cCache, C, gStartx, gStarty, vx, vy, bid, M, N);
  write_c_v3(cCache, C, gStartx, gStarty, vx, vy, bid, M, N);

  //C[bIdxY * N + bIdxX] = gStarty;
}

extern "C"
__global__ void bmm_tt(
  const float* __restrict__ A,
  const float* __restrict__ B,
  float* __restrict__ C,
  int M, int N, int K
){
}
"""
with open("BMMCUDAv2_5.cu", "w") as f:
  f.write(kernel)

In [14]:
#@title BMMv2.5
import torch
import cupy as cp
import numpy as np
import math

class BMMCUDAv2_5(CustomKernel): 
  def __init__(self, m=None, n=None, k=None, patch_m=4, patch_n=4):
    super(BMMCUDAv2_5, self).__init__()
    self.m = m
    self.n = n
    self.k = k
    self.patch_m = patch_m
    self.patch_n = patch_n
    
    with open("BMMCUDAv2_5.cu",'r') as f: ###
      self.kernel = f.read()
      
    self.kernel = (self.kernel
      .replace("_M_", str(m) if m else "M")
      .replace("_N_", str(n) if n else "N")
      .replace("_K_", str(k) if k else "K")
      .replace("_PM_", str(self.patch_m))
      .replace("_PN_", str(self.patch_n))
    )
    
    self._fn_tt = cp.RawKernel(
      code=self.kernel,
      name="bmm_tt",
      backend='nvcc',
      options=('--maxrregcount=128', '--use_fast_math')
    )
    self._fn_nn = cp.RawKernel(
      code=self.kernel,
      name="bmm_nn",
      backend='nvcc',
      options=(
        '--maxrregcount=128',
        '--use_fast_math',
        #'-Xptxas',
        #'-dlcm=cg',
      )
    )
    # print(self._fn_nn.attributes)
    self._fn_tn = cp.RawKernel(
      code=self.kernel,
      name="bmm_tn",
      backend='nvcc',
      options=('--maxrregcount=128', '--use_fast_math')
    )
    self._fn_nt = cp.RawKernel(
      code=self.kernel,
      name="bmm_nt",
      backend='nvcc',
      options=('--maxrregcount=128', '--use_fast_math')
    )

  def _call_nn(self, A, B):
    """
      Performs C = A @ B
      A: shape = [l, m, k]
      B: shape = [l, k, n]
      returns C: shape = [l, m, n]
    """
    assert A.shape[0] == B.shape[0]
    assert A.shape[2] == B.shape[1]
    assert A.device.type == "cuda"
    assert B.device.type == "cuda"
    assert A.dtype in (torch.float, torch.half)
    assert B.dtype in (torch.float, torch.half)
    
    l, m, k = A.shape
    l, k, n = B.shape

    if self.m is not None: assert m == self.m
    if self.n is not None: assert n == self.n
    if self.k is not None: assert k == self.k

    C = torch.zeros([l, m, n], device="cuda:0", dtype=A.dtype)

    threads_per_block = (256,)
    #blocks_per_grid = (math.ceil(n/128), math.ceil(m/128), l)
    
    n_ = math.ceil(n/(128*self.patch_n))
    m_ = math.ceil(m/(128*self.patch_m))
    blocks_per_grid = (self.patch_n*self.patch_m, n_ * m_, l)
    # print(blocks_per_grid, m_, n_)

    self._fn_nn(
      grid=blocks_per_grid,
      block=threads_per_block,
      args=[
        A.data_ptr(),
        B.data_ptr(),
        C.data_ptr(),
        m, n, k,
      ],
      stream=self.stream
    )
    return C

  def _call_tt(self, A, B):
    raise NotImplementedError

  def _call_tn(self, A, B):
    raise NotImplementedError

  def _call_nt(self, A, B):
    raise NotImplementedError

  def __call__(self, A, B, mode="nn"):
    """
      Performs C = f(A) @ f(B)
      A: torch.Tensor, shape : [l, m, k] or [l, k, m]
      B: torch.Tensor, shape : [l, n, k] or [l, k, n]
      returns C: torch.Tensor, shape : [l, m, n]
      mode: str, default: "nn"
      Notes:
        f() and g() are determined by mode
        "nn" --> A @ B
        "tt" --> A.T @ B.T
        "nt" --> A @ B.T
        "tn" --> A.T @ B
    """
    assert len(A.shape) == len(B.shape)
    A = A.contiguous()
    B = B.contiguous()
    if len(A.shape) == 2 and len(B.shape) == 2:
      A2 = A[None]
      B2 = B[None]
    elif len(A.shape) == 3 and len(B.shape) == 3:
      A2 = A
      B2 = B
    else:
      raise ValueError("shape of A and B need to be 2d or 3d")

    if mode == "nn":
      C = self._call_nn(A2, B2)
    elif mode == "tt":
      C = self._call_tt(A2, B2)
    elif mode == "tn":
      C = self._call_tn(A2, B2)
    elif mode == "nt":
      C = self._call_nt(A2, B2)

    if len(A.shape) == 2 and len(B.shape) == 2:
      C = C[0]
    return C

## Note
#### TopkBMM only works correctly when *M* and *N* are divisible by 128.  
#### "n_candidates" is the k in topk.  
#### *n_candidates* should be smaller than 128.  

In [None]:
#@title test TopkBMM
def test_topk_bmm(l, m, n, k, mode="nn", n_iter=1, dim=1, n_candidates=128, verbose=0):
  print(f"l={l}  m={m}  n={n}  k={k}")
  if mode[0] == "n":
    A = torch.randn(l, m, k, device="cuda:0")
  elif mode[0] == "t":
    A = torch.randn(l, k, m, device="cuda:0")
  
  if mode[1] == "n":
    B = torch.randn(l, k, n, device="cuda:0")
  elif mode[1] == "t":
    B = torch.randn(l, n, k, device="cuda:0")
  custom_topk_bmm = TopkBMMCUDA(patch_m=1, patch_n=16)
  flop = l * m * n * k * 2 + l * m * n

  if mode[0] == "t":
    At = A.transpose(1, 2)
  else: 
    At = A
  if mode[1] == "t":
    Bt = B.transpose(1, 2)
  else:
    Bt = B
  #warmup

  for i in range(n_iter):
    C = torch.bmm(At, Bt)
    C.topk(k=n_candidates, dim = dim)
    torch.cuda.synchronize()

  tm = time()
  for i in range(n_iter):
    C = torch.bmm(At, Bt)
    C_v, C_i = C.topk(k = n_candidates, dim = dim)
    torch.cuda.synchronize()
  time_cost_0 = (time() - tm) / n_iter
  flops0 = (flop / time_cost_0) / 1000**4
  if dim == 1:
    C_v = C_v.transpose(1,2)
    C_i = C_i.transpose(1,2)
  if verbose > 0:
    print("time spent for torch.bmm + min:", time_cost_0)
    print("tflops:", flops0)
  else:
    del C_v, C_i 
  del C

  # warmup
  for i in range(n_iter):
    custom_topk_bmm(A, B, mode=mode, dim=dim)
    torch.cuda.synchronize()
  tm = time()
  for i in range(n_iter):
    C1_v, C1_i = custom_topk_bmm(A, B, mode=mode, dim=dim, k=n_candidates)
    torch.cuda.synchronize()
  time_cost_1 = (time() - tm) / n_iter
  flops1 = (flop / time_cost_1) / 1000**4
  if verbose > 0:
    print("time spent for custom_topk_bmm:", time_cost_1)
    print("tflops:", flops1)
  else:
    del C1_v, C1_i

  if verbose > 0:
    val_dif = (C1_v - C_v).abs()
    idx_dif = (C1_i != C_i)
    print("Max Val Error", val_dif.max())
    print("Val Error:", val_dif.sum())
    print("Idx Error:", idx_dif.sum())
    print("ratio:", time_cost_1 / time_cost_0)

  return time_cost_0, time_cost_1
  
_ = test_topk_bmm(
    1, 16384, 1024*2 + 512, 256,
    mode="nn", dim=1, n_iter=10, 
    n_candidates=128, verbose=1
)
# topk_dim1_v1 0.41 0.03
# topk_dim1_v2 0.38 0.0305

In [16]:
import os
if not os.path.exists("imgs"):
  os.mkdir("imgs")

In [None]:
#@title Grid test TopkBMM
ls = [1]
ms = [256* 2**i for i in range(1, 3)]
ns = [1024]

ks = [1024]
mode="nn"

custom_res = dict()
cublass_res = dict()
for l in ls:
  for n in ns:
    for m in ms:
      for k in ks:     
        res = test_topk_bmm(
          l, m, n, k,
          mode=mode, n_iter=15, dim=1,
          n_candidates = 128,
        )
        cublass_res[m] = res[0]*1e3
        custom_res[m] = res[1]*1e3


plt.figure(figsize=(15, 10) )
plt.tight_layout()
plt.xlabel("N", fontsize=17)
plt.ylabel("milliseconds", fontsize=17)
title = f"A[{l},N,{k}] B[{l},{k},{n}]"
plt.title(title)
plt.rcParams["font.size"] = "17"
plt.grid()
colors = ["red", "blue"]
labels = ["custom_topk_bmm", "torch.bmm -> torch.topk"]
for i, res in enumerate([custom_res, cublass_res]):
  res_x = list(res.keys())
  res_y = list(res.values())
  # plt.plot(
  # plt.loglog(
  plt.semilogx(
    res_x,
    res_y,
    color=colors[i],
    label=labels[i],
  )
  # plt.plot(res_x, res_y, colors[i])
plt.legend()
plt.savefig("imgs/topk_bmm_" + title + "_semilogx")
plt.show()