In [None]:
from kernel_runner import KernelRunner
from dataclasses import dataclass

@dataclass(frozen=True)
class SharedMemoryKernelRunner(KernelRunner):
    template: str = "<float, 32, 32, 32>"
    kernel_name: str = "shared_memory"

# What is shared memory?

Shared memory (sometimes called scratchpad) is an explicitly managed L1 cache.

We can load "tiles" of our inputs into shared memory and access them from there
instead of global memory. This is useful if each entry in our input arrays is
used more than once, offsetting the cost of the load before use.

# Transforming loops and inserting fast memory storage

The naive loop structure looks something like this:
```python
for i in range(N):         # parallel
    for j in range(N):     # parallel
        for k in range(N): # sequential
            # compute
```

Corresponding to the math statement:
$$
C_{ij} = \sum_{k=1}^N A_{ik} B_{kj}
$$

Our goal for the shared memory example is to implement a rewritten math
expression
$$
C_{(i,ii),(j,jj)} = 
    \sum_{ko=1}^{N_b}\sum_{ki=1}^{BK} A_{(io,ii),(ko,ki)} B_{(ko,ki),(jo,jj)}
$$

This is just *block matrix multiplication*, described by the following picture:

<img src="../../resources/block-matrix-multiplication.svg" width=500 />

- Picture from: https://en.wikipedia.org/wiki/Block_matrix#Multiplication

So, we can just split the loops:
```python
for io in range(N // BM):                  # parallel outer
    for jo in range(N // BN):              # parallel outer
        for ii in range(BM):               # parallel inner
            for ji in range(BN):           # parallel inner

                for ko in range(N // BK):  # sequential outer
                    for ki in range(BK):   # sequential inner
                        # compute
```

However, this doesn't really help us because we aren't exploiting the fast memory
of the GPU.

Insert the points at which it would be useful to move data from our inputs into
shared memory:
```python
for io in range(N // BM):                  # parallel outer
    for jo in range(N // BN):              # parallel outer
        for ii in range(BM):               # parallel inner
            for ji in range(BN):           # parallel inner

                for ko in range(N // BK):  # sequential outer
                    for ki in range(BK):   # sequential inner
                        # compute
```

Take a look at <a href=../src/numpy-examples/shared-memory-matmul.py>
the shared memory numpy example to discuss loop structure and storage</a> to see
if you were correct

In [None]:
N = 8192 
BM = BN = 32
block_dim = (BM, BN)
grid_dim = (N // BM, N // BN)

# Q: what's with the GB/s? should it be higher? lower? why or why not?
runner = SharedMemoryKernelRunner()
# flip read_full_src to True if you want to use the provided implementation
_ = runner(block_dim, grid_dim, (N,), read_full_src=False, niterations=20)

# A note on bank conflicts

Shared memory is divided into 32 "banks", with each bank being 32-bits. The bank
that data belongs to is determined by the low 5-bits of its address in shared
memory.

Bank conflicts arise when multiple threads in a single warp access different
addresses in the same bank.

*Padding* is a trick that adds an offset (at the cost of more shared memory
usage) to the contiguous axis to prevent these conflicts.

Given the array:
$$
A = \begin{bmatrix}
0 & 1 & 2\\
3 & 4 & 5\\
6 & 7 & 8
\end{bmatrix}
$$

Place each entry, in row-major order, into the banks below. In particular, use
the formula $(i * 3 + j) \mod 3$ to compute the correct bank for an entry 
$A_{ij}$. Is there a potential for bank conflicts?
| bank 0 | bank 1 | bank 2 |
| ------ | ------ | ------ |
|        |        |        |
|        |        |        |
|        |        |        |
|        |        |        |
|        |        |        |
|        |        |        |
|        |        |        |
|        |        |        |
|        |        |        |

Now, pad the column-axis by 1. In particular: Use the formula 
$(i * 4 + j) \mod 3$ to compute the correct bank for an entry $A_{ij}$. Is there
a potential for bank conflicts?
| bank 0 | bank 1 | bank 2 |
| ------ | ------ | ------ |
|        |        |        |
|        |        |        |
|        |        |        |
|        |        |        |
|        |        |        |
|        |        |        |
|        |        |        |
|        |        |        |
|        |        |        |


In [None]:
def print_shared_memory_banks(nbanks, index_fn, n):
    import numpy as np
    from math import ceil

    A = np.arange(n * n).reshape(n, n).astype(int)
    banks = {b: [] for b in range(nbanks)}

    for i in range(n):
        for j in range(n):
            addr = index_fn(i, j)
            bank = addr % nbanks
            banks[bank].append((i, j, int(A[i, j]), addr))

    max_len = max(len(lst) for lst in banks.values())

    table_rows = []
    for k in range(max_len):
        row_cells = []
        for b in range(nbanks):
            if k < len(banks[b]):
                (i, j, v, addr) = banks[b][k]
                row_cells.append(f"{v}[{i},{j}]")
            else:
                row_cells.append("")
        table_rows.append(row_cells)

    col_widths = []
    for b in range(nbanks):
        col_entries = [row[b] for row in table_rows] + [str(b)]
        col_widths.append(max(len(x) for x in col_entries))

    def render_row(cells):
        return "| " + " | ".join(
            cell.ljust(w) for cell, w in zip(cells, col_widths)
        ) + " |"

    header_cells = [str(b) for b in range(nbanks)]
    sep_cells = ["-" * w for w in col_widths]

    print(render_row(header_cells))
    print(render_row(sep_cells))

    for row_cells in table_rows:
        print(render_row(row_cells))


# no padding: address = i*3 + j, 3 "banks"
print_shared_memory_banks(3, lambda i, j: i*3 + j, 3)

# with padding: address = i*4 + j, same logical 3x3, but stride 4
print_shared_memory_banks(3, lambda i, j: i*4 + j, 3)