In [1]:
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, TensorDataset
import time
from tqdm import tqdm

### Constants and configs used below.

In [45]:
DEVICE = "cpu"
LR = 4.0
BATCH_SIZE = 16
NUM_EPOCHS = 5
MARKER = '.'
# N-gram level; P(w_t | w_{t-1}, ..., w_{t-n+1}).
# We use 3 words to predict the next word.
n = 4
# Hidden layer dimension.
h = 20
# Word embedding dimension.
m = 20

### Get the dataset and the tokenizer.

In [95]:
class CharDataset(Dataset):
    def __init__(self, words, chars):
        self.words = words
        self.chars = chars
        # Inverse dictionaries mapping char tokens to unique ids and the reverse.
        # Tokens in this case are the unique chars we passed in above.
        # Each token should be mappend to a unique integer and MARKER should have token 0.
        # For example, stoi should be like {'.' -> 0, 'a' -> 1, 'b' -> 2} if I pass in chars = '.ab'.
        self.stoi = {letter : i for i,letter in enumerate(self.chars)} | {'.':0}
        self.itos = {val:key for key,val in self.stoi.items()} # Inverse mapping.

    def __len__(self):
        # Number of words.
        return len(self.words)

    def contains(self, word):
        # Check if word is in self.words and return True/False if it is, is not.
        return word in self.words

    def get_vocab_size(self):
        # Return the vocabulary size.
        # not sure about that
        return len(self.chars)

    def encode(self, word):
        # Express this word as a list of int ids. For example, maybe ".abc" -> [0, 1, 2, 3].
        # This assumes 'a' -> 1, etc.
        return [self.stoi[char] for char in word]
    
    def decode(self, tokens):
        # For a set of tokens, return back the string.
        # For example, maybe [1, 1, 2] -> "aac"
        return ''.join([self.itos[token] for token in tokens])

    def __getitem__(self, idx):
        # This is used so we can loop over the data.
        word = self.words[idx]
        return self.encode(word)

In [96]:
def create_datasets(window, input_file = 'names.txt'):
    """
    This takes a file of words and separates all the words.
    It then gets all the characters present in the universe of words and then ouputs the statistics. 
    """
    with open(input_file, 'r') as f:
        data = f.read()
    # Split the file by new lines. You should get a list of names.
    words = data.split('\n')
    words = [ word.strip() for word in words ] # This gets rid of any trailing and starting white spaces.
    words = [ word for word in words if len(word)>0 ] # Filter out all the empty words.
    
    # This gets the universe of all characters.
    chars = set()
    for word in words:
        for char in word:
            chars.add(char)
    chars = sorted(list(chars))

    # Will force chars to have MARKER having index 0.
    chars= [MARKER] + chars
    
    # Pad each word with a context window of size n-1.
    # Why? a word like "abc" should becomes "..abc.." if the window is size 3.
    # This is some we can get pair of (x, y) data like this: ".." -> "a", ".a" -> "b", "ab" -> "c", "bc" -> ".", "c." -> "."
    # I.e. this allows us to know that "a" is a start character.
    # So you should get something like ["ab", "c"] -> ["..ab..", "..c.."], for example.
    window = '.'*(n-1)
    words = [ f"{window}{word}{window}" for word in words]
            
    print(f"The number of examples in the dataset: {len(words)}")
    print(f"The number of unique characters in the vocabulary: {len(chars)}")
    print(f"The vocabulary we have is: {''.join(chars)}")

    # Partition the input data into a training, validation, and the test set.
    out_of_sample_set_size = min(2000, int(len(words) * 0.1)) # We use 10% of the training set, or up to 2000 examples.
    test_set_size = 1500
    
    # First, get a random permutation of randomly permute of size len(words).
    # Then, convert this to a list. 
    # This index list is used below to get the train, validation, and test sets.
    rp = torch.randperm(len(words)).tolist()
    
    # Get train, validation, and test set.
    train_words = [words[i] for i in rp[:-out_of_sample_set_size]]
    validation_words = [words[i] for i in rp[-out_of_sample_set_size:-test_set_size]]
    test_words = [words[i] for i in rp[-test_set_size:]]    
    print(f"We've split up the dataset into {len(train_words)}, {len(validation_words)}, {len(test_words)} training, validation, and test examples")

    # But the data in the data set objects.
    train_dataset = CharDataset(train_words, chars)
    validation_dataset = CharDataset(validation_words, chars)
    test_dataset = CharDataset(test_words, chars)

    return train_dataset, validation_dataset, test_dataset

