## Transformer From Scratch

### Transformer Decoder

#### Next Token Prediction Training

In [1]:
import math

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
import torch.nn.init as init
from tqdm import tqdm

In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super().__init__()
        pe = torch.zeros((max_seq_length, d_model))
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

In [15]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "Vetor de embedding precisa ser divisivel pelo número de cabeças da camada de atenção!"
        self.head_dim = d_model // num_heads
        self.d_model, self.num_heads = d_model, num_heads
        self.q = nn.Linear(d_model, d_model)
        self.k = nn.Linear(d_model, d_model)
        self.v = nn.Linear(d_model, d_model)
        self.output_linear = nn.Linear(d_model, d_model)

    def split_heads(self, x, encoder_output=None):
        # Entra Q, K, V com dimensão (batch_size, sequence_length, d_model)
        # Reshape para (batch_size, num_heads, sequence_length, d_model)
        # Reordering para (batch_size, num_heads, sequence_length, d_model)
        if encoder_output is None:
            x = torch.reshape(x, shape=(x.shape[0], self.num_heads, x.shape[1], self.head_dim)) #.contiguous()
        else:
            raise NotImplementedError("Modelo ainda não compatível com Encoder.")
        return x

    def compute_attention_scores(self, q_linear_out, k_linear_out, v_linear_out, mask=None):
        qk_dot_product = torch.matmul(q_linear_out, k_linear_out.transpose(2, 3)) / self.head_dim ** 0.5

        if mask is not None:
            qk_dot_product = qk_dot_product.masked_fill(mask == 0, float('-inf'))

        attn_scores = nn.functional.softmax(qk_dot_product, dim=-1)
        attn_weighted_v = torch.matmul(attn_scores, v_linear_out)

        return attn_weighted_v


    def combine_heads(self, x):
        return torch.reshape(x, shape=(x.shape[0], x.shape[2], int(x.shape[1] * x.shape[3]))).contiguous()

    def forward(self, x, mask):
        q_linear_out = self.split_heads(self.q(x))
        k_linear_out = self.split_heads(self.k(x))
        v_linear_out = self.split_heads(self.v(x))
        
        attn_weighted_v = self.compute_attention_scores(q_linear_out, k_linear_out, v_linear_out, mask=mask)
        attn_weighted_v = self.combine_heads(attn_weighted_v)
        return self.output_linear(attn_weighted_v)

In [16]:
matrix_1 = torch.rand(1, 8, 512, 10)
matrix_2 = torch.rand(1, 8, 512, 10)
print(torch.matmul(matrix_1, matrix_2.transpose(-1, -2)).shape)

torch.Size([1, 8, 512, 512])


In [17]:
class FeedForwardSubLayer(nn.Module):
    def __init__(self, d_model, hidden_size):
        super().__init__()
        self.ff_1 = nn.Linear(d_model, hidden_size)
        self.ff_2 = nn.Linear(hidden_size, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.ff_2(self.relu(self.ff_1(x)))

In [18]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model, hidden_size, num_heads, dropout=0.1):
        super().__init__()
        self.feed_forward = FeedForwardSubLayer(d_model, hidden_size)
        self.mha = MultiHeadAttention(d_model, num_heads) # nn.MultiheadAttention()
        self.norm_1 = nn.LayerNorm(d_model)
        self.norm_2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, tgt_mask):
        x = self.norm_1(x + self.dropout(self.mha(x, tgt_mask)))
        x = self.norm_2(x + self.dropout(self.feed_forward(x)))
        return x

In [19]:
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, max_sequence_length, n_layers, hidden_size, num_heads, dropout=0.1):
        super(TransformerDecoder, self).__init__()
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
        self.pe = PositionalEncoding(d_model, max_sequence_length)
        self.layers = nn.ModuleList(
            [DecoderBlock(d_model, hidden_size, num_heads, dropout) for _ in range(n_layers)]
        )
        self.output_layer = nn.Linear(d_model, vocab_size)

    def forward(self, x, tgt_mask):
        x = self.embedding(x)
        x = self.pe(x)
        for layer in self.layers:
            x = layer(x, tgt_mask)
        out = self.output_layer(x)
        return out

