# Greedy Decoding with T5 models on Trn1 or Inf2

## Introduction

In this tutorial we will compile and deploy a pretrained T5 model for accelerated inference on Neuron. 

This tutorial will use the [t5-large](https://huggingface.co/t5-large) model. The T5 model can be used for machine translation, document summarization, question answering, and classification tasks. 

This tutorial has the following main sections:

1. Install dependencies
1. Compile the T5 model
1. Run inference with greedy decoding on Neuron

This Jupyter notebook should be run on a Trn1 instance (`trn1.2xlarge` or larger.) or Inf2 instance (`inf2.xlarge` or larger.)

## Install dependencies

The code in this tutorial is written for Jupyter Notebooks. To use Jupyter Notebook on the Neuron instance, you
can use this [guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/setup/notebook/setup-jupyter-notebook-steps-troubleshooting.html).

This tutorial requires the following pip packages:

- `torch-neuronx`
- `neuronx-cc`
- `transformers`
- `optimum-neuron`

Most of these packages will be installed when configuring your environment using the Trn1/Inf2 setup guide. The additional dependencies must be installed here:

In [None]:
!pip install --upgrade transformers==4.31.0 optimum-neuron==0.0.8


🤗 Optimum Neuron is the interface between the 🤗 Transformers library and AWS Accelerators including AWS Trainium and AWS Inferentia. It provides a set of tools enabling easy model loading, training and inference on single- and multi-Accelerator settings for different downstream tasks. In this tutorial we use 🤗 HuggingFace Optimum Neuron's generate() method instead of 🤗 [transformers's generate()](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate) to perform greedy decoding. Optimum Neuron takes care of padding the inputs which is necessary to infer on Neuron.


## Compile the model into an AWS Neuron optimized TorchScript

In the following section, we load the T5 model, compile the model's encoder and decoder for Neuron using `torch_neuronx.trace()`, and save the optimized encoder and decoder as `TorchScript`. 

`torch_neuronx` can only trace functions with positional arguments. The T5 encoder and decoder both use keyword arguments. To trace them, we write wrappers that convert keyword arguments to positional arguments. 

In [None]:
import torch
from transformers.models.t5.modeling_t5 import T5Stack

class EncoderWrapper(torch.nn.Module):
    '''
        We will trace an instance of the EncoderWrapper. 
        This wrapper just converts positional args to kwargs. 
    '''

    def __init__(self, encoder: T5Stack):
        super().__init__()
        self.encoder = encoder
    
    def forward(self, input_ids, attention_mask):
        
        # This is the core functionality we want to trace. 
        return self.encoder(input_ids=input_ids,
                            attention_mask=attention_mask,
                            output_attentions=False,
                            output_hidden_states=False)


In the decoder wrapper, in addition to converting keyword arguments to positional arguments we add support for attention caching. Generating text from the encoder decoder models is an autoregressive process. For each invocation, we have to compute the key and value states of the attention heads repeatedly. To improve the performance, we cache the key and value states. This cache is what HuggingFace transformers code refers to as `past_key_values`.

In HuggingFace transformers, the `past_key_values` are updated outside the decoder. This works for training and evaluation but for inference we want to perform them within a single trace. This way, we can optimize across both the decoder execution and cache update. So, we move the cache update within the decoder wrapper.

In [3]:
from torch.nn import Parameter

