In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [33]:
vocab_size=27
head_size=4
block_size=6# MAX POSSIBLE SEQ LENGTH
n_embd=4
class MyTorchModule(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        print("Tril Buffer ",self.tril)
    def forward(self,token_seq):
        x= self.token_embedding_table(token_seq)
        T,C=x.shape # C is n_embd; T is current Sequence Length
        print("Current Sequence Length ",T)
        k=self.key(x) # (T, head_size)
        q=self.query(x) # (T, head_size)
        v=self.value(x) #(T, head_size)
        print("Called Forward with Input ",x)
        wei = q @ k.transpose(-2,-1) #Not doing any scaling here; Because the focus is just on KV-Caching
        print("Weight = ", wei)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        print("Weight After Mask Filling ", wei)
        wei = F.softmax(wei,dim=-1)
        print("Raw Values to Offer for Attention ",v)
        print("Final Weights for Value Accumulation ",wei)
        out = wei @ v #(T,T)@(T,head_size) => (T, head_size)
        print("Accumulated Values for Attention  ", out)


In [34]:
attention_model=MyTorchModule()

Tril Buffer  tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [35]:
curr_token_seq= [10,12,24,3]
ix= torch.tensor(curr_token_seq)
attention_output1=attention_model(ix)

Current Sequence Length  4
Called Forward with Input  tensor([[-0.1114,  0.2369,  1.2289,  0.0469],
        [ 1.1885,  0.8388,  0.4477, -1.8496],
        [-1.3062,  0.8522, -0.7679,  1.1046],
        [-1.3318,  0.5408,  0.4993,  0.5870]], grad_fn=<EmbeddingBackward0>)
Weight =  tensor([[ 0.1432,  0.7684, -0.5193, -0.2819],
        [-0.0383, -0.8364,  0.3031,  0.2034],
        [ 0.0155,  0.8748, -0.3244, -0.2943],
        [ 0.2540,  1.1855, -0.8233, -0.4735]], grad_fn=<MmBackward0>)
Weight After Mask Filling  tensor([[ 0.1432,    -inf,    -inf,    -inf],
        [-0.0383, -0.8364,    -inf,    -inf],
        [ 0.0155,  0.8748, -0.3244,    -inf],
        [ 0.2540,  1.1855, -0.8233, -0.4735]], grad_fn=<MaskedFillBackward0>)
Raw Values to Offer for Attention  tensor([[ 0.6596, -0.3412,  0.6854,  0.4623],
        [-0.3994,  0.1570,  0.5419,  0.6956],
        [ 0.7482, -0.7134, -0.0972, -0.1077],
        [ 0.9776, -0.8043,  0.2481,  0.1621]], grad_fn=<MmBackward0>)
Final Weights for Value Acc

In [37]:
curr_token_seq= [10,12,24,3,6]
ix= torch.tensor(curr_token_seq)
attention_output2=attention_model(ix)

Current Sequence Length  5
Called Forward with Input  tensor([[-1.1140e-01,  2.3685e-01,  1.2289e+00,  4.6872e-02],
        [ 1.1885e+00,  8.3883e-01,  4.4766e-01, -1.8496e+00],
        [-1.3062e+00,  8.5220e-01, -7.6790e-01,  1.1046e+00],
        [-1.3318e+00,  5.4085e-01,  4.9931e-01,  5.8696e-01],
        [-4.0275e-04, -1.0988e+00, -5.7268e-01, -1.2227e+00]],
       grad_fn=<EmbeddingBackward0>)
Weight =  tensor([[ 0.1432,  0.7684, -0.5193, -0.2819,  0.0923],
        [-0.0383, -0.8364,  0.3031,  0.2034, -0.1206],
        [ 0.0155,  0.8748, -0.3244, -0.2943,  0.0160],
        [ 0.2540,  1.1855, -0.8233, -0.4735,  0.0055],
        [ 0.1698, -1.2351,  0.1965,  0.2963, -0.2878]], grad_fn=<MmBackward0>)
Weight After Mask Filling  tensor([[ 0.1432,    -inf,    -inf,    -inf,    -inf],
        [-0.0383, -0.8364,    -inf,    -inf,    -inf],
        [ 0.0155,  0.8748, -0.3244,    -inf,    -inf],
        [ 0.2540,  1.1855, -0.8233, -0.4735,    -inf],
        [ 0.1698, -1.2351,  0.1965,  0.296

In [40]:
#wei => [0.2293, 0.5819, 0.0781, 0.1108, 0.0000]
c1= 0.2293*.6596 + 0.5819*-0.3994 + 0.0781*.7482 + 0.1108*.9776 + 0.0000*-1.1223
print(c1)

0.08558792000000001
