Import PyTorch modules needed for the class.

In [1]:
import torch
import torch.nn as nn

Define parameters: batch size (b), sequence length (num_tokens), input dimension (d_in), output dimension (d_out), number of heads (num_heads), and head dimension (head_dim = d_out // num_heads). Initialize dropout as an nn.Dropout module with the specified rate (0.0 here, meaning no dropout applied, but the module is ready for non-zero rates). 
 
Efficiency: Defining once avoids recreating per batch, and dropout doesn’t add matrix multiplications, preserving the efficiency of the batched attention mechanism (one multiplication for scores, one for weighted sum).

In [3]:
# Input shape
b = 2  # Batch size
num_tokens = 3  # Sequence length (context length)
d_in = 4  # Input embedding dimension

# Output shape
d_out = 6  # Output dimension (must be divisible by num_heads)
num_heads = 3  # Number of attention heads
head_dim = d_out // num_heads  # Per-head dimension (2 here)

dropout = nn.Dropout(0)  # Dropout rate (0 for no dropout)
qkv_bias = False  # No bias in projections

Generate a random input tensor x with shape (b, num_tokens, d_in), representing batched token embeddings. All computations start from x; its columns (d_in) must match the input size of projection layers for matrix multiplication. 

Efficiency: Random data allows testing without real datasets, and small sizes (like 3 tokens) keep computations fast while demonstrating

In [4]:
# Sample input: batched random embeddings
x = torch.randn(b, num_tokens, d_in)
print("Input x shape:", x.shape)
print("Input x:\n", x)

Input x shape: torch.Size([2, 3, 4])
Input x:
 tensor([[[ 0.5846, -2.4947,  1.4039,  1.6177],
         [ 1.0283,  0.4905, -0.9044,  1.4440],
         [ 0.2894, -0.3367, -1.3002,  1.1622]],

        [[ 1.5172,  0.1696, -0.0997, -0.7608],
         [-0.5116,  1.0167,  0.2778, -0.4208],
         [-0.5266,  1.3310, -0.2707,  0.5006]]])


Create linear projection variables for queries (W_query), keys (W_key), values (W_value), and output (out_proj). Also create the causal mask as an upper-triangular tensor. Projections transform x to attention spaces, and the mask prevents attending to future tokens. Projections change dimensions (from d_in to d_out)

Efficiency: Shared projections (one per Q/K/V) are cheaper than per-head ones; mask is computed once, avoiding recreation per batch.

In [5]:
# Projection layers
W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
out_proj = nn.Linear(d_out, d_out)  # For combining heads later

# Causal mask (upper triangular, 1s above diagonal)
# We won't use register buffer because this won't be transferred to another gpu
mask = torch.triu(torch.ones(num_tokens, num_tokens), diagonal=1)

Apply the linear projections to x to get keys, queries, and values, each with shape (b, num_tokens, d_out). This transforms the input into specialized representations for attention (queries search, keys are searched, values are aggregated). The full d_out is concatenated across heads. For matrix multiplication in projections: x's columns (d_in) match W's rows. 

Efficiency: Only 3 multiplications here (one per projection) vs. naive per-head (3 * num_heads=9); shared across heads reduces parameters/compute in large LLMs.

In [6]:
# Project input to keys, queries, values
keys = W_key(x)  # Shape: (b, num_tokens, d_out)
print("Keys shape:", keys.shape)
print("Keys:\n", keys)

Keys shape: torch.Size([2, 3, 6])
Keys:
 tensor([[[-1.3128, -1.1525, -0.0044,  0.3529, -1.0175, -0.4947],
         [-0.0478,  0.2416,  0.3827, -0.2477,  0.8610, -0.1500],
         [ 0.3607,  0.4046,  0.1977, -0.3166,  0.5330, -0.0159]],

        [[-0.3900,  0.0901,  0.8316,  0.0871, -0.3479, -0.0886],
         [ 0.1938,  0.0591, -0.2975,  0.0132,  0.3347,  0.1114],
         [ 0.3892,  0.2314, -0.3749, -0.1714,  0.9555,  0.0942]]],
       grad_fn=<UnsafeViewBackward0>)


In [8]:
queries = W_query(x)
print("Queries shape:", queries.shape)
print("Queries:\n", queries)

Queries shape: torch.Size([2, 3, 6])
Queries:
 tensor([[[-0.7678, -1.3186, -0.2261, -0.1504, -1.0803, -0.5805],
         [ 0.9581,  0.6228, -0.7427, -0.2901,  0.8595, -1.0349],
         [ 0.3950,  0.3857, -0.2860,  0.1255,  0.3254, -0.5456]],

        [[ 0.6956,  1.0206, -0.2647,  0.7459,  0.4569, -0.4890],
         [ 0.0291, -0.0745,  0.0721, -0.4411,  0.2085,  0.3205],
         [ 0.3327, -0.0129, -0.2063, -0.8172,  0.5573, -0.0188]]],
       grad_fn=<UnsafeViewBackward0>)


In [9]:
values = W_value(x)
print("Values shape:", values.shape)
print("Values:\n", values)

Values shape: torch.Size([2, 3, 6])
Values:
 tensor([[[ 0.5184, -0.1331, -0.0145,  0.2668, -1.0190, -0.0328],
         [ 0.0262,  1.5439, -0.2288, -0.3577, -0.6426,  1.0325],
         [ 0.1131,  0.8719, -0.1355, -0.1839, -0.3956,  0.4536]],

        [[-0.9401,  0.3831, -0.6437,  0.0657, -0.0758,  0.5278],
         [ 0.1385, -0.1387,  0.2486, -0.0979,  0.3257, -0.1112],
         [ 0.4374,  0.5414,  0.3435, -0.3193,  0.0749,  0.2452]]],
       grad_fn=<UnsafeViewBackward0>)


Use .view() to reshape keys, values, and queries by splitting the last dimension (d_out) into num_heads and head_dim, resulting in shape (b, num_tokens, num_heads, head_dim). This divides the projected tensors into separate heads without copying data (cheap operation). Allows each head to compute attention independently in its own low-dimensional space. 

In [10]:
# We implicitly split the matrix by adding a `num_heads` dimension
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        
# Reshape to separate heads (add num_heads dimension)
keys = keys.view(b, num_tokens, num_heads, head_dim)
print("Keys after view shape:", keys.shape)
print("Keys after view:\n", keys)

Keys after view shape: torch.Size([2, 3, 3, 2])
Keys after view:
 tensor([[[[-1.3128, -1.1525],
          [-0.0044,  0.3529],
          [-1.0175, -0.4947]],

         [[-0.0478,  0.2416],
          [ 0.3827, -0.2477],
          [ 0.8610, -0.1500]],

         [[ 0.3607,  0.4046],
          [ 0.1977, -0.3166],
          [ 0.5330, -0.0159]]],


        [[[-0.3900,  0.0901],
          [ 0.8316,  0.0871],
          [-0.3479, -0.0886]],

         [[ 0.1938,  0.0591],
          [-0.2975,  0.0132],
          [ 0.3347,  0.1114]],

         [[ 0.3892,  0.2314],
          [-0.3749, -0.1714],
          [ 0.9555,  0.0942]]]], grad_fn=<ViewBackward0>)


In [14]:
values = values.view(b, num_tokens, num_heads, head_dim)
print("Values after view shape:", values.shape)
print("Values after view:\n", values)

Values after view shape: torch.Size([2, 3, 3, 2])
Values after view:
 tensor([[[[ 0.5184, -0.1331],
          [-0.0145,  0.2668],
          [-1.0190, -0.0328]],

         [[ 0.0262,  1.5439],
          [-0.2288, -0.3577],
          [-0.6426,  1.0325]],

         [[ 0.1131,  0.8719],
          [-0.1355, -0.1839],
          [-0.3956,  0.4536]]],


        [[[-0.9401,  0.3831],
          [-0.6437,  0.0657],
          [-0.0758,  0.5278]],

         [[ 0.1385, -0.1387],
          [ 0.2486, -0.0979],
          [ 0.3257, -0.1112]],

         [[ 0.4374,  0.5414],
          [ 0.3435, -0.3193],
          [ 0.0749,  0.2452]]]], grad_fn=<ViewBackward0>)


In [13]:
queries = queries.view(b, num_tokens, num_heads, head_dim)
print("Queries after view shape:", queries.shape)
print("Queries after view:\n", queries)

Queries after view shape: torch.Size([2, 3, 3, 2])
Queries after view:
 tensor([[[[-0.7678, -1.3186],
          [-0.2261, -0.1504],
          [-1.0803, -0.5805]],

         [[ 0.9581,  0.6228],
          [-0.7427, -0.2901],
          [ 0.8595, -1.0349]],

         [[ 0.3950,  0.3857],
          [-0.2860,  0.1255],
          [ 0.3254, -0.5456]]],


        [[[ 0.6956,  1.0206],
          [-0.2647,  0.7459],
          [ 0.4569, -0.4890]],

         [[ 0.0291, -0.0745],
          [ 0.0721, -0.4411],
          [ 0.2085,  0.3205]],

         [[ 0.3327, -0.0129],
          [-0.2063, -0.8172],
          [ 0.5573, -0.0188]]]], grad_fn=<ViewBackward0>)


Transpose the sequence (num_tokens) and heads dimensions to get shape (b, num_heads, num_tokens, head_dim). This positions num_heads as a batch-like dimension, allowing PyTorch to compute operations in parallel across heads without loops. Prepares for efficient matrix multiplication (in attention scores, queries' last dim head_dim will match keys.T's rows). 

Efficiency: Transpose is fast; it enables one vectorized multiplication later instead of num_heads separate ones—vital for LLMs to minimize calculations on huge matrices. Basically changing parameters positions to make it suitable for matrix multiplication

In [15]:
# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
# Transpose to prepare for batched matrix multiplication (heads as batch dim)
keys = keys.transpose(1, 2)
print("Keys after transpose shape:", keys.shape)
print("Keys after transpose:\n", keys)

Keys after transpose shape: torch.Size([2, 3, 3, 2])
Keys after transpose:
 tensor([[[[-1.3128, -1.1525],
          [-0.0478,  0.2416],
          [ 0.3607,  0.4046]],

         [[-0.0044,  0.3529],
          [ 0.3827, -0.2477],
          [ 0.1977, -0.3166]],

         [[-1.0175, -0.4947],
          [ 0.8610, -0.1500],
          [ 0.5330, -0.0159]]],


        [[[-0.3900,  0.0901],
          [ 0.1938,  0.0591],
          [ 0.3892,  0.2314]],

         [[ 0.8316,  0.0871],
          [-0.2975,  0.0132],
          [-0.3749, -0.1714]],

         [[-0.3479, -0.0886],
          [ 0.3347,  0.1114],
          [ 0.9555,  0.0942]]]], grad_fn=<TransposeBackward0>)


In [17]:
queries = queries.transpose(1, 2)
print("Queries after transpose shape:", queries.shape)
print("Queries after transpose:\n", queries)

Queries after transpose shape: torch.Size([2, 3, 3, 2])
Queries after transpose:
 tensor([[[[-0.7678, -1.3186],
          [-0.2261, -0.1504],
          [-1.0803, -0.5805]],

         [[ 0.9581,  0.6228],
          [-0.7427, -0.2901],
          [ 0.8595, -1.0349]],

         [[ 0.3950,  0.3857],
          [-0.2860,  0.1255],
          [ 0.3254, -0.5456]]],


        [[[ 0.6956,  1.0206],
          [-0.2647,  0.7459],
          [ 0.4569, -0.4890]],

         [[ 0.0291, -0.0745],
          [ 0.0721, -0.4411],
          [ 0.2085,  0.3205]],

         [[ 0.3327, -0.0129],
          [-0.2063, -0.8172],
          [ 0.5573, -0.0188]]]], grad_fn=<TransposeBackward0>)


In [18]:
values = values.transpose(1, 2)
print("Values after transpose shape:", values.shape)
print("Values after transpose:\n", values)

Values after transpose shape: torch.Size([2, 3, 3, 2])
Values after transpose:
 tensor([[[[ 0.5184, -0.1331],
          [ 0.0262,  1.5439],
          [ 0.1131,  0.8719]],

         [[-0.0145,  0.2668],
          [-0.2288, -0.3577],
          [-0.1355, -0.1839]],

         [[-1.0190, -0.0328],
          [-0.6426,  1.0325],
          [-0.3956,  0.4536]]],


        [[[-0.9401,  0.3831],
          [ 0.1385, -0.1387],
          [ 0.4374,  0.5414]],

         [[-0.6437,  0.0657],
          [ 0.2486, -0.0979],
          [ 0.3435, -0.3193]],

         [[-0.0758,  0.5278],
          [ 0.3257, -0.1112],
          [ 0.0749,  0.2452]]]], grad_fn=<TransposeBackward0>)


Calculate attn_scores as the scaled dot-product between queries and transposed keys: queries @ keys.transpose(2, 3). Shape: (b, num_heads, num_tokens, num_tokens). This measures similarity between tokens per head. The transpose(2,3) flips keys' last two dims to get K^T for dot-product. 

Efficiency: Thanks to prior reshape/transpose, this is one batched multiplication across all heads—no loops. Naive approach: num_heads multiplications (inefficient for LLMs with 32+ heads, as extra calls slow down on large matrices).

In [19]:
# Compute scaled dot-product attention (aka self-attention) with a causal mask
# Compute raw attention scores (dot-product per head)
attn_scores = queries @ keys.transpose(2, 3)  # Batched matmul
print("Attention scores shape:", attn_scores.shape)
print("Attention scores:\n", attn_scores)

Attention scores shape: torch.Size([2, 3, 3, 3])
Attention scores:
 tensor([[[[ 2.5277e+00, -2.8185e-01, -8.1044e-01],
          [ 4.7021e-01, -2.5526e-02, -1.4241e-01],
          [ 2.0873e+00, -8.8567e-02, -6.2450e-01]],

         [[ 2.1561e-01,  2.1245e-01, -7.7876e-03],
          [-9.9141e-02, -2.1240e-01, -5.4949e-02],
          [-3.6900e-01,  5.8528e-01,  4.9750e-01]],

         [[-5.9269e-01,  2.8222e-01,  2.0436e-01],
          [ 2.2891e-01, -2.6508e-01, -1.5443e-01],
          [-6.1238e-02,  3.6205e-01,  1.8214e-01]]],


        [[[-1.7929e-01,  1.9518e-01,  5.0693e-01],
          [ 1.7044e-01, -7.2047e-03,  6.9587e-02],
          [-2.2228e-01,  5.9661e-02,  6.4687e-02]],

         [[ 1.7685e-02, -9.6325e-03,  1.8700e-03],
          [ 2.1516e-02, -2.7273e-02,  4.8576e-02],
          [ 2.0132e-01, -5.7797e-02, -1.3308e-01]],

         [[-1.1459e-01,  1.0990e-01,  3.1666e-01],
          [ 1.4421e-01, -1.6009e-01, -2.7410e-01],
          [-1.9220e-01,  1.8442e-01,  5.3070e-01]]]],

Apply the mask to current sequence length, convert to boolean, and fill masked positions in attn_scores with -inf. This prevents tokens from attending to future ones. Without masking, the model would "cheat" by seeing ahead; -inf ensures softmax ignores them. 

Efficiency: In-place fill (masked_fill_) is memory-efficient; precomputed mask avoids recalculating per batch, saving time in LLMs with repeated forward passes.

In [20]:
# Original mask truncated to the number of tokens and converted to boolean
# Apply causal mask (truncate and convert to bool)
mask_bool = mask.bool()[:num_tokens, :num_tokens]
print("Mask shape:", mask_bool.shape)
print("Mask:\n", mask_bool)

Mask shape: torch.Size([3, 3])
Mask:
 tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])


