In [49]:
import tensorflow as tf

In [50]:
import  numpy as np

In [51]:
class DataLoader():
    def __init__(self):
        path = tf.keras.utils.get_file('nietzsche.txt',
            origin='https://s3.amazonaws.com/text-datasets/nietzsche.txt')
        with open(path, encoding='utf-8') as f:
            self.raw_text = f.read().lower()
        self.chars = sorted(list(set(self.raw_text)))
        self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
        self.indices_char = dict((i, c) for i, c in enumerate(self.chars))
        self.text = [self.char_indices[c] for c in self.raw_text]

    def get_batch(self, seq_length, batch_size):
        seq = []
        next_char = []
        for i in range(batch_size):
            index = np.random.randint(0, len(self.text) - seq_length)
            seq.append(self.text[index:index+seq_length])
            next_char.append(self.text[index+seq_length])
        return np.array(seq), np.array(next_char)       # [batch_size, seq_length], [num_batch]

In [60]:
class RNN(tf.keras.Model):
    def __init__(self, num_chars, batch_size, seq_length):
        super().__init__()
        self.num_chars = num_chars
        self.seq_length = seq_length
        self.batch_size = batch_size
        self.cell = tf.keras.layers.LSTMCell(units=256)
        self.dense = tf.keras.layers.Dense(units=self.num_chars)

    def call(self, inputs, from_logits=False):
        inputs = tf.one_hot(inputs, depth=self.num_chars)       # [batch_size, seq_length, num_chars]
#         print(inputs)
        state = self.cell.get_initial_state(batch_size=self.batch_size, dtype=tf.float32)   # 获得 RNN 的初始状态
        for t in range(self.seq_length):
            output, state = self.cell(inputs[:, t, :], state)   # 通过当前输入和前一时刻的状态，得到输出和当前时刻的状态
        logits = self.dense(output)
        if from_logits:                     # from_logits 参数控制输出是否通过 softmax 函数进行归一化
            return logits
        else:
            return tf.nn.softmax(logits)
    def predict(self, inputs, temperature=1.):
        batch_size, _ = tf.shape(inputs)
        logits = self(inputs, from_logits=True)                         # 调用训练好的RNN模型，预测下一个字符的概率分布
        prob = tf.nn.softmax(logits / temperature).numpy()              # 使用带 temperature 参数的 softmax 函数获得归一化的概率分布值
        return np.array([np.random.choice(self.num_chars, p=prob[i, :]) for i in range(batch_size.numpy())])    

In [61]:
data_loader = DataLoader()

In [64]:
num_batches = 20  # 迭代次数
seq_length = 40   # 序列长度
batch_size = 50   # 批大小
learning_rate = 1e-3

In [65]:
len(data_loader.chars)

57

In [66]:
data_loader = DataLoader()
model = RNN(num_chars=len(data_loader.chars), batch_size=batch_size, seq_length=seq_length)
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
for batch_index in range(num_batches):
    X, y = data_loader.get_batch(seq_length, batch_size)
    with tf.GradientTape() as tape:
        y_pred = model(X)
        loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)
        loss = tf.reduce_mean(loss)
        print("batch %d: loss %f" % (batch_index, loss.numpy()))
    grads = tape.gradient(loss, model.variables)
    optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))

batch 0: loss 4.043886
batch 1: loss 4.032377
batch 2: loss 4.017450
batch 3: loss 3.999233
batch 4: loss 3.958704
batch 5: loss 3.873814
batch 6: loss 3.874806
batch 7: loss 3.545548
batch 8: loss 3.199974
batch 9: loss 3.311888
batch 10: loss 2.892821
batch 11: loss 3.078509
batch 12: loss 3.374426
batch 13: loss 3.111451
batch 14: loss 3.406809
batch 15: loss 3.226335
batch 16: loss 3.216307
batch 17: loss 3.150036
batch 18: loss 3.045131
batch 19: loss 3.164841


In [68]:
X_, _ = data_loader.get_batch(seq_length, 1)
for diversity in [0.2, 0.5, 1.0, 1.2]:      # 丰富度（即temperature）分别设置为从小到大的 4 个值
    X = X_
    print("diversity %f:" % diversity)
    for t in range(400):
        y_pred = model.predict(X, diversity)    # 预测下一个字符的编号
        print(data_loader.indices_char[y_pred[0]], end='', flush=True)  # 输出预测的字符
        X = np.concatenate([X[:, 1:], np.expand_dims(y_pred, axis=1)], axis=-1)     # 将预测的字符接在输入 X 的末尾，并截断 X 的第一个字符，以保证 X 的长度不变
    print("\n")

diversity 0.200000:
sststt seshstststthsssstsotsntssstssssatttssssssstssssstsssstssssssssset ss ossssssthsestssaseassttts tssas

KeyboardInterrupt: 

In [58]:
X_

array([[40, 33,  1, 42, 41, 49, 31, 44,  8,  8, 49, 34, 41,  1, 49, 41,
        47, 38, 30,  1, 29, 27, 44, 31,  0, 46, 41,  1, 44, 31, 30, 47,
        29, 31,  1, 39, 31, 40,  1, 46]])

In [26]:
X[:, 1:]

array([[32,  1, 34, 35, 45,  1, 46, 31, 27, 29, 34, 31, 44,  7,  1, 27,
        40, 30,  1, 27, 28, 41, 48, 31,  1, 27, 38, 38,  1, 46, 41,  1,
        35, 40, 46, 31, 44, 42, 44]])

In [27]:
X, y = data_loader.get_batch(seq_length, batch_size)

In [28]:
X

array([[40, 46, 45, ..., 38, 41, 48],
       [40, 31, 44, ..., 45, 46, 35],
       [ 1, 30, 35, ..., 34, 35, 38],
       ...,
       [41, 40, 46, ..., 38, 29, 47],
       [35, 44, 35, ..., 21,  1, 46],
       [27, 29, 31, ..., 40,  1, 28]])

In [29]:
tf.one_hot([1,2,3], 23)

<tf.Tensor: shape=(3, 23), dtype=float32, numpy=
array([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>

In [30]:
inputs = tf.one_hot(X, depth=len(data_loader.chars))

In [31]:
inputs

<tf.Tensor: shape=(50, 40, 57), dtype=float32, numpy=
array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 1., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       ...,

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 