In [1]:
import torch

In [54]:
B = 1
H = 1
N = 2048
D = 256
DV = 64

Q = (torch.ones(B*H*N*D, dtype=torch.bfloat16, device='cuda').reshape(B, H, N, D)).to(torch.float32)/(D*DV)
K = (torch.arange(B*H*N*D, dtype=torch.bfloat16, device='cuda').reshape(B, H, N, D)).to(torch.float32)/(D*DV*2)
V = (torch.ones(B*H*N*DV, dtype=torch.bfloat16, device='cuda').reshape(B, H, N, DV)).to(torch.float32)

q, k, v = Q.unsqueeze(-2), K.unsqueeze(-2), V.unsqueeze(-1)
kv_state = (k * v).cumsum(dim=2)
out = (q * kv_state).sum(dim=-1)
last_kv = kv_state[:, :, -1]

print(f"{q.shape=}, {k.shape=}, {v.shape=}")
print(f"{kv_state.shape=}, {out.shape=}")

q.shape=torch.Size([1, 1, 2048, 1, 256]), k.shape=torch.Size([1, 1, 2048, 1, 256]), v.shape=torch.Size([1, 1, 2048, 64, 1])
kv_state.shape=torch.Size([1, 1, 2048, 64, 256]), out.shape=torch.Size([1, 1, 2048, 64])


In [55]:
last_kv[:, :, -1]

