# Day 3: Embeddings - Part 3: Training Embeddings with Context

In this notebook, we'll explore how embeddings become meaningful when trained with context and implement a simple contextual embedding model.

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.decomposition import PCA
import seaborn as sns

# Set style for plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

## 4. Training Embeddings with Context

Embeddings become meaningful when trained with context. Let's implement a simple skip-gram-like model for training embeddings.

In [None]:
class ContextualEmbedding(nn.Module):
    """Embedding layer trained with contextual information."""
    
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        
        # Target word embeddings
        self.target_embeddings = nn.Embedding(vocab_size, embedding_dim)
        # Context word embeddings
        self.context_embeddings = nn.Embedding(vocab_size, embedding_dim)
        
        # Initialize
        nn.init.uniform_(self.target_embeddings.weight, -0.5/embedding_dim, 0.5/embedding_dim)
        nn.init.uniform_(self.context_embeddings.weight, -0.5/embedding_dim, 0.5/embedding_dim)
    
    def forward(self, target_words, context_words):
        """Forward pass for skip-gram training."""
        target_emb = self.target_embeddings(target_words)  # [batch_size, emb_dim]
        context_emb = self.context_embeddings(context_words)  # [batch_size, emb_dim]
        
        # Compute similarity scores
        scores = torch.sum(target_emb * context_emb, dim=1)  # [batch_size]
        return scores
    
    def get_embedding(self, word_id):
        """Get final embedding for a word."""
        return self.target_embeddings.weight[word_id].detach().numpy()

### Creating Training Data

Let's create training data from a small corpus:

In [None]:
def create_training_data(sentences, vocab, window_size=2):
    """Create skip-gram training data from sentences."""
    training_pairs = []
    
    for sentence in sentences:
        words = sentence.split()
        for i, target_word in enumerate(words):
            if target_word not in vocab:
                continue
                
            target_id = vocab[target_word]
            
            # Get context words within window
            start = max(0, i - window_size)
            end = min(len(words), i + window_size + 1)
            
            for j in range(start, end):
                if i != j and words[j] in vocab:
                    context_id = vocab[words[j]]
                    training_pairs.append((target_id, context_id))
    
    return training_pairs

# Example training setup
sentences = [
    "the cat sat on the mat",
    "the dog ran in the park",
    "cats and dogs are pets",
    "the park has many trees",
    "trees provide shade in summer",
    "the cat and dog played in the park",
    "birds fly in the sky above trees",
    "fish swim in the water",
    "pets like to play with toys",
    "children play in the park with dogs"
]

# Build vocabulary
vocab = {}
for sentence in sentences:
    for word in sentence.split():
        if word not in vocab:
            vocab[word] = len(vocab)

print(f"Vocabulary size: {len(vocab)}")
print(f"Vocabulary: {vocab}")

# Create training data
training_pairs = create_training_data(sentences, vocab, window_size=2)
print(f"Training pairs: {len(training_pairs)}")
print(f"Sample pairs (target_id, context_id): {training_pairs[:5]}")

### Training the Embeddings

Let's train our contextual embeddings using negative sampling:

In [None]:
def train_embeddings(training_pairs, vocab_size, embedding_dim=50, num_epochs=100, batch_size=64, num_negative=5):
    """Train embeddings using skip-gram with negative sampling."""
    # Initialize model
    model = ContextualEmbedding(vocab_size, embedding_dim)
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    
    # Convert training pairs to tensors
    target_words = torch.tensor([pair[0] for pair in training_pairs])
    context_words = torch.tensor([pair[1] for pair in training_pairs])
    
    # Training loop
    losses = []
    
    for epoch in range(num_epochs):
        # Shuffle data
        indices = torch.randperm(len(training_pairs))
        target_words = target_words[indices]
        context_words = context_words[indices]
        
        total_loss = 0
        num_batches = (len(training_pairs) + batch_size - 1) // batch_size
        
        for i in range(num_batches):
            # Get batch
            start_idx = i * batch_size
            end_idx = min(start_idx + batch_size, len(training_pairs))
            
            target_batch = target_words[start_idx:end_idx]
            context_batch = context_words[start_idx:end_idx]
            batch_size_actual = len(target_batch)
            
            # Generate negative samples
            negative_samples = torch.randint(0, vocab_size, (batch_size_actual * num_negative,))
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass for positive samples
            positive_scores = model(target_batch, context_batch)
            positive_loss = -torch.log(torch.sigmoid(positive_scores)).mean()
            
            # Forward pass for negative samples
            negative_target = target_batch.repeat_interleave(num_negative)
            negative_scores = model(negative_target, negative_samples)
            negative_loss = -torch.log(torch.sigmoid(-negative_scores)).mean()
            
            # Total loss
            loss = positive_loss + negative_loss
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        # Track progress
        avg_loss = total_loss / num_batches
        losses.append(avg_loss)
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")
    
    # Plot loss
    plt.figure(figsize=(10, 5))
    plt.plot(losses)
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True, alpha=0.3)
    plt.show()
    
    return model

