In [None]:
# default_exp input_fn
import os
%load_ext autoreload
%autoreload 2
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


# Function to Create Datasets

Function to create datasets to train, eval and predict.

## Imports

In [None]:
# export
from typing import List, Union, Dict
import json
from loguru import logger

import tensorflow as tf

from m3tl.params import Params
from m3tl.read_write_tfrecord import read_tfrecord, write_tfrecord
from m3tl.special_tokens import PREDICT, TRAIN
from m3tl.utils import infer_shape_and_type_from_dict, get_is_pyspark
from m3tl.preproc_decorator import preprocessing_fn


## Train and Eval Dataset
We can get train and eval dataset by passing a problem assigned params and mode.

In [None]:
# export

def element_length_func(yield_dict: Dict[str, tf.Tensor]):  # pragma: no cover
    input_ids_keys = [k for k in yield_dict.keys() if 'input_ids' in k]
    max_length = tf.reduce_sum([tf.shape(yield_dict[k])[0]
                               for k in input_ids_keys])
    return max_length


def train_eval_input_fn(params: Params, mode=TRAIN) -> tf.data.Dataset:
    '''
    This function will write and read tf record for training
    and evaluation.

    Arguments:
        params {Params} -- Params objects

    Keyword Arguments:
        mode {str} -- ModeKeys (default: {TRAIN})

    Returns:
        tf Dataset -- Tensorflow dataset
    '''
    write_tfrecord(params=params)
    
    # reading with pyspark is not supported
    if get_is_pyspark():
        return

    dataset_dict = read_tfrecord(params=params, mode=mode)

    # make sure the order is correct
    dataset_dict_keys = list(dataset_dict.keys())
    dataset_list = [dataset_dict[key] for key in dataset_dict_keys]
    sample_prob_dict = params.calculate_data_sampling_prob()
    weight_list = [
        sample_prob_dict[key]
        for key in dataset_dict_keys
    ]
    
    logger.info('sampling weights: ')
    logger.info(json.dumps(params.problem_sampling_weight_dict, indent=4))
    # for problem_chunk_name, weight in params.problem_sampling_weight_dict.items():
    #     logger.info('{0}: {1}'.format(problem_chunk_name, weight))

    dataset = tf.data.experimental.sample_from_datasets(
        datasets=dataset_list, weights=weight_list)
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
    dataset = dataset.with_options(options)

    if mode == TRAIN:
        dataset = dataset.shuffle(params.shuffle_buffer)

    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    if params.dynamic_padding:
        dataset = dataset.apply(
            tf.data.experimental.bucket_by_sequence_length(
                element_length_func=element_length_func,
                bucket_batch_sizes=params.bucket_batch_sizes,
                bucket_boundaries=params.bucket_boundaries
            ))
    else:
        first_example = next(dataset.as_numpy_iterator())
        output_shapes, _ = infer_shape_and_type_from_dict(first_example)

        if mode == TRAIN:
            dataset = dataset.padded_batch(params.batch_size, output_shapes)
        else:
            dataset = dataset.padded_batch(params.batch_size*2, output_shapes)

    return dataset


In [None]:
# hide
from m3tl.test_base import TestBase
import m3tl
import shutil
import numpy as np
test_base = TestBase()
test_base.params.assign_problem(
    'weibo_fake_ner&weibo_fake_cls|weibo_fake_multi_cls|weibo_masklm')
params = test_base.params


2021-06-15 17:19:05.812 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_fake_ner, problem type: seq_tag
2021-06-15 17:19:05.812 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_fake_multi_cls, problem type: multi_cls
2021-06-15 17:19:05.813 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_fake_cls, problem type: cls
2021-06-15 17:19:05.813 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_masklm, problem type: masklm
2021-06-15 17:19:05.814 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_fake_regression, problem type: regression
2021-06-15 17:19:05.814 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_fake_vector_fit, problem type: vector_fit
2021-06-15 17:19:05.815 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_pre

In [None]:

train_dataset = train_eval_input_fn(
    params=params, mode=m3tl.TRAIN)
eval_dataset = train_eval_input_fn(
    params=params, mode=m3tl.EVAL
)

_ = next(train_dataset.as_numpy_iterator())
_ = next(eval_dataset.as_numpy_iterator())


