# SDXL Fine-tuning for Naruto Style Transfer using LoRA

This notebook demonstrates how to fine-tune Stable Diffusion XL on the Naruto dataset for style transfer using Low-Rank Adaptation (LoRA).

## What is LoRA?
LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning technique that allows you to adapt large models with only a fraction of learnable parameters. This makes training feasible on consumer GPUs.

## Key Features:
- Uses SDXL base model (1024x1024 resolution)
- Loads Naruto anime caption dataset from HuggingFace
- Trains only LoRA adapters (memory efficient)
- Supports both UNet and text encoder fine-tuning
- Includes validation and checkpoint saving

## 1. Install Dependencies

In [None]:
# Install required packages
!pip install -q diffusers transformers accelerate peft datasets torch torchvision
!pip install -q huggingface-hub safetensors Pillow numpy
!pip install -q tqdm wandb

## 2. Import Libraries

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

from datasets import load_dataset
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from peft import LoraConfig, get_peft_model

from accelerate import Accelerator
from accelerate.utils import set_seed

import os
import logging
import math
from pathlib import Path
from tqdm.auto import tqdm
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 3. Configuration Setup

In [None]:
# Training configuration
config = {
    # Model paths
    "pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0",
    "pretrained_vae_model_name_or_path": "madebyollin/sdxl-vae-fp16-fix",

    # Dataset
    "dataset_name": "lambdalabs/naruto-blip-captions",
    "caption_column": "text",
    "image_column": "image",

    # Output
    "output_dir": "./sdxl-naruto-lora",

    # Image settings
    "resolution": 1024,
    "center_crop": False,
    "random_flip": True,

    # Training parameters
    "train_batch_size": 1,  # Adjust based on GPU memory
    "num_train_epochs": 2,
    "max_train_steps": None,  # If None, uses num_train_epochs
    "learning_rate": 1e-4,
    "lr_scheduler": "constant",
    "lr_warmup_steps": 0,
    "gradient_accumulation_steps": 1,

    # LoRA parameters
    "rank": 64,
    "lora_dropout": 0.05,
    "use_dora": False,  # Set to True for DoRA variant
    "train_text_encoder": False,  # Set to True for better quality (needs more memory)

    # Optimization
    "mixed_precision": "fp16",  # "no", "fp16", "bf16"
    "enable_xformers_memory_efficient_attention": True,
    "gradient_checkpointing": True,

    # Validation
    "validation_prompt": "naruto uzumaki, anime style, 4k, detailed",
    "validation_epochs": 1,
    "num_validation_images": 2,

    # Checkpointing
    "checkpointing_steps": 100,

    # Other
    "seed": 42,
    "dataloader_num_workers": 0,
}

# Create output directory
os.makedirs(config["output_dir"], exist_ok=True)

print("Configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")

## 4. Set Random Seeds for Reproducibility

In [None]:
if config["seed"] is not None:
    set_seed(config["seed"])
    print(f"Random seed set to {config['seed']}")

## 5. Load Dataset and Explore

In [None]:
# Load dataset
logger.info(f"Loading dataset: {config['dataset_name']}")
dataset = load_dataset(config["dataset_name"], split="train", streaming=False)

print(f"Dataset loaded with {len(dataset)} samples")
print(f"Dataset columns: {dataset.column_names}")
print(f"\nFirst sample:")
print(f"  Image: {dataset[0][config['image_column']]}")
print(f"  Caption: {dataset[0][config['caption_column']]}")

### Visualize Dataset Samples

In [None]:
# Display a few sample images and captions
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.ravel()

for idx in range(4):
    sample = dataset[idx]
    image = sample[config["image_column"]]
    caption = sample[config["caption_column"]]

    axes[idx].imshow(image)
    axes[idx].set_title(f"{caption[:50]}..." if len(caption) > 50 else caption, fontsize=10, wrap=True)
    axes[idx].axis("off")

plt.tight_layout()
plt.savefig(os.path.join(config["output_dir"], "sample_images.png"), dpi=100, bbox_inches="tight")
plt.show()
print("Sample images saved!")

## 6. Create Custom Dataset Class

