# Word2Vec Implementation in PyTorch
## A Beginner-Friendly Step-by-Step Guide

This notebook implements **Word2Vec** (both **Skip-Gram** and **CBOW**) from scratch using PyTorch.

It serves as a practical companion to the theory notebook (`BTL_WordEmbedding.ipynb`).

---

### What you'll learn:
1. How to preprocess text for Word2Vec
2. How to build training data for Skip-Gram and CBOW
3. How to implement both models in PyTorch
4. How to train and visualize word embeddings

### Prerequisites:
- Basic Python knowledge
- Understanding of the theory (see `BTL_WordEmbedding.ipynb`)

---

### Table of Contents:
1. [Setup & Imports](#1.-Setup-&-Imports)
2. [Data Preparation](#2.-Data-Preparation)
3. [Training Data Generation](#3.-Training-Data-Generation)
4. [PyTorch Dataset & DataLoader](#4.-PyTorch-Dataset-&-DataLoader)
5. [Skip-Gram Model](#5.-Skip-Gram-Model)
6. [CBOW Model](#6.-CBOW-Model)
7. [Training Loop](#7.-Training-Loop)
8. [Evaluation & Visualization](#8.-Evaluation-&-Visualization)
9. [Playing with Results](#9.-Playing-with-Results)
10. [Summary & Next Steps](#10.-Summary-&-Next-Steps)

---
## 1. Setup & Imports

First, let's import all the libraries we need. Each library has a specific purpose:

| Library | Purpose |
|---------|--------|
| `torch` | Deep learning framework (PyTorch) |
| `numpy` | Numerical computations |
| `matplotlib` | Plotting and visualization |
| `sklearn` | Machine learning utilities (t-SNE) |
| `collections` | Data structures (Counter) |

In [None]:
# ============================================
# SECTION 1: IMPORTS AND SETUP
# ============================================

# Core Python libraries
import random                          # For random number generation
from collections import Counter        # For counting word frequencies

# NumPy for numerical operations
import numpy as np                     # Array operations

# PyTorch for deep learning
import torch                           # Core PyTorch
import torch.nn as nn                  # Neural network modules
import torch.nn.functional as F        # Activation functions, loss functions
import torch.optim as optim            # Optimizers (Adam, SGD, etc.)
from torch.utils.data import Dataset, DataLoader  # Data handling

# Visualization
import matplotlib.pyplot as plt        # Plotting
from sklearn.manifold import TSNE      # Dimensionality reduction for visualization

# Make plots look nicer
plt.style.use('seaborn-v0_8-whitegrid')

print("All imports successful!")

In [None]:
# ============================================
# SET RANDOM SEEDS FOR REPRODUCIBILITY
# ============================================
# Why? So we get the same results every time we run the notebook

SEED = 42  # The answer to everything :)

random.seed(SEED)           # Python's random
np.random.seed(SEED)        # NumPy's random
torch.manual_seed(SEED)     # PyTorch's random

# If using GPU, set seed for that too
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

print(f"Random seed set to: {SEED}")

In [None]:
# ============================================
# CHECK DEVICE (CPU or GPU)
# ============================================
# GPU makes training faster, but CPU works fine for our small example

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("=" * 50)
print("SYSTEM INFORMATION")
print("=" * 50)
print(f"PyTorch version: {torch.__version__}")
print(f"Using device: {device}")

if device.type == 'cuda':
    print(f"GPU name: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("No GPU available, using CPU (this is fine for our small example!)")

print("=" * 50)

---
## 2. Data Preparation

Before we can train Word2Vec, we need to prepare our text data. This involves:

1. **Creating a corpus** (collection of sentences)
2. **Tokenization** (breaking sentences into words)
3. **Building vocabulary** (mapping words to numbers)

Let's go step by step!

### 2.1 Our Custom Corpus

We'll use a simple, hand-crafted corpus about **animals and food**. 

**Why a simple corpus?**
- Easy to understand what the model learns
- Easy to debug when something goes wrong
- We can predict what words should be similar:
  - `"cat"` and `"dog"` should be similar (both are pets)
  - `"eat"` and `"drink"` should be similar (both are actions)
  - `"fish"` appears in multiple contexts (food for cats, swims in water)

In [None]:
# ============================================
# 2.1 CREATE OUR CUSTOM CORPUS
# ============================================
# A corpus is simply a collection of text (sentences in our case)

corpus = [
    "the cat sat on the mat",
    "the dog sat on the rug",
    "the cat ate the fish",
    "the dog ate the meat",
    "the cat and the dog are friends",
    "the fish swims in the water",
    "the cat drinks milk",
    "the dog drinks water",
    "the cat chased the dog",
    "the dog chased the cat",
    "i love my cat",
    "i love my dog",
    "the cat is sleeping",
    "the dog is sleeping",
    "cats and dogs are pets",
]

# Let's see our corpus
print("=" * 60)
print("OUR CORPUS")
print("=" * 60)
print()

for i, sentence in enumerate(corpus, 1):
    print(f"  Sentence {i:2d}: \"{sentence}\"")

print()
print(f"Total sentences: {len(corpus)}")
print("=" * 60)

### 2.2 Tokenization

**Tokenization** = Breaking text into individual units called **tokens** (usually words)

For example:
```
"the cat sat on the mat"  →  ["the", "cat", "sat", "on", "the", "mat"]
```

Our tokenizer will:
1. Convert to lowercase (already done in our corpus)
2. Split by spaces
3. (In real projects, you'd also handle punctuation, special characters, etc.)

In [None]:
# ============================================
# 2.2 TOKENIZATION
# ============================================

def tokenize(text):
    """
    Simple tokenizer that splits text into words.
    
    Args:
        text (str): Input sentence
        
    Returns:
        list: List of words (tokens)
    
    Example:
        >>> tokenize("the cat sat")
        ['the', 'cat', 'sat']
    """
    # Convert to lowercase and split by spaces
    return text.lower().split()

# Let's test our tokenizer on one sentence
test_sentence = "the cat sat on the mat"
tokens = tokenize(test_sentence)

print("=" * 60)
print("TOKENIZATION EXAMPLE")
print("=" * 60)
print()
print(f"  Input:  \"{test_sentence}\"")
print(f"  Output: {tokens}")
print()
print(f"  Number of tokens: {len(tokens)}")
print("=" * 60)

In [None]:
# ============================================
# TOKENIZE THE ENTIRE CORPUS
# ============================================

# Apply tokenization to all sentences
tokenized_corpus = [tokenize(sentence) for sentence in corpus]

print("=" * 60)
print("TOKENIZED CORPUS")
print("=" * 60)
print()

for i, tokens in enumerate(tokenized_corpus, 1):
    print(f"  Sentence {i:2d}: {tokens}")

print()

# Flatten into a single list of all words
all_words = [word for sentence in tokenized_corpus for word in sentence]

print(f"Total words in corpus: {len(all_words)}")
print(f"First 20 words: {all_words[:20]}")
print("=" * 60)

### 2.3 Building the Vocabulary

**Vocabulary** = The set of all unique words in our corpus

Neural networks work with **numbers**, not strings! So we need:

1. **`word_to_idx`**: Dictionary mapping each word to a unique integer
   - Example: `{"the": 0, "cat": 1, "dog": 2, ...}`

2. **`idx_to_word`**: Dictionary mapping each integer back to its word
   - Example: `{0: "the", 1: "cat", 2: "dog", ...}`

3. **`vocab_size`**: Total number of unique words

In [None]:
# ============================================
# 2.3 BUILD THE VOCABULARY
# ============================================

def build_vocabulary(tokenized_corpus):
    """
    Build vocabulary from tokenized corpus.
    
    Args:
        tokenized_corpus: List of tokenized sentences
        
    Returns:
        word_to_idx: Dictionary mapping word -> index
        idx_to_word: Dictionary mapping index -> word
        word_counts: Counter object with word frequencies
    """
    # Step 1: Count word frequencies
    word_counts = Counter()
    for sentence in tokenized_corpus:
        word_counts.update(sentence)
    
    print("Step 1: Count word frequencies")
    print(f"  Found {len(word_counts)} unique words")
    print(f"  Most common: {word_counts.most_common(5)}")
    print()
    
    # Step 2: Sort words by frequency (most common first)
    # This is optional but helps with consistency
    sorted_words = sorted(word_counts.keys(), 
                          key=lambda x: (-word_counts[x], x))
    
    print("Step 2: Sort words by frequency")
    print(f"  First 10 words: {sorted_words[:10]}")
    print()
    
    # Step 3: Create word-to-index mapping
    word_to_idx = {word: idx for idx, word in enumerate(sorted_words)}
    
    print("Step 3: Create word_to_idx mapping")
    print(f"  'the' -> {word_to_idx['the']}")
    print(f"  'cat' -> {word_to_idx['cat']}")
    print(f"  'dog' -> {word_to_idx['dog']}")
    print()
    
    # Step 4: Create index-to-word mapping (reverse of word_to_idx)
    idx_to_word = {idx: word for word, idx in word_to_idx.items()}
    
    print("Step 4: Create idx_to_word mapping")
    print(f"  0 -> '{idx_to_word[0]}'")
    print(f"  1 -> '{idx_to_word[1]}'")
    print(f"  2 -> '{idx_to_word[2]}'")
    
    return word_to_idx, idx_to_word, word_counts

# Build vocabulary
print("=" * 60)
print("BUILDING VOCABULARY")
print("=" * 60)
print()

word_to_idx, idx_to_word, word_counts = build_vocabulary(tokenized_corpus)
vocab_size = len(word_to_idx)

print()
print("=" * 60)
print(f"VOCABULARY SIZE: {vocab_size} unique words")
print("=" * 60)

In [None]:
# ============================================
# DISPLAY FULL VOCABULARY
# ============================================

print("=" * 60)
print("COMPLETE VOCABULARY")
print("=" * 60)
print()
print(f"{'Index':<8} {'Word':<15} {'Frequency':<10}")
print("-" * 35)

for idx in range(vocab_size):
    word = idx_to_word[idx]
    freq = word_counts[word]
    print(f"{idx:<8} {word:<15} {freq:<10}")

print("-" * 35)
print(f"Total: {vocab_size} words")
print("=" * 60)

### 2.4 Visualizing Word Frequencies

Let's visualize our vocabulary to better understand our data.

In [None]:
# ============================================
# 2.4 VISUALIZE WORD FREQUENCIES
# ============================================

# Get words and their frequencies
words = [idx_to_word[i] for i in range(vocab_size)]
frequencies = [word_counts[word] for word in words]

# Create figure with 2 subplots
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Bar chart of word frequencies
ax1 = axes[0]
bars = ax1.barh(words[::-1], frequencies[::-1], color='steelblue', edgecolor='navy')
ax1.set_xlabel('Frequency', fontsize=12)
ax1.set_ylabel('Word', fontsize=12)
ax1.set_title('Word Frequencies in Our Corpus', fontsize=14, fontweight='bold')

# Add value labels on bars
for bar, freq in zip(bars, frequencies[::-1]):
    ax1.text(bar.get_width() + 0.1, bar.get_y() + bar.get_height()/2, 
             str(freq), va='center', fontsize=9)

# Plot 2: Pie chart of top words vs others
ax2 = axes[1]
top_n = 5
top_words = words[:top_n]
top_freqs = frequencies[:top_n]
other_freq = sum(frequencies[top_n:])

pie_labels = top_words + ['Others']
pie_sizes = top_freqs + [other_freq]
colors = plt.cm.Blues(np.linspace(0.3, 0.9, len(pie_labels)))

wedges, texts, autotexts = ax2.pie(pie_sizes, labels=pie_labels, autopct='%1.1f%%',
                                    colors=colors, startangle=90)
ax2.set_title(f'Top {top_n} Words vs Others', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

# Print summary statistics
print("=" * 50)
print("VOCABULARY STATISTICS")
print("=" * 50)
print(f"Total unique words: {vocab_size}")
print(f"Total word occurrences: {sum(frequencies)}")
print(f"Most common word: '{words[0]}' (appears {frequencies[0]} times)")
print(f"Least common words: {words[-3:]} (appear {frequencies[-1]} time each)")
print(f"Average frequency: {sum(frequencies)/len(frequencies):.2f}")
print("=" * 50)

---
## 3. Training Data Generation

Now we need to create training data for our models:

- **Skip-Gram**: Given a center word, predict context words
- **CBOW**: Given context words, predict the center word

Both use the concept of a **context window**.

### 3.1 Understanding the Context Window

The **context window** defines how many words on each side of the center word we consider as "context".

**Example with window size = 2:**

```
Sentence: "the  cat  sat  on  the  mat"
                      ↑
              center word = "sat"
              
Context window (m=2):
    ← 2 words →  sat  ← 2 words →
    "the" "cat"       "on" "the"
    
So context words for "sat" are: ["the", "cat", "on", "the"]
```

Let's visualize this:

In [None]:
# ============================================
# 3.1 VISUALIZE CONTEXT WINDOW
# ============================================

def visualize_context_window(sentence, center_idx, window_size):
    """
    Visualize the context window around a center word.
    
    Args:
        sentence: String sentence
        center_idx: Index of the center word
        window_size: Number of words on each side to consider
    """
    tokens = sentence.split()
    
    print("=" * 70)
    print("CONTEXT WINDOW VISUALIZATION")
    print("=" * 70)
    print()
    print(f"Sentence: \"{sentence}\"")
    print(f"Window size (m): {window_size}")
    print(f"Center word position: {center_idx}")
    print(f"Center word: \"{tokens[center_idx]}\"")
    print()
    
    # Build visualization string
    visual = ""
    for i, token in enumerate(tokens):
        if i == center_idx:
            visual += f"[{token.upper()}] "  # Center word in brackets, uppercase
        elif center_idx - window_size <= i <= center_idx + window_size and i != center_idx:
            visual += f"({token}) "  # Context words in parentheses
        else:
            visual += f" {token}  "  # Other words
    
    print("Visualization:")
    print(f"  {visual}")
    print()
    print("  Legend: [CENTER] (context) other")
    print()
    
    # Get context words
    context_words = []
    context_positions = []
    for offset in range(-window_size, window_size + 1):
        if offset != 0:  # Skip center word
            pos = center_idx + offset
            if 0 <= pos < len(tokens):
                context_words.append(tokens[pos])
                context_positions.append(pos)
    
    print(f"Context words: {context_words}")
    print(f"Context positions: {context_positions}")
    print("=" * 70)
    
    return context_words

# Test with different center positions
sentence = "the cat sat on the mat"
print()
visualize_context_window(sentence, center_idx=2, window_size=2)  # "sat" as center

In [None]:
# Let's see what happens at the edge of a sentence
print("\\nWhat happens at the BEGINNING of a sentence?")
print("(We can only get words to the RIGHT)")
print()
visualize_context_window(sentence, center_idx=0, window_size=2)  # "the" at start

print()
print("\\nWhat happens at the END of a sentence?")
print("(We can only get words to the LEFT)")
print()
visualize_context_window(sentence, center_idx=5, window_size=2)  # "mat" at end

### 3.2 Skip-Gram Training Pairs

**Skip-Gram Goal**: Given a center word, predict each context word.

For each center word, we create **multiple training pairs** - one for each context word.

```
For sentence "the cat sat on the mat" with center="sat" and window=2:

Training pairs (center, context):
    (sat, the)   ← predict "the" from "sat"
    (sat, cat)   ← predict "cat" from "sat"
    (sat, on)    ← predict "on" from "sat"
    (sat, the)   ← predict "the" from "sat"
```

Let's generate these pairs:

In [None]:
# ============================================
# 3.2 GENERATE SKIP-GRAM TRAINING PAIRS
# ============================================

def generate_skipgram_pairs(tokenized_corpus, word_to_idx, window_size=2, verbose=False):
    """
    Generate Skip-Gram training pairs.
    
    For each word in the corpus:
        - Treat it as the center word
        - Create pairs with each context word within the window
    
    Args:
        tokenized_corpus: List of tokenized sentences
        word_to_idx: Dictionary mapping words to indices
        window_size: Number of words on each side to consider
        verbose: If True, print details for first sentence
        
    Returns:
        pairs: List of (center_idx, context_idx) tuples
    """
    pairs = []
    
    for sent_idx, sentence in enumerate(tokenized_corpus):
        # Print details for first sentence only
        if verbose and sent_idx == 0:
            print(f"Processing sentence: {sentence}")
            print("-" * 50)
        
        for center_pos, center_word in enumerate(sentence):
            center_idx = word_to_idx[center_word]
            
            if verbose and sent_idx == 0:
                print(f"\\n  Center word: '{center_word}' (position {center_pos}, index {center_idx})")
            
            # Get context words within the window
            for offset in range(-window_size, window_size + 1):
                if offset == 0:  # Skip the center word itself
                    continue
                
                context_pos = center_pos + offset
                
                # Check if position is valid (within sentence bounds)
                if 0 <= context_pos < len(sentence):
                    context_word = sentence[context_pos]
                    context_idx = word_to_idx[context_word]
                    
                    pairs.append((center_idx, context_idx))
                    
                    if verbose and sent_idx == 0:
                        print(f"    -> Context: '{context_word}' (position {context_pos}, index {context_idx})")
                        print(f"       Pair: ({center_idx}, {context_idx}) = ('{center_word}', '{context_word}')")
    
    return pairs

# Set window size
WINDOW_SIZE = 2

print("=" * 70)
print("GENERATING SKIP-GRAM TRAINING PAIRS")
print("=" * 70)
print(f"Window size: {WINDOW_SIZE}")
print()

# Generate pairs with verbose output for first sentence
skipgram_pairs = generate_skipgram_pairs(tokenized_corpus, word_to_idx, WINDOW_SIZE, verbose=True)

print()
print("=" * 70)
print(f"Total Skip-Gram pairs generated: {len(skipgram_pairs)}")
print("=" * 70)

In [None]:
# ============================================
# DISPLAY SAMPLE SKIP-GRAM PAIRS
# ============================================

print("=" * 70)
print("SAMPLE SKIP-GRAM PAIRS (First 30)")
print("=" * 70)
print()
print(f"{'#':<4} {'Center Idx':<12} {'Context Idx':<12} {'Center Word':<15} {'Context Word':<15}")
print("-" * 70)

for i, (center_idx, context_idx) in enumerate(skipgram_pairs[:30]):
    center_word = idx_to_word[center_idx]
    context_word = idx_to_word[context_idx]
    print(f"{i+1:<4} {center_idx:<12} {context_idx:<12} '{center_word}'{'':13} '{context_word}'")

print("-" * 70)
print(f"... and {len(skipgram_pairs) - 30} more pairs")
print()

# Show statistics
print("STATISTICS:")
print(f"  Total pairs: {len(skipgram_pairs)}")
print(f"  Unique center words used: {len(set(p[0] for p in skipgram_pairs))}")
print(f"  Unique context words used: {len(set(p[1] for p in skipgram_pairs))}")
print("=" * 70)

### 3.3 CBOW Training Pairs

**CBOW Goal**: Given context words, predict the center word.

Unlike Skip-Gram, CBOW creates **one training sample** per center word, with **all context words** as input.

```
For sentence "the cat sat on the mat" with center="sat" and window=2:

Training pair (context_words, center):
    ([the, cat, on, the], sat)  ← predict "sat" from all context words
```

The context words are often averaged to create a single input vector.

In [None]:
# ============================================
# 3.3 GENERATE CBOW TRAINING PAIRS
# ============================================

def generate_cbow_pairs(tokenized_corpus, word_to_idx, window_size=2, verbose=False):
    """
    Generate CBOW training pairs.
    
    For each word in the corpus:
        - Treat it as the center word (target)
        - Collect all context words within the window (input)
    
    Args:
        tokenized_corpus: List of tokenized sentences
        word_to_idx: Dictionary mapping words to indices
        window_size: Number of words on each side to consider
        verbose: If True, print details for first sentence
        
    Returns:
        pairs: List of (context_indices_list, center_idx) tuples
    """
    pairs = []
    
    for sent_idx, sentence in enumerate(tokenized_corpus):
        if verbose and sent_idx == 0:
            print(f"Processing sentence: {sentence}")
            print("-" * 60)
        
        for center_pos, center_word in enumerate(sentence):
            center_idx = word_to_idx[center_word]
            
            # Collect ALL context words within the window
            context_indices = []
            context_words = []
            
            for offset in range(-window_size, window_size + 1):
                if offset == 0:  # Skip center word
                    continue
                
                context_pos = center_pos + offset
                
                if 0 <= context_pos < len(sentence):
                    ctx_word = sentence[context_pos]
                    ctx_idx = word_to_idx[ctx_word]
                    context_indices.append(ctx_idx)
                    context_words.append(ctx_word)
            
            # Only add if we have at least one context word
            if context_indices:
                pairs.append((context_indices, center_idx))
                
                if verbose and sent_idx == 0:
                    print(f"\\n  Center word (TARGET): '{center_word}' (index {center_idx})")
                    print(f"  Context words (INPUT): {context_words}")
                    print(f"  Context indices: {context_indices}")
                    print(f"  Pair: ({context_indices}, {center_idx})")
    
    return pairs

print("=" * 70)
print("GENERATING CBOW TRAINING PAIRS")
print("=" * 70)
print(f"Window size: {WINDOW_SIZE}")
print()

# Generate CBOW pairs
cbow_pairs = generate_cbow_pairs(tokenized_corpus, word_to_idx, WINDOW_SIZE, verbose=True)

print()
print("=" * 70)
print(f"Total CBOW pairs generated: {len(cbow_pairs)}")
print("=" * 70)

In [None]:
# ============================================
# DISPLAY SAMPLE CBOW PAIRS
# ============================================

print("=" * 80)
print("SAMPLE CBOW PAIRS (First 15)")
print("=" * 80)
print()

for i, (context_indices, center_idx) in enumerate(cbow_pairs[:15]):
    center_word = idx_to_word[center_idx]
    context_words = [idx_to_word[idx] for idx in context_indices]
    
    print(f"Pair {i+1}:")
    print(f"  Context (INPUT):  {context_indices} = {context_words}")
    print(f"  Center (TARGET):  {center_idx} = '{center_word}'")
    print()

print("=" * 80)
print()

# Compare Skip-Gram vs CBOW
print("COMPARISON: Skip-Gram vs CBOW")
print("-" * 50)
print(f"{'Aspect':<25} {'Skip-Gram':<20} {'CBOW':<20}")
print("-" * 50)
print(f"{'Number of pairs':<25} {len(skipgram_pairs):<20} {len(cbow_pairs):<20}")
print(f"{'Input':<25} {'1 word (center)':<20} {'Multiple (context)':<20}")
print(f"{'Output':<25} {'1 word (context)':<20} {'1 word (center)':<20}")
print("-" * 50)

---
## 4. PyTorch Dataset & DataLoader

PyTorch provides two important classes for data handling:

1. **`Dataset`**: Stores your data samples and labels
   - Must implement `__len__()` and `__getitem__()`

2. **`DataLoader`**: Wraps a Dataset to enable:
   - **Batching**: Group samples together
   - **Shuffling**: Randomize order each epoch
   - **Parallel loading**: Load data in background

Let's create custom Datasets for Skip-Gram and CBOW.

### 4.1 Skip-Gram Dataset

For Skip-Gram, each sample is simple:
- **Input**: center word index (single integer)
- **Target**: context word index (single integer)

In [None]:
# ============================================
# 4.1 SKIP-GRAM DATASET
# ============================================

class SkipGramDataset(Dataset):
    """
    PyTorch Dataset for Skip-Gram model.
    
    Each sample is a (center_word_idx, context_word_idx) pair.
    
    The model will learn to:
        Given center_word_idx -> predict context_word_idx
    """
    
    def __init__(self, pairs):
        """
        Args:
            pairs: List of (center_idx, context_idx) tuples
        """
        self.pairs = pairs
        print(f"SkipGramDataset created with {len(pairs)} samples")
    
    def __len__(self):
        """
        Return the total number of samples.
        
        This is required by PyTorch to know how many samples are in the dataset.
        """
        return len(self.pairs)
    
    def __getitem__(self, idx):
        """
        Return a single sample by index.
        
        Args:
            idx: Index of the sample to return
            
        Returns:
            center_idx: Tensor with center word index (input)
            context_idx: Tensor with context word index (target)
        """
        center_idx, context_idx = self.pairs[idx]
        
        # Convert to PyTorch tensors
        # PyTorch models expect tensors, not Python integers
        return torch.tensor(center_idx, dtype=torch.long), \
               torch.tensor(context_idx, dtype=torch.long)

# Create Skip-Gram dataset
print("=" * 60)
print("CREATING SKIP-GRAM DATASET")
print("=" * 60)
print()

skipgram_dataset = SkipGramDataset(skipgram_pairs)

print()
print("Testing the dataset:")
print("-" * 40)

# Test getting samples
for i in [0, 1, 2]:
    center, context = skipgram_dataset[i]
    print(f"Sample {i}:")
    print(f"  center_idx (input):  {center.item()} -> '{idx_to_word[center.item()]}'")
    print(f"  context_idx (target): {context.item()} -> '{idx_to_word[context.item()]}'")
    print(f"  Tensor types: center={center.dtype}, context={context.dtype}")
    print()

print("=" * 60)

### 4.2 CBOW Dataset

For CBOW, each sample has:
- **Input**: List of context word indices (variable length!)
- **Target**: center word index (single integer)

**Challenge**: Batching requires fixed-size tensors, but context lengths vary!

**Solution**: Pad shorter contexts with a special value (-1) and track the actual length.

In [None]:
# ============================================
# 4.2 CBOW DATASET
# ============================================

class CBOWDataset(Dataset):
    """
    PyTorch Dataset for CBOW model.
    
    Each sample is a (context_word_indices, center_word_idx) pair.
    
    The model will learn to:
        Given context_word_indices -> predict center_word_idx
    
    Note: We pad context to a fixed size for batching.
    """
    
    def __init__(self, pairs, max_context_size):
        """
        Args:
            pairs: List of (context_indices_list, center_idx) tuples
            max_context_size: Maximum number of context words (2 * window_size)
        """
        self.pairs = pairs
        self.max_context_size = max_context_size
        print(f"CBOWDataset created with {len(pairs)} samples")
        print(f"Maximum context size: {max_context_size}")
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        """
        Return a single sample by index.
        
        Returns:
            context_indices: Tensor of shape (max_context_size,) - padded with -1
            center_idx: Tensor with center word index
            context_length: Tensor with actual number of context words
        """
        context_indices, center_idx = self.pairs[idx]
        actual_length = len(context_indices)
        
        # Pad context to fixed size with -1
        # Example: [1, 2, 3] with max_size=4 -> [1, 2, 3, -1]
        padded_context = context_indices + [-1] * (self.max_context_size - actual_length)
        
        return (
            torch.tensor(padded_context, dtype=torch.long),      # Context (padded)
            torch.tensor(center_idx, dtype=torch.long),          # Target
            torch.tensor(actual_length, dtype=torch.long)        # Actual context length
        )

# Maximum context size = 2 * window_size (words on both sides)
MAX_CONTEXT_SIZE = 2 * WINDOW_SIZE

print("=" * 60)
print("CREATING CBOW DATASET")
print("=" * 60)
print()

cbow_dataset = CBOWDataset(cbow_pairs, MAX_CONTEXT_SIZE)

print()
print("Testing the dataset:")
print("-" * 40)

# Test getting samples - show different context lengths
for i in [0, 1, 5]:  # Different samples with potentially different lengths
    context, center, length = cbow_dataset[i]
    
    # Get actual context words (non-padded)
    actual_context = context[:length].tolist()
    context_words = [idx_to_word[idx] for idx in actual_context]
    
    print(f"Sample {i}:")
    print(f"  Context tensor (padded):  {context.tolist()}")
    print(f"  Actual context length:    {length.item()}")
    print(f"  Actual context indices:   {actual_context}")
    print(f"  Context words:            {context_words}")
    print(f"  Center (target):          {center.item()} -> '{idx_to_word[center.item()]}'")
    print()

print("=" * 60)

### 4.3 Creating DataLoaders

DataLoaders handle:
- **Batching**: Group multiple samples into one tensor
- **Shuffling**: Randomize order (good for training)
- **Iteration**: Easy to loop over batches

In [None]:
# ============================================
# 4.3 CREATE DATALOADERS
# ============================================

# Hyperparameter: batch size
BATCH_SIZE = 16  # Number of samples per batch

print("=" * 60)
print("CREATING DATALOADERS")
print("=" * 60)
print()
print(f"Batch size: {BATCH_SIZE}")
print()

# Create DataLoaders
skipgram_loader = DataLoader(
    skipgram_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,  # Shuffle for training
    drop_last=False  # Keep incomplete last batch
)

cbow_loader = DataLoader(
    cbow_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=False
)

print(f"Skip-Gram DataLoader:")
print(f"  Total samples: {len(skipgram_dataset)}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Number of batches: {len(skipgram_loader)}")
print()

print(f"CBOW DataLoader:")
print(f"  Total samples: {len(cbow_dataset)}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Number of batches: {len(cbow_loader)}")
print()
print("=" * 60)

In [None]:
# ============================================
# INSPECT A BATCH FROM EACH DATALOADER
# ============================================

print("=" * 70)
print("EXAMPLE BATCH FROM SKIP-GRAM DATALOADER")
print("=" * 70)
print()

# Get one batch from Skip-Gram
for batch_center, batch_context in skipgram_loader:
    print(f"Batch center words (input):")
    print(f"  Shape: {batch_center.shape}")
    print(f"  Values: {batch_center.tolist()}")
    print(f"  Words: {[idx_to_word[i] for i in batch_center.tolist()]}")
    print()
    print(f"Batch context words (target):")
    print(f"  Shape: {batch_context.shape}")
    print(f"  Values: {batch_context.tolist()}")
    print(f"  Words: {[idx_to_word[i] for i in batch_context.tolist()]}")
    break  # Only show first batch

print()
print("=" * 70)
print("EXAMPLE BATCH FROM CBOW DATALOADER")
print("=" * 70)
print()

# Get one batch from CBOW
for batch_context, batch_center, batch_lengths in cbow_loader:
    print(f"Batch context words (input):")
    print(f"  Shape: {batch_context.shape} (batch_size x max_context_size)")
    print(f"  First 3 samples: {batch_context[:3].tolist()}")
    print()
    print(f"Batch lengths:")
    print(f"  Shape: {batch_lengths.shape}")
    print(f"  Values: {batch_lengths.tolist()}")
    print()
    print(f"Batch center words (target):")
    print(f"  Shape: {batch_center.shape}")
    print(f"  Values: {batch_center.tolist()}")
    print(f"  Words: {[idx_to_word[i] for i in batch_center.tolist()]}")
    break  # Only show first batch

print()
print("=" * 70)

---
## 5. Skip-Gram Model

Now let's implement the Skip-Gram model in PyTorch!

We'll implement two versions:
1. **Basic Skip-Gram**: Uses full softmax (slower, but easier to understand)
2. **Skip-Gram with Negative Sampling**: Much more efficient for large vocabularies

### 5.1 Skip-Gram Architecture

```
┌─────────────────────────────────────────────────────────────────┐
│                     SKIP-GRAM ARCHITECTURE                       │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  INPUT: center_word_idx (e.g., 5 for "cat")                     │
│         ↓                                                        │
│  ┌──────────────────────────────────────────┐                   │
│  │         EMBEDDING LAYER                   │                   │
│  │   (vocab_size × embedding_dim matrix)    │                   │
│  │                                          │                   │
│  │   Row 0: [0.1, 0.2, ..., 0.5]  ← "the"  │                   │
│  │   Row 1: [0.3, 0.1, ..., 0.8]  ← "cat"  │                   │
│  │   Row 2: [0.2, 0.4, ..., 0.3]  ← "dog"  │                   │
│  │   ...                                    │                   │
│  │                                          │                   │
│  │   Lookup row 5 → word vector for "cat"  │                   │
│  └──────────────────────────────────────────┘                   │
│         ↓                                                        │
│  word_vector: [0.3, 0.1, 0.4, ..., 0.8]  (embedding_dim values) │
│         ↓                                                        │
│  ┌──────────────────────────────────────────┐                   │
│  │         LINEAR LAYER                      │                   │
│  │   (embedding_dim → vocab_size)           │                   │
│  │                                          │                   │
│  │   Computes score for EACH word in vocab  │                   │
│  └──────────────────────────────────────────┘                   │
│         ↓                                                        │
│  scores: [1.2, 3.5, 2.1, ..., 0.8]  (vocab_size values)         │
│         ↓                                                        │
│  ┌──────────────────────────────────────────┐                   │
│  │         SOFTMAX                           │                   │
│  │   Convert scores to probabilities        │                   │
│  │   (all values sum to 1.0)                │                   │
│  └──────────────────────────────────────────┘                   │
│         ↓                                                        │
│  OUTPUT: P(context_word | center_word) for all words            │
│          [0.05, 0.35, 0.15, ..., 0.02]                           │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘
```

### 5.2 Basic Skip-Gram Implementation

Let's implement the basic version with full softmax first.

In [None]:
# ============================================
# 5.2 BASIC SKIP-GRAM MODEL
# ============================================

class SkipGramBasic(nn.Module):
    """
    Basic Skip-Gram model using full softmax.
    
    This is the simplest implementation:
    - Input: center word index
    - Output: probability distribution over all context words
    
    Architecture:
        center_idx -> Embedding -> Linear -> LogSoftmax -> log_probs
    """
    
    def __init__(self, vocab_size, embedding_dim):
        """
        Initialize the Skip-Gram model.
        
        Args:
            vocab_size: Number of unique words in vocabulary
            embedding_dim: Dimension of word embeddings (e.g., 50, 100, 300)
        """
        super(SkipGramBasic, self).__init__()
        
        # Save parameters
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        
        # =============================================
        # LAYER 1: Embedding Layer
        # =============================================
        # This is the main thing we're learning!
        # Maps each word index to a dense vector
        # Shape: (vocab_size, embedding_dim)
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        
        # =============================================
        # LAYER 2: Output Linear Layer
        # =============================================
        # Projects embedding to vocabulary size
        # Each output unit represents score for one word
        # Shape: (embedding_dim, vocab_size)
        self.linear = nn.Linear(embedding_dim, vocab_size)
        
        # Initialize weights with small random values
        self._init_weights()
        
        print(f"SkipGramBasic model created:")
        print(f"  Vocab size: {vocab_size}")
        print(f"  Embedding dim: {embedding_dim}")
        print(f"  Embedding layer shape: ({vocab_size}, {embedding_dim})")
        print(f"  Linear layer shape: ({embedding_dim}, {vocab_size})")
    
    def _init_weights(self):
        """Initialize weights with small random values."""
        # Small values help with stable training
        initrange = 0.5 / self.embedding_dim
        self.embeddings.weight.data.uniform_(-initrange, initrange)
        self.linear.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
    
    def forward(self, center_idx):
        """
        Forward pass: predict context word probabilities from center word.
        
        Args:
            center_idx: Tensor of shape (batch_size,) with center word indices
            
        Returns:
            log_probs: Tensor of shape (batch_size, vocab_size) with log probabilities
        """
        # Step 1: Look up embeddings for center words
        # Input shape: (batch_size,)
        # Output shape: (batch_size, embedding_dim)
        embeds = self.embeddings(center_idx)
        
        # Step 2: Compute scores for all words in vocabulary
        # Input shape: (batch_size, embedding_dim)
        # Output shape: (batch_size, vocab_size)
        scores = self.linear(embeds)
        
        # Step 3: Convert scores to log probabilities
        # LogSoftmax is more numerically stable than Softmax + Log
        # Output shape: (batch_size, vocab_size)
        log_probs = F.log_softmax(scores, dim=1)
        
        return log_probs
    
    def get_word_embedding(self, word_idx):
        """
        Get the embedding vector for a specific word.
        
        Args:
            word_idx: Index of the word
            
        Returns:
            Embedding vector of shape (embedding_dim,)
        """
        return self.embeddings.weight[word_idx].detach().cpu()

# Hyperparameter: embedding dimension
EMBEDDING_DIM = 50

print("=" * 60)
print("CREATING SKIP-GRAM MODEL (BASIC)")
print("=" * 60)
print()

# Create the model
skipgram_basic = SkipGramBasic(vocab_size, EMBEDDING_DIM)

print()
print("=" * 60)

In [None]:
# ============================================
# INSPECT THE MODEL
# ============================================

print("=" * 60)
print("MODEL INSPECTION")
print("=" * 60)
print()

# Print model architecture
print("Model Architecture:")
print(skipgram_basic)
print()

# Count parameters
total_params = sum(p.numel() for p in skipgram_basic.parameters())
trainable_params = sum(p.numel() for p in skipgram_basic.parameters() if p.requires_grad)

print("Parameter Count:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print()

print("Parameter Breakdown:")
for name, param in skipgram_basic.named_parameters():
    print(f"  {name}: {param.shape} = {param.numel():,} parameters")

print()
print("=" * 60)

### 5.3 Forward Pass Walkthrough

Let's trace through the forward pass step by step to understand exactly what happens.

In [None]:
# ============================================
# 5.3 FORWARD PASS WALKTHROUGH
# ============================================

print("=" * 70)
print("FORWARD PASS WALKTHROUGH")
print("=" * 70)
print()

# Let's trace what happens when we input "cat"
center_word = "cat"
center_idx = word_to_idx[center_word]
center_tensor = torch.tensor([center_idx])  # Shape: (1,) - batch of 1

print(f"INPUT:")
print(f"  Word: '{center_word}'")
print(f"  Index: {center_idx}")
print(f"  Tensor shape: {center_tensor.shape}")
print()

# Set model to evaluation mode (disables dropout, etc.)
skipgram_basic.eval()

with torch.no_grad():  # Don't compute gradients for this demo
    
    # ========== STEP 1: Embedding Lookup ==========
    embedding = skipgram_basic.embeddings(center_tensor)
    
    print("STEP 1: EMBEDDING LOOKUP")
    print("-" * 50)
    print(f"  Input shape: {center_tensor.shape}")
    print(f"  Output shape: {embedding.shape}")
    print(f"  Embedding for '{center_word}':")
    print(f"    First 10 values: {embedding[0, :10].numpy().round(4)}")
    print(f"    Min: {embedding.min().item():.4f}, Max: {embedding.max().item():.4f}")
    print()
    
    # ========== STEP 2: Linear Transformation ==========
    scores = skipgram_basic.linear(embedding)
    
    print("STEP 2: LINEAR TRANSFORMATION (Scores)")
    print("-" * 50)
    print(f"  Input shape: {embedding.shape}")
    print(f"  Output shape: {scores.shape} (one score per word in vocab)")
    print(f"  Score statistics:")
    print(f"    Min: {scores.min().item():.4f}")
    print(f"    Max: {scores.max().item():.4f}")
    print(f"    Mean: {scores.mean().item():.4f}")
    print()
    
    # ========== STEP 3: Softmax ==========
    probs = F.softmax(scores, dim=1)
    log_probs = F.log_softmax(scores, dim=1)
    
    print("STEP 3: SOFTMAX (Probabilities)")
    print("-" * 50)
    print(f"  Input shape: {scores.shape}")
    print(f"  Output shape: {probs.shape}")
    print(f"  Sum of probabilities: {probs.sum().item():.6f} (should be 1.0)")
    print()
    
    # ========== Show Top Predictions ==========
    print("TOP 10 PREDICTED CONTEXT WORDS (before training):")
    print("-" * 50)
    
    top_probs, top_indices = probs[0].topk(10)
    
    for rank, (prob, idx) in enumerate(zip(top_probs, top_indices), 1):
        word = idx_to_word[idx.item()]
        print(f"  {rank:2d}. '{word}': {prob.item():.4f} ({prob.item()*100:.2f}%)")

print()
print("=" * 70)
print("Note: Before training, predictions are essentially random!")
print("=" * 70)

### 5.4 Skip-Gram with Negative Sampling

**Problem with Basic Skip-Gram**: 
- Computing softmax over the entire vocabulary is expensive!
- For vocabulary of 100,000 words: 100,000 operations per training step

**Solution - Negative Sampling**:
- Instead of computing probabilities for ALL words...
- Only distinguish between the CORRECT context word vs. a few RANDOM "negative" words
- Turns multi-class classification into binary classification
- Much faster! Only ~5-10 operations per training step

**How it works**:
- For positive pair (center, context): Model should output HIGH probability
- For negative pairs (center, random_word): Model should output LOW probability

In [None]:
# ============================================
# 5.4 SKIP-GRAM WITH NEGATIVE SAMPLING
# ============================================

class SkipGramNegSampling(nn.Module):
    """
    Skip-Gram model with Negative Sampling.
    
    Key difference from basic Skip-Gram:
    - Uses TWO embedding matrices (input and output)
    - Only computes scores for positive and negative samples
    - Uses binary cross-entropy instead of softmax
    
    This is MUCH more efficient for large vocabularies!
    """
    
    def __init__(self, vocab_size, embedding_dim):
        """
        Initialize the model.
        
        Args:
            vocab_size: Number of unique words
            embedding_dim: Dimension of word embeddings
        """
        super(SkipGramNegSampling, self).__init__()
        
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        
        # =============================================
        # INPUT EMBEDDINGS (for center words)
        # =============================================
        # These are the embeddings we'll use as our final word vectors
        self.input_embeddings = nn.Embedding(vocab_size, embedding_dim)
        
        # =============================================
        # OUTPUT EMBEDDINGS (for context/negative words)
        # =============================================
        # Separate embeddings used during training
        # We compute similarity between input and output embeddings
        self.output_embeddings = nn.Embedding(vocab_size, embedding_dim)
        
        # Initialize weights
        self._init_weights()
        
        print(f"SkipGramNegSampling model created:")
        print(f"  Vocab size: {vocab_size}")
        print(f"  Embedding dim: {embedding_dim}")
        print(f"  Input embeddings: ({vocab_size}, {embedding_dim})")
        print(f"  Output embeddings: ({vocab_size}, {embedding_dim})")
    
    def _init_weights(self):
        """Initialize weights with small random values."""
        initrange = 0.5 / self.embedding_dim
        self.input_embeddings.weight.data.uniform_(-initrange, initrange)
        self.output_embeddings.weight.data.uniform_(-initrange, initrange)
    
    def forward(self, center_idx, context_idx, negative_indices):
        """
        Compute negative sampling loss.
        
        Args:
            center_idx: (batch_size,) center word indices
            context_idx: (batch_size,) positive context word indices
            negative_indices: (batch_size, num_neg) negative sample indices
            
        Returns:
            loss: Scalar loss value
        """
        batch_size = center_idx.size(0)
        num_neg = negative_indices.size(1)
        
        # Get embeddings
        # center: (batch_size, embedding_dim)
        center_embeds = self.input_embeddings(center_idx)
        
        # positive context: (batch_size, embedding_dim)
        pos_embeds = self.output_embeddings(context_idx)
        
        # negative samples: (batch_size, num_neg, embedding_dim)
        neg_embeds = self.output_embeddings(negative_indices)
        
        # =============================================
        # POSITIVE SCORE
        # =============================================
        # Dot product of center and positive context embeddings
        # Higher = more likely to be a real context word
        # Shape: (batch_size,)
        pos_score = torch.sum(center_embeds * pos_embeds, dim=1)
        
        # Log-sigmoid of positive score
        # We want this to be HIGH (close to 0, since log(1) = 0)
        pos_loss = F.logsigmoid(pos_score)
        
        # =============================================
        # NEGATIVE SCORES
        # =============================================
        # Dot product of center with each negative sample
        # We want these to be LOW
        # Shape: (batch_size, num_neg)
        neg_score = torch.bmm(neg_embeds, center_embeds.unsqueeze(2)).squeeze(2)
        
        # Log-sigmoid of NEGATIVE score (note the minus sign!)
        # We want sigmoid(-neg_score) to be HIGH (close to 1)
        neg_loss = F.logsigmoid(-neg_score).sum(dim=1)
        
        # =============================================
        # TOTAL LOSS
        # =============================================
        # Negative because we want to MAXIMIZE log likelihood
        # But PyTorch optimizers MINIMIZE, so we negate
        loss = -(pos_loss + neg_loss).mean()
        
        return loss
    
    def get_word_embedding(self, word_idx):
        """Get the embedding vector for a word (from input embeddings)."""
        return self.input_embeddings.weight[word_idx].detach().cpu()

print("=" * 60)
print("CREATING SKIP-GRAM MODEL (NEGATIVE SAMPLING)")
print("=" * 60)
print()

# Create the model
skipgram_ns = SkipGramNegSampling(vocab_size, EMBEDDING_DIM)

print()
print("=" * 60)

### 5.5 Negative Sampling Function

We need a function to sample random "negative" words. These are words that are NOT the actual context word.

In [None]:
# ============================================
# 5.5 NEGATIVE SAMPLING FUNCTION
# ============================================

def get_negative_samples(batch_size, num_neg, vocab_size, positive_indices):
    """
    Sample negative words for negative sampling.
    
    For each positive sample, we randomly select num_neg words from the vocabulary
    that are NOT the positive context word.
    
    Args:
        batch_size: Number of samples in the batch
        num_neg: Number of negative samples per positive sample
        vocab_size: Size of vocabulary
        positive_indices: Tensor of positive context word indices to avoid
        
    Returns:
        Tensor of shape (batch_size, num_neg) with negative sample indices
    """
    negative_samples = []
    
    for i in range(batch_size):
        neg = []
        positive = positive_indices[i].item()
        
        while len(neg) < num_neg:
            # Random word from vocabulary
            sample = random.randint(0, vocab_size - 1)
            
            # Make sure it's not the positive word
            if sample != positive:
                neg.append(sample)
        
        negative_samples.append(neg)
    
    return torch.tensor(negative_samples, dtype=torch.long)

# Test the negative sampling function
NUM_NEGATIVE = 5  # Number of negative samples

print("=" * 70)
print("NEGATIVE SAMPLING EXAMPLE")
print("=" * 70)
print()

# Example: positive words are "cat" and "dog"
test_positive = torch.tensor([word_to_idx["cat"], word_to_idx["dog"]])
test_negative = get_negative_samples(2, NUM_NEGATIVE, vocab_size, test_positive)

print(f"Number of negative samples per positive: {NUM_NEGATIVE}")
print()

for i in range(2):
    pos_word = idx_to_word[test_positive[i].item()]
    neg_words = [idx_to_word[idx] for idx in test_negative[i].tolist()]
    
    print(f"Sample {i+1}:")
    print(f"  Positive word: '{pos_word}' (index {test_positive[i].item()})")
    print(f"  Negative indices: {test_negative[i].tolist()}")
    print(f"  Negative words: {neg_words}")
    print()

print("=" * 70)

---
## 6. CBOW (Continuous Bag of Words) Model

CBOW is the "opposite" of Skip-Gram:
- **Skip-Gram**: center word → predict context words
- **CBOW**: context words → predict center word

The key difference is that CBOW **averages** all context word embeddings before making a prediction.

### 6.1 CBOW Architecture

```
┌─────────────────────────────────────────────────────────────────┐
│                       CBOW ARCHITECTURE                          │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  INPUT: context_word_indices (e.g., [0, 1, 3, 0] for            │
│         ["the", "cat", "on", "the"])                            │
│         ↓                                                        │
│  ┌──────────────────────────────────────────┐                   │
│  │         EMBEDDING LAYER                   │                   │
│  │   Look up embedding for EACH context word │                   │
│  │                                          │                   │
│  │   "the" → [0.1, 0.2, ..., 0.5]          │                   │
│  │   "cat" → [0.3, 0.1, ..., 0.8]          │                   │
│  │   "on"  → [0.4, 0.3, ..., 0.2]          │                   │
│  │   "the" → [0.1, 0.2, ..., 0.5]          │                   │
│  └──────────────────────────────────────────┘                   │
│         ↓                                                        │
│  4 vectors, each of shape (embedding_dim,)                      │
│         ↓                                                        │
│  ┌──────────────────────────────────────────┐                   │
│  │         AVERAGE                           │                   │
│  │   Average all context embeddings         │                   │
│  │   (0.1+0.3+0.4+0.1)/4 = 0.225, ...      │                   │
│  └──────────────────────────────────────────┘                   │
│         ↓                                                        │
│  avg_vector: [0.225, 0.2, ..., 0.5]  (one vector!)              │
│         ↓                                                        │
│  ┌──────────────────────────────────────────┐                   │
│  │         LINEAR + SOFTMAX                  │                   │
│  │   Same as Skip-Gram output layer         │                   │
│  └──────────────────────────────────────────┘                   │
│         ↓                                                        │
│  OUTPUT: P(center_word | context_words)                         │
│          Probability for each word in vocabulary                 │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘
```

### 6.2 CBOW Implementation

In [None]:
# ============================================
# 6.2 CBOW MODEL IMPLEMENTATION
# ============================================

class CBOW(nn.Module):
    """
    Continuous Bag of Words (CBOW) model.
    
    Predicts the center word given surrounding context words.
    
    Key steps:
    1. Look up embeddings for ALL context words
    2. AVERAGE the context embeddings into one vector
    3. Use that vector to predict the center word
    
    Architecture:
        context_indices -> Embeddings -> Average -> Linear -> LogSoftmax -> log_probs
    """
    
    def __init__(self, vocab_size, embedding_dim):
        """
        Initialize the CBOW model.
        
        Args:
            vocab_size: Number of unique words
            embedding_dim: Dimension of word embeddings
        """
        super(CBOW, self).__init__()
        
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        
        # Embedding layer (same as Skip-Gram)
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        
        # Output linear layer (same as Skip-Gram)
        self.linear = nn.Linear(embedding_dim, vocab_size)
        
        # Initialize weights
        self._init_weights()
        
        print(f"CBOW model created:")
        print(f"  Vocab size: {vocab_size}")
        print(f"  Embedding dim: {embedding_dim}")
    
    def _init_weights(self):
        """Initialize weights with small random values."""
        initrange = 0.5 / self.embedding_dim
        self.embeddings.weight.data.uniform_(-initrange, initrange)
        self.linear.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
    
    def forward(self, context_indices, context_lengths):
        """
        Forward pass: predict center word from context words.
        
        Args:
            context_indices: (batch_size, max_context_size) context word indices
                            Padded with -1 for variable length contexts
            context_lengths: (batch_size,) actual number of context words per sample
            
        Returns:
            log_probs: (batch_size, vocab_size) log probabilities
        """
        batch_size = context_indices.size(0)
        
        # =============================================
        # STEP 1: Handle padding (-1 values)
        # =============================================
        # Replace -1 (padding) with 0 for embedding lookup
        # We'll zero out these positions later
        context_safe = context_indices.clamp(min=0)
        
        # =============================================
        # STEP 2: Get embeddings for all context words
        # =============================================
        # Shape: (batch_size, max_context_size, embedding_dim)
        embeds = self.embeddings(context_safe)
        
        # =============================================
        # STEP 3: Create mask for padding
        # =============================================
        # True where NOT padded, False where padded
        # Shape: (batch_size, max_context_size, 1)
        mask = (context_indices >= 0).float().unsqueeze(2)
        
        # =============================================
        # STEP 4: Apply mask and compute average
        # =============================================
        # Zero out padded positions
        masked_embeds = embeds * mask
        
        # Sum of context embeddings
        # Shape: (batch_size, embedding_dim)
        sum_embeds = masked_embeds.sum(dim=1)
        
        # Divide by actual context length (not max_context_size!)
        # Shape: (batch_size, 1)
        lengths = context_lengths.float().unsqueeze(1)
        avg_embeds = sum_embeds / lengths
        
        # =============================================
        # STEP 5: Linear layer and softmax
        # =============================================
        scores = self.linear(avg_embeds)
        log_probs = F.log_softmax(scores, dim=1)
        
        return log_probs
    
    def get_word_embedding(self, word_idx):
        """Get the embedding vector for a word."""
        return self.embeddings.weight[word_idx].detach().cpu()

print("=" * 60)
print("CREATING CBOW MODEL")
print("=" * 60)
print()

# Create the model
cbow_model = CBOW(vocab_size, EMBEDDING_DIM)

print()
print("Model Architecture:")
print(cbow_model)
print()

# Count parameters
total_params = sum(p.numel() for p in cbow_model.parameters())
print(f"Total parameters: {total_params:,}")
print()
print("=" * 60)

### 6.3 CBOW Forward Pass Walkthrough

Let's trace through the CBOW forward pass step by step.

In [None]:
# ============================================
# 6.3 CBOW FORWARD PASS WALKTHROUGH
# ============================================

print("=" * 70)
print("CBOW FORWARD PASS WALKTHROUGH")
print("=" * 70)
print()

# Example: predicting "sat" from context ["the", "cat", "on", "the"]
context_words = ["the", "cat", "on", "the"]
center_word = "sat"

context_indices = [word_to_idx[w] for w in context_words]
context_tensor = torch.tensor([context_indices])  # Shape: (1, 4)
length_tensor = torch.tensor([len(context_indices)])

print(f"INPUT:")
print(f"  Context words: {context_words}")
print(f"  Context indices: {context_indices}")
print(f"  Context tensor shape: {context_tensor.shape}")
print(f"  Target (center word): '{center_word}'")
print()

cbow_model.eval()

with torch.no_grad():
    
    # ========== STEP 1: Get embeddings for each context word ==========
    embeds = cbow_model.embeddings(context_tensor)
    
    print("STEP 1: EMBEDDING LOOKUP FOR EACH CONTEXT WORD")
    print("-" * 60)
    print(f"  Output shape: {embeds.shape} (batch, num_context, embed_dim)")
    print()
    
    for i, word in enumerate(context_words):
        print(f"  '{word}' embedding (first 8 values): {embeds[0, i, :8].numpy().round(4)}")
    print()
    
    # ========== STEP 2: Average the embeddings ==========
    avg_embed = embeds.mean(dim=1)
    
    print("STEP 2: AVERAGE CONTEXT EMBEDDINGS")
    print("-" * 60)
    print(f"  Output shape: {avg_embed.shape} (batch, embed_dim)")
    print(f"  Averaged embedding (first 8 values): {avg_embed[0, :8].numpy().round(4)}")
    print()
    
    # ========== STEP 3: Linear transformation ==========
    scores = cbow_model.linear(avg_embed)
    
    print("STEP 3: LINEAR TRANSFORMATION")
    print("-" * 60)
    print(f"  Output shape: {scores.shape}")
    print(f"  Score statistics: min={scores.min():.4f}, max={scores.max():.4f}")
    print()
    
    # ========== STEP 4: Softmax ==========
    probs = F.softmax(scores, dim=1)
    
    print("STEP 4: SOFTMAX (Probabilities)")
    print("-" * 60)
    print(f"  Sum of probabilities: {probs.sum().item():.6f}")
    print()
    
    # ========== Show top predictions ==========
    print("TOP 10 PREDICTIONS (before training):")
    print("-" * 60)
    
    top_probs, top_indices = probs[0].topk(10)
    
    for rank, (prob, idx) in enumerate(zip(top_probs, top_indices), 1):
        word = idx_to_word[idx.item()]
        marker = " <-- TARGET" if word == center_word else ""
        print(f"  {rank:2d}. '{word}': {prob.item():.4f} ({prob.item()*100:.2f}%){marker}")

print()
print("=" * 70)
print("Note: Before training, the target word 'sat' is likely not in top predictions!")
print("=" * 70)

---
## 7. Training the Models

Now let's train our models! The training loop follows this pattern:

```
for each epoch:
    for each batch:
        1. Forward pass: compute predictions
        2. Compute loss: how wrong were we?
        3. Backward pass: compute gradients
        4. Update weights: improve the model
```

### Key Components:
- **Loss Function**: Measures prediction error
- **Optimizer**: Updates weights to minimize loss
- **Learning Rate**: How big the weight updates are

In [None]:
# ============================================
# 7.1 TRAINING HYPERPARAMETERS
# ============================================

# Training settings
LEARNING_RATE = 0.01      # How fast the model learns
NUM_EPOCHS = 100          # How many times to go through the data
NUM_NEGATIVE = 5          # Number of negative samples (for neg sampling)

print("=" * 60)
print("TRAINING HYPERPARAMETERS")
print("=" * 60)
print()
print(f"Embedding dimension:    {EMBEDDING_DIM}")
print(f"Batch size:             {BATCH_SIZE}")
print(f"Learning rate:          {LEARNING_RATE}")
print(f"Number of epochs:       {NUM_EPOCHS}")
print(f"Negative samples:       {NUM_NEGATIVE}")
print()
print("=" * 60)

### 7.2 Training Skip-Gram (Basic)

In [None]:
# ============================================
# 7.2 TRAIN SKIP-GRAM (BASIC)
# ============================================

def train_skipgram_basic(model, dataloader, num_epochs, learning_rate, device):
    """
    Train the basic Skip-Gram model.
    
    Args:
        model: SkipGramBasic model
        dataloader: DataLoader with training data
        num_epochs: Number of training epochs
        learning_rate: Learning rate for optimizer
        device: Device to train on (CPU/GPU)
        
    Returns:
        losses: List of average loss per epoch
    """
    # Move model to device
    model = model.to(device)
    model.train()  # Set to training mode
    
    # Loss function: Negative Log-Likelihood
    # This is the standard loss for classification with log_softmax output
    criterion = nn.NLLLoss()
    
    # Optimizer: Adam (adaptive learning rate)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Track losses
    losses = []
    
    print("=" * 60)
    print("TRAINING SKIP-GRAM (BASIC)")
    print("=" * 60)
    print(f"Device: {device}")
    print(f"Epochs: {num_epochs}")
    print(f"Learning rate: {learning_rate}")
    print("-" * 60)
    
    for epoch in range(num_epochs):
        total_loss = 0
        num_batches = 0
        
        for center, context in dataloader:
            # Move data to device
            center = center.to(device)
            context = context.to(device)
            
            # ========== STEP 1: Forward pass ==========
            # Get predictions (log probabilities)
            log_probs = model(center)
            
            # ========== STEP 2: Compute loss ==========
            # Compare predictions to actual context words
            loss = criterion(log_probs, context)
            
            # ========== STEP 3: Backward pass ==========
            # Clear previous gradients
            optimizer.zero_grad()
            # Compute gradients
            loss.backward()
            
            # ========== STEP 4: Update weights ==========
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
        
        # Calculate average loss for this epoch
        avg_loss = total_loss / num_batches
        losses.append(avg_loss)
        
        # Print progress every 10 epochs
        if (epoch + 1) % 10 == 0 or epoch == 0:
            print(f"Epoch {epoch+1:3d}/{num_epochs} | Loss: {avg_loss:.4f}")
    
    print("-" * 60)
    print(f"Training complete! Final loss: {losses[-1]:.4f}")
    print("=" * 60)
    
    return losses

# Re-create the model (fresh weights)
skipgram_basic = SkipGramBasic(vocab_size, EMBEDDING_DIM)

# Train the model
skipgram_basic_losses = train_skipgram_basic(
    skipgram_basic, 
    skipgram_loader, 
    NUM_EPOCHS, 
    LEARNING_RATE, 
    device
)

### 7.3 Training Skip-Gram (Negative Sampling)

In [None]:
# ============================================
# 7.3 TRAIN SKIP-GRAM (NEGATIVE SAMPLING)
# ============================================

def train_skipgram_negative_sampling(model, pairs, num_epochs, learning_rate, 
                                     batch_size, num_neg, vocab_size, device):
    """
    Train Skip-Gram model with Negative Sampling.
    
    Note: We use the raw pairs instead of DataLoader because we need
    to generate negative samples dynamically.
    """
    model = model.to(device)
    model.train()
    
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    losses = []
    
    print("=" * 60)
    print("TRAINING SKIP-GRAM (NEGATIVE SAMPLING)")
    print("=" * 60)
    print(f"Device: {device}")
    print(f"Epochs: {num_epochs}")
    print(f"Negative samples: {num_neg}")
    print("-" * 60)
    
    for epoch in range(num_epochs):
        # Shuffle pairs at the start of each epoch
        random.shuffle(pairs)
        
        total_loss = 0
        num_batches = 0
        
        # Process in batches
        for i in range(0, len(pairs), batch_size):
            batch = pairs[i:i+batch_size]
            
            # Extract center and context indices
            center_indices = torch.tensor([p[0] for p in batch], dtype=torch.long).to(device)
            context_indices = torch.tensor([p[1] for p in batch], dtype=torch.long).to(device)
            
            # Generate negative samples
            negative_indices = get_negative_samples(
                len(batch), num_neg, vocab_size, context_indices
            ).to(device)
            
            # Forward pass (returns loss directly)
            loss = model(center_indices, context_indices, negative_indices)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
        
        avg_loss = total_loss / num_batches
        losses.append(avg_loss)
        
        if (epoch + 1) % 10 == 0 or epoch == 0:
            print(f"Epoch {epoch+1:3d}/{num_epochs} | Loss: {avg_loss:.4f}")
    
    print("-" * 60)
    print(f"Training complete! Final loss: {losses[-1]:.4f}")
    print("=" * 60)
    
    return losses

# Re-create the model (fresh weights)
skipgram_ns = SkipGramNegSampling(vocab_size, EMBEDDING_DIM)

# Train the model
skipgram_ns_losses = train_skipgram_negative_sampling(
    skipgram_ns,
    skipgram_pairs,
    NUM_EPOCHS,
    LEARNING_RATE,
    BATCH_SIZE,
    NUM_NEGATIVE,
    vocab_size,
    device
)

### 7.4 Training CBOW

In [None]:
# ============================================
# 7.4 TRAIN CBOW
# ============================================

def train_cbow(model, dataloader, num_epochs, learning_rate, device):
    """
    Train the CBOW model.
    """
    model = model.to(device)
    model.train()
    
    criterion = nn.NLLLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    losses = []
    
    print("=" * 60)
    print("TRAINING CBOW")
    print("=" * 60)
    print(f"Device: {device}")
    print(f"Epochs: {num_epochs}")
    print(f"Learning rate: {learning_rate}")
    print("-" * 60)
    
    for epoch in range(num_epochs):
        total_loss = 0
        num_batches = 0
        
        for context, center, lengths in dataloader:
            context = context.to(device)
            center = center.to(device)
            lengths = lengths.to(device)
            
            # Forward pass
            log_probs = model(context, lengths)
            
            # Compute loss
            loss = criterion(log_probs, center)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
        
        avg_loss = total_loss / num_batches
        losses.append(avg_loss)
        
        if (epoch + 1) % 10 == 0 or epoch == 0:
            print(f"Epoch {epoch+1:3d}/{num_epochs} | Loss: {avg_loss:.4f}")
    
    print("-" * 60)
    print(f"Training complete! Final loss: {losses[-1]:.4f}")
    print("=" * 60)
    
    return losses

# Re-create the model (fresh weights)
cbow_model = CBOW(vocab_size, EMBEDDING_DIM)

# Train the model
cbow_losses = train_cbow(
    cbow_model,
    cbow_loader,
    NUM_EPOCHS,
    LEARNING_RATE,
    device
)

### 7.5 Visualize Training Loss

In [None]:
# ============================================
# 7.5 VISUALIZE TRAINING LOSS
# ============================================

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Plot 1: Skip-Gram Basic
ax1 = axes[0]
ax1.plot(skipgram_basic_losses, color='blue', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Skip-Gram (Basic)\nTraining Loss', fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3)

# Plot 2: Skip-Gram Negative Sampling
ax2 = axes[1]
ax2.plot(skipgram_ns_losses, color='green', linewidth=2)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Loss', fontsize=12)
ax2.set_title('Skip-Gram (Neg Sampling)\nTraining Loss', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3)

# Plot 3: CBOW
ax3 = axes[2]
ax3.plot(cbow_losses, color='red', linewidth=2)
ax3.set_xlabel('Epoch', fontsize=12)
ax3.set_ylabel('Loss', fontsize=12)
ax3.set_title('CBOW\nTraining Loss', fontsize=14, fontweight='bold')
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Summary table
print("=" * 60)
print("TRAINING SUMMARY")
print("=" * 60)
print()
print(f"{'Model':<30} {'Initial Loss':<15} {'Final Loss':<15}")
print("-" * 60)
print(f"{'Skip-Gram (Basic)':<30} {skipgram_basic_losses[0]:<15.4f} {skipgram_basic_losses[-1]:<15.4f}")
print(f"{'Skip-Gram (Neg Sampling)':<30} {skipgram_ns_losses[0]:<15.4f} {skipgram_ns_losses[-1]:<15.4f}")
print(f"{'CBOW':<30} {cbow_losses[0]:<15.4f} {cbow_losses[-1]:<15.4f}")
print("-" * 60)
print()
print("=" * 60)

---
## 8. Evaluation & Visualization

Now that our models are trained, let's evaluate the learned embeddings!

We'll look at:
1. **Word Similarity**: Using cosine similarity
2. **t-SNE Visualization**: 2D visualization of embeddings
3. **Model Comparison**: Compare Skip-Gram vs CBOW

### 8.1 Word Similarity with Cosine Similarity

**Cosine similarity** measures the angle between two vectors:

$$\text{cosine}(\vec{u}, \vec{v}) = \frac{\vec{u} \cdot \vec{v}}{||\vec{u}|| \times ||\vec{v}||}$$

- **1.0**: Identical direction (most similar)
- **0.0**: Perpendicular (no similarity)
- **-1.0**: Opposite direction (least similar)

In [None]:
# ============================================
# 8.1 WORD SIMILARITY FUNCTIONS
# ============================================

def cosine_similarity(vec1, vec2):
    """
    Compute cosine similarity between two vectors.
    
    Args:
        vec1, vec2: Tensors of same shape
        
    Returns:
        Cosine similarity (float between -1 and 1)
    """
    dot_product = torch.dot(vec1, vec2)
    norm1 = torch.norm(vec1)
    norm2 = torch.norm(vec2)
    return (dot_product / (norm1 * norm2)).item()

def find_similar_words(model, word, word_to_idx, idx_to_word, top_k=5):
    """
    Find the most similar words to a given word.
    
    Args:
        model: Trained Word2Vec model
        word: Word to find similar words for
        word_to_idx: Word to index mapping
        idx_to_word: Index to word mapping
        top_k: Number of similar words to return
        
    Returns:
        List of (word, similarity) tuples
    """
    if word not in word_to_idx:
        print(f"'{word}' not in vocabulary!")
        return []
    
    word_idx = word_to_idx[word]
    word_vec = model.get_word_embedding(word_idx)
    
    similarities = []
    
    for idx in range(len(idx_to_word)):
        if idx != word_idx:  # Skip the word itself
            other_vec = model.get_word_embedding(idx)
            sim = cosine_similarity(word_vec, other_vec)
            similarities.append((idx_to_word[idx], sim))
    
    # Sort by similarity (descending)
    similarities.sort(key=lambda x: -x[1])
    
    return similarities[:top_k]

# Test the similarity function
print("=" * 60)
print("TESTING COSINE SIMILARITY")
print("=" * 60)
print()

# Get embeddings for cat and dog
cat_vec = skipgram_ns.get_word_embedding(word_to_idx["cat"])
dog_vec = skipgram_ns.get_word_embedding(word_to_idx["dog"])
the_vec = skipgram_ns.get_word_embedding(word_to_idx["the"])

print(f"Cosine similarity (cat, dog): {cosine_similarity(cat_vec, dog_vec):.4f}")
print(f"Cosine similarity (cat, the): {cosine_similarity(cat_vec, the_vec):.4f}")
print(f"Cosine similarity (cat, cat): {cosine_similarity(cat_vec, cat_vec):.4f}")
print()
print("=" * 60)

In [None]:
# ============================================
# FIND SIMILAR WORDS FOR TEST WORDS
# ============================================

test_words = ["cat", "dog", "the", "sat", "love"]

print("=" * 70)
print("MOST SIMILAR WORDS (Skip-Gram with Negative Sampling)")
print("=" * 70)
print()

for word in test_words:
    if word in word_to_idx:
        print(f"Words most similar to '{word}':")
        similar = find_similar_words(skipgram_ns, word, word_to_idx, idx_to_word, top_k=5)
        for rank, (sim_word, score) in enumerate(similar, 1):
            print(f"  {rank}. '{sim_word}': {score:.4f}")
        print()

print("=" * 70)

### 8.2 t-SNE Visualization

**t-SNE** (t-Distributed Stochastic Neighbor Embedding) reduces high-dimensional embeddings to 2D for visualization.

Words that are close in the embedding space should be close in the 2D visualization.

In [None]:
# ============================================
# 8.2 t-SNE VISUALIZATION
# ============================================

def visualize_embeddings_tsne(model, idx_to_word, title="Word Embeddings"):
    """
    Visualize word embeddings using t-SNE.
    
    Args:
        model: Trained Word2Vec model
        idx_to_word: Index to word mapping
        title: Plot title
    """
    # Collect all embeddings
    embeddings = []
    words = []
    
    for idx in range(len(idx_to_word)):
        embedding = model.get_word_embedding(idx).numpy()
        embeddings.append(embedding)
        words.append(idx_to_word[idx])
    
    embeddings = np.array(embeddings)
    
    print(f"Embeddings shape: {embeddings.shape}")
    
    # Apply t-SNE
    # perplexity should be less than number of samples
    perplexity = min(5, len(words) - 1)
    
    tsne = TSNE(n_components=2, random_state=SEED, perplexity=perplexity, n_iter=1000)
    embeddings_2d = tsne.fit_transform(embeddings)
    
    # Create plot
    plt.figure(figsize=(12, 10))
    
    # Scatter plot
    plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], 
                c='steelblue', s=100, alpha=0.7, edgecolors='navy')
    
    # Add labels for each point
    for i, word in enumerate(words):
        plt.annotate(word, 
                    xy=(embeddings_2d[i, 0], embeddings_2d[i, 1]),
                    xytext=(5, 5), textcoords='offset points',
                    fontsize=11, fontweight='bold',
                    color='darkred')
    
    plt.title(title, fontsize=16, fontweight='bold')
    plt.xlabel('t-SNE Dimension 1', fontsize=12)
    plt.ylabel('t-SNE Dimension 2', fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

# Visualize embeddings from Skip-Gram (Negative Sampling)
print("=" * 60)
print("t-SNE VISUALIZATION")
print("=" * 60)
print()

visualize_embeddings_tsne(skipgram_ns, idx_to_word, 
                          "Skip-Gram (Neg Sampling) Word Embeddings")

In [None]:
# Visualize CBOW embeddings for comparison
visualize_embeddings_tsne(cbow_model, idx_to_word, 
                          "CBOW Word Embeddings")

### 8.3 Model Comparison

Let's compare the similarity results from different models side by side.

In [None]:
# ============================================
# 8.3 MODEL COMPARISON
# ============================================

def compare_models(models_dict, test_words, word_to_idx, idx_to_word, top_k=3):
    """
    Compare similar words across multiple models.
    """
    print("=" * 90)
    print("MODEL COMPARISON: Most Similar Words")
    print("=" * 90)
    print()
    
    for word in test_words:
        if word not in word_to_idx:
            continue
            
        print(f"Similar to '{word}':")
        print("-" * 90)
        
        # Header
        header = f"{'Rank':<6}"
        for model_name in models_dict.keys():
            header += f"{model_name:<28}"
        print(header)
        print("-" * 90)
        
        # Get similar words from each model
        all_similar = {}
        for model_name, model in models_dict.items():
            all_similar[model_name] = find_similar_words(
                model, word, word_to_idx, idx_to_word, top_k
            )
        
        # Print row by row
        for i in range(top_k):
            row = f"{i+1:<6}"
            for model_name in models_dict.keys():
                if i < len(all_similar[model_name]):
                    w, s = all_similar[model_name][i]
                    row += f"'{w}' ({s:.3f}){'':12}"
                else:
                    row += f"{'N/A':<28}"
            print(row)
        
        print()
    
    print("=" * 90)

# Create dictionary of models
models = {
    "Skip-Gram (Basic)": skipgram_basic,
    "Skip-Gram (NegSamp)": skipgram_ns,
    "CBOW": cbow_model
}

# Compare models
compare_models(models, ["cat", "dog", "the"], word_to_idx, idx_to_word, top_k=5)

---
## 9. Playing with Results

Let's have some fun with our learned embeddings!

### 9.1 Word Arithmetic (Vector Operations)

One of the most famous properties of Word2Vec is that word embeddings can capture **semantic relationships** through vector arithmetic.

The classic example:
```
king - man + woman ≈ queen
```

This works because the vector `king - man` captures the concept of "royalty without maleness", and adding `woman` gives us "royalty with femaleness" = queen.

Note: Our vocabulary is small, so results may not be perfect, but let's try!

In [None]:
# ============================================
# 9.1 WORD ARITHMETIC
# ============================================

def word_arithmetic(model, positive_words, negative_words, word_to_idx, idx_to_word, top_k=5):
    """
    Perform word vector arithmetic: sum of positive words - sum of negative words.
    
    Example: word_arithmetic(["king", "woman"], ["man"]) should return "queen"
    
    Args:
        model: Trained Word2Vec model
        positive_words: List of words to add
        negative_words: List of words to subtract
        word_to_idx: Word to index mapping
        idx_to_word: Index to word mapping
        top_k: Number of results to return
        
    Returns:
        List of (word, similarity) tuples
    """
    # Start with zero vector
    result_vec = torch.zeros(model.embedding_dim)
    
    # Build equation string for display
    equation = " + ".join(positive_words)
    if negative_words:
        equation += " - " + " - ".join(negative_words)
    
    print(f"Computing: {equation}")
    print("-" * 50)
    
    # Add positive word vectors
    for word in positive_words:
        if word in word_to_idx:
            vec = model.get_word_embedding(word_to_idx[word])
            result_vec += vec
            print(f"  + '{word}'")
        else:
            print(f"  Warning: '{word}' not in vocabulary!")
    
    # Subtract negative word vectors
    for word in negative_words:
        if word in word_to_idx:
            vec = model.get_word_embedding(word_to_idx[word])
            result_vec -= vec
            print(f"  - '{word}'")
        else:
            print(f"  Warning: '{word}' not in vocabulary!")
    
    print("-" * 50)
    
    # Find words most similar to the result vector
    # Exclude input words from results
    exclude = set(positive_words + negative_words)
    
    similarities = []
    for idx in range(len(idx_to_word)):
        word = idx_to_word[idx]
        if word not in exclude:
            word_vec = model.get_word_embedding(idx)
            sim = cosine_similarity(result_vec, word_vec)
            similarities.append((word, sim))
    
    # Sort by similarity
    similarities.sort(key=lambda x: -x[1])
    
    print(f"Result (top {top_k} matches):")
    for rank, (word, sim) in enumerate(similarities[:top_k], 1):
        print(f"  {rank}. '{word}': {sim:.4f}")
    
    return similarities[:top_k]

print("=" * 60)
print("WORD ARITHMETIC EXAMPLES")
print("=" * 60)
print()

# Try some word arithmetic with our vocabulary
# Since we have a pet-focused corpus, let's try relevant examples

print("Example 1: cat + dog (combining pet concepts)")
print()
word_arithmetic(skipgram_ns, ["cat", "dog"], [], word_to_idx, idx_to_word)

print()
print()

print("Example 2: cat - the (removing common word influence)")
print()
word_arithmetic(skipgram_ns, ["cat"], ["the"], word_to_idx, idx_to_word)

print()
print()

print("Example 3: sat + sleeping (action combination)")
print()
word_arithmetic(skipgram_ns, ["sat", "sleeping"], [], word_to_idx, idx_to_word)

print()
print("=" * 60)

### 9.2 Interactive Word Explorer

Let's create a function to explore a word's embedding across all models.

In [None]:
# ============================================
# 9.2 INTERACTIVE WORD EXPLORER
# ============================================

def explore_word(word, models_dict, word_to_idx, idx_to_word):
    """
    Explore a word's embedding across multiple models.
    
    Args:
        word: Word to explore
        models_dict: Dictionary of {model_name: model}
        word_to_idx: Word to index mapping
        idx_to_word: Index to word mapping
    """
    print("=" * 70)
    print(f"EXPLORING WORD: '{word}'")
    print("=" * 70)
    print()
    
    if word not in word_to_idx:
        print(f"'{word}' is not in the vocabulary!")
        return
    
    word_idx = word_to_idx[word]
    print(f"Word index: {word_idx}")
    print(f"Word frequency in corpus: {word_counts[word]}")
    print()
    
    for model_name, model in models_dict.items():
        print(f"--- {model_name} ---")
        
        # Get embedding
        embedding = model.get_word_embedding(word_idx)
        
        # Embedding statistics
        print(f"  Embedding shape: {embedding.shape}")
        print(f"  Embedding norm: {torch.norm(embedding).item():.4f}")
        print(f"  Min value: {embedding.min().item():.4f}")
        print(f"  Max value: {embedding.max().item():.4f}")
        print(f"  Mean value: {embedding.mean().item():.4f}")
        print(f"  First 10 values: {embedding[:10].numpy().round(4)}")
        
        # Similar words
        similar = find_similar_words(model, word, word_to_idx, idx_to_word, top_k=5)
        print(f"  Most similar: {[(w, f'{s:.3f}') for w, s in similar]}")
        print()
    
    print("=" * 70)

# Explore some words
for word in ["cat", "dog"]:
    explore_word(word, models, word_to_idx, idx_to_word)
    print()
    print()

### 9.3 Similarity Heatmap

Let's visualize the similarity between all words as a heatmap.

In [None]:
# ============================================
# 9.3 SIMILARITY HEATMAP
# ============================================

def plot_similarity_heatmap(model, idx_to_word, title="Word Similarity Matrix"):
    """
    Plot a heatmap of word similarities.
    """
    n_words = len(idx_to_word)
    
    # Compute similarity matrix
    sim_matrix = np.zeros((n_words, n_words))
    
    for i in range(n_words):
        vec_i = model.get_word_embedding(i)
        for j in range(n_words):
            vec_j = model.get_word_embedding(j)
            sim_matrix[i, j] = cosine_similarity(vec_i, vec_j)
    
    # Create heatmap
    fig, ax = plt.subplots(figsize=(12, 10))
    
    im = ax.imshow(sim_matrix, cmap='RdYlBu_r', aspect='auto', vmin=-1, vmax=1)
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label('Cosine Similarity', fontsize=12)
    
    # Set ticks
    words = [idx_to_word[i] for i in range(n_words)]
    ax.set_xticks(range(n_words))
    ax.set_yticks(range(n_words))
    ax.set_xticklabels(words, rotation=45, ha='right', fontsize=9)
    ax.set_yticklabels(words, fontsize=9)
    
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.set_xlabel('Word', fontsize=12)
    ax.set_ylabel('Word', fontsize=12)
    
    plt.tight_layout()
    plt.show()

# Plot heatmap for Skip-Gram (Negative Sampling)
print("=" * 60)
print("WORD SIMILARITY HEATMAP")
print("=" * 60)
print()

plot_similarity_heatmap(skipgram_ns, idx_to_word, 
                        "Skip-Gram (Neg Sampling) Word Similarities")

---
## 10. Summary & Next Steps

Congratulations! You've successfully implemented Word2Vec from scratch in PyTorch!

### What We Learned

| Topic | Key Points |
|-------|-----------|
| **Data Preparation** | Tokenization, vocabulary building, word-to-index mappings |
| **Training Data** | Context windows, Skip-Gram pairs (center→context), CBOW pairs (context→center) |
| **Skip-Gram** | Predicts context words from center word, works better for rare words |
| **CBOW** | Predicts center word from context, faster training, works better for frequent words |
| **Negative Sampling** | Efficient training by distinguishing positive vs. negative samples |
| **Evaluation** | Cosine similarity, t-SNE visualization, word arithmetic |

### Key Equations Recap

**Skip-Gram Objective:**
$$P(w_o | w_c) = \frac{\exp(\vec{u}_{w_o}^T \vec{v}_{w_c})}{\sum_{i \in V} \exp(\vec{u}_i^T \vec{v}_{w_c})}$$

**CBOW Objective:**
$$P(w_c | w_{o_1}, ..., w_{o_{2m}}) = \frac{\exp(\vec{u}_{w_c}^T \bar{\vec{v}}_o)}{\sum_{i \in V} \exp(\vec{u}_i^T \bar{\vec{v}}_o)}$$

where $\bar{\vec{v}}_o$ is the average of context embeddings.

**Negative Sampling Loss:**
$$\mathcal{L} = -\log \sigma(\vec{u}_{w_o}^T \vec{v}_{w_c}) - \sum_{k=1}^{K} \log \sigma(-\vec{u}_{w_k}^T \vec{v}_{w_c})$$

### Next Steps

Now that you understand the fundamentals, here's how to continue your learning:

#### 1. Try Larger Datasets
Our small corpus was great for learning, but real Word2Vec models train on:
- **Text8**: 100MB Wikipedia text
- **Wikipedia dump**: Billions of words
- **Google News**: 100 billion words (original Word2Vec paper)

```python
# Example: Load text8 dataset
# Download from: http://mattmahoney.net/dc/text8.zip
```

#### 2. Use Pre-trained Embeddings
Instead of training from scratch, use embeddings trained on massive datasets:
- **Word2Vec** (Google): 3 million words, 300 dimensions
- **GloVe** (Stanford): Multiple sizes available
- **FastText** (Facebook): Handles out-of-vocabulary words

```python
# Example with Gensim
from gensim.models import KeyedVectors
model = KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin', binary=True)
```

#### 3. Explore Modern Embeddings
Word2Vec was groundbreaking, but newer methods exist:
- **FastText**: Uses subword information (handles rare words better)
- **ELMo**: Contextualized embeddings
- **BERT**: Transformer-based, context-dependent

#### 4. Apply to Downstream Tasks
Use your embeddings for:
- **Text Classification**: Sentiment analysis, spam detection
- **Named Entity Recognition**: Finding names, places, organizations
- **Machine Translation**: As input features
- **Information Retrieval**: Document similarity

### References

1. **Original Word2Vec Paper**: Mikolov et al., "Efficient Estimation of Word Representations in Vector Space" (2013)
2. **Negative Sampling Paper**: Mikolov et al., "Distributed Representations of Words and Phrases" (2013)
3. **GloVe Paper**: Pennington et al., "GloVe: Global Vectors for Word Representation" (2014)

---

**Thank you for following along! Happy learning!**

In [None]:
# ============================================
# FINAL SUMMARY
# ============================================

print("=" * 70)
print("                    NOTEBOOK COMPLETE!")
print("=" * 70)
print()
print("What you accomplished:")
print("-" * 70)
print(f"  1. Built vocabulary from {len(corpus)} sentences ({vocab_size} unique words)")
print(f"  2. Generated {len(skipgram_pairs)} Skip-Gram training pairs")
print(f"  3. Generated {len(cbow_pairs)} CBOW training pairs")
print(f"  4. Implemented 3 models:")
print(f"     - Skip-Gram (Basic) with full softmax")
print(f"     - Skip-Gram with Negative Sampling")
print(f"     - CBOW")
print(f"  5. Trained all models for {NUM_EPOCHS} epochs")
print(f"  6. Learned {EMBEDDING_DIM}-dimensional word embeddings")
print(f"  7. Evaluated using cosine similarity and t-SNE visualization")
print(f"  8. Explored word arithmetic and similarity heatmaps")
print()
print("=" * 70)
print("                  Great job! Keep learning!")
print("=" * 70)