# Lab 1.3.3: Custom Embedding Lookup

**Module:** 1.3 - CUDA Python & GPU Programming  
**Time:** 2 hours  
**Difficulty:** ‚≠ê‚≠ê‚≠ê (Intermediate)

---

## üéØ Learning Objectives

By the end of this notebook, you will:
- [ ] Understand how embeddings work in neural networks
- [ ] Implement a custom CUDA kernel for batched embedding lookup
- [ ] Optimize memory access patterns for embedding tables
- [ ] Compare with PyTorch's `nn.Embedding` performance

---

## üìö Prerequisites

- Completed: Labs 1.3.1 and 1.3.2
- Knowledge of: How text is tokenized into IDs

---

## üåç Real-World Context

**Embeddings are the foundation of every LLM.**

When you type "Hello, how are you?" into ChatGPT:
1. **Tokenizer** converts text ‚Üí `[15496, 11, 703, 389, 499, 30]` (token IDs)
2. **Embedding layer** converts IDs ‚Üí vectors (the first neural network operation)
3. The rest of the model processes these vectors

**Scale of embeddings in modern LLMs:**

| Model | Vocab Size | Embedding Dim | Embedding Table Size |
|-------|------------|---------------|---------------------|
| GPT-2 | 50,257 | 768 | 147 MB |
| Llama-2-7B | 32,000 | 4,096 | 500 MB |
| Llama-3-70B | 128,256 | 8,192 | **4 GB** |
| GPT-4 (estimated) | ~100K | ~12K | **~5 GB** |

With batch sizes of thousands and sequences of thousands of tokens, embedding lookup becomes a significant operation. Understanding its memory access pattern is crucial!

---

## üßí ELI5: What Are Embeddings?

