In [1]:
#!pip install --upgrade cupy-cuda112==8.5.0

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import cupy as cp
import math
from time import time

In [3]:
#@title CustomKernel
import cupy as cp
import torch

@cp.util.memoize(for_each_device=True)
def cunnex(func_name, func_body):
  return cp.cuda.compile_with_cache(func_body).get_function(func_name)

class Stream:
  def __init__(self, ptr):
    self.ptr = ptr
  
class CustomKernel:
  def __init__(self):
    self._use_torch_in_cupy_malloc()
    self.stream = Stream(torch.cuda.current_stream().cuda_stream)

  @staticmethod
  def _torch_alloc(size):
    device = cp.cuda.Device().id
    tensor = torch.empty(size, dtype=torch.uint8, device=device)
    return cp.cuda.MemoryPointer(
        cp.cuda.UnownedMemory(tensor.data_ptr(), size, tensor), 0)

  def _use_torch_in_cupy_malloc(self):
    cp.cuda.set_allocator(self._torch_alloc)

  def _compile_kernel_str(
      self,
      kernel,
      name,
      options=(),
      backend="nvrtc",
      max_dynamic_smem=None
    ):
    fn = cp.RawKernel(
      kernel,
      name,
      options=options,
      backend=backend,
    )
    if max_dynamic_smem:
      fn.max_dynamic_shared_size_bytes = max_dynamic_smem
    return fn