class DecoderWrapper(torch.nn.Module):

    def __init__(self, 
                 decoder: T5Stack, 
                 lm_head: torch.nn.Linear,
                 model_config,
                 num_beams: int, 
                 max_length: int,
                 device: str):
                 
        super().__init__()
        self.decoder = decoder
        self.lm_head = lm_head
        self.model_dim=model_config.d_model
        self.device = device
        self.num_beams = num_beams

        num_heads=model_config.num_heads
        num_decoder_layers=model_config.num_decoder_layers

        # Iniitialize the cache
        shape = (num_beams,num_heads,max_length,model_config.d_kv)
        if device == "cpu":
            cache = []
            for _ in range(num_decoder_layers * 4):
                cache.append(torch.ones(shape, dtype=torch.float32))
            self.past_key_values = cache
        elif device == "xla":
            cache = []
            for _ in range(num_decoder_layers * 4):
                cache.append(Parameter(torch.ones(shape, dtype=torch.float32), 
                                      requires_grad=False))    
            self.past_key_values = torch.nn.ParameterList(cache)

    # We add the states for the newly generated token to the cache
    def update_past(self, past_key_values):
        new_past = []
        for past_layer in past_key_values:
            new_past_layer = list(past_layer)
            # We just need to update the self attention cache
            for i in range(len(new_past_layer[:2])):
                new_past_layer[i] = past_layer[i][:, :, 1:]
            new_past += [new_past_layer,]
        return new_past

    def forward(self,
                input_ids,
                decoder_attention_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                beam_idx,
                **kwargs):

        past_key_values = self.past_key_values

        # The cache is stored in a flatten form. 
        # We order the cache per layer before passing it to the decoder. 
        # Each layer has 4 tensors, so we group by 4. 
        past_key_values = [past_key_values[i*4:i*4+4] for i in range(0, int(len(past_key_values)/4))]

        decoder_output = self.decoder(
            input_ids=input_ids,
            attention_mask=decoder_attention_mask,
            past_key_values=past_key_values,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=True,
            output_attentions=False,
            output_hidden_states=False)

        last_hidden_state = decoder_output['last_hidden_state']
        past_key_values = decoder_output['past_key_values']

        last_hidden_state = last_hidden_state * (self.model_dim**-0.5)
        lm_logits = self.lm_head(last_hidden_state)

        past_key_values = self.update_past(past_key_values)

        # We flatten the cache to a single array. 
        # This is required for the input output aliasing to work
        past_key_values = [vec for kv_per_layer in past_key_values for vec in kv_per_layer]

        if self.device == "cpu":
            self.past_key_values = past_key_values

        return [lm_logits] + past_key_values



The key value cache that the decoder uses has both the self attention and cross attention states. While the self attention states are updated with each decoder call, the cross attention states remain unchanged. So, the cross attention states have to be computed just once. Transformer's t5 model initiializes the cross attention cache state on the first decoder invocation, but this is harder to do on Neuron. Similar to `torch.jit.trace()` the `torch_neuronx.trace()` produces a function that has a fixed control flow, i.e. there are no conditional executions. So we cannot choose to conditionally initialize the cache in the first iteration. Instead, we can compute the initial cache state outside the generation flow and pass the cache to it. To do so, we create a cache initalizer. This function will be run once before each generation to get the cache state. 

In [4]:
from transformers.models.t5.modeling_t5 import T5LayerCrossAttention

class CacheInitializer(torch.nn.Module):
    '''
        Cache initializer is used once per input to compute the 
        cross attention key and value states. 
    '''

    def __init__(self, decoder, model_config, batch_size, max_length, device):
        super().__init__()
        self.decoder = decoder
        self.batch_size = batch_size
        self.max_length = max_length
        self.model_config = model_config
        self.device = device

    def forward(self, encoder_hidden_states):
        decoder_blocks = self.decoder.block

        present_key_value_states = []
        for block in decoder_blocks:

            # Self attention kv states are initialized to zeros.
            self_attn_kv_state = list(
                torch.zeros((self.batch_size, self.model_config.num_heads, 
                             self.max_length, self.model_config.d_kv), 
                            dtype=torch.float32, 
                            device=self.device) for i in range(0, 2))

            # Cross attention has to be initialized with the encoder hidden state
            cross_attention: T5LayerCrossAttention = block.layer[1]
            attention = cross_attention.EncDecAttention

            def shape(states):
                return states.view(self.batch_size, -1, attention.n_heads, 
                                   attention.key_value_proj_dim).transpose(1, 2)

            key_states = shape(attention.k(encoder_hidden_states))
            value_states = shape(attention.v(encoder_hidden_states))
            cross_attn_kv_state = [key_states, value_states]

            # We add concatenate the self and cross attentions to create 
            # the KV cache for the decoder block
            present_key_value_state = self_attn_kv_state + cross_attn_kv_state
            present_key_value_states = present_key_value_states + present_key_value_state

        return present_key_value_states

