In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import time
from collections import Counter
import nltk
from nltk.corpus import reuters

# 0. SETUP & DATA SOURCING
# Proper credit: Reuters-21578 Text Categorization Test Collection
try:
    nltk.data.find('corpora/reuters')
except LookupError:
    nltk.download('reuters')
nltk.download('punkt_tab')

def load_news_dataset(category='grain', max_vocab=5000):
    """Loads and cleans real-world news data from NLTK."""
    raw_sents = reuters.sents(categories=category)
    # Clean: lowercasing and removing punctuation/numbers
    clean_corpus = [[w.lower() for w in s if w.isalpha()] for s in raw_sents]
    
    # Vocabulary building
    all_words = [w for s in clean_corpus for w in s]
    counts = Counter(all_words)
    vocab = sorted(counts, key=counts.get, reverse=True)[:max_vocab]
    vocab.append('<UNK>')
    
    word2idx = {w: i for i, w in enumerate(vocab)}
    idx2word = {i: w for i, w in enumerate(vocab)}
    
    return clean_corpus, vocab, word2idx, idx2word

# Initialize Data
corpus, vocab, word2index, index2word = load_news_dataset()
voc_size = len(vocab)



[nltk_data] Downloading package reuters to
[nltk_data]     C:\Users\alsto\AppData\Roaming\nltk_data...
[nltk_data]   Package reuters is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\alsto\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [2]:
# 1. DYNAMIC UTILITIES
def get_skipgrams(corpus, w2i, window_size=2):
    """Generates training pairs with a dynamic window size."""
    data = []
    for sentence in corpus:
        for i, word in enumerate(sentence):
            if word not in w2i: continue
            target = w2i[word]
            
            # Dynamic window bounds
            start = max(0, i - window_size)
            end = min(len(sentence), i + window_size + 1)
            
            for j in range(start, end):
                if i == j or sentence[j] not in w2i: continue
                data.append((target, w2i[sentence[j]]))
    return data

def get_unigram_table(vocab, corpus, w2i):
    """Creates a noise distribution for negative sampling (P(w)^3/4)."""
    counts = Counter([w for s in corpus for w in s if w in w2i])
    total = sum(counts.values())
    table = []
    for word in vocab:
        if word == '<UNK>': continue
        # Mikolov's heuristic for negative sampling
        freq = int(((counts[word]/total)**0.75) / 0.001)
        table.extend([word] * max(freq, 1))
    return table



In [3]:
# 2. WORD2VEC (NEGATIVE SAMPLING)
class Word2VecNeg(nn.Module):
    def __init__(self, v_size, emb_dim):
        super(Word2VecNeg, self).__init__()
        self.v_embeddings = nn.Embedding(v_size, emb_dim) # Center
        self.u_embeddings = nn.Embedding(v_size, emb_dim) # Outside
        self.log_sigmoid = nn.LogSigmoid()
        
    def forward(self, center, target, negative):
        # Reshaping to ensure [Batch, 1, Emb] to avoid IndexErrors
        v_vecs = self.v_embeddings(center).view(center.size(0), 1, -1)
        u_vecs = self.u_embeddings(target).view(target.size(0), 1, -1)
        n_vecs = -self.u_embeddings(negative) # [Batch, K, Emb]
        
        # Positive score: dot product of center and target
        pos_score = torch.bmm(u_vecs, v_vecs.transpose(1, 2)).view(center.size(0), -1)
        # Negative score: dot product of center and K noise samples
        neg_score = torch.bmm(n_vecs, v_vecs.transpose(1, 2)).squeeze(2)
        
        loss = self.log_sigmoid(pos_score) + torch.sum(self.log_sigmoid(neg_score), 1)
        return -torch.mean(loss)





