# Accelerate HF Llama model with Transformer-Engine

Goal: This tutorial showcases three incrementally efficient ways to use [Transformer-Engine library](https://github.com/NVIDIA/TransformerEngine) to finetune (full) a Llama2 model from [HuggingFace](https://huggingface.co/meta-llama/Llama-2-7b-hf).

#### Things to note about this tutorial:
1. It showcases finetuning a full 7B Llama2 model (https://huggingface.co/meta-llama/Llama-2-7b-hf).
2. The GPU requirements are therefore larger (this tutorial targets h100 GPUs which have 80GB of HBM). 
3. Therefore, running each following individual portions for perf benchmarking will require restarting the Jupyter notebook kernel each time.

## Table of contents
1. From "Transformer" to "Llama"
2. HuggingFace's `LlamaModel`
    - HF `LlamaDecoderLayer`
3. Transformer-Engine's `TransformerLayer`
    - `TransformerLayer` options explained
1. Tutorial overview and Benchmarks preview
2. Necessary Imports
3. Tutorial part 1
4. Tutorial part 2
5. Tutorial part 3
6. Tutorial part 4
7. Benchmarks revisited and Conclusion

## From "Transformer" to "Llama" 
![Transformers to Llama](media/transformer_llama.png "transformer llama")

A flashback:
- 2017: ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762) paper introduced pioneering "Transformers" architecture and changed the NLP field forever.
- 2018-2020: Emergence of GPT model series that showed causal decoder architectures are great fit for pretraining, few-shot and zero-shot learning.
- Fast forward to 2023-2024: Following GPT-3/GPT-4 success stories, researchers and companies raced to produce the next best pretrained model that could further be finetuned for application-specific use-cases. 
- One of the latest in this line of pretrained models which is also open source is Meta's Llama-1/Llama-2 models (Large Language Model Meta AI). 
    - These models range from 7B to 65B parameters.
    - LLaMA-2 was pretrained on 2 trillion tokens.

A lot is already available on the web about llama (we consider llama v2, [checkout meta ai for more details](https://llama.meta.com/)). A few important details are:
1. Decoder only model (causal language modeling and next word prediction)
2. RMSNorm in place of the LayerNorm
3. SwiGLU activation function
4. RoPE as positional embeddings 
5. Grouped Query Attention
6. Trained on 4K context length

## HuggingFace's `LlamaModel`
Huggingface is the go to place for open source NLP model implementations and also has an implementation of `Llama` model in [`modeling_llama.py`](https://github.com/huggingface/transformers/blob/3d2900e829ab16757632f9dde891f1947cfc4be0/src/transformers/models/llama/modeling_llama.py#L4).

Here's a block diagram that shows how Llama model is implemented in the HuggingFace repo. Notice the modular encapsulated form and `LlamaDecoderLayer` at the core of the model implementation. We will target this core layer chunk for optimizations in one of the improvements that we talk about in this tutorial. 

![Causal Llama Model Block Diagram](media/llama_for_causal_lm.png "LlamaForCausalLM")

The above diagram translates to the following text output of the model in pytorch. Notice the bunch of `LlamaDecoderLayer`s at the core of the model. 

```
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaFlashAttention2(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
```

#### HF `LlamaDecoderLayer`

Let's take a closer look at `LlamaDecoderLayer`. It's composed of `input_layernorm`, `self_attn`, `post_attention_layernorm` and `mlp` modules. Each module has associated weights as shown in the diagram.

![LlamaDecoderLayer](media/llama_zoom.png "LlamaDecoderLayer")

## Transformer-Engine's `TransformerLayer`

At a higher level, TE's `TransformerLayer` could be visualized as an apt replacement for the `LlamaDecoderLayer`. But the internals of the `TransformerLayer` are organized a bit differently. 

![TransformerLayer](media/tellamadecoderlayer.png "TELlamaDecoderLayer")

Just like HuggingFace's `LlamaDecoderLayer`, TE's `TransformerLayer` majorly encapsulates `self_attention` and `mlp`. A major difference is that the two `Norm`s are combined with `self_attention` and `mlp` layers.

```
TransformerLayer(
    (self_attention): MultiheadAttention(
      (layernorm_qkv): LayerNormLinear()
      (core_attention): DotProductAttention()
      (proj): Linear()
    )
    (layernorm_mlp): LayerNormMLP()
)
```


Just like HF's `LlamaDecoderLayer`, it is possible to create a custom model (let's call it `MyGPT`) with only `TransformerLayer`s as follows:
```
# Define the model with only TransformerLayer
class MyGPT(torch.nn.Module):
    def __init__(self, num_layers = 1, hidden_size = 128, ffn_hidden_size=512, num_attention_heads=16):
        super().__init__()
        self.layers = torch.nn.ModuleList([TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads)])
        
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
...
# Init the model
mygpt = MyGPT()
...
```
```
# Training loop
for data, labels in dataset:
    ...
    out = mygpt(data)
    loss = loss_fn(out, labels)
    ...
```

### `TransformerLayer` options explained
[Refer the [Transformer-Engine's docs](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html?highlight=transformerlayer#transformer_engine.pytorch.TransformerLayer) for more details]

In the accompanying `te_llama.py` file, `TELlamaDecoderLayer` (a wrapper over TE's `TransformerLayer`) is defined as follows with a bunch of options:

```
class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
    def __init__(self, config):
        super().__init__(
            config.hidden_size,
            config.intermediate_size,
            config.num_attention_heads,
            bias=False,
            layernorm_epsilon=config.rms_norm_eps,
            hidden_dropout=0,
            attention_dropout=0,
            fuse_qkv_params=False,
            normalization="RMSNorm",
            activation="swiglu",
            attn_input_format="bshd",
        )
        te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads)
        self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()
```

1. `hidden_size`: size of each input sample
2. `ffn_hidden_size`: intermediate size to which samples are projected.
3. `num_attention_heads`: number of attention heads in the transformer layer.
4. `bias`: switch to add additive biases to the submodule layers.
5. `layernorm_epsilon`: a value added to the denominator of layer normalization for numerical stability. Default is `1e-5`.
6. `hidden_dropout`: dropout probability for the dropout op after FC2 layer (fully connected layer no. 2). Default is `0.1`.
7. `attention_dropout`: dropout probability for the dropout op during multi-head attention. Default is `0.1`. 
8. `fuse_qkv_params`:  if set to True, TransformerLayer module exposes a single fused parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits and also enables the argument fuse_wgrad_accumulation.
9. `normalization`: type of normalization applied. Default is `LayerNorm`.
10. `activation`: type of activation used in the MLP block. Default is `gelu`.
11. `attn_input_format`: controls whether the dimensions of the intermediate hidden states is 'batch first' ('bshd') or 'sequence first' ('sbhd'). `s` stands for the sequence length, `b` batch size, `h` the number of heads, `d` head size. Note that these formats are very closely related to the `qkv_format` in the `MultiHeadAttention` and `DotProductAttention` modules. 


Further, note that `RotaryPositionEmbedding` is defined as part of the TE's `TransformerLayer` itself since it expects this rope cache if RoPE is used in the model. 

## Tutorial Overview and Benchmarks Preview

#### Part 1 (Baseline):  HF Llama2 (precision: `BF16`)
Llama2 weights are loaded into the HuggingFace native implementation `LlamaForCausalLM` (refer to [modeling_llama.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py)) 

#### Part 2 (Improvement 1): HF Llama2 (replace `nn.Linear` with `TE.Linear` | precision: `FP8`)
[HuggingFace accelerate](https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/transformer_engine.py) provides a way to replace `torch.nn.Linear` layers with `transformer_engine.pytorch.module.Linear` layers which could be run with `FP8` precision. This is the most straightforward way to use Transformer-Engine's `FP8` precision during training/finetuning for HF Llama model.

#### Part 3 (Improvement 2): TE Llama2 (replace `LlamaDecoderLayer` with `TE.TransformerLayer` | precision: `BF16`)
```
ModuleList(
  (0-31): 32 x LlamaDecoderLayer(
    (self_attn): LlamaAttention(
      (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (rotary_emb): LlamaRotaryEmbedding()
    )
    (mlp): LlamaMLP(
      (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
      (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
      (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
      (act_fn): SiLU()
    )
    (input_layernorm): LlamaRMSNorm()
    (post_attention_layernorm): LlamaRMSNorm()
  )
)
```

A major portion of the HuggingFace model implementation (`LlamaDecoderLayer`) could be replaced with TransformerEngine's implementation (`TransformerLayer` which is wrapped in `TELlamaDecoderLayer`). 

```
ModuleList(
  (0-31): 32 x TELlamaDecoderLayer(
    (self_attention): MultiheadAttention(
      (layernorm_qkv): LayerNormLinear()
      (core_attention): DotProductAttention(
        (flash_attention): FlashAttention()
        (fused_attention): FusedAttention()
        (unfused_attention): UnfusedDotProductAttention(
          (scale_mask_softmax): FusedScaleMaskSoftmax()
          (attention_dropout): Dropout(p=0, inplace=False)
        )
      )
      (proj): Linear()
    )
    (layernorm_mlp): LayerNormMLP()
  )
)
```

This needs some special care when handling the weight keys being loaded from the Llama2 model checkpoint as those keys are named differently in HF and TE implementaions. This tutorial comes with a `te_llama.py` file which contains the necessary reference wrappers for the larger model chunk replacement and also handles the weight keys mapping b/w the two layer implementaions.

[NOTE: This improvement still runs in `BF16` for reference and the next improvement is switching on `FP8` precision in addition to this improvement]

#### Part 4 (Improvement 3): TE Llama2 (replace `LlamaDecoderLayer` with `TE.TransformerLayer` | precision: `FP8`)
Same as improvement 3, but the precision is `FP8`.


#### Benchmarks Preview
A discussion is provided at the end of the tutorial.

| Models                                                      | Precision | Step Time (or ms per batch) | Speedup (over baseline) |
|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|
| HF (baseline)                                               | BF16      | 288                         | 1                       |
| HF (replace `nn.Linear` with `TE.Linear`)                   | FP8       | 307                         | 0.94                    |
| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16      | 243                         | 1.19                    |
| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | FP8       | 231                         | 1.24                    |

## Necessary imports

In [1]:
import time
import sys

import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader

from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup, AutoConfig
from transformers import DataCollatorForLanguageModeling
from datasets import load_dataset
from accelerate import Accelerator
from accelerate.utils.dataclasses import FP8RecipeKwargs

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_dataloaders(accelerator:Accelerator, batch_size:int = 8):
    dataset = load_dataset(dataset_name, split="train")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if getattr(tokenizer, "pad_token", None) is None:
        tokenizer.pad_token = tokenizer.eos_token

    def tokenize(element):
        outputs = tokenizer(
            element["text"],
            truncation=True,
            padding=False,
            max_length=max_seq_length,
            return_overflowing_tokens=False,
            return_length=False
        )
        return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]}

    with accelerator.main_process_first():
        dataset = dataset.map(
            tokenize,
            batched=True,
            remove_columns=dataset.column_names
        )

    pad_to_multiple_of = 16
    if accelerator.mixed_precision == "fp8":
        pad_to_multiple_of = 16
    elif accelerator.mixed_precision != "no":
        pad_to_multiple_of = 8


    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,
        pad_to_multiple_of=pad_to_multiple_of,
    )

    dataloader_params = {
        "batch_size": batch_size,
        "collate_fn": data_collator,
        "drop_last": True,
    }
    train_dataloader = DataLoader(dataset, **dataloader_params)
    return train_dataloader

#### Default Hyperparameters

In [3]:
mixed_precision = "bf16"
model_name = "" # <== Add model weight location here
dataset_name = "timdettmers/openassistant-guanaco"
dataset_text_field = "text"
learning_rate = 1.41e-5
batch_size = 8
max_seq_length = 256
gradient_accumulation_steps = 1
num_training_steps=10

## [Baseline] Running HF `LlamaModel` (Precision: `BF16`)

Llama2 weights are loaded into the HuggingFace native implementation `LlamaForCausalLM` (refer to [modeling_llama.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py)). 

`batch_size` is `8` and precision is `BF16`

To recap, the `LlamaDecoderLayer` is left unchanged in the baseline as follows:

![LlamaDecoderLayer](media/llamadecoderlayer.png "LlamaDecoderLayer")

In [4]:
# Init the model
config = AutoConfig.from_pretrained(model_name)
config._attn_implementation = "flash_attention_2"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    config=config,
    torch_dtype=torch.bfloat16,
)
# Needed for the cases when using TELlamaForCausalLM. So adding here for 1:1 comparison
model.config.use_cache=False

batch_size = 8

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
  return self.fget.__get__(instance, owner)()
Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.26s/it]


In [5]:
# Init HF accelerator that's used for training
accelerator = Accelerator(log_with="wandb", gradient_accumulation_steps=gradient_accumulation_steps, mixed_precision=mixed_precision)
accelerator.print(f'State: {accelerator.state}')
train_dataloader = get_dataloaders(accelerator, batch_size)

# Wrap model, optimizer/scheduler, dataloaders in accelerate
optimizer = AdamW(params = model.parameters(), lr=learning_rate)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=100,
    num_training_steps=num_training_steps,
)
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, lr_scheduler
)
accelerator.init_trackers("fp8-benchmarks", config={
    "model_name": model_name,
    "dataset_name": dataset_name,
    "batch_size": batch_size,
    "accelerator_state": accelerator.state,
    "mixed_precision": accelerator.mixed_precision,
},
init_kwargs={"wandb": {"name": f'{accelerator.mixed_precision}_bs_{batch_size}_{accelerator.num_processes}_gpus'}})



