# Vanilla Self Attention

In [4]:
import torch
_ = torch.manual_seed(123)

In [5]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [6]:
input_emb_dim, output_emb_dim = inputs.shape[-1], 2

W_query = torch.nn.Parameter(torch.rand(input_emb_dim, output_emb_dim), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(input_emb_dim, output_emb_dim), requires_grad=False)
W_val = torch.nn.Parameter(torch.rand(input_emb_dim, output_emb_dim), requires_grad=False)

### Calcualting self attention for a single token in a data instance

In [7]:
# Calculating attention score of first token "Your" for second token "journey"
token_1_emb, token_2_emb = inputs[0], inputs[1]
token_2_query_emb = token_2_emb @ W_query

token_1_key_emb = token_1_emb @ W_key
token_1_attn_score_for_token_2 = torch.dot(token_2_query_emb, token_1_key_emb)

print(token_1_attn_score_for_token_2)

tensor(1.2705)


In [8]:
# Calculating attention scores for second token "journey" with respect to all tokens in the data instance
keys = inputs @ W_key

token_2_attn_scores = token_2_query_emb @ keys.T
print(token_2_attn_scores)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


In [9]:
# Scaling second token's attention scores, and normalizing to get attention weights
dim_k = keys.shape[-1]
token_2_scaled_attn_scores = token_2_attn_scores / dim_k ** 0.5

token_2_attn_wts = torch.softmax(token_2_scaled_attn_scores, dim=-1)
print(token_2_attn_wts)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


In [10]:
# Calculating second token's context vector
values = inputs @ W_val

token_2_context_vec = token_2_attn_wts @ values
print(token_2_context_vec)

tensor([0.3061, 0.8210])


### Calcualting self attention for all tokens in a data instance

In [11]:
queries = inputs @ W_query

attn_scores = queries @ keys.T
attn_wts = torch.softmax(attn_scores / dim_k ** 0.5, dim=-1)

context_vecs = attn_wts @ values
print(context_vecs)

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]])


### Self Attention Class V1

In [12]:
import torch.nn as nn

class SelfAttentionV1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_q = nn.Parameter(torch.rand(d_in, d_out))
        self.W_k = nn.Parameter(torch.rand(d_in, d_out))
        self.W_v = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        queries = x @ self.W_q
        keys = x @ self.W_k
        values = x @ self.W_v

        attn_scores = queries @ keys.T
        attn_wts = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)

        context_vecs = attn_wts @ values
        return context_vecs

In [14]:
torch.manual_seed(123)
self_attn_layer_v1 = SelfAttentionV1(d_in=inputs.shape[-1], d_out=inputs.shape[-1] - 1)
outputs = self_attn_layer_v1(inputs)
print(outputs)

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


### Self Attention Class V2

In [17]:
class SelfAttentionV2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_q = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_k = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_v = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        queries = self.W_q(x)
        keys = self.W_k(x)
        values = self.W_v(x)

        attn_scores = queries @ keys.T
        attn_wts = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)

        context_vecs = attn_wts @ values
        return context_vecs

In [19]:
torch.manual_seed(789)
self_attn_layer_v2 = SelfAttentionV2(d_in=inputs.shape[-1], d_out=inputs.shape[-1] - 1)
outputs = self_attn_layer_v2(inputs)
print(outputs)

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


### Weight Initialization & Storage Differences (V1 vs V2)

In [30]:
print(self_attn_layer_v2.W_q.weight)
print(self_attn_layer_v1.W_q)

print("\n")

print(type(self_attn_layer_v2.W_q.weight))
print(type(self_attn_layer_v1.W_q))

Parameter containing:
tensor([[ 0.3161,  0.4568,  0.5118],
        [-0.1683, -0.3379, -0.0918]], requires_grad=True)
Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]], requires_grad=True)


<class 'torch.nn.parameter.Parameter'>
<class 'torch.nn.parameter.Parameter'>


In [32]:
self_attn_layer_v1.W_q = nn.Parameter(self_attn_layer_v2.W_q.weight.T)
self_attn_layer_v1.W_k = nn.Parameter(self_attn_layer_v2.W_k.weight.T)
self_attn_layer_v1.W_v = nn.Parameter(self_attn_layer_v2.W_v.weight.T)

print("*** From SelfAttentionV2 Layer ***\n", outputs)
outputs = self_attn_layer_v1(inputs)
print("\n*** From SelfAttentionV1 (Weights Updated) Layer ***\n", outputs)

*** From SelfAttentionV2 Layer ***
 tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)

*** From SelfAttentionV1 (Weights Updated) Layer ***
 tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)
