In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import numpy as np
from keras.layers import *
from keras.models import Sequential, Model
from keras import backend as K
from keras import metrics
import re

In [None]:
def make_one_hot(index, count):
    one_hot = [0] * count
    one_hot[index] = 1
    return one_hot

## Text processing

In [None]:
class TextController:
    def __init__(self, text, seq_length):
        self.seq_length = seq_length
        self.text = self.format_text(text)
        self.chars = sorted(list(set(self.text)))
        self.n_vocab = len(self.chars)
        self.char_to_int = dict( (char, i) for i, char in enumerate(self.chars) )
        self.int_to_char = dict( (i, char) for i, char in enumerate(self.chars) )
        self.sequences = self.make_sequences()
        
    def format_text(self, text):
        text = text.lower()
        format_items = [
            {'from': '\n+', 'to': ' '},
            {'from': '\r+', 'to': ' '},
            {'from': '\t+', 'to': ' '},
            {'from': ' +', 'to': ' '},
        ]
        for format_item in format_items:
            text = re.sub(format_item['from'], format_item['to'], text)
        return text

    def make_sequences(self):
        sequences = []
        for i in range(0, len(self.text) - self.seq_length):
            seq = self.text[i:i + self.seq_length]
            sequences.append(seq)
            
        sequences = [self.chars2nums(seq) for seq in sequences]
        for i in range(len(sequences)):
            for k in range(len(sequences[i])):
                sequences[i][k] = make_one_hot(sequences[i][k], self.n_vocab)
                
        sequences = np.reshape(sequences, (len(sequences), self.seq_length, self.n_vocab))
        return sequences
        
    def chars2nums(self, chars):
        return [self.char_to_int[char] for char in chars]
    
    def nums2chars(self, nums):
        return [self.int_to_char[num] for num in nums]
    
    def nums2str(self, nums):
        string = ''
        chars = self.nums2chars(nums)
        for char in chars:
            string += char
        return string

In [None]:
file = 'robinson_crusoe.txt'
seq_length = 20
text = open(file).read()[:9999]
TC = TextController(text, seq_length)

## Text generator

In [None]:
class TextGenerator:
    def __init__(self, n_chars, seq_length):
        self.batch_size = 1
        self.timesteps = seq_length
        self.original_dim = n_chars
        self.z_dim = 100
        self.c_dim = 1
        self.intermediate_dim = 1000
        self.epsilon_std = 1
        self.build_model()
        
    def build_model(self):
        self.c = Input(shape=(self.c_dim,))
        self.x, self.z_mean, self.z_log_sigma = self.build_encoder()
        self.z = Lambda(self.sampling, output_shape=(self.z_dim,))([self.z_mean, self.z_log_sigma])
        self.z_c = concatenate([self.z, self.c])
        self.x_gen = self.build_generator()
        self.model = Model([self.x, self.c], self.x_gen)
        self.model.compile(optimizer='rmsprop', loss=self.vae_loss)
        
    def sampling(self, args):
        z_mean, z_log_sigma = args
        batch_size = K.shape(z_mean)[0]
        epsilon = K.random_normal(shape=(batch_size, self.z_dim), mean=0., stddev=self.epsilon_std)
        return z_mean + K.exp(z_log_sigma) * epsilon
    
    def build_encoder(self):
        x = Input(shape=(self.timesteps, self.original_dim))
        h = LSTM(self.intermediate_dim, activation='relu')(x)
        z_mean = Dense(self.z_dim)(h)
        z_log_sigma = Dense(self.z_dim)(h)
        return x, z_mean, z_log_sigma
        
    def build_generator(self):
        repeated = RepeatVector(self.timesteps)(self.z_c)
        decoder_h = LSTM(self.intermediate_dim, activation='relu', return_sequences=True)
        decoder_mean = TimeDistributed(Dense(self.original_dim, activation='sigmoid'))
        h_decoded = decoder_h(repeated)
        x_decoded_mean = decoder_mean(h_decoded)
        return x_decoded_mean
    
    def vae_loss(self, x, x_decoded_mean):
        xent_loss = self.original_dim * metrics.binary_crossentropy(x, x_decoded_mean)
        kl_loss = - 0.5 * K.sum(1 + self.z_log_sigma - K.square(self.z_mean) - K.exp(self.z_log_sigma), axis=-1)
        return K.mean(xent_loss + kl_loss)
    
    def train(self, x_train, c_train, epochs):
        x_train = np.array(x_train)
        c_train = np.array(c_train)
        self.model.fit(x=[x_train, c_train], y=x_train, batch_size=self.batch_size, epochs=epochs)
    
    def predict(self, x, c):
        x = np.array([x])
        c = np.array([c])
        return self.model.predict([x, c])[0]

In [None]:
TG = TextGenerator(TC.n_vocab, seq_length)

In [None]:
x_train = TC.sequences[::seq_length]
c_train = [0] * len(x_train)

In [None]:
TG.train(x_train, c_train, 5)

In [None]:
predictions = TG.predict(x_train[1], c_train[0])
indexes = [np.argmax(prediction) for prediction in predictions]
TC.nums2str(indexes)