In [None]:

# 使用pandas读取数据集
%cd ~
import pandas as pd
train = pd.read_table('data/data103654/train.txt', sep='\t',header=None)  # 训练集
dev = pd.read_table('data/data103654/dev.txt', sep='\t',header=None)      # 验证集（官方已经划分的）

# 拼接训练和验证集，便于进行统计分析
total = pd.concat([train,dev],axis=0)

In [None]:
# 总类别标签分布统计
total[1].value_counts()

科技    162245
股票    153949
体育    130982
娱乐     92228
时政     62867
社会     50541
教育     41680
财经     36963
家居     32363
游戏     24283
房产     19922
时尚     13335
彩票      7598
星座      3515
Name: 1, dtype: int64

In [None]:
import matplotlib
import matplotlib.pyplot as plt

import matplotlib.font_manager as font_manager

# 设置显示中文
matplotlib.rcParams['font.sans-serif'] = ['FZSongYi-Z13S'] # 指定默认字体
matplotlib.rcParams['axes.unicode_minus'] = False # 解决保存图像是负号'-'显示为方块的问题
# 设置字体大小
matplotlib.rcParams['font.size'] = 16
# 可视化类别标签分布情况
total[1].value_counts(normalize=True).plot(kind='bar');

# 文本长度统计分析,通过分析可以看出文本较短，最长为48
total[0].map(len).describe()

count    832471.000000
mean         19.388112
std           4.097139
min           2.000000
25%          17.000000
50%          20.000000
75%          23.000000
max          48.000000
Name: 0, dtype: float64

In [None]:
# 定义要进行分类的14个类别
label_list=list(train[1].unique())
print(label_list)

['科技', '体育', '时政', '股票', '娱乐', '教育', '家居', '财经', '房产', '社会', '游戏', '彩票', '星座', '时尚']

In [None]:
# id 到 label转换
id_label_dict={}
for i in range(0,len(label_list)):
    id_label_dict[i]=label_list[i]
print(id_label_dict)
#################################################
# label到id转换
label_id_dict={}
for i in range(0,len(label_list)):
    label_id_dict[label_list[i]]=i
print(label_id_dict)

{0: '科技', 1: '体育', 2: '时政', 3: '股票', 4: '娱乐', 5: '教育', 6: '家居', 7: '财经', 8: '房产', 9: '社会', 10: '游戏', 11: '彩票', 12: '星座', 13: '时尚'}
{'科技': 0, '体育': 1, '时政': 2, '股票': 3, '娱乐': 4, '教育': 5, '家居': 6, '财经': 7, '房产': 8, '社会': 9, '游戏': 10, '彩票': 11, '星座': 12, '时尚': 13}

In [None]:

train[1]=train[1].apply(lambda x: label_id_dict[x])
train.head()

                       0  1
0            网易第三季度业绩低于分析师预期  0
1  巴萨1年前地狱重现这次却是天堂 再赴魔鬼客场必翻盘  1
2         美国称支持向朝鲜提供紧急人道主义援助  2
3           增资交银康联 交行夺参股险商首单  3
4               午盘：原材料板块领涨大盘  3

In [None]:
# 保存处理后的数据集文件
train.to_csv('mydata/train/train.txt', sep='\t', header=None, index=False)  # 保存训练集，以\t分隔开
dev[1]=dev[1].apply(lambda x: label_id_dict[x])
dev.to_csv('mydata/test/test.txt', sep='\t', header=None, index=False)  # 保存训练集，以\t分隔开

In [None]:
!unzip -qoa data/data158935/ERNIE3.zip
%cd ~/ERNIE/applications/tasks/text_classification

/home/aistudio/ERNIE/applications/tasks/text_classification

%cd ../../models_hub
!bash ./download_ernie_3.0_base_ch.sh
%cd ../tasks/text_classification

