In [21]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import Counter
import urllib.request
import os
import time
from torch.utils.data import Dataset, DataLoader
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [22]:
class Word2Vec:
    def __init__(self, embedding_dim=300, window_size=5, min_count=5, 
                 batch_size=512, epochs=5, learning_rate=0.001):
        self.embedding_dim = embedding_dim
        self.window_size = window_size
        self.min_count = min_count
        self.batch_size = batch_size
        self.epochs = epochs
        self.learning_rate = learning_rate
        
        self.word_to_ix = None
        self.ix_to_word = None
        self.model = None
        
    def download_text8(self):
        """Download the text8 dataset if not already present"""
        if not os.path.exists('text8'):
            logger.info("Downloading text8 dataset...")
            url = 'http://mattmahoney.net/dc/text8.zip'
            urllib.request.urlretrieve(url, 'text8.zip')
            os.system('unzip text8.zip')
            logger.info("Download complete")
    
    def preprocess_text(self, text, max_tokens=None):
        """Split text into tokens"""
        tokens = text.split()
        return tokens[:max_tokens] if max_tokens else tokens
    
    def build_vocab(self, tokens):
        """Build vocabulary from tokens with minimum frequency threshold"""
        word_counts = Counter(tokens)
        vocab = [word for word, count in word_counts.items() 
                if count >= self.min_count]
        
        self.word_to_ix = {word: i for i, word in enumerate(['<UNK>'] + vocab)}
        self.ix_to_word = {i: word for word, i in self.word_to_ix.items()}
        
        logger.info(f"Vocabulary size: {len(self.word_to_ix)}")
        return vocab
    
    def subsample_frequent_words(self, tokens, threshold=1e-5):
        """Subsample frequent words using Word2Vec paper's formula"""
        word_counts = Counter(tokens)
        total_count = len(tokens)
        word_freq = {word: count/total_count for word, count in word_counts.items()}
        
        prob_drop = {word: 1 - np.sqrt(threshold/freq) 
                    for word, freq in word_freq.items()}
        
        return [token for token in tokens 
                if np.random.random() > prob_drop[token]]

In [23]:
class SkipGramDataset(Dataset):
    def __init__(self, tokens, word_to_ix, window_size=5):
        self.tokens = tokens
        self.word_to_ix = word_to_ix
        self.window_size = window_size
        self.pairs = self._generate_pairs()
    
    def _generate_pairs(self):
        pairs = []
        for i in range(len(self.tokens)):
            start = max(0, i - self.window_size)
            end = min(len(self.tokens), i + self.window_size + 1)
            context = self.tokens[start:i] + self.tokens[i+1:end]
            pairs.extend([(self.tokens[i], ctx) for ctx in context])
        return pairs
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        target, context = self.pairs[idx]
        target_ix = self.word_to_ix.get(target, self.word_to_ix['<UNK>'])
        context_ix = self.word_to_ix.get(context, self.word_to_ix['<UNK>'])
        return torch.tensor(target_ix), torch.tensor(context_ix)

In [25]:
# Cell 4: Model Class
class SkipGramModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(SkipGramModel, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.output = nn.Linear(embedding_dim, vocab_size)
        
    def forward(self, x):
        embedded = self.embeddings(x)
        out = self.output(embedded)
        return torch.log_softmax(out, dim=1)

In [26]:
# Cell 5: Trainer Class
class Word2VecTrainer:
    def __init__(self, model, dataset, learning_rate=0.001):
        self.model = model
        self.dataset = dataset
        self.learning_rate = learning_rate
        
    def train(self, epochs, batch_size):
        dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True)
        optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        criterion = nn.NLLLoss()
        
        start_time = time.time()
        for epoch in range(epochs):
            total_loss = 0
            for batch_idx, (target, context) in enumerate(dataloader):
                log_probs = self.model(target)
                loss = criterion(log_probs, context)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                
                if batch_idx % 100 == 0:
                    logger.info(f"Epoch {epoch}, Batch {batch_idx}, "
                              f"Loss: {loss.item():.4f}")
            
            avg_loss = total_loss / len(dataloader)
            logger.info(f"Epoch {epoch} completed, Average Loss: {avg_loss:.4f}")
        
        training_time = time.time() - start_time
        logger.info(f"Training completed in {training_time:.2f} seconds")

