# Homework 3 - Text generation with LSTM and Transformer networks



## Installs the unidecode library and downloads the Shakespeare dataset.

In [70]:
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

## LSTM implementation

For this task you will implement the LSTM neural network architecture and train it on the task of character-level text generation. Implement a single layer LSTM and optionally extend your implementation to multiple layers to generate better results.

Links:

- https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html -- Lists the equations for each component of the LSTM cell.
- http://colah.github.io/posts/2015-08-Understanding-LSTMs/ -- Intuitive explanation of LSTM
- http://karpathy.github.io/2015/05/21/rnn-effectiveness/ -- Explanation and uses of RNNs.


Implement the initialization and the forward pass of a LSTMCell and use it as part of the LSTMSimple network class.

The input of the LSTM network will be a sequence of characters, whereas the input of the LSTMCell will be a single input character (x), the output of the previous iteration (C) and the hidden state of the previous iteration (h). Iteratively process the entire input character sequence and calculate the loss based on the prediction at each time step.

### Do NOT use the torch.nn.LSTM class.


In [71]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset

class LSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):

        super(LSTMCell, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        # compute the forgetting gate
        self.forgetting_gate = nn.Sequential(
            nn.Linear(self.input_dim+self.hidden_dim, self.output_dim),
            nn.Sigmoid()
        )

        # compute the input gate and new memories
        self.input_gate = nn.Sequential(
            nn.Linear(self.input_dim+self.hidden_dim, self.output_dim),
            nn.Sigmoid()
        )
        self.new_memories = nn.Sequential(
            nn.Linear(self.input_dim+self.hidden_dim, self.output_dim),
            nn.Tanh()
        )

        # compute the new hidden state
        self.new_hidden_mask = nn.Sequential(
            nn.Linear(self.input_dim+self.hidden_dim, self.hidden_dim),
            nn.Sigmoid()
        )
        self.new_hidden_vals = nn.Sequential(
            nn.Linear(self.input_dim+self.hidden_dim, self.hidden_dim),
            nn.Tanh()
        )

    def forward(self, x:torch.tensor, C:torch.tensor, h:torch.tensor):
        # x - batch of encoded characters
        # C - Cell state of the previous iteration
        # h - Hidden state of the previous iteration

        # concat h_t-1 and x 
        hidden_stack = torch.concat((x,h), dim=1)
        
        # calculate forgetting mask
        forgetting_mask = self.forgetting_gate(hidden_stack)
        # forget some C by the mask
        forgotten_C = C * forgetting_mask

        # calculat new memories
        input_mask = self.input_gate(hidden_stack)
        input_vals = self.new_memories(hidden_stack)
        masked_new_vals = input_mask * input_vals


        # modify cell state with new values
        new_C = forgotten_C + masked_new_vals
        
        # calculat new hidden dim
        hiden_mask = self.new_hidden_mask(hidden_stack)
        hiden_vals = self.new_hidden_vals(hidden_stack)
        new_hidden_state = hiden_mask * hiden_vals

        return new_C, new_hidden_state


class LSTMSimple(nn.Module):
    def __init__(self, seq_length, input_dim, hidden_dim, output_dim, batch_size):
        super(LSTMSimple, self).__init__()

        self.seq_length = seq_length
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        self.lstm_cell = LSTMCell(input_dim, hidden_dim, output_dim)
        self.proj = nn.Sequential(
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x):
        # x - One hot encoded batch - Shape: (batch, seq_len,)
        batch, tokens, features = x.shape
        c = torch.zeros((batch, self.output_dim)).cuda()
        h = torch.zeros((batch, self.hidden_dim)).cuda()
        out = torch.zeros((batch, self.seq_length, self.output_dim)).cuda()

        for t in range(tokens):
            x_t = x[:, t, :]
            c, h = self.lstm_cell(x_t,c,h)

            o = self.proj(h)
            out[:, t,:] = o
        
        return  out, (c,h)


        # Returns the predicted next character for each character in the
        # sequence (outputs), also returns the cell state and hidden state of the
        # LSTMCell call on the last character. -- outputs, (c,t)



### LSTM Sampling Code

To generate text the network must predict the next character in a sequence, however networks do not produce a single character but rather estimate the likelihood for each possible character. Sampling characters from the network output can be done in different ways with common ones being the Greedy sampling process and Top-K sampling.

In the simple greedy sampling method the network takes a text prompt as input and generates an additional N tokens by always taking the token with the highest prediction score as the next token.

In the Top-K sampling, randomness is added to the sampling process as the network samples from K most likely predicitons at each step. This alleviates the problem of generative models repeating text but may generate incorrect text by sampling inappropriate tokens.


In [72]:
def greedy_sampling_lstm(lstm, x, num_chars):
    # x -- b x onehot_char
    outputs = torch.zeros((1,num_chars,x.shape[2]))
    t_outputs, (cell_state, hidden) = lstm(x.float())
    for c in range(num_chars):
        output_tmp = torch.softmax(lstm.proj(hidden),dim=1)
        top_ind = torch.argmax(output_tmp,dim=1)[0]
        tmp = torch.zeros_like(x[:,0,:]).cuda()
        tmp[:,top_ind] = 1
        outputs[:,c] = tmp

        cell_state, hidden = lstm.lstm_cell(tmp,cell_state,hidden)
    return outputs

def topk_sampling_lstm(lstm, x, num_chars):
    # x -- b x onehot_char
    outputs = torch.zeros((1,num_chars,x.shape[2]))
    t_outputs, (cell_state, hidden) = lstm(x.float())
    for c in range(num_chars):
        output_vals, output_ind = torch.topk(lstm.proj(hidden), 5, dim=1)
        output_tmp = torch.softmax(output_vals,dim=1)
        top_ind = torch.multinomial(output_tmp[0], 1)[0]
        tmp = torch.zeros_like(x[:,0,:]).cuda()
        tmp[:,output_ind[0,top_ind]] = 1
        outputs[:,c] = tmp

        cell_state, hidden = lstm.lstm_cell(tmp,cell_state,hidden)

    return outputs

### LSTM Dataset Code

In [73]:
import unidecode
import string
import random
from torch.autograd import Variable
from torch.utils.data import Dataset


class LSTMDataset(Dataset):
    def __init__(self, chunk_len=200, padded_chunks=False):
        # Character based dataset
        dataset_path = "./input.txt"
        # The tokens in the vocabulary (all_characters)
        # are just the printable characters of the string class
        self.all_characters = string.printable
        self.n_characters = len(self.all_characters)
        # Maps characters to indices
        self.char_dict = {x:i for i,x in enumerate(self.all_characters)}
        self.file, self.file_len = self.read_file(dataset_path)
        # Sequence length of the input
        self.chunk_len = chunk_len

    def read_file(self,filename):
        file = unidecode.unidecode(open(filename).read())
        return file, len(file)

    def char_tensor(self,in_str):
        # in_str - input sequence - String
        # Return one-hot encoded characters of in_str
        tensor = torch.zeros(len(in_str),self.n_characters).long()
        char_ind = [self.char_dict[c] for c in in_str]
        tensor[torch.arange(tensor.shape[0]),char_ind] = 1
        return tensor

    def __getitem__(self, idx):
        inp, target = self.get_random_text()
        return {"input":inp, "target":target}

    def __len__(self):
        return 10000

    def get_random_text(self):
        # Pick a random string of length self.chunk_len from the dataset
        start_index = np.random.randint(0, self.file_len - self.chunk_len)
        end_index = start_index + self.chunk_len + 1
        chunk = self.file[start_index:end_index]
        # One-hot encode the chosen string
        inp = self.char_tensor(chunk[:-1])
        # The target string is the same as the
        # input string but shifted by 1 character
        target = self.char_tensor(chunk[1:])
        inp = Variable(inp).cuda()
        target = Variable(target).cuda()
        return inp, target


### LSTM Training loop

With a correct implementation you should get sensible text generation results with the set parameters, however you should experiment with various parameters,
especially with the sequence length (chunk_len) used during training.

In [74]:
import os
from tqdm import tqdm
import torch.optim as optim

batch_size = 256
chunk_len = 128
model_name = "LSTM"
train_dataset = LSTMDataset(chunk_len=chunk_len)
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=0, drop_last=True)
LSTM_MODEL_PATH = f"{model_name}-b{batch_size}-ch{chunk_len}.cktp"

