<a href="https://colab.research.google.com/github/Autobot37/jupyter/blob/main/attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F

In [39]:
from dataclasses import dataclass
@dataclass
class config:
    n_embd: int = 4
    n_head: int = 2
    block_size: int = 4
    batch:int= 2
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

B = config.batch
T = config.block_size
C = config.n_embd
n_head = config.n_head
n_embd = config.n_embd

In [40]:
bias = torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size)

In [41]:
qkv = torch.randn((B,T,3*C))
# print("qkv", qkv)
Q, K, V = qkv.split(n_embd, dim=2)

In [42]:
qkv = qkv.clone().view(B*T*3*C)
print(len(qkv))
qx = [0]*B*T*C
kx = [0]*B*T*C
vx = [0]*B*T*C
print(len(qx))

for i in range(B):
  for j in range(T):
    for k in range(C):

      qx[i*T*C + j*C + k] = qkv[i*T*3*C + j*3*C + k]
      kx[i*T*C + j*C + k] = qkv[i*T*3*C + j*3*C + k + C]
      vx[i*T*C + j*C + k] = qkv[i*T*3*C + j*3*C + k + 2*C]

print(len(qx), len(Q.reshape(B*T*C)))
torch.allclose(torch.tensor(qx).view(B, T, C), Q.contiguous())

96
32
32 32


True

