The goal of a language model is to estimate the word sequence probability  
$ P(x_1,x_2,...,x_T) $
The predictor type llm draws one token at a time: $ x_t ~ P(x_t | x_{t-1},...,x_1) $  

In [2]:
import torch
from d2l import torch as d2l

Modeling the process  
$ P(x_1, x_2, ..., x_T) = \prod_{t=1}^T P(x_t | x_1, ..., x_{t-1}) $

yet to simplify we use Markov models.  
The higher the order the more past we have to see the predict the future
$ P(x_{t+1} | x_t,...,x_1) = P(x_{t+1} | x_t) 

For metric likelihood is bad because it is very influenced from the length of each word.  
That why we use cross entropy or **perplexity** which is the exp(cross-entropy)

Training: 
* The dataset Sequence has length T 
* For randomness in every epoch discard d tokens from the sequence
* We split it in subsequences(batches) of size n

**The goal**
For language modeling, the goal is to predict the next token based on the tokens we have seen so far; hence the targets (labels) are the original sequence, shifted by one token. The target sequence for any input sequence $x_t$ is $x_{t+1}$ with length $n$.

In [3]:
@d2l.add_to_class(d2l.TimeMachine)  #@save
def __init__(self, batch_size, num_steps, num_train=10000, num_val=5000):
    super(d2l.TimeMachine, self).__init__()
    self.save_hyperparameters()
    corpus, self.vocab = self.build(self._download())
    array = torch.tensor([corpus[i:i+num_steps+1]
                        for i in range(len(corpus)-num_steps)])
    self.X, self.Y = array[:,:-1], array[:,1:]

In [4]:
@d2l.add_to_class(d2l.TimeMachine)  #@save
def get_dataloader(self, train):
    idx = slice(0, self.num_train) if train else slice(
        self.num_train, self.num_train + self.num_val)
    return self.get_tensorloader([self.X, self.Y], train, idx)

In [5]:
data = d2l.TimeMachine(batch_size=2, num_steps=10)
for X, Y in data.train_dataloader():
    print('X:', X, '\nY:', Y)
    break


X: tensor([[ 6,  0,  8,  6, 19, 14,  0, 16,  7,  0],
        [ 0, 17,  6, 19, 20, 16, 15,  0, 24, 10]]) 
Y: tensor([[ 0,  8,  6, 19, 14,  0, 16,  7,  0, 14],
        [17,  6, 19, 20, 16, 15,  0, 24, 10, 21]])