In [77]:

#Sample parameters, use whatever you see fit.
input_dim = train_dataset.n_characters
hidden_dim = 256
output_dim = train_dataset.n_characters
learning_rate = 0.005
model = LSTMSimple(chunk_len,input_dim, hidden_dim, output_dim,batch_size)
model.train()
model.cuda()

# if(os.path.exists(LSTM_MODEL_PATH)):
#     model.load_state_dict(torch.load(LSTM_MODEL_PATH))
# else:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
epochs=30

for epoch in range(epochs):
    with tqdm(total=len(trainloader.dataset), desc ='Training - Epoch: '+str(epoch)+"/"+str(epochs), unit='chunks') as prog_bar:
        for i, data in enumerate(trainloader, 0):
            inputs = data['input'].float()
            labels = data['target'].float()
            # b x chunk_len x len(dataset.all_characters)
            optimizer.zero_grad()
            outputs, _ = model(inputs)
            target = torch.argmax(labels,dim=2)

            loss = criterion(outputs.view(inputs.shape[0]*inputs.shape[1],-1),target.view(labels.shape[0]*labels.shape[1]))
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                    max_norm=10.0)
            optimizer.step()
            prog_bar.set_postfix(**{'run:': model_name,'lr': learning_rate,
                                    'loss': loss.item()
                                    })
            prog_bar.update(batch_size)
        # Intermediate output
        sample_text = "O Romeo, wherefore art thou"
        inp = train_dataset.char_tensor(sample_text)
        sample_input = Variable(inp).cuda().unsqueeze(0).float()
        out_test = topk_sampling_lstm(model,sample_input, 300)[0]
        out_char_index = torch.argmax(out_test, dim=1).detach().cpu().numpy()
        out_chars = sample_text+"".join([train_dataset.all_characters[i] for i in out_char_index])
        print("Top-K sampling -----------------")
        print(out_chars)

        out_test = greedy_sampling_lstm(model,sample_input, 300)[0]
        out_char_index = torch.argmax(out_test, dim=1).detach().cpu().numpy()
        out_chars = sample_text+"".join([train_dataset.all_characters[i] for i in out_char_index])
        print("Greedy sampling ----------------")
        print(out_chars)


Training - Epoch: 0/30:   0%|          | 0/10000 [00:00<?, ?chunks/s]

Training - Epoch: 0/30:   0%|          | 0/10000 [00:00<?, ?chunks/s]


ValueError: not enough values to unpack (expected 3, got 2)

In [None]:
# output checkpoint
torch.save(model.state_dict(), LSTM_MODEL_PATH)

# Task 2: Character generation transformer network implementation
Our simple transformer-like network will take as input a sequence of characters and predict the next character in the sequence. To ensure an efficient training procedure, masked attention modules will be used as in the [GPT model](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf).

