In [1]:
from numba import cuda

device = cuda.get_current_device()

# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html
print(f"Shared memory per multiprocessor: {device.MAX_SHARED_MEMORY_PER_MULTIPROCESSOR} bytes")
print(f"Shared memory per block: {device.MAX_SHARED_MEMORY_PER_BLOCK} bytes")
print(f"Registers per block: {device.MAX_REGISTERS_PER_BLOCK}")
print(f"Threads per block: {device.MAX_THREADS_PER_BLOCK}")
print(f"Number of multiprocessors: {device.MULTIPROCESSOR_COUNT}")

Shared memory per multiprocessor: 98304 bytes
Shared memory per block: 49152 bytes
Registers per block: 65536
Threads per block: 1024
Number of multiprocessors: 80


In [3]:
import math
import torch

sizeof = lambda x: torch.tensor(0, dtype=x).element_size()

B, C, SEQ_LEN, D_HEAD = (32, 32, 128, 128)
threadsperblock = (D_HEAD, 1, 1) # Should be max 1024
blockspergrid = (math.ceil(SEQ_LEN * D_HEAD / threadsperblock[0]), B, C)
shared_memory_size = (D_HEAD * sizeof(torch.float64)) * 3 + (SEQ_LEN**2 * sizeof(torch.float64))
    
print(f"blockspergrid: {blockspergrid}, threadsperblock: {threadsperblock}, shared_memory_size: {shared_memory_size}")

blockspergrid: (128, 32, 32), threadsperblock: (128, 1, 1), shared_memory_size: 134144


In [6]:
(SEQ_LEN**2 * sizeof(torch.float64))

131072

In [5]:
import torch
# change print precision
torch.set_printoptions(precision=4, sci_mode=False)

B, C, SEQ_LEN, D_HEAD = (1, 1, 2, 4)
Q = (torch.arange(B * C * SEQ_LEN * D_HEAD, dtype=torch.float64).reshape(B, C, SEQ_LEN, D_HEAD) + 1.) / 10
K = (torch.arange(B * C * SEQ_LEN * D_HEAD, dtype=torch.float64).reshape(B, C, SEQ_LEN, D_HEAD) + 1.) / 10
V = (torch.arange(B * C * SEQ_LEN * D_HEAD, dtype=torch.float64).reshape(B, C, SEQ_LEN, D_HEAD) + 1.) / 10
Q, Q.shape

  from .autonotebook import tqdm as notebook_tqdm


(tensor([[[[0.1000, 0.2000, 0.3000, 0.4000],
           [0.5000, 0.6000, 0.7000, 0.8000]]]], dtype=torch.float64),
 torch.Size([1, 1, 2, 4]))

In [6]:
m = torch.full((B, C, SEQ_LEN, 1), -torch.inf)
l = torch.zeros((B, C, SEQ_LEN, 1))

prev_rowmax = m 
prev_denominator = l
Oi = l.clone()
prev_rowmax.shape

torch.Size([1, 1, 2, 1])

In [14]:
Q.flatten()

tensor([0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000, 0.8000],
       dtype=torch.float64)

In [13]:
K.transpose(-2, -1).flatten()

tensor([0.1000, 0.5000, 0.2000, 0.6000, 0.3000, 0.7000, 0.4000, 0.8000],
       dtype=torch.float64)

In [8]:
S = torch.matmul(Q, K.transpose(2, 3))
S.shape, S

(torch.Size([1, 1, 2, 2]),
 tensor([[[[0.3000, 0.7000],
           [0.7000, 1.7400]]]], dtype=torch.float64))

In [55]:
tile_rowmax = torch.max(S, dim=-1).values[..., None]
tile_rowmax, tile_rowmax.shape

(tensor([[[[0.3000]]]], dtype=torch.float64), torch.Size([1, 1, 1, 1]))

In [33]:
tile_numerator = torch.exp(S - tile_rowmax)
tile_numerator, tile_numerator.shape

(tensor([[[[0.9900, 1.0000],
           [0.9802, 1.0000]]]], dtype=torch.float64),
 torch.Size([1, 1, 2, 2]))

In [40]:
tile_denominator = torch.sum(tile_numerator, dim=-1)[..., None]
tile_denominator

