In [None]:
from bert_multitask_learning import train_bert_multitask, train_eval_input_fn, BertMultiTask, params
from bert_multitask_learning.predefined_problems import get_weibo_ner_fn, get_weibo_cws_fn

In [None]:
problem_type_dict = {
    'weibo_cws': 'seq_tag',
    'weibo_ner': 'seq_tag'
}


processing_fn_dict = {
    'weibo_ner': get_weibo_ner_fn(file_path='../data/ner/weiboNER*'),
    'weibo_cws': get_weibo_cws_fn(file_path='../data/ner/weiboNER*')
}

## Train Models
If you don't want to control every thing, you can just call `train_bert_multitask` function. 

In [None]:
params = params.DynamicBatchSizeParams()
params.bert_num_hidden_layer = 1
train_bert_multitask(problem='weibo_ner&weibo_cws', problem_type_dict=problem_type_dict, processing_fn_dict=processing_fn_dict, num_gpus=1, num_epochs=1)

If you want to take more control of the training process, you can use lower level api

In [None]:
import tensorflow as tf
from tensorflow.estimator import Estimator
from bert_multitask_learning.ckpt_restore_hook import RestoreCheckpointHook

problem = 'weibo_ner&weibo_cws'
num_gpus = 1
bert_multitask_params = params.DynamicBatchSizeParams()
params.bert_num_hidden_layer = 1

for new_problem, new_problem_processing_fn in processing_fn_dict.items():
    print('Adding new problem {0}, problem type: {1}'.format(
        new_problem, problem_type_dict[new_problem]))
    params.add_problem(
        problem_name=new_problem, problem_type=problem_type_dict[new_problem], processing_fn=new_problem_processing_fn)

# assign problem to params
bert_multitask_params.train_epoch = 1
bert_multitask_params.assign_problem(problem, gpu=1)

In [None]:
# get model fn and create mirror strategy for distributed training
model = BertMultiTask(params=bert_multitask_params)
model_fn = model.get_model_fn()

dist_trategy = tf.contrib.distribute.MirroredStrategy(
    num_gpus=int(num_gpus),
    cross_tower_ops=tf.contrib.distribute.AllReduceCrossDeviceOps(
        'nccl', num_packs=int(num_gpus)))

run_config = tf.estimator.RunConfig(
    train_distribute=dist_trategy,
    eval_distribute=dist_trategy,
    log_step_count_steps=bert_multitask_params.log_every_n_steps)

# create estimator
estimator = Estimator(
    model_fn,
    model_dir=bert_multitask_params.ckpt_dir,
    params=bert_multitask_params,
    config=run_config)

# pretrained bert restore hook
train_hook = RestoreCheckpointHook(bert_multitask_params)

In [None]:
# train
def train_input_fn(): return train_eval_input_fn(bert_multitask_params)
estimator.train(
    train_input_fn, max_steps=bert_multitask_params.train_steps, hooks=[train_hook])


## Evaluate and Predict

For NER and CWS, we need different evaluation logic.

In [None]:
from bert_multitask_learning import eval_bert_multitask, predict_bert_multitask

In [None]:
eval_bert_multitask(problem='weibo_cws', model_dir='models/weibo_cws_weibo_ner_ckpt/', eval_scheme='acc')

In [None]:
eval_bert_multitask(problem='weibo_ner', model_dir='models/weibo_cws_weibo_ner_ckpt/', eval_scheme='ner')

In [None]:
# predict
import numpy as np
from bert_multitask_learning.utils import get_or_make_label_encoder
predict_params = params.DynamicBatchSizeParams()
# get prediction generator
pred_prob = predict_bert_multitask(inputs=['中国和美国在打贸易战'], problem='weibo_cws&weibo_ner', params=predict_params)
# get label encoder
ner_label_encoder = get_or_make_label_encoder(params=predict_params, problem='weibo_ner', mode='predict')
cws_label_encoder = get_or_make_label_encoder(params=predict_params, problem='weibo_cws', mode='predict')

for prob in pred_prob:
    ner_pred = np.argmax(prob['weibo_ner'], axis = -1)
    print(ner_label_encoder.inverse_transform(ner_pred.tolist()))

## Export Model for Serving

You can export the trained model for [serving](https://github.com/JayYip/bert-as-service).

In [None]:
from bert_multitask_learning import export_model

In [None]:
export_model(bert_multitask_params)