# My Learning Notes: Embeddings vs Linear Layers

**What I'm exploring**: Understanding why PyTorch uses embedding layers instead of linear layers for token embeddings

**My goal**: Build intuition for the computational efficiency difference between these two approaches

**Reference**: Based on concepts from "Build a Large Language Model From Scratch" by Sebastian Raschka

## My Understanding

I'm learning that embedding layers and linear layers accomplish the same thing mathematically, but embedding layers are much more efficient. I want to see this relationship step-by-step using code examples.

In [None]:
import torch

print("PyTorch version:", torch.__version__)

## Part 1: Using nn.Embedding (The Efficient Way)

I'm starting with `nn.Embedding` because this is what I'll use in practice. Then I'll compare it to the linear layer approach to understand why embedding layers are preferred.

In [None]:
# I'm setting up a simple example with 3 token IDs
# These could represent tokens in an LLM context
token_ids = torch.tensor([2, 3, 1])

# I need to determine the vocabulary size
# If the highest token ID is 3, I need 4 rows (for IDs 0, 1, 2, 3)
vocab_size = max(token_ids) + 1

# I'm choosing an embedding dimension (this is a hyperparameter)
embedding_dim = 5

print(f"Token IDs: {token_ids}")
print(f"Vocab size: {vocab_size}")
print(f"Embedding dimension: {embedding_dim}")

Now I'm creating the embedding layer. I'm setting a random seed so I can reproduce the same results later:

In [None]:
# Setting random seed for reproducibility
torch.manual_seed(123)

# Creating the embedding layer
# It will have vocab_size rows and embedding_dim columns
embedding_layer = torch.nn.Embedding(vocab_size, embedding_dim)

print(f"Embedding layer shape: {embedding_layer.weight.shape}")

Let me look at the embedding weight matrix. Each row corresponds to one token ID:

In [None]:
embedding_layer.weight

Now I'll get the embedding vector for token ID 1. This is just a lookup - it returns row 1 from the weight matrix:

In [None]:
embedding_layer(torch.tensor([1]))

**My understanding**: The embedding layer simply looks up the row corresponding to the token ID. For token ID 1, it returns row 1 of the weight matrix.

Similarly, I can get the embedding for token ID 2:

In [None]:
embedding_layer(torch.tensor([2]))

Now let me convert all my token IDs at once. The embedding layer handles batches efficiently:

In [None]:
token_ids = torch.tensor([2, 3, 1])
embedding_layer(token_ids)

**My takeaway**: The embedding layer returns a matrix where each row is the embedding vector for the corresponding token ID. This is just a series of lookups - very efficient!

## Part 2: Using nn.Linear (The Matrix Multiplication Way)

Now I want to understand how a linear layer can achieve the same result, but why it's less efficient. I'll use one-hot encoding and matrix multiplication.

First, I need to convert my token IDs into one-hot encoded vectors:

In [None]:
onehot = torch.nn.functional.one_hot(token_ids)
print(f"One-hot encoded shape: {onehot.shape}")
print("\nOne-hot encoded vectors:")
print(onehot)

Now I'm creating a Linear layer. This performs matrix multiplication: $X W^T$ where $X$ is the input and $W$ is the weight matrix:

In [None]:
torch.manual_seed(123)
linear_layer = torch.nn.Linear(vocab_size, embedding_dim, bias=False)
print(f"Linear layer weight shape: {linear_layer.weight.shape}")

**Important**: To fairly compare with the embedding layer, I need to use the same weights. The linear layer stores weights transposed, so I'm assigning the transpose of the embedding weights:

In [None]:
linear_layer.weight = torch.nn.Parameter(embedding_layer.weight.T)
print("Assigned embedding weights (transposed) to linear layer")

Now I'll apply the linear layer to the one-hot encoded input. This performs matrix multiplication:

In [None]:
result_linear = linear_layer(onehot.float())
print("Result from linear layer:")
print(result_linear)

Let me verify this matches the embedding layer output:

In [None]:
result_embedding = embedding_layer(token_ids)
print("Result from embedding layer:")
print(result_embedding)

print("\nAre they equal?", torch.allclose(result_linear, result_embedding))

## Part 3: Understanding Why Embeddings Are More Efficient

**My key insight**: When we multiply a one-hot encoded vector with a weight matrix, we're doing a lot of wasteful multiplications by zero.

For example, for token ID 2, the one-hot vector is `[0, 0, 1, 0]`. When I multiply this with the weight matrix:
- Rows 0, 1, and 3 are multiplied by 0 (wasted computation)
- Only row 2 is multiplied by 1 (the only useful part)

This is exactly the same as just looking up row 2 directly, which is what the embedding layer does!

## My Takeaways

✅ **Mathematical equivalence**: Embedding layers and linear layers (with one-hot encoding) produce identical results

✅ **Computational efficiency**: Embedding layers are much faster because they skip the wasteful multiplications by zero

✅ **When to use which**:
- Use `nn.Embedding` in practice for token embeddings (efficient lookup)
- Use `nn.Linear` for standard transformations (actual matrix multiplication needed)

**What I still need to practice**: Understanding how gradients flow back through embedding layers during training