In [20]:
### 259 tokens possíveis:
# ASCII + <UNK> + <SOS> + <EOS> + <PAD>

class TokenizerChar:
    def __init__(self):
        self.chr_to_idx = {chr(v): v for v in range(1, 257)}
        self.chr_to_idx['<SOS>'] = 257
        self.chr_to_idx['<EOS>'] = 258
        self.chr_to_idx['<PAD>'] = 0
        self.chr_to_idx['<UNK>'] = 259

        self.idx_to_chr = {v: k for k, v in self.chr_to_idx.items()}

        self.vocab_size = len(self.chr_to_idx.keys())

    def encode(self, char):
        if char in self.chr_to_idx.keys():
            return self.chr_to_idx[char]
        else:
            return 259
    
    def decode(self, token_idx):
        return self.idx_to_chr[token_idx]
    
    def sos_token(self):
        return '<SOS>'
    
    def sos_token_idx(self):
        return self.chr_to_idx['<SOS>']

    def eos_token(self):
        return '<EOS>'
    
    def eos_token_idx(self):
        return self.chr_to_idx['<EOS>']
    
    def pad_token(self):
        return '<PAD>'
    
    def pad_token_idx(self):
        return self.chr_to_idx['<PAD>']
    
    def get_vocab_size(self):
        return self.vocab_size


In [21]:
from torch.utils.data import Dataset


class DatasetDialogs(Dataset):
    def __init__(self, dataset_path, sentence_length):
        self.dataset_path = dataset_path
        self.sentence_length = sentence_length
        self.tokenizer = TokenizerChar()

    def __len__(self):
        with open(self.dataset_path, 'r') as dataset:
            num_of_sentences = len(dataset.read().split('\n'))
        return num_of_sentences
    
    def get_shape(self):
        with open(self.dataset_path, 'r') as dataset:
            num_of_sentences = len(dataset.read().split('\n'))
        return (num_of_sentences, self.sentence_length)

    def __getitem__(self, line_idx):
        with open(self.dataset_path, 'r') as dataset:
            selected_sentence = dataset.read().split('\n')[line_idx]
            if len(selected_sentence) < self.sentence_length:
                input_tokens = [self.tokenizer.sos_token_idx()] + [self.tokenizer.encode(char) for char in selected_sentence]
                input_tokens.append(self.tokenizer.eos_token_idx())
                pad_length = self.sentence_length - len(input_tokens) + 1
                pad_tokens = [self.tokenizer.pad_token_idx()] * pad_length
                input_tokens += pad_tokens
            elif len(selected_sentence) == self.sentence_length:
                input_tokens = [self.tokenizer.sos_token_idx()] + [self.tokenizer.encode(char) for char in selected_sentence]
                input_tokens[-1] = self.tokenizer.eos_token_idx()
            elif len(selected_sentence) > self.sentence_length:
                selected_sentence = selected_sentence[:self.sentence_length]
                input_tokens = [self.tokenizer.sos_token_idx()] + [self.tokenizer.encode(char) for char in selected_sentence]
                input_tokens[-1] = self.tokenizer.eos_token_idx()
            # print(f'{len(input_tokens)} - {self.sentence_length}') # debug only
            assert len(input_tokens) == self.sentence_length + 1, f"Lista de índices de tokens não possui mesmo tamanho que 'sentence_length'! len(input_tokens): {len(input_tokens)} - self.sentence_length: {self.sentence_length}"
            try:
                x = torch.tensor(input_tokens[:-1])
                y = torch.tensor(input_tokens[1:])
            except RuntimeError as e:
                print(e)
                print(f"Input tokens: {input_tokens}")
                raise e
        return x, y

In [40]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
sequence_length = 250
batch_size = 32
dataset_train = DatasetDialogs('dataset_text/dialogs_train.txt', sequence_length)
dataset_test = DatasetDialogs('dataset_text/dialogs_test.txt', sequence_length)
dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True)
vocab_size = dataset_train.tokenizer.get_vocab_size()

