In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as functional
import torch.optim as optim

torch.manual_seed(1)

<torch._C.Generator at 0x1cf35865a70>

In [3]:
# two words before i th word and two words after i th word form a context
CONTEXT_SIZE = 2

# 10 dimensional embeddings, embedding vector has size [n,10], n = number of words in vocabulary
EMBEDDING_DIM = 10

# Macbeth
test_para = """The raven himself is hoarse,
That croaks the fatal entrance of Duncan
Under my battlements. Come, you spirits
That tend on mortal thoughts, unsex me here,
And fill me from the crown to the toe top-full
Of direst cruelty! make thick my blood;
Stop up the access and passage to remorse,
That no compunctious visitings of nature
Shake my fell purpose, nor keep peace between
The effect and it! Come to my woman’s breasts,
And take my milk for gall, you murdering ministers,
Wherever in your sightless substances
You wait on nature’s mischief! Come, thick night,
And pall thee in the dunnest smoke of hell,
That my keen knife see not the wound it makes,
Nor heaven peep through the blanket of the dark,
To cry ‘Hold, hold!’""".split()

In [5]:
vocab = set(test_para)
vocab_size = len(vocab)
print("Vocabulary length ::: ", vocab_size)

word_to_ix = {word: i for i, word in enumerate(vocab)}
data = []
for i in range(2, len(test_para) - 2):
    context = [test_para[i - 2], test_para[i - 1],
               test_para[i + 1], test_para[i + 2]]
    target = test_para[i]
    data.append((context, target))
print("Generated CBOW ::: ",data[:5])

Vocabulary length :::  98
Generated CBOW :::  [(['The', 'raven', 'is', 'hoarse,'], 'himself'), (['raven', 'himself', 'hoarse,', 'That'], 'is'), (['himself', 'is', 'That', 'croaks'], 'hoarse,'), (['is', 'hoarse,', 'croaks', 'the'], 'That'), (['hoarse,', 'That', 'the', 'fatal'], 'croaks')]


In [9]:
class CBOW(nn.Module):

    def __init__(self, vocab_size, embedding_dim, context_size):
        super(CBOW, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear1 = nn.Linear((context_size**2) * embedding_dim, 128)
        self.linear2 = nn.Linear(128, vocab_size)



    def forward(self, inputs):
        embeds = self.embeddings(inputs).view((1, -1))
        out = functional.relu(self.linear1(embeds))
        out = self.linear2(out)
        log_probs = functional.log_softmax(out, dim=1)
        return log_probs


def make_context_vector(context, word_to_ix):
    idxs = [word_to_ix[w] for w in context]
    return torch.tensor(idxs, dtype=torch.long)    
    
losses = []
loss_function = nn.NLLLoss()
model = CBOW(len(vocab), EMBEDDING_DIM, CONTEXT_SIZE)
optimizer = optim.SGD(model.parameters(), lr=0.001)

for epoch in range(10):
    total_loss = 0
    for context, target in data:

        # Converting each owrd in 1*10 sized tensor tensor
        context_idxs = torch.tensor(make_context_vector(context, word_to_ix), dtype=torch.long)

        # Recall that torch *accumulates* gradients. Before passing in a
        # new instance, you need to zero out the gradients from the old
        # instance
        model.zero_grad()

        # Run the forward pass, getting log probabilities over next words
        log_probs = model(context_idxs)

        # Compute your loss function. (Again, Torch wants the target word wrapped in a tensor)
        loss = loss_function(log_probs, torch.tensor([word_to_ix[target]], dtype=torch.long))

        # Do the backward pass and update the gradient
        loss.backward()
        optimizer.step()

        # Get the Python number from a 1-element Tensor by calling tensor.item()
        total_loss += loss.item()
    losses.append(total_loss)
print(losses)  # The loss decreased every iteration over the training data!




[574.643773317337, 570.9957427978516, 567.3815424442291, 563.7986721992493, 560.243744134903, 556.715674161911, 553.2099783420563, 549.7271420955658, 546.2651727199554, 542.820720911026, 539.3921296596527, 535.9783999919891, 532.5780999660492, 529.1924240589142, 525.8173489570618, 522.4510202407837, 519.0959916114807, 515.7501153945923, 512.4106996059418, 509.07572984695435, 505.7439651489258, 502.4139461517334, 499.0902771949768, 495.77408957481384, 492.46336555480957, 489.16151547431946, 485.86507868766785, 482.5750743150711, 479.2907748222351, 476.0074745416641, 472.7292813062668, 469.4566503763199, 466.185830950737, 462.91646015644073, 459.6520085334778, 456.3902645111084, 453.1323436498642, 449.87439036369324, 446.61417031288147, 443.35293793678284, 440.0900717973709, 436.82505917549133, 433.5539928674698, 430.27351117134094, 426.98337280750275, 423.6858986020088, 420.3775392770767, 417.06152683496475, 413.7356404662132, 410.39505565166473, 407.0431424379349, 403.68037378787994, 4