In [7]:
from collections import Counter
import torch
import torch.nn as nn
import math
import re
import numpy as np
from nltk.tokenize import word_tokenize


In [None]:
# Load text file
with open("data/wiki.train.tokens", "r", encoding="utf-8") as f:
    text = f.read().lower()

tokens = word_tokenize(text)
tokens = [word for word in tokens if word.isalpha()]

# replace unk with <unk>
tokens = [word if word != "unk" else "<unk>" for word in tokens]

# Remove rare words
counter = Counter(tokens)
rare_words = set(word for word, count in counter.items() if count < 5)
tokens = [word if word not in rare_words else "<unk>" for word in tokens]

counter = Counter(tokens)

vocab = {word: i for i, (word, _) in enumerate(counter.items())}

vocab["<unk>"] = len(vocab)  # handle unknowns
inv_vocab = {i: w for w, i in vocab.items()}
vocab_size = len(vocab)+1  # +1 for <unk>
# Encode tokens
encoded = [vocab.get(word, vocab["<unk>"]) for word in tokens]
data = torch.tensor(encoded, dtype=torch.long)

In [9]:
def batchify(data, batch_size):
    nbatch = len(data) // batch_size
    data = data[:nbatch * batch_size]
    return data.view(batch_size, -1).t().contiguous()

batch_size = 32
train_data = batchify(data, batch_size)

# Get batch: source and target
def get_batch(source, i, bptt=35):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)

    return data, target


In [10]:
class PositionalEncoding(nn.Module):
    def __init__(self, emb_size, dropout=0.1, maxlen=5000):
        super().__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding.unsqueeze(1))

    def forward(self, x):
        return self.dropout(x + self.pos_embedding[:x.size(0), :])

class TransformerModel(nn.Module):
    def __init__(self, vocab_size, emb_size, nhead, nhid, nlayers, dropout=0.2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.pos_encoder = PositionalEncoding(emb_size, dropout)
        encoder_layers = nn.TransformerEncoderLayer(emb_size, nhead, nhid, dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, nlayers)
        self.decoder = nn.Linear(emb_size, vocab_size)
        self.emb_size = emb_size

    def forward(self, src, src_mask):
        src = self.embedding(src) * math.sqrt(self.emb_size)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        return self.decoder(output)

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        return mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))


In [11]:
device = torch.device("cuda")

model = TransformerModel(
    vocab_size=vocab_size,
    emb_size=200,
    nhead=2,
    nhid=200,
    nlayers=2,
    dropout=0.2
).to(device)

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

def train(model, data, epochs=5):
    model.train()
    for epoch in range(epochs):
        total_loss = 0.
        for i in range(0, data.size(0) - 1, 35):
            input_seq, target = get_batch(data, i)

            src_mask = model.generate_square_subsequent_mask(input_seq.size(0)).to(device)
            optimizer.zero_grad()
            output = model(input_seq.to(device), src_mask)
            loss = loss_fn(output.view(-1, vocab_size), target.to(device))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1} | Loss: {total_loss:.2f}")

def predict_next(model, text, vocab, inv_vocab, n_words=5):
    model.eval()
    tokens = text.lower().split()
    input_ids = torch.tensor([vocab.get(w, vocab["<unk>"]) for w in tokens], dtype=torch.long).unsqueeze(1).to(device)
    mask = model.generate_square_subsequent_mask(input_ids.size(0)).to(device)
    with torch.no_grad():
        output = model(input_ids, mask)
    last_logits = output[-1, 0, :]  # last timestep, first batch index
    top_ids = torch.topk(torch.softmax(last_logits, dim=-1), k=n_words).indices
    return [inv_vocab[i.item()] for i in top_ids]



In [12]:
train(model, train_data, epochs=30)

Epoch 1 | Loss: 9965.29
Epoch 2 | Loss: 9252.32
Epoch 3 | Loss: 8897.11
Epoch 4 | Loss: 8644.11
Epoch 5 | Loss: 8454.86
Epoch 6 | Loss: 8308.26
Epoch 7 | Loss: 8190.66
Epoch 8 | Loss: 8094.43
Epoch 9 | Loss: 8014.35
Epoch 10 | Loss: 7941.40
Epoch 11 | Loss: 7877.91
Epoch 12 | Loss: 7823.14
Epoch 13 | Loss: 7774.18
Epoch 14 | Loss: 7727.42
Epoch 15 | Loss: 7685.10
Epoch 16 | Loss: 7646.91
Epoch 17 | Loss: 7609.03
Epoch 18 | Loss: 7574.76
Epoch 19 | Loss: 7541.29
Epoch 20 | Loss: 7510.80
Epoch 21 | Loss: 7480.18
Epoch 22 | Loss: 7451.75
Epoch 23 | Loss: 7423.63
Epoch 24 | Loss: 7397.26
Epoch 25 | Loss: 7372.31
Epoch 26 | Loss: 7348.31
Epoch 27 | Loss: 7323.03
Epoch 28 | Loss: 7299.12
Epoch 29 | Loss: 7277.05
Epoch 30 | Loss: 7254.97


In [29]:
print(predict_next(model, "in several ", vocab, inv_vocab))

['weeks', 'countries', 'years', 'days', 'other']
