In [None]:
pip install diffusers transformers accelerate datasets bitsandbytes


In [None]:
from diffusers import StableDiffusionXLPipeline
from peft import LoraConfig, get_peft_model


In [None]:

# Load SDXL
pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")

# Apply LoRA
lora_config = LoraConfig(r=4, lora_alpha=16, target_modules=["to_q", "to_v"])
pipe.unet = get_peft_model(pipe.unet, lora_config)

# Training loop (pseudo-code)
for batch in dataloader:
    text, image = batch
    loss = pipe.train_step(text, image)
    loss.backward()
    optimizer.step()


In [None]:
pipe.load_lora_weights("path/to/lora")
image = pipe("young male, short curly hair, medium-dark skin, oval face").images[0]


In [None]:
#!/usr/bin/env python3
"""
Train SDXL (Stable Diffusion XL 1.0-base) with LoRA on a text–image dataset.

Requirements:
    pip install torch torchvision pillow diffusers transformers accelerate datasets safetensors

Data format (simple):
    data_dir/
      images/
        000001.jpg
        000002.jpg
        ...
      prompts.jsonl    # one JSON per line: {"image": "000001.jpg", "prompt": "compressed face description ..."}

Usage (example):
    python train_sdxl_lora.py \
        --data_dir /path/to/data \
        --output_dir ./sdxl_lora_out \
        --resolution 512 \
        --train_batch_size 4 \
        --gradient_accumulation_steps 2 \
        --learning_rate 5e-5 \
        --max_train_steps 5000 \
        --checkpointing_steps 500 \
        --seed 42

Notes:
- This script follows a standard diffusion training loop for SDXL and applies LoRA to UNet attention processors.
- Prompts should already be compressed to <=77 tokens (CLIP limit) for consistency.
- Saves LoRA weights only; load them at inference time with the SDXL pipeline.
"""

import os
import math
import json
import random
import argparse
from dataclasses import dataclass
from typing import List, Dict, Any

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

from accelerate import Accelerator
from accelerate.utils import set_seed

from diffusers import (
    StableDiffusionXLPipeline,
    AutoencoderKL,
    DDIMScheduler,
)
from diffusers.utils import logging
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
from transformers import AutoTokenizer, CLIPTextModel, CLIPTextModelWithProjection

from safetensors.torch import save_file

logger = logging.get_logger(__name__)


# -------------------------
# Dataset
# -------------------------

