# Accelerating token generation of the Hugging Face Gemma Model with Transformer Engine

Generative AI has made remarkable strides in recent years, with Large Language Models (LLMs) like ChatGPT at the forefront. These models have revolutionized how we interact with machine-generated content, providing capabilities that range from writing assistance to complex decision support. The core functionality of these models is the generation process, which involves predicting the next token in a sequence based on the preceding text. This task is critical for applications such as automated content creation, translation, and more, emphasizing the importance of efficient implementation.



<figure align="center">
<img src="./media/generation_animation.gif" alt="" >
<figcaption>
Animation 1: Hugging Face Gemma model token generation.
</figcaption>
</figure>

For those seeking a deeper understanding of text generation mechanisms in Transformers, it is recommended to check out the [HuggingFace generation tutorial](https://huggingface.co/docs/transformers/llm_tutorial).

In the previous tutorials on [Llama](../te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb) and [Gemma](./tutorial_accelerate_hf_gemma_finetuning_with_te.ipynb), it was demonstrated how finetuning can be accelerated using the Transformer Engine's `TransformerLayer`. Building on this foundation, the current objective is to enhance the generation speed of the Gemma model.

This tutorial will introduce and explain several advanced features of the Transformer Engine that contribute to this goal:

###### **1. THD Attention Layout.**

Addressing the challenge of computing attention for sequences with varying lengths, a common method is to pad these sequences and apply an attention mask. The Transformer Engine, however, offers a more optimized approach—by specifying the lengths and offsets of the sequences, attention can be computed directly. Instead of passing the tensor with shape `[b, s, h, d]` and the attention mask, one can pass a tensor of the shape `[t, h, d]` along with tensors detailing cumulative sequence lengths and offsets to run the attention optimized for this case. This specific attention layout is referred to as the **THD layout**. 


The letter `t` in the standard `[t, h, d]` layout is equal to the total length of the sequences, namely `t = s_1 + s_2 + ... + s_b`, where `s_i` denotes the length of sequence `i`. TransformerEngine supports a THD layout that incorporates gaps between these sequences - the lengths of the offsets need to be passed in the additional parameter.

<figure align="center">
<img src="./media/thd_bshd.svg" alt="">
<figcaption>
Figure 1: The difference between BSHD (default) and THD attention layouts is as follows: with BSHD, one needs to provide the attention mask, while with THD, one needs to provide cumulative sequence lengths and sequence offsets.
</figcaption>
</figure>

###### **2. CUDA Graphs API.**

The speed of GPUs is increasing at a rapid pace. It turns out that sometimes the runtime of kernels is shorter than the time it takes for the CPU to submit them, which can lead to significant overhead. CUDA Graphs can address this issue. When certain kernels are executed repeatedly, it allows us to record and replay them with less CPU involvement. This becomes particularly useful in applications like token generation, where a `TransformerLayer` is run for every token that needs to be generated.

One can read more about CUDA Graphs [here](https://developer.nvidia.com/blog/cuda-graphs/).

PyTorch exposes graphs via a raw `torch.cuda.CUDAGraph` class and two convenience wrappers: `torch.cuda.graph` and `torch.cuda.make_graphed_callables`. More information about the cuda graphs in Pytorch can be found [here](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/).

<figure align="center">
<img src="./media/graphs.svg" alt="">
<figcaption>
Figure 2: CUDA Graphs reduce the overhead generated by the long time it takes to launch a single kernel. It enables the recording and replaying of subsequent launches, thus reducing the total time used by the CPU.
</figcaption>
</figure>


###### **3. FP8 Weights Calibration.**

Assuming that the model is trained in FP32/BF16 precision and the goal is to execute it in FP8 precision, the process isn't straightforward due to the absence of appropriate FP8 scaling factors. In this scenario, FP8 calibration becomes essential. By conducting several forward passes on sample data, the FP8 scaling parameters can be computed. This calibration allows the model to operate correctly in FP8 precision.

It is highly recommended to familiarize oneself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the importance of proper scaling factors.


<figure align="center">
<img src="./media/calibration.svg" alt="">
<figcaption>
Figure 3:
If the model is trained in BF16/FP32, it does not include the computed FP8 scaling factors. When it is run under <b>fp8_autocast()</b>, the value of these scaling factors will default to their initial values, which can cause numerical errors. Weight calibration involves calculating FP8 scaling factors from higher precision forward passes. Once these factors are computed, the model becomes numerically stable. 
</figcaption>
</figure>

###### **4. FP8 Model Weights.**

The typical approach is to store weights in higher precision and then cast them to fp8 before operations. This may prevent accuraccy drops in training. However, for inference, this level of precision is not necessary.

The TransformerEngine includes a wrapper `fp8_model_​init`, which allows for the creation of models that store only the FP8 copy of the weights. This eliminates the need to cast from higher precision to BF16, saving time in this casting process. 

<figure align="center">
<img src="./media/fp8_model_init.svg" alt="">
<figcaption>
Figure 4: Model under <b>fp8_autocast()</b> stores weights in high precision by default, and casts them if needed. It can leads to slowdown and increased memory usage. Using <i>fp8_model_init()</i> results in storing weight in FP8.
</figcaption>
</figure>

###### Benchmarking

We'll evaluate the generation time across one benchmark: generation with context phase max sequence length = 128, batch size = 64 and number of generated tokens = 896 on random texts with random lengths.

<div class="alert alert-info">
<b>Note</b>
    
This tutorial focuses on showcasing the mentioned features of Transformer Engine in the context of token generation. It's important to note, however, that NVIDIA provides [TensorRT](https://developer.nvidia.com/tensorrt), which is optimized for inference tasks and should be considered for such use cases.
</div>

## Dependencies for this tutorial

Following files and media are necessary to effectively run this tutorial:

1. `te_gemma.py`
    - This file contains the code to load a Hugging Face Gemma checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `GemmaDecoderLayer`. It does also contain code for generation with THD attention, CUDA Graphs and weight calibration.
2. `te_gemma_loading_weights.py`
    - This file contains logic of mapping the parameters from `GemmaDecoderLayer` into the `TransformerLayer`.
3. `utils.py`
    - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. 
4. `requirements.txt`
    - This file contains necessary Python packages for this tutorial.
5. `media/`
    - This directory contains the images used in the following tutorial.

In [1]:
%pip install -r requirements.txt

import torch
cudnn_version = torch.backends.cudnn.version()
assert cudnn_version >= 90100, "cuDNN version >= 9.1.0 is needed to run this tutorial."

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting transformers==4.41.1 (from -r requirements.txt (line 1))
  Downloading transformers-4.41.1-py3-none-any.whl.metadata (43 kB)
Collecting accelerate==0.30.1 (from -r requirements.txt (line 2))
  Downloading accelerate-0.30.1-py3-none-any.whl.metadata (18 kB)
Collecting datasets==2.19.1 (from -r requirements.txt (line 3))
  Downloading datasets-2.19.1-py3-none-any.whl.metadata (19 kB)
Collecting sentencepiece==0.2.0 (from -r requirements.txt (line 4))
  Downloading sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.7 kB)
Collecting huggingface-hub<1.0,>=0.23.0 (from transformers==4.41.1->-r requirements.txt (line 1))
  Downloading huggingface_hub-0.26.2-py3-none-any.whl.metadata (13 kB)
Collecting tokenizers<0.20,>=0.19 (from transformers==4.41.1->-r requirements.txt (line 1))
  Downloading tokenizers-0.19.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.


|
## [Baseline] Running Hugging Face generation with Gemma model

HuggingFace Transformers library offers generation API. 
HuggingFace generation for the Gemma model will be used as a baseline.

In [2]:
# Restart the notebook (to flush the GPU memory)
from utils import restart_jupyter_notebook
restart_jupyter_notebook()

from utils import *

# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`
# !!! `model_name` attr must point to the location of the model weights !!!
# Weights can be downloaded from: https://huggingface.co/google/gemma-7b.
# Weights should be in the *.safetensors HF format, not in the original format.
hyperparams.model_name = "/tmp/gemma-7b-hf"  # <== Add model weight location here e.g. "/path/to/downloaded/gemma/weights"

model = init_baseline_model(hyperparams)

print_sample_of_generated_texts(model)
# benchmark_generation(model)

  from .autonotebook import tqdm as notebook_tqdm
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:04<00:00,  1.02s/it]