In [None]:
class NarutoDataset(Dataset):
    """Custom dataset for Naruto style transfer with SDXL"""

    def __init__(
        self,
        dataset,
        caption_column: str,
        image_column: str,
        size: int = 1024,
        center_crop: bool = False,
        random_flip: bool = False,
    ):
        self.dataset = dataset
        self.caption_column = caption_column
        self.image_column = image_column

        # Build transforms
        transform_list = [
            transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
        ]

        if center_crop:
            transform_list.append(transforms.CenterCrop(size))
        else:
            transform_list.append(transforms.RandomCrop(size))

        if random_flip:
            transform_list.append(transforms.RandomHorizontalFlip(p=0.5))

        transform_list.extend([
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])

        self.image_transforms = transforms.Compose(transform_list)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        example = self.dataset[idx]
        image = example[self.image_column]

        # Convert to RGB if needed
        if not image.mode == "RGB":
            image = image.convert("RGB")

        # Apply transforms
        pixel_values = self.image_transforms(image)

        # Get caption
        caption = example[self.caption_column]

        return {
            "pixel_values": pixel_values,
            "input_ids": caption,
        }


def collate_fn(examples):
    """Collate function for dataloader"""
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    input_ids = [example["input_ids"] for example in examples]

    return {
        "pixel_values": pixel_values,
        "input_ids": input_ids,
    }

print("Dataset class defined!")

## 7. Create DataLoader

In [None]:
# Create train dataset
train_dataset = NarutoDataset(
    dataset=dataset,
    caption_column=config["caption_column"],
    image_column=config["image_column"],
    size=config["resolution"],
    center_crop=config["center_crop"],
    random_flip=config["random_flip"],
)

print(f"Dataset created with {len(train_dataset)} samples")

# Create dataloader
train_dataloader = DataLoader(
    train_dataset,
    batch_size=config["train_batch_size"],
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=config["dataloader_num_workers"],
)

print(f"DataLoader created with {len(train_dataloader)} batches")

# Test one batch
sample_batch = next(iter(train_dataloader))
print(f"\nSample batch shapes:")
print(f"  pixel_values: {sample_batch['pixel_values'].shape}")
print(f"  input_ids (captions): {len(sample_batch['input_ids'])}")
print(f"  First caption: {sample_batch['input_ids'][0]}")

## 8. Load Pre-trained Models

In [None]:
# Determine weight dtype
weight_dtype = torch.float16 if config["mixed_precision"] == "fp16" else torch.float32

logger.info("Loading pre-trained models...")

# Load VAE
vae = AutoencoderKL.from_pretrained(
    config["pretrained_vae_model_name_or_path"],
    torch_dtype=weight_dtype,
)
print("‚úì VAE loaded")

# Load text encoders
text_encoder_one = CLIPTextModel.from_pretrained(
    config["pretrained_model_name_or_path"],
    subfolder="text_encoder",
    torch_dtype=weight_dtype,
)
print("‚úì Text Encoder 1 loaded")

text_encoder_two = CLIPTextModel.from_pretrained(
    config["pretrained_model_name_or_path"],
    subfolder="text_encoder_2",
    torch_dtype=weight_dtype,
)
print("‚úì Text Encoder 2 loaded")

# Load tokenizers
tokenizer_one = CLIPTokenizer.from_pretrained(
    config["pretrained_model_name_or_path"],
    subfolder="tokenizer",
)
print("‚úì Tokenizer 1 loaded")

tokenizer_two = CLIPTokenizer.from_pretrained(
    config["pretrained_model_name_or_path"],
    subfolder="tokenizer_2",
)
print("‚úì Tokenizer 2 loaded")

# Load UNet
unet = UNet2DConditionModel.from_pretrained(
    config["pretrained_model_name_or_path"],
    subfolder="unet",
    torch_dtype=weight_dtype,
)
print("‚úì UNet loaded")

# Load noise scheduler
noise_scheduler = DDPMScheduler.from_pretrained(
    config["pretrained_model_name_or_path"],
    subfolder="scheduler",
)
print("‚úì Noise Scheduler loaded")

print("\n‚úì All models loaded successfully!")

## 9. Freeze Base Models and Setup LoRA

In [None]:
# Freeze base models - we only train LoRA adapters
vae.requires_grad_(False)
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)

logger.info("Setting up LoRA...")

# Create LoRA config for UNet
unet_lora_config = LoraConfig(
    r=config["rank"],
    lora_alpha=config["rank"],
    init_lora_weights="gaussian",
    target_modules=["to_k", "to_q", "to_v", "to_out.0"],
    use_dora=config["use_dora"],
    lora_dropout=config["lora_dropout"],
)

