In [None]:
# default_exp read_write_tfrecord
import os
%load_ext autoreload
%autoreload 2
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


# Read and Write TFRecord

Write bert features to TFRecord and read bert features from TFRecord.

## Imports

In [None]:
# export
import json
import os
from fastcore.basics import partial
from glob import glob
from typing import Dict, Iterator, Callable
import tempfile

from loguru import logger
import numpy as np
import tensorflow as tf
from fastcore.basics import listify

from m3tl.bert_preprocessing.create_bert_features import create_multimodal_bert_features
from m3tl.special_tokens import EVAL, TRAIN
from m3tl.params import Params
from m3tl.utils import get_is_pyspark


## Write TFRecords

### Serialize Functions

In [None]:
# export

def _float_list_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def _int64_list_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _bytes_list_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))


def serialize_fn(features: dict, return_feature_desc=False):
    features_tuple = {}
    feature_desc = {}
    for feature_name, feature in features.items():
        if type(feature) is list:
            feature = np.array(feature)
        if type(feature) is np.ndarray:
            if issubclass(feature.dtype.type, np.integer):
                features_tuple[feature_name] = _int64_list_feature(
                    feature.flatten())
                feature_desc[feature_name] = 'int64'
            # elif issubclass(feature.dtype.type, np.float):
            else:
                features_tuple[feature_name] = _float_list_feature(
                    feature.flatten())
                feature_desc[feature_name] = 'float32'

            features_tuple['{}_shape'.format(
                feature_name)] = _int64_list_feature(feature.shape)
            feature_desc['{}_shape_value'.format(feature_name)] = feature.shape

            feature_desc['{}_shape'.format(
                feature_name)] = 'int64'

            # this seems not a good idea
            if len(feature.shape) > 1:
                feature_desc['{}_shape_value'.format(feature_name)] = [
                    None] + list(feature.shape[1:])
            else:
                feature_desc['{}_shape_value'.format(feature_name)] = [
                    None for _ in feature.shape]

        elif np.issubdtype(type(feature), np.float):
            features_tuple[feature_name] = _float_feature(feature)
            features_tuple['{}_shape'.format(
                feature_name)] = _int64_list_feature([])
            feature_desc[feature_name] = 'float32'

            feature_desc['{}_shape'.format(
                feature_name)] = 'int64'
            feature_desc['{}_shape_value'.format(feature_name)] = []
        elif np.issubdtype(type(feature), np.integer):
            features_tuple[feature_name] = _int64_feature(feature)
            features_tuple['{}_shape'.format(
                feature_name)] = _int64_list_feature([])
            feature_desc[feature_name] = 'int64'
            feature_desc['{}_shape'.format(
                feature_name)] = 'int64'
            feature_desc['{}_shape_value'.format(feature_name)] = []
        else:
            if isinstance(feature, str):
                feature = feature.encode('utf8')
            features_tuple[feature_name] = _bytes_feature(feature)
            features_tuple['{}_shape'.format(
                feature_name)] = _int64_list_feature([])
            feature_desc[feature_name] = 'string'
            feature_desc['{}_shape'.format(
                feature_name)] = 'int64'
            feature_desc['{}_shape_value'.format(feature_name)] = []

    example_proto = tf.train.Example(
        features=tf.train.Features(feature=features_tuple)).SerializeToString()

    if return_feature_desc:
        return example_proto, feature_desc

    return example_proto


In [None]:
# hide
from m3tl.test_base import TestBase
test_base = TestBase()
test_features = {
    'int_scalar': 1,
    'float_scalar': 2.0,
    'int_array': [1, 2, 3],
    'float_array': np.array([4, 5, 6], dtype='float32'),
    'int_matrix': [[1, 2, 3], [4, 5, 6]],
    'float_matrix': np.random.uniform(size=(32, 5, 5)),
    'string': 'this is test'
}
expected_desc = {'int_scalar': 'int64', 'int_scalar_shape': 'int64', 'int_scalar_shape_value': [],
                 'float_scalar': 'float32', 'float_scalar_shape': 'int64', 'float_scalar_shape_value': [],
                 'int_array': 'int64', 'int_array_shape_value': [None], 'int_array_shape': 'int64',
                 'float_array': 'float32', 'float_array_shape_value': [None], 'float_array_shape': 'int64',
                 'int_matrix': 'int64', 'int_matrix_shape_value': [None, 3], 'int_matrix_shape': 'int64',
                 'float_matrix': 'float32', 'float_matrix_shape_value': [None, 5, 5], 'float_matrix_shape': 'int64',
                 'string': 'string', 'string_shape': 'int64', 'string_shape_value': []}