In [4]:
#@title BMMv2.5 Kernel
kernel = """
#define _VOLATILE_  

#define likely(x)      __builtin_expect(!!(x), 1)
#define unlikely(x)    __builtin_expect(!!(x), 0)
#define load(x)        __ldcg(x)
#define store(x, value) __stcs(x, value)

typedef long long ll_t;
typedef unsigned long long ull_t;

typedef struct __builtin_align__(32) {
  float s0, s1, s2, s3, s4, s5, s6, s7;
} _float8;

typedef union {
  _float8 f8;
  float val[8];
} float8;

__device__ void init_cCache(
  float8 cCache[8]
) {
  #pragma unroll
  for (int i=0; i<8; i++){
    #pragma unroll
    for (int j=0; j<8; j++){
      cCache[i].val[j] = 0.f;
    }
  }
}

__device__ void thread_matmul_v4(
  _VOLATILE_ float aSM[8][128+4],
  _VOLATILE_ float bSM[8][128+4],
  float8 cCache[8],
  int vx, int vy
) {
  float aCache1[8];
  float aCache2[8];
  #pragma unroll
  for (int mi=0; mi<8; mi++){
    aCache1[mi] = aSM[0][8*vy + mi];
  }

  #pragma unroll
  for (int ki=0; ki<8; ki++){
    int is_odd = ki & 1;
    if (is_odd == 0){
      if (likely(ki < 7)){
        #pragma unroll
        for (int mi=0; mi<8; mi++){
          aCache2[mi] = aSM[ki+1][8*vy + mi];
        }
      }
      #pragma unroll
      for (int ni=0; ni<8; ni++){
        float b = bSM[ki][vx/4 + 8*vx + ni];
        #pragma unroll
        for (int mi=0; mi<8; mi++){
          float a = aCache1[mi];
          cCache[mi].val[ni] = fmaf(a, b, cCache[mi].val[ni]);
        }
      }
    } else {
      if (likely(ki < 7)){
        #pragma unroll
        for (int mi=0; mi<8; mi++){
          aCache1[mi] = aSM[ki+1][8*vy + mi];
        }
      }
      #pragma unroll
      for (int ni=0; ni<8; ni++){
        float b = bSM[ki][vx/4 + 8*vx + ni];
        #pragma unroll
        for (int mi=0; mi<8; mi++){
          float a = aCache2[mi];
          cCache[mi].val[ni] = fmaf(a, b, cCache[mi].val[ni]);
        }
      }
    }
  }
}

__device__ void thread_matmul_v3(
  _VOLATILE_ float aSM[8][128+4],
  _VOLATILE_ float bSM[8][128+4],
  float8 cCache[8],
  int vx, int vy
) {
  float aCache[8];

  #pragma unroll
  for (int ki=0; ki<8; ki++){
    #pragma unroll
    for (int mi=0; mi<8; mi++){
      aCache[mi] = aSM[ki][8*vy + mi];
    }
    #pragma unroll
    for (int ni=0; ni<8; ni++){
      float b = bSM[ki][vx/4 + 8*vx + ni];
      #pragma unroll
      for (int mi=0; mi<8; mi++){
        float a = aCache[mi];
        cCache[mi].val[ni] = fmaf(a, b, cCache[mi].val[ni]);
      }
    }
  }
}

// Unsafe
__device__ void write_c(
  float8 cCache[8],
  float* C,
  int gStartx, int gStarty,
  int vx, int vy, int bid,
  int M, int N
) {
  #pragma unroll
  for (int i=0; i<8; i++){
    int iM = gStarty + vy*8 + i;
    if (likely(iM < M)){
      int iN_start = gStartx + vx*8;
      reinterpret_cast<float8*>(C + (bid)*M*N + (iM)*N + (iN_start))[0] = cCache[i];
      /*
      if (likely(iN_start + 7 < N)){
        reinterpret_cast<float8*>(C + (bid)*M*N + (iM)*N + (iN_start))[0] = cCache[i];
      } else {
        #pragma unroll
        for (int j=0; j<8; j++){
          int iN = iN_start + j;
          if (iN < N){
            C[(bid)*M*N + (iM)*N + (iN)] = cCache[i].val[j];
          }
        }
      }
      */
    }
  }
}

__device__ void write_c_v3(
  float8 cCache[8],
  float* C,
  int gStartx, int gStarty,
  int vx, int vy, int bid,
  int M, int N
) {
  __shared__ volatile float cSM[16][128];
  #pragma unroll
  for (int mi=0; mi<8; mi++){
    int iM = gStarty + vy*8 + mi;
    // Store 1 row from cCache to cSM
    if (iM < M){
      #pragma unroll
      for (int ni=0; ni<8; ni++){
        cSM[vy][vx*8 + ni] = cCache[mi].val[ni];
      }
      // Store to C
      #pragma unroll
      for (int ni=0; ni<8; ni++){
        int iN = gStartx + 16*ni + vx;
        if (iN < N){
          float cVal = cSM[vy][16*ni + vx];
          store(C+(bid)*M*N + (iM)*N + (iN), cVal);
        }
      }
    }
  } 
}

extern "C"
__global__ void bmm_tn(
  const float* __restrict__ A,
  const float* __restrict__ B,
  float* __restrict__ C,
  int M, int N, int K
){
}

extern "C"
__global__ void bmm_nt(
  const float* __restrict__ A,
  const float* __restrict__ B,
  float* __restrict__ C,
  int M, int N, int K
){
}

extern "C"
__global__ void bmm_nn(
  const float* __restrict__ A,
  const float* __restrict__ B,
  float* __restrict__ C,
  int M, int N, int K
){
  int tid = threadIdx.x;     // thread idx
  int bid = blockIdx.z;      // batch idx

  // Neighboring blocks are grouped into PN x PM block groups in order to increase
  // L1 cache hit rate
  // There are ceil(M/PM) x ceil(N/PN) block groups in total.
  // Blocks within block groups are indexed with blockIdx.x % PN and blockIdx.x / PN
  int px = blockIdx.x % _PN_;
  int py = blockIdx.x / _PN_;
  int bDimX = (N + (128*_PN_) - 1) / (128*_PN_); 
  int bDimY = (M + (128*_PM_) - 1) / (128*_PM_); 
  int bIdxX = (blockIdx.y % bDimX) * _PN_ + px;
  int bIdxY = (blockIdx.y / bDimX) * _PM_ + py;
  int gStartx = bIdxX * 128;   // starting index of block on N axis
  int gStarty = bIdxY * 128;   // starting index of block on M axis
  if (gStartx > N || gStarty > M){
    return;
  }
  // These are used to re-arrange threads into different shapes
  // for example: (256) -> (16, 16) -> (8, 32) -> (32, 8)
  int vx = tid % 16;
  int vy = tid / 16;
  int wx = tid % 32; // thread idx in warp
  int wy = tid / 32; // warp id
  int dx = tid % 8;
  int dy = tid / 8;

  __shared__ _VOLATILE_ float aSM1[8][128+4];
  __shared__ _VOLATILE_ float bSM1[8][128+4];
  __shared__ _VOLATILE_ float aSM2[8][128+4];
  __shared__ _VOLATILE_ float bSM2[8][128+4];
  float aBuffer1[4];
  float bBuffer1[4];
  float aBuffer2[4];
  float bBuffer2[4];

  float8 cCache[8];
  init_cCache(cCache);

  // Load initial 16 x 128 tile of A and B to buffer1 and buffer2
  #pragma unroll
  for (int i=0; i<4; i++){
    int iM = gStarty + dy + i*32;
    int iN = gStartx + wx + i*32;
    if (likely(iM < _M_)){
      if (likely(dx < _K_)){
        aBuffer1[i] = load(A + (bid)*_M_*_K_ + (iM)*_K_ + (dx));
      } else {
        aBuffer1[i] = 0.f;
      }
      if (likely(dx+8 < _K_)){
        aBuffer2[i] = load(A + (bid)*_M_*_K_ + (iM)*_K_ + (dx+8));
      } else {
        aBuffer2[i] = 0.f;
      }
    }
    if (likely(iN < N)){
      if (likely(wy < _K_)){
        bBuffer1[i] = load(B + (bid)*_N_*_K_ + (wy)*_N_ + (iN));
      } else {
        bBuffer1[i] = 0.f;
      }
      if (likely(wy+8 < _K_)){
        bBuffer2[i] = load(B + (bid)*_N_*_K_ + (wy+8)*_N_ + (iN));
      } else {
        bBuffer2[i] = 0.f;
      }
    }
  }

  // Number of main loop iterations is ceil(k/16)
  int nIt = (_K_ + 16 - 1) / 16;
  #pragma unroll
  for (int itr=0; itr<nIt; itr++){
    int gStartk = itr * 16;

    // Index on K axis of A and B
    int iKA = gStartk + 16 + dx;
    int iKB = gStartk + 16 + wy;

    #pragma unroll
    for (int i=0; i<4; i++){
      // Store buffered tiles into shared memory
      aSM1[dx][dy+i*32] = aBuffer1[i];
      bSM1[wy][wx+i*32+i] = bBuffer1[i];
      aSM2[dx][dy+i*32] = aBuffer2[i];
      bSM2[wy][wx+i*32+i] = bBuffer2[i];

      // Start loading next 16*128 tile of A and B to buffer1 and buffer2.
      // Don't load anything on the last iteration.
      // Loading from global memory will not block thread_matmul
      if (likely(itr < nIt - 1)){
        int iM = gStarty + i*32 + dy;
        int iN = gStartx + i*32 + wx;
        
        if (likely(iM < _M_)){
          if (likely(iKA < _K_)){
            aBuffer1[i] = load(A + (bid)*_M_*_K_ + (iM)*_K_ + (iKA));
          } else {
            aBuffer1[i] = 0.f;
          }
          if (likely(iKA+8 < _K_)){
            aBuffer2[i] = load(A + (bid)*_M_*_K_ + (iM)*_K_ + (iKA+8));
          } else {
            aBuffer2[i] = 0.f;
          }
        }

        if (likely(iN < _N_)){
          if (likely(iKB < _K_)){
            bBuffer1[i] = load(B + (bid)*_N_*_K_ + (iKB)*_N_ + (iN));
          } else {
            bBuffer1[i] = 0.f;
          }
          if (likely(iKB+8 < _K_)){
            bBuffer2[i] = load(B + (bid)*_N_*_K_ + (iKB+8)*_N_ + (iN));
          } else {
            bBuffer2[i] = 0.f;
          }
        }
      }
    }
    // synchroznie threads in order make sure tiles of A and B are fully
    // loaded to shared memory.
    __syncthreads();

    // Each thread computes 8 x 8 matrix multiplication
    // Accumulate intermediate results in cCache
    // aSM1, bSM1, aSM2, bSM2 are consumed
    thread_matmul_v3(aSM1, bSM1, cCache, vx, vy);
    thread_matmul_v3(aSM2, bSM2, cCache, vx, vy);

    // synchronize threads to signal that shared memory is consumed.
    __syncthreads();
  }
  
  // At the end of main loop, store cCache to C
  //write_c(cCache, C, gStartx, gStarty, vx, vy, bid, M, N);
  write_c_v3(cCache, C, gStartx, gStarty, vx, vy, bid, M, N);

  //C[bIdxY * N + bIdxX] = gStarty;
}

extern "C"
__global__ void bmm_tt(
  const float* __restrict__ A,
  const float* __restrict__ B,
  float* __restrict__ C,
  int M, int N, int K
){
}
"""
with open("BMMCUDAv2_5.cu", "w") as f:
  f.write(kernel)