For this task you must implement the Scaled dot product attention module and the Masked multi-head attention module. Both of these modules are described in the [Attention is all you need](https://arxiv.org/pdf/1706.03762.pdf) paper (See Figure 2 in the paper as well as Sections 3.2.1, 3.2.2 and 3.2.3). They are the core operations of transformers. As we will use our model for text generation also add the masking operation shown as (mask opt.) in Figure 2, implemented as AttentionMasking in the code.

**Implement the modules in the ScaledDotProductAttention class and the MultiHeadAttention class.**

Read the GPT paper and the Attention is all you need paper for a better understanding of the components. For a more high level overview, this [post](https://jalammar.github.io/illustrated-gpt2/) may also be helpful.


In [79]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset
import math
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1000):
        super().__init__()
        # Positional encoding adds the positional information to the
        # embedding. Without it the model would not be able to differentiate
        # between different characters orders such as between "dog" and "god".
        position = torch.arange(max_len).unsqueeze(1).float()
        div_term = 10000.0**(torch.arange(0,d_model,2).float()/d_model)
        print(div_term.shape)
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position / div_term)
        pe[:, 1::2] = torch.cos(position / div_term)
        pe = pe.unsqueeze(0)
        self.pe = pe.cuda()
        self.pe.requires_grad = False

    def forward(self, x):
        p = self.pe[:, :x.size(1)]
        return p

class AttentionMasking(nn.Module):
    def __init__(self, max_len):
        super(AttentionMasking, self).__init__()
        self.register_buffer("mask", torch.tril(torch.ones(max_len, max_len))
                                     .view(1, 1, max_len, max_len))
    def forward(self,x):
        length = x.shape[-1]
        out = x.masked_fill(self.mask[:,:,:length,:length] == 0, float('-inf'))
        return out


class ScaledDotProductAttention(nn.Module):
    def __init__(self, max_len):
        super(ScaledDotProductAttention, self).__init__()
        self.softmax = nn.Softmax(dim=-1)
        self.mask_opt = AttentionMasking(max_len)
        self.scale = 1.0 /  math.sqrt(max_len)


    def forward(self,q,k,v):
        # Implement the scaled dot product attention as described in
        # the Attention is all you need paper in Equation 1
        
        attention = q @ k.transpose(-1, -2) * self.scale
        attention = attention + self.mask_opt(attention)
        
        attention_weights = self.softmax(attention)
        
        outputs = attention_weights @ v
        
        return outputs

class MultiHeadAttention(nn.Module):
    def __init__(self, dim_model, num_neuron, n_head, max_len):
        super(MultiHeadAttention, self).__init__()
        self.n_head = n_head
        self.num_neuron = num_neuron

        # self implemented attention
        self.attention = ScaledDotProductAttention(max_len)

        # liner tranforms in the begining of the block
        self.q_transform = nn.Linear(dim_model, n_head * num_neuron)
        self.k_transform = nn.Linear(dim_model, n_head * num_neuron)
        self.v_transform = nn.Linear(dim_model, n_head * num_neuron)

        # output linear transforms
        self.out_linear = nn.Linear(n_head * num_neuron, dim_model)
        

    def split(self,tensor):
        batch_size, length, total_dim = tensor.size()
        # Reshape the tensor to enable the use in
        # the ScaledDotProductAttention module
        split_tensor = tensor.view(batch_size, length, self.n_head, self.num_neuron).transpose(1,2)
        return split_tensor

    def concat(self,tensor):
        batch_size, num_heads, length, num_neuron = tensor.size()
        # Reshape the tensor to its original size before the split operation.
        concat_tensor = tensor.transpose(1,2).contiguous().view(batch_size, length, self.n_head*self.num_neuron)
        return concat_tensor

    def forward(self, q, k, v):
        # Apply linear layer to make them fit the corect size
        q_trans = self.q_transform(q)
        k_trans = self.k_transform(k)
        v_trans = self.v_transform(v)
        
        # Split into multiple heads with the provided function
        q_split = self.split(q_trans)
        k_split = self.split(k_trans)
        v_split = self.split(v_trans)
        
        # Process attention and merge them back
        out = self.concat(
            self.attention(q_split, k_split, v_split)
        )

        return self.out_linear(out)

class PositionFeedForwardNet(nn.Module):
    def __init__(self, dim_model):
        super(PositionFeedForwardNet, self).__init__()
        self.ff_net1 = nn.Linear(dim_model, dim_model*4)
        self.ff_net2 = nn.Linear(dim_model*4, dim_model)
    def forward(self,x):
        ff_out = self.ff_net1(x)
        ff_out = torch.nn.functional.relu(ff_out)
        ff_out = self.ff_net2(ff_out)
        return ff_out

class TransformerBlock(nn.Module):
    def __init__(self, dim_model, num_neuron, n_head, max_len):
        super(TransformerBlock, self).__init__()
        self.mha = MultiHeadAttention(dim_model, num_neuron, n_head, max_len)
        self.l_norm = torch.nn.LayerNorm(dim_model)
        self.l_norm2 = torch.nn.LayerNorm(dim_model)
        self.ff_net = PositionFeedForwardNet(dim_model)
        # b, len_seq, n_head, num_neuron

    def forward(self, x):
      # A Transformer block as described in the
      # Attention is all you need paper. In Figure 1 the transformer
      # block is marked with a gray rectangle right of the text "Nx"
      _x = x
      mha1 = self.mha(x,x,x)
      lnorm = self.l_norm(_x+mha1)
      _x = lnorm
      ff_out = self.ff_net(lnorm)
      out = self.l_norm2(ff_out+_x)

      return out

