## Define new problem type and data reading function

We'll use IMDB dataset as example

In [None]:
!pip install tensorflow-gpu
!pip install tensorflow-addons==0.11.2
!pip install bert-multitask-learning==0.5.7b8
!pip install transformers==3.5.1

In [2]:
import tensorflow as tf
from tensorflow import keras

In [3]:
from bert_multitask_learning import (train_bert_multitask, 
                                     eval_bert_multitask, DynamicBatchSizeParams, TRAIN, EVAL, PREDICT, preprocessing_fn, get_or_make_label_encoder)
import pickle

In [4]:
new_problem_type = {'imdb_cls': 'cls'}

@preprocessing_fn
def imdb_cls(params, mode):

    # get data
    (train_data, train_labels), (test_data, test_labels) = keras.datasets.imdb.load_data(num_words=10000)
    label_encoder = get_or_make_label_encoder(params, 'imdb_cls', mode, train_labels+test_labels)
    word_to_id = keras.datasets.imdb.get_word_index()
    index_from=3
    word_to_id = {k:(v+index_from) for k,v in word_to_id.items()}
    word_to_id["<PAD>"] = 0
    word_to_id["<START>"] = 1
    word_to_id["<UNK>"] = 2
    id_to_word = {value:key for key,value in word_to_id.items()}

    train_data = [[id_to_word[i] for i in sentence] for sentence in train_data]
    test_data = [[id_to_word[i] for i in sentence] for sentence in test_data]
    
    if mode == TRAIN:
        input_list = train_data
        target_list = train_labels
    else:
        input_list = test_data
        target_list = test_labels
    
    return input_list, target_list
new_problem_process_fn_dict = {'imdb_cls': imdb_cls}
    

## Train Model

Please make sure you're using the correct checkpoint to initialize model.

In [None]:
params = DynamicBatchSizeParams()
params.transformer_config_loading = 'BertConfig'
params.transformer_model_name = 'bert-base-chinese'
params.transformer_tokenizer_name = 'bert-base-chinese'
params.transformer_tokenizer_loading = 'BertTokenizer'
train_bert_multitask(problem='imdb_cls', num_gpus=1, 
                     num_epochs=1, params=params, 
                     problem_type_dict=new_problem_type, processing_fn_dict=new_problem_process_fn_dict)