<a href="https://colab.research.google.com/github/Vladus-CPU/IQ/blob/main/work3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install --quiet colorcet diffusers gradio hf-transfer huggingface-hub==0.14.0 matplotlib \
    numpy==1.23.5 opencv-contrib-python-headless pandas Pillow rich git+https://github.com/huggingface/pytorch-image-models.git@main#egg=timm \
    tokenizers torch>=2.1.0 torchvision transformers

# При бажанні можете клонувати код з репозиторію, якщо він опублікований на GitHub,
# або просто зкопіювати його у наступну комірку. Наведений нижче приклад показує,
# як скопіювати файли безпосередньо:
# !git clone https://github.com/user/wd-tagger-heatmap-example.git

[0m[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
accelerate 1.2.1 requires huggingface-hub>=0.21.0, but you have huggingface-hub 0.14.0 which is incompatible.
albucore 0.0.19 requires numpy>=1.24.4, but you have numpy 1.23.5 which is incompatible.
albumentations 1.4.20 requires numpy>=1.24.4, but you have numpy 1.23.5 which is incompatible.
bigframes 1.29.0 requires numpy>=1.24.0, but you have numpy 1.23.5 which is incompatible.
chex 0.1.88 requires numpy>=1.24.1, but you have numpy 1.23.5 which is incompatible.
google-genai 0.3.0 requires websockets<15.0dev,>=13.0, but you have websockets 11.0.3 which is incompatible.
jax 0.4.33 requires numpy>=1.24, but you have numpy 1.23.5 which is incompatible.
jaxlib 0.4.33 requires numpy>=1.24, but you have numpy 1.23.5 which is incompatible.
peft 0.14.0 requires huggingface-hub>=0.25.0, but you have huggingface-

In [9]:
# Install all required packages
!pip install gradio torch torchvision timm huggingface_hub colorcet
!pip install --upgrade gradio

import gradio as gr
import math
import time
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Optional, List

import numpy as np
import pandas as pd
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HfHubHTTPError
from PIL import Image
from torch import Tensor, nn
import torch.nn.functional as F

import cv2
import colorcet as cc
from matplotlib.colors import LinearSegmentedColormap

import timm
from timm.data import create_transform, resolve_data_config
from timm.models import VisionTransformer
from torchvision import transforms as T
import json
import csv
import io

# ---------------------
# BEGIN: common.py code
# ---------------------

@dataclass
class Heatmap:
    label: str
    score: float
    image: bytes  # Serialized image bytes

@dataclass
class LabelData:
    names: list[str]
    rating: list[np.int64]
    general: list[np.int64]
    character: list[np.int64]

@dataclass
class ImageLabels:
    caption: str
    booru: str
    rating: dict[str, float]
    general: dict[str, float]
    character: dict[str, float]

@lru_cache(maxsize=5)
def load_labels_hf(repo_id: str, revision: Optional[str] = None, token: Optional[str] = None) -> LabelData:
    try:
        csv_path = hf_hub_download(
            repo_id=repo_id,
            filename="selected_tags.csv",
            revision=revision,
            token=token
        )
        csv_path = Path(csv_path).resolve()
    except HfHubHTTPError as e:
        print(f"Error downloading selected_tags.csv from {repo_id}: {e}")
        return LabelData(names=[], rating=[], general=[], character=[])
    df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
    tag_data = LabelData(
        names=df["name"].tolist(),
        rating=list(np.where(df["category"] == 9)[0]),
        general=list(np.where(df["category"] == 0)[0]),
        character=list(np.where(df["category"] == 4)[0]),
    )
    return tag_data

def pil_ensure_rgb(image: Image.Image) -> Image.Image:
    if image.mode not in ["RGB", "RGBA"]:
        image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
    if image.mode == "RGBA":
        canvas = Image.new("RGBA", image.size, (255, 255, 255))
        canvas.alpha_composite(image)
        image = canvas.convert("RGB")
    return image

def pil_pad_square(image: Image.Image, fill: tuple[int, int, int] = (255, 255, 255)) -> Image.Image:
    w, h = image.size
    px = max(image.size)
    canvas = Image.new("RGB", (px, px), fill)
    canvas.paste(image, ((px - w) // 2, (px - h) // 2))
    return canvas

def preprocess_image(image: Image.Image, size_px: int | tuple[int, int], upscale: bool = True) -> Image.Image:
    if isinstance(size_px, int):
        size_px = (size_px, size_px)
    image = pil_ensure_rgb(image)
    image = pil_pad_square(image)
    if image.size[0] < size_px[0] or image.size[1] < size_px[1]:
        if not upscale:
            raise ValueError("Image is smaller than target size, and upscaling is disabled")
        image = image.resize(size_px, Image.LANCZOS)
    if image.size[0] > size_px[0] or image.size[1] > size_px[1]:
        image.thumbnail(size_px, Image.BICUBIC)
    return image

def pil_make_grid(
    images: list[Image.Image],
    max_cols: int = 8,
    padding: int = 4,
    bg_color: tuple[int, int, int] = (40, 42, 54),
    partial_rows: bool = True,
) -> Image.Image:
    if not images:
        return Image.new("RGB", (200, 200), bg_color)
    n_cols = min(math.ceil(math.sqrt(len(images))), max_cols)
    n_rows = math.ceil(len(images) / n_cols)
    if n_cols * n_rows > len(images) and not partial_rows:
        n_rows -= 1
    image_width, image_height = images[0].size
    canvas_width = ((image_width + padding) * n_cols) + padding
    canvas_height = ((image_height + padding) * n_rows) + padding
    canvas = Image.new("RGB", (canvas_width, canvas_height), bg_color)
    for i, img in enumerate(images):
        x = (i % n_cols) * (image_width + padding) + padding
        y = (i // n_cols) * (image_height + padding) + padding
        canvas.paste(img, (x, y))
    return canvas

# ---------------------
# END: common.py code
# ---------------------

# -----------------------
# BEGIN: model.py code
# -----------------------

class RGBtoBGR(nn.Module):
    def forward(self, x: Tensor) -> Tensor:
        if x.ndim == 4:
            return x[:, [2, 1, 0], :, :]
        return x[[2, 1, 0], :, :]

model_cache: dict[str, VisionTransformer] = {}
transform_cache: dict[str, T.Compose] = {}

def model_device(model: nn.Module) -> torch.device:
    return next(model.parameters()).device

def load_model(repo_id: str) -> VisionTransformer:
    global model_cache
    if repo_id not in model_cache:
        model_cache[repo_id] = timm.create_model("hf-hub:" + repo_id, pretrained=True).eval().to(
            torch.device("cuda" if torch.cuda.is_available() else "cpu")
        )
    return model_cache[repo_id]

def load_model_and_transform(repo_id: str) -> tuple[VisionTransformer, T.Compose]:
    global transform_cache, model_cache
    if repo_id not in model_cache:
        model = timm.create_model("hf-hub:" + repo_id, pretrained=True).eval()
        model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
        model_cache[repo_id] = model
    if repo_id not in transform_cache:
        transforms = create_transform(**resolve_data_config(model.pretrained_cfg, model=model_cache[repo_id]))
        transform_cache[repo_id] = T.Compose(transforms.transforms + [RGBtoBGR()])
    return model_cache[repo_id], transform_cache[repo_id]

def get_tags(probs: Tensor, labels: LabelData, gen_threshold: float, char_threshold: float):
    if probs is None or not probs.numel():
        return "", "", {}, {}, {}

    probs_list = list(zip(labels.names, probs.numpy()))
    rating_labels = {probs_list[i][0]: probs_list[i][1] for i in labels.rating}
    gen_labels = {k: v for k, v in probs_list if v > gen_threshold and any(i in labels.general for i, _ in probs_list)}
    char_labels = {k: v for k, v in probs_list if v > char_threshold and any(i in labels.character for i, _ in probs_list)}

    combined_names = list(gen_labels.keys()) + list(char_labels.keys())
    caption = ", ".join(combined_names) if combined_names else ""
    booru = caption.replace("_", " ") if caption else ""

    return caption, booru, rating_labels, char_labels, gen_labels

@torch.no_grad()
def render_heatmap(
    image: Tensor,
    gradients: Tensor,
    image_feats: Tensor,
    image_probs: Tensor,
    image_labels: list[str],
    cmap: LinearSegmentedColormap = cc.m_linear_bmy_10_95_c71,
    image_size: tuple[int, int] = (448, 448),
    font_args: dict = {
        "fontFace": cv2.FONT_HERSHEY_SIMPLEX,
        "fontScale": 1,
        "color": (255, 255, 255),
        "thickness": 2,
        "lineType": cv2.LINE_AA,
    },
    partial_rows: bool = True,
) -> tuple[List[Heatmap], bytes]:
    image_hmaps = gradients.mean(2, keepdim=True).mul(image_feats.unsqueeze(0)).squeeze()
    hmap_dim = int(math.sqrt(image_hmaps.mean(-1).numel() / len(image_labels)))
    image_hmaps = image_hmaps.mean(-1).reshape(len(image_labels), -1)
    image_hmaps = image_hmaps[..., -hmap_dim**2:]
    image_hmaps = image_hmaps.reshape(len(image_labels), hmap_dim, hmap_dim)
    image_hmaps = image_hmaps.max(torch.zeros_like(image_hmaps))
    image_hmaps /= image_hmaps.reshape(image_hmaps.shape[0], -1).max(-1)[0].unsqueeze(-1).unsqueeze(-1)
    image_hmaps = torch.stack([(x - x.min()) / (x.max() - x.min() + 1e-8) for x in image_hmaps]).unsqueeze(1)
    image_hmaps = F.interpolate(image_hmaps, size=image_size, mode="bilinear").squeeze(1)

    hmap_imgs = []
    for tag, hmap, score in zip(image_labels, image_hmaps, image_probs.cpu()):
        image_pixels = image.add(1).mul(127.5).squeeze().permute(1, 2, 0).cpu().numpy().astype(np.uint8)
        hmap_pixels = cmap(hmap.cpu().numpy(), bytes=True)[:, :, :3]
        hmap_cv2 = cv2.cvtColor(hmap_pixels, cv2.COLOR_RGB2BGR)
        hmap_image = cv2.addWeighted(image_pixels, 0.5, hmap_cv2, 0.5, 0)

        if tag:
            cv2.putText(hmap_image, tag, (10, 30), **font_args)
            cv2.putText(hmap_image, f"{score:.3f}", (10, 60), **font_args)

        hmap_pil = Image.fromarray(cv2.cvtColor(hmap_image, cv2.COLOR_BGR2RGB))
        img_byte_arr = io.BytesIO()
        hmap_pil.save(img_byte_arr, format='PNG')
        img_bytes = img_byte_arr.getvalue()

        hmap_imgs.append(Heatmap(label=tag, score=score.item(), image=img_bytes))

    hmap_imgs.sort(key=lambda x: x.score, reverse=True)
    hmap_grid = pil_make_grid([Image.open(io.BytesIO(x.image)) for x in hmap_imgs], partial_rows=partial_rows)

    grid_byte_arr = io.BytesIO()
    hmap_grid.save(grid_byte_arr, format='PNG')
    grid_bytes = grid_byte_arr.getvalue()

    return hmap_imgs, grid_bytes

def process_heatmap(
    model: VisionTransformer,
    image: Tensor,
    labels: LabelData,
    threshold: float = 0.5,
    partial_rows: bool = True,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    start_time = time.time()

    with torch.set_grad_enabled(True):
        features = model.forward_features(image.to(device))
        probs = model.forward_head(features)
        if probs is not None:
            probs = torch.sigmoid(probs).squeeze(0)
            probs_mask = probs > threshold
            heatmap_probs = probs[probs_mask]
            label_indices = torch.nonzero(probs_mask, as_tuple=False).squeeze(1)
            image_labels = [labels.names[i] for i in label_indices if i < len(labels.names)]
            eye = torch.eye(heatmap_probs.shape[0], device=device)
            grads = torch.autograd.grad(
                outputs=heatmap_probs,
                inputs=features,
                grad_outputs=eye,
                is_grads_batched=True,
                retain_graph=True,
            )
            grads = grads[0].detach().requires_grad_(False)[:, 0, :, :].unsqueeze(1)
        else:
            print("Model output is None")
            return ([], b'', ImageLabels("", "", {}, {}, {}), 0, "CPU")

    with torch.set_grad_enabled(False):
        if probs is not None:
            hmap_imgs, hmap_grid = render_heatmap(
                image=image,
                gradients=grads,
                image_feats=features,
                image_probs=heatmap_probs,
                image_labels=image_labels,
                partial_rows=partial_rows,
            )
            caption, booru, rating_dict, character_dict, general_dict = get_tags(
                probs=probs.cpu(),
                labels=labels,
                gen_threshold=threshold,
                char_threshold=threshold,
            )
            total_time = time
            time - start_time
            device_type = "GPU" if torch.cuda.is_available() else "CPU"
            image_labels_res = ImageLabels(caption, booru, rating_dict, general_dict, character_dict)
        else:
            hmap_imgs, hmap_grid = [], b''
            image_labels_res = ImageLabels("", "", {}, {}, {})
            total_time = 0
            device_type = "CPU"

    return (hmap_imgs, hmap_grid, image_labels_res, total_time, device_type)

# -----------------------
# END: model.py code
# -----------------------

# -----------------------
# BEGIN: app.py code
# -----------------------
from os import getenv
from io import BytesIO

# Application Metadata
TITLE = "WD Tagger Heatmap With Advanced Features"
DESCRIPTION = """WD Tagger extended with batch processing, model comparison, searchable tags, exporting, and advanced customization."""
HF_TOKEN = getenv("HF_TOKEN", None)

# Available Models for Selection
AVAILABLE_MODEL_REPOS = [
    'SmilingWolf/wd-convnext-tagger-v3',
    'SmilingWolf/wd-swinv2-tagger-v3',
    'SmilingWolf/wd-vit-tagger-v3',
    'SmilingWolf/wd-vit-large-tagger-v3',
    "SmilingWolf/wd-eva02-large-tagger-v3",
]
DEFAULT_MODEL = "SmilingWolf/wd-vit-tagger-v3"
WORK_DIR = Path(".").resolve()
IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"]

# Example Images Placeholder (optional)
example_images = []

def predict_batch(
    images: List[Image.Image],
    model_repos: List[str],
    threshold: float,
    partial_rows: bool,
    custom_gen_threshold: float,
    custom_char_threshold: float,
):
    all_results = []
    for img in images:
        per_model_outputs = []
        for repo in model_repos:
            model, transform = load_model_and_transform(repo)
            labels: LabelData = load_labels_hf(repo, token=HF_TOKEN)
            processed_img = preprocess_image(img, (448, 448))
            input_tensor = transform(processed_img).unsqueeze(0)

            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            model = model.to(device)
            start_time = time.time()
            try:
                hmap_imgs, hmap_grid, image_labels_res, total_time, device_type = process_heatmap(
                    model=model,
                    image=input_tensor,
                    labels=labels,
                    threshold=threshold,
                    partial_rows=partial_rows,
                )
            except Exception as e:
                print(f"Error processing model {repo}: {e}")
                continue

            per_model_outputs.append({
                "model_repo": repo,
                "heatmap_images": hmap_imgs,
                "heatmap_grid": hmap_grid,
                "caption": image_labels_res.caption,
                "tags": image_labels_res.booru,
                "rating": image_labels_res.rating,
                "character": image_labels_res.character,
                "general": image_labels_res.general,
                "inference_time_s": f"{total_time:.2f}",
                "device_type": device_type
            })
        all_results.append(per_model_outputs)
    return all_results

def build_searchable_tags(results):
    tag_set = set()
    for img_res in results:
        for model_res in img_res:
            tag_set.update(model_res["general"].keys())
            tag_set.update(model_res["character"].keys())
    return sorted(list(tag_set))

def export_csv_json(results):
    csv_buffer = io.StringIO()
    csv_writer = csv.writer(csv_buffer)
    csv_writer.writerow(["ImageIndex", "ModelRepo", "Caption", "Tags", "InferenceTime_s", "DeviceType"])
    for img_idx, img_res in enumerate(results):
        for model_res in img_res:
            csv_writer.writerow([
                img_idx + 1,
                model_res["model_repo"],
                model_res["caption"],
                model_res["tags"],
                model_res["inference_time_s"],
                model_res["device_type"]
            ])
    csv_content = csv_buffer.getvalue()

    json_content = json.dumps([model_res.__dict__ for img_res in results for model_res in img_res], default=lambda o: f"<not serializable> {type(o)}", indent=2)
    return csv_content, json_content

# Custom CSS for Styling
css = """
#use_mcut, #char_mcut {
    padding-top: var(--scale-3);
}
"""

# Initialize Gradio Interface
with gr.Blocks(theme="default", analytics_enabled=False, title=TITLE, css=css) as demo:
    gr.Markdown(f"# {TITLE}")
    gr.Markdown(DESCRIPTION)

    with gr.Tab("Main Inference"):
        with gr.Row():
            with gr.Column():
                input_images = gr.Files(
                    label="Upload Multiple Images",
                    file_types=["image"],
                    type="binary",
                    interactive=True
                )
                model_choices = gr.CheckboxGroup(
                    choices=AVAILABLE_MODEL_REPOS,
                    value=[DEFAULT_MODEL],
                    label="Select one or more Models"
                )

                threshold_slider = gr.Slider(0.0, 1.0, 0.35, 0.01, label="Base Threshold")

                partial_rows_toggle = gr.Checkbox(
                    value=True, label="Allow Partial Rows in Heatmap Grid"
                )

                custom_gen_threshold_slider = gr.Slider(
                    0.0, 1.0, 0.35, 0.01, label="Custom General Tag Threshold"
                )
                custom_char_threshold_slider = gr.Slider(
                    0.0, 1.0, 0.35, 0.01, label="Custom Character Tag Threshold"
                )

                run_button = gr.Button("Process", variant="primary")
                export_button = gr.Button("Export Results as CSV & JSON")
                csv_output = gr.File(label="CSV Export", file_count="single")
                json_output = gr.File(label="JSON Export", file_count="single")

            with gr.Column():
                result_data = gr.State([])

                all_results_text = gr.Textbox(
                    label="Consolidated Results (Textual)",
                    interactive=False,
                    lines=20
                )

                heatmap_grid_output = gr.Image(
                    label="Heatmap Grid",
                    interactive=False
                )

                search_input = gr.Textbox(
                    label="Search Tags",
                    placeholder="Enter text to filter detected tags"
                )
                search_results = gr.Label(label="Matching Tags")

                def update_search(search_text, rdata):
                    if not rdata:
                        return "No results loaded."
                    all_tags = build_searchable_tags(rdata)
                    filtered = [t for t in all_tags if search_text.lower() in t.lower()]
                    if not filtered:
                        return "No matching tags found."
                    return ", ".join(filtered)

                search_input.change(
                    update_search,
                    inputs=[search_input, result_data],
                    outputs=search_results
                )

        def run_inference_fn(files, model_repos, threshold, partial_rows, gen_thresh, char_thresh):
            try:
                if not files:
                    return [], "No images provided."

                images = []
                for file in files:
                    try:
                        image = Image.open(BytesIO(file))
                        images.append(image)
                    except Exception as e:
                        print(f"Error processing file {file}: {e}")

                if not images:
                    return [], "No valid images were uploaded."

                results = predict_batch(images, model_repos, threshold, partial_rows, gen_thresh, char_thresh)

                lines = []
                for i, per_img in enumerate(results):
                    lines.append(f"Image {i + 1}:")
                    for pmodel in per_img:
                        lines.append(f"  Model: {pmodel['model_repo']} (Device: {pmodel['device_type']})")
                        lines.append(f"    Inference Time: {pmodel['inference_time_s']} s")
                        lines.append(f"    Tags: {pmodel['tags']}")
                        lines.append(f"    Caption: {pmodel['caption']}")
                    lines.append("-----")
                textual_output = "\n".join(lines)

                if len(results) == 1 and len(results[0]) == 1:
                    heatmap_grid_output.update(value=BytesIO(results[0][0]["heatmap_grid"]))
                else:
                    heatmap_grid_output.update(value=None)

                return results, textual_output

            except Exception as e:
                error_message = f"An unexpected error occurred: {str(e)}"
                print(error_message)
                return [], error_message

        run_button.click(
            run_inference_fn,
            inputs=[
                input_images,
                model_choices,
                threshold_slider,
                partial_rows_toggle,
                custom_gen_threshold_slider,
                custom_char_threshold_slider
            ],
            outputs=[result_data, all_results_text]
        )

        def export_results(rdata):
            if not rdata:
                return None, None
            csv_content, json_content = export_csv_json(rdata)
            csv_file = io.BytesIO(csv_content.encode("utf-8"))
            csv_file.name = "results.csv"
            json_file = io.BytesIO(json_content.encode("utf-8"))
            json_file.name = "results.json"
            return (csv_file, json_file)

        export_button.click(
            export_results,
            inputs=[result_data],
            outputs=[csv_output, json_output]
        )

    demo.queue(max_size=10)
    demo.launch(share=True, debug=False)

# -----------------------
# END: app.py code
# -----------------------

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://5613563c27ac0dee18.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
