# Accelerating Generation of the Hugging Face Gemma Model with Transformer Engine

Generative AI has made remarkable strides in recent years, with Large Language Models (LLMs) like ChatGPT at the forefront. These models have revolutionized how we interact with machine-generated content, providing capabilities that range from writing assistance to complex decision support. The core functionality of these models is the generation process, which involves predicting the next token in a sequence based on the preceding text. This task is critical for applications such as automated content creation, translation, and more, emphasizing the importance of efficient implementation.

For those seeking a deeper understanding of text generation mechanisms in Transformers, it is recommended to check out the [HuggingFace generation tutorial](https://huggingface.co/docs/transformers/llm_tutorial).

In our previous tutorials on [Llama](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb) and [Gemma](./tutorial_accelerate_hf_gemma_with_te.ipynb), we demonstrated how finetuning can be accelerated using the Transformer Engine's `TransformerLayer`. Building on this foundation, our current objective is to enhance the generation speed of the Gemma model.

This tutorial will introduce and explain several advanced features of the Transformer Engine that contribute to this goal:

##### 1. THD Attention Layout.

Addressing the challenge of computing attention for sequences with varying lengths, a common method is to pad these sequences and apply an attention mask. The Transformer Engine, however, offers a more optimized approach—by specifying the lengths and offsets of the sequences, attention can be computed directly. Instead of passing the matrix and mask with the shape `[b, s, h, d]`, one can pass a matrix of the shape `[t, h, d]` along with tensors detailing sequence lengths and offsets to run the attention optimized for this case. This specific attention layout is referred to as the **THD layout**.

<center>
<img src="./media/atn1.png" alt="" width= "400"><br>
Fig. 1. The sequences and the mask for standard attention layout - padding from the end.<br><br>
<img src="./media/atn2.png" alt="" width="400"><br>
Fig. 2. The sequences and the mask for standard attention layout - padding from the beginning.<br><br>
<img src="./media/atn3.png" alt="" width="400"><br>
Fig. 3. An attention with thd layer.<br><br>
</center>

##### 2. FP8 Weight Calibration.

Assuming that we have a model trained in FP32/BF16 precision and we wish to execute it in FP8 precision, the process isn't straightforward due to the absence of appropriate FP8 scaling factors. In this scenario, FP8 calibration becomes essential. By conducting several forward passes on sample data, we can compute the FP8 saling parameters. This calibration allows the model to operate correctly in FP8 precision.

We highly recommend familiarizing yourself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the importance of proper scaling factors.

##### 3. CUDA Graphs API.

The speed of GPUs is increasing at a rapid pace. It turns out that sometimes the runtime of kernels is shorter than the time it takes for the CPU to submit them, which can lead to significant overhead. CUDA Graphs were developed to address this issue. When certain kernels are executed repeatedly, this tool allows us to record and replay them without CPU involvement. This becomes particularly useful in applications like text generation, where a `TransformerLayer` is run for every token that needs to be generated.

We recommend reading further about CUDA Graphs [here](https://developer.nvidia.com/blog/cuda-graphs/).

PyTorch exposes graphs via a raw `torch.cuda.CUDAGraphclass` and two convenience wrappers, `torch.cuda.graph` and `torch.cuda.make_graphed_callables`. More information about the cuda graphs in Pytorch can be found [here](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/).

Transformer Engine supports cuda graphs from version 1.5.

#### Benchmarking

We'll evaluate the generation time across three benchmarks:
- Long input sequences (up to 256 tokens) with short generation (up to 128 tokens),
- Short input sequences (up to 64 tokens) with long generation (up to 1000 tokens).

All benchmarks are conducted with a batch size of 64 using the dataset "timdettmers/openassistant-guanaco".

<div class="alert alert-info">
<b>Note</b>
    
This tutorial focuses on showcasing the mentioned features of Transformer Engine in the context of generation. It's important to note, however, that NVIDIA provides another library, [TensorRT](https://developer.nvidia.com/tensorrt), which is optimized for inference tasks and should be considered for such use cases.
</div>

## Dependencies for this tutorial

Following files and media are necessary to effectively run this tutorial:

1. `te_gemma.py`
    - This file contains the code to load a Hugging Face Gemma checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `GemmaDecoderLayer`. It does also contain code for generation with THD attention and weight calibration.
2. `utils.py`
    - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. 
3. `media/`
    - This directory contains the images used in the following tutorial.

## [Baseline] Running Hugging Face generation with Gemma model

HuggingFace Transformers library offers generation API. We will use HuggingFace generation for the Gemma model as our baseline.

In [None]:
# Restart the notebook (to flush the GPU memory)
from utils import restart_jupyter_notebook
#restart_jupyter_notebook()

# Import necessary packages and methods
from utils import *

# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`
## !!! `model_name` attr must point to the location of the model weights !!!
## Weights can be downloaded from: https://huggingface.co/google/gemma-7b
hyperparams.model_name = "../../../../gemma-weights"  # <== Add model weight location here e.g. "/path/to/downloaded/gemma/weights"
hyperparams.mixed_precision = "bf16"

model = init_te_gemma_model(hyperparams).cuda()

generate_sample_text(model)
benchmark_generation(model)

We put these times into the table for later comparison.

| Models                                                      | max_input_len=64, max_new_tokens=1000 | max_input_len=128, max_new_tokens=128 |  
|-------------------------------------------------------------|---------------------------------------|--------------------------------------|
| HF (baseline)                                               | -      | -                         |  

## [Improvement 1] Speeding up generation by using Transformer Engine with THD attention

Similarly to the Gemma tutorial, we substitute `GemmaDecoderLayer` with `TransformerLayer` from Transformer Engine. 

Input sequences can have various lengths. The most common approach is to use the padding and attention masks in such situation. We will use more straightforward method - using the THD attention layout with offests. 

<center>
<span style="display: flex; flex-direction: row; justify-content: center">
<span style="display: flex; flex-direction: column; align-items: center">
Query layer   
<img src="./media/pic1.png" alt="" height="200">
</span>
<span style="display: flex; flex-direction: column; align-items: center">
Key layer and value layer  
<img src="./media/pic2.png" alt="" height="200">
</span>
</span>
cu_seqlens_q = [0, 1, 3, 7, 9, 12] <br>
cu_seqlens_kv = [0, 1, 3, 6, 8, 10] <br>
seq_offsets_q = [0, 5, 10, 15, 20, 25] * h * d <br>
seq_offsets_k = [0, 7, 14, 21, 28, 35] * h * d <br>
seq_offsets_v = [0, 7, 14, 21, 28, 35] * h * d <br>
</center>

The class `transformer_engine.DotProductAttention` supports this format. One need to pass the following things as the arguments to the forward:
- `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` - which represents the offsets of the beginnings of the next sequences,
- `cu_seqlens_q`, `cu_seqlens_kv` - cumulative sum of the lengths of the sequences of query and values,
- `max_seqlen_q` - maximum sequence length in query layer,
- `max_seqlen_kv` - maximum sequence length in key-value layer.

<div class="alert alert-info">

<b>Note</b>
Currently, the THD attention for `TransformerLayer` is supported only for inference.
</div>

Let's look how using TransformerEngine with THD attention impacts the speed of generation:

In [None]:
# Import necessary packages and methods
from utils import *

# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`
## !!! `model_name` attr must point to the location of the model weights !!!
## Weights can be downloaded from: https://huggingface.co/google/gemma-7b
hyperparams.model_name = "../../../../gemma-weights"  # <== Add model weight location here e.g. "/path/to/downloaded/gemma/weights"
hyperparams.mixed_precision = "bf16"
hyperparams.fuse_qkv_params = False

# Init the model and accelerator wrapper
model = init_te_gemma_model(hyperparams).to(torch.bfloat16).cuda()

generate_sample_text(model)
benchmark_generation(model)

By using THD attention we obtained following speedups:

| Models                                                      | max_input_len=64, max_new_tokens=1000 | max_input_len=128, max_new_tokens=128 |  
|-------------------------------------------------------------|---------------------------------------|--------------------------------------|
| HF (baseline)                                               | -      | -                         |
| THD attention with TE                                               | -      | -                         |  

## [Improvement 2] Running generation in FP8 of the model trained in higher precision 

We are now preparing to execute FP8 generation using the Gemma model. However, this process is not straightforward. Since the model was originally trained with BF16 precision, the FP8 scaling factors have not been computed. Operating the model at such low precision without the correct scaling could result in significant numerical errors, which in turn would produce incorrect results.

We highly recommend familiarizing yourself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the necessity of scaling.

##### Weight Calibration

To address the issue outlined above, we will implement weight calibration. This involves running several forward iterations at BF16 precision within the context `te.fp8_autocast(enabled=False, calibration=True)`. This setup allows the forward pass to operate at higher precision, while we simultaneously collect `amax_history` and other parameters related to the FP8 precision, which is essential for calculating the FP8 scaling factors.

The code below outlines the steps to initialize the BF16 model and conduct several forward iterations within the specified context. After these iterations, we save the model, and these weights will be utilized in subsequent chapters.

In [None]:
# Import necessary packages and methods
import transformer_engine.pytorch as te
from utils import *
import torch


hyperparams.model_name = "../../../../gemma-weights"
hyperparams.fuse_qkv_params = True
model = init_te_gemma_model(hyperparams, fp8_model_init=False).cuda()
model = model.to(torch.bfloat16)
accelerator = Accelerator(
        log_with="wandb",
        gradient_accumulation_steps=hyperparams.gradient_accumulation_steps,
        mixed_precision=hyperparams.mixed_precision
    )
train_dataloader = get_dataloaders(accelerator, hyperparams)

tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name)

print("Calibration started")
with te.fp8_autocast(enabled=False, calibrating=True):
    model.train()
    train_dataloader = enumerate(train_dataloader)

    for i in range(100):
        step, batch = next(train_dataloader)
        batch["input_ids"] = batch["input_ids"].cuda()
        outputs = model.generate(
            **batch,
            max_new_tokens=10
        )
        generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print("calibration_finished")

print("scale_fwd computation started")
with te.fp8_autocast(enabled=True):
    for i in range(10):
        step, batch = next(train_dataloader)
        batch["input_ids"] = batch["input_ids"].cuda()
        outputs = model.generate(
            **batch,
            max_new_tokens=1
        )
print("scale_fwd_computation ended")

print("Casting weights...")
model_fp8 = init_te_gemma_model(hyperparams, fp8_model_init=True).cuda()
model_fp8.load_state_dict(model.state_dict())
print("Weights casted")

print("Saving model...")
torch.save(model_fp8.state_dict(), 'model_fp8_state_dict.pth')
print("Model saved!")

#### Generation in FP8

Now we are ready to run FP8 inference.

In [None]:
#Restart the notebook (to flush the GPU memory)
from utils import restart_jupyter_notebook
#restart_jupyter_notebook()

from utils import *

hyperparams.model_name = "../../../../gemma-weights"
hyperparams.fuse_qkv_params = True
model = init_te_gemma_model(hyperparams, fp8_model_init=True, qkv_format="thd").cuda()

# Load weights of the model with the proper scaling factors.
model.load_state_dict(torch.load('model_fp8_state_dict.pth'))

generate_sample_text(model, fp8=True)
benchmark_generation(model, fp8=True)

We add the speedups to the table:

| Models                                                      | max_input_len=64, max_new_tokens=1000 | max_input_len=128, max_new_tokens=128 |  
|-------------------------------------------------------------|---------------------------------------|--------------------------------------|
| HF (baseline)                                               | -      | -                         |
| THD attention with TE                                               | -      | -                         | 
| THD attention + FP8 with TE                                               | -      | -                         |  

## [Improvement 3] Speeding up generation with CUDA Graphs

TransformerEngine includes a function `transformer_engine.pytorch.make_graphed_callables`, which functions similarly to the corresponding feature in PyTorch. It is capable of recording any modules from the Transformer Engine. Below is a code excerpt from `te_gemma.py`:
```
        generator = GemmaGenerator(
            lm_head=self.lm_head,
            model=self.model, 
            inference_params=inference_params, 
            generation_config=generation_config, 
            dtype=hidden_states.dtype,
        )

        (...)
        if use_cuda_graphs:
            fp8_format = Format.HYBRID
            fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
            graphed_generator = te.pytorch.make_graphed_callables(
                generator, 
                args, 
                fp8_enabled=True, 
                fp8_recipe=fp8_recipe, 
                allow_unused_input=True,
                num_warmup_iters=10
            )
            
        (...)

        for i in range(max_new_tokens):
            next_tokens = graphed_generator(*args) if use_cuda_graphs else generator(*args)
            output_tokens.append(next_tokens.clone())
```

Let us now proceed to evaluate the performance improvement offered by CUDA Graphs.

In [None]:
#Restart the notebook (to flush the GPU memory)
from utils import restart_jupyter_notebook
#restart_jupyter_notebook()

from utils import *

hyperparams.model_name = "../../../../gemma-weights"
hyperparams.fuse_qkv_params = True
model = init_te_gemma_model(hyperparams, fp8_model_init=True, qkv_format="thd").cuda()

# Load weights of the model with the proper scaling factors.
model.load_state_dict(torch.load('model_fp8_state_dict.pth'))

generate_sample_text(model, fp8=True, use_cuda_graphs=True)
benchmark_generation(model, fp8=True, use_cuda_graphs=True)

We finally obtained the **??%** speedup.

| Models                                                      | max_input_len=64, max_new_tokens=1000 | max_input_len=128, max_new_tokens=128 |  
|-------------------------------------------------------------|---------------------------------------|--------------------------------------|
| HF (baseline)                                               | -      | -                         |
| THD attention with TE                                               | -      | -                         | 
| THD attention + FP8 with TE                                               | -      | -                         |  
| THD attention + FP8 + Cuda Graphs with TE                                               | -      | -                         |  

We can also see how use of graphs reduced CPU overhead. Here are two screenshots from the profiler:

<center>
<img src="./media/pic2.png" alt="Logo Pythona" height="200">
<br>
Generation without CUDA Graphs
<br>

<img src="./media/pic2.png" alt="Logo Pythona" height="200">
<br>
Generation with CUDA Graphs
</center>

## Conclusions

In this tutorial, we've explored three features of the Transformer Engine:
1. Support for the THD attention layout,
2. FP8 weights calibration,
3. Integration with CUDA Graphs.

Each of these features can be applied in various contexts, and here we demonstrated their use for achieving fast inference. It's important to note that the fastest possible inference speeds can be achieved using NVIDIA's inference-optimized [TensorRT](https://developer.nvidia.com/tensorrt) library.