<a href="https://colab.research.google.com/github/abhishekkumawat23/word2vec-embedding-model-from-scratch/blob/main/word2vec_embedding_model_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Word2Vec Embedding Model from Scratch

In this notebook, we create a simple version of the word2vec embedding model from scratch.
The aim is educational - this contains simplifications and is not fully optimized.

## Important Context

Modern LLMs learn/train their embeddings along with the main LLM training itself, so they
don't need a separate embedding model. When you train an LLM, you simultaneously train:
1. The model to predict/generate the next token (main objective)
2. The embedding vectors for the vocabulary

For non-LLM cases or pre-trained embeddings, we implement word2vec here. Other alternatives
exist (like GloVe), but we focus on word2vec with skip-gram and negative sampling.

## Algorithm Overview

**Core Task:**
Word2vec is a binary classification task that takes a pair of words as input and predicts
whether they belong to the same context. Instead of returning discrete 0 or 1, it outputs
the probability that the pair shares the same context.

**Creating Training Pairs:**

1. **Positive Pairs (Skip-gram approach):**
   - Use a sliding window of fixed size over the text corpus
   - For each window, the center word becomes the first token
   - All other words in the window become second tokens
   - Create pairs: (center_word, context_word) with label 1
   - Example: "the quick brown fox" with window_size=2
     - Center="brown" → pairs: (brown, the), (brown, quick), (brown, fox)

2. **Negative Pairs (Negative sampling):**
   - Keep the same center word from the sliding window
   - Randomly sample words from the entire vocabulary as second tokens
   - These become negative pairs with label 0
   - Risk: Might accidentally sample an actual context word, but probability is low
   - Typically sample 5-20 negative pairs per positive pair

**Training Input:**
For each center word, we provide:
- 1 positive pair (actual context word) with target = 1
- N negative pairs (random words) with target = 0

**Model Architecture (Shallow - No Hidden Layers):**

1. **Two Embedding Matrices:**
   - Center word embeddings: [vocab_size × embedding_dim]
   - Context word embeddings: [vocab_size × embedding_dim]
   - Each word has two representations (one as center, one as context)

2. **Forward Pass:**
   - Lookup embeddings for center word and context word
   - Compute dot product: similarity = center_embed · context_embed
   - Dot product measures semantic similarity between embeddings
   - Apply sigmoid activation: probability = 1 / (1 + e^(-similarity))

3. **Why Sigmoid?**
   - Introduces non-linearity (enables learning complex semantic relationships)
   - Maps output to [0, 1] range (perfect for probability)
   - Output represents: P(words are in same context)

4. **Loss and Optimization:**
   - Binary Cross-Entropy Loss: measures difference between predicted probability and target
   - Adam optimizer: updates embedding matrices to minimize loss
   - Alternative: SGD or other optimizers work too

**Final Embeddings:**
After training, average the center and context embeddings to get the final word vectors:
- final_embedding[word] = (center_embedding[word] + context_embedding[word]) / 2
- This captures both perspectives and typically gives better results

**Key Insight:**
No neural network layers needed! The embeddings themselves are the parameters being learned.
Through the training objective, words that appear in similar contexts naturally end up with
similar embedding vectors.

# [WARNING] Generated by Claude

Below code is only for reference to be used when I actually implement the word2vec. Below code is generated by claude when i asked to create a simple version of word2vec. Remove it once you have implemented the word2vec by yourself.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from collections import Counter
import numpy as np

class Word2Vec(nn.Module):
    """
    Simple Word2Vec implementation with Skip-gram and Negative Sampling
    """
    def __init__(self, vocab_size, embedding_dim):
        super(Word2Vec, self).__init__()

        # Center word embeddings (input)
        self.center_embeddings = nn.Embedding(vocab_size, embedding_dim)

        # Context word embeddings (output)
        self.context_embeddings = nn.Embedding(vocab_size, embedding_dim)

        # Initialize embeddings with small random values
        self.center_embeddings.weight.data.uniform_(-0.5/embedding_dim, 0.5/embedding_dim)
        self.context_embeddings.weight.data.uniform_(-0.5/embedding_dim, 0.5/embedding_dim)

    def forward(self, center_words, context_words):
        """
        Forward pass: compute probability that word pairs are in same context

        Args:
            center_words: [batch_size] - indices of center words
            context_words: [batch_size, num_samples] - indices of context/negative words
                          (first sample is positive, rest are negative)

        Returns:
            [batch_size, num_samples] - probabilities for each pair
        """
        # Get center word embeddings: [batch_size, embedding_dim]
        center_embeds = self.center_embeddings(center_words)

        # Get context word embeddings: [batch_size, num_samples, embedding_dim]
        context_embeds = self.context_embeddings(context_words)

        # Compute dot product (similarity measure): [batch_size, num_samples]
        # Using einsum for clarity: 'be' (batch, embed) × 'bse' (batch, samples, embed) -> 'bs'
        scores = torch.einsum('be,bse->bs', center_embeds, context_embeds)

        # Apply sigmoid to convert to probability [0, 1]
        # High score (similar embeddings) → probability near 1
        # Low score (dissimilar embeddings) → probability near 0
        probs = torch.sigmoid(scores)

        return probs


