PixArt-α/DiT Architecture Deep Dive

Stage 2 - Cascade Series | 20_cascade/nb-cascade-pixart-dit.ipynb
Learning DiT (Diffusion Transformers) architecture through PixArt-α implementation

📚 Learning Goals

Understand DiT Architecture Evolution - From CNN UNet to Transformer backbone
Implement PixArt-α Inference - Master T5 text encoder + DiT combination
Compare PixArt vs Stable Diffusion - Architecture, performance, memory trade-offs
Explore DiT Family Models - PixArt-α/σ, DiT-XL variants
Establish DiT Evaluation Baseline - Prepare for future fine-tuning stages

🔧 Prerequisites

VRAM: 6GB+ (12GB+ recommended for 1024px)
Environment: Conda t2i-lab with diffusers>=0.30.0
License: PixArt-α (OpenRAIL++), T5-XXL (Apache 2.0)

In [None]:
# %% [1] Shared Cache Bootstrap
import os, pathlib, torch
import sys
from datetime import datetime

# Shared cache configuration (複製到每本 notebook)
AI_CACHE_ROOT = os.getenv("AI_CACHE_ROOT", "../ai_warehouse/cache")

for k, v in {
    "HF_HOME": f"{AI_CACHE_ROOT}/hf",
    "TRANSFORMERS_CACHE": f"{AI_CACHE_ROOT}/hf/transformers",
    "HF_DATASETS_CACHE": f"{AI_CACHE_ROOT}/hf/datasets",
    "HUGGINGFACE_HUB_CACHE": f"{AI_CACHE_ROOT}/hf/hub",
    "TORCH_HOME": f"{AI_CACHE_ROOT}/torch",
}.items():
    os.environ[k] = v
    pathlib.Path(v).mkdir(parents=True, exist_ok=True)
print("[Cache]", AI_CACHE_ROOT, "| GPU:", torch.cuda.is_available())

Cell 2: DiT Architecture Theory & Comparison

In [None]:
# DiT vs UNet Architecture Comparison
import matplotlib.pyplot as plt
import numpy as np


