Our goal is to implement an efficient CUDA kernel for FlashAttention v1

We will:
1. Do a minimal FlashAttention computation by hand to understand it
2. Implement a naive Python version
3. Implement a Python version with CUDA-like kernels
4. Implement a mamba version
5. Implement a CUDA version

## 1: Do a minimal FlashAttention computation by hand to understand it

This is not in obviously not in the notebook. ;)

## 2. Implement a naive Python version

In [1]:
import math
import torch
from torch import tensor, float32 as f32
from torch.nn.functional import softmax, scaled_dot_product_attention

torch.set_printoptions(sci_mode=False, precision=2, linewidth=200)

In [2]:
N, d = 2,1
M = 2

Q = tensor([1, 1], dtype=f32).reshape(N,d)
K = tensor([0, 2], dtype=f32).reshape(N,d)
V = tensor([0,-1], dtype=f32).reshape(N,d)

In [3]:
scaled_dot_product_attention(Q, K, V)

tensor([[-0.88],
        [-0.88]])

In [4]:
softmax(Q@K.t(), dim=1)@V

tensor([[-0.88],
        [-0.88]])

In [5]:
# Output
O = torch.zeros(N,d)
# softmax denominator, per row
l = torch.zeros(N)
# max for numerical stability, per row
m = torch.full((N,), float('-inf'))

O, l, m

(tensor([[0.],
         [0.]]),
 tensor([0., 0.]),
 tensor([-inf, -inf]))

In [6]:
# block sizes
bc = math.ceil(M / (4*d))
br = min(math.ceil(M  / (4*d)), d)

# block numbers
tc = math.ceil(N / bc)
tr = math.ceil(N / br)

print(f'Block sizes:   {bc}, {br}')
print(f'Block numbers: {tc}, {tr }')

Block sizes:   1, 1
Block numbers: 2, 2


In [7]:
# we don't want partial blocks, so pad Q, K, V with zeros
Q_ = torch.zeros(br*tr,d)
K_ = torch.zeros(bc*tc,d)
V_ = torch.zeros(bc*tc,d)

Q_[:N,:d] = Q
K_[:N,:d] = K
V_[:N,:d] = V

print(Q.shape, K.shape, V.shape)

Q,K,V = Q_,K_,V_

print(Q.shape, K.shape, V.shape)

torch.Size([2, 1]) torch.Size([2, 1]) torch.Size([2, 1])
torch.Size([2, 1]) torch.Size([2, 1]) torch.Size([2, 1])


In [8]:
softmax(Q@K.t(), dim=1)@V

tensor([[-0.88],
        [-0.88]])

In [9]:
# Loop for phase
for j in range(0, tc):
    print(f'Outer loop for phase: {j = }')
    
    # load from HBM -> SRAM
    Kj = K[j*bc:(j+1)*bc]
    Vj = V[j*bc:(j+1)*bc]    

    assert Kj.shape==(bc,d)
    assert Vj.shape==(bc,d)

    print(f'\t{Kj = }\t{Vj = }')

    # Loop for output rows
    for i in range(0, tr):
        print(f'\tInner loop for output rows: {i = }')
    
        # load from HBM -> SRAM
        Qi = Q[i*br:(i+1)*br]
        Oi = O[i*br:(i+1)*br]
        mi = m[i*br:(i+1)*br]
        li = l[i*br:(i+1)*br]

        print(f'\t\t{Qi = }\t{Oi = }\t{mi = }\t{li = }')
        
        assert Qi.shape==(br,d)
        assert Oi.shape==(br,d)
        assert li.shape==(br,)
        assert mi.shape==(br,)

        # compute
        Sij = Qi@Kj.t() / math.sqrt(d)
        assert Sij.shape==(br,bc)

        mij = Sij.max(1).values
        assert mij.shape==(br,)
        
        Pij = (Sij - mij[:,None]).exp()
        assert Pij.shape==(br,bc)

        lij = Pij.sum(1)
        assert lij.shape==(br,)

        print(f'\t\t{Sij = }\t{mij = }\t{Pij = }\t{lij = }')
        
        mi_new = max(mi, mij)
        li_new = (mi-mi_new).exp()*li + (mij-mi_new).exp()*lij
        assert mi_new.shape==(br,)
        assert li_new.shape==(br,)

        Oi_new_part1 = torch.diag(li)@(mi - mi_new).exp() # br,br   @ br   -> br
        Oi_new_part1 = Oi_new_part1[:,None] * Oi          # br,None * br,d -> br,d
        assert Oi_new_part1.shape==(br,d)        
        Oi_new_part2 = (mij-mi_new).exp()[:,None]*Pij     # br,None * br,bc -> br,bc
        Oi_new_part2 = Oi_new_part2@Vj                    # br,bc   @ bc,d  -> br,d
        assert Oi_new_part2.shape==(br,d)
        Oi_new = Oi_new_part1 + Oi_new_part2              # br,d + br,d  -> br,d
        Oi_new = torch.diag(li_new).inverse() @ Oi_new    # br,br @ br,d -> br,d        
        assert Oi_new.shape==(br,d)

        print(f'\t\t{Oi_new = }\t{mi_new = }\t{li_new = }')
        
        # Write SRAM -> HBM
        O[i*br:(i+1)*br] = Oi_new
        m[i*br:(i+1)*br] = mi_new
        l[i*br:(i+1)*br] = li_new

        print('\n\t\tResult:')
        print(f'\t\t{O = }\t{m = }\t{l = }')
        print()

