# LoRA Fine-tuning with Llama 3.1: A Complete Guide

This notebook demonstrates Low-Rank Adaptation (LoRA) techniques for efficiently fine-tuning Llama 3.1 on a custom dataset.

## Table of Contents
1. Theory: Understanding LoRA
2. Setup and Installation
3. Project: Fine-tuning Llama 3.1 for Code Documentation
4. Data Preparation
5. Model Loading and LoRA Configuration
6. Training Process
7. Evaluation and Inference
8. Advanced Techniques (QLoRA)
9. Model Comparison and Analysis
10. Saving and Loading Adapters
11. Best Practices

## 1. Theory: Understanding LoRA

**LoRA (Low-Rank Adaptation) Theory:**

Traditional fine-tuning updates all parameters of a pre-trained model:
```
W_new = W_original + ΔW
```

LoRA decomposes the weight update ΔW into low-rank matrices:
```
ΔW = B × A
```

Where:
- A is a matrix of shape (rank, input_dim)
- B is a matrix of shape (output_dim, rank)
- rank << min(input_dim, output_dim)

This reduces trainable parameters from d² to 2×d×r (where r is rank)

**Benefits:**
- 99%+ reduction in trainable parameters
- Faster training and inference
- Lower memory requirements
- Modular: can switch between different adaptations

## 2. Setup and Installation

First, install the required packages:

In [1]:
# Install required packages
!pip install torch torchvision torchaudio
!pip install transformers
!pip install peft
!pip install datasets
!pip install bitsandbytes
!pip install accelerate
!pip install wandb  # optional for experiment tracking

Collecting torchvision
  Downloading torchvision-0.22.1-cp310-cp310-manylinux_2_28_x86_64.whl (7.5 MB)
