In [1]:
import os
import requests
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from torch import nn

# Get data

In [2]:
if os.path.exists('./les_miserables.txt'):
    with open('./les_miserables.txt', 'r', encoding='utf-8-sig') as file:
        les_miserables = file.read()
else:
    url_tomes = [
        'https://www.gutenberg.org/ebooks/17489.txt.utf-8',
        'https://www.gutenberg.org/ebooks/17493.txt.utf-8',
        'https://www.gutenberg.org/ebooks/17494.txt.utf-8',
        'https://www.gutenberg.org/ebooks/17518.txt.utf-8',
        'https://www.gutenberg.org/ebooks/17519.txt.utf-8'
        ]
    les_miserables = ''
    for url in url_tomes:
        response = requests.get(url)
        response.encoding = 'utf-8-sig'
        tome = response.text
        tome = tome.replace('\r\n', ' ')
        les_miserables += tome
    with open('./les_miserables.txt', 'w') as file:
        file.write(les_miserables)
        
print(les_miserables[10000:10500])

nt monseigneur Bienvenu   Le palais épiscopal de Digne était attenant à l'hôpital.  Le palais épiscopal était un vaste et bel hôtel bâti en pierre au commencement du siècle dernier par monseigneur Henri Puget, docteur en théologie de la faculté de Paris, abbé de Simore, lequel était évêque de Digne en 1712. Ce palais était un vrai logis seigneurial. Tout y avait grand air, les appartements de l'évêque, les salons, les chambres, la cour d'honneur, fort large, avec promenoirs à arcades, selon l'an


# Build tokenizers

In [3]:
class CharacterLevelTokenizer:
    def __init__(self, train_text):
        self.train(train_text)
   
    def train(self, train_text):
        self.characters = sorted(set(train_text))
        self.vocab_size = len(self.characters)
        self.char_to_int = {c: i for i, c in enumerate(self.characters)}
        self.int_to_char = {i: c for i, c in enumerate(self.characters)}
        
    def encode(self, text):
        return [self.char_to_int[c] for c in text]
    
    def decode(self, text):
        return ''.join([self.int_to_char[i] for i in text])


class BPETokenizer:
    def __init__(self, train_text, vocab_size):
        self.train(train_text, vocab_size)
        self.vocab_size = vocab_size
    
    def _replace_pair(self, tokens, pair, idx):
            tokens_bpe = []
            i = 0
            while i < len(tokens):
                if i < len(tokens) - 1 and tokens[i] == pair[0] and tokens[i+1] == pair[1]:
                    tokens_bpe.append(idx)
                    i += 2
                else:
                    tokens_bpe.append(tokens[i])
                    i += 1
            return tokens_bpe
        
    def train(self, train_text, vocab_size):
        tokens = list(train_text.encode('utf-8'))
        self.merges = {}
        self.vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
        for idx in range(256, vocab_size):
            # get most frequent pair
            pair_counts = {}
            for pair in zip(tokens[:-1], tokens[1:]):
                pair_counts[pair] = pair_counts.get(pair, 0) + 1
            max_pair = max(pair_counts, key=pair_counts.get)
            
            # save results            
            self.merges[max_pair] = idx
            self.vocab[idx] = self.vocab[max_pair[0]] + self.vocab[max_pair[1]]
            
            # replace most frequent pair by new idx, others stay the same
            tokens = self._replace_pair(tokens, max_pair, idx)
            
    def encode(self, text):
        tokens = list(text.encode('utf-8'))
        for pair, idx in self.merges.items():
            tokens = self._replace_pair(tokens, pair, idx)
        return tokens
        
    def decode(self, tokens):
        text_bytes = b''.join([self.vocab[t] for t in tokens])
        text = text_bytes.decode("utf-8", errors="replace")
        return text


print('Character level tokenizer')
tokenizer = CharacterLevelTokenizer(les_miserables)
example_text = ''.join(tokenizer.characters)
print('Example text:', example_text)
print('Encoded:', tokenizer.encode(example_text))
print('Encoded+decoded:', tokenizer.decode(tokenizer.encode(example_text)))

print('\nByte Pair Encoding tokenizer')
tokenizer = BPETokenizer(les_miserables[:100000], vocab_size=1000)
example_text = les_miserables[2000000:2000100]
print('Example text:', example_text)
print('Encoded:', tokenizer.encode(example_text))
print('Encoded+decoded:', tokenizer.decode(tokenizer.encode(example_text)))