class TextImageDataset(Dataset):
    def __init__(self, data_dir: str, resolution: int):
        self.images_dir = os.path.join(data_dir, "images")
        prompts_path = os.path.join(data_dir, "prompts.jsonl")
        assert os.path.exists(self.images_dir), f"Missing images dir: {self.images_dir}"
        assert os.path.exists(prompts_path), f"Missing prompts.jsonl: {prompts_path}"

        self.items: List[Dict[str, Any]] = []
        with open(prompts_path, "r", encoding="utf-8") as f:
            for line in f:
                obj = json.loads(line)
                img_name = obj["image"]
                prompt = obj["prompt"]
                img_path = os.path.join(self.images_dir, img_name)
                if os.path.exists(img_path):
                    self.items.append({"image": img_path, "prompt": prompt})

        assert len(self.items) > 0, "No valid items found."

        self.transform = transforms.Compose([
            CenterSquareCrop(),
            transforms.Resize((resolution, resolution), interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx: int):
        item = self.items[idx]
        image = Image.open(item["image"]).convert("RGB")
        image = self.transform(image)
        prompt = item["prompt"]
        return {"pixel_values": image, "prompt": prompt}


class CenterSquareCrop:
    """Center-crop to square by the shorter side."""
    def __call__(self, img: Image.Image) -> Image.Image:
        w, h = img.size
        side = min(w, h)
        left = (w - side) // 2
        top = (h - side) // 2
        return img.crop((left, top, left + side, top + side))


# -------------------------
# LoRA utilities
# -------------------------

def is_torch2_available():
    return hasattr(torch, "__version__") and int(torch.__version__.split(".")[0]) >= 2

def inject_lora_unet(unet, r: int = 4):
    """
    Replace attention processors with LoRA-enabled ones.
    """
    lora_cls = LoRAAttnProcessor2_0 if is_torch2_available() else LoRAAttnProcessor
    for name, module in unet.named_modules():
        if hasattr(module, "set_processor"):
            # create LoRA processor with rank r
            # NOTE: processor knows in/out dims internally
            module.set_processor(lora_cls(r=r))
    return unet

def collect_lora_parameters(unet):
    """
    Return only LoRA parameters for optimization/saving.
    """
    lora_params = []
    for name, module in unet.named_modules():
        proc = getattr(module, "processor", None)
        if proc is None:
            continue
        # LoRA processors store weights in 'lora_*' attributes
        params = [p for p in module.parameters() if p.requires_grad]
        if params:
            lora_params.extend(params)
    return lora_params

def save_unet_lora(unet, save_path: str):
    """
    Save LoRA weights from UNet to a safetensors file.
    """
    state = {}
    for name, module in unet.named_modules():
        proc = getattr(module, "processor", None)
        if proc is None:
            continue
        # grab parameters that require grad (LoRA params)
        for pname, param in module.named_parameters():
            if param.requires_grad:
                key = f"{name}.{pname}"
                state[key] = param.detach().cpu()
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    save_file(state, save_path)
    logger.info(f"Saved UNet LoRA weights to: {save_path}")


# -------------------------
# SDXL prompt encoding
# -------------------------

@dataclass
class PromptBatch:
    prompt: List[str]
    negative_prompt: List[str]


def encode_sdxl_prompts(pipe: StableDiffusionXLPipeline, prompts: List[str], device: torch.device):
    """
    Encode prompts for SDXL's dual text encoders (CLIP-ViT-L and OpenCLIP ViT-bigG).
    Returns a dict of embeddings consumed by UNet forward.
    """
    # SDXL helper handles dual encoders internally
    # Use blank negative or configurable one
    negs = [""] * len(prompts)
    enc = pipe.encode_prompt(
        prompt=prompts,
        negative_prompt=negs,
        device=device,
        num_images_per_prompt=1,
        do_classifier_free_guidance=True,
    )
    return enc


# -------------------------
# Training
# -------------------------

def train(args):
    accelerator = Accelerator(mixed_precision=args.mixed_precision)
    set_seed(args.seed)

    # Load base SDXL pipeline
    pipe = StableDiffusionXLPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        torch_dtype=torch.float16 if accelerator.mixed_precision == "fp16" else torch.float32,
        use_safetensors=True,
    )
    pipe.set_progress_bar_config(disable=True)
    pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)  # training noise scheduling

    # Freeze VAE and text encoders; train LoRA on UNet only
    vae: AutoencoderKL = pipe.vae
    text_encoder: CLIPTextModel = pipe.text_encoder
    text_encoder_2: CLIPTextModelWithProjection = pipe.text_encoder_2
    unet = pipe.unet

    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    text_encoder_2.requires_grad_(False)

    # Inject LoRA into UNet attention processors
    inject_lora_unet(unet, r=args.lora_rank)
    lora_params = collect_lora_parameters(unet)
    assert len(lora_params) > 0, "No LoRA parameters found in UNet."

    # Dataset / Dataloader
    dataset = TextImageDataset(args.data_dir, resolution=args.resolution)
    dl = DataLoader(
        dataset,
        batch_size=args.train_batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        drop_last=True,
        pin_memory=True,
    )

    # Optimizer
    optimizer = torch.optim.AdamW(lora_params, lr=args.learning_rate)

    # Prepare with accelerator
    unet, optimizer, dl = accelerator.prepare(unet, optimizer, dl)
    vae.to(accelerator.device)
    text_encoder.to(accelerator.device)
    text_encoder_2.to(accelerator.device)

    # Latent scaling constant used by SDXL/SD pipelines
    vae_scale_factor = 0.18215

    # Training state
    global_step = 0
    num_update_steps_per_epoch = math.ceil(len(dl) / args.gradient_accumulation_steps)
    max_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    # Optionally resume…
    os.makedirs(args.output_dir, exist_ok=True)

    accelerator.print(f"Starting training: steps={args.max_train_steps}, epochs={max_epochs}")

    for epoch in range(max_epochs):
        unet.train()

        for step, batch in enumerate(dl):
            with accelerator.accumulate(unet):
                # Images → latents
                pixel_values = batch["pixel_values"].to(accelerator.device)
                latents = vae.encode(pixel_values).latent_dist.sample()
                latents = latents * vae_scale_factor

                # Sample timesteps and add noise
                bsz = latents.shape[0]
                timesteps = torch.randint(0, pipe.scheduler.config.num_train_timesteps, (bsz,), device=accelerator.device).long()
                noise = torch.randn_like(latents)
                noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)

                # Encode prompts (dual encoders handled internally)
                enc = encode_sdxl_prompts(pipe, batch["prompt"], accelerator.device)

                # UNet forward
                model_pred = unet(
                    noisy_latents,
                    timesteps,
                    prompt_embeds=enc["prompt_embeds"],
                    pooled_prompt_embeds=enc["pooled_prompt_embeds"],
                    add_text_embeds=enc.get("add_text_embeds", None),
                    add_time_ids=enc.get("add_time_ids", None),
                ).sample

                # Standard diffusion training: predict noise
                loss = F.mse_loss(model_pred, noise, reduction="mean")

                accelerator.backward(loss)
                optimizer.step()
                optimizer.zero_grad()

            if accelerator.sync_gradients:
                global_step += 1

                # Checkpointing
                if args.checkpointing_steps > 0 and global_step % args.checkpointing_steps == 0:
                    # Save LoRA weights only
                    lora_path = os.path.join(args.output_dir, f"unet_lora_step_{global_step}.safetensors")
                    save_unet_lora(unet, lora_path)
                    accelerator.print(f"[epoch {epoch}] step {global_step} | loss={loss.item():.4f} | saved {lora_path}")

            # Stop condition
            if global_step >= args.max_train_steps:
                break

        accelerator.print(f"Epoch {epoch+1}/{max_epochs} completed. Last loss: {loss.item():.4f}")
        if global_step >= args.max_train_steps:
            break

    # Final save
    final_lora_path = os.path.join(args.output_dir, "unet_lora_final.safetensors")
    save_unet_lora(unet, final_lora_path)
    accelerator.print(f"Training complete. Final LoRA saved to {final_lora_path}")


