In [1]:
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, TensorDataset
import time
from tqdm import tqdm

In [2]:
# Configurations
device = "cpu"
learning_rate = 4.0
batch_size = 16
num_epochs = 5
marker = '.'
ngram_level = 4
hidden_dim = 20
embedding_dim = 20

In [3]:
# Create set & mappings
class CharacterDataset(Dataset):
    def __init__(self, words, characters):
        self.words = words
        self.characters = characters
        # Mapping characters to indices and vice versa
        self.char_to_idx = {ch: i for i, ch in enumerate(characters)}
        self.idx_to_char = {i: s for s, i in self.char_to_idx.items()}

    def __len__(self):
        return len(self.words)

    def contains(self, word):
        return word in self.words

    def get_vocab_size(self):
        return len(self.characters)

    def encode(self, word):
        # Convert a word to a list of character indices
        return [self.char_to_idx[c] for c in word]

    def decode(self, tokens):
        # Convert a list of character indices to a word
        return ''.join(self.idx_to_char[i] for i in tokens)

    def __getitem__(self, idx):
        word = self.words[idx]
        return self.encode(word)

In [4]:
#load data and create datasets

def load_datasets(window, input_file='data/names.txt'):
    with open(input_file, 'r') as f:
        data = f.read()
    words = data.splitlines()
    words = [w.strip() for w in words]
    words = [w for w in words if w]

    characters = sorted(list(set(''.join(words))))
    characters = [marker] + characters

    # Add marker characters to the beginning and end of each word
    words = [''.join((window - 1) * [marker]) + word + ''.join((window - 1) * [marker]) for word in words]

    print(f"The number of examples in the dataset: {len(words)}")
    print(f"The number of unique characters in the vocabulary: {len(characters)}")
    print(f"The vocabulary we have is: {''.join(characters)}")

    # Split the dataset into training, validation, and test sets
    out_of_sample_set_size = min(2000, int(len(words) * 0.1))
    test_set_size = 1500

    rp = torch.randperm(len(words)).tolist()

    train_words = [words[i] for i in rp[:-out_of_sample_set_size]]
    validation_words = [words[i] for i in rp[-out_of_sample_set_size:-test_set_size]]
    test_words = [words[i] for i in rp[-test_set_size:]]

    print(f"We've split up the dataset into {len(train_words)}, {len(validation_words)}, {len(test_words)} training, validation, and test examples")

    train_dataset = CharacterDataset(train_words, characters)
    validation_dataset = CharacterDataset(validation_words, characters)
    test_dataset = CharacterDataset(test_words, characters)

    return train_dataset, validation_dataset, test_dataset

In [5]:
def create_data_loader(dataset, window):
    x_list = []
    y_list = []
    for i, word in enumerate(dataset):
        for j, _ in enumerate(word):
            if j + window > len(word) - 1:
                break
            word_window = word[j:j + window]
            x, y = word_window[:-1], word_window[-1]
            x_list.append(x)
            y_list.append(y)

    # Create a DataLoader with the input-output pairs
    return DataLoader(
        TensorDataset(torch.tensor(x_list), torch.tensor(y_list)),
        batch_size,
        shuffle=True
    )

In [6]:
class NGramLanguageModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, ngram_level):
        super(NGramLanguageModel, self).__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.ngram_level = ngram_level

        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.hidden_layer = nn.Parameter(torch.zeros((ngram_level - 1) * embedding_dim, hidden_dim))
        self.output_layer_W = nn.Parameter(torch.zeros((ngram_level - 1) * embedding_dim, vocab_size))
        self.output_layer_U = nn.Parameter(torch.zeros(hidden_dim, vocab_size))

        self.bias_output = torch.nn.Parameter(torch.ones(vocab_size))
        self.bias_hidden = torch.nn.Parameter(torch.ones(hidden_dim))

        self.init_weights()

    def init_weights(self):
        initrange = 0.5

    def forward(self, x):
        # Get character embeddings
        x = self.embedding(x)

        # Get the batch size
        batch_size = x.shape[0]

        # Reshape the input to a 2D tensor
        x = x.view(batch_size, -1)

        # Compute the output using the hidden and output layer parameters
        y = self.bias_output + torch.matmul(x, self.output_layer_W) + torch.matmul(torch.tanh(self.bias_hidden + torch.matmul(x, self.hidden_layer)), self.output_layer_U)

        return y