class TransformerSimple(nn.Module):
    def __init__(self, seq_length, input_dim, output_dim,
                 batch_size):
        super(TransformerSimple, self).__init__()
        num_neuron = 64
        n_head = 8
        dim_model=256
        max_len = 512
        self.start_embedding = nn.Embedding(input_dim, dim_model)

        self.pos_embedding = PositionalEncoding(dim_model)

        # b x l x c*n_head
        self.t_block1 = TransformerBlock(dim_model, num_neuron, n_head, max_len)
        self.t_block2 = TransformerBlock(dim_model, num_neuron, n_head, max_len)
        self.t_block3 = TransformerBlock(dim_model, num_neuron, n_head, max_len)
        self.t_block4 = TransformerBlock(dim_model, num_neuron, n_head, max_len)
        self.t_block5 = TransformerBlock(dim_model, num_neuron, n_head, max_len)

        #self.out_layer_1 = nn.Linear(dim_model, dim_model)
        self.output_layer = nn.Linear(dim_model,output_dim)
        

    def forward(self,x):
      # x - Tensor - (b, seq_len)
      # Embeds the input tensor from tokens to features
      s_emb = self.start_embedding(x)
      # Adds positional embeddings
      p_emb = self.pos_embedding(s_emb)
      b_out = p_emb + s_emb
      # Transformer blocks - You can experiment with varying depth
      # For example GPT uses 12 blocks but this might be a bit memory intensive
      b_out = self.t_block1(b_out)
      b_out = self.t_block2(b_out)
      b_out = self.t_block3(b_out)
      b_out = self.t_block4(b_out)
      b_out = self.t_block5(b_out)

      # Output mapping to a classification of output tokens
      # For each token the network tries to predict the next token
      # based only on the previous tokens.
      # Output shape: (b x seq_len x vocabulary_size)
      out = self.output_layer(b_out)

      return out


## Dataset class


In [80]:
import unidecode
import string
import random
from torch.autograd import Variable
from torch.utils.data import Dataset

class TextDataset(Dataset):
    def __init__(self, chunk_len=200, padded_chunks=False):
        # Character based dataset
        dataset_path = "./input.txt"
        # The tokens in the vocabulary (all_characters)
        # are just the printable characters of the string class
        self.all_characters = string.printable
        self.n_characters = len(self.all_characters)
        # Maps characters to indices
        self.char_dict = {x:i for i,x in enumerate(self.all_characters)}
        self.file, self.file_len = self.read_file(dataset_path)
        # Sequence length of the input
        self.chunk_len = chunk_len
        self.encoded_file = [self.char_dict[x] for x in self.file]

    def read_file(self,filename):
        file = unidecode.unidecode(open(filename).read())
        return file, len(file)

    def encode_text(self,in_str):
        # in_str - input sequence - String
        # Returns - in_str mapped to tokens in char_dict
        tensor = torch.LongTensor([self.char_dict[x] for x in in_str])
        return tensor

    def __getitem__(self, idx):
        inp, target = self.get_random_text()
        return {"input":inp, "target":target}

    def __len__(self):
        return 10000

    def get_random_text(self):
        # Pick a random string of length self.chunk_len from the dataset
        start_index = np.random.randint(0, self.file_len - self.chunk_len)
        end_index = start_index + self.chunk_len + 1
        chunk = self.encoded_file[start_index:end_index]
        # input_tokens - random sequence of tokens from the dataset
        input_tokens = torch.LongTensor(chunk[:-1])
        # target - input token sequence shifted by 1
        # the idea is to predict next token for each token in the input sequence
        # therefore if the input is [1,2,3,4] the target is [2,3,4,5]
        target = torch.LongTensor(chunk[1:])
        input_tokens = input_tokens.cuda()
        target = target.cuda()
        return input_tokens, target


## Character sampling

To generate text the network must predict the next character in a sequence, however networks do not produce a single character but rather estimate the likelihood for each possible character. Sampling characters from the network output can be done in different ways with common ones being the Greedy sampling process and Top-K sampling.

In the simple greedy sampling method the network takes a text prompt as input and generates an additional N tokens by always taking the token with the highest prediction score as the next token.

In the Top-K sampling, randomness is added to the sampling process as the network samples from K most likely predicitons at each step. This alleviates the problem of generative models repeating text but may generate incorrect text by sampling inappropriate tokens.


In [81]:
def topk_sampling_iter_transformer(model, x, num_chars, chunk_len, output_token):
    # x -- b x onehot_char
    # x = b x l
    outputs = torch.zeros((1,num_chars))
    inp = x

    for t in range(num_chars):
        # b x onehot_char
        output = model(inp.long())[0,-1:]
        #output = torch.softmax(output, dim=1)
        # b x 3
        output_vals, output_ind = torch.topk(output, 5, dim=1)
        # 3 -> int
        output_vals = torch.softmax(output_vals, dim=1)
        top_ind = torch.multinomial(output_vals[0], 1)[0]
        # int
        out_char_index = output_ind[0,top_ind]
        # int -> 1
        out_char_index = torch.ones(1).cuda() * out_char_index

        outputs[:,t] = out_char_index.item()
        if inp.shape[1] > chunk_len:
          inp = torch.cat((inp[:,1:], out_char_index.unsqueeze(0)), dim=1)
        else:
          inp = torch.cat((inp, out_char_index.unsqueeze(0)), dim=1)

    return outputs


