In [1]:
import torch
import torch.nn as nn

## TODO:
- create a vocab of sorts for demonstration
- consider an example sentence and tokenize it (using the vocab)
- create token embeddings

In [2]:
word_list_str = "the and of to a in is it you that he was for on are with as I his they at be this have from or one had by word but not what all were we when your can said there use an each which she do how their if will up other about out many then them these so some her would make like him into time has look two more write go see number no way could people my than first water been call who oil its now find long down day did get come made may part"

word_list = word_list_str.split()
rev_vocab = dict(enumerate(word_list))

# map words to numbers
vocab = dict(zip(rev_vocab.values(), rev_vocab.keys()))

vocab_size = len(vocab)
dim_tok_emb = 120

In [3]:
sentence = "the number was long"

In [4]:
def tokenize(sentence: str) -> torch.Tensor:
    tokenized = []
    for word in sentence.split():
        tokenized.append(vocab[word])
        
    return torch.tensor(tokenized)

tokenized = tokenize(sentence); tokenized.shape

torch.Size([4])

In [5]:
tokenized

tensor([ 0, 75, 11, 91])

In [6]:
# here, we will use embedding of size 120 for no particular reason
embed = nn.Embedding(vocab_size, dim_tok_emb)

In [7]:
sentence_embeddings = embed(tokenized).detach(); sentence_embeddings.shape

torch.Size([4, 120])

## TODO - single attention head

- create W_q, W_k, W_v
- try out the entire process for a single token embedding
- create the attention pattern initially having just q.k
- then apply softmax column-wise (imagine the attention pattern having keys spread down each row in every column and queries spread across each column down the rows)

Attention formula: $\text{Attention}(Q,K,V)=\text{softmax}(\frac{QK^T}{\sqrt{d_k}})V$

In [8]:
# weight matrices

# dim of W_q is less than that of tok_emb. here, let's consider 12 as the dim
# how is W_q and W_k used? W_q is mmulted with tok_emb to get a new representation in the key-query space (12 dim)
dim_key_query = 12
W_q = torch.randn(dim_tok_emb, dim_key_query)
W_k = torch.randn(dim_tok_emb, dim_key_query)

# W_v can be implemented using two low rank matrices as well, but we'll do it using a single matrix here.
# how is W_v used? W_v is mmulted with tok_emb to get a new representation of the original token embedding in the embedding space itself (not entirely sure about this, but this is the implementation in the 3b1b vid)
W_v = torch.randn(dim_tok_emb, dim_tok_emb)

W_q & W_k -> 120 x 12 <br>
W_v -> 120 x 120

k -> 12 <br>
q -> 12 <br>
v -> 12 <br>

tok_emb -> 120

In [9]:
W_q.shape

torch.Size([120, 12])

In [10]:
x_1 = sentence_embeddings[0]
q_1 = x_1 @ W_q
k_1 = x_1 @ W_k
v_1 = x_1 @ W_v

In [11]:
k_1.shape

torch.Size([12])

In [12]:
qk_1 = torch.dot(q_1, k_1) / dim_key_query**0.5
neg_infs = torch.tensor([-torch.inf] * (sentence_embeddings.shape[0] - 1))
col_temp = torch.cat([torch.tensor([qk_1]), neg_infs])
col_norm = torch.softmax(col_temp, 0)

In [13]:
col_norm.shape, v_1.shape

(torch.Size([4]), torch.Size([120]))

In [14]:
deltaE = (col_norm.unsqueeze(-1) * v_1).sum(dim=0); deltaE.shape


torch.Size([120])

In [15]:
# this indicates that the above operation worked correctly because we know that col_norm has values [1, 0, 0, 0]
# and computing the weighted sum of v_1 with col_norm should just give v_1
torch.equal(deltaE, v_1)

True

### Repeat the above process for the entire sentence

In [16]:
sentence_embeddings.shape, W_q.shape

(torch.Size([4, 120]), torch.Size([120, 12]))

In [17]:
x = sentence_embeddings
q = x @ W_q
k = x @ W_k
v = x @ W_v

In [18]:
q.shape, k.shape, v.shape

(torch.Size([4, 12]), torch.Size([4, 12]), torch.Size([4, 120]))

In [19]:
q.T.shape

torch.Size([12, 4])

In [20]:
torch.equal(q @ k.T, (k@q.T).T)

True

In [21]:
# care has to be taken to see if k @ q.T or q @ k.T has to be used here
# the former implementation is more appropriate purely going by what is shown by 3b1b
qk = (k @ q.T) / dim_key_query**0.5
mask = torch.eq(torch.triu(qk), 0)
qk[mask] = -torch.inf
attn_pattern = torch.softmax(qk, dim=0)

## TODO:
- compute deltaE values from the attention pattern
- add them to the original embeddings
- unlike before where the usage of certain variable names makes it appear like the v values are being used column wise, we must use it such it each key has a different v value (each row has a different v value)

In [22]:
attn_pattern.shape, v.shape

(torch.Size([4, 4]), torch.Size([4, 120]))

In [23]:
torch.equal((attn_pattern.T @ v), (v.T @ attn_pattern).T)

True

In [24]:
deltaE = (attn_pattern.T @ v); deltaE.shape

torch.Size([4, 120])

In [25]:
# deltaE1 is the change to be applied to the first token embedding.
deltaE1 = (attn_pattern[:, 1].unsqueeze(-1) * v).sum(dim=0); deltaE1.shape

torch.Size([120])

In [26]:
torch.equal(deltaE[1], deltaE1)

True

In [27]:
import attention
d_in, d_out_kq, d_out_v = 120, 12, 120

s = attention.SelfAttention(d_in, d_out_kq, d_out_v)
s(sentence_embeddings).shape

torch.Size([4, 120])

In [28]:
m = attention.MultiHeadAttention(120, 12, 120, 4)
m(sentence_embeddings).shape

torch.Size([4, 4, 120])