# Efficient Llama Training with Gradient Checkpointing and _Adapters_

In this notebook, we show how to efficiently fine-tune a **Llama 3** model using **gradient checkpointing** and adapter methods.

**Gradient checkpointing** is a technique to reduce peak memory usage significantly and thus enables training larger models with larger batch sizes. Gradient checkpointing achieves this by trading compute for memory: During the forward pass, gradient checkpointing only stores a subset of activations (thus saving memory). During backpropagation, gradient checkpointing recomputes the activations that were not stored. This can significantly reduce memory requirements at the cost of slightly increased computation time.

In this notebook, we finetune Llama-3 8B on supervised instruction tuning data collected by the [Open Assistant project](https://github.com/LAION-AI/Open-Assistant) for training chatbots.

Another way to reduce memore usage is to use quantization. Have a look a the [QLora notebook](QLoRA_Llama_Finetuning.ipynb) for an example. This gradient checkpointing notebook is based on the QLoRA notebook. While we use a normal LoRA setup in this notebook, you can easily replace LoRA with QLoRA to reduce memory usage even further.

## Installation

We need `adapters`, `datasets` and `pytorch` for training.

In [None]:
%pip install -qq -U adapters datasets torch

## Load Open Assistant dataset

We use the [`timdettmers/openassistant-guanaco`](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) dataset, which contains a small subset of conversations from the full Open Assistant database.

In [None]:
from datasets import load_dataset

dataset = load_dataset("timdettmers/openassistant-guanaco")
dataset

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 9846
    })
    test: Dataset({
        features: ['text'],
        num_rows: 518
    })
})

## Load and prepare model

We download the official Llama-2 7B/ Llama-3 8B checkpoint from the HuggingFace Hub. Note that you must request access to this model on the HuggingFace website and use an API token to download it.

The key difference in this notebook is that we'll enable gradient checkpointing to reduce memory usage during training.

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# To access the Llama 3 model, you need to provide your key:
HUGGINGFACE_ACCESS_TOKEN = "<PASTE_YOUR_TOKEN_HERE>"

modelpath="meta-llama/Meta-Llama-3-8B"

# Load model with gradient checkpointing enabled
model = AutoModelForCausalLM.from_pretrained(
    modelpath,    
    device_map="auto",
    torch_dtype=torch.bfloat16,
    token=HUGGINGFACE_ACCESS_TOKEN,
)
model.config.use_cache = False

tokenizer = AutoTokenizer.from_pretrained(modelpath, token=HUGGINGFACE_ACCESS_TOKEN)
tokenizer.pad_token = tokenizer.eos_token

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

If you get a message similar to `WARNING:accelerate.big_modeling:Some parameters are on the meta device because they were offloaded to the cpu and disk.`, then the model itself is too big for your GPU. If you don't have a bigger / additional GPU at hand, you can use a quantization method like we show in the [QLoRA notebook](QLoRA_Llama_Finetuning.ipynb). Adding the quantization_config when loading the model and choosing a quantized `LoRAConfig` in the next step will enable quantized training.

## Initialize adapter

We initialize the adapter functionality and add a LoRA adapter. When using gradient checkpointing with adapters, we need to enable input gradients explicitly.

In [None]:
import adapters
from adapters import LoRAConfig

adapters.init(model)

config = LoRAConfig()
model.add_adapter("lora_adapter", config=config)
model.train_adapter("lora_adapter")

# Activate gradient checkpointing
model.gradient_checkpointing_enable()

print(model.adapter_summary())

Name                     Architecture         #Param      %Param  Active   Train
--------------------------------------------------------------------------------
lora_adapter             lora              3,407,872       0.085       1       1
--------------------------------------------------------------------------------
Full model                              4,015,263,744     100.000               0


## Prepare data for training

The dataset is tokenized and truncated.

In [None]:
import os

def tokenize(element):
    return tokenizer(
        element["text"],
        truncation=True,
        max_length=512,
        add_special_tokens=False,
    )

dataset_tokenized = dataset.map(
    tokenize, 
    batched=True, 
    num_proc=os.cpu_count(),
    remove_columns=["text"]
)

## Training

We specify training hyperparameters and train the model using the `AdapterTrainer` class. With gradient checkpointing enabled, we can use larger batch sizes than would otherwise be possible.

In [None]:
from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="output/llama_gradient_checkpointing",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    evaluation_strategy="steps",
    logging_steps=10,
    save_steps=500,
    eval_steps=187,
    save_total_limit=3,
    gradient_accumulation_steps=16,
    max_steps=1875,
    learning_rate=0.0002,
    bf16=True,
    warmup_ratio=0.03,
    group_by_length=True,
    lr_scheduler_type="constant",
    optim="adamw_torch"
)

In [None]:
from adapters import AdapterTrainer
from transformers import DataCollatorForLanguageModeling

trainer = AdapterTrainer(
    model=model,
    tokenizer=tokenizer,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
    train_dataset=dataset_tokenized["train"],
    eval_dataset=dataset_tokenized["test"],
    args=args,
)

trainer.train()

## Inference

For inference, we can disable gradient checkpointing since we don't need gradients:

In [None]:
# Disable gradient checkpointing for inference
model.gradient_checkpointing_disable()
model.config.use_cache = True

def prompt_model(model, text: str):
    batch = tokenizer(f"### Human: {text}\n### Assistant:", return_tensors="pt")
    batch = batch.to(model.device)
    
    model.eval()
    with torch.inference_mode():
        output_tokens = model.generate(**batch, max_new_tokens=50)

    return tokenizer.decode(output_tokens[0], skip_special_tokens=True)

In [None]:
print(prompt_model(model, "Explain gradient checkpointing in simple terms"))