In [74]:
import tensorflow as tf
import numpy as np

In [75]:
class DataLoader():
    def __init__(self):
        path = tf.keras.utils.get_file('nietzsche.txt',
            origin='https://s3.amazonaws.com/text-datasets/nietzsche.txt')
        with open(path, encoding='utf-8') as f:
            self.raw_text = f.read().lower()
        self.chars = sorted(list(set(self.raw_text)))
        self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
        self.indices_char = dict((i, c) for i, c in enumerate(self.chars))
        self.text = [self.char_indices[c] for c in self.raw_text]

    def get_batch(self, seq_length, batch_size):# 每次训练一批样本 可以取均值,避免单个离群值带来的偏差
        seq = []
        next_char = []
        for i in range(batch_size):
            index = np.random.randint(0, len(self.text) - seq_length)
            seq.append(self.text[index:index+seq_length])
            next_char.append(self.text[index+seq_length])
        return np.array(seq), np.array(next_char)       # [batch_size, seq_length], [batch_size]

In [76]:
class RNN(tf.keras.Model):
    def __init__(self,num_chars,batch_size,seq_length):
        super().__init__()
        self.num_chars = num_chars
        self.batch_size = batch_size
        self.seq_length = seq_length
        self.cell = tf.keras.layers.LSTMCell(units = 256)
        self.dense = tf.keras.layers.Dense(units = self.num_chars)
        
    def call(self,inputs,from_logits = False):
        inputs = tf.one_hot(inputs,depth = self.num_chars)
        state = self.cell.get_initial_state(batch_size = self.batch_size,dtype = tf.float32)
        for t in range(self.seq_length):
            output,state = self.cell(inputs[:,t,:],state)
        
        logits = self.dense(output)
        if from_logits:
            return logits
        else:
            return tf.nn.softmax(logits)
        
    def predict(self, inputs, temperature=1.):
        batch_size, _ = tf.shape(inputs)
        logits = self(inputs, from_logits=True)                         # 调用训练好的RNN模型，预测下一个字符的概率分布
        prob = tf.nn.softmax(logits / temperature).numpy()              # 使用带 temperature 参数的 softmax 函数获得归一化的概率分布值
        return np.array([np.random.choice(self.num_chars, p=prob[i, :]) # 使用 np.random.choice 函数，
                         for i in range(batch_size.numpy())])   
        
    
    

In [77]:
num_batches = 1000  # 训练1000轮
seq_length = 40     # 序列长度40
batch_size = 50     # 每次的批次为50
learning_rate = 1e-3 #学习率

In [78]:
data_loader = DataLoader()
model = RNN(num_chars = len(data_loader.chars),batch_size = batch_size,seq_length = seq_length)
optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate)
for batch_index in range(num_batches):
    X,y = data_loader.get_batch(seq_length,batch_size)
    with tf.GradientTape() as tape:
        y_pred = model(X)
        loss = tf.keras.losses.sparse_categorical_crossentropy(y_true = y,y_pred = y_pred)
        loss = tf.reduce_mean(loss)
        print("batch_index %d,loss %f" % (batch_index,loss.numpy()))
    grads = tape.gradient(loss,model.variables)
    optimizer.apply_gradients(grads_and_vars = zip(grads,model.variables))


batch_index 0,loss 4.046930
batch_index 1,loss 4.023759
batch_index 2,loss 4.012225
batch_index 3,loss 3.974755
batch_index 4,loss 3.959831
batch_index 5,loss 3.888484
batch_index 6,loss 3.741380
batch_index 7,loss 3.456250
batch_index 8,loss 3.312930
batch_index 9,loss 3.769981
batch_index 10,loss 3.195498
batch_index 11,loss 3.174865
batch_index 12,loss 3.027397
batch_index 13,loss 3.234847
batch_index 14,loss 3.376699
batch_index 15,loss 3.104056
batch_index 16,loss 3.066119
batch_index 17,loss 3.109254
batch_index 18,loss 3.186244
batch_index 19,loss 2.985195
batch_index 20,loss 3.010646
batch_index 21,loss 3.104559
batch_index 22,loss 3.439594
batch_index 23,loss 3.189495
batch_index 24,loss 3.137863
batch_index 25,loss 2.824996
batch_index 26,loss 2.992568
batch_index 27,loss 2.920878
batch_index 28,loss 2.891249
batch_index 29,loss 3.231491
batch_index 30,loss 3.112705
batch_index 31,loss 3.126866
batch_index 32,loss 3.228296
batch_index 33,loss 2.992580
batch_index 34,loss 3.15