State: Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: bf16





In [6]:
model.model.layers

ModuleList(
  (0-31): 32 x LlamaDecoderLayer(
    (self_attn): LlamaFlashAttention2(
      (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (rotary_emb): LlamaRotaryEmbedding()
    )
    (mlp): LlamaMLP(
      (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
      (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
      (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
      (act_fn): SiLU()
    )
    (input_layernorm): LlamaRMSNorm()
    (post_attention_layernorm): LlamaRMSNorm()
  )
)

In [7]:
# Fine-tune the model
model.train()
completed_steps = 0
total_loss = 0
optimizer.zero_grad()

for _ in range(10):
    if completed_steps >= num_training_steps:
        break
    for step, batch in enumerate(train_dataloader):
        start_time = time.time()
        with accelerator.accumulate(model):
            outputs = model(**batch)
            loss = outputs.loss
            print(f"Step {step}: loss: {loss.item()}, batch shape: {batch['input_ids'].shape}, peak gpu mem: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
            total_loss += loss.detach().float()
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        if accelerator.sync_gradients:
            completed_steps += 1

        end_time = time.time()
        total_time = end_time - start_time
        print(f"Step {step} time {total_time}")
        accelerator.log({"batch_time": total_time, "input_ids": batch["input_ids"].cpu().numpy(), "attention_mask": batch["attention_mask"].cpu().numpy()})
        start_time = end_time

        if completed_steps >= num_training_steps:
            break

accelerator.end_training()

Step 0: loss: 1.742400050163269, batch shape: torch.Size([8, 256]), peak gpu mem: 25.39 GB
Step 0 time 0.7563114166259766
Step 1: loss: 2.0216784477233887, batch shape: torch.Size([8, 256]), peak gpu mem: 63.09 GB
Step 1 time 0.2901787757873535
Step 2: loss: 2.2615396976470947, batch shape: torch.Size([8, 256]), peak gpu mem: 63.09 GB
Step 2 time 0.2868986129760742
Step 3: loss: 2.092174530029297, batch shape: torch.Size([8, 256]), peak gpu mem: 63.09 GB
Step 3 time 0.289764404296875
Step 4: loss: 2.0941758155822754, batch shape: torch.Size([8, 256]), peak gpu mem: 63.09 GB
Step 4 time 0.28885555267333984
Step 5: loss: 2.010141372680664, batch shape: torch.Size([8, 256]), peak gpu mem: 63.09 GB
Step 5 time 0.287872314453125
Step 6: loss: 2.3474209308624268, batch shape: torch.Size([8, 256]), peak gpu mem: 63.09 GB
Step 6 time 0.2875847816467285
Step 7: loss: 1.8907099962234497, batch shape: torch.Size([8, 256]), peak gpu mem: 63.09 GB
Step 7 time 0.28859663009643555
Step 8: loss: 1.920

## [Improvement 1] Replace `nn.Linear` with TE's `Linear` layers (Precision: `FP8`)

[HuggingFace accelerate](https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/transformer_engine.py) provides a way to replace `torch.nn.Linear` layers with `transformer_engine.pytorch.module.Linear` layers as shown in the diagram below: 

![Replace `nn.Linear` with `TE.Linear`](media/llamadecoderlayer_replace_with_telinear.png "llamadecoderlayer_replace_with_telinear")


This is the most straightforward way to use Transformer-Engine's `FP8` precision during training/finetuning for HF Llama model. Notice that the entire `LlamaDecoderLayer` is mostly left unchanged, it's only the `nn.Linear` layers that get replaced with `TE.Linear` layers. After replacing, these layers can be run in `FP8` precision.



In [4]:
# Init the model
config = AutoConfig.from_pretrained(model_name)
config._attn_implementation = "flash_attention_2"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    config=config,
    torch_dtype=torch.bfloat16,
)
# Needed for the cases when using TELlamaForCausalLM. So adding here for 1:1 comparison
model.config.use_cache=False

batch_size = 8
mixed_precision="fp8"

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
  return self.fget.__get__(instance, owner)()
Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.24s/it]


In [5]:
# Init HF accelerator that's used for training. 
# Notice the use of `fp8` recipe
fp8_kwarg_handler = [FP8RecipeKwargs(backend="te")] if mixed_precision == "fp8" else None
accelerator = Accelerator(
    log_with="wandb", gradient_accumulation_steps=gradient_accumulation_steps, 
    mixed_precision=mixed_precision, 
    kwargs_handlers=fp8_kwarg_handler
)
accelerator.print(f'State: {accelerator.state}')

# Wrap model, optimizer/lr-scheduler, dataloaders in accelerator
train_dataloader = get_dataloaders(accelerator, batch_size)
optimizer = AdamW(params = model.parameters(), lr=learning_rate)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=100,
    num_training_steps=num_training_steps,
)
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, lr_scheduler
)
accelerator.init_trackers("fp8-benchmarks", config={
    "model_name": model_name,
    "dataset_name": dataset_name,
    "batch_size": batch_size,
    "accelerator_state": accelerator.state,
    "mixed_precision": accelerator.mixed_precision,
},
init_kwargs={"wandb": {"name": f'{accelerator.mixed_precision}_bs_{batch_size}_{accelerator.num_processes}_gpus'}})



State: Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: fp8





In [6]:
model.model.layers

ModuleList(
  (0-31): 32 x LlamaDecoderLayer(
    (self_attn): LlamaFlashAttention2(
      (q_proj): Linear()
      (k_proj): Linear()
      (v_proj): Linear()
      (o_proj): Linear()
      (rotary_emb): LlamaRotaryEmbedding()
    )
    (mlp): LlamaMLP(
      (gate_proj): Linear()
      (up_proj): Linear()
      (down_proj): Linear()
      (act_fn): SiLU()
    )
    (input_layernorm): LlamaRMSNorm()
    (post_attention_layernorm): LlamaRMSNorm()
  )
)

In [7]:
# Train the model
model.train()
completed_steps = 0
total_loss = 0
optimizer.zero_grad()

for _ in range(10):
    if completed_steps >= num_training_steps:
        break
    for step, batch in enumerate(train_dataloader):
        start_time = time.time()
        with accelerator.accumulate(model):
            outputs = model(**batch)
            loss = outputs.loss
            print(f"Step {step}: loss: {loss.item()}, batch shape: {batch['input_ids'].shape}, peak gpu mem: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
            total_loss += loss.detach().float()
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        if accelerator.sync_gradients:
            completed_steps += 1

        end_time = time.time()
        total_time = end_time - start_time
        print(f"Step {step} time {total_time}")
        accelerator.log({"batch_time": total_time, "input_ids": batch["input_ids"].cpu().numpy(), "attention_mask": batch["attention_mask"].cpu().numpy()})
        start_time = end_time

        if completed_steps >= num_training_steps:
            break

accelerator.end_training()

Step 0: loss: 1.7408778667449951, batch shape: torch.Size([8, 256]), peak gpu mem: 30.95 GB
Step 0 time 1.494196891784668
Step 1: loss: 2.0046780109405518, batch shape: torch.Size([8, 256]), peak gpu mem: 68.52 GB
Step 1 time 0.926837682723999
Step 2: loss: 2.2610206604003906, batch shape: torch.Size([8, 256]), peak gpu mem: 68.52 GB
Step 2 time 0.30959415435791016
Step 3: loss: 2.0951247215270996, batch shape: torch.Size([8, 256]), peak gpu mem: 68.55 GB
Step 3 time 0.3102715015411377
Step 4: loss: 2.0994796752929688, batch shape: torch.Size([8, 256]), peak gpu mem: 68.55 GB
Step 4 time 0.30820536613464355
Step 5: loss: 2.005664587020874, batch shape: torch.Size([8, 256]), peak gpu mem: 68.55 GB
Step 5 time 0.3083922863006592
Step 6: loss: 2.3457205295562744, batch shape: torch.Size([8, 256]), peak gpu mem: 68.55 GB
Step 6 time 0.30889034271240234
Step 7: loss: 1.8973335027694702, batch shape: torch.Size([8, 256]), peak gpu mem: 68.55 GB
Step 7 time 0.30847978591918945
Step 8: loss: 1

## [Improvement 2] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)

In addition to basic layers like `Linear` and `Layernorm`, Transformer-Engine offers larger modules like `Attention` and `MLP` that could replace their counterparts in the `LlamaDecoderLayer` and potentially provide more speedup. Further Transformer-Engine also offers a full-on `TransformerLayer` which could be substituted for `LlamaDecoderLayer`. 

More concretely, here's how `LlamaDecoderLayer`s form the core of the decoder layer stack in HF's llama implementation:
```
ModuleList(
  (0-31): 32 x LlamaDecoderLayer(
    (self_attn): LlamaAttention(
      (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (rotary_emb): LlamaRotaryEmbedding()
    )
    (mlp): LlamaMLP(
      (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
      (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
      (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
      (act_fn): SiLU()
    )
    (input_layernorm): LlamaRMSNorm()
    (post_attention_layernorm): LlamaRMSNorm()
  )
)
```

As is clear above, a major portion of the HuggingFace model implementation (`LlamaDecoderLayer`) could be replaced with TransformerEngine's implementation.

This is a bit more involved and the accompanying file `te_llama.py` file provides a reference implementation.
Briefly, 
1. `TELlamaDecoderLayer` is added as a wrapper for `TransformerLayer`. 
2. Before creating a `LlamaForCausalLM`, `LlamaDecoderLayer` is monkey-patched with `TELlamaDecoderLayer`
3. A custom `pretrained_from_local` method is added that copies the weights from the checkpoint (which is meant for HF Llama implementation) to the modified `TELlamaForCausalLM` by carefully mapping the weights from the `LlamaDecoderLayer` (HF) to `TransformerLayer` (TE). Refer to the following diagram for more details.

![Replace `LlamaDecoderLayer` with `TransformerLayer`](media/weight_swap.png "weight swap")


After initializing the modified Llama model this way, the core decoder layers get changed to `TELlamaDecoderLayer` (wrapper around `TransformerLayer`) as shown in the following output:
```
ModuleList(
  (0-31): 32 x TELlamaDecoderLayer(
    (self_attention): MultiheadAttention(
      (layernorm_qkv): LayerNormLinear()
      (core_attention): DotProductAttention(
        (flash_attention): FlashAttention()
        (fused_attention): FusedAttention()
        (unfused_attention): UnfusedDotProductAttention(
          (scale_mask_softmax): FusedScaleMaskSoftmax()
          (attention_dropout): Dropout(p=0, inplace=False)
        )
      )
      (proj): Linear()
    )
    (layernorm_mlp): LayerNormMLP()
  )
)
```
In summary, the model gets changed as follows with a bigger chunk of the implementation (core decoder layers) coming from TransformerEngine.

![model change in action](media/model_change.png "model change")

[NOTE: This implementation still runs in `BF16`]

In [4]:
# Init the model
from te_llama import TELlamaForCausalLM
config = AutoConfig.from_pretrained(model_name)
model = TELlamaForCausalLM.from_pretrained_local(
        model_name,
        config=config,
        torch_dtype=torch.bfloat16,
)
# Needed for the cases when using TELlamaForCausalLM
model.config.use_cache=False

batch_size = 8

  return self.fget.__get__(instance, owner)()


In [5]:
# Init HF accelerator that's used for training. 
accelerator = Accelerator(
    log_with="wandb", 
    gradient_accumulation_steps=gradient_accumulation_steps, 
    mixed_precision=mixed_precision
)
accelerator.print(f'State: {accelerator.state}')

# Wrap model, optimizer/lr_scheduler, dataloaders in accelerator
train_dataloader = get_dataloaders(accelerator, batch_size)
optimizer = AdamW(params = model.parameters(), lr=learning_rate)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=100,
    num_training_steps=num_training_steps,
)
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, lr_scheduler
)
accelerator.init_trackers("fp8-benchmarks", config={
    "model_name": model_name,
    "dataset_name": dataset_name,
    "batch_size": batch_size,
    "accelerator_state": accelerator.state,
    "mixed_precision": accelerator.mixed_precision,
},
init_kwargs={"wandb": {"name": f'{accelerator.mixed_precision}_bs_{batch_size}_{accelerator.num_processes}_gpus'}})