class Word2VecTrainer:
    """
    Trainer for Word2Vec model with data preparation and training loop
    """
    def __init__(self, embedding_dim=100, window_size=5, num_negative_samples=5,
                 min_count=5, learning_rate=0.025):
        self.embedding_dim = embedding_dim
        self.window_size = window_size
        self.num_negative_samples = num_negative_samples
        self.min_count = min_count
        self.learning_rate = learning_rate

        self.word2idx = {}
        self.idx2word = {}
        self.word_counts = None
        self.model = None

    def build_vocab(self, sentences):
        """Build vocabulary from sentences"""
        # Count word frequencies
        word_counts = Counter()
        for sentence in sentences:
            word_counts.update(sentence.lower().split())

        # Filter by min_count
        word_counts = {word: count for word, count in word_counts.items()
                      if count >= self.min_count}

        # Create word-to-index mapping
        self.word2idx = {word: idx for idx, word in enumerate(word_counts.keys())}
        self.idx2word = {idx: word for word, idx in self.word2idx.items()}

        # Store word counts for negative sampling (raised to 0.75 power)
        self.word_counts = np.array([word_counts[self.idx2word[i]]
                                     for i in range(len(self.word2idx))], dtype=np.float64)
        self.word_counts = np.power(self.word_counts, 0.75)
        self.word_counts /= self.word_counts.sum()

        print(f"Vocabulary size: {len(self.word2idx)}")

        return len(self.word2idx)

    def generate_training_data(self, sentences, batch_size=512):
        """
        Generate training batches with positive and negative samples

        For each center word in a sliding window:
        - Create 1 positive pair with actual context word (label=1)
        - Create N negative pairs with random words (label=0)

        Returns batches of:
        - center_words: [batch_size]
        - context_words: [batch_size, 1 + num_negative_samples]
        - labels: [batch_size, 1 + num_negative_samples]
        """
        for sentence in sentences:
            words = sentence.lower().split()
            word_indices = [self.word2idx[w] for w in words if w in self.word2idx]

            if len(word_indices) < 2:
                continue

            # Generate skip-gram pairs using sliding window
            for center_pos, center_word in enumerate(word_indices):
                # Define context window around center word
                start = max(0, center_pos - self.window_size)
                end = min(len(word_indices), center_pos + self.window_size + 1)

                # For each word in the window (except center itself)
                for context_pos in range(start, end):
                    if context_pos == center_pos:
                        continue

                    context_word = word_indices[context_pos]

                    # Create POSITIVE sample (words that actually appear together)
                    center_batch = [center_word]
                    context_batch = [[context_word]]  # First sample is positive
                    label_batch = [[1.0]]  # Label 1 = same context

                    # Add NEGATIVE samples (random words from vocabulary)
                    # Sampled with frequency^0.75 to give rare words better chance
                    negative_samples = np.random.choice(
                        len(self.word2idx),
                        size=self.num_negative_samples,
                        replace=False,
                        p=self.word_counts  # Weighted by word frequency^0.75
                    )

                    # Note: Small chance negative sample is actually a context word
                    # but probability is low enough to ignore
                    context_batch[0].extend(negative_samples.tolist())
                    label_batch[0].extend([0.0] * self.num_negative_samples)

                    yield (
                        torch.LongTensor(center_batch),
                        torch.LongTensor(context_batch),
                        torch.FloatTensor(label_batch)
                    )

    def train(self, sentences, epochs=5):
        """Train the Word2Vec model"""
        vocab_size = self.build_vocab(sentences)

        # Initialize model
        self.model = Word2Vec(vocab_size, self.embedding_dim)
        optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)

        # Binary Cross-Entropy Loss for binary classification (same context or not)
        # Loss = -[y*log(p) + (1-y)*log(1-p)] where y is target, p is predicted probability
        criterion = nn.BCELoss()

        self.model.train()

        for epoch in range(epochs):
            total_loss = 0
            batch_count = 0

            for center, context, labels in self.generate_training_data(sentences):
                optimizer.zero_grad()

                # Forward pass: get probabilities for each pair
                probs = self.model(center, context)

                # Compute loss between predicted probabilities and true labels
                loss = criterion(probs, labels)

                # Backward pass: update embedding matrices
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                batch_count += 1

            avg_loss = total_loss / batch_count if batch_count > 0 else 0
            print(f"Epoch {epoch+1}/{epochs}, Average Loss: {avg_loss:.4f}")

    def get_word_vector(self, word):
        """
        Get embedding vector for a word by averaging center and context embeddings

        Args:
            word: The word to get embedding for
        """
        if word not in self.word2idx:
            return None
        idx = self.word2idx[word]

        # Average both embedding matrices for best results
        center_vec = self.model.center_embeddings.weight[idx].detach().numpy()
        context_vec = self.model.context_embeddings.weight[idx].detach().numpy()
        return (center_vec + context_vec) / 2

    def get_embedding_matrix(self):
        """
        Get the final embedding matrix [vocab_size x embedding_dim]

        This is what you'd save and use for downstream tasks.
        Averages center and context embeddings for best representation.

        Why average both matrices?
        - Center embeddings learn: "when I'm the query word"
        - Context embeddings learn: "when I'm the context word"
        - Both capture similar semantic relationships but from different perspectives
        - Averaging combines both views for richer representations
        """
        center_matrix = self.model.center_embeddings.weight.detach().numpy()
        context_matrix = self.model.context_embeddings.weight.detach().numpy()
        return (center_matrix + context_matrix) / 2

    def find_similar_words(self, word, top_k=5):
        """Find most similar words using cosine similarity"""
        if word not in self.word2idx:
            return []

        word_vec = self.get_word_vector(word)

        # Get all word embeddings
        all_embeddings = self.model.center_embeddings.weight.detach().numpy()

        # Compute cosine similarities
        similarities = np.dot(all_embeddings, word_vec) / (
            np.linalg.norm(all_embeddings, axis=1) * np.linalg.norm(word_vec)
        )

        # Get top-k indices (excluding the word itself)
        similar_indices = similarities.argsort()[::-1][1:top_k+1]

        return [(self.idx2word[idx], similarities[idx]) for idx in similar_indices]


