# Accelerating a Hugging Face Gemma model generation with Transformer Engine

<div class="alert alert-info">

<b>Goal</b>

This tutorial showcases how to accelerate generation done by a full Gemma model from [Hugging Face](https://huggingface.co/google/gemma-7b-it) by using `TransformerLayer` from the [Transformer Engine library](https://github.com/NVIDIA/TransformerEngine) in `BF16` precision.

</div>


## Dependencies for this tutorial

Following files and media are necessary to effectively run this tutorial:

1. `te_gemma.py`
    - This file contains the code to load a Hugging Face Gemma checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `LlamaDecoderLayer`. Also it contains the logic of the generation using TransformerEngine. 
2. `utils.py`
    - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. 
3. `media/`
    - This directory contains the images used in the following tutorial.

## Baseline HuggingFace Gemma generation

<div class="alert alert-info">

<b>Note</b>
    
This tutorial loads and trains a Llama 2 7B model which takes up most of the GPU memory and therefore, we need to restart the jupyter notebook each time before running the following sections. A small utility method `restart_jupyter_notebook` is defined in the accompanying `utils.py` file. This function restarts the jupyter notebook so that the GPU memory is flushed before the model is loaded again from the checkpoint in order to avoid running into OOM (Out Of Memory) errors.

If the utility doesn't work, comment this line `restart_jupyter_notebook()` in the following cell and manually restart the jupyter notebook before running the cell. Repeat the same for other sections in this tutorial.

</div>


In [5]:
# Restart the notebook (to flush the GPU memory)
from utils import restart_jupyter_notebook
#restart_jupyter_notebook()


# Import necessary packages and methods
from utils import *
import torch


# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`
## !!! `model_name` attr must point to the location of the model weights !!!
## Weights can be downloaded from: https://llama.meta.com/llama-downloads/
hyperparams.model_name = "../../../../gemma-weights"  # <== Add model weight location here e.g. "/path/to/downloaded/llama/weights"
hyperparams.mixed_precision = "no"


# Init the model and accelerator wrapper
model = init_baseline_model(hyperparams).cuda()
model = model.to(torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name)
inputs = tokenizer(["I like", "I do not like"] * 32, return_tensors="pt", padding=True)

inputs['input_ids'] = inputs['input_ids'].cuda()
inputs['attention_mask'] = inputs['attention_mask'].cuda()


start_time = time.time()

outputs = model.generate(
    **inputs,
    max_new_tokens=400
)

end_time = time.time()
duration = end_time - start_time
print(f"Generation time: {duration} seconds")


# Decode the output tensor to text
generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)

# Display the first two samples of the generated text
print(generated_texts[0][:80])
print(30 * "=")
print(generated_texts[1][:80])

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00,  2.60it/s]


Generation time: 26.482454538345337 seconds
I like the new look of the app. I like the new features. I like the new look of 
I do not like the way the new version of the app is set up. I do not like the fa


Let's add this information in a table and keep comparing it with a few possible improvements in future sections:

| Models                                                      | Precision | Generation time | Speedup (over baseline) |
|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|
| HF (baseline)                                               | BF16      | 26.48                       | 1                       |

## [Improvement] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` and use generation within TE



