In [1]:
import collections
import dataclasses
import itertools
import json
import os
import pathlib
import re
import shlex
import socket
import subprocess
import sys
from typing import *
print("stdlib", flush=True)

import markdown_strings
import numpy as np
from IPython.display import display, Markdown, Latex
import rich
import rich.console
import rich.markdown
import rich.table
import tensorflow.python.framework.ops as ops
import tensorflow as tf
import tensorflow.python.distribute.values as values
import toolz
import tqdm.notebook as tqdm
import transformers
print("3rd party", flush=True)

_PROJECT_DIRECTORY = pathlib.Path().resolve().parent
sys.path.append(str(_PROJECT_DIRECTORY))
import constants
# import task_specific
import tf_utils
import utils
print("custom", flush=True)



#------------------------------------------------------------------------------
# Flags
#------------------------------------------------------------------------------
_MAX_QTY = None
_MODEL_TYPE = "distilgpt2"
_EXPECTED_SIZES = dict(train=272634, eval=1507, test=600)
_NUM_PATHS_DISPLAY = 10
_NUM_REPLICAS = 8
_ACCEL_TYPE = "TPU"
_ZONE = "europe-west4-a"
_APPROACH_TYPE = "naked_lm" # "cached_pretok"

print("done")

stdlib
3rd party
custom
done


In [2]:
@tf.function
def _tokenize_and_concat_while_loop(
    all_retrieved_tokens,
    indices,
    num_retrieved,
    batch_size,
):
  """Tokenizes and puts together the retrievals, per batch unit."""
  def condition(
      index,
      _  # pylint: disable=unused-argument
  ):
    return tf.less(index, num_retrieved)

  def body(
      index,
      concat_tokens,
  ):

    addition = tf.gather(all_retrieved_tokens, indices[:, index], batch_dims=1)

    concat_tokens = tf.concat([
        concat_tokens, addition
    ], axis=1)

    return index + 1, concat_tokens

  if batch_size is None:
    raise RuntimeError("batch_size is `None`. This should not happen.")

  return tf.while_loop(
      condition, body, [
          0, tf.RaggedTensor.from_tensor(
              tf.zeros(
                  shape=(batch_size, 0),
                  dtype=tf.int32
              ),
          )
      ])[1]


