In [1]:
"""
Full-Scale Text-to-Image Diffusion Pipeline
Fine-tune Stable Diffusion on custom datasets
Optimized for Google Colab (GPU required)
"""

'\nFull-Scale Text-to-Image Diffusion Pipeline\nFine-tune Stable Diffusion on custom datasets\nOptimized for Google Colab (GPU required)\n'

In [2]:
# ============================================================================
# 1. INSTALL DEPENDENCIES
# ============================================================================
print("Installing dependencies...")
!pip install -q diffusers transformers accelerate datasets torch torchvision
!pip install -q xformers  # Memory-efficient attention
!pip install -q gradio wandb  # For visualization and logging
!pip install -q ftfy regex tqdm  # Text processing
!pip install -q bitsandbytes  # For 8-bit optimization



Installing dependencies...
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m117.2/117.2 MB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
!pip install -q peft  # For LoRA implementation

In [4]:
# ============================================================================
# 2. IMPORTS
# ============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from diffusers import (
    StableDiffusionPipeline,
    AutoencoderKL,
    UNet2DConditionModel,
    DDPMScheduler,
    DDIMScheduler,
    PNDMScheduler,
    DPMSolverMultistepScheduler
)
from transformers import CLIPTextModel, CLIPTokenizer
from accelerate import Accelerator
from datasets import load_dataset
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import os
import gc
from dataclasses import dataclass
from typing import Optional, List
import wandb

# Set seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

In [17]:
# ============================================================================
# 3. CONFIGURATION
# ============================================================================
@dataclass
class TrainingConfig:
    # Model
    pretrained_model_name: str = "runwayml/stable-diffusion-v1-5"
    # Alternative: "CompVis/stable-diffusion-v1-4", "stabilityai/stable-diffusion-2-1"

    # Training
    train_batch_size: int = 4
    gradient_accumulation_steps: int = 4  # Effective batch size = 4 * 4 = 16
    num_epochs: int = 10
    learning_rate: float = 5e-6
    lr_scheduler: str = "cosine"  # "linear", "cosine", "constant"
    lr_warmup_steps: int = 500
    max_grad_norm: float = 1.0

    # Diffusion
    num_inference_steps: int = 50
    guidance_scale: float = 7.5  # Classifier-free guidance

    # Data
    dataset_name: str = "svjack/pokemon-blip-captions-en-zh"  # Small dataset for demo
    # Alternatives: "nlphuji/flickr30k", "ChristophSchuhmann/MS-COCO_2017_URL_TEXT"
    # "poloclub/diffusiondb" (2M images), "Gustavosta/Stable-Diffusion-Prompts"
    image_size: int = 512
    center_crop: bool = True
    random_flip: bool = True
    hf_token: Optional[str] = None  # Set your HuggingFace token if needed

    # Optimization
    mixed_precision: str = "no"  # "no", "fp16", "bf16" - Changed to "no" to avoid dtype errors
    use_8bit_adam: bool = False  # Requires bitsandbytes
    use_ema: bool = True
    ema_decay: float = 0.9999

    # Advanced
    use_lora: bool = True  # Low-Rank Adaptation (memory efficient)
    lora_rank: int = 4
    noise_offset: float = 0.1  # Improves dark/light generation
    prior_preservation: bool = False
    prior_loss_weight: float = 1.0

    # Validation
    validation_prompts: List[str] = None
    validation_epochs: int = 2
    num_validation_images: int = 4

    # Logging
    output_dir: str = "./text2img-model"
    logging_dir: str = "./logs"
    use_wandb: bool = False
    wandb_project: str = "text2img-diffusion"
    save_model_epochs: int = 5

    # Memory optimization
    gradient_checkpointing: bool = True

    def __post_init__(self):
        if self.validation_prompts is None:
            self.validation_prompts = [
                "a photo of a pikachu",
                "a cute dragon breathing fire",
                "a futuristic city at sunset",
                "an astronaut riding a horse"
            ]

config = TrainingConfig()

print("Configuration:")
print(f"  Model: {config.pretrained_model_name}")
print(f"  Dataset: {config.dataset_name}")
print(f"  Batch size: {config.train_batch_size} x {config.gradient_accumulation_steps} = {config.train_batch_size * config.gradient_accumulation_steps}")
print(f"  Mixed precision: {config.mixed_precision}")
print(f"  LoRA: {config.use_lora}")
print(f"  EMA: {config.use_ema}")

Configuration:
  Model: runwayml/stable-diffusion-v1-5
  Dataset: svjack/pokemon-blip-captions-en-zh
  Batch size: 4 x 4 = 16
  Mixed precision: no
  LoRA: True
  EMA: True


In [6]:
# ============================================================================
# 4. INITIALIZE ACCELERATOR (FIXED FOR LORA + FP16)
# ============================================================================
# CRITICAL: When using LoRA with mixed precision, we need to handle gradient scaling carefully
accelerator = Accelerator(
    mixed_precision=config.mixed_precision,
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    log_with="wandb" if config.use_wandb else None,
    project_dir=config.logging_dir,
)

os.makedirs(config.output_dir, exist_ok=True)
os.makedirs(config.logging_dir, exist_ok=True)