tensor([[[16376.0000, 16383.0000, 16383.0000, 16383.0000, 16383.0000,
          16383.0020, 16383.0020, 16383.0020, 16383.0020, 16383.0059,
          16383.0059, 16383.0059, 16383.0078, 16383.0078, 16383.0078,
          16383.0078, 16383.0078, 16383.0234, 16383.0234, 16383.0234,
          16383.0234, 16383.0254, 16383.0254, 16383.0254, 16383.0293,
          16383.0293, 16383.0293, 16383.0293, 16383.0312, 16383.0312,
          16383.0312, 16383.0312, 16383.0312, 16383.0938, 16383.0938,
          16383.0938, 16383.0938, 16383.0957, 16383.0957, 16383.0957,
          16383.0957, 16383.0996, 16383.0996, 16383.0996, 16383.1016,
          16383.1016, 16383.1016, 16383.1016, 16383.1172, 16383.1172,
          16383.1172, 16383.1172, 16383.1172, 16383.1191, 16383.1191,
          16383.1191, 16383.1230, 16383.1230, 16383.1230, 16383.1230,
          16383.1250, 16383.1250, 16383.1250, 16383.1250, 16383.1250,
          16383.3750, 16383.3750, 16383.3750, 16383.3750, 16383.3770,
          16383.3770

In [3]:
def pytorch_test(Q, K, V, TESTNAME='all'):

    def make_causal(X):
        (b,h,n,m) = X.shape
        print(f"{X.shape=}")
        mask= ~(torch.arange(n).view(1,1,n,1) >= torch.arange(n).view(1,1,1,n)).expand(b,h,n,n)
        X[mask] = 0.
        return X

    ATT = make_causal(torch.einsum("bhnd,bhmd->bhnm", Q, K))
    out = torch.einsum("bhnm,bhmd->bhnd", ATT, V).to(torch.bfloat16)
    return out

o = pytorch_test(Q, K, V)
o.shape

X.shape=torch.Size([1, 1, 2048, 2048])


torch.Size([1, 1, 2048, 64])

In [4]:
print(o[0][0][0])

tensor([0.0000e+00, 1.2666e-06, 2.5332e-06, 3.7998e-06, 5.0664e-06, 6.3181e-06,
        7.5996e-06, 8.8215e-06, 1.0133e-05, 1.1384e-05, 1.2636e-05, 1.3888e-05,
        1.5199e-05, 1.6451e-05, 1.7643e-05, 1.8954e-05, 2.0266e-05, 2.1458e-05,
        2.2769e-05, 2.3961e-05, 2.5272e-05, 2.6584e-05, 2.7776e-05, 2.9087e-05,
        3.0398e-05, 3.1710e-05, 3.2902e-05, 3.4094e-05, 3.5286e-05, 3.6716e-05,
        3.7909e-05, 3.9101e-05, 4.0531e-05, 4.1723e-05, 4.2915e-05, 4.4346e-05,
        4.5538e-05, 4.6730e-05, 4.7922e-05, 4.9353e-05, 5.0545e-05, 5.1737e-05,
        5.3167e-05, 5.4359e-05, 5.5552e-05, 5.6982e-05, 5.8174e-05, 5.9366e-05,
        6.0797e-05, 6.1989e-05, 6.3419e-05, 6.4373e-05, 6.5804e-05, 6.7234e-05,
        6.8188e-05, 6.9618e-05, 7.0572e-05, 7.2002e-05, 7.3433e-05, 7.4387e-05,
        7.5817e-05, 7.7248e-05, 7.8201e-05, 7.9632e-05], device='cuda:0',
       dtype=torch.bfloat16)


In [34]:
# TK outputs: note refresh the printouts folder if you're seeing strange results
import os

path = "printouts"
os.listdir(path)

# load o.txt
with open(f"{path}/o.txt", "r") as f:
    o_tk = torch.tensor([float(x) for x in f.read().split()]).reshape(1, 1, 2048, 64)
# load o_ref.txt
with open(f"{path}/o_ref.txt", "r") as f:
    o_ref_tk = torch.tensor([float(x) for x in f.read().split()]).reshape(1, 1, 2048, 64)


In [28]:
print(o_ref_tk.shape)
o_ref_tk[0][0][0]

torch.Size([1, 1, 2048, 64])


tensor([0.0000e+00, 1.2666e-06, 2.5332e-06, 3.7998e-06, 5.0664e-06, 6.3181e-06,
        7.5996e-06, 8.8215e-06, 1.0133e-05, 1.1384e-05, 1.2636e-05, 1.3888e-05,
        1.5199e-05, 1.6451e-05, 1.7643e-05, 1.8954e-05, 2.0266e-05, 2.1458e-05,
        2.2769e-05, 2.3961e-05, 2.5272e-05, 2.6584e-05, 2.7776e-05, 2.9087e-05,
        3.0398e-05, 3.1710e-05, 3.2902e-05, 3.4094e-05, 3.5286e-05, 3.6717e-05,
        3.7909e-05, 3.9101e-05, 4.0531e-05, 4.1723e-05, 4.2915e-05, 4.4346e-05,
        4.5538e-05, 4.6730e-05, 4.7922e-05, 4.9353e-05, 5.0545e-05, 5.1737e-05,
        5.3167e-05, 5.4359e-05, 5.5551e-05, 5.6982e-05, 5.8174e-05, 5.9366e-05,
        6.0797e-05, 6.1989e-05, 6.3419e-05, 6.4373e-05, 6.5803e-05, 6.7234e-05,
        6.8188e-05, 6.9618e-05, 7.0572e-05, 7.2002e-05, 7.3433e-05, 7.4387e-05,
        7.5817e-05, 7.7248e-05, 7.8201e-05, 7.9632e-05])

In [29]:
print(o_tk.shape)
o_tk[0][0][0]

torch.Size([1, 1, 2048, 64])


tensor([0.0000e+00, 1.2666e-06, 2.5332e-06, 3.7998e-06, 5.0664e-06, 6.3181e-06,
        7.5996e-06, 8.8811e-06, 1.0133e-05, 1.1384e-05, 1.2636e-05, 1.3947e-05,
        1.5199e-05, 1.6451e-05, 1.7762e-05, 1.8954e-05, 2.0266e-05, 2.1577e-05,
        2.2769e-05, 2.4080e-05, 2.5272e-05, 2.6584e-05, 2.7895e-05, 2.9087e-05,
        3.0398e-05, 3.1710e-05, 3.2902e-05, 3.4094e-05, 3.5524e-05, 3.6717e-05,
        3.7909e-05, 3.9339e-05, 4.0531e-05, 4.1723e-05, 4.3154e-05, 4.4346e-05,
        4.5538e-05, 4.6969e-05, 4.8161e-05, 4.9353e-05, 5.0545e-05, 5.1975e-05,
        5.3167e-05, 5.4359e-05, 5.5790e-05, 5.6982e-05, 5.8174e-05, 5.9605e-05,
        6.0797e-05, 6.1989e-05, 6.3419e-05, 6.4373e-05, 6.5803e-05, 6.7234e-05,
        6.8188e-05, 6.9618e-05, 7.1049e-05, 7.2002e-05, 7.3433e-05, 7.4863e-05,
        7.5817e-05, 7.7248e-05, 7.8678e-05, 7.9632e-05])

In [30]:
fpath = "arange.txt"
with open(fpath, "r") as f:
    arange_tensor = [float(x) for x in f.read().split()]

num_q_elements = B*H*N*D
num_k_elements = B*H*N*D
num_v_elements = B*H*N*DV
num_o_elements = B*H*N*DV

q_in = torch.tensor(arange_tensor[:num_q_elements]).reshape(B, H, N, D)
k_in = torch.tensor(arange_tensor[num_q_elements:num_q_elements+num_k_elements]).reshape(B, H, N, D)
v_in = torch.tensor(arange_tensor[num_q_elements+num_k_elements:num_q_elements+num_k_elements+num_v_elements]).reshape(B, H, N, DV)
o_in = torch.tensor(arange_tensor[num_q_elements+num_k_elements+num_v_elements:num_q_elements+num_k_elements+num_v_elements+num_o_elements]).reshape(B, H, N, DV)   


In [32]:
o_in[0][0][0]

tensor([0.0000e+00, 1.2666e-06, 2.5332e-06, 3.7998e-06, 5.0664e-06, 6.3181e-06,
        7.5996e-06, 8.8215e-06, 1.0133e-05, 1.1384e-05, 1.2636e-05, 1.3888e-05,
        1.5199e-05, 1.6451e-05, 1.7643e-05, 1.8954e-05, 2.0266e-05, 2.1458e-05,
        2.2769e-05, 2.3961e-05, 2.5272e-05, 2.6584e-05, 2.7776e-05, 2.9087e-05,
        3.0398e-05, 3.1710e-05, 3.2902e-05, 3.4094e-05, 3.5286e-05, 3.6716e-05,
        3.7909e-05, 3.9101e-05, 4.0531e-05, 4.1723e-05, 4.2915e-05, 4.4346e-05,
        4.5538e-05, 4.6730e-05, 4.7922e-05, 4.9353e-05, 5.0545e-05, 5.1737e-05,
        5.3167e-05, 5.4359e-05, 5.5552e-05, 5.6982e-05, 5.8174e-05, 5.9366e-05,
        6.0797e-05, 6.1989e-05, 6.3419e-05, 6.4373e-05, 6.5804e-05, 6.7234e-05,
        6.8188e-05, 6.9618e-05, 7.0572e-05, 7.2002e-05, 7.3433e-05, 7.4387e-05,
        7.5817e-05, 7.7248e-05, 7.8201e-05, 7.9632e-05])

In [33]:
o[0][0][0]

tensor([0.0000e+00, 1.2666e-06, 2.5332e-06, 3.7998e-06, 5.0664e-06, 6.3181e-06,
        7.5996e-06, 8.8215e-06, 1.0133e-05, 1.1384e-05, 1.2636e-05, 1.3888e-05,
        1.5199e-05, 1.6451e-05, 1.7643e-05, 1.8954e-05, 2.0266e-05, 2.1458e-05,
        2.2769e-05, 2.3961e-05, 2.5272e-05, 2.6584e-05, 2.7776e-05, 2.9087e-05,
        3.0398e-05, 3.1710e-05, 3.2902e-05, 3.4094e-05, 3.5286e-05, 3.6716e-05,
        3.7909e-05, 3.9101e-05, 4.0531e-05, 4.1723e-05, 4.2915e-05, 4.4346e-05,
        4.5538e-05, 4.6730e-05, 4.7922e-05, 4.9353e-05, 5.0545e-05, 5.1737e-05,
        5.3167e-05, 5.4359e-05, 5.5552e-05, 5.6982e-05, 5.8174e-05, 5.9366e-05,
        6.0797e-05, 6.1989e-05, 6.3419e-05, 6.4373e-05, 6.5804e-05, 6.7234e-05,
        6.8188e-05, 6.9618e-05, 7.0572e-05, 7.2002e-05, 7.3433e-05, 7.4387e-05,
        7.5817e-05, 7.7248e-05, 7.8201e-05, 7.9632e-05], device='cuda:0',
       dtype=torch.bfloat16)