# Lab 2.4.6: Mamba Fine-tuning

**Module:** 2.4 - Efficient Architectures  
**Time:** 2 hours  
**Difficulty:** ‚≠ê‚≠ê‚≠ê‚≠ê (Advanced)

---

## üéØ Learning Objectives

By the end of this lab, you will:
- [ ] Understand that LoRA works with Mamba (not just transformers!)
- [ ] Fine-tune a Mamba model on a custom instruction dataset
- [ ] Compare fine-tuning memory requirements vs transformers
- [ ] Evaluate fine-tuned model performance

---

## üìö Prerequisites

- Completed: Labs 2.4.1-2.4.5
- Knowledge of: LoRA, fine-tuning basics, PEFT library
- Hardware: DGX Spark recommended (fine-tuning requires more memory)

---

## üåç Real-World Context

**Why Fine-tune Mamba?**

Mamba excels at long-context tasks, but base models need customization for:
- Domain-specific vocabulary (legal, medical, code)
- Following your organization's style guidelines
- Specific output formats
- Custom knowledge injection

**Good news**: LoRA works with Mamba! The linear projections in Mamba can be adapted just like transformer attention layers.

---

## üßí ELI5: LoRA on Mamba

> **Remember our LoRA analogy?**
>
> LoRA is like adding thin wallpaper instead of repainting the entire house.
>
> **For transformers**: We add wallpaper to attention layers (Q, K, V, O projections)
>
> **For Mamba**: We add wallpaper to:
> - Input projections (entering the SSM)
> - Output projections (leaving the SSM)
> - The selective parameter generators (B, C, delta)
>
> The principle is the same: low-rank updates to existing linear layers!

---

## Part 1: Setup

In [None]:
# Install required packages (uncomment if needed)
# !pip install peft>=0.8.0 datasets trl

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Optional
import gc
import time

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
)
from datasets import load_dataset, Dataset

try:
    from peft import (
        LoraConfig,
        get_peft_model,
        TaskType,
        prepare_model_for_kbit_training,
    )
    HAS_PEFT = True
except ImportError:
    HAS_PEFT = False
    print("‚ö†Ô∏è PEFT not installed. Run: pip install peft>=0.8.0")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")

In [None]:
# Clear memory
def clear_memory():
    gc.collect()
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        print(f"GPU memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")

clear_memory()

---

## Part 2: Load Mamba Model

In [None]:
# Choose Mamba model - smaller for faster fine-tuning
MODEL_NAME = "state-spaces/mamba-1.4b-hf"  # Good balance of size and capability
# Alternatives:
# MODEL_NAME = "state-spaces/mamba-130m-hf"  # Fastest for testing
# MODEL_NAME = "state-spaces/mamba-2.8b-hf"  # Best quality

