# Accelerate HF Llama model with TransformerEngine

<div class="alert alert-info">

<b>Goal</b>

This tutorial showcases three incrementally efficient ways to use [TransformerEngine library](https://github.com/NVIDIA/TransformerEngine) to finetune (full) a Llama2 model from [Hugging Face](https://huggingface.co/meta-llama/Llama-2-7b-hf).

</div>


<div class="alert alert-info">

<b>Note</b>
    
This tutorial showcases finetuning a full 7B [Llama2 model](https://huggingface.co/meta-llama/Llama-2-7b-hf) on h100 GPUs (which have 80GB of HBM). Therefore, running the following individual portions (Baseline, Improvement 1, Improvement 2 and Improvement 3) for perf benchmarking will require restarting the Jupyter notebook kernel each time.

</div>


## Table of contents
1. From "Transformer" to "Llama"
2. Hugging Face's `LlamaModel`
    - Hugging Face's `LlamaDecoderLayer`
3. TransformerEngine's `TransformerLayer`
    - `TransformerLayer` options explained
4. Necessary Imports
5. [Baseline] Running HF `LlamaModel` (Precision: `BF16`)
6. [Improvement 1] Replace `nn.Linear` with TE's `Linear` layers (Precision: FP8)
7. [Improvement 2] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)
8. [Improvement 3] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)
9. Benchmarks revisited and Conclusion

## From "Transformer" to "Llama" 

<figure align="center">
<img src="media/transformer_llama.png" width="50%">
    <figcaption> Fig 1: Llama visualized as a transformer. (generated with <a href="https://catalog.ngc.nvidia.com/orgs/nvidia/teams/ai-foundation/models/sdxl">Nvidia's AI-foundation models</a>)</figcaption>
</figure>

A flashback:
- 2017: ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762) paper introduced pioneering "Transformer" architecture and changed the NLP field forever.
- 2018-2020: Emergence of GPT model series that showed causal decoder architectures are great fit for pretraining, few-shot and zero-shot learning.
- Fast forward to 2023-2024: Following GPT-3/GPT-4 success stories, researchers and companies raced to produce the next best pretrained model that could further be finetuned for application-specific use-cases. 
- One of the latest in this line of pretrained models which is also open source is Meta's [Llama 2](https://llama.meta.com/llama2) models (Large Language Model Meta AI). 
    - These models range from 7B to 65B parameters.
    - LLaMA 2 was pretrained on 2 trillion tokens.

