In [None]:
# default_exp params
from nbdev.showdoc import show_doc

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# hide
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

# Params

`BaseParams` is the major object to control the whole modeling process. It is supposed to be accessable anywhere. 

In [None]:
# export

import json
import os
import re
import shutil
import logging
from typing import Callable, List, Tuple, Dict, Union
from collections import defaultdict
from distutils.dir_util import copy_tree
import tensorflow as tf

from bert_multitask_learning.utils import create_path, load_transformer_tokenizer, load_transformer_config
from bert_multitask_learning.special_tokens import BOS_TOKEN, EOS_TOKEN


class BaseParams():
    # pylint: disable=attribute-defined-outside-init
    def __init__(self):
        self.run_problem_list = []

        self.problem_type = {
        }

        # transformers params
        self.transformer_model_name = 'bert-base-chinese'
        self.transformer_tokenizer_name = 'bert-base-chinese'
        self.transformer_config_name = 'bert-base-chinese'
        self.transformer_model_loading = 'TFBertModel'
        self.transformer_config_loading = 'BertConfig'
        self.transformer_tokenizer_loading = 'BertTokenizer'
        self.transformer_decoder_model_name = None
        self.transformer_decoder_config_name = None
        self.transformer_decoder_tokenizer_name = None
        # self.transformer_decoder_model_name = "hfl/chinese-xlnet-base"
        # self.transformer_decoder_config_name = "hfl/chinese-xlnet-base"
        # self.transformer_decoder_tokenizer_name = "hfl/chinese-xlnet-base"
        self.transformer_decoder_model_loading = 'TFAutoModel'
        self.transformer_decoder_config_loading = 'AutoConfig'
        self.transformer_decoder_tokenizer_loading = 'AutoTokenizer'

        # multimodal params
        self.modal_segment_id = {
            'text': 0,
            'image': 0,
            'others': 0
        }
        self.modal_type_id = {
            'text': 0,
            'image': 1,
            'others': 2
        }
        self.enable_modal_type = False
        # bert config
        self.init_checkpoint = ''

        # specify this will make key reuse values top
        # that it, weibo_ner problem will use NER's top
        self.share_top = {
        }
        for p in self.problem_type:
            if p not in self.share_top:
                self.share_top[p] = p

        self.multitask_balance_type = 'data_balanced'
        self.problem_type_list = ['cls', 'seq_tag', 'seq2seq_tag',
                                  'seq2seq_text', 'multi_cls', 'pretrain', 'masklm']
        self.predefined_problem_type = ['cls', 'seq_tag', 'seq2seq_tag',
                                        'seq2seq_text', 'multi_cls', 'pretrain', 'masklm']
        self.get_or_make_label_encoder_fn_dict: Dict[str, Callable] = {}
        self.label_handling_fn: Dict[str, Callable] = {}
        self.top_layer = {}
        self.num_classes = {}
        # self.multitask_balance_type = 'problem_balanced'

        # logging control
        self.log_every_n_steps = 100
        self.detail_log = True

        self.multiprocess = True
        self.num_cpus = 4
        self.per_cpu_buffer = 3000
        self.decode_vocab_file = None
        self.eval_throttle_secs = 600

        # training
        self.init_lr = 2e-5
        self.batch_size = 32
        self.train_epoch = 15
        self.freeze_step = 0
        self.prefetch = 5000
        self.dynamic_padding = True
        self.bucket_batch_sizes = [32, 32, 32, 16]
        self.bucket_boundaries = [30, 64, 128]
        self.shuffle_buffer = 200000

        # hparm
        self.dropout_keep_prob = 0.9
        self.max_seq_len = 256
        self.use_one_hot_embeddings = True
        self.label_smoothing = 0.0
        self.crf = False
        self.bert_num_hidden_layer = 12
        self.hidden_dense = False
        # threshold to calculate metrics for multi_cls
        self.multi_cls_threshold = 0.5
        self.multi_cls_positive_weight = 1.0
        self.custom_pooled_hidden_size = 0
        self.share_embedding = True

        # seq2seq
        self.decoder_num_hidden_layers = 3
        self.beam_size = 10
        self.init_decoder_from_encoder = False
        self.beam_search_alpha = 0.6
        self.decode_max_seq_len = 90

        # experimental multitask approach
        self.label_transfer = False
        # train mask lm and downstream task at the same time
        self.augument_mask_lm = False
        self.augument_rate = 0.5
        # NOT implemented
        self.distillation = False
        # Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics
        # ref: https://arxiv.org/abs/1705.07115
        self.uncertain_weight_loss = False
        # dep since not good
        # self.mutual_prediction = False

        # add an extra attention for each task
        #   with BERT layers as encoder output, task logits as decoder inputs
        self.grid_transformer = False

        # add an extra attention for each task
        #   with other tasks' logits as encoder output, task logits asn decoder inputs
        self.task_transformer = False

        # do a mean for gradients of BERT layers instead of sum
        self.mean_gradients = False

        # random replace punctuation by some prob to
        # ease the punctuation sensitive problem
        self.punc_replace_prob = 0.0
        self.punc_list = list(',.!?！。？，、')
        self.hidden_gru = False
        self.label_transfer_gru = False
        # if None, we will use the same hidden_size as inputs
        # e.g. # of labels
        self.label_transfer_gru_hidden_size = None

        # pretrain hparm
        self.dupe_factor = 10
        self.short_seq_prob = 0.1
        self.masked_lm_prob = 0.15
        self.max_predictions_per_seq = 20
        self.mask_lm_hidden_size = 768
        self.mask_lm_hidden_act = 'gelu'
        self.mask_lm_initializer_range = 0.02

        self.train_problem = None
        self.tmp_file_dir = 'tmp'
        self.cache_dir = 'models/transformers_cache'
        # get generator function for each problem
        self.read_data_fn = {}
        self.problem_assigned = False

    def add_problem(self, problem_name: str, problem_type='cls', processing_fn: Callable = None):

        if problem_type not in self.problem_type_list:
            raise ValueError('Provided problem type not valid, expect {0}, got {1}'.format(
                self.problem_type_list,
                problem_type))

        self.problem_type[problem_name] = problem_type
        self.read_data_fn[problem_name] = processing_fn

    def add_multiple_problems(self, problem_type_dict: Dict[str, str], processing_fn_dict: Dict[str, Callable] = None):
        # add new problem to params if problem_type_dict and processing_fn_dict provided
        for new_problem, problem_type in problem_type_dict.items():
            print('Adding new problem {0}, problem type: {1}'.format(
                new_problem, problem_type_dict[new_problem]))
            if processing_fn_dict:
                new_problem_processing_fn = processing_fn_dict[new_problem]
            else:
                new_problem_processing_fn = None
            self.add_problem(
                problem_name=new_problem, problem_type=problem_type, processing_fn=new_problem_processing_fn)

    def assign_problem(self,
                       flag_string: str,
                       gpu=2,
                       base_dir: str = None,
                       dir_name: str = None,
                       predicting=False):
        self.assigned_details = (
            flag_string, gpu, base_dir, dir_name, predicting)
        self.problem_assigned = True
        self.predicting = predicting

        self.problem_list, self.problem_chunk = self.parse_problem_string(
            flag_string)

        # create dir and get vocab, config
        self.prepare_dir(base_dir, dir_name, self.problem_list)

        self.get_data_info(self.problem_list, self.ckpt_dir)

        self.set_data_sampling_strategy()

        if not predicting:
            for problem in self.problem_list:
                if self.problem_type[problem] == 'pretrain':
                    dup_fac = self.dupe_factor
                    break
                else:
                    dup_fac = 1
            self.train_steps = int((
                self.data_num * self.train_epoch * dup_fac) / (self.batch_size*max(1, gpu)))
            self.train_steps_per_epoch = int(
                self.train_steps / self.train_epoch)
            self.num_warmup_steps = int(0.1 * self.train_steps)

            # linear scale learing rate
            self.lr = self.init_lr * gpu

    def to_json(self):
        """Save the params as json files. Please note that processing_fn is not saved.
        """
        dump_dict = {}
        for att_name, att in vars(self).items():
            try:
                json.dumps(att)
                dump_dict[att_name] = att
            except TypeError:
                pass

        with open(self.params_path, 'w', encoding='utf8') as f:
            json.dump(dump_dict, f)

    def from_json(self, json_path: str = None):
        """Load json file as params. 

        json_path could not be None if the problem is not assigned to params

        Args:
            json_path (str, optional): Path to json file. Defaults to None.

        Raises:
            AttributeError
        """
        try:
            params_path = json_path if json_path is not None else self.params_path
        except AttributeError:
            raise AttributeError(
                'Either json_path should not be None or problem is assigned.')
        if self.problem_assigned:
            assign_details = self.assigned_details
        else:
            assign_details = None

        with open(params_path, 'r', encoding='utf8') as f:
            dump_dict = json.load(f)
        for att in dump_dict:
            setattr(self, att, dump_dict[att])
        self.bert_config = load_transformer_config(
            self.bert_config_dict, self.transformer_config_loading)
        if hasattr(self, 'bert_decoder_config_dict'):
            self.bert_decoder_config = load_transformer_config(
                self.bert_decoder_config_dict, self.transformer_decoder_config_loading
            )
        if assign_details:
            self.assign_problem(*assign_details)

    def get_data_info(self, problem_list: List[str], base: str):

        json_path = os.path.join(base, 'data_info.json')
        if os.path.exists(json_path):
            data_info = json.load(open(json_path, 'r', encoding='utf8'))
            self.data_num_dict = data_info['data_num']
            self.num_classes = data_info['num_classes']
        elif self.predicting:
            data_info = {
                'data_num': self.data_num_dict,
                'num_classes': self.num_classes,
            }
            return json.dump(data_info, open(json_path, 'w', encoding='utf8'))
        else:
            if not hasattr(self, 'data_num_dict'):
                self.data_num_dict = {}
            if not hasattr(self, 'num_classes'):
                self.num_classes = {}

        if not self.predicting:
            # update data_num and train_steps
            self.data_num = 0
            for problem in problem_list:
                if problem not in self.data_num_dict:

                    self.data_num_dict[problem], _ = self.read_data_fn[problem](
                        self, 'train', get_data_num=True)
                    self.data_num += self.data_num_dict[problem]
                else:
                    self.data_num += self.data_num_dict[problem]

            data_info = {
                'data_num': self.data_num_dict,
                'num_classes': self.num_classes,
            }

            json.dump(data_info, open(json_path, 'w', encoding='utf8'))
        return json_path

    def parse_problem_string(self, flag_string: str) -> Tuple[List[str], List[List[str]]]:

        self.problem_str = flag_string
        # Parse problem string
        self.run_problem_list = []
        problem_chunk = []
        for flag_chunk in flag_string.split('|'):

            if '&' not in flag_chunk:
                problem_type = {}
                problem_type[flag_chunk] = self.problem_type[flag_chunk]
                self.run_problem_list.append(problem_type)
                problem_chunk.append([flag_chunk])
            else:
                problem_type = {}
                problem_chunk.append([])
                for problem in flag_chunk.split('&'):
                    problem_type[problem] = self.problem_type[problem]
                    problem_chunk[-1].append(problem)
                self.run_problem_list.append(problem_type)
        # if (self.label_transfer or self.mutual_prediction) and self.train_problem is None:
        if self.train_problem is None:
            self.train_problem = [p for p in self.run_problem_list]

        problem_list = sorted(re.split(r'[&|]', flag_string))
        return problem_list, problem_chunk

    def prepare_dir(self, base_dir: str, dir_name: str, problem_list: List[str]):
        """prepare model checkpoint dir. this function will copy or save transformers' configs
        and tokenizers to params.ckpt_dir

        Args:
            base_dir (str): base_dir of params.ckpt_dir. same as os.path.dirname(params.ckpt_dir). bad naming
            dir_name (str): dir_name, same as os.path.basename(params.ckpt_dir). bad naming
            problem_list (List[str]): [description]
        """
        base = base_dir if base_dir is not None else 'models'

        dir_name = dir_name if dir_name is not None else '_'.join(
            problem_list)+'_ckpt'
        self.ckpt_dir = os.path.join(base, dir_name)

        # we need to make sure all configs, tokenizers are in ckpt_dir
        # configs
        from_config_path = os.path.join(self.init_checkpoint,
                                        'bert_config')
        from_decoder_config_path = os.path.join(self.init_checkpoint,
                                                'bert_decoder_config')
        to_config_path = os.path.join(self.ckpt_dir, 'bert_config')
        to_decoder_config_path = os.path.join(
            self.ckpt_dir, 'bert_decoder_config')

        # tokenizers
        from_tokenizer_path = os.path.join(self.init_checkpoint, 'tokenizer')
        to_tokenizer_path = os.path.join(self.ckpt_dir, 'tokenizer')

        from_decoder_tokenizer_path = os.path.join(
            self.init_checkpoint, 'decoder_tokenizer')
        to_decoder_tokenizer_path = os.path.join(
            self.ckpt_dir, 'decoder_tokenizer')

        self.params_path = os.path.join(self.ckpt_dir, 'params.json')

        if not self.predicting:
            create_path(self.ckpt_dir)

            # two ways to init model
            # 1. init from TF checkpoint dir created by bert-multitask-learning.
            # 2. init from huggingface checkpoint.

            # bert config exists, init from existing config
            if os.path.exists(from_config_path):
                # copy config
                copy_tree(from_config_path, to_config_path)
                self.bert_config = load_transformer_config(
                    to_config_path, self.transformer_config_loading)

                # copy tokenizer
                copy_tree(from_tokenizer_path, to_tokenizer_path)

                # copy decoder config
                if os.path.exists(from_decoder_config_path):
                    copy_tree(from_decoder_config_path,
                              to_decoder_config_path)
                    self.bert_decoder_config = load_transformer_config(
                        from_decoder_config_path, self.transformer_decoder_config_loading
                    )
                    self.bert_decoder_config_dict = self.bert_decoder_config.to_dict()
                # copy decoder tokenizer
                if os.path.exists(from_decoder_tokenizer_path):
                    copy_tree(from_decoder_tokenizer_path,
                              to_decoder_tokenizer_path)

                self.init_weight_from_huggingface = False
            else:
                # load config from huggingface
                logging.warning(
                    '%s not exists. will load model from huggingface checkpoint.', from_config_path)
                # get or download config
                self.init_weight_from_huggingface = True
                self.bert_config = load_transformer_config(
                    self.transformer_config_name, self.transformer_config_loading)
                self.bert_config.save_pretrained(to_config_path)

                # save tokenizer
                tokenizer = load_transformer_tokenizer(
                    self.transformer_tokenizer_name, self.transformer_tokenizer_loading)
                tokenizer.save_pretrained(to_tokenizer_path)
                # save_pretrained method of tokenizer saves the config as tokenizer_config.json, which will cause
                # OSError if use tokenizer.from_pretrained directly. we need to manually rename the json file
                try:
                    os.rename(os.path.join(to_tokenizer_path, 'tokenizer_config.json'), os.path.join(
                        to_tokenizer_path, 'config.json'))
                except:
                    pass

                # if decoder is specified
                if self.transformer_decoder_model_name:
                    self.bert_decoder_config = load_transformer_config(
                        self.transformer_decoder_config_name, self.transformer_decoder_config_loading
                    )
                    self.bert_decoder_config_dict = self.bert_decoder_config.to_dict()
                    self.bert_decoder_config.save_pretrained(
                        to_decoder_config_path)
                    decoder_tokenizer = load_transformer_tokenizer(
                        self.transformer_decoder_tokenizer_name, self.transformer_decoder_tokenizer_loading)
                    decoder_tokenizer.save_pretrained(
                        to_decoder_tokenizer_path)
                    try:
                        os.rename(os.path.join(to_decoder_tokenizer_path, 'tokenizer_config.json'), os.path.join(
                            to_decoder_tokenizer_path, 'config.json'))
                    except:
                        pass
        else:
            self.bert_config = load_transformer_config(to_config_path)
            if os.path.exists(to_decoder_config_path):
                self.bert_decoder_config = load_transformer_config(
                    to_decoder_config_path)
            self.init_weight_from_huggingface = False

        self.transformer_config_name = to_config_path
        # set value if and only if decoder is assigned
        self.transformer_decoder_config_name = to_decoder_config_path if self.transformer_decoder_config_name is not None else None
        self.transformer_tokenizer_name = to_tokenizer_path
        # set value if and only if decoder is assigned
        self.transformer_decoder_tokenizer_name = to_decoder_tokenizer_path if self.transformer_decoder_tokenizer_name is not None else None

        self.bert_config_dict = self.bert_config.to_dict()

        tokenizer = load_transformer_tokenizer(
            self.transformer_tokenizer_name, self.transformer_tokenizer_loading)
        self.vocab_size = tokenizer.vocab_size
        if self.transformer_decoder_tokenizer_name:
            decoder_tokenizer = load_transformer_tokenizer(
                self.transformer_decoder_tokenizer_name,
                self.transformer_decoder_tokenizer_loading
            )

            # if set bos and eos
            if decoder_tokenizer.bos_token is None:
                decoder_tokenizer.add_special_tokens({'bos_token': BOS_TOKEN})

            if decoder_tokenizer.eos_token is None:
                decoder_tokenizer.add_special_tokens({'eos_token': EOS_TOKEN})

            # overwrite tokenizer
            decoder_tokenizer.save_pretrained(to_decoder_tokenizer_path)

            self.decoder_vocab_size = decoder_tokenizer.vocab_size
            self.bos_id = decoder_tokenizer.bos_token_id
            self.eos_id = decoder_tokenizer.eos_token_id

    def get_problem_type(self, problem: str) -> str:
        return self.problem_type[problem]

    def update_train_steps(self, train_steps_per_epoch: int, epoch: int = None, warmup_ratio=0.1) -> None:
        """If the batch_size is dynamic, we have to loop through the tf.data.Dataset
        to get the accurate number of training steps. In this case, we need a function to
        update the train_steps which will be used to calculate learning rate schedule.

        WARNING: updating should be called before the model is compiled! 

        Args:
            train_steps (int): new number of train_steps
        """
        if epoch:
            train_steps = train_steps_per_epoch * epoch
        else:
            train_steps = train_steps_per_epoch * self.train_epoch

        logging.info('Updating train_steps from {0} to {1}'.format(
            self.train_steps, train_steps))

        self.train_steps = train_steps
        self.train_steps_per_epoch = train_steps_per_epoch
        self.num_warmup_steps = int(self.train_steps * warmup_ratio)

    def get_problem_chunk(self, as_str=True) -> Union[List[str], List[List[str]]]:

        if as_str:
            res_list = []
            for problem_list in self.problem_chunk:
                res_list.append('_'.join(sorted(problem_list)))
            return res_list
        else:
            return self.problem_chunk

    def set_data_sampling_strategy(self,
                                   sampling_strategy='data_balanced',
                                   sampling_strategy_fn: Callable = None) -> Dict[str, float]:
        if sampling_strategy_fn:
            logging.info(
                'sampling_strategy_fn is provided, sampling_strategy arg will be ignored.')
            raise NotImplementedError

        problem_chunk_data_num = defaultdict(float)
        if sampling_strategy == 'data_balanced':
            problem_chunk = self.get_problem_chunk(as_str=False)
            for problem_list in problem_chunk:
                str_per_chunk = '_'.join(sorted(problem_list))
                for problem in problem_list:
                    problem_chunk_data_num[str_per_chunk] += self.data_num_dict[problem]
        elif sampling_strategy == 'problem_balanced':
            problem_chunk = self.get_problem_chunk(as_str=True)
            for str_per_chunk in problem_chunk:
                problem_chunk_data_num[str_per_chunk] = 1
        else:
            raise ValueError(
                'sampling strategy {} is not implemented by default. '
                'please provide sampling_strategy_fn.'.format(sampling_strategy))

        # devided by sum to get sampling prob
        sum_across_problems = sum(
            [v for _, v in problem_chunk_data_num.items()])
        self.problem_sampling_weight_dict = {
            k: v / sum_across_problems for k, v in problem_chunk_data_num.items()}
        return self.problem_sampling_weight_dict

    def register_problem_type(self,
                              problem_type: str,
                              top_layer: tf.keras.Model,
                              label_handling_fn: Callable = None,
                              get_or_make_label_encoder_fn: Callable = None):
        self.problem_type_list.append(problem_type)
        self.get_or_make_label_encoder_fn_dict[problem_type] = get_or_make_label_encoder_fn
        self.top_layer[problem_type] = top_layer
        self.label_handling_fn[problem_type] = label_handling_fn


