In [1]:
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, TensorDataset
import time
from tqdm import tqdm

### Information
- We will do a few preliminary exercises and also build a character level MLP language model.
- This model will be similar to the model we did in class, except that we will have characters as tokens, not words.
- You will need a conda environment for this, here is general information on this.
 - https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html
 - PyTorch: https://anaconda.org/pytorch/pytorch
 
In the code below, FILL-IN the code necessary in the hint string provided.

### Preliminary exercises
- Please fill in the cells below with the asked for data.

In [2]:
torch.manual_seed(1)

<torch._C.Generator at 0x1dab7f4ca70>

In [3]:
# Create an embedding layer for a vocabulary of size 10 and the word vectors are each of dimension 5.
e = nn.Embedding(10, 5)

# Extract the embedding for the word whose token index is 3. What is the shape of this vector?
v = e(torch.tensor(3)) # should be a vector of size 5
print(v.shape)

# Extract the weight matrix from the layer e.
w = e.weight
# Create a linear layer (with no bias) of size 10 by 5 and set it's data to the embedding matrix.
l = nn.Linear(10, 5)
l.weight = w

# Insert inside of the assert below some sort of equality check between l.weight and e.weight; it should pass to true.
# Hint: look up torch.all() and torch.eq()
assert(torch.all(torch.eq(l.weight, e.weight)))

torch.Size([5])


In [4]:
# Create a batch of size 2 with entries [0, 1, 2] and [2, 3, 4] in the data batch.
x = torch.tensor([
    [0, 1, 2],
    [2, 3, 4]
])

In [5]:
# What is the dimension of this batch ran through the embeding layer?
# e will tranform each integer to a vector of size 5. So we hould expect an output of size (2, 3, 5)
assert(e(x).shape == torch.Size([2, 3, 5]))

### Constants and configs used below.

In [6]:
DEVICE = "cpu"
LR = 4.0
BATCH_SIZE = 16
NUM_EPOCHS = 5
MARKER = '.'
# N-gram level; P(w_t | w_{t-1}, ..., w_{t-n+1}).
# We use 3 words to predict the next word.
n = 4
# Hidden layer dimension.
h = 20
# Word embedding dimension.
m = 20

### Get the dataset and the tokenizer.

In [7]:
class CharDataset(Dataset):
    def __init__(self, words, chars):
        self.words = words
        self.chars = chars
        # Inverse dictionaries mapping char tokens to unique ids and the reverse.
        # Tokens in this case are the unique chars we passed in above.
        # Each token should be mappend to a unique integer and MARKER should have token 0.
        # For example, stoi should be like {'.' -> 0, 'a' -> 1, 'b' -> 2} if I pass in chars = '.ab'.
        self.stoi = {c:i for (c, i) in zip(chars, range(len(chars)))}
        self.itos = {i:c for (c, i) in self.stoi.items()}

    def __len__(self):
        # Number of words.
        return len(self.words)

    def contains(self, word):
        # Check if word is in self.words and return True/False if it is, is not.
        return (word in self.words)

    def get_vocab_size(self):
        # Return the vocabulary size.
        return len(set(self.words))

    def encode(self, word):
        # Express this word as a list of int ids. For example, maybe ".abc" -> [0, 1, 2, 3].
        # This assumes 'a' -> 1, etc.
        return [self.stoi[c] for c in word]
    
    def decode(self, tokens):
        # For a set of tokens, return back the string.
        # For example, maybe [1, 1, 2] -> "aac"
        return ''.join([self.itos[i] for i in tokens])

    def __getitem__(self, idx):
        # This is used so we can loop over the data.
        word = self.words[idx]
        return self.encode(word)

