!pip -q install trax
!pip install t5
!pip install -U jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [1]:
import string
import t5
import numpy as np
import trax
import gc
import time
from trax.supervised import decoding
from numba import cuda 
import textwrap 
# Will come handy later.

2022-09-26 18:19:15.056319: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
PAD, EOS, UNK = 0, 1, 2

def detokenize(np_array):
    return trax.data.detokenize(
    np_array,
    vocab_type = 'sentencepiece',
    vocab_file = 'sentencepiece.model',
    vocab_dir ="./")

def tokenize(s):
  # The trax.data.tokenize function operates on streams,
  # that's why we have to create 1-element stream with iter
  # and later retrieve the result with next.
  return next(trax.data.tokenize(
      iter([s]),
      vocab_type = 'sentencepiece',
      vocab_file = 'sentencepiece.model',
      vocab_dir = "./"))
 
vocab_size = trax.data.vocab_size(
    vocab_type = 'sentencepiece',
    vocab_file = 'sentencepiece.model',
    vocab_dir = "./")

def get_sentinels(vocab_size):
    sentinels = {}

    for i, char in enumerate(reversed(string.ascii_letters), 1):

        decoded_text = detokenize([vocab_size - i]) 
        
        # Sentinels, ex: <Z> - <a>
        sentinels[decoded_text] = f'<{char}>'
        
    return sentinels

sentinels = get_sentinels(vocab_size)    


def pretty_decode(encoded_str_list, sentinels=sentinels):
    # If already a string, just do the replacements.
    if isinstance(encoded_str_list, (str, bytes)):
        for token, char in sentinels.items():
            encoded_str_list = encoded_str_list.replace(token, char)
        return encoded_str_list
  
    # We need to decode and then prettyfy it.
    return pretty_decode(detokenize(encoded_str_list))

In [3]:
# Initialize the model 
model = trax.models.Transformer(
    d_ff = 4096,
    d_model = 1024,
    max_len = 2048,
    n_heads = 16,
    dropout = 0.1,
    input_vocab_size = 32000,
    n_encoder_layers = 24,
    n_decoder_layers = 24,
    mode='predict')  # Change to 'eval' for slow decoding.

In [4]:
# load in the model
# this will take a minute
shape11 = trax.shapes.ShapeDtype((1, 1), dtype=np.int32)
model.init_from_file('./model_squad.pkl.gz',
                     weights_only=True, input_signature=(shape11, shape11))
empty_state = model.state



In [10]:
# an extensive example
inputs = '''question: when doesthe definiton of thermodynamics was stated? context: Historically, thermodynamics developed out of a desire to increase the efficiency of early steam engines, particularly through the work of French physicist Sadi Carnot (1824) who believed that engine efficiency was the key that could help France win the Napoleonic Wars.[1] Scots-Irish physicist Lord Kelvin was the first to formulate a concise definition of thermodynamics in 1854'''
wrapper = textwrap.TextWrapper(width=70)
print(tokenize(inputs))
test_inputs = tokenize(inputs)

[  822    10   125   405     3 11599  2116    58  2625    10 22139    19
     8   793  2056    24  2116  1052     6   165  4431 17429     7     6
   165  4644    11  3889   190   628    11    97     6    11     8  1341
 12311    13   827    11  2054     5 22139    19    80    13     8   167
  4431  4290 15015     6    28   165   711  1288   271    12   734   149
     8  8084 19790     7     5  6306   115   908  6306   519   908  6306
   591   908  6306   755   908    71 17901   113     3 17095    16     8
  1057    13     3 11599    19   718     3     9     3  6941     7   447
   343     5 22139    19    80    13     8 10043  2705 15015    11     6
   190   165 11980    13     3 12466    63     6  2361     8 10043     5
 23847   231    13     8   657   192  3293    35    29    23     9     6
     3 11599     6     3 11366     6 15651     6    11   824  9678    13
 17082   130     3     9   294    13   793  8156     6    68   383     8
 19268 12197    16     8  1003   189  2646   175   

In [11]:
start_time = time.time()
output = decoding.autoregressive_sample(model, inputs=np.array(test_inputs)[None, :],
                                        temperature=0.0, max_length=100, accelerate=False, eval_mode=False)
print(wrapper.fill(pretty_decode(output[0])))
end_time=time.time()
print("--- command took %s seconds ---" % (time.time() - start_time))
model.state = empty_state

its fundamental constituents
--- command took 9.610665559768677 seconds ---