State: Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: bf16





In [6]:
model.model.layers

ModuleList(
  (0-31): 32 x TELlamaDecoderLayer(
    (self_attention): MultiheadAttention(
      (layernorm_qkv): LayerNormLinear()
      (core_attention): DotProductAttention(
        (flash_attention): FlashAttention()
        (fused_attention): FusedAttention()
        (unfused_attention): UnfusedDotProductAttention(
          (scale_mask_softmax): FusedScaleMaskSoftmax()
          (attention_dropout): Dropout(p=0, inplace=False)
        )
      )
      (proj): Linear()
    )
    (layernorm_mlp): LayerNormMLP()
  )
)

In [7]:
# Train the model
model.train()
completed_steps = 0
total_loss = 0
optimizer.zero_grad()

for _ in range(10):
    if completed_steps >= num_training_steps:
        break
    for step, batch in enumerate(train_dataloader):
        start_time = time.time()
        with accelerator.accumulate(model):
            outputs = model(**batch)
            loss = outputs.loss
            print(f"Step {step}: loss: {loss.item()}, batch shape: {batch['input_ids'].shape}, peak gpu mem: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
            total_loss += loss.detach().float()
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        if accelerator.sync_gradients:
            completed_steps += 1

        end_time = time.time()
        total_time = end_time - start_time
        print(f"Step {step} time {total_time}")
        accelerator.log({"batch_time": total_time, "input_ids": batch["input_ids"].cpu().numpy(), "attention_mask": batch["attention_mask"].cpu().numpy()})
        start_time = end_time

        if completed_steps >= num_training_steps:
            break

accelerator.end_training()

Step 0: loss: 2.285313367843628, batch shape: torch.Size([8, 256]), peak gpu mem: 24.39 GB
Step 0 time 1.4395570755004883
Step 1: loss: 3.433112144470215, batch shape: torch.Size([8, 256]), peak gpu mem: 63.13 GB
Step 1 time 0.2440946102142334
Step 2: loss: 4.062510967254639, batch shape: torch.Size([8, 256]), peak gpu mem: 63.13 GB
Step 2 time 0.24254989624023438
Step 3: loss: 3.3462915420532227, batch shape: torch.Size([8, 256]), peak gpu mem: 63.13 GB
Step 3 time 0.24259161949157715
Step 4: loss: 3.0718109607696533, batch shape: torch.Size([8, 256]), peak gpu mem: 63.13 GB
Step 4 time 0.24408507347106934
Step 5: loss: 3.7262790203094482, batch shape: torch.Size([8, 256]), peak gpu mem: 63.13 GB
Step 5 time 0.24119019508361816
Step 6: loss: 5.115753650665283, batch shape: torch.Size([8, 256]), peak gpu mem: 63.13 GB
Step 6 time 0.24044084548950195
Step 7: loss: 3.9757368564605713, batch shape: torch.Size([8, 256]), peak gpu mem: 63.13 GB
Step 7 time 0.24254655838012695
Step 8: loss: 

## [Improvement 3] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)

Now that most of the HF Llama model implementation (`LlamaDecoderLayer`s) has been swapped with Transformer-Engine implementation (`TELlamaDecoderLayer` or `TransformerLayer`), let's see how `FP8` training helps improve performance.

In [4]:
# Init the model
from te_llama import TELlamaForCausalLM
config = AutoConfig.from_pretrained(model_name)

model = TELlamaForCausalLM.from_pretrained_local(
        model_name,
        config=config,
        torch_dtype=torch.bfloat16,
)
# Needed for the cases when using TELlamaForCausalLM
model.config.use_cache=False

batch_size = 8
mixed_precision="fp8"

  return self.fget.__get__(instance, owner)()


In [5]:
# Init HF accelerator that's used for training. 
# Notice the use of `fp8` recipe
fp8_kwarg_handler = [FP8RecipeKwargs(backend="te")] if mixed_precision == "fp8" else None
accelerator = Accelerator(
    log_with="wandb", gradient_accumulation_steps=gradient_accumulation_steps, 
    mixed_precision=mixed_precision, 
    kwargs_handlers=fp8_kwarg_handler
)
accelerator.print(f'State: {accelerator.state}')

# Wrap model, optimizer/lr-scheduler, dataloaders in accelerator
train_dataloader = get_dataloaders(accelerator, batch_size)
optimizer = AdamW(params = model.parameters(), lr=learning_rate)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=100,
    num_training_steps=num_training_steps,
)
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, lr_scheduler
)
accelerator.init_trackers("fp8-benchmarks", config={
    "model_name": model_name,
    "dataset_name": dataset_name,
    "batch_size": batch_size,
    "accelerator_state": accelerator.state,
    "mixed_precision": accelerator.mixed_precision,
},
init_kwargs={"wandb": {"name": f'{accelerator.mixed_precision}_bs_{batch_size}_{accelerator.num_processes}_gpus'}})