# Apply LoRA to UNet
unet = get_peft_model(unet, unet_lora_config)
print(f"‚úì LoRA applied to UNet")
print(f"  - Rank: {config['rank']}")
print(f"  - Dropout: {config['lora_dropout']}")
print(f"  - DoRA: {config['use_dora']}")

# Optional: Apply LoRA to text encoders for better quality
if config["train_text_encoder"]:
    text_encoder_one.requires_grad_(True)
    text_encoder_two.requires_grad_(True)

    text_lora_config = LoraConfig(
        r=config["rank"],
        lora_alpha=config["rank"],
        init_lora_weights="gaussian",
        target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
        use_dora=config["use_dora"],
        lora_dropout=config["lora_dropout"],
    )

    text_encoder_one = get_peft_model(text_encoder_one, text_lora_config)
    text_encoder_two = get_peft_model(text_encoder_two, text_lora_config)
    print(f"‚úì LoRA applied to text encoders")
else:
    print(f"‚äò Text encoders frozen (set train_text_encoder=True to train them)")

# Print trainable parameters
trainable_params = sum(p.numel() for p in unet.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in unet.parameters())
print(f"\nUNet trainable parameters: {trainable_params:,} / {total_params:,}")
print(f"Trainable percentage: {100 * trainable_params / total_params:.2f}%")

## 10. Setup Gradient Checkpointing and Memory Optimization

In [None]:
if config["gradient_checkpointing"]:
    unet.enable_gradient_checkpointing()
    if config["train_text_encoder"]:
        text_encoder_one.gradient_checkpointing_enable()
        text_encoder_two.gradient_checkpointing_enable()
    print("‚úì Gradient checkpointing enabled")

if config["enable_xformers_memory_efficient_attention"]:
    try:
        unet.enable_xformers_memory_efficient_attention()
        print("‚úì xFormers memory efficient attention enabled")
    except Exception as e:
        logger.warning(f"Could not enable xFormers: {e}")
        print("‚äò xFormers not available (optional optimization)")

## 11. Setup Optimizer and Learning Rate Scheduler

In [None]:
# Collect parameters to optimize
if config["train_text_encoder"]:
    params_to_optimize = (
        list(unet.parameters()) +
        list(text_encoder_one.parameters()) +
        list(text_encoder_two.parameters())
    )
else:
    params_to_optimize = list(unet.parameters())

# Create optimizer
optimizer = torch.optim.AdamW(
    params_to_optimize,
    lr=config["learning_rate"],
    betas=(0.9, 0.999),
    weight_decay=1e-2,
    eps=1e-08,
)

print(f"‚úì Optimizer created")
print(f"  - Learning rate: {config['learning_rate']}")
print(f"  - Parameters to optimize: {len(params_to_optimize):,}")

# Calculate number of training steps
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config["gradient_accumulation_steps"])

if config["max_train_steps"] is None:
    config["max_train_steps"] = config["num_train_epochs"] * num_update_steps_per_epoch
else:
    config["num_train_epochs"] = math.ceil(config["max_train_steps"] / num_update_steps_per_epoch)

print(f"\n‚úì Training steps calculated")
print(f"  - Steps per epoch: {num_update_steps_per_epoch}")
print(f"  - Total epochs: {config['num_train_epochs']}")
print(f"  - Total steps: {config['max_train_steps']}")

# Create learning rate scheduler
lr_scheduler = get_scheduler(
    config["lr_scheduler"],
    optimizer=optimizer,
    num_warmup_steps=config["lr_warmup_steps"],
    num_training_steps=config["max_train_steps"],
)

print(f"‚úì Learning rate scheduler created")
print(f"  - Scheduler: {config['lr_scheduler']}")
print(f"  - Warmup steps: {config['lr_warmup_steps']}")

## 12. Helper Functions for Training

In [None]:
def encode_prompt(text_encoders, tokenizers, prompts, device, weight_dtype):
    """
    Encode prompts using SDXL's two text encoders.
    SDXL uses two text encoders (CLIP-L and OpenCLIP-bigG) for better text understanding.
    """
    prompt_embeds_list = []
    pooled_prompt_embeds = None

    for tokenizer, text_encoder in zip(tokenizers, text_encoders):
        text_inputs = tokenizer(
            prompts,
            padding="max_length",
            max_length=tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
        text_input_ids = text_inputs.input_ids

        prompt_embeds = text_encoder(
            text_input_ids.to(device),
            output_hidden_states=True,
        )

        pooled_prompt_embeds = prompt_embeds[0]
        prompt_embeds = prompt_embeds.hidden_states[-2]
        prompt_embeds_list.append(prompt_embeds)

    prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)

    return prompt_embeds.to(dtype=weight_dtype), pooled_prompt_embeds.to(dtype=weight_dtype)


