#  T5 inference with Tensor Parallelism

This is an extension to the [t5 inference tutorial](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/pytorch/torch-neuronx/t5-inference-tutorial.html). Here we will use NeuronxDistributed to improve the inference performance using tensor parallelism.

This tutorial has the following main sections:

1. Install dependencies
1. Plug in `NeuronxDistributed` layers into T5
1. Compile the T5 model
1. Run distributed infernece with beam search 

This Jupyter notebook should be run on a Inf2 instance (`inf2.24xlarge`) or Trn1 isntance (`trn1.32xlarge`)


> Do note that flan-t5 models do not work with the code in this tutorial. We are working on fixing that. 

## Install dependencies

The code in this tutorial is written for Jupyter Notebooks. To use Jupyter Notebook on the Neuron instance, you
can use this [guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/setup/notebook/setup-jupyter-notebook-steps-troubleshooting.html).

It is recommended to go through the [t5 inference tutorial](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/pytorch/torch-neuronx/t5-inference-tutorial.html) before you start this tutorial. 
In addition to the dependencies in the [t5 inference tutorial](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/pytorch/torch-neuronx/t5-inference-tutorial.html), we need to install neuronx-distributed. 

This tutorial requires the following pip packages:

- `torch-neuronx`
- `neuronx-cc`
- `transformers`
- `optimum-neuron`
- `neuronx-distributed`

Most of these packages will be installed when configuring your environment using the Trn1/Inf2 setup guide. The additional dependencies must be installed here:

In [None]:
! pip install --upgrade transformers==4.31.0 optimum-neuron==0.0.8 neuronx_distributed --extra-index-url https://pip.repos.neuron.amazonaws.com

## Plug in NeuronxDistributed layers into T5

We extend the huggingface's T5 model to use the `NeuronxDistributed` parallel layers. To do so, we simply swap linear layers in `T5LayerSelfAttention`, `T5LayerCrossAttention`, and `T5LayerFF` definitions with `ColumnParallelLinear` and `RowParallelLinear`. We also need to swap the `Embedding` layer with `ParallelEmbedding`.

Let us take the example of T5Attention. The [attention block](https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L363-L366) has q, k, v, and o linear layers. 
The multi-head attention block uses q, k and v to compute the attention scores. The attention scores are then passed through o to compute the attention block output. 
So let us swap q, k and v layers with `ColumnParallelLinear` and o with `RowParallelLinear`. Having `RowParallelLinear` following a `ColumnParallelLinear` is a performance optimization. The attention scores computed with q, k and v are already split across Neuron devices. The row parallel layer can use this shared output directly. 
The embedding layer is simply swapped with the `ParallelEmbedding`.

```
class ParallelAttention(T5Attention):
    def __init__(self, config: T5Config, has_relative_attention_bias=False):
        super().__init__(config, has_relative_attention_bias)
        # Per attention head and per partition values
        world_size = parallel_state.get_tensor_model_parallel_size()
        self.num_attention_heads_per_partition = divide(self.n_heads, world_size)
        self.hidden_size_per_partition = self.num_attention_heads_per_partition * self.key_value_proj_dim

        # Mesh TensorFlow initialization to avoid scaling before softmax
        self.q = ColumnParallelLinear(self.d_model,
                                      self.inner_dim,
                                      bias=False,
                                      gather_output=False)
        self.k = ColumnParallelLinear(self.d_model,
                                      self.inner_dim,
                                      bias=False,
                                      gather_output=False)
        self.v = ColumnParallelLinear(self.d_model,
                                      self.inner_dim,
                                      bias=False,
                                      gather_output=False)
        self.o = RowParallelLinear(self.inner_dim,
                                   self.d_model,
                                   bias=False,
                                   input_is_parallel=True)

        if self.has_relative_attention_bias:
            self.relative_attention_bias = ParallelEmbedding(self.relative_attention_num_buckets, self.n_heads)
        self.n_heads = self.num_attention_heads_per_partition
...
```