ser_str, feat_desc = serialize_fn(
    features=test_features, return_feature_desc=True)
assert feat_desc == expected_desc

example = tf.train.Example()
example.ParseFromString(ser_str)
assert example.features.feature['int_array'].int64_list.value == [1, 2, 3]


2021-06-19 21:40:51.722 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_fake_ner, problem type: seq_tag
2021-06-19 21:40:51.723 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_fake_multi_cls, problem type: multi_cls
2021-06-19 21:40:51.723 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_fake_cls, problem type: cls
2021-06-19 21:40:51.723 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_masklm, problem type: masklm
2021-06-19 21:40:51.724 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_fake_regression, problem type: regression
2021-06-19 21:40:51.724 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_fake_vector_fit, problem type: vector_fit
2021-06-19 21:40:51.724 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_pre

### Make TFRecord

In [None]:
# export
def make_tfrecord_local(data_list, output_dir, serialize_fn, mode='train', example_per_file=100000, prefix='', **kwargs) -> int:
    """
    make tf record and return total number of records
    """
    # create output tfrecord path
    os.makedirs(os.path.join(
        output_dir, prefix), exist_ok=True)

    def _write_fn(d_list, path, serialize_fn, mode='train'):
        logger.debug('Writing {}'.format(path))
        feature_desc_path = os.path.join(os.path.dirname(
            path), '{}_feature_desc.json'.format(mode))

        with tf.io.TFRecordWriter(path) as writer:
            for features in d_list:
                example, feature_desc = serialize_fn(
                    features, return_feature_desc=True)
                writer.write(example)
                if not os.path.exists(feature_desc_path):
                    json.dump(feature_desc, open(
                        feature_desc_path, 'w', encoding='utf8'))

    _write_part_fn = partial(_write_fn, serialize_fn=serialize_fn, mode=mode)

    x = []
    total_count = 0
    shard_count = 0
    for idx, example in enumerate(data_list):
        total_count += 1
        x.append(example)
        if idx % example_per_file == 0 and idx:  # pragma: no cover
            path = os.path.join(
                output_dir, prefix, '{}_{:05d}.tfrecord'.format(mode, shard_count))
            shard_count += 1
            _write_part_fn(d_list=x, path=path)
            x = []

    # add remaining
    if x:
        path = os.path.join(
            output_dir, prefix, '{}_{:05d}.tfrecord'.format(mode, shard_count))
        _write_part_fn(d_list=x, path=path)
        total_count += len(x)
    return total_count


