# LSTM Text Generation

Based on: https://github.com/keras-team/keras/blob/master/examples/lstm_text_generation.py

Example script to generate text from Nietzsche's writings.

At least 20 epochs are required before the generated text
starts sounding coherent.

It is recommended to run this script on GPU, as recurrent
networks are quite computationally intensive.

If you try this script on new data, make sure your corpus
has at least ~100k characters. ~1M is better.

In [26]:
import conx as cx

In [27]:
cx.download('https://s3.amazonaws.com/text-datasets/nietzsche.txt')

Using cached https://s3.amazonaws.com/text-datasets/nietzsche.txt as './nietzsche.txt'.


In [28]:
text = open("nietzsche.txt").read().lower()

In [29]:
len(text)

600893

In [30]:
text[:100]

'preface\n\n\nsupposing that truth is a woman--what then? is there not ground\nfor suspecting that all ph'

In [31]:
chars = sorted(list(set(text)))
print('total chars:', len(chars))

total chars: 57


In [32]:
"".join(chars)

'\n !"\'(),-.0123456789:;=?[]_abcdefghijklmnopqrstuvwxyzäæéë'

In [33]:
char_indices = dict((c, i) for i, c in enumerate(chars))
indices_char = dict((i, c) for i, c in enumerate(chars))

In [34]:
print("char to index:", char_indices)
print("index to char:", indices_char)

