<a href="https://colab.research.google.com/github/PamelaVQ/Base-ML/blob/master/Pytorch_Basics/RNN_for_Sentence_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Use Recurrent Neural Network for Sentence Generation

Reference Links:

http://karpathy.github.io/2015/05/21/rnn-effectiveness/

https://gist.github.com/karpathy/d4dee566867f8291f086

https://github.com/keras-team/keras/blob/master/examples/lstm_text_generation.py


In [1]:
from tensorflow.keras.utils import get_file
import io
import sklearn
import sklearn.feature_extraction
import numpy as np
from torch import nn
import torchvision
import torch

In [24]:
def pattern_text(start_pattern, end_pattern, data):
  result = data[data.find(start_pattern)+len(start_pattern):data.rfind(end_pattern)]
  return result

In [25]:
path = get_file("agatha_christie", origin="https://www.gutenberg.org/files/863/863-0.txt")
with io.open(path, encoding='utf-8') as read_file:
  data_agatha_christie = read_file.read().lower()
start_text = """*** START OF THIS PROJECT GUTENBERG EBOOK THE MYSTERIOUS AFFAIR AT STYLES ***""".lower()
end_text = """*** END OF THIS PROJECT GUTENBERG EBOOK THE MYSTERIOUS AFFAIR AT STYLES ***""".lower()
data_agatha_christie = pattern_text(start_text, end_text, data_agatha_christie)
print(f'agatha_christie corpus length:{len(data_agatha_christie)}')

path = get_file("lewis_carroll", origin="https://www.gutenberg.org/files/11/11-0.txt")
with io.open(path, encoding='utf-8') as read_file:
  data_lewis_carroll = read_file.read().lower()
start_text = """*** START OF THIS PROJECT GUTENBERG EBOOK ALICE’S ADVENTURES IN WONDERLAND ***""".lower()
end_text = """*** END OF THIS PROJECT GUTENBERG EBOOK ALICE’S ADVENTURES IN WONDERLAND ***""".lower()
data_lewis_carroll = pattern_text(start_text, end_text, data_lewis_carroll)
print(f'lewis_carroll corpus length:{len(data_lewis_carroll)}')

agatha_christie corpus length:321018
lewis_carroll corpus length:144730


In [26]:
def get_character_vocab(corpus):
  data = list(corpus.strip())
  vocab = list(set(data))
  return data, vocab

# word2idx, idx2word, output_corpus = create_vocab(data_agatha_christie)
data, vocab = get_character_vocab(data_agatha_christie)
data_size, vocab_size = len(data), len(vocab)
print(data_size, vocab_size)
idx_to_char = {i:char for i, char in enumerate(vocab)}
char_to_idx = {char:i for i, char in enumerate(vocab)}
print(idx_to_char)

321011 63
{0: 'q', 1: 'é', 2: 't', 3: '0', 4: 'l', 5: 'ç', 6: 'p', 7: ';', 8: ')', 9: 'i', 10: '-', 11: '5', 12: 'a', 13: '7', 14: 'c', 15: '“', 16: '2', 17: ',', 18: '4', 19: 'd', 20: 'b', 21: 'x', 22: 'à', 23: 'y', 24: 'r', 25: '”', 26: '’', 27: 'k', 28: "'", 29: '3', 30: 'ó', 31: 'n', 32: 'j', 33: '!', 34: 'o', 35: '‘', 36: 'â', 37: 'z', 38: '6', 39: 'v', 40: ' ', 41: 's', 42: '(', 43: 'è', 44: 'f', 45: 'h', 46: '1', 47: 'm', 48: '—', 49: '9', 50: 'w', 51: '.', 52: 'u', 53: '\n', 54: 'î', 55: 'ê', 56: '8', 57: '?', 58: ':', 59: 'g', 60: '&', 61: '_', 62: 'e'}


In [27]:
X = [char_to_idx[i] for i in data]
Y = X[1:] + [char_to_idx['\n']]
print(X[:10])
print(Y[:10])

[6, 24, 34, 19, 52, 14, 62, 19, 40, 20]
[24, 34, 19, 52, 14, 62, 19, 40, 20, 23]


In [28]:
class RNN(nn.Module):
    def __init__(self, max_length, vocab_size):
      super(RNN, self).__init__()
      # self.embed = nn.Embedding(input_size, 128)
      self.lstm = nn.LSTM(128, input_shape=(max_length, vocab_size))
      self.dense = nn.Dense(vocab_size, activation='softmax')

    def forward(self, x):
      x = self.lstm(x)
      x = self.dense(x)