Now let's create a T5 model wrapper to make it compatible with our traced encoder and decoder. 

There are two reasons for having this wrapper, 

1. The encoder and decoder traces can only be invoked with positional arguments. But the HuggingFace transformers code is written with keyword arguments. So we override the functions that invoke encoder and decoder to call with positional arguments. 
1. The generate() function in the NeuronGenerationMixin performs cache update within the CPU. As we are handling the cache within the DecoderWrapper, we disable the cache update on CPU. 

Let's also override the `generate()` function so that it will intialize the cache using the cache initalizer before starting the greedy decoding.

In [5]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers.modeling_outputs import Seq2SeqLMOutput
from transformers.generation.utils import ModelOutput
from typing import Any, Dict, Optional, Tuple, Union

from optimum.neuron.generation import NeuronGenerationMixin

class T5Wrapper(T5ForConditionalGeneration, NeuronGenerationMixin):

    def _prepare_encoder_decoder_kwargs_for_generation(
        self, 
        inputs_tensor: torch.Tensor, 
        model_kwargs, 
        model_input_name: Optional[str] = None
    ) -> Dict[str, Any]:
        encoder = self.get_encoder()
        model_kwargs["encoder_outputs"]: ModelOutput = encoder(inputs_tensor, 
                                                               model_kwargs["attention_mask"])
        return model_kwargs

    # Override to cut the input_ids to just last token
    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        decoder_attention_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs,
    ):
        # cut decoder_input_ids as past is cached
        input_ids = input_ids[:, -1:]

        return {
            "decoder_input_ids": input_ids,
            "past_key_values": past_key_values,
            "encoder_outputs": encoder_outputs,
            "attention_mask": attention_mask,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "decoder_attention_mask": decoder_attention_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,
        }
    
    '''
        We update the cache in the decoder trace, 
        so let's disable the update in NeuronGenerationMixin
    '''
    def _update_model_kwargs_for_xla_generation(
        self,
        outputs: ModelOutput,
        model_kwargs: Dict[str, Any],
        batch_size: int,
        is_encoder_decoder: bool = False,
        standardize_cache_format: bool = False,
        max_length: Optional[int] = None,
        seq_length: Optional[int] = None,
        use_cache: bool = True,
    ) -> Dict[str, Any]:

        def _update_attention(model_kwargs, is_encoder_decoder):
            """Updates the appropriate attention mask -- encoder-decoder models use `decoder_attention_mask`"""

            attention_mask_name = "decoder_attention_mask" if is_encoder_decoder else "attention_mask"
            attention_mask = model_kwargs.pop(attention_mask_name)
            attention_mask_update_slice = torch.ones(
                (batch_size, 1), 
                dtype=attention_mask.dtype, 
                device=attention_mask.device
            )
            attention_mask = torch.cat([attention_mask[:, 1:], 
                                        attention_mask_update_slice], dim=-1)
            mask = {attention_mask_name: attention_mask}
            return mask

        mask = _update_attention(model_kwargs, is_encoder_decoder)
        # sets the updated variables (mask and past_key_values)
        model_kwargs.update(mask)

        # Set a mock cache tensor for NeuronGenerationMixin
        model_kwargs["past_key_values"] = torch.tensor([])

        return model_kwargs

    def generate(self,
                 cache_initializer,
                 tokenizer: T5Tokenizer,
                 prompt: str,
                 max_length: int,
                 num_beams: int,
                 num_return_sequences: int,
                 device: str):

        batch_encoding = tokenizer(prompt, 
                                   max_length=max_length, 
                                   truncation=True, 
                                   padding='max_length',
                                   return_tensors="pt")

        encoder_output = self.encoder(batch_encoding['input_ids'],
                                      batch_encoding['attention_mask'])
        last_hidden_state = encoder_output["last_hidden_state"]
        
        encoder_hidden_states = torch.concat(
            [tensor.unsqueeze(0).repeat(num_beams, 1, 1) for tensor in last_hidden_state])

        # Initialize the cache and mask
        past_key_values = cache_initializer(encoder_hidden_states)
        decoder_attention_mask = torch.cat([torch.zeros((1, max_length), dtype=torch.int32),
                                            torch.ones((1, 1), dtype=torch.int32)], axis=1)

        # copy the new cache state to the decoder
        if device == "xla":
            for state, tensor in zip(self.decoder.parameters(), past_key_values):
                state.copy_(tensor)
        else:
            self.decoder.past_key_values = past_key_values
        
        output = super().generate(**batch_encoding,
                                  max_length=max_length,
                                  num_beams=num_beams,
                                  num_return_sequences=num_return_sequences,
                                  do_sample=False,
                                  use_cache=True,
                                  decoder_attention_mask=decoder_attention_mask)

        return output

    def forward(
        self,
        attention_mask: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        **kwargs
    ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:

        hidden_states = encoder_outputs["last_hidden_state"]

        if not hasattr(self, 'beam_idx'):
            # Infering the number of beams from the attention mask
            num_beams = attention_mask.shape[0]
            self.beam_idx = torch.arange(0, num_beams, dtype=torch.int64)

        decoder_outputs = self.decoder(
            decoder_input_ids,
            decoder_attention_mask,
            hidden_states,
            attention_mask,
            self.beam_idx
        )

        lm_logits = decoder_outputs[0]

        return Seq2SeqLMOutput(logits=lm_logits)



Now let's test inference on CPU with all the wrappers before tracing.

In [6]:
# Let's set some run parameters

model_name = "t5-large"
num_beams = 1
num_return_sequences = 1
max_length = 128

In [7]:
from transformers import T5Tokenizer


prompt="translate English to German: Lets eat good food."
        
tokenizer = T5Tokenizer.from_pretrained(model_name, model_max_length=max_length)
model = T5Wrapper.from_pretrained(model_name)

cache_initializer = CacheInitializer(model.decoder, model.config, num_beams, max_length, "cpu")
model.encoder = EncoderWrapper(model.encoder)
setattr(model.encoder, 'main_input_name', 'input_ids')  # Attribute required by beam search

model.decoder = DecoderWrapper(decoder=model.decoder,
                                lm_head=model.lm_head,
                                model_config=model.config,
                                num_beams=num_beams,
                                max_length=max_length,
                                device="cpu")

output = model.generate(tokenizer=tokenizer,
                        cache_initializer=cache_initializer,
                        prompt=prompt,
                        max_length=max_length,
                        num_beams=num_beams,
                        num_return_sequences=num_return_sequences,
                        device="cpu")

results = [tokenizer.decode(t, skip_special_tokens=True) for t in output]

print('Results:')
for i, summary in enumerate(results):
    print(i + 1, summary)


Results:
1 Lassen Sie uns gutes Essen essen.


Now that the wrappers are running as expected, let's trace the encoder, decoder and the cache initializer. To trace these functions, we pass the function and a sample input to the trace function. The result of the trace stage will be a static executable where the operations to be run upon inference are determined during compilation. This means that when inferring, the resulting Neuron model must be executed with tensors that are the exact same shape as those provided at compilation time. If a model is given a tensor at inference time whose shape does not match the tensor given at compilation time, an error will occur.

The decoder wrapper returns the new state of the cache as an output which is copied back to the CPU. As the cache is a large tensor, copying it to and from the XLA device for each decoder invocation will significantly slow down the inference. Instead, we can use input output aliasing, a feature of `torch_neuronx` to keep these tensors on device rather than copying back to the CPU. To use input output aliasing, we need to map the outputs to input parameters while tracing. 

In [None]:
import torch
import torch_neuronx

from transformers import T5Tokenizer, T5ForConditionalGeneration

def trace_encoder(model: T5ForConditionalGeneration,
                  tokenizer: T5Tokenizer,
                  max_length: int):

    # Trace encoder
    batch_encoding = tokenizer("translate English to German: Lets go home now",
                               max_length=max_length,
                               truncation=True,
                               padding='max_length', 
                               return_tensors="pt")
    input_ids = batch_encoding['input_ids']
    attention_mask = batch_encoding['attention_mask']

    encoder = EncoderWrapper(model.encoder)
    traced_encoder = torch_neuronx.trace(encoder, 
                                         (input_ids, attention_mask), 
                                         compiler_workdir="/tmp/encoder/")
    # Attribute required by beam search
    setattr(traced_encoder, 'main_input_name', 'input_ids')  

    return traced_encoder


def trace_decoder(model: T5ForConditionalGeneration,
                  num_beams: int,
                  max_length: int):

    decoder = DecoderWrapper(decoder=model.decoder,
                             lm_head=model.lm_head,
                             model_config=model.config,
                             num_beams=num_beams,
                             max_length=max_length,
                             device="xla")

    # We create mock inputs so we can trace the decoder
    decoder_input_ids = torch.ones((num_beams, 1), dtype=torch.int64)
    decoder_attention_mask = torch.ones((num_beams, max_length + 1), dtype=torch.int32)
    encoder_attention_mask = torch.ones((num_beams, max_length), dtype=torch.int64)
    encoder_hidden_states = torch.ones((num_beams, max_length, model.config.d_model), dtype=torch.float32)

    beam_idx = torch.arange(0, num_beams, dtype=torch.int64)

    traced_decoder = torch_neuronx.trace(decoder, (
        decoder_input_ids,
        decoder_attention_mask,
        encoder_hidden_states,
        encoder_attention_mask,
        beam_idx
    ), 
    input_output_aliases={decoder.past_key_values[i]:i+1 for i in range(len(decoder.past_key_values))}, 
    compiler_workdir="/tmp/decoder/")

    return traced_decoder

def trace_cache_initializer(model: T5ForConditionalGeneration,
                            num_beams: int,
                            max_length: int):
        
    encoder_hidden_states = torch.ones((num_beams, 
                                        max_length, 
                                        model.config.d_model), dtype=torch.float32)

    cache_initializer = CacheInitializer(model.decoder,
                                         model.config, 
                                         num_beams, 
                                         max_length, 
                                         "xla")
    cache_initializer = torch_neuronx.trace(cache_initializer, 
                                            encoder_hidden_states, 
                                            compiler_workdir="/tmp/cache_init/")

    return cache_initializer



tokenizer = T5Tokenizer.from_pretrained(model_name, model_max_length=max_length)
model = T5ForConditionalGeneration.from_pretrained(model_name)

# We enable this flag to ensure model uses attention key value caching
model.config.use_cache = True

traced_encoder = trace_encoder(model, tokenizer, max_length)
traced_decoder = trace_decoder(model, num_beams, max_length)
traced_cache_initializer = trace_cache_initializer(model, num_beams, max_length)

torch.jit.save(traced_encoder, "TracedEncoder.pt")
torch.jit.save(traced_decoder, "TracedDecoder.pt")
torch.jit.save(traced_cache_initializer, "TracedCacheInitializer.pt")

## Run inference with greedy decoding
Now that we have the traced model, let's use it for inference. 

In [9]:
runtime = torch.classes.neuron.Runtime()
runtime.initialize()
runtime.set_default_neuron_cores(0, 1)

tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5Wrapper.from_pretrained(model_name)

model.encoder = torch.jit.load("TracedEncoder.pt")
# Attribute required by beam search
setattr(model.encoder, 'main_input_name', 'input_ids')  

model.decoder = torch.jit.load("TracedDecoder.pt")
torch_neuronx.move_trace_to_device(model.decoder, 0)

cache_init_trace = torch.jit.load("TracedCacheInitializer.pt")


output = model.generate(tokenizer=tokenizer,
                        cache_initializer=cache_init_trace,
                        prompt="translate English to German: Lets eat good food.",
                        max_length=max_length,
                        num_beams=num_beams,
                        num_return_sequences=num_return_sequences,
                        device="xla")

results = [tokenizer.decode(t, skip_special_tokens=True) for t in output]

print('XLA Texts:')
for i, summary in enumerate(results):
    print(i + 1, summary)



XLA Texts:
1 Lassen Sie uns gutes Essen essen.
