In [1]:
import sys
from pathlib import Path

here = Path.cwd().resolve()
repo_root = here if (here / "src").exists() else here.parents[1]

if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

### Simple attention mechanism - calculating context vector for one token (to understand)

### assume embedding dimension = 3 and already input embeddings are calculated

![image.png](attachment:image.png)

In [2]:
import torch
inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your (x^1)
[0.55, 0.87, 0.66], # journey (x^2)
[0.57, 0.85, 0.64], # starts (x^3)
[0.22, 0.58, 0.33], # with (x^4)
[0.77, 0.25, 0.10], # one (x^5)
[0.05, 0.80, 0.55]] # step (x^6)
)

![image.png](attachment:image.png)

In [3]:
# query token - i need to calcualte the attention score for the token "journey"
query = inputs[1] # (x^2)

In [4]:
# attention scores
# attention_scores = torch.matmul(inputs, query) -- easier way

attention_scores = torch.zeros(inputs.shape[0])
for i in range(inputs.shape[0]):
    attention_scores[i] = torch.dot(inputs[i], query)
print("Attention scores: ", attention_scores)

Attention scores:  tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


![image.png](attachment:image.png)

In [5]:
# attention weights - normalized attention scores- softmax
attention_weights = torch.softmax(attention_scores, dim=-1)

In [6]:
attention_weights

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])

![image.png](attachment:image.png)

In [7]:
# context vector for the token "journey"
# attention_weight * inputs = context vector
context_vector = torch.zeros(inputs.shape[1])
for i in range(inputs.shape[0]):
    context_vector += attention_weights[i] * inputs[i]

context_vector.shape


torch.Size([3])

In [8]:
print(context_vector)

tensor([0.4419, 0.6515, 0.5683])


### Now calculate at once all attattion weights

In [9]:
attention_scores= torch.matmul(inputs, inputs.T)
attention_weights = torch.softmax(attention_scores, dim=-1) # normalized attention scores
print(attention_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [10]:
## all the context vectors at once
context_vectors = torch.matmul(attention_weights, inputs)
print(context_vectors)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


# With trainable wights

![image.png](attachment:image.png)

In [11]:
# qyery is the second token "journey"
x_2 = inputs[1]
print(x_2)

tensor([0.5500, 0.8700, 0.6600])


In [12]:
d_in = 3
d_out =2 # ideally d_in = d_out

In [13]:
# randomly intitalize trainable query, key and value matrix
# output embedding size = 2
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [14]:
W_query 

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])

In [15]:
# query, key and value vectors of the token "journey"
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key     
value_2 = x_2 @ W_value
print("query_2:", query_2)
print("key_2:", key_2)  
print("value_2:", value_2)

query_2: tensor([0.4306, 1.4551])
key_2: tensor([0.4433, 1.1419])
value_2: tensor([0.3951, 1.0037])


In [16]:
# key and value vectors of all the tokens
keys = inputs @ W_key
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


In [17]:
# Now unscaled attention scores for all the tokesn with respect to the query token "journey"
unscaled_attention_scores = torch.zeros(inputs.shape[0])
for i in range(inputs.shape[0]):
    unscaled_attention_scores[i] = torch.dot(keys[i], query_2) # attention score of each key with respect to the query
print("Unscaled attention scores:", unscaled_attention_scores)

Unscaled attention scores: tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


#now we scale the attention scores by dividing
them by the square root of the embedding dimension of the keys (taking the square
root is mathematically the same as exponentiating by 0.5)

In [18]:
keys.shape[-1] # d_k
d_k = keys.shape[-1]

In [19]:

attn_weights_2 = torch.softmax(unscaled_attention_scores / d_k**0.5, dim=-1)

In [20]:
print(attn_weights_2)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


In [21]:
# now the context wector for the token "journey" is attention_weight * values
context_vector_2 = attn_weights_2 @ values
print("context_vector_2:", context_vector_2)


context_vector_2: tensor([0.3061, 0.8210])


In [22]:
from src.gpt_blocks.self_attention import SelfAttention_v1, SelfAttention_v2

In [None]:
torch.manual_seed(123)
self_attention = SelfAttention_v1(d_in=3, d_out=2)
context_vectors = self_attention(inputs)
print(context_vectors)

In [None]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

In [None]:
Wq = sa_v2.W_query.weight.T
Wk = sa_v2.W_key.weight.T
Wv = sa_v2.W_value.weight.T

In [None]:
print(Wk)

In [None]:
torch.manual_seed(789)
sa_v1 = SelfAttention_v1(d_in, d_out)
sa_v1.W_key.data = Wk
sa_v1.W_value.data = Wv
sa_v1.W_query.data = Wq

In [None]:
print(sa_v1.W_query)
print(sa_v1.W_key)
print(sa_v1.W_value)


In [None]:
torch.manual_seed(789)
print(sa_v1(inputs))


# Causal Attention

![image.png](attachment:image.png)

In [None]:
self_attn = SelfAttention_v2(d_in=3, d_out=2)

In [None]:
Q = self_attn.W_query(inputs)
K = self_attn.W_key(inputs)

In [None]:
attn_scores = Q @ K.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

In [None]:
# masking matrix
context_lenght = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones((context_lenght, context_lenght)))
print(mask_simple)

In [None]:
# mask the attentionweights
mask_simple = attn_weights * mask_simple
print(mask_simple)


In [None]:
# renormalized attention weights
renorm_attention_weights = mask_simple / mask_simple.sum(dim=-1, keepdim=True)
print(renorm_attention_weights)

In [None]:
# Causal Attention improvements 
mask = torch.triu(torch.ones(context_lenght, context_lenght), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

In [None]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

## Dropout

In [None]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6, 6)
print(dropout(example))

In [None]:
torch.manual_seed(123)
print(dropout(attention_scores))

In [None]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

In [None]:
from self_attention import CausalAttention
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)

In [None]:
context_vecs

# Multi Head attention

![image.png](attachment:image.png)

In [None]:
from self_attention import MultiHeadAttentionWrapper

In [None]:
torch.manual_seed(123)
context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 1

In [None]:
mha = MultiHeadAttentionWrapper(
d_in, d_out, context_length, 0.0, num_heads=2
)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)  # 2 d context vectors for each head hence 4 dimention context vector of each token 

In [None]:
# Final Multi-head attention with masking and dropout and mixing of heads
from self_attention import MultiHeadAttention
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in = d_in, d_out=d_out, context_length=context_length, dropout = 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

In [None]:
gpt_mha = MultiHeadAttention(d_in=3, d_out=768, context_length=768, dropout=0.0, num_heads=12)

In [None]:
total_params = sum(p.numel() for p in gpt_mha.parameters() if p.requires_grad)
print(f"Total trainable parameters: {total_params}")