In [2]:
# %load run_han.py

from __future__ import print_function

import os
import sys
import time
from datetime import timedelta

import numpy as np
import tensorflow as tf
from sklearn import metrics

from HANModel import HANConfig, HANModel
from data.cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocab

base_dir = 'data/cnews'
train_dir = os.path.join(base_dir, 'cnews.train.txt')
test_dir = os.path.join(base_dir, 'cnews.test.txt')
val_dir = os.path.join(base_dir, 'cnews.val.txt')
vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')

save_dir = 'checkpoints/HANModel'
save_path = os.path.join(save_dir, 'best_validation')  # 最佳验证结果保存路径


def get_time_dif(start_time):
    """获取已使用时间"""
    end_time = time.time()
    time_dif = end_time - start_time
    return timedelta(seconds=int(round(time_dif)))


def feed_data(x_batch, word_nums, sentence_nums, labels, keep_prob):
    feed_dict = {
        model.inputs: x_batch,
        model.word_lengths: word_nums,
        model.sentence_lengths: sentence_nums,
        model.labels: labels,
        model.keep_prob: keep_prob
    }
    return feed_dict


def evaluate(sess, x_, y_):
    """评估在某一数据上的准确率和损失"""
    data_len = len(x_)
    batch_eval = batch_iter(x_, y_, 128)
    total_loss = 0.0
    total_acc = 0.0
    for x_batch, y_batch in batch_eval:
        batch_len = len(x_batch)
        x_batch = np.reshape(x_batch, (batch_len, int(config.seq_length/30), 30))
        word_lengths = np.zeros((batch_len, int(config.seq_length/30))) + 30
        sentence_lengths = np.zeros((batch_len,)) + int(config.seq_length/30)

        feed_dict = feed_data(x_batch,
                              word_lengths,
                              sentence_lengths,
                              y_batch,
                              1.0)
        y_pred_class,loss, acc = sess.run([model.y_pred_cls,model.loss, model.acc], feed_dict=feed_dict)
        total_loss += loss * batch_len
        total_acc += acc * batch_len

    return y_pred_class,total_loss / data_len, total_acc / data_len


def train():
    print("Configuring TensorBoard and Saver...")
    # 配置 Tensorboard，重新训练时，请将tensorboard文件夹删除，不然图会覆盖
    tensorboard_dir = 'tensorboard/HANModel'
    if not os.path.exists(tensorboard_dir):
        os.makedirs(tensorboard_dir)

    tf.summary.scalar("loss", model.loss)
    tf.summary.scalar("accuracy", model.acc)
    merged_summary = tf.summary.merge_all()
    writer = tf.summary.FileWriter(tensorboard_dir)

    # 配置 Saver
    saver = tf.train.Saver()
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    print("Loading training and validation data...")
    # 载入训练集与验证集
    start_time = time.time()
    x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, config.seq_length)
    x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, config.seq_length)
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)

    # 创建session
    session = tf.Session()
    session.run(tf.global_variables_initializer())
    writer.add_graph(session.graph)

    print('Training and evaluating...')
    start_time = time.time()
    total_batch = 0  # 总批次
    best_acc_val = 0.0  # 最佳验证集准确率
    last_improved = 0  # 记录上一次提升批次
    require_improvement = 1000  # 如果超过1000轮未提升，提前结束训练

    flag = False
    for epoch in range(config.num_epochs):
        print('Epoch:', epoch + 1)
        batch_train = batch_iter(x_train, y_train, config.batch_size)
        for x_batch, y_batch in batch_train:
            # 取每个句子长为30
            batch_len = len(x_batch)
            x_batch = np.reshape(x_batch, (batch_len, int(config.seq_length/30), 30))
            word_lengths = np.zeros((batch_len, int(config.seq_length/30))) + 30
            sentence_lengths = np.zeros((batch_len,)) + int(config.seq_length/30)

            feed_dict = feed_data(x_batch,
                                  word_lengths,
                                  sentence_lengths,
                                  y_batch,
                                  config.dropout_keep_prob)

            if total_batch % config.save_per_batch == 0:
                # 每多少轮次将训练结果写入tensorboard scalar
                s = session.run(merged_summary, feed_dict=feed_dict)
                writer.add_summary(s, total_batch)

            if total_batch % config.print_per_batch == 0:
                # 每多少轮次输出在训练集和验证集上的性能
                feed_dict[model.keep_prob] = 1.0
                loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict)
                _, loss_val, acc_val = evaluate(session, x_val, y_val)

                if acc_val > best_acc_val:
                    # 保存最好结果
                    best_acc_val = acc_val
                    last_improved = total_batch
                    saver.save(sess=session, save_path=save_path)
                    improved_str = '*'
                else:
                    improved_str = ''

                time_dif = get_time_dif(start_time)
                msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \
                      + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}'
                print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str))

            session.run(model.optim, feed_dict=feed_dict)  # 运行优化
            total_batch += 1

            if total_batch - last_improved > require_improvement:
                # 验证集正确率长期不提升，提前结束训练
                print("No optimization for a long time, auto-stopping...")
                flag = True
                break  # 跳出循环
        if flag:  # 同上
            break