You can find the all modified T5 layers defined in [t5_model_layers.py](https://github.com/aws-neuron/aws-neuron-sdk/tree/master/src/examples/pytorch/neuronx_distributed/t5-inference/t5_model_layers.py).  


Once we have the modified T5 layers, we can plug in the T5Attention and T5LayerFF into the pretrained model. Here is how you do that. 

```
def load_pretrained_with_parallel_attn(model_name):
    
    model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")

    # Parallel implementation of Attention modules.
    from t5_model_layers import ParallelSelfAttention, ParallelFF, ParallelCrossAttention

    for index, block in enumerate(model.decoder.block):
        if index == 0:
            block.layer[0] = ParallelSelfAttention(model.config,
                                                   has_relative_attention_bias=True)
        else:
            block.layer[0] = ParallelSelfAttention(model.config)
        block.layer[1] = ParallelCrossAttention(model.config)
        block.layer[2] = ParallelFF(model.config)
    # Load the weights into the parallel layers        
    neuronx_distributed.parallel_layers.load(model_name + ".pt", model, sharded=False)

    return model

```


## Compile the parallel T5 model

Let us set some model parameters

In [None]:
model_name = "t5-3b"
max_length = 128
num_beams = 4
tp_degree = 8 # tensor parallelism degree

Download and save the model that we want to trace. 

In [None]:
import torch
from transformers import T5ForConditionalGeneration

model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
torch.save({"model":model.state_dict()}, model_name + ".pt")
model.config.use_cache = True

To run HuggingFace T5 models on Neuron, we need to make a couple of changes. Let us reuse the code from the [t5 inference tutorial](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/pytorch/torch-neuronx/t5-inference-tutorial.html) which makes T5 compatible with Neuron. For your convenience, the code copied into [wrapper.py](https://github.com/aws-neuron/aws-neuron-sdk/tree/master/src/examples/pytorch/neuronx_distributed/t5-inference/wrapper.py) and [t5_models.py](https://github.com/aws-neuron/aws-neuron-sdk/tree/master/src/examples/pytorch/neuronx_distributed/t5-inference/t5_models.py). This notebook will import these files. 

The only change made to this code is that we use `neuronx_distributed.trace` instead of `torch_neuronx.trace`. 

Let us trace the encoder and decoder. 

In [None]:
import t5_models  
import neuronx_distributed

traced_encoder = t5_models.parallel_trace_encoder(model_name, max_length, num_beams, tp_degree)
neuronx_distributed.trace.parallel_model_save(traced_encoder, "TracedParallelEncoder.pt")


In [None]:
traced_decoder = t5_models.parallel_trace_decoder(model, model_name, num_beams, max_length, tp_degree)
neuronx_distributed.trace.parallel_model_save(traced_decoder, "TracedParallelDecoder.pt")

## Inference with the traced parallel T5 model

With the traced model, let us try using beam search for inference.

In [None]:
import neuronx_distributed
from wrapper import T5Wrapper
from transformers import T5Tokenizer


num_return_sequences = 4

traced_encoder = neuronx_distributed.trace.parallel_model_load("TracedParallelEncoder.pt")
traced_decoder = neuronx_distributed.trace.parallel_model_load("TracedParallelDecoder.pt")

tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5Wrapper.from_pretrained(model_name)

model.encoder = traced_encoder
model.decoder = traced_decoder
setattr(model.encoder, 'main_input_name', 'input_ids')  # Attribute required by beam search

output = model.parallel_infer(tokenizer=tokenizer,
                              prompt="translate English to German: Lets eat good food.",
                              max_length=max_length,
                              num_beams=num_beams,
                              num_return_sequences=num_return_sequences,
                              device="xla")

results = [tokenizer.decode(t, skip_special_tokens=True) for t in output]

print('Results:')
for i, summary in enumerate(results):
    print(i + 1, summary)


Results:
1 Lassen Sie uns gutes Essen essen.
2 Lassen Sie uns gut essen.
3 Lassen Sie uns gutes Essen zu essen.
4 Lassen Sie uns gutes Essen zu sich nehmen.
