# Self Attention
**With trainable weights (Wq, Wk, Wv)**

In [25]:
import torch

In [26]:
D_IN = 4
D_OUT = 3
CONTEXT_SIZE = 6
NUM_TOKENS = 6
torch.manual_seed(0)

<torch._C.Generator at 0x1c3ce671d70>

## Input Embeddings

In [27]:
input_embeddings = torch.randn(NUM_TOKENS, D_IN)
print(f"Input Embeddings:\n{input_embeddings}\n")
print(input_embeddings.shape)

Input Embeddings:
tensor([[-1.1258, -1.1524, -0.2506, -0.4339],
        [ 0.8487,  0.6920, -0.3160, -2.1152],
        [ 0.4681, -0.1577,  1.4437,  0.2660],
        [ 0.1665,  0.8744, -0.1435, -0.1116],
        [ 0.9318,  1.2590,  2.0050,  0.0537],
        [ 0.6181, -0.4128, -0.8411, -2.3160]])

torch.Size([6, 4])


## For One Token

### Weight Matrices

In [28]:
TOKEN_INDEX = 2  # the token we are focusing on
print(f"Focusing on token index: {TOKEN_INDEX}\n")
print(f"Input Embedding:\n{input_embeddings[TOKEN_INDEX]}\n")
print(f"{input_embeddings[TOKEN_INDEX].shape}\n")

Focusing on token index: 2

Input Embedding:
tensor([ 0.4681, -0.1577,  1.4437,  0.2660])

torch.Size([4])



In [29]:
# query, key, value weight matrices
torch.manual_seed(0)
Wq = torch.nn.Parameter(torch.randn(D_IN, D_OUT))
Wk = torch.nn.Parameter(torch.randn(D_IN, D_OUT))
Wv = torch.nn.Parameter(torch.randn(D_IN, D_OUT))

### Query, Key, Value Vectors

In [30]:
query = input_embeddings[TOKEN_INDEX] @ Wq
key = input_embeddings[TOKEN_INDEX] @ Wk
value = input_embeddings[TOKEN_INDEX] @ Wv

print(f"Query Vector:\n{query}\t Shape: {query.shape}\n")
print(f"Key Vector:\n{key}\t Shape: {key.shape}\n")
print(f"Value Vector:\n{value}\t Shape: {value.shape}\n")

Query Vector:
tensor([ 1.1067,  1.0848, -1.7892], grad_fn=<SqueezeBackward4>)	 Shape: torch.Size([3])

Key Vector:
tensor([-1.3206, -1.3241,  0.1806], grad_fn=<SqueezeBackward4>)	 Shape: torch.Size([3])

Value Vector:
tensor([ 1.1077, -1.8903, -0.1447], grad_fn=<SqueezeBackward4>)	 Shape: torch.Size([3])



### Key and Value Matrices
We need full key and value matrices to compute the attention scores

In [31]:
keys = input_embeddings @ Wk  # (NUM_TOKENS, D_OUT)
values = input_embeddings @ Wv  # (NUM_TOKENS, D_OUT)

print(f"All Key Vectors:\n{keys}\t Shape: {keys.shape}\n")
print(f"All Value Vectors:\n{values}\t Shape: {values.shape}\n")

All Key Vectors:
tensor([[ 0.4147, -0.6161,  0.3470],
        [-3.4411, -0.9688, -2.4070],
        [-1.3206, -1.3241,  0.1806],
        [-0.0699, -0.2004,  0.0083],
        [-2.3540, -2.6627,  0.2559],
        [-3.2032,  0.0066, -2.9222]], grad_fn=<MmBackward0>)	 Shape: torch.Size([6, 3])

All Value Vectors:
tensor([[ 0.3302, -0.1072,  2.1203],
        [-1.6966,  6.1661, -2.2009],
        [ 1.1077, -1.8903, -0.1447],
        [-0.7948,  0.6281, -0.8995],
        [ 0.2518, -1.5995, -1.7082],
        [-1.2481,  7.0324, -1.2623]], grad_fn=<MmBackward0>)	 Shape: torch.Size([6, 3])



### Attention Scores

In [32]:
attention_scores = query @ keys.T

print(f"Attention Scores (unnormalized):\n{attention_scores}\nShape: {attention_scores.shape}\n")