In [22]:
# Fill scores with -inf where masked
attn_scores.masked_fill_(mask_bool, -torch.inf)
print("Masked attention scores shape:", attn_scores.shape)
print("Masked attention scores:\n", attn_scores)

Masked attention scores shape: torch.Size([2, 3, 3, 3])
Masked attention scores:
 tensor([[[[ 2.5277,    -inf,    -inf],
          [ 0.4702, -0.0255,    -inf],
          [ 2.0873, -0.0886, -0.6245]],

         [[ 0.2156,    -inf,    -inf],
          [-0.0991, -0.2124,    -inf],
          [-0.3690,  0.5853,  0.4975]],

         [[-0.5927,    -inf,    -inf],
          [ 0.2289, -0.2651,    -inf],
          [-0.0612,  0.3620,  0.1821]]],


        [[[-0.1793,    -inf,    -inf],
          [ 0.1704, -0.0072,    -inf],
          [-0.2223,  0.0597,  0.0647]],

         [[ 0.0177,    -inf,    -inf],
          [ 0.0215, -0.0273,    -inf],
          [ 0.2013, -0.0578, -0.1331]],

         [[-0.1146,    -inf,    -inf],
          [ 0.1442, -0.1601,    -inf],
          [-0.1922,  0.1844,  0.5307]]]], grad_fn=<MaskedFillBackward0>)


