In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
TEXT = "deep learning models are cool deep learning is fun"
CONTEXT_SIZE = 2
EMBED_DIM = 10
EPOCHS = 500
LR = 0.1
SEED = 42

In [3]:
torch.manual_seed(SEED)

words = TEXT.lower().split()
vocab = sorted(set(words))

word_to_ix = {w: i for i, w in enumerate(vocab)}
ix_to_word = {i: w for w, i in word_to_ix.items()}
vocab_size = len(vocab)

data = []
for i in range(len(words) - CONTEXT_SIZE):
    context = words[i:i + CONTEXT_SIZE]
    target = words[i + CONTEXT_SIZE]
    data.append((context, target))

In [4]:
class NextWordModel(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.fc1 = nn.Linear(embed_dim, embed_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        # x: [context_size]
        e = self.embedding(x)      # [context_size, embed_dim]
        e = e.mean(dim=0)          # [embed_dim]
        h = self.relu(self.fc1(e))
        return self.fc2(h)         # [vocab_size]

In [5]:
model = NextWordModel(vocab_size, EMBED_DIM)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LR)

for epoch in range(EPOCHS + 1):
    total_loss = 0.0

    for context, target in data:
        x = torch.tensor([word_to_ix[w] for w in context])
        y = torch.tensor([word_to_ix[target]])

        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits.unsqueeze(0), y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    if epoch % 50 == 0:
        print(f"Epoch {epoch:04d} | Loss: {total_loss:.4f}")

Epoch 0000 | Loss: 14.2626
Epoch 0050 | Loss: 2.0488
Epoch 0100 | Loss: 1.7592
Epoch 0150 | Loss: 1.6764
Epoch 0200 | Loss: 1.6340
Epoch 0250 | Loss: 1.6071
Epoch 0300 | Loss: 1.5882
Epoch 0350 | Loss: 1.5740
Epoch 0400 | Loss: 1.5628
Epoch 0450 | Loss: 1.5537
Epoch 0500 | Loss: 1.5462


In [6]:
def predict_next(context, top_k=1):
    model.eval()
    x = torch.tensor([word_to_ix[w] for w in context])

    with torch.no_grad():
        logits = model(x)
        probs = torch.softmax(logits, dim=0)
        topk = torch.topk(probs, top_k)

    return [ix_to_word[i.item()] for i in topk.indices]

In [7]:
print("\nPredictions:")
print("deep learning  ->", predict_next(["deep", "learning"]))
print("learning models ->", predict_next(["learning", "models"]))
print("models are     ->", predict_next(["models", "are"]))


Predictions:
deep learning  -> ['is']
learning models -> ['are']
models are     -> ['cool']