In [4]:
# 3. GLOVE IMPLEMENTATION
class GloVeModel(nn.Module):
    def __init__(self, v_size, emb_dim):
        super(GloVeModel, self).__init__()
        self.v_emb = nn.Embedding(v_size, emb_dim)
        self.u_emb = nn.Embedding(v_size, emb_dim)
        self.v_bias = nn.Embedding(v_size, 1)
        self.u_bias = nn.Embedding(v_size, 1)
        
    def forward(self, i_indices, j_indices, cooc_counts, weights):
        v_vec = self.v_emb(i_indices)
        u_vec = self.u_emb(j_indices)
        b_i = self.v_bias(i_indices).squeeze(1)
        b_j = self.u_bias(j_indices).squeeze(1)
        
        # Dot product
        dot = (v_vec * u_vec).sum(1)
        # Main GloVe Objective
        loss = weights * torch.pow(dot + b_i + b_j - torch.log(cooc_counts), 2)
        return torch.mean(loss)



In [5]:
# 4. FULL TRAINING EXECUTION

# Hyperparameters
WINDOW_SIZE = 2 
EMB_DIM = 100
BATCH_SIZE = 256
EPOCHS = 15 
LEARNING_RATE = 0.001

# Prepare Data
pairs = get_skipgrams(corpus, word2index, window_size=WINDOW_SIZE)
noise_table = get_unigram_table(vocab, corpus, word2index)

# Initialize Model & Optimizer
model = Word2VecNeg(voc_size, EMB_DIM)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

def sample_negative(targets, table, k, w2i):
    """Helper to pick negative indices."""
    batch_size = targets.size(0)
    negs = []
    for i in range(batch_size):
        s = []
        while len(s) < k:
            pick = random.choice(table)
            if w2i[pick] == targets[i].item(): continue
            s.append(w2i[pick])
        negs.append(torch.LongTensor(s).view(1, -1))
    return torch.cat(negs)

print(f"Training Word2Vec (NEG) on Reuters Category: 'grain'")
print(f"Total pairs: {len(pairs)} | Window Size: {WINDOW_SIZE}")



Training Word2Vec (NEG) on Reuters Category: 'grain'
Total pairs: 352144 | Window Size: 2


In [6]:
start_time = time.time()

for epoch in range(EPOCHS):
    epoch_loss = 0
    random.shuffle(pairs) # Shuffle for each epoch
    
    for i in range(0, len(pairs), BATCH_SIZE):
        batch = pairs[i : i + BATCH_SIZE]
        if len(batch) < BATCH_SIZE: continue
            
        centers = torch.LongTensor([p[0] for p in batch]) # Center words
        targets = torch.LongTensor([p[1] for p in batch]) # Outside words
        negatives = sample_negative(targets, noise_table, 5, word2index)
        
        optimizer.zero_grad()
        loss = model(centers, targets, negatives) # Compute loss
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    avg_loss = epoch_loss / (len(pairs) // BATCH_SIZE)
    print(f"Epoch [{epoch+1}/{EPOCHS}] - Average Loss: {avg_loss:.4f}")

end_time = time.time()
print(f"\nTraining Complete!")
print(f"Total time taken: {end_time - start_time:.2f} seconds")

Epoch [1/15] - Average Loss: 20.2847
Epoch [2/15] - Average Loss: 14.8388
Epoch [3/15] - Average Loss: 11.1705
Epoch [4/15] - Average Loss: 8.4469
Epoch [5/15] - Average Loss: 6.5545
Epoch [6/15] - Average Loss: 5.2548
Epoch [7/15] - Average Loss: 4.3565
Epoch [8/15] - Average Loss: 3.7561
Epoch [9/15] - Average Loss: 3.3312
Epoch [10/15] - Average Loss: 3.0127
Epoch [11/15] - Average Loss: 2.7689
Epoch [12/15] - Average Loss: 2.5746
Epoch [13/15] - Average Loss: 2.4191
Epoch [14/15] - Average Loss: 2.2873
Epoch [15/15] - Average Loss: 2.1816

Training Complete!
Total time taken: 261.11 seconds