def greedy_sampling_iter_transformer(model, x, num_chars, chunk_len, output_token):
    # x -- shape (batch, tokens in x)
    outputs = torch.zeros((1,num_chars))
    inp = x

    for t in range(num_chars):
        # b x l x onehot_char
        output = model(inp.long())[0,-1:]
        output = torch.softmax(output, dim=1)
        out_char_index = torch.argmax(output, dim=1)
        outputs[:,t] = out_char_index.item()
        if inp.shape[1] > chunk_len:
          inp = torch.cat((inp[:,1:], out_char_index.unsqueeze(0)), dim=1)
        else:
          inp = torch.cat((inp, out_char_index.unsqueeze(0)), dim=1)

    return outputs


## Transformer model training

With a correct implementation you should get sensible text generation results with the set parameters, however you should experiment with various parameters,
especially with the sequence length (chunk_len) used during training.

In [82]:
from tqdm import tqdm
import torch.optim as optim


#Sample parameters, use whatever you see fit.
batch_size = 256
chunk_len = 128
train_dataset = TextDataset(chunk_len=chunk_len)
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=0)

input_dim = train_dataset.n_characters
output_dim = train_dataset.n_characters
learning_rate = 0.0006

model = TransformerSimple(chunk_len, input_dim, output_dim,batch_size)
model.train()
model.cuda()

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

epochs=10

for epoch in range(epochs):
    with tqdm(total=len(trainloader.dataset), desc ='Training - Epoch: '+str(epoch)+"/"+str(epochs), unit='chunks') as prog_bar:
        for i, data in enumerate(trainloader, 0):
            # inputs - shape (batch_size, chunk_len) - Tensor of vocabulary tokens
            inputs = data['input'].long()
            # labels - shape (batch_size, chunk_len) - Tensor of vocabulary tokens
            labels = data['target'].long()

            optimizer.zero_grad()
            outputs = model(inputs)
            target_t = labels
            loss = criterion(outputs.view(inputs.shape[0]*inputs.shape[1],-1),target_t.view(labels.shape[0]*labels.shape[1]))
            loss.backward()
            optimizer.step()
            prog_bar.set_postfix(**{'run:': "Transformer", 'lr': learning_rate,
                                    'loss': loss.item()
                                    })
            prog_bar.update(batch_size)

        # Intermediate text output
        sample_texts = ["What authority surfeits on",
                        "I say unto you, what he hath done famously, he did it to that end:",
                        "That in submission will return to us: And then, as we have ta'en the sacrament,"]
        output_token = torch.zeros(1,1).cuda()
        output_token[0,0] = train_dataset.n_characters-1
        print("Top-K sampling")
        for sample_text in sample_texts:
            sample_encoding = train_dataset.encode_text(sample_text)
            sample_input = Variable(sample_encoding).cuda().unsqueeze(0).long()

            #out_test= greedy_sampling_iter_transformer(model, sample_input, 400, chunk_len, output_token)[0]
            out_test= topk_sampling_iter_transformer(model, sample_input, 400, chunk_len, output_token)[0]
            out_char_index = out_test.long().detach().cpu().numpy()
            out_chars = sample_text+" "+"".join([train_dataset.all_characters[i] for i in out_char_index])
            print("----------------------------------------")
            print(out_chars)




torch.Size([128])


Training - Epoch: 0/10: 100%|█████████▉| 9984/10000 [00:12<00:00, 788.91chunks/s, loss=2.56, lr=0.0006, run:=Transformer]

Top-K sampling
----------------------------------------
What authority surfeits on  me s thit mere m s t tho s mind se thes hange we at ate te thin she my me that me se al ar t s my arorind we miche and,


Ant mathathe s ather touthear man me wathillerd alanghour thar tho tind s t stishe she s she t mant w s t s my wind tin an mathit s the thirs an meris s my shan mathou m t my, thour sthar tin my shanondeangr astheand, an s m sto m manderoreathe m m arenghat s al ser angh whe m
----------------------------------------
I say unto you, what he hath done famously, he did it to that end: 

I meshe me thasthathingout s w t wher sthise t me wh se te at as windsthas ar t t thale me me wice wirithas mand t then ar h theantheshand antind s ticoule s al st t me aser me my me st ateren t tith w my t m st masthe thingen t thang ale ast win se there thal mind thin t me weroton my meast wingous me s me thato alle me mear th my mash thotis m athiles w an whit t s we atind anor thas s mard th


Training - Epoch: 0/10: 10240chunks [00:18, 552.91chunks/s, loss=2.56, lr=0.0006, run:=Transformer]                      


----------------------------------------
That in submission will return to us: And then, as we have ta'en the sacrament,  te t s mit marererd theard at wash allllerind ateng wichan m thiche athise myo tis she wis sh s the thot m se t th he st s t m andinds se as an wat s wear s weritous hous whour wil sise t s manthashanouris my t wind s t an me angrenge ase thill me shean s wire m mar the an t ate t we alisheanth withit t my, at t w t w win theatheat weser tharishe wis sear mer and as sh wealesease we are an tharer


Training - Epoch: 1/10: 100%|█████████▉| 9984/10000 [00:12<00:00, 803.01chunks/s, loss=2.41, lr=0.0006, run:=Transformer]

Top-K sampling
----------------------------------------
What authority surfeits on  t ho tithe thild s s se the ate t ore on as mere thend thir hesther thin thor the me t mas t soullle antom al ass athilllir s t boul alll seathotong bust thor me sthas t s sthes tertin tiste bareands t s thin t s asthe seanes ar te athaten tistound t sthile thenting mase t mat the mind t tous stheate t m sen toull te buner as t ase at ar ane tengor thor theant bend athesthe s t alle tin be bre be
----------------------------------------
I say unto you, what he hath done famously, he did it to that end: 
And alou to he st mingo thow mate s thes,
I alerd, her thin hand bund ald myore me hond be tous bo atho s angr misther tomyon ar bu stounde t sse alon men me stho t be bestis be ther tom t sthes be se s mango and silis m tent blllin bre to se ters bl ates t thoth miler ber s an berst stondind me tour sthe an s thathe thotire thatothingr mane t bere stis s
Anengharert s se s thes test man sthesend


