In [1]:
import numpy as np
from keras.models import Sequential
from keras.layers import LSTM
from keras.layers import Dense
from keras.utils import np_utils

Using TensorFlow backend.


1. 问题描述：对于原始字符集“AB……Z”，根据前面的字符预测下一个字符。
2. 思路：首先定义问题的输入输出，对于序列问题，每一步的输入组成的序列为模型的输入，每一步的输出组成的序列为模型的输出；然后，必须将语义输入输出转换成计算机能够处理的数值，这个过程中用一个字典映射作辅助。

In [2]:
raw_data = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
window = 3                  # 相当与n-gram模型中的窗口大小n
char_to_int = dict((c,i) for i,c in enumerate(raw_data))  # 将输入转换成数值（模型只能计算数值）
int_to_char = dict((i,c) for i,c in enumerate(raw_data))  # 方便将计算结果转换成语义结果
x_data = []
y_data = []
print("window=%s时理想的效果：" % window)
for i in range(0, len(raw_data)-window, 1):
    seq_in = raw_data[i:i+window]
    seq_out = raw_data[i+window]
    x_data.append([char_to_int[char] for char in seq_in]) # 将输入转换成数值
    y_data.append(char_to_int[seq_out])
    print(seq_in,'->',seq_out)
print("total samples: %s" % len(x_data))

window=3时理想的效果：
ABC -> D
BCD -> E
CDE -> F
DEF -> G
EFG -> H
FGH -> I
GHI -> J
HIJ -> K
IJK -> L
JKL -> M
KLM -> N
LMN -> O
MNO -> P
NOP -> Q
OPQ -> R
PQR -> S
QRS -> T
RST -> U
STU -> V
TUV -> W
UVW -> X
VWX -> Y
WXY -> Z
total samples: 23


In [3]:
x = np.reshape(x_data, (len(x_data), window, 1)) # 将数值输入转换成（sample，time-step，feature）形式送入LSTM处理。
x = x/len(raw_data)
y = np_utils.to_categorical(y_data)              # 对y_data进行one-hot编码

In [4]:
lstm = Sequential()
lstm.add(LSTM(32, input_shape=(x.shape[1], x.shape[2])))
lstm.add(Dense(y.shape[1], activation="softmax"))
lstm.compile(loss="categorical_crossentropy", optimizer="adam", metrics=['accuracy'])
lstm.fit(x, y, batch_size=1, epochs=500, verbose=0)
score = lstm.evaluate(x, y, verbose=0)
print("train accurancy: %.2f" % (score[1]*100))
for term in x_data:
    sample = np.reshape(term, (1,len(term),1))
    sample = sample/float(len(raw_data))
    prediction = lstm.predict(sample, verbose=0)
    index = np.argmax(prediction)
    result = int_to_char[index]
    seq_in = [int_to_char[value] for value in term]
    print(seq_in,'->',result)

train accurancy: 100.00
['A', 'B', 'C'] -> D
['B', 'C', 'D'] -> E
['C', 'D', 'E'] -> F
['D', 'E', 'F'] -> G
['E', 'F', 'G'] -> H
['F', 'G', 'H'] -> I
['G', 'H', 'I'] -> J
['H', 'I', 'J'] -> K
['I', 'J', 'K'] -> L
['J', 'K', 'L'] -> M
['K', 'L', 'M'] -> N
['L', 'M', 'N'] -> O
['M', 'N', 'O'] -> P
['N', 'O', 'P'] -> Q
['O', 'P', 'Q'] -> R
['P', 'Q', 'R'] -> S
['Q', 'R', 'S'] -> T
['R', 'S', 'T'] -> U
['S', 'T', 'U'] -> V
['T', 'U', 'V'] -> W
['U', 'V', 'W'] -> X
['V', 'W', 'X'] -> Y
['W', 'X', 'Y'] -> Z