def _prepare_samples_w_retrieval(
    split,
    batch_size,
    question_ids_inputs,
    answer_ids_inputs,
    gpt2_tokenized_retrieved,
    distances,
    num_retrievals,
    temperature,
    context_size,
    enable_debug_checks,
    use_helper_words,
    helper_word_token_ids,
    max_generation_length
):
  """Prepares the samples that use retrieval."""
  assert (split == constants.SplitChoices.test) == (
      answer_ids_inputs is None
  ), (split == constants.SplitChoices.test, answer_ids_inputs)
  # If and only if

  is_not_test = split != constants.SplitChoices.test

  if not isinstance(question_ids_inputs, tf.RaggedTensor):
    question_ids_inputs = tf.RaggedTensor.from_tensor(
        question_ids_inputs,
        padding=constants.RAGGED_PADDING_ID
    )

  if enable_debug_checks:
    asserts = []
    asserts.append(
        tf.Assert(
            tf.math.reduce_all(
                question_ids_inputs != constants.RAGGED_PADDING_ID,
            ),
            [question_ids_inputs.to_tensor()]
        )
    )
    if is_not_test:
      asserts.append(
          tf.Assert(
              tf.math.reduce_all(
                  answer_ids_inputs != constants.RAGGED_PADDING_ID,
              ),
              [answer_ids_inputs.to_tensor()]
          )
      )
    with tf.control_dependencies(asserts):
      question_ids_inputs = tf.identity(question_ids_inputs)

  # These checks are at graph composition time, so OK
  utils.check_isinstance(question_ids_inputs, tf.RaggedTensor)

  if is_not_test:
    utils.check_isinstance(answer_ids_inputs, tf.RaggedTensor)

  ##############################################################################
  # Sample from the possible retrievals
  ##############################################################################
  # Choose the indices
  indices = tf_utils.sample_without_replacement(
      distances / temperature, num_retrievals
  )

  # Concatenate the retrievals
  concat_retrieved = _tokenize_and_concat_while_loop(
      gpt2_tokenized_retrieved,
      indices=indices,
      batch_size=batch_size,
      num_retrieved=num_retrievals,
  )

  # Add Context and Answer Helper Words
  if use_helper_words:
    concat_retrieved = tf.concat([
        helper_word_token_ids["context"],
        concat_retrieved,
    ], axis=1)

  # Cut the lengths down to max_lens_retrieval.
  # The eventual length of the ["question"] helper_tokens is included in
  # question_ids_inputs.
  if is_not_test:
    max_lens_retrieval = (
        context_size * tf.ones(
            shape=(batch_size,),
            dtype=tf.int64,
        )
        - (question_ids_inputs.row_lengths() +
           # We always generate the same length of text.
           max_generation_length +  # answer_ids_inputs.row_lengths() +
           (helper_word_token_ids["answer"].shape[1] if use_helper_words else 0)
           )
    )

  else:
    max_lens_retrieval = (
        context_size * tf.ones(
            shape=(batch_size,),
            dtype=tf.int64,
        ) - (question_ids_inputs.row_lengths()  +
             max_generation_length +
             (helper_word_token_ids["answer"].shape[1]
              if use_helper_words else 0
              )
             )
    )

  concat_retrieved = tf.ragged.boolean_mask(
      concat_retrieved,
      (
          tf.ragged.range(concat_retrieved.row_lengths()) <
          tf.expand_dims(max_lens_retrieval, axis=1)
      )
  )

  if enable_debug_checks:
    asserts = [
        tf.Assert(
            tf.math.reduce_all(max_lens_retrieval < context_size),
            [max_lens_retrieval, context_size]
        ),
    ]
    with tf.control_dependencies(asserts):
      concat_retrieved = tf.identity(concat_retrieved)

  if use_helper_words:
    if is_not_test:
      new_input_ids = tf.concat(
          [question_ids_inputs,
           concat_retrieved,
           helper_word_token_ids["answer"],
           answer_ids_inputs
           ],
          axis=1
      )
      new_label_ids = tf.concat(
          [-100 * tf.ones_like(question_ids_inputs),
           -100 * tf.ones_like(concat_retrieved),
           -100 * tf.ones_like(helper_word_token_ids["answer"]),
           answer_ids_inputs
           ],
          axis=1
      )
    else:
      new_input_ids = tf.concat(
          [question_ids_inputs,
           concat_retrieved,
           helper_word_token_ids["answer"],
           ],
          axis=1
      )
  else:
    if is_not_test:
      new_input_ids = tf.concat(
          [question_ids_inputs,
           concat_retrieved,
           answer_ids_inputs
           ],
          axis=1
      )
      new_label_ids = tf.concat(
          [-100 * tf.ones_like(question_ids_inputs),
           -100 * tf.ones_like(concat_retrieved),
           answer_ids_inputs
           ],
          axis=1
      )
    else:
      new_input_ids = tf.concat(
          [question_ids_inputs,
           concat_retrieved,
           ],
          axis=1
      )
  return new_input_ids, new_label_ids if is_not_test else None

print("done")

done


In [3]:
def _make_maybe_retrieve_and_merge_fn(
    *,
    tokenizer,
    context_size,
    ds_split,
    approach_type,  # FLAG_APPROACH_TYPE.value
    use_helper_words,  # FLAG_USE_HELPER_WORDS
    retriever,  # pylint: disable=unused-argument
    temperature,
    num_retrievals,
    enable_debug_checks,
    max_length_generation,
    tf_function_kwargs = None,
):
  """Build the `maybe_retrieve_and_merge` closure."""
  tf_function_kwargs = {} if tf_function_kwargs is None else tf_function_kwargs
  not_test_split = ds_split != constants.SplitChoices.test

