# Sketch to Image Application

Colab 환경에서 스케치 투 이미지 애플리케이션을 만들어봅시다.


## 패키지 및 예제 데이터 다운로드하기
python package들을 설치합니다. Colab에서 실행하지 않는 경우 이 셀은 실행하지 않습니다.

In [None]:
!wget https://raw.githubusercontent.com/mrsyee/dl_apps/main/image_generation/requirements-colab.txt
!pip install -r requirements-colab.txt

## 패키지 불러오기

In [None]:
import os
from typing import IO

import gradio as gr
import requests
import torch
from tqdm import tqdm
from diffusers import StableDiffusionImg2ImgPipeline
from PIL import Image

## 스케치 투 이미지 생성 UI 구현하기

In [None]:
WIDTH = 512
HEIGHT = 512

with gr.Blocks() as app:
    gr.Markdown("## 프롬프트 입력")
    with gr.Row():
        prompt = gr.Textbox(label="Prompt")
    with gr.Row():
        n_prompt = gr.Textbox(label="Negative Prompt")

    gr.Markdown("## 스케치 to 이미지 생성")
    with gr.Row():
        with gr.Column():
            with gr.Tab("Canvas"):
                with gr.Row():
                    canvas = gr.Image(
                        label="Draw",
                        source="canvas",
                        image_mode="RGB",
                        tool="color-sketch",
                        interactive=True,
                        width=WIDTH,
                        height=HEIGHT,
                        shape=(WIDTH, HEIGHT),
                        brush_radius=20,
                        type="pil",
                    )
                with gr.Row():
                    canvas_run_btn = gr.Button(value="Generate")

            with gr.Tab("File"):
                with gr.Row():
                    file = gr.Image(
                        label="Upload",
                        source="upload",
                        image_mode="RGB",
                        tool="color-sketch",
                        interactive=True,
                        width=WIDTH,
                        height=HEIGHT,
                        shape=(WIDTH, HEIGHT),
                        type="pil",
                    )
                with gr.Row():
                    file_run_btn = gr.Button(value="Generate")

        with gr.Column():
            result_gallery = gr.Gallery(label="Output", height=512)

In [None]:
app.launch(inline=False, share=True)

In [None]:
app.close()

## 모델 다운로드 UI 구현하기

In [None]:
with gr.Blocks() as app:
    gr.Markdown("## 모델 다운로드")
    with gr.Row():
        model_url = gr.Textbox(label="모델 URL", placeholder="https://civitai.com/")
        download_model_btn = gr.Button(value="모델 다운로드")
    with gr.Row():
        model_file = gr.File(label="모델 파일")

In [None]:
app.launch(inline=False, share=True)

In [None]:
app.close()

## 모델 다운로드 기능 구현하기

In [None]:
def download_model(url: str) -> str:
    model_id = url.replace("https://civitai.com/models/", "").split("/")[0]

    try:
        response = requests.get(f"https://civitai.com/api/v1/models/{model_id}", timeout=600)
    except Exception as err:
        print(f"[ERROR] {err}")
        raise err

    download_url = response.json()["modelVersions"][0]["downloadUrl"]
    filename = response.json()["modelVersions"][0]["files"][0]["name"]

    file_path = f"models/{filename}"
    if os.path.exists(file_path):
        print(f"[INFO] File already exists: {file_path}")
        return file_path

    os.makedirs("models", exist_ok=True)
    download_from_url(download_url, file_path)
    print(f"[INFO] File downloaded: {file_path}")
    return file_path


def download_from_url(url: str, file_path: str, chunk_size: int = 1024):
    resp = requests.get(url, stream=True)
    total = int(resp.headers.get('content-length', 0))
    with open(file_path, 'wb') as file, tqdm(
        desc=file_path,
        total=total,
        unit='iB',
        unit_scale=True,
        unit_divisor=1024,
    ) as bar:
        for data in resp.iter_content(chunk_size=chunk_size):
            size = file.write(data)
            bar.update(size)

In [None]:
with gr.Blocks() as app:
    gr.Markdown("## 모델 다운로드")
    with gr.Row():
        model_url = gr.Textbox(label="모델 URL", placeholder="https://civitai.com/")
        download_model_btn = gr.Button(value="모델 다운로드")
    with gr.Row():
        model_file = gr.File(label="모델 파일")

    download_model_btn.click(
        download_model,
        [model_url],
        [model_file],
    )

In [None]:
app.queue().launch(inline=False, share=True)

In [None]:
app.close()

## 모델 불러오기 UI 및 기능 구현하기