2021-06-15 17:19:11.740 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmp2afsw8rx/weibo_fake_cls_weibo_fake_ner/train_00000.tfrecord
2021-06-15 17:19:11.777 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmp2afsw8rx/weibo_fake_cls_weibo_fake_ner/eval_00000.tfrecord
2021-06-15 17:19:11.803 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmp2afsw8rx/weibo_fake_multi_cls/train_00000.tfrecord
2021-06-15 17:19:11.827 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmp2afsw8rx/weibo_fake_multi_cls/eval_00000.tfrecord
2021-06-15 17:19:11.905 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmp2afsw8rx/weibo_masklm/train_00000.tfrecord
2021-06-15 17:19:11.955 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmp2afsw8rx/weibo_masklm/eval_00000.tfrecord
2021-06-15 17:19:12.697 | INFO     | __main__:train_eval_input_fn:37 - sampling weights: 
2021-06-15 17:19:12.698 | INFO     | __ma

In [None]:
# hide
# dynamic_padding disabled
# have to remove existing tfrecord
shutil.rmtree(test_base.tmpfiledir)
test_base.params.dynamic_padding = False
train_dataset = train_eval_input_fn(
    params=test_base.params, mode=m3tl.TRAIN)
eval_dataset = train_eval_input_fn(
    params=test_base.params, mode=m3tl.EVAL
)
_ = next(train_dataset.as_numpy_iterator())
_ = next(eval_dataset.as_numpy_iterator())


2021-06-15 17:19:14.010 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmp2afsw8rx/weibo_fake_cls_weibo_fake_ner/train_00000.tfrecord
2021-06-15 17:19:14.047 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmp2afsw8rx/weibo_fake_cls_weibo_fake_ner/eval_00000.tfrecord
2021-06-15 17:19:14.072 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmp2afsw8rx/weibo_fake_multi_cls/train_00000.tfrecord
2021-06-15 17:19:14.098 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmp2afsw8rx/weibo_fake_multi_cls/eval_00000.tfrecord
2021-06-15 17:19:14.180 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmp2afsw8rx/weibo_masklm/train_00000.tfrecord
2021-06-15 17:19:14.231 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmp2afsw8rx/weibo_masklm/eval_00000.tfrecord
2021-06-15 17:19:14.518 | INFO     | __main__:train_eval_input_fn:37 - sampling weights: 
2021-06-15 17:19:14.519 | INFO     | __ma

## Predict Dataset

We can create a predict dataset by passing list/generator of inputs and problem assigned params.

In [None]:
# export
def predict_input_fn(input_file_or_list: Union[str, List[str]],
                     params: Params,
                     mode=PREDICT,
                     labels_in_input=False) -> tf.data.Dataset:
    '''Input function that takes a file path or list of string and
    convert it to tf.dataset

    Example:
        predict_fn = lambda: predict_input_fn('test.txt', params)
        pred = estimator.predict(predict_fn)

    Arguments:
        input_file_or_list {str or list} -- file path or list of string
        params {Params} -- Params object

    Keyword Arguments:
        mode {str} -- ModeKeys (default: {PREDICT})

    Returns:
        tf dataset -- tf dataset
    '''

    # if is string, treat it as path to file
    if isinstance(input_file_or_list, str):
        inputs = open(input_file_or_list, 'r', encoding='utf8')
    else:
        inputs = input_file_or_list

    # ugly wrapping
    def gen():
        @preprocessing_fn
        def gen_wrapper(params, mode):
            return inputs
        return gen_wrapper(params, mode)

    first_dict = next(gen())

    output_shapes, output_type = infer_shape_and_type_from_dict(first_dict)
    dataset = tf.data.Dataset.from_generator(
        gen, output_types=output_type, output_shapes=output_shapes)

    dataset = dataset.padded_batch(
        params.batch_size,
        output_shapes
    )
    # dataset = dataset.batch(config.batch_size*2)

    return dataset


### Single modal inputs

In [None]:
from m3tl.utils import set_phase
from m3tl.special_tokens import PREDICT

In [None]:

set_phase(PREDICT)
single_dataset = predict_input_fn(
    ['this is a test']*5, params=params)
first_batch = next(single_dataset.as_numpy_iterator())
assert first_batch['text_input_ids'].tolist()[0] == [
    101,  8554,  8310,   143, 10060,   102]


2021-06-15 17:19:16.349 | INFO     | m3tl.utils:set_phase:478 - Setting phase to infer


### Multi-modal inputs

In [None]:
# multi modal input
mm_input = [{'text': 'this is a test',
             'image': np.zeros(shape=(5, 10), dtype='float32')}] * 5
mm_dataset = predict_input_fn(
    mm_input, params=params)
first_batch = next(mm_dataset.as_numpy_iterator())
assert first_batch['text_input_ids'].tolist()[0] == [
    101,  8554,  8310,   143, 10060,   102]
assert first_batch['image_input_ids'].tolist()[0] == np.zeros(
    shape=(5, 10), dtype='float32').tolist()
