## 数据集下载

使用THUCNews的一个子集进行训练与测试：[https://www.lanzous.com/i5t0lsd](https://www.lanzous.com/i5t0lsd)

本次训练使用了其中的10个分类，每个分类6500条数据。类别如下：

体育, 财经, 房产, 家居, 教育, 科技, 时尚, 时政, 游戏, 娱乐

### cnews_loader.py为数据的预处理文件。

- read_file(): 读取文件数据;
- build_vocab(): 构建词汇表，使用字符级的表示，这一函数会将词汇表存储下来，避免每一次重复处理;
- read_vocab(): 读取上一步存储的词汇表，转换为{词：id}表示;
- read_category(): 将分类目录固定，转换为{类别: id}表示;
- to_words(): 将一条由id表示的数据重新转换为文字;
- process_file(): 将数据集从文字转换为固定长度的id序列表示;
- batch_iter(): 为神经网络的训练准备经过shuffle的批次的数据。

### textCNN的模型和可配置的参数，在cnn_model.py中。

In [None]:
from __future__ import print_function

import os
import sys
import time
from datetime import timedelta

import numpy as np
import tensorflow as tf
from sklearn import metrics

from cnn_model import TCNNConfig, TextCNN
from cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocab

In [3]:
base_dir = 'cnews'
train_dir = os.path.join(base_dir, 'cnews.train.txt')
test_dir = os.path.join(base_dir, 'cnews.test.txt')
val_dir = os.path.join(base_dir, 'cnews.val.txt')
vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')

save_dir = 'checkpoints/textcnn'
save_path = os.path.join(save_dir, 'best_validation')  # 最佳验证结果保存路径


def get_time_dif(start_time):
    """获取已使用时间"""
    end_time = time.time()
    time_dif = end_time - start_time
    return timedelta(seconds=int(round(time_dif)))


def feed_data(x_batch, y_batch, keep_prob):
    feed_dict = {
        model.input_x: x_batch,
        model.input_y: y_batch,
        model.keep_prob: keep_prob
    }
    return feed_dict


def evaluate(sess, x_, y_):
    """评估在某一数据上的准确率和损失"""
    data_len = len(x_)
    batch_eval = batch_iter(x_, y_, 128)
    total_loss = 0.0
    total_acc = 0.0
    for x_batch, y_batch in batch_eval:
        batch_len = len(x_batch)
        feed_dict = feed_data(x_batch, y_batch, 1.0)
        loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict)
        total_loss += loss * batch_len
        total_acc += acc * batch_len

    return total_loss / data_len, total_acc / data_len


def train():
    print("Configuring TensorBoard and Saver...")
    # 配置 Tensorboard，重新训练时，请将tensorboard文件夹删除，不然图会覆盖
    tensorboard_dir = 'tensorboard/textcnn'
    if not os.path.exists(tensorboard_dir):
        os.makedirs(tensorboard_dir)

    tf.summary.scalar("loss", model.loss)
    tf.summary.scalar("accuracy", model.acc)
    merged_summary = tf.summary.merge_all()
    writer = tf.summary.FileWriter(tensorboard_dir)

    # 配置 Saver
    saver = tf.train.Saver()
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    print("Loading training and validation data...")
    # 载入训练集与验证集
    start_time = time.time()
    x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, config.seq_length)
    x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, config.seq_length)
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)

    # 创建session
    session = tf.Session()
    session.run(tf.global_variables_initializer())
    writer.add_graph(session.graph)

    print('Training and evaluating...')
    start_time = time.time()
    total_batch = 0  # 总批次
    best_acc_val = 0.0  # 最佳验证集准确率
    last_improved = 0  # 记录上一次提升批次
    require_improvement = 1000  # 如果超过1000轮未提升，提前结束训练

    flag = False
    for epoch in range(config.num_epochs):
        print('Epoch:', epoch + 1)
        batch_train = batch_iter(x_train, y_train, config.batch_size)
        for x_batch, y_batch in batch_train:
            feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob)

            if total_batch % config.save_per_batch == 0:
                # 每多少轮次将训练结果写入tensorboard scalar
                s = session.run(merged_summary, feed_dict=feed_dict)
                writer.add_summary(s, total_batch)

            if total_batch % config.print_per_batch == 0:
                # 每多少轮次输出在训练集和验证集上的性能
                feed_dict[model.keep_prob] = 1.0
                loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict)
                loss_val, acc_val = evaluate(session, x_val, y_val)  # todo

                if acc_val > best_acc_val:
                    # 保存最好结果
                    best_acc_val = acc_val
                    last_improved = total_batch
                    saver.save(sess=session, save_path=save_path)
                    improved_str = '*'
                else:
                    improved_str = ''

                time_dif = get_time_dif(start_time)
                msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \
                      + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}'
                print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str))

            feed_dict[model.keep_prob] = config.dropout_keep_prob
            session.run(model.optim, feed_dict=feed_dict)  # 运行优化
            total_batch += 1

            if total_batch - last_improved > require_improvement:
                # 验证集正确率长期不提升，提前结束训练
                print("No optimization for a long time, auto-stopping...")
                flag = True
                break  # 跳出循环
        if flag:  # 同上
            break