#   @tf.function(**tf_function_kwargs)
  def maybe_retrieve_and_merge(
      batch,
  ):
    """Retrieve if needed, then finalize the prep. for model consumption."""

    batch_size = tf.shape(batch[
        constants.CTH5Fields.gpt2_question_ids_inputs
    ])[0]

    # Prepare the question ids inputs
    question_ids_inputs = batch[constants.CTH5Fields.gpt2_question_ids_inputs]
    question_ids_inputs = tf.RaggedTensor.from_tensor(
        question_ids_inputs,
        padding=constants.RAGGED_PADDING_ID
    )

    # Prepare the answer ids inputs
    answer_ids_inputs = None
    answer_ids_labels = None
    if not_test_split:
      answer_ids_inputs = batch[constants.CTH5Fields.gpt2_answer_ids_inputs]
      answer_ids_inputs = tf.RaggedTensor.from_tensor(
          answer_ids_inputs,
          padding=constants.RAGGED_PADDING_ID
      )
      answer_ids_labels = answer_ids_inputs

    ############################################################################
    # Prepare the helper words
    ############################################################################
    helper_word_token_ids = None
    if use_helper_words:

      helper_text = {
          "question": "Question:\n",
          "context": "\nContext:\n",
          "answer": "\nAnswer:\n"
      }

      helper_word_token_ids = {}
      for k in helper_text:
        ids = tf.constant(tokenizer.encode(helper_text[k]), dtype=tf.int32)
        ids = tf.repeat(tf.expand_dims(ids, 0), batch_size, axis=0)
        helper_word_token_ids[k] = ids
      question_ids_inputs = tf.concat(
          [helper_word_token_ids["question"], question_ids_inputs],
          axis=1
      )

    ##########################################################################
    # Cached Retrievals.
    ##########################################################################
    label_ids = None
    if approach_type == constants.ApproachTypeChoices.cached_pretok:
      bpe_indices_gpt2 = batch[constants.CTH5Fields.gpt2_retrieved_ids]
      bpe_indices_gpt2 = tf.RaggedTensor.from_tensor(
          bpe_indices_gpt2,
          ragged_rank=2,
          padding=constants.RAGGED_PADDING_ID
      )

      distances = batch[constants.CTH5Fields.distances]
      input_ids, label_ids = _prepare_samples_w_retrieval(
          split=ds_split,
          batch_size=batch_size,
          question_ids_inputs=question_ids_inputs,
          answer_ids_inputs=(
              answer_ids_inputs if not_test_split else None
          ),
          gpt2_tokenized_retrieved=bpe_indices_gpt2,
          num_retrievals=num_retrievals,
          temperature=temperature,
          context_size=context_size,
          enable_debug_checks=enable_debug_checks,
          distances=distances,
          max_generation_length=max_length_generation,
          helper_word_token_ids=(
              helper_word_token_ids if use_helper_words else None
          ),
          use_helper_words=use_helper_words,
      )

    elif approach_type == constants.ApproachTypeChoices.naked_lm:
      ##########################################################################
      # Without Retrievals
      ##########################################################################
      if use_helper_words:
        question_ids_inputs = tf.concat([
            question_ids_inputs,
            helper_word_token_ids["answer"],
        ], axis=1)

      question_ids_labels = tf.ones_like(
          question_ids_inputs
      ) * constants.PPL_MASK_ID

      if not_test_split:
        input_ids = tf.concat((question_ids_inputs, answer_ids_inputs),
                              axis=1)
        label_ids = tf.concat((question_ids_labels, answer_ids_labels),
                              axis=1)
      else:
        input_ids = question_ids_inputs
    else:
      raise RuntimeError("Unnsupported approach_type value"
                         f" {approach_type}")

    ############################################################################
    # Finalize the preparation
    ############################################################################
    # Convert to dense tensors
    input_ids = input_ids.to_tensor(tokenizer.eos_token_id)

    if not_test_split:
      final_eos = tf.RaggedTensor.from_tensor(
          tokenizer.eos_token_id * tf.ones([batch_size, 1], dtype=tf.int32)
      )
      label_ids = tf.concat([label_ids, final_eos], axis=1)
      label_ids = label_ids.to_tensor(constants.PPL_MASK_ID)

    # All samples need to have at least one token != -100 (PPL_MASK_ID)
    if enable_debug_checks and not_test_split:
      not_any_padding = tf.reduce_any(
          label_ids != constants.PPL_MASK_ID, axis=1
      )
      none_has_padding = tf.math.reduce_all(
          not_any_padding
      )
      qty_doesnt_have_padding = tf.reduce_sum(
          tf.cast(not_any_padding))

      check_no_padding = tf.Assert(
          none_has_padding,
          [qty_doesnt_have_padding]
      )
      with tf.control_dependencies([check_no_padding]):
        label_ids = tf.identity(label_ids)

    # Limit size
    input_ids = input_ids[:, :context_size]
    if not_test_split:
      label_ids = label_ids[:, :context_size]

    ############################################################################
    # Pad `input_ids` and `label_ids` to context_size
    ############################################################################
    # Prepare the ones
    pad_qty = tf.math.maximum(
        0, tf.constant(context_size) - tf.shape(input_ids)[1]
    )
    padding_ones = tf.ones(
        [batch_size, pad_qty],
        dtype=input_ids.dtype
    )
    # Pad the inputs
    input_padding = tokenizer.eos_token_id * padding_ones
    input_ids = tf.concat((input_ids, input_padding), axis=1)

    # Pad the labels labels
    if not_test_split:
      pad_qty = tf.math.maximum(
          0, tf.constant(context_size) - tf.shape(label_ids)[1]
      )
      padding_ones = tf.ones(
          [batch_size, pad_qty],
          dtype=input_ids.dtype
      )
      label_padding = -100 * padding_ones
      label_ids = tf.concat((label_ids, label_padding), axis=1)

    # Make checks
    if enable_debug_checks:
      control_dependencies = []
      control_dependencies.append(tf.Assert(
          tf.math.reduce_all(input_ids != -1),
          [input_ids],
          name="NoMinusOnesInputs"
      ))
      if not_test_split:
        control_dependencies.append(tf.Assert(
            tf.math.reduce_all(label_ids != -1),
            [label_ids],
            name="NoMinusOnesLabel"
        ))
        control_dependencies.append(tf.Assert(
            tf.logical_not(
                tf.math.reduce_any(
                    tf.math.reduce_all(label_ids != -100, axis=1)
                )
            ),
            [label_ids],
            name="NotAllMinusOneHundred"
        ))
      with tf.control_dependencies(control_dependencies):
        input_ids = tf.identity(input_ids)

    return dict(
        input_ids=input_ids,
        label_ids=label_ids if not_test_split else None
    )

  return maybe_retrieve_and_merge