In [97]:
train_dataset, validation_dataset, test_dataset = create_datasets(n)

The number of examples in the dataset: 32033
The number of unique characters in the vocabulary: 27
The vocabulary we have is: .abcdefghijklmnopqrstuvwxyz
We've split up the dataset into 30033, 500, 1500 training, validation, and test examples


## Explore the data

In [98]:
# Get the first word in "train_dataset"
train_dataset.decode(train_dataset.__getitem__(0))

'...maelani...'

In [100]:
# Get the stoi map of train_dataset. How many keys does it have?
len(train_dataset.stoi.keys())

27

### Get the dataloader

In [152]:
def create_dataloader(dataset, window):
    x_list = []
    y_list = []
    # For ech word.
    for i, word in enumerate(dataset):
        # Grab a context of size window and window-1 characters will be in x, 1 will be in y.
        for j, _ in enumerate(word):
            # If there is no widow of size window left, break.
            if j + window > len(word) - 1:
                break
            word_window = word[j:j + window-1]
            x, y = word_window, word[j + window-1]
            x_list.append(x)
            y_list.append(y)
            
    return DataLoader(
        TensorDataset(torch.tensor(x_list), torch.tensor(y_list)),
        BATCH_SIZE,
        shuffle=True
    )

In [153]:
train_dataloader = create_dataloader(train_dataset, n)
validation_dataloader = create_dataloader(validation_dataset, n)
test_dataloader = create_dataloader(test_dataset, n)

### Set up the model

In [234]:
# One of the first Neural language models!
class CharacterNeuralLanguageModel(nn.Module):
    def __init__(self, V, m, h, n):
        super(CharacterNeuralLanguageModel, self).__init__()
        
        # Vocabulary size.
        self.V = V
        
        # Embedding dimension, per word.
        self.m = m
        
        # Hidden dimension.
        self.h = h
        
        # N in "N-gram"
        self.n = n
        
        # Can you change all this stuff to use nn.Linear?
        # Ca also use nn.Parameter(torch.zeros(V, m)) for self.C but then we need one-hot and this is slow.
        self.C = nn.Embedding(V, m)
        self.H = nn.Parameter(torch.zeros((n-1) * m, h))
        self.W = nn.Parameter(torch.zeros((n-1) * m, V))
        self.U = nn.Parameter(torch.zeros(h, V))
        
        self.b = torch.nn.Parameter(torch.ones(V))
        self.d = torch.nn.Parameter(torch.ones(h))
        
        self.init_weights()

    def init_weights(self):
        # Intitialize C, H, W, U in a nice way. Use xavier initialization for the weights.
        # On a first run, just pass.
        torch.nn.init.xavier_uniform_(self.C.weight)
        torch.nn.init.xavier_uniform_(self.H)
        torch.nn.init.xavier_uniform_(self.W)
        torch.nn.init.xavier_uniform_(self.U)

        
    def forward(self, x):
        
        # x is of dimenson N = batch size X n-1
        
        # N X (n-1) X m 
        x = self.C(x)
        
        # N
        B = x.shape[0]
        
        # N X (n-1) * m -> concat les embeddings 
        x = x.view(B, -1)
    
        # N X V
        y = self.b + torch.matmul(x, self.W) + torch.matmul(nn.Tanh()(self.d + torch.matmul(x, self.H)), self.U)
        
        return y

### Set up the model.

In [235]:
# Identical to lecture.
criterion = torch.nn.CrossEntropyLoss().to(DEVICE)
model = CharacterNeuralLanguageModel(
    train_dataset.get_vocab_size(), m, h, n
).to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler =torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)