def test():
    print("Loading test data...")
    start_time = time.time()
    x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, config.seq_length)

    session = tf.Session()
    session.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess=session, save_path=save_path)  # 读取保存的模型

    print('Testing...')
    loss_test, acc_test = evaluate(session, x_test, y_test)
    msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}'
    print(msg.format(loss_test, acc_test))

    batch_size = 128
    data_len = len(x_test)
    num_batch = int((data_len - 1) / batch_size) + 1

    y_test_cls = np.argmax(y_test, 1)
    y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32)  # 保存预测结果
    for i in range(num_batch):  # 逐批次处理
        start_id = i * batch_size
        end_id = min((i + 1) * batch_size, data_len)
        feed_dict = {
            model.input_x: x_test[start_id:end_id],
            model.keep_prob: 1.0
        }
        y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls, feed_dict=feed_dict)

    # 评估
    print("Precision, Recall and F1-Score...")
    print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories))

    # 混淆矩阵
    print("Confusion Matrix...")
    cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
    print(cm)

    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)


## 开始训练

sys.argv 输出运行文件名

In [5]:
type_ = 'train'

print('Configuring CNN model...')
config = TCNNConfig()
if not os.path.exists(vocab_dir):  # 如果不存在词汇表，重建
    build_vocab(train_dir, vocab_dir, config.vocab_size)
categories, cat_to_id = read_category()
words, word_to_id = read_vocab(vocab_dir)
config.vocab_size = len(words)
model = TextCNN(config)

if type_ == 'train':
    train()
else:
    test()

W0826 20:03:59.630935 140222699312960 module_wrapper.py:136] From /usr/local/python3/lib/python3.6/site-packages/tensorflow_core/python/util/module_wrapper.py:163: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.



Configuring CNN model...


W0826 20:03:59.934021 140222699312960 deprecation.py:323] From /home/python_home/WeiZhongChuang/ML/RNN/cnn_model.py:50: conv1d (from tensorflow.python.layers.convolutional) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.keras.layers.Conv1D` instead.
W0826 20:03:59.935921 140222699312960 deprecation.py:323] From /usr/local/python3/lib/python3.6/site-packages/tensorflow_core/python/layers/convolutional.py:218: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `layer.__call__` method instead.
W0826 20:04:00.065614 140222699312960 deprecation.py:323] From /home/python_home/WeiZhongChuang/ML/RNN/cnn_model.py:56: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.Dense instead.
W0826 20:04:00.080801 140222699312960 lazy_loader.py:50] 
The TensorFlow contrib 

Configuring TensorBoard and Saver...
Loading training and validation data...
Time usage: 0:00:13
Training and evaluating...
Epoch: 1
Iter:      0, Train Loss:    2.3, Train Acc:  10.94%, Val Loss:    2.3, Val Acc:  10.00%, Time: 0:00:04 *
Iter:    100, Train Loss:   0.83, Train Acc:  76.56%, Val Loss:    1.2, Val Acc:  66.44%, Time: 0:00:24 *
Iter:    200, Train Loss:    0.2, Train Acc:  93.75%, Val Loss:   0.67, Val Acc:  78.18%, Time: 0:00:43 *
Iter:    300, Train Loss:   0.14, Train Acc:  95.31%, Val Loss:   0.42, Val Acc:  87.84%, Time: 0:01:04 *
Iter:    400, Train Loss:   0.12, Train Acc:  96.88%, Val Loss:   0.37, Val Acc:  90.06%, Time: 0:01:25 *
Iter:    500, Train Loss:   0.16, Train Acc:  96.88%, Val Loss:   0.35, Val Acc:  89.48%, Time: 0:01:46 
Iter:    600, Train Loss:    0.2, Train Acc:  92.19%, Val Loss:   0.26, Val Acc:  92.38%, Time: 0:02:08 *
Iter:    700, Train Loss:   0.13, Train Acc:  96.88%, Val Loss:   0.27, Val Acc:  92.50%, Time: 0:02:30 *
Epoch: 2
Iter:    80

## 模型测试

In [7]:

test()

Loading test data...
Testing...
Test Loss:   0.11, Test Acc:  97.05%
Precision, Recall and F1-Score...
              precision    recall  f1-score   support

          体育       1.00      0.99      0.99      1000
          财经       0.98      0.98      0.98      1000
          房产       1.00      1.00      1.00      1000
          家居       0.97      0.92      0.94      1000
          教育       0.93      0.96      0.95      1000
          科技       0.95      0.98      0.96      1000
          时尚       0.98      0.97      0.97      1000
          时政       0.95      0.96      0.95      1000
          游戏       0.97      0.98      0.98      1000
          娱乐       0.98      0.97      0.98      1000

    accuracy                           0.97     10000
   macro avg       0.97      0.97      0.97     10000
weighted avg       0.97      0.97      0.97     10000

Confusion Matrix...
[[990   0   0   0   3   3   0   2   1   1]
 [  0 977   0   2   5   4   0  12   0   0]
 [  0   0 996   4   0   0   0   