# Inference with Trained LoRA
This notebook loads a trained LoRA adapter and uses it with the Stable Diffusion pipeline to generate images based on prompts.

**Key steps:**
1. Load the base Stable Diffusion model.
2. Load the trained LoRA adapter weights.
3. Sample a prompt from the dataset.
4. Generate and display the resulting image.

---

In [None]:
# Essential Imports
import gc
import random
import pandas as pd
from pathlib import Path
from typing import Any, Dict, Optional
import matplotlib.pyplot as plt
from PIL import Image as PILImage
import torch
from diffusers import StableDiffusionPipeline
from transformers import CLIPTokenizer
from peft import LoraConfig

In [None]:
# --- Setup and Path Configuration ---
# 💡 NOTE: If you're running this in Colab, you need to mount your drive first!
# from google.colab import drive
# drive.mount('/content/drive')
# project_folder = Path('/content/drive/MyDrive/My_Cool_Project')
# dataset_root = project_folder / 'datasets/appa-real-dataset_v2'
# lora_checkpoint_base_dir = project_folder / 'lora_training_runs'

# Local paths for demonstration
dataset_root: Path = Path('./datasets/appa-real-dataset_v2')
lora_checkpoint_base_dir: Path = Path("./lora_training_runs")

# Define paths to your metadata and image data
labels_md_train = dataset_root / 'labels_metadata_train.csv'
ds_train = dataset_root / 'train_data'

# Load the dataset metadata to generate a sample prompt
df_md_train = pd.read_csv(labels_md_train)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")

# Helper function to build a prompt from a metadata row
def build_prompt(row: pd.Series) -> str:
    age_desc = f"{int(row['age'])} years old"
    gender_desc = row['gender']
    ethnicity_desc = row['ethnicity']
    emotion_map = {
        'neutral': "with a neutral expression",
        'happy': "smiling happily",
        'slightlyhappy': "smiling slightly",
        'other': "showing a subtle emotion"
    }
    emotion_desc = emotion_map.get(row['emotion'], "with an expression")
    return f"A {age_desc} {ethnicity_desc} {gender_desc} {emotion_desc}"

In [None]:
# --------------------------------------------------------------------------------------------------
## Core Inference Functions
# --------------------------------------------------------------------------------------------------

def load_lora_for_inference(
    lora_checkpoint_path: Path,
    device: str = 'cuda'
) -> Optional[StableDiffusionPipeline]:
    """
    Loads the Stable Diffusion pipeline and applies the trained LoRA weights for inference.

    Args:
        lora_checkpoint_path (Path): The path to the saved LoRA adapter directory.
        device (str): Device to load the model onto.

    Returns:
        Optional[StableDiffusionPipeline]: The configured pipeline, or None if loading fails.
    """
    if not lora_checkpoint_path.exists():
        print(f"❌ Error: LoRA checkpoint path not found at {lora_checkpoint_path}.")
        return None

    print("🚀 Loading base Stable Diffusion model...")
    pipe = StableDiffusionPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        torch_dtype=torch.float16
    )
    
    # 💡 IMPORTANT: Load the LoRA weights into the pipeline
    pipe.unet.load_attn_procs(lora_checkpoint_path)
    
    # Move the entire pipeline to the specified device
    pipe = pipe.to(device)

    # Disable the safety checker as it can be resource-intensive and is not
    # typically needed for fine-tuned models if you trust the dataset.
    pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images))

    print(f"✅ Base model loaded and LoRA weights applied from {lora_checkpoint_path}")
    return pipe

def generate_image_from_prompt(
    pipe: StableDiffusionPipeline,
    prompt: str,
    device: str
) -> PILImage.Image:
    """Generates an image from a prompt using the given pipeline."""
    # Ensure the model is in evaluation mode
    pipe.unet.eval()

    # Use autocasting for efficiency on GPUs and set seed for reproducibility
    generator = torch.Generator(device=device).manual_seed(42)

    with torch.no_grad():
        with torch.amp.autocast(device_type=device):
            result = pipe(
                prompt,
                num_inference_steps=75,
                guidance_scale=7.5,
                generator=generator
            )
    return result.images[0]

def unload_model(pipe: StableDiffusionPipeline) -> None:
    """Unloads the pipeline and clears GPU memory."""
    print("🔻 Unloading pipeline to free GPU memory...")
    del pipe
    torch.cuda.empty_cache()
    gc.collect()
    print("✅ Model unloaded and GPU cache cleared.")

In [None]:
######################################
##### SET RUN_ID TO GENERATE FROM ######
######################################

# You have two options:
# 1. Load the best-performing checkpoint
# RESUME_RUN_ID = 'af641a_run_20250805-173847'
# lora_path = lora_checkpoint_base_dir / RESUME_RUN_ID / "lora_checkpoints/best_lora_adapter"

# 2. Load a checkpoint from a specific epoch
RESUME_RUN_ID = 'af641a_run_20250805-173847'
RESUME_EPOCH = 10
lora_path = lora_checkpoint_base_dir / RESUME_RUN_ID / f"lora_checkpoints/lora_adapters_epoch_{RESUME_EPOCH}"

device = "cuda" if torch.cuda.is_available() else "cpu"

# ---- Load LoRA Weights and Model ----
pipe: Optional[StableDiffusionPipeline] = load_lora_for_inference(lora_path, device)

if pipe is None:
    print("❌ Failed to load the model. Please check the paths.")
else:
    # ---- Access Dataset & Sample Prompt ----
    sample_row: pd.Series = random.choice(df_md_train.to_dict(orient="records"))
    sample_prompt: str = build_prompt(sample_row)

    print(f"\n🔹 Generating image with prompt: {sample_prompt}")

    # ---- Generate and Display Image ----
    img: PILImage.Image = generate_image_from_prompt(pipe, sample_prompt, device)

    # ---- Display Image ----
    plt.figure(figsize=(10, 10))
    plt.imshow(img)
    plt.axis('off')
    plt.title(f"Generated Image\nPrompt: {sample_prompt}", fontsize=12, wrap=True)
    plt.show()

    # Unload model after use to free GPU memory
    unload_model(pipe)