In [16]:
# Examples for understanding attention input structure in transformer models

import torch

# Set manual seed for reproducibility of random values
torch.manual_seed(1337)

# ----------------------------
# Define input tensor dimensions
# ----------------------------

B = 4   # Batch size: number of independent sequences processed in parallel
T = 8   # Sequence length (Time steps): number of tokens in each sequence
C = 2   # Number of channels (features per token): e.g., embedding dimension per token

# ----------------------------
# Create a random input tensor
# ----------------------------

# Simulate input data typically seen by attention layers: (Batch, Time, Channels)
# Each token in each sequence is represented as a vector of size C
x = torch.randn(B, T, C)

# Display the shape of the input tensor
# Expected: (4 sequences, each of length 8 tokens, with 2 features per token)
x.shape


torch.Size([4, 8, 2])

In [17]:
# Initialize an empty tensor to store "bag-of-words" representations
# Shape: (B, T, C) — for each batch and time step, we will compute a mean over the previous tokens
xbow = torch.zeros((B, T, C))

# Iterate over the batch dimension
for b in range(B):
    
    # Iterate over the sequence (time) dimension
    for t in range(T):
        
        # Extract all previous tokens (from position 0 to t, inclusive) for the current sequence
        # This produces a sub-tensor of shape (t+1, C) — a sequence of token vectors
        xprev = x[b, :t+1]
        
        # Compute the mean vector over the time dimension (averaging all previous token embeddings)
        # Result is a single vector of shape (C,) — representing the average "context" up to time t
        xbow[b, t] = torch.mean(xprev, dim=0)


In [18]:
# Inspect the first sequence in the batch (batch index 0)
# This returns a tensor of shape (T, C), representing all tokens in that sequence
# While useful for debugging, it doesn't reveal batch variation since it shows only one example
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [19]:
# Examine the first sequence in the output tensor (xbow[0])
# Note: The first two vectors may appear similar or identical because they are averaging only 1 or 2 tokens respectively.
# 
# For example, the vector at position t=2:
#     [ 0.3504, -0.2238]
# is the result of averaging the first three token vectors in the original input x[0][:3], which were:
#     [[ 0.1808, -0.0700],
#      [-0.3596, -0.9152],
#      [ 0.6258,  0.0255]]
#
# This demonstrates that xbow is a causal aggregation of previous embeddings — each position t
# contains the average of all token vectors from position 0 up to t (inclusive).
xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In [20]:
# Version 2: Use matrix multiplication to compute causal (left-to-right) weighted averages efficiently

# Create a lower-triangular matrix of shape (T, T) to simulate a causal attention mask
# Each row 't' allows attention only to tokens 0 through t (inclusive)
wei = torch.tril(torch.ones(T, T))  # Shape: (T, T)

# Normalize each row so that the weights sum to 1 (turning the mask into an averaging filter)
wei = wei / wei.sum(dim=1, keepdim=True)  # Still (T, T), now each row is a distribution over time steps

# Apply the weight matrix to the input x using batch matrix multiplication
# x has shape (B, T, C): batch of sequences with T tokens and C features each
# wei has shape (T, T) and is broadcasted over the batch
# Result xbow2 will have shape (B, T, C): each token is now the mean of all previous tokens
xbow2 = wei @ x  # Efficient causal aggregation using matrix multiply

# Verify that this result matches the previous (loop-based) implementation
# This should return True if both approaches produce the same output
torch.allclose(xbow, xbow2)


False

In [22]:
xbow[0], xbow2[0]

(tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]),
 tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]))

In [24]:
import torch.nn.functional as F

# Version 3: Use softmax-based attention weights (learned-like behavior)

# Create a lower-triangular mask (T x T) to enforce causality
# This ensures each position only attends to itself and previous tokens
tril = torch.tril(torch.ones(T, T))

# Initialize an attention weight matrix with zeros
wei = torch.zeros((T, T))

# Apply causal masking:
# All elements above the diagonal are set to -inf so softmax will assign them zero probability
wei = wei.masked_fill(tril == 0, float('-inf'))

# Apply softmax over the last dimension (each row) to convert logits into a probability distribution
# Now each row sums to 1, but unlike uniform averaging, weights are non-uniform and adaptive
wei = F.softmax(wei, dim=-1)

# Apply the attention weights to the input x via matrix multiplication
# x: shape (B, T, C)
# wei: shape (T, T), broadcast over batch
# Output xbow3: each token is a weighted combination of all previous tokens (causally)
xbow3 = wei @ x

