In [None]:
from kernel_runner import KernelRunner
from dataclasses import dataclass

@dataclass(frozen=True)
class RegisterTiledKernelRunner(KernelRunner):
    template: str = "<float, 128, 128, 8, 8, 8>"
    kernel_name: str = "register_tiled"

# More loop transformations and storage insertion

Registers are pieces of memory where compute happens on a chip. This is the
ideal place to keep values if at all possible.

Instead of splitting loops, we can permute their order to change the how compute
happens. Register tiling relies on this permutation.

Here's the shared memory loop structure, can you permute the loops to expose more
reuse that we can exploit with register usage?
```python
for io in range(N // BM):                  # parallel outer
    for jo in range(N // BN):              # parallel outer
        
        for ii in range(BM):               # parallel inner
            for ji in range(BN):           # parallel inner

                for ko in range(N // BK):  # sequential outer
                    for ki in range(BK):   # sequential inner
                        # compute
```

# Note on register pressure
- We have a limited number of registers (64K 32-bit registers)
- If we use too many registers, SMs will be short on resources and underutilized
  as a result
- Given thread tile sizes of TM*TN, compute the total number of registers

In [None]:
N = 8192 
BM, BN, TM, TN = 128, 128, 8, 8
block_dim = (BN // TN, BM // TM)
grid_dim = ((N + BM - 1) // BM,
            (N + BN - 1) // BN)

runner = RegisterTiledKernelRunner()
# there is no "blank" version of this kernel: keep read_full_src=True
_ = runner(block_dim, grid_dim, (N,), read_full_src=True, niterations=20)

Using the above parameters, compute the following things:

In [None]:
# theoretical occupancy calculation
registers_per_thread = 0
threads_per_block = 0
registers_per_block = 0

max_registers_per_sm = 65536
warp_size = 32

blocks_per_sm_theoretical = 0
threads_per_sm_theoretical = 0
warps_per_sm_theoretical = 0
print(f"Registers per thread        : {registers_per_thread}")
print(f"Registers per block         : {registers_per_block}")
print(f"Blocks    per SM theoretical: {blocks_per_sm_theoretical}")
print(f"Threads   per SM theoretical: {threads_per_sm_theoretical}")
print(f"Warps     per SM theoretical: {warps_per_sm_theoretical}")

Based on information in the first notebook:
1. Are the numbers above good or bad?
2. Which number matters the most from the perspective of the SM?