In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random

# Sample corpus
text = "the sun rises in the east and sets in the west"

# Tokenize
words = text.lower().split()
vocab = list(set(words))
word_to_ix = {word: i for i, word in enumerate(vocab)}
ix_to_word = {i: word for word, i in word_to_ix.items()}

# Generate trigrams: context = (w1, w2), target = w3
trigrams = [([words[i], words[i + 1]], words[i + 2]) for i in range(len(words) - 2)]
print(trigrams[:3])


[(['the', 'sun'], 'rises'), (['sun', 'rises'], 'in'), (['rises', 'in'], 'the')]


In [2]:
class NGramLanguageModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, context_size):
        super(NGramLanguageModel, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear1 = nn.Linear(context_size * embedding_dim, 128)
        self.linear2 = nn.Linear(128, vocab_size)

    def forward(self, inputs):
        embeds = self.embeddings(inputs).view((1, -1))
        out = F.relu(self.linear1(embeds))
        out = self.linear2(out)
        return out


In [5]:
len(vocab)

8

In [9]:
CONTEXT_SIZE = 2
EMBEDDING_DIM = 10

model = NGramLanguageModel(len(vocab), EMBEDDING_DIM, CONTEXT_SIZE)
print(model)
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Training loop
for epoch in range(100):
    total_loss = 0
    for context, target in trigrams:
        context_idxs = torch.tensor([word_to_ix[w] for w in context], dtype=torch.long)
        target_idx = torch.tensor([word_to_ix[target]], dtype=torch.long)
        # Forward + loss
        log_probs = model(context_idxs)
        loss = loss_function(log_probs, target_idx)

        # Backward + update
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {total_loss:.4f}")


NGramLanguageModel(
  (embeddings): Embedding(8, 10)
  (linear1): Linear(in_features=20, out_features=128, bias=True)
  (linear2): Linear(in_features=128, out_features=8, bias=True)
)
Epoch 0, Loss: 19.8202
Epoch 10, Loss: 7.5456
Epoch 20, Loss: 4.0914
Epoch 30, Loss: 2.9292
Epoch 40, Loss: 2.4596
Epoch 50, Loss: 2.2263
Epoch 60, Loss: 2.0916
Epoch 70, Loss: 2.0061
Epoch 80, Loss: 1.9452
Epoch 90, Loss: 1.9022


In [7]:
# Predict next word
context = ['the', 'sun']
context_idxs = torch.tensor([word_to_ix[w] for w in context], dtype=torch.long)
output = model(context_idxs)
predicted_idx = torch.argmax(output).item()
print(f"Predicted word: {ix_to_word[predicted_idx]}")


Predicted word: rises