Calc attn_weights by attn_scores / sqrt(keys.shape) and apply softmax along the last dimension to convert scores to weights, and apply the pre-initialized dropout module. Softmax normalizes scores into attention weights (summing to 1 per row); dropout (at 0.0) regularizes by randomly zeroing elements during training. 

Efficiency: Using the pre-initialized dropout avoids recreating the module, unlike the naive approach of looping over heads (which would repeat operations). 

In [23]:
# Compute attention weights (scaled softmax)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print("Attention weights after softmax shape:", attn_weights.shape)
print("Attention weights after softmax:\n", attn_weights)

Attention weights after softmax shape: torch.Size([2, 3, 3, 3])
Attention weights after softmax:
 tensor([[[[1.0000, 0.0000, 0.0000],
          [0.5867, 0.4133, 0.0000],
          [0.7344, 0.1577, 0.1079]],

         [[1.0000, 0.0000, 0.0000],
          [0.5200, 0.4800, 0.0000],
          [0.2079, 0.4083, 0.3837]],

         [[1.0000, 0.0000, 0.0000],
          [0.5864, 0.4136, 0.0000],
          [0.2827, 0.3814, 0.3358]]],


        [[[1.0000, 0.0000, 0.0000],
          [0.5314, 0.4686, 0.0000],
          [0.2902, 0.3543, 0.3555]],

         [[1.0000, 0.0000, 0.0000],
          [0.5086, 0.4914, 0.0000],
          [0.3814, 0.3175, 0.3011]],

         [[1.0000, 0.0000, 0.0000],
          [0.5536, 0.4464, 0.0000],
          [0.2517, 0.3286, 0.4197]]]], grad_fn=<SoftmaxBackward0>)


