# 使用Jupyter-Notebook快速搭建文本分类应用

这是一篇介绍如何在PAI-DSW里用EasyTransfer平台训练文本分类器的教程。只需要一份配置文件，一份ipynb文件，您就可以完成对原始数据的特征提取，网络构建，损失函数及分类评估/预测的简单调用。运行本DEMO需要如下的配置信息

- python 3.6+
- tensorflow 1.12+

## （一）数据准备
下面以一个基于bert的文本分类为例，通过端到端的分布式训练/评估/预测流程，展示平台的易用性。这里的端到端指的是直接读入原始数据就可以训练，而不需要事先转换成Bert特征格式。


In [1]:
!mkdir data
!wget -O ./data/train.csv https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/tutorial/dsw/train.csv
!wget -O ./data/dev.csv https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/tutorial/dsw/dev.csv

--2020-10-27 10:51:04--  https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/tutorial/dsw/train.csv
Resolving atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)... 47.101.88.27
Connecting to atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)|47.101.88.27|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 6008185 (5.7M) [text/csv]
Saving to: ‘./data/train.csv’


2020-10-27 10:51:05 (11.8 MB/s) - ‘./data/train.csv’ saved [6008185/6008185]

--2020-10-27 10:51:05--  https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/tutorial/dsw/dev.csv
Resolving atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)... 47.101.88.27
Connecting to atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)|47.101.88.27|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1125139 (1.1M) [text/csv]
Saving to: ‘./da

In [2]:
import pandas as pd

In [3]:
train_set = pd.read_csv('./data/train.csv', header=None, delimiter='\t', encoding='utf8')

In [4]:
dev_set = pd.read_csv('./data/dev.csv', header=None, delimiter='\t', encoding='utf8')

In [5]:
train_set.columns = ['label','content']

In [6]:
train_set.head(2)

Unnamed: 0,label,content
0,agriculture,"加快产城融合 以科技创新引领新城区建设 新城区,城镇化率,中心城区,科技新城,科技创新"
1,agriculture,"9X10米清雅型别墅，大容量简约民宿风，来自美墅建房的诱惑！ 农村,琉璃瓦,民宿,农村自建房..."


In [7]:
train_set.count()

label      53360
content    53360
dtype: int64

In [8]:
dev_set.count()

0    10000
1    10000
dtype: int64

## （二）定义配置文件

如下是我们easytransfe的配置，比如说predict_checkpoint_path是指定验证集上指标最好的checkpoint的路径。
详细配置介绍请看easytransfer文档:https://yuque.antfin-inc.com/pai-innovative-algo/apx4dp/dqy4go

In [9]:
config_json = {
    "worker_hosts": "locahost",
    "task_index": 1,
    "job_name": "chief",
    "num_gpus": 1,
    "num_workers": 1,
    "modelZooBasePath": "/home/admin/jupyter/my_model_zoo",
    "preprocess_config": {
        "input_schema": "label:str:1,content:str:1",
        "first_sequence": "content",
        "second_sequence": None,
        "sequence_length": 16,
        "label_name": "label",
        "label_enumerate_values": "tech,finance,entertainment,world,car,culture,sports,military,edu,game,travel,agriculture,house,story,stock",
        "output_schema": "label,predictions"
    },
    "model_config": {
        "pretrain_model_name_or_path": "pai-bert-tiny-zh",
        "num_labels": 15
    },
    "train_config": {
        "train_input_fp": "./data/train.csv",
        "train_batch_size": 2,
        "num_epochs": 0.01,
        "model_dir": "model_dir",
        "optimizer_config": {
            "learning_rate": 1e-5
        },
        "distribution_config": {
            "distribution_strategy": None
        }
    },
    "evaluate_config": {
        "eval_input_fp": "./data/dev.csv",
        "eval_batch_size": 8
    },
    "predict_config": {
        "predict_checkpoint_path": "model_dir/model.ckpt-267",
        "predict_input_fp": "./data/dev.csv",
        "predict_output_fp": "./data/predict.csv",
        "predict_batch_size": 1
    }
}

##  （三）定义分类应用