Prompt:
Here are the two facts about GPUs:
Generated text:


1. GPUs are very good at doing the same thing over and over again.
2. GPUs are very bad at doing different things at the same time.

The first fact is why GPUs are so good at graphics. The second fact is
Prompt:
Some facts about NVIDIA:
Generated text:


* NVIDIA is a global technology company that designs and develops high-performance computer graphics and computer processing units (CPUs) for the gaming and professional markets.
* The company was founded in 1993 and is headquartered in Santa Clara
Prompt:
Here are the two facts about GPUs:
Generated text:


1. GPUs are very good at doing the same thing over and over again.
2. GPUs are very bad at doing different things at the same time.

The first fact is why GPUs are so good at graphics. The second fact is
Prompt:
Some facts about NVIDIA:
Generated text:


* NVIDIA is a global technology company that designs and develops high-performance computer graphics and computer proce

Let's put this time into the table for later comparison.

| Models                                                      | Time (s) | Speedup |  
|-------------------------------------------------------------|---------------------------------------|--------------------------------------|
| HF (baseline)                                               | 87.68      | 1                         |

## [Improvement 1] Using TransformerLayer from Transformer Engine instead of GemmaDecoderLayer.

As in the [Gemma](./tutorial_accelerate_hf_gemma_finetuning_with_te.ipynb) finetuning tutorial, a GemmaDecoderLayer is substituted by a tuned TransformerLayer from the Transformer Engine. Let's run it and compare the time with the baseline.