print(f"Loading {MODEL_NAME}...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Report
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
memory = torch.cuda.memory_allocated() / 1e9

print(f"\n‚úÖ Model loaded!")
print(f"   Total params: {total_params/1e9:.2f}B")
print(f"   Trainable: {trainable_params/1e6:.1f}M (all for now)")
print(f"   Memory: {memory:.2f} GB")

---

## Part 3: Explore Model Structure for LoRA

In [None]:
# Find linear layers suitable for LoRA
def find_lora_targets(model) -> List[str]:
    """
    Find linear layers that can be targeted with LoRA.
    """
    linear_layers = []
    
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            # Extract the layer name without indices
            layer_type = name.split('.')[-1]
            if layer_type not in linear_layers:
                linear_layers.append(layer_type)
    
    return linear_layers

lora_targets = find_lora_targets(model)
print("Available LoRA targets:")
for target in lora_targets:
    print(f"  - {target}")

In [None]:
# Look at specific layer structure
print("\nMamba layer structure (first layer):")
for name, module in model.named_modules():
    if '.0.' in name and isinstance(module, nn.Linear):
        print(f"  {name}: {module.in_features} -> {module.out_features}")

---

## Part 4: Apply LoRA

In [None]:
if HAS_PEFT:
    # Auto-detect available LoRA target modules
    # Different Mamba versions may have different layer names
    available_targets = find_lora_targets(model)
    
    # Preferred Mamba target modules (in order of preference)
    preferred_targets = ["in_proj", "out_proj", "x_proj", "dt_proj", "embed_tokens", "lm_head"]
    
    # Find which preferred targets actually exist in the model
    valid_targets = [t for t in preferred_targets if t in available_targets]
    
    if not valid_targets:
        # Fallback: use first few available linear layers
        print("‚ö†Ô∏è Standard Mamba targets not found, using available layers...")
        valid_targets = available_targets[:3] if len(available_targets) >= 3 else available_targets
    
    if not valid_targets:
        print("‚ùå No linear layers found for LoRA! Check model architecture.")
    else:
        print(f"Detected valid LoRA targets: {valid_targets}")
        
        # Configure LoRA for Mamba
        lora_config = LoraConfig(
            r=16,  # LoRA rank
            lora_alpha=32,  # Scaling factor
            target_modules=valid_targets,  # Auto-detected layers
            lora_dropout=0.05,
            bias="none",
            task_type=TaskType.CAUSAL_LM,
        )
        
        print("\nLoRA Configuration:")
        print(f"  Rank: {lora_config.r}")
        print(f"  Alpha: {lora_config.lora_alpha}")
        print(f"  Targets: {lora_config.target_modules}")
        
        # Apply LoRA with error handling
        print("\nApplying LoRA...")
        try:
            model = get_peft_model(model, lora_config)
            
            # Check trainable parameters
            trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
            total_params = sum(p.numel() for p in model.parameters())
            
            print(f"\n‚úÖ LoRA applied!")
            print(f"   Trainable params: {trainable_params/1e6:.2f}M")
            print(f"   Total params: {total_params/1e9:.2f}B")
            print(f"   Trainable %: {100 * trainable_params / total_params:.2f}%")
            
            # Print trainable modules
            model.print_trainable_parameters()
        except ValueError as e:
            print(f"‚ùå LoRA application failed: {e}")
            print(f"\nAvailable targets in this model: {available_targets}")
            print("Try modifying target_modules to match available layers.")
else:
    print("PEFT not available. Skipping LoRA application.")

---

## Part 5: Prepare Training Data

In [None]:
# Create a simple instruction dataset
# For demonstration, we'll create a small synthetic dataset

instruction_examples = [
    {
        "instruction": "Summarize the following text in one sentence.",
        "input": "Machine learning is a subset of artificial intelligence that enables systems to learn from data.",
        "output": "Machine learning is an AI subset that allows systems to learn from data."
    },
    {
        "instruction": "Translate to French.",
        "input": "Hello, how are you?",
        "output": "Bonjour, comment allez-vous?"
    },
    {
        "instruction": "Write a Python function to calculate factorial.",
        "input": "",
        "output": "def factorial(n):\n    if n <= 1:\n        return 1\n    return n * factorial(n-1)"
    },
    {
        "instruction": "What is the capital of France?",
        "input": "",
        "output": "The capital of France is Paris."
    },
    {
        "instruction": "Explain photosynthesis simply.",
        "input": "",
        "output": "Photosynthesis is how plants make food using sunlight, water, and carbon dioxide."
    },
] * 100  # Repeat for more training data

def format_instruction(example):
    """Format an instruction example for training."""
    if example["input"]:
        text = f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:\n{example['output']}"
    else:
        text = f"### Instruction:\n{example['instruction']}\n\n### Response:\n{example['output']}"
    return text

# Create dataset
formatted_data = [format_instruction(ex) for ex in instruction_examples]

# Tokenize
def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=256,
        padding="max_length",
    )

dataset = Dataset.from_dict({"text": formatted_data})
tokenized_dataset = dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=["text"]
)

print(f"Dataset size: {len(tokenized_dataset)}")
print(f"\nExample formatted instruction:")
print("-" * 50)
print(formatted_data[0][:300])

In [None]:
# Split dataset
split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]

print(f"Train size: {len(train_dataset)}")
print(f"Eval size: {len(eval_dataset)}")

---

## Part 6: Fine-tuning

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir="./mamba-lora-output",
    num_train_epochs=1,  # Quick demo - increase for better results
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    warmup_steps=10,
    learning_rate=2e-4,
    bf16=True,  # Use bfloat16 for DGX Spark
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=50,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    report_to="none",  # Disable wandb for demo
    gradient_checkpointing=True,  # Save memory
)

# Data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,  # Causal LM, not masked
)

# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
)

print("Trainer configured!")
print(f"Training steps: {len(train_dataset) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps) * training_args.num_train_epochs}")

In [None]:
# Train!
print("Starting training...")
print(f"Initial GPU memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")

start_time = time.time()

try:
    train_result = trainer.train()
    
    training_time = time.time() - start_time
    print(f"\n‚úÖ Training complete!")
    print(f"   Time: {training_time/60:.1f} minutes")
    print(f"   Final loss: {train_result.training_loss:.4f}")
    print(f"   Peak GPU memory: {torch.cuda.max_memory_allocated()/1e9:.2f} GB")
    
except RuntimeError as e:
    if "out of memory" in str(e).lower():
        print("‚ùå Out of memory! Try:")
        print("   1. Reduce batch size")
        print("   2. Reduce max_length")
        print("   3. Use a smaller model")
        print("   4. Enable gradient checkpointing")
    else:
        raise

---

## Part 7: Evaluate Fine-tuned Model

In [None]:
# Test the fine-tuned model
def generate_response(model, tokenizer, instruction: str, input_text: str = "") -> str:
    """Generate a response from the fine-tuned model."""
    if input_text:
        prompt = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n"
    else:
        prompt = f"### Instruction:\n{instruction}\n\n### Response:\n"
    
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=100,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Extract just the response part
    if "### Response:" in response:
        response = response.split("### Response:")[-1].strip()
    
    return response

# Test with various prompts
test_prompts = [
    ("What is machine learning?", ""),
    ("Write a haiku about programming.", ""),
    ("Summarize this text:", "The sun rose over the mountains, casting golden light across the valley."),
]

print("\nü§ñ Testing Fine-tuned Model:")
print("=" * 60)

