# QLoRA Training with FSDP2 for Llama 3.1 8B

This notebook demonstrates QLoRA training with FSDP2 on multiple GPUs using Meta-Llama-3.1-8B

In [None]:
!pip install -q torch>=2.2.0 transformers>=4.36.0 accelerate>=0.26.0 bitsandbytes>=0.41.3 peft>=0.7.0 flash-attn>=2.5.0 datasets>=2.16.0 wandb>=0.16.0

In [None]:
import os
import torch
import torch.nn as nn
from dataclasses import dataclass, field
from typing import Optional, Dict, Sequence
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    TrainingArguments,
    Trainer,
    BitsAndBytesConfig,
    LlamaConfig,
    LlamaForCausalLM,
    GenerationConfig
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import load_dataset
import torch.distributed as dist
from torch.distributed.fsdp import *
from torch.distributed.fsdp.wrap import *
from huggingface_hub import notebook_login
from time import perf_counter

In [None]:
# Login to HuggingFace
notebook_login()

In [None]:
# Model configuration
model_id = "meta-llama/Meta-Llama-3.1-8B"

def get_model_and_tokenizer(model_id):
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True, 
        bnb_4bit_quant_type="nf4", 
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_id, 
        quantization_config=bnb_config, 
        device_map="auto"
    )
    model.config.use_cache = False
    model.config.pretraining_tp = 1
    return model, tokenizer

model, tokenizer = get_model_and_tokenizer(model_id)

In [None]:
def generate_response(user_input):
    prompt = formatted_prompt(user_input)
    generation_config = GenerationConfig(
        penalty_alpha=0.6,
        do_sample=True,
        top_k=5,
        temperature=0.5,
        repetition_penalty=1.2,
        max_new_tokens=60,
        pad_token_id=tokenizer.eos_token_id
    )
    
    start_time = perf_counter()
    inputs = tokenizer(prompt, return_tensors="pt").to('cuda')
    outputs = model.generate(**inputs, generation_config=generation_config)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    output_time = perf_counter() - start_time
    
    print(response)
    print(f"Time taken for inference: {round(output_time,2)} seconds")
    return response

In [None]:
# Test generation before training
test_prompt = "What is machine learning?"
print("\nTesting generation before training:")
generate_response(test_prompt)

In [None]:
# Setup FSDP
def setup_fsdp_model(model):
    from transformers.models.llama.modeling_llama import LlamaDecoderLayer
    
    mp_policy = MixedPrecision(
        param_dtype=torch.float16,
        reduce_dtype=torch.float16,
        buffer_dtype=torch.float16
    )
    
    wrap_policy = transformer_auto_wrap_policy(
        transformer_layer_cls={LlamaDecoderLayer}
    )
    
    model = FSDP(
        model,
        auto_wrap_policy=wrap_policy,
        mixed_precision=mp_policy,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        cpu_offload=CPUOffload(offload_params=True),
        limit_all_gathers=True,
        use_orig_params=True,
    )
    
    return model

In [None]:
# Prepare model for training
model = prepare_model_for_kbit_training(model)

# LoRA config
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

# Apply LoRA
model = get_peft_model(model, lora_config)

# Setup FSDP
model = setup_fsdp_model(model)

In [None]:
# Training setup
dataset = load_dataset("tatsu-lab/alpaca")

training_args = TrainingArguments(
    output_dir="./llama3-qlora-fsdp",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    bf16=True,
    logging_steps=1,
    save_strategy="epoch",
    save_total_limit=3,
    ddp_backend="nccl",
    gradient_checkpointing=True,
    report_to="wandb"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    tokenizer=tokenizer,
)

In [None]:
# Train
trainer.train()

In [None]:
# Test generation after training
test_prompt = "What is machine learning?"
print("\nTesting generation after training:")
generate_response(test_prompt)

In [None]:
# Save model
trainer.save_model()