In [None]:
##配置文件：./examples/cls_ernie_fc_ch.json
{
  "dataset_reader": {
    "train_reader": {
      "name": "train_reader",
      "type": "BasicDataSetReader",
      "fields": [
        {
          "name": "text_a",
          "data_type": "string",
          "reader": {
            "type": "ErnieTextFieldReader"
          },
          "tokenizer": {
            "type": "FullTokenizer",
            "split_char": " ",
            "unk_token": "[UNK]"
          },
          "need_convert": true,
          "vocab_path": "../../models_hub/ernie_3.0_base_ch_dir/vocab.txt",
          "max_seq_len": 512,
          "truncation_type": 0,
          "padding_id": 0
        },
        {
          "name": "label",
          "data_type": "int",
          "reader": {
            "type": "ScalarFieldReader"
          },
          "tokenizer": null,
          "need_convert": false,
          "vocab_path": "",
          "max_seq_len": 1,
          "truncation_type": 0,
          "padding_id": 0,
          "embedding": null
        }
      ],
      "config": {
        "data_path": "/home/aistudio/mydata/train",
        "shuffle": false,
        "batch_size": 32,
        "epoch": 5,
        "sampling_rate": 1.0,
        "need_data_distribute": true,
        "need_generate_examples": false
      }
    },
    "test_reader": {
      "name": "test_reader",
      "type": "BasicDataSetReader",
      "fields": [
        {
          "name": "text_a",
          "data_type": "string",
          "reader": {
            "type": "ErnieTextFieldReader"
          },
          "tokenizer": {
            "type": "FullTokenizer",
            "split_char": " ",
            "unk_token": "[UNK]"
          },
          "need_convert": true,
          "vocab_path": "../../models_hub/ernie_3.0_base_ch_dir/vocab.txt",
          "max_seq_len": 512,
          "truncation_type": 0,
          "padding_id": 0
        },
        {
          "name": "label",
          "data_type": "int",
          "need_convert": false,
          "reader": {
            "type": "ScalarFieldReader"
          },
          "tokenizer": null,
          "vocab_path": "",
          "max_seq_len": 1,
          "truncation_type": 0,
          "padding_id": 0,
          "embedding": null
        }
      ],
      "config": {
        "data_path": "/home/aistudio/mydata/test",
        "shuffle": false,
        "batch_size": 32,
        "epoch": 1,
        "sampling_rate": 1.0,
        "need_data_distribute": false,
        "need_generate_examples": false
      }
    }
  },
  "model": {
    "type": "ErnieClassification",
    "is_dygraph": 1,
    "optimization": {
      "learning_rate": 5e-05,
      "use_lr_decay": true,
      "warmup_steps": 0,
      "warmup_proportion": 0.1,
      "weight_decay": 0.01,
      "use_dynamic_loss_scaling": false,
      "init_loss_scaling": 128,
      "incr_every_n_steps": 100,
      "decr_every_n_nan_or_inf": 2,
      "incr_ratio": 2.0,
      "decr_ratio": 0.8
    },
    "embedding": {
      "config_path": "../../models_hub/ernie_3.0_base_ch_dir/ernie_config.json"
    },
    "num_labels": 14
  },
  "trainer": {
    "type": "CustomDynamicTrainer",
    "PADDLE_PLACE_TYPE": "gpu",
    "PADDLE_IS_FLEET": 1,
    "train_log_step": 10,
    "use_amp": true,
    "is_eval_dev": 0,
    "is_eval_test": 1,
    "eval_step": 100,
    "save_model_step": 200,
    "load_parameters": "",
    "load_checkpoint": "",
    "pre_train_model": [
      {
        "name": "ernie_3.0_base_ch",
        "params_path": "../../models_hub/ernie_3.0_base_ch_dir/params"
      }
    ],
    "output_path": "./output/cls_ernie_3.0_base_fc_ch_dy",
    "extra_param": {
      "meta":{
        "job_type": "text_classification"
      }

    }
  }
}

In [None]:
!cp ~/cls_ernie_fc_ch.json ./examples/cls_ernie_fc_ch.json

In [None]:
!python run_trainer.py --param_path ./examples/cls_ernie_fc_ch.json