For more information on Llama 2 consider reading the [Huggingface tutorial](https://huggingface.co/blog/llama2). As a quick summary, here are some of the important differences b/w the conventional transformer decoder architecture vs Llama2 architecture:

1. Decoder only model (causal language modeling and next word prediction)
2. RMSNorm in place of the LayerNorm
3. SwiGLU activation function
4. RoPE as positional embeddings 
5. Grouped Query Attention
6. Trained on 4K context length

<figure align="center">
<img src="media/transformers_vs_llama.png" width="100%">
    <figcaption> Fig 2: Comparing conventional Transformer architecture with Llama architecture. </figcaption>
</figure>

## Hugging Face's `LlamaModel`
Hugging Face provides an open-source implementation of `Llama` model in [`modeling_llama.py`](https://github.com/huggingface/transformers/blob/3d2900e829ab16757632f9dde891f1947cfc4be0/src/transformers/models/llama/modeling_llama.py#L4).

Here's a block diagram that shows how Llama model is implemented in the Hugging Face repo. Notice the modular encapsulated form and `LlamaDecoderLayer` at the core of the model implementation. This core layer chunk is targeted for optimizations in a couple of the improvements later in this tutorial (Improvement 2 and Improvement 3). 

<figure align="center">
<img src="media/llama_for_causal_lm.png" width="40%">
    <figcaption> Fig 3: Causal Llama Model Block Diagram. </figcaption>
</figure>

The above diagram translates to the following text output of the model in pytorch. Notice that the core of the model has 32 `LlamaDecoderLayer`s. 

```
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaFlashAttention2(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
```

#### HF `LlamaDecoderLayer`

Let's take a closer look at `LlamaDecoderLayer`. It's composed of `input_layernorm`, `self_attn`, `post_attention_layernorm` and `mlp` modules. Each module has associated weights as shown in the diagram.

<figure align="center">
<img src="media/llama_zoom.png" width="70%">
    <figcaption> Fig 4: Causal Llama Model Block Diagram (with simplified illustration of the <a href="https://github.com/huggingface/transformers/blob/e770f0316d2a9b787c9d1440f204fcb65e176682/src/transformers/models/llama/modeling_llama.py#L695">LlamaDecoderLayer</a>). </figcaption>
</figure>

##### Self_Attn Layer
For simplicity in the block diagram illustration of the "self_attn" box, we omit the "Grouped Query Attention" operation and only showcase the modules which have associated weights.
   
##### MLP Layer

SwiGLU is an activation defined as follows in the [modeling_llama.py](https://github.com/huggingface/transformers/blob/7c4995f93d8d24aae05e1e43279c96dce736e5c8/src/transformers/models/llama/modeling_llama.py#L236) file in the Hugging Face github repo:
```
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
```
It requires a set of 3 weights as compared to 2 weights in conventional "MLP" layers e.g. in the traditional transformer architecture.

## [Baseline] Running HF `LlamaModel` (Precision: `BF16`)

Llama2 weights are loaded into the Hugging Face native implementation `LlamaForCausalLM` (refer to [modeling_llama.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py)). 

`batch_size` is `8` and precision is `BF16`

The `LlamaDecoderLayer` is left unchanged in the baseline as follows:

<figure align="center">
<img src="media/llamadecoderlayer.png" width="30%">
    <figcaption> Fig 5: Revisiting "LlamaDecoderLayer". </figcaption>
</figure>


In [1]:
import torch
torch.set_warn_always(False)

In [2]:
# Restart the notebook (to flush the GPU memory)
from utils import restart_jupyter_notebook
restart_jupyter_notebook()


# Minimize the bloat by wrapping all the imports in a function and return
from utils import *


## Default hyperparams, also defined in `utils.py` in class `Hyperparameters`
## !!! `model_name` attr must point to the location of the model weights !!!
hyperparams.model_name = "" # <== Add model weight location here
hyperparams.mixed_precision = "bf16"


## Init the model and accelerator wrapper
model = init_baseline_model(hyperparams)
accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)


## Finetune the model
finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)

  from .autonotebook import tqdm as notebook_tqdm
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.25s/it]
Map: 100%|██████████| 9846/9846 [00:00<00:00, 12152.20 examples/s]


10 finetuning steps complete!
Average time taken per step: 289 milliseconds


Let's add this information in a table and keep comparing it with a few possible improvements in future sections:

| Models                                                      | Precision | Step Time (or ms per batch) | Speedup (over baseline) |
|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|
| HF (baseline)                                               | BF16      | 289                         | 1                     

## [Improvement 1] Replace `nn.Linear` with TE's `Linear` layers (Precision: `FP8`)

