In [1]:
!pip install keras==2.3.1

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple


In [None]:
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
import random
import numpy as np
from keras.utils import to_categorical
from sklearn.model_selection import train_test_split
from keras import Sequential
from keras.layers import Embedding, LSTM, Dense, SpatialDropout1D
from keras.callbacks import ModelCheckpoint

f = open('./dataset.txt', 'r',  encoding='utf-8')
data = f.readlines()
random.shuffle(data)#乱序一个列表
all_data = []
labels = []
for i in range(len(data)):
    line = data[i].split('\t')
    all_data.append(line[0])
    labels.append(int(line[1].strip('\n')))
labels = np.array(labels)
# 设置最频繁使用的50000个词
MAX_NB_WORDS = 100000
# 每条cut_review最大的长度
MAX_SEQUENCE_LENGTH = 50
# 设置Embeddingceng层的维度
EMBEDDING_DIM = 150


# 构建分词器
tokenizer = Tokenizer(num_words=MAX_NB_WORDS, filters='!"#$%&()*+,-./:;<=>?@[\]^_`{|}~', lower=True)
# 将所有数据放到分词器里边
tokenizer.fit_on_texts(all_data)
# 文本转化为数字序列
sequences = tokenizer.texts_to_sequences(all_data)
# 构建词汇表
word_index = tokenizer.word_index

# 按照最大文本长度截断文本
features = pad_sequences(sequences, maxlen=MAX_SEQUENCE_LENGTH)
print('共有 %s 个不相同的词语.' % len(word_index))

# 对标签进行独热编码
labels = to_categorical(labels)
#拆分训练集和测试集
X_train, X_test, Y_train, Y_test = train_test_split(features,labels, test_size = 0.10, random_state = 42)
print(X_train.shape,Y_train.shape)
print(X_test.shape,Y_test.shape)

#定义模型
model = Sequential()
model.add(Embedding(MAX_NB_WORDS, EMBEDDING_DIM, input_length=features.shape[1]))
model.add(SpatialDropout1D(0.2))
model.add(LSTM(100, dropout=0.2, recurrent_dropout=0.2))
model.add(Dense(8, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
print(model.summary())

epochs = 10
batch_size = 128

checkpoint = ModelCheckpoint('weights.{epoch:03d}-{accuracy:.4f}.hdf5', monitor='accuracy', verbose=1, save_best_only=True, mode='auto')
model.fit(X_train, Y_train, epochs=epochs, batch_size=batch_size, verbose=1, callbacks=[checkpoint], validation_data=(X_test, Y_test))

Using TensorFlow backend.


共有 425995 个不相同的词语.
(1024667, 50) (1024667, 8)
(113852, 50) (113852, 8)
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding_1 (Embedding)      (None, 50, 150)           15000000  
_________________________________________________________________
spatial_dropout1d_1 (Spatial (None, 50, 150)           0         
_________________________________________________________________
lstm_1 (LSTM)                (None, 100)               100400    
_________________________________________________________________
dense_1 (Dense)              (None, 8)                 808       
Total params: 15,101,208
Trainable params: 15,101,208
Non-trainable params: 0
_________________________________________________________________
None




Train on 1024667 samples, validate on 113852 samples
Epoch 1/1

In [4]:
!pip install jieba

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
You should consider upgrading via the '/Users/maliyu/miniconda3/bin/python -m pip install --upgrade pip' command.[0m


In [5]:
import jieba
def predict(text):
    txt = remove_punctuation(text)
    txt = [" ".join([w for w in list(jieba.cut(txt))])]
    seq = tokenizer.texts_to_sequences(txt)
    padded = pad_sequences(seq, maxlen=MAX_SEQUENCE_LENGTH)
    pred = model.predict(padded)
    cat_id= pred.argmax(axis=1)[0]
    if cat_id == 0:
        cat = '秦'
    elif cat_id ==1:
        cat = '汉、三国'
    elif cat_id ==2:
        cat = '魏晋南北朝'
    elif cat_id ==3:
        cat = '隋唐五代'
    elif cat_id ==4:
        cat = '宋、金'
    elif cat_id ==5:
        cat = '元'
    elif cat_id ==6:
        cat = '明'
    else:
        cat = '清'
    return cat

In [6]:
#定义删除除字母,数字，汉字以外的所有符号的函数
import re
def remove_punctuation(line):
    line = str(line)
    if line.strip()=='':
        return ''
    rule = re.compile(u"[^a-zA-Z0-9\u4E00-\u9FA5]")
    line = rule.sub('',line)
    return line

In [14]:
predict('青青园中葵，朝露待日晞')

'汉、三国'

In [13]:
predict('我愿平东海，身沉心不改。')

'清'

In [11]:
predict('大漠孤烟直')

'隋唐五代'

In [10]:
predict('白毛浮绿水')

'隋唐五代'

In [34]:
predict('采菊东篱下，悠然现南山')

'魏晋南北朝'

In [16]:
predict('小桥流水人家，古道西风瘦马')

'元'

In [17]:
predict('大风起兮云飞扬')

'汉、三国'

In [26]:
predict('明月松间照，清泉石上流')

'隋唐五代'

In [30]:
predict('琼姿只合在瑶台，谁向江南处处栽')

'明'

In [33]:
predict('一蓑一笠一扁舟，一丈丝纶一寸钩')

'元'

In [35]:
predict('东临碣石，以观沧海')

'汉、三国'