if accelerator.is_main_process:
    if config.use_wandb:
        wandb.init(project=config.wandb_project, config=config.__dict__)


In [18]:
# ============================================================================
# 5. LOAD PRETRAINED MODELS
# ============================================================================
print("\nLoading pretrained models...")

# Load tokenizer and text encoder (CLIP)
tokenizer = CLIPTokenizer.from_pretrained(
    config.pretrained_model_name,
    subfolder="tokenizer"
)
text_encoder = CLIPTextModel.from_pretrained(
    config.pretrained_model_name,
    subfolder="text_encoder"
)

# Load VAE (Variational Autoencoder)
vae = AutoencoderKL.from_pretrained(
    config.pretrained_model_name,
    subfolder="vae"
)

# Load U-Net
unet = UNet2DConditionModel.from_pretrained(
    config.pretrained_model_name,
    subfolder="unet"
)

# Load noise scheduler
noise_scheduler = DDPMScheduler.from_pretrained(
    config.pretrained_model_name,
    subfolder="scheduler"
)

print(f"✓ Loaded models from {config.pretrained_model_name}")

# Freeze VAE and text encoder (only train U-Net)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)

# Enable gradient checkpointing for memory efficiency
if config.gradient_checkpointing:
    unet.enable_gradient_checkpointing()

# Enable xformers memory efficient attention
try:
    unet.enable_xformers_memory_efficient_attention()
    print("✓ Enabled xformers memory efficient attention")
except:
    print("⚠ xformers not available, using default attention")



Loading pretrained models...
✓ Loaded models from runwayml/stable-diffusion-v1-5
✓ Enabled xformers memory efficient attention