In [5]:
#@title BMMv2.5
import torch
import cupy as cp
import numpy as np
import math

class BMMCUDAv2_5(CustomKernel): 
  def __init__(self, m=None, n=None, k=None, patch_m=4, patch_n=4):
    super(BMMCUDAv2_5, self).__init__()
    self.m = m
    self.n = n
    self.k = k
    self.patch_m = patch_m
    self.patch_n = patch_n
    
    with open("BMMCUDAv2_5.cu",'r') as f: ###
      self.kernel = f.read()
      
    self.kernel = (self.kernel
      .replace("_M_", str(m) if m else "M")
      .replace("_N_", str(n) if n else "N")
      .replace("_K_", str(k) if k else "K")
      .replace("_PM_", str(self.patch_m))
      .replace("_PN_", str(self.patch_n))
    )
    
    self._fn_tt = cp.RawKernel(
      code=self.kernel,
      name="bmm_tt",
      backend='nvcc',
      options=('--maxrregcount=128', '--use_fast_math')
    )
    self._fn_nn = cp.RawKernel(
      code=self.kernel,
      name="bmm_nn",
      backend='nvcc',
      options=(
        '--maxrregcount=128',
        '--use_fast_math',
        #'-Xptxas',
        #'-dlcm=cg',
      )
    )
    # print(self._fn_nn.attributes)
    self._fn_tn = cp.RawKernel(
      code=self.kernel,
      name="bmm_tn",
      backend='nvcc',
      options=('--maxrregcount=128', '--use_fast_math')
    )
    self._fn_nt = cp.RawKernel(
      code=self.kernel,
      name="bmm_nt",
      backend='nvcc',
      options=('--maxrregcount=128', '--use_fast_math')
    )

  def _call_nn(self, A, B):
    """
      Performs C = A @ B
      A: shape = [l, m, k]
      B: shape = [l, k, n]
      returns C: shape = [l, m, n]
    """
    assert A.shape[0] == B.shape[0]
    assert A.shape[2] == B.shape[1]
    assert A.device.type == "cuda"
    assert B.device.type == "cuda"
    assert A.dtype in (torch.float, torch.half)
    assert B.dtype in (torch.float, torch.half)
    
    l, m, k = A.shape
    l, k, n = B.shape

    if self.m is not None: assert m == self.m
    if self.n is not None: assert n == self.n
    if self.k is not None: assert k == self.k

    C = torch.zeros([l, m, n], device="cuda:0", dtype=A.dtype)

    threads_per_block = (256,)
    #blocks_per_grid = (math.ceil(n/128), math.ceil(m/128), l)
    
    n_ = math.ceil(n/(128*self.patch_n))
    m_ = math.ceil(m/(128*self.patch_m))
    blocks_per_grid = (self.patch_n*self.patch_m, n_ * m_, l)
    # print(blocks_per_grid, m_, n_)

    self._fn_nn(
      grid=blocks_per_grid,
      block=threads_per_block,
      args=[
        A.data_ptr(),
        B.data_ptr(),
        C.data_ptr(),
        m, n, k,
      ],
      stream=self.stream
    )
    return C

  def _call_tt(self, A, B):
    raise NotImplementedError

  def _call_tn(self, A, B):
    raise NotImplementedError

  def _call_nt(self, A, B):
    raise NotImplementedError

  def __call__(self, A, B, mode="nn"):
    """
      Performs C = f(A) @ f(B)
      A: torch.Tensor, shape : [l, m, k] or [l, k, m]
      B: torch.Tensor, shape : [l, n, k] or [l, k, n]
      returns C: torch.Tensor, shape : [l, m, n]
      mode: str, default: "nn"
      Notes:
        f() and g() are determined by mode
        "nn" --> A @ B
        "tt" --> A.T @ B.T
        "nt" --> A @ B.T
        "tn" --> A.T @ B
    """
    assert len(A.shape) == len(B.shape)
    A = A.contiguous()
    B = B.contiguous()
    if len(A.shape) == 2 and len(B.shape) == 2:
      A2 = A[None]
      B2 = B[None]
    elif len(A.shape) == 3 and len(B.shape) == 3:
      A2 = A
      B2 = B
    else:
      raise ValueError("shape of A and B need to be 2d or 3d")

    if mode == "nn":
      C = self._call_nn(A2, B2)
    elif mode == "tt":
      C = self._call_tt(A2, B2)
    elif mode == "tn":
      C = self._call_tn(A2, B2)
    elif mode == "nt":
      C = self._call_nt(A2, B2)

    if len(A.shape) == 2 and len(B.shape) == 2:
      C = C[0]
    return C

