In [27]:
import cutlass
import cutlass.cute as cute
import cuda.bindings as cu
%load_ext nvcc4jupyter

from nvcc4jupyter import set_defaults
set_defaults(compiler_args='-arch=sm_100a -Xptxas=-v -O0')
import numpy as np

The nvcc4jupyter extension is already loaded. To reload it, use:
  %reload_ext nvcc4jupyter


In [28]:
%%cuda
#include <stdio.h>
#include <cuda_runtime.h> 
#include <cuda.h> 
#include <mma.h> 
#include <cuda_bf16.h>
#include <cuda/barrier>

using barrier = cuda::barrier<cuda::thread_scope_block>;
namespace cde = cuda::device::experimental;

constexpr int M = 4096;
constexpr int N = 4096;
constexpr int K = 4096;
constexpr int bM = 32; 
constexpr int bN = 32; 
constexpr int bK = 32;

__global__ void matmul(__nv_bfloat16* A, __nv_bfloat16 *B, __nv_bfloat16 *C, 
                        const __grid_constant__ CUtensorMap tensor_map_A, 
                        const __grid_constant__ CUtensorMap tensor_map_B)
{
  __shared__ alignas(128) __nv_bfloat16 S0[2][(bM*bK) + (bK*bN)]; 
  __shared__ alignas(128) __nv_bfloat16 S1[2][(bM*bK) + (bK*bN)]; 

  int t = threadIdx.x; 

  __shared__ barrier S0_E, S0_F, S1_E, S1_F; 
  if (t == 0)
  {
    init(&S0_E, blockDim.x);
    init(&S1_E, blockDim.x);
    init(&S0_F, blockDim.x);
    init(&S1_F, blockDim.x);
    cde::fence_proxy_async_shared_cta();
  }
  __syncthreads(); 
}

void printDeviceStats(int deviceId) {
    cudaDeviceProp prop;
    cudaGetDeviceProperties(&prop, deviceId);

    printf("\n==========================================================\n");
    printf(" DEVICE: %s (Compute %d.%d)\n", prop.name, prop.major, prop.minor);
    printf("==========================================================\n");
    
    // SM and Threading Info
    printf(" [Compute]\n");
    printf("  Multiprocessors (SMs):       %d\n", prop.multiProcessorCount);
    printf("  Max Threads per SM:          %d\n", prop.maxThreadsPerMultiProcessor);
    printf("  Max Threads per Block:       %d\n", prop.maxThreadsPerBlock);
    printf("  Warp Size:                   %d\n", prop.warpSize);

    // Registers
    printf("\n [Registers]\n");
    printf("  Max 32-bit Regs per Block:   %d\n", prop.regsPerBlock);
    printf("  Max 32-bit Regs per SM:      %d\n", prop.regsPerMultiprocessor);

    // Shared Memory
    printf("\n [Memory Hierarchy]\n");
    // Standard static shared mem limit
    printf("  Smem per Block (Static):     %zu KB\n", prop.sharedMemPerBlock / 1024);
    // Max possible shared mem per block (requires opt-in via cudaFuncSetAttribute)
    printf("  Smem per Block (Max Dyn):    %zu KB\n", prop.sharedMemPerBlockOptin / 1024);
    // Total shared memory available per SM (partitioned among blocks resident on that SM)
    printf("  Smem per Multiprocessor:     %zu KB\n", prop.sharedMemPerMultiprocessor / 1024);
    
    // Global Memory & Cache
    printf("  L2 Cache Size:               %d MB\n", prop.l2CacheSize / (1024 * 1024));
    printf("  Total Global Memory:         %.2f GB\n", (float)prop.totalGlobalMem / (1024.0 * 1024.0 * 1024.0));
    printf("  Memory Bus Width:            %d-bit\n", prop.memoryBusWidth);
    //#printf("  Memory Clock Rate:           %.2f GHz\n", prop.memoryClockRate / 1.0e6);
    
    printf("==========================================================\n\n");
}