In [27]:
# ============================================================================
# 6. LORA IMPLEMENTATION (Optional)
# ============================================================================
if config.use_lora:
    try:
        # Try newer API first (diffusers >= 0.20.0)
        from diffusers.loaders import AttnProcsLayers
        from diffusers.models.attention_processor import LoRAAttnProcessor2_0

        lora_attn_procs = {}
        for name, attn_processor in unet.attn_processors.items():
            cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
            if name.startswith("mid_block"):
                hidden_size = unet.config.block_out_channels[-1]
            elif name.startswith("up_blocks"):
                block_id = int(name[len("up_blocks.")])
                hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
            elif name.startswith("down_blocks"):
                block_id = int(name[len("down_blocks.")])
                hidden_size = unet.config.block_out_channels[block_id]

            lora_attn_procs[name] = LoRAAttnProcessor2_0(
                hidden_size=hidden_size,
                cross_attention_dim=cross_attention_dim,
                rank=config.lora_rank
            )

        unet.set_attn_processor(lora_attn_procs)
        lora_layers = AttnProcsLayers(unet.attn_processors)

    except (ImportError, TypeError):
        # Fallback to older API or simpler LoRA implementation
        print("⚠ Using simplified LoRA implementation")
        from peft import LoraConfig, get_peft_model

        lora_config = LoraConfig(
            r=config.lora_rank,
            lora_alpha=config.lora_rank,
            init_lora_weights="gaussian",
            target_modules=["to_k", "to_q", "to_v", "to_out.0"],
        )
        unet = get_peft_model(unet, lora_config)
        # Explicitly cast LoRA weights to float32 after applying PEFT
        for param in unet.parameters():
            if param.requires_grad:
                param.data = param.data.to(torch.float32)

        lora_layers = filter(lambda p: p.requires_grad, unet.parameters())


    # Only train LoRA parameters
    trainable_params = list(filter(lambda p: p.requires_grad, lora_layers))

    print(f"✓ Enabled LoRA with rank {config.lora_rank}")
    print(f"  Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
else:
    trainable_params = unet.parameters()
    print(f"  Training full U-Net: {sum(p.numel() for p in trainable_params):,} parameters")

⚠ Using simplified LoRA implementation
✓ Enabled LoRA with rank 4
  Trainable parameters: 797,184




In [20]:
# ============================================================================
# 7. EXPONENTIAL MOVING AVERAGE (EMA)
# ============================================================================
class EMAModel:
    def __init__(self, model, decay=0.9999):
        self.decay = decay
        self.shadow = {}
        self.original = {}

        for name, param in model.named_parameters():
            if param.requires_grad:
                # Ensure shadow tensors are on the same device as the model parameters
                self.shadow[name] = param.data.clone().to(param.data.device)


    def update(self, model):
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = self.decay * self.shadow[name] + (1 - self.decay) * param.data

    def apply_shadow(self, model):
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.original[name] = param.data.clone()
                param.data = self.shadow[name]

    def restore(self, model):
        for name, param in model.named_parameters():
            if param.requires_grad:
                param.data = self.original[name]

ema_model = EMAModel(unet, decay=config.ema_decay) if config.use_ema else None

if config.use_ema:
    print(f"✓ Enabled EMA with decay {config.ema_decay}")

✓ Enabled EMA with decay 0.9999


In [21]:
# ============================================================================
# 8. DATASET AND DATALOADER
# ============================================================================
from PIL import Image # Import Image at the class level

class TextImageDataset(Dataset):
    def __init__(self, dataset, tokenizer, config):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.config = config

        self.transforms = transforms.Compose([
            transforms.Resize(config.image_size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(config.image_size) if config.center_crop else transforms.Lambda(lambda x: x),
            transforms.RandomHorizontalFlip() if config.random_flip else transforms.Lambda(lambda x: x),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1]
        ])

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        try:
            item = self.dataset[idx]

            # Get image - handle different column names
            image = None
            for img_col in ['image', 'img', 'picture', 'photo']:
                if img_col in item:
                    image = item[img_col]
                    break

            if image is None:
                raise ValueError(f"No image column found in item keys: {item.keys()}")

            # Handle different image formats
            if isinstance(image, str):  # URL or file path
                try:
                    if image.startswith('http'):
                        import requests
                        from io import BytesIO
                        response = requests.get(image, timeout=5)
                        image = Image.open(BytesIO(response.content)).convert('RGB')
                    else:
                        image = Image.open(image).convert('RGB')
                except Exception as e:
                    # Return a blank image if download fails
                    image = Image.new('RGB', (self.config.image_size, self.config.image_size), (128, 128, 128))
            elif not isinstance(image, Image.Image):
                 # Convert numpy array or other format
                 if hasattr(image, 'shape'):
                     image = Image.fromarray(np.array(image)).convert('RGB')
                 else:
                     image = Image.new('RGB', (self.config.image_size, self.config.image_size), (128, 128, 128))


            image = self.transforms(image)

            # Get caption - handle different column names
            caption = None
            for text_col in ['text', 'en_text', 'caption', 'prompt', 'description', 'en_caption']:
                if text_col in item:
                    caption = item[text_col]
                    break

            if caption is None:
                caption = 'a photo'

            if not isinstance(caption, str):
                caption = str(caption)

            # Tokenize caption
            text_inputs = self.tokenizer(
                caption,
                padding="max_length",
                max_length=self.tokenizer.model_max_length,
                truncation=True,
                return_tensors="pt"
            )

            return {
                "pixel_values": image,
                "input_ids": text_inputs.input_ids[0]
            }
        except Exception as e:
            print(f"Error loading item {idx}: {e}")
            # Return a default item
            default_image = torch.zeros(3, self.config.image_size, self.config.image_size)
            default_text = self.tokenizer(
                "a photo",
                padding="max_length",
                max_length=self.tokenizer.model_max_length,
                truncation=True,
                return_tensors="pt"
            ).input_ids[0]
            return {"pixel_values": default_image, "input_ids": default_text}

print("\nLoading dataset...")

# Try multiple datasets in order of preference
dataset_options = [
    ("svjack/pokemon-blip-captions-en-zh", None),  # Public Pokemon dataset
    ("poloclub/diffusiondb", "2m_random_1k"),  # 1k subset for quick testing
    ("Gustavosta/Stable-Diffusion-Prompts", None),  # Text-only prompts
]

dataset = None
for dataset_name, subset in dataset_options:
    try:
        print(f"Trying to load: {dataset_name}")
        if subset:
            dataset = load_dataset(dataset_name, subset, split="train")
        else:
            dataset = load_dataset(dataset_name, split="train")
        config.dataset_name = dataset_name
        print(f"✓ Successfully loaded {len(dataset)} examples from {dataset_name}")
        break
    except Exception as e:
        print(f"  ⚠ Failed to load {dataset_name}: {str(e)[:100]}")
        continue

# If all fail, create a simple synthetic dataset
if dataset is None:
    print("\n⚠ Could not load any dataset. Creating synthetic dataset for testing...")
    from datasets import Dataset as HFDataset
    from PIL import Image # Import Image for synthetic data creation

    # Create simple colored squares with labels
    def create_synthetic_data(num_samples=100):
        images = []
        texts = []
        colors = ["red", "blue", "green", "yellow", "purple", "orange"]

        for i in range(num_samples):
            # Create a colored square
            color_idx = i % len(colors)
            color_name = colors[color_idx]

            # RGB values for colors
            color_map = {
                "red": (255, 0, 0), "blue": (0, 0, 255), "green": (0, 255, 0),
                "yellow": (255, 255, 0), "purple": (128, 0, 128), "orange": (255, 165, 0)
            }

            img = Image.new('RGB', (512, 512), color_map[color_name])
            images.append(img)
            texts.append(f"a {color_name} square")

        return {"image": images, "text": texts}

    synthetic_data = create_synthetic_data(200)
    dataset = HFDataset.from_dict(synthetic_data)
    config.dataset_name = "synthetic_colored_squares"
    print(f"✓ Created synthetic dataset with {len(dataset)} examples")

# Verify dataset has required columns
print(f"\nDataset columns: {dataset.column_names}")

# Map columns to standard names if needed
needs_rename = False
column_mapping = {}

if 'image' not in dataset.column_names:
    for col in dataset.column_names:
        if 'image' in col.lower() or 'img' in col.lower() or 'picture' in col.lower():
            column_mapping[col] = 'image'
            needs_rename = True
            break

if 'text' not in dataset.column_names:
    for col in dataset.column_names:
        if any(keyword in col.lower() for keyword in ['text', 'caption', 'prompt', 'description']):
            column_mapping[col] = 'text'
            needs_rename = True
            break

if needs_rename and column_mapping:
    print(f"Renaming columns: {column_mapping}")
    dataset = dataset.rename_columns(column_mapping)
    print(f"✓ New columns: {dataset.column_names}")
elif 'image' not in dataset.column_names or 'text' not in dataset.column_names:
    print(f"⚠ Warning: Expected 'image' and 'text' columns. Available: {dataset.column_names}")
    print("⚠ The dataset loader will handle column mapping automatically")

print(f"\nDataset info:")
print(f"  Name: {config.dataset_name}")
print(f"  Size: {len(dataset)} examples")
print(f"  Columns: {dataset.column_names}")
if len(dataset) > 0:
    # Show sample based on available columns
    sample = dataset[0]
    text_col = next((col for col in dataset.column_names if any(k in col.lower() for k in ['text', 'caption', 'prompt'])), None)
    if text_col:
        text_sample = sample[text_col]
        if isinstance(text_sample, str):
            print(f"  Sample text: {text_sample[:100]}")
        else:
            print(f"  Sample text: {str(text_sample)[:100]}")
print()

train_dataset = TextImageDataset(dataset, tokenizer, config)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=config.train_batch_size,
    shuffle=True,
    num_workers=0  # Set to 0 for Colab
)