print("done")

done


In [4]:
def normal(text, escape=False):
    if escape:
        text = markdown_strings.esc_format(text)
    display(Markdown(text))

    
def h1(text, escape=False):
    if escape:
        text = markdown_strings.esc_format(text)
    display(Markdown(f"# {text}"))
    
    
def h2(text, escape=False):
    if escape:
        text = markdown_strings.esc_format(text)
    display(Markdown(f"#### {text}"))
    
    
def quote(text, escape=True):
    if escape:
        text = markdown_strings.esc_format(text)
    display(Markdown(markdown_strings.blockquote(text)))
    
    
def build_split_to_ds_paths(project_directory, num_paths_display):
    h1("Getting filenames.")
    h2("Loading json config.")
    config_path = project_directory/"configs"/"train_configs"/"tpu_gpt2_eli5_kilt.json"
    config = utils.from_json_file(config_path)
    
    h2("Calling `gsutil ls` on the dataset repo.")
    ds_path = config["tfr_prefix"]
    filenames = subprocess.check_output(f"gsutil ls {ds_path}", shell=True).decode().strip().split("\n")

    h2("Printing a few paths:")
    normal(f"There are actually {len(filenames)}.")
    normal(" - " + "\n - ".join(filenames[:num_paths_display]))
    
    h1("Building the `per_split` Path dict.")
    per_split = collections.defaultdict(list)
    for path in tqdm.tqdm(filenames, desc="Building `per_split` dict."):
        split = pathlib.Path(path).name.split("_")[0]
        per_split[split].append(path)

    normal("Sorting the `per_split` lists.")
    for split in per_split:
        # Ad-hoc split per file index
        per_split[split].sort(key=lambda p: int(pathlib.Path(p).name.split("_")[1].split(".")[0]))

    normal("Len per split for the per_split dict:")
    
    print({split: len(per_split[split]) for split in per_split})
    
    return per_split


