<a href="https://colab.research.google.com/github/Spycsh/tf-saved-model-text-generation/blob/main/tf_xla_generate_saved_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Faster Text Generation with TensorFlow and XLA

This notebook is a companion to the 🤗 [blog post with the same title](https://huggingface.co/blog/tf-xla-generate). 
It is meant to illustrate how to use XLA with TensorFlow text generation.

It contains two stand-alone examples, one for encoder-decoder models and another for decoder-only models.

⚠️ If you are running this on colab, you might not have access to a GPU. The benefits of XLA are best observed with a GPU!

In [1]:
# Preparing the environment
!pip install transformers>=4.21.0

In [2]:
# Stand-alone TF XLA generate example for Encoder-Decoder Models.

# Note: execution times are deeply dependent on hardware.
# If you have a machine with a powerful GPU, I highly recommend you to try this example there!
import time
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM

# 1. Load model and tokenizer
model_name = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)

# 2. Prepare tokenization and generation arguments -- don't forget padding to avoid retracing!
tokenization_kwargs = {"pad_to_multiple_of": 32, "padding": True, "return_tensors": "tf"}
generation_kwargs = {"num_beams": 4, "max_new_tokens": 32}

# 3. Create your XLA generate function a̶n̶d̶ ̶m̶a̶k̶e̶ ̶P̶y̶T̶o̶r̶c̶h̶ ̶e̶a̶t̶ ̶d̶u̶s̶t̶
# This is the only change with respect to original generate workflow!
xla_generate = tf.function(model.generate, jit_compile=True)

# 4. Generate! Remember -- the first call will be slow, but all subsequent calls will be fast if you've done things right.
input_prompts = [
    f"translate English to {language}: I have four cats and three dogs." for language in ["German", "French", "Romanian"]
]
for input_prompt in input_prompts:
    tokenized_inputs = tokenizer([input_prompt], **tokenization_kwargs)
    start = time.time_ns()
    generated_text = xla_generate(**tokenized_inputs, **generation_kwargs)
    end = time.time_ns()
    decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)
    print(f"Original prompt -- {input_prompt}")
    print(f"Generated -- {decoded_text}")
    print(f"Execution time -- {(end - start) / 1e6:.1f} ms\n")

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-small automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


Downloading tf_model.h5:   0%|          | 0.00/242M [00:00<?, ?B/s]

All model checkpoint layers were used when initializing TFT5ForConditionalGeneration.

All the layers of TFT5ForConditionalGeneration were initialized from the model checkpoint at t5-small.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.


Downloading (…)neration_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


Original prompt -- translate English to German: I have four cats and three dogs.
Generated -- Ich habe vier Katzen und drei Hunde.
Execution time -- 34819.4 ms

Original prompt -- translate English to French: I have four cats and three dogs.
Generated -- J'ai quatre chats et trois chiens.
Execution time -- 3938.4 ms

Original prompt -- translate English to Romanian: I have four cats and three dogs.
Generated -- Am patru pisici şi trei câini.
Execution time -- 2270.4 ms



In [46]:
for input_prompt in input_prompts:
    tokenized_inputs = tokenizer([input_prompt], **tokenization_kwargs)
    start = time.time_ns()
    print(tokenized_inputs)
    generated_text = model.generate(**tokenized_inputs)
    print("xxxxxxxxx")
    print(generated_text)
    print("xxxxxxxxx")
    end = time.time_ns()
    decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)
    print(f"Original prompt -- {input_prompt}")
    print(f"Generated -- {decoded_text}")
    print(f"Execution time -- {(end - start) / 1e6:.1f} ms\n")

{'input_ids': <tf.Tensor: shape=(1, 32), dtype=int32, numpy=
array([[13959,  1566,    12,  2968,    10,    27,    43,   662, 10003,
           11,   386,  3887,     5,     1,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0]], dtype=int32)>, 'attention_mask': <tf.Tensor: shape=(1, 32), dtype=int32, numpy=
array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)>}




xxxxxxxxx
tf.Tensor([[    0  1674  2010  7629 16699    35    64  4052 14216     5     1]], shape=(1, 11), dtype=int32)
xxxxxxxxx
Original prompt -- translate English to German: I have four cats and three dogs.
Generated -- Ich habe vier Katzen und drei Hunde.
Execution time -- 3203.2 ms

{'input_ids': <tf.Tensor: shape=(1, 32), dtype=int32, numpy=
array([[13959,  1566,    12,  2379,    10,    27,    43,   662, 10003,
           11,   386,  3887,     5,     1,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0]], dtype=int32)>, 'attention_mask': <tf.Tensor: shape=(1, 32), dtype=int32, numpy=
array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)>}
xxxxxxxxx
tf.Tensor(
[[    0   446    31     9    23 12081  3582     7     3    15    17  5611
  17826     7     5     1]], shape=(1, 16), dtype=int32)
xxxxxxxxx
Original prompt -- tr

In [40]:
@tf.function(jit_compile=True)
def generates(**kwargs):
    return model.generate(**kwargs)

tf.saved_model.save(model, "./model", signatures={"generates":generates.get_concrete_function(input_ids = tf.TensorSpec(shape=[1,32], dtype=tf.int32), attention_mask = tf.TensorSpec(shape=[1,32], dtype=tf.int32))})
d_model = tf.saved_model.load("./model")

  return py_builtins.overload_of(f)(*args)