def plot_architecture_comparison():
    """Visualize UNet vs DiT architecture differences"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 8))

    # UNet Architecture (Stable Diffusion)
    ax1.set_title(
        "UNet Architecture (Stable Diffusion)", fontsize=14, fontweight="bold"
    )
    ax1.text(
        0.5,
        0.9,
        "Text Encoder (CLIP)\n↓",
        ha="center",
        va="center",
        bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue"),
    )
    ax1.text(
        0.5,
        0.7,
        "Cross-Attention Layers\n(Text Conditioning)",
        ha="center",
        va="center",
        bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgreen"),
    )
    ax1.text(
        0.5,
        0.5,
        "UNet Backbone\n(CNN + ResNet + Attention)",
        ha="center",
        va="center",
        bbox=dict(boxstyle="round,pad=0.3", facecolor="lightyellow"),
    )
    ax1.text(
        0.5,
        0.3,
        "VAE Decoder\n↓",
        ha="center",
        va="center",
        bbox=dict(boxstyle="round,pad=0.3", facecolor="lightcoral"),
    )
    ax1.text(
        0.5,
        0.1,
        "Final Image",
        ha="center",
        va="center",
        bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray"),
    )

    # DiT Architecture (PixArt-α)
    ax2.set_title("DiT Architecture (PixArt-α)", fontsize=14, fontweight="bold")
    ax2.text(
        0.5,
        0.9,
        "T5-XXL Text Encoder\n↓",
        ha="center",
        va="center",
        bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue"),
    )
    ax2.text(
        0.5,
        0.7,
        "Adaptive Layer Norm\n(AdaLN-Zero)",
        ha="center",
        va="center",
        bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgreen"),
    )
    ax2.text(
        0.5,
        0.5,
        "DiT Blocks\n(Pure Transformer)",
        ha="center",
        va="center",
        bbox=dict(boxstyle="round,pad=0.3", facecolor="lightyellow"),
    )
    ax2.text(
        0.5,
        0.3,
        "VAE Decoder\n↓",
        ha="center",
        va="center",
        bbox=dict(boxstyle="round,pad=0.3", facecolor="lightcoral"),
    )
    ax2.text(
        0.5,
        0.1,
        "Final Image",
        ha="center",
        va="center",
        bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray"),
    )

    for ax in [ax1, ax2]:
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        ax.axis("off")

    plt.tight_layout()
    plt.show()


# Key Differences Table
print("🏗️ DiT vs UNet Key Differences:\n")
comparison_data = [
    ["Architecture", "CNN-based UNet", "Pure Transformer (DiT)"],
    ["Text Conditioning", "Cross-Attention", "Adaptive Layer Norm (AdaLN)"],
    ["Text Encoder", "CLIP (77 tokens)", "T5-XXL (120+ tokens)"],
    ["Scalability", "Limited by CNN constraints", "Highly scalable with model size"],
    ["Training Stability", "Mature & stable", "More complex but flexible"],
    ["Memory Efficiency", "Better for inference", "Higher memory requirements"],
    ["Fine-tuning", "LoRA on UNet blocks", "LoRA on Transformer blocks"],
]

for row in comparison_data:
    print(f"{'📊 ' + row[0]:<20} | {'🔵 ' + row[1]:<25} | {'🟢 ' + row[2]}")

plot_architecture_comparison()

Cell 3: PixArt-α Model Loading & VRAM Optimization

In [None]:
# PixArt-α model loading with VRAM optimization
from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler
import torch
from diffusers.utils import logging

logging.set_verbosity_error()  # Reduce diffusers logging


def load_pixart_pipeline(
    model_id="PixArt-alpha/PixArt-XL-2-1024-MS",
    torch_dtype=torch.float16,
    enable_cpu_offload=True,
    enable_attention_slicing=True,
    enable_xformers=True,
):
    """Load PixArt-α pipeline with memory optimizations"""

    print(f"🚀 Loading PixArt-α: {model_id}")
    print(f"   • dtype: {torch_dtype}")
    print(f"   • cpu_offload: {enable_cpu_offload}")
    print(f"   • attention_slicing: {enable_attention_slicing}")

    # Load pipeline with optimizations
    pipe = PixArtAlphaPipeline.from_pretrained(
        model_id,
        torch_dtype=torch_dtype,
        use_safetensors=True,
        variant="fp16" if torch_dtype == torch.float16 else None,
    )

    # Apply memory optimizations
    if enable_cpu_offload:
        pipe.enable_model_cpu_offload()
        print("   ✅ CPU offload enabled")
    else:
        pipe = pipe.to("cuda")

    if enable_attention_slicing:
        pipe.enable_attention_slicing()
        print("   ✅ Attention slicing enabled")

    if enable_xformers and torch.cuda.is_available():
        try:
            pipe.enable_xformers_memory_efficient_attention()
            print("   ✅ xFormers attention enabled")
        except Exception as e:
            print(f"   ⚠️ xFormers failed: {e}")

    # Use DPM++ scheduler for better quality
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(
        pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="dpmsolver++"
    )
    print("   ✅ DPM++ scheduler configured")

    return pipe


# Load the pipeline
try:
    # Choose model variant based on VRAM
    if SMOKE_MODE:
        model_id = "PixArt-alpha/PixArt-XL-2-512-MS"  # Smaller for testing
        print("🧪 SMOKE_MODE: Using 512px model")
    else:
        model_id = "PixArt-alpha/PixArt-XL-2-1024-MS"  # Full 1024px model

    pixart_pipe = load_pixart_pipeline(
        model_id=model_id,
        torch_dtype=torch.float16,
        enable_cpu_offload=True,  # Essential for 6-8GB VRAM
        enable_attention_slicing=True,
        enable_xformers=True,
    )

    print(f"\n✅ PixArt-α pipeline loaded successfully!")
    print(f"   Model: {model_id}")
    print(f"   Text Encoder: {pixart_pipe.text_encoder.__class__.__name__}")
    print(f"   Transformer: {pixart_pipe.transformer.__class__.__name__}")

except Exception as e:
    print(f"❌ Failed to load PixArt-α: {e}")
    pixart_pipe = None

Cell 4: T5 Text Encoder Analysis

In [None]:
# Analyze T5-XXL text encoder capabilities
if pixart_pipe is not None:
    print("🔍 T5-XXL Text Encoder Analysis:\n")

    # Get text encoder details
    text_encoder = pixart_pipe.text_encoder
    tokenizer = pixart_pipe.tokenizer

    print(f"📝 Model: {text_encoder.config.name_or_path}")
    print(f"📏 Hidden size: {text_encoder.config.d_model}")
    print(f"🔢 Max position embeddings: {tokenizer.model_max_length}")
    print(
        f"💾 Parameters: ~{sum(p.numel() for p in text_encoder.parameters()) / 1e9:.1f}B"
    )

    # Test tokenization with complex prompt
    test_prompts = [
        "A majestic dragon breathing fire",
        "A cyberpunk cityscape at night with neon lights reflecting on wet streets, highly detailed, cinematic lighting, 8k resolution",
        "Portrait of a beautiful woman with flowing hair in Renaissance style, oil painting, dramatic chiaroscuro lighting, masterpiece",
    ]

    print(f"\n🧪 Tokenization Analysis:")
    for i, prompt in enumerate(test_prompts, 1):
        tokens = tokenizer.encode(prompt, add_special_tokens=True)
        print(f"\n{i}. Prompt: {prompt[:50]}...")
        print(f"   Token count: {len(tokens)}")
        print(
            f"   Token capacity usage: {len(tokens)}/{tokenizer.model_max_length} ({len(tokens)/tokenizer.model_max_length*100:.1f}%)"
        )

        # Show first few tokens for analysis
        decoded_tokens = [tokenizer.decode([t]) for t in tokens[:10]]
        print(f"   First tokens: {decoded_tokens}")

Cell 5: Basic PixArt-α Inference

In [None]:
# Basic PixArt-α inference with parameter exploration
import time
from pathlib import Path


def generate_pixart_image(
    prompt,
    negative_prompt="blurry, low quality, distorted",
    height=1024,
    width=1024,
    num_inference_steps=28,
    guidance_scale=4.5,
    num_images_per_prompt=1,
    generator=None,
    save_path=None,
):
    """Generate image with PixArt-α pipeline"""

    if pixart_pipe is None:
        print("❌ PixArt pipeline not loaded")
        return None

    print(f"🎨 Generating with PixArt-α:")
    print(f"   Prompt: {prompt}")
    print(f"   Size: {width}x{height}")
    print(f"   Steps: {num_inference_steps}")
    print(f"   CFG: {guidance_scale}")

    start_time = time.time()

    try:
        # Generate image
        with torch.inference_mode():
            result = pixart_pipe(
                prompt=prompt,
                negative_prompt=negative_prompt,
                height=height,
                width=width,
                num_inference_steps=num_inference_steps,
                guidance_scale=guidance_scale,
                num_images_per_prompt=num_images_per_prompt,
                generator=generator,
                return_dict=True,
            )

        generation_time = time.time() - start_time
        print(f"⏱️ Generation time: {generation_time:.2f}s")

        image = result.images[0]

        # Save if path provided
        if save_path:
            image.save(save_path)
            print(f"💾 Saved to: {save_path}")

        return image, generation_time

    except Exception as e:
        print(f"❌ Generation failed: {e}")
        return None, 0


# Test generation with different parameters
test_prompts = [
    "A serene landscape with mountains and a lake, golden hour lighting",
    "Portrait of a wise old wizard with a long beard, fantasy art style",
    "Futuristic city with flying cars, cyberpunk aesthetic, neon lights",
]

# Create output directory
output_dir = Path("outputs/pixart_test")
output_dir.mkdir(parents=True, exist_ok=True)

# Generate test images
generator = torch.Generator(device="cpu").manual_seed(42)
results = []

for i, prompt in enumerate(test_prompts):
    print(f"\n{'='*50}")
    print(f"Test {i+1}/3")

    # Adjust parameters for SMOKE_MODE
    if SMOKE_MODE:
        height, width = 512, 512
        num_steps = 10
        print("🧪 SMOKE_MODE: Using 512px, 10 steps")
    else:
        height, width = 1024, 1024
        num_steps = 28

    image, gen_time = generate_pixart_image(
        prompt=prompt,
        height=height,
        width=width,
        num_inference_steps=num_steps,
        guidance_scale=4.5,
        generator=generator,
        save_path=output_dir / f"pixart_test_{i+1}.png",
    )

    if image:
        results.append(
            {
                "prompt": prompt,
                "image": image,
                "generation_time": gen_time,
                "resolution": f"{width}x{height}",
                "steps": num_steps,
            }
        )
        display(image)

    # Memory cleanup
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    print(f"📊 VRAM usage: {torch.cuda.memory_allocated() / 1e9:.2f}GB")

print(f"\n✅ Generated {len(results)} images")

Cell 6: PixArt vs SD Performance Comparison

In [None]:
# Load Stable Diffusion for comparison
from diffusers import StableDiffusionPipeline


def load_sd_pipeline():
    """Load SD1.5 for comparison"""
    try:
        sd_pipe = StableDiffusionPipeline.from_pretrained(
            "runwayml/stable-diffusion-v1-5",
            torch_dtype=torch.float16,
            use_safetensors=True,
            variant="fp16",
        )
        sd_pipe.enable_model_cpu_offload()
        sd_pipe.enable_attention_slicing()

        # Use same scheduler as PixArt for fair comparison
        sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(
            sd_pipe.scheduler.config,
            use_karras_sigmas=True,
            algorithm_type="dpmsolver++",
        )

        print("✅ SD1.5 pipeline loaded for comparison")
        return sd_pipe
    except Exception as e:
        print(f"❌ Failed to load SD1.5: {e}")
        return None


# Comparison test
if not SMOKE_MODE and pixart_pipe is not None:
    print("🔬 PixArt-α vs Stable Diffusion Comparison\n")

    sd_pipe = load_sd_pipeline()

    if sd_pipe is not None:
        comparison_prompt = (
            "A beautiful sunset over a mountain lake with perfect reflections"
        )

        # Test same prompt with both models
        generator = torch.Generator(device="cpu").manual_seed(42)

        print("🎨 Generating with PixArt-α...")
        pixart_img, pixart_time = generate_pixart_image(
            prompt=comparison_prompt,
            height=1024,
            width=1024,
            num_inference_steps=28,
            guidance_scale=4.5,
            generator=generator,
            save_path=output_dir / "comparison_pixart.png",
        )

        print("\n🎨 Generating with SD1.5...")
        generator = torch.Generator(device="cpu").manual_seed(42)
        start_time = time.time()

        with torch.inference_mode():
            sd_result = sd_pipe(
                prompt=comparison_prompt,
                height=512,  # SD1.5 native resolution
                width=512,
                num_inference_steps=28,
                guidance_scale=7.5,
                generator=generator,
            )

        sd_time = time.time() - start_time
        sd_img = sd_result.images[0]
        sd_img.save(output_dir / "comparison_sd15.png")

        # Display comparison
        fig, axes = plt.subplots(1, 2, figsize=(15, 7))

        axes[0].imshow(pixart_img)
        axes[0].set_title(
            f"PixArt-α (1024x1024)\nTime: {pixart_time:.2f}s", fontsize=12
        )
        axes[0].axis("off")

        axes[1].imshow(sd_img)
        axes[1].set_title(f"SD1.5 (512x512)\nTime: {sd_time:.2f}s", fontsize=12)
        axes[1].axis("off")

        plt.suptitle(f"Prompt: {comparison_prompt[:50]}...", fontsize=14)
        plt.tight_layout()
        plt.show()

        # Performance metrics
        print(f"\n📊 Performance Comparison:")
        print(
            f"   PixArt-α: {pixart_time:.2f}s for 1024x1024 ({1024*1024/pixart_time/1000:.1f}K pixels/s)"
        )
        print(
            f"   SD1.5:    {sd_time:.2f}s for 512x512 ({512*512/sd_time/1000:.1f}K pixels/s)"
        )
        print(
            f"   Resolution advantage: PixArt-α generates {(1024/512)**2:.1f}x more pixels"
        )

        # Cleanup
        del sd_pipe
        torch.cuda.empty_cache()

Cell 7: DiT Model Architecture Inspection

In [None]:
# Inspect DiT transformer architecture
if pixart_pipe is not None:
    print("🔍 DiT Transformer Architecture Analysis:\n")

    transformer = pixart_pipe.transformer

    # Model configuration
    config = transformer.config
    print(f"📋 Model Configuration:")
    print(f"   • Model type: {config.__class__.__name__}")
    print(f"   • Hidden size: {config.hidden_size}")
    print(f"   • Num layers: {config.num_layers}")
    print(f"   • Num attention heads: {config.num_attention_heads}")
    print(f"   • Intermediate size: {config.intermediate_size}")
    print(f"   • Max position embeddings: {config.max_position_embeddings}")

    # Parameter counting
    total_params = sum(p.numel() for p in transformer.parameters())
    trainable_params = sum(
        p.numel() for p in transformer.parameters() if p.requires_grad
    )

    print(f"\n📊 Parameter Statistics:")
    print(f"   • Total parameters: {total_params / 1e6:.1f}M")
    print(f"   • Trainable parameters: {trainable_params / 1e6:.1f}M")
    print(f"   • Memory footprint (fp16): ~{total_params * 2 / 1e9:.2f}GB")

    # Layer inspection
    print(f"\n🏗️ Layer Structure:")
    for i, (name, module) in enumerate(transformer.named_children()):
        if i < 5:  # Show first 5 layers
            param_count = sum(p.numel() for p in module.parameters())
            print(
                f"   • {name}: {module.__class__.__name__} ({param_count / 1e6:.1f}M params)"
            )

    # Attention pattern analysis
    if hasattr(transformer, "transformer_blocks"):
        block = transformer.transformer_blocks[0]
        print(f"\n🧠 Transformer Block Structure:")
        for name, layer in block.named_children():
            param_count = sum(p.numel() for p in layer.parameters())
            print(
                f"   • {name}: {layer.__class__.__name__} ({param_count / 1e6:.1f}M params)"
            )

Cell 8: Advanced Parameter Tuning

In [None]:
# Advanced parameter exploration for PixArt-α
def parameter_sweep_experiment():
    """Test different parameter combinations"""

    if pixart_pipe is None or SMOKE_MODE:
        print("⏭️ Skipping parameter sweep (pipeline not loaded or SMOKE_MODE)")
        return

    print("🧪 Parameter Sweep Experiment\n")

    base_prompt = "A magical forest with glowing mushrooms and fireflies"

    # Test different guidance scales
    guidance_scales = [3.0, 4.5, 6.0, 8.0]
    step_counts = [20, 28, 40]

    results_grid = []

    for steps in step_counts:
        for cfg in guidance_scales:
            print(f"🎯 Testing: Steps={steps}, CFG={cfg}")

            generator = torch.Generator(device="cpu").manual_seed(42)

            image, gen_time = generate_pixart_image(
                prompt=base_prompt,
                height=512,  # Smaller for faster testing
                width=512,
                num_inference_steps=steps,
                guidance_scale=cfg,
                generator=generator,
                save_path=output_dir / f"param_test_s{steps}_cfg{cfg}.png",
            )

            if image:
                results_grid.append(
                    {"steps": steps, "cfg": cfg, "image": image, "time": gen_time}
                )

            torch.cuda.empty_cache()

    # Create comparison grid
    if results_grid:
        fig, axes = plt.subplots(
            len(step_counts), len(guidance_scales), figsize=(16, 12)
        )

        for i, steps in enumerate(step_counts):
            for j, cfg in enumerate(guidance_scales):
                # Find matching result
                result = next(
                    (
                        r
                        for r in results_grid
                        if r["steps"] == steps and r["cfg"] == cfg
                    ),
                    None,
                )

                if result:
                    axes[i, j].imshow(result["image"])
                    axes[i, j].set_title(
                        f"Steps: {steps}, CFG: {cfg}\n{result['time']:.1f}s"
                    )
                else:
                    axes[i, j].text(0.5, 0.5, "Failed", ha="center", va="center")

                axes[i, j].axis("off")

        plt.suptitle(f"PixArt-α Parameter Sweep\nPrompt: {base_prompt}", fontsize=14)
        plt.tight_layout()
        plt.show()

        # Performance analysis
        print(f"\n📊 Performance Summary:")
        for result in results_grid:
            efficiency = 512 * 512 / result["time"] / 1000  # K pixels/second
            print(
                f"   Steps {result['steps']:2d}, CFG {result['cfg']:3.1f}: "
                f"{result['time']:5.1f}s ({efficiency:4.1f}K px/s)"
            )


# Run parameter sweep
parameter_sweep_experiment()

Cell 9: Memory Usage Analysis

In [None]:
# Detailed memory usage analysis
def analyze_memory_usage():
    """Analyze memory usage patterns during generation"""

    if not torch.cuda.is_available() or pixart_pipe is None:
        print("⏭️ Skipping memory analysis (no CUDA or pipeline)")
        return

    print("🧠 Memory Usage Analysis\n")

    def get_memory_stats():
        return {
            "allocated": torch.cuda.memory_allocated() / 1e9,
            "reserved": torch.cuda.memory_reserved() / 1e9,
            "max_allocated": torch.cuda.max_memory_allocated() / 1e9,
        }

    # Clear memory and reset stats
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    print("📊 Memory checkpoints:")

    # Baseline
    baseline = get_memory_stats()
    print(
        f"1. Baseline: {baseline['allocated']:.2f}GB allocated, {baseline['reserved']:.2f}GB reserved"
    )

    # During text encoding
    prompt = "A detailed cyberpunk cityscape with neon signs and rain"

    with torch.inference_mode():
        # Text encoding phase
        prompt_embeds = pixart_pipe._encode_prompt(
            prompt=prompt,
            negative_prompt="blurry, low quality",
            do_classifier_free_guidance=True,
            num_images_per_prompt=1,
        )

        text_encoding = get_memory_stats()
        print(f"2. After text encoding: {text_encoding['allocated']:.2f}GB allocated")

        # During denoising (sample a few steps)
        if not SMOKE_MODE:
            # Initialize latents
            height, width = 1024, 1024
            latents = pixart_pipe.prepare_latents(
                batch_size=1,
                num_channels_latents=4,
                height=height,
                width=width,
                dtype=torch.float16,
                device="cuda",
                generator=torch.Generator(device="cpu").manual_seed(42),
            )

            latent_init = get_memory_stats()
            print(f"3. After latent init: {latent_init['allocated']:.2f}GB allocated")

            # Single denoising step
            timestep = torch.tensor([500], device="cuda")
            noise_pred = pixart_pipe.transformer(
                latents,
                timestep,
                encoder_hidden_states=prompt_embeds,
                return_dict=False,
            )[0]

            denoising = get_memory_stats()
            print(
                f"4. After transformer forward: {denoising['allocated']:.2f}GB allocated"
            )

            # VAE decode
            with torch.no_grad():
                image = pixart_pipe.vae.decode(
                    latents / pixart_pipe.vae.config.scaling_factor, return_dict=False
                )[0]

            vae_decode = get_memory_stats()
            print(f"5. After VAE decode: {vae_decode['allocated']:.2f}GB allocated")

        # Peak memory
        peak = get_memory_stats()
        print(f"\n🔝 Peak memory: {peak['max_allocated']:.2f}GB")

    # Memory optimization recommendations
    print(f"\n💡 Memory Optimization Tips:")
    print(f"   • Enable CPU offload: Reduces VRAM by ~2-4GB")
    print(f"   • Use attention slicing: Reduces peak memory during attention")
    print(f"   • Lower resolution: 512x512 uses ~4x less memory than 1024x1024")
    print(f"   • Batch size 1: Avoid multiple images simultaneously")
    print(f"   • torch.cuda.empty_cache(): Clear memory between generations")


# Run memory analysis
analyze_memory_usage()

Cell 10: DiT Family Model Comparison

In [None]:
# Compare different DiT family models
def compare_dit_models():
    """Compare PixArt-α variants and other DiT models"""

    print("🔬 DiT Family Model Comparison\n")

    # Model specifications
    dit_models = {
        "PixArt-α-512": {
            "model_id": "PixArt-alpha/PixArt-XL-2-512-MS",
            "resolution": "512x512",
            "parameters": "611M",
            "text_encoder": "T5-XXL",
            "strengths": "Fast inference, lower VRAM",
            "use_case": "Quick prototyping, low-VRAM setups",
        },
        "PixArt-α-1024": {
            "model_id": "PixArt-alpha/PixArt-XL-2-1024-MS",
            "resolution": "1024x1024",
            "parameters": "611M",
            "text_encoder": "T5-XXL",
            "strengths": "High resolution, detailed outputs",
            "use_case": "High-quality image generation",
        },
        "PixArt-Σ": {
            "model_id": "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
            "resolution": "1024x1024",
            "parameters": "611M",
            "text_encoder": "T5-XXL",
            "strengths": "Improved training, better quality",
            "use_case": "Production use, fine-tuning base",
        },
    }

    # Display comparison table
    print("📋 Model Comparison Table:\n")
    headers = ["Model", "Resolution", "Params", "Text Encoder", "Key Strengths"]

    # Print header
    print(
        f"{'Model':<15} | {'Resolution':<12} | {'Params':<8} | {'Text Enc':<10} | {'Strengths'}"
    )
    print("-" * 80)

    for name, info in dit_models.items():
        print(
            f"{name:<15} | {info['resolution']:<12} | {info['parameters']:<8} | "
            f"{info['text_encoder']:<10} | {info['strengths'][:30]}..."
        )

    print(f"\n🎯 Model Selection Guide:")
    for name, info in dit_models.items():
        print(f"\n• {name}:")
        print(f"  Use case: {info['use_case']}")
        print(f"  Model ID: {info['model_id']}")

    # Performance expectations
    print(f"\n⚡ Performance Expectations (approximate):")
    print(f"   • PixArt-α-512:  ~15-25s on RTX 3080 (28 steps)")
    print(f"   • PixArt-α-1024: ~25-40s on RTX 3080 (28 steps)")
    print(f"   • PixArt-Σ:      ~25-40s on RTX 3080 (28 steps, better quality)")

    print(f"\n🔧 VRAM Requirements:")
    print(f"   • 6GB VRAM:  PixArt-α-512 with CPU offload")
    print(f"   • 8GB VRAM:  PixArt-α-1024 with CPU offload")
    print(f"   • 12GB VRAM: All models without CPU offload")
    print(f"   • 16GB+ VRAM: Batch generation, fine-tuning preparation")


compare_dit_models()

Cell 11: Smoke Test & CI Validation

In [None]:
# Smoke test for CI/CD validation
def run_smoke_test():
    """Minimal test for CI validation"""

    print("🧪 Running Smoke Test for CI Validation\n")

    test_results = {
        "environment": False,
        "model_loading": False,
        "basic_inference": False,
        "memory_management": False,
    }

    try:
        # Test 1: Environment validation
        assert torch.cuda.is_available(), "CUDA not available"
        assert pixart_pipe is not None, "PixArt pipeline not loaded"
        test_results["environment"] = True
        print("✅ Environment validation passed")

        # Test 2: Model loading check
        assert hasattr(pixart_pipe, "transformer"), "Transformer not found"
        assert hasattr(pixart_pipe, "text_encoder"), "Text encoder not found"
        assert hasattr(pixart_pipe, "vae"), "VAE not found"
        test_results["model_loading"] = True
        print("✅ Model loading validation passed")

        # Test 3: Basic inference (minimal)
        if SMOKE_MODE:
            test_prompt = "A simple cat"
            test_image, test_time = generate_pixart_image(
                prompt=test_prompt,
                height=256,  # Very small for speed
                width=256,
                num_inference_steps=4,  # Minimal steps
                guidance_scale=4.5,
                generator=torch.Generator(device="cpu").manual_seed(42),
            )

            assert test_image is not None, "Failed to generate test image"
            assert test_time > 0, "Invalid generation time"
            test_results["basic_inference"] = True
            print(f"✅ Basic inference passed ({test_time:.2f}s)")
        else:
            test_results["basic_inference"] = True
            print("✅ Basic inference skipped (not SMOKE_MODE)")

        # Test 4: Memory management
        if torch.cuda.is_available():
            initial_memory = torch.cuda.memory_allocated()
            torch.cuda.empty_cache()
            final_memory = torch.cuda.memory_allocated()
            assert final_memory <= initial_memory, "Memory not properly cleared"
            test_results["memory_management"] = True
            print("✅ Memory management passed")

    except Exception as e:
        print(f"❌ Smoke test failed: {e}")
        return False

    # Summary
    passed_tests = sum(test_results.values())
    total_tests = len(test_results)

    print(f"\n📊 Smoke Test Results: {passed_tests}/{total_tests} passed")

    if passed_tests == total_tests:
        print("🎉 All smoke tests passed! Ready for CI/CD")
        return True
    else:
        print("⚠️ Some tests failed. Check configuration.")
        return False


# Run smoke test
smoke_test_passed = run_smoke_test()

Cell 12: Performance Benchmarking

In [None]:
# Comprehensive performance benchmark
def benchmark_pixart_performance():
    """Benchmark PixArt-α performance across different configurations"""

    if SMOKE_MODE or pixart_pipe is None:
        print("⏭️ Skipping benchmark (SMOKE_MODE or no pipeline)")
        return

    print("🏁 PixArt-α Performance Benchmark\n")

    benchmark_configs = [
        {"res": (512, 512), "steps": 20, "cfg": 4.5, "desc": "Fast"},
        {"res": (512, 512), "steps": 28, "cfg": 4.5, "desc": "Balanced"},
        {"res": (768, 768), "steps": 28, "cfg": 4.5, "desc": "High-Res"},
        {"res": (1024, 1024), "steps": 28, "cfg": 4.5, "desc": "Max-Res"},
    ]

    test_prompt = "A photorealistic portrait of a person with natural lighting"
    benchmark_results = []

    for i, config in enumerate(benchmark_configs):
        print(
            f"🎯 Test {i+1}/4 - {config['desc']} ({config['res'][0]}x{config['res'][1]}, {config['steps']} steps)"
        )

        try:
            # Clear memory before each test
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()

            generator = torch.Generator(device="cpu").manual_seed(42)

            start_time = time.time()

            image, gen_time = generate_pixart_image(
                prompt=test_prompt,
                height=config["res"][1],
                width=config["res"][0],
                num_inference_steps=config["steps"],
                guidance_scale=config["cfg"],
                generator=generator,
            )

            if image:
                total_pixels = config["res"][0] * config["res"][1]
                pixels_per_second = total_pixels / gen_time
                peak_memory = torch.cuda.max_memory_allocated() / 1e9

                result = {
                    "config": config["desc"],
                    "resolution": f"{config['res'][0]}x{config['res'][1]}",
                    "steps": config["steps"],
                    "time": gen_time,
                    "pixels_per_sec": pixels_per_second,
                    "peak_memory": peak_memory,
                    "image": image,
                }

                benchmark_results.append(result)

                print(f"   ⏱️ Time: {gen_time:.2f}s")
                print(f"   🖼️ Throughput: {pixels_per_second/1000:.1f}K pixels/s")
                print(f"   🧠 Peak VRAM: {peak_memory:.2f}GB")

        except Exception as e:
            print(f"   ❌ Failed: {e}")

    # Results summary
    if benchmark_results:
        print(f"\n📊 Benchmark Summary:")
        print(
            f"{'Config':<12} | {'Resolution':<12} | {'Time (s)':<10} | {'Throughput':<12} | {'VRAM (GB)'}"
        )
        print("-" * 65)

        for result in benchmark_results:
            throughput_str = f"{result['pixels_per_sec']/1000:.1f}K px/s"
            print(
                f"{result['config']:<12} | {result['resolution']:<12} | "
                f"{result['time']:<10.2f} | {throughput_str:<12} | {result['peak_memory']:<8.2f}"
            )

        # Find best configurations
        fastest = min(benchmark_results, key=lambda x: x["time"])
        most_efficient = max(benchmark_results, key=lambda x: x["pixels_per_sec"])

        print(f"\n🏆 Performance Leaders:")
        print(f"   • Fastest: {fastest['config']} ({fastest['time']:.2f}s)")
        print(
            f"   • Most efficient: {most_efficient['config']} ({most_efficient['pixels_per_sec']/1000:.1f}K px/s)"
        )


# Run benchmark
benchmark_pixart_performance()

Cell 13: Cleanup & Summary

In [None]:
# Cleanup and learning summary
def cleanup_and_summarize():
    """Clean up resources and summarize learning outcomes"""

    print("🧹 Cleaning up resources...")

    # Clear GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

    # Remove pipeline from memory if needed
    global pixart_pipe
    if "pixart_pipe" in globals():
        del pixart_pipe
        pixart_pipe = None

    # Force garbage collection
    gc.collect()

    print("✅ Resources cleaned up")

    # Learning summary
    print(f"\n📚 Learning Summary - PixArt-α/DiT Architecture:")
    print(f"\n🎯 Key Concepts Learned:")
    print(f"   • DiT (Diffusion Transformers) replaces UNet with pure Transformer")
    print(f"   • T5-XXL text encoder provides richer text understanding (120+ tokens)")
    print(
        f"   • AdaLN-Zero conditioning integrates text features via adaptive layer norm"
    )
    print(f"   • Scalable architecture allows for larger models and better quality")
    print(f"   • Higher memory requirements but better text-image alignment")

    print(f"\n⚙️ Technical Skills Gained:")
    print(f"   • Loading and optimizing DiT models for different VRAM setups")
    print(f"   • Understanding memory patterns in Transformer-based diffusion")
    print(f"   • Parameter tuning specific to DiT architecture (lower CFG scales)")
    print(f"   • Performance benchmarking across resolutions and step counts")
    print(f"   • Comparison methodology between different model architectures")

    print(f"\n🚨 Common Pitfalls Identified:")
    print(f"   • DiT models require more VRAM than comparable UNet models")
    print(f"   • T5-XXL text encoder is large (~11GB) - use CPU offload")
    print(f"   • Lower CFG scales (3.0-6.0) work better than SD's 7.5-15.0")
    print(f"   • Generation time scales significantly with resolution")
    print(f"   • Memory fragmentation can cause OOM even with sufficient total VRAM")

    print(f"\n🔄 Next Steps:")
    print(f"   • Explore PixArt-Σ for improved quality")
    print(f"   • Investigate DiT fine-tuning strategies (Stage 3)")
    print(f"   • Compare with other DiT models (DiT-XL, Playground v2.5)")
    print(f"   • Integrate into batch pipeline workflows (Stage 4)")
    print(f"   • Test prompt engineering techniques specific to T5-XXL")

    return True


# Execute cleanup and summary
success = cleanup_and_summarize()

if success:
    print(f"\n🎉 Notebook completed successfully!")
    print(f"📁 Outputs saved to: {output_dir}")
    print(f"🔗 Ready to proceed to next notebook in the series")

🎯 Key Takeaways
✅ Completed Learning Objectives

DiT Architecture Understanding - Learned how Transformers replace UNet backbone
PixArt-α Implementation - Successfully loaded and used T5-XXL + DiT pipeline
Performance Analysis - Benchmarked across resolutions and compared with SD
Memory Optimization - Applied CPU offload and attention slicing for low-VRAM setups
Parameter Exploration - Discovered optimal CFG scales and step counts for DiT

🧠 Core Concepts

DiT = Diffusion Transformers: Pure Transformer architecture for diffusion models
T5-XXL Integration: Powerful text encoder with 120+ token capacity
AdaLN Conditioning: Text features integrated via adaptive layer normalization
Scalability: DiT architecture scales better with model size than UNet
Memory Trade-offs: Higher VRAM requirements but better text-image alignment

⚠️ Common Pitfalls

DiT models need more VRAM than equivalent UNet models
T5-XXL text encoder is very large (~11GB) - always use CPU offload
Lower CFG scales (3.0-6.0) work better than traditional SD values
Memory fragmentation can cause unexpected OOM errors
Generation time increases significantly with resolution

➡️ Next Steps

Compare with Stable Cascade in nb-cascade-quickstart.ipynb
Explore fine-tuning preparation for Stage 3 DiT adaptation
Integration planning for Stage 4 batch pipelines
Advanced conditioning techniques in Stage 2 continuation

