# 使用Jupyter-Notebook快速搭建文本分类应用

这是一篇介绍如何在PAI-DSW里用EasyTransfer平台训练文本分类器的教程。只需要一份配置文件，一份ipynb文件，您就可以完成对原始数据的特征提取，网络构建，损失函数及分类评估/预测的简单调用。运行本DEMO需要如下的配置信息

- python 3.6+
- tensorflow 1.12+

## （一）数据准备
下面以一个基于bert的文本分类为例，通过端到端的分布式训练/评估/预测流程，展示平台的易用性。这里的端到端指的是直接读入原始数据就可以训练，而不需要事先转换成Bert特征格式。


In [9]:
!wget -O easytransfer-1.0.0-py3-none-any.whl https://pai-public-data.oss-cn-beijing.aliyuncs.com/public_whl/easytransfer-1.0.0-py3-none-any.whl

--2020-09-16 20:43:23--  https://pai-public-data.oss-cn-beijing.aliyuncs.com/public_whl/easytransfer-1.0.0-py3-none-any.whl
Resolving pai-public-data.oss-cn-beijing.aliyuncs.com (pai-public-data.oss-cn-beijing.aliyuncs.com)... 59.110.185.63
Connecting to pai-public-data.oss-cn-beijing.aliyuncs.com (pai-public-data.oss-cn-beijing.aliyuncs.com)|59.110.185.63|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 179210 (175K) [application/octet-stream]
Saving to: ‘easytransfer-1.0.0-py3-none-any.whl’


2020-09-16 20:43:23 (1.70 MB/s) - ‘easytransfer-1.0.0-py3-none-any.whl’ saved [179210/179210]



In [10]:
!pip install ./easytransfer-1.0.0-py3-none-any.whl --user

Processing ./easytransfer-1.0.0-py3-none-any.whl
Collecting joblib==0.14.1 (from easytransfer==1.0.0)
  Downloading https://mirrors.aliyun.com/pypi/packages/28/5c/cf6a2b65a321c4a209efcdf64c2689efae2cb62661f8f6f4bb28547cf1bf/joblib-0.14.1-py2.py3-none-any.whl (294kB)
[K    100% |████████████████████████████████| 296kB 72.4MB/s ta 0:00:01
Collecting sentencepiece (from easytransfer==1.0.0)
  Downloading https://mirrors.aliyun.com/pypi/packages/68/e5/0366f50a00db181f4b7f3bdc408fc7c4177657f5bf45cb799b79fb4ce15c/sentencepiece-0.1.92-cp36-cp36m-manylinux1_x86_64.whl (1.2MB)
[K    100% |████████████████████████████████| 1.2MB 82.7MB/s ta 0:00:01
Installing collected packages: joblib, sentencepiece, easytransfer
  Found existing installation: joblib 0.16.0
    Uninstalling joblib-0.16.0:
      Successfully uninstalled joblib-0.16.0
Successfully installed easytransfer-1.0.0 joblib-0.14.1 sentencepiece-0.1.92
[33mYou are using pip version 9.0.1, however version 20.2.3 is available.
You should

In [1]:
!mkdir data
!wget -O ./data/train.csv https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/tutorial/dsw/train.csv
!wget -O ./data/dev.csv https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/tutorial/dsw/dev.csv

--2020-09-16 20:39:38--  https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/tutorial/dsw/train.csv
Resolving atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)... 106.14.228.37
Connecting to atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)|106.14.228.37|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 6008185 (5.7M) [text/csv]
Saving to: ‘./data/train.csv’


2020-09-16 20:39:38 (25.1 MB/s) - ‘./data/train.csv’ saved [6008185/6008185]

--2020-09-16 20:39:39--  https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/tutorial/dsw/dev.csv
Resolving atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)... 106.14.228.37
Connecting to atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)|106.14.228.37|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1125139 (1.1M) [text/csv]
Saving to: ‘

In [2]:
import pandas as pd

In [3]:
!head ./data/train.csv

agriculture	加快产城融合 以科技创新引领新城区建设 新城区,城镇化率,中心城区,科技新城,科技创新
agriculture	9X10米清雅型别墅，大容量简约民宿风，来自美墅建房的诱惑！ 农村,琉璃瓦,民宿,农村自建房,自建房
agriculture	15000亿基建投资！为特色小镇、田园综合体带来哪些新的机遇？ 仙新路,和燕路,南京,总投资,2018年南京市城乡建设计划,高速公路,轨道交通,省干线公路
agriculture	“滞销大爷”家里没滞销水果 原图拍摄者称要维权 店家,著作权,滥用,悲情牌,滞销大爷
agriculture	史上最全经济作物需肥总结 盛期,生育期,营养生长,氧化钾,K2O,需肥量,生长量,硝态氮,五氧化二磷,P2O5
agriculture	植保无人机飞控品牌有哪些？
agriculture	「牛羊催肥」肉羊吃什么饲料催肥增重长得快，一天长7两-1斤 美力盾小杨,旺长素,玉米粒,肉羊,育肥羊
agriculture	农村创业别养猪，今年养猪只剩哭，猪价持续下跌原来是这些原因 养猪,利润率,农村
agriculture	你什么时候回来，不要在大城市打工了 嫁祸于人,农村
agriculture	到底怎样才能称之为田园综合体！


