In [1]:
import collections
import itertools
import json
import numpy as np
import os
import pathlib
import subprocess
import sys
import transformers
from typing import *

import markdown_strings
from IPython.display import display, Markdown, Latex

try:
    from rich import pretty
    pretty.install()
    from rich import print
except ImportError:
    pass

import tensorflow as tf
import tqdm.notebook as tqdm

_PROJECT_DIRECTORY = pathlib.Path().resolve().parent
sys.path.append(str(_PROJECT_DIRECTORY))
import constants
import utils

In [2]:
def normal(text, escape=False):
    if escape:
        text = markdown_strings.esc_format(text)
    display(Markdown(text))

def h1(text, escape=False):
    if escape:
        text = markdown_strings.esc_format(text)
    display(Markdown(f"# {text}"))
    
def h2(text, escape=False):
    if escape:
        text = markdown_strings.esc_format(text)
    display(Markdown(f"#### {text}"))
    
def quote(text, escape=True):
    if escape:
        text = markdown_strings.esc_format(text)
    display(Markdown(markdown_strings.blockquote(text)))

In [3]:
def build_per_split(num_paths_display):
    h1("Getting filenames.")
    h2("Loading json config.")
    config_path = _PROJECT_DIRECTORY/"configs"/"train_configs"/"tpu_gpt2_eli5_kilt.json"
    config = utils.from_json_file(config_path)
    h2("Calling `gsutil ls` on the dataset repo.")
    ds_path = config["tfr_prefix"]
    filenames = subprocess.check_output(f"gsutil ls {ds_path}", shell=True).decode().strip().split("\n")
    h2("A few paths:")
    normal(f"There are actually {len(filenames)}.")
    normal(" - " + "\n - ".join(filenames[:num_paths_display]))
    

    h1("Building the `per_split` Path dict.")
    per_split = collections.defaultdict(list)
    for path in tqdm.tqdm(filenames, desc="Building `per_split` dict."):
        split = pathlib.Path(path).name.split("_")[0]
        per_split[split].append(path)

    normal("Sorting the `per_split` lists.")
    for split in per_split:
        # Ad-hoc split per file index
        per_split[split].sort(key=lambda p: int(pathlib.Path(p).name.split("_")[1].split(".")[0]))

    normal("Len per split for the per_split dict:")
    print({split: len(per_split[split]) for split in per_split})
    return per_split

In [4]:
def build_dataset(paths, context_window_size, split):
    ds = tf.data.TFRecordDataset(paths)


    description = {
      constants.CTH5Fields.distances:
          tf.io.FixedLenFeature((), tf.string),
      constants.CTH5Fields.gpt2_retrieved_ids:
          tf.io.FixedLenFeature((), tf.string),
      constants.CTH5Fields.gpt2_question_ids_inputs:
          tf.io.FixedLenFeature((), tf.string),
    }
    if split != constants.SplitChoices.test:
        description[
            constants.CTH5Fields.gpt2_answer_ids_inputs
        ] = tf.io.FixedLenFeature((), tf.string)

    feature_dtypes = {
      constants.CTH5Fields.distances:
          tf.float32,
      constants.CTH5Fields.gpt2_retrieved_ids:
          tf.int32,
      constants.CTH5Fields.gpt2_question_ids_inputs:
          tf.int32,
    }
    if split != constants.SplitChoices.test:
        feature_dtypes[
            constants.CTH5Fields.gpt2_answer_ids_inputs
        ] = tf.int32

    feature_shape = {
      constants.CTH5Fields.distances:
          (10,),
      constants.CTH5Fields.gpt2_retrieved_ids:
          (10, context_window_size,),
      constants.CTH5Fields.gpt2_question_ids_inputs:
          (context_window_size,),
    }
    if split != constants.SplitChoices.test:
        feature_shape[constants.CTH5Fields.gpt2_answer_ids_inputs] = (
            context_window_size
        )

    @tf.function
    def parse(sample):
        example = tf.io.parse_single_example(sample, description)
        output = {}
        for k, v in example.items():
            output[k] = tf.io.parse_tensor(v, out_type=feature_dtypes[k])
            output[k].set_shape(feature_shape[k])
        return output

    ds = ds.map(
      parse,
      num_parallel_calls=tf.data.experimental.AUTOTUNE,
      deterministic=False
      )
    return ds