Character level tokenizer
Example text:  !"#$%'()*+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]_abcdefghijklmnopqrstuvwxyz«°º»ÀÂÇÈÉÊÔàâæçèéêëîïñôöùûü—‘’“”•™
Encoded: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116]
Encoded+decoded:  !"#$%'()*+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]_abcdefghijklmnopqrstuvwxyz«°º»ÀÂÇÈÉÊÔàâæçèéêëîïñôöùûü—‘’“”•™

Byte Pair Encoding tokenizer
Example text: ns çà et là dans l'azur pâle et profond, la terre toute noire, le ciel tout blanc, un frisson dans l
Encoded: [110, 257, 529, 291, 283, 651, 821, 97, 122, 117, 271, 112, 481, 275

# Build dataset

In [4]:
class LesMiserablesDataset(Dataset):
    def __init__(self, str_data, seq_len, tokenizer):
        self.data = torch.tensor(tokenizer.encode(str_data))
        self.seq_len = seq_len

    def __len__(self):
        return len(self.data) - self.seq_len

    def __getitem__(self, idx):
        context = self.data[idx:idx+self.seq_len]
        target = self.data[idx+1:idx+self.seq_len+1]
        return context, target

tokenizer = CharacterLevelTokenizer(les_miserables)
example_dataset = LesMiserablesDataset(les_miserables[:10000], 10, tokenizer)

print('x:', [tokenizer.decode([i.item() for i in example_dataset[1000][0]])])
print('y:', [tokenizer.decode([i.item() for i in example_dataset[1000][1]])])

x: ['Monsieur M']
y: ['onsieur My']


# Build model

In [5]:
class MaskedMultiHeadAttention(nn.Module):
    def __init__(self, seq_len, embed_size, nb_heads, head_size):
        super().__init__()
        self.nb_heads = nb_heads
        self.head_size = head_size
        self.query = nn.Linear(embed_size, nb_heads*head_size, bias=False)
        self.key = nn.Linear(embed_size, nb_heads*head_size, bias=False)
        self.value = nn.Linear(embed_size, nb_heads*head_size, bias=False)
        self.projection = nn.Linear(nb_heads*head_size, embed_size)
        self.register_buffer('mask', torch.tril(torch.ones(seq_len, seq_len)) == 0)
    
    def forward(self, x):
        batch_size, seq_len, _ = x.shape # x: batch_size x seq_len x embed_size
        # compute q, k, v
        q = self.query(x) # batch_size x seq_len x nb_heads*head_size
        q = q.view(batch_size, seq_len, self.nb_heads, self.head_size) # batch_size x seq_len x nb_heads x head_size
        q = q.permute(0, 2, 1, 3) # batch_size x nb_heads x seq_len x head_size
        k = self.key(x).view(batch_size, seq_len, self.nb_heads, self.head_size).permute(0, 2, 1, 3)
        v = self.value(x).view(batch_size, seq_len, self.nb_heads, self.head_size).permute(0, 2, 1, 3)
        
        # compute multi-head attention
        att = q @ k.transpose(2, 3) / self.head_size**0.5 # batch_size x nb_heads x seq_len x seq_len
        att.masked_fill_(mask=self.mask[:seq_len, :seq_len], value=float('-inf'))
        att = F.softmax(att, dim=-1)
        att = att @ v # batch_size x nb_heads x seq_len x head_size
        
        # concatenate heads and project
        att = att.permute(0, 2, 1, 3).reshape(batch_size, seq_len, self.nb_heads*self.head_size) # batch_size x seq_len x nb_heads*head_size
        att = self.projection(att) # batch_size x seq_len x embed_size
        return att


class Block(nn.Module):
    def __init__(self, seq_len, embed_size, nb_heads, head_size):
        super().__init__()
        self.masked_multi_head_attention = MaskedMultiHeadAttention(seq_len, embed_size, nb_heads, head_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, 4*embed_size),
            nn.ReLU(),
            nn.Linear(4*embed_size, embed_size))
        self.layer_norm = nn.LayerNorm(embed_size)
        
    def forward(self, x):
        x = x + self.masked_multi_head_attention(x)
        x = self.layer_norm(x)
        x = x + self.feed_forward(x)
        x = self.layer_norm(x)
        return x
        
        
