## Text Generation with PyTorch
Generate Text based on Shakespeare text

### Import libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import Counter
import os
from argparse import Namespace

In [20]:
flags = Namespace(
    train_file='Shakespeare.txt',
    seq_size=32,
    batch_size=16,
    embedding_size=64,
    lstm_size=64,
    gradients_norm=5,
    initial_words=['I', 'am'],
    predict_top_k=5,
    checkpoint_path='checkpoint',
)

### Process the raw data
Steps:
- Input dataset is a text data file
- Text should be splitted into word tokens to train a word-based model
- Convert word tokens into integer indices. These will be the input to the network
- Train a mini-batch each iteration

In [41]:
def get_data_from_file(train_file, batch_size, seq_size):
    with open(train_file, 'r') as f:
        text = f.read()
    text = text.split()
    
    word_counts = Counter(text)
    sorted_vocab = sorted(word_counts, key=word_counts.get, reverse=True)
    int_to_vocab = {k: w for k,w in enumerate(sorted_vocab)}
    vocab_to_int = {w: k for k,w in int_to_vocab.items()}
    n_vocab = len(int_to_vocab)
    
    print('Vocabulary size', n_vocab)
    
    int_text = [vocab_to_int[w] for w in text]
    num_batches = int(len(int_text) / (seq_size * batch_size))
    in_text = int_text[:num_batches * batch_size * seq_size]
    
    #target data for the network to learn
    #the target of each input word will be its consecutive word
    #shift the whole input data to the left by one step
    out_text = np.zeros_like(in_text)
    out_text[:-1] = in_text[1:]
    out_text[-1] = in_text[0]
    in_text = np.reshape(in_text, (batch_size, -1))
    out_text = np.reshape(out_text, (batch_size, -1))
    return int_to_vocab, vocab_to_int, n_vocab, in_text, out_text

In [42]:
##generate batches for training
def get_batches(in_text, out_text, batch_size, seq_size):
    num_batches = np.prod(in_text.shape) // (seq_size * batch_size)
    for i in range(0, num_batches * seq_size, seq_size):
        yield in_text[:, i:i+seq_size], out_text[:, i:i+seq_size]

In [14]:
int_to_vocab[0]

'the'

In [15]:
in_text[:10]

array([[1670,  966, 1671, ...,  414,    7,   49],
       [  46,   10,   26, ...,    1,  732,    7],
       [ 190,    0,  239, ..., 1187,   10,   29],
       ...,
       [   4, 3102,   19, ...,   58,  328,    1],
       [3305, 3306,  145, ...,    3,   45,  124],
       [   1, 3532, 3533, ...,   42, 3731,    5]])

In [16]:
out_text[:10]

array([[ 966, 1671,   93, ...,    7,   49,   46],
       [  10,   26,    0, ...,  732,    7,  190],
       [   0,  239,  554, ...,   10,   29,   56],
       ...,
       [3102,   19,   49, ...,  328,    1, 3305],
       [3306,  145,   23, ...,   45,  124,    1],
       [3532, 3533,    0, ..., 3731,    5,  399]])

### Model
- create a subclass of torch.nn.Module
- define the necessary layers in __init__ method
- implement the forward pass within forward method

In [48]:
class RNNModule(nn.Module):
    def __init__(self, n_vocab, seq_size, embedding_size, lstm_size):
        super(RNNModule, self).__init__()
        self.seq_size = seq_size
        self.lstm_size = lstm_size
        self.embedding = nn.Embedding(n_vocab, embedding_size)
        self.lstm = nn.LSTM(embedding_size,
                            lstm_size,
                            batch_first=True)
        self.dense = nn.Linear(lstm_size, n_vocab)
        
    #will take an input sequence and the previous states 
    #and produce the output together with states of the current timestep
    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.dense(output)

        return logits, state
    
    # reset states at the beginning of every epoch
    def zero_state(self, batch_size):
        return (torch.zeros(1, batch_size, self.lstm_size),
                torch.zeros(1, batch_size, self.lstm_size))    

### Loss

In [44]:
#a loss function and a training op
def get_loss_and_train_op(net, lr=0.001):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)

    return criterion, optimizer