def build_dataset(paths, context_window_size, split, batch_size):
    ds = tf.data.TFRecordDataset(paths)
    description = {
      constants.CTH5Fields.distances:
          tf.io.FixedLenFeature((), tf.string),
      constants.CTH5Fields.gpt2_retrieved_ids:
          tf.io.FixedLenFeature((), tf.string),
      constants.CTH5Fields.gpt2_question_ids_inputs:
          tf.io.FixedLenFeature((), tf.string),
    }
    if split != constants.SplitChoices.test:
        description[
            constants.CTH5Fields.gpt2_answer_ids_inputs
        ] = tf.io.FixedLenFeature((), tf.string)

    feature_dtypes = {
      constants.CTH5Fields.distances:
          tf.float32,
      constants.CTH5Fields.gpt2_retrieved_ids:
          tf.int32,
      constants.CTH5Fields.gpt2_question_ids_inputs:
          tf.int32,
    }
    if split != constants.SplitChoices.test:
        feature_dtypes[
            constants.CTH5Fields.gpt2_answer_ids_inputs
        ] = tf.int32

    feature_shape = {
      constants.CTH5Fields.distances:
          (10,),
      constants.CTH5Fields.gpt2_retrieved_ids:
          (10, context_window_size,),
      constants.CTH5Fields.gpt2_question_ids_inputs:
          (context_window_size,),
    }
    if split != constants.SplitChoices.test:
        feature_shape[constants.CTH5Fields.gpt2_answer_ids_inputs] = (
            context_window_size
        )

    @tf.function
    def parse(sample):
        example = tf.io.parse_single_example(sample, description)
        output = {}
        for k, v in example.items():
            output[k] = tf.io.parse_tensor(v, out_type=feature_dtypes[k])
            output[k].set_shape(feature_shape[k])
        return output

    ds = ds.map(
      parse,
      num_parallel_calls=tf.data.experimental.AUTOTUNE,
      deterministic=False
      )
    
    ds = ds.batch(
      batch_size,
      drop_remainder=split != constants.SplitChoices.test
      )

    return ds


def decode_line(tokenizer, line):
    return tokenizer.decode([x for x in line if x >= 0])


def is_all_neg(tensor):
    if not isinstance(tensor, (np.ndarray, tf.Tensor, ops.EagerTensor)):
        return all([x < 0 for x in tensor])
    else:
        return np.all(tensor < 0)
    
    
def check_and_decode(feature_key, item, tokenizer):
    feature = item[feature_key]
    all_neg = is_all_neg(feature)
    assert not all_neg, feature_key
    return decode_line(tokenizer, feature)
    
    
def display_item(major, minor, max_minor, tokenizer, item, split):    
    item = vars(item)
    ##################################################################################################################
    # Produce information
    ##################################################################################################################
    question = check_and_decode(
        constants.CTH5Fields.gpt2_question_ids_inputs,
        item,
        tokenizer
    )
    
    answer = None
    if split != "test":
        feature_key = constants.CTH5Fields.gpt2_answer_ids_inputs
        feature = item[feature_key]
        answer = check_and_decode(feature_key, item, tokenizer)

    retrieved_segments = []
    for line in item[constants.CTH5Fields.gpt2_retrieved_ids]:
        retrieved_segments.append(decode_line(tokenizer, line))

    ##################################################################################################################
    # Display
    ##################################################################################################################
    console = rich.console.Console()
    table = rich.table.Table(title=f"{major}:[{minor}/{max_minor}] - Item from split `{split}`", show_lines=True)
    table.add_column("Field", style="bold")
    table.add_column("Value")
    table.add_row("Question:", question)
    for i, segment in enumerate(retrieved_segments):
        table.add_row(f"Retrieved segment {i}:", segment)

    if answer:
        table.add_row("Answer:", answer)
    
    console.print(table)
    
    

def check_all_unique(iterable):
    """Memory and computation scale in O(N) with N = len(iterable). """
    
    iter_count = 0
    set_ = set()
    
    for item in iterable:
        iter_count += 1
        set_.add(item) 
    
    utils.check_equal(iter_count, len(set_))

    

def check_still_got_tpus():
    tpu_name = socket.gethostname()
    instance_count = subprocess.check_output(
        f"gcloud compute tpus list --zone {shlex.quote(_ZONE)} "
        f"| grep {shlex.quote(tpu_name)} | wc -l", shell=True
    ).decode().strip()
    assert instance_count == "1", (
        f"instance count: {instance_count}"
    )
    
    
    
def check_setup_is_as_expected(strategy):
    if _ACCEL_TYPE == "TPU":
        assert tf_utils.devices_to_use()[0].device_type == "TPU", (
        f_utils.devices_to_use()[0].device_type
        )
        assert isinstance(strategy, tf.distribute.TPUStrategy), (
            strategy
        )
        check_still_got_tpus()
    elif _ACCEL_TYPE == "CPU":
        pass
    else:
        raise ValueError(_ACCEL_TYPE)
    
    rich.print(f"[blue] < Things are good : {_ACCEL_TYPE} > [/]")
    
    
    
