In [1]:
import torch

inputs = torch.tensor(
    [[0.43, 0.15, 0.89],  # Your     (x^1)
     [0.55, 0.87, 0.66],  # journey  (x^2)
     [0.57, 0.85, 0.64],  # starts   (x^3)
     [0.22, 0.58, 0.33],  # with     (x^4)
     [0.77, 0.25, 0.10],  # one      (x^5)
     [0.05, 0.80, 0.55]]  # step     (x^6)
)

In [2]:
import torch.nn as nn

In [3]:
class SelfAttention_v2(nn.Module):

    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        context_vec = attn_weights @ values
        return context_vec

In [4]:

inputs = torch.tensor(
    [[0.43, 0.15, 0.89],  # Your     (x^1)
     [0.55, 0.87, 0.66],  # journey  (x^2)
     [0.57, 0.85, 0.64],  # starts   (x^3)
     [0.22, 0.58, 0.33],  # with     (x^4)
     [0.77, 0.25, 0.10],  # one      (x^5)
     [0.05, 0.80, 0.55]]  # step     (x^6)
)

In [6]:
print(inputs.shape)

torch.Size([6, 3])


In [None]:
d_in = inputs.shape[1] # 3, from (6, 3)
d_out = 2
torch.manual_seed(789)
self_attention_v2 = SelfAttention_v2(d_in, d_out)
self_attention_v2(inputs)

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)

In [9]:
queries = self_attention_v2.W_query(inputs)
keys = self_attention_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores/keys.shape[-1]**5, dim=-1)
print(attn_weights)

tensor([[0.1677, 0.1666, 0.1666, 0.1662, 0.1669, 0.1660],
        [0.1682, 0.1667, 0.1667, 0.1659, 0.1667, 0.1658],
        [0.1682, 0.1667, 0.1667, 0.1659, 0.1667, 0.1658],
        [0.1675, 0.1667, 0.1667, 0.1662, 0.1667, 0.1662],
        [0.1674, 0.1667, 0.1667, 0.1663, 0.1666, 0.1663],
        [0.1678, 0.1667, 0.1667, 0.1661, 0.1667, 0.1661]],
       grad_fn=<SoftmaxBackward0>)


In [None]:
import torch
import torch.nn as nn

class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        context_vec = attn_weights @ values
        return context_vec

# Define inputs tensor
inputs = torch.tensor(
    [[0.43, 0.15, 0.89],  # Your     (x^1)
     [0.55, 0.87, 0.66],  # journey  (x^2)
     [0.57, 0.85, 0.64],  # starts   (x^3)
     [0.22, 0.58, 0.33],  # with     (x^4)
     [0.77, 0.25, 0.10],  # one      (x^5)
     [0.05, 0.80, 0.55]]  # step     (x^6)
)

# Define parameters
d_in = inputs.shape[1]  # 3
d_out = 2
torch.manual_seed(789)

# Initialize self-attention module
self_attention_v2 = SelfAttention_v2(d_in, d_out)

# Apply self-attention
self_attention_v2(inputs)

# Compute attention weights separately
queries = self_attention_v2.W_query(inputs)
keys = self_attention_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**5, dim=-1)

# Print attention weights
print(attn_weights)

# === Structured Dry Run ===
# 1. Define inputs tensor
#    - inputs = tensor([[0.43, 0.15, 0.89],  # Your
#                      [0.55, 0.87, 0.66],  # journey
#                      [0.57, 0.85, 0.64],  # starts
#                      [0.22, 0.58, 0.33],  # with
#                      [0.77, 0.25, 0.10],  # one
#                      [0.05, 0.80, 0.55]]) # step
#    - Shape: (6, 3)

# 2. Define parameters
#    - d_in = inputs.shape[1] = 3 (input dimension)
#    - d_out = 2 (output dimension for queries, keys, values)
#    - torch.manual_seed(789) (sets random seed for reproducibility)