class CRFParams(BaseParams):
    def __init__(self):
        super(CRFParams, self).__init__()
        self.crf = True


class StaticBatchParams(BaseParams):
    def __init__(self):
        super(StaticBatchParams, self).__init__()
        self.dynamic_padding = False


class DynamicBatchSizeParams(BaseParams):
    def __init__(self):
        super(DynamicBatchSizeParams, self).__init__()
        self.bucket_batch_sizes = [128, 64, 32, 16]


In [None]:
# hide
params = BaseParams()
assert params.problem_assigned == False

## Add Problems


In [None]:
# hide
# define a simple preprocessing function
import bert_multitask_learning
from bert_multitask_learning import preprocessing_fn
@preprocessing_fn
def toy_cls(params: BaseParams, mode: str):
    "Simple example to demonstrate singe modal tuple of list return"
    if mode == bert_multitask_learning.TRAIN:
        toy_input = ['this is a toy input' for _ in range(10)]
        toy_target = ['a' for _ in range(10)]
    else:
        toy_input = ['this is a toy input for test' for _ in range(10)]
        toy_target = ['a' for _ in range(10)]
    return toy_input, toy_target

@preprocessing_fn
def toy_seq_tag(params: BaseParams, mode: str):
    "Simple example to demonstrate singe modal tuple of list return"
    if mode == bert_multitask_learning.TRAIN:
        toy_input = ['this is a toy input'.split(' ') for _ in range(10)]
        toy_target = [['a', 'b', 'c', 'd', 'e'] for _ in range(10)]
    else:
        toy_input = ['this is a toy input for test'.split(' ') for _ in range(10)]
        toy_target = [['a', 'b', 'c', 'd', 'e', 'e', 'e'] for _ in range(10)]
    return toy_input, toy_target

