In [1]:
import torch

inputs = torch.tensor(
    [[0.43, 0.15, 0.89],  # Your     (x^1)
     [0.55, 0.87, 0.66],  # journey  (x^2)
     [0.57, 0.85, 0.64],  # starts   (x^3)
     [0.22, 0.58, 0.33],  # with     (x^4)
     [0.77, 0.25, 0.10],  # one      (x^5)
     [0.05, 0.80, 0.55]]  # step     (x^6)
)

In [2]:
import torch.nn as nn

In [6]:
class SelfAttention_v2(nn.Module):

    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        context_vec = attn_weights @ values
        return context_vec

In [7]:

inputs = torch.tensor(
    [[0.43, 0.15, 0.89],  # Your     (x^1)
     [0.55, 0.87, 0.66],  # journey  (x^2)
     [0.57, 0.85, 0.64],  # starts   (x^3)
     [0.22, 0.58, 0.33],  # with     (x^4)
     [0.77, 0.25, 0.10],  # one      (x^5)
     [0.05, 0.80, 0.55]]  # step     (x^6)
)

In [8]:
d_in = inputs.shape[1]
d_out = 2
torch.manual_seed(789)
self_attention_v2 = SelfAttention_v2(d_in, d_out)
self_attention_v2(inputs)

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)

In [9]:
queries = self_attention_v2.W_query(inputs)
keys = self_attention_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores/keys.shape[-1]**5, dim=-1)
print(attn_weights)

tensor([[0.1677, 0.1666, 0.1666, 0.1662, 0.1669, 0.1660],
        [0.1682, 0.1667, 0.1667, 0.1659, 0.1667, 0.1658],
        [0.1682, 0.1667, 0.1667, 0.1659, 0.1667, 0.1658],
        [0.1675, 0.1667, 0.1667, 0.1662, 0.1667, 0.1662],
        [0.1674, 0.1667, 0.1667, 0.1663, 0.1666, 0.1663],
        [0.1678, 0.1667, 0.1667, 0.1661, 0.1667, 0.1661]],
       grad_fn=<SoftmaxBackward0>)


In [11]:
# We use the PyTorch tril function to create a mask where the values above the diagonal are zero 

context_length = attn_weights.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [12]:
masked_simple = attn_weights * mask_simple
masked_simple

tensor([[0.1677, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1682, 0.1667, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1682, 0.1667, 0.1667, 0.0000, 0.0000, 0.0000],
        [0.1675, 0.1667, 0.1667, 0.1662, 0.0000, 0.0000],
        [0.1674, 0.1667, 0.1667, 0.1663, 0.1666, 0.0000],
        [0.1678, 0.1667, 0.1667, 0.1661, 0.1667, 0.1661]],
       grad_fn=<MulBackward0>)

In [None]:
# Normalize the attention weights so that each row sums to 1.
# This can be done by dividing each element in a row by the sum of that row.

In [13]:
row_sums = masked_simple.sum(dim=1, keepdim=True)
masked_simple_norm = masked_simple/row_sums
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5023, 0.4977, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3353, 0.3323, 0.3323, 0.0000, 0.0000, 0.0000],
        [0.2511, 0.2498, 0.2499, 0.2492, 0.0000, 0.0000],
        [0.2008, 0.1999, 0.1999, 0.1995, 0.1999, 0.0000],
        [0.1678, 0.1667, 0.1667, 0.1661, 0.1667, 0.1661]],
       grad_fn=<DivBackward0>)


Another way of the mask out future tokens

In [None]:

# mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
# Structured Dry Run:
# 1. context_length = 6
# 2. Create mask = torch.triu(torch.ones(6, 6), diagonal=1):
#    - torch.ones(6, 6) creates a 6x6 tensor filled with 1s
#    - torch.triu(..., diagonal=1) keeps elements above the main diagonal (i.e., i < j) as 1, others as 0
#    mask = tensor([[0., 1., 1., 1., 1., 1.],
#                   [0., 0., 1., 1., 1., 1.],
#                   [0., 0., 0., 1., 1., 1.],
#                   [0., 0., 0., 0., 1., 1.],
#                   [0., 0., 0., 0., 0., 1.],
#                   [0., 0., 0., 0., 0., 0.]])

# masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
# 3. Convert mask to boolean: mask.bool() = tensor([[False,  True,  True,  True,  True,  True],
#                                                 [False, False,  True,  True,  True,  True],
#                                                 [False, False, False,  True,  True,  True],
#                                                 [False, False, False, False,  True,  True],
#                                                 [False, False, False, False, False,  True],
#                                                 [False, False, False, False, False, False]])
# 4. Apply masked_fill:
#    - Replace elements where mask.bool() is True with -torch.inf
#    - masked = attn_scores where mask is 0, else -torch.inf
#    masked = tensor([[ 1.2825,    -inf,    -inf,    -inf,    -inf,    -inf],
#                     [ 1.7507,  2.0249,    -inf,    -inf,    -inf,    -inf],
#                     [ 1.7269,  1.9966,  1.9817,    -inf,    -inf,    -inf],
#                     [ 0.5938,  1.1114,  1.1028,  0.5625,    -inf,    -inf],
#                     [ 0.5294,  0.9639,  0.9560,  0.4898,  0.5778,    -inf],
#                     [ 1.2447,  1.4388,  1.4287,  0.7274,  0.8839,  0.8610]])

