In [None]:
# default_exp preproc_decorator


In [None]:
%load_ext autoreload
%autoreload 2

# Preprocessing Decorator

A decorator to simplify data preprocessing

In [None]:
# export
import logging
from types import GeneratorType
from typing import Callable
from inspect import signature

from sklearn.preprocessing import MultiLabelBinarizer

from bert_multitask_learning.read_write_tfrecord import (write_single_problem_chunk_tfrecord,
                                  write_single_problem_gen_tfrecord)
from bert_multitask_learning.special_tokens import PREDICT
from bert_multitask_learning.utils import LabelEncoder, get_or_make_label_encoder, load_transformer_tokenizer


def preprocessing_fn(func: Callable):
    """Usually used as a decorator.

    The input and output signature of decorated function should be:
    func(params: bert_multitask_learning.BaseParams,
         mode: str) -> Union[Generator[X, y], Tuple[List[X], List[y]]]

    Where X can be:
    - Dicitionary of 'a' and 'b' texts: {'a': 'a test', 'b': 'b test'}
    - Text: 'a test'
    - Dicitionary of modalities: {'text': 'a test', 'image': np.array([1,2,3])}

    Where y can be:
    - Text or scalar: 'label_a'
    - List of text or scalar: ['label_a', 'label_a1'] (for seq2seq and seq_tag)

    This decorator will do the following things:
    - load tokenizer
    - call func, save as example_list
    - create label_encoder and count the number of rows of example_list
    - create bert features from example_list and write tfrecord

    Args:
        func (Callable): preprocessing function for problem
    """
    def wrapper(params, mode, get_data_num=False, write_tfrecord=True):
        problem = func.__name__

        tokenizer = load_transformer_tokenizer(
            params.transformer_tokenizer_name, params.transformer_tokenizer_loading)
        proc_fn_signature_names = list(signature(
            func).parameters.keys())

        # proc func can return generator or tuple of lists
        # and it can have an optional get_data_num argument to
        # avoid iterate through the whole dataset to create
        # label encoder and get number of rows of data
        if len(proc_fn_signature_names) == 2:
            example_list = func(params, mode)
        else:
            example_list = func(params, mode, get_data_num)

        if isinstance(example_list, GeneratorType):
            if get_data_num:
                # create label encoder and data num
                cnt = 0
                label_list = []
                logging.info(
                    "Preprocessing function returns generator, might take some time to create label encoder...")
                for example in example_list:
                    if isinstance(example[0], int):
                        data_num, label_encoder = example
                        return data_num, None
                    cnt += 1
                    try:
                        _, label = example
                        label_list.append(label)
                    except ValueError:
                        pass

                # create label encoder
                label_encoder = get_or_make_label_encoder(
                    params, problem=problem, mode=mode, label_list=label_list)

                if label_encoder is None:
                    return cnt, 0
                if isinstance(label_encoder, LabelEncoder):
                    return cnt, len(label_encoder.encode_dict)
                if isinstance(label_encoder, MultiLabelBinarizer):
                    return cnt, label_encoder.classes_.shape[0]

                # label_encoder is tokenizer
                try:
                    return cnt, len(label_encoder.vocab)
                except AttributeError:
                    # models like xlnet's vocab size can only be retrieved from config instead of tokenizer
                    return cnt, params.bert_decoder_config.vocab_size

            else:
                # create label encoder
                label_encoder = get_or_make_label_encoder(
                    params, problem=problem, mode=mode, label_list=[])

            if mode == PREDICT:
                return example_list, label_encoder

            if write_tfrecord:
                return write_single_problem_gen_tfrecord(
                    func.__name__,
                    example_list,
                    label_encoder,
                    params,
                    tokenizer,
                    mode)
            else:
                return {
                    'problem': func.__name__,
                    'gen': example_list,
                    'label_encoder': label_encoder,
                    'tokenizer': tokenizer
                }

        else:
            # if proc func returns integer as the first element,
            # that means it returns (num_of_data, label_encoder)
            if isinstance(example_list[0], int):
                data_num, label_encoder = example_list
                inputs_list, target_list = None, None
            else:
                try:
                    inputs_list, target_list = example_list
                except ValueError:
                    inputs_list = example_list
                    target_list = None

                label_encoder = get_or_make_label_encoder(
                    params, problem=problem, mode=mode, label_list=target_list)
                data_num = len(inputs_list)

            if get_data_num:
                if label_encoder is None:
                    return data_num, 0
                if isinstance(label_encoder, LabelEncoder):
                    return data_num, len(label_encoder.encode_dict)
                if isinstance(label_encoder, MultiLabelBinarizer):
                    return data_num, label_encoder.classes_.shape[0]
                if hasattr(label_encoder, 'vocab'):
                    # label_encoder is tokenizer
                    return data_num, len(label_encoder.vocab)
                elif hasattr(params, 'decoder_vocab_size'):
                    return data_num, params.decoder_vocab_size
                else:
                    raise ValueError('Cannot determine num of classes for problem {0}.'
                                     'This is usually caused by {1} dose not has attribute vocab. In this case, you should manually specify vocab size to params: params.decoder_vocab_size = 32000'.format(problem, type(label_encoder).__name__))

            if mode == PREDICT:
                return inputs_list, target_list, label_encoder

            if write_tfrecord:
                return write_single_problem_chunk_tfrecord(
                    func.__name__,
                    inputs_list,
                    target_list,
                    label_encoder,
                    params,
                    tokenizer,
                    mode)
            else:
                return {
                    'problem': func.__name__,
                    'inputs_list': inputs_list,
                    'target_list': target_list,
                    'label_encoder': label_encoder,
                    'tokenizer': tokenizer
                }

    return wrapper