Training - Epoch: 1/10: 10240chunks [00:18, 553.27chunks/s, loss=2.41, lr=0.0006, run:=Transformer]                      


----------------------------------------
That in submission will return to us: And then, as we have ta'en the sacrament,  our tond totit t anors,
Thist t his s maliril ato to s t malir thang thast athit at at thit and,
The tir t atothande br s st be t ano te s s sthorers be arere ss serire the tors an to tis my
Allind thalang as, ass br bate anthas,
Thare stis meres, as my.
Malend,

My,

Whithit t my s al by buthe s te at stissthing t mound t at mo tor ath moutes tomouthanou t tomend ar athir t sthasst my
Ard me te 


Training - Epoch: 2/10: 100%|█████████▉| 9984/10000 [00:12<00:00, 785.22chunks/s, loss=2.25, lr=0.0006, run:=Transformer]

Top-K sampling
----------------------------------------
What authority surfeits on  and has there thall he
I har hean the the marer s alore wins
An the s alome she mas sthind m tin hen alis,
s sthallll titime ts mange ss areates s maro to s andis, the ss thereeres,
An heat he t hiromat, meng t s tharin s sheas merere at merare s heseange st heares s
Thasen me my thitheas athe t as s has han ssss hindares here matile me s tistis ande t mange ands
Thitere s, therat he s wind shate
----------------------------------------
I say unto you, what he hath done famously, he did it to that end:  han wimar then
Trir se thus thasth te hangre hous,
Anche whe s himen heealll angst hat sstis, t areand, te atimitis ang, tes h witime s
Angheand t thalale, me thouse s, hent healis we se as s hanengan tere s
s anote therar he marere m atomendes an at hear s thy sheeeround heas m t mitin hisend, tes, athe henthes,
Thean t henounghe ss has mirinds mar mas thanghe, ssean te as an thistes my marith t


Training - Epoch: 2/10: 10240chunks [00:18, 549.35chunks/s, loss=2.25, lr=0.0006, run:=Transformer]                      


----------------------------------------
That in submission will return to us: And then, as we have ta'en the sacrament, 
To the hast the thass s andeangenger ath ar wo tores ss hale s tomaran ar the than sterangh man he to my mishin th thithe this
Ther mathe se he shende thear me te here hinthareng s as thes aller ss thes thares s s, tatiserest
The t thit theat mas tharist at hasssest t s shon teanthe s,
Thas s wofeat alle the man tinghe s an are ales s meather merongere aresterte ar se athearer s seatere st
s andi


Training - Epoch: 3/10: 100%|█████████▉| 9984/10000 [00:12<00:00, 786.61chunks/s, loss=2.12, lr=0.0006, run:=Transformer]

Top-K sampling
----------------------------------------
What authority surfeits on enth and
tourd sist hathis tound, wis toundesen
This to henow arith he tof my hours aves tande theall
Whave hise ast ofive an beresisesedstes teneed
Thimous ouge ancheneser t oreeale thyof henener broust
Ant hind therar andis haver t seis ancard ous then arise
Ind, ane aneses o thof ha anour hanes t bea o ondeed
I thaner,
Andis thive and an tashis hean athe bleerent teease

Then asend anonened ato
----------------------------------------
I say unto you, what he hath done famously, he did it to that end: 
I's mat than to she ave soulld sownt our han sim house
The hat seange t anere anen henomenged, tousstht tathat
Tan buse andound oustont aver hyowest heree tee ond anoullle
The oune ount onowst then beartessse aris t thasees t
Wit heanderd hive aves,
Henghy speatye omere avind, ben beset ondsthenen
Wese t sispeangsty anoued t oun thind heaneds of theeas
Thesser bene tiste aras tono toulld, then he


Training - Epoch: 3/10: 10240chunks [00:18, 550.26chunks/s, loss=2.12, lr=0.0006, run:=Transformer]                      


----------------------------------------
That in submission will return to us: And then, as we have ta'en the sacrament, 
Thy terare his hingne tated steees oull hit.

Thas an t hyourest asp thanded, heit ath bee arer beaste heand.
I anand I hanough t thoouer hthan t heist,
Isthen t tinoule be bratowes and tenous t tashiss ouses
Whousted t sono himeer hist ate teee tyoullds.

Andis seast ano hese and heanes, hand as spave beave
Issth oueend heand ave hee tinde outous and henendes
Ise as beard ome and ounes an beaned


Training - Epoch: 4/10: 100%|█████████▉| 9984/10000 [00:12<00:00, 794.16chunks/s, loss=2.04, lr=0.0006, run:=Transformer]

Top-K sampling
----------------------------------------
What authority surfeits on tred, and hor as whavengind and,
To min histe the sen mary my hereser ontand weald a thase where
And heand the him touse hthaves omoun onchend,
Wis an hore owase tt heenger oneansthios owar thast.


CLAUCELIUS:
And my hispessesen ta ofee hise as myon a heant.
An Reass my anousent his splean an toupse teand
A o myshe my hise hye trean ame hindseng trinces
I sthentes hourd ano of him thissess.
This 
----------------------------------------
I say unto you, what he hath done famously, he did it to that end: 
Thist theard ay severences hee areass offortheres.

PERIO:
Thavers to hea m anof hyond sheave ad a shan ands
Aspe ontr theay ass andserd and anoter and a thises
s hea tea he thee as tra thime omppe send ifracesss
Tofore ano indithis are o theasest.