batch_index 277,loss 2.823972
batch_index 278,loss 2.827854
batch_index 279,loss 2.732025
batch_index 280,loss 2.975671
batch_index 281,loss 2.864226
batch_index 282,loss 2.801941
batch_index 283,loss 2.881842
batch_index 284,loss 2.798435
batch_index 285,loss 3.143430
batch_index 286,loss 2.875626
batch_index 287,loss 2.619751
batch_index 288,loss 3.037164
batch_index 289,loss 2.388178
batch_index 290,loss 3.079540
batch_index 291,loss 2.711489
batch_index 292,loss 2.895314
batch_index 293,loss 2.693599
batch_index 294,loss 2.480186
batch_index 295,loss 3.090853
batch_index 296,loss 2.703168
batch_index 297,loss 2.688894
batch_index 298,loss 2.977446
batch_index 299,loss 2.895598
batch_index 300,loss 2.617867
batch_index 301,loss 2.795693
batch_index 302,loss 2.629225
batch_index 303,loss 2.727459
batch_index 304,loss 2.603723
batch_index 305,loss 2.881585
batch_index 306,loss 2.754404
batch_index 307,loss 3.038110
batch_index 308,loss 2.769368
batch_index 309,loss 2.799619
batch_inde

batch_index 551,loss 2.605426
batch_index 552,loss 2.433233
batch_index 553,loss 2.397949
batch_index 554,loss 2.604249
batch_index 555,loss 2.657391
batch_index 556,loss 2.468696
batch_index 557,loss 2.846173
batch_index 558,loss 2.325582
batch_index 559,loss 2.517137
batch_index 560,loss 2.332819
batch_index 561,loss 2.484801
batch_index 562,loss 2.687213
batch_index 563,loss 2.627560
batch_index 564,loss 2.634017
batch_index 565,loss 2.585071
batch_index 566,loss 2.675633
batch_index 567,loss 2.806278
batch_index 568,loss 2.470644
batch_index 569,loss 2.635962
batch_index 570,loss 2.518394
batch_index 571,loss 2.579055
batch_index 572,loss 2.664146
batch_index 573,loss 2.544859
batch_index 574,loss 2.742432
batch_index 575,loss 2.543057
batch_index 576,loss 2.526941
batch_index 577,loss 2.822198
batch_index 578,loss 2.481163
batch_index 579,loss 2.304783
batch_index 580,loss 2.506709
batch_index 581,loss 2.261612
batch_index 582,loss 2.704406
batch_index 583,loss 2.825309
batch_inde

batch_index 825,loss 2.579637
batch_index 826,loss 2.396679
batch_index 827,loss 2.452386
batch_index 828,loss 2.302875
batch_index 829,loss 2.298754
batch_index 830,loss 2.325796
batch_index 831,loss 2.631603
batch_index 832,loss 2.386598
batch_index 833,loss 2.205257
batch_index 834,loss 2.409565
batch_index 835,loss 2.434512
batch_index 836,loss 2.769735
batch_index 837,loss 2.530623
batch_index 838,loss 2.298646
batch_index 839,loss 2.419894
batch_index 840,loss 2.230025
batch_index 841,loss 2.582532
batch_index 842,loss 2.341814
batch_index 843,loss 2.485786
batch_index 844,loss 2.677834
batch_index 845,loss 2.499001
batch_index 846,loss 2.598781
batch_index 847,loss 2.531510
batch_index 848,loss 2.401077
batch_index 849,loss 2.344452
batch_index 850,loss 2.457129
batch_index 851,loss 2.557630
batch_index 852,loss 2.178030
batch_index 853,loss 2.123830
batch_index 854,loss 2.491939
batch_index 855,loss 2.617218
batch_index 856,loss 2.369057
batch_index 857,loss 2.147678
batch_inde

In [79]:
X_, _ = data_loader.get_batch(seq_length, 1)
for diversity in [0.2, 0.5, 1.0, 1.2]:      # 丰富度（即temperature）分别设置为从小到大的 4 个值
    X = X_
    print("diversity %f:" % diversity)
    for t in range(400):
        y_pred = model.predict(X, diversity)    # 预测下一个字符的编号
        print(data_loader.indices_char[y_pred[0]], end='', flush=True)  # 输出预测的字符
        X = np.concatenate([X[:, 1:], np.expand_dims(y_pred, axis=1)], axis=-1)     # 将预测的字符接在输入 X 的末尾，并截断 X 的第一个字符，以保证 X 的长度不变
    print("\n")

diversity 0.200000:
 the the the the the the the the the the the the the the the the the the the the the the the the the the the core the the the the the the the the the ther the the the the the the ther the the the the the the the the the the the the the the the the the the the the the the the the ther the the the the the the the tore the the the the the the the the the the the the the the the the the the the the th

diversity 0.500000:
 in the the tore the thevenit ond the irinite the the the the core fore hererty be the rores ent the gnrin tuling thicant on the the the pres or the the the the the srether the nor verules in the the the mon the tere of manenan the the eth the an treres the the thit ther ent ind rather pretn ther of the heren the indery the coreres the
reratire ther thit tise anders sont cinise
the tor the prerenc

diversity 1.000000:
 pore gad vedd, de iple boremo onlo
gis if vers therceen aft5rindy init or ted iele inct perted mongiup,, anlencb
ctd-yfetetns
r mulivar