@dataclasses.dataclass
class Sample:
    distances: tf.Tensor
    gpt2_answer_ids: tf.Tensor
    gpt2_question_ids: tf.Tensor
    gpt2_retrieved_ids: tf.Tensor

        
print("done")

done


In [5]:
###############################################################################
# Long configuration stuff
###############################################################################

    
#------------------------------------------------------------------------------
# TPU Stuff
#------------------------------------------------------------------------------
if _ACCEL_TYPE == "TPU":
    tpu_name = socket.gethostname()
    check_still_got_tpus()
    tpu_setup = tf_utils.init_tpus(socket.gethostname())
    utils.check_equal(tf_utils.devices_to_use()[0].device_type, "TPU")
    utils.check_equal(len(tf_utils.devices_to_use()), 8)
    strategy = tf.distribute.TPUStrategy(tpu_setup.resolver)
    
elif _ACCEL_TYPE == "CPU":
    device = tf_utils.devices_to_use()[0]
    utils.check_equal(len(tf_utils.devices_to_use()), 1)
    utils.check_equal(device.device_type, "CPU")
    strategy = tf.distribute.OneDeviceStrategy(device)
    
else:
    raise RuntimeError(_ACCEL_TYPE)
    
    
#------------------------------------------------------------------------------
# Huggingface Stuff
#------------------------------------------------------------------------------
model_config = transformers.AutoConfig.from_pretrained(_MODEL_TYPE)
tokenizer = transformers.GPT2TokenizerFast.from_pretrained(_MODEL_TYPE)
splits_to_ds_paths = build_split_to_ds_paths(_PROJECT_DIRECTORY, _NUM_PATHS_DISPLAY)
print("Making sure all paths are unique")
for k, v in splits_to_ds_paths.items():
    print(k)
    check_all_unique(tqdm.tqdm(v))
context_window_size = model_config.n_ctx

print("done")

INFO:absl:Entering into master device scope: /job:worker/replica:0/task:0/device:CPU:0


INFO:tensorflow:Initializing the TPU system: jules


INFO:tensorflow:Initializing the TPU system: jules


INFO:tensorflow:Clearing out eager caches


INFO:tensorflow:Clearing out eager caches


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Found TPU system:


INFO:tensorflow:Found TPU system:


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


# Getting filenames.

#### Loading json config.

#### Calling `gsutil ls` on the dataset repo.

#### Printing a few paths:

There are actually 8192.

 - gs://julesgm-research-v3/tfrecord_query_cache/20210225-191356/eval_0.tfr
 - gs://julesgm-research-v3/tfrecord_query_cache/20210225-191356/eval_1.tfr
 - gs://julesgm-research-v3/tfrecord_query_cache/20210225-191356/eval_10.tfr
 - gs://julesgm-research-v3/tfrecord_query_cache/20210225-191356/eval_100.tfr
 - gs://julesgm-research-v3/tfrecord_query_cache/20210225-191356/eval_1000.tfr
 - gs://julesgm-research-v3/tfrecord_query_cache/20210225-191356/eval_1001.tfr
 - gs://julesgm-research-v3/tfrecord_query_cache/20210225-191356/eval_1002.tfr
 - gs://julesgm-research-v3/tfrecord_query_cache/20210225-191356/eval_1003.tfr
 - gs://julesgm-research-v3/tfrecord_query_cache/20210225-191356/eval_1004.tfr
 - gs://julesgm-research-v3/tfrecord_query_cache/20210225-191356/eval_1005.tfr

# Building the `per_split` Path dict.