Outer loop for phase: j = 0
	Kj = tensor([[0.]])	Vj = tensor([[0.]])
	Inner loop for output rows: i = 0
		Qi = tensor([[1.]])	Oi = tensor([[0.]])	mi = tensor([-inf])	li = tensor([0.])
		Sij = tensor([[0.]])	mij = tensor([0.])	Pij = tensor([[1.]])	lij = tensor([1.])
		Oi_new = tensor([[0.]])	mi_new = tensor([0.])	li_new = tensor([1.])

		Result:
		O = tensor([[0.],
        [0.]])	m = tensor([0., -inf])	l = tensor([1., 0.])

	Inner loop for output rows: i = 1
		Qi = tensor([[1.]])	Oi = tensor([[0.]])	mi = tensor([-inf])	li = tensor([0.])
		Sij = tensor([[0.]])	mij = tensor([0.])	Pij = tensor([[1.]])	lij = tensor([1.])
		Oi_new = tensor([[0.]])	mi_new = tensor([0.])	li_new = tensor([1.])

		Result:
		O = tensor([[0.],
        [0.]])	m = tensor([0., 0.])	l = tensor([1., 1.])

Outer loop for phase: j = 1
	Kj = tensor([[2.]])	Vj = tensor([[-1.]])
	Inner loop for output rows: i = 0
		Qi = tensor([[1.]])	Oi = tensor([[0.]])	mi = tensor([0.])	li = tensor([1.])
		Sij = tensor([[2.]])	mij = tenso

In [10]:
O

tensor([[-0.88],
        [-0.88]])

In [11]:
(O==scaled_dot_product_attention(Q,K,V)).all()

tensor(True)

In [12]:
del Q,K,V,O,N,d,M,m,l

In [13]:
from tqdm.notebook import tqdm