## User-Defined Preprocessing Function

The user-defined preprocessing function should return two elements: features and targets, except for `pretrain` problem type.

For features and targets, it can be one of the following format:
- tuple of list
- generator of tuple

Please note that if preprocessing function returns generator of tuple, then corresponding problem cannot be chained using `&`.

In [None]:
# hide
import bert_multitask_learning
from bert_multitask_learning.params import BaseParams
from typing import Tuple
import shutil
import tempfile
import numpy as np
import os


In [None]:
# setup params for testing
params = BaseParams()
params.ckpt_dir = tempfile.mkdtemp()
params.tmp_file_dir = tempfile.mkdtemp()

### Tuple of List

#### Single Modal


In [None]:
@preprocessing_fn
def toy_cls(params: BaseParams, mode: str) -> Tuple[list, list]:
    "Simple example to demonstrate singe modal tuple of list return"
    if mode == bert_multitask_learning.TRAIN:
        toy_input = ['this is a toy input' for _ in range(10)]
        toy_target = ['a' for _ in range(10)]
    else:
        toy_input = ['this is a toy input for test' for _ in range(10)]
        toy_target = ['a' for _ in range(10)]
    return toy_input, toy_target

In [None]:
# hide
def preproc_dec_test():
    params.add_problem(problem_name='toy_cls', problem_type='cls', processing_fn=toy_cls)
    assert (10, 1)==toy_cls(params=params, mode=bert_multitask_learning.TRAIN, get_data_num=True, write_tfrecord=False)

    toy_cls(params=params, mode=bert_multitask_learning.TRAIN, get_data_num=False, write_tfrecord=True)
    assert os.path.exists(os.path.join(params.tmp_file_dir, 'toy_cls', 'train_feature_desc.json'))
preproc_dec_test()

INFO:tensorflow:this is a toy input
INFO:tensorflow:input_ids: [[101, 8554, 8310, 143, 8228, 8179, 8217, 11300, 102]]
INFO:tensorflow:input_mask: [[1, 1, 1, 1, 1, 1, 1, 1, 1]]
INFO:tensorflow:segment_ids: [[0, 0, 0, 0, 0, 0, 0, 0, 0]]
INFO:tensorflow:toy_cls_label_ids: 0
INFO:tensorflow:this is a toy input
INFO:tensorflow:input_ids: [[101, 8554, 8310, 143, 8228, 8179, 8217, 11300, 102]]
INFO:tensorflow:input_mask: [[1, 1, 1, 1, 1, 1, 1, 1, 1]]
INFO:tensorflow:segment_ids: [[0, 0, 0, 0, 0, 0, 0, 0, 0]]
INFO:tensorflow:toy_cls_label_ids: 0
INFO:tensorflow:this is a toy input
INFO:tensorflow:input_ids: [[101, 8554, 8310, 143, 8228, 8179, 8217, 11300, 102]]
INFO:tensorflow:input_mask: [[1, 1, 1, 1, 1, 1, 1, 1, 1]]
INFO:tensorflow:segment_ids: [[0, 0, 0, 0, 0, 0, 0, 0, 0]]
INFO:tensorflow:toy_cls_label_ids: 0
INFO:tensorflow:this is a toy input
INFO:tensorflow:input_ids: [[101, 8554, 8310, 143, 8228, 8179, 8217, 11300, 102]]
INFO:tensorflow:input_mask: [[1, 1, 1, 1, 1, 1, 1, 1, 1]]
INFO:ten