In [41]:
dir(d_model)

['__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_add_trackable_child',
 '_add_variable_with_custom_getter',
 '_checkpoint_dependencies',
 '_default_save_signature',
 '_deferred_dependencies',
 '_delete_tracking',
 '_deserialization_dependencies',
 '_deserialize_from_proto',
 '_export_to_saved_model_graph',
 '_gather_saveables_for_checkpoint',
 '_handle_deferred_dependencies',
 '_lookup_dependency',
 '_map_resources',
 '_maybe_initialize_trackable',
 '_name_based_attribute_restore',
 '_name_based_restores',
 '_no_dependency',
 '_object_identifier',
 '_preload_simple_restoration',
 '_restore_from_tensors',
 '_self_name_based_restores',
 '_self_saveable_obje

In [42]:
d_model.signatures["generates"]

<ConcreteFunction signature_wrapper(*, attention_mask, input_ids) at 0x7FE547B4AA90>

In [50]:
gen = d_model.signatures["generates"]
for input_prompt in input_prompts:
    tokenized_inputs = tokenizer([input_prompt], **tokenization_kwargs)
    start = time.time_ns()
    print(tokenized_inputs)
    generated_text = gen(**tokenized_inputs) #generated_text = model.generate(**tokenized_inputs)
    end = time.time_ns()
    print("xxxxxxxxx")
    print(generated_text)
    print("xxxxxxxxx")
    # decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)
    decoded_text = tokenizer.decode(generated_text['output_0'].numpy()[0], skip_special_tokens=True)  # here should get the key with the tensor 'output_0'
    print(f"Original prompt -- {input_prompt}")
    print(f"Generated -- {decoded_text}")
    print(f"Execution time -- {(end - start) / 1e6:.1f} ms\n")

{'input_ids': <tf.Tensor: shape=(1, 32), dtype=int32, numpy=
array([[13959,  1566,    12,  2968,    10,    27,    43,   662, 10003,
           11,   386,  3887,     5,     1,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0]], dtype=int32)>, 'attention_mask': <tf.Tensor: shape=(1, 32), dtype=int32, numpy=
array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)>}
xxxxxxxxx
{'output_0': <tf.Tensor: shape=(1, 20), dtype=int32, numpy=
array([[    0,  1674,  2010,  7629, 16699,    35,    64,  4052, 14216,
            5,     1,     0,     0,     0,     0,     0,     0,     0,
            0,     0]], dtype=int32)>}
xxxxxxxxx
Original prompt -- translate English to German: I have four cats and three dogs.
Generated -- Ich habe vier Katzen und drei Hunde.
Execution time -- 379.5 ms

{'input_ids': <tf.Tensor: shape=(1, 32), dtype=int32

In [17]:
infer = d_model.signatures["serving_default"]
for input_prompt in input_prompts:
    tokenized_inputs = tokenizer([input_prompt], **tokenization_kwargs)
    start = time.time_ns()
    print(tokenized_inputs)
    generated_text = infer(**tokenized_inputs) #generated_text = model.generate(**tokenized_inputs)
    end = time.time_ns()
    decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)
    print(f"Original prompt -- {input_prompt}")
    print(f"Generated -- {decoded_text}")
    print(f"Execution time -- {(end - start) / 1e6:.1f} ms\n")

{'input_ids': <tf.Tensor: shape=(1, 32), dtype=int32, numpy=
array([[13959,  1566,    12,  2968,    10,    27,    43,   662, 10003,
           11,   386,  3887,     5,     1,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0]], dtype=int32)>, 'attention_mask': <tf.Tensor: shape=(1, 32), dtype=int32, numpy=
array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)>}


TypeError: ignored

In [None]:
# Stand-alone TF XLA generate example for Decoder-Only Models.

# Note: execution times are deeply dependent on hardware.
# If you have a machine with a powerful GPU, I highly recommend you to try this example there!
import time
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForCausalLM

# 1. Load model and tokenizer
model_name = "gpt2"
# remember: decoder-only models need left-padding
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", pad_token="</s>")
model = TFAutoModelForCausalLM.from_pretrained(model_name)

# 2. Prepare tokenization and generation arguments -- don't forget padding to avoid retracing!
tokenization_kwargs = {"pad_to_multiple_of": 32, "padding": True, "return_tensors": "tf"}
generation_kwargs = {"num_beams": 4, "max_new_tokens": 32}

# 3. Create your XLA generate function a̶n̶d̶ ̶m̶a̶k̶e̶ ̶P̶y̶T̶o̶r̶c̶h̶ ̶e̶a̶t̶ ̶d̶u̶s̶t̶
# This is the only change with respect to original generate workflow!
xla_generate = tf.function(model.generate, jit_compile=True)

# 4. Generate! Remember -- the first call will be slow, but all subsequent calls will be fast if you've done things right.
input_prompts = [f"The best thing about {country} is" for country in ["Spain", "Japan", "Angola"]]
for input_prompt in input_prompts:
    tokenized_inputs = tokenizer([input_prompt], **tokenization_kwargs)
    start = time.time_ns()
    generated_text = xla_generate(**tokenized_inputs, **generation_kwargs)
    end = time.time_ns()
    decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)
    print(f"Original prompt -- {input_prompt}")
    print(f"Generated -- {decoded_text}")
    print(f"Execution time -- {(end - start) / 1e6:.1f} ms\n")