Loading dataset...
Trying to load: svjack/pokemon-blip-captions-en-zh
✓ Successfully loaded 833 examples from svjack/pokemon-blip-captions-en-zh

Dataset columns: ['image', 'en_text', 'zh_text']
Renaming columns: {'en_text': 'text'}
✓ New columns: ['image', 'text', 'zh_text']

Dataset info:
  Name: svjack/pokemon-blip-captions-en-zh
  Size: 833 examples
  Columns: ['image', 'text', 'zh_text']
  Sample text: a drawing of a green pokemon with red eyes



In [22]:
# ============================================================================
# 9. OPTIMIZER AND SCHEDULER
# ============================================================================
if config.use_8bit_adam:
    try:
        import bitsandbytes as bnb
        optimizer = bnb.optim.AdamW8bit(
            trainable_params,
            lr=config.learning_rate,
            betas=(0.9, 0.999),
            weight_decay=0.01,
            eps=1e-8
        )
        print("✓ Using 8-bit AdamW optimizer")
    except ImportError:
        optimizer = torch.optim.AdamW(
            trainable_params,
            lr=config.learning_rate,
            betas=(0.9, 0.999),
            weight_decay=0.01,
            eps=1e-8
        )
        print("⚠ 8-bit optimizer not available, using standard AdamW")
else:
    optimizer = torch.optim.AdamW(
        trainable_params,
        lr=config.learning_rate,
        betas=(0.9, 0.999),
        weight_decay=0.01,
        eps=1e-8
    )

# Learning rate scheduler
from diffusers.optimization import get_scheduler

lr_scheduler = get_scheduler(
    config.lr_scheduler,
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps * config.gradient_accumulation_steps,
    num_training_steps=len(train_dataloader) * config.num_epochs,
)

In [23]:
# ============================================================================
# 10. PREPARE FOR TRAINING WITH ACCELERATOR (FIXED)
# ============================================================================
# Prepare models - DON'T prepare unet yet if using LoRA, to avoid gradient issues
if config.use_lora:
    # Only prepare optimizer and dataloader first
    optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        optimizer, train_dataloader, lr_scheduler
    )
    # Move unet to device manually and keep it unwrapped for LoRA
    unet.to(accelerator.device)
else:
    # Standard preparation for full fine-tuning
    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        unet, optimizer, train_dataloader, lr_scheduler
    )

vae.to(accelerator.device)
text_encoder.to(accelerator.device)
weight_dtype = torch.float32
if config.mixed_precision == "fp16":
    weight_dtype = torch.float16
elif config.mixed_precision == "bf16":
    weight_dtype = torch.bfloat16

vae.to(dtype=weight_dtype)
text_encoder.to(dtype=weight_dtype)

CLIPTextModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 768)
      (position_embedding): Embedding(77, 768)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e

In [None]:
# ============================================================================
# 11. TRAINING LOOP (FIXED)
# ============================================================================
def encode_prompt(batch):
    """Encode text prompt to embeddings"""
    text_input_ids = batch["input_ids"].to(accelerator.device)
    encoder_hidden_states = text_encoder(text_input_ids)[0]
    return encoder_hidden_states

