1. 数据集准备

In [1]:
# 打开并读取文本文件
with open('alice_in_wonderland.txt', 'r', encoding='utf-8') as file:
    text = file.read()

In [2]:
print('Text length:', len(text))

Text length: 141518


In [3]:
text[0:1000]

"\u3000\u3000\n\n\u3000\u3000ALICE'S ADVENTURES IN WONDERLAND\n\n\u3000\u3000Lewis Carroll\n\n\u3000\u3000CHAPTER I\n\n\u3000\u3000Down the Rabbit-Hole\n\n\u3000\u3000Alice was beginning to get very tired of sitting by her sisteron the bank, and of having nothing to do:  once or twice she hadpeeped into the book her sister was reading, but it had nopictures or conversations in it, `and what is the use of a book,'thought Alice `without pictures or conversation?'So she was considering in her own mind (as well as she could,for the hot day made her feel very sleepy and stupid), whetherthe pleasure of making a daisy-chain would be worth the troubleof getting up and picking the daisies, when suddenly a WhiteRabbit with pink eyes ran close by her.There was nothing so VERY remarkable in that; nor did Alicethink it so VERY much out of the way to hear the Rabbit say toitself, `Oh dear!  Oh dear!  I shall be late!'  (when she thoughtit over afterwards, it occurred to her that she ought to havewon

In [4]:
# 初始化两个列表
segments = []  # 用于存储长度为60的字符片段
next_chars = []  # 用于存储对应的下一个字符

# 设置步长和片段长度
step = 3
sequence_length = 60

In [5]:
# 遍历文本，提取片段和对应的下一个字符
for i in range(0, len(text) - sequence_length, step):
    # 提取长度为60的字符片段
    segment = text[i:i + sequence_length]
    # 提取对应的下一个字符
    next_char = text[i + sequence_length]
    # 将片段和下一个字符分别添加到对应的列表中
    segments.append(segment)
    next_chars.append(next_char)

# 打印部分结果，用于验证
print(f"Number of sequences: {len(segments)}")
print(f"Example segment: {segments[50]}")
print(f"Example next char: {next_chars[50]}")

Number of sequences: 47153
Example segment: r sisteron the bank, and of having nothing to do:  once or t
Example next char: w


2. Character to Vector

不需要Word Embedding，这是因为之前用的是word level tokenization，常用英文单词有10000个，one-hot向量维度太高了，所以要用word embedding降维为低维词向量；但做文本生成是character level tokenization，使用的字符只有几十一百个，维度不高。

In [6]:
import numpy as np

In [7]:
# 构建字符字典
chars = sorted(list(set(text)))  # 获取所有唯一字符
char_to_index = {char: index for index, char in enumerate(chars)}  # 字符到索引的映射
index_to_char = {index: char for index, char in enumerate(chars)}  # 索引到字符的映射

In [8]:
num_chars=len(char_to_index)
num_chars

71

In [9]:
num_sequences=len(segments)
# 初始化输入矩阵和目标向量
X = np.zeros((num_sequences, sequence_length, num_chars), dtype=np.bool)  # 输入矩阵

In [10]:
y = np.zeros((num_sequences, num_chars), dtype=np.bool)  # 目标向量

In [11]:
# 填充输入矩阵和目标向量
for i, segment in enumerate(segments):
    for t, char in enumerate(segment):
        X[i, t, char_to_index[char]] = 1  # One-Hot 编码
    y[i, char_to_index[next_chars[i]]] = 1  # One-Hot 编码

# 打印部分结果，用于验证
print(f"Example input sequence (One-Hot encoded):")
print(X[50])
print(f"Example target character (One-Hot encoded):")
print(y[50])

Example input sequence (One-Hot encoded):
[[False False False ... False False False]
 [False  True False ... False False False]
 [False False False ... False False False]
 ...
 [False False False ... False False False]
 [False  True False ... False False False]
 [False False False ... False False False]]
Example target character (One-Hot encoded):
[False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False  True False False False False]


3. build a NN

In [12]:
import tensorflow as tf
from tensorflow.keras import optimizers
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Input

In [13]:
model = Sequential([
    Input(shape=(sequence_length,num_chars)), # 输入60x71 片段长度60，one-hot向量 71维
    LSTM(128), # 状态向量的维度：128
    Dense(num_chars, activation="softmax") # 全连接层输出71维，使用softmax激活函数生成概率分布
])
model.summary()

In [14]:
optimizer = optimizers.RMSprop(learning_rate=0.01)

In [15]:
model.compile(loss='categorical_crossentropy', optimizer=optimizer)

4. 训练神经网络

In [16]:
model.fit(X, y, batch_size=128, epochs=5)