char to index: {'\n': 0, ' ': 1, '!': 2, '"': 3, "'": 4, '(': 5, ')': 6, ',': 7, '-': 8, '.': 9, '0': 10, '1': 11, '2': 12, '3': 13, '4': 14, '5': 15, '6': 16, '7': 17, '8': 18, '9': 19, ':': 20, ';': 21, '=': 22, '?': 23, '[': 24, ']': 25, '_': 26, 'a': 27, 'b': 28, 'c': 29, 'd': 30, 'e': 31, 'f': 32, 'g': 33, 'h': 34, 'i': 35, 'j': 36, 'k': 37, 'l': 38, 'm': 39, 'n': 40, 'o': 41, 'p': 42, 'q': 43, 'r': 44, 's': 45, 't': 46, 'u': 47, 'v': 48, 'w': 49, 'x': 50, 'y': 51, 'z': 52, 'ä': 53, 'æ': 54, 'é': 55, 'ë': 56}
index to char: {0: '\n', 1: ' ', 2: '!', 3: '"', 4: "'", 5: '(', 6: ')', 7: ',', 8: '-', 9: '.', 10: '0', 11: '1', 12: '2', 13: '3', 14: '4', 15: '5', 16: '6', 17: '7', 18: '8', 19: '9', 20: ':', 21: ';', 22: '=', 23: '?', 24: '[', 25: ']', 26: '_', 27: 'a', 28: 'b', 29: 'c', 30: 'd', 31: 'e', 32: 'f', 33: 'g', 34: 'h', 35: 'i', 36: 'j', 37: 'k', 38: 'l', 39: 'm', 40: 'n', 41: 'o', 42: 'p', 43: 'q', 44: 'r', 45: 's', 46: 't', 47: 'u', 48: 'v', 49: 'w', 50: 'x', 51: 'y', 52: '

Cut the text in semi-redundant sequences of maxlen characters:

In [35]:
maxlen = 40
step = 3
sequences = []
for i in range(0, len(text) - maxlen - 1, step):
    sequences.append(text[i: i + maxlen + 1])
print('sequences:', len(sequences))


sequences: 200284


In [36]:
sequences[0:10]

['preface\n\n\nsupposing that truth is a woman',
 'face\n\n\nsupposing that truth is a woman--w',
 'e\n\n\nsupposing that truth is a woman--what',
 '\nsupposing that truth is a woman--what th',
 'pposing that truth is a woman--what then?',
 'sing that truth is a woman--what then? is',
 'g that truth is a woman--what then? is th',
 'hat truth is a woman--what then? is there',
 ' truth is a woman--what then? is there no',
 'uth is a woman--what then? is there not g']

In [37]:
len(sequences[0])

41

In [38]:
(len(sequences), maxlen, len(chars))

(200284, 40, 57)

## Vectorization

In [39]:
cx.onehot(2, 5, values=[False, True])

[False, False, True, False, False]

In [40]:
char_encode = {ch: cx.onehot(char_indices[ch], len(chars), values=[False, True]) for ch in chars}

In [41]:
print(char_encode["a"])

[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]


In [42]:
inputs = []
targets = []
for s in range(len(sequences)):
    current = [char_encode[ch] for ch in sequences[s]]
    inputs.append(current[:-1])
    targets.append(current[-1])

In [43]:
cx.shape(inputs)

(200284, 40, 57)

In [44]:
cx.shape(targets)

(200284, 57)

In [45]:
net = cx.Network("LSTM Text Generation")
net.add(
    cx.Layer("input", (maxlen, len(chars))),
    cx.LSTMLayer("lstm", 128),
    cx.Layer("output", len(chars), activation="softmax"),
)
net.connect()
net.compile(error="categorical_crossentropy", optimizer="RMSProp", lr=0.01)

In [47]:
net.dataset.load(inputs=inputs, targets=targets)

In [48]:
net.dataset.summary()

_________________________________________________________________
LSTM Text Generation Dataset:
Patterns    Shape                 Range                         
inputs      (40, 57)              (0.0, 1.0)                    
targets     (57,)                 (0.0, 1.0)                    
Total patterns: 200284
   Training patterns: 200284
   Testing patterns: 0
_________________________________________________________________


In [None]:
net.dashboard()

In [None]:
net.dataset.chop(.99)

In [38]:
net.dataset.summary()

_________________________________________________________________
LSTM Text Generation Dataset:
Patterns    Shape                 Range                         
inputs      (40, 57)              (0.0, 1.0)                    
targets     (57,)                 (0.0, 1.0)                    
Total patterns: 2003
   Training patterns: 2003
   Testing patterns: 0
_________________________________________________________________


In [49]:
"".join([indices_char[cx.argmax(v)] for v in net.dataset.inputs[0]])
probs = sorted(enumerate(net.propagate(net.dataset.inputs[0])), 
               key=lambda v: v[1], reverse=True)
[(indices_char[w[0]], round(w[1], 2)) for w in probs]

[('é', 0.02),
 (';', 0.02),
 ('d', 0.02),
 (']', 0.02),
 ('\n', 0.02),
 ('y', 0.02),
 ('z', 0.02),
 ('m', 0.02),
 ('a', 0.02),
 ('2', 0.02),
 ('k', 0.02),
 (':', 0.02),
 ('n', 0.02),
 ('_', 0.02),
 ('o', 0.02),
 ('l', 0.02),
 ('3', 0.02),
 ('p', 0.02),
 ('x', 0.02),
 (' ', 0.02),
 ('v', 0.02),
 ('t', 0.02),
 ('æ', 0.02),
 ('c', 0.02),
 ('(', 0.02),
 ('9', 0.02),
 ('j', 0.02),
 ('ë', 0.02),
 ('?', 0.02),
 ('6', 0.02),
 ('e', 0.02),
 ('g', 0.02),
 ('h', 0.02),
 ('1', 0.02),
 ('=', 0.02),
 ('i', 0.02),
 (',', 0.02),
 ('!', 0.02),
 ('r', 0.02),
 ('f', 0.02),
 ('w', 0.02),
 ('b', 0.02),
 ('u', 0.02),
 ('.', 0.02),
 ('[', 0.02),
 ('5', 0.02),
 (')', 0.02),
 ('ä', 0.02),
 ('8', 0.02),
 ('"', 0.02),
 ('q', 0.02),
 ("'", 0.02),
 ('0', 0.02),
 ('-', 0.02),
 ('7', 0.02),
 ('s', 0.02),
 ('4', 0.02)]

In [50]:
from IPython.display import clear_output

In [51]:
def on_epoch_end(network, epoch=None, logs=None):
    import io
    epoch = epoch if epoch is not None else network.epoch_count
    s = io.StringIO()
    s.write("\n")
    s.write('----- Generating text after Epoch: %d\n' % epoch)
    start_index = cx.choice(len(text) - maxlen - 1)
    for diversity in [0.2, 0.5, 1.0, 1.2]:
        sentence = text[start_index: start_index + maxlen]
        s.write('----- diversity: %s\n' % diversity)
        s.write('----- Generating with seed: "' + sentence + '"\n\n')
        s.write(sentence)
        current = [char_encode[ch] for ch in sentence]
        for i in range(400):
            output = network.propagate(current)
            next_index = cx.choice(p=output, temperature=diversity, index=True)
            s.write(indices_char[next_index])
            next_char = char_encode[indices_char[next_index]]
            current = current[1:]
            current.append(next_char)
        s.write("\n")
    clear_output()
    print(s.getvalue())

In [52]:
import tensorflow as tf

In [54]:
%%time
session = tf.Session( config = tf.ConfigProto(log_device_placement=True))
net.train(1, batch_size=128, plot=False)

Evaluating initial training metrics...
Training...
       |  Training |  Training 
Epochs |     Error |  Accuracy 
------ | --------- | --------- 
#    0 |   4.04202 |   0.02643 
#    1 |   2.00208 |   0.41688 
#    1 |   2.00208 |   0.41688 
CPU times: user 9min 12s, sys: 1min 16s, total: 10min 28s
Wall time: 3min 29s


In [69]:
%%time
net.train(1, batch_size=128, plot=False)

Training...
       |  Training |  Training 
Epochs |     Error |  Accuracy 
------ | --------- | --------- 
#    4 |   2.64944 |   0.28108 
#    5 |   2.51173 |   0.29805 
#    5 |   2.51173 |   0.29805 
CPU times: user 3.83 s, sys: 627 ms, total: 4.46 s
Wall time: 1.55 s


In [70]:
"".join([indices_char[cx.argmax(v)] for v in net.dataset.inputs[0]])

'preface\n\n\nsupposing that truth is a woma'

In [71]:
probs = sorted(enumerate(net.propagate(net.dataset.inputs[0])), 
               key=lambda v: v[1], reverse=True)

In [72]:
probs[0]

(40, 0.2107514590024948)

In [73]:
[(indices_char[w[0]], round(w[1], 2)) for w in probs]

[('n', 0.21),
 ('t', 0.11),
 ('l', 0.09),
 ('s', 0.09),
 (' ', 0.08),
 ('d', 0.08),
 ('c', 0.06),
 ('g', 0.05),
 ('e', 0.03),
 ('r', 0.03),
 ('y', 0.02),
 ('i', 0.02),
 ('m', 0.02),
 ('o', 0.01),
 ('v', 0.01),
 ('b', 0.01),
 ('\n', 0.01),
 ('a', 0.01),
 ('u', 0.01),
 ('-', 0.01),
 ('p', 0.01),
 (',', 0.01),
 ('w', 0.0),
 ('f', 0.0),
 ('?', 0.0),
 ('h', 0.0),
 ('k', 0.0),
 ('"', 0.0),
 ('q', 0.0),
 ('.', 0.0),
 (':', 0.0),
 (';', 0.0),
 ('!', 0.0),
 ('z', 0.0),
 ('x', 0.0),
 ('8', 0.0),
 ('1', 0.0),
 ('j', 0.0),
 ('3', 0.0),
 (')', 0.0),
 ('ë', 0.0),
 ('æ', 0.0),
 ('4', 0.0),
 ('7', 0.0),
 ('5', 0.0),
 ('0', 0.0),
 ('9', 0.0),
 ('=', 0.0),
 (']', 0.0),
 ('(', 0.0),
 ('_', 0.0),
 ('2', 0.0),
 ('ä', 0.0),
 ('é', 0.0),
 ('[', 0.0),
 ("'", 0.0),
 ('6', 0.0)]

In [74]:
on_epoch_end(net)


----- Generating text after Epoch: 5
----- diversity: 0.2
----- Generating with seed: "s of years be powerful enough to endow m"

s of years be powerful enough to endow me the ge the the the the the the the the the the the the the the the the the the the the the the the the the se the the ind whe the the the and were wig out ind and whe the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the te ind whe the the the the the the the the the the the 
----- diversity: 0.5
----- Generating with seed: "s of years be powerful enough to endow m"

s of years be powerful enough to endow mengeg or whe bede tio eica the s be ied are of erent in y tot  or an e th em gort of ne at inn te the in e th an werese sed theute thencl an  ind then t thed et le beded of an st thelo ie then the were th when t tire the whens er woy an is 
pe weme wan tire wers and then end def dews thr theg p