 # FLUX/Flow-DiT 快速入門 🌊
 
 **學習目標**:
 1. 理解 Flow-based DiT 架構差異 (flow matching vs diffusion)
 2. 實作 FLUX.1-dev/schnell 與 Playground v2.5 基礎推論
 3. 探索 guidance_scale、steps 對 Flow-DiT 的影響
 4. 記憶體最佳化策略 (12-16GB VRAM)
 5. Flow-DiT vs SD 性能對比分析

 **前置需求**: 12GB+ VRAM, diffusers>=0.30.0, torch>=2.3

 ## 📁 共享快取設定 (Shared Cache Bootstrap)

In [None]:
# %% [1] Shared Cache Bootstrap
import os, pathlib, torch
import sys
from datetime import datetime

# Shared cache configuration (複製到每本 notebook)
AI_CACHE_ROOT = os.getenv("AI_CACHE_ROOT", "../ai_warehouse/cache")

for k, v in {
    "HF_HOME": f"{AI_CACHE_ROOT}/hf",
    "TRANSFORMERS_CACHE": f"{AI_CACHE_ROOT}/hf/transformers",
    "HF_DATASETS_CACHE": f"{AI_CACHE_ROOT}/hf/datasets",
    "HUGGINGFACE_HUB_CACHE": f"{AI_CACHE_ROOT}/hf/hub",
    "TORCH_HOME": f"{AI_CACHE_ROOT}/torch",
}.items():
    os.environ[k] = v
    pathlib.Path(v).mkdir(parents=True, exist_ok=True)
print("[Cache]", AI_CACHE_ROOT, "| GPU:", torch.cuda.is_available())

## 🔧 環境檢查與依賴導入

In [None]:
# %%
import warnings

warnings.filterwarnings("ignore", category=FutureWarning)

import torch
import gc
from PIL import Image
import numpy as np
from diffusers import FluxPipeline, DiffusionPipeline
from transformers import pipeline
import time
from typing import List, Tuple, Optional
import json

# Check environment
print(f"PyTorch: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")

# Environment flags
SMOKE_MODE = os.getenv("SMOKE_MODE", "false").lower() == "true"
print(f"SMOKE_MODE: {SMOKE_MODE}")

 ## 🌊 Flow-DiT 架構理解
 
 ### Flow Matching vs Diffusion
 
 **傳統 Diffusion (SD)**:
 - 噪聲排程: `β_t` → 逐步去噪
 - DDPM/DDIM sampler
 - CFG guidance 強依賴
 
 **Flow Matching (FLUX)**:
 - 向量場學習: 直接學習 `noise → image` 的軌跡
 - Euler/Heun sampler 為主
 - Guidance 更自然整合
 
 ### FLUX 家族模型
 
 | 模型 | 參數量 | 特點 | 推薦用途 |
 |------|-------|------|----------|
 | FLUX.1-schnell | 12B | 4步快速推論 | 即時生成 |
 | FLUX.1-dev | 12B | 50步高品質 | 精細創作 |
 | Playground v2.5 | ? | 美學調優 | 藝術風格 |

 ## 🚀 FLUX.1-schnell 快速推論 (4-step)

In [None]:
# %%
def load_flux_schnell(enable_cpu_offload: bool = True) -> FluxPipeline:
    """Load FLUX.1-schnell with memory optimization"""
    print("Loading FLUX.1-schnell...")

    pipe = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-schnell",
        torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
        device_map="auto" if not enable_cpu_offload else None,
        use_safetensors=True,
    )

    if enable_cpu_offload:
        pipe.to("cuda")
        pipe.enable_sequential_cpu_offload()

    # Enable memory efficient attention
    pipe.enable_attention_slicing()
    if hasattr(pipe, "enable_xformers_memory_efficient_attention"):
        try:
            pipe.enable_xformers_memory_efficient_attention()
        except:
            print("xFormers not available, using default attention")

    return pipe


