In [1]:
import torch

inputs = torch.tensor(
    [[0.43, 0.5, 0.57], # your
    [0.5, 0.6, 0.7],    # x_2
    [0.57, 0.7, 0.8],
    [0.6, 0.7, 0.8],      
    [0.7, 0.8, 0.9],
    [0.8, 0.9, 1.0]],
    dtype=torch.float32
)

In [2]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [3]:
# Now we initialize the q, k, v matrices
torch.manual_seed(0)  # For reproducibility
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False) # d_input = 3, d_model = 3
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)


In [4]:
print("W_query:", W_query)

W_query: Parameter containing:
tensor([[0.4963, 0.7682],
        [0.0885, 0.1320],
        [0.3074, 0.6341]])


In [5]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print("query_2:", query_2)

query_2: tensor([0.5164, 0.9072])


In [6]:
keys = inputs @ W_key
queries = inputs @ W_query
values = inputs @ W_value
print("keys:", keys.shape)
print("queries:", queries.shape)
print("values:", values.shape)

keys: torch.Size([6, 2])
queries: torch.Size([6, 2])
values: torch.Size([6, 2])


In [7]:
# These embeddings have been projected from a 3D space to a 2D space. ( d_in = 3, d_out = 2 )

In [8]:
keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2) # Compute attention score
print("attn_score_2:", attn_score_22)

attn_score_2: tensor(1.3997)


In [9]:
attn_scores_2 = query_2 @ keys.T  # Compute attention scores for all pairs
print("attn_scores_2:", attn_scores_2)

attn_scores_2: tensor([1.1734, 1.3997, 1.6097, 1.6417, 1.8837, 2.1257])


In [10]:
attn_scores = queries @ keys.T  # Compute attention scores for all inputs
print("attn_scores:", attn_scores)

attn_scores: tensor([[0.9811, 1.1703, 1.3459, 1.3727, 1.5750, 1.7773],
        [1.1734, 1.3997, 1.6097, 1.6417, 1.8837, 2.1257],
        [1.3421, 1.6010, 1.8412, 1.8778, 2.1545, 2.4313],
        [1.3731, 1.6379, 1.8836, 1.9211, 2.2042, 2.4874],
        [1.5727, 1.8761, 2.1575, 2.2004, 2.5247, 2.8490],
        [1.7724, 2.1142, 2.4314, 2.4797, 2.8452, 3.2107]])


In [11]:
# Now we will get attention weights by scaling by the square root of d_out
attn_scores_2 = attn_scores_2 / (d_out ** 0.5)  # Scale the attention scores
print("attn_scores_2 (scaled):", attn_scores_2)
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)  # Softmax over the attention scores
print("attn_weights_2:", attn_weights_2)

attn_scores_2 (scaled): tensor([0.8297, 0.9898, 1.1382, 1.1609, 1.3320, 1.5031])
attn_weights_2: tensor([0.1171, 0.1374, 0.1594, 0.1630, 0.1935, 0.2296])


In [12]:
# Now we will calculate the context vector

In [13]:
context_2 = attn_weights_2 @ values  # Weighted sum of values
print("context_2:", context_2)

context_2: tensor([0.8033, 1.1431])


In [14]:
import torch.nn as nn
# Now we will compile this into a class
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super(SelfAttention_v1, self).__init__()
        self.W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
        self.W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
        self.W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

    def forward(self, inputs):
        queries = inputs @ self.W_query
        keys = inputs @ self.W_key
        values = inputs @ self.W_value
        
        attn_scores = queries @ keys.T
        attn_scores = attn_scores / (self.W_query.shape[1] ** 0.5)
        attn_weights = torch.softmax(attn_scores, dim=0)
        
        context = attn_weights @ values
        return context

In [15]:
torch.manual_seed(0)  # For reproducibility
sa_v1 = SelfAttention_v1(d_in, d_out)
context_v1 = sa_v1(inputs)
print("context_v1:", context_v1)

context_v1: tensor([[0.5045, 0.7172],
        [0.6124, 0.8707],
        [0.7267, 1.0334],
        [0.7500, 1.0664],
        [0.9197, 1.3079],
        [1.1293, 1.6061]])


In [16]:
# This is the self Attention mechanism implemented in a class.

In [25]:
import torch.nn as nn
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out):
        super(SelfAttention_v2, self).__init__()
        self.W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
        self.W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
        self.W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

    def forward(self, inputs):
        queries = inputs @ self.W_query
        keys = inputs @ self.W_key
        values = inputs @ self.W_value
        
        attn_scores = queries @ keys.T
        attn_scores = attn_scores / (self.W_query.shape[1] ** 0.5)
        attn_weights = torch.softmax(attn_scores, dim=0)
        
        context = attn_weights @ values
        return attn_weights

