# QLoRA Training with FSDP2 for Llama 3.1 8B

This notebook demonstrates QLoRA training with FSDP2 on multiple GPUs using Meta-Llama-3.1-8B with all optimizations enabled.

In [None]:
!pip install -q torch>=2.2.0 transformers>=4.36.0 accelerate>=0.26.0 bitsandbytes>=0.41.3 peft>=0.7.0 flash-attn>=2.5.0 datasets>=2.16.0 wandb>=0.16.0

In [None]:
import os
import torch
import torch.nn as nn
import bitsandbytes as bnb
from dataclasses import dataclass, field
from typing import Optional, Dict, Sequence
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    TrainingArguments,
    Trainer,
    BitsAndBytesConfig,
    LlamaConfig,
    LlamaForCausalLM,
    GenerationConfig
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import load_dataset
import torch.distributed as dist
from torch.distributed.fsdp import *
from torch.distributed.fsdp.wrap import *
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import *
from huggingface_hub import notebook_login
from time import perf_counter
import random
import numpy as np
from functools import partial

In [None]:
# Login to HuggingFace
notebook_login()

In [None]:
# Utility functions
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def malloc_in_gb():
    return torch.cuda.memory_allocated() / 1024**3

def free_memory():
    torch.cuda.empty_cache()
    torch.cuda.synchronize()

def get_model_size_config(model_size):
    if model_size == "DEBUG":
        model_size_config = dict(hidden_size=128,
                                num_hidden_layers=2,
                                num_attention_heads=2,
                                num_key_value_heads=2,
                                intermediate_size=256)
    elif model_size == "60M":
        model_size_config = dict(hidden_size=512,
                                num_hidden_layers=4,
                                num_attention_heads=4,
                                num_key_value_heads=4,
                                intermediate_size=1024)
    elif model_size == "120M":
        model_size_config = dict(hidden_size=768,
                                num_hidden_layers=12,
                                num_attention_heads=12,
                                num_key_value_heads=12,
                                intermediate_size=1536)
    elif model_size == "290M":
        model_size_config = dict(hidden_size=1024,
                                num_hidden_layers=12,
                                num_attention_heads=16,
                                num_key_value_heads=16,
                                intermediate_size=4096)
    elif model_size == "1B":
        model_size_config = dict(hidden_size=2048,
                                num_hidden_layers=24,
                                num_attention_heads=16,
                                num_key_value_heads=16,
                                intermediate_size=4096)
    elif model_size == "7B":
        model_size_config = {}
    return model_size_config

def create_model(model_size="1B"):
    model_size_config = get_model_size_config(model_size)
    config = LlamaConfig()
    config.update(model_size_config)
    model = LlamaForCausalLM(config)
    return model

In [None]:
def replace_with_bnb_4bit_linear(
    model,
    modules_to_not_convert=None,
    current_key_name=None,
    quantization_config=None,
    has_been_replaced=False,
    quant_storage=torch.uint8, 
    keep_trainable=False,
):
    if modules_to_not_convert is None:
        modules_to_not_convert = []
    
    for name, module in model.named_children():
        if current_key_name is None:
            current_key_name = []
        current_key_name.append(name)

        if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
            if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
                model._modules[name] = bnb.nn.Linear4bit(
                    module.in_features,
                    module.out_features,
                    module.bias is not None,
                    quantization_config.bnb_4bit_compute_dtype,
                    compress_statistics=quantization_config.bnb_4bit_use_double_quant,
                    quant_type=quantization_config.bnb_4bit_quant_type,
                    quant_storage=quant_storage
                )
                has_been_replaced = True
                model._modules[name].source_cls = type(module)
                if keep_trainable:
                    model._modules[name].requires_grad_(True)
                else:
                    model._modules[name].requires_grad_(False)

        if len(list(module.children())) > 0:
            _, has_been_replaced = replace_with_bnb_4bit_linear(
                module,
                modules_to_not_convert,
                current_key_name,
                quantization_config,
                has_been_replaced=has_been_replaced,
                quant_storage=quant_storage,
                keep_trainable=keep_trainable
            )
        current_key_name.pop(-1)
    return model, has_been_replaced