def make_tfrecord_pyspark(data_list, output_dir: str, serialize_fn: Callable, mode='train', example_per_file=100000, prefix='', **kwargs) -> int:
    """
    make tf record and return total number of records with pyspark
    """
    from m3tl.pyspark_utils import Hdfs, repar_rdd
    from pyspark import RDD

    # write RDD to TFRecords
    # ref: https://github.com/yahoo/TensorFlowOnSpark/blob/master/examples/mnist/mnist_data_setup.py
    # just for type hint
    data_list: RDD = data_list

    # since single record might not contain all problem labels
    # we create feature desc for all record and aggregate
    # TODO: poor performance, optimize this
    feat_desc_tfrecord_tuple_rdd = data_list.map(
        lambda x: serialize_fn(x, return_feature_desc=True)
    )
    feat_desc_tfrecord_tuple_rdd = feat_desc_tfrecord_tuple_rdd.cache()
    rdd_count = int(feat_desc_tfrecord_tuple_rdd.sample(
        False, 0.01).count() * 100)
    feat_desc_tfrecord_tuple_rdd = repar_rdd(
        rdd=feat_desc_tfrecord_tuple_rdd,
        rdd_count=rdd_count,
        example_per_par=example_per_file
    )
    feature_desc_pair_rdd = feat_desc_tfrecord_tuple_rdd.map(
        lambda x: (0, x[1]))
    tfrecord_rdd = feat_desc_tfrecord_tuple_rdd.map(
        lambda x: (bytearray(x[0]), None))

    tfrecord_rdd.saveAsNewAPIHadoopFile(
        path=output_dir,
        outputFormatClass="org.tensorflow.hadoop.io.TFRecordFileOutputFormat",
        keyClass="org.apache.hadoop.io.BytesWritable",
        valueClass="org.apache.hadoop.io.NullWritable"
    )

    # create feature desc
    def _update_dict(ld: dict, rd: dict) -> dict:
        ld.update(rd)
        return ld
    feature_desc = feature_desc_pair_rdd.reduceByKeyLocally(_update_dict)[0]

    local_feature_desc_path = '{}_feature_desc.json'.format(mode)
    json.dump(feature_desc, open(local_feature_desc_path, 'w'), indent=4)
    hdfs_client = Hdfs()
    hdfs_client.copyFromLocalFile(
        local_feature_desc_path,
        os.path.join(output_dir, local_feature_desc_path))
    return rdd_count


def make_tfrecord(data_list, output_dir, serialize_fn, mode='train', example_per_file=100000, prefix='', **kwargs):
    if get_is_pyspark():
        output_dir = kwargs['pyspark_dir']
        output_dir = os.path.join(output_dir, mode)
        return make_tfrecord_pyspark(data_list=data_list,
                                     output_dir=output_dir,
                                     serialize_fn=serialize_fn,
                                     mode=mode,
                                     example_per_file=example_per_file,
                                     prefix=prefix, **kwargs)

    return make_tfrecord_local(data_list=data_list,
                               output_dir=output_dir,
                               serialize_fn=serialize_fn,
                               mode=mode,
                               example_per_file=example_per_file,
                               prefix=prefix, **kwargs)


#### Local write tfrecord test

In [None]:
# hide
make_tfrecord(
    [test_features], output_dir=test_base.tmpfiledir, serialize_fn=serialize_fn)
assert os.path.exists(os.path.join(
    test_base.tmpfiledir, 'train_feature_desc.json'))
assert os.path.exists(os.path.join(
    test_base.tmpfiledir, 'train_00000.tfrecord'))


2021-06-19 21:40:57.758 | DEBUG    | __main__:_write_fn:11 - Writing /tmp/tmpa5zj7lsf/train_00000.tfrecord


#### Pyspark write tfrecord test

In [None]:
# init spark
from pyspark import SparkContext, SparkConf
jar_path = '/data/m3tl/tmp/tensorflow-hadoop-1.10.0.jar'


conf = SparkConf().set('spark.jars', jar_path)

sc = SparkContext(conf=conf)


In [None]:
from m3tl.utils import set_is_pyspark
import tempfile
set_is_pyspark(True)
test_features_rdd = sc.parallelize([test_features]).coalesce(1)
pyspark_dir = tempfile.mkdtemp()
make_tfrecord(
    test_features_rdd, output_dir=test_base.tmpfiledir, serialize_fn=serialize_fn, pyspark_dir=pyspark_dir)


0

In [None]:
# check data
from m3tl.read_write_tfrecord import make_feature_desc
json_path = os.path.join(test_base.tmpfiledir, 'train_feature_desc.json')
feature_desc = make_feature_desc(json.load(open(json_path, 'r')))

tfrecord_path = os.path.join(pyspark_dir, 'train', 'part-r-00000')

tfr_dataset = tf.data.TFRecordDataset(tfrecord_path)


def _parse_fn(x):
    return tf.io.parse_single_example(x, feature_desc)


tfr_dataset = tfr_dataset.map(_parse_fn)

for i in tfr_dataset.take(1):
    assert np.all(tf.sparse.to_dense(i['float_array']).numpy(
    ) == np.array([4., 5., 6.], dtype='float32'))


### Chain problems and write API