[K     |████████████████████████████████| 7.5 MB 503 kB/s eta 0:00:01
[?25hCollecting torchaudio
  Downloading torchaudio-2.7.1-cp310-cp310-manylinux_2_28_x86_64.whl (3.5 MB)
[K     |████████████████████████████████| 3.5 MB 13.2 MB/s eta 0:00:01
Collecting pillow!=8.3.*,>=5.3.0
  Using cached pillow-11.2.1-cp310-cp310-manylinux_2_28_x86_64.whl (4.6 MB)
Installing collected packages: pillow, torchvision, torchaudio
Successfully installed pillow-11.2.1 torchaudio-2.7.1 torchvision-0.22.1
You should consider upgrading via the '/home/mohdasimkhan/.pyenv/versions/3.10.2/envs/lora/bin/python -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/home/mohdasimkhan/.pyenv/versions/3.10.2/envs/lora/bin/python -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/home/mohdasimkhan/.pyenv/versions/3.10.2/envs/lora/bin/python -m pip inst

In [2]:
import torch
import torch.nn as nn
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from peft import (
    LoraConfig, 
    get_peft_model, 
    TaskType,
    PeftModel
)
from datasets import Dataset, load_dataset
import json
import numpy as np
from typing import Dict, List
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu


## 3. Project Overview: Code Documentation Assistant

**Project Goal:** Fine-tune Llama 3.1 to generate high-quality code documentation

**Task:** Given a Python function, generate comprehensive docstrings including:
- Function description
- Parameter explanations
- Return value description
- Usage examples

This demonstrates LoRA's effectiveness for domain-specific adaptation.

## 4. Data Preparation

In [None]:
# Sample dataset for code documentation
code_examples = [
    {
        "instruction": "Generate a comprehensive docstring for this Python function:",
        "input": """def calculate_compound_interest(principal, rate, time, compound_frequency=1):
    amount = principal * (1 + rate/compound_frequency)**(compound_frequency * time)
    return amount - principal""",
        "output": """Calculate compound interest on an investment.
        
        This function computes the compound interest earned on a principal amount
        over a specified time period with a given interest rate and compounding frequency.
        
        Args:
            principal (float): The initial investment amount in dollars
            rate (float): Annual interest rate as a decimal (e.g., 0.05 for 5%)
            time (float): Investment period in years
            compound_frequency (int, optional): Number of times interest compounds per year. Defaults to 1.
        
        Returns:
            float: The compound interest earned (not including principal)
            
        Example:
            >>> interest = calculate_compound_interest(1000, 0.05, 2, 4)
            >>> print(f"Interest earned: ${interest:.2f}")
            Interest earned: $104.49
        """
    },
    {
        "instruction": "Generate a comprehensive docstring for this Python function:",
        "input": """def merge_sorted_lists(list1, list2):
    result = []
    i = j = 0
    while i < len(list1) and j < len(list2):
        if list1[i] <= list2[j]:
            result.append(list1[i])
            i += 1
        else:
            result.append(list2[j])
            j += 1
    result.extend(list1[i:])
    result.extend(list2[j:])
    return result""",
        "output": """Merge two sorted lists into a single sorted list.
        
        This function takes two pre-sorted lists and efficiently merges them
        into a single sorted list using a two-pointer approach. The algorithm
        has O(n+m) time complexity where n and m are the lengths of the input lists.
        
        Args:
            list1 (List[Any]): First sorted list to merge
            list2 (List[Any]): Second sorted list to merge
            
        Returns:
            List[Any]: A new sorted list containing all elements from both input lists
            
        Example:
            >>> merged = merge_sorted_lists([1, 3, 5], [2, 4, 6])
            >>> print(merged)
            [1, 2, 3, 4, 5, 6]
        """
    },
    {
        "instruction": "Generate a comprehensive docstring for this Python function:",
        "input": """def find_prime_factors(n):
    factors = []
    d = 2
    while d * d <= n:
        while n % d == 0:
            factors.append(d)
            n //= d
        d += 1
    if n > 1:
        factors.append(n)
    return factors""",
        "output": """Find all prime factors of a given positive integer.
        
        This function decomposes a positive integer into its prime factors using
        trial division. The algorithm iterates through potential divisors up to
        the square root of the number for efficiency.
        
        Args:
            n (int): Positive integer to factorize (must be > 1)
            
        Returns:
            List[int]: List of prime factors in ascending order. The product
                      of all factors equals the original number.
                      
        Raises:
            ValueError: If n is less than or equal to 1
            
        Example:
            >>> factors = find_prime_factors(60)
            >>> print(factors)
            [2, 2, 3, 5]
            >>> # Verification: 2 * 2 * 3 * 5 = 60
        """
    }
]

In [None]:
def create_training_prompt(instruction: str, input_code: str, output_doc: str) -> str:
    """Create a formatted training prompt for the model."""
    return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a helpful AI assistant that generates comprehensive Python docstrings.<|eot_id|><|start_header_id|>user<|end_header_id|>

{instruction}

```python
{input_code}
```<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{output_doc}<|eot_id|>"""

def create_inference_prompt(instruction: str, input_code: str) -> str:
    """Create a formatted prompt for inference."""
    return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a helpful AI assistant that generates comprehensive Python docstrings.<|eot_id|><|start_header_id|>user<|end_header_id|>

{instruction}

```python
{input_code}
```<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""

def prepare_dataset(examples: List[Dict]) -> Dataset:
    """Convert examples to a HuggingFace Dataset."""
    formatted_examples = []
    
    for example in examples:
        prompt = create_training_prompt(
            example["instruction"],
            example["input"],
            example["output"]
        )
        formatted_examples.append({"text": prompt})
    
    return Dataset.from_list(formatted_examples)

# Create dataset
train_dataset = prepare_dataset(code_examples)
print(f"Created dataset with {len(train_dataset)} examples")

## 5. Model Loading and LoRA Configuration

In [None]:
# Model configuration
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"  # or smaller version for testing
max_length = 1024

# Load tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [None]:
# Load model
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto",
    load_in_8bit=True,  # Use 8-bit quantization to reduce memory
)

In [None]:
# Configure LoRA
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=16,  # Rank of adaptation - higher rank = more parameters but potentially better performance
    lora_alpha=32,  # LoRA scaling parameter - typically 2x the rank
    lora_dropout=0.1,  # Dropout for LoRA layers
    target_modules=[
        "q_proj",  # Query projection
        "k_proj",  # Key projection  
        "v_proj",  # Value projection
        "o_proj",  # Output projection
        "gate_proj",  # Gate projection (for LLaMA)
        "up_proj",   # Up projection
        "down_proj", # Down projection
    ],
    bias="none",
)

# Apply LoRA to model
print("Applying LoRA configuration...")
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
# Tokenize dataset
def tokenize_function(examples):
    """Tokenize the text examples."""
    return tokenizer(
        examples["text"],
        truncation=True,
        padding=False,
        max_length=max_length,
        return_overflowing_tokens=False,
    )

# Tokenize the dataset
print("Tokenizing dataset...")
tokenized_dataset = train_dataset.map(tokenize_function, batched=True)

## 6. Training Process

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir="./llama-lora-docstring",
    num_train_epochs=3,
    per_device_train_batch_size=1,  # Small batch size for memory efficiency
    gradient_accumulation_steps=4,  # Simulate larger batch size
    warmup_steps=100,
    learning_rate=2e-4,  # Higher learning rate for LoRA
    fp16=True,  # Mixed precision training
    logging_steps=10,
    save_strategy="epoch",
    evaluation_strategy="no",  # No validation set for this example
    remove_unused_columns=False,
    dataloader_pin_memory=False,
)

# Data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,  # We're not doing masked language modeling
)

In [None]:
# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
)

# Start training
print("Starting training...")
trainer.train()

# Save the model
trainer.save_model()
print("Training completed and model saved!")

## 7. Evaluation and Inference

In [None]:
# Function to generate docstring
def generate_docstring(code: str, instruction: str = "Generate a comprehensive docstring for this Python function:"):
    """Generate a docstring for the given code using our fine-tuned model."""
    
    # Create the prompt
    prompt = create_inference_prompt(instruction, code)
    
    # Tokenize
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )
    
    # Decode and extract the response
    full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    response = full_response.split("<|start_header_id|>assistant<|end_header_id|>")[-1].strip()
    
    return response

In [None]:
# Test the model with a new function
test_code = """def binary_search(arr, target):
    left, right = 0, len(arr) - 1
    while left <= right:
        mid = (left + right) // 2
        if arr[mid] == target:
            return mid
        elif arr[mid] < target:
            left = mid + 1
        else:
            right = mid - 1
    return -1"""

print("Testing the fine-tuned model:")
print("=" * 50)
print("Input code:")
print(test_code)
print("\nGenerated docstring:")
generated_docstring = generate_docstring(test_code)
print(generated_docstring)

## 8. Advanced Techniques: QLoRA

QLoRA (Quantized LoRA) combines LoRA with 4-bit quantization for even greater efficiency.

In [None]:
# QLoRA configuration example
from transformers import BitsAndBytesConfig

# 4-bit quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,  # Double quantization
    bnb_4bit_quant_type="nf4",       # Normal Float 4 quantization
    bnb_4bit_compute_dtype=torch.float16,
)

# Load model with QLoRA (uncomment to use)
"""
qlora_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16,
)

qlora_model = get_peft_model(qlora_model, lora_config)
"""

print("QLoRA configuration ready!")
print("Uncomment the code above to use 4-bit quantization")

## 9. Model Comparison and Analysis

In [None]:
def compare_model_sizes():
    """Compare memory usage of different approaches."""
    print("Model Size Comparison:")
    print("=" * 40)
    
    # Original model parameters
    original_params = sum(p.numel() for p in model.base_model.parameters())
    
    # LoRA parameters
    lora_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"Original model parameters: {original_params:,}")
    print(f"LoRA trainable parameters: {lora_params:,}")
    print(f"Reduction factor: {original_params / lora_params:.1f}x")
    print(f"LoRA parameters as % of original: {(lora_params / original_params) * 100:.2f}%")

compare_model_sizes()

In [None]:
def analyze_lora_config(config: LoraConfig):
    """Analyze the LoRA configuration."""
    print("LoRA Configuration Analysis:")
    print("=" * 35)
    print(f"Rank: {config.r}")
    print(f"Alpha: {config.lora_alpha}")
    print(f"Scaling factor: {config.lora_alpha / config.r}")
    print(f"Target modules: {config.target_modules}")
    print(f"Dropout: {config.lora_dropout}")

analyze_lora_config(lora_config)

## 10. Saving and Loading the Adapter

In [None]:
def save_lora_adapter(model, save_path: str):
    """Save only the LoRA adapter weights."""
    model.save_pretrained(save_path)
    print(f"LoRA adapter saved to {save_path}")

def load_lora_adapter(base_model_name: str, adapter_path: str):
    """Load a base model with LoRA adapter."""
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        torch_dtype=torch.float16,
        device_map="auto",
    )
    
    model_with_adapter = PeftModel.from_pretrained(base_model, adapter_path)
    return model_with_adapter

# Save the adapter
save_lora_adapter(model, "./lora-docstring-adapter")

## 11. Best Practices and Tips

### Key LoRA Hyperparameters and Their Effects:

**1. Rank (r):**
- Lower rank (4-8): Fewer parameters, faster training, may limit expressiveness
- Higher rank (16-64): More parameters, potentially better performance
- Rule of thumb: Start with 8-16 for most tasks

**2. Alpha (lora_alpha):**
- Controls the scaling of LoRA updates
- Typically set to 2x the rank value
- Higher alpha = stronger adaptation

**3. Target Modules:**
- More modules = more parameters but better coverage
- Focus on attention layers (q_proj, v_proj) for efficiency
- Include MLP layers for more comprehensive adaptation

**4. Dropout:**
- Prevents overfitting in LoRA layers
- Typical range: 0.05-0.1

### LoRA Fine-tuning Best Practices:

**1. Data Quality:**
- Use high-quality, domain-specific data
- Ensure consistent formatting
- Include diverse examples

**2. Hyperparameter Tuning:**
- Start with rank=8, alpha=16
- Use higher learning rates (1e-4 to 5e-4)
- Experiment with different target modules

**3. Training Strategy:**
- Use gradient accumulation for effective larger batch sizes
- Monitor for overfitting (especially with small datasets)
- Consider warmup steps for stability

**4. Memory Optimization:**
- Use gradient checkpointing
- Enable fp16/bf16 mixed precision
- Use DeepSpeed ZeRO for large models

**5. Evaluation:**
- Test on held-out data
- Compare with base model performance
- Evaluate task-specific metrics

In [None]:
print("LoRA Fine-tuning Complete!")
print("\nThis notebook demonstrated:")
print("- LoRA theory and implementation")
print("- Fine-tuning Llama 3.1 for code documentation")
print("- Advanced techniques like QLoRA")
print("- Best practices and optimization strategies")
print("\nYou can now adapt this framework for your own tasks!")