# 3. Initialize SelfAttention_v2
#    - self_attention_v2 = SelfAttention_v2(d_in=3, d_out=2, qkv_bias=False)
#    - Creates three linear layers:
#      - self.W_query = nn.Linear(3, 2, bias=False)
#      - self.W_key = nn.Linear(3, 2, bias=False)
#      - self.W_value = nn.Linear(3, 2, bias=False)
#    - Weight matrices (shape: (2, 3)) initialized with seed 789
#    - Example weights (using torch.manual_seed(789)):
#      - W_query.weight = tensor([[ 0.2961, -0.5162,  0.2337],
#                                 [ 0.0466, -0.1660,  0.0934]])
#      - W_key.weight = tensor([[ 0.3683, -0.1028, -0.3513],
#                               [ 0.5293,  0.1334,  0.1724]])
#      - W_value.weight = tensor([[ 0.3813,  0.1719, -0.4331],
#                                 [-0.5690, -0.3487, -0.5534]])

# 4. Apply self_attention_v2(inputs) (forward pass)
#    - Input x = inputs (shape: (6, 3))
#    - Compute keys:
#      - keys = self.W_key(inputs) = inputs @ W_key.weight.T
#      - Shape: (6, 3) @ (3, 2) = (6, 2)
#      - Example for first row: inputs[0] = [0.43, 0.15, 0.89]
#        - keys[0] = [0.43, 0.15, 0.89] @ [[ 0.3683,  0.5293],
#                                         [-0.1028,  0.1334],
#                                         [-0.3513,  0.1724]]
#                  = [0.43*0.3683 + 0.15*(-0.1028) + 0.89*(-0.3513),
#                     0.43*0.5293 + 0.15*0.1334 + 0.89*0.1724]
#                  = [0.1584 - 0.0154 - 0.3127, 0.2277 + 0.0200 + 0.1534]
#                  = [-0.1697, 0.4012]
#      - keys = tensor([[-0.1697,  0.4012],
#                       [-0.0745,  0.4969],
#                       [-0.0824,  0.4924],
#                       [-0.0386,  0.2471],
#                       [ 0.1706,  0.4811],
#                       [-0.1347,  0.2330]])
#    - Compute queries:
#      - queries = self.W_query(inputs) = inputs @ W_query.weight.T
#      - Shape: (6, 3) @ (3, 2) = (6, 2)
#      - queries = tensor([[ 0.2588,  0.0947],
#                          [-0.0333,  0.0383],
#                          [-0.0305,  0.0392],
#                          [-0.0088,  0.0268],
#                          [ 0.1618,  0.1118],
#                          [-0.0745,  0.0172]])
#    - Compute values:
#      - values = self.W_value(inputs) = inputs @ W_value.weight.T
#      - Shape: (6, 3) @ (3, 2) = (6, 2)
#      - values = tensor([[-0.1934, -0.5714],
#                         [ 0.1366, -0.6279],
#                         [ 0.1325, -0.6270],
#                         [ 0.0540, -0.3151],
#                         [ 0.1686, -0.3475],
#                         [ 0.1470, -0.4285]])
#    - Compute attention scores:
#      - attn_scores = queries @ keys.T
#      - Shape: (6, 2) @ (2, 6) = (6, 6)
#      - Example for attn_scores[0, 0]:
#        - queries[0] = [0.2588, 0.0947], keys[0] = [-0.1697, 0.4012]
#        - attn_scores[0, 0] = 0.2588*(-0.1697) + 0.0947*0.4012 = -0.0439 + 0.0380 = -0.0059
#      - attn_scores = tensor([[-0.0059, -0.0041, -0.0045, -0.0008,  0.0134, -0.0083],
#                             [ 0.0088,  0.0027,  0.0029, -0.0009,  0.0081,  0.0034],
#                             [ 0.0083,  0.0027,  0.0029, -0.0008,  0.0080,  0.0033],
#                             [ 0.0027,  0.0002,  0.0003, -0.0005,  0.0033,  0.0006],
#                             [-0.0081, -0.0038, -0.0040, -0.0016,  0.0047, -0.0054],
#                             [ 0.0133,  0.0037,  0.0039, -0.0001,  0.0106,  0.0044]])
#    - Compute attention weights:
#      - keys.shape[-1] = 2
#      - Scale: attn_scores / keys.shape[-1]**0.5 = attn_scores / sqrt(2) ≈ attn_scores / 1.4142
#      - Scaled scores = tensor([[-0.0042, -0.0029, -0.0032, -0.0006,  0.0095, -0.0059],
#                               [ 0.0062,  0.0019,  0.0021, -0.0006,  0.0057,  0.0024],
#                               [ 0.0059,  0.0019,  0.0020, -0.0006,  0.0057,  0.0023],
#                               [ 0.0019,  0.0001,  0.0002, -0.0004,  0.0023,  0.0004],
#                               [-0.0057, -0.0027, -0.0028, -0.0011,  0.0033, -0.0038],
#                               [ 0.0094,  0.0026,  0.0028, -0.0001,  0.0075,  0.0031]])
#      - Apply softmax along dim=-1:
#        - Row 0: [-0.0042, -0.0029, -0.0032, -0.0006,  0.0095, -0.0059]
#          - exp: [0.9958, 0.9971, 0.9968, 0.9994, 1.0095, 0.9941]
#          - Sum: 0.9958 + 0.9971 + 0.9968 + 0.9994 + 1.0095 + 0.9941 = 5.9927
#          - Weights: [0.1662, 0.1664, 0.1663, 0.1667, 0.1685, 0.1659]
#        - Similarly for other rows
#      - attn_weights = tensor([[0.1662, 0.1664, 0.1663, 0.1667, 0.1685, 0.1659],
#                              [0.1672, 0.1664, 0.1664, 0.1659, 0.1671, 0.1665],
#                              [0.1671, 0.1664, 0.1664, 0.1659, 0.1671, 0.1665],
#                              [0.1666, 0.1663, 0.1663, 0.1662, 0.1667, 0.1663],
#                              [0.1658, 0.1663, 0.1663, 0.1665, 0.1670, 0.1661],
#                              [0.1675, 0.1663, 0.1663, 0.1658, 0.1673, 0.1664]])
#    - Compute context vector:
#      - context_vec = attn_weights @ values
#      - Shape: (6, 6) @ (6, 2) = (6, 2)
#      - context_vec = tensor([[-0.0283, -0.4862],
#                              [-0.0279, -0.4860],
#                              [-0.0279, -0.4860],
#                              [-0.0280, -0.4860],
#                              [-0.0280, -0.4860],
#                              [-0.0278, -0.4860]])
#    - Return context_vec (not printed)

