# Lab 2.6.5: LoRA Style Training - Create Your Own Style

**Module:** 2.6 - Diffusion Models  
**Time:** 2 hours  
**Difficulty:** ‚≠ê‚≠ê‚≠ê‚≠ê (Intermediate-Advanced)

---

## üéØ Learning Objectives

By the end of this notebook, you will:
- [ ] Understand how LoRA adapters work for diffusion models
- [ ] Prepare a dataset for style training
- [ ] Train an SDXL LoRA on custom style images
- [ ] Apply and combine multiple LoRAs
- [ ] Adjust LoRA strength for style control

---

## üìö Prerequisites

- Completed: Labs 2.6.1-2.6.4
- Knowledge of: SDXL generation, basic training concepts
- **Required packages:**
  - `diffusers>=0.27.0`
  - `peft>=0.10.0` (for LoRA adapter support)
  - `transformers>=4.38.0`
  - `datasets`

**Version check:**
```python
# Run this to verify your versions
import diffusers, peft, transformers
print(f"diffusers: {diffusers.__version__}")  # Need >=0.27.0
print(f"peft: {peft.__version__}")            # Need >=0.10.0
print(f"transformers: {transformers.__version__}")  # Need >=4.38.0
```

---

## üåç Real-World Context

**LoRA lets you create custom AI art styles:**

- **Artists** create their signature style as a LoRA
- **Game studios** train LoRAs for consistent game art
- **Brands** develop on-brand image generation
- **Researchers** adapt models for specific domains

With DGX Spark's 128GB memory, you can train LoRAs comfortably at full precision!

---

## üßí ELI5: What is LoRA?

> **Imagine you have a master artist (SDXL) who can paint anything.**
>
> Instead of retraining them completely (expensive!), you give them a small
> "style guide" notebook (LoRA) that shows examples of a specific style.
>
> Now when they paint, they reference the notebook to add that style!
>
> **Benefits:**
> - The notebook is tiny (10-100MB vs 7GB for the full model)
> - You can swap notebooks (styles) instantly
> - You can combine multiple notebooks
> - The original skills aren't forgotten

### How LoRA Works (Technical)

```
Original Model:      With LoRA:
                     
   W                 W + ŒîW
   ‚îÇ                  ‚îÇ
   ‚îÇ                  ‚îÇ  where ŒîW = A √ó B
   ‚ñº                  ‚ñº  (low-rank matrices)
[Input] ‚îÄ‚îÄ‚ñ∫ [Output]  [Input] ‚îÄ‚îÄ‚ñ∫ [Output + Style Shift]

- W: Original weights (frozen, not trained)
- A, B: Small trainable matrices (rank 4-32)
- ŒîW = A √ó B: The "style adjustment"
```

LoRA trains only A and B (~0.1% of total parameters)!

---

## Part 1: Setting Up

In [None]:
# Core imports
import torch
import gc
import time
import os
from pathlib import Path

# Diffusers and PEFT
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
from diffusers.utils import load_image
from peft import LoraConfig

# Data handling
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# Visualization
import matplotlib.pyplot as plt
import numpy as np

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"Memory: {mem:.1f} GB")
    
# Set random seed
torch.manual_seed(42)

In [None]:
# Helper functions
def get_memory_usage():
    """Get current GPU memory usage."""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1e9
        return f"{allocated:.2f}GB"
    return "N/A"

def show_images_grid(images, titles=None, ncols=4, figsize=(16, 4)):
    """Display images in a grid."""
    n = len(images)
    nrows = (n + ncols - 1) // ncols
    
    fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
    axes = axes.flatten() if nrows > 1 else [axes] if ncols == 1 else axes
    
    for i, ax in enumerate(axes):
        if i < n:
            ax.imshow(images[i])
            if titles and i < len(titles):
                ax.set_title(titles[i], fontsize=10)
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

print("Helper functions ready!")

---

## Part 2: Understanding LoRA Configuration

Before training, let's understand the key LoRA parameters.

In [None]:
# LoRA configuration explanation
print("LoRA Configuration Parameters:")
print("=" * 50)
print("""
r (rank): How much capacity the LoRA has
  - 4:   Minimal capacity, subtle changes
  - 16:  Good balance (recommended)
  - 32:  More capacity, stronger styles
  - 64+: Maximum capacity, can overfit

lora_alpha: Scaling factor
  - Usually set to 2√ó rank (e.g., r=16, alpha=32)
  - Higher = stronger effect at inference

target_modules: Which layers to adapt
  - For SDXL U-Net attention:
    ["to_q", "to_k", "to_v", "to_out.0"]
  - Can also include cross-attention:
    ["add_k_proj", "add_v_proj"]

lora_dropout: Regularization
  - 0.0: No dropout (can overfit)
  - 0.05-0.1: Light regularization (recommended)
""")