def generate_validation_images(epoch, global_step):
    """
    Generate validation images using the current model state.
    """
    logger.info(f"Running validation at epoch {epoch}, step {global_step}...")

    # Create pipeline
    pipeline = StableDiffusionXLPipeline.from_pretrained(
        config["pretrained_model_name_or_path"],
        unet=unet,
        vae=vae,
        text_encoder=text_encoder_one,
        text_encoder_2=text_encoder_two,
        tokenizer=tokenizer_one,
        tokenizer_2=tokenizer_two,
        torch_dtype=weight_dtype,
    )
    pipeline = pipeline.to("cuda")
    pipeline.set_progress_bar_config(disable=True)

    # Generate images
    with torch.no_grad():
        images = pipeline(
            prompt=config["validation_prompt"],
            num_inference_steps=30,
            guidance_scale=7.5,
            num_images_per_prompt=config["num_validation_images"],
        ).images

    # Save images
    validation_dir = os.path.join(config["output_dir"], "validation")
    os.makedirs(validation_dir, exist_ok=True)

    for idx, image in enumerate(images):
        image.save(os.path.join(validation_dir, f"epoch_{epoch}_step_{global_step}_{idx}.png"))

    del pipeline
    torch.cuda.empty_cache()

    return images


def save_model(save_path, is_final=False):
    """
    Save the LoRA model weights.
    """
    os.makedirs(save_path, exist_ok=True)

    # Save UNet LoRA weights
    unet_lora_weights = unet.state_dict()
    torch.save(unet_lora_weights, os.path.join(save_path, "unet_lora.pth"))

    # Save text encoder LoRA weights if trained
    if config["train_text_encoder"]:
        text_encoder_one_lora_weights = text_encoder_one.state_dict()
        text_encoder_two_lora_weights = text_encoder_two.state_dict()
        torch.save(text_encoder_one_lora_weights, os.path.join(save_path, "text_encoder_one_lora.pth"))
        torch.save(text_encoder_two_lora_weights, os.path.join(save_path, "text_encoder_two_lora.pth"))

    logger.info(f"Model saved to {save_path}")

print("‚úì Helper functions defined")

## 13. Training Loop

In [None]:
# Move models to GPU and set to appropriate precision
unet = unet.to("cuda", dtype=weight_dtype)
vae = vae.to("cuda", dtype=weight_dtype)
text_encoder_one = text_encoder_one.to("cuda", dtype=weight_dtype)
text_encoder_two = text_encoder_two.to("cuda", dtype=weight_dtype)