In [None]:
# export

def chain_processed_data(problem_preproc_gen_dict: Dict[str, Iterator]) -> Iterator:
    # problem chunk size is 1, return generator directly
    if len(problem_preproc_gen_dict) == 1:
        return next(iter(problem_preproc_gen_dict.values()))

    if get_is_pyspark():
        from pyspark import RDD
        from m3tl.pyspark_utils import join_dict_of_rdd

        rdd = join_dict_of_rdd(rdd_dict=problem_preproc_gen_dict)
        return rdd

    logger.warning('Chaining problems with & may consume a lot of memory if'
                   ' data is not pyspark RDD.')
    data_dict = {}
    column_list = []
    for pro in problem_preproc_gen_dict:
        data_dict[pro] = listify(problem_preproc_gen_dict[pro])
        try:
            column_list.append(list(data_dict[pro][0].keys()))
        except IndexError:
            raise IndexError("Problem {} has no data".format(pro))

    # get intersection and use as ensure features are the same
    join_key = list(set(column_list[0]).intersection(*column_list[1:]))

    flat_data_list = []
    first_problem = next(iter(problem_preproc_gen_dict.keys()))
    while data_dict[first_problem]:
        d = {}
        for pro in data_dict:
            if not d:
                d = data_dict[pro].pop(0)
            else:
                for k in join_key:
                    assert d[k] == data_dict[pro][0][k], 'At iteration {}, feature {} not align. Expected {}, got: {}'.format(
                        len(flat_data_list), k, d[k], data_dict[pro][0][k]
                    )
                d.update(data_dict[pro].pop(0))
        flat_data_list.append(d)
    return flat_data_list


def write_tfrecord(params: Params, replace=False):
    """Write TFRecord for every problem chunk

    Output location: params.tmp_file_dir

    Arguments:
        params {params} -- params

    Keyword Arguments:
        replace {bool} -- Whether to replace existing tfrecord (default: {False})
    """

    read_data_fn_dict = params.read_data_fn
    path_list = []
    for problem_list in params.problem_chunk:
        problem_str = '_'.join(sorted(problem_list))
        file_dir = os.path.join(params.tmp_file_dir, problem_str)
        if params.pyspark_output_path is not None:
            pyspark_dir = os.path.join(params.pyspark_output_path, problem_str)
        else:
            pyspark_dir = None
        if not os.path.exists(file_dir) or replace:
            for mode in [TRAIN, EVAL]:

                problem_preproc_gen_dict = {}
                for p in problem_list:
                    problem_preproc_gen_dict[p] = read_data_fn_dict[p](
                        params=params, mode=mode)

                chained_data = chain_processed_data(problem_preproc_gen_dict)

                total_count = make_tfrecord(data_list=chained_data, output_dir=file_dir,
                                            mode=mode, serialize_fn=serialize_fn, pyspark_dir=pyspark_dir,
                                            example_per_file=params.example_per_file)
                if mode == TRAIN:
                    params.set_problem_info(
                        problem=problem_str, info_name='data_num', info=total_count)
                
                if get_is_pyspark():
                    from m3tl.pyspark_utils import Hdfs, get_text_file_from_executor
                    # upload problem_info if pyspark
                    local_problem_info_path = params.get_problem_info_path(problem_str)
                    tempfile_name = tempfile.NamedTemporaryFile().name
                    get_text_file_from_executor(local_problem_info_path, tempfile_name)
                    params.merge_problem_info_file(tempfile_name)                    
                    Hdfs().copyFromLocalFile(local_problem_info_path, pyspark_dir)
                



In [None]:
set_is_pyspark(False)
test_base.params.assign_problem(
    'weibo_fake_ner&weibo_fake_cls|weibo_fake_multi_cls|weibo_masklm|weibo_premask_mlm', base_dir=test_base.tmpckptdir)
write_tfrecord(
    params=test_base.params, replace=True)
assert os.path.exists(os.path.join(
    test_base.tmpfiledir, 'weibo_fake_cls_weibo_fake_ner'))
assert os.path.exists(os.path.join(
    test_base.tmpfiledir, 'weibo_fake_multi_cls'))


