# Training a text generator with Word-RNN 

In this notebook I am using the code from NLP Week 7 to train my RNN.

In [10]:
import torch
import random

import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from torch.distributions import Categorical

In [11]:
device = 'cpu'

**Setting the hyperparameters:**

Whilst working with this notebook I was going back to the below cell to change the Hidden State Size and Number of layers.  

In [43]:
hidden_size = 512   # size of hidden state
num_layers = 3      # number of layers in LSTM layer stack
gen_seq_len = 100   # length of LSTM sequence
temperature = 10    # how random do we want our predictions to be
load_path = "/Users/loiskelly/Documents/GitHub/LoisNLPProject/all_data/Beatles Word RNN Model 5.pt"

***Defining the network***

In [44]:
class RNN(nn.Module):
    def __init__(self, input_size, output_size, hidden_size, num_layers):
        super(RNN, self).__init__()
        self.embedding = nn.Embedding(input_size, input_size)
        self.rnn = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
        self.decoder = nn.Linear(hidden_size, output_size)
    
    def forward(self, input_batch, hidden_state):
        embedding = self.embedding(input_batch)
        output, hidden_state = self.rnn(embedding, hidden_state)
        output = self.decoder(output)
        return output, (hidden_state[0].detach(), hidden_state[1].detach())

***Setting up network and optimiser:***

This is from the note books: 

[PyTorch tensors](https://pytorch.org/docs/stable/tensors.html) have been designed to work in almost exactly the same way as [numpy arrays](https://numpy.org/doc/stable/reference/generated/numpy.array.html).

In [45]:
checkpoint = torch.load(load_path)

# Load word_to_ix and ix_to_word dictionaries from checkpoint file
word_to_ix = checkpoint['word_to_ix']
ix_to_word = checkpoint['ix_to_word']

# Calculate vocab size
vocab_size = len(word_to_ix)

# Instantiate RNN
rnn = RNN(vocab_size, vocab_size, hidden_size, num_layers).to(device)

# Load model weights from checkpoint file 
rnn.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

#### Generate a random sequence

In [46]:
with torch.no_grad():
    hidden_state = None
    
    #Pick a random starting word
    random_start = np.array(random.choice(list(word_to_ix.values())))
    
    # Convert to PyTorch Tensor
    input_seq = torch.tensor(random_start, dtype=torch.int64)
    
    # Change dimensionality of tensor for PyTorch compatibility
    # For more info on this function see: https://stackoverflow.com/questions/57237352/what-does-unsqueeze-do-in-pytorch
    input_seq = input_seq.unsqueeze(0).unsqueeze(0)

    # Iterate over our sequence length
    for i in range(gen_seq_len):
        # Forward pass
        output, hidden_state = rnn(input_seq, hidden_state)
        
        # Construct categorical distribution and sample a word
        output = F.softmax(torch.squeeze(output), dim=0)
        dist = Categorical(output / temperature)
        index = dist.sample()
        
        # Print the sampled word
        print(ix_to_word[index.item()], end=' ')
        
        # Next input is current output
        input_seq[0][0] = index.item()

of all comes Well I feel happy 

just it did you girl you hurt me if you know how hard to understand cause Its been unhappy and hoping Im changing it up your feet back home things in love never be alright? Alright? Alright? You better The girl Here there It feels like Ive waited only goes on your head Bang bang Maxwells silver hammer Came very long Please dont know my own mind I could be sad just of ten its all I said move over Beethoven Roll over Even throws it calls me oh yeh Imagine Im with 

***Map string to indexes***

In [47]:
def map_str_to_ix(input_str):
    index_list = []
    for word in input_str.split():
        ix = word_to_ix.get(word, None)
        if ix is not None:
            index_list.append(ix)
        else:
            print(f'The word {word} is not in the dictionary')
    return index_list

Creating the index list and converting it to a numpy array.
Here is where I came back and updated the input string for each test. 

In [52]:
input_str = 'Why hello there'
index_list = map_str_to_ix(input_str)
print(f'Our list is: {index_list}')

Our list is: [984, 2158, 3582]


***Generate from randomly created starting sequence***

In [53]:
with torch.no_grad():
    hidden_state = None

    for ix in index_list:
        
        # Print current input sequence
        print(ix_to_word[ix], end=' ')

        #Pick a random starting word
        current_ix = np.array(ix)
        
        # Convert to PyTorch Tensor
        input = torch.tensor(current_ix, dtype=torch.int64)
        
        # Change dimensionality of tensor for PyTorch compatibility
        # For more info on this function see: https://stackoverflow.com/questions/57237352/what-does-unsqueeze-do-in-pytorch
        input = input.unsqueeze(0).unsqueeze(0)
        
        # Condition the model on starting sequence
        output, hidden_state = rnn(input, hidden_state)


    # Iterate over our sequence length
    for i in range(gen_seq_len):
        # Forward pass
        output, hidden_state = rnn(input, hidden_state)
        
        # Construct categorical distribution and sample a word
        output = F.softmax(torch.squeeze(output), dim=0)
        dist = Categorical(output / temperature)
        index = dist.sample()
        
        # Print the sampled word
        print(ix_to_word[index.item()], end=' ')
        
        # Next input is current output
        input[0][0] = index.item()

Why hello there To become energy wearing rings it easy You wont see shell matter what you too much Though the single reason When you do me do they all together dressed in the same if you dont care Theres nothing to dance with a hard it on baby Yeeeeh baby you Shell remember just look at all belong? All thru this world around you evry fool and friends and that when I want her She like me ‘Cause I even listen to pretend But till your own girl You tell you can talk to lose affection known high Newspaper taxis appear to 