In [6]:
#@title BMM Kernel
kernel = """
typedef long long ll_t;
typedef unsigned long long ull_t;

typedef struct __builtin_align__(32) {
  float s0, s1, s2, s3, s4, s5, s6, s7;
} _float8;

typedef union {
  _float8 f8;
  float val[8];
} float8;

__device__ void init_cCache(
  float8 cCache[8]
) {
#pragma unroll
  for (int i=0; i<8; i++){
#pragma unroll
    for (int j=0; j<8; j++){
      cCache[i].val[j] = 0.f;
    }
  }
}

__device__ void SM2Cache(
  float cache[8][4],
  volatile float SM[8][128+4],
  int vy, int p
) {
#pragma unroll
  for (int ki=0; ki<8; ki++){
#pragma unroll
    for (int mi=0; mi<4; mi++){
      cache[ki][mi] = SM[ki][8*vy + 4*p + mi];
    }
  }
}

__device__ void thread_matmul(
  float aCache[8][4],
  volatile float bSM[8][128+4],
  float8 cCache[8],
  int vx, int p
) {
#pragma unroll
  for (int ki=0; ki<8; ki++){
#pragma unroll
    for (int ni=0; ni<8; ni++){
      float b = bSM[ki][ vx/4 + 8*vx + ni];
#pragma unroll
      for (int mi=0; mi<4; mi++){
        float a = aCache[ki][mi];
        cCache[mi + 4*p].val[ni] = fmaf(a, b, cCache[mi + 4*p].val[ni]);
      }
    }
  }
}

__device__ void write_c(
  float8 cCache[8],
  float* C,
  int gStartx, int gStarty,
  int vx, int vy, int bid,
  int M, int N
) {
#pragma unroll
  for (int i=0; i<8; i++){
    int iM = gStarty + vy*8 + i;
    if (iM < M){
      reinterpret_cast<float8*>(C + (bid)*M*N + (iM)*N + (gStartx + vx*8))[0] = cCache[i];
      /*
#pragma unroll
      for (int j=0; j<8; j++){
        int iN = gStartx + vx*8 + j;
        if (iN < N){
          C[(bid)*M*N + (iM)*N + (iN)] = cCache[i].val[j];
        }
      }
      */
    }
  }
}

extern "C"
__global__ void bmm_tn(
  const float* __restrict__ A,
  const float* __restrict__ B,
  float* __restrict__ C,
  int M, int N, int K
){
  int tid = threadIdx.x;
  int bid = blockIdx.x;
  int gStartx = blockIdx.y * 128;
  int gStarty = blockIdx.z * 128;

  int vx = tid % 16;
  int vy = tid / 16;
  int wx = tid % 32; // thread idx in warp
  int wy = tid / 32; // warp id
  int dx = tid % 8;
  int dy = tid / 8;

  __shared__ volatile float aSM[8][128+4];
  __shared__ volatile float bSM[8][128+4];
  float aBuffer1[4];
  float bBuffer1[4];
  float aBuffer2[4];
  float bBuffer2[4];

  float8 cCache[8];
  init_cCache(cCache);

  int nIt = (_K_ + 8 - 1) / 8;
  float init_value = 0.f;
#pragma unroll
  for (int i=0; i<4; i++){

    int iM = gStarty + wx + i*32;
    int iN = gStartx + wx + i*32;
    if (wy < _K_){
      if (iM < _M_)
        aBuffer1[i] = A[(bid)*_M_*_K_ + (wy)*_M_ + (iM)];
      if (iN < _N_)
        bBuffer1[i] = B[(bid)*_N_*_K_ + (wy)*_N_ + (gStartx + wx + i*32)];
    } else {
      aBuffer1[i] = 0.f;
      bBuffer1[i] = 0.f;
    }
  }
#pragma unroll
  for (int itr=0; itr<nIt; itr++){
    
    int gStartk = itr * 8;
    int iK = gStartk + 8 + wy;
    int is_odd = itr & 1;
    if (is_odd == 0){
#pragma unroll
      for (int i=0; i<4; i++){
        if (itr < nIt - 1){
          int iM = gStarty + i*32 + wx;
          int iN = gStartx + i*32 + wx;
          
          if (iK < _K_){
            if (iM < _M_)
              aBuffer2[i] = A[(bid)*_M_*_K_ + (iK)*_M_ + (iM)];
            if (iN < _N_)
              bBuffer2[i] = B[(bid)*_N_*_K_ + (iK)*_N_ + (iN)];
          } else {
            aBuffer2[i] = 0.f;
            bBuffer2[i] = 0.f;
          }
        }
        aSM[wy][wx+i*32] = aBuffer1[i];
        bSM[wy][wx+i*32+i] = bBuffer1[i];
      }
    } else {
#pragma unroll
      for (int i=0; i<4; i++){
        if (itr < nIt - 1){
          int iM = gStarty + i*32 + wx;
          int iN = gStartx + i*32 + wx;
          if (iK < _K_){
            if (iM < _M_)
              aBuffer1[i] = A[(bid)*_M_*_K_ + (iK)*_M_ + (iM)];
            if (iN < N)
              bBuffer1[i] = B[(bid)*_N_*_K_ + (iK)*_N_ + (iN)];
          } else {
            aBuffer1[i] = 0.f;
            bBuffer1[i] = 0.f;
          }
        }
        aSM[wy][wx+i*32] = aBuffer2[i];
        bSM[wy][wx+i*32+i] = bBuffer2[i];
      }
    }
    __syncthreads();

    float aCache[8][4];

#pragma unroll
    for (int p=0; p<2; p++){
      SM2Cache(aCache, aSM, vy, p);
      // thread_matmul(aCache, bSM, cCache, vx, p);
      thread_matmul(aCache, bSM, cCache, vx, p);
    }
    __syncthreads();
  }

  write_c(cCache, C, gStartx, gStarty, vx, vy, bid, M, N);
}

extern "C"
__global__ void bmm_nt(
  const float* __restrict__ A,
  const float* __restrict__ B,
  float* __restrict__ C,
  int M, int N, int K
){
  int tid = threadIdx.x;
  int bid = blockIdx.x;
  int gStartx = blockIdx.y * 128;
  int gStarty = blockIdx.z * 128;

  int vx = tid % 16;
  int vy = tid / 16;
  int wx = tid % 32; // thread idx in warp
  int wy = tid / 32; // warp id
  int dx = tid % 8;
  int dy = tid / 8;

  __shared__ volatile float aSM[8][128+4];
  __shared__ volatile float bSM[8][128+4];
  float aBuffer1[4];
  float bBuffer1[4];
  float aBuffer2[4];
  float bBuffer2[4];

  float8 cCache[8];
  init_cCache(cCache);

  int nIt = (_K_ + 8 - 1) / 8;
  float init_value = 0.f;
#pragma unroll
  for (int i=0; i<4; i++){

    int iM = gStarty + dy + i*32;
    int iN = gStartx + dy + i*32;
    if (dx < _K_){
      if (iM < _M_)
        aBuffer1[i] = A[(bid)*_M_*_K_ + (iM)*_K_ + (dx)];
      if (iN < _N_)
        bBuffer1[i] = B[(bid)*_N_*_K_ + (iN)*_K_ + (dx)];
    } else {
      aBuffer1[i] = 0.f;
      bBuffer1[i] = 0.f;
    }
  }
#pragma unroll
  for (int itr=0; itr<nIt; itr++){
    
    int gStartk = itr * 8;
    int iK = gStartk + 8 + dx;
    int is_odd = itr & 1;
    if (is_odd == 0){
#pragma unroll
      for (int i=0; i<4; i++){
        if (itr < nIt - 1){
          int iM = gStarty + i*32 + dy;
          int iN = gStartx + i*32 + dy;
          
          if (iK < _K_){
            if (iM < _M_)
              aBuffer2[i] = A[(bid)*_M_*_K_ + (iM)*_K_ + (iK)];
            if (iN < _N_)
              bBuffer2[i] = B[(bid)*_N_*_K_ + (iN)*_K_ + (iK)];
          } else {
            aBuffer2[i] = 0.f;
            bBuffer2[i] = 0.f;
          }
        }
        aSM[dx][dy+i*32] = aBuffer1[i];
        bSM[dx][dy+i*32+i] = bBuffer1[i];
      }
    } else {
#pragma unroll
      for (int i=0; i<4; i++){
        if (itr < nIt - 1){
          int iM = gStarty + i*32 + dy;
          int iN = gStartx + i*32 + dy;
          if (iK < _K_){
            if (iM < _M_)
              aBuffer1[i] = A[(bid)*_M_*_K_ + (iM)*_K_ + (iK)];
            if (iN < N)
              bBuffer1[i] = B[(bid)*_N_*_K_ + (iN)*_K_ + (iK)];
          } else {
            aBuffer1[i] = 0.f;
            bBuffer1[i] = 0.f;
          }
        }
        aSM[dx][dy+i*32] = aBuffer2[i];
        bSM[dx][dy+i*32+i] = bBuffer2[i];
      }
    }
    __syncthreads();

    float aCache[8][4];

#pragma unroll
    for (int p=0; p<2; p++){
      SM2Cache(aCache, aSM, vy, p);
      thread_matmul(aCache, bSM, cCache, vx, p);
    }
    __syncthreads();
  }

  write_c(cCache, C, gStartx, gStarty, vx, vy, bid, M, N);
}

extern "C"
__global__ void bmm_nn(
  const float* __restrict__ A,
  const float* __restrict__ B,
  float* __restrict__ C,
  int M, int N, int K
){
  int tid = threadIdx.x;
  int bid = blockIdx.x;
  int gStartx = blockIdx.y * 128;
  int gStarty = blockIdx.z * 128;

  int vx = tid % 16;
  int vy = tid / 16;
  int wx = tid % 32; // thread idx in warp
  int wy = tid / 32; // warp id
  int dx = tid % 8;
  int dy = tid / 8;

  __shared__ volatile float aSM[8][128+4];
  __shared__ volatile float bSM[8][128+4];
  float aBuffer1[4];
  float bBuffer1[4];
  float aBuffer2[4];
  float bBuffer2[4];

  float8 cCache[8];
  init_cCache(cCache);

  int nIt = (_K_ + 8 - 1) / 8;
  float init_value = 0.f;
#pragma unroll
  for (int i=0; i<4; i++){

    int iM = gStarty + dy + i*32;
    int iN = gStartx + wx + i*32;
    if (iM < _M_){
      if (dx < _K_){
        aBuffer1[i] = A[(bid)*_M_*_K_ + (iM)*_K_ + (dx)];
      } else {
        aBuffer1[i] = 0.f;
      }
    }
    if (iN < N){
      if (wy < _K_){
        bBuffer1[i] = B[(bid)*_N_*_K_ + (wy)*_N_ + (iN)];
      } else {
        bBuffer1[i] = 0.f;
      }
    }

  }
#pragma unroll
  for (int itr=0; itr<nIt; itr++){
    
    int gStartk = itr * 8;
    int iKA = gStartk + 8 + dx;
    int iKB = gStartk + 8 + wy;
    int is_odd = itr & 1;
    if (is_odd == 0){
#pragma unroll
      for (int i=0; i<4; i++){
        if (itr < nIt - 1){
          int iM = gStarty + i*32 + dy;
          int iN = gStartx + i*32 + wx;
          
          if (iKA < _K_){
            if (iM < _M_){
              aBuffer2[i] = A[(bid)*_M_*_K_ + (iM)*_K_ + (iKA)];
            }
          } else {
            aBuffer2[i] = 0.f;
          }

          if (iKB < _K_){
            if (iN < _N_){
              bBuffer2[i] = B[(bid)*_N_*_K_ + (iKB)*_N_ + (iN)];
            }
          } else {
            bBuffer2[i] = 0.f;
          }
        }
        aSM[dx][dy+i*32] = aBuffer1[i];
        bSM[wy][wx+i*32+i] = bBuffer1[i];
      }
    } else {
#pragma unroll
      for (int i=0; i<4; i++){
        if (itr < nIt - 1){
          int iM = gStarty + i*32 + dy;
          int iN = gStartx + i*32 + wx;

          if (iKA < _K_){
            if (iM < _M_){
              aBuffer1[i] = A[(bid)*_M_*_K_ + (iM)*_K_ + (iKA)];
            }
          } else {
            aBuffer1[i] = 0.f;
          }
          

          if (iKB < _K_){
            if (iN < _N_){
              bBuffer1[i] = B[(bid)*_N_*_K_ + (iKB)*_N_ + (iN)];
            }
          } else {
            bBuffer1[i] = 0.f;
          }
        }
        aSM[dx][dy+i*32] = aBuffer2[i];
        bSM[wy][wx+i*32+i] = bBuffer2[i];
      }
    }
    __syncthreads();

    float aCache[8][4];

#pragma unroll
    for (int p=0; p<2; p++){
      SM2Cache(aCache, aSM, vy, p);
      thread_matmul(aCache, bSM,cCache, vx, p);
    }
    __syncthreads();
  }

  write_c(cCache, C, gStartx, gStarty, vx, vy, bid, M, N);
}

extern "C"
__global__ void bmm_tt(
  const float* __restrict__ A,
  const float* __restrict__ B,
  float* __restrict__ C,
  int M, int N, int K
){
  int tid = threadIdx.x;
  int bid = blockIdx.x;
  int gStartx = blockIdx.y * 128;
  int gStarty = blockIdx.z * 128;

  int vx = tid % 16;
  int vy = tid / 16;
  int wx = tid % 32; // thread idx in warp
  int wy = tid / 32; // warp id
  int dx = tid % 8;
  int dy = tid / 8;

  __shared__ volatile float aSM[8][128+4];
  __shared__ volatile float bSM[8][128+4];
  float aBuffer1[4];
  float bBuffer1[4];
  float aBuffer2[4];
  float bBuffer2[4];

  float8 cCache[8];
  init_cCache(cCache);

  int nIt = (_K_ + 8 - 1) / 8;
  float init_value = 0.f;
#pragma unroll
  for (int i=0; i<4; i++){

    int iM = gStarty + wx + i*32;
    int iN = gStartx + dy + i*32;
    if (iM < _M_){
      if (wy < _K_){
        aBuffer1[i] = A[(bid)*_M_*_K_ + (wy)*_M_ + (iM)];
      } else {
        aBuffer1[i] = 0.f;
      }
    }
    if (iN < _N_){
      if (dx < _K_){
        bBuffer1[i] = B[(bid)*_N_*_K_ + (iN)*_K_ + (dx)];
      } else {
        bBuffer1[i] = 0.f;
      }
    }
  }
#pragma unroll
  for (int itr=0; itr<nIt; itr++){
    
    int gStartk = itr * 8;
    int iKA = gStartk + 8 + wy;
    int iKB = gStartk + 8 + dx;
    int is_odd = itr & 1;
    if (is_odd == 0){
#pragma unroll
      for (int i=0; i<4; i++){
        if (itr < nIt - 1){
          int iM = gStarty + i*32 + wx;
          int iN = gStartx + i*32 + dy;
          
          if (iKA < _K_){
            if (iM < _M_){
              aBuffer2[i] = A[(bid)*_M_*_K_ + (iKA)*_M_ + (iM)];
            }
          } else {
            aBuffer2[i] = 0.f;
          }

          if (iKB < _K_){
            if (iN < _N_){
              bBuffer2[i] = B[(bid)*_N_*_K_ + (iN)*_K_ + (iKB)];
            }
          } else {
            bBuffer2[i] = 0.f;
          }
        }
        aSM[wy][wx+i*32] = aBuffer1[i];
        bSM[dx][dy+i*32+i] = bBuffer1[i];
      }
    } else {
#pragma unroll
      for (int i=0; i<4; i++){
        if (itr < nIt - 1){
          int iM = gStarty + i*32 + wx;
          int iN = gStartx + i*32 + dy;
          if (iKA < _K_){
            if (iM < _M_){
              aBuffer1[i] = A[(bid)*_M_*_K_ + (iKA)*_M_ + (iM)];
            }
          } else {
            aBuffer1[i] = 0.f;
          }

          if (iKB < _K_){
            if (iN < _N_){
              bBuffer1[i] = B[(bid)*_N_*_K_ + (iN)*_K_ + (iKB)];
            }
          } else {
            bBuffer1[i] = 0.f;
          }
        }
        aSM[wy][wx+i*32] = aBuffer2[i];
        bSM[dx][dy+i*32+i] = bBuffer2[i];
      }
    }
    __syncthreads();

    float aCache[8][4];

#pragma unroll
    for (int p=0; p<2; p++){
      SM2Cache(aCache, aSM, vy, p);
      thread_matmul(aCache, bSM, cCache, vx, p);
    }
    __syncthreads();
  }
  write_c(cCache, C, gStartx, gStarty, vx, vy, bid, M, N);
}
"""
with open("BMMCUDA.cu", "w") as f:
  f.write(kernel)