[Hugging Face accelerate](https://github.com/huggingface/accelerate) provides a [`convert_model`](https://github.com/huggingface/accelerate/blob/97d2168e5953fe7373a06c69c02c5a00a84d5344/src/accelerate/utils/transformer_engine.py#L24) method to replace `torch.nn.Linear` layers with `transformer_engine.pytorch.module.Linear` layers. The following diagram illustrates this visually.

<figure align="center">
<img src="media/llamadecoderlayer_replace_with_telinear.png" width="70%">
    <figcaption> Fig 6: Replacing "nn.Linear" with "TE.Linear". </figcaption>
</figure>


This is the most straightforward way to use TransformerEngine's `FP8` precision during training/finetuning for HF Llama model. Notice that the entire `LlamaDecoderLayer` is mostly left unchanged, it's only the `nn.Linear` layers that get replaced with `TE.Linear` layers. 

#### How to run the model in `FP8` precision

After the substition, these layers can be run in `FP8` precision by the following change over the previous BF16 runs. (For more information, refer the corresponding `te_llama.py` file in this tutorial, especially the `wrap_with_accelerator` function).

```
# Specify the `FP8RecipeKwargs` (additional argument required to run in `fp8` precision)
fp8_kwarg_handler = [FP8RecipeKwargs(backend="te")]

# Pass the `FP8RecipeKwargs` to the `Accelerator` init call
accelerator = Accelerator(
    ...
    kwargs_handlers=fp8_kwarg_handler
)
```


In [1]:
# Restart the notebook (to flush the GPU memory)
from utils import restart_jupyter_notebook
restart_jupyter_notebook()


# Minimize the bloat by wrapping all the imports in a function and return
from utils import *


## Default hyperparams, also defined in `utils.py` in class `Hyperparameters`
## !!! `model_name` attr must point to the location of the model weights !!!
hyperparams.model_name = "" # <== Add model weight location here
hyperparams.mixed_precision = "fp8"


## Init the model and accelerator wrapper
model = init_baseline_model(hyperparams)
accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)


## Finetune the model
finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)

  from .autonotebook import tqdm as notebook_tqdm
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.43s/it]


10 finetuning steps complete!
Average time taken per step: 310 milliseconds


Based on the above run, the performance of the baseline implementation is as follows:

| Models                                                      | Precision | Step Time (or ms per batch) | Speedup (over baseline) |
|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|
| HF (baseline)                                               | BF16      | 289                         | 1                       |
| HF (replace `nn.Linear` with `TE.Linear`)                   | FP8       | 310                         | _**0.93**_                 

The performance with TE `Linear` layers has actually decreased by a factor of **0.93** (or **7% slower**). Let's try to understand the reason of the slow down. 

### Understanding the aberration with Improvement 2 (CPU Overheads)

<div class="alert alert-info">

<b>Note</b>

<a href="https://developer.nvidia.com/nsight-systems">NVIDIA Nsight Systems</a> allows taking a closer look at the underlying CPU and GPU activity of an application or workload. This process is called "profiling" and the graphical visualization produced after running this process is called a "profile".

In the following explanation, profiles are annotated manually to allow the reader to better comprehend the time spent in various computations. 

</div>

<figure align="center">
<img src="media/profile_hf_llama_bf16.png" width="100%">
    <figcaption> Fig 7: Profile of HF Llama baseline implementation. (Top) A cross-section of the CPU/GPU activity showing a single transformer layer (mainly "SelfAttention" and "MLP" modules) forward. (Middle) An emphasized cross-section of a portion of the "SelfAttention" module that shows the "Q", "K" and "V" projection operations (basically `nn.Linear` layers). (Bottom) A further emphasized cross-section of only the "Q" projection operation which shows the corresponding CPU and GPU activity for a single `nn.Linear` layer. </figcaption>
</figure>


<figure align="center">
<img src="media/profile_hf_llama_fp8.png" width="100%">
    <figcaption> Fig 8: Profile of HF Llama with "Improvement 1" (replacing `nn.Linear` layers with TE's `Linear` layers) implementation. (Top) A cross-section of the CPU/GPU activity showing a single transformer layer (mainly "MultiheadAttention" and "LayerNormMLP" modules) forward. (Middle) An emphasized cross-section of a portion of the "MultiheadAttention" module that shows the Q, K and V projection operations (basically TE's `Linear` layers). (Bottom) A further emphasized cross-section of only the Q projection operation which shows the CPU and GPU activity for a single TE's `Linear` layer.</figcaption>
</figure>

<div class="alert alert-light">

<b>Insight</b>
    
In the profiles above, whenever the GPU activity is absent, it generally indicates that the GPU is waiting for CPU to dispatch a kernel i.e. the GPU is idle and waiting for more work. In general, we'd want the GPU to be active as much as possible and wait less for the CPU.
    