2021-06-19 21:41:03.942 | DEBUG    | __main__:_write_fn:11 - Writing /tmp/tmpa5zj7lsf/weibo_fake_cls_weibo_fake_ner/train_00000.tfrecord
2021-06-19 21:41:03.979 | DEBUG    | __main__:_write_fn:11 - Writing /tmp/tmpa5zj7lsf/weibo_fake_cls_weibo_fake_ner/eval_00000.tfrecord
2021-06-19 21:41:04.005 | DEBUG    | __main__:_write_fn:11 - Writing /tmp/tmpa5zj7lsf/weibo_fake_multi_cls/train_00000.tfrecord
2021-06-19 21:41:04.030 | DEBUG    | __main__:_write_fn:11 - Writing /tmp/tmpa5zj7lsf/weibo_fake_multi_cls/eval_00000.tfrecord
2021-06-19 21:41:04.110 | DEBUG    | __main__:_write_fn:11 - Writing /tmp/tmpa5zj7lsf/weibo_masklm/train_00000.tfrecord
2021-06-19 21:41:04.162 | DEBUG    | __main__:_write_fn:11 - Writing /tmp/tmpa5zj7lsf/weibo_masklm/eval_00000.tfrecord
2021-06-19 21:41:04.229 | DEBUG    | __main__:_write_fn:11 - Writing /tmp/tmpa5zj7lsf/weibo_premask_mlm/train_00000.tfrecord
2021-06-19 21:41:04.297 | DEBUG    | __main__:_write_fn:11 - Writing /tmp/tmpa5zj7lsf/weibo_premask_mlm/eval

## Read TFRecords

In [None]:
# export
def make_feature_desc(feature_desc_dict: dict):
    feature_desc = {}
    for feature_name, feature_type in feature_desc_dict.items():
        if feature_type == 'int64':
            feature_desc[feature_name] = tf.io.VarLenFeature(tf.int64)
        elif feature_type == 'float32':
            feature_desc[feature_name] = tf.io.VarLenFeature(tf.float32)

    return feature_desc


def reshape_tensors_in_dataset(example, feature_desc_dict: dict):
    """Reshape serialized tensor back to its original shape

    Arguments:
        example {Example} -- Example

    Returns:
        Example -- Example
    """

    for feature_key in example:
        example[feature_key] = tf.sparse.to_dense(example[feature_key])

    @tf.function
    def _reshape_tensor(tensor: tf.Tensor, shape_tensor: tf.Tensor, shape_tensor_in_dict: tf.Tensor):
        """
        avoid empty tensor reshape error

        we need to fill tensor with zeros to make sure 
        that loss multiplier aligns with features correctly
        """
        if tf.equal(tf.size(tensor), 0):
            # scalar
            if tf.equal(tf.size(shape_tensor), 0):
                return tf.zeros(shape=shape_tensor_in_dict, dtype=tensor.dtype)
            else:
                return tf.zeros(shape=shape_tensor, dtype=tensor.dtype)

        return tf.reshape(tensor, shape=shape_tensor)

    for feature_key in example:
        if '_shape' in feature_key:
            continue

        shape_tensor = example['{}_shape'.format(feature_key)]
        shape_tensor_in_dict = tf.convert_to_tensor(
            feature_desc_dict[feature_key+'_shape_value'], dtype=tf.int32)

        example[feature_key] = _reshape_tensor(
            example[feature_key], shape_tensor, shape_tensor_in_dict)

    for feature_key in list(example.keys()):
        if '_shape' in feature_key:
            del example[feature_key]

    return example


def add_loss_multiplier(example, problem):  # pragma: no cover
    loss_multiplier_name = '{}_loss_multiplier'.format(problem)
    if loss_multiplier_name not in example:
        example[loss_multiplier_name] = tf.constant(
            value=1, shape=(), dtype=tf.int32)
    return example


def set_shape_for_dataset(example, feature_desc_dict):  # pragma: no cover
    for feature_key in example:
        example[feature_key].set_shape(
            feature_desc_dict['{}_shape_value'.format(feature_key)])
    return example