In [None]:
def decode_line(tokenizer, line):
    return tokenizer.decode([x for x in line if x >= 0])

def count(paths, context_window_size, split, tokenizer, min_length=7):
    ds = build_dataset(paths, context_window_size, split)
    i = 0    
    feature_lengths = collections.defaultdict(lambda: collections.defaultdict(int))
    
    for item in tqdm.tqdm(ds, desc=f"Counting items for split `{split}`"):
        i += 1
        for feature_key, feature in item.items():
            
            if feature_key in {
                constants.CTH5Fields.gpt2_question_ids_inputs, constants.CTH5Fields.gpt2_answer_ids_inputs
            }:
                length = np.sum(feature >= 0)
                feature_lengths[feature_key][length] += 1
#                 if length <= min_length:
#                     normal(f"{split} - `{feature_key}`: {decode_line(tokenizer, feature)}")

    for k, v in feature_lengths.items():
        print(f"{split} - {k}: ")
        sorted_ = sorted(v.items(), key=lambda x: x[0])
        print(sorted_)
    return i


def main():
    _MAX_QTY = None
    _MODEL_TYPE = "gpt2-xl"
    _MODEL_CONFIG = transformers.AutoConfig.from_pretrained("distilgpt2")
    _CONTEXT_WINDOW_SIZE = _MODEL_CONFIG.n_ctx
    _EXPECTED_SIZES = dict(train=272634, eval=1507, test=600)

    tokenizer = transformers.AutoTokenizer.from_pretrained(_MODEL_TYPE)
    per_split = build_per_split()

    for split in ["eval", "test", "train"]:
        to_test = per_split[split][:_MAX_QTY]
        paths = tf.data.Dataset.from_tensor_slices(to_test)
        count(paths, _CONTEXT_WINDOW_SIZE, split, tokenizer)
main()

# Getting filenames.

#### Loading json config.

#### Calling `gsutil ls` on the dataset repo.

#### A few paths:

 - gs://julesgm-research-v3/tfrecord_query_cache/20210225-191356/eval_0.tfr
 - gs://julesgm-research-v3/tfrecord_query_cache/20210225-191356/eval_1.tfr
 - gs://julesgm-research-v3/tfrecord_query_cache/20210225-191356/eval_10.tfr
 - gs://julesgm-research-v3/tfrecord_query_cache/20210225-191356/eval_100.tfr
 - gs://julesgm-research-v3/tfrecord_query_cache/20210225-191356/eval_1000.tfr
 - gs://julesgm-research-v3/tfrecord_query_cache/20210225-191356/eval_1001.tfr
 - gs://julesgm-research-v3/tfrecord_query_cache/20210225-191356/eval_1002.tfr
 - gs://julesgm-research-v3/tfrecord_query_cache/20210225-191356/eval_1003.tfr
 - gs://julesgm-research-v3/tfrecord_query_cache/20210225-191356/eval_1004.tfr
 - gs://julesgm-research-v3/tfrecord_query_cache/20210225-191356/eval_1005.tfr

# Building the `per_split` Path dict.

HBox(children=(FloatProgress(value=0.0, description='Building `per_split` dict.', max=8192.0, style=ProgressSt…




Sorting the `per_split` lists.

Len per split for the per_split dict:

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Counting items for split `eval`', layou…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Counting items for split `test`', layou…

In [None]:
def display_item(item):    
    h2("Question:")
    quote(decode_line(item[constants.CTH5Fields.gpt2_question_ids_inputs]))

    if _SPLIT != "test":
        h2("Answer:")
        quote(decode_line(item[constants.CTH5Fields.gpt2_answer_ids_inputs]))

    h2("Retrieved segments:")
    for line in item[constants.CTH5Fields.gpt2_retrieved_ids]:
        quote(decode_line(line))