In [8]:
def create_datasets(window, input_file = 'names.txt'):
    """
    This takes a file of words and separates all the words.
    It then gets all the characters present in the universe of words and then ouputs the statistics. 
    """
    with open(input_file, 'r') as f:
        data = f.read()
    # Split the file by new lines. You should get a list of names.
    words = data.split('\n')
    words = [w.strip() for w in words] # This gets rid of any trailing and starting white spaces.
    words = [w for w in words if w != ''] # Filter out all the empty words.
    
    chars = sorted(list(set(''.join(words)))) # This gets the universe of all characters.
    
    # Will force chars to have MARKER having index 0.
    chars= [MARKER] + chars
    
    # Pad each word with a context window of size n-1.
    # Why? a word like "abc" should becomes "..abc.." if the window is size 3.
    # This is some we can get pair of (x, y) data like this: ".." -> "a", ".a" -> "b", "ab" -> "c", "bc" -> ".", "c." -> "."
    # I.e. this allows us to know that "a" is a start character.
    # So you should get something like ["ab", "c"] -> ["..ab..", "..c.."], for example.
    words = ['.' * (window - 1) + w + '.' * (window - 1) for w in words]
            
    print(f"The number of examples in the dataset: {len(words)}")
    print(f"The number of unique characters in the vocabulary: {len(chars)}")
    print(f"The vocabulary we have is: {''.join(chars)}")

    # Partition the input data into a training, validation, and the test set.
    out_of_sample_set_size = min(2000, int(len(words) * 0.1)) # We use 10% of the training set, or up to 2000 examples.
    test_set_size = 1500
    
    # First, get a random permutation of randomly permute of size len(words).
    # Then, convert this to a list. 
    # This index list is used below to get the train, validation, and test sets.
    rp = torch.randperm(len(words)).tolist()
    
    # Get train, validation, and test set.
    train_words = [words[i] for i in rp[:-out_of_sample_set_size]]
    validation_words = [words[i] for i in rp[-out_of_sample_set_size:-test_set_size]]
    test_words = [words[i] for i in rp[-test_set_size:]]    
    
    print(f"We've split up the dataset into {len(train_words)}, {len(validation_words)}, {len(test_words)} training, validation, and test examples")

    # But the data in the data set objects.
    train_dataset = CharDataset(train_words, chars)
    validation_dataset = CharDataset(validation_words, chars)
    test_dataset = CharDataset(test_words, chars)

    return train_dataset, validation_dataset, test_dataset

In [9]:
train_dataset, validation_dataset, test_dataset = create_datasets(n)

The number of examples in the dataset: 32033
The number of unique characters in the vocabulary: 27
The vocabulary we have is: .abcdefghijklmnopqrstuvwxyz
We've split up the dataset into 30033, 500, 1500 training, validation, and test examples


## Explore the data

In [10]:
# Get the first word in "train_dataset"
train_dataset.words[0]

'...zaliya...'

In [11]:
# Get the stoi map of train_dataset. How many keys does it have?
print(train_dataset.stoi)
print(f'It has {len(train_dataset.stoi.keys())} keys')

{'.': 0, 'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26}
It has 27 keys


### Get the dataloader

In [12]:
def create_dataloader(dataset, window):
    x_list = []
    y_list = []
    # For ech word.
    for i, word in enumerate(dataset):
        # Grab a context of size window and window-1 characters will be in x, 1 will be in y.
        for j, _ in enumerate(word):
            # If there is no window of size window left, break.
            if j + window > len(word) - 1:
                break
            word_window = word[j:j + window + 1]
            x, y = word_window[:-1], word_window[-1]
            x_list.append(x)
            y_list.append(y)
            
    return DataLoader(
        TensorDataset(torch.tensor(x_list), torch.tensor(y_list)),
        BATCH_SIZE,
        shuffle=True
    )

In [13]:
train_dataloader = create_dataloader(train_dataset, n-1)
validation_dataloader = create_dataloader(validation_dataset, n-1)
test_dataloader = create_dataloader(test_dataset, n-1)

### Set up the model
- Identical to lecture. Please look over that!

In [14]:
# One of the first Neural language models!
class CharacterNeuralLanguageModel(nn.Module):
    def __init__(self, V, m, h, n):
        super(CharacterNeuralLanguageModel, self).__init__()
        
        # Vocabulary size.
        self.V = V
        
        # Embedding dimension, per word.
        self.m = m
        
        # Hidden dimension.
        self.h = h
        
        # N in "N-gram"
        self.n = n
        print(n)
        
        # Can you change all this stuff to use nn.Linear?
        # Ca also use nn.Parameter(torch.zeros(V, m)) for self.C but then we need one-hot and this is slow.
        self.C = nn.Embedding(V, m) 
        self.H = nn.Parameter(torch.zeros((n-1) * m, h))
        self.W = nn.Parameter(torch.zeros((n-1) * m, V))
        self.U = nn.Parameter(torch.zeros(h, V))
        
        self.b = torch.nn.Parameter(torch.ones(V))
        self.d = torch.nn.Parameter(torch.ones(h))
        
        self.init_weights()

    def init_weights(self):
        # Intitialize C, H, W, U in a nice way. Use xavier initialization for the weights.
        # On a first run, just pass.
        nn.init.xavier_uniform_(self.C.weight)
        nn.init.xavier_uniform_(self.H)
        nn.init.xavier_uniform_(self.W)
        nn.init.xavier_uniform_(self.U)
        
    def forward(self, x):
        
        # x is of dimenson N = batch size X n-1
        
        # N X (n-1) X m 
        x = self.C(x)
        
        # N
        B = x.shape[0]
        
        # N X (n-1) * m
        x = x.view(B, -1)
    
        # N X V
        y = self.b + torch.matmul(x, self.W) + torch.matmul(nn.Tanh()(self.d + torch.matmul(x, self.H)), self.U)
        
        return y

### Set up the model.

In [15]:
# Identical to lecture.
criterion = torch.nn.CrossEntropyLoss().to(DEVICE)
model = CharacterNeuralLanguageModel(
    train_dataset.get_vocab_size(), m, h, n
).to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)