# Load FLUX schnell
flux_schnell = load_flux_schnell(enable_cpu_offload=True)

## 🎨 FLUX 基礎推論測試

In [None]:
# %%
def generate_flux_image(
    pipe: FluxPipeline,
    prompt: str,
    num_inference_steps: int = 4,
    guidance_scale: float = 0.0,  # FLUX schnell typically uses 0.0
    width: int = 1024,
    height: int = 1024,
    generator: Optional[torch.Generator] = None,
) -> Tuple[Image.Image, dict]:
    """Generate image with FLUX pipeline"""

    start_time = time.time()

    # FLUX schnell parameters
    if "schnell" in str(pipe.__class__).lower():
        num_inference_steps = min(num_inference_steps, 4)  # schnell max 4 steps
        guidance_scale = 0.0  # schnell doesn't use guidance

    with torch.no_grad():
        result = pipe(
            prompt=prompt,
            width=width,
            height=height,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            generator=generator,
        )

    end_time = time.time()

    metadata = {
        "prompt": prompt,
        "steps": num_inference_steps,
        "guidance_scale": guidance_scale,
        "size": f"{width}x{height}",
        "generation_time": f"{end_time - start_time:.2f}s",
        "model": "FLUX.1-schnell",
    }

    return result.images[0], metadata


# Test prompt
test_prompt = "a majestic dragon soaring through cloudy mountain peaks, cinematic lighting, epic fantasy art"
if SMOKE_MODE:
    test_prompt = "a simple red apple"

# Generate with FLUX schnell
print(f"Generating with FLUX.1-schnell...")
generator = torch.Generator(device="cuda").manual_seed(42)

flux_image, flux_meta = generate_flux_image(
    flux_schnell,
    prompt=test_prompt,
    num_inference_steps=4 if not SMOKE_MODE else 1,
    width=1024 if not SMOKE_MODE else 512,
    height=1024 if not SMOKE_MODE else 512,
    generator=generator,
)

# Display result
print(f"✅ Generated in {flux_meta['generation_time']}")
print(f"Metadata: {json.dumps(flux_meta, indent=2)}")
flux_image.show() if not SMOKE_MODE else None

## 📊 Memory Usage Analysis

In [None]:
# %%
def get_memory_stats() -> dict:
    """Get current CUDA memory statistics"""
    if not torch.cuda.is_available():
        return {"error": "CUDA not available"}

    return {
        "allocated_gb": torch.cuda.memory_allocated() / 1e9,
        "reserved_gb": torch.cuda.memory_reserved() / 1e9,
        "max_allocated_gb": torch.cuda.max_memory_allocated() / 1e9,
        "total_gb": torch.cuda.get_device_properties(0).total_memory / 1e9,
    }


memory_after_flux = get_memory_stats()
print("Memory usage after FLUX generation:")
print(json.dumps(memory_after_flux, indent=2))

## 🔄 Clean up and load FLUX.1-dev

In [None]:
# %%
# Clean up schnell model
del flux_schnell
gc.collect()
torch.cuda.empty_cache()

print("Loading FLUX.1-dev...")


def load_flux_dev(enable_cpu_offload: bool = True) -> FluxPipeline:
    """Load FLUX.1-dev with memory optimization"""

    pipe = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
        device_map="auto" if not enable_cpu_offload else None,
        use_safetensors=True,
    )

    if enable_cpu_offload:
        pipe.to("cuda")
        pipe.enable_sequential_cpu_offload()

    pipe.enable_attention_slicing()
    if hasattr(pipe, "enable_xformers_memory_efficient_attention"):
        try:
            pipe.enable_xformers_memory_efficient_attention()
        except:
            pass

    return pipe


# Skip FLUX.1-dev in smoke mode (memory intensive)
if not SMOKE_MODE:
    flux_dev = load_flux_dev(enable_cpu_offload=True)
else:
    print("⚠️ Skipping FLUX.1-dev in SMOKE_MODE")
    flux_dev = None