One thing clearly noticeable is that the GPU kernels in the "baseline" implementation occupy GPU more of the time than the "Improvement 1" (replacing `nn.Linear` with TE's `Linear`) implementation. Let's dig into why that is the case.

</div>

To simplify the information from the above profiles, consider the following table that compares the two implementations:

|                                                     |              |                        |              |                   | Baseline (microseconds) | Improvement 1 (replace `nn.Linear` with TE's `Linear`)(microseconds) | Speedup |
|-----------------------------------------------------|--------------|------------------------|--------------|-------------------|-------------------------|----------------------------------------------------------------------|---------|
| Single transformer layer forward ("attn" and "mlp") |              |                        |              |                   | 2443                    | 6299                                                                 | -       |
|                                                     | "attn" layer |                        |              |                   | 1470                    | 3842                                                                 | -       |
|                                                     |              | Q, K and V projections |              |                   | 326                     | 2027                                                                 | -       |
|                                                     |              |                        | Q projection |                   | 117                     | 1001                                                                 | -       |
|                                                     |              |                        |              | Amax/Scale update | -                       | 72                                                                   | -       |
|                                                     |              |                        |              | Buffer allocation | -                       | 49                                                                   | -       |
|                                                     |              |                        |              | Cast+Transpose    | -                       | 35 (10 + 25)                                                         | -       |
|                                                     |              |                        |              | MatMul            | 106                     | 52                                                                   | **2.03x**   |
|                                                     | "mlp" layer  |                        |              |                   | 790                     | 2057                                                                 | -       |


Now let's make a few observations:

1. For a single transformer layer, Improvement 1 implementation (`nn.Linear` layers replaced with TE's `Linear` layers and with FP8 precision) takes more time than the baseline implementation (BF16 precision). 
2. If we keep zooming in the profile to individual "attn" (`SelfAttention` for baseline and `MultiheadAttention` for Improvement 1) or "mlp" (`MLP` for baseline or `LayerNormMLP` for Improvement 1) layers, the trend is similar that Improvement 1 is slower than the baseline implementation.
3. At its core the `Linear` layers in Improvement 1 are slower than the `nn.Linear` layers in the baseline implementation (1001 microseconds vs 117 microseconds, i.e. _slower by a factor of **8.5x**_)

#### Why is TE's `Linear` slower than `nn.Linear` layer?

If we look closely, TE's `Linear` layer contains more kernels for the following tasks: 
1. Amax and scale update 
2. FP8 weights and transpose buffer allotment
3. Cast+Transpose kernels for inputs and weights (to cast BF16 inputs and weights to their FP8 counterparts)
4. Matrix Multiplication kernel (in FP8 precision)

While the GPU is idle, it's usually waiting for CPU to finish doing its work and dispatch a kernel that can run on the GPU! Further, all those kernels are pretty short for the workload for the current finetuning tutorial (`batch_size=8`, `max_seq_length=256`). Therefore, overall the time taken by the TE's `Linear` layer is **1001 microseconds**. 

Compare this to `nn.Linear` layer (in Fig: ???) which contains only a Matrix Multiplication kernel (in BF16 precision). Almost all of the work inside the linear layer is running that kernel. Further, as the GPU doesn't have to wait for the CPU for more kernels, the `nn.Linear` layer itself takes less time to run, only **117 microseconds**, almost a tenth of the fraction of TE's `Linear` layer. 

<div class="alert alert-light">

<b>Insight</b>
    
Note that in Improvement 1, the Matrix Multiplication operation in FP8 precision itself is faster than the Matrix Multiplication in BF16 precision in the baseline implementation!
    
What if we could force the GPU to spend more time in Matrix Multiplication?

</div>

#### How to make TE's `Linear` faster than `nn.Linear` layer?

As we noted earlier, the workload for the current finetuning tutorial (`batch_size=8`, `max_seq_length=256`) is small and therefore isn't able to fully utilize the capability of TE's `Linear` layers.

_Generally, we'd want to increase the workload so that GPU is active more of the time than the CPU._

As a small experiment, let's see how the profiles look like for the "Q" projection operation when we increase the `batch_size` from `8` to `128`. Since this will result in GPU running of memory, let's simultaneously decrease the size of our model from `32` layers to just `4` (i.e. `config=num_hidden_layers=4`).


<figure align="center">
<img src="media/profile_hf_llama_bf16_bs_128.png" width="100%">
    <figcaption> Fig 9: (baseline implementation) Profile of the "Q" projection, especially the `nn.Linear` layer when batch_size=128 and num_hidden_layers=4. </figcaption>
</figure>


<figure align="center">
<img src="media/profile_hf_llama_fp8_bs_128.png" width="100%">
    <figcaption> Fig 10: (Improvement 1 implementation) Profile of the "Q" projection, especially the TE's `Linear` layer when  batch_size=128 and num_hidden_layers=4. </figcaption>
</figure>

As is visible in the figures above, the "Q Projection" with TE's `Linear` layer (**928 microseconds**) is faster as compared to the case with `nn.Linear` layer (**1356 microseconds**). It is mainly due to the following reasons:
1. The FP8 Matrix Multiplication (**691 microseconds**) in the TE's `Linear` layer is faster than the BF16 Matrix Multiplication (**1353 microseconds**) in the `nn.Linear` layer case. That's **almost 2x** speedup. Further, the FP8 Matrix Multiplication also takes more time as compared to the smaller tutorial workload (**52 microseconds**) i.e. more time is spent in the GPU instead of the CPU.
2. Other kernels in the TE's `Linear` layer (especially the transposes) also take more time, but the speedup provided by the Matrix Multiplication more than compensates for this increased time, thereby providing an overall speedup of **46%** (**1356** for `nn.Linear` vs **928** for TE `Linear`)!

<div class="alert alert-light">
<b>Insight</b>
    
As the workload is increased, more time is spend in the GPU as compared to CPU and the speedup provided by FP8 Matrix Multiplication dominates the time spent in CPU and other GPU kernels (e.g. cast transposes).

In the above example, `batch_size` was increased to increase the workload size. Another way is to increase the `sequence_length` which will also have a similar effect of increasing the workload size!

</div>

## [Improvement 2] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)

In addition to basic layers like `Linear` and `LayerNorm`, TransformerEngine offers larger modules like `MultiheadAttention` (combines "LayerNorm" and "Self Attention") and `LayerNormMLP` (combines "LayerNorm" and "MLP") that could replace their counterparts in the `LlamaDecoderLayer` and potentially provide more speedup. Further, TransformerEngine also offers a full `TransformerLayer` (which further combines `MultiheadAttention` and `LayerNormMLP` layers) which could be substituted for `LlamaDecoderLayer` (with careful mapping of the weights since the name of the weights are different for those two layers). Let's take a closer look at TransformerEngine's `TransformerLayer`. 

### TransformerEngine's `TransformerLayer`

At a higher level, TE's `TransformerLayer` could be visualized as an apt replacement for the `LlamaDecoderLayer`. But the internals of the `TransformerLayer` are organized a bit differently. 

<figure align="center">
<img src="media/tellamadecoderlayer.png" width="30%">
    <figcaption> Fig 11: TransformerEngine's `TransformerLayer` </figcaption>
</figure>

Just like Hugging Face's `LlamaDecoderLayer`, TransformerEngine's `TransformerLayer` encapsulates `self_attention` (as `MultiheadAttention`) and `mlp` (as `LayerNormMLP`). A major difference is that the two `Norm`s are included in the `MultiheadAttention` and `LayerNormMLP` layers as shown in the following output prompt:

```
TransformerLayer(
    (self_attention): MultiheadAttention(
      (layernorm_qkv): LayerNormLinear()
      (core_attention): DotProductAttention()
      (proj): Linear()
    )
    (layernorm_mlp): LayerNormMLP()
)
```

### `TransformerLayer` options explained

<div class="alert alert-info">

<b>Note</b>
    
Here, we go over some of the options in `TransformerLayer` that are needed for the tutorial. For a complete list of options, refer the [TransformerLayer API documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html?highlight=transformerlayer#transformer_engine.pytorch.TransformerLayer).

</div>

In the accompanying `te_llama.py` file, `TELlamaDecoderLayer` is defined as a wrapper over TE's `TransformerLayer` with a few needed options that make `TransformerLayer` as a plug-in replacement for the HF's `LlamaDecoderLayer`.

```
class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
    def __init__(self, config):
        super().__init__(
            config.hidden_size,
            config.intermediate_size,
            config.num_attention_heads,
            bias=False,
            layernorm_epsilon=config.rms_norm_eps,
            hidden_dropout=0,
            attention_dropout=0,
            fuse_qkv_params=False,
            normalization="RMSNorm",
            activation="swiglu",
            attn_input_format="bshd",
        )
        te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads)
        self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()
```

Here's a list summarizing each option briefly:

1. `hidden_size`: size of each input sample.
2. `ffn_hidden_size`: intermediate size to which samples are projected.
3. `num_attention_heads`: number of attention heads in the transformer layer.
4. `bias`: switch to add additive biases to the submodule layers.
5. `layernorm_epsilon`: a value added to the denominator of layer normalization for numerical stability. Default is `1e-5`.
6. `hidden_dropout`: dropout probability for the dropout op after FC2 layer (fully connected layer no. 2). Default is `0.1`.
7. `attention_dropout`: dropout probability for the dropout op during multi-head attention. Default is `0.1`. 
8. `fuse_qkv_params`:  if set to True, TransformerLayer module exposes a single fused parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits and also enables the argument fuse_wgrad_accumulation.
9. `normalization`: type of normalization applied. Default is `LayerNorm`.
10. `activation`: type of activation used in the MLP block. Default is `gelu`.
11. `attn_input_format`: controls whether the dimensions of the intermediate hidden states is 'batch first' ('bshd') or 'sequence first' ('sbhd'). `s` stands for the sequence length, `b` batch size, `h` the number of heads, `d` head size. Note that these formats are very closely related to the `qkv_format` in the `MultiHeadAttention` and `DotProductAttention` modules. 


Further, note that `RotaryPositionEmbedding` is defined as part of the TE's `TransformerLayer` itself since it expects this rope cache if RoPE is used in the model. 

### Comparing Hugging Face's `LlamaDecoderLayer` with TranformerEngine's `TransformerLayer`

Let's revisit how `LlamaDecoderLayer`s form the core of the decoder layer stack in HF's llama implementation:
```
ModuleList(
  (0-31): 32 x LlamaDecoderLayer(
    (self_attn): LlamaAttention(
      (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (rotary_emb): LlamaRotaryEmbedding()
    )
    (mlp): LlamaMLP(
      (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
      (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
      (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
      (act_fn): SiLU()
    )
    (input_layernorm): LlamaRMSNorm()
    (post_attention_layernorm): LlamaRMSNorm()
  )
)
```

A major portion of the Hugging Face model implementation (`LlamaDecoderLayer`) could be potentially replaced with TransformerEngine's implementation.


### Mapping weights from HF's `LlamaDecoderLayer` to TE's `TransformerLayer`

Refer the accompanying file `te_llama.py` which provides a reference to create a Llama 2 model with TE's `TransformerLayer` after replacing HF's `LlamaDecoderLayer`.

Briefly, following pieces of code are put together:

1. `TELlamaDecoderLayer` is added as a wrapper for `TransformerLayer`. 
```
class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
    """
    Wrapper class over TE's `TransformerLayer`. This makes the wrapper very
    similar to HF's `LlamaDecoderLayer` and easier to replace it in the code.

    Args:
        config: LlamaConfig
        args: positional args (for compatibility with `LlamaDecoderLayer`)
        kwargs: keyword args (for compatibility with `LlamaDecoderLayer`)
    """
    def __init__(self, config, *args, **kwargs):
        super().__init__(
            hidden_size=config.hidden_size,
            ffn_hidden_size=config.intermediate_size,
            num_attention_heads=config.num_attention_heads,
            bias=False,
            layernorm_epsilon=config.rms_norm_eps,
            hidden_dropout=0,
            attention_dropout=0,
            fuse_qkv_params=False,
            normalization="RMSNorm",
            activation="swiglu",
            attn_input_format="bshd",
        )
        te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads)
        self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()

    def forward(self,
                hidden_states,
                *args,
                attention_mask,
                **kwargs):
        """
        Custom forward to make sure we only pass relevant arguments to the
        forward pass of the `TransformerLayer`. Also, make sure the output
        format matches the output of the HF's `LlamaDecoderLayer`.
        """
        return (super().forward(hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb),)
```

2. Before creating a `LlamaForCausalLM`, `replace_decoder` context manager is used to monkey-patch `LlamaDecoderLayer` with `TELlamaDecoderLayer`.

```
@contextmanager
def replace_decoder(te_decodder_cls):
    """
    Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`.
    """
    original_llama_decoder_cls = transformers.models.llama.modeling_llama.LlamaDecoderLayer
    transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decodder_cls
    try:
        yield
    finally:
        transformers.models.llama.modeling_llama.LlamaDecoderLayer = original_llama_decoder_cls
.
.
.
class TELlamaForCausalLM:
    """
    Causal LM created with `LlamaModel`. The underlying `LlamaDecoderLayer`
    class is monkey-patched with `TELlamaDecoderLayer` class before
    initializing the causal LM with `LlamaForCausalLM`.

    Args:
        config: LlamaConfig
    """

    def __new__(cls, config: LlamaConfig):
        with replace_decoder(te_decodder_cls=TELlamaDecoderLayer):
            llama_for_causal_lm = LlamaForCausalLM(config)
        return llama_for_causal_lm
.
.
.
```

3. A custom `pretrained_from_local` method is added that copies the weights from the checkpoint (which is meant for HF Llama implementation) to the modified `TELlamaForCausalLM` by carefully mapping the weights from the `LlamaDecoderLayer` (HF) to `TransformerLayer` (TE). The method `replace_params` maps and copies apt weights from `LlamaDecoderLayer` to the `TransformerLayer`. Refer to the following diagram for more details.

```
def replace_params(hf_state_dict, te_state_dict):
    # collect all layer prefixes to update
    all_layer_prefixes = set()
    for param_key in hf_state_dict.keys():
        layer_prefix_pat = 'model.layers.\d+.'
        m = re.match(layer_prefix_pat, param_key)
        if m is not None:
            all_layer_prefixes.add(m.group())

    for layer_prefix in all_layer_prefixes:
        # When loading weights into models with less number of layers, skip the
        # copy if the corresponding layer doesn't exist in TE model
        if layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight' in te_state_dict:
            te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'input_layernorm.weight'].data[:]

        if layer_prefix + 'self_attention.layernorm_qkv.query_weight' in te_state_dict:
            te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.query_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.q_proj.weight'].data[:]

        if layer_prefix + 'self_attention.layernorm_qkv.key_weight' in te_state_dict:
            te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.k_proj.weight'].data[:]
    .
    .
    .

    return all_layer_prefixes
```

<figure align="center">
<img src="media/weight_swap.png" width="70%">
    <figcaption> Fig 12: Replace `LlamaDecoderLayer` with `TransformerLayer`. </figcaption>
</figure>


After initializing the modified Llama model this way, the core decoder layers get changed to `TELlamaDecoderLayer` (wrapper around `TransformerLayer`) as shown in the following output:
```
ModuleList(
  (0-31): 32 x TELlamaDecoderLayer(
    (self_attention): MultiheadAttention(
      (layernorm_qkv): LayerNormLinear()
      (core_attention): DotProductAttention(
        (flash_attention): FlashAttention()
        (fused_attention): FusedAttention()
        (unfused_attention): UnfusedDotProductAttention(
          (scale_mask_softmax): FusedScaleMaskSoftmax()
          (attention_dropout): Dropout(p=0, inplace=False)
        )
      )
      (proj): Linear()
    )
    (layernorm_mlp): LayerNormMLP()
  )
)
```

In summary, the model gets changed as follows with a bigger chunk of the implementation (core decoder layers) coming from TransformerEngine.

<figure align="center">
<img src="media/model_change.png" width="80%">
    <figcaption> Fig 13: Language model after the HF's `LlamaDecoderLayer`s are replaced with TE's `TransformerLayer`s. </figcaption>
</figure>


<div class="alert alert-info">
<b>Note</b>

Let's first run this "TELlama" implementation in `BF16` precision.
</div>

In [1]:
# Restart the notebook (to flush the GPU memory)
from utils import restart_jupyter_notebook
restart_jupyter_notebook()


# Minimize the bloat by wrapping all the imports in a function and return
from utils import *


# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`
## !!! `model_name` attr must point to the location of the model weights !!!
hyperparams.model_name = "" # <== Add model weight location here
hyperparams.mixed_precision = "bf16"


# Init the model and accelerator wrapper
model = init_te_llama_model(hyperparams)
accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)


# Finetune the model
finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)

  from .autonotebook import tqdm as notebook_tqdm


10 finetuning steps complete!
Average time taken per step: 242 milliseconds


Compared to the "baseline" implementation, we see that with replacing larger chunks of Transformer Engine's layers gives a speedup of **19%** even when using only BF16 precision!

| Models                                                      | Precision | Step Time (or ms per batch) | Speedup (over baseline) |
|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|
| HF (baseline)                                               | BF16      | 289                         | 1                       |
| HF (replace `nn.Linear` with `TE.Linear`)                   | FP8       | 310                         | 0.93                    |
| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16      | 242                         | 1.19                   

## [Improvement 3] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)

Now that most of the HF Llama model implementation (`LlamaDecoderLayer`s) has been swapped with TransformerEngine implementation (`TELlamaDecoderLayer` or `TransformerLayer`), let's see how `FP8` training helps improve performance.

In [2]:
# Restart the notebook (to flush the GPU memory)
from utils import restart_jupyter_notebook
restart_jupyter_notebook()


# Minimize the bloat by wrapping all the imports in a function and return
from utils import *


# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`
## !!! `model_name` attr must point to the location of the model weights !!!
hyperparams.model_name = "" # <== Add model weight location here
hyperparams.mixed_precision = "fp8"


## Init the model and accelerator wrapper
model = init_te_llama_model(hyperparams)
accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)


## Finetune the model
finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)

  from .autonotebook import tqdm as notebook_tqdm


10 finetuning steps complete!
Average time taken per step: 231 milliseconds


| Models                                                      | Precision | Step Time (or ms per batch) | Speedup (over baseline) |
|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|
| HF (baseline)                                               | BF16      | 289                         | 1                       |
| HF (replace `nn.Linear` with `TE.Linear`)                   | FP8       | 310                         | 0.93                    |
| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16      | 242                         | 1.19                    |
| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | FP8       | 231                         | 1.25                    |


## Conclusion

Major takeaways from the above tutorial are:
1. Transformer Engine offers layers with differing level of granularity 
    - Basic modules: `Linear` and `LayerNormLinear` (Also provides the core attention operation in `DotProductAttention`)
    - Combine basic modules into larger modules - `MultiheadAttention` and `LayerNormMLP`
    - Combine `MultiheadAttention` and `LayerNormMLP` into `TransformerLayer`, the most performance efficient implementation.
2. Replacing basic modules like `nn.Linear` with TE's `Linear` layer should be done keeping in mind the workload size since at small workload sizes, the time spent in the CPU can dominate the time spent in the GPU which could actually result in a slowdown.
3. Using larger `TransformerLayer` module from Transformer Engine provides speedup over Hugging Face's native Llama 2 implementation. This needs careful initializing of model such that the HF's `LlamaDecoderLayer` are substituted with TE's `TransformerLayer` and then model weights (which are meant for `LlamaDecoderLayer`) are correctly mapped to their counterparts in TE's `TransformerLayer`.
4. NVIDIA Nsight Systems allows to peek into the CPU and GPU activity while running a workload (this tutorial in our case) and reason about making appropriate optimizations.  
    - In this tutorial, we demonstrated that increasing the workload size makes even basic TE modules like `Linear` run faster (in FP8 precision) than `nn.Linear` (in BF16 precision).