An ELICHABET:
Ay whor hyou avengnce tooman t hofee her thofe
I him sthee tere t o hear the ofee sof hofour,
An hyone avasht han heand shy the team t 


Training - Epoch: 4/10: 10240chunks [00:18, 552.59chunks/s, loss=2.04, lr=0.0006, run:=Transformer]                      


----------------------------------------
That in submission will return to us: And then, as we have ta'en the sacrament, 
In and the herees and there hear tinges and
Ast hind theass herar and thendes hand amin owend asthere
Wo he hast asear te ofend trend tofee arrrthe toue hear
Tomenes ane andy ofoned and honomer tasps an toone,
A han wim hy shin own ane theate own hent harid ome
And sthen trimenes on hea ast hand he and ond ares
And weand the a hend anen one hean here heave,
And winde hise ass a hand therer teas h


Training - Epoch: 5/10: 100%|█████████▉| 9984/10000 [00:12<00:00, 779.96chunks/s, loss=1.88, lr=0.0006, run:=Transformer]

Top-K sampling
----------------------------------------
What authority surfeits on er ontard,
As well to my haverthouns, will and sperss,
But the shee strest themst ander'd ther wis my drecot.


ClORUMNIO:

My whest word art,
I malllty truesty, annd toursessts, o stof willitstst
I trevestt ave a themppe on to tof splo trisest
Who the the that at ande aswnd allo te heatst
Wis that at ournchts oforrd, werrd heldsst tinde.
Wit My, trerry thou herrd atherts orer hand.
Whowndsh shall
----------------------------------------
I say unto you, what he hath done famously, he did it to that end: 
Thaven, to warder all the ther my serrdest,
Whill word he hom wnentst ofo thas ourt,
Buntt are won withelld ouch a a alll tontsssend feertth.

What Clisttt, ay mond sheat theence wend
To mance otatte in tound, the wit havencht the imard,
Whe his ade t soulld thanckst at o wontht wa thith
The sheaved d sprestst oundsir of tof tree t
Who hy sthent ourth theer at, ta toffte to whee
Whertst were toff


Training - Epoch: 5/10: 10240chunks [00:18, 548.88chunks/s, loss=1.88, lr=0.0006, run:=Transformer]                      


----------------------------------------
That in submission will return to us: And then, as we have ta'en the sacrament, 
As ward wand my the so deved arnd word this at
that t hyo herld an twellt o heastst are to this
Then peat it thengh tree tthath tho and tallly a a a dond.
A Stharrvou dof ar ane as thist and weastct trestst,
Withtch wis thingnght thisst trit o spevichst
Whath arid wand wandlly by thee t sha oulllt and,
I wear that serrtt ifoong honghird,
Whit, thought tinde a an and terrt windlst,
I my worrd ashi


Training - Epoch: 6/10: 100%|█████████▉| 9984/10000 [00:12<00:00, 782.10chunks/s, loss=1.81, lr=0.0006, run:=Transformer]

Top-K sampling
----------------------------------------
What authority surfeits on  ans bright,
As bure sone a with him hone off meent of mye.

KING RICHARD III:
Ay me sight the sake would has wand them aterre and thee
A spoffftires on and a wa whis one of hear,
Whath spare tandy ale tharm he arm on ast teall and
Off wit of at a mastcend your toundse, aye beash'd,
Stird alo of an that and a whan bearnce ound.


KING RICHARD:
HIIf h mam he seam orme wande thit tooo ommme.


DUCHO
----------------------------------------
I say unto you, what he hath done famously, he did it to that end: 
The wort wall a with tho well woul a be so winded.
And thy sore wende, are it ande ond thele torld
And amally ont amirtty allid,
On thar as thame wend ofe would weay as heasse
I wit ore wisth ounghtiong on thame orreanth,
Thereint off the in tomee,--beancht and win hath at tremante, as wouch's and the
Whe that heave offf and thime is tond outht an
theme oforthe they was witheres oreanth wit ofthi


Training - Epoch: 6/10: 10240chunks [00:18, 551.24chunks/s, loss=1.81, lr=0.0006, run:=Transformer]                      


----------------------------------------
That in submission will return to us: And then, as we have ta'en the sacrament, 
We and the whan she sporeast and a brown
As should tree att arert thonk an weante ist,
Inde tomee think of thee our ar toff theem,
Ang wit as were as a the and ofell on oust.
There hase ave off hince theat, ore and and thome
As he the shee weas ore aspireniciarions of head,
That I stree ivow an a as a whybeare armace ount,
Whit wit are on tof take amirt ore amppost
Of ther hane our an our thasst 


Training - Epoch: 7/10: 100%|█████████▉| 9984/10000 [00:12<00:00, 799.49chunks/s, loss=1.7, lr=0.0006, run:=Transformer] 

Top-K sampling
----------------------------------------
What authority surfeits on  and to sell.
Botcke you: there steervet me shall, attered,
And to me sears. O a sare, to this the hearth,
And ware the tought are and the are of hof theest,
And the toure though has bow oran and and,
And be the are ofther ache areasts it heared,
Orether wince hims treand, ond an beatther,
I seaven, or and theathen, thoushe he head,
Of with wife to ale ove ore other our hand,
And seay tallle to an
----------------------------------------
I say unto you, what he hath done famously, he did it to that end: 
Who tell that wartedere to said, was warent to my son.
Then ane welll so to the an the sare, and thows here
Tould tongumes, and be thowse illes wead,
In wan we has andst wheat, and the wering hasst.
Whond an woult, then, I ast and word at alay ond this
Tones all the and on steeed to shand and tand the telll
Ouse to and so heaved hinghe if an worth his wonest.
Arot:
I welll, wear in windomer are o