def train_step(batch, global_step):
    """Single training step - FIXED for LoRA + FP16"""
    # For LoRA, we need to handle gradients manually without accelerator.accumulate
    if config.use_lora:
        # Convert images to latent space - ENSURE CORRECT DTYPE
        pixel_values = batch["pixel_values"].to(accelerator.device, dtype=weight_dtype)
        latents = vae.encode(pixel_values).latent_dist.sample()
        latents = latents * vae.config.scaling_factor

        # Sample noise
        noise = torch.randn_like(latents)

        # Add noise offset
        if config.noise_offset > 0:
            noise += config.noise_offset * torch.randn(
                (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
            )

        # Sample random timestep
        bsz = latents.shape[0]
        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps, (bsz,),
            device=latents.device
        ).long()

        # Add noise to latents
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        # Get text embeddings
        encoder_hidden_states = encode_prompt(batch)

        # Predict noise residual
        # CRITICAL FIX: Remove explicit .float() calls as mixed_precision is "no"
        model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

        # Compute loss
        if noise_scheduler.config.prediction_type == "epsilon":
            target = noise
        elif noise_scheduler.config.prediction_type == "v_prediction":
            target = noise_scheduler.get_velocity(latents, noise, timesteps)
        else:
            raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

        # Remove explicit .float() from loss calculation
        loss = F.mse_loss(model_pred, target, reduction="mean")

        # Manual backward without accelerator for LoRA
        loss = loss / config.gradient_accumulation_steps
        loss.backward()

        # Manual gradient step every gradient_accumulation_steps
        if (global_step + 1) % config.gradient_accumulation_steps == 0:
            # Clip gradients
            torch.nn.utils.clip_grad_norm_(trainable_params, config.max_grad_norm)

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            # Update EMA
            if config.use_ema:
                ema_model.update(unet)

        return loss.detach().item() * config.gradient_accumulation_steps

    else:
        # Standard training step with accelerator for full fine-tuning
        with accelerator.accumulate(unet):
            # Convert images to latent space - ENSURE CORRECT DTYPE
            pixel_values = batch["pixel_values"].to(accelerator.device, dtype=weight_dtype)
            latents = vae.encode(pixel_values).latent_dist.sample()
            latents = latents * vae.config.scaling_factor

            # Sample noise
            noise = torch.randn_like(latents)

            # Add noise offset
            if config.noise_offset > 0:
                noise += config.noise_offset * torch.randn(
                    (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
                )

            # Sample random timestep
            bsz = latents.shape[0]
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (bsz,),
                device=latents.device
            ).long()

            # Add noise to latents
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # Get text embeddings
            encoder_hidden_states = encode_prompt(batch)

            # Predict noise residual
            model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

            # Compute loss
            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                target = noise_scheduler.get_velocity(latents, noise, timesteps)
            else:
                raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

            # Remove explicit .float() from loss calculation
            loss = F.mse_loss(model_pred, target, reduction="mean")

            # Backpropagation
            accelerator.backward(loss)

            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(trainable_params, config.max_grad_norm)

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            # Update EMA
            if config.use_ema and accelerator.sync_gradients:
                ema_model.update(accelerator.unwrap_model(unet))

        return loss.detach().item()

@torch.no_grad()
def validate(epoch):
    """Generate validation images - FIXED DTYPE HANDLING"""
    print(f"\nGenerating validation images for epoch {epoch}...")

    # Get the correct unet reference
    unet_for_inference = unet if config.use_lora else accelerator.unwrap_model(unet)

    # Use EMA weights if available
    if config.use_ema:
        ema_model.apply_shadow(unet_for_inference)

    # Create inference pipeline with EXPLICIT dtype handling
    pipeline = StableDiffusionPipeline(
        vae=vae,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        unet=unet_for_inference,
        scheduler=DPMSolverMultistepScheduler.from_pretrained(
            config.pretrained_model_name, subfolder="scheduler"
        ),
        safety_checker=None,
        feature_extractor=None,
        requires_safety_checker=False,
    )

    # CRITICAL FIX: Set pipeline to use the correct dtype
    pipeline = pipeline.to(accelerator.device, dtype=weight_dtype)
    pipeline.set_progress_bar_config(disable=True)

    # Generate images
    images = []
    for prompt in config.validation_prompts[:config.num_validation_images]:
        try:
            image = pipeline(
                prompt,
                num_inference_steps=config.num_inference_steps,
                guidance_scale=config.guidance_scale,
                generator=torch.Generator(device=accelerator.device).manual_seed(42)
            ).images[0]
            images.append(image)
        except Exception as e:
            print(f"Error generating image for '{prompt}': {e}")
            # Create a placeholder image
            images.append(Image.new('RGB', (512, 512), (128, 128, 128)))

    # Display images
    fig, axes = plt.subplots(1, len(images), figsize=(5*len(images), 5))
    if len(images) == 1:
        axes = [axes]

    for idx, (img, prompt) in enumerate(zip(images, config.validation_prompts[:len(images)])):
        axes[idx].imshow(img)
        axes[idx].set_title(prompt, fontsize=10, wrap=True)
        axes[idx].axis('off')

    plt.tight_layout()
    plt.savefig(f"{config.output_dir}/validation_epoch_{epoch}.png", dpi=150, bbox_inches='tight')
    plt.show()

    # Log to wandb
    if config.use_wandb and accelerator.is_main_process:
        wandb.log({
            "validation": [wandb.Image(img, caption=prompt)
                          for img, prompt in zip(images, config.validation_prompts[:len(images)])]
        }, step=epoch)

    # Restore original weights
    if config.use_ema:
        ema_model.restore(unet_for_inference)

    del pipeline
    gc.collect()
    torch.cuda.empty_cache()

print("\n" + "="*70)
print("STARTING TRAINING")
print("="*70 + "\n")