In [26]:
torch.manual_seed(0)  # For reproducibility
sa_v2 = SelfAttention_v2(d_in, d_out)
attn_weights_v2 = sa_v2(inputs)
print("attn_weights_v2:", attn_weights_v2)


attn_weights_v2: tensor([[0.1246, 0.1174, 0.1109, 0.1099, 0.1028, 0.0961],
        [0.1428, 0.1380, 0.1336, 0.1330, 0.1279, 0.1229],
        [0.1608, 0.1591, 0.1574, 0.1571, 0.1549, 0.1526],
        [0.1644, 0.1634, 0.1622, 0.1620, 0.1605, 0.1587],
        [0.1893, 0.1933, 0.1969, 0.1974, 0.2013, 0.2050],
        [0.2180, 0.2288, 0.2389, 0.2405, 0.2525, 0.2647]])


In [27]:
attn_scores_lower = attn_scores.clone()
attn_scores_lower[torch.triu_indices(*attn_scores_lower.shape, offset=1).unbind(0)] = 0
print(attn_scores_lower)

tensor([[0.9811, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.1734, 1.3997, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.3421, 1.6010, 1.8412, 0.0000, 0.0000, 0.0000],
        [1.3731, 1.6379, 1.8836, 1.9211, 0.0000, 0.0000],
        [1.5727, 1.8761, 2.1575, 2.2004, 2.5247, 0.0000],
        [1.7724, 2.1142, 2.4314, 2.4797, 2.8452, 3.2107]])


In [28]:
# renormalize the attention scores
attn_scores_lower = attn_scores_lower / attn_scores_lower.sum(dim=1, keepdim=True)
print("renormalized attn_scores_lower:", attn_scores_lower)

renormalized attn_scores_lower: tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4560, 0.5440, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2805, 0.3346, 0.3848, 0.0000, 0.0000, 0.0000],
        [0.2015, 0.2403, 0.2764, 0.2819, 0.0000, 0.0000],
        [0.1522, 0.1816, 0.2088, 0.2130, 0.2444, 0.0000],
        [0.1193, 0.1423, 0.1637, 0.1669, 0.1915, 0.2162]])


In [29]:
# this causes data leakage, so we will not use this in the future
# but it is useful to understand how the attention mechanism works in a self-supervised manner

In [31]:
# Set the upper triangular part of the attention scores to negative infinity
attn_scores_upper = attn_scores.clone()
attn_scores_upper[torch.triu_indices(*attn_scores_upper.shape, offset=1).unbind(0)] = float('-inf')
print("attn_scores_upper:", attn_scores_upper)

attn_scores_upper: tensor([[0.9811,   -inf,   -inf,   -inf,   -inf,   -inf],
        [1.1734, 1.3997,   -inf,   -inf,   -inf,   -inf],
        [1.3421, 1.6010, 1.8412,   -inf,   -inf,   -inf],
        [1.3731, 1.6379, 1.8836, 1.9211,   -inf,   -inf],
        [1.5727, 1.8761, 2.1575, 2.2004, 2.5247,   -inf],
        [1.7724, 2.1142, 2.4314, 2.4797, 2.8452, 3.2107]])


In [33]:
#renormalize the attention scores
attn_scores_upper = torch.softmax(attn_scores_upper, dim=1)
print("renormalized attn_scores_upper:", attn_scores_upper)

renormalized attn_scores_upper: tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4437, 0.5563, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2536, 0.3286, 0.4178, 0.0000, 0.0000, 0.0000],
        [0.1755, 0.2287, 0.2924, 0.3035, 0.0000, 0.0000],
        [0.1161, 0.1572, 0.2084, 0.2175, 0.3008, 0.0000],
        [0.0740, 0.1042, 0.1431, 0.1502, 0.2165, 0.3120]])


In [None]:
# Now we will implement a causal attention mechanism with dropout
class CausalSelfAttention(nn.Module):
    def __init__(self, d_in, d_out, dropout=0.1):
        super(CausalSelfAttention, self).__init__()
        self.W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
        self.W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
        self.W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, inputs):
        queries = inputs @ self.W_query
        keys = inputs @ self.W_key
        values = inputs @ self.W_value
        
        attn_scores = queries @ keys.T
        attn_scores = attn_scores / (self.W_query.shape[1] ** 0.5)
        
        # Set the upper triangular part of the attention scores to negative infinity
        attn_scores[torch.triu_indices(*attn_scores.shape, offset=1).unbind(0)] = float('-inf')
        
        attn_weights = torch.softmax(attn_scores, dim=1)
        attn_weights = self.dropout(attn_weights)  # Apply dropout
        
        context = attn_weights @ values
        return context