# Speeding up the Hugging Face Gemma model generation with Cuda Graphs and THD attention with FP8 precision

As it can be seen in the [tutorial for Llama](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb) or [tutorial for Gemma](./tutorial_accelerate_hf_gemma_with_te.ipynb), transformer models can be accelerated by using Transformer's Engine `TransformerLayer`. In this tutorial we want to present few more advanced features, namely
1. THD attention layout.
2. FP8 weight calibration - for doing inference in FP8 precisions for models, which were trained in higher precisions.
3. CUDA Graphs API.

We will compare generation time at 3 benchmarks:
- long input sequences (max 256 tokens), short generation part (max 128 tokens),
- short input sequences (max 64 tokens), long generation (max 100 tokens),

All benchmarks above run with batch size 64 and on the dataset "timdettmers/openassistant-guanaco".

<div class="alert alert-info">

<b>Note</b>
    
This tutorial aims to demonstrate features of TransformerEngine mentioned above on the example of generation. It's important to note though, that NVIDIA offers other library to use for inference - namely [TensorRT](https://developer.nvidia.com/tensorrt), which should be used in such cases.

</div>




## Dependencies for this tutorial

Following files and media are necessary to effectively run this tutorial:

1. `te_gemma.py`
    - This file contains the code to load a Hugging Face Gemma checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `GemmaDecoderLayer`. It does also contain code for generation with THD attention and weight calibration.
2. `utils.py`
    - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. 
3. `media/`
    - This directory contains the images used in the following tutorial.

## Table of contents

1. [Baseline] Running Hugging Face generation with Gemma model
2. [Improvement 1] Speeding up generation by using Transformer Engine THD attention.
3. [Improvement 2] Running generation of the model trained in hign precision in FP8.
4. [Improvement 3] Speeding up generation with CudaGraphs.
5. Conclusions.

## [Baseline] Running Hugging Face generation with Gemma model

Hugging Face Transformers library offers generation API. We will treat this as our baseline.

In [None]:
# Restart the notebook (to flush the GPU memory)
from utils import restart_jupyter_notebook
#restart_jupyter_notebook()

# Import necessary packages and methods
from utils import *


# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`
## !!! `model_name` attr must point to the location of the model weights !!!
## Weights can be downloaded from: https://llama.meta.com/llama-downloads/
hyperparams.model_name = "../../../../gemma-weights"  # <== Add model weight location here e.g. "/path/to/downloaded/llama/weights"
hyperparams.mixed_precision = "bf16"


# Init the model and accelerator wrapper
model = init_baseline_model(hyperparams).cuda()
model = model.to(torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name)
inputs = tokenizer(["Some random initial str ", "Another string ... "] * 32, return_tensors="pt", padding=True)

inputs['input_ids'] = inputs['input_ids'].cuda()
inputs['attention_mask'] = inputs['attention_mask'].cuda()

start_time = time.time()

outputs = model.generate(
    **inputs,
    max_new_tokens=1000
)

end_time = time.time()
duration = end_time - start_time

print(duration)

# Decode the output tensor to text
generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)

# Display the generated text
for text in generated_texts:
    print(text)
    print("=" * 100)

We will put these times into the table for later comparison.

| Models                                                      | max_input_len=64, max_new_tokens=1000 | max_input_len=128, max_new_tokens=128 |  
|-------------------------------------------------------------|---------------------------------------|--------------------------------------|
| HF (baseline)                                               | -      | -                         |  

## [Improvement 1] Speeding up generation by using Transformer Engine THD attention

Similarly to the Gemma tutorial, we substitute `GemmaDecoderLayer` with `TransformerLayer` from Transformer Engine. Since initial sequences have different lengths, we have following choices:
1. Use padding from the beginning and then use standard attention with `"bshd"` or `"sbhd"` layout.
2. Do not pad from the beginning and use THD attention.

In this tutorial we will show the second option. We illustrate THD attention idea on the two pictures below.

<center>
<img src="./media/pic1.png" alt="Logo Pythona" width="200" height="200">
<img src="./media/pic2.png" alt="Logo Pythona" width="200" height="200">
</center>



In [None]:
# Import necessary packages and methods
from utils import *

# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`
## !!! `model_name` attr must point to the location of the model weights !!!
## Weights can be downloaded from: https://llama.meta.com/llama-downloads/
hyperparams.model_name = "../../../../gemma-weights"  # <== Add model weight location here e.g. "/path/to/downloaded/llama/weights"
hyperparams.mixed_precision = "bf16"
hyperparams.fuse_qkv_params = False

# Init the model and accelerator wrapper
model = init_te_gemma_model(hyperparams).cuda()
#accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)

model = model.to(torch.bfloat16).cuda()

tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name)
inputs = tokenizer(["I love when ", "I "] * 32, return_tensors="pt", padding=True)