In [43]:
K = K.view(B, T, n_head, C // n_head)
Q = Q.view(B, T, n_head, C // n_head)
V = V.view(B, T, n_head, C // n_head)
# print("q", Q),print("k", K) ,print("V", V)

In [44]:

head_size = C // n_head
print("head_size", head_size)
for i in range(B):
  for j in range(T):
    for k in range(n_head):
      for l in range(head_size):
        if(qx[i*T*C + j*C + k*head_size + l] != Q[i][j][k][l]):
          print("incorrect")

print("correct")

head_size 2
correct


In [45]:
torch.allclose(torch.tensor(qx).view(B, T, n_head, head_size) ,Q.contiguous())

True

In [46]:
tensor_q = torch.tensor(qx).view(B, T, n_head, head_size).contiguous()
torch.allclose(tensor_q, Q.contiguous())


True

In [47]:
K = K.transpose(1, 2) # (B, nh, T, hs)
Q = Q.transpose(1, 2) # (B, nh, T, hs)
V = V.transpose(1, 2) # (B, nh, T, hs)

# print("q", Q),print("k", K) ,print("V", V)

In [48]:
head_size = C // n_head
tensor_q = torch.tensor(qx).clone().view(B, T, n_head, head_size)
tensor_qt = torch.zeros_like(tensor_q).view(B, n_head, T, head_size)

for i in range(B):
    for k in range(n_head):
        for j in range(T):
            for l in range(head_size):
                tensor_qt[i][k][j][l] = tensor_q[i][j][k][l]

print(torch.allclose(tensor_q.transpose(1,2).contiguous(), Q.contiguous()))
print(torch.allclose(tensor_qt, tensor_q.transpose(1, 2).contiguous()))

True
True


In [49]:
head_size = C // n_head

qxt = [0] * B * T * C

# Q[B,T,NH,HS] MATCHES = qx[B,T,NH,HS]
# Now Q became [B,NH,T,HS]
# qxt = [B,NH,T,HS]

print("head_size", head_size)
for i in range(B):
    for j in range(T):
        for k in range(n_head):
            for l in range(head_size):
                qxt[i * T * n_head * head_size + k * T * head_size + j * head_size + l] = qx[i * T * C + j * C + k * head_size + l]

print(torch.allclose(torch.tensor(qxt).view(B, n_head, T, head_size), Q.contiguous()))
# print(torch.tensor(qxt).view(B, n_head, T, head_size).contiguous(), "\n", Q.contiguous())


head_size 2
True


In [50]:
tensor_q = torch.tensor(qx).clone().view(B, T, n_head, head_size)
print(torch.allclose(tensor_q.transpose(1,2).contiguous(), Q.contiguous()))

True


In [51]:
def transpose_inplace(matrix):
    n = len(matrix)
    for i in range(n):
        for j in range(i + 1, n):
            matrix[i][j], matrix[j][i] = matrix[j][i], matrix[i][j]

matrix = [
    [1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]
]

transpose_inplace(matrix)
print(matrix)


[[1, 4, 7], [2, 5, 8], [3, 6, 9]]


In [52]:
idx2_list = []
# Q is shape B, nh, T, hs
print("head_size", head_size)
for i in range(B):
    for k in range(n_head):
        for j in range(k+1, T):
            for l in range(head_size):
                idx1 = i * T * C + j * C + k * head_size + l
                idx2 = i * n_head * T * head_size + k * T * head_size + j * head_size + l
                temp = qx[idx1]
                qx[idx1] = qx[idx2]
                qx[idx2] = temp

qx_tensor = torch.tensor(qx).view(B, n_head, T, head_size).contiguous()
print(torch.allclose(qx_tensor, Q.contiguous()))

head_size 2
True


In [53]:
att = (Q @ K.transpose(-2, -1))
# print("att", att)

In [54]:
#CUBLAS SGEMM

kt = [0]*B*T*C

#k orignial shape = B,T,NH,HS
#K required shape = B,NH,HS,T

for i in range(B):
    for j in range(T):
      for k in range(n_head):
        for l in range(head_size):
          from_idx = i * T * C + j * C + k * head_size + l
          to_idx =   i * T * C + k * T * head_size + l * T + j
          kt[to_idx] = kx[from_idx]

print(torch.allclose(torch.tensor(kt).view(B, n_head, head_size, T).contiguous(), K.transpose(-2, -1).contiguous()))
# torch.tensor(kt).view(B, n_head, head_size, T).contiguous(), K.transpose(-2, -1).contiguous()

True


In [56]:
attx = [0]*B*T*C

def mul(a,b):
  return a @ b

for b in range(B):
  xx = torch.tensor(qx[b*T*C : (b+1)*T*C]).view(n_head, T, head_size)
  yy = torch.tensor(kt[b*T*C : (b+1)*T*C]).view(n_head, head_size, T)
  res = mul(xx,yy)
  print(res.shape)
  attx[b*n_head*T*T : (b+1)*T*n_head*T] = res.view(-1)

tensor_attx = torch.tensor(attx).view(B, n_head, T, T)
print(torch.allclose(tensor_attx, att))

torch.Size([2, 4, 4])
torch.Size([2, 4, 4])
True


In [57]:
attx = [0]*B*T*C

def mul(a,b):
  return a @ b

xx = torch.tensor(qx).view(B*n_head, T, head_size)
yy = torch.tensor(kt).view(B*n_head, head_size, T)
res = mul(xx,yy)
print(res.shape)
attx = res.view(-1)

tensor_attx = torch.tensor(attx).view(B, n_head, T, T)
print(torch.allclose(tensor_attx, att))

torch.Size([4, 4, 4])
True


  tensor_attx = torch.tensor(attx).view(B, n_head, T, T)


In [58]:
#then softmax
#then last one is accumulating.

for i in range(B):
    for k in range(n_head):
        for j in range(k+1, T):
            for l in range(head_size):
                idx1 = i * T * C + j * C + k * head_size + l
                idx2 = i * n_head * T * head_size + k * T * head_size + j * head_size + l
                temp = vx[idx1]
                vx[idx1] = vx[idx2]
                vx[idx2] = temp

vx_tensor = torch.tensor(vx).view(B, n_head, T, head_size).contiguous()
print(torch.allclose(vx_tensor, V.contiguous()))

True


In [59]:
# att = att.masked_fill(bias[:,:,:T,:T] == 0, float('-inf'))
# print("masked att", att)

In [60]:
# att = F.softmax(att, dim=-1)
# print("softmaxed att", att)

In [61]:
y = att @ V # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
# print("y", y)

In [62]:
B,T,C

(2, 4, 4)

In [63]:
yx = [0]*B*T*C

for b in range(B):
  xx = torch.tensor(attx[b*n_head*T*T : (b+1)*n_head*T*T]).view(n_head, T, T)
  yy = torch.tensor(vx[b*T*C : (b+1)*T*C]).view(n_head, T, head_size)
  yx[b*T*C:(b+1)*T*C] = (xx @ yy).view(-1)

torch.allclose(torch.tensor(yx).view(B,n_head,T, head_size).contiguous(), y)


  xx = torch.tensor(attx[b*n_head*T*T : (b+1)*n_head*T*T]).view(n_head, T, T)


True

In [64]:

#from B,NH,T,HS
#TO B,T,NH,HS
yxt = [0]*B*T*C

for i in range(B):
    for j in range(n_head):
        for k in range(T):
            for l in range(head_size):
                idx1 = i * n_head * T * head_size + j * T * head_size + k * head_size + l
                idx2 = i * n_head * T * head_size + k * n_head * head_size + j * head_size + l
                yxt[idx2] = yx[idx1]

In [65]:
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# print("final", y)
# print(torch.tensor(yxt).view(B,T,C).contiguous())

In [66]:
torch.allclose(torch.tensor(yxt).view(B,T,C).contiguous(), y)

True