# 【使い方】

ランタイム>すべてのセルを実行（**Ctrl+F9**）によりすべてのセルを実行し、セル[2]の最後に生成された**URL（Running on public URL）をクリック**して開いてください。（GUIが新しいタブで開かれる）

※このcolabの画面（タブ）は閉じないでください。




# [How to use]

Run all cells by selecting Runtime > Run all cells (**Ctrl+F9**), and **click the URL generated at the end of Cell [2] (Running on public URL)** to open it. (The GUI will open in a new tab)

Note: Please do not close this Colab screen (tab).

In [None]:
!pip install gradio opencv-python-headless
!pip install svgwrite numpy opencv-python

In [None]:
import gradio as gr
import cv2
import numpy as np
import zipfile
import os
import tempfile
import svgwrite

from PIL import Image
import requests
from io import BytesIO

# グローバル変数
original_images = []  # 初回アップロードした画像を保持するリスト
modified_images = []  # 色変更した画像を保持するリスト
logs = ""  # デバッグメッセージ用
initial_process_done = False  # 初回処理が完了したかどうかのフラグ

# 色変更関数
def change_object_color(img, original_num, target_num):
    global logs
    original_color = color_labels[original_num - 1]
    target_color = color_labels[target_num - 1]

    img_copy = img.copy()
    mask = np.all(img_copy == original_color, axis=-1)
    img_copy[mask] = target_color
    logs += f"Applied color change from {original_color} to {target_color}\n"
    return img_copy

# SVGファイルの保存関数
def save_mask_as_svg(img_rgb, base_name):
    image_height, image_width, _ = img_rgb.shape
    svg_filename = f"/content/{base_name}.svg"
    dwg = svgwrite.Drawing(svg_filename, profile='tiny', size=(image_width, image_height))
    dwg.add(dwg.rect(insert=(0, 0), size=(image_width, image_height), fill='black'))

    for idx, color_rgb in enumerate(color_labels):
        mask = np.all(img_rgb == color_rgb, axis=-1)
        if not np.any(mask):
            continue
        hex_color = '#{:02x}{:02x}{:02x}'.format(color_rgb[0], color_rgb[1], color_rgb[2])
        contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        for contour in contours:
            points = [(int(point[0][0]), int(point[0][1])) for point in contour]
            dwg.add(dwg.polygon(points, fill=hex_color))

    dwg.save()
    return svg_filename

# 初回の色変更処理
def process_or_change_color(image_files, original_str, target_str):
    global logs, original_images, modified_images, initial_process_done
    original_num = int(original_str)
    target_num = int(target_str)
    file_names = [img_file.name if not isinstance(img_file, str) else os.path.basename(img_file) for img_file in image_files]

    if not initial_process_done:
        logs = ""
        original_images = []
        modified_images = []
        logs += f"Processing {len(image_files)} images\n"

        for idx, img_file in enumerate(image_files):
            try:
                logs += f"Reading image {idx + 1}\n"
                if isinstance(img_file, str):
                    img = cv2.imread(img_file, cv2.IMREAD_COLOR)
                else:
                    with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
                        tmp_file.write(img_file.read())
                        temp_path = tmp_file.name
                        logs += f"Temporary file created at: {temp_path}\n"
                        img = cv2.imread(temp_path, cv2.IMREAD_COLOR)

                if img is None:
                    logs += f"Error: Image {idx + 1} could not be read\n"
                    continue

                img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                original_images.append(img_rgb)
                logs += f"Image {idx + 1} converted to RGB successfully\n"

                modified_img_rgb = change_object_color(img_rgb, original_num, target_num)
                modified_images.append(modified_img_rgb)
                logs += f"Image {idx + 1} processed successfully\n"

            except Exception as e:
                logs += f"Exception while processing image {idx + 1}: {str(e)}\n"

        initial_process_done = True
    else:
        logs += f"\nProcessing color change for another object\n"
        new_modified_images = []

        for idx, img_rgb in enumerate(modified_images):
            logs += f"Changing color for image {idx + 1}\n"
            modified_img_rgb = change_object_color(img_rgb, original_num, target_num)
            new_modified_images.append(modified_img_rgb.copy())
            logs += f"Image {idx + 1} processed successfully\n"

        modified_images = new_modified_images

    png_zip_path, svg_zip_path = save_modified_images(file_names)
    return modified_images, png_zip_path, svg_zip_path, logs, original_images  # original_imagesを返す