In [None]:
# Apply dropout using pre-initialized module (rate=0.0)
attn_weights = dropout(attn_weights)
print("After dropout shape:", attn_weights.shape)
print("After dropout:\n", attn_weights)

After dropout shape: torch.Size([2, 3, 3, 3])
After dropout:
 tensor([[[[1.0000, 0.0000, 0.0000],
          [0.5867, 0.4133, 0.0000],
          [0.7344, 0.1577, 0.1079]],

         [[1.0000, 0.0000, 0.0000],
          [0.5200, 0.4800, 0.0000],
          [0.2079, 0.4083, 0.3837]],

         [[1.0000, 0.0000, 0.0000],
          [0.5864, 0.4136, 0.0000],
          [0.2827, 0.3814, 0.3358]]],


        [[[1.0000, 0.0000, 0.0000],
          [0.5314, 0.4686, 0.0000],
          [0.2902, 0.3543, 0.3555]],

         [[1.0000, 0.0000, 0.0000],
          [0.5086, 0.4914, 0.0000],
          [0.3814, 0.3175, 0.3011]],

         [[1.0000, 0.0000, 0.0000],
          [0.5536, 0.4464, 0.0000],
          [0.2517, 0.3286, 0.4197]]]], grad_fn=<SoftmaxBackward0>)