## 🎭 FLUX.1-dev 高品質推論 (25-50 steps)

In [None]:
# %%
if flux_dev is not None:
    print("Generating with FLUX.1-dev...")

    # FLUX.1-dev supports guidance and more steps
    generator = torch.Generator(device="cuda").manual_seed(42)

    flux_dev_image, flux_dev_meta = generate_flux_image(
        flux_dev,
        prompt=test_prompt,
        num_inference_steps=25,  # dev supports more steps
        guidance_scale=3.5,  # dev supports guidance
        width=1024,
        height=1024,
        generator=generator,
    )

    flux_dev_meta["model"] = "FLUX.1-dev"
    print(f"✅ Generated in {flux_dev_meta['generation_time']}")
    print(f"Metadata: {json.dumps(flux_dev_meta, indent=2)}")
    flux_dev_image.show()

    memory_after_dev = get_memory_stats()
    print("Memory usage after FLUX.1-dev:")
    print(json.dumps(memory_after_dev, indent=2))
else:
    print("⚠️ FLUX.1-dev skipped in SMOKE_MODE")

## 📈 參數影響實驗 (Parameter Experiments)

In [None]:
# %%
def parameter_experiment(pipe: FluxPipeline, model_name: str):
    """Test different parameters on Flow-DiT models"""

    if SMOKE_MODE:
        print("⚠️ Skipping parameter experiments in SMOKE_MODE")
        return

    print(f"\n🧪 Parameter experiments for {model_name}")

    base_prompt = "a cozy coffee shop in cyberpunk Tokyo, neon lights, rain"
    experiments = []

    # Different step counts
    step_configs = [4, 10, 25] if "dev" in model_name.lower() else [1, 4]

    for steps in step_configs:
        generator = torch.Generator(device="cuda").manual_seed(42)

        # Adjust guidance for model type
        guidance = 3.5 if "dev" in model_name.lower() else 0.0

        start_time = time.time()
        result = pipe(
            prompt=base_prompt,
            num_inference_steps=steps,
            guidance_scale=guidance,
            width=512,  # Smaller for experiments
            height=512,
            generator=generator
        )
        generation_time = time.time() - start_time

        experiments.append({
            "steps": steps,
            "guidance": guidance,
            "time": f"{generation_time:.2f}s",
            "model": model_name
        })

        print(f"Steps {steps}: {generation_time:.2f}s")

    return experiments

# Run experiments
if flux_dev is not None:
    dev_experiments = parameter_experiment(flux_dev, "FLUX.1-dev")
else:
    print("⚠️ FLUX.1-dev experiments skipped")

## 🆚 Flow-DiT vs SD 對比分析

In [None]:
# %%
def compare_architectures():
    """Compare Flow-DiT vs traditional Diffusion characteristics"""

    comparison = {
        "architecture": {
            "Flow-DiT (FLUX)": {
                "noise_schedule": "Flow matching (continuous)",
                "sampler": "Euler/Heun",
                "guidance": "Natural integration",
                "steps_range": "1-50",
                "memory_efficiency": "Good (DiT architecture)",
            },
            "Diffusion (SD)": {
                "noise_schedule": "DDPM/DDIM (discrete)",
                "sampler": "DPM++/DDIM/Euler",
                "guidance": "CFG (classifier-free)",
                "steps_range": "20-150",
                "memory_efficiency": "Moderate (U-Net)",
            },
        },
        "performance": {
            "FLUX.1-schnell": {
                "optimal_steps": 4,
                "speed": "Very Fast",
                "quality": "Good",
                "vram_usage": "~12GB",
            },
            "FLUX.1-dev": {
                "optimal_steps": 25,
                "speed": "Medium",
                "quality": "Excellent",
                "vram_usage": "~14GB",
            },
            "SD1.5": {
                "optimal_steps": 20,
                "speed": "Fast",
                "quality": "Good",
                "vram_usage": "~6GB",
            },
            "SDXL": {
                "optimal_steps": 30,
                "speed": "Medium",
                "quality": "Excellent",
                "vram_usage": "~8GB",
            },
        },
    }

    print("🆚 Architecture Comparison:")
    print(json.dumps(comparison, indent=2))

    return comparison


