In [1]:
import torch
inputs = torch.tensor(
[[0.43, 0.15, 0.89], # your     (x^1)
 [0.55, 0.87, 0.66], # journey  (x^2)
 [0.55, 0.87, 0.66], # starts   (x^3)
 [0.22, 0.58, 0.33], # with     (x^4)
 [0.77, 0.25, 0.10], # one      (x^5)
 [0.05, 0.80, 0.55]] # step     (x^6)
)

In [2]:
x_2 = inputs[1] 
d_in = inputs.shape[1] # input embedding size , d = 3
d_out = 2 # output embedding size

In [3]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [4]:
W_query

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])

In [5]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)

tensor([0.4306, 1.4551])


In [6]:
keys = inputs @ W_key
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


In [7]:
keys

tensor([[0.3669, 0.7646],
        [0.4433, 1.1419],
        [0.4433, 1.1419],
        [0.2408, 0.6706],
        [0.1827, 0.3292],
        [0.3275, 0.9642]])

In [8]:
keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22) # unnormalized attention score

tensor(1.8524)


In [9]:
attn_scores_2 = query_2 @ keys.T # computaton to all attention scores
print(attn_scores_2) # 2nd element matches we computed prev (attn_scores_22)

tensor([1.2705, 1.8524, 1.8524, 1.0795, 0.5577, 1.5440])


In [10]:
d_k = keys.shape[-1] 
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1) # qk^T / root_dk(dim of the key matrix)
print(attn_weights_2)

tensor([0.1490, 0.2249, 0.2249, 0.1302, 0.0900, 0.1808])


In [11]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([0.3082, 0.8267])


In [12]:
# compact self-attention class

In [13]:
import torch.nn as nn
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
        
    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax( attn_scores / keys.shape[-1]**0.5, dim=-1 )
        context_vec = attn_weights @ values
        return context_vec

In [14]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs)) # inputs contains 6 embedding vectors , results in a matrix storing 6 context vectors

tensor([[0.3015, 0.8104],
        [0.3082, 0.8267],
        [0.3082, 0.8267],
        [0.2965, 0.7986],
        [0.2944, 0.7936],
        [0.3009, 0.8091]], grad_fn=<MmBackward0>)


In [15]:
# self-attention using linear layer.
# instead of manually implementing nn.Parameter(torch.rand(...)) is that nn.Linear
# has an optimized weight initialization scheme, contributing to more stable and
# effective model training.

In [16]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias = qkv_bias)
        
    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax( attn_scores / keys.shape[-1]**0.5, dim=-1 )
        context_vec = attn_weights @ values
        return context_vec

In [17]:
torch.manual_seed(123)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs)) # SelfAttention_v1 and SelfAttention_v2 give different outputs because
                     # they use different initial weights for the weight matrices since nn.Linear uses a more
                     # sophisticated weight initialization scheme.

tensor([[-0.5340, -0.1049],
        [-0.5326, -0.1078],
        [-0.5326, -0.1078],
        [-0.5300, -0.1074],
        [-0.5313, -0.1064],
        [-0.5302, -0.1079]], grad_fn=<MmBackward0>)


In [18]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
                                # Reuses the query and key weight matrices
                                # of the SelfAttention_v2 object from the
                                # previous section for convenience
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[0.1716, 0.1762, 0.1762, 0.1555, 0.1626, 0.1579],
        [0.1635, 0.1749, 0.1749, 0.1611, 0.1604, 0.1651],
        [0.1635, 0.1749, 0.1749, 0.1611, 0.1604, 0.1651],
        [0.1636, 0.1703, 0.1703, 0.1651, 0.1632, 0.1674],
        [0.1667, 0.1722, 0.1722, 0.1617, 0.1633, 0.1639],
        [0.1624, 0.1708, 0.1708, 0.1654, 0.1624, 0.1681]],
       grad_fn=<SoftmaxBackward0>)


In [19]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length)) #tril fn to create a mask where the values above the diagonal are zero
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [20]:
masked_simple = attn_weights * mask_simple
masked_simple

tensor([[0.1716, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1635, 0.1749, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1635, 0.1749, 0.1749, 0.0000, 0.0000, 0.0000],
        [0.1636, 0.1703, 0.1703, 0.1651, 0.0000, 0.0000],
        [0.1667, 0.1722, 0.1722, 0.1617, 0.1633, 0.0000],
        [0.1624, 0.1708, 0.1708, 0.1654, 0.1624, 0.1681]],
       grad_fn=<MulBackward0>)

In [21]:
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm) # renormalize the attention weights to sum up to 1 again in each row

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3186, 0.3407, 0.3407, 0.0000, 0.0000, 0.0000],
        [0.2444, 0.2544, 0.2544, 0.2467, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2060, 0.1934, 0.1953, 0.0000],
        [0.1624, 0.1708, 0.1708, 0.1654, 0.1624, 0.1681]],
       grad_fn=<DivBackward0>)


In [23]:
row_sums

tensor([[0.1716],
        [0.3384],
        [0.5133],
        [0.6694],
        [0.8361],
        [1.0000]], grad_fn=<SumBackward1>)

In [24]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[0.3111,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.1655, 0.2602,   -inf,   -inf,   -inf,   -inf],
        [0.1655, 0.2602, 0.2602,   -inf,   -inf,   -inf],
        [0.0510, 0.1080, 0.1080, 0.0643,   -inf,   -inf],
        [0.1415, 0.1875, 0.1875, 0.0987, 0.1121,   -inf],
        [0.0476, 0.1192, 0.1192, 0.0731, 0.0477, 0.0966]],
       grad_fn=<MaskedFillBackward0>)


In [25]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3186, 0.3407, 0.3407, 0.0000, 0.0000, 0.0000],
        [0.2444, 0.2544, 0.2544, 0.2467, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2060, 0.1934, 0.1953, 0.0000],
        [0.1624, 0.1708, 0.1708, 0.1654, 0.1624, 0.1681]],
       grad_fn=<SoftmaxBackward0>)


In [26]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6, 6)
print(dropout(example))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


In [27]:
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6373, 0.6814, 0.6814, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.5089, 0.5089, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4119, 0.0000, 0.3869, 0.0000, 0.0000],
        [0.0000, 0.3417, 0.3417, 0.3307, 0.3248, 0.0000]],
       grad_fn=<MulBackward0>)


In [28]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

torch.Size([2, 6, 3])