In [2]:
# Restart the notebook (to flush the GPU memory)
from utils import restart_jupyter_notebook
restart_jupyter_notebook()

from utils import *

hyperparams.model_name = "/tmp/gemma-7b-hf" # <== Add model weight location here e.g. "/path/to/downloaded/gemma/weights"

model = init_te_gemma_model(hyperparams)

print_sample_of_generated_texts(model)
# benchmark_generation(model)

  from .autonotebook import tqdm as notebook_tqdm
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in TEGemmaForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in GemmaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`


> [0;32m/perfhome/mnt/wkstn/work/repos/TransformerEngine/transformer_engine/pytorch/attention.py[0m(8223)[0;36mforward[0;34m()[0m
[0;32m   8221 [0;31m[0;34m[0m[0m
[0m[0;32m   8222 [0;31m                [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 8223 [0;31m                key_layer, value_layer = inference_params.save_to_kv_cache(
[0m[0;32m   8224 [0;31m                    [0mself[0m[0;34m.[0m[0mlayer_number[0m[0;34m,[0m [0mkey_layer[0m[0;34m,[0m [0mvalue_layer[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   8225 [0;31m                )
[0m


ipdb>  key_layer.shape


torch.Size([128, 64, 16, 256])


ipdb>  value_layer.shape


torch.Size([128, 64, 16, 256])


ipdb>  query_layer.shape


torch.Size([8192, 16, 256])


ipdb>  c


AssertionError: Queries, keys and values must be 4D tensors when qkv_format = bshd!

The speedup of **62%** was obtained.

| Models                                                      | Time (s) | Speedup |  
|-------------------------------------------------------------|---------------------------------------|--------------------------------------|
| HF (baseline)                                               | 87.68      | 1                         |
| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer)                                              | 54.11      | 1.62                         | 

## [Improvement 2] Use of THD attention layout.

Input sequences can have various lengths. Hugging Face generation – as can be seen in Animation 1 – pads the sequences and then uses attention mask. In the THD attention layout cumulative sequence lengths and offsets need to be provided, instead of attention mask. The THD attention layout is much more optimized than BSHD layout.

The class `transformer_engine.pytorch.DotProductAttention` supports this format. One need to pass the following things as the arguments to the forward:
- `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` – offsets of the beginnings of the next sequences,
- `cu_seqlens_q`, `cu_seqlens_kv` – cumulative sum of the lengths of the sequences of query and values,
- `max_seqlen_q` – maximum sequence length in query layer,
- `max_seqlen_kv` – maximum sequence length in key-value layer.

<div class="alert alert-info">
<b>Note</b>

Currently, the THD attention for `TransformerLayer` is supported only for token generation.
</div>

Let's look how using TransformerEngine with THD attention impacts the speed of token generation:

In [1]:
# Restart the notebook (to flush the GPU memory)
from utils import restart_jupyter_notebook
restart_jupyter_notebook()

from utils import *

hyperparams.model_name = "/tmp/gemma-7b-hf/" # <== Add model weight location here e.g. "/path/to/downloaded/gemma/weights"
hyperparams.qkv_format = "thd"

model = init_te_gemma_model(hyperparams)

print_sample_of_generated_texts(model)
# benchmark_generation(model)

  from .autonotebook import tqdm as notebook_tqdm
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in TEGemmaForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in GemmaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`