In [4]:
df = pd.read_csv('./data/train.csv', header=None, delimiter='\t', encoding='utf8')

In [5]:
df.columns = ['label','content']

In [6]:
df.head(2)

Unnamed: 0,label,content
0,agriculture,"加快产城融合 以科技创新引领新城区建设 新城区,城镇化率,中心城区,科技新城,科技创新"
1,agriculture,"9X10米清雅型别墅，大容量简约民宿风，来自美墅建房的诱惑！ 农村,琉璃瓦,民宿,农村自建房..."


## （二）定义配置文件

如下是我们easytransfe的配置，比如说predict_checkpoint_path是指定验证集上指标最好的checkpoint的路径。
详细配置介绍请看easytransfer文档: https://yuque.antfin-inc.com/pai/transfer-learning/zyib3t

In [7]:
config_json = {
    "worker_hosts": "locahost",
    "task_index": 1,
    "job_name": "chief",
    "num_gpus": 1,
    "num_workers": 1,
    "modelZooBasePath": "/home/admin/jupyter/my_model_zoo",
    "preprocess_config": {
        "input_schema": "label:str:1,content:str:1",
        "first_sequence": "content",
        "second_sequence": None,
        "sequence_length": 16,
        "label_name": "label",
        "label_enumerate_values": "tech,finance,entertainment,world,car,culture,sports,military,edu,game,travel,agriculture,house,story,stock",
        "output_schema": "label,predictions"
    },
    "model_config": {
        "pretrain_model_name_or_path": "pai-bert-tiny-zh",
        "num_labels": 15
    },
    "train_config": {
        "train_input_fp": "./data/train.csv",
        "train_batch_size": 2,
        "num_epochs": 1,
        "model_dir": "model_dir",
        "optimizer_config": {
            "learning_rate": 1e-5
        },
        "distribution_config": {
            "distribution_strategy": None
        }
    },
    "evaluate_config": {
        "eval_input_fp": "./data/dev.csv",
        "eval_batch_size": 8
    },
    "predict_config": {
        "predict_checkpoint_path": "model_dir/model.ckpt-834",
        "predict_input_fp": "./data/dev.csv",
        "predict_output_fp": "./data/predict.csv"
    }
}

##  （三）定义分类应用

### 导入ez_transfer库文件
- base_model: 所有应用都需要继承的父类
- Config：用来解析配置文件的父类
- layers：基础组件。比如Embedding，Attention等
- model_zoo: 管理预训练模型的组件库，通过get_pretrained_model方法可调用bert模型
- preprocessors：管理各种应用的预处理逻辑
- CSVReader：csv格式的数据读取器
- softmax_cross_entropy：用于分类任务的损失函数
- classification_eval_metrics：用于分类任务的评估指标，比如Accuracy

In [11]:
import tensorflow as tf
from easytransfer import base_model, Config
from easytransfer import layers
from easytransfer import model_zoo
from easytransfer import preprocessors
from easytransfer.datasets import CSVReader,CSVWriter
from easytransfer.losses import softmax_cross_entropy
from easytransfer.evaluators import classification_eval_metrics

W0916 20:44:11.559335 140053906782016 deprecation_wrapper.py:119] From /home/admin/.local/lib/python3.6/site-packages/easytransfer/engines/model.py:22: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.

W0916 20:44:11.560436 140053906782016 deprecation_wrapper.py:119] From /home/admin/.local/lib/python3.6/site-packages/easytransfer/engines/model.py:22: The name tf.logging.INFO is deprecated. Please use tf.compat.v1.logging.INFO instead.

W0916 20:44:11.561143 140053906782016 deprecation_wrapper.py:119] From /home/admin/.local/lib/python3.6/site-packages/easytransfer/engines/model.py:27: The name tf.set_random_seed is deprecated. Please use tf.compat.v1.set_random_seed instead.

W0916 20:44:11.577982 140053906782016 deprecation_wrapper.py:119] From /home/admin/.local/lib/python3.6/site-packages/easytransfer/engines/model.py:28: The name tf.reset_default_graph is deprecated. Please use tf.compat.v1.reset_default_graph instead.



## 构图
完整的训练/评估/预测/链路，由四个函数构成
- build_logits: 构图
- build_loss：定义损失函数
- build_eval_metrics：定义评估指标
- build_predictions：定义预测输出