inputs['input_ids'] = inputs['input_ids'].cuda()
inputs['attention_mask'] = inputs['attention_mask'].cuda()

import time

# Początek pomiaru czasu
start_time = time.time()

outputs = model.generate(
    **inputs,
    max_new_tokens=40
)

# Koniec pomiaru czasu
end_time = time.time()

# Obliczamy czas trwania operacji
duration = end_time - start_time
print(f"Generation time: {duration} seconds")


# Decode the output tensor to text
generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)

# Display the generated text
for text in generated_texts:
    print(text)
    print("=" * 100)

By using THD attention we obtained following speedups:

| Models                                                      | max_input_len=64, max_new_tokens=1000 | max_input_len=128, max_new_tokens=128 |  
|-------------------------------------------------------------|---------------------------------------|--------------------------------------|
| HF (baseline)                                               | -      | -                         |
| THD attention with TE                                               | -      | -                         |  

## [Improvement 2] Running generation of the model trained in high precision in FP8

Now we want to run FP8 generation with Gemma model. But it's not that simple! Since model was trained in BF16 precision, the FP8 scaling factors are not computed. Running the model with such low precision without proper scaling will lead to serious numerical divergence, which will lead to wrong output.

##### Weight calibration

The wieght calibration is solution of the problem mentioned above. We will run few forward iterations on BF16 precision within context `te.fp8_autocast(enabled=False, calibration=True)`. This means that the forward pass will be done in higher precision, but we will store `amax_history`, which will be used to compute FP8 scaling factors. 

In the code below, we initialize BF16 model, run few iterations of forward passes within mentioned context. Then, we save model - we will also use these weights in the next chapter. 

In [None]:
# Import necessary packages and methods
import transformer_engine.pytorch as te
from utils import *
import accelerate
from transformer_engine.pytorch import fp8_model_init
from transformer_engine.common.recipe import Format, DelayedScaling
import torch


hyperparams.model_name = "../../../../gemma-weights"
hyperparams.fuse_qkv_params = True
model = init_te_gemma_model(hyperparams, fp8_model_init=False).cuda()
model = model.to(torch.bfloat16)


accelerator = Accelerator(
        log_with="wandb",
        gradient_accumulation_steps=hyperparams.gradient_accumulation_steps,
        mixed_precision=hyperparams.mixed_precision
    )
train_dataloader = get_dataloaders(accelerator, hyperparams)

tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name)

print("Calibration started")
with te.fp8_autocast(enabled=False, calibrating=True):
    model.train()
    train_dataloader = enumerate(train_dataloader)

    for i in range(100):
        step, batch = next(train_dataloader)
        batch["input_ids"] = batch["input_ids"].cuda()
        outputs = model.generate(
            **batch,
            max_new_tokens=10
        )
        generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        print(generated_texts[0][:50])
print("calibration_finished")

print("scale_fwd computation started")
with te.fp8_autocast(enabled=True):
    for i in range(10):
        step, batch = next(train_dataloader)
        batch["input_ids"] = batch["input_ids"].cuda()
        outputs = model.generate(
            **batch,
            max_new_tokens=1
        )
print("scale_fwd_computation ended")

print("Casting weights...")
model_fp8 = init_te_gemma_model(hyperparams, fp8_model_init=True).cuda()
model_fp8.load_state_dict(model.state_dict())
print("Weights casted")


print("Saving model...")
torch.save(model_fp8.state_dict(), 'model_fp8_state_dict.pth')
print("Model saved!")

Now we are ready to run FP8 inference.

In [None]:
#Restart the notebook (to flush the GPU memory)
from utils import restart_jupyter_notebook
#restart_jupyter_notebook()
import transformer_engine.pytorch as te

import os
from torch.cuda.amp import autocast


# Import necessary packages and methods
from utils import *

from transformer_engine.pytorch import fp8_model_init
from transformer_engine.common.recipe import Format, DelayedScaling


hyperparams.model_name = "../../../../gemma-weights"
hyperparams.fuse_qkv_params = True
model = init_te_gemma_model(hyperparams, fp8_model_init=True, qkv_format="thd").cuda()

print("Loading model")
model_state_dict = torch.load('model_fp8_state_dict.pth')
model.load_state_dict(model_state_dict)
print("Model loaded")

tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name)
inputs = tokenizer(["Some random initial str ", "Another string ... "] * 32, return_tensors="pt", padding=True)

inputs['input_ids'] = inputs['input_ids'].cuda()
inputs['attention_mask'] = inputs['attention_mask'].cuda()

import time



start_time = time.time()

fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
torch.manual_seed(1234)
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    with autocast(dtype=torch.bfloat16, cache_enabled=False):
        with torch.no_grad():
            model.eval()
            outputs = model.generate(
                **inputs,
                max_new_tokens=40,
                use_cuda_graphs=False
            )