Prompt:
Here are the two facts about GPUs:
Generated text:


1. They are very good at doing the same thing over and over again.
2. They are very bad at doing different things at the same time.

This is why they are so good at rendering 3D graphics.

The GPU
Prompt:
Some facts about NVIDIA:
Generated text:


* NVIDIA is a global technology company that designs and develops graphics processing units (GPUs) for the gaming and professional markets.
* NVIDIA was founded in 1993 by Jensen Huang, Chris Malachowsky, and Curtis Priem.
Prompt:
Here are the two facts about GPUs:
Generated text:


1. They are very good at doing the same thing over and over again.
2. They are very bad at doing different things at the same time.

This is why they are so good at rendering 3D graphics.

The GPU
Prompt:
Some facts about NVIDIA:
Generated text:


* NVIDIA is a global technology company that designs and develops graphics processing units (GPUs) for the gaming and professional markets.
* NVIDIA was founde

By using THD attention, the following speedup was obtained:

| Models                                                      | Time (s) | Speedup |  
|-------------------------------------------------------------|---------------------------------------|--------------------------------------|
| HF (baseline)                                               | 87.68      | 1                         |
| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer)                                              | 54.11      | 1.62                         | 
| TE + THD attention                                               | 28.22      | 3.11                         |  

## [Improvement 3] Speeding up generation with CUDA Graphs

TransformerEngine includes a function `transformer_engine.pytorch.make_graphed_callables`, which functions similarly to the corresponding feature in PyTorch. It is capable of recording any modules from the Transformer Engine. Below is a code excerpt from `te_gemma.py` from class `TEGemmaForCausalLMCudaGraphs`:
```
    def __init__(self, config : GemmaConfig):
            (...)
            
            # Here "the trick" happens. We override methods from TEGemmaForCausalLM
            # with their recorded version. After invocation of each of them,
            # captured graph will be replayed with minimal usage of CPU,
            # what will lead to huge speedup.
            (...)
            self._model_context_phase = 
                self.record_graph(self._model_context_phase, self.hidden_states_buffer) # CUDA Graphs recording

            (...)        
            self._model_generation_phase = 
                self.record_graph(self._model_generation_phase, self.generation_buffer) # CUDA Graphs recording

    @torch.no_grad()
    def record_graph(self, function, input_tensor):
        (...)
        # function is invoked on argument (self.hidden_states,) and all kernels are recorded.
        # record_graph() returns captured function, which can be run later with minimal use of th CPU.
        fp8_format = Format.HYBRID
        fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
        with autocast(dtype=torch.bfloat16, cache_enabled=False):
            graphed_function = te.pytorch.make_graphed_callables(
                function, 
                (input_tensor,), 
                fp8_enabled=True, 
                fp8_recipe=fp8_recipe, 
                allow_unused_input=True,
                num_warmup_iters=3
            )
        return graphed_function
```

It is strongly reccomended to review the entire code of the class `TEGemmaForCausalLMCudaGraphs`. Let's now proceed to evaluate the performance improvement offered by CUDA Graphs.

In [1]:
#Restart the notebook (to flush the GPU memory)
from utils import restart_jupyter_notebook
restart_jupyter_notebook()


from utils import *

hyperparams.model_name = "/tmp/gemma-7b-hf/" # <== Add model weight location here e.g. "/path/to/downloaded/gemma/weights"
hyperparams.qkv_format = "thd"

