# Dataset for a Language Model

The core goal of a language model is to predict next best word. So, here's how we can prepare a dataset for that.

Let's imagine this is our text

```
We usually eat ice-cream in a sunny day
```

So, let's assumue our language model as `F`. Then it should give us values like this:

```
F(We) => usually
F(We usually) => eat
F(We usually eat) => ice-cream
F(We usually eat ice-cream) => in
F(We usually eat ice-cream in) => a
F(We usually eat ice-cream in a) => sunny
F(We usually eat ice-cream in a) => day
```

So, we need to prepare a data set like this:

* Input: `[We usually eat]`
* Output: `[usually eat ice-cream]`

Instead of words, now we have numbers.

## Tokenizer

In [1]:
!pip install -q d2l==1.0.0-alpha1.post0

[0m

In [2]:
import torch
from d2l import torch as d2l
import re
from matplotlib import pyplot as plt
import numpy as np

In [3]:
class TimeMachine(d2l.DataModule): #@save
    def _download(self):
        fname = d2l.download(d2l.DATA_URL + 'timemachine.txt', self.root,
                             '090b5e7e70c295757f55df93cb0a180b9691891a')
        with open(fname) as f:
            return f.read()

data = TimeMachine()
raw_text = data._download()

In [70]:
class Vocab:
    def __init__(self):
        self.tokens = []
        self.token_to_id = {}
        self.token_freq = {}
        
        self._process_word("<unk>")
        
    def to_id(self, word):
        if word in self.token_to_id:
            return self.token_to_id[word]
        else:
            return self.token_to_id["<unk>"]
        
    def to_token(self, id):
        if id < len(self.tokens):
            return self.tokens[id]
        else:
            return self.tokens[0]
    
    def _process_word(self, word):
        idx = 0
        if word in self.token_to_id:
            idx = self.token_to_id[word]
            self.token_freq[idx] += 1;
        else:
            self.tokens.append(word)
            idx = len(self.tokens) - 1
            self.token_to_id[word] = idx
            self.token_freq[idx] = 1
            
        return idx
            
    def tokenize(self, text):
        cleaned = re.sub('[^A-Za-z]+', ' ', text).lower().strip()
        return cleaned.split(" ")
        
    def build(self, text):
        corpus = [self._process_word(word) for word in self.tokenize(text)]
        self.token_freq = dict(sorted(self.token_freq.items(), key=lambda item: item[1], reverse=True))
        return corpus
    
    
vocab = Vocab()
corpus = vocab.build(raw_text)

## Dataset

In [65]:
class LanguageDataset(torch.utils.data.Dataset):
    def __init__(self, corpus, segment_len):
        self.corpus = corpus
        self.segment_len = segment_len
        self.total_segments = len(corpus) - segment_len - 1
    
        self.input_list = []
        self.output_list = []
        for i in range(self.total_segments):
            self.input_list.append(corpus[i: i + segment_len])
            self.output_list.append(corpus[i + 1: i + 1 + segment_len])
        
    def __len__(self):
        return self.total_segments
    
    def __getitem__(self, i):
        return (torch.FloatTensor(self.input_list[i]), torch.FloatTensor(self.output_list[i]))
            
    
dataset = LanguageDataset(corpus, 10)
len(dataset), dataset[-1]

(32764,
 (tensor([3.0000e+01, 2.2000e+01, 4.5710e+03, 2.3250e+03, 4.1500e+02, 1.2390e+03,
          1.1800e+02, 4.4000e+01, 1.0000e+00, 2.8760e+03]),
  tensor([2.2000e+01, 4.5710e+03, 2.3250e+03, 4.1500e+02, 1.2390e+03, 1.1800e+02,
          4.4000e+01, 1.0000e+00, 2.8760e+03, 1.8000e+01])))

In [69]:
batch = next(iter(torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True)))
batch

[tensor([[8.0000e+00, 4.1200e+02, 3.9280e+03, 1.6000e+01, 4.2020e+03, 4.5240e+03,
          1.0000e+00, 4.5250e+03, 1.1400e+02, 1.0400e+02],
         [1.9840e+03, 3.0000e+01, 4.4000e+01, 2.6710e+03, 8.0000e+00, 1.0390e+03,
          3.3000e+02, 2.6720e+03, 3.3300e+02, 3.0000e+01],
         [3.7870e+03, 4.4000e+01, 2.2800e+03, 3.2580e+03, 8.0000e+00, 9.2700e+02,
          2.5700e+02, 1.2840e+03, 2.9980e+03, 1.6890e+03],
         [8.3000e+01, 8.4000e+01, 8.5000e+01, 8.6000e+01, 6.2000e+01, 3.0000e+01,
          8.7000e+01, 8.8000e+01, 2.6000e+01, 8.9000e+01]]),
 tensor([[4.1200e+02, 3.9280e+03, 1.6000e+01, 4.2020e+03, 4.5240e+03, 1.0000e+00,
          4.5250e+03, 1.1400e+02, 1.0400e+02, 8.0000e+00],
         [3.0000e+01, 4.4000e+01, 2.6710e+03, 8.0000e+00, 1.0390e+03, 3.3000e+02,
          2.6720e+03, 3.3300e+02, 3.0000e+01, 6.4000e+01],
         [4.4000e+01, 2.2800e+03, 3.2580e+03, 8.0000e+00, 9.2700e+02, 2.5700e+02,
          1.2840e+03, 2.9980e+03, 1.6890e+03, 3.7880e+03],
         [8