To understand workings of positional embeddings

In [2]:
import torch
import torch.nn as nn

# Set up the same parameters as in the original code
block_size = 8
n_embd = 32

# Create the position embedding table
position_embedding_table = nn.Embedding(block_size, n_embd)

print("=== Understanding Position Embeddings ===")
print(f"Embedding table shape: {position_embedding_table.weight.shape}")
print(f"This can handle positions 0 to {block_size-1}")

# This is what the code should do
T = 5  # sequence length (example)
device = 'cpu'

# Create position indices for sequence of length T
position_indices = torch.arange(T, device=device)
print(f"\nPosition indices for sequence length {T}: {position_indices}")

# Get position embeddings
pos_emb = position_embedding_table(position_indices)
print(f"Position embeddings shape: {pos_emb.shape}")
print(f"Each position gets a {n_embd}-dimensional embedding vector")

# Show what happens with different sequence lengths
print("\n=== Testing different sequence lengths ===")
for seq_len in [1, 3, 8]:
    indices = torch.arange(seq_len)
    embeddings = position_embedding_table(indices)
    print(f"Sequence length {seq_len}: indices={indices.tolist()}, embedding_shape={embeddings.shape}")

# This would cause an error (index out of range)
print("\n=== What goes wrong ===")
try:
    bad_indices = torch.arange(10)  # indices 0-9, but we only have positions 0-7
    bad_embeddings = position_embedding_table(bad_indices)
except IndexError as e:
    print(f"Error with indices {bad_indices.tolist()}: {e}")

# Demonstrate the full forward pass concept
print("\n=== How it works in the forward pass ===")
batch_size = 2
seq_length = 6
vocab_size = 100

# Simulate input
idx = torch.randint(0, vocab_size, (batch_size, seq_length))
print(f"Input shape (B, T): {idx.shape}")

# Token embeddings
token_embedding_table = nn.Embedding(vocab_size, n_embd)
tok_emb = token_embedding_table(idx)
print(f"Token embeddings shape: {tok_emb.shape}")

# Position embeddings (corrected version)
B, T = idx.shape  # Extract T from input shape
pos_indices = torch.arange(T, device=device)
pos_emb = position_embedding_table(pos_indices)
print(f"Position embeddings shape: {pos_emb.shape}")

# Broadcasting addition
x = tok_emb + pos_emb  # Broadcasting: (B,T,n_embd) + (T,n_embd) -> (B,T,n_embd)
print(f"Combined embeddings shape: {x.shape}")

print(f"\nBroadcasting works because:")
print(f"  tok_emb: {tok_emb.shape} (token info for each position in each batch)")
print(f"  pos_emb: {pos_emb.shape} (position info, same for all batches)")
print(f"  Result:  {x.shape} (each token gets both token and position info)")

=== Understanding Position Embeddings ===
Embedding table shape: torch.Size([8, 32])
This can handle positions 0 to 7

Position indices for sequence length 5: tensor([0, 1, 2, 3, 4])
Position embeddings shape: torch.Size([5, 32])
Each position gets a 32-dimensional embedding vector

=== Testing different sequence lengths ===
Sequence length 1: indices=[0], embedding_shape=torch.Size([1, 32])
Sequence length 3: indices=[0, 1, 2], embedding_shape=torch.Size([3, 32])
Sequence length 8: indices=[0, 1, 2, 3, 4, 5, 6, 7], embedding_shape=torch.Size([8, 32])

=== What goes wrong ===
Error with indices [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]: index out of range in self

=== How it works in the forward pass ===
Input shape (B, T): torch.Size([2, 6])
Token embeddings shape: torch.Size([2, 6, 32])
Position embeddings shape: torch.Size([6, 32])
Combined embeddings shape: torch.Size([2, 6, 32])

Broadcasting works because:
  tok_emb: torch.Size([2, 6, 32]) (token info for each position in each batch)
  pos_

<span style="color:#FF0000; font-family: 'Bebas Neue'; font-size: 01em;">Pytorch nuance:</span><br>

So nn.Embedding supports calling using a tensor, but a normal tensor wont!

In [9]:
x_embd = nn.Embedding(block_size, n_embd)

x_regular = torch.randn((block_size, n_embd))

In [10]:
print(x_embd(torch.arange(block_size))[0])

tensor([-0.7521,  1.8812,  0.8526,  0.5902,  1.5629,  0.9514, -0.7993,  1.8407,
        -1.2865, -0.2323,  0.6199, -1.2639,  0.0909, -0.1176, -0.5648, -0.1640,
        -0.8622,  0.6527, -0.0118,  0.1396,  0.1709,  1.1068, -1.0915, -0.9353,
         0.7777,  1.2527, -0.8601, -1.0533, -0.9846, -0.0555, -0.3534,  0.0202],
       grad_fn=<SelectBackward0>)


In [11]:
print(x_regular(torch.arange(block_size)))

TypeError: 'Tensor' object is not callable

So a nn.Embedding() object is callable while a regular tensor is not!