tensor([[[[1.9900],
          [1.9802]]]], dtype=torch.float64)

In [41]:
new_rowmax = torch.max(torch.column_stack([prev_rowmax, tile_rowmax]), dim=1).values[:, None]
new_rowmax

tensor([[[[0.0200],
          [0.0400]]]], dtype=torch.float64)

In [42]:
update_prev_exponent = torch.exp(prev_rowmax - new_rowmax)
update_prev_exponent

tensor([[[[0.],
          [0.]]]], dtype=torch.float64)

In [43]:
new_denominator = prev_denominator * update_prev_exponent + torch.exp(tile_rowmax - new_rowmax) * tile_denominator
new_denominator

tensor([[[[1.9900],
          [1.9802]]]], dtype=torch.float64)

In [44]:
left = (Oi * (prev_denominator * update_prev_exponent) / new_denominator)
left

tensor([[[[0.],
          [0.]]]], dtype=torch.float64)

In [45]:
right = ((tile_numerator * torch.exp(tile_rowmax - new_rowmax)) / new_denominator)
right

tensor([[[[0.4975, 0.5025],
          [0.4950, 0.5050]]]], dtype=torch.float64)

In [48]:
V

tensor([[[[0.1000],
          [0.2000]]]], dtype=torch.float64)

In [49]:
right @ V

tensor([[[[0.1502],
          [0.1505]]]], dtype=torch.float64)

In [None]:
print(torch.sum(Q[0, 0, 0, :] * K.transpose(2,3)[0, 0, :, 0]), S[0, 0, 0, 0])
print(torch.sum(Q[0, 0, 0, :] * K.transpose(2,3)[0, 0, :, 1]), S[0, 0, 0, 1])
print(torch.sum(Q[0, 0, 1, :] * K.transpose(2,3)[0, 0, :, 0]), S[0, 0, 1, 0])
print(torch.sum(Q[0, 0, 1, :] * K.transpose(2,3)[0, 0, :, 1]), S[0, 0, 1, 1])

In [39]:
# S.shape, V.shape
torch.matmul(S, V)