In [12]:
class TextClassification(base_model):

    def __init__(self, **kwargs):
        super(TextClassification, self).__init__(**kwargs)
        self.user_defined_config = kwargs["user_defined_config"]

    def build_logits(self, features, mode=None):
        # 负责对原始数据进行预处理，生成模型需要的特征，比如：input_ids, input_mask, segment_ids等
        preprocessor = preprocessors.get_preprocessor(self.pretrain_model_name_or_path,
                                                      user_defined_config=self.user_defined_config)

        # 负责构建网络的backbone
        model = model_zoo.get_pretrained_model(self.pretrain_model_name_or_path)

        dense = layers.Dense(self.num_labels, kernel_initializer=layers.get_initializer(0.02), name='dense')

        input_ids, input_mask, segment_ids, label_ids = preprocessor(features)

        _, pooled_output = model([input_ids, input_mask, segment_ids], mode=mode)

        logits = dense(pooled_output)

        return logits, label_ids

    def build_loss(self, logits, labels):
        return softmax_cross_entropy(labels, self.num_labels, logits)
    
    def build_eval_metrics(self, logits, labels):
        return classification_eval_metrics(logits, labels, self.num_labels)

    def build_predictions(self, output):
        logits, _ = output
        predictions = dict()
        predictions["predictions"] = tf.argmax(logits, axis=-1, output_type=tf.int32)
        return predictions

# (四）启动训练

In [13]:
config = Config(mode="train_and_evaluate_on_the_fly", config_json=config_json)

W0916 20:44:43.581403 140053906782016 deprecation_wrapper.py:119] From /home/admin/.local/lib/python3.6/site-packages/easytransfer/engines/model.py:61: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.

I0916 20:44:43.582599 140053906782016 model.py:61] ***************** modelZooBasePath /home/admin/jupyter/my_model_zoo ***************


In [14]:
app = TextClassification(user_defined_config=config)

W0916 20:44:44.966868 140053906782016 deprecation_wrapper.py:119] From /home/admin/.local/lib/python3.6/site-packages/easytransfer/engines/model.py:731: The name tf.gfile.Open is deprecated. Please use tf.io.gfile.GFile instead.

I0916 20:44:45.229859 140053906782016 model.py:736] total number of training examples 53360
I0916 20:44:45.230817 140053906782016 model.py:238] ***********Running in train_and_evaluate_on_the_fly mode***********
I0916 20:44:45.231338 140053906782016 model.py:247] ***********Disable Tao***********
I0916 20:44:45.231803 140053906782016 model.py:254] ***********Disable AUTO_MIXED_PRECISION***********
I0916 20:44:45.232288 140053906782016 model.py:265] ***********NCCL_MAX_NRINGS 4***********
I0916 20:44:45.232758 140053906782016 model.py:266] ***********NCCL_MIN_NRINGS 4***********
I0916 20:44:45.233205 140053906782016 model.py:267] ***********TF_JIT_PROFILING False***********
I0916 20:44:45.233658 140053906782016 model.py:268] ***********PAI_ENABLE_HLO_DUMPER Fal

In [15]:
train_reader = CSVReader(input_glob=app.train_input_fp,
                         is_training=True,
                         input_schema=app.input_schema,
                         batch_size=app.train_batch_size)

eval_reader = CSVReader(input_glob=app.eval_input_fp,
                        is_training=False,
                        input_schema=app.input_schema,
                        batch_size=app.eval_batch_size)

I0916 20:44:47.737787 140053906782016 reader.py:78] num_parallel_batches 1
I0916 20:44:47.739200 140053906782016 reader.py:79] shuffle_buffer_size None
I0916 20:44:47.739801 140053906782016 reader.py:80] prefetch_buffer_size 1
I0916 20:44:47.740364 140053906782016 reader.py:81] batch_size 2
I0916 20:44:47.740909 140053906782016 reader.py:82] distribution_strategy None
I0916 20:44:47.741471 140053906782016 reader.py:83] num_micro_batches 1
I0916 20:44:47.742014 140053906782016 reader.py:84] input_schema label:str:1,content:str:1
I0916 20:44:48.012047 140053906782016 csv_reader.py:54] ./data/train.csv, total number of training examples 53360
I0916 20:44:48.012912 140053906782016 reader.py:78] num_parallel_batches 1
I0916 20:44:48.013436 140053906782016 reader.py:79] shuffle_buffer_size None
I0916 20:44:48.013882 140053906782016 reader.py:80] prefetch_buffer_size 1
I0916 20:44:48.014325 140053906782016 reader.py:81] batch_size 8
I0916 20:44:48.014764 140053906782016 reader.py:82] distribu

In [16]:
app.run_train_and_evaluate(train_reader=train_reader, eval_reader=eval_reader)

I0916 20:44:49.485316 140053906782016 estimator_training.py:186] Not using Distribute Coordinator.
I0916 20:44:49.486536 140053906782016 training.py:612] Running training and evaluation locally (non-distributed).
I0916 20:44:49.487238 140053906782016 training.py:700] Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps 26680 or save_checkpoints_secs None.
W0916 20:44:49.497976 140053906782016 deprecation.py:323] From /opt/conda/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
I0916 20:44:49.532213 140053906782016 reader.py:89] Random shuffle on the whole 53360 training examples


KeyboardInterrupt: 