In [7]:
def compute_perplexity(total_loss, total_batches):
    # Compute the perplexity given the total loss and number of batches
    return torch.exp(torch.tensor(total_loss / total_batches)).item()

In [8]:
#Set up model
def train_model(data_loader, model, optimizer, criterion, epoch):
    model.train()
    total_loss, total_batches = 0.0, 0.0
    log_interval = 500

    for idx, (x, y) in tqdm(enumerate(data_loader)):
        optimizer.zero_grad()

        if not x.nelement():
            continue

        # Forward pass
        logits = model(x)

        # Compute the loss
        loss = criterion(input=logits, target=y.squeeze(-1))
        # Backward pass
        loss.backward()

        # Clip gradients to avoid exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        # Update the model parameters
        optimizer.step()
        total_loss += loss.item()
        total_batches += 1

        # Log the training progress
        if idx % log_interval == 0 and idx > 0:
            perplexity = compute_perplexity(total_loss, total_batches)
            print(
                "| epoch {:3d} "
                "| {:5d}/{:5d} batches "
                "| perplexity {:8.3f} "
                "| loss {:8.3f} "
                .format(
                    epoch,
                    idx,
                    len(data_loader),
                    perplexity,
                    total_loss / total_batches,
                )
            )
            total_loss, total_batches = 0.0, 0

In [9]:
def evaluate_model(data_loader, model, criterion):
    model.eval()
    total_loss, total_batches = 0.0, 0

    with torch.no_grad():
        for idx, (x, y) in enumerate(data_loader):
            # Forward pass
            logits = model(x)
            # Compute the loss
            total_loss += criterion(input=logits, target=y.squeeze(-1)).item()
            total_batches += 1
    return total_loss / total_batches, compute_perplexity(total_loss, total_batches)

In [10]:
def generate_word(model, dataset, window=ngram_level):
    generated_word = []
    context = [0] * (ngram_level - 1)

    while True:
        # Get the logits for the current context
        logits = model(torch.tensor(context).view(1, -1))

        # Convert logits to probabilities using softmax
        probs = nn.Softmax(dim=1)(logits)

        # Sample a character index based on the probabilities
        token_id = torch.multinomial(probs, num_samples=1).item()

        generated_word.append(token_id)

        # Update the context with the generated character
        context = context[1:] + [token_id]

        # Stop generating if the end marker is encountered
        if token_id == 0:
            break

    # Decode the generated character indices to a word
    return ''.join(dataset.decode(generated_word))

In [11]:
#Load Data
train_dataset, validation_dataset, test_dataset = load_datasets(ngram_level)

# Create data loaders
train_loader = create_data_loader(train_dataset, ngram_level)
validation_loader = create_data_loader(validation_dataset, ngram_level)
test_loader = create_data_loader(test_dataset, ngram_level)

The number of examples in the dataset: 32033
The number of unique characters in the vocabulary: 27
The vocabulary we have is: .abcdefghijklmnopqrstuvwxyz
We've split up the dataset into 30033, 500, 1500 training, validation, and test examples


In [12]:
# Set up the model, loss function, and optimizer
criterion = torch.nn.CrossEntropyLoss().to(device)
model = NGramLanguageModel(
    train_dataset.get_vocab_size(), embedding_dim, hidden_dim, ngram_level
).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)


In [13]:
# Train the model
for epoch in range(1, num_epochs + 1):
    epoch_start_time = time.time()
    train_model(train_loader, model, optimizer, criterion, epoch)
    loss_val, perplexity_val = evaluate_model(validation_loader, model, criterion)
    scheduler.step()
    print("-" * 59)
    print(
        "| end of epoch {:3d} "
        "| time: {:5.2f}s "
        "| valid perplexity {:8.3f} "
        "| valid loss {:8.3f}".format(
            epoch,
            time.time() - epoch_start_time,
            perplexity_val,
            loss_val
        )
    )
    print("-" * 59)

print("Checking the results of test dataset.")
loss_test, perplexity_test = evaluate_model(test_loader, model, criterion)
print("test perplexity {:8.3f} | test loss {:8.3f} ".format(perplexity_test, loss_test))


663it [00:00, 951.43it/s]

| epoch   1 |   500/15242 batches | perplexity    9.989 | loss    2.301 


1169it [00:01, 1215.90it/s]

| epoch   1 |  1000/15242 batches | perplexity    8.874 | loss    2.183 