In [None]:
show_doc(BaseParams.add_problem)

<h4 id="BaseParams.add_problem" class="doc_header"><code>BaseParams.add_problem</code><a href="__main__.py#L168" class="source_link" style="float:right">[source]</a></h4>

> <code>BaseParams.add_problem</code>(**`problem_name`**:`str`, **`problem_type`**=*`'cls'`*, **`processing_fn`**:`Callable`=*`None`*)



Add problems.

Args:
- problem_name (str): problem name.
- problem_type (str, optional): One of the following problem types:
['cls', 'seq_tag', 'seq2seq_tag', 'seq2seq_text', 'multi_cls', 'pretrain'].
Defaults to 'cls'.
- processing_fn (Callable, optional): preprocessing function. Defaults to None.

Raises:
- ValueError: unexpected problem_type

In [None]:
params.add_problem(problem_name='toy_cls', problem_type='cls', processing_fn=toy_cls)
params.add_problem(problem_name='toy_seq_tag', problem_type='seq_tag', processing_fn=toy_seq_tag)

In [None]:
show_doc(BaseParams.add_multiple_problems)

<h4 id="BaseParams.add_multiple_problems" class="doc_header"><code>BaseParams.add_multiple_problems</code><a href="__main__.py#L178" class="source_link" style="float:right">[source]</a></h4>