attn_weights @ values, then transpose back to (b, num_tokens, num_heads, head_dim) to aggregate relevant values based on attention weights per head; transpose prepares for combining heads. 

Efficiency: One batched multiplication here; naive: num_heads muls. 

In [26]:
# Shape: (b, num_tokens, num_heads, head_dim)
# Weighted sum of values and transpose back
context_vec = (attn_weights @ values).transpose(1, 2)
print("Context vec after matrix multiplication and transpose shape:", context_vec.shape)
print("Context vec after matrix multiplication and transpose:\n", context_vec)

Context vec after matrix multiplication and transpose shape: torch.Size([2, 3, 3, 2])
Context vec after matrix multiplication and transpose:
 tensor([[[[ 0.5184, -0.1331],
          [-0.0145,  0.2668],
          [-1.0190, -0.0328]],

         [[ 0.3150,  0.5600],
          [-0.1174, -0.0329],
          [-0.8633,  0.4078]],

         [[ 0.3971,  0.2398],
          [-0.1484, -0.1611],
          [-0.6661,  0.5369]]],


        [[[-0.9401,  0.3831],
          [-0.6437,  0.0657],
          [-0.0758,  0.5278]],

         [[-0.4346,  0.1385],
          [-0.2052, -0.0147],
          [ 0.1035,  0.2425]],

         [[-0.0683,  0.2545],
          [-0.0631, -0.1022],
          [ 0.1194,  0.1993]]]], grad_fn=<TransposeBackward0>)