4


In [16]:
# How many parameters does the neural network have?
# Hint: look up model.named_parameters and the method "nelement" on a tensor.
# See also the XOR notebook where we count the gradients that are 0.
# There, we loop over the parameters.
number_parameters = sum([param.nelement() for _, param in model.named_parameters()])
print(f'There are {number_parameters} parameters in the model')

There are 2812151 parameters in the model


### Train the model.

In [17]:
def calculate_perplexity(total_loss, total_batches):
    return torch.exp(torch.tensor(total_loss / total_batches)).item()

In [18]:
def train(dataloader, model, optimizer, criterion, epoch):
    model.train()
    total_loss, total_batches = 0.0, 0.0
    log_interval = 500

    for idx, (x, y) in tqdm(enumerate(dataloader)):
        optimizer.zero_grad()
        
        logits = model(x)
                        
        # Get the loss.
        loss = criterion(input=logits, target=y.squeeze(-1))

        # Do back propagation.
        loss.backward()
                        
        # Clip the gradients so they don't explode. Look at how this is done in lecture.
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        
        # Do an optimization step.
        optimizer.step()
        
        total_loss += loss.item()
        total_batches += 1
                
        if idx % log_interval == 0 and idx > 0:
            perplexity = calculate_perplexity(total_loss,  total_batches)
            print(
                "| epoch {:3d} "
                "| {:5d}/{:5d} batches "
                "| perplexity {:8.3f} "
                "| loss {:8.3f} "
                .format(
                    epoch,
                    idx,
                    len(dataloader),
                    perplexity,
                    total_loss / total_batches,
                )
            )
            total_loss, total_batches = 0.0, 0

In [19]:
def evaluate(dataloader, model, criterion):
    model.eval()
    total_loss, total_batches = 0.0, 0

    with torch.no_grad():
        for idx, (x, y) in enumerate(dataloader):
            logits = model(x)
            total_loss += criterion(input=logits, target=y.squeeze(-1)).item()
            total_batches += 1
    return total_loss / total_batches, calculate_perplexity(total_loss,  total_batches)

In [20]:
for epoch in range(1, NUM_EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader, model, optimizer, criterion, epoch)
    loss_val, perplexity_val = evaluate(validation_dataloader, model, criterion)
    scheduler.step()
    print("-" * 59)
    print(
        "| end of epoch {:3d} "
        "| time: {:5.2f}s "
        "| valid perplexity {:8.3f} "
        "| valid loss {:8.3f}".format(
            epoch,
            time.time() - epoch_start_time,
            perplexity_val,
            loss_val
        )
    )
    print("-" * 59)

print("Checking the results of test dataset.")
loss_test, perplexity_test = evaluate(test_dataloader, model, criterion)
print("test perplexity {:8.3f} | test loss {:8.3f} ".format(perplexity_test, loss_test))

508it [00:12, 42.17it/s]

| epoch   1 |   500/17123 batches | perplexity   12.394 | loss    2.517 


1006it [00:26, 37.13it/s]

| epoch   1 |  1000/17123 batches | perplexity    7.784 | loss    2.052 


1507it [00:39, 38.94it/s]

| epoch   1 |  1500/17123 batches | perplexity    7.436 | loss    2.006 


2005it [00:52, 39.62it/s]

| epoch   1 |  2000/17123 batches | perplexity    7.357 | loss    1.996 


2505it [01:04, 40.99it/s]

| epoch   1 |  2500/17123 batches | perplexity    6.886 | loss    1.929 


3008it [01:16, 43.28it/s]

| epoch   1 |  3000/17123 batches | perplexity    6.971 | loss    1.942 


3509it [01:29, 42.63it/s]

| epoch   1 |  3500/17123 batches | perplexity    6.912 | loss    1.933 


4006it [01:41, 39.72it/s]

| epoch   1 |  4000/17123 batches | perplexity    6.794 | loss    1.916 


4508it [01:54, 39.11it/s]

| epoch   1 |  4500/17123 batches | perplexity    6.949 | loss    1.939 


5005it [02:07, 40.75it/s]

| epoch   1 |  5000/17123 batches | perplexity    6.886 | loss    1.929 


5509it [02:20, 41.52it/s]

| epoch   1 |  5500/17123 batches | perplexity    6.763 | loss    1.912 


6004it [02:32, 42.36it/s]

| epoch   1 |  6000/17123 batches | perplexity    6.921 | loss    1.935 


6507it [02:44, 41.90it/s]

| epoch   1 |  6500/17123 batches | perplexity    6.633 | loss    1.892 


7009it [02:56, 41.46it/s]

