<a href="https://colab.research.google.com/github/PYH1107/generative_ai/blob/main/HW10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install diffusers transformers accelerate safetensors huggingface_hub gradio --upgrade

In [None]:
from google.colab import userdata
from huggingface_hub import login

hf_token = userdata.get("huggungface")
login(token=hf_token)

In [None]:
import torch
import gc
import random
import gradio as gr
from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler
from PIL import Image
import numpy as np

In [None]:
# ========== Cell 1: 模型配置 ==========
model_name = "runwayml/stable-diffusion-v1-5"
print(f"🎯 選擇模型: {model_name}")

In [None]:
# ========== Cell 1: 加载模型 ==========
print("🔄 loading...")

try:
    pipe = StableDiffusionPipeline.from_pretrained(
        model_name,
        torch_dtype=torch.float16 if device == "cuda" else torch.float32,
        use_safetensors=True
    ).to(device)

    # 使用更高效的调度器
    pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)

    # 根据设备优化
    if device == "cpu":
        pipe.enable_attention_slicing()
        print("OK")
    else:
        pipe.enable_model_cpu_offload()
        print("nono")

    print("✅")

except Exception as e:
    print(f"❌ failed: {e}")

In [None]:
# ========== Cell 2: Prompt recommendation ==========
# 分類 Prompt
PROMPT_CATEGORIES = {
    "portray": [
        "a beautiful woman, portrait, photorealistic, detailed face, soft lighting",
        "handsome man in suit, professional photo, studio lighting, sharp focus",
        "elderly person with kind eyes, wrinkled face, wisdom, natural light",
        "child playing in garden, innocent smile, golden hour lighting"
    ],
    "scenic": [
        "mountain landscape at sunset, dramatic sky, golden light, 8k resolution",
        "peaceful lake reflection, misty morning, serene atmosphere",
        "ancient forest with sunbeams, mystical atmosphere, detailed foliage",
        "ocean waves at sunset, dramatic sky, seascape photography"
    ],
    "city": [
        "modern city skyline at night, neon lights, urban photography",
        "old european street, cobblestone, vintage architecture, warm lighting",
        "futuristic building design, glass and steel, architectural photography",
        "cozy cafe interior, warm lighting, books and coffee, inviting atmosphere"
    ],
    "animal": [
        "majestic lion in savanna, golden hour, wildlife photography",
        "cute kitten playing with yarn, soft focus, adorable expression",
        "colorful tropical bird, detailed feathers, natural habitat",
        "underwater scene with tropical fish, coral reef, clear water"
    ],
    "sci-fi": [
        "spaceship in deep space, stars and nebula, sci-fi concept art",
        "magical forest with glowing mushrooms, fantasy atmosphere",
        "dragon flying over medieval castle, epic fantasy scene",
        "cyberpunk city street, neon signs, futuristic vehicles"
    ],
    "asthetic": [
        "oil painting style portrait, classical art, detailed brushstrokes",
        "watercolor landscape, soft colors, artistic technique",
        "digital art character design, anime style, vibrant colors",
        "pencil sketch of old tree, detailed line art, artistic drawing"
    ]
}

STYLE_ENHANCERS = [
    "masterpiece, best quality, ultra detailed",
    "photorealistic, 8k resolution, sharp focus",
    "cinematic lighting, dramatic shadows",
    "soft natural lighting, warm tones",
    "vibrant colors, high contrast",
    "minimalist, clean composition",
    "vintage style, film grain",
    "digital art, concept art style"
]


NEGATIVE_PROMPTS = {
    "general": "blurry, bad anatomy, bad hands, deformed, low quality, worst quality, jpeg artifacts",
    "character": "bad anatomy, bad hands, extra fingers, missing fingers, deformed face, ugly, disfigured",
    "scenes": "blurry, overexposed, bad composition, tilted horizon, noise, artifacts",
    "art": "low quality, bad art, amateur, poorly drawn, sketchy, unfinished"
}

print("📚 finish loading")
print(f"✨ 包含 {len(PROMPT_CATEGORIES)} categories{sum(len(prompts) for prompts in PROMPT_CATEGORIES.values())} templates")