def test():
    print("Loading test data...")
    start_time = time.time()
    x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, config.seq_length)

    session = tf.Session()
    session.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess=session, save_path=save_path)  # 读取保存的模型

    print('Testing...')
    y_pred, loss_test, acc_test = evaluate(session, x_test, y_test)
    msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}'
    print(msg.format(loss_test, acc_test))

    batch_size = 128
    data_len = len(x_test)
    num_batch = int((data_len - 1) / batch_size) + 1

    y_test_cls = np.argmax(y_test, 1)
    y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32)  # 保存预测结果
    for i in range(num_batch):  # 逐批次处理
        start_id = i * batch_size
        end_id = min((i + 1) * batch_size, data_len)
        batch_len = len(x_test[start_id:end_id])
        x_batch = np.reshape(x_test[start_id:end_id], (batch_len, int(config.seq_length/30), 30))
        word_lengths = np.zeros((batch_len, int(config.seq_length/30))) + 30
        sentence_lengths = np.zeros((batch_len,)) + int(config.seq_length/30)

        feed_dict = {
            model.inputs: x_batch,
            model.word_lengths: word_lengths,
            model.sentence_lengths: sentence_lengths,
            model.keep_prob: 1.0
        }
        y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls, feed_dict=feed_dict)

    # 评估
    print("Precision, Recall and F1-Score...")
    print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories))

    # 混淆矩阵
    print("Confusion Matrix...")
    cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
    print(cm)

    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)


if __name__ == '__main__':


    print('Configuring HAN model...')
    config = HANConfig()
    if not os.path.exists(vocab_dir):  # 如果不存在词汇表，重建
        build_vocab(train_dir, vocab_dir, config.vocab_size)
    categories, cat_to_id = read_category()
    words, word_to_id = read_vocab(vocab_dir)
    config.vocab_size = len(words)
    model = HANModel(config)

    option='train'
    if option == 'train':
        train()
    else:
        test()


Configuring HAN model...
Instructions for updating:
keep_dims is deprecated, use keepdims instead
Instructions for updating:
dim is deprecated, use axis instead
Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See tf.nn.softmax_cross_entropy_with_logits_v2.

Configuring TensorBoard and Saver...
Loading training and validation data...
Time usage: 0:00:15
Training and evaluating...
Epoch: 1
Iter:      0, Train Loss:    2.3, Train Acc:  12.50%, Val Loss:    2.3, Val Acc:   8.18%, Time: 0:00:06 *
Iter:    100, Train Loss:    2.1, Train Acc:  15.62%, Val Loss:    2.4, Val Acc:   8.34%, Time: 0:00:38 *
Iter:    200, Train Loss:    1.7, Train Acc:  28.12%, Val Loss:    1.9, Val Acc:  23.68%, Time: 0:01:10 *
Iter:    300, Train Loss:    1.1, Train Acc:  57.03%, Val Loss:    1.3, Val Acc:  49.42%, Time: 0:01:40 *
Epoch: 2
Iter:    400, Train Loss:    1.1, Train Acc:  53.12%, Val Loss:    1.2, Val Acc:  56

In [6]:
test()

Loading test data...
INFO:tensorflow:Restoring parameters from checkpoints/HANModel\best_validation
Testing...
Test Loss:   0.36, Test Acc:  88.89%
Precision, Recall and F1-Score...
              precision    recall  f1-score   support

          体育       0.98      0.98      0.98      1000
          财经       0.92      0.98      0.95      1000
          房产       0.98      0.98      0.98      1000
          家居       0.88      0.51      0.64      1000
          教育       0.89      0.88      0.88      1000
          科技       0.77      0.89      0.83      1000
          时尚       0.90      0.92      0.91      1000
          时政       0.74      0.89      0.81      1000
          游戏       0.95      0.91      0.93      1000
          娱乐       0.94      0.95      0.94      1000

    accuracy                           0.89     10000
   macro avg       0.89      0.89      0.88     10000
weighted avg       0.89      0.89      0.88     10000

Confusion Matrix...
[[978   0   0   0   3   0   1   1   0  