1698it [00:01, 1274.74it/s]

| epoch   1 |  1500/15242 batches | perplexity    8.676 | loss    2.161 


2219it [00:02, 1239.78it/s]

| epoch   1 |  2000/15242 batches | perplexity    8.474 | loss    2.137 


2759it [00:02, 1329.60it/s]

| epoch   1 |  2500/15242 batches | perplexity    8.422 | loss    2.131 


3169it [00:02, 1348.99it/s]

| epoch   1 |  3000/15242 batches | perplexity    8.630 | loss    2.155 


3735it [00:03, 1373.46it/s]

| epoch   1 |  3500/15242 batches | perplexity    8.375 | loss    2.125 


4151it [00:03, 1373.38it/s]

| epoch   1 |  4000/15242 batches | perplexity    8.213 | loss    2.106 


4711it [00:03, 1380.69it/s]

| epoch   1 |  4500/15242 batches | perplexity    8.231 | loss    2.108 


5264it [00:04, 1353.43it/s]

| epoch   1 |  5000/15242 batches | perplexity    8.238 | loss    2.109 


5668it [00:04, 1328.94it/s]

| epoch   1 |  5500/15242 batches | perplexity    8.372 | loss    2.125 


6191it [00:05, 1267.53it/s]

| epoch   1 |  6000/15242 batches | perplexity    8.200 | loss    2.104 


6697it [00:05, 1207.04it/s]

| epoch   1 |  6500/15242 batches | perplexity    8.293 | loss    2.115 


7191it [00:05, 1204.62it/s]

| epoch   1 |  7000/15242 batches | perplexity    8.308 | loss    2.117 


7703it [00:06, 1261.07it/s]

| epoch   1 |  7500/15242 batches | perplexity    8.217 | loss    2.106 


8251it [00:06, 1359.65it/s]

| epoch   1 |  8000/15242 batches | perplexity    8.131 | loss    2.096 


8655it [00:07, 1303.93it/s]

| epoch   1 |  8500/15242 batches | perplexity    8.327 | loss    2.119 


9193it [00:07, 1320.02it/s]

| epoch   1 |  9000/15242 batches | perplexity    8.205 | loss    2.105 


9727it [00:07, 1254.52it/s]

| epoch   1 |  9500/15242 batches | perplexity    8.337 | loss    2.121 


10230it [00:08, 1190.28it/s]

| epoch   1 | 10000/15242 batches | perplexity    8.466 | loss    2.136 


10755it [00:08, 1271.69it/s]

| epoch   1 | 10500/15242 batches | perplexity    8.121 | loss    2.094 


11157it [00:09, 1317.34it/s]

| epoch   1 | 11000/15242 batches | perplexity    8.217 | loss    2.106 


11701it [00:09, 1314.82it/s]

| epoch   1 | 11500/15242 batches | perplexity    8.015 | loss    2.081 


12191it [00:09, 1033.81it/s]

| epoch   1 | 12000/15242 batches | perplexity    8.337 | loss    2.121 


12657it [00:10, 1120.45it/s]

| epoch   1 | 12500/15242 batches | perplexity    8.037 | loss    2.084 


13130it [00:10, 1122.35it/s]

| epoch   1 | 13000/15242 batches | perplexity    8.005 | loss    2.080 


13759it [00:11, 1270.45it/s]

| epoch   1 | 13500/15242 batches | perplexity    8.027 | loss    2.083 


14181it [00:11, 1362.95it/s]

| epoch   1 | 14000/15242 batches | perplexity    8.244 | loss    2.109 


14755it [00:12, 1356.30it/s]

| epoch   1 | 14500/15242 batches | perplexity    8.246 | loss    2.110 


15242it [00:12, 1231.27it/s]

| epoch   1 | 15000/15242 batches | perplexity    8.128 | loss    2.095 





-----------------------------------------------------------
| end of epoch   1 | time: 12.47s | valid perplexity    8.438 | valid loss    2.133
-----------------------------------------------------------


696it [00:00, 1394.67it/s]

| epoch   2 |   500/15242 batches | perplexity    7.791 | loss    2.053 


1258it [00:00, 1392.25it/s]

| epoch   2 |  1000/15242 batches | perplexity    7.540 | loss    2.020 


1664it [00:01, 1247.44it/s]

| epoch   2 |  1500/15242 batches | perplexity    7.466 | loss    2.010 


