In [4]:
import tensorflow as tf
import numpy as np

class DataLoader():
    def __init__(self):
        path = tf.keras.utils.get_file('nietzsche.txt',
            origin='https://s3.amazonaws.com/text-datasets/nietzsche.txt')
        with open(path, encoding='utf-8') as f:
            self.raw_text = f.read().lower()
        self.chars = sorted(list(set(self.raw_text)))
        self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
        self.indices_char = dict((i, c) for i, c in enumerate(self.chars))
        self.text = [self.char_indices[c] for c in self.raw_text]

    def get_batch(self, seq_length, batch_size):
        seq = []
        next_char = []
        for i in range(batch_size):
            index = np.random.randint(0, len(self.text) - seq_length)
            seq.append(self.text[index:index+seq_length])
            next_char.append(self.text[index+seq_length])
        return np.array(seq), np.array(next_char)       # [batch_size, seq_length], [num_batch]
    

data_loader = DataLoader()
batch = data_loader.get_batch(3, 20)
print(batch)

(array([[44,  1, 27],
       [42, 47, 38],
       [ 9,  1, 46],
       [ 1, 40, 41],
       [39,  7,  1],
       [ 1, 27, 40],
       [31, 45,  1],
       [45, 42, 35],
       [30,  1, 32],
       [34, 31, 30],
       [30, 31, 44],
       [31, 38, 51],
       [32,  1, 32],
       [29, 41, 40],
       [46, 34, 31],
       [ 1, 46, 34],
       [35, 39, 27],
       [46, 34, 31],
       [27, 40, 30],
       [45, 41, 40]]), array([ 1, 42, 34,  1, 32, 30, 27, 44, 41,  1,  1,  1, 44, 30, 40, 31, 46,
        1,  1, 31]))


In [9]:
# one hot 处理


class RNN(tf.keras.Model):
    def __init__(self, num_chars, batch_size, seq_length):
        super().__init__()
        self.num_chars = num_chars
        self.seq_length = seq_length
        self.batch_size = batch_size
        self.cell = tf.keras.layers.LSTMCell(units=256)
        self.dense = tf.keras.layers.Dense(units=self.num_chars)

    def call(self, inputs, from_logits=False):
        inputs = tf.one_hot(inputs, depth=self.num_chars)       # [batch_size, seq_length, num_chars]
        state = self.cell.get_initial_state(batch_size=self.batch_size, dtype=tf.float32)   # 获得 RNN 的初始状态
        for t in range(self.seq_length):
            output, state = self.cell(inputs[:, t, :], state)   # 通过当前输入和前一时刻的状态，得到输出和当前时刻的状态
        logits = self.dense(output)
        if from_logits:                     # from_logits 参数控制输出是否通过 softmax 函数进行归一化
            return logits
        else:
            return tf.nn.softmax(logits)
        
    def predict(self, inputs, temperature=1.):
        batch_size, _ = tf.shape(inputs)
        logits = self(inputs, from_logits=True)                         # 调用训练好的RNN模型，预测下一个字符的概率分布
        prob = tf.nn.softmax(logits / temperature).numpy()              # 使用带 temperature 参数的 softmax 函数获得归一化的概率分布值
        return np.array([np.random.choice(self.num_chars, p=prob[i, :]) # 使用 np.random.choice 函数，
                         for i in range(batch_size.numpy())])           # 在预测的概率分布 prob 上进行随机取样


In [10]:
num_batches = 1000
seq_length = 40
batch_size = 50
learning_rate = 1e-3

data_loader = DataLoader()
model = RNN(num_chars=len(data_loader.chars), batch_size=batch_size, seq_length=seq_length)
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
for batch_index in range(num_batches):
    X, y = data_loader.get_batch(seq_length, batch_size)
    with tf.GradientTape() as tape:
        y_pred = model(X)
        loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)
        loss = tf.reduce_mean(loss)
        print("batch %d: loss %f" % (batch_index, loss.numpy()))
    grads = tape.gradient(loss, model.variables)
    optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))


batch 0: loss 4.047487
batch 1: loss 4.029623
batch 2: loss 4.009845
batch 3: loss 3.982798
batch 4: loss 3.951386
batch 5: loss 3.902879
batch 6: loss 3.762376
batch 7: loss 3.281501
batch 8: loss 3.721145
batch 9: loss 3.968124
batch 10: loss 3.261647
batch 11: loss 3.074416
batch 12: loss 3.357088
batch 13: loss 3.314607
batch 14: loss 3.257729
batch 15: loss 3.371844
batch 16: loss 3.148440
batch 17: loss 3.208633
batch 18: loss 3.299102
batch 19: loss 3.216666
batch 20: loss 3.033338
batch 21: loss 3.182354
batch 22: loss 3.209240
batch 23: loss 3.054554
batch 24: loss 3.148987
batch 25: loss 3.004602
batch 26: loss 3.223320
batch 27: loss 3.179885
batch 28: loss 2.847673
batch 29: loss 3.291892
batch 30: loss 2.824774
batch 31: loss 3.054077
batch 32: loss 3.085437
batch 33: loss 3.039536
batch 34: loss 3.246731
batch 35: loss 3.107517
batch 36: loss 2.922457
batch 37: loss 3.014835
batch 38: loss 3.219885
batch 39: loss 2.945192
batch 40: loss 2.810359
batch 41: loss 2.938333
ba