class LesMiserablesLanguageModel(nn.Module):
    def __init__(self, tokenizer, seq_len, embed_size, nb_heads, head_size, n_blocks):
        super().__init__()
        self.tokenizer = tokenizer
        self.token_embedding = nn.Embedding(tokenizer.vocab_size, embed_size)
        self.position_embedding = nn.Embedding(seq_len, embed_size)
        self.blocks = nn.Sequential(*[Block(seq_len, embed_size, nb_heads, head_size) for _ in range(n_blocks)])
        self.linear = nn.Linear(embed_size, tokenizer.vocab_size)
        self.seq_len = seq_len
        self.register_buffer('seq_arange', torch.arange(seq_len))
    
    def forward(self, x):
        seq_len = x.size(1)
        x = self.token_embedding(x) + self.position_embedding(self.seq_arange[:seq_len]) # batch_size x seq_len x embedding_dim
        x = self.blocks(x) # batch_size x seq_len x embedding_dim
        x = self.linear(x) # batch_size x seq_len x vocab_size
        return x
    
    def generate(self, x, nb_tokens):
        assert type(x) == str
        x = torch.tensor(self.tokenizer.encode(x), device=self.seq_arange.device).unsqueeze(0)
        for _ in range(nb_tokens):
            logits = self(x[:, -self.seq_len:])
            logits = logits[:, -1, :]
            probas = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probas, 1)
            x = torch.cat((x, next_token), dim=1)
        txt = self.tokenizer.decode(x.tolist()[0])
        return txt

# Train

In [6]:
def train():
    itr = 0
    while itr < nb_iters: # otherwise dataloader ends the training when all dataset iterated
        for x, y in train_dataloader:
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = F.cross_entropy(y_pred.view(-1, tokenizer.vocab_size), y.view(-1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if itr % 2000 == 0:
                losses = evaluate()
                print(f"iter {itr}: train loss = {losses['train'].item():.3f}, test loss = {losses['test'].item():.3f}, generated text = {model.generate('Jean', 30)}")
            if itr >= nb_iters:
                break
            
            itr += 1
            


@torch.no_grad()
def evaluate():
    model.eval()
    losses = {'train': torch.zeros(eval_iters), 'test': torch.zeros(eval_iters)}
    for split, dataloader in zip(['train', 'test'], [train_dataloader, test_dataloader]):
        for i, (x, y) in enumerate(dataloader):
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = F.cross_entropy(y_pred.view(-1, tokenizer.vocab_size), y.view(-1))
            losses[split][i] = loss.item()
            if i >= eval_iters-1:
                break
    losses = {split: losses[split].mean() for split in ['train', 'test']}
    model.train()
    return losses

In [7]:
seq_len = 64
vocab_size = 1000
batch_size = 256
nb_iters = 20000
eval_iters = 10
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# tokenizer = CharacterLevelTokenizer(train_text=les_miserables)
tokenizer = BPETokenizer(train_text=les_miserables[:100000], vocab_size=vocab_size)

train_dataset = LesMiserablesDataset(les_miserables[:int(0.9*len(les_miserables))], seq_len, tokenizer)
test_dataset = LesMiserablesDataset(les_miserables[int(0.9*len(les_miserables)):], seq_len, tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

model = LesMiserablesLanguageModel(tokenizer, seq_len, embed_size=64, nb_heads=4, head_size=16, n_blocks=6).to(device)
optimizer = torch.optim.AdamW(model.parameters())

train()

iter 0: train loss = 6.960, test loss = 6.958, generated text = Jeanademoisdu Kmaiferementmane toute anlaissation �kaissmarest ez-livres_ _�	our �opi�bre ��
iter 2000: train loss = 3.289, test loss = 3.294, generated text = Jean Valjean, par le fait de la vierge inthe Jonffle jusqu'avait une gouche 
iter 4000: train loss = 3.045, test loss = 3.086, generated text = Jean Valjean, auprès d'y seulement trois cent mille trois hommes.  Aucun crime.     Chapitre X
iter 6000: train loss = 2.954, test loss = 3.023, generated text = Jean Valjean se comprit.  Il recommenait Muphant. Toute de Marius et D
iter 8000: train loss = 2.882, test loss = 3.001, generated text = Jean Valjean tournait main. C'était la longue figure sortait qu'elle s'abaissa 
iter 10000: train loss = 2.862, test loss = 2.963, generated text = Jean Valjean chez lui, avec une sortie qui perverne à aimer le désinsi st
iter 12000: train loss = 2.828, test loss = 2.941, generated text = Jean Valjean alors, artillé se retenant d