### Train the model.

In [238]:
def calculate_perplexity(total_loss, total_batches):
    return torch.exp(torch.tensor(total_loss / total_batches)).item()

In [239]:
def train(dataloader, model, optimizer, criterion, epoch):
    model.train()
    total_loss, total_batches = 0.0, 0.0
    log_interval = 500

    for idx, (x, y) in tqdm(enumerate(dataloader)):
        optimizer.zero_grad()
        
        logits = model(x)
                        
        # Get the loss.
        loss = criterion(input=logits, target=y.squeeze(-1))

        # Do back propagation.
        loss.backward()
                        
        # Clip the gradients so they don't explode. Look at how this is done in lecture.
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        
        # Do an optimization step.
        optimizer.step()
        
        total_loss += loss.item()
        total_batches += 1
                
        if idx % log_interval == 0 and idx > 0:
            perplexity = calculate_perplexity(total_loss,  total_batches)
            print(
                "| epoch {:3d} "
                "| {:5d}/{:5d} batches "
                "| perplexity {:8.3f} "
                "| loss {:8.3f} "
                .format(
                    epoch,
                    idx,
                    len(dataloader),
                    perplexity,
                    total_loss / total_batches,
                )
            )
            total_loss, total_batches = 0.0, 0

In [240]:
def evaluate(dataloader, model, criterion):
    model.eval()
    total_loss, total_batches = 0.0, 0

    with torch.no_grad():
        for idx, (x, y) in enumerate(dataloader):
            logits = model(x)
            total_loss += criterion(input=logits, target=y.squeeze(-1)).item()
            total_batches += 1
    return total_loss / total_batches, calculate_perplexity(total_loss,  total_batches)

In [241]:
for epoch in range(1, NUM_EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader, model, optimizer, criterion, epoch)
    loss_val, perplexity_val = evaluate(validation_dataloader, model, criterion)
    scheduler.step()
    print("-" * 59)
    print(
        "| end of epoch {:3d} "
        "| time: {:5.2f}s "
        "| valid perplexity {:8.3f} "
        "| valid loss {:8.3f}".format(
            epoch,
            time.time() - epoch_start_time,
            perplexity_val,
            loss_val
        )
    )
    print("-" * 59)

print("Checking the results of test dataset.")
loss_test, perplexity_test =evaluate(test_dataloader, model, criterion)
print("test perplexity {:8.3f} | test loss {:8.3f} ".format(perplexity_test, loss_test))

999it [00:00, 2581.39it/s]

| epoch   1 |   500/15247 batches | perplexity    6.911 | loss    1.933 
| epoch   1 |  1000/15247 batches | perplexity    7.198 | loss    1.974 


1794it [00:00, 2640.36it/s]

| epoch   1 |  1500/15247 batches | perplexity    6.807 | loss    1.918 
| epoch   1 |  2000/15247 batches | perplexity    7.080 | loss    1.957 


2886it [00:01, 2717.80it/s]

| epoch   1 |  2500/15247 batches | perplexity    6.921 | loss    1.935 
| epoch   1 |  3000/15247 batches | perplexity    7.212 | loss    1.976 


4001it [00:01, 2759.64it/s]

| epoch   1 |  3500/15247 batches | perplexity    7.158 | loss    1.968 
| epoch   1 |  4000/15247 batches | perplexity    7.093 | loss    1.959 


4828it [00:01, 2729.75it/s]

| epoch   1 |  4500/15247 batches | perplexity    6.966 | loss    1.941 
| epoch   1 |  5000/15247 batches | perplexity    7.051 | loss    1.953 


5939it [00:02, 2768.69it/s]

| epoch   1 |  5500/15247 batches | perplexity    7.008 | loss    1.947 
| epoch   1 |  6000/15247 batches | perplexity    6.832 | loss    1.922 


7063it [00:02, 2797.95it/s]