> <code>BaseParams.add_multiple_problems</code>(**`problem_type_dict`**:`Dict`\[`str`, `str`\], **`processing_fn_dict`**:`Dict`\[`str`, `Callable`\]=*`None`*)



Add multiple problems.

processing_fn_dict is optional, if it's not provided, processing fn will be set as None.

Args:
- problem_type_dict (Dict[str, str]): problem type dict
- processing_fn_dict (Dict[str, Callable], optional): problem type fn. Defaults to None.

In [None]:
# make dict and add problems to params
problem_type_dict = {'toy_cls': 'cls', 'toy_seq_tag': 'seq_tag'}
processing_fn_dict = {'toy_cls': toy_cls, 'toy_seq_tag': toy_seq_tag}
params.add_multiple_problems(problem_type_dict=problem_type_dict, processing_fn_dict=processing_fn_dict)

Adding new problem toy_cls, problem type: cls
Adding new problem toy_seq_tag, problem type: seq_tag


## Assign Problems

In [None]:
show_doc(BaseParams.assign_problem)

<h4 id="BaseParams.assign_problem" class="doc_header"><code>BaseParams.assign_problem</code><a href="__main__.py#L190" class="source_link" style="float:right">[source]</a></h4>

> <code>BaseParams.assign_problem</code>(**`flag_string`**:`str`, **`gpu`**=*`2`*, **`base_dir`**:`str`=*`None`*, **`dir_name`**:`str`=*`None`*, **`predicting`**=*`False`*)



