In [1]:
import keras
import numpy as np

path = keras.utils.get_file(
    '/Users/bifnudozhao/Projects/ai-playground/datasets/nietzshe.txt',
    origin='https://s3.amazonaws.com/text-datasets/nietzsche.txt'
)

text = open(path).read().lower()
print('Corpus length: ', len(text))

Corpus length:  600893


获取语料之后，先使用 `maxlen` 产生相互重叠的序列，用 one-hot 编码，然后将它们组织为一个形状为 `(sequences, maxlen, unique_charaters)` 的 numpy 数组。

In [2]:
maxlen = 60
step = 3 # sample a new sequence every three characters
sentences = []
next_chars = []

for i in range(0, len(text) - maxlen, step):
    sentences.append(text[i: i + maxlen])
    next_chars.append(text[i + maxlen])

print('Number of sequences: ', len(sentences))

chars = sorted(list(set(text)))
print('Unique characters: ', len(chars))
char_indices = dict((char, chars.index(char)) for char in chars)

x = np.zeros((len(sentences), maxlen, len(chars)), dtype=bool)
y = np.zeros((len(sentences), len(chars)), dtype=bool)

for i, sentence in enumerate(sentences):
    for t, char in enumerate(sentence):
        x[i, t, char_indices[char]] = 1
    y[i, char_indices[next_chars[i]]] = 1

Number of sequences:  200278
Unique characters:  57


In [5]:
from keras import layers
import tensorflow as tf

model = keras.models.Sequential()
model.add(tf.compat.v1.keras.layers.CuDNNLSTM(128, input_shape=(maxlen, len(chars))))
model.add(layers.Dense(len(chars), activation='softmax'))

optimizer = tf.keras.optimizers.legacy.RMSprop(learning_rate=0.01)
model.compile(loss='categorical_crossentropy', optimizer=optimizer)

In [6]:
def sample(preds, temperature=1.0):
    preds = np.asarray(preds).astype('float64')
    preds = np.log(preds) / temperature
    exp_preds = np.exp(preds)
    preds = exp_preds / np.sum(exp_preds)
    probas = np.random.multinomial(1, preds, 1)
    return np.argmax(probas)

In [7]:
import random
import sys

for epoch in range(1, 61):
    print('epoch', epoch)
    # fits the model for one iteration on the data
    model.fit(x, y, batch_size=128, epochs=1)

    # only outputs temperal result at epoch 10, 20, 30, 40, 50, 60
    if epoch % 10 != 0: continue

    start_index = random.randint(0, len(text) - maxlen - 1)
    generated_text = text[start_index:start_index + maxlen]
    print('--- Generateing with seed: ', generated_text)

    for temperature in [0.2, 0.5, 1.0, 1.2]:
        print('----- temperature: ', temperature)
        sys.stdout.write(generated_text)

        # generates 400 characters, starting from the seed text
        for i in range(400):
            sampled = np.zeros((1, maxlen, len(chars)))
            for t, char in enumerate(generated_text):
                sampled[0, t, char_indices[char]] = 1.

            preds = model.predict(sampled, verbose=0)[0]
            next_index = sample(preds, temperature)
            next_char = chars[next_index]

            generated_text += next_char
            generated_text = generated_text[1:]

            sys.stdout.write(next_char)

epoch 1


2023-10-06 11:13:45.693781: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.


--- Generateing with seed:   heaven of clear, wicked spirituality,
which, from above, wo
----- temperature:  0.2
 heaven of clear, wicked spirituality,
which, from above, would in the deation of the seem from the sense and such the called and moral sould and in the consequents and the seems and in the sense and and be seem from the delight of the seem the understand in the seems and the seals and the consequents of the seems and the sender and in the seems the sense of the seems and the seem the realing of the sense of the sense and in the allower called and in the s----- temperature:  0.5
he sense of the sense and in the allower called and in the sacred and say, which the course of may the precessation. the deleation of the all one of seems in the seched. and it has and the compatised preciles the faith in a relects for the same in the sense of the deflain in the great and prease in the a faidity of the and the man which religion, all the a decesss of the compired of an all in a could

KeyboardInterrupt: 