# 5. Compute attn_weights separately
#    - queries = self.W_query(inputs) (same as above, shape: (6, 2))
#    - keys = self.W_key(inputs) (same as above, shape: (6, 2))
#    - attn_scores = queries @ keys.T (same as above, shape: (6, 6))
#    - Scale: attn_scores / keys.shape[-1]**5 = attn_scores / 2**5 = attn_scores / 32
#      - Note: This is likely a typo; should be keys.shape[-1]**0.5 (sqrt(2) ≈ 1.4142)
#      - Scaled scores = tensor([[-0.0002, -0.0001, -0.0001, -0.0000,  0.0004, -0.0003],
#                               [ 0.0003,  0.0001,  0.0001, -0.0000,  0.0003,  0.0001],
#                               [ 0.0003,  0.0001,  0.0001, -0.0000,  0.0002,  0.0001],
#                               [ 0.0001,  0.0000,  0.0000, -0.0000,  0.0001,  0.0000],
#                               [-0.0003, -0.0001, -0.0001, -0.0000,  0.0001, -0.0002],
#                               [ 0.0004,  0.0001,  0.0001, -0.0000,  0.0003,  0.0001]])
#    - Apply softmax along dim=-1:
#      - Row 0: [-0.0002, -0.0001, -0.0001, -0.0000,  0.0004, -0.0003]
#        - exp: [0.9998, 0.9999, 0.9999, 1.0000, 1.0004, 0.9997]
#        - Sum: 0.9998 + 0.9999 + 0.9999 + 1.0000 + 1.0004 + 0.9997 = 5.9997
#        - Weights: [0.1666, 0.1667, 0.1667, 0.1667, 0.1668, 0.1666]
#      - Similarly for other rows
#      - attn_weights = tensor([[0.1666, 0.1667, 0.1667, 0.1667, 0.1668, 0.1666],
#                              [0.1667, 0.1667, 0.1667, 0.1666, 0.1667, 0.1667],
#                              [0.1667, 0.1667, 0.1667, 0.1666, 0.1667, 0.1667],
#                              [0.1667, 0.1667, 0.1667, 0.1666, 0.1667, 0.1667],
#                              [0.1666, 0.1667, 0.1667, 0.1667, 0.1667, 0.1666],
#                              [0.1667, 0.1667, 0.1667, 0.1666, 0.1667, 0.1667]])
#    - Note: Due to large scaling (1/32), scores become very small, and softmax produces near-uniform weights

