### Implementing a self attention mechanism with trainable weights

In [1]:
import torch 
# Each row represents a word and each column represents an embedding dimension.

inputs = torch.tensor(
    [[0.43, 0.15, 0.89],    # Your      (x^1)
     [0.55, 0.87, 0.66],    # journey   (x^2)
     [0.57, 0.85, 0.64],    # starts    (x^3)
     [0.22, 0.58, 0.33],    # with      (x^4)
     [0.77, 0.25, 0.10],    # one       (x^5)
     [0.05, 0.80, 0.55]]    # step      (x^6)
)

In [2]:
### A - the second input element
### B - the input embedding size, d_in=3
### C - the output embedding size, d_out=2

In [3]:
x_2 = inputs[1] # A
d_in = inputs.shape[1]  # B
d_out = 2   # C

In [4]:
### In GPT-like models, the input and output dimensions are usually the same.

In [5]:
torch.manual_seed(123)

W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [6]:
print(W_query)

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])


In [7]:
print(W_key)

Parameter containing:
tensor([[0.1366, 0.1025],
        [0.1841, 0.7264],
        [0.3153, 0.6871]])


In [8]:
print(W_value)

Parameter containing:
tensor([[0.0756, 0.1966],
        [0.3164, 0.4017],
        [0.1186, 0.8274]])


In [9]:
### Note that we are setting requires_grad=False to reduce clutter in the outputs for illustration.
### If we use the weight matrices for model training, we should set the requires_grad=True to update these matrices during training.

In [10]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)


tensor([0.4306, 1.4551])


In [12]:
keys = inputs @ W_key
queries = inputs @ W_query
values = inputs @ W_value

print("keys.shape:", keys.shape)
print("queries.shape:", queries.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
queries.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


In [13]:
### We have projected the 6 input tokens from a 3D onto a 2D embedding space

In [14]:
# computing the attention score for w_22

keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)

tensor(1.8524)


In [None]:
attn_scores_2 = query_2 @ keys.T    # All attention scores for given query i.e "journey"
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


In [16]:
attn_scores = queries @ keys.T # omega
print(attn_scores)

tensor([[0.9231, 1.3545, 1.3241, 0.7910, 0.4032, 1.1330],
        [1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440],
        [1.2544, 1.8284, 1.7877, 1.0654, 0.5508, 1.5238],
        [0.6973, 1.0167, 0.9941, 0.5925, 0.3061, 0.8475],
        [0.6114, 0.8819, 0.8626, 0.5121, 0.2707, 0.7307],
        [0.8995, 1.3165, 1.2871, 0.7682, 0.3937, 1.0996]])


In [17]:
### We compute the attention weights by scaling the attention scores and using the softmax function we used earlier.
### The difference to earlier is that we now scale the attention scores by dividing them by the square root of the embedding
### dimension of the keys.

In [18]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)
print(d_k)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])
2


In [21]:
attn_weights = torch.softmax(attn_scores / d_k**0.5, dim=-1)
print(attn_weights)

tensor([[0.1551, 0.2104, 0.2059, 0.1413, 0.1074, 0.1799],
        [0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820],
        [0.1503, 0.2256, 0.2192, 0.1315, 0.0914, 0.1819],
        [0.1591, 0.1994, 0.1962, 0.1477, 0.1206, 0.1769],
        [0.1610, 0.1949, 0.1923, 0.1501, 0.1265, 0.1752],
        [0.1557, 0.2092, 0.2048, 0.1419, 0.1089, 0.1794]])


In [None]:
### We now compute the context vector as a weighted sum over the value vectors.
### Here, the attention weights serve as a weighting factor that weighs the respective importance of each value vector.

In [20]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([0.3061, 0.8210])


In [22]:
context_vectors = attn_weights @ values
print(context_vectors)

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]])


### Implementing a compact self attention python class

In [26]:
import torch.nn as nn

class SelfAttention_V1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
    
    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        
        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        
        context_vectors = attn_weights @ values
        return context_vectors

In [27]:
torch.manual_seed(123)
sa_v1 = SelfAttention_V1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


In [28]:
### Since inputs contains six embedding vectors, we get a matrix storing the six context vectors

In [29]:
### We can improve the SelfAttention_v1 implementation further by utilizing pytorch's nn.Linear layers which effectively
### performs matrix multiplication when the bias units are disabled.
### Additionally, a significant advantage of using nn.Linear instead of manually implementing nn.Parameter(torch.rand(..)) is that 
### nn.Linear has an optimized weight initialization scheme, contributing to more stable and effective model training.

In [30]:
class SelfAttention_v2(nn.Module):
    
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    
    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        
        context_vectors = attn_weights @ values
        return context_vectors

In [33]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)