# -------------------------
# Inference helper (optional)
# -------------------------

def load_pipe_with_lora(lora_path: str, fp16: bool = True) -> StableDiffusionXLPipeline:
    dtype = torch.float16 if fp16 and torch.cuda.is_available() else torch.float32
    pipe = StableDiffusionXLPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        torch_dtype=dtype,
        use_safetensors=True,
    )
    # Re-inject LoRA processors and load weights
    inject_lora_unet(pipe.unet, r=4)
    # Load weights into matching parameter names
    state = torch.load(lora_path, map_location="cpu") if lora_path.endswith(".pt") else None
    if state is None:
        from safetensors import safe_open
        tensors = {}
        with safe_open(lora_path, framework="pt", device="cpu") as f:
            for k in f.keys():
                tensors[k] = f.get_tensor(k)
        state = tensors
    missing = pipe.unet.load_state_dict(state, strict=False)
    print("Loaded LoRA; missing keys (expected for non-LoRA params):", missing.missing_keys)
    pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
    pipe.set_progress_bar_config(disable=False)
    return pipe


# -------------------------
# CLI
# -------------------------

def parse_args():
    parser = argparse.ArgumentParser(description="Train SDXL with LoRA on text-image pairs.")
    parser.add_argument("--data_dir", type=str, required=True, help="Directory containing images/ and prompts.jsonl")
    parser.add_argument("--output_dir", type=str, required=True, help="Directory to save LoRA checkpoints")
    parser.add_argument("--resolution", type=int, default=512, help="Input resolution (SDXL trained on 1024; we use 512)")
    parser.add_argument("--train_batch_size", type=int, default=4)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=2)
    parser.add_argument("--max_train_steps", type=int, default=5000)
    parser.add_argument("--checkpointing_steps", type=int, default=500)
    parser.add_argument("--learning_rate", type=float, default=5e-5)
    parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"])
    parser.add_argument("--num_workers", type=int, default=4)
    parser.add_argument("--lora_rank", type=int, default=4, help="LoRA rank for attention processors")
    parser.add_argument("--seed", type=int, default=42)
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    train(args)