architecture_comparison = compare_architectures()

## 💾 結果保存與清理

In [None]:
# %%
# Save results
output_dir = pathlib.Path("outputs") / "flowdit_quickstart"
output_dir.mkdir(parents=True, exist_ok=True)

# Save images and metadata
if "flux_image" in locals():
    flux_image.save(output_dir / "flux_schnell_result.png")
    with open(output_dir / "flux_schnell_meta.json", "w") as f:
        json.dump(flux_meta, f, indent=2)

if "flux_dev_image" in locals():
    flux_dev_image.save(output_dir / "flux_dev_result.png")
    with open(output_dir / "flux_dev_meta.json", "w") as f:
        json.dump(flux_dev_meta, f, indent=2)

# Save comparison analysis
with open(output_dir / "architecture_comparison.json", "w") as f:
    json.dump(architecture_comparison, f, indent=2)

print(f"✅ Results saved to {output_dir}")

# Clean up memory
if "flux_dev" in locals() and flux_dev is not None:
    del flux_dev
gc.collect()
torch.cuda.empty_cache()

final_memory = get_memory_stats()
print("Final memory usage:")
print(json.dumps(final_memory, indent=2))

## 🧪 Smoke Test (CI Mode)

In [None]:
# %%
def smoke_test():
    """Minimal test for CI pipeline"""

    print("🧪 Running smoke test...")

    # Test basic imports
    assert torch.cuda.is_available(), "CUDA required"

    # Test pipeline loading (schnell only in smoke mode)
    test_pipe = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16, device_map="auto"
    )

    # Quick generation test
    generator = torch.Generator(device="cuda").manual_seed(42)
    result = test_pipe(
        "a red apple", num_inference_steps=1, width=256, height=256, generator=generator
    )

    assert len(result.images) == 1, "Should generate 1 image"
    assert result.images[0].size == (256, 256), "Wrong image size"

    print("✅ Smoke test passed!")

    # Cleanup
    del test_pipe
    gc.collect()
    torch.cuda.empty_cache()


if SMOKE_MODE:
    smoke_test()

 ## 📋 完成摘要 (Summary)
 
 ### ✅ 已完成
 1. **Flow-DiT 架構理解**: Flow matching vs Diffusion 核心差異
 2. **FLUX 模型實作**: schnell (4步) 與 dev (25步) 推論
 3. **記憶體最佳化**: CPU offload + attention slicing (12-16GB)
 4. **參數調優**: steps/guidance 對 Flow-DiT 的影響分析
 5. **架構對比**: Flow-DiT vs SD 在速度/品質/記憶體權衡
 
 ### 🔑 核心概念
 - **Flow Matching**: 直接學習 noise→image 軌跡，比 diffusion 更高效
 - **FLUX.1-schnell**: 4步快速推論，guidance_scale=0.0
 - **FLUX.1-dev**: 25-50步高品質，支援 guidance_scale=3.5
 - **DiT vs U-Net**: Transformer 架構 vs 卷積架構的記憶體特性
 
 ### ⚠️ 常見坑 (Pitfalls)
 1. **VRAM 需求高**: FLUX 需 12GB+，注意 CPU offload
 2. **授權限制**: 部分 Flow-DiT 模型僅推論，無法微調
 3. **Guidance 差異**: schnell 不支援 guidance，dev 才支援
 4. **Step 數限制**: schnell 最多 4步，dev 建議 25步以上
  
 ### ➡️ 下一步
 - **Stage 2**: ControlNet/T2I-Adapter 在 Flow-DiT 的適用性
 - **Stage 3**: Flow-DiT 微調可行性評估 (取決於模型授權)
 - **Stage 4**: Flow-DiT 整合到批次生圖流程