```
@torch.no_grad()
    def generate(
        self,
        input_ids: Optional[torch.Tensor] = None,
        generation_config: Optional[GenerationConfig] = None,
        max_new_tokens = 0,
        **kwargs,
    ):
        num_heads = self.model.config.num_attention_heads
        batch_size, seq_len = input_ids.shape
        max_seq_len = seq_len + max_new_tokens
        generation_config, _ = self._prepare_generation_config(generation_config, **kwargs)
        unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)

        # inference_params object is a cache, where keys and values of previous tokens are stored
        inference_params = te.pytorch.InferenceParams(
            max_batch_size=batch_size, 
            max_sequence_length=seq_len+max_new_tokens+1) 

        # mask has shape [batch_size, num_heads, 1, max_seq_len] and contains False 
        # when coressponding token is padding and True otherwise.
        pad_attention_mask = input_ids.ne(generation_config.pad_token_id)
        mask = torch.ones((batch_size, num_heads, 1, max_seq_len), dtype=torch.bool).cuda()
        mask[..., :seq_len] = mask[..., :seq_len] & pad_attention_mask.unsqueeze(1).unsqueeze(2).expand(-1, num_heads, -1, -1)

        hidden_states = self.model.embed_tokens(input_ids)
        output_tokens = []
        for i in range(max_new_tokens):
            normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
            hidden_states = hidden_states * normalizer
            for decoder_layer in self.model.layers:
                hidden_states = decoder_layer(
                            hidden_states,
                            # In the case of arbiutrary mask, the meaning of True and False is switched, so negation is needed.
                            attention_mask=pad_attention_mask if i == 0 else ~mask[..., :seq_len],
                            self_attn_mask_type="padding_causal" if i == 0 else "arbitrary",
                            inference_params=inference_params
                        )[0]

            # inference_params.sequence_len_offset should contain position of the current token in the sequence.
            inference_params.sequence_len_offset += hidden_states.shape[1]

            hidden_states = self.model.norm(hidden_states)
            logits = self.lm_head(hidden_states)
            logits = logits.float()
            logits = logits[:, -1, :]
            next_tokens = torch.argmax(logits, dim=-1)

            # Sequences, which are finished should contain padding - taken from huggingface transformers.
            next_tokens = next_tokens * unfinished_sequences + generation_config.pad_token_id * (1 - unfinished_sequences)
            output_tokens.append(next_tokens)

            unfinished_sequences = unfinished_sequences & ~(next_tokens == generation_config.eos_token_id)

            hidden_states = self.model.embed_tokens(next_tokens).unsqueeze(1)
            seq_len += 1

        result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1)
        return result
```

In [8]:
# Restart the notebook (to flush the GPU memory)
from utils import restart_jupyter_notebook
#restart_jupyter_notebook()


# Import necessary packages and methods
from utils import *
import accelerate

# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`
## !!! `model_name` attr must point to the location of the model weights !!!
## Weights can be downloaded from: https://llama.meta.com/llama-downloads/
hyperparams.model_name = "../../../../gemma-weights"  # <== Add model weight location here e.g. "/path/to/downloaded/llama/weights"

# Init the model and accelerator wrapper
model = init_te_gemma_model(hyperparams)
#accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)

model = model.to(torch.bfloat16).cuda()


tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name)
inputs = tokenizer(["I like", "I do not like"] * 32, return_tensors="pt", padding=True)

inputs['input_ids'] = inputs['input_ids'].cuda()
inputs['attention_mask'] = inputs['attention_mask'].cuda()

import time

start_time = time.time()

outputs = model.generate(
    **inputs,
    max_new_tokens=400
)

end_time = time.time()
duration = end_time - start_time
print(f"Generation time: {duration} seconds")


# Decode the output tensor to text
generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)

# Display the first two samples of the generated text
print(generated_texts[0][:80])
print(30 * "=")
print(generated_texts[1][:80])



Generation time: 16.87099289894104 seconds
I like the idea of a "re-do" of the original "The Man from U.N.C.L.E." movie. I 
I do not like the way the "new" (2011) version of the 1099-MISC is set up.  I ha


| Models                                                      | Precision | Generation time | Speedup (over baseline) |
|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|
| HF (baseline)                                               | BF16      | 26.48                         | 1                       |
| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16      | 16.87                         | 1.56                    |



After converting to TE Transformer Layers, we obtained the speedup of **56%**!

## Conclusion

Using `TransformerLayer` module from Transformer Engine as a substitute for Hugging Face's `LlamaDecoderLayer` provides a speedup over Hugging Face's native Gemma generation implementation. `TransformerLayer` provides a speedup over the baseline implementation