| epoch   1 |  7000/17123 batches | perplexity    6.703 | loss    1.903 


7506it [03:08, 41.17it/s]

| epoch   1 |  7500/17123 batches | perplexity    6.392 | loss    1.855 


8007it [03:21, 38.47it/s]

| epoch   1 |  8000/17123 batches | perplexity    6.312 | loss    1.842 


8509it [03:34, 41.81it/s]

| epoch   1 |  8500/17123 batches | perplexity    6.569 | loss    1.882 


9005it [03:46, 41.37it/s]

| epoch   1 |  9000/17123 batches | perplexity    6.411 | loss    1.858 


9508it [03:59, 42.68it/s]

| epoch   1 |  9500/17123 batches | perplexity    6.569 | loss    1.882 


10008it [04:11, 41.89it/s]

| epoch   1 | 10000/17123 batches | perplexity    6.427 | loss    1.860 


10507it [04:23, 39.86it/s]

| epoch   1 | 10500/17123 batches | perplexity    6.454 | loss    1.865 


11009it [04:36, 39.72it/s]

| epoch   1 | 11000/17123 batches | perplexity    6.275 | loss    1.837 


11509it [04:49, 41.60it/s]

| epoch   1 | 11500/17123 batches | perplexity    6.462 | loss    1.866 


12008it [05:01, 39.19it/s]

| epoch   1 | 12000/17123 batches | perplexity    6.326 | loss    1.845 


12508it [05:15, 39.41it/s]

| epoch   1 | 12500/17123 batches | perplexity    6.418 | loss    1.859 


13008it [05:28, 39.79it/s]

| epoch   1 | 13000/17123 batches | perplexity    6.524 | loss    1.876 


13506it [05:41, 26.49it/s]

| epoch   1 | 13500/17123 batches | perplexity    6.403 | loss    1.857 


14005it [05:56, 37.71it/s]

| epoch   1 | 14000/17123 batches | perplexity    6.361 | loss    1.850 


14509it [06:10, 38.31it/s]

| epoch   1 | 14500/17123 batches | perplexity    6.483 | loss    1.869 


15006it [06:23, 41.54it/s]

| epoch   1 | 15000/17123 batches | perplexity    6.267 | loss    1.835 


15504it [06:35, 36.23it/s]

| epoch   1 | 15500/17123 batches | perplexity    6.249 | loss    1.832 


16006it [06:48, 40.18it/s]

| epoch   1 | 16000/17123 batches | perplexity    6.341 | loss    1.847 


16504it [07:00, 40.17it/s]

| epoch   1 | 16500/17123 batches | perplexity    6.333 | loss    1.846 


17007it [07:12, 40.72it/s]

| epoch   1 | 17000/17123 batches | perplexity    6.264 | loss    1.835 


17123it [07:15, 39.31it/s]


-----------------------------------------------------------
| end of epoch   1 | time: 438.14s | valid perplexity    6.260 | valid loss    1.834
-----------------------------------------------------------


506it [00:13, 39.15it/s]

| epoch   2 |   500/17123 batches | perplexity    5.919 | loss    1.778 


1007it [00:28, 39.32it/s]

| epoch   2 |  1000/17123 batches | perplexity    5.911 | loss    1.777 


1506it [00:41, 38.10it/s]

| epoch   2 |  1500/17123 batches | perplexity    5.854 | loss    1.767 


2007it [00:55, 40.95it/s]

| epoch   2 |  2000/17123 batches | perplexity    5.850 | loss    1.766 


2508it [01:07, 38.48it/s]

| epoch   2 |  2500/17123 batches | perplexity    5.747 | loss    1.749 


3006it [01:20, 40.77it/s]

| epoch   2 |  3000/17123 batches | perplexity    5.562 | loss    1.716 


3506it [01:34, 37.20it/s]

| epoch   2 |  3500/17123 batches | perplexity    5.875 | loss    1.771 


4004it [01:48, 35.96it/s]

| epoch   2 |  4000/17123 batches | perplexity    5.621 | loss    1.726 


4506it [02:01, 40.47it/s]

| epoch   2 |  4500/17123 batches | perplexity    5.780 | loss    1.754 


5009it [02:14, 40.80it/s]

| epoch   2 |  5000/17123 batches | perplexity    5.825 | loss    1.762 


5504it [02:26, 41.83it/s]

| epoch   2 |  5500/17123 batches | perplexity    5.829 | loss    1.763 


6006it [02:38, 41.96it/s]

| epoch   2 |  6000/17123 batches | perplexity    5.711 | loss    1.742 


6507it [02:50, 40.82it/s]

| epoch   2 |  6500/17123 batches | perplexity    5.988 | loss    1.790 


7008it [03:03, 41.10it/s]

| epoch   2 |  7000/17123 batches | perplexity    5.762 | loss    1.751 


