In [None]:
import warnings
warnings.filterwarnings("ignore")

import os
import jieba
import torch
import pickle
import torch.nn as nn
import torch.optim as optim
import pandas as pd

from ark_nlp.model.tm.bert import Bert
from ark_nlp.model.tm.bert import BertConfig
from ark_nlp.model.tm.bert import Dataset
from ark_nlp.model.tm.bert import Task
from ark_nlp.model.tm.bert import get_default_model_optimizer
from ark_nlp.model.tm.bert import Tokenizer

### 一、数据读入与处理

#### 1. 数据读入

In [None]:
train_data_df = pd.read_json('../data/source_datasets/KUAKE-QTR/KUAKE-QTR_train.json')
train_data_df = (train_data_df
                 .rename(columns={'query': 'text_a', 'title': 'text_b'})
                 .loc[:,['text_a', 'text_b', 'label']])

dev_data_df = pd.read_json('../data/source_datasets/KUAKE-QTR/KUAKE-QTR_dev.json')
dev_data_df = (dev_data_df
                 .rename(columns={'query': 'text_a', 'title': 'text_b'})
                 .loc[:,['text_a', 'text_b', 'label']])

In [None]:
tm_train_dataset = Dataset(train_data_df)
tm_dev_dataset = Dataset(dev_data_df)

#### 2. 词典创建和生成分词器

In [None]:
tokenizer = Tokenizer(vocab='hfl/chinese-macbert-base', max_seq_len=50)

#### 3. ID化

In [None]:
tm_train_dataset.convert_to_ids(tokenizer)
tm_dev_dataset.convert_to_ids(tokenizer)

<br>

### 二、模型构建

#### 1. 模型参数设置

In [None]:
config = BertConfig.from_pretrained('hfl/chinese-macbert-base',
                                    num_labels=len(tm_train_dataset.cat2id))

#### 2. 模型创建

In [None]:
torch.cuda.empty_cache()

In [None]:
dl_module = Bert.from_pretrained('hfl/chinese-macbert-base', 
                                 config=config)

In [None]:
dl_module.pooling = 'last_avg'

<br>

### 三、任务构建

#### 1. 任务参数和必要部件设定

In [None]:
# 设置运行次数
num_epoches = 3
batch_size = 16

In [None]:
optimizer = get_default_model_optimizer(dl_module)

#### 2. 任务创建

In [None]:
model = Task(dl_module, optimizer, 'ce', cuda_device=0)

#### 3. 训练

In [None]:
model.fit(tm_train_dataset, 
          tm_dev_dataset,
          lr=3e-5,
          epochs=num_epoches, 
          batch_size=batch_size
         )

<br>

### 四、模型验证与保存

In [None]:
import json
from ark_nlp.model.tm.bert import Predictor

In [None]:
tm_predictor_instance = Predictor(model.module, tokenizer, tm_train_dataset.cat2id)

In [None]:
test_df = pd.read_json('../data/source_datasets/KUAKE-QTR/KUAKE-QTR_test.json')

submit = []
for _id, _text_a, _text_b in zip(test_df['id'], test_df['query'], test_df['title']):
    _predict = tm_predictor_instance.predict_one_sample([_text_a, _text_b])[0] 
    submit.append({
        'id': _id,
        'query': _text_a,
        'title': _text_b,
        'label': _predict
    })

In [None]:
output_path = '../data/output_datasets/KUAKE-QTR_test.json'

with open(output_path,'w', encoding='utf-8') as f:
    f.write(json.dumps(submit, ensure_ascii=False))