# Example configuration
example_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["to_q", "to_k", "to_v", "to_out.0"],
    lora_dropout=0.05,
    bias="none",
)

print("\nExample LoRA Config:")
print(f"  Rank: {example_config.r}")
print(f"  Alpha: {example_config.lora_alpha}")
print(f"  Target modules: {example_config.target_modules}")
print(f"  Dropout: {example_config.lora_dropout}")

---

## Part 3: Preparing Training Data

For this example, we'll create a synthetic "art style" dataset.
In practice, you would use 10-50 images in your desired style.

In [None]:
# Create a simple dataset class
class StyleDataset(Dataset):
    """
    Dataset for LoRA training.
    
    Expected structure:
    data_dir/
        image1.jpg
        image1.txt  (caption)
        image2.jpg
        image2.txt
        ...
    """
    
    def __init__(self, data_dir, resolution=1024, center_crop=True):
        self.data_dir = Path(data_dir)
        self.resolution = resolution
        self.center_crop = center_crop
        
        # Find all images
        self.image_paths = []
        self.captions = []
        
        for ext in ['*.jpg', '*.jpeg', '*.png', '*.webp']:
            for img_path in self.data_dir.glob(ext):
                caption_path = img_path.with_suffix('.txt')
                if caption_path.exists():
                    with open(caption_path) as f:
                        caption = f.read().strip()
                    self.image_paths.append(img_path)
                    self.captions.append(caption)
        
        print(f"Found {len(self.image_paths)} images with captions")
        
        # Transforms
        self.transform = transforms.Compose([
            transforms.Resize(resolution, interpolation=transforms.InterpolationMode.LANCZOS),
            transforms.CenterCrop(resolution) if center_crop else transforms.RandomCrop(resolution),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),  # [-1, 1] range
        ])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        image = self.transform(image)
        caption = self.captions[idx]
        
        return {
            'pixel_values': image,
            'caption': caption,
        }

In [None]:
# Create sample training data for demonstration
# In practice, you would use your own style images

sample_data_dir = Path("./sample_lora_data")
sample_data_dir.mkdir(exist_ok=True)

# Create synthetic "style" images using base SDXL
# This is just for demonstration - normally you'd have real images

print("Creating sample training data...")

# Load base model for generating sample data
pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.bfloat16,
    variant="fp16",
)
pipe = pipe.to(device)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

# Generate sample images in a specific "style" (watercolor for demo)
style_prefix = "watercolor painting of"
subjects = [
    "a serene mountain landscape",
    "a beautiful flower garden",
    "a cozy cottage in the woods",
    "a peaceful lake at sunset",
    "rolling hills with wildflowers",
    "a charming village street",
    "autumn trees with golden leaves",
    "a rustic bridge over a stream",
]

for i, subject in enumerate(subjects):
    prompt = f"{style_prefix} {subject}, soft colors, artistic, delicate brushstrokes"
    generator = torch.Generator(device=device).manual_seed(42 + i)
    
    image = pipe(
        prompt=prompt,
        num_inference_steps=25,
        generator=generator,
    ).images[0]
    
    # Save image and caption
    image_path = sample_data_dir / f"sample_{i:02d}.jpg"
    caption_path = sample_data_dir / f"sample_{i:02d}.txt"
    
    image.save(image_path)
    with open(caption_path, 'w') as f:
        f.write(f"a watercolor painting of {subject}")
    
    print(f"  Created: {image_path.name}")

print(f"\n‚úÖ Created {len(subjects)} sample training images")

In [None]:
# Display the sample training images
sample_images = [Image.open(f) for f in sorted(sample_data_dir.glob("*.jpg"))]
show_images_grid(sample_images[:8], ncols=4, figsize=(16, 8))
print("These are the sample images we'll train our LoRA on.")
print("The goal: Learn the 'watercolor' style to apply to any prompt!")

---

## Part 4: LoRA Training Loop

Now let's set up and run the training.

In [None]:
from peft import get_peft_model, LoraConfig
import torch.nn.functional as F
from tqdm.auto import tqdm

# Training configuration
config = {
    'learning_rate': 1e-4,
    'num_epochs': 50,  # Reduce for faster demo, increase for quality
    'batch_size': 1,   # Increase if memory allows
    'gradient_accumulation_steps': 4,
    'lora_rank': 16,
    'lora_alpha': 32,
    'output_dir': './lora_output',
}

print("Training Configuration:")
for k, v in config.items():
    print(f"  {k}: {v}")

In [None]:
# Create dataset and dataloader
dataset = StyleDataset(sample_data_dir, resolution=1024)
dataloader = DataLoader(
    dataset, 
    batch_size=config['batch_size'], 
    shuffle=True,
    num_workers=0,  # Set to 0 for notebook compatibility
)