global_step = 0
for epoch in range(config.num_epochs):
    unet.train()
    epoch_loss = 0

    progress_bar = tqdm(
        train_dataloader,
        desc=f"Epoch {epoch+1}/{config.num_epochs}",
        disable=not accelerator.is_local_main_process
    )

    for step, batch in enumerate(progress_bar):
        loss = train_step(batch, global_step)
        epoch_loss += loss

        # Update global step based on whether we're using manual accumulation (LoRA) or accelerator
        if config.use_lora:
            if (global_step + 1) % config.gradient_accumulation_steps == 0:
                global_step += 1
        else:
            if accelerator.sync_gradients:
                global_step += 1

        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{loss:.4f}',
            'lr': f'{lr_scheduler.get_last_lr()[0]:.2e}'
        })

        # Log to wandb
        if config.use_wandb and accelerator.is_main_process and global_step % 100 == 0:
            wandb.log({
                "train_loss": loss,
                "learning_rate": lr_scheduler.get_last_lr()[0],
                "epoch": epoch
            }, step=global_step)

    avg_loss = epoch_loss / len(train_dataloader)
    print(f"\nEpoch {epoch+1} - Average Loss: {avg_loss:.4f}")

    # Validation
    if (epoch + 1) % config.validation_epochs == 0:
        if accelerator.is_main_process:
            try:
                validate(epoch + 1)
            except Exception as e:
                print(f"⚠ Validation failed: {e}")
                print("Continuing training...")

    # Save model
    if (epoch + 1) % config.save_model_epochs == 0:
        if accelerator.is_main_process:
            print(f"\nSaving model checkpoint at epoch {epoch+1}...")

            try:
                if config.use_lora:
                    # For LoRA, save the attention processors
                    unet.save_attn_procs(f"{config.output_dir}/lora_weights_epoch_{epoch+1}.pt")
                    print(f"✓ Saved LoRA weights")
                else:
                    accelerator.unwrap_model(unet).save_pretrained(f"{config.output_dir}/unet_epoch_{epoch+1}")
                    print(f"✓ Saved U-Net weights")
            except Exception as e:
                print(f"⚠ Failed to save checkpoint: {e}")

print("\n" + "="*70)
print("TRAINING COMPLETE!")
print("="*70)


STARTING TRAINING



Epoch 1/10:   0%|          | 0/209 [00:00<?, ?it/s]


Epoch 1 - Average Loss: 0.0642


Epoch 2/10:   0%|          | 0/209 [00:00<?, ?it/s]


Epoch 2 - Average Loss: 0.0624

Generating validation images for epoch 2...
⚠ Validation failed: 'base_model.model.base_model.model.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.lora_A.default.weight'
Continuing training...


Epoch 3/10:   0%|          | 0/209 [00:00<?, ?it/s]


Epoch 3 - Average Loss: 0.0598


Epoch 4/10:   0%|          | 0/209 [00:00<?, ?it/s]

In [29]:
# ============================================================================
# 11. TRAINING LOOP
# ============================================================================
def train_step(batch, global_step):
    """Single training step"""
    with accelerator.accumulate(unet):
        # Convert images to latent space
        pixel_values = batch["pixel_values"].to(accelerator.device, dtype=weight_dtype)
        latents = vae.encode(pixel_values).latent_dist.sample()
        latents = latents * vae.config.scaling_factor

        # Sample noise
        noise = torch.randn_like(latents)

        # Add noise offset for better dark/light generation
        if config.noise_offset > 0:
            noise += config.noise_offset * torch.randn(
                (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
            )

        # Sample random timestep
        bsz = latents.shape[0]
        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps, (bsz,),
            device=latents.device
        ).long()

        # Add noise to latents (forward diffusion)
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        # Get text embeddings
        text_input_ids = batch["input_ids"].to(accelerator.device)
        encoder_hidden_states = text_encoder(text_input_ids, return_dict=False)[0]

        # Convert encoder hidden states to correct dtype
        encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)

        # Predict noise residual
        model_pred = unet(
            noisy_latents.to(dtype=weight_dtype),
            timesteps,
            encoder_hidden_states,
            return_dict=False
        )[0]

        # Compute loss
        if noise_scheduler.config.prediction_type == "epsilon":
            target = noise
        elif noise_scheduler.config.prediction_type == "v_prediction":
            target = noise_scheduler.get_velocity(latents, noise, timesteps)
        else:
            raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

        # Backpropagation
        accelerator.backward(loss)

        if accelerator.sync_gradients:
            params_to_clip = trainable_params
            accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm)

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        # Update EMA
        if config.use_ema and accelerator.sync_gradients:
            ema_model.update(accelerator.unwrap_model(unet))

    return loss.detach().item()