hyperparams.generation_cuda_graphs = True

# It is necessary to preallocate a static buffer.
# CUDA graphs require static input tensors for every kernel.
# This approach may result in a slight increase in memory consumption;
# however, the substantial speedup achieved makes it worthwhile.
hyperparams.cuda_graphs_static_batch_size = 64
hyperparams.cuda_graphs_static_max_seq_len = 1024
hyperparams.cuda_graphs_static_max_context_len = 128
model = init_te_gemma_model(hyperparams)

print_sample_of_generated_texts(model)
# benchmark_generation(model)

  from .autonotebook import tqdm as notebook_tqdm
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in TEGemmaForCausalLMCudaGraphs is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in GemmaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`


Prompt:
Here are the two facts about GPUs:
Generated text:


1. They are very good at doing the same thing over and over again.
2. They are very bad at doing different things at the same time.

This is why they are so good at rendering 3D graphics.

The GPU
Prompt:
Some facts about NVIDIA:
Generated text:


* NVIDIA is a global technology company that designs and develops graphics processing units (GPUs) for the gaming and professional markets.
* NVIDIA was founded in 1993 by Jensen Huang, Chris Malachowsky, and Curtis Priem.
Prompt:
Here are the two facts about GPUs:
Generated text:


1. They are very good at doing the same thing over and over again.
2. They are very bad at doing different things at the same time.

This is why they are so good at rendering 3D graphics.

The GPU
Prompt:
Some facts about NVIDIA:
Generated text:


* NVIDIA is a global technology company that designs and develops graphics processing units (GPUs) for the gaming and professional markets.
* NVIDIA was founde

The **5.23x** speedup was obtained.

| Models                                                      | Time (s) | Speedup |  
|-------------------------------------------------------------|---------------------------------------|--------------------------------------|
| HF (baseline)                                               | 87.68      | 1                         |
| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer)                                              | 54.11      | 1.62                         | 
| TE + THD attention                                               | 28.22      | 3.11                         |  
| TE + THD attention + CUDA Graphs                                             | 16.75      | 5.23                         |  


Let's look at the screenshots from *NVIDIA Nsight System* profiler to see where this speedup comes from:

<figure align="center">
<img src="./media/graphs_1.png" width="80%">
<figcaption>
Figure 5: Without CUDA Graphs. One can see that GPU (blue) is idle for big portion of the time.
</figcaption>
</figure>

<figure align="center">
<img src="./media/graphs_2.png" width="80%">
<figcaption>
Figure 6: With CUDA Graphs. One can see that GPU (orange) is fully utilized.
</figcaption>
</figure>

## [Improvement 4] Running generation in FP8 of the model trained in higher precision 

Implementing FP8 generation with the Gemma model is not straightforward, because this model was initially trained using BF16 precision, and the necessary FP8 scaling factors are missing. Running the model at this lower precision without proper scaling could lead to significant errors and incorrect results.

It is highly recommended to familiarize oneself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the necessity of scaling.


<figure align="center">
<img src="./media/calibration_1_half.svg">
<figcaption>
    Figure 8: The FP8 scaling factors are incorrect and that leads to numerical errors. The weight calibration allows us to compute FP8 metadata during the forwards in higher precision.
</figcaption>
</figure>

### Weight Calibration

To address the issue outlined above, weight calibration will be used. This involves running several forward iterations at BF16 precision within the context `te.fp8_autocast(enabled=False, calibration=True)`. This setup allows the forward pass to operate at higher precision, while simultaneously collecting `amax_history` and other parameters related to the FP8 precision, which are essential for calculating the FP8 scaling well.

The code below outlines the steps to initialize the BF16 model and conduct several forward iterations within the specified context. After these iterations, the model is saved, and these weights will be utilized in subsequent chapters.

In [1]:
#Restart the notebook (to flush the GPU memory)
from utils import restart_jupyter_notebook
restart_jupyter_notebook()

from utils import *
import transformer_engine.pytorch as te

hyperparams.model_name = "/tmp/gemma-7b-hf/" # <== Add model weight location here e.g. "/path/to/downloaded/gemma/weights"
hyperparams.fuse_qkv_params = True # This is needed by the last improvement.

model = init_te_gemma_model(hyperparams)

# Calibration
with te.fp8_autocast(enabled=False, calibrating=True), \
    torch.autocast(device_type='cuda', dtype=torch.bfloat16):
    model.train()
    run_forward_pass(model, hyperparams, num_iters=512)

# Compute scale_fwd with enabled fp8 autocast
with te.fp8_autocast(enabled=True), \
    torch.autocast(device_type='cuda', dtype=torch.bfloat16):
    run_forward_pass(model, hyperparams, 1)

# Some parameters are in pointing to the same tensors, double save is avoided here.
dict_to_save = {k: v for k, v in model.state_dict().items() \
                if ("_context_phase" not in k and "_generation_phase" not in k)}
torch.save(dict_to_save, 'calibrated_weights.pth') # <== Add path to save calibrated weights.

  from .autonotebook import tqdm as notebook_tqdm
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in TEGemmaForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in GemmaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
Repo card metadata block was not f

|
### Generation in FP8

<figure align="center">
<img src="./media/calibration_2_half.svg">
<figcaption>
    Figure 8: After the weight calibration FP8 scaling factors are correct and prevent numerical errors.
</figcaption>
</figure>

Now FP8 inference is ready to be run.

In [None]:
!ls -alh /perfhome/repos/data/gemma-7b-hf/

In [1]:
#Restart the notebook (to flush the GPU memory)
from utils import restart_jupyter_notebook
restart_jupyter_notebook()

from utils import *

hyperparams.model_name = "/tmp/gemma-7b-hf/"   # <== Add model weight location here e.g. "/path/to/downloaded/gemma/weights"
hyperparams.qkv_format = "thd"
hyperparams.fuse_qkv_params = True # This is needed by the last improvement.

hyperparams.fp8 = True
# Calibrated fp8 weights are loaded directly from the file.

hyperparams.fp8_model_weights_filename = "calibrated_weights.pth" # <== Add calibrated weights location here.

hyperparams.generation_cuda_graphs = True
hyperparams.cuda_graphs_static_batch_size = 64
hyperparams.cuda_graphs_static_max_seq_len = 1024
hyperparams.cuda_graphs_static_max_context_len = 128
model = init_te_gemma_model(hyperparams)

print_sample_of_generated_texts(model)
# benchmark_generation(model)

  from .autonotebook import tqdm as notebook_tqdm
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in TEGemmaForCausalLMCudaGraphs is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in GemmaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`