Assign the actual run problem to param. This function will
do the following things:

1. parse the flag string to form the run_problem_list
2. create checkpoint saving path
3. calculate total number of training data and training steps
4. scale learning rate with the number of gpu linearly

Arguments:
- flag_string {str} -- run problem string
- example: cws|POS|weibo_ner&weibo_cws

Keyword Arguments:
- gpu {int} -- number of gpu use for training, this will affect the training steps and learning rate (default: {2})
- base_dir {str} -- base dir for ckpt, if None, then "models" is assigned (default: {None})
- dir_name {str} -- dir name for ckpt, if None, will be created automatically (default: {None})
- predicting {bool} -- whether is predicting

In [None]:
params.assign_problem(flag_string='toy_seq_tag|toy_cls')
assert params.problem_assigned



After problem assigned, the model path should be created with tokenizers, label encoder files in it.

In [None]:
# hide
# assert os.listdir(params.ckpt_dir) == ['data_info.json',
#  'tokenizer',
#  'toy_cls_label_encoder.pkl',
#  'toy_seq_tag_label_encoder.pkl',
#  'bert_config']

## Register new problem type

You can also implement your own problem type. Essentially, a problem type has:
- name
- top layer
- label handling function
- label encoder creating function

Here we register a vector fitting(vector annealing) problem type as an example.