# Train embeddings
embedding_dim = 20  # Small dimension for visualization
trained_model = train_embeddings(training_pairs, len(vocab), embedding_dim=embedding_dim, num_epochs=100)

### Analyzing Trained Embeddings

Let's analyze our trained embeddings to see if they've captured semantic relationships:

In [None]:
def analyze_embedding_properties(model, vocab):
    """Analyze properties of trained embeddings."""
    # Create reverse vocabulary mapping
    vocab_reverse = {v: k for k, v in vocab.items()}
    
    # Get all embeddings
    all_embeddings = []
    labels = []
    
    for word_id in range(len(vocab)):
        emb = model.get_embedding(word_id)
        all_embeddings.append(emb)
        labels.append(vocab_reverse[word_id])
    
    all_embeddings = np.array(all_embeddings)
    
    # Compute embedding statistics
    print("Embedding Statistics:")
    print(f"Shape: {all_embeddings.shape}")
    print(f"Mean norm: {np.mean(np.linalg.norm(all_embeddings, axis=1)):.4f}")
    print(f"Std norm: {np.std(np.linalg.norm(all_embeddings, axis=1)):.4f}")
    
    # Find most similar word pairs
    print("\nMost similar word pairs:")
    similarities = cosine_similarity(all_embeddings)
    
    # Get top similar pairs (excluding self-similarity)
    np.fill_diagonal(similarities, -1)  # Remove self-similarity
    
    top_pairs = []
    for i in range(len(vocab)):
        for j in range(i+1, len(vocab)):
            top_pairs.append((similarities[i, j], labels[i], labels[j]))
    
    top_pairs.sort(reverse=True)
    
    for sim, word1, word2 in top_pairs[:10]:
        print(f"{word1} - {word2}: {sim:.4f}")
    
    # Visualize with PCA
    pca = PCA(n_components=2)
    embeddings_2d = pca.fit_transform(all_embeddings)
    
    plt.figure(figsize=(12, 10))
    plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], alpha=0.7)
    
    # Add labels
    for i, label in enumerate(labels):
        plt.annotate(label, (embeddings_2d[i, 0], embeddings_2d[i, 1]), 
                    fontsize=12, alpha=0.8)
    
    plt.title('PCA of Trained Word Embeddings')
    plt.xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')
    plt.ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)')
    plt.grid(True, alpha=0.3)
    plt.show()
    
    return all_embeddings, similarities, embeddings_2d

# Analyze trained embeddings
embeddings, similarities, embeddings_2d = analyze_embedding_properties(trained_model, vocab)

### Finding Semantic Relationships

Let's see if our embeddings have captured semantic relationships by performing vector arithmetic:

In [None]:
def vector_analogy(word_a, word_b, word_c, vocab, model):
    """Solve analogy: a is to b as c is to ?"""
    # Get word IDs
    if word_a not in vocab or word_b not in vocab or word_c not in vocab:
        return "One or more words not in vocabulary"
    
    a_id = vocab[word_a]
    b_id = vocab[word_b]
    c_id = vocab[word_c]
    
    # Get embeddings
    a_emb = model.get_embedding(a_id)
    b_emb = model.get_embedding(b_id)
    c_emb = model.get_embedding(c_id)
    
    # Vector arithmetic: b - a + c
    target_vector = b_emb - a_emb + c_emb
    
    # Find most similar word
    vocab_reverse = {v: k for k, v in vocab.items()}
    all_embeddings = np.array([model.get_embedding(i) for i in range(len(vocab))])
    
    similarities = cosine_similarity([target_vector], all_embeddings)[0]
    
    # Exclude input words
    for word_id in [a_id, b_id, c_id]:
        similarities[word_id] = -float('inf')
    
    # Get top results
    top_indices = np.argsort(similarities)[::-1][:5]
    results = [(vocab_reverse[idx], similarities[idx]) for idx in top_indices]
    
    return results

# Try some analogies if we have the right words in our vocabulary
analogies = [
    ('cat', 'cats', 'dog'),  # singular:plural
    ('dog', 'park', 'fish'),  # animal:habitat
    ('cat', 'pets', 'tree')   # random test
]

for a, b, c in analogies:
    if a in vocab and b in vocab and c in vocab:
        results = vector_analogy(a, b, c, vocab, trained_model)
        print(f"{a} : {b} :: {c} : ?")
        for word, sim in results:
            print(f"  {word}: {sim:.4f}")
        print()

## Key Takeaways

1. **Context is Key**: Embeddings become meaningful when trained with contextual information
2. **Skip-gram Model**: Predicts context words from target words, capturing semantic relationships
3. **Vector Arithmetic**: Semantic relationships can be expressed through vector operations
4. **Visualization**: PCA helps understand the structure of the embedding space
5. **Training Data**: The quality and quantity of training data significantly impacts embedding quality

In our small example, we may not see perfect analogies due to the limited training data. In practice, embeddings are trained on billions of words to capture rich semantic relationships.