# Text generation models

One very basic kind of text generation model is the Markov model.  In
such a model, we have a state which consists of the previous character.
We also have a matrix of transitions from one character to another.  We
*train* the model by feeding it some text, and observing the
transitions.  We can then generate more text from the model.

In [7]:
from collections import defaultdict
import re

def train_sentence(transitions, sentence, depth=1):
    # We need a "special" character to represent the beginning of a sentence.
    # This is also the character we'll use to feed the generator, below.
    prevchar = "•" * depth
    for char in sentence:
        transitions[prevchar][char] += 1
        prevchar = prevchar[1:] + char

    return transitions

def split_text(text):
    for sentence in re.finditer(".*?([.?!][”’]?|\n\n)", text, re.DOTALL):
        # Turn all sequences of whitespace into a single space
        sentence = re.sub("[ \t\n\r]+", " ", sentence.group(0)).strip()
        if sentence != '':
            yield sentence

def train(filename, depth=1):
    transitions = defaultdict(lambda: defaultdict(int))
    with open(filename) as fin:
        text = fin.read()
        for sentence in split_text(text):
            if len(sentence) < 3:
                continue
            transitions = train_sentence(transitions, sentence, depth)

    return transitions

In [8]:
import pandas as pd
import numpy as np

def format_transitions(trs):
    rows = []
    for key in trs:
        for key2 in trs[key]:
            rows.append({'from': key, 'to': key2, 'n': trs[key][key2]})
    data = pd.DataFrame(rows)
    data = data.pivot_table(index='from', columns='to', values='n')
    data = data.div(data.sum(axis=1), axis=0)
    data[np.isnan(data)] = 0

    return data

In [3]:
def produce(transitions):
    if isinstance(transitions, defaultdict):
        transitions = format_transitions(transitions)

    # Nifty trick: auto-calculate the depth we were trained on
    depth = len(transitions.index[0])

    output = ""
    last = "•" * depth
    nxt = ""

    while nxt not in [".", "?", "!"]:
        trs = transitions.loc[last]
        nxt = np.random.choice(trs.index, p=trs)
        last = last[1:] + nxt
        output += nxt

    return output

In [4]:
tr = train("alice.txt", 4)

In [5]:
[produce(tr) for x in range(5)]

['Supposed up and ther is,’ thoughtfull over unpleast in accur: but states’ll trying--‘Catch the elbow what?',
 'Down,’ said that promoting riddle into them, and the rose things, we should not be folding then!',
 'How CAN I haven’t makes also, and put once, when sat she talk.',
 'This cut it see as shrink I must really far don’t might-eyed the Project Gutenberg License was sense, which was herself.',
 '‘No,’ said Alice if no idea,’ ther.']

Ideas for extension:

- Train on a different text
- Try normalizing the text in different ways (e.g. what happens if you take
  out quotation marks?)
- Play around with different depths.  Do different ones work better for
  different texts?
- Try generating longer passages (you will need to alter the training
  also)
- Convert to word-based rather than character based

# Neural networks

Neural networks in the news