print(f"Dataset size: {len(dataset)}")
print(f"Batches per epoch: {len(dataloader)}")

In [None]:
# Configure LoRA for the U-Net
lora_config = LoraConfig(
    r=config['lora_rank'],
    lora_alpha=config['lora_alpha'],
    target_modules=[
        "to_q", "to_k", "to_v", "to_out.0",  # Self-attention
        "add_k_proj", "add_v_proj",           # Cross-attention
    ],
    lora_dropout=0.05,
    bias="none",
)

# Add LoRA adapters to U-Net
pipe.unet.add_adapter(lora_config)

# Count trainable parameters
trainable = sum(p.numel() for p in pipe.unet.parameters() if p.requires_grad)
total = sum(p.numel() for p in pipe.unet.parameters())
print(f"\nLoRA adapter added!")
print(f"Trainable parameters: {trainable:,} ({100*trainable/total:.2f}%)")
print(f"Total U-Net parameters: {total:,}")

In [None]:
# Prepare for training
from diffusers import DDPMScheduler

# Use DDPM scheduler for training noise schedule
noise_scheduler = DDPMScheduler.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    subfolder="scheduler"
)

# Freeze everything except LoRA
pipe.vae.requires_grad_(False)
pipe.text_encoder.requires_grad_(False)
pipe.text_encoder_2.requires_grad_(False)

# Move VAE to float32 for stability
pipe.vae.to(dtype=torch.float32)

# Optimizer
optimizer = torch.optim.AdamW(
    [p for p in pipe.unet.parameters() if p.requires_grad],
    lr=config['learning_rate'],
)

print("Training setup complete!")

In [None]:
# Training loop
print(f"\nüöÄ Starting LoRA training...")
print(f"   Epochs: {config['num_epochs']}")
print(f"   Memory: {get_memory_usage()}")
print()

losses = []
pipe.unet.train()