In [None]:
# Model configuration
model_id = "meta-llama/Meta-Llama-3.1-8B"

def get_model_and_tokenizer(model_id):
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True, 
        bnb_4bit_quant_type="nf4", 
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_id, 
        quantization_config=bnb_config, 
        device_map="auto"
    )
    model.config.use_cache = False
    model.config.pretraining_tp = 1
    return model, tokenizer

model, tokenizer = get_model_and_tokenizer(model_id)

In [None]:
def generate_response(user_input):
    prompt = formatted_prompt(user_input)
    generation_config = GenerationConfig(
        penalty_alpha=0.6,
        do_sample=True,
        top_k=5,
        temperature=0.5,
        repetition_penalty=1.2,
        max_new_tokens=60,
        pad_token_id=tokenizer.eos_token_id
    )
    
    start_time = perf_counter()
    inputs = tokenizer(prompt, return_tensors="pt").to('cuda')
    outputs = model.generate(**inputs, generation_config=generation_config)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    output_time = perf_counter() - start_time
    
    print(response)
    print(f"Time taken for inference: {round(output_time,2)} seconds")
    return response

In [None]:
# Test generation before training
test_prompt = "What is machine learning?"
print("\nTesting generation before training:")
generate_response(test_prompt)

In [None]:
def setup_distributed():
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29500"
    dist.init_process_group("nccl")
    torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0)))

def get_policies(model):
    from transformers.models.llama.modeling_llama import LlamaDecoderLayer
    return transformer_auto_wrap_policy(transformer_layer_cls={LlamaDecoderLayer})

def setup_fsdp_model(model):
    from transformers.models.llama.modeling_llama import LlamaDecoderLayer
    
    mp_policy = MixedPrecision(
        param_dtype=torch.float16,
        reduce_dtype=torch.float16,
        buffer_dtype=torch.float16
    )
    
    wrap_policy = get_policies(model)
    
    model = FSDP(
        model,
        auto_wrap_policy=wrap_policy,
        mixed_precision=mp_policy,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        cpu_offload=CPUOffload(offload_params=True),
        limit_all_gathers=True,
        use_orig_params=True,
    )
    
    non_reentrant_wrapper = partial(
        checkpoint_wrapper,
        checkpoint_impl=CheckpointImpl.NO_REENTRANT,
    )
    
    check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer)
    apply_activation_checkpointing(
        model,
        checkpoint_wrapper_fn=non_reentrant_wrapper,
        check_fn=check_fn
    )
    
    return model

In [None]:
# Setup distributed training
setup_distributed()

# Profile model memory before training
if dist.get_rank() == 0:
    print(f"Memory allocated [MODEL]: {malloc_in_gb():.3f} GB")

# Prepare model for training
model = prepare_model_for_kbit_training(model)

# LoRA config
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

# Apply LoRA
model = get_peft_model(model, lora_config)

# Setup FSDP
model = setup_fsdp_model(model)

if dist.get_rank() == 0:
    print(f"Memory allocated [AFTER SETUP]: {malloc_in_gb():.3f} GB")

In [None]:
# Training setup
dataset = load_dataset("tatsu-lab/alpaca")

training_args = TrainingArguments(
    output_dir="./llama3-qlora-fsdp",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    bf16=True,
    logging_steps=1,
    save_strategy="epoch",
    save_total_limit=3,
    ddp_backend="nccl",
    gradient_checkpointing=True,
    report_to="wandb"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    tokenizer=tokenizer,
)

In [None]:
# Train
trainer.train()

In [None]:
# Test generation after training
test_prompt = "What is machine learning?"
print("\nTesting generation after training:")
generate_response(test_prompt)

In [None]:
# Save model
if dist.get_rank() == 0:
    trainer.save_model()