@torch.no_grad()
def validate(epoch):
    """Generate validation images - FIXED for LoRA"""
    print(f"\nGenerating validation images for epoch {epoch}...")

    try:
        if config.use_lora:
            # For LoRA: Create a fresh pipeline and load LoRA weights
            print("Setting up inference pipeline with LoRA...")

            # Create base pipeline
            pipeline = StableDiffusionPipeline.from_pretrained(
                config.pretrained_model_name,
                torch_dtype=weight_dtype,
                safety_checker=None,
                requires_safety_checker=False,
            )

            # Save current LoRA weights temporarily
            temp_lora_path = f"{config.output_dir}/temp_lora_validation.pt"
            unet.save_attn_procs(temp_lora_path)

            # Load LoRA weights into the pipeline's unet
            pipeline.unet.load_attn_procs(temp_lora_path)

            # Clean up temp file
            import os
            if os.path.exists(temp_lora_path):
                os.remove(temp_lora_path)

        else:
            # For full fine-tuning: use the unwrapped model
            unet_for_inference = accelerator.unwrap_model(unet)

            # Use EMA weights if available
            if config.use_ema:
                ema_model.apply_shadow(unet_for_inference)

            pipeline = StableDiffusionPipeline(
                vae=vae,
                text_encoder=text_encoder,
                tokenizer=tokenizer,
                unet=unet_for_inference,
                scheduler=DPMSolverMultistepScheduler.from_pretrained(
                    config.pretrained_model_name, subfolder="scheduler"
                ),
                safety_checker=None,
                feature_extractor=None,
                requires_safety_checker=False,
            )

        pipeline = pipeline.to(accelerator.device)
        pipeline.set_progress_bar_config(disable=True)

        # Enable memory optimizations
        if hasattr(pipeline, 'enable_attention_slicing'):
            pipeline.enable_attention_slicing()
        if hasattr(pipeline, 'enable_vae_slicing'):
            pipeline.enable_vae_slicing()

        # Generate images
        images = []
        for prompt in config.validation_prompts[:config.num_validation_images]:
            try:
                image = pipeline(
                    prompt,
                    num_inference_steps=config.num_inference_steps,
                    guidance_scale=config.guidance_scale,
                    generator=torch.Generator(device=accelerator.device).manual_seed(42)
                ).images[0]
                images.append(image)
            except Exception as e:
                print(f"  ⚠ Error generating '{prompt}': {str(e)[:100]}")
                # Create a placeholder
                images.append(Image.new('RGB', (512, 512), (128, 128, 128)))

        # Display images
        if len(images) > 0:
            fig, axes = plt.subplots(1, len(images), figsize=(5*len(images), 5))
            if len(images) == 1:
                axes = [axes]

            for idx, (img, prompt) in enumerate(zip(images, config.validation_prompts[:len(images)])):
                axes[idx].imshow(img)
                axes[idx].set_title(prompt, fontsize=10, wrap=True)
                axes[idx].axis('off')

            plt.tight_layout()
            plt.savefig(f"{config.output_dir}/validation_epoch_{epoch}.png", dpi=150, bbox_inches='tight')
            plt.show()

            # Log to wandb
            if config.use_wandb and accelerator.is_main_process:
                wandb.log({
                    "validation": [wandb.Image(img, caption=prompt)
                                  for img, prompt in zip(images, config.validation_prompts[:len(images)])]
                }, step=epoch)

        # Restore original weights for non-LoRA
        if not config.use_lora and config.use_ema:
            ema_model.restore(accelerator.unwrap_model(unet))

        # Cleanup
        del pipeline
        gc.collect()
        torch.cuda.empty_cache()

        print("✓ Validation complete")

    except Exception as e:
        print(f"⚠ Validation failed with error: {str(e)}")
        import traceback
        traceback.print_exc()
        print("Continuing training...")

print("\n" + "="*70)
print("STARTING TRAINING")
print("="*70 + "\n")

global_step = 0
for epoch in range(config.num_epochs):
    unet.train()
    epoch_loss = 0

    progress_bar = tqdm(
        train_dataloader,
        desc=f"Epoch {epoch+1}/{config.num_epochs}",
        disable=not accelerator.is_local_main_process
    )

    for step, batch in enumerate(progress_bar):
        loss = train_step(batch, global_step)
        epoch_loss += loss

        if accelerator.sync_gradients:
            global_step += 1

        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{loss:.4f}',
            'lr': f'{lr_scheduler.get_last_lr()[0]:.2e}'
        })

        # Log to wandb
        if config.use_wandb and accelerator.is_main_process and global_step % 100 == 0:
            wandb.log({
                "train_loss": loss,
                "learning_rate": lr_scheduler.get_last_lr()[0],
                "epoch": epoch
            }, step=global_step)

    avg_loss = epoch_loss / len(train_dataloader)
    print(f"\nEpoch {epoch+1} - Average Loss: {avg_loss:.4f}")

    # Validation
    if (epoch + 1) % config.validation_epochs == 0:
        if accelerator.is_main_process:
            validate(epoch + 1)

    # Save model
    if (epoch + 1) % config.save_model_epochs == 0:
        if accelerator.is_main_process:
            print(f"\nSaving model checkpoint at epoch {epoch+1}...")

            # Save U-Net weights
            if config.use_lora:
                unet_lora = accelerator.unwrap_model(unet)
                unet_lora.save_attn_procs(f"{config.output_dir}/lora_weights_epoch_{epoch+1}.pt")
                print(f"✓ Saved LoRA weights")
            else:
                accelerator.unwrap_model(unet).save_pretrained(f"{config.output_dir}/unet_epoch_{epoch+1}")
                print(f"✓ Saved U-Net weights")