# 6. Print attn_weights
#    - Output: tensor([[0.1666, 0.1667, 0.1667, 0.1667, 0.1668, 0.1666],
#                     [0.1667, 0.1667, 0.1667, 0.1666, 0.1667, 0.1667],
#                     [0.1667, 0.1667, 0.1667, 0.1666, 0.1667, 0.1667],
#                     [0.1667, 0.1667, 0.1667, 0.1666, 0.1667, 0.1667],
#                     [0.1666, 0.1667, 0.1667, 0.1667, 0.1667, 0.1666],
#                     [0.1667, 0.1667, 0.1667, 0.1666, 0.1667, 0.1667]])

# === End of Dry Run ===

In [None]:
import torch

# Assume attn_weights from previous context
attn_weights = torch.tensor([[0.1666, 0.1667, 0.1667, 0.1667, 0.1668, 0.1666],
                            [0.1667, 0.1667, 0.1667, 0.1666, 0.1667, 0.1667],
                            [0.1667, 0.1667, 0.1667, 0.1666, 0.1667, 0.1667],
                            [0.1667, 0.1667, 0.1667, 0.1666, 0.1667, 0.1667],
                            [0.1666, 0.1667, 0.1667, 0.1667, 0.1667, 0.1666],
                            [0.1667, 0.1667, 0.1667, 0.1666, 0.1667, 0.1667]])

# Create causal mask
context_length = attn_weights.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

# Apply mask to attention weights
masked_simple = attn_weights * mask_simple