for epoch in range(config['num_epochs']):
    epoch_loss = 0
    
    for step, batch in enumerate(dataloader):
        # Move to device
        pixel_values = batch['pixel_values'].to(device, dtype=torch.float32)
        captions = batch['caption']
        
        # Encode images to latent space
        with torch.no_grad():
            latents = pipe.vae.encode(pixel_values).latent_dist.sample()
            latents = latents * pipe.vae.config.scaling_factor
            latents = latents.to(dtype=torch.bfloat16)
        
        # Sample noise and timesteps
        noise = torch.randn_like(latents)
        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps, 
            (latents.shape[0],), device=device
        ).long()
        
        # Add noise
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
        
        # Get text embeddings
        with torch.no_grad():
            prompt_embeds, pooled_embeds = pipe.encode_prompt(
                captions,
                device=device,
                num_images_per_prompt=1,
                do_classifier_free_guidance=False,
            )
        
        # Add time embeddings for SDXL
        add_time_ids = pipe._get_add_time_ids(
            (1024, 1024),  # original_size
            (0, 0),        # crops_coords_top_left
            (1024, 1024),  # target_size
            dtype=prompt_embeds.dtype,
            text_encoder_projection_dim=pipe.text_encoder_2.config.projection_dim,
        ).to(device)
        add_time_ids = add_time_ids.repeat(latents.shape[0], 1)
        
        # Predict noise
        added_cond_kwargs = {
            "text_embeds": pooled_embeds,
            "time_ids": add_time_ids,
        }
        
        noise_pred = pipe.unet(
            noisy_latents,
            timesteps,
            encoder_hidden_states=prompt_embeds,
            added_cond_kwargs=added_cond_kwargs,
        ).sample
        
        # Compute loss
        loss = F.mse_loss(noise_pred, noise, reduction="mean")
        loss = loss / config['gradient_accumulation_steps']
        loss.backward()
        
        epoch_loss += loss.item()
        
        # Gradient accumulation
        if (step + 1) % config['gradient_accumulation_steps'] == 0:
            torch.nn.utils.clip_grad_norm_(pipe.unet.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()
    
    avg_loss = epoch_loss / len(dataloader)
    losses.append(avg_loss)
    
    if (epoch + 1) % 10 == 0 or epoch == 0:
        print(f"Epoch {epoch+1}/{config['num_epochs']} - Loss: {avg_loss:.4f}")

print(f"\n‚úÖ Training complete!")
print(f"Final loss: {losses[-1]:.4f}")

In [None]:
# Plot training loss
plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('LoRA Training Loss')
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# Save the trained LoRA
output_dir = Path(config['output_dir'])
output_dir.mkdir(exist_ok=True)

# Save LoRA weights
pipe.unet.save_attn_procs(output_dir / "watercolor_lora")
print(f"\nüíæ LoRA saved to {output_dir / 'watercolor_lora'}")

# Check file sizes
for f in (output_dir / "watercolor_lora").iterdir():
    size = f.stat().st_size / 1e6
    print(f"   {f.name}: {size:.2f} MB")

---

## Part 5: Testing the Trained LoRA

In [None]:
# Set to eval mode
pipe.unet.eval()

# Restore inference scheduler
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

# Test prompts - subjects NOT in training data
test_prompts = [
    "a watercolor painting of a cat sleeping on a windowsill",
    "a watercolor painting of a bustling city street",
    "a watercolor painting of a spaceship in deep space",
    "a watercolor painting of a dragon flying over mountains",
]

print("Testing trained LoRA with new subjects...\n")

test_images = []
for prompt in test_prompts:
    print(f"Generating: {prompt[:50]}...")
    generator = torch.Generator(device=device).manual_seed(42)
    
    with torch.no_grad():
        image = pipe(
            prompt=prompt,
            num_inference_steps=25,
            generator=generator,
        ).images[0]
    
    test_images.append(image)

# Display results
titles = [p[23:50] + "..." for p in test_prompts]
show_images_grid(test_images, titles, ncols=2, figsize=(14, 14))

print("\nüé® The watercolor style is now applied to new subjects!")

---

## Part 6: Adjusting LoRA Strength

In [None]:
# Test different LoRA strengths
prompt = "a watercolor painting of a majestic eagle soaring through clouds"
strengths = [0.0, 0.3, 0.5, 0.7, 1.0]

images = []
for strength in strengths:
    print(f"Generating with LoRA scale={strength}...")
    
    # Set LoRA scale
    pipe.unet.set_adapters(["default"], adapter_weights=[strength])
    
    generator = torch.Generator(device=device).manual_seed(42)
    with torch.no_grad():
        image = pipe(
            prompt=prompt,
            num_inference_steps=25,
            generator=generator,
        ).images[0]
    
    images.append(image)

# Reset to full strength
pipe.unet.set_adapters(["default"], adapter_weights=[1.0])

# Display
titles = [f"Scale: {s}" for s in strengths]
show_images_grid(images, titles, ncols=5, figsize=(20, 4))

print("\nüìä LoRA Scale Guide:")
print("  0.0: No LoRA effect (base model)")
print("  0.3-0.5: Subtle style influence")
print("  0.7-1.0: Strong style application")
print("  >1.0: Over-stylized (use carefully)")

---

## Part 7: Loading Pre-trained LoRAs from CivitAI

You can also use community-created LoRAs!

In [None]:
# Example of loading a LoRA from file
print("Loading a LoRA from file:")
print("""
# Download a LoRA from CivitAI or Hugging Face
# Then load it like this:

from diffusers import StableDiffusionXLPipeline

pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.bfloat16,
)

# Load LoRA weights
pipe.load_lora_weights("./my_lora.safetensors")

# Generate with LoRA
image = pipe(
    prompt="...",
    cross_attention_kwargs={"scale": 0.8},  # LoRA strength
).images[0]

# Unload LoRA when done
pipe.unload_lora_weights()
""")

print("üí° Popular LoRA sources:")
print("  - https://civitai.com (largest community)")
print("  - https://huggingface.co/models (official)")

---

## ‚ö†Ô∏è Common Mistakes

### Mistake 1: Too High Learning Rate

```python
# ‚ùå Wrong: Will cause training instability
learning_rate = 1e-3

# ‚úÖ Right: Start lower
learning_rate = 1e-4  # Good starting point
```

### Mistake 2: Too Few Training Images

```python
# ‚ùå Wrong: Will overfit to specific images
dataset = 3 images

# ‚úÖ Right: Need variety
dataset = 10-50 images with varied subjects
```

### Mistake 3: Poor Captions

```python
# ‚ùå Wrong: Generic captions
caption = "a picture"

# ‚úÖ Right: Descriptive with trigger word
caption = "a watercolor painting of a mountain landscape, soft colors"
```

---

## üéâ Checkpoint

You've learned:
- ‚úÖ How LoRA adapters work (low-rank updates)
- ‚úÖ Preparing datasets for style training
- ‚úÖ Training a custom SDXL LoRA
- ‚úÖ Adjusting LoRA strength at inference
- ‚úÖ Loading pre-trained LoRAs

---

## üßπ Cleanup

In [None]:
# Clean up
del pipe
gc.collect()
torch.cuda.empty_cache()
print("GPU memory cleared!")

---

## Next Steps

Proceed to **Lab 2.6.6: Image Generation Pipeline** to build a complete end-to-end generation system!