In [None]:
# ========== Cell 7: 核心生成函数 ==========
def generate_images(prompt, use_enhance, enhance_text, use_negative, negative_text,
                    use_custom_seed, custom_seed, height, width, steps, num_images):
    """图像生成核心函数 - 基于老师的代码改进"""

    # 检查模型是否加载
    if 'pipe' not in globals():
        return [], "❌ 模型未加载！请先运行模型加载单元格"

    try:
        height = int(height)
        width = int(width)

        # 检查尺寸
        if height % 8 != 0 or width % 8 != 0:
            return [], "❌ 高度和宽度必须是8的倍数！"

        # 设置种子
        if use_custom_seed:
            base_seed = int(custom_seed)
        else:
            base_seed = random.randint(0, 2**32 - 1)

        seeds = [base_seed + i for i in range(num_images)]

        # 构建最终 prompt
        final_prompt = prompt
        if use_enhance and enhance_text:
            final_prompt = prompt + ", " + enhance_text

        final_negative = negative_text if use_negative else None

        # 清理内存
        gc.collect()
        if device == "cuda":
            torch.cuda.empty_cache()

        # 生成图像
        images = []
        for i, seed in enumerate(seeds):
            print(f"🎨 正在生成第 {i+1}/{num_images} 张图像...")

            generator = torch.Generator(device).manual_seed(seed)

            with torch.no_grad():
                image = pipe(
                    prompt=final_prompt,
                    negative_prompt=final_negative,
                    height=height,
                    width=width,
                    num_inference_steps=steps,
                    guidance_scale=7.5,
                    generator=generator
                ).images[0]
                images.append(image)

        print("✅ 所有图像生成完成！")
        return images, f"✅ 生成完成！使用的 seeds: {seeds}"

    except Exception as e:
        error_msg = f"❌ 生成失败: {str(e)}"
        print(error_msg)
        return [], error_msg

# 辅助函数
def get_random_prompt(category):
    """获取随机推荐 prompt"""
    if category in PROMPT_CATEGORIES:
        return random.choice(PROMPT_CATEGORIES[category])
    return ""

def get_random_style():
    """获取随机风格增强"""
    return random.choice(STYLE_ENHANCERS)

def get_negative_prompt(category):
    """获取对应的负面提示词"""
    return NEGATIVE_PROMPTS.get(category, NEGATIVE_PROMPTS["通用"])

print("🔧 核心函数定义完成！")

In [None]:
# ========== Cell 8: 创建 Gradio 界面 ==========
# 自定义样式
custom_css = """
.gradio-container {
    background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
    font-family: 'Arial', sans-serif;
}
.gr-button {
    background: linear-gradient(45deg, #FF6B6B, #4ECDC4);
    border: none;
    color: white;
    font-weight: bold;
    border-radius: 10px;
}
"""

# 默认值
default_enhance = "masterpiece, ultra high quality, intricate skin details, cinematic lighting"
default_negative = "bad anatomy, blurry, disfigured, poorly drawn hands, extra fingers, mutated hands, low quality, worst quality"