def flash_attention(Q,K,V,M=10,verbose=False):
    # Q,K,V = attn matrices; M = shared mem size
    N,d = Q.shape
    assert V.shape==K.shape==(N,d), "Shape mismatch"
    
    # block sizes
    bc = math.ceil(M / (4*d))
    br = min(math.ceil(M / (4*d)), d)
    # block numbers
    tc = math.ceil(N / bc)
    tr = math.ceil(N / br)    

    print(f'Block sizes:   {bc}, {br}')
    print(f'Block numbers: {tc}, {tr }')
    
    O = torch.zeros(N,d)                 # output
    l = torch.zeros(N)                   # softmax denominator, per row
    m = torch.full((N,), float('-inf'))  # max for numerical stability, per row

    with tqdm(total=tc*tr) as pbar:
        # Loop for phase
        for j in range(0, tc):
            if verbose: print(f'Outer loop for phase: {j = }')
            # load from HBM -> SRAM
            Kj = K[j*bc:(j+1)*bc]
            Vj = V[j*bc:(j+1)*bc]    

            if verbose: print(f'\t{Kj = }\n\t{Vj = }')    

            # Loop for output rows
            for i in range(0, tr):
                if verbose: print(f'\tInner loop for output rows: {i = }') 
                # load from HBM -> SRAM
                Qi = Q[i*br:(i+1)*br]
                Oi = O[i*br:(i+1)*br]
                mi = m[i*br:(i+1)*br]
                li = l[i*br:(i+1)*br]

                if verbose: print(f'\t\t{Qi = }\n\t\t{Oi = }\n\t\t{mi = }\n\t\t{li = }')
                        
                # compute
                Sij = Qi@Kj.t() / math.sqrt(d)
                mij = Sij.max(1).values
                Pij = (Sij - mij[:,None]).exp()
                lij = Pij.sum(1)

                if verbose: print(f'\t\t{Sij = }\n\t\t{mij = }\n\t\t{Pij = }\n\t\t{lij = }')
                
                mi_new = torch.max(mi, mij)
                li_new = (mi-mi_new).exp()*li + (mij-mi_new).exp()*lij

                Oi_new_part1 = torch.diag(li)@(mi - mi_new).exp() # br,br   @ br   -> br
                Oi_new_part1 = Oi_new_part1[:,None] * Oi          # br,None * br,d -> br,d
                assert Oi_new_part1.shape==(br,d)        
                Oi_new_part2 = (mij-mi_new).exp()[:,None]*Pij     # br,None * br,bc -> br,bc
                Oi_new_part2 = Oi_new_part2@Vj                    # br,bc   @ bc,d  -> br,d
                assert Oi_new_part2.shape==(br,d)
                Oi_new = Oi_new_part1 + Oi_new_part2              # br,d + br,d  -> br,d
                Oi_new = torch.diag(li_new).inverse() @ Oi_new    # br,br @ br,d -> br,d        
                assert Oi_new.shape==(br,d)

                if verbose: print(f'\t\t{Oi_new = }\n\t\t{mi_new = }\n\t\t{li_new = }')
                            
                # Write SRAM -> HBM
                O[i*br:(i+1)*br] = Oi_new
                m[i*br:(i+1)*br] = mi_new
                l[i*br:(i+1)*br] = li_new

                if verbose: print(f'\n\t\tResult:\n\t\t{O = }\t{m = }\t{l = }\n')
                
                pbar.update(1)

    return O

In [14]:
Q = tensor([1, 1], dtype=f32).reshape(2,1)
K = tensor([0, 2], dtype=f32).reshape(2,1)
V = tensor([0,-1], dtype=f32).reshape(2,1)

O = flash_attention(Q,K,V)

Block sizes:   3, 1
Block numbers: 1, 2


  0%|          | 0/2 [00:00<?, ?it/s]

In [15]:
torch.isclose(O, scaled_dot_product_attention(Q,K,V)).all()

tensor(True)

A minimal example that didn't work (and now does):

In [16]:
Q = tensor([1,0,0,0], dtype=f32).view(2,2)
K = tensor([1,0,0,0], dtype=f32).view(2,2)
V = tensor([0,1,0,0], dtype=f32).view(2,2)

O = flash_attention(Q,K,V,M=9)

torch.isclose(O, scaled_dot_product_attention(Q,K,V)).all()

Block sizes:   2, 2
Block numbers: 1, 1


  0%|          | 0/1 [00:00<?, ?it/s]

tensor(True)

A random 2x2 matrix

In [17]:
Q = torch.rand(2,2, dtype=f32)
K = torch.rand(2,2, dtype=f32)
V = torch.rand(2,2, dtype=f32)

In [18]:
O = flash_attention(Q,K,V,M=2)

Block sizes:   1, 1
Block numbers: 2, 2


  0%|          | 0/4 [00:00<?, ?it/s]

In [19]:
torch.isclose(O, scaled_dot_product_attention(Q,K,V)).all()

tensor(True)

A random, larger matrix

In [20]:
Q = torch.rand(256,64, dtype=f32)
K = torch.rand(256,64, dtype=f32)
V = torch.rand(256,64, dtype=f32)

In [21]:
O = flash_attention(Q,K,V,M=100)