### 导入easytransfer库文件
- base_model: 所有应用都需要继承的父类
- Config：用来解析配置文件的父类
- layers：基础组件。比如Embedding，Attention等
- model_zoo: 管理预训练模型的组件库，通过get_pretrained_model方法可调用bert模型
- preprocessors：管理各种应用的预处理逻辑
- CSVReader：csv格式的数据读取器
- softmax_cross_entropy：用于分类任务的损失函数
- classification_eval_metrics：用于分类任务的评估指标，比如Accuracy

In [10]:
import tensorflow as tf
from easytransfer import base_model, Config
from easytransfer import layers
from easytransfer import model_zoo
from easytransfer import preprocessors
from easytransfer.datasets import CSVReader,CSVWriter
from easytransfer.losses import softmax_cross_entropy
from easytransfer.evaluators import classification_eval_metrics

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
W1027 10:51:40.673681 139683208804160 deprecation_wrapper.py:119] From /home/admin/.local/lib/python3.6/site-packages/easytransfer/engines/model.py:22: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.

W1027 10:51:40.674461 139683208804160 deprecation_wrapper.py:119] From /home/admin/.local/lib/python3.6/site-packages/easytra

## 构图
完整的训练/评估/预测/链路，由四个函数构成
- build_logits: 构图
- build_loss：定义损失函数
- build_eval_metrics：定义评估指标
- build_predictions：定义预测输出

In [11]:
class TextClassification(base_model):

    def __init__(self, **kwargs):
        super(TextClassification, self).__init__(**kwargs)
        self.user_defined_config = kwargs["user_defined_config"]

    def build_logits(self, features, mode=None):
        # 负责对原始数据进行预处理，生成模型需要的特征，比如：input_ids, input_mask, segment_ids等
        preprocessor = preprocessors.get_preprocessor(self.pretrain_model_name_or_path,
                                                      user_defined_config=self.user_defined_config)

        # 负责构建网络的backbone
        model = model_zoo.get_pretrained_model(self.pretrain_model_name_or_path)

        dense = layers.Dense(self.num_labels, kernel_initializer=layers.get_initializer(0.02), name='dense')

        input_ids, input_mask, segment_ids, label_ids = preprocessor(features)

        _, pooled_output = model([input_ids, input_mask, segment_ids], mode=mode)

        logits = dense(pooled_output)

        return logits, label_ids

    def build_loss(self, logits, labels):
        return softmax_cross_entropy(labels, self.num_labels, logits)
    
    def build_eval_metrics(self, logits, labels):
        return classification_eval_metrics(logits, labels, self.num_labels)

    def build_predictions(self, output):
        logits, _ = output
        predictions = dict()
        predictions["predictions"] = tf.argmax(logits, axis=-1, output_type=tf.int32)
        return predictions

# (四）启动训练

In [12]:
config = Config(mode="train_and_evaluate_on_the_fly", config_json=config_json)

W1027 10:51:47.386891 139683208804160 deprecation_wrapper.py:119] From /home/admin/.local/lib/python3.6/site-packages/easytransfer/engines/model.py:62: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.

I1027 10:51:47.387405 139683208804160 model.py:62] ***************** modelZooBasePath /home/admin/jupyter/my_model_zoo ***************


In [13]:
app = TextClassification(user_defined_config=config)

W1027 10:51:47.413406 139683208804160 deprecation_wrapper.py:119] From /home/admin/.local/lib/python3.6/site-packages/easytransfer/engines/model.py:731: The name tf.gfile.Open is deprecated. Please use tf.io.gfile.GFile instead.

I1027 10:51:47.597829 139683208804160 model.py:736] total number of training examples 53360
I1027 10:51:47.598403 139683208804160 model.py:239] ***********Running in train_and_evaluate_on_the_fly mode***********
I1027 10:51:47.598751 139683208804160 model.py:248] ***********Disable Tao***********
I1027 10:51:47.599076 139683208804160 model.py:255] ***********Disable AUTO_MIXED_PRECISION***********
I1027 10:51:47.599404 139683208804160 model.py:266] ***********NCCL_MAX_NRINGS 4***********
I1027 10:51:47.599712 139683208804160 model.py:267] ***********NCCL_MIN_NRINGS 4***********
I1027 10:51:47.600023 139683208804160 model.py:268] ***********TF_JIT_PROFILING False***********
I1027 10:51:47.600324 139683208804160 model.py:269] ***********PAI_ENABLE_HLO_DUMPER Fal

