In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

In [132]:
corpus = [
    'he is king',
    'she is a queen',
    'he is a man',
    'she is a woman',
    'warsaw is poland capital',
    'berlin is germany capital',
    'dhaka is bangladesh capital'
]

In [133]:
def tokenize_corpus(corpus):
    tokens = [x.split() for x in corpus]
    return tokens

tokenized_corpus = tokenize_corpus(corpus)
tokenized_corpus

[['he', 'is', 'king'],
 ['she', 'is', 'a', 'queen'],
 ['he', 'is', 'a', 'man'],
 ['she', 'is', 'a', 'woman'],
 ['warsaw', 'is', 'poland', 'capital'],
 ['berlin', 'is', 'germany', 'capital'],
 ['dhaka', 'is', 'bangladesh', 'capital']]

In [134]:
vocabulary = []
for sentence in tokenized_corpus:
    for token in sentence:
        if token not in vocabulary:
            vocabulary.append(token)
            
word2idx = {w: idx for (idx, w) in enumerate(vocabulary)}
idx2word = {idx: w for (idx, w) in enumerate(vocabulary)}

In [242]:
vocabulary_size = len(vocabulary)
context_size = 2
idx_pairs = []

for sentence in tokenized_corpus:
    indices = [word2idx[word] for word in sentence]
    
    # for each word, treat it as a center word
    
    for center_word_pos in range(len(indices)):
        
        # for each window position
        
        for w in range(-window_size, window_size + 1):
            context_word_pos = center_word_pos + w
            
            # make sure not to jump outside of the sentence
            
            if context_word_pos < 0 or context_word_pos >= len(indices) or center_word_pos == context_word_pos:
                continue
            context_word_idx = indices[context_word_pos]
            idx_pairs.append((indices[center_word_pos], context_word_idx))

idx_pairs = np.array(idx_pairs)

In [220]:
class SkipGramLanguageModeler(nn.Module):
    def __init__(self, vocab_size, embedding_dim, context_size):
        super(SkipGramLanguageModeler, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear1 = nn.Linear(embedding_dim, 128)
        self.linear2 = nn.Linear(128, vocab_size)
    
    def forward(self, word_index):
        embeds = self.embeddings(word_index)
        out = F.relu(self.linear1(embeds))
        out = self.linear2(out)
        log_probs = F.log_softmax(out, dim=1)
        return log_probs

In [222]:
losses = []
loss_function = nn.NLLLoss()
skip_gram_model = SkipGramLanguageModeler(vocabulary_size, embedding_dims, context_size)
optimizer = optim.SGD(skip_gram_model.parameters(), lr = 1e-2)

for epoch in range(1000):
    total_loss = 0
    for word, target in idx_pairs:
        word_idx = torch.tensor([word], dtype=torch.long)
        
        skip_gram_model.zero_grad()
        
        log_probs = skip_gram_model(word_idx)
        
        foo = torch.tensor([target], dtype=torch.long)
        
#         print(f'tens:{foo}, shape: {foo.shape}')
        
        loss = loss_function(log_probs, foo)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    losses.append(total_loss)

In [227]:
class TestModeler(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(TestModeler, self).__init__()
        self.linear1 = nn.Linear(vocab_size, embedding_dim)
        self.linear2 = nn.Linear(embedding_dim, vocab_size)
        
    def forward(self, word_idx, vocab_size):
        x = torch.zeros(vocab_size).float()
        x[word_idx] = 1.0
        
        out = self.linear1(x)
        out = self.linear2(out)
        log_probs = F.log_softmax(out, dim=0)
        return log_probs.view(1, -1)

In [228]:
test_losses = []
loss_function = nn.NLLLoss()
test_model = TestModeler(vocabulary_size, 5)
optimizer = optim.SGD(test_model.parameters(), lr=1e-2)

for epoch in range(1000):
    loss_val = 0
    
    for data, target in idx_pairs:
        
        test_model.zero_grad()
        log_probs = test_model(data, vocabulary_size)
        foo = torch.tensor([target], dtype=torch.long)
        
        
        loss = loss_function(log_probs, foo)
        loss.backward()
        optimizer.step()
        
        loss_val+=loss.item()
        
    test_losses.append(loss_val)
#     if epoch % 100 == 0:
#         print(f'Loss at epoch {epoch}: {loss_val/len(idx_pairs)}')

In [241]:
word = 'berlin'
a = skip_gram_model(torch.tensor([word2idx[word]],dtype=torch.long))
b = test_model(word2idx[word],vocabulary_size)
print("Prediction of skip gram: ", idx2word[torch.argmax(a[0]).item()])
print("Prediction of test model: ", idx2word[torch.argmax(b[0]).item()])

Prediction of skip gram:  germany
Prediction of test model:  germany
