In [1]:
## Self-Attention With Trainable Weights

To implement the self-attention mechanisme, we would need to introduce the trainable weight matrices
Wq, Wk and Wv
These three matrices are used to project the embedded input tokens into query, key and value vectors respectively /
In the first step of the self-attention mechanism with traninable weights, we compute:
* query (q)
* key (k)
* value (v) \
for the input elements x \

We designate the second input x2 as the query input
* the query vector is obtained by matrix multiplication between the input and the weight matrix Wq
* the key vector is obtained by matrix multiplication between the input and the weight matrix Wk
* the value vector is obtained by a matriv multiplication between the input and weight matrix Wv

In [2]:
import torch

#initialize the inputs
inputs = torch.tensor(
    [
        [0.43, 0.15, 0.89], # Your    (x^1)
        [0.55, 0.87, 0.66], # journey (x^2)
        [0.57, 0.85, 0.64], # starts  
        [0.22, 0.58, 0.33], # with
        [0.77, 0.25, 0.01], # one
        [0.05, 0.80, 0.55]  #step
    ]
)

x_2 = inputs[1] #the second element of the inputs
d_in = inputs.shape[1] # the size of the input embedding, d=3
d_out = 2 # the output embedding size, d_out=2

  cpu = _conversion_method_template(device=torch.device("cpu"))


In [3]:
x_2

tensor([0.5500, 0.8700, 0.6600])

In [4]:
# Initialize the weight matrices
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
# Note that we set the requires_grad parameter to False, but if we would use the weight matrices for training, we set to True

In [5]:
W_key

Parameter containing:
tensor([[0.1366, 0.1025],
        [0.1841, 0.7264],
        [0.3153, 0.6871]])

In [6]:
# We now comput the query, key and value vectors
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value

In [7]:
print(query_2)
print(key_2)
score_2 = query_2.dot(key_2)
print(score_2)

tensor([0.4306, 1.4551])
tensor([0.4433, 1.1419])
tensor(1.8524)


## Weight Parameters vs Attention Weights

The weight matrices here is not the same as the attention weights.
The weight parameters of a network are the values that are optimized during the training phase
Attention weights are the values that determine the extent to which the context vector depends on other parts of the input \
Weight parameters are teh fundamental learnt coefficients that define the network structur while attention weights are
dynamic context-specific values

In [8]:
# Obtain all keys and values
keys = inputs @ W_key
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


In [9]:
keys

tensor([[0.3669, 0.7646],
        [0.4433, 1.1419],
        [0.4361, 1.1156],
        [0.2408, 0.6706],
        [0.1543, 0.2674],
        [0.3275, 0.9642]])

In [10]:
keys_2 = keys[1]
keys_2

tensor([0.4433, 1.1419])

In [11]:
# Compute the attention scores for input 2 against keys_2
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)

tensor(1.8524)


In [12]:
# Compute the attention scores for the given query
attention_scores_2 = query_2 @ keys.T
print(attention_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.4555, 1.5440])


In [13]:
# We now need to calculate the attention weights
# this is done by scaling the attention scores by dividing them by the square root of the embedding dimention of the keys

d_k = keys.shape[-1] # calculate the size of the last dimension of the keys tensor
attn_weight_2 = torch.softmax(attention_scores_2/ d_k**0.5, dim=-1) # dk**0.5 is the scaling factor and used to stabilize the softmax computation
print(attn_weight_2)

tensor([0.1510, 0.2278, 0.2213, 0.1319, 0.0848, 0.1832])


In [14]:
# Compute the context vector by combining all value vectors via the attention weights
context_vec_2 = attn_weight_2 @ values
print(context_vec_2)

tensor([0.3062, 0.8178])


In [15]:
#Assignment: Compute all context vectors
context_vectors = torch.empty(inputs.size(0), W_value.size(1))
for index, item in enumerate(inputs):
    # compute the query
    query = item @ W_query
    key = inputs @ W_key
    value = inputs @ W_value

    # compute the attention scores across the keys
    attention_scores = query @ key.T

    # calculate the attention weight
    key_dim = key.shape[-1]
    attention_weights = torch.softmax(attention_scores/key_dim**0.5, dim=-1)

    # compute the context vector
    context_vector = attention_weights @ value

    # add the computed context vector to postion index of the context_vectors
    context_vectors[index] = context_vector

context_vectors

tensor([[0.2993, 0.8003],
        [0.3062, 0.8178],
        [0.3059, 0.8170],
        [0.2942, 0.7874],
        [0.2906, 0.7782],
        [0.2987, 0.7989]])

In [16]:
# Using our concise self-attention class
import torch
from utilities.SelfAttention import SelfAttention_v1
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2993, 0.8003],
        [0.3062, 0.8178],
        [0.3059, 0.8170],
        [0.2942, 0.7874],
        [0.2906, 0.7782],
        [0.2987, 0.7989]], grad_fn=<MmBackward0>)


In [17]:
# Using the improved self-attention class
from utilities.SelfAttention import SelfAttention_v2
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.5308, -0.1104],
        [-0.5293, -0.1133],
        [-0.5293, -0.1133],
        [-0.5266, -0.1131],
        [-0.5275, -0.1122],
        [-0.5268, -0.1136]], grad_fn=<MmBackward0>)