In [14]:
train_reader = CSVReader(input_glob=app.train_input_fp,
                         is_training=True,
                         input_schema=app.input_schema,
                         batch_size=app.train_batch_size)

eval_reader = CSVReader(input_glob=app.eval_input_fp,
                        is_training=False,
                        input_schema=app.input_schema,
                        batch_size=app.eval_batch_size)

I1027 10:51:47.652350 139683208804160 reader.py:78] num_parallel_batches 1
I1027 10:51:47.652874 139683208804160 reader.py:79] shuffle_buffer_size None
I1027 10:51:47.653255 139683208804160 reader.py:80] prefetch_buffer_size 1
I1027 10:51:47.653590 139683208804160 reader.py:81] batch_size 2
I1027 10:51:47.653918 139683208804160 reader.py:82] distribution_strategy None
I1027 10:51:47.654216 139683208804160 reader.py:83] num_micro_batches 1
I1027 10:51:47.655252 139683208804160 reader.py:84] input_schema label:str:1,content:str:1
I1027 10:51:47.842130 139683208804160 csv_reader.py:54] ./data/train.csv, total number of training examples 53360
I1027 10:51:47.842768 139683208804160 reader.py:78] num_parallel_batches 1
I1027 10:51:47.843145 139683208804160 reader.py:79] shuffle_buffer_size None
I1027 10:51:47.843453 139683208804160 reader.py:80] prefetch_buffer_size 1
I1027 10:51:47.843765 139683208804160 reader.py:81] batch_size 8
I1027 10:51:47.844077 139683208804160 reader.py:82] distribu

In [15]:
app.run_train_and_evaluate(train_reader=train_reader, eval_reader=eval_reader)

I1027 10:51:47.915409 139683208804160 estimator_training.py:186] Not using Distribute Coordinator.
I1027 10:51:47.915982 139683208804160 training.py:612] Running training and evaluation locally (non-distributed).
I1027 10:51:47.916495 139683208804160 training.py:700] Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps 26680 or save_checkpoints_secs None.
W1027 10:51:47.923486 139683208804160 deprecation.py:323] From /home/admin/.local/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
I1027 10:51:47.946563 139683208804160 reader.py:89] Random shuffle on the whole 53360 training e

# (五）启动预测

In [16]:
config = Config(mode="predict_on_the_fly", config_json=config_json)   
app = TextClassification(user_defined_config=config)
pred_reader = CSVReader(input_glob=app.predict_input_fp,
                        is_training=False,
                        input_schema=app.input_schema,
                        batch_size=app.predict_batch_size)

pred_writer = CSVWriter(output_glob=app.predict_output_fp,
                        output_schema=app.output_schema)

app.run_predict(reader=pred_reader, writer=pred_writer, 
                checkpoint_path=app.predict_checkpoint_path)

I1027 10:53:17.824488 139683208804160 model.py:62] ***************** modelZooBasePath /home/admin/jupyter/my_model_zoo ***************
I1027 10:53:17.865181 139683208804160 model.py:772] total number of predicting examples 10000
I1027 10:53:17.865805 139683208804160 model.py:437] ***********Running in predict_on_the_fly mode***********
W1027 10:53:17.866708 139683208804160 estimator.py:1811] Using temporary folder as model directory: /tmp/tmpus3ibqy0
I1027 10:53:17.867390 139683208804160 estimator.py:209] Using config: {'_model_dir': '/tmp/tmpus3ibqy0', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': intra_op_parallelism_threads: 1024
inter_op_parallelism_threads: 1024
gpu_options {
  per_process_gpu_memory_fraction: 1.0
  allow_growth: true
  force_gpu_compatible: true
}
allow_soft_placement: true
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_

In [17]:
pred = pd.read_csv('./data/predict.csv', header=None, delimiter='\t', encoding='utf8')

In [18]:
pred.columns = ['true_label','pred_label_id']

In [19]:
pred.head(10)

Unnamed: 0,true_label,pred_label_id
0,b'agriculture',4
1,b'agriculture',4
2,b'agriculture',4
3,b'agriculture',4
4,b'agriculture',10
5,b'agriculture',4
6,b'agriculture',4
7,b'agriculture',0
8,b'agriculture',10
9,b'agriculture',10