2177it [00:01, 1251.57it/s]

| epoch   2 |  2000/15242 batches | perplexity    7.527 | loss    2.019 


2699it [00:02, 1237.90it/s]

| epoch   2 |  2500/15242 batches | perplexity    7.620 | loss    2.031 


3203it [00:02, 1241.42it/s]

| epoch   2 |  3000/15242 batches | perplexity    7.572 | loss    2.025 


3702it [00:02, 1156.41it/s]

| epoch   2 |  3500/15242 batches | perplexity    7.598 | loss    2.028 


4216it [00:03, 1190.84it/s]

| epoch   2 |  4000/15242 batches | perplexity    7.509 | loss    2.016 


4715it [00:03, 1230.86it/s]

| epoch   2 |  4500/15242 batches | perplexity    7.627 | loss    2.032 


5194it [00:04, 1142.23it/s]

| epoch   2 |  5000/15242 batches | perplexity    7.498 | loss    2.015 


5701it [00:04, 1227.31it/s]

| epoch   2 |  5500/15242 batches | perplexity    7.368 | loss    1.997 


6241it [00:05, 1319.34it/s]

| epoch   2 |  6000/15242 batches | perplexity    7.524 | loss    2.018 


6645it [00:05, 1312.06it/s]

| epoch   2 |  6500/15242 batches | perplexity    7.572 | loss    2.024 


7197it [00:05, 1280.95it/s]

| epoch   2 |  7000/15242 batches | perplexity    7.666 | loss    2.037 


7724it [00:06, 1285.54it/s]

| epoch   2 |  7500/15242 batches | perplexity    7.300 | loss    1.988 


8274it [00:06, 1368.14it/s]

| epoch   2 |  8000/15242 batches | perplexity    7.473 | loss    2.011 


8689it [00:06, 1296.87it/s]

| epoch   2 |  8500/15242 batches | perplexity    7.535 | loss    2.020 


9245it [00:07, 1346.99it/s]

| epoch   2 |  9000/15242 batches | perplexity    7.453 | loss    2.009 


9671it [00:07, 1397.29it/s]

| epoch   2 |  9500/15242 batches | perplexity    7.462 | loss    2.010 


10241it [00:08, 1407.26it/s]

| epoch   2 | 10000/15242 batches | perplexity    7.564 | loss    2.023 


10634it [00:08, 1189.31it/s]

| epoch   2 | 10500/15242 batches | perplexity    7.447 | loss    2.008 


11124it [00:08, 1187.09it/s]

| epoch   2 | 11000/15242 batches | perplexity    7.475 | loss    2.012 


11626it [00:09, 1194.93it/s]

| epoch   2 | 11500/15242 batches | perplexity    7.416 | loss    2.004 


12151it [00:09, 1296.62it/s]

| epoch   2 | 12000/15242 batches | perplexity    7.483 | loss    2.013 


12702it [00:10, 1360.41it/s]

| epoch   2 | 12500/15242 batches | perplexity    7.341 | loss    1.993 


13219it [00:10, 1228.01it/s]

| epoch   2 | 13000/15242 batches | perplexity    7.585 | loss    2.026 


13738it [00:10, 1282.67it/s]

| epoch   2 | 13500/15242 batches | perplexity    7.457 | loss    2.009 


14248it [00:11, 1233.46it/s]

| epoch   2 | 14000/15242 batches | perplexity    7.524 | loss    2.018 


14755it [00:11, 1244.41it/s]

| epoch   2 | 14500/15242 batches | perplexity    7.535 | loss    2.019 


15125it [00:12, 1191.13it/s]

| epoch   2 | 15000/15242 batches | perplexity    7.477 | loss    2.012 


15242it [00:12, 1255.74it/s]


-----------------------------------------------------------
| end of epoch   2 | time: 12.20s | valid perplexity    7.647 | valid loss    2.034
-----------------------------------------------------------


673it [00:00, 1375.25it/s]

| epoch   3 |   500/15242 batches | perplexity    7.638 | loss    2.033 


1223it [00:00, 1331.78it/s]

| epoch   3 |  1000/15242 batches | perplexity    7.353 | loss    1.995 


1764it [00:01, 1266.37it/s]

| epoch   3 |  1500/15242 batches | perplexity    7.414 | loss    2.003 


2158it [00:01, 1290.33it/s]

| epoch   3 |  2000/15242 batches | perplexity    7.311 | loss    1.989 