In [7]:
#@title BMM
import torch
import cupy as cp
import numpy as np
import math

class BMMCUDA(CustomKernel): 
  def __init__(self, m=None, n=None, k=None):
    super(BMMCUDA, self).__init__()
    self.m = m
    self.n = n
    self.k = k
    # with open(get_absolute_path("BMMCUDA.cu"),'r') as f: ###
    with open("BMMCUDA.cu",'r') as f: ###
      self.kernel = f.read()
      
    self.kernel = (self.kernel
      .replace("_M_", str(m) if m else "M")
      .replace("_N_", str(n) if n else "N")
      .replace("_K_", str(k) if k else "K")
    )
    
    self._fn_tt = cp.RawKernel(
      code=self.kernel,
      name="bmm_tt",
      backend='nvcc',
      options=('--maxrregcount=128', '--use_fast_math')
    )
    self._fn_nn = cp.RawKernel(
      code=self.kernel,
      name="bmm_nn",
      backend='nvcc',
      options=('--maxrregcount=128', '--use_fast_math')
    )
    self._fn_tn = cp.RawKernel(
      code=self.kernel,
      name="bmm_tn",
      backend='nvcc',
      options=('--maxrregcount=128', '--use_fast_math')
    )
    self._fn_nt = cp.RawKernel(
      code=self.kernel,
      name="bmm_nt",
      backend='nvcc',
      options=('--maxrregcount=128', '--use_fast_math')
    )

  def _call_nn(self, A, B):
    """
      Performs C = A @ B
      A: shape = [l, m, k]
      B: shape = [l, k, n]
      returns C: shape = [l, m, n]
    """
    assert A.shape[0] == B.shape[0]
    assert A.shape[2] == B.shape[1]
    assert A.device.type == "cuda"
    assert B.device.type == "cuda"
    assert A.dtype in (torch.float, torch.half)
    assert B.dtype in (torch.float, torch.half)
    
    l, m, k = A.shape
    l, k, n = B.shape

    if self.m is not None: assert m == self.m
    if self.n is not None: assert n == self.n
    if self.k is not None: assert k == self.k

    C = torch.zeros([l, m, n], device="cuda:0", dtype=A.dtype)

    threads_per_block = (256,)
    blocks_per_grid = (l, math.ceil(n/128), math.ceil(m/128))

    self._fn_nn(
      grid=blocks_per_grid,
      block=threads_per_block,
      args=[
        A.data_ptr(),
        B.data_ptr(),
        C.data_ptr(),
        m, n, k
      ],
      stream=self.stream
    )
    return C

  def _call_tt(self, A, B):
    """
      Performs C = A.t @ B.t
      A: shape = [l, k, m]
      B: shape = [l, n, k]
      returns C: shape = [l, m, n]
    """
    assert A.shape[0] == B.shape[0]
    assert A.shape[1] == B.shape[2]
    assert A.device.type == "cuda"
    assert B.device.type == "cuda"
    assert A.dtype in (torch.float, torch.half)
    assert B.dtype in (torch.float, torch.half)
    
    l, k, m = A.shape
    l, n, k = B.shape

    if self.m is not None: assert m == self.m
    if self.n is not None: assert n == self.n
    if self.k is not None: assert k == self.k


    C = torch.zeros([l, m, n], device="cuda:0", dtype=A.dtype)

    threads_per_block = (256,)
    blocks_per_grid = (l, math.ceil(n/128), math.ceil(m/128))

    self._fn_tt(
      grid=blocks_per_grid,
      block=threads_per_block,
      args=[
        A.data_ptr(),
        B.data_ptr(),
        C.data_ptr(),
        m, n, k
      ],
      stream=self.stream
    )
    return C

  def _call_tn(self, A, B):
    """
      Performs C = A.t @ B
      A: shape = [l, k, m]
      B: shape = [l, k, n]
      returns C: shape = [l, m, n]
    """
    assert A.shape[0] == B.shape[0]
    assert A.shape[1] == B.shape[1]
    assert A.device.type == "cuda"
    assert B.device.type == "cuda"
    assert A.dtype in (torch.float, torch.half)
    assert B.dtype in (torch.float, torch.half)

    l, k, m = A.shape
    l, k, n = B.shape

    if self.m is not None: assert m == self.m
    if self.n is not None: assert n == self.n
    if self.k is not None: assert k == self.k

    C = torch.zeros([l, m, n], device="cuda:0", dtype=A.dtype)
    
    threads_per_block = (256,)
    blocks_per_grid = (l, math.ceil(n/128), math.ceil(m/128))

    self._fn_tn(
      grid=blocks_per_grid,
      block=threads_per_block,
      args=[
        A.data_ptr(),
        B.data_ptr(),
        C.data_ptr(),
        m, n, k
      ],
      stream=self.stream,
    )
    return C

  def _call_nt(self, A, B):
    """
      Performs C = A @ B.t
      A: shape = [l, m, k]
      B: shape = [l, n, k]
      returns C: shape = [l, m, n]
    """
    assert A.shape[0] == B.shape[0]
    assert A.shape[2] == B.shape[2]
    assert A.device.type == "cuda"
    assert B.device.type == "cuda"
    assert A.dtype in (torch.float, torch.half)
    assert B.dtype in (torch.float, torch.half)

    l, m, k = A.shape
    l, n, k = B.shape

    if self.m is not None: assert m == self.m
    if self.n is not None: assert n == self.n
    if self.k is not None: assert k == self.k

    C = torch.zeros([l, m, n], device="cuda:0", dtype=A.dtype)

    threads_per_block = (256,)
    blocks_per_grid = (l, math.ceil(n/128), math.ceil(m/128))

    self._fn_nt(
      grid=blocks_per_grid,
      block=threads_per_block,
      args=[
        A.data_ptr(),
        B.data_ptr(),
        C.data_ptr(),
        m, n, k
      ],
      stream=self.stream
    )
    return C

  def __call__(self, A, B, mode="nn"):
    """
      Performs C = f(A) @ f(B)
      A: torch.Tensor, shape : [l, m, k] or [l, k, m]
      B: torch.Tensor, shape : [l, n, k] or [l, k, n]
      returns C: torch.Tensor, shape : [l, m, n]
      mode: str, default: "nn"
      Notes:
        f() and g() are determined by mode
        "nn" --> A @ B
        "tt" --> A.T @ B.T
        "nt" --> A @ B.T
        "tn" --> A.T @ B
    """
    assert len(A.shape) == len(B.shape)
    A = A.contiguous()
    B = B.contiguous()
    if len(A.shape) == 2 and len(B.shape) == 2:
      A2 = A[None]
      B2 = B[None]
    elif len(A.shape) == 3 and len(B.shape) == 3:
      A2 = A
      B2 = B
    else:
      raise ValueError("shape of A and B need to be 2d or 3d")

    if mode == "nn":
      C = self._call_nn(A2, B2)
    elif mode == "tt":
      C = self._call_tt(A2, B2)
    elif mode == "tn":
      C = self._call_tn(A2, B2)
    elif mode == "nt":
      C = self._call_nt(A2, B2)

    if len(A.shape) == 2 and len(B.shape) == 2:
      C = C[0]
    return C