HBox(children=(FloatProgress(value=0.0, description='Building `per_split` dict.', max=8192.0, style=ProgressSt…




Sorting the `per_split` lists.

Len per split for the per_split dict:

{'eval': 2048, 'test': 2048, 'train': 2048, 'validation': 2048}
Making sure all paths are unique
eval


HBox(children=(FloatProgress(value=0.0, max=2048.0), HTML(value='')))


test


HBox(children=(FloatProgress(value=0.0, max=2048.0), HTML(value='')))


train


HBox(children=(FloatProgress(value=0.0, max=2048.0), HTML(value='')))


validation


HBox(children=(FloatProgress(value=0.0, max=2048.0), HTML(value='')))


done


In [6]:
check_still_got_tpus()
maybe_retrieve_and_merge = {
    split: _make_maybe_retrieve_and_merge_fn(
        tokenizer=tokenizer,
        context_size=context_window_size,
        ds_split=split,
        approach_type=_APPROACH_TYPE,  # FLAG_APPROACH_TYPE.value
        use_helper_words=True,  # FLAG_USE_HELPER_WORDS
        retriever=None,  # pylint: disable=unused-argument
        temperature=0.03,
        num_retrievals=5,
        enable_debug_checks=False,
        max_length_generation=350,
        tf_function_kwargs=None,
) for split in ["train"]}

print("done")


def prepare_special_token(bpe_id):
    if bpe_id == tokenizer.eos_token_id:
        return f" << eos >> "
    if bpe_id < 0:
        return f" << {x} >> "
    else:
        raise ValueError(bpe_id)
    

    
def filter_and_decode(tokenizer, token_ids):
    return tokenizer.decode([
        x for x in token_ids if x >= 0 and x != tokenizer.eos_token_id
    ])
    
    
def format_output_text(text, mode="normal"):
    if mode == "normal":
        boldify_words = ["Question:", "Answer:", "Context:"]
        for word in boldify_words:
            text = text.replace(word, f"\n\n**{word}**")
            text = re.sub(r"\n{3, +}", "\n\n", text)
    elif mode == "plain":
        return text
    else:
        raise RuntimeError(mode)
    rich.jupyter.print(text)
    
    return text.strip()
        

done


In [7]:
print("Starting", flush=True)
check_setup_is_as_expected(strategy)
console = rich.console.Console()
vocab = tokenizer.get_vocab()
inversed_vocab = {v: k for k, v in vocab.items()}
print("Checked", flush=True)



for split in [
  # "eval", 
  # "test", 
    "train"
]:
    ds_paths = splits_to_ds_paths[split]
    ###########################################################################
    # Build and Distribute the DS
    ###########################################################################
    print("Building DS")
    ds = build_dataset(
        ds_paths, 
        context_window_size, 
        split, 
        len(tf_utils.devices_to_use())
    )
    
    ds = ds.map(maybe_retrieve_and_merge[split])
    dds = strategy.experimental_distribute_dataset(ds)
    
    print("Starting loop.", flush=True)
    for major, dist_items in enumerate(toolz.take(1 if _ACCEL_TYPE == "TPU" else 8, dds)):
        
        is_distributed = isinstance(dist_items["input_ids"],  values.PerReplica)
        utils.check_equal(is_distributed, isinstance(dist_items["label_ids"],  values.PerReplica))
        
        print("inputs, labels loop", flush=True)
        for inputs, labels in zip(
            dist_items["input_ids"].values if is_distributed else [dist_items["input_ids"]], 
            dist_items["label_ids"].values if is_distributed else [dist_items["label_ids"]]
        ):
            print(type(dist_items["input_ids"]), flush=True)
            utils.check_equal(inputs.shape[0], 1)
            utils.check_equal(labels.shape[0], 1)
            inputs = inputs[0]
            labels = labels[0]
            
            # table = rich.table.Table("Type", "Value", show_header=False, show_lines=True)
            # table.add_row(f"[bold]Inputs[/]", format_output_text(filter_and_decode(tokenizer, inputs)))
            # table.add_row(f"[bold]Labels:[/]", format_output_text(filter_and_decode(tokenizer, labels)))
            # rich.jupyter.print(table)
                        
            table = rich.table.Table(show_header=False, show_lines=True)
            for index, input_token, label_token in zip(
                    list(map(str, range(context_window_size))),
                    [inversed_vocab.get(x, f"<< {x} >>") for x in inputs.numpy().tolist()],
                    [inversed_vocab.get(x, f"<< {x} >>") for x in labels.numpy().tolist()]
                ):
                table.add_row(index, input_token, label_token)

            rich.jupyter.print(table)
            

print("done")

Starting


Checked
Building DS
Starting loop.


UnavailableError: Socket closed
Additional GRPC error information from remote target /job:worker/replica:0/task:0:
:{"created":"@1617243730.211673429","description":"Error received from peer ipv4:10.79.32.74:8470","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Socket closed","grpc_status":14}