print("\n" + "="*70)
print("TRAINING COMPLETE!")
print("="*70)


STARTING TRAINING



Epoch 1/10:   0%|          | 0/209 [00:00<?, ?it/s]

KeyError: 'base_model.model.base_model.model.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.lora_A.default.weight'

In [None]:
# ============================================================================
# 12. FINAL MODEL SAVE
# ============================================================================
if accelerator.is_main_process:
    print("\nSaving final model...")

    if config.use_ema:
        ema_model.apply_shadow(accelerator.unwrap_model(unet))

    if config.use_lora:
        accelerator.unwrap_model(unet).save_attn_procs(f"{config.output_dir}/lora_weights_final.pt")
    else:
        accelerator.unwrap_model(unet).save_pretrained(f"{config.output_dir}/unet_final")

    print(f"✓ Model saved to {config.output_dir}")


In [None]:
# ============================================================================
# 13. INFERENCE PIPELINE
# ============================================================================
print("\n" + "="*70)
print("CREATING INFERENCE PIPELINE")
print("="*70)

# Load the trained model
if config.use_ema:
    ema_model.apply_shadow(accelerator.unwrap_model(unet))

inference_pipeline = StableDiffusionPipeline.from_pretrained(
    config.pretrained_model_name,
    unet=accelerator.unwrap_model(unet),
    torch_dtype=weight_dtype,
    safety_checker=None,
    requires_safety_checker=False,
)
inference_pipeline.scheduler = DPMSolverMultistepScheduler.from_pretrained(
    config.pretrained_model_name, subfolder="scheduler"
)
inference_pipeline.to(accelerator.device)

print("✓ Inference pipeline ready!")

In [None]:
# ============================================================================
# 14. GENERATE SAMPLE IMAGES
# ============================================================================
def generate_image(prompt, num_images=1, seed=None):
    """Generate images from text prompt"""
    if seed is not None:
        generator = torch.Generator(device=accelerator.device).manual_seed(seed)
    else:
        generator = None

    images = inference_pipeline(
        prompt,
        num_images_per_prompt=num_images,
        num_inference_steps=config.num_inference_steps,
        guidance_scale=config.guidance_scale,
        generator=generator
    ).images

    return images

print("\n" + "="*70)
print("GENERATING SAMPLE IMAGES")
print("="*70 + "\n")

sample_prompts = [
    "a beautiful landscape with mountains and lake",
    "a cute robot playing guitar",
    "a magical forest with glowing mushrooms",
    "a cyberpunk city at night"
]

for prompt in sample_prompts:
    print(f"Generating: '{prompt}'")
    images = generate_image(prompt, num_images=1, seed=42)

    plt.figure(figsize=(8, 8))
    plt.imshow(images[0])
    plt.title(prompt, fontsize=12, wrap=True)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
# ============================================================================
# 15. GRADIO INTERFACE (OPTIONAL)
# ============================================================================
print("\n" + "="*70)
print("CREATING GRADIO INTERFACE")
print("="*70)

import gradio as gr

def generate_gradio(prompt, num_steps, guidance, seed):
    """Gradio generation function"""
    if seed == -1:
        seed = None

    temp_pipeline = inference_pipeline
    temp_pipeline.scheduler.config.num_train_timesteps = num_steps

    image = temp_pipeline(
        prompt,
        num_inference_steps=num_steps,
        guidance_scale=guidance,
        generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed else None
    ).images[0]

    return image

demo = gr.Interface(
    fn=generate_gradio,
    inputs=[
        gr.Textbox(label="Prompt", placeholder="Enter your text prompt here..."),
        gr.Slider(10, 100, value=50, step=5, label="Inference Steps"),
        gr.Slider(1, 20, value=7.5, step=0.5, label="Guidance Scale"),
        gr.Number(value=42, label="Seed (-1 for random)")
    ],
    outputs=gr.Image(label="Generated Image", type="pil"),
    title="Text-to-Image Diffusion Model",
    description="Generate images from text descriptions using your fine-tuned model!",
    examples=[
        ["a photo of a cute puppy", 50, 7.5, 42],
        ["a futuristic cityscape", 50, 7.5, 123],
        ["an astronaut riding a horse on mars", 50, 7.5, 456],
    ]
)

print("\n✓ Gradio interface created!")
print("\nLaunch with: demo.launch(share=True)")


In [None]:
# ============================================================================
# 16. USAGE INSTRUCTIONS
# ============================================================================
print("\n" + "="*70)
print("USAGE INSTRUCTIONS")
print("="*70)
print("""
# Generate single image:
images = generate_image("a beautiful sunset over ocean", num_images=1, seed=42)

# Generate multiple images:
images = generate_image("a cute cat", num_images=4, seed=123)

# Launch Gradio interface:
demo.launch(share=True)

# Load model later:
from diffusers import StableDiffusionPipeline
pipeline = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    unet=UNet2DConditionModel.from_pretrained("./text2img-model/unet_final")
)

# For LoRA weights:
pipeline.unet.load_attn_procs("./text2img-model/lora_weights_final.pt")
""")

print("\n✓ Setup complete! Your text-to-image model is ready to use.")