# Normalize masked weights
row_sums = masked_simple.sum(dim=1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

# === Structured Dry Run ===
# 1. Define context_length
#    - context_length = attn_weights.shape[0] = 6 (number of tokens)

# 2. Create causal mask
#    - torch.ones(context_length, context_length) = torch.ones(6, 6)
#      - Creates a 6x6 tensor of ones:
#        tensor([[1., 1., 1., 1., 1., 1.],
#                [1., 1., 1., 1., 1., 1.],
#                [1., 1., 1., 1., 1., 1.],
#                [1., 1., 1., 1., 1., 1.],
#                [1., 1., 1., 1., 1., 1.],
#                [1., 1., 1., 1., 1., 1.]])
#    - mask_simple = torch.tril(torch.ones(6, 6))
#      - Applies lower triangular mask (sets elements above diagonal to 0):
#        tensor([[1., 0., 0., 0., 0., 0.],
#                [1., 1., 0., 0., 0., 0.],
#                [1., 1., 1., 0., 0., 0.],
#                [1., 1., 1., 1., 0., 0.],
#                [1., 1., 1., 1., 1., 0.],
#                [1., 1., 1., 1., 1., 1.]])
#    - Shape: (6, 6)

# 3. Print mask_simple
#    - Output: tensor([[1., 0., 0., 0., 0., 0.],
#                     [1., 1., 0., 0., 0., 0.],
#                     [1., 1., 1., 0., 0., 0.],
#                     [1., 1., 1., 1., 0., 0.],
#                     [1., 1., 1., 1., 1., 0.],
#                     [1., 1., 1., 1., 1., 1.]])

# 4. Apply mask to attention weights
#    - masked_simple = attn_weights * mask_simple
#    - Element-wise multiplication:
#      - Row 0: [0.1666, 0.1667, 0.1667, 0.1667, 0.1668, 0.1666] * [1, 0, 0, 0, 0, 0] = [0.1666, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]
#      - Row 1: [0.1667, 0.1667, 0.1667, 0.1666, 0.1667, 0.1667] * [1, 1, 0, 0, 0, 0] = [0.1667, 0.1667, 0.0000, 0.0000, 0.0000, 0.0000]
#      - Row 2: [0.1667, 0.1667, 0.1667, 0.1666, 0.1667, 0.1667] * [1, 1, 1, 0, 0, 0] = [0.1667, 0.1667, 0.1667, 0.0000, 0.0000, 0.0000]
#      - Row 3: [0.1667, 0.1667, 0.1667, 0.1666, 0.1667, 0.1667] * [1, 1, 1, 1, 0, 0] = [0.1667, 0.1667, 0.1667, 0.1666, 0.0000, 0.0000]
#      - Row 4: [0.1666, 0.1667, 0.1667, 0.1667, 0.1667, 0.1666] * [1, 1, 1, 1, 1, 0] = [0.1666, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000]
#      - Row 5: [0.1667, 0.1667, 0.1667, 0.1666, 0.1667, 0.1667] * [1, 1, 1, 1, 1, 1] = [0.1667, 0.1667, 0.1667, 0.1666, 0.1667, 0.1667]
#    - masked_simple = tensor([[0.1666, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
#                              [0.1667, 0.1667, 0.0000, 0.0000, 0.0000, 0.0000],
#                              [0.1667, 0.1667, 0.1667, 0.0000, 0.0000, 0.0000],
#                              [0.1667, 0.1667, 0.1667, 0.1666, 0.0000, 0.0000],
#                              [0.1666, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000],
#                              [0.1667, 0.1667, 0.1667, 0.1666, 0.1667, 0.1667]])
#    - Shape: (6, 6)

# 5. Compute row sums
#    - row_sums = masked_simple.sum(dim=1, keepdim=True)
#    - Sum along dim=1 (columns), keep dimension for broadcasting:
#      - Row 0: 0.1666 = 0.1666
#      - Row 1: 0.1667 + 0.1667 = 0.3334
#      - Row 2: 0.1667 + 0.1667 + 0.1667 = 0.5001
#      - Row 3: 0.1667 + 0.1667 + 0.1667 + 0.1666 = 0.6667
#      - Row 4: 0.1666 + 0.1667 + 0.1667 + 0.1667 + 0.1667 = 0.8334
#      - Row 5: 0.1667 + 0.1667 + 0.1667 + 0.1666 + 0.1667 + 0.1667 = 1.0001
#    - row_sums = tensor([[0.1666],
#                         [0.3334],
#                         [0.5001],
#                         [0.6667],
#                         [0.8334],
#                         [1.0001]])
#    - Shape: (6, 1)

# 6. Normalize masked weights
#    - masked_simple_norm = masked_simple / row_sums
#    - Element-wise division with broadcasting:
#      - Row 0: [0.1666, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000] / 0.1666 = [1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]
#      - Row 1: [0.1667, 0.1667, 0.0000, 0.0000, 0.0000, 0.0000] / 0.3334 = [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000]
#      - Row 2: [0.1667, 0.1667, 0.1667, 0.0000, 0.0000, 0.0000] / 0.5001 = [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000]
#      - Row 3: [0.1667, 0.1667, 0.1667, 0.1666, 0.0000, 0.0000] / 0.6667 = [0.2500, 0.2500, 0.2500, 0.2499, 0.0000, 0.0000]
#      - Row 4: [0.1666, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000] / 0.8334 = [0.1999, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000]
#      - Row 5: [0.1667, 0.1667, 0.1667, 0.1666, 0.1667, 0.1667] / 1.0001 = [0.1667, 0.1667, 0.1667, 0.1666, 0.1667, 0.1667]
#    - masked_simple_norm = tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
#                                  [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000],
#                                  [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
#                                  [0.2500, 0.2500, 0.2500, 0.2499, 0.0000, 0.0000],
#                                  [0.1999, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000],
#                                  [0.1667, 0.1667, 0.1667, 0.1666, 0.1667, 0.1667]])
#    - Shape: (6, 6)

# 7. Print masked_simple_norm
#    - Output: tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
#                     [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000],
#                     [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
#                     [0.2500, 0.2500, 0.2500, 0.2499, 0.0000, 0.0000],
#                     [0.1999, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000],
#                     [0.1667, 0.1667, 0.1667, 0.1666, 0.1667, 0.1667]])

# === End of Dry Run ===

In [11]:
# We use the PyTorch tril function to create a mask where the values above the diagonal are zero 

context_length = attn_weights.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [12]:
masked_simple = attn_weights * mask_simple
masked_simple

tensor([[0.1677, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1682, 0.1667, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1682, 0.1667, 0.1667, 0.0000, 0.0000, 0.0000],
        [0.1675, 0.1667, 0.1667, 0.1662, 0.0000, 0.0000],
        [0.1674, 0.1667, 0.1667, 0.1663, 0.1666, 0.0000],
        [0.1678, 0.1667, 0.1667, 0.1661, 0.1667, 0.1661]],
       grad_fn=<MulBackward0>)

In [None]:
# Normalize the attention weights so that each row sums to 1.
# This can be done by dividing each element in a row by the sum of that row.

In [13]:
row_sums = masked_simple.sum(dim=1, keepdim=True)
masked_simple_norm = masked_simple/row_sums
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5023, 0.4977, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3353, 0.3323, 0.3323, 0.0000, 0.0000, 0.0000],
        [0.2511, 0.2498, 0.2499, 0.2492, 0.0000, 0.0000],
        [0.2008, 0.1999, 0.1999, 0.1995, 0.1999, 0.0000],
        [0.1678, 0.1667, 0.1667, 0.1661, 0.1667, 0.1661]],
       grad_fn=<DivBackward0>)