Block sizes:   1, 1
Block numbers: 256, 256


  0%|          | 0/65536 [00:00<?, ?it/s]

In [22]:
torch.isclose(O, scaled_dot_product_attention(Q,K,V)).all()

tensor(True)

## 3. Implement a Python version with CUDA-like kernels

In [23]:
import math
import torch
from torch import tensor, float32 as f32
from torch.nn.functional import softmax, scaled_dot_product_attention

torch.set_printoptions(sci_mode=False, precision=2, linewidth=200)

This kernel runner is taken from 
> https://github.com/cuda-mode/lectures/blob/main/lecture5/matmul_l5.ipynb

In [24]:
from collections import namedtuple
dim3 = namedtuple('dim3', ['x','y','z'], defaults=(1,1))
def cdiv(a,b): return (a + b - 1) // b

In [25]:
from threading import Thread, Barrier

def blk_kernel2d_shar(f, blocks, tpb, sh_sz, *args, **kwargs):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            shar = torch.zeros(sh_sz)
            syncb = Barrier(tpb.y*tpb.x)
            threads = [
                Thread(
                    target=f,
                    args=(dim3(i1,i0), dim3(p,o), tpb, shar, syncb, *args),
                    kwargs=kwargs
                )
                for o in range(tpb.y)
                for p in range(tpb.x)
            ]
            for tr in threads: tr.start()
            for tr in threads: tr.join()

In [26]:
def flash_attn(kernel_func, Q, K, V, sh_sz):
    # Q,K,V = attn matrices; sh_sz = shared mem size; n x d = shape of Q/K/V/O
    n,d = Q.shape
    assert V.shape==K.shape==(n,d), "Shape mismatch"

    O = torch.zeros(n,d)                 # output
    l = torch.zeros(n)                   # softmax denominator, per row
    m = torch.full((n,), float('-inf'))  # max for numerical stability, per row
    
    # block sizes
    bc = cdiv(sh_sz, 4*d)
    br = min(cdiv(sh_sz, 4*d), d)
    block_sizes = dim3(bc,br)
    
    # block numbers
    tc = cdiv(n, bc)
    tr = cdiv(n, br)
    block_numbers = dim3(tc,tr)
    
    print(f'Block sizes:   {bc}, {br}')
    print(f'Block numbers: {tc}, {tr }')
        
    blk_kernel2d_shar(
        f=kernel_func,
        blocks=block_numbers, tpb=block_sizes,
        sh_sz=sh_sz,
        Q=Q.flatten(), K=K.flatten(), V=V.flatten(), O=O.flatten(),
        m=m,l=l,
        n=n, d=d
    )

In [27]:
Q = tensor([1,0,0,0], dtype=f32).view(2,2)
K = tensor([1,0,0,0], dtype=f32).view(2,2)
V = tensor([0,1,0,0], dtype=f32).view(2,2)
sh_sz = 9

We need to decide how to split up shared memory. Let's just store the objects in SRAM by order of read (`Kj`, `Vj`, `Qi`, `Oi`, `mi`, `li`).

Then the offsets for each are:
- `Kj`: `0`
- `Vj`: `bc * d`
- `Qi`: `2*bc*d`
- `Oi`: `2*bc*d + br*d`
- `mi`: `2*bc*d + 2*br*d`
- `li`: `2*bc*d + 2*br*d + br`

