In [22]:
# imports ...
import os
import tensorflow as tf
import numpy as np
import collections
from tqdm import tqdm

In [2]:
path = '/mnt/gdn-workloads/ihubara/tfrecord_dir_packed'
list_ = ! ls /mnt/gdn-workloads/ihubara/tfrecord_dir_packed

In [16]:
def record2dict(record:bytes) -> collections.OrderedDict:
    example = tf.train.Example()
    example.ParseFromString(record.numpy())
    result = collections.OrderedDict()
    feature = example.feature.feature
    result['input_ids'] = np.array(feature['input_ids'].int64_list.value)
    result['input_mask'] = np.array(feature['input_mask'].int64_list.value)
    result['segment_ids'] = np.array(feature['segment_ids'].int64_list.value)
    result['masked_lm_positions'] = np.array(feature['masked_lm_positions'].int64_list.value)
    result['masked_lm_ids'] = np.array(feature['masked_lm_ids'].int64_list.value)
    result['masked_lm_weights'] = np.array(feature['masked_lm_weights'].float_list.value)
    result['next_sentence_labels'] = np.array(feature['next_sentence_labels'].int64_list.value)
    return result

def packed_record2dict(record:bytes) -> collections.OrderedDict:
    example = tf.train.Example()
    example.ParseFromString(record.numpy())
    result = collections.OrderedDict()
    feature = example.feature.feature
    result['input_ids'] = np.array(feature['input_ids'].int64_list.value)
    result['input_mask'] = np.array(feature['input_mask'].int64_list.value)
    result['segment_ids'] = np.array(feature['segment_ids'].int64_list.value)
    result['positions'] = np.array(feature['positions'].int64_list.value)
    result['masked_lm_positions'] = np.array(feature['masked_lm_positions'].int64_list.value)
    result['masked_lm_ids'] = np.array(feature['masked_lm_ids'].int64_list.value)
    result['masked_lm_weights'] = np.array(feature['masked_lm_weights'].float_list.value)
    result['next_sentence_positions'] = np.array(feature['next_sentence_positions'].int64_list.value)
    result['next_sentence_labels'] = np.array(feature['next_sentence_labels'].int64_list.value)
    result['next_sentence_weights'] = np.array(feature['next_sentence_weights'].float_list.value)
    return result

def get_next_sentence_weights(record:bytes) -> np.ndarray:
    example = tf.train.Example()
    example.ParseFromString(record.numpy())
    feature = example.feature.feature
    return np.array(feature['next_sentence_weights'].float_list.value)

In [None]:
'''
    Count number of records.
'''
counter = 0
for file in tqdm(list_):
    records = tf.data.TFRecordDataset(os.path.join(path, file))
    records = [record for record in records]
    counter += len(records)

In [None]:
'''
    Compute average of number samples per record (avg_seq_per_pack).
'''
pbar = tqdm(total=4746826)

pack_numbers = []
for file in list_:
    records = tf.data.TFRecordDataset(os.path.join(path, file))
    for record in records:
        next_sentence_weights = get_next_sentence_weights(record)
        pack_numbers.append(next_sentence_weights.sum())
        
        pbar.update(1)

pbar.close()

In [23]:
records = tf.data.TFRecordDataset(os.path.join(path, list_[0]))
dict_ = record2dict(next(iter(records)))

In [24]:
dict_['masked_lm_positions']

array([  5,   6,   7,   8,  14,  22,  31,  36,  37,  53,  55,  59,  62,
        73,  87, 101, 102, 105, 120, 130, 131, 133, 135, 161, 162, 164,
       168, 171, 180, 198, 205, 219, 230, 235, 237, 238, 241, 249, 251,
       263, 270, 271, 288, 289, 293, 300, 303, 304, 310, 313, 315, 320,
       321, 324, 328, 347, 353, 356, 366, 369, 378, 384, 390, 395, 403,
       404, 407, 431, 448, 451, 456, 465, 484, 491, 493, 495, 507,   0,
         0])

In [25]:
dict_['masked_lm_ids']

array([10578,  2003,  1037,  7960,  1010,  2094, 11842,  4856,  1998,
        2062,  1010,  1011,  2018, 11762,  7170,  1996, 12731,  4221,
        2029, 15185,  6590,  1006,  3126,  2988,  1010,  1037,  2006,
        2340, 15185,  1024,  1024,  1996,  1012,  2470,  3919,  3934,
        2752,  1998,  1999,  2047,  8324,  1010,  2015,  1010,  2464,
        3934,  4981,  1998,  1038,  1012,  2504,  1006,  4816,  3330,
        5992,  1006,  5992,  4806,  1007,  1997,  4513,  1997,  2008,
        2007,  2095,  1010,  1998,  1997,  8065,  2048,  1998,  5179,
        2627,  4249,  1038,  1055,  2089,     0,     0])

In [21]:
dict_['masked_lm_weights']

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2., 2.,
       2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 3., 3., 3., 3.,
       3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.,
       3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.,
       3., 3., 3., 3., 3., 3., 3., 3., 3., 0., 0.])