Understanding LSTM: [Pytorch LSTM Docs](https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html)

In [67]:
# testing LSTM
lstm = nn.LSTM(3, 3)  # Input dim is 3, output dim is 3
inputs = [torch.randn(1, 3) for _ in range(5)]  # make a sequence of length 5
print(inputs[0].shape)
# initialize the hidden state.
hidden = (torch.randn(1, 1, 3),
          torch.randn(1, 1, 3))
print(hidden)

AttributeError: ignored

In [31]:
for i in inputs:
    # Step through the sequence one element at a time.
    # after each step, hidden contains the hidden state.
    out, hidden = lstm(i.view(1, 1, -1), hidden) # (3,3) ((3,3), (3,3))

In [32]:
# do all 3 inputs at once
inputs_1 = torch.cat(inputs).view(len(inputs), 1, -1)
hidden = (torch.randn(1, 1, 3), torch.randn(1, 1, 3))
output, hidden = lstm(inputs_1, hidden)
print(output.shape)
print(hidden[0].shape)

torch.Size([5, 1, 3])
torch.Size([1, 1, 3])


Learning: [An LSTM for Part-of-Speech Tagging](https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html#example-an-lstm-for-part-of-speech-tagging)

In [33]:
def prepare_sequence(seq):
    return torch.tensor(seq, dtype=torch.long)

In [35]:
EMBEDDING_DIM = 126
HIDDEN_DIM = 126

In [61]:
class LSTMSequenceGenerator(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, vocab_size):
      super(LSTMSequenceGenerator, self).__init__()
      self.hidden_dim = hidden_dim
      self.embeddings = nn.Embedding(1, embedding_dim)
      self.lstm = nn.LSTM(embedding_dim, hidden_dim)
      self.dense = nn.Linear(embedding_dim, vocab_size)
      # self.softmax = nn.functional.softmax(vocab_size, dim=0)

    def forward(self, x):
      x = self.embeddings(x)
      # x, _ = self.lstm(x.view(len(x), 1, -1))
      # x = self.dense(lstm_oxut.view(len(x), 1, -1))
      return x


In [62]:
lstm = LSTMSequenceGenerator(embedding_dim=EMBEDDING_DIM, hidden_dim=HIDDEN_DIM, vocab_size=vocab_size)

In [68]:
from torch.autograd import Variable
max_len = max_len

In [75]:
inputs = [prepare_sequence(seq) for seq in X]
# need padding here
inputs
# inputs_1 = torch.cat(inputs) .view(len(inputs), 1, -1)
# hidden = (torch.randn(1, 1, 3), torch.randn(1, 1, 3))
# output, hidden = lstm(inputs_1, hidden)
# inputs_1
# output = lstm(x_input)
# output.shape

[tensor(6),
 tensor(24),
 tensor(34),
 tensor(19),
 tensor(52),
 tensor(14),
 tensor(62),
 tensor(19),
 tensor(40),
 tensor(20),
 tensor(23),
 tensor(40),
 tensor(14),
 tensor(45),
 tensor(12),
 tensor(24),
 tensor(4),
 tensor(62),
 tensor(41),
 tensor(40),
 tensor(27),
 tensor(62),
 tensor(4),
 tensor(4),
 tensor(62),
 tensor(24),
 tensor(53),
 tensor(53),
 tensor(53),
 tensor(53),
 tensor(53),
 tensor(53),
 tensor(53),
 tensor(53),
 tensor(53),
 tensor(2),
 tensor(45),
 tensor(62),
 tensor(40),
 tensor(47),
 tensor(23),
 tensor(41),
 tensor(2),
 tensor(62),
 tensor(24),
 tensor(9),
 tensor(34),
 tensor(52),
 tensor(41),
 tensor(40),
 tensor(12),
 tensor(44),
 tensor(44),
 tensor(12),
 tensor(9),
 tensor(24),
 tensor(40),
 tensor(12),
 tensor(2),
 tensor(40),
 tensor(41),
 tensor(2),
 tensor(23),
 tensor(4),
 tensor(62),
 tensor(41),
 tensor(53),
 tensor(53),
 tensor(20),
 tensor(23),
 tensor(40),
 tensor(12),
 tensor(59),
 tensor(12),
 tensor(2),
 tensor(45),
 tensor(12),
 tensor(40)