In [None]:
#@title test BMMv2_5
def test_bmm_v2_5(l, m, n, k, mode="nn", n_iter=1, verbose=0):
  print(f"l={l}  m={m}  n={n}  k={k}")
  if mode[0] == "n":
    A = torch.randn(l, m, k, device="cuda:0")
  elif mode[0] == "t":
    A = torch.randn(l, k, m, device="cuda:0")
  
  if mode[1] == "n":
    B = torch.randn(l, k, n, device="cuda:0")
  elif mode[1] == "t":
    B = torch.randn(l, n, k, device="cuda:0")
  # custom_bmm = BMMCUDA()
  custom_bmm_v2_5 = BMMCUDAv2_5(patch_m=4, patch_n=4)
  flop = l * m * n * k * 2

  if mode[0] == "t":
    At = A.transpose(1, 2)
  else: 
    At = A
  if mode[1] == "t":
    Bt = B.transpose(1, 2)
  else:
    Bt = B
  #warmup
  for i in range(n_iter):
    torch.bmm(At, Bt)
    torch.cuda.synchronize()

  tm = time()
  for i in range(n_iter):
    C = torch.bmm(At, Bt)
    # C2 = At @ Bt
    torch.cuda.synchronize()
  time_cost_2 = (time() - tm) / n_iter
  flops2 = (flop / time_cost_2) / 1000**4
  if verbose > 0:
    print("time spent for torch.bmm:", time_cost_2)
    print("tflops:", flops2)
  else:
    del C

  # # warmup
  # for i in range(n_iter):
  #   custom_bmm(A, B, mode=mode)
  # torch.cuda.synchronize()
  # tm = time()
  # for i in range(n_iter):
  #   C4 = custom_bmm(A, B, mode=mode)
  #   torch.cuda.synchronize()
  # time_cost_4 = (time() - tm) / n_iter
  # flops4 = (flop / time_cost_4) / 1000**4
  # if verbose > 0:
  #   print("time spent for custom_bmm:", time_cost_4)
  #   print("tflops:", flops4)
  # del C4

  # warmup
  for i in range(n_iter):
    custom_bmm_v2_5(A, B, mode=mode)
    torch.cuda.synchronize()
  tm = time()
  for i in range(n_iter):
    C1 = custom_bmm_v2_5(A, B, mode=mode)
    torch.cuda.synchronize()
  time_cost_1 = (time() - tm) / n_iter
  flops1 = (flop / time_cost_1) / 1000**4
  if verbose > 0:
    print("time spent for custom_bmm_v2_5:", time_cost_1)
    print("tflops:", flops1)
  else:
    del C1

  if verbose > 0:
    dif = (C1 - C).abs()
    print("Max Error", dif.max())
    print("Error:", dif.sum())
    print("ratio:", time_cost_1 / time_cost_2)


  if verbose > 1:
    plt.imshow(( dif < 1e-4)[0].cpu())
    plt.show()
    plt.imshow(C1[0].cpu())
    plt.show()

  return time_cost_1, time_cost_2
  