int main()
{
    int deviceId = 0;
    cudaGetDevice(&deviceId);
    
    // 1. Print the pretty stats
    printDeviceStats(deviceId);

    // 2. Your occupancy calculation
    int numBlocks;
    int blockSize = 5*32; 
    size_t dynamicSMemSize = 0; 

    cudaOccupancyMaxActiveBlocksPerMultiprocessor(
        &numBlocks, 
        matmul, 
        blockSize, 
        dynamicSMemSize
    );

    printf("Occupancy Check:\n");
    printf("  Kernel: matmul\n");
    printf("  Block Size: %d threads\n", blockSize);
    printf("  Max Active Blocks per SM: %d\n", numBlocks);

    return 0;
}


 DEVICE: NVIDIA GeForce RTX 5090 (Compute 12.0)
 [Compute]
  Multiprocessors (SMs):       170
  Max Threads per SM:          1536
  Max Threads per Block:       1024
  Warp Size:                   32

 [Registers]
  Max 32-bit Regs per Block:   65536
  Max 32-bit Regs per SM:      65536

 [Memory Hierarchy]
  Smem per Block (Static):     48 KB
  Smem per Block (Max Dyn):    99 KB
  Smem per Multiprocessor:     100 KB
  L2 Cache Size:               96 MB
  Total Global Memory:         31.36 GB
  Memory Bus Width:            512-bit

Occupancy Check:
  Kernel: matmul
  Block Size: 160 threads
  Max Active Blocks per SM: 9



In [29]:
M = 4096 
N = 4096 
K = 4096
wmma_m = 16
wmma_n = 16
wmma_k = 16
n_warps = 32 
n_buffers = 2
n_stages_per_buffer = 8
work_per_block = wmma_m*wmma_n*n_warps*n_buffers*n_stages_per_buffer
work = M*N
n_blocks = work//work_per_block

n_registers_per_block = ((wmma_m*wmma_k)//2 + (wmma_k*wmma_n)//2  + wmma_n*wmma_m)*n_warps


In [30]:
n_registers_per_block

16384

In [33]:
A_tiled = np.arange(M*K//(wmma_m*wmma_n)).reshape(M//wmma_m, K//wmma_k)

In [34]:
A_tiled

array([[    0,     1,     2, ...,   253,   254,   255],
       [  256,   257,   258, ...,   509,   510,   511],
       [  512,   513,   514, ...,   765,   766,   767],
       ...,
       [64768, 64769, 64770, ..., 65021, 65022, 65023],
       [65024, 65025, 65026, ..., 65277, 65278, 65279],
       [65280, 65281, 65282, ..., 65533, 65534, 65535]], shape=(256, 256))

In [36]:
A_wmma_units, B_wmma_units, C_wmma_units = np.empty((256,256)), np.empty((256,256)), np.empty((256,256))


In [75]:
N_warps_per_block = 32 
N_c_warp_units = 256*256 
N_sms = 170 
N_stages = 512 
work_per_block = N_stages*N_warps_per_block
n_blocks = N_c_warp_units//work_per_block
while n_blocks < 170: 

  work_per_block = N_stages*N_warps_per_block
  n_blocks = N_c_warp_units//work_per_block
  if N_stages == 16: 
    print(n_blocks, work_per_block, N_stages)
  N_stages -= 1
  

128 512 16


In [61]:
work_per_block

384

In [62]:
n_blocks

170

In [63]:
N_stages

11

In [65]:
print(32*16)


512


In [73]:
print(170*384)

65280


In [74]:
print(256*256)

65536


In [86]:
smem_dyn_kb = 99
# Use 1024 for KiB to Bytes conversion
smem_total_bytes = smem_dyn_kb * 1024  
# 99 * 1024 = 101,376 bytes

# bf16 is 2 bytes
n_max_bf16s_per_sm = smem_total_bytes // 2 
# 101,376 // 2 = 50,688 elements

In [87]:
n_max_bf16s_per_sm

50688

In [None]:
def calculate_smem_n_elements(BM,BN,BK): 
  return 16*16*(BM*BK + BK*BN)

In [95]:
X = calculate_smem_n_elements(8,8,8)

In [96]:
X

32768