In [96]:
from Orch import Bijection, GemmTopology, SliceIR_node, SliceIR_tree, generate_dot_source, GemmNode
import numpy as np
import math 
from typing import Tuple, List, Union, Optional, Dict
from collections import defaultdict


In [118]:
def calculate_sM_count(C_shape, wmma_atom_shape): 
  wmma_M, wmma_K, wmma_N = wmma_atom_shape
  gM, gN = C_shape
  L = defaultdict(list)
  prev_sm_count = 0 
  for kM in range(1,256): 
    for kN in range(1,256): 
      bM = wmma_M*kM 
      bN = wmma_N*kN
      sm_count = (gM/bM)*(gN/bN)
      sm_count_int = (gM//bM)*(gN//bN)
      if sm_count == sm_count_int and sm_count <= 170:
        L[sm_count].append((bM,bN))
        if sm_count > prev_sm_count:
          prev_sm_count = sm_count
          
  best_schedules = L[prev_sm_count]
  sorted_schedules = sorted(best_schedules, key=lambda x: (abs(x[0]-x[1]), -x[1]))
    
  print(f"Best SM Count: {prev_sm_count}")
  print("Top Schedule (Most Square + Widest N):", sorted_schedules[0])
  
  return sorted_schedules[0]

def make_wmma_atom(wmma_m, wmma_k, wmma_n): 
  
  b = Bijection(1, 0,0,0,1,0)
  shape = (wmma_m, wmma_k, wmma_n)
  assert shape in [(16,16,8),(16,32,8)]
  tiler = np.zeros((3,1)).astype(int)
  tiler[0] = wmma_m 
  tiler[1] = wmma_k 
  tiler[2] = wmma_n
  return GemmTopology(shape, tiler, (0,1,2), b)  


def make_sm_schedule(gM, gK, gN, bM, bN, bits, base, shift, scale, step, sigma:Tuple[int,int,int]): 
  grid_M = gM//bM 
  grid_N = gN//bN 
  shape = (gM,gK,gN)
  M = max(grid_M, grid_N)
  tiler = np.zeros((3,M))
  tiler[0:] = [bM]*grid_M + [0]*(M-grid_M)
  tiler[1:] = [gK] + [0]*(M-1)
  tiler[2:] = [bN]*grid_N + [0]*(M-grid_N)
  print(tiler)
  length = M*M*M
  b = Bijection(length, bits, base, shift, scale, step)
  sm_schedule_matmul = GemmTopology(shape, tiler, sigma, b)
  return sm_schedule_matmul, grid_M, grid_N


In [119]:
wmma_shape = (16,32,8)
wmma_atom = make_wmma_atom(*wmma_shape)

In [120]:
gM, gK, gN = 4096, 4096,4096


C_shape = (gM, gN)

bM, bN = calculate_sM_count(C_shape, wmma_shape)

print(gM//bM)
print(gN//bN)


Best SM Count: 128.0
Top Schedule (Most Square + Widest N): (256, 512)
16
8


In [121]:
bits = 0 
base = 0 
shift = 0 
scale = 1 
step = 0
sigma = (2,1,0)


M,grid_M, grid_N = make_sm_schedule(gM,gK,gN,bM,bN,bits,base,shift,scale,step,sigma)
print(grid_M*grid_N)
  

[[ 256.  256.  256.  256.  256.  256.  256.  256.  256.  256.  256.  256.
   256.  256.  256.  256.]
 [4096.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.
     0.    0.    0.    0.]
 [ 512.  512.  512.  512.  512.  512.  512.  512.    0.    0.    0.    0.
     0.    0.    0.    0.]]
128


In [122]:
def get_usage(wmma_shape):
  m,k,n = wmma_shape
  smem_usage_bytes = ((m*k) + (k*n))*2 
  a_b_frag_registers =((m*k) + (k*n))//2
  c_frag_registers = m*n 
  return smem_usage_bytes, a_b_frag_registers, c_frag_registers
  

In [123]:
smem, a_b_frag, c_frag = get_usage(wmma_shape)

In [124]:
print(smem, a_b_frag, c_frag)

1536 384 128


In [125]:
max_reg_per_sm = 65536
max_smem_bytes = 1000*99

In [126]:
def get_max_atom_tiles(max_reg_per_sm, max_smem_bytes, wmma_shape): 
  smem_usage, a_b_frags, c_frags = get_usage(wmma_shape)
  for n_tiles in range(1,1024): 
    total_smem = n_tiles*smem_usage
    total_n_regs = n_tiles*(a_b_frags+c_frags)
    if total_smem >= max_smem_bytes or total_n_regs >= max_reg_per_sm:
      return 1<<(math.floor(math.log2(n_tiles)))
    

In [127]:
x = get_max_atom_tiles(max_reg_per_sm, max_smem_bytes, wmma_shape)

In [128]:
x

64