tensor([[[[   946.5,    986.3,   1026.1,   1065.9,   1105.6,   1145.4,   1185.2,
             1225.0,   1264.7,   1304.5,   1344.3,   1384.1,   1423.8,   1463.6,
             1503.4,   1543.2,   1582.9,   1622.7,   1662.5,   1702.3,   1742.0,
             1781.8,   1821.6,   1861.4,   1901.2,   1940.9,   1980.7,   2020.5,
             2060.3,   2100.0,   2139.8,   2179.6],
          [  2602.3,   2708.7,   2815.0,   2921.3,   3027.7,   3134.0,   3240.4,
             3346.7,   3453.0,   3559.4,   3665.7,   3772.0,   3878.4,   3984.7,
             4091.0,   4197.4,   4303.7,   4410.0,   4516.4,   4622.7,   4729.1,
             4835.4,   4941.7,   5048.1,   5154.4,   5260.7,   5367.1,   5473.4,
             5579.7,   5686.1,   5792.4,   5898.8]]],


        [[[ 42031.5,  42534.1,  43036.7,  43539.3,  44042.0,  44544.6,  45047.2,
            45549.8,  46052.4,  46555.1,  47057.7,  47560.3,  48062.9,  48565.6,
            49068.2,  49570.8,  50073.4,  50576.1,  51078.7,  51581.3,  52083.9,
 

In [38]:
S[0, 0, 0, 0] * V[0, 0, 0, :] + S[0, 0, 0, 1] * V[0, 0, 1, :]

tensor([ 946.5,  986.3, 1026.1, 1065.9, 1105.6, 1145.4, 1185.2, 1225.0, 1264.7,
        1304.5, 1344.3, 1384.1, 1423.8, 1463.6, 1503.4, 1543.2, 1582.9, 1622.7,
        1662.5, 1702.3, 1742.0, 1781.8, 1821.6, 1861.4, 1901.2, 1940.9, 1980.7,
        2020.5, 2060.3, 2100.0, 2139.8, 2179.6], dtype=torch.float64)

In [None]:
def ref_attention(Q, K, V):
    # B, C, SEQ_LEN, D_HEAD
    S = torch.matmul(Q, K.transpose(2, 3))
    # (3, 2, SEQ_LEN, SEQ_LEN)
    print(S.shape)
    # Compute softmax for each row of S
    P = torch.softmax(S, dim=2)
    O =  torch.matmul(P, V)
    return O

ref = ref_attention(Q.clone(), K.clone(), V.clone())
print(ref)

In [None]:
import torch

seq_len = 4
d_head = 4

Q = (torch.arange(seq_len * d_head).reshape(seq_len, d_head) + 1.)  / 10
K = (torch.arange(seq_len * d_head).reshape(seq_len, d_head) + 1.)  / 10
V = (torch.arange(seq_len * d_head).reshape(seq_len, d_head) + 1.)  / 10


In [None]:
Q

In [None]:
torch.softmax(Q @ K.T, dim=1) @ V

In [None]:
torch.softmax(Q[0, :] @ K.T, dim=0)

# One-shot

In [None]:
seq_len = 4
d_head = 4

Q = (torch.arange(seq_len * d_head).reshape(seq_len, d_head) + 1.)  / 10
K = (torch.arange(seq_len * d_head).reshape(seq_len, d_head) + 1.)  / 10
V = (torch.arange(seq_len * d_head).reshape(seq_len, d_head) + 1.)  / 10
O = torch.zeros((seq_len, d_head))


m = torch.full((seq_len, 1), -torch.inf)
l = torch.zeros((seq_len, 1))

prev_rowmax = m  # shape Br x 1
prev_denominator = l  # shape Br x 1

# print("====== Q =======")
# print(Q)
# print()

# print("====== Sij =======")
Sij = Q @ K # TODO: I removed the transpose here to match GPU version 
# print(Sij)
# print()

print("====== tile_rowmax =======")
tile_rowmax = torch.max(Sij, dim=1).values[:, None]
print(tile_rowmax)
print()

print("====== tile_numerator =======")
tile_numerator = torch.exp(Sij - tile_rowmax)
print(tile_numerator)
print()

print("====== tile_denominator =======")
tile_denominator = torch.sum(tile_numerator, dim=1)[:, None]
print(tile_denominator)
print()

print("====== new_rowmax =======")
new_rowmax = torch.max(torch.column_stack([prev_rowmax, tile_rowmax]), dim=1).values[:, None]
print(new_rowmax)
print()

print("====== update_prev_exponent  =======")
update_prev_exponent = torch.exp(prev_rowmax - new_rowmax)
print(update_prev_exponent)

print("====== new_denominator =======")
new_denominator = prev_denominator * update_prev_exponent + torch.exp(tile_rowmax - new_rowmax) * tile_denominator
print(new_denominator)
print()

# print("====== O =======")
# O = (li * torch.exp(mi - mi_new) * O / li_new) + (torch.exp(tile_rowmax - mi_new) * pij_hat / li_new) @ V
# print(O)
# print()


In [None]:
0.2454 * 0.1 + 0.3660 * 0.5

In [None]:
import numpy as np
np.exp(-float("inf"))

# 2-shot on K

In [None]:
seq_len = 4
d_head = 4

Q = (torch.arange(seq_len * d_head).reshape(seq_len, d_head) + 1.)  / 10
K = (torch.arange(seq_len * d_head).reshape(seq_len, d_head) + 1.)  / 10
V = (torch.arange(seq_len * d_head).reshape(seq_len, d_head) + 1.)  / 10
O = torch.zeros((seq_len, d_head))


m = torch.full((seq_len, 1), -torch.inf)
l = torch.zeros((seq_len, 1))

mi = m  # shape Br x 1
li = l  # shape Br x 1

# print("1st pass on K")
print("====== Q =======")
print(Q)
print()

print("====== Sij =======")
Sij = Q @ K[0:2, :].T
print(Sij)
print()

print("====== mij_hat =======")
mij_hat = torch.max(Sij, dim=1).values[:, None]
print(mij_hat)
print()

print("====== pij_hat =======")
pij_hat = torch.exp(Sij - mij_hat)
print(pij_hat)
print()

print("====== lij_hat =======")
lij_hat = torch.sum(pij_hat, dim=1)[:, None]
print(lij_hat)
print()

print("====== mi_new =======")
mi_new = torch.max(torch.column_stack([mi, mij_hat]), dim=1).values[:, None]
print(mi_new)
print()

print("====== li_new =======")
li_new = torch.exp(mi - mi_new) * li + torch.exp(mij_hat - mi_new) * lij_hat
print(li_new)
print()

print("====== O =======")
O = (li * torch.exp(mi - mi_new) * O / li_new) + (torch.exp(mij_hat - mi_new) * pij_hat / li_new) @ V[0:2, :]
print(O)
print()

mi = mi_new


In [None]:
# print("2nd pass on K")
print("====== Q =======")
print(Q)
print()

print("====== Sij =======")
Sij = Q @ K[2:4, :].T
print(Sij)
print()

print("====== mij_hat =======")
mij_hat = torch.max(Sij, dim=1).values[:, None]
print(mij_hat)
print()

print("====== pij_hat =======")
pij_hat = torch.exp(Sij - mij_hat)
print(pij_hat)
print()

print("====== lij_hat =======")
lij_hat = torch.sum(pij_hat, dim=1)[:, None]
print(lij_hat)
print()

print("====== mi_new =======")
mi_new = torch.max(torch.column_stack([mi, mij_hat]), dim=1).values[:, None]
print(mi_new)
print()

print("====== li_new =======")
li_new = torch.exp(mi - mi_new) * li + torch.exp(mij_hat - mi_new) * lij_hat
print(li_new)
print()

print("====== O =======")
O = (li * torch.exp(mi - mi_new) * O / li_new) + (torch.exp(mij_hat - mi_new) * pij_hat / li_new) @ V[2:4, :]
print(O)
print()

In [None]:
print(torch.exp(mi - mi_new) * li)
print(torch.exp(mij_hat - mi_new) * lij_hat)
print(li_new)

In [None]:
1.6703 + 0.7505

In [None]:
tensor([[0.8915, 0.9915, 1.0915, 1.1915],
        [1.1067, 1.2067, 1.3067, 1.4067],
        [1.2103, 1.3103, 1.4103, 1.5103],
        [1.2566, 1.3566, 1.4566, 1.5566]])

# Softmax

> https://arxiv.org/pdf/1805.02867.pdf

In [None]:
torch.manual_seed(0)

x = torch.arange(10, dtype=torch.float32)
x

In [None]:
def naive_softmax(vec):
   return torch.exp(vec) / torch.sum(torch.exp(vec))

torch.allclose(torch.softmax(x.clone(), dim=0), naive_softmax(x.clone()))

In [None]:
def safe_softmax(vec):
    # Compute row max in 1st pass
    row_max = torch.max(vec)
    # Compute denominator in 2nd pass
    denominator = torch.sum(torch.exp(vec - row_max))
    return torch.exp(vec - row_max) / denominator


torch.allclose(torch.softmax(x.clone(), dim=0), safe_softmax(x.clone()))

In [None]:
#TODO: Cf goodnotes ipad "Flash attention   "

def online_softmax(vec):
    N = len(vec)
    rowmax = -float("inf")
    denominator = 0
    # Compute the max and denominator in single pass
    for i in range(N):
        prev_rowmax = rowmax
        rowmax = max(prev_rowmax, vec[i])
        denominator = denominator * torch.exp(prev_rowmax - rowmax) + torch.exp(vec[i] - rowmax)
        print(f"prev_rowmax = {prev_rowmax} | rowmax = {rowmax} | denominator = {denominator}")

    print(torch.exp(vec - rowmax))
    print(denominator)
    return torch.exp(vec - rowmax) / denominator

torch.allclose(torch.softmax(x.clone(), dim=0), online_softmax(x.clone()))

In [None]:
print(x)

In [None]:
online_softmax(x.clone())

In [None]:
torch.exp(x - torch.max(x))

In [None]:
torch.sum(torch.exp(x - torch.max(x)))

In [None]:
torch.exp(x - torch.max(x)) / torch.sum(torch.exp(x - torch.max(x)))