end_time = time.time()
duration = end_time - start_time

generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
for text in generated_texts[:12]:
    print("-" * 50)
    print(text)

print(f"Duration = {duration}")


We add the speedups to the table:

| Models                                                      | max_input_len=64, max_new_tokens=1000 | max_input_len=128, max_new_tokens=128 |  
|-------------------------------------------------------------|---------------------------------------|--------------------------------------|
| HF (baseline)                                               | -      | -                         |
| THD attention with TE                                               | -      | -                         | 
| THD attention + FP8 with TE                                               | -      | -                         |  

## [Improvement 3] Speeding up generation with CudaGraphs

The inference code is run by CPU which starts GPU kernels. When GPU kernels are fast enough, it can result in overhead caused by the CPU. That's where Cuda Graphs come in. When some series of kernels starts is repeatable, then it can be recorded and then repeated without usage of the CPU. You can read more about the Cuda Graphs [here](https://developer.nvidia.com/blog/cuda-graphs/).

Pytorch supports Cuda Graphs with `torch.cuda` API. Neverthless, there are some requirements for sequence of tensor operations to be able of being captured and repeated correctly. Namely, all the operations need to be static - meaning that tensors should not "move" between iterations. Pytorch offers also simpler way of using cuda graphs - the method `torch.cuda.make_graphed_callables`. We can easily record every pytorch module.

Transformer Engine from version 1.6 also `make_graphed_callables` API. In the following code I run generate method from `te_gemma.py`. This is the code responsible for making graphed part:

```
graphed_generator = TeGraphed(...)
(...)
    if use_cuda_graphs:
        fp8_format = Format.HYBRID
        fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
        graphed_layers = te.pytorch.make_graphed_callables(
                graphed_generator, 
                args, 
                fp8_enabled=True, 
                fp8_recipe=fp8_recipe, 
                allow_unused_input=True,
                num_warmup_iters=10
            )
            
    for i in range(max_new_tokens):
        next_tokens = graphed_layers(*args) if use_cuda_graphs else graphed_generator(*args)
        output_tokens.append(next_tokens.clone())
```

Now, let's see how big the speedup is going to be.

In [None]:
import os

os.environ['CUDNN_LOGLEVEL_DBG'] = '3'
os.environ['CUDNN_LOGDEST_DBG'] = 'backlog.txt'
#Restart the notebook (to flush the GPU memory)
from utils import restart_jupyter_notebook
#restart_jupyter_notebook()
import transformer_engine.pytorch as te

from torch.cuda.amp import autocast


# Import necessary packages and methods
from utils import *

from transformer_engine.pytorch import fp8_model_init
from transformer_engine.common.recipe import Format, DelayedScaling


hyperparams.model_name = "../../../../gemma-weights"
hyperparams.fuse_qkv_params = True
model = init_te_gemma_model(hyperparams, fp8_model_init=True, qkv_format="thd").cuda()

print("Loading model")
model_state_dict = torch.load('model_fp8_state_dict.pth')
model.load_state_dict(model_state_dict)
print("Model loaded")

tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name)
inputs = tokenizer(["Some random initial str ", "Another string ... "] * 32, return_tensors="pt", padding=True)

inputs['input_ids'] = inputs['input_ids'].cuda()
inputs['attention_mask'] = inputs['attention_mask'].cuda()

import time

start_time = time.time()

fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
torch.manual_seed(1234)
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    with autocast(dtype=torch.bfloat16, cache_enabled=False):
        with torch.no_grad():
            model.eval()
            outputs = model.generate(
                **inputs,
                max_new_tokens=10,
                use_cuda_graphs=True
            )

end_time = time.time()
duration = end_time - start_time

generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
for text in generated_texts[:12]:
    print("-" * 50)
    print(text)

print(f"Duration = {duration}")


We finally obtained the **??%** speedup.

| Models                                                      | max_input_len=64, max_new_tokens=1000 | max_input_len=128, max_new_tokens=128 |  
|-------------------------------------------------------------|---------------------------------------|--------------------------------------|
| HF (baseline)                                               | -      | -                         |
| THD attention with TE                                               | -      | -                         | 
| THD attention + FP8 with TE                                               | -      | -                         |  
| THD attention + FP8 + Cuda Graphs with TE                                               | -      | -                         |  

## Conclusions

In this tutorial we showed three features of Transformer Engine:
1. Support of THD attention layout,
2. FP8 weights calibration.
3. Support of Cuda Graphs.

Each one of them can be used in different context, here we showed how to use them to obtain fast inference. We remind though, that this is not the fastest possible way of doing inference - for doing do we reccommend looking at the [TensorRT](https://developer.nvidia.com/tensorrt) library from NVIDIA.