7508it [03:15, 42.11it/s]

| epoch   2 |  7500/17123 batches | perplexity    5.758 | loss    1.751 


8005it [03:27, 40.98it/s]

| epoch   2 |  8000/17123 batches | perplexity    5.655 | loss    1.733 


8508it [03:39, 40.59it/s]

| epoch   2 |  8500/17123 batches | perplexity    5.605 | loss    1.724 


9009it [03:51, 43.18it/s]

| epoch   2 |  9000/17123 batches | perplexity    5.796 | loss    1.757 


9504it [04:03, 40.04it/s]

| epoch   2 |  9500/17123 batches | perplexity    5.779 | loss    1.754 


10007it [04:16, 41.65it/s]

| epoch   2 | 10000/17123 batches | perplexity    5.834 | loss    1.764 


10504it [04:28, 40.68it/s]

| epoch   2 | 10500/17123 batches | perplexity    5.801 | loss    1.758 


11004it [04:40, 41.72it/s]

| epoch   2 | 11000/17123 batches | perplexity    5.818 | loss    1.761 


11506it [04:53, 41.11it/s]

| epoch   2 | 11500/17123 batches | perplexity    5.760 | loss    1.751 


12005it [05:05, 39.62it/s]

| epoch   2 | 12000/17123 batches | perplexity    5.811 | loss    1.760 


12506it [05:18, 40.23it/s]

| epoch   2 | 12500/17123 batches | perplexity    5.576 | loss    1.719 


13005it [05:30, 41.21it/s]

| epoch   2 | 13000/17123 batches | perplexity    5.743 | loss    1.748 


13504it [05:43, 39.84it/s]

| epoch   2 | 13500/17123 batches | perplexity    5.669 | loss    1.735 


14005it [05:55, 36.00it/s]

| epoch   2 | 14000/17123 batches | perplexity    5.687 | loss    1.738 


14509it [06:09, 38.99it/s]

| epoch   2 | 14500/17123 batches | perplexity    5.688 | loss    1.738 


15005it [06:21, 38.32it/s]

| epoch   2 | 15000/17123 batches | perplexity    5.659 | loss    1.733 


15509it [06:34, 41.79it/s]

| epoch   2 | 15500/17123 batches | perplexity    5.457 | loss    1.697 


16008it [06:46, 41.61it/s]

| epoch   2 | 16000/17123 batches | perplexity    5.726 | loss    1.745 


16505it [06:58, 40.39it/s]

| epoch   2 | 16500/17123 batches | perplexity    5.563 | loss    1.716 


17009it [07:11, 40.11it/s]

| epoch   2 | 17000/17123 batches | perplexity    5.686 | loss    1.738 


17123it [07:13, 39.46it/s]


-----------------------------------------------------------
| end of epoch   2 | time: 436.45s | valid perplexity    5.719 | valid loss    1.744
-----------------------------------------------------------


505it [00:12, 40.90it/s]

| epoch   3 |   500/17123 batches | perplexity    5.728 | loss    1.745 


1005it [00:24, 39.64it/s]

| epoch   3 |  1000/17123 batches | perplexity    5.580 | loss    1.719 


1507it [00:37, 41.65it/s]

| epoch   3 |  1500/17123 batches | perplexity    5.844 | loss    1.765 


2006it [00:49, 40.19it/s]

| epoch   3 |  2000/17123 batches | perplexity    5.689 | loss    1.739 


2505it [01:01, 42.51it/s]

| epoch   3 |  2500/17123 batches | perplexity    5.730 | loss    1.746 


3007it [01:14, 39.35it/s]

| epoch   3 |  3000/17123 batches | perplexity    5.517 | loss    1.708 


3509it [01:26, 41.59it/s]

| epoch   3 |  3500/17123 batches | perplexity    5.738 | loss    1.747 


4004it [01:38, 41.45it/s]

| epoch   3 |  4000/17123 batches | perplexity    5.645 | loss    1.731 


4509it [01:51, 41.06it/s]

| epoch   3 |  4500/17123 batches | perplexity    5.686 | loss    1.738 


5009it [02:03, 41.70it/s]

| epoch   3 |  5000/17123 batches | perplexity    5.617 | loss    1.726 


5509it [02:16, 41.46it/s]

| epoch   3 |  5500/17123 batches | perplexity    5.472 | loss    1.700 


6005it [02:28, 39.56it/s]

| epoch   3 |  6000/17123 batches | perplexity    5.715 | loss    1.743 


6509it [02:41, 39.92it/s]

| epoch   3 |  6500/17123 batches | perplexity    5.692 | loss    1.739 


7006it [02:53, 43.02it/s]

| epoch   3 |  7000/17123 batches | perplexity    5.544 | loss    1.713 


7504it [03:05, 39.80it/s]