We used .contigous here to recombine (concatenate) all heads, then .view() to flatten heads back to d_out (shape (b, num_tokens, d_out)), and apply out_proj; projection mixes it for better expressivity. 

Efficiency: Contiguous+view is fast (no mul); one final multiplication in projection.

In [27]:
# Combine heads, where self.d_out = self.num_heads * self.head_dim
# Combine heads (flatten back to d_out)
context_vec = context_vec.contiguous().view(b, num_tokens, d_out)
print("Context vec after view shape:", context_vec.shape)
print("Context vec after view:\n", context_vec)

Context vec after view shape: torch.Size([2, 3, 6])
Context vec after view:
 tensor([[[ 0.5184, -0.1331, -0.0145,  0.2668, -1.0190, -0.0328],
         [ 0.3150,  0.5600, -0.1174, -0.0329, -0.8633,  0.4078],
         [ 0.3971,  0.2398, -0.1484, -0.1611, -0.6661,  0.5369]],

        [[-0.9401,  0.3831, -0.6437,  0.0657, -0.0758,  0.5278],
         [-0.4346,  0.1385, -0.2052, -0.0147,  0.1035,  0.2425],
         [-0.0683,  0.2545, -0.0631, -0.1022,  0.1194,  0.1993]]],
       grad_fn=<ViewBackward0>)


In [28]:
# Optional output projection
context_vec = out_proj(context_vec)
print("Final output shape:", context_vec.shape)
print("Final output:\n", context_vec)

Final output shape: torch.Size([2, 3, 6])
Final output:
 tensor([[[ 0.3552,  0.2756, -0.2894,  0.2334, -0.6752, -0.3562],
         [ 0.4371,  0.2421, -0.0093,  0.6206, -0.9919, -0.2705],
         [ 0.4519,  0.0296,  0.0292,  0.5685, -0.9382, -0.3079]],

        [[ 0.9782,  0.5996,  0.2588,  0.4722, -0.5337, -0.0014],
         [ 0.6233,  0.2310,  0.1319,  0.2327, -0.3441,  0.0383],
         [ 0.4272,  0.0567,  0.1532,  0.2812, -0.4479,  0.0405]]],
       grad_fn=<ViewBackward0>)