2696it [00:02, 1317.93it/s]

| epoch   3 |  2500/15242 batches | perplexity    7.334 | loss    1.993 


3242it [00:02, 1350.00it/s]

| epoch   3 |  3000/15242 batches | perplexity    7.481 | loss    2.012 


3648it [00:02, 1276.55it/s]

| epoch   3 |  3500/15242 batches | perplexity    7.339 | loss    1.993 


4163it [00:03, 1244.42it/s]

| epoch   3 |  4000/15242 batches | perplexity    7.458 | loss    2.009 


4678it [00:03, 1142.66it/s]

| epoch   3 |  4500/15242 batches | perplexity    7.385 | loss    1.999 


5190it [00:04, 1259.51it/s]

| epoch   3 |  5000/15242 batches | perplexity    7.383 | loss    1.999 


5717it [00:04, 1255.72it/s]

| epoch   3 |  5500/15242 batches | perplexity    7.397 | loss    2.001 


6247it [00:04, 1313.48it/s]

| epoch   3 |  6000/15242 batches | perplexity    7.323 | loss    1.991 


6646it [00:05, 1314.11it/s]

| epoch   3 |  6500/15242 batches | perplexity    7.449 | loss    2.008 


7176it [00:05, 1304.75it/s]

| epoch   3 |  7000/15242 batches | perplexity    7.345 | loss    1.994 


7710it [00:06, 1313.05it/s]

| epoch   3 |  7500/15242 batches | perplexity    7.422 | loss    2.004 


8245it [00:06, 1322.72it/s]

| epoch   3 |  8000/15242 batches | perplexity    7.424 | loss    2.005 


8645it [00:06, 1322.72it/s]

| epoch   3 |  8500/15242 batches | perplexity    7.479 | loss    2.012 


9175it [00:07, 1304.98it/s]

| epoch   3 |  9000/15242 batches | perplexity    7.359 | loss    1.996 


9706it [00:07, 1313.06it/s]

| epoch   3 |  9500/15242 batches | perplexity    7.335 | loss    1.993 


10091it [00:07, 1177.24it/s]

| epoch   3 | 10000/15242 batches | perplexity    7.386 | loss    2.000 


10543it [00:08, 677.05it/s] 

| epoch   3 | 10500/15242 batches | perplexity    7.537 | loss    2.020 


11122it [00:09, 972.42it/s]

| epoch   3 | 11000/15242 batches | perplexity    7.337 | loss    1.993 


11700it [00:09, 1145.30it/s]

| epoch   3 | 11500/15242 batches | perplexity    7.473 | loss    2.011 


12195it [00:10, 1194.79it/s]

| epoch   3 | 12000/15242 batches | perplexity    7.429 | loss    2.005 


12686it [00:10, 1173.81it/s]

| epoch   3 | 12500/15242 batches | perplexity    7.390 | loss    2.000 


13219it [00:11, 1262.31it/s]

| epoch   3 | 13000/15242 batches | perplexity    7.560 | loss    2.023 


13781it [00:11, 1335.19it/s]

| epoch   3 | 13500/15242 batches | perplexity    7.528 | loss    2.019 


14193it [00:12, 1200.24it/s]

| epoch   3 | 14000/15242 batches | perplexity    7.316 | loss    1.990 


14716it [00:12, 1285.72it/s]

| epoch   3 | 14500/15242 batches | perplexity    7.546 | loss    2.021 


15234it [00:12, 1255.47it/s]

| epoch   3 | 15000/15242 batches | perplexity    7.164 | loss    1.969 


15242it [00:12, 1185.95it/s]


-----------------------------------------------------------
| end of epoch   3 | time: 12.91s | valid perplexity    7.570 | valid loss    2.024
-----------------------------------------------------------


655it [00:00, 1326.71it/s]

| epoch   4 |   500/15242 batches | perplexity    7.458 | loss    2.009 


1201it [00:00, 1334.39it/s]

| epoch   4 |  1000/15242 batches | perplexity    7.344 | loss    1.994 


1712it [00:01, 1176.24it/s]

| epoch   4 |  1500/15242 batches | perplexity    7.249 | loss    1.981 


2235it [00:01, 1293.54it/s]

| epoch   4 |  2000/15242 batches | perplexity    7.357 | loss    1.996 


2651it [00:02, 1359.40it/s]