| epoch   1 |  6500/15247 batches | perplexity    6.770 | loss    1.912 
| epoch   1 |  7000/15247 batches | perplexity    6.897 | loss    1.931 


7904it [00:02, 2780.11it/s]

| epoch   1 |  7500/15247 batches | perplexity    7.047 | loss    1.953 
| epoch   1 |  8000/15247 batches | perplexity    6.766 | loss    1.912 


9006it [00:03, 2684.47it/s]

| epoch   1 |  8500/15247 batches | perplexity    7.046 | loss    1.952 
| epoch   1 |  9000/15247 batches | perplexity    6.959 | loss    1.940 


9817it [00:03, 2602.89it/s]

| epoch   1 |  9500/15247 batches | perplexity    7.050 | loss    1.953 
| epoch   1 | 10000/15247 batches | perplexity    7.029 | loss    1.950 


10902it [00:04, 2687.96it/s]

| epoch   1 | 10500/15247 batches | perplexity    7.001 | loss    1.946 
| epoch   1 | 11000/15247 batches | perplexity    6.833 | loss    1.922 


12002it [00:04, 2684.70it/s]

| epoch   1 | 11500/15247 batches | perplexity    6.919 | loss    1.934 
| epoch   1 | 12000/15247 batches | perplexity    6.949 | loss    1.939 


12826it [00:04, 2712.03it/s]

| epoch   1 | 12500/15247 batches | perplexity    6.820 | loss    1.920 
| epoch   1 | 13000/15247 batches | perplexity    6.985 | loss    1.944 


13935it [00:05, 2745.53it/s]

| epoch   1 | 13500/15247 batches | perplexity    7.195 | loss    1.973 
| epoch   1 | 14000/15247 batches | perplexity    6.977 | loss    1.943 


15043it [00:05, 2759.05it/s]

| epoch   1 | 14500/15247 batches | perplexity    6.990 | loss    1.944 
| epoch   1 | 15000/15247 batches | perplexity    6.968 | loss    1.941 


15247it [00:05, 2697.48it/s]


-----------------------------------------------------------
| end of epoch   1 | time:  5.68s | valid perplexity    7.084 | valid loss    1.958
-----------------------------------------------------------


259it [00:00, 2583.29it/s]

| epoch   2 |   500/15247 batches | perplexity    6.876 | loss    1.928 


808it [00:00, 2664.13it/s]

| epoch   2 |  1000/15247 batches | perplexity    7.101 | loss    1.960 


1345it [00:00, 2649.67it/s]

| epoch   2 |  1500/15247 batches | perplexity    7.041 | loss    1.952 


1887it [00:00, 2679.04it/s]

| epoch   2 |  2000/15247 batches | perplexity    7.002 | loss    1.946 


2444it [00:00, 2732.90it/s]

| epoch   2 |  2500/15247 batches | perplexity    6.849 | loss    1.924 


2992it [00:01, 2719.37it/s]

| epoch   2 |  3000/15247 batches | perplexity    7.132 | loss    1.965 


3533it [00:01, 2669.50it/s]

| epoch   2 |  3500/15247 batches | perplexity    6.864 | loss    1.926 


3801it [00:01, 2648.28it/s]

| epoch   2 |  4000/15247 batches | perplexity    7.046 | loss    1.952 


4339it [00:01, 2659.34it/s]

| epoch   2 |  4500/15247 batches | perplexity    7.036 | loss    1.951 


4875it [00:01, 2665.24it/s]

| epoch   2 |  5000/15247 batches | perplexity    7.010 | loss    1.947 


5424it [00:02, 2699.62it/s]

| epoch   2 |  5500/15247 batches | perplexity    6.931 | loss    1.936 


5964it [00:02, 2640.54it/s]

| epoch   2 |  6000/15247 batches | perplexity    7.000 | loss    1.946 


6516it [00:02, 2620.84it/s]

| epoch   2 |  6500/15247 batches | perplexity    6.869 | loss    1.927 


6786it [00:02, 2641.69it/s]

| epoch   2 |  7000/15247 batches | perplexity    6.934 | loss    1.936 