State: Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: fp8





In [6]:
model.model.layers

ModuleList(
  (0-31): 32 x TELlamaDecoderLayer(
    (self_attention): MultiheadAttention(
      (layernorm_qkv): LayerNormLinear()
      (core_attention): DotProductAttention(
        (flash_attention): FlashAttention()
        (fused_attention): FusedAttention()
        (unfused_attention): UnfusedDotProductAttention(
          (scale_mask_softmax): FusedScaleMaskSoftmax()
          (attention_dropout): Dropout(p=0, inplace=False)
        )
      )
      (proj): Linear()
    )
    (layernorm_mlp): LayerNormMLP()
  )
)

In [7]:
# Train the model
model.train()
completed_steps = 0
total_loss = 0
optimizer.zero_grad()

for _ in range(10):
    if completed_steps >= num_training_steps:
        break
    for step, batch in enumerate(train_dataloader):
        start_time = time.time()
        with accelerator.accumulate(model):
            outputs = model(**batch)
            loss = outputs.loss
            print(f"Step {step}: loss: {loss.item()}, batch shape: {batch['input_ids'].shape}, peak gpu mem: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
            total_loss += loss.detach().float()
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        if accelerator.sync_gradients:
            completed_steps += 1

        end_time = time.time()
        total_time = end_time - start_time
        print(f"Step {step} time {total_time}")
        accelerator.log({"batch_time": total_time, "input_ids": batch["input_ids"].cpu().numpy(), "attention_mask": batch["attention_mask"].cpu().numpy()})
        start_time = end_time

        if completed_steps >= num_training_steps:
            break

accelerator.end_training()

Step 0: loss: 2.2819559574127197, batch shape: torch.Size([8, 256]), peak gpu mem: 26.61 GB
Step 0 time 3.093244791030884
Step 1: loss: 3.3127150535583496, batch shape: torch.Size([8, 256]), peak gpu mem: 64.54 GB
Step 1 time 0.7673726081848145
Step 2: loss: 3.993682384490967, batch shape: torch.Size([8, 256]), peak gpu mem: 64.54 GB
Step 2 time 0.23313021659851074
Step 3: loss: 3.22248911857605, batch shape: torch.Size([8, 256]), peak gpu mem: 64.54 GB
Step 3 time 0.23214364051818848
Step 4: loss: 3.0897724628448486, batch shape: torch.Size([8, 256]), peak gpu mem: 64.54 GB
Step 4 time 0.23146915435791016
Step 5: loss: 3.6932754516601562, batch shape: torch.Size([8, 256]), peak gpu mem: 64.54 GB
Step 5 time 0.23121118545532227
Step 6: loss: 5.0839033126831055, batch shape: torch.Size([8, 256]), peak gpu mem: 64.54 GB
Step 6 time 0.23106837272644043
Step 7: loss: 3.907317638397217, batch shape: torch.Size([8, 256]), peak gpu mem: 64.54 GB
Step 7 time 0.23175501823425293
Step 8: loss: 3

## Benchmarks (revisited)
Let's take a look at the summary of the performance numbers with various configurations.

| Models                                                      | Precision | Step Time (or ms per batch) | Speedup (over baseline) |
|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|
| HF (baseline)                                               | BF16      | 288                         | 1                       |
| HF (replace `nn.Linear` with `TE.Linear`)                   | FP8       | 307                         | 0.94                    |
| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16      | 243                         | 1.19                    |
| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | FP8       | 231                         | 1.24                    |

When we use larger chunks of layers from TransformerEngine, we see larger performance improvement. Further, when `FP8` precision is enabled, we see even larger speedup as is expected.

### Understanding the aberration with Improvement 2 (CPU Overheads)
When the `nn.Linear` layers in the model are naively replaced with `TE.Linear` layers, a performance drop is seen. Although it seems like a bummer in this case, it's not unexpected. 

1. `TE.Linear` has to do some extra work (in python, so on CPU) before it can issue `FP8` version of GEMMs (matrix-multiplies from the linear layers) to the GPU.
2. So, if batch-size and/or sequence-length are small enough, the GEMM will be smaller. The GEMMs on GPU therefore usually finish early than the CPU can issue them in succession. In this case, we say the the workload is CPU bound. Ideally we'd want the workload to be GPU bound, i.e. it should spend more time in the GPU to fully utilize its capability.
3. We can verify this by looking the profiles of the two workloads (using nsight systems).
    - For the baseline case (with `nn.Linear` layers):
        ![baseline profile](media/baseline.png "baseline")
        
    - Improvement 1 - when `nn.Linear`s are replaced with `TE.Linear`s:
        ![replace linears](media/replace_nnlinear_with_telinear.png "replace linears")
        
    - As we can see, the GPU is busy comparatively less in the second case, compared to the first. CPU overhead is not letting 

But How do we alleviate this issue? 

To fully utilize the Transformer-Engine's capabilities, the workload should be larger. In this present tutorial, the batch-size is 8 which is fine but sequence-length is 256 which is pretty low. If the sequence-length is large enough, even this case could be faster than the baseline. 

    