| epoch   3 |  7500/17123 batches | perplexity    5.575 | loss    1.718 


8006it [03:18, 41.12it/s]

| epoch   3 |  8000/17123 batches | perplexity    5.636 | loss    1.729 


8507it [03:30, 40.58it/s]

| epoch   3 |  8500/17123 batches | perplexity    5.571 | loss    1.718 


9007it [03:43, 41.84it/s]

| epoch   3 |  9000/17123 batches | perplexity    5.864 | loss    1.769 


9505it [03:55, 39.95it/s]

| epoch   3 |  9500/17123 batches | perplexity    5.684 | loss    1.738 


10004it [04:07, 39.98it/s]

| epoch   3 | 10000/17123 batches | perplexity    5.632 | loss    1.728 


10506it [04:19, 40.51it/s]

| epoch   3 | 10500/17123 batches | perplexity    5.562 | loss    1.716 


11005it [04:32, 39.47it/s]

| epoch   3 | 11000/17123 batches | perplexity    5.734 | loss    1.746 


11507it [04:44, 40.98it/s]

| epoch   3 | 11500/17123 batches | perplexity    5.654 | loss    1.732 


12007it [04:56, 38.82it/s]

| epoch   3 | 12000/17123 batches | perplexity    5.620 | loss    1.726 


12506it [05:09, 40.42it/s]

| epoch   3 | 12500/17123 batches | perplexity    5.786 | loss    1.756 


13006it [05:21, 40.67it/s]

| epoch   3 | 13000/17123 batches | perplexity    5.460 | loss    1.697 


13507it [05:33, 42.59it/s]

| epoch   3 | 13500/17123 batches | perplexity    5.739 | loss    1.747 


14009it [05:46, 41.96it/s]

| epoch   3 | 14000/17123 batches | perplexity    5.618 | loss    1.726 


14508it [05:58, 42.24it/s]

| epoch   3 | 14500/17123 batches | perplexity    5.618 | loss    1.726 


15006it [06:10, 41.11it/s]

| epoch   3 | 15000/17123 batches | perplexity    5.616 | loss    1.726 


15505it [06:23, 41.56it/s]

| epoch   3 | 15500/17123 batches | perplexity    5.708 | loss    1.742 


16007it [06:35, 39.45it/s]

| epoch   3 | 16000/17123 batches | perplexity    5.473 | loss    1.700 


16505it [06:47, 39.94it/s]

| epoch   3 | 16500/17123 batches | perplexity    5.497 | loss    1.704 


17007it [06:59, 40.44it/s]

| epoch   3 | 17000/17123 batches | perplexity    5.637 | loss    1.729 


17123it [07:02, 40.50it/s]


-----------------------------------------------------------
| end of epoch   3 | time: 425.36s | valid perplexity    5.674 | valid loss    1.736
-----------------------------------------------------------


506it [00:12, 39.58it/s]

| epoch   4 |   500/17123 batches | perplexity    5.686 | loss    1.738 


1006it [00:24, 42.24it/s]

| epoch   4 |  1000/17123 batches | perplexity    5.431 | loss    1.692 


1508it [00:37, 40.87it/s]

| epoch   4 |  1500/17123 batches | perplexity    5.588 | loss    1.721 


2009it [00:49, 40.46it/s]

| epoch   4 |  2000/17123 batches | perplexity    5.814 | loss    1.760 


2505it [01:01, 41.16it/s]

| epoch   4 |  2500/17123 batches | perplexity    5.790 | loss    1.756 


3006it [01:14, 39.48it/s]

| epoch   4 |  3000/17123 batches | perplexity    5.575 | loss    1.718 


3506it [01:26, 40.23it/s]

| epoch   4 |  3500/17123 batches | perplexity    5.614 | loss    1.725 


4004it [01:39, 34.83it/s]

| epoch   4 |  4000/17123 batches | perplexity    5.549 | loss    1.714 


4509it [01:53, 37.48it/s]

| epoch   4 |  4500/17123 batches | perplexity    5.727 | loss    1.745 


5006it [02:05, 39.11it/s]

| epoch   4 |  5000/17123 batches | perplexity    5.572 | loss    1.718 


5508it [02:18, 41.28it/s]

| epoch   4 |  5500/17123 batches | perplexity    5.555 | loss    1.715 


6008it [02:30, 40.66it/s]

| epoch   4 |  6000/17123 batches | perplexity    5.564 | loss    1.716 


6504it [02:42, 36.12it/s]

| epoch   4 |  6500/17123 batches | perplexity    5.695 | loss    1.740 


7008it [02:55, 39.38it/s]

| epoch   4 |  7000/17123 batches | perplexity    5.672 | loss    1.735 


7507it [03:08, 39.59it/s]

| epoch   4 |  7500/17123 batches | perplexity    5.664 | loss    1.734 