# Example usage
if __name__ == "__main__":
    # Sample corpus
    sentences = [
        "the quick brown fox jumps over the lazy dog",
        "the dog is very lazy and sleeps all day",
        "the cat is quick and agile",
        "a quick brown fox runs through the forest",
        "the lazy cat sleeps on the couch",
        "dogs and cats are great pets",
        "the fox is clever and quick",
        "brown bears live in the forest"
    ] * 20  # Repeat for more training data

    # Initialize and train
    trainer = Word2VecTrainer(
        embedding_dim=50,
        window_size=2,
        num_negative_samples=5,
        min_count=2,
        learning_rate=0.01
    )

    print("Training Word2Vec model...")
    trainer.train(sentences, epochs=10)

    # Test the model
    print("\nWord vectors learned!")

    # Get the final embedding matrix (averaged from both center and context)
    embedding_matrix = trainer.get_embedding_matrix()
    print(f"\nFinal embedding matrix shape: {embedding_matrix.shape}")
    print(f"(vocab_size={len(trainer.word2idx)}, embedding_dim={trainer.embedding_dim})")

    print("\nSimilar words to 'dog':")
    similar = trainer.find_similar_words('dog', top_k=3)
    for word, similarity in similar:
        print(f"  {word}: {similarity:.4f}")

    print("\nSimilar words to 'quick':")
    similar = trainer.find_similar_words('quick', top_k=3)
    for word, similarity in similar:
        print(f"  {word}: {similarity:.4f}")

Training Word2Vec model...
Vocabulary size: 31
Epoch 1/10, Average Loss: 0.4063
Epoch 2/10, Average Loss: 0.3408
Epoch 3/10, Average Loss: 0.3284
Epoch 4/10, Average Loss: 0.3226
Epoch 5/10, Average Loss: 0.3156
Epoch 6/10, Average Loss: 0.3132
Epoch 7/10, Average Loss: 0.3094
Epoch 8/10, Average Loss: 0.3094
Epoch 9/10, Average Loss: 0.3121
Epoch 10/10, Average Loss: 0.3092

Word vectors learned!

Final embedding matrix shape: (31, 50)
(vocab_size=31, embedding_dim=50)

Similar words to 'dog':
  lazy: 0.4625
  very: 0.3182
  is: 0.2009

Similar words to 'quick':
  quick: 0.2955
  a: 0.2879
  agile: 0.2383