Attention Scores (unnormalized):
tensor([-0.8304, -0.5525, -3.2210, -0.3096, -5.9513,  1.6908],
       grad_fn=<SqueezeBackward4>)
Shape: torch.Size([6])



### Attention Weights

In [33]:
dk = keys.shape[-1]
print(f"dk: {dk}\nD_OUT: {D_OUT}\n")

dk: 3
D_OUT: 3



In [34]:
attention_weights = torch.softmax(attention_scores / (dk ** 0.5), dim=-1)
print(f"Attention Weights (normalized):\n{attention_weights}\nShape: {attention_weights.shape}\n")

Attention Weights (normalized):
tensor([0.1232, 0.1447, 0.0310, 0.1664, 0.0064, 0.5283],
       grad_fn=<SoftmaxBackward0>)
Shape: torch.Size([6])



In [35]:
print(torch.sum(attention_weights))

tensor(1., grad_fn=<SumBackward0>)


### Context Vector

In [36]:
context_vector = attention_weights @ values

print(f"Context Vector:\n{context_vector}\nShape: {context_vector.shape}\n")

Context Vector:
tensor([-0.9604,  4.6295, -0.8891], grad_fn=<SqueezeBackward4>)
Shape: torch.Size([3])



## All Tokens

In [37]:
import torch.nn as nn

In [38]:
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super(SelfAttention_v1, self).__init__()
        self.Wq = nn.Parameter(torch.randn(d_in, d_out))
        self.Wk = nn.Parameter(torch.randn(d_in, d_out))
        self.Wv = nn.Parameter(torch.randn(d_in, d_out))

    def forward(self, input_embeddings):
        queries = input_embeddings @ self.Wq
        keys = input_embeddings @ self.Wk
        values = input_embeddings @ self.Wv

        dk = keys.shape[-1] # d_out
        attention_scores = queries @ keys.T
        attention_weights = torch.softmax(attention_scores / (dk ** 0.5), dim=-1)
        context_vectors = attention_weights @ values

        return context_vectors

In [39]:
torch.manual_seed(0)
self_attention_v1 = SelfAttention_v1(D_IN, D_OUT)
context_vectors = self_attention_v1(input_embeddings)
print(f"Context Vectors from SelfAttention_v1:\n{context_vectors}\nShape: {context_vectors.shape}\n")

Context Vectors from SelfAttention_v1:
tensor([[ 0.3741, -1.1071, -0.7528],
        [-0.7642,  3.6132, -0.4795],
        [-0.9604,  4.6295, -0.8891],
        [-0.7429,  3.4181, -1.0734],
        [-1.2533,  6.6111, -1.2935],
        [-0.2899,  1.0370,  0.2127]], grad_fn=<MmBackward0>)
Shape: torch.Size([6, 3])



### Better Initialization with nn.Linear

In [42]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out):
        super(SelfAttention_v2, self).__init__()
        self.Wq = nn.Linear(d_in, d_out, bias=False)
        self.Wk = nn.Linear(d_in, d_out, bias=False)
        self.Wv = nn.Linear(d_in, d_out, bias=False)

    def forward(self, input_embeddings):
        queries = self.Wq(input_embeddings)
        keys = self.Wk(input_embeddings)
        values = self.Wv(input_embeddings)

        dk = keys.shape[-1] # d_out
        attention_scores = queries @ keys.T
        attention_weights = torch.softmax(attention_scores / (dk ** 0.5), dim=-1)
        context_vectors = attention_weights @ values

        return context_vectors

In [43]:
torch.manual_seed(0)
self_attention_v2 = SelfAttention_v2(D_IN, D_OUT)
context_vectors_v2 = self_attention_v2(input_embeddings)
print(f"Context Vectors from SelfAttention_v2:\n{context_vectors_v2}\nShape: {context_vectors_v2.shape}\n")

Context Vectors from SelfAttention_v2:
tensor([[ 0.4873, -0.5124,  0.0852],
        [ 0.3599, -0.2287,  0.2441],
        [ 0.5439, -0.6058, -0.0137],
        [ 0.4614, -0.4796,  0.1658],
        [ 0.5274, -0.5823,  0.0200],
        [ 0.3539, -0.2063,  0.2399]], grad_fn=<MmBackward0>)
Shape: torch.Size([6, 3])