In [27]:
class Word2VecEvaluator:
    def __init__(self, model, word_to_ix, ix_to_word):
        self.model = model
        self.word_to_ix = word_to_ix
        self.ix_to_word = ix_to_word
        
    def get_vector(self, word):
        """Get the embedding vector for a word"""
        if word not in self.word_to_ix:
            return None
        word_ix = torch.tensor([self.word_to_ix[word]])
        return self.model.embeddings(word_ix).detach().numpy()[0]
    
    def similar_words(self, word, n=5):
        """Find n most similar words"""
        if word not in self.word_to_ix:
            return []
        
        target_vec = self.get_vector(word)
        similarities = []
        
        for w in self.word_to_ix:
            if w != word and w != '<UNK>':
                vec = self.get_vector(w)
                similarity = float(np.dot(target_vec, vec) / 
                                (np.linalg.norm(target_vec) * np.linalg.norm(vec)))
                similarities.append((w, similarity))
        
        return sorted(similarities, key=lambda x: x[1], reverse=True)[:n]
    
    def get_analogy(self, word1, word2, word3, n=5):
        """Find word4 such that word1:word2 :: word3:word4"""
        for word in [word1, word2, word3]:
            if word not in self.word_to_ix:
                return []
        
        v1, v2, v3 = map(self.get_vector, [word1, word2, word3])
        target = v2 - v1 + v3
        
        similarities = []
        for w in self.word_to_ix:
            if w not in [word1, word2, word3, '<UNK>']:
                vec = self.get_vector(w)
                similarity = np.dot(target, vec) / (np.linalg.norm(target) * 
                                                  np.linalg.norm(vec))
                similarities.append((w, similarity))
        
        return sorted(similarities, key=lambda x: x[1], reverse=True)[:n]

In [28]:
# Initialize Word2Vec
w2v = Word2Vec(embedding_dim=300, window_size=5, min_count=5, 
               batch_size=512, epochs=5)

# Download and load data
w2v.download_text8()
with open('text8', 'r', encoding='utf-8') as f:
    text = f.read()

# Preprocess text (using first 1M tokens for quick testing)
tokens = w2v.preprocess_text(text, max_tokens=1_000_000)
print(f"Total tokens: {len(tokens)}")

# Build vocabulary
vocab = w2v.build_vocab(tokens)

# Create dataset
dataset = SkipGramDataset(tokens, w2v.word_to_ix, window_size=w2v.window_size)
print(f"Total training pairs: {len(dataset)}")

# Initialize model
model = SkipGramModel(len(w2v.word_to_ix), w2v.embedding_dim)

# Train model
trainer = Word2VecTrainer(model, dataset, learning_rate=w2v.learning_rate)
trainer.train(epochs=w2v.epochs, batch_size=w2v.batch_size)

INFO:__main__:Vocabulary size: 13967


Total tokens: 1000000
Total training pairs: 9999970


INFO:__main__:Epoch 0, Batch 0, Loss: 9.7407
INFO:__main__:Epoch 0, Batch 100, Loss: 8.5167
INFO:__main__:Epoch 0, Batch 200, Loss: 7.9896
INFO:__main__:Epoch 0, Batch 300, Loss: 7.6763
INFO:__main__:Epoch 0, Batch 400, Loss: 7.4904
INFO:__main__:Epoch 0, Batch 500, Loss: 7.2885
INFO:__main__:Epoch 0, Batch 600, Loss: 7.4762
INFO:__main__:Epoch 0, Batch 700, Loss: 6.9735
INFO:__main__:Epoch 0, Batch 800, Loss: 7.0200
INFO:__main__:Epoch 0, Batch 900, Loss: 7.3512
INFO:__main__:Epoch 0, Batch 1000, Loss: 7.0595
INFO:__main__:Epoch 0, Batch 1100, Loss: 7.1793
INFO:__main__:Epoch 0, Batch 1200, Loss: 7.1391
INFO:__main__:Epoch 0, Batch 1300, Loss: 6.8866
INFO:__main__:Epoch 0, Batch 1400, Loss: 7.1409
INFO:__main__:Epoch 0, Batch 1500, Loss: 7.1071
INFO:__main__:Epoch 0, Batch 1600, Loss: 6.9509
INFO:__main__:Epoch 0, Batch 1700, Loss: 6.8570
INFO:__main__:Epoch 0, Batch 1800, Loss: 6.7480
INFO:__main__:Epoch 0, Batch 1900, Loss: 6.8451


KeyboardInterrupt: 