| epoch   4 |  2500/15242 batches | perplexity    7.577 | loss    2.025 


3182it [00:02, 1164.48it/s]

| epoch   4 |  3000/15242 batches | perplexity    7.494 | loss    2.014 


3633it [00:03, 1007.37it/s]

| epoch   4 |  3500/15242 batches | perplexity    7.106 | loss    1.961 


4238it [00:03, 1191.56it/s]

| epoch   4 |  4000/15242 batches | perplexity    7.344 | loss    1.994 


4739it [00:03, 1143.98it/s]

| epoch   4 |  4500/15242 batches | perplexity    7.193 | loss    1.973 


5221it [00:04, 1123.57it/s]

| epoch   4 |  5000/15242 batches | perplexity    7.477 | loss    2.012 


5690it [00:04, 1124.97it/s]

| epoch   4 |  5500/15242 batches | perplexity    7.357 | loss    1.996 


6193it [00:05, 1230.64it/s]

| epoch   4 |  6000/15242 batches | perplexity    7.244 | loss    1.980 


6690it [00:05, 1217.42it/s]

| epoch   4 |  6500/15242 batches | perplexity    7.436 | loss    2.006 


7179it [00:06, 1205.11it/s]

| epoch   4 |  7000/15242 batches | perplexity    7.374 | loss    1.998 


7694it [00:06, 1268.44it/s]

| epoch   4 |  7500/15242 batches | perplexity    7.472 | loss    2.011 


8243it [00:06, 1324.15it/s]

| epoch   4 |  8000/15242 batches | perplexity    7.328 | loss    1.992 


8650it [00:07, 1326.11it/s]

| epoch   4 |  8500/15242 batches | perplexity    7.386 | loss    2.000 


9192it [00:07, 1279.58it/s]

| epoch   4 |  9000/15242 batches | perplexity    7.425 | loss    2.005 


9736it [00:08, 1339.36it/s]

| epoch   4 |  9500/15242 batches | perplexity    7.456 | loss    2.009 


10139it [00:08, 1327.64it/s]

| epoch   4 | 10000/15242 batches | perplexity    7.445 | loss    2.008 


10693it [00:08, 1369.19it/s]

| epoch   4 | 10500/15242 batches | perplexity    7.414 | loss    2.003 


11254it [00:09, 1394.59it/s]

| epoch   4 | 11000/15242 batches | perplexity    7.277 | loss    1.985 


11673it [00:09, 1257.60it/s]

| epoch   4 | 11500/15242 batches | perplexity    7.474 | loss    2.011 


12188it [00:09, 1197.59it/s]

| epoch   4 | 12000/15242 batches | perplexity    7.417 | loss    2.004 


12746it [00:10, 1349.12it/s]

| epoch   4 | 12500/15242 batches | perplexity    7.561 | loss    2.023 


13164it [00:10, 1330.42it/s]

| epoch   4 | 13000/15242 batches | perplexity    7.309 | loss    1.989 


13723it [00:11, 1380.97it/s]

| epoch   4 | 13500/15242 batches | perplexity    7.489 | loss    2.013 


14281it [00:11, 1381.30it/s]

| epoch   4 | 14000/15242 batches | perplexity    7.536 | loss    2.020 


14703it [00:11, 1385.33it/s]

| epoch   4 | 14500/15242 batches | perplexity    7.270 | loss    1.984 


15100it [00:12, 1175.71it/s]

| epoch   4 | 15000/15242 batches | perplexity    7.504 | loss    2.015 


15242it [00:12, 1243.36it/s]


-----------------------------------------------------------
| end of epoch   4 | time: 12.32s | valid perplexity    7.567 | valid loss    2.024
-----------------------------------------------------------


632it [00:00, 1067.83it/s]

| epoch   5 |   500/15242 batches | perplexity    7.283 | loss    1.985 


1111it [00:01, 1106.92it/s]

| epoch   5 |  1000/15242 batches | perplexity    7.332 | loss    1.992 


1649it [00:01, 1009.57it/s]

| epoch   5 |  1500/15242 batches | perplexity    7.336 | loss    1.993 


2172it [00:02, 1032.16it/s]

| epoch   5 |  2000/15242 batches | perplexity    7.404 | loss    2.002 


2607it [00:02, 1034.70it/s]

| epoch   5 |  2500/15242 batches | perplexity    7.393 | loss    2.001 


3200it [00:03, 1178.10it/s]

