### Import libraries

In [60]:
import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt

### Load Data

In [69]:
def load_dataset(files, text = ''):
     # Load text from files
     for file in files:
          text += open(file, 'rb').read().decode(encoding='UTF-8')
     # Create a sorted list of unique characters
     vocab = sorted(set(text))
     print('Text length:', len(text), 'Unique characters:', len(vocab))
     return text, vocab

def split_input_target(chunk):
     # Split the text into input and target
     input_text = chunk[:-1]
     target_text = chunk[1:]
     return input_text, target_text

In [70]:
text, vocab = load_dataset(['mmebovary.txt', 'thebluecastle.txt'])

Text length: 1060629 Unique characters: 89


### Preprocessing

In [72]:
# Create a mapping from characters to numbers and vice versa
ids_from_chars = tf.keras.layers.StringLookup(
     vocabulary = list(vocab), mask_token = None
)
chars_from_ids = tf.keras.layers.StringLookup(
     vocabulary = ids_from_chars.get_vocabulary(), invert = True, mask_token = None
)

# create sequences
seq_length = 100
sequences = tf.data.Dataset.from_tensor_slices(
     ids_from_chars(tf.strings.unicode_split(text, 'UTF-8'))
).batch(seq_length+1, drop_remainder=True)


# Create a dataset of sequences
dataset = (
     sequences.map(split_input_target)
     .shuffle(10000)
     .batch(64, drop_remainder=True)
     .prefetch(tf.data.experimental.AUTOTUNE)
)

### Model

In [73]:
# Build the model
model = tf.keras.Sequential([
     tf.keras.layers.Embedding(len(ids_from_chars.get_vocabulary()), 256, batch_input_shape=[64, None]),
     tf.keras.layers.GRU(1024, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
     tf.keras.layers.Dense(len(ids_from_chars.get_vocabulary()))
])

# Compile the model
model.compile(
     optimizer='adam',
     loss=tf.losses.SparseCategoricalCrossentropy(from_logits = True),
     metrics=['accuracy']
)

history = {
     'loss': [],
     'accuracy': []
}

model.summary()

Model: "sequential_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding_4 (Embedding)     (64, None, 256)           23040     
                                                                 
 gru_4 (GRU)                 (64, None, 1024)          3938304   
                                                                 
 dense_4 (Dense)             (64, None, 90)            92250     
                                                                 
Total params: 4,053,594
Trainable params: 4,053,594
Non-trainable params: 0
_________________________________________________________________


### Train & Save the model

In [None]:
# load model
model.load_weights('params/model.h5')
model.summary()

In [None]:
modelLink = "model.h5"


model.fit(dataset, epochs=1)
model.save_weights(modelLink)

for i in range(len(model.history.history["loss"])):
     history["loss"].append(model.history.history["loss"][i])
     history["accuracy"].append(model.history.history["accuracy"][i])


# afficher l'evolution de l'apprentissage
plt.plot(history["accuracy"])

plt.legend(['train accuracy'], loc='lower left')
plt.title('accuracy')

plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.ylim(0, 1)

# ligne au verticale au point de changement de dataset
plt.axvline(len(history["loss"]) - len(model.history.history["loss"]) - 1, color='r', linestyle='--')
plt.show()

# afficher l'evolution de la perte
plt.plot(history["loss"])

plt.legend(['train loss'], loc='upper left')
plt.title('loss')

plt.xlabel('epoch')
plt.ylabel('loss')

# ligne au verticale au point de changement de dataset
plt.axvline(len(history["loss"]) - len(model.history.history["loss"]) - 1, color='r', linestyle ='--')
plt.show()

### Predict

In [35]:
def predict(text):
     # Convert the text to numbers
     input_eval = [ids_from_chars(v) for v in text]
     input_eval = tf.expand_dims(input_eval, 0)
     # Predict the next character
     res = model.predict(input_eval)
     res = tf.argmax(res[0], axis=1).numpy()
     res = chars_from_ids(res).numpy()
     return res[len(res)-1].decode("utf-8")

In [None]:
initial_text = "The men"
textLen = len(initial_text)

In [None]:
# Generate x characters
for i in range(100):
     pred = predict(initial_text[-textLen:])
     initial_text += pred

In [None]:
print(initial_text)