In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import Counter
import re

In [2]:
text = """
The quick brown fox jumped over the lazy dog.
The fox is quick and brown. The dog is lazy.
The cat is neither quick nor lazy.
"""

def preprocess(text):
    text = text.lower()
    tokens = re.findall(r'\b\w+\b', text)
    return tokens

tokens = preprocess(text)
vocab = list(set(tokens))

word_to_ix = {word: i for i, word in enumerate(vocab)}
ix_to_word = {i: word for word, i in word_to_ix.items()}

In [3]:
def get_context_pairs(tokens, window_size=2):
    pairs = []
    for i, target in enumerate(tokens):
        start = max(0, i - window_size) # to not look before the start
        end = min(len(tokens), i + window_size + 1) # to not look after the end
        context = tokens[start:i] + tokens[i+1:end] # context before + after of the "target"
        pairs.extend([(target, ctx) for ctx in context])
    return pairs

context_pairs = get_context_pairs(tokens)

In [4]:
class SkipGram(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(SkipGram, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.output = nn.Linear(embedding_dim, vocab_size)
        
    def forward(self, x):
        embedded = self.embeddings(x)
        out = self.output(embedded)
        return torch.log_softmax(out, dim=1) # becasue of NLLLoss

In [5]:
model = SkipGram(len(vocab), embedding_dim=50)
optimizer = optim.SGD(model.parameters(), lr=0.001)
criterion = nn.NLLLoss()

for epoch in range(100):
    total_loss = 0
    for target, context in context_pairs:
        target_ix = torch.tensor([word_to_ix[target]])
        context_ix = torch.tensor([word_to_ix[context]])
        
        log_probs = model(target_ix)
        loss = criterion(log_probs, context_ix)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    if epoch % 10 == 0:
        print(f"epoch {epoch}, loss: {total_loss}")

epoch 0, loss: 268.04978477954865
epoch 10, loss: 228.10314893722534
epoch 20, loss: 212.74036729335785
epoch 30, loss: 204.69714057445526
epoch 40, loss: 199.6276605129242
epoch 50, loss: 196.04817807674408
epoch 60, loss: 193.3554549217224
epoch 70, loss: 191.25290977954865
epoch 80, loss: 189.56860542297363
epoch 90, loss: 188.1928254365921


In [6]:
def get_vector(word):
    word_ix = torch.tensor([word_to_ix[word]])
    return model.embeddings(word_ix).detach().numpy()[0]

def similar_words(word, n=3):
    target_vec = get_vector(word)
    similarities = []
    for w in vocab:
        if w != word:
            vec = get_vector(w)
            similarity = float(np.dot(target_vec, vec) / (np.linalg.norm(target_vec) * np.linalg.norm(vec)))
            similarities.append((w, similarity))
    return sorted(similarities, key=lambda x: x[1], reverse=True)[:n]


print("similar words to 'fox':")
similar = similar_words('fox', n=3)
for word, similarity in similar:
    print(f"{word}: {similarity:.3f}")

similar words to 'fox':
cat: 0.135
nor: 0.120
lazy: 0.097