# Check whether this result still matches xbow (uniform averaging from version 1)
# It likely won't match exactly anymore, because softmax weights differ from uniform weights
torch.allclose(xbow, xbow3)


False

In [13]:
# Demonstration: Efficiently compute causal (left-to-right) averages using matrix multiplication

# Set manual seed for reproducibility
torch.manual_seed(42)

# Create a lower-triangular 3x3 matrix filled with 1s
# This simulates a causal mask: each position sees only itself and previous positions
a = torch.tril(torch.ones(3, 3))  # Shape: (3, 3)

# Normalize each row to turn the matrix into a causal averaging kernel
# Now each row i averages over the first i+1 elements
a = a / torch.sum(a, dim=1, keepdim=True)  # Still shape (3, 3)

# Create a matrix `b` of 3 token embeddings, each with 2 features (random integers from 0 to 9)
b = torch.randint(0, 10, (3, 2)).float()  # Shape: (3, 2)

# Apply matrix multiplication to compute causal averages of embeddings
# 'a' (3x3) is the averaging matrix
# 'b' (3x2) contains token vectors
# Result 'c' (3x2) is the causal average of embeddings:
#     - Row 0 of `c` is just token 0
#     - Row 1 of `c` is the average of tokens 0 and 1
#     - Row 2 of `c` is the average of tokens 0, 1, and 2
c = a @ b  # Shape: (3, 2)

# Print all matrices for inspection
print('a = (causal averaging weights)')
print(a)
print('--')
print('b = (original token embeddings)')
print(b)
print('--')
print('c = (causal averaged embeddings via matrix multiplication)')
print(c)


a = (causal averaging weights)
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b = (original token embeddings)
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c = (causal averaged embeddings via matrix multiplication)
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [25]:
# Version 4: Implement a basic self-attention head with learnable parameters
import torch
import torch.nn as nn
import torch.nn.functional as F

# Set random seed for reproducibility
torch.manual_seed(1337)

# Define tensor dimensions
B, T, C = 4, 8, 32  # B: batch size, T: sequence length, C: embedding dimension per token

# Create a random input tensor representing a batch of token sequences
# Shape: (B, T, C)
x = torch.randn(B, T, C)

In [26]:
# Define the size of the attention head (i.e., dimensionality of query/key/value projections)
head_size = 16

# Create three linear layers to project the input embeddings into query, key, and value vectors
# These are learnable transformations
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

# Project the input x into keys and queries
# Both will have shape (B, T, head_size)
k = key(x)    # Key vectors per token
q = query(x)  # Query vectors per token


In [27]:
# Compute attention scores via scaled dot-product attention
# Matrix multiply each query with all keys (transposed), for each token
# Shape: (B, T, head_size) @ (B, head_size, T) → (B, T, T)
# Each [b, t1, t2] element represents how much token t1 attends to token t2
wei = q @ k.transpose(-2, -1)

# Create a lower-triangular mask to enforce causality (autoregressive attention)
# Shape: (T, T) — broadcasted across batch
tril = torch.tril(torch.ones(T, T))

# Mask out future positions by setting them to -inf (so softmax will zero them out)
wei = wei.masked_fill(tril == 0, float('-inf'))

# Apply softmax to turn attention scores into attention weights (probabilities)
# Each row now sums to 1 — represents how much to attend to each token in the past
wei = F.softmax(wei, dim=-1)


In [28]:
# Project the input into value vectors (same shape as q and k: B, T, head_size)
v = value(x)

# Perform the attention-weighted aggregation of value vectors
# Shape: (B, T, T) @ (B, T, head_size) → (B, T, head_size)
# Each output vector is a weighted sum of the value vectors from all previous tokens
out = wei @ v

# Alternative (commented): if you used `out = wei @ x`, you’d be directly mixing raw input embeddings
# That would not learn useful structure compared to projecting into values
# out = wei @ x


In [29]:
# Check the shape of the output
# Expected: (B, T, head_size) → one new vector per token, per sequence
out.shape


torch.Size([4, 8, 16])

In [30]:
# Inspect the attention weight matrix for the first example in the batch (batch index 0)
# Shape of wei: (B, T, T), where:
#   B = batch size
#   T = sequence length
#
# Each wei[b] is a (T, T) matrix where:
#   - Row i contains the attention weights (probabilities) used to compute the output for token i
#   - Column j in row i represents how much token i attends to token j
#
# By printing wei[0], we're examining the full attention pattern for the first sequence in the batch,
# which helps us visualize and debug how each token in the sequence attends to its past context
wei[0]


tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)