In [28]:
def flash_attn_kernel(blockIdx, threadIdx, blockDim, shared, syncb, Q, K, V, O, m, l, n, d):
    # Q,K,V = attn matrices; O = output; m,l = running statistic; n x d = shape of Q/K/V/O

    print(f'Executing kernel: {blockIdx = } , {threadIdx = }')

    br,bc,_ = blockDim
    
    # split SRAM
    Kj = shared[                     : bc+d                ]
    Vj = shared[  bc+d               : 2*bc*d              ]
    Qi = shared[2*bc*d               : 2*bc*d +   br*d     ]
    Oi = shared[2*bc*d +   br*d      : 2*bc*d + 2*br*d     ]
    mi = shared[2*bc*d + 2*br*d      : 2*bc*d + 2*br*d + br]
    li = shared[2*bc*d + 2*br*d + br :                     ]
        
    for j in range(0, tc):
        # load from HBM -> SRAM
        Kj = K[j*bc:(j+1)*bc]
        Vj = V[j*bc:(j+1)*bc]

        for i in range(0, tr):
            # load from HBM -> SRAM
            Qi = Q[i*br : (i+1)*br]  
            Oi = O[i*br : (i+1)*br]    
            mi = m[i*br:(i+1)*br]
            li = l[i*br:(i+1)*br]

            # currently, loading is done by each thread. how can we split loading?
            
            for n,o in zip('Kj, Vj, Qi, Oi, mi, li'.split(', '), [Kj, Vj, Qi, Oi, mi, li]): print(n, o)
            return
            
            # compute
            Sij = Qi@Kj.t() / math.sqrt(d)
            mij = Sij.max(1).values
            Pij = (Sij - mij[:,None]).exp()
            lij = Pij.sum(1)

            mi_new = torch.max(mi, mij)
            li_new = (mi-mi_new).exp()*li + (mij-mi_new).exp()*lij

            Oi_new_part1 = torch.diag(li)@(mi - mi_new).exp() # br,br   @ br   -> br
            Oi_new_part1 = Oi_new_part1[:,None] * Oi          # br,None * br,d -> br,d
            assert Oi_new_part1.shape==(br,d)        
            Oi_new_part2 = (mij-mi_new).exp()[:,None]*Pij     # br,None * br,bc -> br,bc
            Oi_new_part2 = Oi_new_part2@Vj                    # br,bc   @ bc,d  -> br,d
            assert Oi_new_part2.shape==(br,d)
            Oi_new = Oi_new_part1 + Oi_new_part2              # br,d + br,d  -> br,d
            Oi_new = torch.diag(li_new).inverse() @ Oi_new    # br,br @ br,d -> br,d        
            assert Oi_new.shape==(br,d)

            # Write SRAM -> HBM
            O[i*br:(i+1)*br] = Oi_new
            m[i*br:(i+1)*br] = mi_new
            l[i*br:(i+1)*br] = li_new

In [29]:
flash_attn(flash_attn_kernel, Q, K, V, sh_sz)

Block sizes:   2, 2
Block numbers: 1, 1
Executing kernel: blockIdx = dim3(x=0, y=0, z=1) , threadIdx = dim3(x=0, y=0, z=1)
Kj Executing kernel: blockIdx = dim3(x=0, y=0, z=1) , threadIdx = dim3(x=1, y=0, z=1)
Kj Executing kernel: blockIdx = dim3(x=0, y=0, z=1) , threadIdx = dim3(x=0, y=1, z=1)
Kj Executing kernel: blockIdx = dim3(x=0, y=0, z=1) , threadIdx = dim3(x=1, y=1, z=1)
Kj tensor([1., 0.])
Vj tensor([1., 0.])
Vj tensor([1., 0.])
Vj tensor([1., 0.])
Vj tensor([0., 1.])
Qi tensor([0., 1.])
Qi tensor([0., 1.])
Qi tensor([0., 1.])
Qi tensor([1., 0.])
Oi tensor([1., 0.])
Oi tensor([1., 0.])
Oi tensor([0., 0.])
mi tensor([0., 0.])
mi tensor([0., 0.])
mi tensor([-inf, -inf])
li tensor([-inf, -inf])
li tensor([-inf, -inf])
li tensor([1., 0.])
Oi tensor([0., 0.])
tensor([0., 0.])
tensor([0., 0.])
mi tensor([-inf, -inf])
li tensor([0., 0.])
tensor([0., 0.])


The inputs Q,K,V,O are flattened, so we need to changed access to them from 2d into 1d

