<a href="https://colab.research.google.com/github/CurtesMalteser/text-generator-rnn/blob/master/text-generator-rnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

### **Load Data**

In [21]:
with open('drive/My Drive/anna.txt', 'r') as f:
    text = f.read()

text[:100]

'Chapter 1\n\n\nHappy families are all alike; every unhappy family is unhappy in its own\nway.\n\nEverythin'

### **Tokenization**

Here the chars will be converted to and from integers.


In [0]:
# we create two dictionaries:
# variable names are self explanatory
chars = tuple(set(text))
int2char = dict(enumerate(chars))
char2int = {ch: ii for ii, ch in int2char.items()}

# encode the text
encoded = np.array([char2int[ch] for ch in text])

In [32]:
# print encoded chars
encoded[:100]

array([77, 45, 70, 11, 12,  6, 10, 15, 61, 76, 76, 76, 53, 70, 11, 11,  2,
       15, 69, 70, 46, 36, 24, 36,  6,  9, 15, 70, 10,  6, 15, 70, 24, 24,
       15, 70, 24, 36, 43,  6, 57, 15,  6, 54,  6, 10,  2, 15, 78, 28, 45,
       70, 11, 11,  2, 15, 69, 70, 46, 36, 24,  2, 15, 36,  9, 15, 78, 28,
       45, 70, 11, 11,  2, 15, 36, 28, 15, 36, 12,  9, 15, 16, 34, 28, 76,
       34, 70,  2,  4, 76, 76, 39, 54,  6, 10,  2, 12, 45, 36, 28])

### **Pre-Precessing the data**
The LSTM expects as input a char converted into int, and the will be converted into one column vector where only the correspending index will have value one and remaing will be 0. This is the **one-hot encoded**.

In [0]:
def one_hot_encoded(arr, n_labels):

  # Initialize the encoded array
  one_hot = np.zeros((np.multiply(*arr.shape), n_labels), dtype=np.float32)

  # Fill the appropriate elements with ones
  one_hot[np.arange(one_hot.shape[0]), arr.flatten()] = 1.

  # Finally reshape it to get back to the original array
  one_hot = one_hot.reshape((*arr.shape, n_labels))

  return one_hot

In [38]:
# check that one_hot_encoded works as expected
test_seq = np.array([(3, 5, 1)])
one_hot = one_hot_encoded(test_seq, 8)

print(one_hot)

[[[0. 0. 0. 1. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 1. 0. 0.]
  [0. 1. 0. 0. 0. 0. 0. 0.]]]


### **Make mini-batches**

In [0]:
def get_batches(arr, batch_size, seq_length):
  ''' Create a generator tat returns batches of size: batch_size*seq_length
  from arr.

  Arguments
  ---------
  arr: Array to generate batches from
  batch_size: The number of sequences per batch
  seq_length: Number of encoded chars in a sequence
  '''

  batch_size_total = batch_size * seq_length

  # total number of batches we can make
  n_batches = len(arr)//batch_size_total

  # Keep only enough chars. to make full batches
  arr = arr[:n_batches * batch_size_total]

  # Reshape into batch_size_rows
  arr = arr.reshape((batch_size, -1))

  # iterate through the array, one sequence at a time
  for n in range(0, arr.shape[1], seq_length):
    # The features
    x = arr[:, n:n + seq_length]

    # The targets, shifted by one
    y = np.zeros_like(x)
    try:
      y[:, :-1], y[:, -1], = x[:, 1:], arr[:, n + seq_length]
    except:
      y[:, :-1], y[:, -1], = x[:, 1:], arr[:, 0]
    yield x, y

### **Test Implementation**


*   Batch Size: 8
*   Sequence Steps: 50



In [0]:
batches = get_batches(encoded, 8, 50)
x, y = next(batches)

In [45]:
# print first 10 times in a sequence
print('x\n', x[:10, :10])
print('\ny\n', y[:10, :10])

x
 [[77 45 70 11 12  6 10 15 61 76]
 [ 9 16 28 15 12 45 70 12 15 70]
 [ 6 28 62 15 16 10 15 70 15 69]
 [ 9 15 12 45  6 15 19 45 36  6]
 [15  9 70 34 15 45  6 10 15 12]
 [19 78  9  9 36 16 28 15 70 28]
 [15 35 28 28 70 15 45 70 62 15]
 [ 7 67 24 16 28  9 43  2  4 15]]

y
 [[45 70 11 12  6 10 15 61 76 76]
 [16 28 15 12 45 70 12 15 70 12]
 [28 62 15 16 10 15 70 15 69 16]
 [15 12 45  6 15 19 45 36  6 69]
 [ 9 70 34 15 45  6 10 15 12  6]
 [78  9  9 36 16 28 15 70 28 62]
 [35 28 28 70 15 45 70 62 15  9]
 [67 24 16 28  9 43  2  4 15 68]]