batch 333: loss 2.421344
batch 334: loss 2.542543
batch 335: loss 2.874084
batch 336: loss 2.586570
batch 337: loss 2.810910
batch 338: loss 2.804030
batch 339: loss 2.839488
batch 340: loss 3.143677
batch 341: loss 2.957455
batch 342: loss 2.904529
batch 343: loss 2.518216
batch 344: loss 2.913491
batch 345: loss 2.739369
batch 346: loss 2.766927
batch 347: loss 2.804340
batch 348: loss 2.609743
batch 349: loss 2.731372
batch 350: loss 2.803982
batch 351: loss 2.841900
batch 352: loss 2.506616
batch 353: loss 2.670060
batch 354: loss 2.639583
batch 355: loss 2.777249
batch 356: loss 2.630954
batch 357: loss 2.649018
batch 358: loss 2.843201
batch 359: loss 2.766170
batch 360: loss 2.852554
batch 361: loss 2.817475
batch 362: loss 3.214310
batch 363: loss 2.735746
batch 364: loss 2.913682
batch 365: loss 2.470592
batch 366: loss 2.717864
batch 367: loss 2.425551
batch 368: loss 2.852311
batch 369: loss 2.706606
batch 370: loss 2.613817
batch 371: loss 2.687792
batch 372: loss 2.945539


batch 661: loss 2.511185
batch 662: loss 2.415679
batch 663: loss 2.211957
batch 664: loss 2.365398
batch 665: loss 2.562698
batch 666: loss 2.329716
batch 667: loss 2.798088
batch 668: loss 2.744644
batch 669: loss 2.379339
batch 670: loss 2.579828
batch 671: loss 2.476270
batch 672: loss 2.294713
batch 673: loss 2.596537
batch 674: loss 2.609087
batch 675: loss 2.518300
batch 676: loss 2.449132
batch 677: loss 2.295431
batch 678: loss 2.917285
batch 679: loss 2.479810
batch 680: loss 2.702585
batch 681: loss 2.581052
batch 682: loss 2.347682
batch 683: loss 2.502775
batch 684: loss 2.370778
batch 685: loss 2.554116
batch 686: loss 2.724463
batch 687: loss 2.490904
batch 688: loss 2.599313
batch 689: loss 2.514360
batch 690: loss 2.409863
batch 691: loss 2.246351
batch 692: loss 2.524272
batch 693: loss 2.238588
batch 694: loss 2.401886
batch 695: loss 2.184940
batch 696: loss 2.121734
batch 697: loss 2.767905
batch 698: loss 2.291572
batch 699: loss 2.717912
batch 700: loss 2.587994


batch 989: loss 2.381066
batch 990: loss 1.967214
batch 991: loss 2.216496
batch 992: loss 2.336154
batch 993: loss 2.398797
batch 994: loss 2.210247
batch 995: loss 2.665572
batch 996: loss 2.313117
batch 997: loss 2.538114
batch 998: loss 2.477353
batch 999: loss 2.497485


In [11]:
X_, _ = data_loader.get_batch(seq_length, 1)
for diversity in [0.2, 0.5, 1.0, 1.2]:      # 丰富度（即temperature）分别设置为从小到大的 4 个值
    X = X_
    print("diversity %f:" % diversity)
    for t in range(400):
        y_pred = model.predict(X, diversity)    # 预测下一个字符的编号
        print(data_loader.indices_char[y_pred[0]], end='', flush=True)  # 输出预测的字符
        X = np.concatenate([X[:, 1:], np.expand_dims(y_pred, axis=1)], axis=-1)     # 将预测的字符接在输入 X 的末尾，并截断 X 的第一个字符，以保证 X 的长度不变
    print("\n")

diversity 0.200000:
he the the whe he the whe the the whe the he the he the mon the the whe the whe the whe the and the be the se the he the he the he the the for the the he the on the mand and and the and and the whe the the whe the the the the he the he the the pe the whe the whe the he the the the the the the the the he the of the the the mere the whe the the he the whe whe the whe he the the he the s and be the a

diversity 0.500000:
hithe whe of the wet ins in he are and of hin the hare porence the hom on whithe wor the whe whe he the he mand the ind whe the the whe be the the he the the the mere ane se the wand son the ince pe the wis une ne in the cose the the fve fomt ind reng in the in led in ind the ang and of pthe the ner and the sid in th the whith th in weud the the co te ber red an the rece the whingent of the the an

diversity 1.000000:
ave maulha
-er
mopr co0e arlle sel end boo s ficcich  g the iple on laoun turd and whe med dpmin therpenj? wewithe ges pyisulltemell io 