7327it [00:02, 2674.95it/s]

| epoch   2 |  7500/15247 batches | perplexity    6.962 | loss    1.940 


7884it [00:02, 2733.25it/s]

| epoch   2 |  8000/15247 batches | perplexity    7.041 | loss    1.952 


8430it [00:03, 2704.15it/s]

| epoch   2 |  8500/15247 batches | perplexity    6.900 | loss    1.932 


8974it [00:03, 2704.54it/s]

| epoch   2 |  9000/15247 batches | perplexity    6.941 | loss    1.937 


9520it [00:03, 2704.99it/s]

| epoch   2 |  9500/15247 batches | perplexity    7.178 | loss    1.971 


9794it [00:03, 2714.83it/s]

| epoch   2 | 10000/15247 batches | perplexity    7.172 | loss    1.970 


10331it [00:03, 2610.40it/s]

| epoch   2 | 10500/15247 batches | perplexity    7.003 | loss    1.946 


10866it [00:04, 2609.08it/s]

| epoch   2 | 11000/15247 batches | perplexity    6.907 | loss    1.933 


11397it [00:04, 2611.17it/s]

| epoch   2 | 11500/15247 batches | perplexity    6.902 | loss    1.932 


11956it [00:04, 2700.42it/s]

| epoch   2 | 12000/15247 batches | perplexity    6.855 | loss    1.925 


12508it [00:04, 2681.09it/s]

| epoch   2 | 12500/15247 batches | perplexity    6.961 | loss    1.940 


12781it [00:04, 2694.64it/s]

| epoch   2 | 13000/15247 batches | perplexity    6.978 | loss    1.943 


13331it [00:04, 2723.10it/s]

| epoch   2 | 13500/15247 batches | perplexity    6.942 | loss    1.938 


13877it [00:05, 2724.47it/s]

| epoch   2 | 14000/15247 batches | perplexity    7.082 | loss    1.958 


14956it [00:05, 2616.81it/s]

| epoch   2 | 14500/15247 batches | perplexity    7.103 | loss    1.960 
| epoch   2 | 15000/15247 batches | perplexity    6.912 | loss    1.933 


15247it [00:05, 2659.07it/s]


-----------------------------------------------------------
| end of epoch   2 | time:  5.76s | valid perplexity    7.084 | valid loss    1.958
-----------------------------------------------------------


257it [00:00, 2568.61it/s]

| epoch   3 |   500/15247 batches | perplexity    6.899 | loss    1.931 


817it [00:00, 2739.58it/s]

| epoch   3 |  1000/15247 batches | perplexity    6.902 | loss    1.932 


1361it [00:00, 2627.26it/s]

| epoch   3 |  1500/15247 batches | perplexity    7.139 | loss    1.966 


1889it [00:00, 2626.79it/s]

| epoch   3 |  2000/15247 batches | perplexity    6.803 | loss    1.917 


2419it [00:00, 2621.59it/s]

| epoch   3 |  2500/15247 batches | perplexity    6.788 | loss    1.915 


2966it [00:01, 2636.31it/s]

| epoch   3 |  3000/15247 batches | perplexity    7.003 | loss    1.946 


3514it [00:01, 2686.09it/s]

| epoch   3 |  3500/15247 batches | perplexity    7.030 | loss    1.950 


3787it [00:01, 2696.89it/s]

| epoch   3 |  4000/15247 batches | perplexity    7.103 | loss    1.961 


4330it [00:01, 2705.40it/s]

| epoch   3 |  4500/15247 batches | perplexity    7.018 | loss    1.948 


4877it [00:01, 2702.59it/s]

| epoch   3 |  5000/15247 batches | perplexity    6.889 | loss    1.930 


5427it [00:02, 2727.45it/s]

| epoch   3 |  5500/15247 batches | perplexity    6.909 | loss    1.933 


5978it [00:02, 2742.29it/s]

| epoch   3 |  6000/15247 batches | perplexity    7.029 | loss    1.950 


6531it [00:02, 2720.59it/s]