_ = test_bmm_v2_5(1, 1024*16, 1024, 256,
    mode="nn", n_iter=100, verbose=1)

In [11]:
import os
if not os.path.exists("imgs"):
  os.mkdir("imgs")

In [None]:
#@title Grid test BMM
ls = [i*128 for i in range(1, 3)]
ms = [512]
ns = ms
ks = [64]
mode="nn"

custom_res = dict()
cublass_res = dict()
for l in ls:
  for m in ms:
    for k in ks:
      res = test_bmm_v2_5(l, m, m, k, mode=mode, n_iter=50)
      custom_res[l] = res[0] *1e3
      cublass_res[l] = res[1] *1e3


plt.figure(figsize=(15, 10) )
plt.tight_layout()
plt.xlabel("X", fontsize=17)
plt.ylabel("milliseconds", fontsize=17)
title = f"A[X,{m},{k}] B[X,{k},{m}]"
plt.title(title)
plt.rcParams["font.size"] = "17"
plt.grid()
colors = ["red", "blue"]
labels = ["custom_bmm", "torch.bmm"]
for i, res in enumerate([custom_res, cublass_res]):
  res_x = list(res.keys())
  res_y = list(res.values())
  plt.plot(
    res_x,
    res_y,
    color=colors[i],
    label=labels[i],
  )
plt.legend()
plt.savefig("imgs/mbmm_" + title)
plt.show()