> **Imagine a giant library** where every word in every language has its own book. There are 50,000+ books (vocabulary size).
>
> Each book contains a **secret code** - a list of 768 numbers that describe everything important about that word:
> - Its meaning
> - What other words it's related to
> - Whether it's a noun, verb, emotion, etc.
>
> When you give the AI a word like "cat" (book #1234), it looks up book #1234 and reads the secret code. That's the **embedding**!
>
> **The magic:** Similar words have similar codes. "Cat" and "kitten" have codes that are close together. "Cat" and "refrigerator" have very different codes.
>
> **In code terms:**
> ```python
> # embedding_table shape: (50000, 768)
> # Like a 2D array: 50000 rows (words), 768 columns (features)
> token_id = 1234  # The word "cat"
> embedding = embedding_table[token_id]  # Get the 768-dim vector for "cat"
> ```

```
Token ID: 1234 ("cat")
    ‚Üì
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ     Embedding Table            ‚îÇ
‚îÇ  (50,000 √ó 768)               ‚îÇ
‚îú‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î§
‚îÇ Row 0:    [0.1, -0.3, ...]    ‚îÇ ‚Üê "the"
‚îÇ Row 1:    [0.5,  0.2, ...]    ‚îÇ ‚Üê "a"
‚îÇ ...                           ‚îÇ
‚îÇ Row 1234: [0.8, -0.1, ...]    ‚îÇ ‚Üê "cat" ‚úì (we want this row!)
‚îÇ ...                           ‚îÇ
‚îÇ Row 49999: [-0.2, 0.7, ...]   ‚îÇ ‚Üê "zebra"
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
    ‚Üì
Output: [0.8, -0.1, 0.3, -0.5, ...] (768 numbers)
```

---

## Part 0: Environment Setup

In [None]:
import numpy as np
import time
from typing import Tuple
import warnings
warnings.filterwarnings('ignore')

from numba import cuda, float32, int32
import numba

# Check for PyTorch
try:
    import torch
    import torch.nn as nn
    HAS_TORCH = True
    print(f"‚úÖ PyTorch {torch.__version__} available")
    print(f"   CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"   Device: {torch.cuda.get_device_name()}")
except ImportError:
    HAS_TORCH = False
    print("‚ö†Ô∏è PyTorch not available")

print(f"\n‚úÖ Numba {numba.__version__} available")
print(f"   CUDA available: {cuda.is_available()}")

---

## Part 1: Understanding PyTorch's nn.Embedding

First, let's understand how the standard PyTorch embedding works.

In [None]:
# Simple example
if HAS_TORCH:
    # Create a small embedding table: 10 words, 4-dimensional embeddings
    vocab_size = 10
    embedding_dim = 4
    
    embedding = nn.Embedding(vocab_size, embedding_dim)
    
    print("Embedding table (10 words √ó 4 features):")
    print(embedding.weight.data)
    
    # Look up embeddings for specific tokens
    token_ids = torch.tensor([0, 3, 7])  # Look up words 0, 3, and 7
    result = embedding(token_ids)
    
    print(f"\nInput token IDs: {token_ids.tolist()}")
    print(f"Output embeddings (shape {result.shape}):")
    print(result)
    
    print("\nüí° Notice: Output row 0 matches embedding.weight[0]")
    print(f"   embedding.weight[0]: {embedding.weight[0].tolist()}")
    print(f"   result[0]:           {result[0].tolist()}")

### Batched Embedding Lookup

In practice, we process entire sequences in batches:

In [None]:
if HAS_TORCH:
    # Simulate a batch of 3 sequences, each 5 tokens long
    batch_size = 3
    seq_length = 5
    
    # Random token IDs (simulating tokenized text)
    token_ids = torch.randint(0, vocab_size, (batch_size, seq_length))
    
    print(f"Input shape: {token_ids.shape} (batch √ó sequence length)")
    print(f"Token IDs:\n{token_ids}")
    
    # Embedding lookup
    embeddings = embedding(token_ids)
    
    print(f"\nOutput shape: {embeddings.shape} (batch √ó seq_len √ó embed_dim)")
    print(f"\nüí° For each of {batch_size * seq_length} tokens, we get a {embedding_dim}-dim vector")

---

## Part 2: CPU Implementation (NumPy)

Before writing CUDA, let's implement embedding lookup in pure NumPy to understand the operation.

In [None]:
def embedding_lookup_cpu(embedding_table: np.ndarray, token_ids: np.ndarray) -> np.ndarray:
    """
    CPU embedding lookup using NumPy advanced indexing.
    
    Args:
        embedding_table: Shape (vocab_size, embedding_dim)
        token_ids: Shape (batch_size, seq_length) or any shape
    
    Returns:
        embeddings: Shape (*token_ids.shape, embedding_dim)
    """
    # NumPy's advanced indexing does all the work!
    # embedding_table[token_ids] gathers rows from embedding_table
    return embedding_table[token_ids]


# Test
vocab_size = 10
embedding_dim = 4

# Create embedding table
np.random.seed(42)
embed_table = np.random.randn(vocab_size, embedding_dim).astype(np.float32)

# Token IDs to look up
token_ids_np = np.array([[0, 3, 7], [2, 5, 9]], dtype=np.int32)  # 2 sequences √ó 3 tokens

result_cpu = embedding_lookup_cpu(embed_table, token_ids_np)

print(f"Embedding table shape: {embed_table.shape}")
print(f"Token IDs shape: {token_ids_np.shape}")
print(f"Output shape: {result_cpu.shape}")

print(f"\nToken IDs:\n{token_ids_np}")
print(f"\nFirst sequence embeddings (tokens 0, 3, 7):")
print(result_cpu[0])

### üîç Understanding the Memory Access Pattern

```
Embedding table in memory (row-major):
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ [Row 0: emb[0,0], emb[0,1], ..., emb[0,767]]    ‚îÇ ‚Üê Contiguous in memory
‚îÇ [Row 1: emb[1,0], emb[1,1], ..., emb[1,767]]    ‚îÇ
‚îÇ [Row 2: ...]                                    ‚îÇ
‚îÇ ...                                             ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò

Token IDs: [1234, 5678, 9012, ...]

Access pattern:
1. Jump to row 1234 ‚Üí read 768 floats
2. Jump to row 5678 ‚Üí read 768 floats  ‚Üê Random access! Cache miss!
3. Jump to row 9012 ‚Üí read 768 floats  ‚Üê Random access! Cache miss!
```

**Key insight:** Embedding lookup is essentially **random memory access**. Unlike matrix multiplication where we can reuse data, each token accesses a different, unpredictable row. This makes it **memory bandwidth bound**, not compute bound.

---

## Part 3: CUDA Kernel for Embedding Lookup

### Design Choices

**Option 1:** One thread per token (each thread reads entire embedding)
- Simple, but poor parallelism for large embedding_dim

**Option 2:** One thread per embedding element (our choice)
- Better parallelism
- Need to handle the token ID broadcast

**Option 3:** One block per token, threads cooperate
- Best for very large embedding dimensions
- More complex

In [None]:
@cuda.jit
def embedding_lookup_kernel(embedding_table, token_ids, output):
    """
    CUDA kernel for batched embedding lookup.
    
    Grid layout:
    - x dimension: embedding dimension (which feature)
    - y dimension: flattened token index (which token)
    
    Each thread copies one float from embedding_table to output.
    
    Args:
        embedding_table: (vocab_size, embedding_dim)
        token_ids: (total_tokens,) - flattened token IDs
        output: (total_tokens, embedding_dim) - flattened output
    """
    # Which embedding dimension (feature) this thread handles
    embed_idx = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
    # Which token this thread handles
    token_idx = cuda.blockIdx.y * cuda.blockDim.y + cuda.threadIdx.y
    
    vocab_size, embedding_dim = embedding_table.shape
    total_tokens = token_ids.shape[0]
    
    # Bounds check
    if embed_idx < embedding_dim and token_idx < total_tokens:
        # Look up which row of embedding table to use
        token_id = token_ids[token_idx]
        
        # Copy the embedding value
        # Note: This is coalesced access across threads in x-dimension!
        output[token_idx, embed_idx] = embedding_table[token_id, embed_idx]


def embedding_lookup_gpu(embedding_table: np.ndarray, token_ids: np.ndarray) -> np.ndarray:
    """
    GPU embedding lookup.
    
    Args:
        embedding_table: (vocab_size, embedding_dim)
        token_ids: Any shape, will be flattened
    
    Returns:
        embeddings: (*token_ids.shape, embedding_dim)
    """
    original_shape = token_ids.shape
    vocab_size, embedding_dim = embedding_table.shape
    
    # Flatten token_ids
    flat_tokens = token_ids.flatten().astype(np.int32)
    total_tokens = flat_tokens.shape[0]
    
    # Transfer to GPU
    d_embed = cuda.to_device(embedding_table)
    d_tokens = cuda.to_device(flat_tokens)
    d_output = cuda.device_array((total_tokens, embedding_dim), dtype=np.float32)
    
    # Configure grid
    threads_per_block = (32, 8)  # 256 threads per block
    blocks_x = (embedding_dim + 31) // 32
    blocks_y = (total_tokens + 7) // 8
    blocks = (blocks_x, blocks_y)
    
    # Launch kernel
    embedding_lookup_kernel[blocks, threads_per_block](d_embed, d_tokens, d_output)
    
    # Get result and reshape
    result = d_output.copy_to_host()
    return result.reshape(*original_shape, embedding_dim)


# Test
result_gpu = embedding_lookup_gpu(embed_table, token_ids_np)

print(f"GPU result shape: {result_gpu.shape}")
print(f"Results match CPU: {np.allclose(result_gpu, result_cpu)}")
print(f"\nFirst sequence embeddings:")
print(result_gpu[0])

### üîç Why This Kernel Design?

```
Block (0, 0):
  Thread (0,0) ‚Üí output[0, 0]
  Thread (1,0) ‚Üí output[0, 1]  ‚Üê Coalesced: adjacent threads access adjacent memory
  Thread (2,0) ‚Üí output[0, 2]
  ...
  Thread (31,0) ‚Üí output[0, 31]
  Thread (0,1) ‚Üí output[1, 0]
  Thread (1,1) ‚Üí output[1, 1]
  ...

Memory access pattern (output):
  32 adjacent threads write to 32 adjacent memory locations
  = Perfect coalescing = Maximum bandwidth utilization!

Memory access pattern (embedding_table):
  Each row (token_id) may be different
  But within a row, access is coalesced
```

---

## Part 4: Optimized Version with Better Coalescing

The previous version has good write coalescing, but the read pattern could be better. Let's optimize:

In [None]:
EMBED_BLOCK_SIZE = 256  # Threads per block for embedding dimension

@cuda.jit
def embedding_lookup_optimized_kernel(embedding_table, token_ids, output):
    """
    Optimized embedding lookup with better thread organization.
    
    Each block handles one token. Threads within the block cooperatively
    copy the entire embedding vector.
    
    Benefits:
    - Threads in same warp read from same row = better cache utilization
    - Write coalescing maintained
    - Works well for any embedding_dim
    """
    # Block index = which token
    token_idx = cuda.blockIdx.x
    # Thread index = position in embedding
    tx = cuda.threadIdx.x
    
    vocab_size, embedding_dim = embedding_table.shape
    total_tokens = token_ids.shape[0]
    
    if token_idx >= total_tokens:
        return
    
    # All threads in this block read the same token_id
    token_id = token_ids[token_idx]
    
    # Each thread copies multiple elements if embedding_dim > block_size
    for embed_idx in range(tx, embedding_dim, cuda.blockDim.x):
        output[token_idx, embed_idx] = embedding_table[token_id, embed_idx]


def embedding_lookup_gpu_optimized(embedding_table: np.ndarray, token_ids: np.ndarray) -> np.ndarray:
    """
    Optimized GPU embedding lookup.
    """
    original_shape = token_ids.shape
    vocab_size, embedding_dim = embedding_table.shape
    
    flat_tokens = token_ids.flatten().astype(np.int32)
    total_tokens = flat_tokens.shape[0]
    
    d_embed = cuda.to_device(embedding_table)
    d_tokens = cuda.to_device(flat_tokens)
    d_output = cuda.device_array((total_tokens, embedding_dim), dtype=np.float32)
    
    # One block per token, up to 256 threads per block
    threads = min(EMBED_BLOCK_SIZE, embedding_dim)
    blocks = total_tokens
    
    embedding_lookup_optimized_kernel[blocks, threads](d_embed, d_tokens, d_output)
    
    result = d_output.copy_to_host()
    return result.reshape(*original_shape, embedding_dim)


# Test
result_opt = embedding_lookup_gpu_optimized(embed_table, token_ids_np)
print(f"Optimized GPU result matches: {np.allclose(result_opt, result_cpu)}")

---

## Part 5: Benchmarking

Let's compare our implementations against PyTorch at realistic LLM scales.

In [None]:
# Realistic LLM-scale parameters
configs = [
    {"name": "GPT-2 Small", "vocab_size": 50257, "embedding_dim": 768, "batch_size": 32, "seq_length": 512},
    {"name": "Llama-2-7B", "vocab_size": 32000, "embedding_dim": 4096, "batch_size": 8, "seq_length": 1024},
    {"name": "Llama-3-8B", "vocab_size": 128256, "embedding_dim": 4096, "batch_size": 8, "seq_length": 2048},
]

print("üìä Embedding Lookup Benchmark")
print("="*80)

for config in configs:
    name = config["name"]
    vocab_size = config["vocab_size"]
    embedding_dim = config["embedding_dim"]
    batch_size = config["batch_size"]
    seq_length = config["seq_length"]
    total_tokens = batch_size * seq_length
    
    print(f"\n{name}:")
    print(f"  Vocab: {vocab_size:,}, Embed: {embedding_dim}, Batch√óSeq: {batch_size}√ó{seq_length} = {total_tokens:,} tokens")
    
    # Create data
    embed_table = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
    token_ids = np.random.randint(0, vocab_size, (batch_size, seq_length), dtype=np.int32)
    
    # Embedding table size
    table_size_mb = vocab_size * embedding_dim * 4 / 1024 / 1024
    output_size_mb = total_tokens * embedding_dim * 4 / 1024 / 1024
    print(f"  Table size: {table_size_mb:.1f} MB, Output size: {output_size_mb:.1f} MB")
    
    # Pre-transfer to GPU
    d_embed = cuda.to_device(embed_table)
    d_tokens = cuda.to_device(token_ids.flatten())
    d_output = cuda.device_array((total_tokens, embedding_dim), dtype=np.float32)
    
    # Benchmark NumPy
    start = time.perf_counter()
    for _ in range(5):
        _ = embed_table[token_ids]
    time_numpy = (time.perf_counter() - start) / 5
    
    # Benchmark basic GPU
    threads = (32, 8)
    blocks = ((embedding_dim + 31) // 32, (total_tokens + 7) // 8)
    
    # Warm up
    embedding_lookup_kernel[blocks, threads](d_embed, d_tokens, d_output)
    cuda.synchronize()
    
    start = time.perf_counter()
    for _ in range(10):
        embedding_lookup_kernel[blocks, threads](d_embed, d_tokens, d_output)
    cuda.synchronize()
    time_gpu_basic = (time.perf_counter() - start) / 10
    
    # Benchmark optimized GPU
    threads_opt = min(EMBED_BLOCK_SIZE, embedding_dim)
    blocks_opt = total_tokens
    
    embedding_lookup_optimized_kernel[blocks_opt, threads_opt](d_embed, d_tokens, d_output)
    cuda.synchronize()
    
    start = time.perf_counter()
    for _ in range(10):
        embedding_lookup_optimized_kernel[blocks_opt, threads_opt](d_embed, d_tokens, d_output)
    cuda.synchronize()
    time_gpu_opt = (time.perf_counter() - start) / 10
    
    # Benchmark PyTorch if available
    time_pytorch = None
    if HAS_TORCH and torch.cuda.is_available():
        torch_embed = nn.Embedding(vocab_size, embedding_dim).cuda()
        torch_tokens = torch.from_numpy(token_ids).cuda()
        
        # Warm up
        _ = torch_embed(torch_tokens)
        torch.cuda.synchronize()
        
        start = time.perf_counter()
        for _ in range(10):
            _ = torch_embed(torch_tokens)
        torch.cuda.synchronize()
        time_pytorch = (time.perf_counter() - start) / 10
        
        del torch_embed, torch_tokens
        torch.cuda.empty_cache()
    
    # Calculate bandwidth
    bytes_transferred = total_tokens * embedding_dim * 4  # Output writes
    bandwidth_gpu = bytes_transferred / time_gpu_opt / 1e9
    
    print(f"\n  {'Method':<20} {'Time (ms)':<12} {'Speedup':<12}")
    print(f"  {'-'*44}")
    print(f"  {'NumPy':<20} {time_numpy*1000:<12.3f} {'1.0x':<12}")
    print(f"  {'GPU Basic':<20} {time_gpu_basic*1000:<12.3f} {f'{time_numpy/time_gpu_basic:.1f}x':<12}")
    print(f"  {'GPU Optimized':<20} {time_gpu_opt*1000:<12.3f} {f'{time_numpy/time_gpu_opt:.1f}x':<12}")
    if time_pytorch:
        print(f"  {'PyTorch':<20} {time_pytorch*1000:<12.3f} {f'{time_numpy/time_pytorch:.1f}x':<12}")
        print(f"\n  Our optimized vs PyTorch: {time_gpu_opt/time_pytorch:.2f}x")
    print(f"  Effective bandwidth: {bandwidth_gpu:.1f} GB/s")

### üîç Understanding the Results

**Why might our kernel be slower than PyTorch?**

PyTorch's `nn.Embedding` uses highly optimized CUDA primitives:
1. **Vectorized loads/stores** - Uses `float4` (128-bit) operations
2. **Memory prefetching** - Hints to GPU to load data in advance
3. **Warp-level optimizations** - Uses cooperative group operations

**Why embedding lookup is memory-bound:**

```
Arithmetic operations: 0 (just copying data)
Memory operations: read vocab√óembed, write batch√óseq√óembed

For Llama-3:
- Read: 128K √ó 4K √ó 4B = 2 GB (worst case, all tokens different)
- Write: 16K √ó 4K √ó 4B = 256 MB

At 273 GB/s bandwidth: theoretical minimum ~8ms
```

The operation is fundamentally limited by memory bandwidth, not compute!

---

## Part 6: Implementing Backward Pass (Gradient)

For training, we need the gradient of the embedding lookup. This is a **scatter add** operation: gradients are accumulated back into the embedding table rows.

In [None]:
@cuda.jit
def embedding_backward_kernel(grad_output, token_ids, grad_embedding):
    """
    Backward pass for embedding lookup.
    
    For each token, add its gradient back to the corresponding row
    of the embedding table gradient.
    
    Note: This uses atomic add because multiple tokens might map
    to the same embedding row!
    
    Args:
        grad_output: (total_tokens, embedding_dim) - gradient from next layer
        token_ids: (total_tokens,) - which rows to update
        grad_embedding: (vocab_size, embedding_dim) - gradient to accumulate
    """
    token_idx = cuda.blockIdx.x
    embed_idx = cuda.threadIdx.x
    
    total_tokens, embedding_dim = grad_output.shape
    
    if token_idx >= total_tokens:
        return
    
    token_id = token_ids[token_idx]
    
    for idx in range(embed_idx, embedding_dim, cuda.blockDim.x):
        # Atomic add because multiple tokens might update same row!
        cuda.atomic.add(grad_embedding, (token_id, idx), grad_output[token_idx, idx])


def embedding_backward_gpu(grad_output: np.ndarray, token_ids: np.ndarray, 
                           vocab_size: int) -> np.ndarray:
    """
    Compute gradient of embedding table.
    """
    total_tokens, embedding_dim = grad_output.shape
    
    d_grad_output = cuda.to_device(grad_output)
    d_token_ids = cuda.to_device(token_ids.flatten().astype(np.int32))
    d_grad_embedding = cuda.to_device(np.zeros((vocab_size, embedding_dim), dtype=np.float32))
    
    threads = min(256, embedding_dim)
    blocks = total_tokens
    
    embedding_backward_kernel[blocks, threads](d_grad_output, d_token_ids, d_grad_embedding)
    
    return d_grad_embedding.copy_to_host()


# Test backward pass
print("Testing Embedding Backward Pass")
print("="*50)

vocab_size = 10
embedding_dim = 4
batch_size = 2
seq_length = 3

# Random gradients coming from next layer
grad_output = np.random.randn(batch_size * seq_length, embedding_dim).astype(np.float32)
# Some tokens repeat to test atomic add
token_ids = np.array([0, 1, 0, 2, 1, 0], dtype=np.int32)  # Token 0 appears 3 times!

# GPU backward
grad_embedding_gpu = embedding_backward_gpu(grad_output, token_ids, vocab_size)

# CPU reference
grad_embedding_cpu = np.zeros((vocab_size, embedding_dim), dtype=np.float32)
for i, token_id in enumerate(token_ids):
    grad_embedding_cpu[token_id] += grad_output[i]

print(f"Token IDs: {token_ids}")
print(f"Token 0 appears {np.sum(token_ids == 0)} times")
print(f"\nGradient for embedding row 0 (should be sum of 3 gradients):")
print(f"  CPU: {grad_embedding_cpu[0]}")
print(f"  GPU: {grad_embedding_gpu[0]}")
print(f"\nResults match: {np.allclose(grad_embedding_gpu, grad_embedding_cpu)}")

---

## ‚ö†Ô∏è Common Mistakes

### Mistake 1: Forgetting Atomic Operations in Backward Pass

In [None]:
# ‚ùå WRONG: Regular assignment overwrites instead of accumulating
# grad_embedding[token_id, idx] = grad_output[token_idx, idx]  # BUG!

# ‚úÖ CORRECT: Use atomic add
# cuda.atomic.add(grad_embedding, (token_id, idx), grad_output[token_idx, idx])

print("üí° When multiple tokens map to the same row (very common!),")
print("   you MUST use atomic add to accumulate gradients correctly.")
print("   Without atomics, gradients get overwritten and training fails!")

### Mistake 2: Int64 Token IDs

In [None]:
# ‚ùå WRONG: Using int64 (wasteful, may cause issues)
# token_ids = np.array([1, 2, 3], dtype=np.int64)

# ‚úÖ CORRECT: Use int32 (sufficient for vocab sizes up to 2 billion)
# token_ids = np.array([1, 2, 3], dtype=np.int32)

print("üí° Token IDs should be int32:")
print("   - Vocab sizes are typically < 200K (well under 2B limit)")
print("   - int32 uses half the memory of int64")
print("   - Better memory coalescing")

### Mistake 3: Not Handling Out-of-Vocabulary Tokens

In [None]:
# In production, always validate token IDs!
def safe_embedding_lookup(embedding_table, token_ids):
    vocab_size = embedding_table.shape[0]
    
    # Check for invalid token IDs
    if np.any(token_ids < 0) or np.any(token_ids >= vocab_size):
        invalid_tokens = token_ids[(token_ids < 0) | (token_ids >= vocab_size)]
        raise ValueError(f"Invalid token IDs: {invalid_tokens[:5]}... (vocab_size={vocab_size})")
    
    return embedding_lookup_gpu_optimized(embedding_table, token_ids)

print("üí° Always validate token IDs before embedding lookup!")
print("   Out-of-bounds access = undefined behavior or crash.")

---

## ‚úã Try It Yourself: Implement Positional Embeddings

**Challenge:** Extend the embedding kernel to also add positional embeddings.

In transformers, the input to the model is:
```
input = token_embedding + position_embedding
```

Where position_embedding depends on the position (0, 1, 2, ...) in the sequence.

In [None]:
# TODO: Implement combined token + position embedding lookup

@cuda.jit
def combined_embedding_kernel(token_embed_table, pos_embed_table, 
                               token_ids, positions, output):
    """
    Combined token and positional embedding lookup.
    
    Args:
        token_embed_table: (vocab_size, embedding_dim)
        pos_embed_table: (max_seq_length, embedding_dim)
        token_ids: (total_tokens,)
        positions: (total_tokens,) - position index for each token
        output: (total_tokens, embedding_dim)
    
    Output should be: token_embedding[token_id] + pos_embedding[position]
    """
    # TODO: Implement this!
    # Hint: Similar to embedding_lookup_optimized_kernel, but add two lookups
    pass


# When implemented, test with:
# vocab_size = 100
# max_seq_length = 512
# embedding_dim = 64
# token_embed = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
# pos_embed = np.random.randn(max_seq_length, embedding_dim).astype(np.float32)
# token_ids = np.array([5, 10, 15, 20], dtype=np.int32)
# positions = np.array([0, 1, 2, 3], dtype=np.int32)
# 
# Expected: token_embed[5] + pos_embed[0], token_embed[10] + pos_embed[1], ...

<details>
<summary>üí° Hint</summary>

```python
@cuda.jit
def combined_embedding_kernel(token_embed_table, pos_embed_table, 
                               token_ids, positions, output):
    token_idx = cuda.blockIdx.x
    tx = cuda.threadIdx.x
    
    _, embedding_dim = token_embed_table.shape
    total_tokens = token_ids.shape[0]
    
    if token_idx >= total_tokens:
        return
    
    token_id = token_ids[token_idx]
    position = positions[token_idx]
    
    for embed_idx in range(tx, embedding_dim, cuda.blockDim.x):
        # Add both embeddings!
        output[token_idx, embed_idx] = (
            token_embed_table[token_id, embed_idx] + 
            pos_embed_table[position, embed_idx]
        )
```
</details>

---

## üéâ Checkpoint

Congratulations! You've learned:

- ‚úÖ **How embeddings work** - The foundation of all language models
- ‚úÖ **Memory access patterns** - Why embedding lookup is memory-bound
- ‚úÖ **Custom CUDA kernels** - Forward and backward passes
- ‚úÖ **Atomic operations** - Essential for gradient accumulation
- ‚úÖ **Performance analysis** - Understanding bandwidth limitations

You now understand one of the most fundamental operations in deep learning!

---

## üöÄ Challenge (Optional)

**Advanced Challenge: Implement Sparse Gradient Updates**

In training, most embedding rows are never updated in a single batch. Instead of computing gradients for the entire embedding table, you can:

1. Find unique token IDs in the batch
2. Only compute gradients for those rows
3. Use sparse tensor representations

This is how efficient embedding implementations work in practice!

---

## üìñ Further Reading

- [Word2Vec Original Paper](https://arxiv.org/abs/1301.3781) - The paper that popularized embeddings
- [Efficient Estimation of Word Representations](https://arxiv.org/abs/1310.4546) - Negative sampling
- [Rotary Position Embedding (RoPE)](https://arxiv.org/abs/2104.09864) - Used in modern LLMs like Llama
- [FlashAttention Embedding Techniques](https://github.com/Dao-AILab/flash-attention)

---

## üßπ Cleanup

In [None]:
import gc

# Clean up large arrays safely
if 'embed_table' in dir():
    del embed_table
if 'd_embed' in dir():
    del d_embed
if 'd_tokens' in dir():
    del d_tokens
if 'd_output' in dir():
    del d_output

gc.collect()

if HAS_TORCH:
    torch.cuda.empty_cache()

try:
    cuda.current_context().reset()
except Exception:
    pass  # Context might not exist

print("‚úÖ GPU memory cleared!")
print("\n‚û°Ô∏è Ready for Lab 1.3.4: CuPy Integration")