| epoch   3 |  6500/15247 batches | perplexity    7.067 | loss    1.955 


6806it [00:02, 2728.70it/s]

| epoch   3 |  7000/15247 batches | perplexity    6.963 | loss    1.941 


7361it [00:02, 2754.03it/s]

| epoch   3 |  7500/15247 batches | perplexity    7.074 | loss    1.956 


7920it [00:02, 2759.25it/s]

| epoch   3 |  8000/15247 batches | perplexity    6.960 | loss    1.940 


8473it [00:03, 2699.84it/s]

| epoch   3 |  8500/15247 batches | perplexity    6.809 | loss    1.918 


9028it [00:03, 2738.99it/s]

| epoch   3 |  9000/15247 batches | perplexity    6.998 | loss    1.946 


9306it [00:03, 2749.46it/s]

| epoch   3 |  9500/15247 batches | perplexity    7.069 | loss    1.956 


9861it [00:03, 2727.31it/s]

| epoch   3 | 10000/15247 batches | perplexity    7.062 | loss    1.955 


10407it [00:03, 2720.45it/s]

| epoch   3 | 10500/15247 batches | perplexity    6.924 | loss    1.935 


10951it [00:04, 2686.61it/s]

| epoch   3 | 11000/15247 batches | perplexity    7.155 | loss    1.968 


11513it [00:04, 2745.98it/s]

| epoch   3 | 11500/15247 batches | perplexity    7.021 | loss    1.949 


11788it [00:04, 2736.06it/s]

| epoch   3 | 12000/15247 batches | perplexity    6.948 | loss    1.938 


12335it [00:04, 2721.71it/s]

| epoch   3 | 12500/15247 batches | perplexity    6.821 | loss    1.920 


12895it [00:04, 2760.74it/s]

| epoch   3 | 13000/15247 batches | perplexity    7.076 | loss    1.957 


13454it [00:04, 2763.03it/s]

| epoch   3 | 13500/15247 batches | perplexity    6.744 | loss    1.909 


14016it [00:05, 2787.91it/s]

| epoch   3 | 14000/15247 batches | perplexity    7.006 | loss    1.947 


14295it [00:05, 2764.63it/s]

| epoch   3 | 14500/15247 batches | perplexity    7.110 | loss    1.961 


14849it [00:05, 2760.76it/s]

| epoch   3 | 15000/15247 batches | perplexity    7.075 | loss    1.957 


15247it [00:05, 2710.54it/s]


-----------------------------------------------------------
| end of epoch   3 | time:  5.66s | valid perplexity    7.087 | valid loss    1.958
-----------------------------------------------------------


261it [00:00, 2601.34it/s]

| epoch   4 |   500/15247 batches | perplexity    7.009 | loss    1.947 


803it [00:00, 2677.38it/s]

| epoch   4 |  1000/15247 batches | perplexity    7.054 | loss    1.954 


1352it [00:00, 2690.32it/s]

| epoch   4 |  1500/15247 batches | perplexity    6.795 | loss    1.916 


1888it [00:00, 2645.74it/s]

| epoch   4 |  2000/15247 batches | perplexity    7.065 | loss    1.955 


2968it [00:01, 2600.40it/s]

| epoch   4 |  2500/15247 batches | perplexity    6.752 | loss    1.910 
| epoch   4 |  3000/15247 batches | perplexity    6.729 | loss    1.906 


3775it [00:01, 2658.21it/s]

| epoch   4 |  3500/15247 batches | perplexity    7.090 | loss    1.959 
| epoch   4 |  4000/15247 batches | perplexity    7.056 | loss    1.954 


4888it [00:01, 2756.06it/s]

| epoch   4 |  4500/15247 batches | perplexity    6.845 | loss    1.924 
| epoch   4 |  5000/15247 batches | perplexity    6.915 | loss    1.934 


6004it [00:02, 2738.98it/s]

| epoch   4 |  5500/15247 batches | perplexity    6.898 | loss    1.931 
| epoch   4 |  6000/15247 batches | perplexity    6.959 | loss    1.940 