### Training
- get training data
- create the network
- create loss and training op
- for each epoch, we will loop through the batches to compute loss values and update network’s parameters
- Call the train() method on the network’s instance (it will inform inner mechanism that we are about to train, not execute the training)
- Reset all gradients
- Compute output, loss value, accuracy, etc
- Perform back-propagation
- Update the network’s parameters

In [50]:
def predict(net, words, n_vocab, vocab_to_int, int_to_vocab, top_k=5):
    net.eval()

    state_h, state_c = net.zero_state(1)

    for w in words:
        ix = torch.tensor([[vocab_to_int[w]]]).to(torch.int64)
        output, (state_h, state_c) = net(ix, (state_h, state_c))
    
    _, top_ix = torch.topk(output[0], k=top_k)
    choices = top_ix.tolist()
    choice = np.random.choice(choices[0])

    words.append(int_to_vocab[choice])
    
    for _ in range(100):
        ix = torch.tensor([[choice]]).to(torch.int64)
        output, (state_h, state_c) = net(ix, (state_h, state_c))

        _, top_ix = torch.topk(output[0], k=top_k)
        choices = top_ix.tolist()
        choice = np.random.choice(choices[0])
        words.append(int_to_vocab[choice])

    print(' '.join(words))

In [52]:
def main():
    int_to_vocab, vocab_to_int, n_vocab, in_text, out_text = get_data_from_file(
        flags.train_file, flags.batch_size, flags.seq_size)
    
    net = RNNModule(n_vocab, flags.seq_size,
                    flags.embedding_size, flags.lstm_size)  
    
    criterion, optimizer = get_loss_and_train_op(net, 0.01)
    
    iteration = 0
    
    if not os.path.exists(flags.checkpoint_path):
        os.mkdir(flags.checkpoint_path)
    
    for e in range(50):
        batches = get_batches(in_text, out_text, flags.batch_size, flags.seq_size)
        state_h, state_c = net.zero_state(flags.batch_size)
        
        for x,y in batches:
            iteration += 1
            # Tell it we are in training mode
            net.train()
            # Reset all gradients
            optimizer.zero_grad()
            
            x = torch.tensor(x).to(torch.int64)
            y = torch.tensor(y).to(torch.int64)
            
            logits, (state_h, state_c) = net(x, (state_h, state_c))
            loss = criterion(logits.transpose(1, 2), y)
            
            state_h = state_h.detach()
            state_c = state_c.detach()
            
            loss_value = loss.item()

            # Perform back-propagation
            loss.backward()
            
            # gradient clipping
            _ = torch.nn.utils.clip_grad_norm_(
                net.parameters(), flags.gradients_norm)

            # Update the network's parameters
            optimizer.step()
            
            # print loss value to console
            if iteration % 100 == 0:
                print('Epoch: {}/{}'.format(e, 200),
                      'Iteration: {}'.format(iteration),
                      'Loss: {}'.format(loss_value))            
            
            # the model generate some text 
            if iteration % 1000 == 0:
                predict(net, flags.initial_words, n_vocab,
                        vocab_to_int, int_to_vocab, top_k=5)
                torch.save(net.state_dict(),
                           'checkpoint/model-{}.pth'.format(iteration))

In [53]:
main()

Vocabulary size 5232
Epoch: 2/200 Iteration: 100 Loss: 6.114497184753418
Epoch: 5/200 Iteration: 200 Loss: 4.590234756469727
Epoch: 7/200 Iteration: 300 Loss: 3.6697957515716553
Epoch: 10/200 Iteration: 400 Loss: 3.264127254486084
Epoch: 12/200 Iteration: 500 Loss: 2.620922803878784
Epoch: 15/200 Iteration: 600 Loss: 2.296342134475708
Epoch: 17/200 Iteration: 700 Loss: 2.0858142375946045
Epoch: 20/200 Iteration: 800 Loss: 1.9138472080230713
Epoch: 23/200 Iteration: 900 Loss: 1.5676251649856567
Epoch: 25/200 Iteration: 1000 Loss: 1.297070026397705
I am scarcely twins, when an absolute year. This, for that purpose. The Merchant for refusing and to be legally to be a on time the First Part perhaps from the Poet stood to be thought in certain property; thing but an old English paraphrase as a poet. to the demands by its origin, and there can well shown them all I felt, found it on its victims. been acted at Whitehall has vitiated but at that year, they are not then in all been beholden to 