d_model = 512
num_layers = 6
num_heads = 8
d_ff = 2048
dropout = 0.1
max_seq_length = sequence_length
# model = TransformerDecoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length)
model = TransformerDecoder(vocab_size, d_model, max_seq_length, num_layers, d_ff, num_heads, dropout=0.1)
model.to(device)

tgt_mask = (1 - torch.triu(
  torch.ones(1, sequence_length, sequence_length), diagonal=1)
).bool()

def init_weights(module):
    if isinstance(module, (nn.Linear)):
        init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
        if module.bias is not None:
            init.zeros_(module.bias)
model.apply(init_weights)

optimizer = Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss(ignore_index=0)
n_epochs = 50
n_batches = int(dataset_train.__len__() // batch_size)

print("Starting model training...")
for epoch in range(n_epochs):
    print(f"Epoch: {epoch + 1}")
    avg_loss = 0
    model.train()
    for batch_idx, batch in enumerate(tqdm(dataloader_train, total=n_batches)):
        x, y = batch
        x = x.to(device)
        y = y.to(device)
        outputs = model(x, tgt_mask.to(device))
        loss = loss_fn(outputs.view(-1, vocab_size), y.view(-1))
        avg_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    torch.save(model.state_dict(), f'model_checkpoints/model_checkpoint_{epoch+1}.pth')

    avg_loss /= (batch_idx + 1)
    print(f"Average epoch training loss: {avg_loss}")
    print(f"Last batch training loss: {loss}")

    model.eval()
    avg_loss = 0
    for batch_idx, batch in enumerate(dataloader_test):
        x, y = batch
        x = x.to(device)
        y = y.to(device)
        outputs = model(x, tgt_mask.to(device))
        loss = loss_fn(outputs.view(-1, vocab_size), y.view(-1))
        avg_loss += loss.item()
    
    avg_loss /= (batch_idx + 1)
    print(f"Epoch validation loss: {avg_loss}")
    

cuda
Starting model training...
Epoch: 1


100%|██████████| 107/107 [00:18<00:00,  5.74it/s]


Average epoch training loss: 3.1735775448451533
Last batch training loss: 2.7660892009735107
Epoch validation loss: 2.6767372846603394
Epoch: 2


100%|██████████| 107/107 [00:18<00:00,  5.80it/s]


Average epoch training loss: 2.580989764115521
Last batch training loss: 2.481311798095703
Epoch validation loss: 2.4199381828308106
Epoch: 3


100%|██████████| 107/107 [00:18<00:00,  5.76it/s]


Average epoch training loss: 2.4017743930638393
Last batch training loss: 2.3187801837921143
Epoch validation loss: 2.303889751434326
Epoch: 4


100%|██████████| 107/107 [00:18<00:00,  5.72it/s]


Average epoch training loss: 2.28008285861149
Last batch training loss: 2.1804916858673096
Epoch validation loss: 2.1839465618133547
Epoch: 5


100%|██████████| 107/107 [00:18<00:00,  5.66it/s]


Average epoch training loss: 2.157245627073484
Last batch training loss: 2.099167823791504
Epoch validation loss: 2.0813934564590455
Epoch: 6


100%|██████████| 107/107 [00:19<00:00,  5.50it/s]


Average epoch training loss: 2.0475467042388202
Last batch training loss: 2.074234962463379
Epoch validation loss: 1.9872362971305848
Epoch: 7


100%|██████████| 107/107 [00:20<00:00,  5.29it/s]


Average epoch training loss: 1.950556634742523
Last batch training loss: 1.8038610219955444
Epoch validation loss: 1.9098064064979554
Epoch: 8


100%|██████████| 107/107 [00:20<00:00,  5.18it/s]


Average epoch training loss: 1.870883761165298
Last batch training loss: 1.8429877758026123
Epoch validation loss: 1.8441076040267945
Epoch: 9


100%|██████████| 107/107 [00:20<00:00,  5.12it/s]


Average epoch training loss: 1.8056340384706158
Last batch training loss: 1.823001742362976
Epoch validation loss: 1.8065894365310669
Epoch: 10


100%|██████████| 107/107 [00:21<00:00,  5.09it/s]


Average epoch training loss: 1.7506003446668108
Last batch training loss: 1.669348955154419
Epoch validation loss: 1.770089852809906
Epoch: 11


100%|██████████| 107/107 [00:21<00:00,  5.07it/s]


Average epoch training loss: 1.6962462819625284
Last batch training loss: 1.6796238422393799
Epoch validation loss: 1.747429883480072
Epoch: 12


100%|██████████| 107/107 [00:21<00:00,  5.09it/s]


Average epoch training loss: 1.652255013724354
Last batch training loss: 1.5694397687911987
Epoch validation loss: 1.7166443705558776
Epoch: 13


100%|██████████| 107/107 [00:21<00:00,  4.97it/s]


Average epoch training loss: 1.6114624415602639
Last batch training loss: 1.6831997632980347
Epoch validation loss: 1.6984066724777223
Epoch: 14


100%|██████████| 107/107 [00:21<00:00,  4.95it/s]


Average epoch training loss: 1.5713848889431106
Last batch training loss: 1.4785460233688354
Epoch validation loss: 1.6779712915420533
Epoch: 15


100%|██████████| 107/107 [00:21<00:00,  4.95it/s]


Average epoch training loss: 1.5383966147342576
Last batch training loss: 1.4607819318771362
Epoch validation loss: 1.6440895080566407
Epoch: 16


100%|██████████| 107/107 [00:22<00:00,  4.85it/s]


Average epoch training loss: 1.5069312046621448
Last batch training loss: 1.5084620714187622
Epoch validation loss: 1.6411741733551026
Epoch: 17


100%|██████████| 107/107 [00:21<00:00,  5.02it/s]


Average epoch training loss: 1.478037105542477
Last batch training loss: 1.4786267280578613
Epoch validation loss: 1.6250544905662536
Epoch: 18


100%|██████████| 107/107 [00:21<00:00,  4.99it/s]


Average epoch training loss: 1.4514438437524242
Last batch training loss: 1.382041573524475
Epoch validation loss: 1.6198889136314392
Epoch: 19


100%|██████████| 107/107 [00:21<00:00,  5.05it/s]


Average epoch training loss: 1.4256321559442537
Last batch training loss: 1.4371919631958008
Epoch validation loss: 1.6077964663505555
Epoch: 20


100%|██████████| 107/107 [00:21<00:00,  4.90it/s]


Average epoch training loss: 1.4022864381843638
Last batch training loss: 1.4393110275268555
Epoch validation loss: 1.6071891784667969
Epoch: 21


100%|██████████| 107/107 [00:21<00:00,  4.99it/s]


Average epoch training loss: 1.3799509467365585
Last batch training loss: 1.4499001502990723
Epoch validation loss: 1.595920741558075
Epoch: 22


100%|██████████| 107/107 [00:21<00:00,  4.89it/s]


Average epoch training loss: 1.3591369212230788
Last batch training loss: 1.3573778867721558
Epoch validation loss: 1.5823466777801514
Epoch: 23


100%|██████████| 107/107 [00:21<00:00,  4.98it/s]


Average epoch training loss: 1.3381482942082057
Last batch training loss: 1.3940792083740234
Epoch validation loss: 1.585340976715088
Epoch: 24


100%|██████████| 107/107 [00:21<00:00,  4.99it/s]


Average epoch training loss: 1.320044987669615
Last batch training loss: 1.3111461400985718
Epoch validation loss: 1.5944262623786927
Epoch: 25


100%|██████████| 107/107 [00:21<00:00,  5.04it/s]


Average epoch training loss: 1.3019802470073523
Last batch training loss: 1.2995924949645996
Epoch validation loss: 1.5765890717506408
Epoch: 26


100%|██████████| 107/107 [00:21<00:00,  5.04it/s]


Average epoch training loss: 1.2816898622245432
Last batch training loss: 1.286665678024292
Epoch validation loss: 1.589464020729065
Epoch: 27


100%|██████████| 107/107 [00:21<00:00,  5.05it/s]


Average epoch training loss: 1.2667390275224346
Last batch training loss: 1.2584857940673828
Epoch validation loss: 1.5821871757507324
Epoch: 28


100%|██████████| 107/107 [00:21<00:00,  5.02it/s]


Average epoch training loss: 1.2498530182883003
Last batch training loss: 1.2928460836410522
Epoch validation loss: 1.581800389289856
Epoch: 29


100%|██████████| 107/107 [00:21<00:00,  4.98it/s]


Average epoch training loss: 1.234554877905088
Last batch training loss: 1.2436082363128662
Epoch validation loss: 1.577298104763031
Epoch: 30


100%|██████████| 107/107 [00:21<00:00,  5.05it/s]


Average epoch training loss: 1.2190044126778006
Last batch training loss: 1.1904098987579346
Epoch validation loss: 1.5823105454444886
Epoch: 31


100%|██████████| 107/107 [00:21<00:00,  4.97it/s]


Average epoch training loss: 1.2035604247422975
Last batch training loss: 1.2013874053955078
Epoch validation loss: 1.5917557954788208
Epoch: 32


100%|██████████| 107/107 [00:21<00:00,  4.98it/s]


Average epoch training loss: 1.1885405756602778
Last batch training loss: 1.1575632095336914
Epoch validation loss: 1.6000290870666505
Epoch: 33


100%|██████████| 107/107 [00:21<00:00,  5.02it/s]


Average epoch training loss: 1.176967635332981
Last batch training loss: 1.2055195569992065
Epoch validation loss: 1.5867462873458862
Epoch: 34


100%|██████████| 107/107 [00:21<00:00,  4.98it/s]


Average epoch training loss: 1.1612085672182457
Last batch training loss: 1.1793259382247925
Epoch validation loss: 1.5880399346351624
Epoch: 35


100%|██████████| 107/107 [00:21<00:00,  4.98it/s]


Average epoch training loss: 1.1463404637630854
Last batch training loss: 1.177831768989563
Epoch validation loss: 1.5902249097824097
Epoch: 36


100%|██████████| 107/107 [00:21<00:00,  4.98it/s]


Average epoch training loss: 1.1334810034136906
Last batch training loss: 1.178155541419983
Epoch validation loss: 1.6176453351974487
Epoch: 37


100%|██████████| 107/107 [00:21<00:00,  4.91it/s]


Average epoch training loss: 1.119857259999926
Last batch training loss: 1.1363736391067505
Epoch validation loss: 1.6042938590049745
Epoch: 38


100%|██████████| 107/107 [00:21<00:00,  4.94it/s]


Average epoch training loss: 1.1052395762684188
Last batch training loss: 1.1142548322677612
Epoch validation loss: 1.6196529746055603
Epoch: 39


100%|██████████| 107/107 [00:22<00:00,  4.84it/s]


Average epoch training loss: 1.0954942279886977
Last batch training loss: 1.0758405923843384
Epoch validation loss: 1.6311343550682067
Epoch: 40


100%|██████████| 107/107 [00:21<00:00,  5.02it/s]


Average epoch training loss: 1.0820913247973005
Last batch training loss: 1.0827819108963013
Epoch validation loss: 1.6220815896987915
Epoch: 41


100%|██████████| 107/107 [00:21<00:00,  5.05it/s]


Average epoch training loss: 1.0676533618820048
Last batch training loss: 1.0470728874206543
Epoch validation loss: 1.6483017086982727
Epoch: 42


100%|██████████| 107/107 [00:22<00:00,  4.83it/s]


Average epoch training loss: 1.057369582563917
Last batch training loss: 1.0569454431533813
Epoch validation loss: 1.646824264526367
Epoch: 43


100%|██████████| 107/107 [00:22<00:00,  4.80it/s]


Average epoch training loss: 1.0416490819966682
Last batch training loss: 1.059358835220337
Epoch validation loss: 1.656047785282135
Epoch: 44


100%|██████████| 107/107 [00:21<00:00,  4.94it/s]


Average epoch training loss: 1.0287096349992484
Last batch training loss: 0.9842891693115234
Epoch validation loss: 1.6568758845329286
Epoch: 45


100%|██████████| 107/107 [00:21<00:00,  5.09it/s]


Average epoch training loss: 1.0167588881243055
Last batch training loss: 1.037386417388916
Epoch validation loss: 1.6532402515411377
Epoch: 46


100%|██████████| 107/107 [00:22<00:00,  4.80it/s]


Average epoch training loss: 1.0047794285221634
Last batch training loss: 1.039312720298767
Epoch validation loss: 1.6715827345848084
Epoch: 47


100%|██████████| 107/107 [00:21<00:00,  5.00it/s]


Average epoch training loss: 0.9949499858874027
Last batch training loss: 1.0006977319717407
Epoch validation loss: 1.673180067539215
Epoch: 48


100%|██████████| 107/107 [00:22<00:00,  4.77it/s]


Average epoch training loss: 0.9795880869170216
Last batch training loss: 1.0079280138015747
Epoch validation loss: 1.700710701942444
Epoch: 49


100%|██████████| 107/107 [00:21<00:00,  4.88it/s]


Average epoch training loss: 0.9691606559486032
Last batch training loss: 0.9730602502822876
Epoch validation loss: 1.6907180190086364
Epoch: 50


100%|██████████| 107/107 [00:22<00:00,  4.82it/s]


Average epoch training loss: 0.9598854412542326
Last batch training loss: 1.0286651849746704
Epoch validation loss: 1.7168082237243651


In [43]:
def pad_sequence_to_length(sequence, sequence_length, pad_token_idx):
    """
    Completa a sequência com tokens de padding até atingir o comprimento desejado.

    Args:
        sequence (list): Sequência de tokens (índices) a ser completada.
        sequence_length (int): Comprimento desejado da sequência.
        pad_token_idx (int): Índice do token de padding.

    Returns:
        list: Sequência completada com tokens de padding.
    """
    if len(sequence) < sequence_length:
        # Adiciona tokens de padding no final da sequência
        sequence += [pad_token_idx] * (sequence_length - len(sequence))
    elif len(sequence) > sequence_length:
        # Trunca a sequência se for maior que o comprimento desejado
        sequence = sequence[:sequence_length]
    return sequence

In [44]:
import torch

def predict(start_text, model, tokenizer, sequence_length=250, temperature=1.0):
    model.eval()
    device = next(model.parameters()).device

    # Tokeniza e preenche a sequência
    sequence = [tokenizer.sos_token_idx()] + [tokenizer.encode(c) for c in start_text]
    sequence = pad_sequence_to_length(sequence, sequence_length, tokenizer.pad_token_idx())
    input_tokens = torch.tensor(sequence, dtype=torch.long).unsqueeze(0).to(device)

    current_text = start_text

    with torch.no_grad():
        for i in range(len(start_text), sequence_length):
            tgt_mask = torch.tril(torch.ones(sequence_length, sequence_length)).to(device).bool()
            outputs = model(input_tokens, tgt_mask=tgt_mask)
            log_probs = outputs[0, i - 1] / temperature

            predicted_token_idx = torch.distributions.Categorical(logits=log_probs).sample().item()

            if predicted_token_idx == tokenizer.eos_token_idx():
                break

            current_text += tokenizer.decode(predicted_token_idx)
            input_tokens[0, i] = predicted_token_idx

    print('Texto predito:', current_text)
    return current_text


In [45]:
dataset = DatasetDialogs('dataset_text/dialogs.txt', 50)

In [46]:
predict('Hello ', model=model, tokenizer=dataset.tokenizer, temperature=1)

Texto predito: Hello , i went.	it sure isn't. because i was nice to do.


"Hello , i went.\tit sure isn't. because i was nice to do."

In [47]:
predict('How are you doing? ', model=model, tokenizer=dataset.tokenizer, temperature=0.1)

Texto predito: How are you doing? 	i was hoping.


'How are you doing? \ti was hoping.'