# 创建界面
with gr.Blocks(css=custom_css, title="AI 圖片產生器") as demo:

    gr.Markdown("""
    # 🎨 AI 圖像生成器 (完整作業版)
    ### 📝 基於 Stable Diffusion 的智慧文字生圖工具

    **✨ 核心功能：**
    - 🎯 6大分類 Prompt 推薦系統
    - 🎨 8種專業風格增強
    - ⚙️ 完整參數控制
    - 🔄 批量生成功能
    """)

    # 检查模型状态
    if 'pipe' not in globals():
        gr.Markdown("⚠️ **fffffailed**")
    else:
        gr.Markdown(f"✅ **模型已就緒** (設備: {device.upper()})")

    with gr.Row():
        # 左侧控制面板
        with gr.Column(scale=1):
            gr.Markdown("### 📝 創作控制台")

            # Prompt 推荐区
            with gr.Group():
                gr.Markdown("**🎯 prompt recommendation**")
                category_selector = gr.Dropdown(
                    choices=list(PROMPT_CATEGORIES.keys()),
                    label="選擇分類",
                    value="人物肖像"
                )
                recommend_btn = gr.Button("🎲 recommendation", variant="secondary")

            # 主要输入
            prompt_input = gr.Textbox(
                label="🖼️ main description (Prompt)",
                placeholder="please describe...",
                lines=3,
                value="a beautiful landscape with mountains and lake"
            )

            # 增强控制
            with gr.Row():
                use_enhance = gr.Checkbox(label="啟用風格增強", value=True)
                enhance_text = gr.Textbox(
                    label="風格增強內容",
                    value=default_enhance
                )

            style_selector = gr.Dropdown(
                choices=STYLE_ENHANCERS,
                label="default style",
                value=STYLE_ENHANCERS[0]
            )
            random_style_btn = gr.Button("🎲 random style", variant="secondary")

            # 负面提示词
            with gr.Row():
                use_negative = gr.Checkbox(label="使用 negative prompt", value=True)
                negative_text = gr.Textbox(
                    label="negative prompt",
                    value=default_negative,
                    lines=2
                )

            # 生成参数
            gr.Markdown("### ⚙️ 生成参数")

            with gr.Row():
                height = gr.Dropdown(["512", "768", "1024"], label="高度", value="512")
                width = gr.Dropdown(["512", "768", "1024"], label="寬度", value="512")

            steps = gr.Slider(10, 50, value=20, step=5, label="推理步數")
            num_images = gr.Slider(1, 4, step=1, value=1, label="生成數量")

            # 种子控制
            with gr.Row():
                use_custom_seed = gr.Checkbox(label="self0defined seed", value=False)
                custom_seed = gr.Number(label="seed value", value=42)

            # 生成按钮
            generate_btn = gr.Button("🚀 start", variant="primary", size="lg")

        # 右侧结果展示
        with gr.Column(scale=1):
            gr.Markdown("### 🖼️ 生成结果")

            gallery = gr.Gallery(
                label="生成結果",
                columns=2,
                object_fit="contain",
                height="auto"
            )

            status_info = gr.Textbox(
                label="生成狀態",
                value="等待...",
                interactive=False
            )

    # ========== 事件绑定 ==========

    # Prompt 推荐
    recommend_btn.click(
        fn=get_random_prompt,
        inputs=[category_selector],
        outputs=[prompt_input]
    )

    # 随机风格
    random_style_btn.click(
        fn=lambda: get_random_style(),
        outputs=[enhance_text]
    )

    # 风格选择器更新增强文本
    style_selector.change(
        fn=lambda x: x,
        inputs=[style_selector],
        outputs=[enhance_text]
    )

    # 主要生成功能
    generate_btn.click(
        fn=generate_images,
        inputs=[
            prompt_input, use_enhance, enhance_text,
            use_negative, negative_text,
            use_custom_seed, custom_seed,
            height, width, steps, num_images
        ],
        outputs=[gallery, status_info]
    )

    # manual
    gr.Markdown("""
    ### 📖 使用指南

    **🎯 快速開始：**
    1. 選擇分類 → 點擊「獲取推薦」→ 選擇風格 → 開始生成

    **💡 進階技巧：**
    - **推理步數**：20步適合快速預覽，30-50步獲得更高品質
    - **引導強度**：7.5是標準值，過高可能過度飽和
    - **批量生成**：可以一次生成多張不同seed的圖像

    **🔧 參數建議：**
    - CPU用戶：推薦512x512，20步
    - GPU用戶：可嘗試768x768或更高解析度
    """)

print("🎨 Done")

In [None]:
# ========== Cell 9: 启动应用 ==========
# 运行这个 cell 来启动 Web 应用
if __name__ == "__main__":
    print("🚀 launching...")
    demo.launch(
        share=True,      # 生成公开链接
        debug=True,      # 启用调试模式
        height=800,      # 界面高度
        show_error=True  # 显示错误信息
    )