Prompt:
Here are the two facts about GPUs:
Generated text:


1. GPUs are very good at doing the same thing over and over again.
2. GPUs are very bad at doing different things at the same time.

This is a very important distinction to make.

The first fact is a good thing
Prompt:
Some facts about NVIDIA:
Generated text:


* NVIDIA is a global technology company that designs and develops graphics processing units (GPUs) for the gaming and professional markets.
* NVIDIA was founded in 1993 and is headquartered in Santa Clara, California.
* NVIDIA's
Prompt:
Here are the two facts about GPUs:
Generated text:


1. GPUs are very good at doing the same thing over and over again.
2. GPUs are very bad at doing different things at the same time.

This is a very important distinction to make.

The first fact is a good thing
Prompt:
Some facts about NVIDIA:
Generated text:


* NVIDIA is a global technology company that designs and develops graphics processing units (GPUs) for the gaming and profess

One can observe that the outputs are coherent; however, the generation time has increased. Why is this the case?


<figure align="center">
<img src="./media/fp8_model_init_1_half.svg">
<figcaption>
    Figure 9: Running the model at higher precision involves only one GEMM operation. However, when the model operates in FP8, it requires not just the low-precision GEMM but also weight casting.
</figcaption>
</figure>

Running the model in FP8 does not imply that all weights are stored in FP8. By default, they are stored in higher precision and are cast to FP8, using saved scaling factors, before operations such as GEMMs.

This approach is beneficial during training: one can perform one cast for both backward and forward passes, leading to speedups. However, performing a single cast for each forward pass introduces too much overhead to achieve a speedup. This issue will be addressed in the next section of the tutorial.