print("‚úì Models moved to GPU")
print(f"GPU Memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB / {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

In [None]:
# Training loop
logger.info("Starting training...")

global_step = 0
epoch_losses = []
step_losses = []

# Set models to training mode
unet.train()
if config["train_text_encoder"]:
    text_encoder_one.train()
    text_encoder_two.train()

for epoch in range(config["num_train_epochs"]):
    epoch_loss = 0.0
    num_steps = 0

    progress_bar = tqdm(
        total=num_update_steps_per_epoch,
        desc=f"Epoch {epoch + 1}/{config['num_train_epochs']}",
        position=0,
        leave=True
    )

    for step, batch in enumerate(train_dataloader):
        # Move batch to GPU
        pixel_values = batch["pixel_values"].to("cuda", dtype=weight_dtype)
        input_ids = batch["input_ids"]

        # Encode images
        with torch.no_grad():
            latents = vae.encode(pixel_values).latent_dist.sample()
            latents = latents * vae.config.scaling_factor

        # Sample noise
        noise = torch.randn_like(latents)
        batch_size = latents.shape[0]

        # Sample timesteps
        timesteps = torch.randint(
            0,
            noise_scheduler.config.num_train_timesteps,
            (batch_size,),
            device=latents.device,
        ).long()

        # Add noise to latents (forward diffusion process)
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        # Encode prompts
        with torch.no_grad():
            prompt_embeds, pooled_prompt_embeds = encode_prompt(
                [text_encoder_one, text_encoder_two],
                [tokenizer_one, tokenizer_two],
                input_ids,
                "cuda",
                weight_dtype,
            )

        # ensure device/dtype match
        pooled_prompt_embeds = pooled_prompt_embeds.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)

        if pooled_prompt_embeds.dim() == 3:
            # average pooling across sequence dimension -> [B, D]
            pooled_prompt_embeds_2d = pooled_prompt_embeds.mean(dim=1)
        else:
            pooled_prompt_embeds_2d = pooled_prompt_embeds

        time_ids_for_unet = torch.zeros(batch_size, 6, device=prompt_embeds.device, dtype=weight_dtype)

        # Get UNet predictions (use pooled_prompt_embeds_2d now to avoid the 3D vs 2D concat)
        model_pred = unet(
            noisy_latents,
            timesteps,
            encoder_hidden_states=prompt_embeds,
            added_cond_kwargs={
                "text_embeds": pooled_prompt_embeds_2d,
                "time_ids": time_ids_for_unet,
            },
        ).sample

        # Compute loss
        if noise_scheduler.config.prediction_type == "epsilon":
            target = noise
        elif noise_scheduler.config.prediction_type == "v_prediction":
            target = noise_scheduler.get_velocity(latents, noise, timesteps)
        else:
            raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

        # Backward pass
        loss.backward()

        # Gradient clipping
        if hasattr(optimizer, "clip_grad_norm_"):
            optimizer.clip_grad_norm_(1.0)
        else:
            torch.nn.utils.clip_grad_norm_(params_to_optimize, 1.0)

        # Optimizer step
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        # Update progress
        global_step += 1
        epoch_loss += loss.detach().item()
        step_losses.append(loss.detach().item())
        num_steps += 1

        progress_bar.update(1)
        progress_bar.set_postfix({"loss": f"{loss.detach().item():.4f}"})

        # Save checkpoint
        if global_step % config["checkpointing_steps"] == 0:
            checkpoint_dir = os.path.join(config["output_dir"], f"checkpoint-{global_step}")
            save_model(checkpoint_dir)
            logger.info(f"Saved checkpoint at step {global_step}")

        # Validation
        if (epoch + 1) % config["validation_epochs"] == 0 and step == len(train_dataloader) - 1:
            try:
                images = generate_validation_images(epoch + 1, global_step)
                # Display validation images
                fig, axes = plt.subplots(1, len(images), figsize=(4 * len(images), 4))
                if len(images) == 1:
                    axes = [axes]
                for idx, img in enumerate(images):
                    axes[idx].imshow(img)
                    axes[idx].axis("off")
                plt.suptitle(f"Epoch {epoch + 1}, Step {global_step}")
                plt.tight_layout()
                plt.show()
            except Exception as e:
                logger.warning(f"Validation failed: {e}")

    progress_bar.close()

    # Log epoch loss
    avg_epoch_loss = epoch_loss / num_steps
    epoch_losses.append(avg_epoch_loss)
    logger.info(f"Epoch {epoch + 1}: Average loss = {avg_epoch_loss:.4f}")

logger.info("Training completed!")

## 14. Save Final Model

In [None]:
# Save final model
final_save_path = os.path.join(config["output_dir"], "final_model")
save_model(final_save_path, is_final=True)

print(f"‚úì Final model saved to {final_save_path}")

## 15. Plot Training Loss

In [None]:
# Plot training loss
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# Step loss
axes[0].plot(step_losses, linewidth=1.5)
axes[0].set_xlabel("Step")
axes[0].set_ylabel("Loss")
axes[0].set_title("Training Loss per Step")
axes[0].grid(True, alpha=0.3)

# Epoch loss
axes[1].plot(epoch_losses, marker='o', linewidth=2, markersize=8)
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Average Loss")
axes[1].set_title("Training Loss per Epoch")
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(config["output_dir"], "training_loss.png"), dpi=100, bbox_inches="tight")
plt.show()

print("‚úì Training loss plot saved!")

## 16. Inference with Fine-tuned Model

### Load and Test the Fine-tuned Model

In [None]:
def create_pipeline_with_lora(model_dir):
    """
    Create a pipeline with fine-tuned LoRA weights loaded.
    """
    # Load base pipeline
    pipeline = StableDiffusionXLPipeline.from_pretrained(
        config["pretrained_model_name_or_path"],
        torch_dtype=weight_dtype,
        use_safetensors=True,
        variant="fp16",
    )

    # Load custom VAE
    pipeline.vae = AutoencoderKL.from_pretrained(
        config["pretrained_vae_model_name_or_path"],
        torch_dtype=weight_dtype,
    )

    # Load LoRA weights from checkpoint
    unet_lora_path = os.path.join(model_dir, "unet_lora.pth")
    if os.path.exists(unet_lora_path):
        unet_lora_weights = torch.load(unet_lora_path, map_location="cpu")
        
    # Move to GPU
    pipeline = pipeline.to("cuda")

    return pipeline

