In [1]:
import functools
import os

import clu.data.dataset_iterator
import tensorflow as tf
import jax
from jax import random
from jax.experimental import multihost_utils
import jax.numpy as jnp
from flax import linen
import numpy as np
import seqio
import t5.data
from t5.evaluation import metrics as t5_metrics

import t5x
from t5x import partitioning
from t5x import train_state as train_state_lib
from t5x import utils
from t5x.examples.t5 import network
from t5x.examples.scalable_t5 import network as scalable_network
from t5x.interactive_model import InteractiveModel
from t5x.interactive_model import get_batches_from_seqio
from t5x.interactive_model import InferenceType
import nest_asyncio
nest_asyncio.apply()

2023-04-14 22:07:54.577490: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib
2023-04-14 22:07:55.249181: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib
2023-04-14 22:07:55.249303: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib


In [11]:
# Define EncoderDecoderModel constructor args (except the module).
input_vocabulary=t5.data.get_default_vocabulary()
output_vocabulary=t5.data.get_default_vocabulary()
optimizer=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0, logical_factor_rules=t5x.adafactor.standard_logical_factor_rules())
decode_fn=functools.partial(t5x.decoding.temperature_sample, temperature=1.0, topk=40)

# Define a model using the minimal T5 module.
t5_module = network.Transformer(config=network.T5Config(
    vocab_size=32128,
    dtype='bfloat16',
    emb_dim=512,
    num_heads=6,
    num_encoder_layers=8,
    num_decoder_layers=8,
    head_dim=64,
    mlp_dim=1024,
    mlp_activations=('gelu', 'linear'),
    dropout_rate=0.0,
    logits_via_embedding=False))
model = t5x.models.EncoderDecoderModel(
    module=t5_module,
    input_vocabulary=input_vocabulary,
    output_vocabulary=output_vocabulary,
    optimizer_def=optimizer,
    decode_fn=decode_fn)

In [12]:
# The checkpoint below is a T5-1.1-Small checkpoint (https://github.com/google-research/t5x/blob/main/docs/models.md) 
# that has additionally been finetuned on the (Open Domain) Natural Questions 
# benchmark (https://ai.google.com/research/NaturalQuestions).
checkpoint_path='gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000'
dtype='bfloat16'

restore_mode='specific'

In [13]:
partitioner=partitioning.PjitPartitioner(
        num_partitions=1,
        model_parallel_submesh=None)

In [91]:
batch_size=8
task_feature_lengths = {'inputs': 256, 'targets': 18}
output_dir='/tmp/output_dir'
input_shapes = {
    'encoder_input_tokens': np.array([8, 256]),
    'decoder_target_tokens': np.array([8, 18]),
    'decoder_input_tokens': np.array([8, 18]),
    'decoder_loss_weights': np.array([8, 18])
}

interactive_model = InteractiveModel(
  batch_size=batch_size,
  task_feature_lengths=task_feature_lengths,
  output_dir=output_dir,
  partitioner=partitioner,
  model=model,
  dtype=dtype,
  restore_mode=restore_mode,
  checkpoint_path=checkpoint_path,
  input_shapes=input_shapes
)



In [102]:
question = f'nq question: ' \
'1241523426247511241523426247512425123634624572345756484568456845686341241523426247512425123634624572345756484568456845686342425123634624572345756484568456845686347124152342624756347124152342624756347124152342624756347124152342624756347124152342624756347124152342624756347' \
'what is the capital of france?'

In [103]:
len(input_vocabulary.encode(question))

139

In [104]:
validation_examples = [
  {
      'target': 'Joe Biden', 
      'input':'nq question: who is the president of the united states'
  }, 
  {
      'target': 'F. Scott Fitzgerald', 
      'input': question},
  {
      'target': '1914', 
      'input': 'nq question: in what year did the first world war begin'}, 
  {
      'target': 'Idina Menzel', 
      'input': 'nq question: who does the voice of elsa in Frozen'}, 
  {
      'target': 'Taylor Swift', 
      'input': 'nq question: who sings shake it off'}, 
  {
      'target': 'Tom Kenny', 
      'input': 'nq question: who voices spongebob squarepants'}, 
  {
      'target': '2010', 
      'input': 'nq question: when did the great british bake off start'}, 
  {
      'target': 'the Philadelphia Eagles', 
      'input': 'nq question: who won the superbowl in 2018'},
]

In [105]:
examples_and_predictions, _ = interactive_model.predict_with_aux(examples=validation_examples)
predictions = [prediction for example, prediction in examples_and_predictions]
for val_example, prediction in zip(validation_examples, predictions):
    print(f"Prediction: {prediction.decode('utf-8')}\n")

examples_and_scores = interactive_model.score(examples=validation_examples)
scores = [score for example, score in examples_and_scores]
print(f"Scores: {scores}\n")

Prediction: Donald Trump

Prediction: Paris

Prediction: 28 July 1914

Prediction: Essie Davis

Prediction: The Mavericks

Prediction: Kelsey Grammer

Prediction: 1992

Prediction: Philadelphia Eagles

Scores: [-30.693138122558594, -66.5424575805664, -8.097404479980469, -3.092555046081543, -10.69161319732666, -9.33945369720459, -15.275007247924805, -39.803077697753906]