# print(masked)
# 5. Print: tensor([[ 1.2825,    -inf,    -inf,    -inf,    -inf,    -inf],
#                  [ 1.7507,  2.0249,    -inf,    -inf,    -inf,    -inf],
#                  [ 1.7269,  1.9966,  1.9817,    -inf,    -inf,    -inf],
#                  [ 0.5938,  1.1114,  1.1028,  0.5625,    -inf,    -inf],
#                  [ 0.5294,  0.9639,  0.9560,  0.4898,  0.5778,    -inf],
#                  [ 1.2447,  1.4388,  1.4287,  0.7274,  0.8839,  0.8610]])

# Structured Dry Run Summary:
# 1. context_length = 6
# 2. Create mask = torch.triu(torch.ones(6, 6), diagonal=1):
#    - mask = [[0., 1., 1., 1., 1., 1.],
#              [0., 0., 1., 1., 1., 1.],
#              [0., 0., 0., 1., 1., 1.],
#              [0., 0., 0., 0., 1., 1.],
#              [0., 0., 0., 0., 0., 1.],
#              [0., 0., 0., 0., 0., 0.]]
# 3. Convert to boolean: mask.bool() = [[False, True, True, True, True, True], ...]
# 4. Compute masked = attn_scores.masked_fill(mask.bool(), -torch.inf):
#    - masked = [[ 1.2825,    -inf,    -inf,    -inf,    -inf,    -inf],
#                [ 1.7507,  2.0249,    -inf,    -inf,    -inf,    -inf],
#                [ 1.7269,  1.9966,  1.9817,    -inf,    -inf,    -inf],
#                [ 0.5938,  1.1114,  1.1028,  0.5625,    -inf,    -inf],
#                [ 0.5294,  0.9639,  0.9560,  0.4898,  0.5778,    -inf],
#                [ 1.2447,  1.4388,  1.4287,  0.7274,  0.8839,  0.8610]]
# 5. Print: tensor([[ 1.2825,    -inf,    -inf,    -inf,    -inf,    -inf], ...])

In [14]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)


In [15]:
# apply softmax

attn_weights = torch.softmax(masked/keys.shape[0]**0.5, dim=1)
attn_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5299, 0.4701, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3599, 0.3199, 0.3202, 0.0000, 0.0000, 0.0000],
        [0.2647, 0.2478, 0.2479, 0.2395, 0.0000, 0.0000],
        [0.2100, 0.1991, 0.1991, 0.1935, 0.1983, 0.0000],
        [0.1818, 0.1666, 0.1667, 0.1595, 0.1667, 0.1587]],
       grad_fn=<SoftmaxBackward0>)

In [18]:
# Dropouts
torch.manual_seed(123)
dropouts=torch.nn.Dropout(0.5)
example=torch.ones(6,6)
print(dropouts(example))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


#### **Multiple Batches Input**

In [19]:
batch = torch.stack((inputs, inputs), dim=0)
batch.shape

torch.Size([2, 6, 3])

In [22]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        # causal mask (upper triangular)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        # x: (batch_size, num_tokens, d_in)
        batch_size, num_tokens, d_in = x.shape

        # project input into queries, keys, values
        keys = self.W_key(x)      # (batch, num_tokens, d_out)
        queries = self.W_query(x) # (batch, num_tokens, d_out)
        values = self.W_value(x)  # (batch, num_tokens, d_out)

        # attention scores: (batch, num_tokens, num_tokens)
        attn_scores = queries @ keys.transpose(1, 2)

        # apply causal mask
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
        )

        # normalize with softmax
        attn_weights = torch.softmax(attn_scores / (keys.shape[-1] ** 0.5), dim=-1)
        attn_weights = self.dropout(attn_weights)

        # weighted sum of values
        context_vec = attn_weights @ values

        return context_vec


In [24]:
print(d_in), print(d_out)

3
2


(None, None)

In [25]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)

print("Context vector shape: ", context_vecs.shape)

Context vector shape:  torch.Size([2, 6, 2])


In [26]:
context_vecs[0]

tensor([[-0.4519,  0.2216],
        [-0.5874,  0.0058],
        [-0.6300, -0.0632],
        [-0.5675, -0.0843],
        [-0.5526, -0.0981],
        [-0.5299, -0.1081]], grad_fn=<SelectBackward0>)

In [27]:
context_vecs[1]

tensor([[-0.4519,  0.2216],
        [-0.5874,  0.0058],
        [-0.6300, -0.0632],
        [-0.5675, -0.0843],
        [-0.5526, -0.0981],
        [-0.5299, -0.1081]], grad_fn=<SelectBackward0>)