print("‚úì Pipeline creation function defined")

In [None]:
# Create inference pipeline with trained UNet
inference_pipeline = StableDiffusionXLPipeline.from_pretrained(
    config["pretrained_model_name_or_path"],
    torch_dtype=weight_dtype,
    use_safetensors=True,
    variant="fp16",
)

# Load custom VAE
inference_pipeline.vae = AutoencoderKL.from_pretrained(
    config["pretrained_vae_model_name_or_path"],
    torch_dtype=weight_dtype,
)

# Use the trained UNet
inference_pipeline.unet = unet

# Enable memory optimizations
inference_pipeline.enable_attention_slicing()

# Move to GPU
inference_pipeline = inference_pipeline.to("cuda")

print("‚úì Inference pipeline created")

In [None]:
# Generate images with fine-tuned model
test_prompts = [
    "naruto uzumaki eating ramen",
    "Bill Gates in naruto style",
    "A boy with blue eyes in Naruto style",
]

print("Generating images with fine-tuned model...\n")

generated_images = {}

for prompt in test_prompts:
    print(f"Prompt: {prompt}")

    with torch.no_grad():
        result = inference_pipeline(
            prompt=prompt,
            num_inference_steps=30,
            guidance_scale=7.5,
            height=1024,
            width=1024,
        )

    image = result.images[0]
    generated_images[prompt] = image

    # Save image
    safe_prompt = prompt[:50].replace(" ", "_").replace(",", "")
    image.save(os.path.join(config["output_dir"], f"generated_{safe_prompt}.png"))
    print(f"‚úì Image saved\n")

print("‚úì All images generated!")

In [None]:
# Display generated images
num_prompts = len(test_prompts)
fig, axes = plt.subplots(num_prompts, 1, figsize=(8, 6 * num_prompts))

if num_prompts == 1:
    axes = [axes]

for idx, (prompt, image) in enumerate(generated_images.items()):
    axes[idx].imshow(image)
    axes[idx].set_title(f"{prompt}", fontsize=10, wrap=True)
    axes[idx].axis("off")

plt.tight_layout()
plt.savefig(os.path.join(config["output_dir"], "generated_images.png"), dpi=100, bbox_inches="tight")
plt.show()

print("‚úì Generated images displayed and saved!")

## 17. Model Summary and Results

In [None]:
print("\n" + "="*60)
print("TRAINING SUMMARY")
print("="*60)

print(f"\nüìä Training Statistics:")
print(f"  - Total epochs: {config['num_train_epochs']}")
print(f"  - Total steps: {global_step}")
print(f"  - Final loss: {step_losses[-1]:.4f}")
print(f"  - Best loss: {min(step_losses):.4f}")
print(f"  - Learning rate: {config['learning_rate']}")

print(f"\nüéØ Model Configuration:")
print(f"  - Base model: SDXL v1.0")
print(f"  - LoRA rank: {config['rank']}")
print(f"  - LoRA dropout: {config['lora_dropout']}")
print(f"  - Text encoder training: {config['train_text_encoder']}")
print(f"  - Resolution: {config['resolution']}x{config['resolution']}")

print(f"\nüíæ Output Directory:")
print(f"  {config['output_dir']}")

# List saved files
print(f"\nüìÅ Saved Files:")
if os.path.exists(config["output_dir"]):
    for root, dirs, files in os.walk(config["output_dir"]):
        level = root.replace(config["output_dir"], "").count(os.sep)
        indent = " " * 2 * level
        print(f"{indent}{os.path.basename(root)}/")
        subindent = " " * 2 * (level + 1)
        for file in files:
            size_mb = os.path.getsize(os.path.join(root, file)) / 1024 / 1024
            print(f"{subindent}{file} ({size_mb:.2f} MB)")

print(f"\n‚úÖ Training completed successfully!")
print("="*60)

## 18. Clean Up GPU Memory

In [None]:
# Clean up GPU memory
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("‚úì GPU memory cleared")
    print(f"Current GPU memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")