In [30]:
def flash_attn_kernel(blockIdx, threadIdx, blockDim, shared, syncb, Q, K, V, O, m, l, n, d):
    # Q,K,V = attn matrices; O = output; m,l = running statistic; n x d = shape of Q/K/V/O

    print(f'Executing kernel: {blockIdx = } , {threadIdx = }')

    br,bc,_ = blockDim
    
    # split SRAM
    Kj = shared[                     : bc+d                ]
    Vj = shared[  bc+d               : 2*bc*d              ]
    Qi = shared[2*bc*d               : 2*bc*d +   br*d     ]
    Oi = shared[2*bc*d +   br*d      : 2*bc*d + 2*br*d     ]
    mi = shared[2*bc*d + 2*br*d      : 2*bc*d + 2*br*d + br]
    li = shared[2*bc*d + 2*br*d + br :                     ]
        
    for j in range(0, tc):
        # load from HBM -> SRAM
        # remember, everything is 1d
        Kj = K[j*n*bc:(j+1)*n*bc]
        Vj = V[j*n*bc:(j+1)*n*bc]

        for i in range(0, tr):
            # load from HBM -> SRAM
            # remember, everything is 1d
            Qi = Q[i*n*br : (i+1)*n*br]  
            Oi = O[i*n*br : (i+1)*n*br]    
            mi = m[i*br:(i+1)*br]
            li = l[i*br:(i+1)*br]

            # currently, loading is done by each thread. how can we split loading?
            
            for n,o in zip('Kj, Vj, Qi, Oi, mi, li'.split(', '), [Kj, Vj, Qi, Oi, mi, li]): print(n, o)
            return
            
            # compute
            Sij = Qi@Kj.t() / math.sqrt(d)
            mij = Sij.max(1).values
            Pij = (Sij - mij[:,None]).exp()
            lij = Pij.sum(1)

            mi_new = torch.max(mi, mij)
            li_new = (mi-mi_new).exp()*li + (mij-mi_new).exp()*lij

            Oi_new_part1 = torch.diag(li)@(mi - mi_new).exp() # br,br   @ br   -> br
            Oi_new_part1 = Oi_new_part1[:,None] * Oi          # br,None * br,d -> br,d
            assert Oi_new_part1.shape==(br,d)        
            Oi_new_part2 = (mij-mi_new).exp()[:,None]*Pij     # br,None * br,bc -> br,bc
            Oi_new_part2 = Oi_new_part2@Vj                    # br,bc   @ bc,d  -> br,d
            assert Oi_new_part2.shape==(br,d)
            Oi_new = Oi_new_part1 + Oi_new_part2              # br,d + br,d  -> br,d
            Oi_new = torch.diag(li_new).inverse() @ Oi_new    # br,br @ br,d -> br,d        
            assert Oi_new.shape==(br,d)

            # Write SRAM -> HBM
            O[i*br:(i+1)*br] = Oi_new
            m[i*br:(i+1)*br] = mi_new
            l[i*br:(i+1)*br] = li_new

In [31]:
flash_attn(flash_attn_kernel, Q, K, V, sh_sz)

Block sizes:   2, 2
Block numbers: 1, 1
Executing kernel: blockIdx = dim3(x=0, y=0, z=1) , threadIdx = dim3(x=0, y=0, z=1)
Kj Executing kernel: blockIdx = dim3(x=0, y=0, z=1) , threadIdx = dim3(x=1, y=0, z=1)
Kj Executing kernel: blockIdx = dim3(x=0, y=0, z=1) , threadIdx = dim3(x=0, y=1, z=1)
Kj Executing kernel: blockIdx = dim3(x=0, y=0, z=1) , threadIdx = dim3(x=1, y=1, z=1)
Kj tensor([1., 0., 0., 0.])
Vj tensor([1., 0., 0., 0.])
Vj tensor([1., 0., 0., 0.])
Vj tensor([1., 0., 0., 0.])
Vj tensor([0., 1., 0., 0.])
Qi tensor([0., 1., 0., 0.])
Qi tensor([0., 1., 0., 0.])
Qi tensor([0., 1., 0., 0.])
Qi tensor([1., 0., 0., 0.])
Oi tensor([0., 0., 0., 0.])
mi tensor([1., 0., 0., 0.])
Oi tensor([1., 0., 0., 0.])
Oi tensor([1., 0., 0., 0.])
Oi tensor([0., 0., 0., 0.])
mi tensor([0., 0., 0., 0.])
mi tensor([-inf, -inf])
li tensor([0., 0., 0., 0.])
mi tensor([-inf, -inf])
li tensor([-inf, -inf])
li tensor([0., 0.])
tensor([0., 0.])
tensor([0., 0.])
tensor([-inf, -inf])
li tensor([0., 0.])