- [An excellent introduction to RNNs and character level language models](http://karpathy.github.io/2015/05/21/rnn-effectiveness/)
- [Generating Harry Potter text with a RNN (with tutorial)](https://medium.com/deep-writing/harry-potter-written-by-artificial-intelligence-8a9431803da6)
  - “Yeah, I know, I saw him run off the balls of the Three Broomsticks around the Daily Prophet that we met Potter’s name!” said Hermione. “We’ve done all right, Draco, and Karkaroff would have to spell the Imperius Curse,” said Dumbledore. “But Harry, never found out about the happy against the school.”
- [The AI that's "too dangerous to release"](https://openai.com/blog/better-language-models/) (NB this specific claim is propaganda, and the model probably does not actually work as it's claimed to)
  - *John F. Kennedy was just elected President of the United States after rising from the grave decades after his assassination. Due to miraculous developments in nanotechnology, Kennedy’s brain was rebuilt from his remains and installed in the control center of a state-of-the art humanoid robot. Below is a transcript of his acceptance speech.*
  
    “It is time once again. I believe this nation can do great things if the people make their voices heard. The men and women of America must once more summon our best elements, all our ingenuity, and find a way to turn such overwhelming tragedy into the opportunity for a greater good and the fulfillment of all our dreams. In the months and years to come, there will be many battles in which we will have to be strong and we must give all of our energy, not to repel invaders, but rather to resist aggression and to win the freedom and the equality for all of our people. The destiny of the human race hangs in the balance; we cannot afford for it to slip away. Now and in the years to come, the challenge before us is to work out how we achieve our ultimate destiny. If we fail to do so, we are doomed.”
    
    [...]

    After the election, the President of the United States asked for a “more complete” body of information on Kennedy’s brains’ origin. In an e-mail exchange with me, Dr. Robert Wood, who, according to Wikipedia is the head of the University of Maryland Advanced Nanotechnology Research project, provided the following information:

    (1) The brain of JFK was harvested and reconstructed via tissue sampling. There was no way that the tissue could be transported by air. (2) A sample was collected from the area around his upper chest and sent to the University of Maryland for analysis.
    
Tutorials:

- [Tensorflow text generation (Shakespeare)](https://www.tensorflow.org/beta/tutorials/text/text_generation)
- [PyTorch name generation/classification](https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutorial.html) -- a different task, but with useful conceptual background

Code:
- [Keras](https://github.com/ekzhang/char-rnn-keras) -- on which some of the below code is based

For our neural network, we have a series of layers that do different things:
- Embedding: turn a single number (character-number) into a vector of numbers
- LSTM (long short term memory): a very general kind of neural network gate.  Intuitively, it consists of a single "cell" of memory, with three controls:
  - input: how likely this cell is to "learn" something new
  - forget: how likely this cell is to hold on to what it already knows
  - output: how likely this cell is to pass what it knows on to the rest of the network
- Dropout: a layer that randomly zeroes out part of its contents.  While this might seem like a bad thing, it turns out to prevent overfitting in practice
- TimeDistributed: somewhat misleading name (in this application).  Since we are feeding multiple slices in at once, and predicting multiple outputs, we need this layer to glue the inputs and outputs together.  The multiple characters are the "time" in the name.  ([explanation](https://github.com/keras-team/keras/issues/1029#issuecomment-158105579))
- Activation: transforms to probabilities/weights per character

In [2]:
from keras.models import Sequential
from keras.layers import LSTM, Dense, Dropout, Activation, TimeDistributed, Embedding

EMBEDDING_SZ = 128
LSTM_SZ = 64

def build_model(simultaneous_batches, chars_per_batch, vocab_size):
    model = Sequential()
    model.add(Embedding(vocab_size, EMBEDDING_SZ, batch_input_shape=(simultaneous_batches, chars_per_batch)))
    for i in range(3):
        model.add(LSTM(LSTM_SZ, return_sequences=True, stateful=True))
        model.add(Dropout(0.2))
        
    model.add(TimeDistributed(Dense(vocab_size)))
    model.add(Activation('softmax'))
    
    return model

Using TensorFlow backend.


We want to feed the text into the network in simultaneous batches.  Imagine that we want 4 batches, each 8 characters long.  Imagine too that the text is 96 characters long.  For convenience we'll split the text into 4 lines, and use the 24 letters a-x to represent the text.  The first batch will consist of 4 copies of the string "abcdefgh" (characters 0-7, 24-31, 48-55, and 72-79).  The second batch will consist of 4 copies of "ijklmnop" (8-15, 32-39, 56-63, 80-87).  And the third will consist of 4 copies of "qrstuvwx" (16-23, 40-47, 64-71, 88-95).  Thus, each character is fed into the network exactly once.  The advantage of this scheme is that it spreads any local variation in the text across different training periods, rather than combining it into one.  (Remember: a real text is more likely to be like "aaabbbccc" than "abcabcabc")

```
abcdefghijklmnopqrstuvwx
abcdefghijklmnopqrstuvwx
abcdefghijklmnopqrstuvwx
abcdefghijklmnopqrstuvwx
```

In [3]:
CHARS_PER_BATCH = 64
SIMULTANEOUS_BATCHES = 16
def batches(text):
    alphabet = sorted(set(list(text)))
    distance_between_batches = len(text) // SIMULTANEOUS_BATCHES
    for start in range(0, distance_between_batches - CHARS_PER_BATCH, CHARS_PER_BATCH):
        x = np.zeros((SIMULTANEOUS_BATCHES, CHARS_PER_BATCH))
        y = np.zeros((SIMULTANEOUS_BATCHES, CHARS_PER_BATCH, len(alphabet)))
        for i in range(0, SIMULTANEOUS_BATCHES):
            for j in range(0, CHARS_PER_BATCH):
                x[i, j] = alphabet.index(text[distance_between_batches * i + start + j])
                y[i, j, alphabet.index(text[distance_between_batches * i + start + j + 1])] = 1
        yield x, y

In [4]:
with open("alice.txt") as fin:
    text = fin.read()

In [5]:
alphabet = sorted(list(set(text)))
vocab_size = len(alphabet)
model = build_model(simultaneous_batches=SIMULTANEOUS_BATCHES,
                    chars_per_batch=CHARS_PER_BATCH,
                    vocab_size=vocab_size)
model.compile(loss='categorical_crossentropy',
              optimizer='adam', metrics=['accuracy'])

print(model.summary())

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding_1 (Embedding)      (16, 64, 128)             9088      
_________________________________________________________________
lstm_1 (LSTM)                (16, 64, 64)              49408     
_________________________________________________________________
dropout_1 (Dropout)          (16, 64, 64)              0         
_________________________________________________________________
lstm_2 (LSTM)                (16, 64, 64)              33024     
_________________________________________________________________
dropout_2 (Dropout)          (16, 64, 64)              0         
_________________________________________________________________
lstm_3 (LSTM)        

In [46]:
from tqdm import tqdm_notebook

n_batches = len(list(range(0, len(text) // SIMULTANEOUS_BATCHES - CHARS_PER_BATCH, CHARS_PER_BATCH)))

for epoch in range(0, 100):
    for x, y in tqdm_notebook(batches(text), total=n_batches, desc=f'Epoch {epoch}'):
        losses = []
        accs = []
        loss, acc = model.train_on_batch(x, y)
        losses.append(loss)
        accs.append(acc)
    print(f"Epoch: {epoch}, losses {np.mean(losses)}, accuracy {np.mean(accs)}")

HBox(children=(IntProgress(value=0, description='Epoch 0', max=140, style=ProgressStyle(description_width='ini…

Epoch: 0, losses 1.4618688821792603, accuracy 0.560546875


HBox(children=(IntProgress(value=0, description='Epoch 1', max=140, style=ProgressStyle(description_width='ini…

Epoch: 1, losses 1.485126256942749, accuracy 0.56640625


HBox(children=(IntProgress(value=0, description='Epoch 2', max=140, style=ProgressStyle(description_width='ini…

Epoch: 2, losses 1.4572129249572754, accuracy 0.5517578125


HBox(children=(IntProgress(value=0, description='Epoch 3', max=140, style=ProgressStyle(description_width='ini…

Epoch: 3, losses 1.4663128852844238, accuracy 0.5498046875


HBox(children=(IntProgress(value=0, description='Epoch 4', max=140, style=ProgressStyle(description_width='ini…

Epoch: 4, losses 1.4179461002349854, accuracy 0.5859375


HBox(children=(IntProgress(value=0, description='Epoch 5', max=140, style=ProgressStyle(description_width='ini…

Epoch: 5, losses 1.4777932167053223, accuracy 0.564453125


HBox(children=(IntProgress(value=0, description='Epoch 6', max=140, style=ProgressStyle(description_width='ini…

Epoch: 6, losses 1.4612092971801758, accuracy 0.5712890625


HBox(children=(IntProgress(value=0, description='Epoch 7', max=140, style=ProgressStyle(description_width='ini…

Epoch: 7, losses 1.4690296649932861, accuracy 0.5576171875


HBox(children=(IntProgress(value=0, description='Epoch 8', max=140, style=ProgressStyle(description_width='ini…

Epoch: 8, losses 1.458778977394104, accuracy 0.5537109375


HBox(children=(IntProgress(value=0, description='Epoch 9', max=140, style=ProgressStyle(description_width='ini…

Epoch: 9, losses 1.4613113403320312, accuracy 0.5693359375


HBox(children=(IntProgress(value=0, description='Epoch 10', max=140, style=ProgressStyle(description_width='in…

Epoch: 10, losses 1.4536778926849365, accuracy 0.5556640625


HBox(children=(IntProgress(value=0, description='Epoch 11', max=140, style=ProgressStyle(description_width='in…

Epoch: 11, losses 1.4411113262176514, accuracy 0.560546875


HBox(children=(IntProgress(value=0, description='Epoch 12', max=140, style=ProgressStyle(description_width='in…

Epoch: 12, losses 1.4363384246826172, accuracy 0.5634765625


HBox(children=(IntProgress(value=0, description='Epoch 13', max=140, style=ProgressStyle(description_width='in…

Epoch: 13, losses 1.47456693649292, accuracy 0.5595703125


HBox(children=(IntProgress(value=0, description='Epoch 14', max=140, style=ProgressStyle(description_width='in…

Epoch: 14, losses 1.460085391998291, accuracy 0.5654296875


HBox(children=(IntProgress(value=0, description='Epoch 15', max=140, style=ProgressStyle(description_width='in…

Epoch: 15, losses 1.4746644496917725, accuracy 0.5595703125


HBox(children=(IntProgress(value=0, description='Epoch 16', max=140, style=ProgressStyle(description_width='in…

Epoch: 16, losses 1.450880765914917, accuracy 0.57421875


HBox(children=(IntProgress(value=0, description='Epoch 17', max=140, style=ProgressStyle(description_width='in…

Epoch: 17, losses 1.4569833278656006, accuracy 0.5634765625


HBox(children=(IntProgress(value=0, description='Epoch 18', max=140, style=ProgressStyle(description_width='in…

Epoch: 18, losses 1.463867425918579, accuracy 0.56640625


HBox(children=(IntProgress(value=0, description='Epoch 19', max=140, style=ProgressStyle(description_width='in…

Epoch: 19, losses 1.4164955615997314, accuracy 0.576171875


HBox(children=(IntProgress(value=0, description='Epoch 20', max=140, style=ProgressStyle(description_width='in…

Epoch: 20, losses 1.4289519786834717, accuracy 0.583984375


HBox(children=(IntProgress(value=0, description='Epoch 21', max=140, style=ProgressStyle(description_width='in…

Epoch: 21, losses 1.4555165767669678, accuracy 0.5625


HBox(children=(IntProgress(value=0, description='Epoch 22', max=140, style=ProgressStyle(description_width='in…

Epoch: 22, losses 1.4699397087097168, accuracy 0.5498046875


HBox(children=(IntProgress(value=0, description='Epoch 23', max=140, style=ProgressStyle(description_width='in…

Epoch: 23, losses 1.461560606956482, accuracy 0.5517578125


HBox(children=(IntProgress(value=0, description='Epoch 24', max=140, style=ProgressStyle(description_width='in…

Epoch: 24, losses 1.4780645370483398, accuracy 0.5556640625


HBox(children=(IntProgress(value=0, description='Epoch 25', max=140, style=ProgressStyle(description_width='in…

Epoch: 25, losses 1.4652488231658936, accuracy 0.564453125


HBox(children=(IntProgress(value=0, description='Epoch 26', max=140, style=ProgressStyle(description_width='in…

Epoch: 26, losses 1.4645494222640991, accuracy 0.564453125


HBox(children=(IntProgress(value=0, description='Epoch 27', max=140, style=ProgressStyle(description_width='in…

Epoch: 27, losses 1.457063913345337, accuracy 0.57421875


HBox(children=(IntProgress(value=0, description='Epoch 28', max=140, style=ProgressStyle(description_width='in…

Epoch: 28, losses 1.4627666473388672, accuracy 0.56640625


HBox(children=(IntProgress(value=0, description='Epoch 29', max=140, style=ProgressStyle(description_width='in…

Epoch: 29, losses 1.4414362907409668, accuracy 0.55078125


HBox(children=(IntProgress(value=0, description='Epoch 30', max=140, style=ProgressStyle(description_width='in…

Epoch: 30, losses 1.4400426149368286, accuracy 0.5654296875


HBox(children=(IntProgress(value=0, description='Epoch 31', max=140, style=ProgressStyle(description_width='in…

Epoch: 31, losses 1.502441644668579, accuracy 0.5537109375


HBox(children=(IntProgress(value=0, description='Epoch 32', max=140, style=ProgressStyle(description_width='in…

Epoch: 32, losses 1.4443989992141724, accuracy 0.5576171875


HBox(children=(IntProgress(value=0, description='Epoch 33', max=140, style=ProgressStyle(description_width='in…

Epoch: 33, losses 1.453369140625, accuracy 0.55859375


HBox(children=(IntProgress(value=0, description='Epoch 34', max=140, style=ProgressStyle(description_width='in…

Epoch: 34, losses 1.4408341646194458, accuracy 0.572265625


HBox(children=(IntProgress(value=0, description='Epoch 35', max=140, style=ProgressStyle(description_width='in…

Epoch: 35, losses 1.46270751953125, accuracy 0.5791015625


HBox(children=(IntProgress(value=0, description='Epoch 36', max=140, style=ProgressStyle(description_width='in…

Epoch: 36, losses 1.432405710220337, accuracy 0.580078125


HBox(children=(IntProgress(value=0, description='Epoch 37', max=140, style=ProgressStyle(description_width='in…

Epoch: 37, losses 1.4550193548202515, accuracy 0.55859375


HBox(children=(IntProgress(value=0, description='Epoch 38', max=140, style=ProgressStyle(description_width='in…

Epoch: 38, losses 1.4655499458312988, accuracy 0.556640625


HBox(children=(IntProgress(value=0, description='Epoch 39', max=140, style=ProgressStyle(description_width='in…

Epoch: 39, losses 1.4403514862060547, accuracy 0.564453125


HBox(children=(IntProgress(value=0, description='Epoch 40', max=140, style=ProgressStyle(description_width='in…

Epoch: 40, losses 1.457787036895752, accuracy 0.5771484375


HBox(children=(IntProgress(value=0, description='Epoch 41', max=140, style=ProgressStyle(description_width='in…

Epoch: 41, losses 1.4736101627349854, accuracy 0.55859375


HBox(children=(IntProgress(value=0, description='Epoch 42', max=140, style=ProgressStyle(description_width='in…

Epoch: 42, losses 1.468393325805664, accuracy 0.5673828125


HBox(children=(IntProgress(value=0, description='Epoch 43', max=140, style=ProgressStyle(description_width='in…

Epoch: 43, losses 1.471571445465088, accuracy 0.5634765625


HBox(children=(IntProgress(value=0, description='Epoch 44', max=140, style=ProgressStyle(description_width='in…

Epoch: 44, losses 1.4733247756958008, accuracy 0.5576171875


HBox(children=(IntProgress(value=0, description='Epoch 45', max=140, style=ProgressStyle(description_width='in…

Epoch: 45, losses 1.4601458311080933, accuracy 0.5537109375


HBox(children=(IntProgress(value=0, description='Epoch 46', max=140, style=ProgressStyle(description_width='in…

Epoch: 46, losses 1.4400534629821777, accuracy 0.556640625


HBox(children=(IntProgress(value=0, description='Epoch 47', max=140, style=ProgressStyle(description_width='in…

Epoch: 47, losses 1.485327959060669, accuracy 0.580078125


HBox(children=(IntProgress(value=0, description='Epoch 48', max=140, style=ProgressStyle(description_width='in…

Epoch: 48, losses 1.4615252017974854, accuracy 0.55859375


HBox(children=(IntProgress(value=0, description='Epoch 49', max=140, style=ProgressStyle(description_width='in…

Epoch: 49, losses 1.4330434799194336, accuracy 0.572265625


HBox(children=(IntProgress(value=0, description='Epoch 50', max=140, style=ProgressStyle(description_width='in…

Epoch: 50, losses 1.4360542297363281, accuracy 0.5673828125


HBox(children=(IntProgress(value=0, description='Epoch 51', max=140, style=ProgressStyle(description_width='in…

Epoch: 51, losses 1.4591615200042725, accuracy 0.568359375


HBox(children=(IntProgress(value=0, description='Epoch 52', max=140, style=ProgressStyle(description_width='in…

Epoch: 52, losses 1.4619404077529907, accuracy 0.5634765625


HBox(children=(IntProgress(value=0, description='Epoch 53', max=140, style=ProgressStyle(description_width='in…

Epoch: 53, losses 1.4366556406021118, accuracy 0.5693359375


HBox(children=(IntProgress(value=0, description='Epoch 54', max=140, style=ProgressStyle(description_width='in…

Epoch: 54, losses 1.4352331161499023, accuracy 0.5888671875


HBox(children=(IntProgress(value=0, description='Epoch 55', max=140, style=ProgressStyle(description_width='in…

Epoch: 55, losses 1.4608805179595947, accuracy 0.5673828125


HBox(children=(IntProgress(value=0, description='Epoch 56', max=140, style=ProgressStyle(description_width='in…

Epoch: 56, losses 1.438961386680603, accuracy 0.5673828125


HBox(children=(IntProgress(value=0, description='Epoch 57', max=140, style=ProgressStyle(description_width='in…

Epoch: 57, losses 1.4514296054840088, accuracy 0.5595703125


HBox(children=(IntProgress(value=0, description='Epoch 58', max=140, style=ProgressStyle(description_width='in…

Epoch: 58, losses 1.454436182975769, accuracy 0.5546875


HBox(children=(IntProgress(value=0, description='Epoch 59', max=140, style=ProgressStyle(description_width='in…

Epoch: 59, losses 1.401925802230835, accuracy 0.5693359375


HBox(children=(IntProgress(value=0, description='Epoch 60', max=140, style=ProgressStyle(description_width='in…

Epoch: 60, losses 1.4376635551452637, accuracy 0.5556640625


HBox(children=(IntProgress(value=0, description='Epoch 61', max=140, style=ProgressStyle(description_width='in…

Epoch: 61, losses 1.4388678073883057, accuracy 0.564453125


HBox(children=(IntProgress(value=0, description='Epoch 62', max=140, style=ProgressStyle(description_width='in…

Epoch: 62, losses 1.4764814376831055, accuracy 0.5546875


HBox(children=(IntProgress(value=0, description='Epoch 63', max=140, style=ProgressStyle(description_width='in…

Epoch: 63, losses 1.4351048469543457, accuracy 0.5625


HBox(children=(IntProgress(value=0, description='Epoch 64', max=140, style=ProgressStyle(description_width='in…

Epoch: 64, losses 1.4446271657943726, accuracy 0.5673828125


HBox(children=(IntProgress(value=0, description='Epoch 65', max=140, style=ProgressStyle(description_width='in…

Epoch: 65, losses 1.4472249746322632, accuracy 0.55078125


HBox(children=(IntProgress(value=0, description='Epoch 66', max=140, style=ProgressStyle(description_width='in…

Epoch: 66, losses 1.4443583488464355, accuracy 0.568359375


HBox(children=(IntProgress(value=0, description='Epoch 67', max=140, style=ProgressStyle(description_width='in…

Epoch: 67, losses 1.4590065479278564, accuracy 0.548828125


HBox(children=(IntProgress(value=0, description='Epoch 68', max=140, style=ProgressStyle(description_width='in…

Epoch: 68, losses 1.4640425443649292, accuracy 0.560546875


HBox(children=(IntProgress(value=0, description='Epoch 69', max=140, style=ProgressStyle(description_width='in…

Epoch: 69, losses 1.447706699371338, accuracy 0.5615234375


HBox(children=(IntProgress(value=0, description='Epoch 70', max=140, style=ProgressStyle(description_width='in…

Epoch: 70, losses 1.4379727840423584, accuracy 0.5849609375


HBox(children=(IntProgress(value=0, description='Epoch 71', max=140, style=ProgressStyle(description_width='in…

Epoch: 71, losses 1.4465597867965698, accuracy 0.55859375


HBox(children=(IntProgress(value=0, description='Epoch 72', max=140, style=ProgressStyle(description_width='in…

Epoch: 72, losses 1.4500224590301514, accuracy 0.5654296875


HBox(children=(IntProgress(value=0, description='Epoch 73', max=140, style=ProgressStyle(description_width='in…

Epoch: 73, losses 1.4511027336120605, accuracy 0.5517578125


HBox(children=(IntProgress(value=0, description='Epoch 74', max=140, style=ProgressStyle(description_width='in…

Epoch: 74, losses 1.4357295036315918, accuracy 0.5458984375


HBox(children=(IntProgress(value=0, description='Epoch 75', max=140, style=ProgressStyle(description_width='in…

Epoch: 75, losses 1.442272424697876, accuracy 0.55859375


HBox(children=(IntProgress(value=0, description='Epoch 76', max=140, style=ProgressStyle(description_width='in…

Epoch: 76, losses 1.4324414730072021, accuracy 0.560546875


HBox(children=(IntProgress(value=0, description='Epoch 77', max=140, style=ProgressStyle(description_width='in…

Epoch: 77, losses 1.4359469413757324, accuracy 0.580078125


HBox(children=(IntProgress(value=0, description='Epoch 78', max=140, style=ProgressStyle(description_width='in…

Epoch: 78, losses 1.4509236812591553, accuracy 0.5703125


HBox(children=(IntProgress(value=0, description='Epoch 79', max=140, style=ProgressStyle(description_width='in…

Epoch: 79, losses 1.4320318698883057, accuracy 0.568359375


HBox(children=(IntProgress(value=0, description='Epoch 80', max=140, style=ProgressStyle(description_width='in…

Epoch: 80, losses 1.438913345336914, accuracy 0.5732421875


HBox(children=(IntProgress(value=0, description='Epoch 81', max=140, style=ProgressStyle(description_width='in…

Epoch: 81, losses 1.4336243867874146, accuracy 0.5712890625


HBox(children=(IntProgress(value=0, description='Epoch 82', max=140, style=ProgressStyle(description_width='in…

Epoch: 82, losses 1.4854052066802979, accuracy 0.556640625


HBox(children=(IntProgress(value=0, description='Epoch 83', max=140, style=ProgressStyle(description_width='in…

Epoch: 83, losses 1.4349448680877686, accuracy 0.560546875


HBox(children=(IntProgress(value=0, description='Epoch 84', max=140, style=ProgressStyle(description_width='in…

Epoch: 84, losses 1.4406530857086182, accuracy 0.5673828125


HBox(children=(IntProgress(value=0, description='Epoch 85', max=140, style=ProgressStyle(description_width='in…

Epoch: 85, losses 1.4281706809997559, accuracy 0.556640625


HBox(children=(IntProgress(value=0, description='Epoch 86', max=140, style=ProgressStyle(description_width='in…

Epoch: 86, losses 1.434483289718628, accuracy 0.5810546875


HBox(children=(IntProgress(value=0, description='Epoch 87', max=140, style=ProgressStyle(description_width='in…

Epoch: 87, losses 1.4200862646102905, accuracy 0.583984375


HBox(children=(IntProgress(value=0, description='Epoch 88', max=140, style=ProgressStyle(description_width='in…

Epoch: 88, losses 1.4266871213912964, accuracy 0.5703125


HBox(children=(IntProgress(value=0, description='Epoch 89', max=140, style=ProgressStyle(description_width='in…

Epoch: 89, losses 1.4264721870422363, accuracy 0.5830078125


HBox(children=(IntProgress(value=0, description='Epoch 90', max=140, style=ProgressStyle(description_width='in…

Epoch: 90, losses 1.421893835067749, accuracy 0.5673828125


HBox(children=(IntProgress(value=0, description='Epoch 91', max=140, style=ProgressStyle(description_width='in…

Epoch: 91, losses 1.4321389198303223, accuracy 0.57421875


HBox(children=(IntProgress(value=0, description='Epoch 92', max=140, style=ProgressStyle(description_width='in…

Epoch: 92, losses 1.4669698476791382, accuracy 0.5546875


HBox(children=(IntProgress(value=0, description='Epoch 93', max=140, style=ProgressStyle(description_width='in…

Epoch: 93, losses 1.4072277545928955, accuracy 0.57421875


HBox(children=(IntProgress(value=0, description='Epoch 94', max=140, style=ProgressStyle(description_width='in…

Epoch: 94, losses 1.4708144664764404, accuracy 0.5615234375


HBox(children=(IntProgress(value=0, description='Epoch 95', max=140, style=ProgressStyle(description_width='in…

Epoch: 95, losses 1.4481927156448364, accuracy 0.5703125


HBox(children=(IntProgress(value=0, description='Epoch 96', max=140, style=ProgressStyle(description_width='in…

Epoch: 96, losses 1.4628201723098755, accuracy 0.5595703125


HBox(children=(IntProgress(value=0, description='Epoch 97', max=140, style=ProgressStyle(description_width='in…

Epoch: 97, losses 1.4198832511901855, accuracy 0.5693359375


HBox(children=(IntProgress(value=0, description='Epoch 98', max=140, style=ProgressStyle(description_width='in…

Epoch: 98, losses 1.4525489807128906, accuracy 0.5576171875


HBox(children=(IntProgress(value=0, description='Epoch 99', max=140, style=ProgressStyle(description_width='in…

Epoch: 99, losses 1.4368534088134766, accuracy 0.5771484375


In [47]:
model.save_weights("model.h5")

In [48]:
def build_sample_model(vocab_size):
    model = Sequential()
    model.add(Embedding(vocab_size, EMBEDDING_SZ, batch_input_shape=(1, 1)))
    for i in range(3):
        model.add(LSTM(LSTM_SZ, return_sequences=(i != 2), stateful=True))
        model.add(Dropout(0.2))

    model.add(Dense(vocab_size))
    model.add(Activation('softmax'))
    return model

pred_model = build_sample_model(len(alphabet))

In [45]:
pred_model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding_9 (Embedding)      (1, 1, 128)               9088      
_________________________________________________________________
lstm_25 (LSTM)               (1, 1, 64)                49408     
_________________________________________________________________
dropout_25 (Dropout)         (1, 1, 64)                0         
_________________________________________________________________
lstm_26 (LSTM)               (1, 1, 64)                33024     
_________________________________________________________________
dropout_26 (Dropout)         (1, 1, 64)                0         
_________________________________________________________________
lstm_27 (LSTM)               (1, 64)                   33024     
_________________________________________________________________
dropout_27 (Dropout)         (1, 64)                   0         
__________

In [60]:
pred_model.load_weights("model.h5")
inp = np.zeros((1,1))
inp[0,0] = np.random.randint(len(alphabet))
predicted = []

for i in "The":
    inp[0,0] = alphabet.index(i)
    pred_model.predict_on_batch(inp)
    predicted.append(alphabet.index(i))
    
inp[0,0] = alphabet.index(' ')
predicted.append(alphabet.index(' '))

for i in range(80):
    probs = pred_model.predict_on_batch(inp).ravel()
    pred = np.random.choice(list(range(len(alphabet))), p=probs)
    predicted.append(pred)
    inp[0,0] = pred
    
''.join([alphabet[x] for x in predicted])

'The laved bink younsfor\ntick, and was firsing?’\n\n‘It I’ve knew I do it CITE!\nE garde'