**載入文字資料**

In [79]:
import tensorflow as tf
import numpy as np

class DataLoader():
  def __init__(self):
    path = tf.keras.utils.get_file('nietzsche.txt',origin='https://s3.amazonaws.com/text-datasets/nietzsche.txt')

    with open(path,encoding='utf-8') as f:
      self.raw_text = f.read().lower()

    self.chars = sorted(list(set(self.raw_text)))
    self.char_indices = dict((c,i) for i, c in enumerate(self.chars))
    self.indices_char = dict((i,c) for i, c in enumerate(self.chars))
    self.text = [self.char_indices[c] for c in self.raw_text]

  def get_batch(self, seq_length, batch_size):
    seq = []
    next_char = []

    for i in range(batch_size):
      index = np.random.randint(0, len(self.text) - seq_length)
      seq.append(self.text[index:index+seq_length])
      next_char.append(self.text[index+seq_length])

      return np.array(seq), np.array(next_char)

**建立模型**

In [80]:
from numpy.core.multiarray import dtype
class RNN(tf.keras.Model):
  def __init__(self, num_chars, batch_size, seq_length):
    super().__init__()

    self.num_chars = num_chars
    self.seq_length = seq_length
    self.batch_size = batch_size

    self.cell = tf.keras.layers.LSTMCell(units=256)
    self.dense  = tf.keras.layers.Dense(units=self.num_chars)

  def call(self, inputs, from_logits=False):
    inputs = tf.one_hot(inputs, depth=self.num_chars)
    state = self.cell.get_initial_state(batch_size=self.batch_size,dtype=tf.float32)

    for t in range(self.seq_length):
      output, state = self.cell(inputs[:,t,:],state)

    logits = self.dense(output)

    if from_logits:
      return logits
    else:
      return tf.nn.softmax(logits)



**參數設定**

In [81]:
num_batches = 1000
seq_length = 40
batch_size = 1
learning_rate = 1e-3

 **訓練開始**

In [None]:
from keras.engine.training import optimizer
data_loader = DataLoader()
model = RNN(num_chars=len(data_loader.chars),batch_size=batch_size,seq_length=seq_length)
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

for batch_index in range(num_batches):
  X, y = data_loader.get_batch(seq_length,batch_size)
  with tf.GradientTape() as tape:
    y_pred = model(X)
    loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y,y_pred=y_pred)
    loss = tf.reduce_mean(loss)
    print("batch %d loss %f" % (batch_index,loss.numpy()))
  grads = tape.gradient(loss,model.variables)
  optimizer.apply_gradients(grads_and_vars=zip(grads,model.variables))


batch 0 loss 4.057539
batch 1 loss 4.075435
batch 2 loss 4.075804
batch 3 loss 4.040296
batch 4 loss 3.993575
batch 5 loss 4.045423
batch 6 loss 3.972641
batch 7 loss 3.902284
batch 8 loss 4.059404
batch 9 loss 4.077812
batch 10 loss 4.126665
batch 11 loss 4.059216
batch 12 loss 4.065313
batch 13 loss 3.990505
batch 14 loss 3.793549
batch 15 loss 4.105025
batch 16 loss 3.806931
batch 17 loss 4.053829
batch 18 loss 3.636839
batch 19 loss 4.183213
batch 20 loss 3.140736
batch 21 loss 3.822880
batch 22 loss 2.644754
batch 23 loss 5.376073
batch 24 loss 3.918265
batch 25 loss 3.955055
batch 26 loss 5.548432
batch 27 loss 1.757520
batch 28 loss 4.011651
batch 29 loss 1.776685
batch 30 loss 5.233916
batch 31 loss 3.922043
batch 32 loss 3.799827
batch 33 loss 3.142810
batch 34 loss 3.323278
batch 35 loss 2.945462
batch 36 loss 1.839828
batch 37 loss 1.785256
batch 38 loss 3.579210
batch 39 loss 3.371750
batch 40 loss 2.546045
batch 41 loss 3.288194
batch 42 loss 4.705637
batch 43 loss 3.11726

**預測**

In [82]:
def predict(self, inputts,temperature=1.):
  batch_size, _ = tf.shape(inputs)
  logits = self(inputs, from_logits=True)
  prob = tf.nn.softmax(logits / temperature).numpy()

  return np.array([np.random.choice(self.num_chars,p=prob[i,:]) for i in range(batch_size.numpy())])



In [95]:
X_, _ = data_loader.get_batch(seq_length,1)

for diversity in [2,5,10,12]:
  X = X_
  print("diversity %f:" % diversity)
  data_loader.indices_char = {}
  for t in range(400):
    y_pred = model.predict(X,diversity)
    y_pred_tuple = tuple(y_pred[0])
    print(data_loader.indices_char[y_pred_tuple],end='',flush=True)
    X =np.concatenate([[X[:,1:]],np.expand_dims(y_pred,axis=1)],axis=-1)

  print("\n")


diversity 2.000000:


KeyError: ignored

In [97]:
print(data_loader.chars)

['\n', ' ', '!', '"', "'", '(', ')', ',', '-', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '=', '?', '[', ']', '_', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'ä', 'æ', 'é', 'ë']
