Roughly what the code does is as follows:
- Duplicate `num_passes` times to get random chunks the full length of the corpus `N` which are concatenated
- Get a tensor of size `batch_size x ((N * num_passes) // batch_size)`
- Split each row into chunks of size `tgt_len`
- Divide the tokens in each chunk into bins

In [375]:
import pickle

from collections import Counter, OrderedDict
import json
import numpy as np
import tensorflow as tf
import os

class Vocab(object):
    def __init__(self, special=[], min_freq=0, max_size=None,
                 lower_case=True, delimiter=None, vocab_file=None):
        self.counter = Counter()
        self.special = special
        self.min_freq = min_freq
        self.max_size = max_size
        self.lower_case = lower_case
        self.delimiter = delimiter
        self.vocab_file = vocab_file

    def tokenize(self, line, add_eos=False, add_double_eos=False):
        line = line.strip()
        if self.lower_case:
            line = line.lower()

        if self.delimiter == "":
            symbols = line
        else:
            # If None splits by space 
            symbols = line.split(self.delimiter)

        if add_double_eos:  # lm1b
            return ['<S>'] + symbols + ['<S>']
        elif add_eos:
            return symbols + ['<eos>']
        else:
            return symbols

    def count_file(self, path, verbose=False, add_eos=False):
        if verbose: print('counting file {} ...'.format(path))
        assert tf.io.gfile.exists(path)

        sents = []
        with open(path, 'r') as f:
            for idx, line in enumerate(f):
                if verbose and idx > 0 and idx % 500000 == 0:
                    print(' line {}'.format(idx))
                # Split each sentence into token and increment
                # counter for each token
                symbols = self.tokenize(line, add_eos=add_eos)
                self.counter.update(symbols)
                sents.append(symbols)

        return sents

    def count_sents(self, sents, verbose=False):
        if verbose: print('counting {} sents ...'.format(len(sents)))
        for idx, symbols in enumerate(sents):
            if verbose and idx > 0 and idx % 500000 == 0:
                print(' line {}'.format(idx))
            self.counter.update(symbols)

    def _build_from_file(self, vocab_file):
        self.idx2sym = []
        self.sym2idx = OrderedDict()

        with open(vocab_file, 'r') as f:
            for line in f:
                symb = line.strip().split()[0]
                self.add_symbol(symb)
        self.unk_idx = self.sym2idx['<UNK>']


    def build_vocab(self):
        if self.vocab_file:
            print('building vocab from {}'.format(self.vocab_file))
            self._build_from_file(self.vocab_file)
            print('final vocab size {}'.format(len(self)))

        else:
            print('building vocab with min_freq={}, max_size={}'.format(
                self.min_freq, self.max_size
            ))
            self.idx2sym = []
            self.sym2idx = OrderedDict()

            for sym in self.special:
                self.add_special(sym)

            for sym, cnt in self.counter.most_common(self.max_size):
                # Works because words in most_common are
                # in descending order of frequency so you
                # can just stop if count is below the min value
                if cnt < self.min_freq: break
                self.add_symbol(sym)

            print('final vocab size {} from {} unique tokens'.format(
                len(self), len(self.counter))
            )

    def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False):
        # Reads file, tokenises each line, converts tokens to symbols
        if verbose: print('encoding file {} ...'.format(path))
        assert tf.io.gfile.exists(path)
        encoded = []
        with open(path, 'r') as f:
            for idx, line in enumerate(f):
                if verbose and idx > 0 and idx % 500000 == 0:
                    print('  line {}'.format(idx))
                symbols = self.tokenize(line, add_eos=add_eos, add_double_eos=add_double_eos)
                encoded.append(self.convert_to_nparray(symbols))

        if ordered:
            encoded = np.concatenate(encoded)

        return encoded

    def encode_sents(self, sents, ordered=False, verbose=False):
        if verbose: print('encoding {} sents ...'.format(len(sents)))
        encoded = []
        for idx, symbols in enumerate(sents):
            if verbose and idx > 0 and idx % 500000 == 0:
                print(' line {}'.format(idx))
            encoded.append(self.convert_to_nparray(symbols))

        if ordered:
            encoded = np.concatenate(encoded)

        return encoded

    def add_special(self, sym):
        # Essentially same as add_symbol
        # but also makes the symbol
        # an attribute
        if sym not in self.sym2idx:
            self.idx2sym.append(sym)
            self.sym2idx[sym] = len(self.idx2sym) - 1
            setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym])

    def add_symbol(self, sym):
        if sym not in self.sym2idx:
            self.idx2sym.append(sym)
            self.sym2idx[sym] = len(self.idx2sym) - 1

    def get_sym(self, idx) -> str:
        assert 0 <= idx < len(self), 'Index {} out of range'
        return self.idx2sym[idx]

    def get_idx(self, sym):
        if sym in self.sym2idx:
            return self.sym2idx[sym]
        else:
            hasattr(self, 'unk_idx')
            return self.sym2idx.get(sym, self.unk_idx)

    def get_symbols(self, indices):
        return [self.get_sym(idx) for idx in indices]

    def get_indices(self, symbols):
        return [self.get_idx(sym) for sym in symbols]

    def convert_to_nparray(self, symbols):
        nparray = np.array(
            self.get_indices(symbols), dtype=np.int64
        )
        return nparray

    def convert_to_sent(self, indices, exclude=None):
        if exclude is None:
            return ' '.join([self.get_sym(idx) for idx in indices])
        else:
            return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])

    def __len__(self):
        return len(self.idx2sym)