Another way of the mask out future tokens

In [None]:

# mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
# Structured Dry Run:
# 1. context_length = 6
# 2. Create mask = torch.triu(torch.ones(6, 6), diagonal=1):
#    - torch.ones(6, 6) creates a 6x6 tensor filled with 1s
#    - torch.triu(..., diagonal=1) keeps elements above the main diagonal (i.e., i < j) as 1, others as 0
#    mask = tensor([[0., 1., 1., 1., 1., 1.],
#                   [0., 0., 1., 1., 1., 1.],
#                   [0., 0., 0., 1., 1., 1.],
#                   [0., 0., 0., 0., 1., 1.],
#                   [0., 0., 0., 0., 0., 1.],
#                   [0., 0., 0., 0., 0., 0.]])

# masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
# 3. Convert mask to boolean: mask.bool() = tensor([[False,  True,  True,  True,  True,  True],
#                                                 [False, False,  True,  True,  True,  True],
#                                                 [False, False, False,  True,  True,  True],
#                                                 [False, False, False, False,  True,  True],
#                                                 [False, False, False, False, False,  True],
#                                                 [False, False, False, False, False, False]])
# 4. Apply masked_fill:
#    - Replace elements where mask.bool() is True with -torch.inf
#    - masked = attn_scores where mask is 0, else -torch.inf
#    masked = tensor([[ 1.2825,    -inf,    -inf,    -inf,    -inf,    -inf],
#                     [ 1.7507,  2.0249,    -inf,    -inf,    -inf,    -inf],
#                     [ 1.7269,  1.9966,  1.9817,    -inf,    -inf,    -inf],
#                     [ 0.5938,  1.1114,  1.1028,  0.5625,    -inf,    -inf],
#                     [ 0.5294,  0.9639,  0.9560,  0.4898,  0.5778,    -inf],
#                     [ 1.2447,  1.4388,  1.4287,  0.7274,  0.8839,  0.8610]])

# print(masked)
# 5. Print: tensor([[ 1.2825,    -inf,    -inf,    -inf,    -inf,    -inf],
#                  [ 1.7507,  2.0249,    -inf,    -inf,    -inf,    -inf],
#                  [ 1.7269,  1.9966,  1.9817,    -inf,    -inf,    -inf],
#                  [ 0.5938,  1.1114,  1.1028,  0.5625,    -inf,    -inf],
#                  [ 0.5294,  0.9639,  0.9560,  0.4898,  0.5778,    -inf],
#                  [ 1.2447,  1.4388,  1.4287,  0.7274,  0.8839,  0.8610]])