# アップロードされた画像をギャラリーに表示する関数
def display_uploaded_images(image_files):
    global original_images
    original_images = []  # リセット
    for img_file in image_files:
        if isinstance(img_file, str):
            img = cv2.imread(img_file, cv2.IMREAD_COLOR)
        else:
            with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
                tmp_file.write(img_file.read())
                img = cv2.imread(tmp_file.name, cv2.IMREAD_COLOR)
        if img is not None:
            original_images.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    return original_images

# PNGファイルとSVGファイルを別々のZIPに保存する関数
def save_modified_images(file_names):
    png_zip_path = '/content/modified_masks_png.zip'
    svg_zip_path = '/content/modified_masks_svg.zip'

    with zipfile.ZipFile(png_zip_path, 'w') as png_zip, zipfile.ZipFile(svg_zip_path, 'w') as svg_zip:
        for idx, img_rgb in enumerate(modified_images):
            base_name = os.path.splitext(file_names[idx])[0]
            png_file_path = f"/content/{base_name}.png"
            img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
            cv2.imwrite(png_file_path, img_bgr)
            png_zip.write(png_file_path, os.path.basename(png_file_path))

            svg_file_path = save_mask_as_svg(img_rgb, base_name)
            svg_zip.write(svg_file_path, os.path.basename(svg_file_path))

    return png_zip_path, svg_zip_path

# 色ラベル定義
color_labels = [
    (255, 0, 0), (0, 0, 255), (0, 255, 0), (255, 255, 0),
    (128, 0, 128), (255, 165, 0), (0, 255, 255), (173, 255, 47),
    (128, 128, 128), (0, 128, 128), (255, 192, 203), (255, 20, 147),
    (0, 128, 0), (128, 0, 0), (0, 255, 230), (255, 215, 0),
    (255, 69, 0), (0, 0, 128), (220, 20, 60), (128, 128, 0)
]

# 画像をURLからダウンロード
url = "https://github.com/SatoruMuro/SAM2GUIfor3Drecon/blob/main/images/colorlist.png?raw=true"
response = requests.get(url)
image = Image.open(BytesIO(response.content))

# Gradioインターフェース定義
with gr.Blocks() as demo:
    gr.Markdown("## Object Color Changer")

    with gr.Row():
        with gr.Column():
            image_files_input = gr.File(label="Upload Mask Images (PNG)", file_types=["image"], file_count="multiple")
            original_gallery = gr.Gallery(label="Uploaded Images")
            debug_output = gr.Textbox(label="Debug Log", interactive=False, lines=20)

        with gr.Column():
            gr.Image(value=image, label="Color List")
            original_num = gr.Radio(choices=[str(i) for i in range(1, 21)], label="Original Object Number", value="1")
            target_num = gr.Radio(choices=[str(i) for i in range(1, 21)], label="Target Object Number", value="2")
            btn_process = gr.Button("Apply Object Color Change")
            image_output = gr.Gallery(label="Modified Masks")
            download_png_zip = gr.File(label="Download PNG Zip File", visible=True)
            download_svg_zip = gr.File(label="Download SVG Zip File", visible=True)

    image_files_input.change(
        display_uploaded_images,
        [image_files_input],
        [original_gallery]
    )

    btn_process.click(
        process_or_change_color,
        [image_files_input, original_num, target_num],
        [image_output, download_png_zip, download_svg_zip, debug_output, original_gallery]
    )

demo.launch()