def get_dummy_features(dataset_dict, feature_desc_dict):
    """Get dummy features.
    Dummy features are used to make sure every feature dict
    at every iteration has the same keys.

    Example:
        problem A: {'input_ids': [1,2,3], 'A_label_ids': 1}
        problem B: {'input_ids': [1,2,3], 'B_label_ids': 2}

    Then dummy features:
        {'A_label_ids': 0, 'B_label_ids': 0}

    At each iteration, we sample a problem, let's say we sampled A
    Then:
        feature dict without dummy:
            {'input_ids': [1,2,3], 'A_label_ids': 1}
        feature dict with dummy:
            {'input_ids': [1,2,3], 'A_label_ids': 1, 'B_label_ids':0}

    Arguments:
        dataset_dict {dict} -- dict of datasets of all problems

    Returns:
        dummy_features -- dict of dummy tensors
    """

    feature_keys = [list(d.element_spec.keys())
                    for _, d in dataset_dict.items()]
    common_features_accross_problems = set(
        feature_keys[0]).intersection(*feature_keys[1:])

    dummy_features = {}
    for problem, problem_dataset in dataset_dict.items():
        output_types = {k: v.dtype for k,
                        v in problem_dataset.element_spec.items()}
        dummy_features.update({
            k: tf.cast(
                tf.constant(shape=[1 if s is None else s for s in feature_desc_dict.get('{}_shape_value'.format(k), [])],
                            value=0),
                v)
            for k, v in output_types.items()
            if k not in common_features_accross_problems})

    return dummy_features


def add_dummy_features_to_dataset(example, dummy_features):  # pragma: no cover
    """Add dummy features to dataset

    feature dict without dummy:
        {'input_ids': [1,2,3], 'A_label_ids': 1}
    feature dict with dummy:
        {'input_ids': [1,2,3], 'A_label_ids': 1, 'B_label_ids':0}

    Arguments:
        example {data example} -- dataset example
        dummy_features {dict} -- dict of dummy tensors
    """
    for feature_name in dummy_features:
        if feature_name not in example:
            example[feature_name] = tf.identity(dummy_features[feature_name])
    return example


def read_tfrecord(params: Params, mode: str):
    """Read and parse TFRecord for every problem

    The returned dataset is parsed, reshaped, to_dense tensors
    with dummy features.

    Arguments:
        params {params} -- params
        mode {str} -- mode, train, eval or predict

    Returns:
        dict -- dict with keys: problem name, values: dataset
    """
    dataset_dict = {}
    all_feature_desc_dict = {}
    for problem_list in params.problem_chunk:
        problem = '_'.join(sorted(problem_list))
        file_dir = os.path.join(params.tmp_file_dir, problem)

        # pyspark path is different
        local_mode_feature_desc_path = os.path.join(
            file_dir, '{}_feature_desc.json'.format(mode))
        if not os.path.exists(local_mode_feature_desc_path):
            tfrecord_path_list = glob(os.path.join(
                file_dir, mode, 'part*'))
            feature_desc_dict = json.load(
                open(os.path.join(file_dir, mode, '{}_feature_desc.json'.format(mode))))
        else:
            tfrecord_path_list = glob(os.path.join(
                file_dir, '{}_*.tfrecord'.format(mode)))
            feature_desc_dict = json.load(
                open(os.path.join(file_dir, '{}_feature_desc.json'.format(mode))))
        all_feature_desc_dict.update(feature_desc_dict)
        feature_desc = make_feature_desc(feature_desc_dict)
        dataset = tf.data.TFRecordDataset(
            tfrecord_path_list, num_parallel_reads=tf.data.experimental.AUTOTUNE)
        # when using hvd, we need to shard dataset
        if params.use_horovod:
            import horovod.tensorflow.keras as hvd
            dataset = dataset.shard(hvd.size(), hvd.rank())
        dataset = dataset.map(lambda x: tf.io.parse_single_example(
            serialized=x, features=feature_desc), num_parallel_calls=tf.data.experimental.AUTOTUNE)
        feature_desc_dict_replace_none = {}
        for name, desc in feature_desc_dict.items():
            if not isinstance(desc, list):
                feature_desc_dict_replace_none[name] = desc
            else:
                desc_without_none = [i if i is not None else 1 for i in desc]
                feature_desc_dict_replace_none[name] = desc_without_none

        dataset = dataset.map(
            lambda x: reshape_tensors_in_dataset(x, feature_desc_dict_replace_none),
            num_parallel_calls=tf.data.experimental.AUTOTUNE).map(  # pylint: disable=no-member
            lambda x: set_shape_for_dataset(
                x, feature_desc_dict),
            num_parallel_calls=tf.data.experimental.AUTOTUNE  # pylint: disable=no-member
        )
        for p in problem_list:
            dataset = dataset.map(lambda x: add_loss_multiplier(x, p),
                                  num_parallel_calls=tf.data.experimental.AUTOTUNE)
        dataset_dict[problem] = dataset

    # add dummy features
    dummy_features = get_dummy_features(dataset_dict, all_feature_desc_dict)
    for idx, problem in enumerate(params.get_problem_chunk(as_str=True)):
        dataset_dict[problem] = dataset_dict[problem].map(
            lambda x: add_dummy_features_to_dataset(x, dummy_features),
            num_parallel_calls=tf.data.experimental.AUTOTUNE
        )
    return dataset_dict