In [376]:
kwargs = dict()
# I am using settings given for WikiText-103
kwargs["special"] = ["<eos>"]
kwargs["lower_case"] = False

In [377]:
vocab = Vocab(**kwargs)

In [378]:
!trash data_tmp/cache.pkl

In [379]:
x = vocab.count_file('train_pride_and_prejudice.txt')

In [380]:
vocab.build_vocab()

building vocab with min_freq=0, max_size=None
final vocab size 6702 from 6701 unique tokens


In [381]:
x = vocab.encode_file('train_pride_and_prejudice.txt', ordered=True)

In [382]:
class Corpus(object):
    def __init__(self, path, dataset, *args, **kwargs):
        # In all of this I am using settings given for WikiText-103

        self.dataset = dataset
        self.vocab = Vocab(*args, **kwargs)
        self.vocab.count_file(
            os.path.join(path, "train.txt")
        )

        self.vocab.build_vocab()

        for fold in ['train', 'valid', 'test']:
            setattr(self, fold,
                    self.vocab.encode_file(os.path.join(path, f"{fold}.txt"), ordered=True))

        # TODO: maybe these should be different?
        self.cutoffs = [0, 20000, 40000, 200000] + [len(self.vocab)]

    def convert_to_tfrecords(self, split, save_dir, bsz, tgt_len,
                             num_core_per_host, **kwargs):
        FLAGS = kwargs.get('FLAGS')

        file_names = []
        use_tpu = FLAGS.use_tpu and not (split == "test" and num_core_per_host == 1)

        if use_tpu:
            record_name = "record_info-{}.bsz-{}.tlen-{}.core-{}.json".format(
                split, bsz, tgt_len, num_core_per_host)
        else:
            record_name = "record_info-{}.bsz-{}.tlen-{}.json".format(
                split, bsz, tgt_len)

        record_info_path = os.path.join(save_dir, record_name)

        data = getattr(self, split)

        bin_sizes = get_bin_sizes(data, bsz // num_core_per_host, tgt_len, self.cutoffs)
        file_name, num_batch = create_ordered_tfrecords(
            save_dir, split, data, bsz, tgt_len, num_core_per_host,
            self.cutoffs, bin_sizes,
            num_passes=FLAGS.num_passes if split == 'train' and use_tpu else 1,
            use_tpu=use_tpu)
        file_names.append(file_name)

        with open(record_info_path, "w") as fp:
            record_info = {
                "filenames": file_names,
                "bin_sizes": bin_sizes,
                "num_batch": num_batch
            }
            json.dump(record_info, fp)

In [383]:
def get_lm_corpus(data_dir, dataset):
    fn = os.path.join(data_dir, "cache.pkl")

    if tf.io.gfile.exists(fn):
        print("Loading cached dataset")
        with open(fn, "rb") as fp:
            corpus = pickle.load(fp)

    else:
        print("Producing dataset...")
        kwargs = dict()
        # I am using settings given for WikiText-103
        kwargs["special"] = ["<eos>", '<pad>', '<unk>']
        kwargs["lower_case"] = False
        corpus = Corpus(data_dir, dataset, **kwargs)

        print("Saving dataset...")

        with open(fn, "wb") as fp:
            # TODO: why protocol = 2 ?
            pickle.dump(corpus, fp, protocol=2)

        corpus_info = dict(
            vocab_size=len(corpus.vocab),
            cutoffs=corpus.cutoffs,
            dataset=corpus.dataset
        )

        with open(os.path.join(data_dir, "corpus-info.json"), "w") as fp:
            json.dump(corpus_info, fp)

    return corpus

In [384]:
from easydict import EasyDict

In [385]:
# for fold in ['valid', 'test']:
#     with open(f'data_tmp/{fold}.txt') as f:
#         x = f.read()[:1000]

#     with open(f'data_tmp/{fold}.txt', 'w') as f:
#         f.write(x)


In [386]:
FLAGS = EasyDict(
    data_dir='data_tmp',
    dataset='jane_austen',
    per_host_test_bsz=0,
    num_hosts=1,
    num_core_per_host=8,
    per_host_train_bsz=2,
    per_host_valid_bsz=2,
    tgt_len=10,
)

In [387]:
corpus = get_lm_corpus(FLAGS.data_dir, FLAGS.dataset)

Producing dataset...
building vocab with min_freq=0, max_size=None
final vocab size 121 from 118 unique tokens
Saving dataset...


In [388]:
file_names = []
record_name = "record_info-{}.bsz-{}.tlen-{}.core-{}.json".format(
                'train', FLAGS.per_host_train_bsz, FLAGS.tgt_len, 
    FLAGS.num_core_per_host)
record_name

'record_info-train.bsz-2.tlen-10.core-8.json'

In [389]:
def _nearest_multiple_of_eight(x):
    # i.e. nearest that is greater than 0
    y = x - x % 8
    return y + 8 if x % 8 >= 4 else max(8, y)

In [390]:
bin_sizes = []
num_batch = int(np.ceil(len(corpus.train) // FLAGS.per_host_train_bsz // FLAGS.tgt_len))
total = FLAGS.per_host_train_bsz * num_batch * FLAGS.tgt_len

In [391]:
len(corpus.train), num_batch, total

(245, 12, 240)

In [392]:
if total - len(corpus.train) > 0:
    data = np.concatenate([corpus.train, corpus.vocab.pad_idx * np.ones(total - len(corpus.train))])
else:
    data = corpus.train[:total]
data_tmp = data.reshape(FLAGS.per_host_train_bsz, num_batch, FLAGS.tgt_len)
data_tmp.shape

(2, 12, 10)

In [393]:
len(corpus.train), np.product(data_tmp.shape)

(245, 240)

data

In [394]:
tot = FLAGS.per_host_train_bsz * FLAGS.tgt_len

In [395]:
import math

Is this for model parallel? The data is getting split based on value of token. Then you estimate how what fraction tokens there are in each token range. Each batch would have `FLAGS.tgt_len * FLAGS.per_host_train_bsz` tokens. You multiply this by a factor `mean + std_mult[b] * std` to get an estimate for the bin size.

In [396]:
std_mult=[2.5, 2.5, 2.5]
for b, (left, right) in enumerate(zip(corpus.cutoffs[1:-1], corpus.cutoffs[2:])):
    mask = (data_tmp >= left) * (data_tmp < right)
    percents = mask.astype(np.float64).sum(2).sum(0) / tot
    mean = np.mean(percents)
    std = np.std(percents)

    bin_size = int(math.ceil(FLAGS.tgt_len * FLAGS.per_host_train_bsz * (mean + std_mult[b] * std)))
    bin_size = _nearest_multiple_of_eight(bin_size)
    bin_sizes.append(bin_size)

In [397]:
batch_size = FLAGS.per_host_train_bsz
tgt_len = FLAGS.tgt_len
num_core_per_host = FLAGS.num_core_per_host
basename = 'train'

In [398]:
file_name = "{}.bsz-{}.tlen-{}.core-{}.tfrecords".format(
        basename, batch_size, tgt_len, num_core_per_host)
file_name

'train.bsz-2.tlen-10.core-8.tfrecords'

In [399]:
save_path = os.path.join('data_tmp/', file_name)
record_writer = tf.io.TFRecordWriter(save_path)

In [400]:
num_passes = 2
data_len = len(data)
double_data = np.concatenate([data, data])
data_list = list()

Note that this step shuffles the data in a wrap-around fashion so that order is maintained except where the corpus ends and restarts.

In [401]:
for i in range(num_passes):
    start = np.random.randint(0, data_len)
    data_list.append(double_data[start:start+data_len])

In [402]:
data2 = np.concatenate(data_list)

In [403]:
num_step = (len(data2) // batch_size)
#assert num_step == len(data2) / batch_size
data2 = data2[:batch_size * num_step]
data2 = data2.reshape((batch_size, num_step))

In [404]:
def as_text(x):
    x = [' '.join(corpus.vocab.get_symbols(i.astype('int'))) for i in x]
    return '\n\n'.join(x)
        

In [405]:
as_text(data2[:1])

'she ; " for Mrs. Long has just been here , and <eos> she told me all about it . " <eos> <eos> Mr. Bennet made no answer . <eos> <eos> " Do you not want to know who has taken it ? " cried his wife <eos> impatiently . <eos> <eos> " _ You _ want to tell me , and I have no objection to hearing it . " <eos> <eos> This was invitation enough . <eos> <eos> " Why , my dear , you must know , Mrs. Long says that Netherfield is <eos> taken by a young man of Chapter 1 <eos> <eos> It is a truth universally acknowledged , that a single man in <eos> possession of a good fortune , must be in want of a wife . <eos> <eos> However little known the feelings or views of such a man may be <eos> on his first entering a neighbourhood , this truth is so well <eos> fixed in the minds of the surrounding families , that he is <eos> considered the rightful property of some one or other of their <eos> daughters . <eos> <eos> " My dear Mr. Bennet , " said his lady to him one day , " have you <eos> heard that Netherf

In [406]:
data2[:1]

array([[ 33,  90,   3,  91,  34,  35,  36,  92,  93,  94,   4,  37,   0,
         33,  95,  38,  96,  97,  12,   5,   3,   0,   0,  17,  18,  98,
         39,  99,   5,   0,   0,   3, 100,  19,  32,  15,  11,  40, 101,
         36,  41,  12,  31,   3, 102,  16,  24,   0, 103,   5,   0,   0,
          3,  42, 104,  42,  15,  11, 105,  38,   4,  37, 106,  29,  39,
        107,  11, 108,  12,   5,   3,   0,   0, 109, 110, 111, 112,   5,
          0,   0,   3, 113,   4, 114,  28,   4,  19,  22,  40,   4,  34,
         35, 115,   9,  30,   8,   0,  41, 116,   6, 117,  13,   7,  43,
         44,   0,   0,  45,   8,   6,  20,  46,  47,   4,   9,   6,  48,
         13,  14,   0,  49,   7,   6,  50,  21,   4,  22,  23,  14,  15,
          7,   6,  24,   5,   0,   0,  51,  52,  53,  10,  54,  25,  55,
          7,  56,   6,  13,  57,  23,   0,  58,  16,  59,  60,   6,  61,
          4,  62,  20,   8,  63,  64,   0,  65,  14,  10,  66,   7,  10,
         67,  68,   4,   9,  26,   8,   0,  69,  10

In [407]:
for l in [' '.join(corpus.vocab.get_symbols(i.astype('int'))) for i in data2]:
    print(l + '\n\n')

she ; " for Mrs. Long has just been here , and <eos> she told me all about it . " <eos> <eos> Mr. Bennet made no answer . <eos> <eos> " Do you not want to know who has taken it ? " cried his wife <eos> impatiently . <eos> <eos> " _ You _ want to tell me , and I have no objection to hearing it . " <eos> <eos> This was invitation enough . <eos> <eos> " Why , my dear , you must know , Mrs. Long says that Netherfield is <eos> taken by a young man of Chapter 1 <eos> <eos> It is a truth universally acknowledged , that a single man in <eos> possession of a good fortune , must be in want of a wife . <eos> <eos> However little known the feelings or views of such a man may be <eos> on his first entering a neighbourhood , this truth is so well <eos> fixed in the minds of the surrounding families , that he is <eos> considered the rightful property of some one or other of their <eos> daughters . <eos> <eos> " My dear Mr. Bennet , " said his lady to him one day , " have you <eos> heard that Netherfi

In [408]:
num_batch = 0

In [409]:
data2.shape

(2, 240)

What this does:
- Splits each batch element into chunks of size `tgt_len` dropping the any remainder whose length is less that `tgt_len`
- Inputs and labels otherwise as usual (label shifted ahead by one)

TODO:
- [x] Why are there double `<eos>` symbols here? - because of blank lines which get converted to a single `<eos>`

In [410]:
for t in range(0, data2.shape[1] - 1, tgt_len):
    cur_tgt_len = min(data2.shape[1] - 1 - t, tgt_len)
    print(f'tgt_len={tgt_len}, data2.shape[1] - 1 - t = {data2.shape[1] - 1 - t}, cur_tgt_len = {cur_tgt_len}\n')
    if cur_tgt_len < tgt_len:
        print('Breaking')
        break
        
    for idx in range(batch_size):
        inputs = data2[idx, t:t + cur_tgt_len]
        labels = data2[idx, t + 1:t + cur_tgt_len + 1]
        print(f'idx={idx}, t={t}, (t, t + cur_tgt_len)={(t, t + cur_tgt_len)}, (t + 1, t + cur_tgt_len + 1)={(t, t + cur_tgt_len)}')
        print('inputs||' + as_text([inputs]))
        print('targets||' + as_text([labels]))
        print()
    
    print('=' * 80 + '\n')

tgt_len=10, data2.shape[1] - 1 - t = 239, cur_tgt_len = 10

idx=0, t=0, (t, t + cur_tgt_len)=(0, 10), (t + 1, t + cur_tgt_len + 1)=(0, 10)
inputs||she ; " for Mrs. Long has just been here
targets||; " for Mrs. Long has just been here ,

idx=1, t=0, (t, t + cur_tgt_len)=(0, 10), (t + 1, t + cur_tgt_len + 1)=(0, 10)
inputs||. <eos> <eos> " Do you not want to know
targets||<eos> <eos> " Do you not want to know who


tgt_len=10, data2.shape[1] - 1 - t = 229, cur_tgt_len = 10

idx=0, t=10, (t, t + cur_tgt_len)=(10, 20), (t + 1, t + cur_tgt_len + 1)=(10, 20)
inputs||, and <eos> she told me all about it .
targets||and <eos> she told me all about it . "

idx=1, t=10, (t, t + cur_tgt_len)=(10, 20), (t + 1, t + cur_tgt_len + 1)=(10, 20)
inputs||who has taken it ? " cried his wife <eos>
targets||has taken it ? " cried his wife <eos> impatiently


tgt_len=10, data2.shape[1] - 1 - t = 219, cur_tgt_len = 10

idx=0, t=20, (t, t + cur_tgt_len)=(20, 30), (t + 1, t + cur_tgt_len + 1)=(20, 30)
inputs||" 

In [411]:
def _int64_feature(values):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))


def _float_feature(values):
    return tf.train.Feature(float_list=tf.train.FloatList(value=values))

In [412]:
feature = {
          "inputs": _int64_feature(inputs),
          "labels": _int64_feature(labels),
      }
feature

{'inputs': int64_list {
   value: 93
   value: 94
   value: 4
   value: 37
   value: 0
   value: 33
   value: 95
   value: 38
   value: 96
   value: 97
 },
 'labels': int64_list {
   value: 94
   value: 4
   value: 37
   value: 0
   value: 33
   value: 95
   value: 38
   value: 96
   value: 97
   value: 12
 }}

In [413]:
corpus.cutoffs

[0, 20000, 40000, 200000, 121]

In [414]:
len(corpus.cutoffs), len(bin_sizes)

(5, 3)

In [415]:
left, right = corpus.cutoffs[:2]
inp_mask = ((inputs >= left) * (inputs < right)).astype(np.float32)
tgt_mask = ((labels >= left) * (labels < right)).astype(np.float32)

In [416]:
left, right

(0, 20000)

In [417]:
inp_mask

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)

In [418]:
tgt_mask

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)

In [419]:
feature["inp_mask"] = _float_feature(inp_mask)
feature["tgt_mask"] = _float_feature(tgt_mask)
feature 

{'inputs': int64_list {
   value: 93
   value: 94
   value: 4
   value: 37
   value: 0
   value: 33
   value: 95
   value: 38
   value: 96
   value: 97
 },
 'labels': int64_list {
   value: 94
   value: 4
   value: 37
   value: 0
   value: 33
   value: 95
   value: 38
   value: 96
   value: 97
   value: 12
 },
 'inp_mask': float_list {
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
 },
 'tgt_mask': float_list {
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
 }}

In [420]:
tmp_batch_size = 10
tmp_batch_size // num_core_per_host

1

In [421]:
inp_cnts = [0] * len(bin_sizes)
tgt_cnts = [0] * len(bin_sizes)

inp_cnts, tgt_cnts

([0, 0, 0], [0, 0, 0])

In [436]:
inputs, labels

(array([93, 94,  4, 37,  0, 33, 95, 38, 96, 97]),
 array([94,  4, 37,  0, 33, 95, 38, 96, 97, 12]))

In [444]:
inp_pos_per_bin, tgt_pos_per_bin

([array([], dtype=int64), array([3, 5, 7]), array([0, 1, 6, 8, 9])],
 [array([9]), array([2, 4, 6]), array([0, 5, 7, 8])])

In [445]:
head_labels = np.copy(labels)

left, right = cutoffs[:2]
inp_mask = ((inputs >= left) * (inputs < right)).astype(np.float32)
tgt_mask = ((labels >= left) * (labels < right)).astype(np.float32)

feature["inp_mask"] = _float_feature(inp_mask)
feature["tgt_mask"] = _float_feature(tgt_mask)

inp_pos_per_bin, tgt_pos_per_bin = [], []
cutoffs = [0, 10, 25, 50, 100]
print(
     f'inp_pos_per_bin = {inp_pos_per_bin}, ' + 
     f'tgt_pos_per_bin = {tgt_pos_per_bin}, ' +
     f'head_labels = {head_labels}' )
print()
for b, (left, right) in enumerate(zip(cutoffs[1:-1], cutoffs[2:])):

    inp_pos = np.where((inputs >= left) * (inputs < right))[0]
    tgt_pos = np.where((labels >= left) * (labels < right))[0]
    inp_pos_per_bin.append(inp_pos)
    tgt_pos_per_bin.append(tgt_pos)

    # All these target locations have this value
    head_labels[tgt_pos] = cutoffs[1] + b
    print(f'b={b}, left={left}, right={right}')
    print(f'inp_pos = {inp_pos}, ' + 
          f'tgt_pos = {tgt_pos}, ' + 
         f'inp_pos_per_bin = {inp_pos_per_bin}, ' + 
         f'tgt_pos_per_bin = {tgt_pos_per_bin}, ' +
         f'head_labels = {head_labels}, ' 
         f'cutoffs[1] + b = {cutoffs[1] + b}')
    print()

inp_pos_per_bin = [], tgt_pos_per_bin = [], head_labels = [94  4 37  0 33 95 38 96 97 12]

b=0, left=10, right=25
inp_pos = [], tgt_pos = [9], inp_pos_per_bin = [array([], dtype=int64)], tgt_pos_per_bin = [array([9])], head_labels = [94  4 37  0 33 95 38 96 97 10], cutoffs[1] + b = 10

b=1, left=25, right=50
inp_pos = [3 5 7], tgt_pos = [2 4 6], inp_pos_per_bin = [array([], dtype=int64), array([3, 5, 7])], tgt_pos_per_bin = [array([9]), array([2, 4, 6])], head_labels = [94  4 11  0 11 95 11 96 97 10], cutoffs[1] + b = 11

b=2, left=50, right=100
inp_pos = [0 1 6 8 9], tgt_pos = [0 5 7 8], inp_pos_per_bin = [array([], dtype=int64), array([3, 5, 7]), array([0, 1, 6, 8, 9])], tgt_pos_per_bin = [array([9]), array([2, 4, 6]), array([0, 5, 7, 8])], head_labels = [12  4 11  0 11 12 11 12 12 10], cutoffs[1] + b = 12



In [446]:
feature

{'inputs': int64_list {
   value: 93
   value: 94
   value: 4
   value: 37
   value: 0
   value: 33
   value: 95
   value: 38
   value: 96
   value: 97
 },
 'labels': int64_list {
   value: 94
   value: 4
   value: 37
   value: 0
   value: 33
   value: 95
   value: 38
   value: 96
   value: 97
   value: 12
 },
 'inp_mask': float_list {
   value: 0.0
   value: 0.0
   value: 1.0
   value: 0.0
   value: 1.0
   value: 0.0
   value: 0.0
   value: 0.0
   value: 0.0
   value: 0.0
 },
 'tgt_mask': float_list {
   value: 0.0
   value: 1.0
   value: 0.0
   value: 1.0
   value: 0.0
   value: 0.0
   value: 0.0
   value: 0.0
   value: 0.0
   value: 0.0
 },
 'head_labels': int64_list {
   value: 12
   value: 4
   value: 11
   value: 0
   value: 11
   value: 12
   value: 11
   value: 12
   value: 12
   value: 10
 },
 'inp_cnt_0': int64_list {
   value: 0
 },
 'inp_tup_0': int64_list {
 },
 'inp_cnt_1': int64_list {
   value: 3
 },
 'inp_tup_1': int64_list {
   value: 3
   value: 3
   value: 5
   valu

In [439]:
feature["head_labels"] = _int64_feature(head_labels)

In [440]:
feature

{'inputs': int64_list {
   value: 93
   value: 94
   value: 4
   value: 37
   value: 0
   value: 33
   value: 95
   value: 38
   value: 96
   value: 97
 },
 'labels': int64_list {
   value: 94
   value: 4
   value: 37
   value: 0
   value: 33
   value: 95
   value: 38
   value: 96
   value: 97
   value: 12
 },
 'inp_mask': float_list {
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
 },
 'tgt_mask': float_list {
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
 },
 'head_labels': int64_list {
   value: 12
   value: 4
   value: 11
   value: 0
   value: 11
   value: 12
   value: 11
   value: 12
   value: 12
   value: 10
 },
 'inp_cnt_0': int64_list {
   value: 0
 },
 'inp_tup_0': int64_list {
 },
 'inp_cnt_1': int64_list {
   value: 3
 },
 'inp_tup_1': int64_list {
   value: 3
   value: 0
   value: 5
   valu

In [441]:
def _add_perm_feature(feature, pos_per_bin, cnts, prefix):
  # Go through each bin
  print(f'feature={feature}, pos_per_bin={pos_per_bin}, cnts={cnts}, prefix={prefix}')
  for b, pos in enumerate(pos_per_bin):
    print(f'\tfor1 --- b={b}, pos={pos}')
    idx_tuple = []
    # if bin is not full add to bin (bin_id, num_so_far)
    for p in pos:
        print(f'\t\tfor2 --- p={p}')
        print(f'\t\tfor2 --- cnts[b]={cnts[b]}, bin_sizes[b]={bin_sizes[b]}, cnts[b] < bin_sizes[b]={cnts[b] < bin_sizes[b]}')
        if cnts[b] < bin_sizes[b]:
            idx_tuple.append([p, cnts[b]])
            print(f'\t\t\tif --- idx_tuple={idx_tuple}')
            cnts[b] += 1
        else:
            print('\t\t\telse --- break')
            break
            
        print()

    n_tup = len(idx_tuple)
    tup = np.array(idx_tuple).reshape(n_tup * 2)
    
    print(f'\tfor1 --- n_tup={n_tup}, tup={tup}')
    
    print(f'\tfor1 --- prefix={prefix}, b={b}')
    
    print()

    # inp_cnt_<bin_id> is total number in bin
    feature["{}_cnt_{}".format(prefix, b)] = _int64_feature([n_tup])
    # inp_tup_<bin_id> is the list of (bin_id, num_so_far)
    feature["{}_tup_{}".format(prefix, b)] = _int64_feature(tup)

In [442]:
inp_pos_per_bin, tgt_pos_per_bin

([array([], dtype=int64), array([3, 5, 7]), array([0, 1, 6, 8, 9])],
 [array([9]), array([2, 4, 6]), array([0, 5, 7, 8])])

In [443]:
_add_perm_feature(feature, inp_pos_per_bin, inp_cnts, "inp")
_add_perm_feature(feature, tgt_pos_per_bin, tgt_cnts, "tgt")

feature={'inputs': int64_list {
  value: 93
  value: 94
  value: 4
  value: 37
  value: 0
  value: 33
  value: 95
  value: 38
  value: 96
  value: 97
}
, 'labels': int64_list {
  value: 94
  value: 4
  value: 37
  value: 0
  value: 33
  value: 95
  value: 38
  value: 96
  value: 97
  value: 12
}
, 'inp_mask': float_list {
  value: 1.0
  value: 1.0
  value: 1.0
  value: 1.0
  value: 1.0
  value: 1.0
  value: 1.0
  value: 1.0
  value: 1.0
  value: 1.0
}
, 'tgt_mask': float_list {
  value: 1.0
  value: 1.0
  value: 1.0
  value: 1.0
  value: 1.0
  value: 1.0
  value: 1.0
  value: 1.0
  value: 1.0
  value: 1.0
}
, 'head_labels': int64_list {
  value: 12
  value: 4
  value: 11
  value: 0
  value: 11
  value: 12
  value: 11
  value: 12
  value: 12
  value: 10
}
, 'inp_cnt_0': int64_list {
  value: 0
}
, 'inp_tup_0': int64_list {
}
, 'inp_cnt_1': int64_list {
  value: 3
}
, 'inp_tup_1': int64_list {
  value: 3
  value: 0
  value: 5
  value: 1
  value: 7
  value: 2
}
, 'inp_cnt_2': int64_list {

In [434]:
feature

{'inputs': int64_list {
   value: 93
   value: 94
   value: 4
   value: 37
   value: 0
   value: 33
   value: 95
   value: 38
   value: 96
   value: 97
 },
 'labels': int64_list {
   value: 94
   value: 4
   value: 37
   value: 0
   value: 33
   value: 95
   value: 38
   value: 96
   value: 97
   value: 12
 },
 'inp_mask': float_list {
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
 },
 'tgt_mask': float_list {
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
   value: 1.0
 },
 'head_labels': int64_list {
   value: 12
   value: 4
   value: 11
   value: 0
   value: 11
   value: 12
   value: 11
   value: 12
   value: 12
   value: 10
 },
 'inp_cnt_0': int64_list {
   value: 0
 },
 'inp_tup_0': int64_list {
 },
 'inp_cnt_1': int64_list {
   value: 3
 },
 'inp_tup_1': int64_list {
   value: 3
   value: 0
   value: 5
   valu

In [433]:
for split, batch_size in zip(
  ["train", "valid"],
  [FLAGS.per_host_train_bsz, FLAGS.per_host_valid_bsz]):

if batch_size <= 0: continue
print("Converting {} set...".format(split))
corpus.convert_to_tfrecords(split, save_dir, batch_size, FLAGS.tgt_len,
                            FLAGS.num_core_per_host, FLAGS=FLAGS)

IndentationError: expected an indented block (<ipython-input-433-e2a5c87b608f>, line 5)

In [None]:
# def main(unused_argv):
#   del unused_argv  # Unused

#   corpus = get_lm_corpus(FLAGS.data_dir, FLAGS.dataset)

#   save_dir = os.path.join(FLAGS.data_dir, "tfrecords")
#   if not exists(save_dir):
#     makedirs(save_dir)

#   # test mode
#   if FLAGS.per_host_test_bsz > 0:
#     corpus.convert_to_tfrecords("test", save_dir, FLAGS.per_host_test_bsz,
#                                 FLAGS.tgt_len, FLAGS.num_core_per_host, 
#                                 FLAGS=FLAGS)
#     return