6820it [00:02, 2650.68it/s]

| epoch   4 |  6500/15247 batches | perplexity    7.094 | loss    1.959 
| epoch   4 |  7000/15247 batches | perplexity    7.034 | loss    1.951 


7899it [00:02, 2677.94it/s]

| epoch   4 |  7500/15247 batches | perplexity    6.860 | loss    1.926 
| epoch   4 |  8000/15247 batches | perplexity    7.074 | loss    1.956 


9009it [00:03, 2750.47it/s]

| epoch   4 |  8500/15247 batches | perplexity    6.972 | loss    1.942 
| epoch   4 |  9000/15247 batches | perplexity    7.017 | loss    1.948 


9829it [00:03, 2700.56it/s]

| epoch   4 |  9500/15247 batches | perplexity    6.995 | loss    1.945 
| epoch   4 | 10000/15247 batches | perplexity    7.206 | loss    1.975 


10927it [00:04, 2703.85it/s]

| epoch   4 | 10500/15247 batches | perplexity    7.001 | loss    1.946 
| epoch   4 | 11000/15247 batches | perplexity    7.048 | loss    1.953 


12015it [00:04, 2695.12it/s]

| epoch   4 | 11500/15247 batches | perplexity    7.132 | loss    1.965 
| epoch   4 | 12000/15247 batches | perplexity    6.945 | loss    1.938 


12829it [00:04, 2686.21it/s]

| epoch   4 | 12500/15247 batches | perplexity    6.942 | loss    1.938 
| epoch   4 | 13000/15247 batches | perplexity    6.968 | loss    1.941 


13918it [00:05, 2701.15it/s]

| epoch   4 | 13500/15247 batches | perplexity    6.848 | loss    1.924 
| epoch   4 | 14000/15247 batches | perplexity    7.037 | loss    1.951 


15003it [00:05, 2699.90it/s]

| epoch   4 | 14500/15247 batches | perplexity    7.146 | loss    1.967 
| epoch   4 | 15000/15247 batches | perplexity    7.089 | loss    1.959 


15247it [00:05, 2679.23it/s]


-----------------------------------------------------------
| end of epoch   4 | time:  5.72s | valid perplexity    7.081 | valid loss    1.957
-----------------------------------------------------------


246it [00:00, 2454.60it/s]

| epoch   5 |   500/15247 batches | perplexity    6.994 | loss    1.945 


789it [00:00, 2652.76it/s]

| epoch   5 |  1000/15247 batches | perplexity    6.866 | loss    1.927 


1321it [00:00, 2641.52it/s]

| epoch   5 |  1500/15247 batches | perplexity    6.852 | loss    1.925 


1878it [00:00, 2725.71it/s]

| epoch   5 |  2000/15247 batches | perplexity    6.941 | loss    1.937 


2424it [00:00, 2716.05it/s]

| epoch   5 |  2500/15247 batches | perplexity    7.036 | loss    1.951 


2968it [00:01, 2708.67it/s]

| epoch   5 |  3000/15247 batches | perplexity    7.012 | loss    1.948 


3510it [00:01, 2693.30it/s]

| epoch   5 |  3500/15247 batches | perplexity    7.035 | loss    1.951 


3780it [00:01, 2679.56it/s]

| epoch   5 |  4000/15247 batches | perplexity    7.025 | loss    1.949 


4844it [00:01, 2594.19it/s]

| epoch   5 |  4500/15247 batches | perplexity    6.965 | loss    1.941 
| epoch   5 |  5000/15247 batches | perplexity    6.872 | loss    1.928 


5910it [00:02, 2645.15it/s]

| epoch   5 |  5500/15247 batches | perplexity    7.065 | loss    1.955 
| epoch   5 |  6000/15247 batches | perplexity    7.148 | loss    1.967 


6969it [00:02, 2567.57it/s]

| epoch   5 |  6500/15247 batches | perplexity    6.905 | loss    1.932 
| epoch   5 |  7000/15247 batches | perplexity    6.893 | loss    1.931 