Epoch 1/5
[1m369/369[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 55ms/step - loss: 2.8125
Epoch 2/5
[1m369/369[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 54ms/step - loss: 1.9719
Epoch 3/5
[1m369/369[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 54ms/step - loss: 1.7496
Epoch 4/5
[1m369/369[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 54ms/step - loss: 1.6083
Epoch 5/5
[1m369/369[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 53ms/step - loss: 1.4793


<keras.src.callbacks.history.History at 0x1f46c71d1f0>

In [17]:
model.save("alice_generator.keras")

预测下一个字符

In [18]:
def sample_with_temperature(preds, temperature=1.0):
    """
    根据温度参数从预测的概率分布中选择字符
    :param preds: 模型预测的概率分布
    :param temperature: 温度参数，控制生成的多样性
    :return: 选择的字符索引
    """
    preds = np.asarray(preds).astype('float64')
    # 公式1：调整概率分布的锐度
    preds = preds ** (1 / temperature)
    # 公式2：归一化概率分布
    preds = preds / np.sum(preds)
    
    # 从调整后的概率分布中随机选择字符
    probas = np.random.multinomial(1, preds, 1)
    return np.argmax(probas)

In [19]:
# 定义文本生成函数
def generate_text(model, seed_text, length=400, temperature=0.5):
    generated = ''
    generated += seed_text
    sentence = seed_text

    for i in range(length):
        x_pred = np.zeros((1, sequence_length, num_chars))
        for t, char in enumerate(sentence):
            x_pred[0, t, char_to_index[char]] = 1

        preds = model.predict(x_pred, verbose=0)[0]
        next_index = sample_with_temperature(preds, temperature)
        next_char = index_to_char[next_index]

        generated += next_char
        sentence = sentence[1:] + next_char

    return generated

In [20]:
# 准备种子文本
seed_text = "Alice was captured by Queen when the rabbit says: `why do cr"
len(seed_text)

60

In [21]:
# 测试生成下一个字符
X_in = np.zeros((1, 60, num_chars), dtype=np.bool)  # 输入矩阵
for t, char in enumerate(seed_text):
    X_in[0, t, char_to_index[char]] = 1

preds = model.predict(X_in, verbose=0)[0] # verbose=0表示不输出日志信息，[0]是因为输出的第一维是bitch size

preds

array([4.05739083e-05, 3.05844995e-04, 4.69993647e-05, 4.53846906e-06,
       4.09607719e-06, 4.17648198e-06, 3.88996341e-06, 3.11852114e-06,
       8.10987258e-05, 1.26622954e-05, 3.94520575e-05, 3.85735220e-05,
       1.04751516e-05, 7.66049998e-05, 4.96368040e-04, 2.97916231e-05,
       1.78040555e-05, 6.07294896e-05, 1.22988538e-04, 8.10615802e-06,
       3.41269952e-05, 2.03033182e-04, 1.55877133e-04, 4.59129251e-06,
       1.10423771e-05, 8.92318349e-05, 8.66631453e-05, 2.77621904e-04,
       7.29588326e-04, 2.82847705e-05, 1.14942413e-05, 1.01993253e-04,
       7.54460489e-05, 9.49585374e-05, 2.05365068e-04, 1.45691829e-05,
       2.41421476e-05, 5.93641971e-06, 1.20172284e-04, 7.50842582e-06,
       6.50267430e-06, 3.99120654e-06, 5.26177428e-06, 3.53123978e-05,
       5.02086133e-02, 5.69160002e-05, 1.44778924e-05, 8.35243839e-07,
       2.99240619e-01, 9.93505364e-06, 1.24537346e-05, 6.25295856e-04,
       3.86385500e-01, 6.49239882e-05, 1.99518436e-05, 2.85638584e-04,
      

In [22]:
next_index = sample_with_temperature(preds, temperature=0.5)
next_char = index_to_char[next_index]
next_char

'i'

In [23]:
# 生成文本1
generated_text = generate_text(model, seed_text, length=400, temperature=0.5)
print("Generated Text:")
print(generated_text)

Generated Text:
Alice was captured by Queen when the rabbit says: `why do cried the same said the was ot the tood her sist and the gat of the Hatter went on the to see the face.`What you cale the great dreat breet to be the same wittle of her been a was a got to leaved repted to hard again.`I'll you thing I sit!' he did not she to do upened the was the was the was of the formout the gat of the was the long at the reat on the thing of it out the gat of the words the more o


In [24]:
# 生成文本1
generated_text = generate_text(model, seed_text, length=400, temperature=0.3)
print("Generated Text:")
print(generated_text)

Generated Text:
Alice was captured by Queen when the rabbit says: `why do cried the same been a nittle was the rept of the long and the same as she was the court, and was the tried to herself and the long and was the time she was not a little had for first this said to herself and she was the reat of the was on the same to lear her at the same and the was on the game again.`I with her so cause it as she was to herself to her her head of the was the other seepting to herse


In [25]:
seed_text2 = "Alice was beginning to get very tired of sitting by her sist"

In [26]:
# 生成文本2
generated_text = generate_text(model, seed_text2, length=400, temperature=0.3)
print("Generated Text:")
print(generated_text)

Generated Text:
Alice was beginning to get very tired of sitting by her sister the on the to see found her her head it was the had had for she spilled at the Caterpillar of the long of the gaterpillar to her a thing the door of the gating to her her hear her sister while the gat of the game of the at her finish and the rest of the ratter of the tried.`I'm a great be a ground of the matter sisterplied to herself was the one of the gater the same been a little a began inthe


In [27]:
# 生成文本2
generated_text = generate_text(model, seed_text2, length=400, temperature=0.5)
print("Generated Text:")
print(generated_text)

Generated Text:
Alice was beginning to get very tired of sitting by her sister her here without of the Rabbit was the was not to like a the sway of the moment had faning oneer the rest hard.  `She parious at a must a mind, and the crowers, and she was the got of the get arden about at a little firit out the reat of the ratter little first again, and the sent of the reppining to herself and the ormouse again, and she had was a get inthe back inthe now, and the sea.'`I'm th