Note: This is originally designed as an internal API for development. So it's not user-friendly.

In [None]:
show_doc(BaseParams.register_problem_type)

<h4 id="BaseParams.register_problem_type" class="doc_header"><code>BaseParams.register_problem_type</code><a href="__main__.py#L552" class="source_link" style="float:right">[source]</a></h4>

> <code>BaseParams.register_problem_type</code>(**`problem_type`**:`str`, **`top_layer`**:`Model`, **`label_handling_fn`**:`Callable`=*`None`*, **`get_or_make_label_encoder_fn`**:`Callable`=*`None`*)



API to register a new problem type

Args:
- problem_type: string, problem type name
- top_layer: a keras model with some specific reqirements
- label_handling_fn: function to convert labels to label ids
- get_or_make_label_encoder_fn: function to create label encoder, num_classes has to be specified here

In [None]:
from bert_multitask_learning.top import BaseTop
# top layer
class VectorFit(BaseTop):
    def __init__(self, params: BaseParams, problem_name: str) -> None:
        super(VectorFit, self).__init__(
            params=params, problem_name=problem_name)
        self.num_classes = self.params.num_classes[problem_name]
        self.dense = tf.keras.layers.Dense(self.num_classes)

    def call(self, inputs: Tuple[Dict], mode: str):
        feature, hidden_feature = inputs
        pooled_hidden = hidden_feature['pooled']

        logits = self.dense(pooled_hidden)
        if mode != tf.estimator.ModeKeys.PREDICT:
            # this is the same as the label_id returned by vector_fit_label_handling_fn
            label = feature['{}_label_ids'.format(self.problem_name)]

            loss = empty_tensor_handling_loss(label, logits, cosine_wrapper)
            loss = nan_loss_handling(loss)
            self.add_loss(loss)

            self.add_metric(tf.math.negative(
                loss), name='{}_cos_sim'.format(self.problem_name), aggregation='mean')
        return logits