To Do:
- Split loading into shared memeory accros threads
- Figure out how to split comutation accross threads
- Add guards

Let's first split the computation naively: Each computation step will be split separately, and we sync in between.

In [32]:
def flash_attn_kernel(blockIdx, threadIdx, blockDim, shared, syncb, Q, K, V, O, m, l, n, d):
    # Q,K,V = attn matrices; O = output; m,l = running statistic; n x d = shape of Q/K/V/O

    print(f'Executing kernel: {blockIdx = } , {threadIdx = }')

    # do idx start at highest (ie z) or lowest (ie x) dim?
    bx,by,_ = blockIdx
    tx,ty,_ = threadIdx
    
    br,bc,_ = blockDim
    
    # split SRAM
    Kj = shared[                     : bc+d                ]
    Vj = shared[  bc+d               : 2*bc*d              ]
    Qi = shared[2*bc*d               : 2*bc*d +   br*d     ]
    Oi = shared[2*bc*d +   br*d      : 2*bc*d + 2*br*d     ]
    mi = shared[2*bc*d + 2*br*d      : 2*bc*d + 2*br*d + br]
    li = shared[2*bc*d + 2*br*d + br :                     ]
        
    for j in range(0, tc):
        # load from HBM -> SRAM
        # remember, everything is 1d
        Kj = K[j*n*bc:(j+1)*n*bc]
        Vj = V[j*n*bc:(j+1)*n*bc]

        for i in range(0, tr):
            # load from HBM -> SRAM
            # remember, everything is 1d
            Qi = Q[i*n*br : (i+1)*n*br]  
            Oi = O[i*n*br : (i+1)*n*br]    
            mi = m[i*br:(i+1)*br]
            li = l[i*br:(i+1)*br]
            # currently, loading is done by each thread. how can we split loading?

            
            # compute
            # Q: Where are the intermediate variables (Sij, mij, ...) allocated? In SRAM?

            # Sij = Qi@Kj.t() / math.sqrt(d)
            for ctn_d in range(d): Sij[tx,ty] += Qi[ty][ctn_d]*Kj[tx][ctn_d]
            Sij[tx,ty] /= math.sqrt(d)
            syncb.wait()     

            # mij = Sij.max(1).values
            mij[ty] = 0
            for ctn_d in range(d): mij[ty] = max(mij[ty], Sij[ty][ctn_d)]
            syncb.wait()

            continue
            
            Pij = (Sij - mij[:,None]).exp()
            syncb.wait()
            
            lij = Pij.sum(1)
            syncb.wait()
            
            mi_new = torch.max(mi, mij)
            li_new = (mi-mi_new).exp()*li + (mij-mi_new).exp()*lij

            Oi_new_part1 = torch.diag(li)@(mi - mi_new).exp() # br,br   @ br   -> br
            Oi_new_part1 = Oi_new_part1[:,None] * Oi          # br,None * br,d -> br,d
            assert Oi_new_part1.shape==(br,d)        
            Oi_new_part2 = (mij-mi_new).exp()[:,None]*Pij     # br,None * br,bc -> br,bc
            Oi_new_part2 = Oi_new_part2@Vj                    # br,bc   @ bc,d  -> br,d
            assert Oi_new_part2.shape==(br,d)
            Oi_new = Oi_new_part1 + Oi_new_part2              # br,d + br,d  -> br,d
            Oi_new = torch.diag(li_new).inverse() @ Oi_new    # br,br @ br,d -> br,d        
            assert Oi_new.shape==(br,d)

            # Write SRAM -> HBM
            O[i*br:(i+1)*br] = Oi_new
            m[i*br:(i+1)*br] = mi_new
            l[i*br:(i+1)*br] = li_new

SyntaxError: closing parenthesis ')' does not match opening parenthesis '[' (1398450284.py, line 46)

In [None]:
A = torch.rand(3,2)
B = torch.rand(3,2)

C = torch.zeros(3,3)

In [None]:
from itertools import product

n = A.shape[0]
for a,b in product(range(n), range(n)):
    for c in range(d):
        C[a,b] += A[a,c]*B[b,c]

In [None]:
torch.isclose(C, A@B.t()).all()