| epoch   5 |  3000/15242 batches | perplexity    7.457 | loss    2.009 


3705it [00:03, 1190.58it/s]

| epoch   5 |  3500/15242 batches | perplexity    7.432 | loss    2.006 


4220it [00:03, 1265.24it/s]

| epoch   5 |  4000/15242 batches | perplexity    7.402 | loss    2.002 


4783it [00:04, 1375.46it/s]

| epoch   5 |  4500/15242 batches | perplexity    7.358 | loss    1.996 


5205it [00:04, 1391.97it/s]

| epoch   5 |  5000/15242 batches | perplexity    7.435 | loss    2.006 


5612it [00:04, 1183.88it/s]

| epoch   5 |  5500/15242 batches | perplexity    7.553 | loss    2.022 


6247it [00:05, 1223.03it/s]

| epoch   5 |  6000/15242 batches | perplexity    7.460 | loss    2.010 


6658it [00:05, 1319.58it/s]

| epoch   5 |  6500/15242 batches | perplexity    7.427 | loss    2.005 


7222it [00:06, 1371.72it/s]

| epoch   5 |  7000/15242 batches | perplexity    7.489 | loss    2.013 


7615it [00:06, 1129.17it/s]

| epoch   5 |  7500/15242 batches | perplexity    7.272 | loss    1.984 


8162it [00:07, 996.46it/s] 

| epoch   5 |  8000/15242 batches | perplexity    7.304 | loss    1.988 


8648it [00:07, 922.44it/s]

| epoch   5 |  8500/15242 batches | perplexity    7.496 | loss    2.014 


9135it [00:08, 939.55it/s]

| epoch   5 |  9000/15242 batches | perplexity    7.371 | loss    1.997 


9665it [00:08, 983.53it/s]

| epoch   5 |  9500/15242 batches | perplexity    7.255 | loss    1.982 


10164it [00:09, 977.64it/s]

| epoch   5 | 10000/15242 batches | perplexity    7.444 | loss    2.007 


10659it [00:09, 922.00it/s]

| epoch   5 | 10500/15242 batches | perplexity    7.196 | loss    1.974 


11174it [00:10, 983.58it/s] 

| epoch   5 | 11000/15242 batches | perplexity    7.574 | loss    2.025 


11699it [00:10, 1006.16it/s]

| epoch   5 | 11500/15242 batches | perplexity    7.657 | loss    2.036 


12135it [00:11, 979.90it/s] 

| epoch   5 | 12000/15242 batches | perplexity    7.227 | loss    1.978 


12623it [00:11, 957.09it/s]

| epoch   5 | 12500/15242 batches | perplexity    7.226 | loss    1.978 


13102it [00:12, 916.90it/s]

| epoch   5 | 13000/15242 batches | perplexity    7.357 | loss    1.996 


13597it [00:12, 926.61it/s]

| epoch   5 | 13500/15242 batches | perplexity    7.407 | loss    2.002 


14094it [00:13, 902.86it/s]

| epoch   5 | 14000/15242 batches | perplexity    7.347 | loss    1.994 


14607it [00:13, 975.52it/s]

| epoch   5 | 14500/15242 batches | perplexity    7.379 | loss    1.999 


15101it [00:14, 966.35it/s]

| epoch   5 | 15000/15242 batches | perplexity    7.213 | loss    1.976 


15242it [00:14, 1044.87it/s]


-----------------------------------------------------------
| end of epoch   5 | time: 14.64s | valid perplexity    7.567 | valid loss    2.024
-----------------------------------------------------------
Checking the results of test dataset.
test perplexity    7.400 | test loss    2.001 


In [14]:
# Generate some words with model
for _ in range(50):
    print(generate_word(model, train_dataset))

brconi.
morijush.
orisa.
bellitylie.
digh.
praylee.
ley.
chril.
ceiblayvennse.
jadi.
nev.
lastapraeno.
corsen.
lavi.
simanttyn.
stre.
ngeri.
rah.
ern.
adrey.
mangeo.
for.
kharo.
barol.
kachanaca.
ren.
kayne.
lee.
ania.
mai.
bondy.
laver.
ayne.
shi.
myani.
shimo.
bran.
keid.
moristovisspperisaree.
mer.
sumbritonta.
eavay.
shilonn.
sodel.
maestoa.
rayadisiny.
rann.
tik.
gona.
faydhele.