### Use of only FP8 model weights

TransformerEngine stores parameters in higher precision and only casts them to FP8. It may be necessary to maintain accucacy during training. However, high precision is not needed when doing inference. 

Transformer Engine supports maintaining only FP8 weights with `fp8_model_init` decorator. Let's see an example
```
linear = te.Linear(1024, 1024) # this module is initialized with full precision weights
with te.fp8_model_init(enabled=True):
    linear_fp8 = te.Linear(1024, 1024) # this module is initialized only with fp8 weights

assert type(linear.weight.data) is torch.Tensor
assert type(linear_fp8.weight.data) is te.float8_tensor.Float8Tensor
```

<figure align="center">
<img src="./media/fp8_model_init_2_half.svg">
<figcaption>
    Figure 10: Using fp8_model_init stores the weights directly in FP8 format, which reduces both time and memory usage.
</figcaption>
</figure>

Let's run the code with `fp8_model_init`:

In [1]:
#Restart the notebook (to flush the GPU memory)
from utils import restart_jupyter_notebook
restart_jupyter_notebook()

# Import necessary packages and methods
from utils import *

hyperparams.model_name = "/tmp/gemma-7b-hf/" # <== Add model weight location here e.g. "/path/to/downloaded/gemma/weights"
hyperparams.fuse_qkv_params = True # Needed for fp8_model_init().
hyperparams.qkv_format = "thd"

hyperparams.fp8 = True
hyperparams.fp8_model_init = True # This will result in storing only fp8 weights.
hyperparams.fp8_model_weights_filename = "calibrated_weights.pth" # <== Add calibrated weights location here.

hyperparams.generation_cuda_graphs = True
hyperparams.cuda_graphs_static_batch_size = 64
hyperparams.cuda_graphs_static_max_seq_len = 1024
hyperparams.cuda_graphs_static_max_context_len = 128
model = init_te_gemma_model(hyperparams)

print_sample_of_generated_texts(model)
# benchmark_generation(model)

  from .autonotebook import tqdm as notebook_tqdm
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in TEGemmaForCausalLMCudaGraphs is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in GemmaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`


Prompt:
Here are the two facts about GPUs:
Generated text:


1. GPUs are very good at doing the same thing over and over again.
2. GPUs are very bad at doing different things at the same time.

This is a very important distinction to make.

The first fact is a good thing
Prompt:
Some facts about NVIDIA:
Generated text:


* NVIDIA is a global technology company that designs and develops graphics processing units (GPUs) for the gaming and professional markets.
* NVIDIA was founded in 1993 and is headquartered in Santa Clara, California.
* NVIDIA's
Prompt:
Here are the two facts about GPUs:
Generated text:


1. GPUs are very good at doing the same thing over and over again.
2. GPUs are very bad at doing different things at the same time.

This is a very important distinction to make.

The first fact is a good thing
Prompt:
Some facts about NVIDIA:
Generated text:


* NVIDIA is a global technology company that designs and develops graphics processing units (GPUs) for the gaming and profess

| Models                                                      | Time (s) | Speedup |  
|-------------------------------------------------------------|---------------------------------------|--------------------------------------|
| HF (baseline)                                               | 87.68      | 1                         |
| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer)                                              | 54.11      | 1.62                         | 
| TE + THD attention                                               | 28.22      | 3.11                         |  
| TE + THD attention + CUDA Graphs                                             | 16.75      | 5.23                         |  
| TE + THD attention + FP8                                             | 12.13      | 7.23                         |  

The final speedup is **7.23x**.

## Conclusions


<figure align="center">
<img src="./media/plot.svg">
<figcaption>
    Figure 11: Times obtained with optimizations using TransformerEngine (seconds).
</figcaption>
</figure>

In this tutorial, we've explored three features of the Transformer Engine:
1. Support for the THD attention layout,
2. Integration with CUDA Graphs,
3. FP8 weights calibration,
4. Models containing only FP8 version of their parameters.

Each of these features can be applied in various contexts, such as fast token generation. It's important to note that the fastest possible inference speeds can be achieved using NVIDIA's inference-optimized [TensorRT](https://developer.nvidia.com/tensorrt) library.