#### Multi-modal

In [None]:
@preprocessing_fn
def toy_cls(params: BaseParams, mode: str) -> Tuple[list, list]:
    "Simple example to demonstrate multi-modal tuple of list return"
    if mode == bert_multitask_learning.TRAIN:
        toy_input = [{'text': 'this is a toy input', 'image': np.random.uniform(size=(16))} for _ in range(10)]
        toy_target = ['a' for _ in range(10)]
    else:
        toy_input = [{'text': 'this is a toy input for test', 'image': np.random.uniform(size=(16))} for _ in range(10)]
        toy_target = ['a' for _ in range(10)]
    
    return toy_input, toy_target

In [None]:
# hide
preproc_dec_test()

INFO:tensorflow:text: this is a toy input
INFO:tensorflow:image: [0.5809141  0.06545048 0.33736762 0.67150339 0.79172166 0.28670109
 0.35819524 0.16445301 0.63652557 0.58635403 0.01462962 0.31659283
 0.60157348 0.15251305 0.47542086 0.64718995]
INFO:tensorflow:input_ids: [[101, 8554, 8310, 143, 8228, 8179, 8217, 11300, 102]]
INFO:tensorflow:input_mask: [[1, 1, 1, 1, 1, 1, 1, 1, 1]]
INFO:tensorflow:segment_ids: [[0, 0, 0, 0, 0, 0, 0, 0, 0]]
INFO:tensorflow:toy_cls_label_ids: 0
INFO:tensorflow:image_input: [[0.5809141  0.06545048 0.33736762 0.67150339 0.79172166 0.28670109
  0.35819524 0.16445301 0.63652557 0.58635403 0.01462962 0.31659283
  0.60157348 0.15251305 0.47542086 0.64718995]]