# label handling fn
def vector_fit_label_handling_fn(target, label_encoder=None, tokenizer=None, decoding_length=None):
    # don't need to encoder labels, return array directly
    # return label_id and label mask
    label_id = np.array(target, dtype='float32')
    return label_id, None

# make label encoder
def vector_fit_get_or_make_label_encoder_fn(params: BaseParams, problem, mode, label_list):
    # don't need to make label encoder here
    # set params num_classes for this problem
    label_array = np.array(label_list)
    params.num_classes[problem] = label_array.shape[-1]
    return None

params.register_problem_type(problem_type='vectorfit', top_layer=VectorFit, label_handling_fn=vector_fit_label_handling_fn, get_or_make_label_encoder_fn=vector_fit_get_or_make_label_encoder_fn)

## Utils

In [None]:
show_doc(BaseParams.from_json)

<h4 id="BaseParams.from_json" class="doc_header"><code>BaseParams.from_json</code><a href="__main__.py#L241" class="source_link" style="float:right">[source]</a></h4>

> <code>BaseParams.from_json</code>(**`json_path`**:`str`=*`None`*)

Load json file as params. 

json_path could not be None if the problem is not assigned to params

Args:
    json_path (str, optional): Path to json file. Defaults to None.

Raises:
    AttributeError

In [None]:
show_doc(BaseParams.to_json)

<h4 id="BaseParams.to_json" class="doc_header"><code>BaseParams.to_json</code><a href="__main__.py#L227" class="source_link" style="float:right">[source]</a></h4>

> <code>BaseParams.to_json</code>()

Save the params as json files. Please note that processing_fn is not saved.
        

In [None]:
show_doc(BaseParams.parse_problem_string)

<h4 id="BaseParams.parse_problem_string" class="doc_header"><code>BaseParams.parse_problem_string</code><a href="__main__.py#L314" class="source_link" style="float:right">[source]</a></h4>