Training - Epoch: 7/10: 10240chunks [00:18, 553.25chunks/s, loss=1.7, lr=0.0006, run:=Transformer]                      


----------------------------------------
That in submission will return to us: And then, as we have ta'en the sacrament, 
And here thangets of hanguereds, sorret heart.
I woulll weare, and, at I thoust hand ofing,
That thims soraninst arught, oun,
This at bloookes ane tofend theares of thee.


BOLINDEO:
Wetweren at if anccame the and,
Wand the weare how sarrre owes shille awe sto
Tond selling sealings, and toue therencand tates,
Werre the are ign torath at herllieds and and
As and are thim s shayst fore, at warere,



Training - Epoch: 8/10: 100%|█████████▉| 9984/10000 [00:12<00:00, 711.29chunks/s, loss=1.55, lr=0.0006, run:=Transformer]

Top-K sampling
----------------------------------------
What authority surfeits on  her speass folk my shall mother,
Where may statther both seat me strans
Thy be sensen the for shalt my son or on thim
Were hor onour athe andverting them our tare.
And das speed on of sherms,
O sor and moves ore helpser are ims alll as as
ddarrd im, the toumb dofth and the opes,
O of the thouse shire and his onste of of shigld.


BRAMPSTEY:
Whas dredged.



MAUCAREN:
Itheren, is this this denerat
----------------------------------------
I say unto you, what he hath done famously, he did it to that end: 
Which thour stand mother shir his stand all for
The chird toons the oward thy ound spent on the
Tourn sonstran o hither.'s I weard as as onde wofes.
Thave and on andird thence ore spirs.
The desep irems as as offfffircerss
To mene at our of heasps your ourse.


GRUMIO:
Nay, meay mead is the them, forsh, treminds,
Is denem othing of the thare.


PENETER:

Weshy has denombed are of athishen ass a t


Training - Epoch: 8/10: 10240chunks [00:18, 544.98chunks/s, loss=1.55, lr=0.0006, run:=Transformer]                      


----------------------------------------
That in submission will return to us: And then, as we have ta'en the sacrament, 
There staing them many the sonss a my shar.

BRUCKINGHANG:
Welll, so mee spropist of mis theme on tof oust,
The and ormes aragand as and as a a devidin
I the out dong one atrmisted.


BRUCKENGHARORD:
As I darrm, me as a off them.


Prevent:
Where and arm as offftaicess.


PROMPERO:

MEddgham, this

Wheren:
Weld marrd to to men at a ass thygarrl and
Is say to ammpps then is teld.


DUCHESSSSS:
Ay,


Training - Epoch: 9/10: 100%|█████████▉| 9984/10000 [00:12<00:00, 780.70chunks/s, loss=1.56, lr=0.0006, run:=Transformer]

Top-K sampling
----------------------------------------
What authority surfeits on 
That that her but and the power thy, both stand.

KING RICHARD III:
Ay, sir? his those!

MERCUTIUS:
My darrre toff thy your
The sowe our of or acclliaie.


MESCUMIUS:
Nay am send and thee terell impond,
I would womb an our of our tresth,
My anny wish our athen his are the bouth
Are wencesss, wis an my trelfiaght.


Kerspeat:
Here thare ond tingle:
I am wish our are our hear at thouse then,
Our se
----------------------------------------
I say unto you, what he hath done famously, he did it to that end: 
I thank his me show hom his her hurse?

First My Moscenaly:
And, werat if atwirst thim angreroan thim.


MARARGHEN:
Welll, I mearr, the I have I wom woesh:
My hare iss freather a wourd with as though.


PEOMNEY:
Were thown:
Whis thate are in thouse bacck at in itre.


MORCUSTIO:
Wenthing ano man ore of acccissaly?

MERCUMIO:
My dourtur once torrme, and word I'lll again.


KING RICHARD:
What I wer


Training - Epoch: 9/10: 10240chunks [00:18, 542.71chunks/s, loss=1.56, lr=0.0006, run:=Transformer]                      

----------------------------------------
That in submission will return to us: And then, as we have ta'en the sacrament, 
To so murder of a comes man my breach,
Where a slive an ass ime.


MENENENIUS:
I am then it tof and that and.


KING RWAMP:
Witch, is ir myow lord merrrie it.


KIRGHAM:
Misshaged is merrce it the and in tolly.


KINGHARD IIII:
No welll I, heav I well I am and agrry.


KARINGHARD:
And wenre tire:
Then ange thare I tount thim sence ome
And senatiur an werre there.


GLOUCOMESTE:
Near, there ance o





## Text sampling - Transformers


In [85]:
sample_text = "To be or not to be?"
sample_encoding = train_dataset.encode_text(sample_text)
sample_input = Variable(sample_encoding).cuda().unsqueeze(0).long()

out_test= greedy_sampling_iter_transformer(model, sample_input, 400, chunk_len, output_token)[0]
# out_test= topk_sampling_iter_transformer(model, sample_input, 400, chunk_len, output_token)[0]
out_char_index = out_test.long().detach().cpu().numpy()
out_chars = sample_text+" "+"".join([train_dataset.all_characters[i] for i in out_char_index])
print("----------------------------------------")
print(out_chars)


----------------------------------------
To be or not to be? 

MENENIUS:
I will not the so the so the so the some.

MENENIUS:
I will the some of the the some of the some
To then are of the our an the our are of thee
The wear our an the our are of the our are
The our off our the our an the our are of thee
And the and the our an the our are of thee
And the and the our an the our are of thee
And the and the our an the our are of thee
And the and the our an the