INFO:tensorflow:image_mask: [1]
INFO:tensorflow:image_segment_ids: [0]
INFO:tensorflow:text: this is a toy input
INFO:tensorflow:image: [0.78141419 0.84232392 0.46107929 0.98250265 0.39704729 0.82511787
 0.13924664 0.91397845 0.46385502 0.50603803 0.8973713  0.26643603
 0.32537559 0.35824151 0.57058196 0.

#### A, B Token Multi-modal

TODO: Implement this. Not working yet.

In [None]:
# hide
@preprocessing_fn
def toy_cls(params: BaseParams, mode: str) -> Tuple[list, list]:
    "Simple example to demonstrate A, B token multi-modal tuple of list return"
    if mode == bert_multitask_learning.TRAIN:
        toy_input = [
            {
                'a': {
            'text': 'this is a toy input', 
            'image': np.random.uniform(size=(16))
            },
            'b':{
            'text': 'this is a toy input', 
            'image': np.random.uniform(size=(16))
            }
            } for _ in range(10)]
        toy_target = ['a' for _ in range(10)]
    else:
        toy_input = [
            {
                'a': {
            'text': 'this is a toy input for test', 
            'image': np.random.uniform(size=(16))
            },
            'b':{
            'text': 'this is a toy input for test', 
            'image': np.random.uniform(size=(16))
            }
            } for _ in range(10)]
        toy_target = ['a' for _ in range(10)]
    
    return toy_input, toy_target

In [None]:
# # hide
# params.add_problem(problem_name='toy_cls', problem_type='cls', processing_fn=toy_cls)
# assert (10, 1)==toy_cls(params=params, mode=bert_multitask_learning.TRAIN, get_data_num=True, write_tfrecord=False)

# shutil.rmtree(os.path.join(params.tmp_file_dir, 'toy_cls'))
# toy_cls(params=params, mode=bert_multitask_learning.TRAIN, get_data_num=False, write_tfrecord=True)
# assert os.path.exists(os.path.join(params.tmp_file_dir, 'toy_cls', 'train_feature_desc.json'))

### Generator of Tuple

#### Single Modal

In [None]:
@preprocessing_fn
def toy_cls(params: BaseParams, mode: str) -> Tuple[list, list]:
    "Simple example to demonstrate singe modal tuple of list return"
    if mode == bert_multitask_learning.TRAIN:
        toy_input = ['this is a toy input' for _ in range(10)]
        toy_target = ['a' for _ in range(10)]
    else:
        toy_input = ['this is a toy input for test' for _ in range(10)]
        toy_target = ['a' for _ in range(10)]
    for i, t in zip(toy_input, toy_target):
        yield i, t

In [None]:
# hide
preproc_dec_test()

INFO:tensorflow:this is a toy input
INFO:tensorflow:input_ids: [[101, 8554, 8310, 143, 8228, 8179, 8217, 11300, 102]]
INFO:tensorflow:input_mask: [[1, 1, 1, 1, 1, 1, 1, 1, 1]]
INFO:tensorflow:segment_ids: [[0, 0, 0, 0, 0, 0, 0, 0, 0]]
INFO:tensorflow:toy_cls_label_ids: 0
INFO:tensorflow:this is a toy input
INFO:tensorflow:input_ids: [[101, 8554, 8310, 143, 8228, 8179, 8217, 11300, 102]]
INFO:tensorflow:input_mask: [[1, 1, 1, 1, 1, 1, 1, 1, 1]]
INFO:tensorflow:segment_ids: [[0, 0, 0, 0, 0, 0, 0, 0, 0]]
INFO:tensorflow:toy_cls_label_ids: 0
INFO:tensorflow:this is a toy input
INFO:tensorflow:input_ids: [[101, 8554, 8310, 143, 8228, 8179, 8217, 11300, 102]]
INFO:tensorflow:input_mask: [[1, 1, 1, 1, 1, 1, 1, 1, 1]]
INFO:tensorflow:segment_ids: [[0, 0, 0, 0, 0, 0, 0, 0, 0]]
INFO:tensorflow:toy_cls_label_ids: 0
INFO:tensorflow:this is a toy input
INFO:tensorflow:input_ids: [[101, 8554, 8310, 143, 8228, 8179, 8217, 11300, 102]]
INFO:tensorflow:input_mask: [[1, 1, 1, 1, 1, 1, 1, 1, 1]]
INFO:ten

#### Multi-modal

In [None]:
@preprocessing_fn
def toy_cls(params: BaseParams, mode: str) -> Tuple[list, list]:
    "Simple example to demonstrate multi-modal tuple of list return"
    if mode == bert_multitask_learning.TRAIN:
        toy_input = [{'text': 'this is a toy input', 'image': np.random.uniform(size=(16))} for _ in range(10)]
        toy_target = ['a' for _ in range(10)]
    else:
        toy_input = [{'text': 'this is a toy input for test', 'image': np.random.uniform(size=(16))} for _ in range(10)]
        toy_target = ['a' for _ in range(10)]
    for i, t in zip(toy_input, toy_target):
        yield i, t

In [None]:
# hide
preproc_dec_test()


INFO:tensorflow:text: this is a toy input
INFO:tensorflow:image: [0.18114737 0.52025721 0.30680549 0.23567995 0.61244397 0.21831956
 0.76148699 0.61470593 0.10734543 0.94891037 0.04046289 0.16569479
 0.76607886 0.90692863 0.26102512 0.97853116]
INFO:tensorflow:input_ids: [[101, 8554, 8310, 143, 8228, 8179, 8217, 11300, 102]]
INFO:tensorflow:input_mask: [[1, 1, 1, 1, 1, 1, 1, 1, 1]]
INFO:tensorflow:segment_ids: [[0, 0, 0, 0, 0, 0, 0, 0, 0]]
INFO:tensorflow:toy_cls_label_ids: 0
INFO:tensorflow:image_input: [[0.18114737 0.52025721 0.30680549 0.23567995 0.61244397 0.21831956
  0.76148699 0.61470593 0.10734543 0.94891037 0.04046289 0.16569479
  0.76607886 0.90692863 0.26102512 0.97853116]]
INFO:tensorflow:image_mask: [1]
INFO:tensorflow:image_segment_ids: [0]
INFO:tensorflow:text: this is a toy input
INFO:tensorflow:image: [0.18484544 0.80277546 0.19575457 0.38967686 0.22814962 0.346866
 0.69174051 0.62605842 0.62948146 0.01061797 0.38932132 0.52792525
 0.05108438 0.26307339 0.33575207 0.20