In [None]:
from kernel_runner import KernelRunner
from dataclasses import dataclass

@dataclass(frozen=True)
class SharedMemoryKernelRunner(KernelRunner):
    template: str = "<float, 32, 32, 32>"
    kernel_name: str = "shared_memory"

# What is shared memory?

Shared memory (sometimes called scratchpad) is an explicitly managed L1 cache.

We can load "tiles" of our inputs into shared memory and access them from there
instead of global memory. This is useful if each entry in our input arrays is
used more than once, offsetting the cost of the load before use.

# Transforming loops and inserting fast memory storage

The naive loop structure looks something like this:
```python
for i in range(N):         # parallel
    for j in range(N):     # parallel
        for k in range(N): # sequential
            # compute
```

Corresponding to the math statement:
$$
C_{ij} = \sum_{k=1}^N A_{ik} B_{kj}
$$

Our goal for the shared memory example is to implement a rewritten math
expression
$$
C_{(i,ii),(j,jj)} = 
    \sum_{ko=1}^{N_b}\sum_{ki=1}^{BK} A_{(io,ii),(ko,ki)} B_{(ko,ki),(jo,jj)}
$$

This is just *block matrix multiplication*, described by the following picture:

<img src="../../resources/block-matrix-multiplication.svg" width=500 />

- Picture from: https://en.wikipedia.org/wiki/Block_matrix#Multiplication

So, we can just split the loops:
```python
for io in range(N // BM):                  # parallel outer
    for jo in range(N // BN):              # parallel outer
        for ii in range(BM):               # parallel inner
            for ji in range(BN):           # parallel inner

                for ko in range(N // BK):  # sequential outer
                    for ki in range(BK):   # sequential inner
                        # compute
```

However, this doesn't really help us because we aren't exploiting the fast memory
of the GPU.

Insert the points at which it would be useful to move data from our inputs into
shared memory:
```python
for io in range(N // BM):                  # parallel outer
    for jo in range(N // BN):              # parallel outer
        for ii in range(BM):               # parallel inner
            for ji in range(BN):           # parallel inner

                for ko in range(N // BK):  # sequential outer
                    for ki in range(BK):   # sequential inner
                        # compute
```

Take a look at <a href=../src/numpy-examples/shared-memory-matmul.py>
the shared memory numpy example to discuss loop structure and storage</a> to see
if you were correct

In [None]:
N = 8192 
BM = BN = 32
block_dim = (BM, BN)
grid_dim = (N // BM, N // BN)

# Q: what's with the GB/s? should it be higher? lower? why or why not?
runner = SharedMemoryKernelRunner()
_ = runner(block_dim, grid_dim, (N,), read_full_src=False, niterations=20)