8046it [00:03, 2671.66it/s]

| epoch   5 |  7500/15247 batches | perplexity    6.925 | loss    1.935 
| epoch   5 |  8000/15247 batches | perplexity    6.732 | loss    1.907 


8843it [00:03, 2592.95it/s]

| epoch   5 |  8500/15247 batches | perplexity    6.997 | loss    1.946 
| epoch   5 |  9000/15247 batches | perplexity    7.144 | loss    1.966 


9930it [00:03, 2697.03it/s]

| epoch   5 |  9500/15247 batches | perplexity    6.970 | loss    1.942 
| epoch   5 | 10000/15247 batches | perplexity    6.902 | loss    1.932 


11046it [00:04, 2776.75it/s]

| epoch   5 | 10500/15247 batches | perplexity    7.122 | loss    1.963 
| epoch   5 | 11000/15247 batches | perplexity    7.010 | loss    1.947 


11871it [00:04, 2587.63it/s]

| epoch   5 | 11500/15247 batches | perplexity    7.045 | loss    1.952 
| epoch   5 | 12000/15247 batches | perplexity    7.011 | loss    1.948 


12953it [00:04, 2643.77it/s]

| epoch   5 | 12500/15247 batches | perplexity    7.064 | loss    1.955 
| epoch   5 | 13000/15247 batches | perplexity    6.975 | loss    1.942 


13785it [00:05, 2694.66it/s]

| epoch   5 | 13500/15247 batches | perplexity    7.064 | loss    1.955 
| epoch   5 | 14000/15247 batches | perplexity    6.988 | loss    1.944 


14901it [00:05, 2723.44it/s]

| epoch   5 | 14500/15247 batches | perplexity    6.790 | loss    1.915 
| epoch   5 | 15000/15247 batches | perplexity    7.068 | loss    1.956 


15247it [00:05, 2648.59it/s]


-----------------------------------------------------------
| end of epoch   5 | time:  5.79s | valid perplexity    7.084 | valid loss    1.958
-----------------------------------------------------------
Checking the results of test dataset.
test perplexity    6.974 | test loss    1.942 


## Generate some text.

In [249]:
def generate_word(model, dataset, window):
    generated_word = []
    # Set the context to a window-1 length array having just the MARKER character's token_id.
    context = dataset.__getitem__(0)[:window-1]
    
    while True:
        logits = model(torch.tensor(context).view(1, -1))
        
        # Get the probabilities from the logits.
        # Hint: softmax!
        softmax = nn.Softmax(dim=1)
        probs = softmax(logits)
        
        # Get 1 sample from a multinomial having the above probabilities.
        token_id = torch.multinomial(probs,1).item()
        
        # Append the token_id to the generated word.
        generated_word.append(token_id)
        
        # Move the context over 1, drop the first (oldest) token and apped the new one above.
        # The size of the resulting context should be the same.
        # For exaple, if it was "[0, 1, 2]" and you generated 4, it should now be [1, 2, 4].
        context = context[1:] + [token_id]
        
        if token_id == 0:
            # If you generate token_id = 0, i.e. '.', break out.
            break
    # Return and decode the generated word to a string.        
    return ''.join(dataset.decode(generated_word))

In [250]:
torch.manual_seed(1)
for _ in range(50):
    print(generate_word(model, train_dataset, n))

anzorie.
lia.
aldin.
xalrest.
dez.
briartai.
rielciylend.
maderion.
caera.
dacelian.
alalie.
shais.
mayas.
tysi.
braxia.
tye.
karie.
gros.
auk.
kanaran.
anyaamius.
kelee.
har.
jami.
naeka.
reem.
kaylen.
quyla.
namius.
bylly.
jutram.
ahazoriexsunyrin.
jen.
tirooni.
evfiah.
rosi.
rouitta.
ynnlaydon.
kenassavoli.
wlynn.
nalaira.
anir.
ilyn.
marri.
alevante.
kalyn.
desleeshanna.
daniullae.
rmanaililah.
cyle.