> <code>BaseParams.parse_problem_string</code>(**`flag_string`**:`str`)



Parse problem string

Arguments: flag_string {str} -- problem string

Returns: list -- problem list

In [None]:
print('chained with |: ', params.parse_problem_string('toy_seq_tag|toy_cls'))
print('chained with &: ', params.parse_problem_string('toy_seq_tag&toy_cls'))

chained with |:  (['toy_cls', 'toy_seq_tag'], [['toy_seq_tag'], ['toy_cls']])
chained with &:  (['toy_cls', 'toy_seq_tag'], [['toy_seq_tag', 'toy_cls']])


In [None]:
show_doc(BaseParams.get_data_info)

<h4 id="BaseParams.get_data_info" class="doc_header"><code>BaseParams.get_data_info</code><a href="__main__.py#L275" class="source_link" style="float:right">[source]</a></h4>

> <code>BaseParams.get_data_info</code>(**`problem_list`**:`List`\[`str`\], **`base`**:`str`)



Get number of data, number of classes of data and eos_id of data.

Arguments:
- problem_list {list} -- problem list
- base {str} -- path to store data_info.json

In [None]:
params.get_data_info(params.problem_list, params.ckpt_dir)
print(params.data_num_dict, params.num_classes)

{'toy_cls': 10, 'toy_seq_tag': 10} {'toy_cls': 2, 'toy_seq_tag': 5}


In [None]:
show_doc(BaseParams.get_problem_type)

<h4 id="BaseParams.get_problem_type" class="doc_header"><code>BaseParams.get_problem_type</code><a href="__main__.py#L486" class="source_link" style="float:right">[source]</a></h4>

> <code>BaseParams.get_problem_type</code>(**`problem`**:`str`)



In [None]:
params.get_problem_type('toy_seq_tag')

'seq_tag'

In [None]:
show_doc(BaseParams.update_train_steps)

<h4 id="BaseParams.update_train_steps" class="doc_header"><code>BaseParams.update_train_steps</code><a href="__main__.py#L489" class="source_link" style="float:right">[source]</a></h4>

> <code>BaseParams.update_train_steps</code>(**`train_steps_per_epoch`**:`int`, **`epoch`**:`int`=*`None`*, **`warmup_ratio`**=*`0.1`*)

If the batch_size is dynamic, we have to loop through the tf.data.Dataset
to get the accurate number of training steps. In this case, we need a function to
update the train_steps which will be used to calculate learning rate schedule.

WARNING: updating should be called before the model is compiled! 

Args:
    train_steps (int): new number of train_steps

If the batch_size is dynamic, we have to loop through the tf.data.Dataset
to get the accurate number of training steps. In this case, we need a function to
update the train_steps which will be used to calculate learning rate schedule.

WARNING: updating should be called before the model is compiled! 

Args:
- train_steps (int): new number of train_steps

In [None]:
print(params.train_steps, params.num_warmup_steps)
params.update_train_steps(train_steps_per_epoch=100)
print(params.train_steps, params.num_warmup_steps)

4 0
1500 150


In [None]:
show_doc(BaseParams.set_data_sampling_strategy)

<h4 id="BaseParams.set_data_sampling_strategy" class="doc_header"><code>BaseParams.set_data_sampling_strategy</code><a href="__main__.py#L521" class="source_link" style="float:right">[source]</a></h4>

> <code>BaseParams.set_data_sampling_strategy</code>(**`sampling_strategy`**=*`'data_balanced'`*, **`sampling_strategy_fn`**:`Callable`=*`None`*)



Set data sampling strategy for multi-task learning.

'data_balanced' and 'problem_balanced' is implemented by default.
data_balanced: sampling weight equals to number of rows of that problem chunk.
problem_balanced: sampling weight equals to 1 for every problem chunk.

Args:
- sampling_strategy (str, optional): sampling strategy. Defaults to 'data_balanced'.
- sampling_strategy_fn (Callable, optional): function to create weight dict. Defaults to None.

Raises:
- NotImplementedError: sampling_strategy_fn is not implemented yet
- ValueError: invalid sampling_strategy provided

Returns:
- Dict[str, float]: sampling weight for each problem_chunk

In [None]:
params.set_data_sampling_strategy(sampling_strategy='problem_balanced')

{'toy_seq_tag': 0.5, 'toy_cls': 0.5}