<a href="https://colab.research.google.com/github/JasonObeid/InteractiveVisualizations/blob/master/T5%20Rotowire.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


<a href="https://colab.research.google.com/github/google-research/text-to-text-transfer-transformer/blob/master/notebooks/t5-trivia.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##### Copyright 2019 The T5 Authors

Licensed under the Apache License, Version 2.0 (the "License");

In [0]:
# Copyright 2019 The T5 Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# Fine-Tuning the Text-To-Text Transfer Transformer (T5) for Context-Free Trivia
## _Or: What does T5 know?_

*The following tutorial guides you through the process of fine-tuning a pre-trained T5 model, evaluating its accuracy, and using it for prediction,
all on a free Google Cloud TPU <a href="https://colab.research.google.com/github/google-research/text-to-text-transfer-transformer/blob/master/notebooks/t5-trivia.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>.*

### Background

T5 was introduced in the paper [_Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer_](https://arxiv.org/abs/1910.10683). In this paper, we provide a comprehensive picture of how we pre-trained a standard text-to-text Transformer model on a large text corpus, achieving state-of-the-art results on many NLP tasks after fine-tuning.

We pre-trained T5 on a mixture of supervised and unsupervised tasks with the majoriy of data coming from an unlabeled dataset we developed called [C4](https://www.tensorflow.org/datasets/catalog/c4). C4 is based on a massive scrape of the web produced by [Common Crawl](https://commoncrawl.org). Loosely speaking, pre-training on C4 ideally gives T5 an understanding of natural language in addition to general world knowledge.

### How can we assess what T5 knows?

As the name implies, T5 is a text-to-text model, which enables us to train it on arbitrary tasks involving a textual input and output. As we showed in our paper, a huge variety of NLP tasks can be cast in this format, including translation, summarization, and even classification and regression tasks.

One way to use this text-to-text framework is on question-answering problems, where the model is fed some context along with a question and is trained to predict the question's answer. For example, we might feed the model the text from the Wikipedia article about [Hurrican Connie](https://en.wikipedia.org/wiki/Hurricane_Connie) along with the question "On what date did Hurricane Connie occur?" and train the model to predict the answer "August 3rd, 1955".

In this notebook, we'll be training T5 on a variant of this task which we call **context-free question answering (QA)**. In context-free QA, we feed the model a question *without any context* and train it to predict the answer. Since the model doesn't receive any context, the primary way it can learn to answer these questions is based on the "knowledge" it obtained during pre-training. We don't expect T5 to contain super specific information, so we will be focusing on two question-answering datasets which largely include trivia questions (i.e. facts about well-known subjects). [Similar](https://arxiv.org/abs/1909.01066) [investigations](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) have recently been done on BERT and GPT-2.

T5 was not pre-trained on context-free QA, so in this notebook we'll first create two new tasks and then use the [`t5`](https://github.com/google-research/text-to-text-transfer-transformer) library to fine-tune, evaluate, and obtain predictions from T5. In the end, T5's performance on this context-free trivia QA can give us a sense of what kind (and how much) information T5 managed to learn during pre-training.


### Caveats

* While we provide instructions for running on a [Cloud TPU](https://cloud.google.com/tpu/) via Colab for free, a [Google Cloud Storage (GCS)](http://console.cloud.google.com/storage) bucket is required for storing model parameters and data. The [GCS free tier](https://cloud.google.com/free/) provides 5 GB of storage, which should be enough to train the `large` model and smaller but not the `3B` or `11B` parameter models. You can use part of your initial $300 credit to get more space.
* The Cloud TPU provided by Colab (a `v2-8`) does not have enough memory to fine-tune the `11B` parameter model. For this model, you will need to fine-tune inside of a GCP instance (see [README](https://github.com/google-research/text-to-text-transfer-transformer/)).


# Set Up

<h3><a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>  &nbsp;&nbsp;Train on TPU</h3>




   1. Create a Cloud Storage bucket for your data and model checkpoints at http://console.cloud.google.com/storage, and fill in the `BASE_DIR` parameter in the following form. There is a [free tier](https://cloud.google.com/free/) if you do not yet have an account.
 
   1. On the main menu, click Runtime and select **Change runtime type**. Set "TPU" as the hardware accelerator.
   1. Run the following cell and follow instructions to:
    *  Set up a Colab TPU running environment
    *   Verify that you are connected to a TPU device
    *   Upload your credentials to TPU to access your GCS bucket


In [1]:
import datetime
import functools
import json
import os
import pprint
import random
import string
import sys
import tensorflow as tf

BASE_DIR = "gs://t5storage" #@param { type: "string" }
DATA_DIR = os.path.join(BASE_DIR, "data")
MODELS_DIR = os.path.join(BASE_DIR, "models")
ON_CLOUD = True


if ON_CLOUD:
  assert "COLAB_TPU_ADDR" in os.environ, "ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!"
  TPU_ADDRESS = "grpc://" + os.environ["COLAB_TPU_ADDR"] 
  TPU_TOPOLOGY = "2x2"
  print("TPU address is", TPU_ADDRESS)

  from google.colab import auth
  auth.authenticate_user()
  with tf.Session(TPU_ADDRESS) as session:
    print('TPU devices:')
    pprint.pprint(session.list_devices())

    # Upload credentials to TPU.
    with open('/content/adc.json', 'r') as f:
      auth_info = json.load(f)
    tf.contrib.cloud.configure_gcs(session, credentials=auth_info)
    # Now credentials are set for all future sessions on this TPU.

TPU address is grpc://10.122.200.90:8470
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

TPU devices:
[_DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:CPU:0, CPU, -1, 1681300757859025885),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 10214636958195339468),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 8813983523005277219),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 8084431980772698488),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 16355736990175730119),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0

In [2]:
#@title Install and import required packages
if ON_CLOUD:
  !pip install -qU "t5>=0.1.7"
  !pip install -U "tfds-nightly>=1.3.2.dev201912070105"

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import t5
import tensorflow as tf
import tensorflow_datasets as tfds
import time

# Improve logging.
from contextlib import contextmanager
import logging as py_logging

if ON_CLOUD:
  tf.get_logger().propagate = False
  py_logging.root.setLevel('INFO')

@contextmanager
def tf_verbosity_level(level):
  og_level = tf.logging.get_verbosity()
  tf.logging.set_verbosity(level)
  yield
  tf.logging.set_verbosity(og_level)

[K     |████████████████████████████████| 122kB 3.4MB/s 
[K     |████████████████████████████████| 8.6MB 8.8MB/s 
[K     |████████████████████████████████| 1.0MB 55.4MB/s 
[K     |████████████████████████████████| 7.6MB 46.4MB/s 
[K     |████████████████████████████████| 245kB 47.5MB/s 
[K     |████████████████████████████████| 163kB 50.1MB/s 
[K     |████████████████████████████████| 256kB 49.4MB/s 
[K     |████████████████████████████████| 51kB 6.8MB/s 
[K     |████████████████████████████████| 133kB 50.7MB/s 
[K     |████████████████████████████████| 61kB 7.8MB/s 
[K     |████████████████████████████████| 204kB 47.1MB/s 
[?25h  Building wheel for overrides (setup.py) ... [?25l[?25hdone
  Building wheel for word2number (setup.py) ... [?25l[?25hdone
  Building wheel for jsonnet (setup.py) ... [?25l[?25hdone
  Building wheel for parsimonious (setup.py) ... [?25l[?25hdone
  Building wheel for numpydoc (setup.py) ... [?25l[?25hdone
  Building wheel for ftfy (setup.p

# Creating new Tasks and Mixture

Two core components of the T5 library are `Task` and `Mixture` objects.

A `Task` is a dataset along with preprocessing functions and evaluation metrics. A `Mixture` is a collection of `Task` objects along with a mixing rate or a function defining how to compute a mixing rate based on the properties of the constituent `Tasks`.

For this example, we will fine-tune the model to do context-free trivia question answering.

In [30]:
import pandas as pd
import numpy as np

from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)
root_dir = "/content/gdrive/My Drive/"
base_dir = root_dir + 'rotowire/'

train = ['src_train.txt', 'tgt_train.txt']
valid = ['src_valid.txt', 'tgt_valid.txt']
test = ['src_test.txt', 'tgt_test.txt']

trainFrame = pd.DataFrame()
validFrame = pd.DataFrame()
testFrame = pd.DataFrame()

src_path = os.path.join(base_dir, valid[0])
tgt_path = os.path.join(base_dir, valid[1])
with open(src_path) as src:
  with open(tgt_path) as tgt:
    for src_line, tgt_line in zip(src, tgt):
      row = pd.Series(index=["data", "target"], data=[src_line, tgt_line])
      validFrame = validFrame.append(row, ignore_index=True)

src_path = os.path.join(base_dir, test[0])
tgt_path = os.path.join(base_dir, test[1])
with open(src_path) as src:
  with open(tgt_path) as tgt:
    for src_line, tgt_line in zip(src, tgt):
      row = pd.Series(index=["data", "target"], data=[src_line, tgt_line])
      testFrame = testFrame.append(row, ignore_index=True)

src_path = os.path.join(base_dir, train[0])
tgt_path = os.path.join(base_dir, train[1])
with open(src_path) as src:
  with open(tgt_path) as tgt:
    for src_line, tgt_line in zip(src, tgt):
      row = pd.Series(index=["data", "target"], data=[src_line, tgt_line])
      trainFrame = trainFrame.append(row, ignore_index=True)

Mounted at /content/gdrive


Next, we define a function to load the TSV data as a `tf.data.Dataset` in TensorFlow.

In [0]:
RT_JSON_DIR = os.path.join(DATA_DIR, "rotowire")
rt_tsv_path = {
    "train": os.path.join(RT_JSON_DIR, "train.tsv"),
    "test": os.path.join(RT_JSON_DIR, "test.tsv"),
    "validation": os.path.join(RT_JSON_DIR, "valid.tsv")
}

In [28]:
x = validFrame.to_dict()
print(x)

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [85]:
def nq_dataset_fn(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  if(split == "validation"):
    #validDs = validFrame.to_dict()
    ds = tf.data.Dataset.from_tensor_slices(validFrame.values)
  if(split == "test"):
    #testDs = testFrame.to_dict()
    ds = tf.data.Dataset.from_tensor_slices(validFrame.values)
  if(split == "train"):
    #trainDs = trainFrame.to_dict()
    ds = tf.data.Dataset.from_tensor_slices(validFrame.values)
  ds = ds.map(lambda x: dict(data=x[0],text=x[1]))
  return ds

print("A few raw validation examples...")
for ex in tfds.as_numpy(nq_dataset_fn("validation").take(10)):
  print(ex)

A few raw validation examples...
{'data': b'F\xef\xbf\xa8DeMarre_Carroll\xef\xbf\xa8START_POSITION\xef\xbf\xa8HOME 21\xef\xbf\xa8DeMarre_Carroll\xef\xbf\xa8MIN\xef\xbf\xa8HOME 10\xef\xbf\xa8DeMarre_Carroll\xef\xbf\xa8PTS\xef\xbf\xa8HOME 4\xef\xbf\xa8DeMarre_Carroll\xef\xbf\xa8FGM\xef\xbf\xa8HOME 5\xef\xbf\xa8DeMarre_Carroll\xef\xbf\xa8FGA\xef\xbf\xa8HOME 80\xef\xbf\xa8DeMarre_Carroll\xef\xbf\xa8FG_PCT\xef\xbf\xa8HOME 2\xef\xbf\xa8DeMarre_Carroll\xef\xbf\xa8FG3M\xef\xbf\xa8HOME 3\xef\xbf\xa8DeMarre_Carroll\xef\xbf\xa8FG3A\xef\xbf\xa8HOME 67\xef\xbf\xa8DeMarre_Carroll\xef\xbf\xa8FG3_PCT\xef\xbf\xa8HOME 0\xef\xbf\xa8DeMarre_Carroll\xef\xbf\xa8FTM\xef\xbf\xa8HOME 0\xef\xbf\xa8DeMarre_Carroll\xef\xbf\xa8FTA\xef\xbf\xa8HOME 0\xef\xbf\xa8DeMarre_Carroll\xef\xbf\xa8FT_PCT\xef\xbf\xa8HOME 1\xef\xbf\xa8DeMarre_Carroll\xef\xbf\xa8OREB\xef\xbf\xa8HOME 4\xef\xbf\xa8DeMarre_Carroll\xef\xbf\xa8DREB\xef\xbf\xa8HOME 5\xef\xbf\xa8DeMarre_Carroll\xef\xbf\xa8REB\xef\xbf\xa8HOME 3\xef\xbf\xa8DeMarre_Carrol

Now, we write a preprocess function to convert the examples in the `tf.data.Dataset` into a text-to-text format, with both `inputs` and `targets` fields. The preprocessor also normalizes the text by lowercasing it and removing quotes since the answers are sometimes formatted in odd ways. Finally, we prepend 'trivia question:' to the inputs so that the model knows what task it's trying to solve.

In [0]:
def trivia_preprocessor(ds):
  def normalize_text(text):
    """Lowercase and remove quotes from a TensorFlow string."""
    text = tf.strings.lower(text)
    text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
    return text

  def to_inputs_and_targets(ex):
    """Map {"question": ..., "answer": ...}->{"inputs": ..., "targets": ...}."""
    return {
        "inputs":
             tf.strings.join(
                 ["rotowire: ", normalize_text(ex["data"])]),
        "targets": normalize_text(ex["text"])
    }
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [0]:
rt_tsv_path = {
    "train": trainFrame,
    "test": testFrame,
    "validation": validFrame
}
num_nq_examples = testFrame.shape[0] + trainFrame.shape[0] + validFrame.shape[0]

Finally, we put everything together to create a `Task`.

In [88]:
t5.data.TaskRegistry.add(
    "nq_context_free",
    # Supply a function which returns a tf.data.Dataset.
    dataset_fn=nq_dataset_fn,
    splits=["train", "validation", "test"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[trivia_preprocessor],
    # Use the same vocabulary that we used for pre-training.
    sentencepiece_model_path=t5.data.DEFAULT_SPM_PATH,
    # Lowercase targets before computing metrics.
    postprocess_fn=t5.data.postprocessors.lower_text, 
    # We'll use accuracy as our evaluation metric.
    metric_fns=[t5.evaluation.metrics.accuracy],
    # Not required, but helps for mixing and auto-caching.
    num_input_examples=num_nq_examples
)

ValueError: ignored

Let's look at a few pre-processed examples from the validation set. Note they contain both the tokenized (integer) and plain-text inputs and targets.


In [90]:
nq_task = t5.data.TaskRegistry.get("nq_context_free")
ds = nq_task.get_dataset(split="validation", sequence_length={"inputs": 128, "targets": 32})
print("A few preprocessed validation examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)

TypeError: ignored

## TriviaQA

A second dataset we will use is related to [TriviaQA](https://nlp.cs.washington.edu/triviaqa/). It is also intended for reading comprehension, but, once again, we will modify the task here by ignoring the provided context.

Since the dataset has been imported into [TensorFlow Datasets (TFDS)](https://www.tensorflow.org/datasets/catalog/trivia_qa), we can let it handle the data parsing for us. It will take a few minutes to download and preprocess the first time, but we'll be able to access it instantly from our data directory afterward.

In [0]:
ds = tfds.load(
    "trivia_qa/unfiltered.nocontext",
    data_dir=DATA_DIR,
    # Download data locally for preprocessing to avoid using GCS space.
    download_and_prepare_kwargs={"download_dir": "./downloads"})
print("A few raw validation examples...")
for ex in tfds.as_numpy(ds["validation"].take(2)):
  print(ex)

INFO:absl:Load pre-computed datasetinfo (eg: splits) from bucket.
INFO:absl:Generating dataset trivia_qa (gs://t5storage/data/trivia_qa/unfiltered.nocontext/1.1.0)


[1mDownloading and preparing dataset trivia_qa (?? GiB) to gs://t5storage/data/trivia_qa/unfiltered.nocontext/1.1.0...[0m


HBox(children=(IntProgress(value=1, bar_style='info', description='Dl Completed...', max=1, style=ProgressStyl…

HBox(children=(IntProgress(value=1, bar_style='info', description='Dl Size...', max=1, style=ProgressStyle(des…

HBox(children=(IntProgress(value=1, bar_style='info', description='Extraction completed...', max=1, style=Prog…

INFO:absl:Downloading http://nlp.cs.washington.edu/triviaqa/data/triviaqa-unfiltered.tar.gz into ./downloads/nlp.cs.washin.edu_trivia_trivia-unfiltFwGuJqGCvLUAj7fbBJcIJwITr8d6aTz4A3xk1zDUQ-A.tar.gz.tmp.988c39b998d44fb5a052cfdc83ceaffc...
INFO:absl:Generating split train








HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

INFO:absl:generating examples from = ./downloads/extracted/TAR_GZ.nlp.cs.washin.edu_trivia_trivia-unfiltfwyUIH_Qoet7uj1Szf4HNcmN6FC55apOdfJ3bvmlMdA.tar.gz/triviaqa-unfiltered/unfiltered-web-train.json


As with Natural Questions, we need to preprocess the raw examples into `inputs` and `targets` features. We can reuse the `trivia_preprocessor` above, but first we need to convert the TriviaQA examples into the correct format, ignoring the fields we don't need for our task.

We'll then define our `Task` and print out a few preprocessed examples from the validation set.

Note that we do not need to specify the splits or number of examples since that information is provided by TFDS.

In [0]:
def tiviaqa_extract_qa(ds):
  def exract_qa(ex):
    return {
        "question": ex["question"],
        "answer": ex["answer"]["value"]
    }
  return ds.map(exract_qa, num_parallel_calls=tf.data.experimental.AUTOTUNE)

t5.data.TaskRegistry.add(
    "triviaqa_context_free",
    # A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
    t5.data.TfdsTask,
    tfds_name="trivia_qa/unfiltered.nocontext:1.1.0",
    tfds_data_dir=DATA_DIR,
    sentencepiece_model_path=t5.data.DEFAULT_SPM_PATH,
    text_preprocessor=[tiviaqa_extract_qa, trivia_preprocessor],
    postprocess_fn=t5.data.postprocessors.lower_text,
    metric_fns=[t5.evaluation.metrics.accuracy]
)

# Load and print a few examples.
triviaqa_task = t5.data.TaskRegistry.get("triviaqa_context_free")
ds = triviaqa_task.get_dataset(split="validation", sequence_length={"inputs": 128, "targets": 32})
print("A few preprocessed validation examples...")
for ex in tfds.as_numpy(ds.take(3)):
  print(ex)

NameError: ignored

## Dataset Mixture

We now create a `Mixture` from the above `Tasks`, which we will fine-tune on.

There are different ways to automatically set the rate (for example, based on the number of examples using `rate_num_examples`), but we will just hardcode an equal mixture for simplicity.

In [0]:
t5.data.MixtureRegistry.remove("trivia_all")
t5.data.MixtureRegistry.add(
    "trivia_all",
    ["nq_context_free", "triviaqa_context_free"],
     default_rate=1.0
)

# Transferring to new Tasks

We are now ready to fine-tune one of the pre-trained T5 models on our new mixture of context-free QA tasks.

First, we'll instantiate a `Model` object using the model size of your choice. Note that larger models are slower to train and use but will likely achieve higher accuracy. You also may be able to increase accuracy by training longer with more `FINETUNE_STEPS` below.


## Caveats

* Due to its memory requirements, you will not be able to train the `11B` parameter model on the TPU provided by Colab. Instead, you will need to fine-tune inside of a GCP instance (see [README](https://github.com/google-research/text-to-text-transfer-transformer/)).
* Due to the checkpoint size, you will not be able use the 5GB GCS free tier for the `3B` parameter models. You will need at least 25GB of space, which you can purchase with your $300 of initial credit on GCP.
* While `large` can achieve decent results, it is recommended that you fine-tune at least the `3B` parameter model.


## Define Model

In [0]:
MODEL_SIZE = "small" #@param["small", "base", "large", "3B", "11B"]
# Public GCS path for T5 pre-trained model checkpoints
BASE_PRETRAINED_DIR = "gs://t5-data/pretrained_models"
PRETRAINED_DIR = os.path.join(BASE_PRETRAINED_DIR, MODEL_SIZE)
MODEL_DIR = os.path.join(MODELS_DIR, MODEL_SIZE)

if ON_CLOUD and MODEL_SIZE == "3B":
  tf.logging.warn(
      "The `3B` model is too large to use with the 5GB GCS free tier. "
      "Make sure you have at least 25GB on GCS before continuing."
  )
elif ON_CLOUD and MODEL_SIZE == "11B":
  raise ValueError(
      "The `11B` parameter is too large to fine-tune on the `v2-8` TPU "
      "provided by Colab. Please comment out this Error if you're running "
      "on a larger TPU."
  )

# Set parallelism and batch size to fit on v2-8 TPU (if possible).
# Limit number of checkpoints to fit within 5GB (if possible).
model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 256, 16),
    "base": (2, 128, 8),
    "large": (8, 64, 4),
    "3B": (8, 16, 1),
    "11B": (8, 16, 1)}[MODEL_SIZE]

tf.io.gfile.makedirs(MODEL_DIR)
# The models from our paper are based on the Mesh Tensorflow Transformer.
model = t5.models.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    sequence_length={"inputs": 128, "targets": 32},
    learning_rate_schedule=0.003,
    save_checkpoints_steps=5000,
    keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None,
    iterations_per_loop=100,
)

Before we continue, let's load a [TensorBoard](https://www.tensorflow.org/tensorboard) visualizer so that we can keep monitor our progress. The page should automatically update as fine-tuning and evaluation proceed.

In [0]:
if ON_CLOUD:
  %reload_ext tensorboard
  import tensorboard as tb
tb.notebook.start("--logdir " + MODELS_DIR)

## Fine-tune

We are now ready to fine-tune our model. This will take a while (~2 hours with default settings), so please be patient! The larger the model and more `FINETUNE_STEPS` you use, the longer it will take.

Don't worry, you can always come back later and increase the number of steps, and it will automatically pick up where you left off.

In [0]:
FINETUNE_STEPS = 25000 #@param {type: "integer"}

model.finetune(
    mixture_or_task_name="trivia_all",
    pretrained_model_dir=PRETRAINED_DIR,
    finetune_steps=FINETUNE_STEPS
)

## Expected Results [SPOILER ALERT]

Below are the expected accuracies on the Natural Question (NQ) and TriviQA validation sets for various model sizes. The full 11B model achieves a maximum validation accuracy of 34.5% and 25.1% on TriviaQA and NQ, respectively. The 3B parameter model, which is the largest that can be trained with a free Cloud TPU in Colab, achieves 29.7% and 23.7%, respectively. These results may not sound very good, but this it is actually fairly impressive considering the model is given no context.

Furthermore, as you'll see in the Evaluate section, these numbers are actually a lower bound since the model often outputs the correct answer in a slightly different format than is expected, which is counted as incorrect. This helps to explain why the model appears to perform better on TriviaQA than NQ for our metric, since the latter tends to include more long-form answers extracted from the context.

<img src="https://storage.googleapis.com/t5-data/assets/t5_trivia_expected.png">

## Evaluate

We now evaluate on the validation sets of the tasks in our mixture. Accuracy results will be logged and added to the TensorBoard above.

In [0]:
# Use a larger batch size for evaluation, which requires less memory.
model.batch_size = train_batch_size * 4
model.eval(
    mixture_or_task_name="trivia_all",
    checkpoint_steps="all"
)

Let's look at a few random predictions from the validation sets. Note that we measure accuracy based on an *exact match* of the predicted answer and the ground-truth answer. As a result, some of the answers are semantically correct but are counted wrong by the exact match score.

In [0]:
def print_random_predictions(task_name, n=10):
  """Print n predictions from the validation split of a task."""
  # Grab the dataset for this task.
  ds = t5.data.TaskRegistry.get(task_name).get_dataset(
      split="validation",
      sequence_length={"inputs": 128, "targets": 32},
      shuffle=False)

  def _prediction_file_to_ckpt(path):
    """Extract the global step from a prediction filename."""
    return int(path.split("_")[-2])

  # Grab the paths of all logged predictions.
  prediction_files = tf.io.gfile.glob(
      os.path.join(
          MODEL_DIR,
          "validation_eval/%s_*_predictions" % task_name))
  # Get most recent prediction file by sorting by their step.
  latest_prediction_file = sorted(
      prediction_files, key=_prediction_file_to_ckpt)[-1]

  # Collect (inputs, targets, prediction) from the dataset and predictions file
  results = []
  with tf.io.gfile.GFile(latest_prediction_file) as preds:
    for ex, pred in zip(tfds.as_numpy(ds), preds):
      results.append((tf.compat.as_text(ex["inputs_plaintext"]),
                      tf.compat.as_text(ex["targets_plaintext"]),
                      pred.strip()))

  print("<== Random predictions for %s using checkpoint %s ==>\n" %
        (task_name, 
         _prediction_file_to_ckpt(latest_prediction_file)))

  for inp, tgt, pred in random.choices(results, k=10):
    print("Input:", inp)
    print("Target:", tgt)
    print("Prediction:", pred)
    print("Counted as Correct?", tgt == pred)
    print()

print_random_predictions("triviaqa_context_free")
print_random_predictions("nq_context_free")

<== Random predictions for triviaqa_context_free using checkpoint 1100000 ==>

Input: trivia question: jackpot counter, ghost drop and drop zone are all terms used in which uk television game show?
Target: tipping point
Prediction: countdown
Counted as Correct? False

Input: trivia question: cursed to sail around the cape of good hope, which ghost ship is the theme of an 1841 opera by richard wagner?
Target: the flying dutchman
Prediction: baron von munchhausen
Counted as Correct? False

Input: trivia question: at what fret are found the same notes as the open strings, but an octave higher, on a standard guitar?
Target: 12th
Prediction: 12th
Counted as Correct? True

Input: trivia question: how many legs does a ladybird have?
Target: six
Prediction: six
Counted as Correct? True

Input: trivia question: in which city’s harbour was the ship queen elizabeth ravaged by fire in 1972?
Target: hong kong
Prediction: hong kong
Counted as Correct? True

Input: trivia question: what are the three

## Predict

Now that we have fine-tuned the model, we can feed T5 arbitrary questions and have it predict the answers!

There is a significant amount of overhead in initializing the model so this may take a few minutes to run each time even though the prediction itself is quite fast.


To avoid this overhead, you might consider exporting a `SavedModel` and running it on [Cloud ML Engine](https://cloud.google.com/ml-engine/).



In [0]:
question_1 = "Where is the Google headquarters located?" #@param {type:"string"}
question_2 = "What is the most populous country in the world?" #@param {type:"string"}
question_3 = "Who are the 4 members of The Beatles?" #@param {type:"string"}
question_4 = "How many teeth do humans have?" #@param {type:"string"}

questions = [question_1, question_2, question_3, question_4]

now = time.time()
# Write out the supplied questions to text files.
predict_inputs_path = os.path.join(MODEL_DIR, "predict_inputs_%d.txt" % now)
predict_outputs_path = os.path.join(MODEL_DIR, "predict_outputs_%d.txt" % now)
# Manually apply preprocessing by prepending "triviaqa question:".
with tf.io.gfile.GFile(predict_inputs_path, "w") as f:
  for q in questions:
    f.write("trivia question: %s\n" % q.lower())

# Ignore any logging so that we only see the model's answers to the questions.
with tf_verbosity_level('ERROR'):
  model.batch_size = len(questions)
  model.predict(
      input_file=predict_inputs_path,
      output_file=predict_outputs_path,
      # Select the most probable output token at each step.
      temperature=0,
  )

# The output filename will have the checkpoint appended so we glob to get 
# the latest.
prediction_files = sorted(tf.io.gfile.glob(predict_outputs_path + "*"))
print("\nPredictions using checkpoint %s:\n" % prediction_files[-1].split("-")[-1])
with tf.io.gfile.GFile(prediction_files[-1]) as f:
  for q, a in zip(questions, f):
    if q:
      print("Q: " + q)
      print("A: " + a)
      print()


Predictions using checkpoint 1100000:

Q: Where is the Google headquarters located?
A: mountain view, california


Q: What is the most populous country in the world?
A: china


Q: Who are the 4 members of The Beatles?
A: john lennon, paul mccartney, george harrison and ringo starr


Q: How many teeth do humans have?
A: 30




## Export

As mentioned in the previous section, exporting a [`SavedModel`](https://www.tensorflow.org/guide/saved_model) can be useful for improving performance during inference or allowing your model to be deployed on a variety of platforms (e.g., TFLite, TensorFlow.js, TensorFlow Serving, or TensorFlow Hub).

In [0]:
model.export(
    os.path.join(MODEL_DIR, "export"),
    checkpoint_step=-1,  # use most recent
    beam_size=1,  # no beam search
    temperature=1.0,  # sample according to predicted distribution
)