In [None]:
with gr.Blocks() as app:
    gr.Markdown("## 모델 불러오기")
    with gr.Row():
        load_model_btn = gr.Button(value="모델 불러오기")
    with gr.Row():
        is_model_check = gr.Textbox(label="Model Load Check", value="Model Not Loaded")

In [None]:
PIPELINE = None

def init_pipeline(model_file: IO) -> str:
    print("[INFO] Initialize pipeline")
    global PIPELINE
    PIPELINE = StableDiffusionImg2ImgPipeline.from_single_file(
        model_file.name,
        torch_dtype=torch.float16,
        variant="fp16",
        use_safetensors=True,
    ).to("cuda")

    print("[INFO] Initialized pipeline")
    return "Model Loaded!"

In [None]:
with gr.Blocks() as app:
    gr.Markdown("## 모델 다운로드")
    with gr.Row():
        model_url = gr.Textbox(label="모델 URL", placeholder="https://civitai.com/")
        download_model_btn = gr.Button(value="모델 다운로드")
    with gr.Row():
        model_file = gr.File(label="모델 파일")

    gr.Markdown("## 모델 불러오기")
    with gr.Row():
        load_model_btn = gr.Button(value="모델 불러오기")
    with gr.Row():
        is_model_check = gr.Textbox(label="Model Load Check", value="Model Not Loaded")

    download_model_btn.click(
        download_model,
        [model_url],
        [model_file],
    )
    load_model_btn.click(
        init_pipeline,
        [model_file],
        [is_model_check],
    )

In [None]:
app.queue().launch(inline=False, share=True)

In [None]:
app.close()

## 스케치 투 이미지 생성 기능 구현하기

In [None]:
def sketch_to_image(sketch: Image.Image, prompt: str, negative_prompt: str):
    width, height = sketch.size
    images =  PIPELINE(
        image=sketch,
        prompt=prompt,
        negative_prompt=negative_prompt,
        height=height,
        width=width,
        num_images_per_prompt=4,
        num_inference_steps=20,
        strength=0.7,
    ).images

    with torch.cuda.device("cuda"):
        torch.cuda.empty_cache()

    return images

In [None]:
print("[INFO] Gradio app ready")
with gr.Blocks() as app:
    gr.Markdown("# 스케치 to 이미지 애플리케이션")

    gr.Markdown("## 모델 다운로드")
    with gr.Row():
        model_url = gr.Textbox(label="Model Link", placeholder="https://civitai.com/")
        download_model_btn = gr.Button(value="Download model")
    with gr.Row():
        model_file = gr.File(label="Model File")

    gr.Markdown("## 모델 불러오기")
    with gr.Row():
        load_model_btn = gr.Button(value="Load model")
    with gr.Row():
        is_model_check = gr.Textbox(label="Model Load Check", value="Model Not loaded")

    gr.Markdown("## 프롬프트 입력")
    with gr.Row():
        prompt = gr.Textbox(label="Prompt")
    with gr.Row():
        n_prompt = gr.Textbox(label="Negative Prompt")

    gr.Markdown("## 스케치 to 이미지 생성")
    with gr.Row():
        with gr.Column():
            with gr.Tab("Canvas"):
                with gr.Row():
                    canvas = gr.Image(
                        label="Draw",
                        source="canvas",
                        image_mode="RGB",
                        tool="color-sketch",
                        interactive=True,
                        width=WIDTH,
                        height=HEIGHT,
                        shape=(WIDTH, HEIGHT),
                        brush_radius=20,
                        type="pil",
                    )
                with gr.Row():
                    canvas_run_btn = gr.Button(value="Generate")

            with gr.Tab("File"):
                with gr.Row():
                    file = gr.Image(
                        label="Upload",
                        source="upload",
                        image_mode="RGB",
                        tool="color-sketch",
                        interactive=True,
                        width=WIDTH,
                        height=HEIGHT,
                        shape=(WIDTH, HEIGHT),
                        type="pil",
                    )
                with gr.Row():
                    file_run_btn = gr.Button(value="Generate")

        with gr.Column():
            result_gallery = gr.Gallery(label="Output", height=512)


    # Event
    download_model_btn.click(
        download_model,
        [model_url],
        [model_file],
    )
    load_model_btn.click(
        init_pipeline,
        [model_file],
        [is_model_check],
    )
    canvas_run_btn.click(
        sketch_to_image,
        [canvas, prompt, n_prompt],
        [result_gallery],
    )
    file_run_btn.click(
        sketch_to_image,
        [file, prompt, n_prompt],
        [result_gallery],
    )

In [None]:
app.queue().launch(inline=False, share=True)

In [None]:
app.close()

## 최종 App 구현

In [None]:
import os
from typing import IO

import gradio as gr
import requests
import torch
from tqdm import tqdm
from diffusers import StableDiffusionImg2ImgPipeline
from PIL import Image

