# Accelerating a Hugging Face Gemma model with Transformer Engine

In the previous [tutorial](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb), we demonstrated how to accelerate HF Llama models using the Transformer Engine library. We replaced `GemmaDecoderLayer` with `TransformerLayer` from the Transformer Engine, achieving a **25%** speedup. Furthermore, we conducted the finetuning in FP8 precision, which yielded an additional **39%** speedup from the baseline.

Now, we will undertake a similar enhancement for the Google's [Gemma](https://blog.google/technology/developers/gemma-open-models/) model.

# Dependencies for this tutorial

Following files and media are necessary to effectively run this tutorial:

1. `te_gemma.py`
    - This file contains the code to load a Hugging Face Gemma checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `GemmaDecoderLayer`. This is used in the following two sections of the tutorial - "Improvement 1" and "Improvement 2".
2. `utils.py`
    - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. 
3. `media/`
    - This directory contains the images used in the following tutorial.

# Differences between Llama and Gemma

Thr Llama and the Gemma are very similar models - both are based on Transformer Decoder architecture. The most important architectural differences between them are the following:


| Feature                                      | Llama                              | Gemma                                      |
|----------------------------------------------|------------------------------------|--------------------------------------------|
| **Norm Layer**                               | Standard RMSNorm <br> $ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta $                   | RMSNorm with zero centered gamma parameter <br>  $ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (\textcolor{red}{1 +} \gamma) + \beta $   |
| **Embedding Dimension/Head Dimension**             | 4096/4096                              | 3072/4096                                  |
| **Activation Function**                      | SwiGlu                             | GeGlu                                      |


# [Baseline] Running HF `GemmaModel` (Precision: `BF16`)

Similarly to the Llama tutorial, we begin the experiments by running baseline training in BF16 precision.

<div class="alert alert-info">

<b>Note</b>
    
This tutorial loads and trains a Llama 2 7B model which takes up most of the GPU memory and therefore, we need to restart the jupyter notebook each time before running the following sections. A small utility method `restart_jupyter_notebook` is defined in the accompanying `utils.py` file. This function restarts the jupyter notebook so that the GPU memory is flushed before the model is loaded again from the checkpoint in order to avoid running into OOM (Out Of Memory) errors.

If the utility doesn't work, comment this line `restart_jupyter_notebook()` in the following cell and manually restart the jupyter notebook before running the cell. Repeat the same for other sections in this tutorial.

</div>


In [None]:
# Restart the notebook (to flush the GPU memory)
from utils import restart_jupyter_notebook
restart_jupyter_notebook()


# Import necessary packages and methods
from utils import *


# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`
## !!! `model_name` attr must point to the location of the model weights !!!
## Weights can be downloaded from: https://llama.meta.com/llama-downloads/
hyperparams.model_name = "../../../../gemma-weights" # <== Add model weight location here e.g. "/path/to/downloaded/llama/weights"
hyperparams.mixed_precision = "bf16"


# Init the model and accelerator wrapper
model = init_baseline_model(hyperparams).cuda()
accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)


# Finetune the model
finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)

Let's add this information in a table and keep comparing it with a few possible improvements in future sections:

| Models                                                      | Precision | Step Time (or ms per batch) | Speedup (over baseline) |
|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|
| HF (baseline)                                               | BF16      | -                         | 1                       |

# [Improvement 1] Replace HF's `GemmaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)

Now we substitute *GemmaDecoderLayer* with highly tuned *TransformerLayer*. Let's see how this will impact the speed of the mode.

In [None]:
# Restart the notebook (to flush the GPU memory)
from utils import restart_jupyter_notebook
#restart_jupyter_notebook()


# Import necessary packages and methods
from utils import *


# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`
## !!! `model_name` attr must point to the location of the model weights !!!
## Weights can be downloaded from: https://llama.meta.com/llama-downloads/
hyperparams.model_name = "../../../../gemma-weights"  # <== Add model weight location here e.g. "/path/to/downloaded/llama/weights"
hyperparams.mixed_precision = "bf16"


# Init the model and accelerator wrapper
model = init_te_gemma_model(hyperparams).cuda()
accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)


# Finetune the model
finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)

Compared to the "baseline" implementation, we see that using Transformer Engine's `TransformerLayer` in place of Huggging Face's `GemmaDecoderLayer` gives a speedup of **??%** even when using only BF16 precision!

| Models                                                      | Precision | Step Time (or ms per batch) | Speedup (over baseline) |
|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|
| HF (baseline)                                               | BF16      | 315                         | 1                       |
| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | BF16      | 252                         | 1.25                    |

# [Improvement 2] Replace HF's `GemmaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)

The last improvement is about enabling FP8 precision. Let's see how it works.

In [None]:
# Restart the notebook (to flush the GPU memory)
from utils import restart_jupyter_notebook
#restart_jupyter_notebook()


# Import necessary packages and methods
from utils import *


# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`
## !!! `model_name` attr must point to the location of the model weights !!!
## Weights can be downloaded from: https://llama.meta.com/llama-downloads/
hyperparams.model_name = "../../../../gemma-weights"  # <== Add model weight location here e.g. "/path/to/downloaded/llama/weights"
hyperparams.mixed_precision = "fp8"


# Init the model and accelerator wrapper
model = init_te_gemma_model(hyperparams).cuda()
accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)


# Finetune the model
finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)

| Models                                                      | Precision | Step Time (or ms per batch) | Speedup (over baseline) |
|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|
| HF (baseline)                                               | BF16      | -                         | 1                       |
| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | BF16      | -                         | -                    |
| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | FP8       | -                         | -                    |


After turning on FP8 precision, we get even more speedup of almost **??%**!

# Conclusion

We can see, that similar to the Llama model, using `TransformerLayer` module from Transformer Engine as a substitute for Hugging Face's `GemmaDecoderLayer` provides a speedup over Hugging Face's native Gemma implementation.

## See more

We also prepared [tutorial](./tutorial_generation_gemma_with_te.ipynb) covering CUDA graphs and THD attention which we use to speedup Gemma generation.