8006it [03:20, 41.38it/s]

| epoch   4 |  8000/17123 batches | perplexity    5.552 | loss    1.714 


8508it [03:33, 40.32it/s]

| epoch   4 |  8500/17123 batches | perplexity    5.736 | loss    1.747 


9004it [03:45, 38.57it/s]

| epoch   4 |  9000/17123 batches | perplexity    5.615 | loss    1.726 


9505it [03:58, 40.86it/s]

| epoch   4 |  9500/17123 batches | perplexity    5.643 | loss    1.730 


10009it [04:10, 39.82it/s]

| epoch   4 | 10000/17123 batches | perplexity    5.497 | loss    1.704 


10507it [04:23, 39.33it/s]

| epoch   4 | 10500/17123 batches | perplexity    5.723 | loss    1.744 


11008it [04:35, 38.87it/s]

| epoch   4 | 11000/17123 batches | perplexity    5.642 | loss    1.730 


11507it [04:48, 41.64it/s]

| epoch   4 | 11500/17123 batches | perplexity    5.556 | loss    1.715 


12006it [05:01, 39.17it/s]

| epoch   4 | 12000/17123 batches | perplexity    5.642 | loss    1.730 


12507it [05:13, 41.10it/s]

| epoch   4 | 12500/17123 batches | perplexity    5.566 | loss    1.717 


13007it [05:26, 40.16it/s]

| epoch   4 | 13000/17123 batches | perplexity    5.657 | loss    1.733 


13505it [05:38, 40.12it/s]

| epoch   4 | 13500/17123 batches | perplexity    5.536 | loss    1.711 


14008it [05:51, 40.89it/s]

| epoch   4 | 14000/17123 batches | perplexity    5.773 | loss    1.753 


14505it [06:03, 41.19it/s]

| epoch   4 | 14500/17123 batches | perplexity    5.809 | loss    1.759 


15008it [06:16, 39.05it/s]

| epoch   4 | 15000/17123 batches | perplexity    5.570 | loss    1.717 


15508it [06:28, 39.57it/s]

| epoch   4 | 15500/17123 batches | perplexity    5.420 | loss    1.690 


16005it [06:41, 41.43it/s]

| epoch   4 | 16000/17123 batches | perplexity    5.578 | loss    1.719 


16508it [06:53, 38.77it/s]

| epoch   4 | 16500/17123 batches | perplexity    5.577 | loss    1.719 


17008it [07:06, 41.39it/s]

| epoch   4 | 17000/17123 batches | perplexity    5.691 | loss    1.739 


17123it [07:09, 39.89it/s]


-----------------------------------------------------------
| end of epoch   4 | time: 431.72s | valid perplexity    5.667 | valid loss    1.735
-----------------------------------------------------------


507it [00:12, 41.24it/s]

| epoch   5 |   500/17123 batches | perplexity    5.638 | loss    1.730 


1005it [00:24, 41.13it/s]

| epoch   5 |  1000/17123 batches | perplexity    5.447 | loss    1.695 


1506it [00:37, 41.43it/s]

| epoch   5 |  1500/17123 batches | perplexity    5.491 | loss    1.703 


2008it [00:50, 37.44it/s]

| epoch   5 |  2000/17123 batches | perplexity    5.610 | loss    1.724 


2509it [01:02, 41.39it/s]

| epoch   5 |  2500/17123 batches | perplexity    5.575 | loss    1.718 


3008it [01:15, 40.30it/s]

| epoch   5 |  3000/17123 batches | perplexity    5.539 | loss    1.712 


3508it [01:27, 40.86it/s]

| epoch   5 |  3500/17123 batches | perplexity    5.560 | loss    1.716 


4007it [01:39, 41.96it/s]

| epoch   5 |  4000/17123 batches | perplexity    5.855 | loss    1.767 


4507it [01:52, 41.71it/s]

| epoch   5 |  4500/17123 batches | perplexity    5.739 | loss    1.747 


5007it [02:04, 40.61it/s]

| epoch   5 |  5000/17123 batches | perplexity    5.682 | loss    1.737 


5505it [02:16, 41.38it/s]

| epoch   5 |  5500/17123 batches | perplexity    5.761 | loss    1.751 


6005it [02:28, 41.16it/s]

| epoch   5 |  6000/17123 batches | perplexity    5.556 | loss    1.715 


6507it [02:41, 41.58it/s]

| epoch   5 |  6500/17123 batches | perplexity    5.744 | loss    1.748 


7009it [02:53, 42.06it/s]

| epoch   5 |  7000/17123 batches | perplexity    5.628 | loss    1.728 


7507it [03:05, 40.52it/s]

| epoch   5 |  7500/17123 batches | perplexity    5.720 | loss    1.744 


8002it [03:18, 24.68it/s]