INFO: 07-20 09:34:05: custom_dynamic_trainer.py:96 * 139931327883008 epoch 1 progress 234497/752471 current learning rate: 3.69e-05
INFO: 07-20 09:34:05: base_cls.py:88 * 139931327883008 phase = training loss = 0.1123419776558876 acc = 0.9678 precision = 0.9711 step = 960 time_cost = 6.1331
INFO: 07-20 09:34:12: custom_dynamic_trainer.py:96 * 139931327883008 epoch 1 progress 244737/752471 current learning rate: 3.68e-05
INFO: 07-20 09:34:12: base_cls.py:88 * 139931327883008 phase = training loss = 0.13989800214767456 acc = 0.9629 precision = 0.9569 step = 970 time_cost = 6.2258
INFO: 07-20 09:34:18: custom_dynamic_trainer.py:96 * 139931327883008 epoch 1 progress 254977/752471 current learning rate: 3.67e-05
INFO: 07-20 09:34:18: base_cls.py:88 * 139931327883008 phase = training loss = 0.09835822135210037 acc = 0.9678 precision = 0.9618 step = 980 time_cost = 6.3737
INFO: 07-20 09:34:24: custom_dynamic_trainer.py:96 * 139931327883008 epoch 1 progress 265217/752471 current learning rate: 3.65e-05
INFO: 07-20 09:34:24: base_cls.py:88 * 139931327883008 phase = training loss = 0.12024983763694763 acc = 0.9678 precision = 0.9678 step = 990 time_cost = 6.1577
INFO: 07-20 09:34:31: custom_dynamic_trainer.py:96 * 139931327883008 epoch 1 progress 275457/752471 current learning rate: 3.64e-05
INFO: 07-20 09:34:31: base_cls.py:88 * 139931327883008 phase = training loss = 0.08035919815301895 acc = 0.9688 precision = 0.9679 step = 1000 time_cost = 6.3367
INFO: 07-20 09:35:19: base_cls.py:93 * 139931327883008 phase = test acc = 0.9702 precision = 0.969 time_cost = 48.1985 step = 79
INFO: 07-20 09:35:19: custom_dynamic_trainer.py:157 * 139931327883008 eval step = 79
INFO: 07-20 09:35:21: dynamic_trainer.py:170 * 139931327883008 save path: ./output/cls_ernie_3.0_base_fc_ch_dy/save_inference_model/inference_step_1000

In [None]:
{
  "dataset_reader": {
    "predict_reader": {
      "name": "predict_reader",
      "type": "BasicDataSetReader",
      "fields": [
        {
          "name": "text_a",
          "data_type": "string",
          "reader": {
            "type": "ErnieTextFieldReader"
          },
          "tokenizer": {
            "type": "FullTokenizer",
            "split_char": " ",
            "unk_token": "[UNK]",
            "params": null
          },
          "need_convert": true,
          "vocab_path": "../../models_hub/ernie_3.0_base_ch_dir/vocab.txt",
          "max_seq_len": 512,
          "truncation_type": 0,
          "padding_id": 0,
          "embedding": null
        }
      ],
      "config": {
      # 要预测的数据
        "data_path": "./data/predict_data",
        "shuffle": false,
        "batch_size": 8,
        "epoch": 1,
        "sampling_rate": 1.0,
        "need_data_distribute": false,
        "need_generate_examples": true
      }
    }
  },

  "inference": {
    "type": "CustomInference",
    # 结果保存文件
    "output_path": "./output/predict_result.txt",
    "PADDLE_PLACE_TYPE": "cpu",
    "num_labels": 2,
    "thread_num": 2,
    "inference_model_path": "./output/cls_ernie_3.0_base_fc_ch_dy/save_inference_model/inference_step_251/",
    "extra_param": {
      "meta":{
        "job_type": "text_classification"
      }

    }
  }
}

In [None]:
# 制造预测目录
!mkdir /home/aistudio/mydata/predict/
!cp /home/aistudio/data/data103654/test.txt /home/aistudio/mydata/predict_data/infer.txt

In [None]:
%cd ~/ERNIE/applications/tasks/text_classification
!cp  ~/run_infer.py run_infer.py 

/home/aistudio/ERNIE/applications/tasks/text_classification

In [None]:
! cp ~/cls_ernie_fc_ch_infer.json ./examples/cls_ernie_fc_ch_infer.json
! python ./run_infer.py  --param_path ./examples/cls_ernie_fc_ch_infer.json

In [None]:
# 查看输出
!head ./output/predict_result.txt

In [None]:
import pandas as pd
predict=pd.read_csv('./output/predict_result.txt', header=None, sep='\t')
predict.head()

                         0                                                  1
0  北京君太百货璀璨秋色 满100省353020元  [0.053806886076927185, 0.003187569323927164, 0...
1        教育部：小学高年级将开始学习性知识  [0.003998706582933664, 0.0019221440888941288, ...
2    专业级单反相机 佳能7D单机售价9280元  [0.9984349608421326, 5.53653335373383e-05, 0.0...
3      星展银行起诉内地客户 银行强硬客户无奈  [0.003201978513970971, 0.00043116536107845604,...
4   脱离中国的实际 强压人民币大幅升值只能是梦想  [0.003938914742320776, 0.003025110810995102, 0...