In [10]:
import nltk
nltk.download('reuters')

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import Counter
import time
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader
import logging
from nltk.corpus import reuters

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

[nltk_data] Downloading package reuters to /root/nltk_data...
[nltk_data]   Package reuters is already up-to-date!


In [11]:
tokens = []
for fileid in reuters.fileids():
    tokens.extend([word.lower() for word in reuters.words(fileid)])

tokens = tokens[:400_000]
print(f"number of tokens: {len(tokens)}")
print(f"sample of tokens: {tokens[:5]}")

number of tokens: 400000
sample of tokens: ['asian', 'exporters', 'fear', 'damage', 'from']


In [12]:
class Word2Vec:
    def __init__(self, embedding_dim=100, window_size=3, min_count=5,
                 batch_size=2048, epochs=3, learning_rate=0.001):
        self.embedding_dim = embedding_dim
        self.window_size = window_size
        self.min_count = min_count
        self.batch_size = batch_size
        self.epochs = epochs
        self.learning_rate = learning_rate

        self.word_to_ix = None
        self.ix_to_word = None
        self.model = None

    def build_vocab(self, tokens):
        # build vocabulary from tokens with minimum frequency threshold
        word_counts = Counter(tokens)
        vocab = [word for word, count in word_counts.items()
                if count >= self.min_count]
                # filtering out rare words;
                #because they don't provide enough contexts to learn good embeddings from

        self.word_to_ix = {word: i for i, word in enumerate(['<UNK>'] + vocab)}
        self.ix_to_word = {i: word for word, i in self.word_to_ix.items()}

        logger.info(f"vocabulary size: {len(self.word_to_ix)}")
        return vocab

    def subsample_frequent_words(self, tokens, threshold=1e-5):
        # subsample frequent words using paper's formula
        total_count = len(tokens)
        word_freq = {word: count/total_count for word, count in word_counts.items()}

        prob_drop = {word: 1 - np.sqrt(threshold/freq)
                    for word, freq in word_freq.items()}
                    # filter out super common words like "the"
                    # that can overwhelm the training

        return [token for token in tokens
                if np.random.random() > prob_drop[token]]

In [13]:
class SkipGramDataset(Dataset):
    def __init__(self, tokens, word_to_ix, window_size=3):
        self.tokens = tokens
        self.word_to_ix = word_to_ix
        self.window_size = window_size
        self.pairs = self._generate_pairs()

    def _generate_pairs(self):
        pairs = []
        for i in range(len(self.tokens)):
            start = max(0, i - self.window_size) # to not look before the start
            end = min(len(self.tokens), i + self.window_size + 1) # to not look after the end
            context = self.tokens[start:i] + self.tokens[i+1:end]
            pairs.extend([(self.tokens[i], ctx) for ctx in context])
        return pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        target, context = self.pairs[idx]
        target_ix = self.word_to_ix.get(target, self.word_to_ix['<UNK>'])
        context_ix = self.word_to_ix.get(context, self.word_to_ix['<UNK>'])
        return torch.tensor(target_ix), torch.tensor(context_ix)

In [14]:
class SkipGramModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(SkipGramModel, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.output = nn.Linear(embedding_dim, vocab_size)

    def forward(self, x):
        embedded = self.embeddings(x)
        out = self.output(embedded)
        return torch.log_softmax(out, dim=1)

In [15]:
class Word2VecTrainer:
    def __init__(self, model, dataset, learning_rate=0.001):
        self.model = model
        self.dataset = dataset
        self.learning_rate = learning_rate

    def train(self, epochs, batch_size):
        dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True)
        optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        criterion = nn.NLLLoss()

        start_time = time.time()
        for epoch in range(epochs):
            total_loss = 0
            progress_bar = tqdm(dataloader, desc=f'epoch {epoch}')

            for target, context in progress_bar:
                log_probs = self.model(target)
                loss = criterion(log_probs, context)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})

            avg_loss = total_loss / len(dataloader)
            print(f"epoch {epoch} completed, avg loss: {avg_loss:.4f}")

        training_time = time.time() - start_time
        print(f"trianing completed in {training_time:.2f} seconds")

In [16]:
w2v = Word2Vec(
    embedding_dim=100,
    window_size=3,
    min_count=5,
    batch_size=2048,
    epochs=3
)

vocab = w2v.build_vocab(tokens)

dataset = SkipGramDataset(tokens, w2v.word_to_ix, window_size=w2v.window_size)
print(f"total training pairs: {len(dataset)}")

model = SkipGramModel(len(w2v.word_to_ix), w2v.embedding_dim)

trainer = Word2VecTrainer(model, dataset, learning_rate=w2v.learning_rate)
trainer.train(epochs=w2v.epochs, batch_size=w2v.batch_size)

total training pairs: 2399988


epoch 0:   0%|          | 0/1172 [00:00<?, ?it/s]

epoch 0 completed, avg loss: 6.5924


epoch 1:   0%|          | 0/1172 [00:00<?, ?it/s]

epoch 1 completed, avg loss: 5.9060


epoch 2:   0%|          | 0/1172 [00:00<?, ?it/s]

epoch 2 completed, avg loss: 5.8094
trianing completed in 1029.16 seconds


In [17]:
class Word2VecEvaluator:
    def __init__(self, model, word_to_ix, ix_to_word):
        self.model = model
        self.word_to_ix = word_to_ix
        self.ix_to_word = ix_to_word

    def similar_words(self, word, n=5):
        if word not in self.word_to_ix:
            return []

        word_ix = torch.tensor([self.word_to_ix[word]])
        word_vec = self.model.embeddings(word_ix).detach().numpy()[0]

        similarities = []
        for w, ix in self.word_to_ix.items():
            if w != word and w != '<UNK>':
                vec = self.model.embeddings(torch.tensor([ix])).detach().numpy()[0]
                similarity = np.dot(word_vec, vec) / (np.linalg.norm(word_vec) * np.linalg.norm(vec))
                similarities.append((w, similarity))

        return sorted(similarities, key=lambda x: x[1], reverse=True)[:n]

evaluator = Word2VecEvaluator(model, w2v.word_to_ix, w2v.ix_to_word)

test_words = ['trade', 'oil', 'bank', 'market']
for word in test_words:
    if word in w2v.word_to_ix:
        print(f"\nsimilar to '{word}':")
        similar = evaluator.similar_words(word, n=5)
        for w, score in similar:
            print(f"{w}: {score:.3f}")


similar to 'trade':
adviser: 0.359
economic: 0.343
merchandise: 0.335
barco: 0.331
stks: 0.330

similar to 'oil':
raised: 0.374
communique: 0.363
43: 0.345
william: 0.337
1984: 0.330

similar to 'bank':
philippines: 0.408
delivered: 0.397
interstate: 0.385
cause: 0.374
doesn: 0.361

similar to 'market':
grow: 0.395
delaware: 0.394
measured: 0.375
urges: 0.373
stable: 0.371