| epoch   5 |  8000/17123 batches | perplexity    5.643 | loss    1.730 


8507it [03:31, 38.45it/s]

| epoch   5 |  8500/17123 batches | perplexity    5.640 | loss    1.730 


9006it [03:43, 40.80it/s]

| epoch   5 |  9000/17123 batches | perplexity    5.675 | loss    1.736 


9508it [03:56, 39.33it/s]

| epoch   5 |  9500/17123 batches | perplexity    5.524 | loss    1.709 


10008it [04:08, 42.96it/s]

| epoch   5 | 10000/17123 batches | perplexity    5.602 | loss    1.723 


10506it [04:20, 39.66it/s]

| epoch   5 | 10500/17123 batches | perplexity    5.719 | loss    1.744 


11005it [04:34, 37.75it/s]

| epoch   5 | 11000/17123 batches | perplexity    5.605 | loss    1.724 


11509it [04:47, 39.43it/s]

| epoch   5 | 11500/17123 batches | perplexity    5.576 | loss    1.718 


12005it [04:59, 40.12it/s]

| epoch   5 | 12000/17123 batches | perplexity    5.661 | loss    1.734 


12508it [05:11, 41.58it/s]

| epoch   5 | 12500/17123 batches | perplexity    5.547 | loss    1.713 


13009it [05:24, 42.47it/s]

| epoch   5 | 13000/17123 batches | perplexity    5.602 | loss    1.723 


13504it [05:36, 39.13it/s]

| epoch   5 | 13500/17123 batches | perplexity    5.615 | loss    1.725 


14005it [05:48, 41.39it/s]

| epoch   5 | 14000/17123 batches | perplexity    5.594 | loss    1.722 


14509it [06:01, 41.72it/s]

| epoch   5 | 14500/17123 batches | perplexity    5.587 | loss    1.720 


15007it [06:13, 40.87it/s]

| epoch   5 | 15000/17123 batches | perplexity    5.615 | loss    1.725 


15506it [06:25, 40.40it/s]

| epoch   5 | 15500/17123 batches | perplexity    5.448 | loss    1.695 


16007it [06:38, 39.32it/s]

| epoch   5 | 16000/17123 batches | perplexity    5.723 | loss    1.745 


16506it [06:50, 40.61it/s]

| epoch   5 | 16500/17123 batches | perplexity    5.615 | loss    1.725 


17006it [07:02, 41.90it/s]

| epoch   5 | 17000/17123 batches | perplexity    5.668 | loss    1.735 


17123it [07:05, 40.24it/s]


-----------------------------------------------------------
| end of epoch   5 | time: 428.17s | valid perplexity    5.667 | valid loss    1.735
-----------------------------------------------------------
Checking the results of test dataset.
test perplexity    5.712 | test loss    1.742 


Hint: For the above, you should see your loss around 2.0 and going down. Similarly to perplexity which should be aroud 7 to 8.

## Generate some text.

In [21]:
def generate_word(model, dataset, window):
    generated_word = []
    # Set the context to a window-1 length array having just the MARKER character's token_id.
    context = [0] * (n-1)
    
    while True:
        logits = model(torch.tensor(context).view(1, -1))
        
        # Get the probabilities from the logits.
        # Hint: softmax!
        softmax = torch.nn.Softmax()
        probs = softmax(logits)
        
        # Get 1 sample from a multinomial having the above probabilities.
        token_id = torch.multinomial(probs, 1, replacement=True).item()
        
        # Append the token_id to the generated word.
        generated_word.append(token_id)
        
        # Move the context over 1, drop the first (oldest) token and apped the new one above.
        # The size of the resulting context should be the same.
        # For exaple, if it was "[0, 1, 2]" and you generated 4, it should now be [1, 2, 4].
        context = context[1:] + [token_id]
        
        if token_id == 0:
            # If you generate token_id = 0, i.e. '.', break out.
            break
    # Return and decode the generated word to a string.        
    return ''.join(dataset.decode(generated_word))

In [22]:
torch.manual_seed(1)
for _ in range(50):
    print(generate_word(model, train_dataset, n))

aka.
kamilih.
nacdon.
nuzorla.
zon.
jovrayah.
calie.
rajandhorandu.
lydah.
jandelaigh.
cho.
dumiriye.
zelise.
truya.
mereiyanny.


  probs = softmax(logits)


beya.
kinza.
kambrisy.
mayah.
boria.
dalleya.
pellia.
dhi.
amar.
ymandilistiee.
adanf.
damekinzaiy.
briel.
rirkanie.
just.
esa.
yra.
sum.
azon.
jerison.
killa.
ell.
est.
sen.
smande.
malanna.
dakerrishayn.
rosy.
rynne.
mayshantm.
caelyliyana.
maen.
asialarynn.
jaiton.
rhetly.