model.eval()
for instruction, input_text in test_prompts:
    print(f"\nüìù Instruction: {instruction}")
    if input_text:
        print(f"   Input: {input_text}")
    
    response = generate_response(model, tokenizer, instruction, input_text)
    print(f"   Response: {response}")
    print("-" * 40)

---

## Part 8: Compare Fine-tuning Memory

In [None]:
# Memory comparison: Mamba LoRA vs Full Fine-tuning

print("üìä Memory Comparison: Mamba Fine-tuning")
print("=" * 60)

# Theoretical calculations for comparison
models_comparison = {
    "Mamba-1.4B": {
        "params_b": 1.4,
        "lora_params_m": 5,  # Approximate
    },
    "Mamba-2.8B": {
        "params_b": 2.8,
        "lora_params_m": 10,
    },
    "Llama-3B (Transformer)": {
        "params_b": 3.0,
        "lora_params_m": 15,  # More target layers
    },
}

print(f"\n{'Model':<25} {'Full FT (GB)':<15} {'LoRA (GB)':<15} {'Savings':<10}")
print("-" * 65)

for name, specs in models_comparison.items():
    # Full fine-tuning: model + gradients + optimizer states
    full_ft_memory = specs["params_b"] * 2 * 6  # BF16 * (model + grad + adam states)
    
    # LoRA: model frozen + small trainable + gradients + optimizer
    lora_memory = (specs["params_b"] * 2) + (specs["lora_params_m"] / 1000 * 6)
    
    savings = (full_ft_memory - lora_memory) / full_ft_memory * 100
    
    print(f"{name:<25} {full_ft_memory:<15.1f} {lora_memory:<15.1f} {savings:.0f}%")

print("\nüí° LoRA reduces memory by ~70-80% compared to full fine-tuning!")

---

## Part 9: Save and Load LoRA Weights

In [None]:
if HAS_PEFT:
    # Save LoRA weights (much smaller than full model!)
    output_dir = "./mamba-lora-adapter"
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    
    # Check size
    import os
    total_size = sum(
        os.path.getsize(os.path.join(output_dir, f))
        for f in os.listdir(output_dir)
        if os.path.isfile(os.path.join(output_dir, f))
    )
    
    print(f"\n‚úÖ LoRA adapter saved to {output_dir}")
    print(f"   Size: {total_size / 1e6:.1f} MB")
    print(f"   (vs ~{1.4 * 2 * 1000:.0f} MB for full model)")

In [None]:
# To load the LoRA adapter later:
# from peft import PeftModel
# 
# base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
# model = PeftModel.from_pretrained(base_model, "./mamba-lora-adapter")

print("\nüìã To load the adapter later:")
print("""
from peft import PeftModel

base_model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-1.4b-hf")
model = PeftModel.from_pretrained(base_model, "./mamba-lora-adapter")
""")

---

## ‚ö†Ô∏è Common Mistakes

### Mistake 1: Wrong Target Modules
```python
# ‚ùå Using transformer target names for Mamba
target_modules=["q_proj", "v_proj"]  # These don't exist in Mamba!

# ‚ùå Hardcoding without checking
target_modules=["in_proj", "out_proj"]  # May not exist in all Mamba variants

# ‚úÖ Auto-detect available targets
available = find_lora_targets(model)
target_modules = [t for t in ["in_proj", "out_proj", "x_proj"] if t in available]
```

### Mistake 2: Too High Learning Rate
```python
# ‚ùå Standard fine-tuning LR
learning_rate=3e-5  # Too low for LoRA

# ‚úÖ Higher LR works for LoRA
learning_rate=2e-4  # LoRA can handle higher LR
```

### Mistake 3: Not Enabling Gradient Checkpointing
```python
# ‚ùå Memory pressure on long sequences
gradient_checkpointing=False

# ‚úÖ Save memory with gradient checkpointing
gradient_checkpointing=True
```

---

## üéâ Checkpoint

You've learned:
- ‚úÖ LoRA works on Mamba (not just transformers!)
- ‚úÖ Which Mamba layers to target for LoRA
- ‚úÖ How to prepare instruction data for fine-tuning
- ‚úÖ Memory savings from LoRA vs full fine-tuning
- ‚úÖ How to save and load LoRA adapters

---

## ‚úã Try It Yourself

### Exercise: Custom Domain Fine-tuning
1. Create a dataset specific to your domain (e.g., customer service, medical, legal)
2. Fine-tune Mamba on this dataset
3. Evaluate on domain-specific test prompts
4. Compare with base model responses

In [None]:
# Your code here



---

## üìñ Further Reading

- [PEFT Documentation](https://huggingface.co/docs/peft)
- [LoRA Paper](https://arxiv.org/abs/2106.09685)
- [QLoRA Paper](https://arxiv.org/abs/2305.14314)
- [Mamba Fine-tuning Guide](https://github.com/state-spaces/mamba)

---

## üßπ Cleanup

In [None]:
# Cleanup
if 'model' in dir():
    del model
if 'trainer' in dir():
    del trainer

gc.collect()
torch.cuda.empty_cache()

print("‚úÖ Cleanup complete!")