# Structured Dry Run Summary:
# 1. context_length = 6
# 2. Create mask = torch.triu(torch.ones(6, 6), diagonal=1):
#    - mask = [[0., 1., 1., 1., 1., 1.],
#              [0., 0., 1., 1., 1., 1.],
#              [0., 0., 0., 1., 1., 1.],
#              [0., 0., 0., 0., 1., 1.],
#              [0., 0., 0., 0., 0., 1.],
#              [0., 0., 0., 0., 0., 0.]]
# 3. Convert to boolean: mask.bool() = [[False, True, True, True, True, True], ...]
# 4. Compute masked = attn_scores.masked_fill(mask.bool(), -torch.inf):
#    - masked = [[ 1.2825,    -inf,    -inf,    -inf,    -inf,    -inf],
#                [ 1.7507,  2.0249,    -inf,    -inf,    -inf,    -inf],
#                [ 1.7269,  1.9966,  1.9817,    -inf,    -inf,    -inf],
#                [ 0.5938,  1.1114,  1.1028,  0.5625,    -inf,    -inf],
#                [ 0.5294,  0.9639,  0.9560,  0.4898,  0.5778,    -inf],
#                [ 1.2447,  1.4388,  1.4287,  0.7274,  0.8839,  0.8610]]
# 5. Print: tensor([[ 1.2825,    -inf,    -inf,    -inf,    -inf,    -inf], ...])

In [14]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)


In [15]:
# apply softmax

attn_weights = torch.softmax(masked/keys.shape[0]**0.5, dim=1)
attn_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5299, 0.4701, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3599, 0.3199, 0.3202, 0.0000, 0.0000, 0.0000],
        [0.2647, 0.2478, 0.2479, 0.2395, 0.0000, 0.0000],
        [0.2100, 0.1991, 0.1991, 0.1935, 0.1983, 0.0000],
        [0.1818, 0.1666, 0.1667, 0.1595, 0.1667, 0.1587]],
       grad_fn=<SoftmaxBackward0>)

In [18]:
# Dropouts
torch.manual_seed(123)
dropouts=torch.nn.Dropout(0.5)
example=torch.ones(6,6)
print(dropouts(example))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


#### **Multiple Batches Input**

In [19]:
batch = torch.stack((inputs, inputs), dim=0)
batch.shape

torch.Size([2, 6, 3])

In [22]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        # causal mask (upper triangular)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        # x: (batch_size, num_tokens, d_in)
        batch_size, num_tokens, d_in = x.shape

        # project input into queries, keys, values
        keys = self.W_key(x)      # (batch, num_tokens, d_out)
        queries = self.W_query(x) # (batch, num_tokens, d_out)
        values = self.W_value(x)  # (batch, num_tokens, d_out)

        # attention scores: (batch, num_tokens, num_tokens)
        attn_scores = queries @ keys.transpose(1, 2)

        # apply causal mask
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
        )

        # normalize with softmax
        attn_weights = torch.softmax(attn_scores / (keys.shape[-1] ** 0.5), dim=-1)
        attn_weights = self.dropout(attn_weights)

        # weighted sum of values
        context_vec = attn_weights @ values

        return context_vec


In [24]:
print(d_in), print(d_out)

3
2


(None, None)

In [25]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)

print("Context vector shape: ", context_vecs.shape)

Context vector shape:  torch.Size([2, 6, 2])


In [26]:
context_vecs[0]

tensor([[-0.4519,  0.2216],
        [-0.5874,  0.0058],
        [-0.6300, -0.0632],
        [-0.5675, -0.0843],
        [-0.5526, -0.0981],
        [-0.5299, -0.1081]], grad_fn=<SelectBackward0>)

In [27]:
context_vecs[1]

tensor([[-0.4519,  0.2216],
        [-0.5874,  0.0058],
        [-0.6300, -0.0632],
        [-0.5675, -0.0843],
        [-0.5526, -0.0981],
        [-0.5299, -0.1081]], grad_fn=<SelectBackward0>)