In [None]:
WIDTH = 512
HEIGHT = 512

PIPELINE = None

In [None]:
def download_model(url: str) -> str:
    model_id = url.replace("https://civitai.com/models/", "").split("/")[0]

    try:
        response = requests.get(f"https://civitai.com/api/v1/models/{model_id}", timeout=600)
    except Exception as err:
        print(f"[ERROR] {err}")
        raise err

    download_url = response.json()["modelVersions"][0]["downloadUrl"]
    filename = response.json()["modelVersions"][0]["files"][0]["name"]

    file_path = f"models/{filename}"
    if os.path.exists(file_path):
        print(f"[INFO] File already exists: {file_path}")
        return file_path

    os.makedirs("models", exist_ok=True)
    download_from_url(download_url, file_path)
    print(f"[INFO] File downloaded: {file_path}")
    return file_path


def download_from_url(url: str, file_path: str, chunk_size=1024):
    resp = requests.get(url, stream=True)
    total = int(resp.headers.get('content-length', 0))
    with open(file_path, 'wb') as file, tqdm(
        desc=file_path,
        total=total,
        unit='iB',
        unit_scale=True,
        unit_divisor=1024,
    ) as bar:
        for data in resp.iter_content(chunk_size=chunk_size):
            size = file.write(data)
            bar.update(size)


def init_pipeline(model_file: IO) -> str:
    print("[INFO] Initialize pipeline")
    global PIPELINE
    PIPELINE = StableDiffusionImg2ImgPipeline.from_single_file(
        model_file.name,
        torch_dtype=torch.float16,
        use_safetensors=True,
    ).to("cuda")
    print("[INFO] Initialized pipeline")
    return "Model Loaded!"


def sketch_to_image(sketch: Image.Image, prompt: str, negative_prompt: str):
    width, height = sketch.size
    return PIPELINE(
        image=sketch,
        prompt=prompt,
        negative_prompt=negative_prompt,
        height=height,
        width=width,
        num_images_per_prompt=4,
        num_inference_steps=20,
        strength=0.7,
    ).images

    with torch.cuda.device("cuda"):
        torch.cuda.empty_cache()


In [None]:
print("[INFO] Gradio app ready")
with gr.Blocks() as app:
    gr.Markdown("# 스케치 to 이미지 애플리케이션")

    gr.Markdown("## 모델 다운로드")
    with gr.Row():
        model_url = gr.Textbox(label="Model Link", placeholder="https://civitai.com/")
        download_model_btn = gr.Button(value="Download model")
    with gr.Row():
        model_file = gr.File(label="Model File")

    gr.Markdown("## 모델 불러오기")
    with gr.Row():
        load_model_btn = gr.Button(value="Load model")
    with gr.Row():
        is_model_check = gr.Textbox(label="Model Load Check", value="Model Not loaded")

    gr.Markdown("## 프롬프트 입력")
    with gr.Row():
        prompt = gr.Textbox(label="Prompt")
    with gr.Row():
        n_prompt = gr.Textbox(label="Negative Prompt")

    gr.Markdown("## 스케치 to 이미지 생성")
    with gr.Row():
        with gr.Column():
            with gr.Tab("Canvas"):
                with gr.Row():
                    canvas = gr.Image(
                        label="Draw",
                        source="canvas",
                        image_mode="RGB",
                        tool="color-sketch",
                        interactive=True,
                        width=WIDTH,
                        height=HEIGHT,
                        shape=(WIDTH, HEIGHT),
                        brush_radius=20,
                        type="pil",
                    )
                with gr.Row():
                    canvas_run_btn = gr.Button(value="Generate")

            with gr.Tab("File"):
                with gr.Row():
                    file = gr.Image(
                        label="Upload",
                        source="upload",
                        image_mode="RGB",
                        tool="color-sketch",
                        interactive=True,
                        width=WIDTH,
                        height=HEIGHT,
                        shape=(WIDTH, HEIGHT),
                        type="pil",
                    )
                with gr.Row():
                    file_run_btn = gr.Button(value="Generate")

        with gr.Column():
            result_gallery = gr.Gallery(label="Output", height=512)


    # Event
    download_model_btn.click(
        download_model,
        [model_url],
        [model_file],
    )
    load_model_btn.click(
        init_pipeline,
        [model_file],
        [is_model_check],
    )
    canvas_run_btn.click(
        sketch_to_image,
        [canvas, prompt, n_prompt],
        [result_gallery],
    )
    file_run_btn.click(
        sketch_to_image,
        [file, prompt, n_prompt],
        [result_gallery],
    )

In [None]:
app.queue().launch(inline=False, share=True)

In [None]:
app.close()