### Local read tfrecord test

In [None]:
# hide
test_base.params.assign_problem(
    'weibo_fake_ner&weibo_fake_cls|weibo_fake_multi_cls|weibo_masklm|weibo_premask_mlm', base_dir=test_base.tmpckptdir)
write_tfrecord(
    params=test_base.params, replace=False)
dataset_dict = read_tfrecord(
    params=test_base.params, mode='train')
dataset: tf.data.Dataset = dataset_dict['weibo_fake_cls_weibo_fake_ner']
assert sorted(list(dataset.element_spec.keys())) == [
    'array_input_ids',
    'array_mask',
    'array_segment_ids',
    'cate_input_ids',
    'cate_mask',
    'cate_segment_ids',
    'masked_lm_ids',
    'masked_lm_positions',
    'masked_lm_weights',
    'text_input_ids',
    'text_mask',
    'text_segment_ids',
    'weibo_fake_cls_label_ids',
    'weibo_fake_cls_loss_multiplier',
    'weibo_fake_multi_cls_label_ids',
    'weibo_fake_multi_cls_loss_multiplier',
    'weibo_fake_ner_label_ids',
    'weibo_fake_ner_loss_multiplier',
    'weibo_masklm_loss_multiplier',
    'weibo_premask_mlm_loss_multiplier',
    'weibo_premask_mlm_masked_lm_ids',
    'weibo_premask_mlm_masked_lm_positions',
    'weibo_premask_mlm_masked_lm_weights']
# make sure loss multiplier is correct
ele = next(dataset.as_numpy_iterator())
assert ele['weibo_fake_cls_loss_multiplier'] == 1
assert ele['weibo_fake_ner_loss_multiplier'] == 1
assert ele['weibo_fake_multi_cls_loss_multiplier'] == 0

# multimodal dataset
dataset: tf.data.Dataset = dataset_dict['weibo_fake_multi_cls']
_ = next(dataset.as_numpy_iterator())




### Pyspark read tfrecord test

NOTE: Test pyspark generated tfrecord

In [None]:
from m3tl.test_base import PysparkTestBase
pyspark_test_base = PysparkTestBase()

problem_chunk_str = pyspark_test_base.params.get_problem_chunk(as_str=True)[0]

write_tfrecord(params=pyspark_test_base.params, replace=True)
assert os.path.exists(os.path.join(
    pyspark_test_base.params.pyspark_output_path, problem_chunk_str, 'problem_info.txt'))


2021-06-19 21:41:07.551 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem pyspark_fake_seq_tag, problem type: seq_tag
2021-06-19 21:41:07.552 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem pyspark_fake_multi_cls, problem type: multi_cls
2021-06-19 21:41:07.552 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem pyspark_fake_cls, problem type: cls
2021-06-19 21:41:13.188 | INFO     | m3tl.utils:set_phase:478 - Setting phase to train
