<a href="https://colab.research.google.com/github/Vladus-CPU/IQ/blob/main/work2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:

!pip install --quiet colorcet diffusers gradio==4.25.0 hf-transfer huggingface-hub==0.14.0 matplotlib \
    numpy==1.23.5 opencv-contrib-python-headless pandas==2.0.0 Pillow==9.5.0 rich git+https://github.com/huggingface/pytorch-image-models.git@main#egg=timm \
    tokenizers torch>=2.1.0 torchvision transformers
# При бажанні можете клонувати код з репозиторію, якщо він опублікований на GitHub,
# або просто зкопіювати його у наступну комірку. Наведений нижче приклад показує,
# як скопіювати файли безпосередньо:
# !git clone https://github.com/user/wd-tagger-heatmap-example.git

[31mERROR: Cannot install gradio==4.25.0 and huggingface-hub==0.14.0 because these package versions have conflicting dependencies.[0m[31m
[0m[31mERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts[0m[31m
[0m

In [25]:
# Install required packages
# !pip install gradio
import gradio as gr

# Перевірка версії Gradio
print(gr.__version__)
import math
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Optional

import numpy as np
import pandas as pd
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HfHubHTTPError
from PIL import Image
from torch import Tensor, nn
import torch.nn.functional as F

import gradio as gr
import cv2
import colorcet as cc
from matplotlib.colors import LinearSegmentedColormap

import timm
from timm.data import create_transform, resolve_data_config
from timm.models import VisionTransformer
from torchvision import transforms as T

# ---------------------
# BEGIN: common.py code
# ---------------------

@dataclass
class Heatmap:
    label: str
    score: float
    image: Image.Image

@dataclass
class LabelData:
    names: list[str]
    rating: list[np.int64]
    general: list[np.int64]
    character: list[np.int64]

@dataclass
class ImageLabels:
    caption: str
    booru: str
    rating: dict[str, float]
    general: dict[str, float]
    character: dict[str, float]

@lru_cache(maxsize=5)
def load_labels_hf(repo_id: str, revision: Optional[str] = None, token: Optional[str] = None) -> LabelData:
    try:
        csv_path = hf_hub_download(
            repo_id=repo_id,
            filename="selected_tags.csv",
            revision=revision,
            token=token
        )
        csv_path = Path(csv_path).resolve()
    except HfHubHTTPError as e:
        raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e

    df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
    tag_data = LabelData(
        names=df["name"].tolist(),
        rating=list(np.where(df["category"] == 9)[0]),
        general=list(np.where(df["category"] == 0)[0]),
        character=list(np.where(df["category"] == 4)[0]),
    )

    return tag_data

def mcut_threshold(probs: np.ndarray) -> float:
    """
    Maximum Cut Thresholding (MCut)
    """
    probs = probs[probs.argsort()[::-1]]
    diffs = probs[:-1] - probs[1:]
    idx = diffs.argmax()
    thresh = (probs[idx] + probs[idx + 1]) / 2
    return float(thresh)

def pil_ensure_rgb(image: Image.Image) -> Image.Image:
    if image.mode not in ["RGB", "RGBA"]:
        image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
    if image.mode == "RGBA":
        canvas = Image.new("RGBA", image.size, (255, 255, 255))
        canvas.alpha_composite(image)
        image = canvas.convert("RGB")
    return image

def pil_pad_square(image: Image.Image, fill: tuple[int, int, int] = (255, 255, 255)) -> Image.Image:
    w, h = image.size
    px = max(image.size)
    canvas = Image.new("RGB", (px, px), fill)
    canvas.paste(image, ((px - w) // 2, (px - h) // 2))
    return canvas

def preprocess_image(image: Image.Image, size_px: int | tuple[int, int], upscale: bool = True) -> Image.Image:
    if isinstance(size_px, int):
        size_px = (size_px, size_px)
    image = pil_ensure_rgb(image)
    image = pil_pad_square(image)
    if image.size[0] < size_px[0] or image.size[1] < size_px[1]:
        if not upscale:
            raise ValueError("Image is smaller than target size, and upscaling is disabled")
        image = image.resize(size_px, Image.LANCZOS)
    if image.size[0] > size_px[0] or image.size[1] > size_px[1]:
        image.thumbnail(size_px, Image.BICUBIC)
    return image

def pil_make_grid(
    images: list[Image.Image],
    max_cols: int = 8,
    padding: int = 4,
    bg_color: tuple[int, int, int] = (40, 42, 54),
    partial_rows: bool = True,
) -> Image.Image:
    n_cols = min(math.floor(math.sqrt(len(images))), max_cols)
    n_rows = math.ceil(len(images) / n_cols)
    if n_cols * n_rows > len(images) and not partial_rows:
        n_rows -= 1
    image_width, image_height = images[0].size
    canvas_width = ((image_width + padding) * n_cols) + padding
    canvas_height = ((image_height + padding) * n_rows) + padding
    canvas = Image.new("RGB", (canvas_width, canvas_height), bg_color)
    for i, img in enumerate(images):
        x = (i % n_cols) * (image_width + padding) + padding
        y = (i // n_cols) * (image_height + padding) + padding
        canvas.paste(img, (x, y))
    return canvas

# ---------------------
# END: common.py code
# ---------------------

# -----------------------
# BEGIN: model.py code
# -----------------------

class RGBtoBGR(nn.Module):
    def forward(self, x: Tensor) -> Tensor:
        if x.ndim == 4:
            return x[:, [2, 1, 0], :, :]
        return x[[2, 1, 0], :, :]

model_cache: dict[str, VisionTransformer] = {}
transform_cache: dict[str, T.Compose] = {}

def model_device(model: nn.Module) -> torch.device:
    return next(model.parameters()).device

def load_model(repo_id: str) -> VisionTransformer:
    global model_cache
    if model_cache.get(repo_id, None) is None:
        model_cache[repo_id] = timm.create_model("hf-hub:" + repo_id, pretrained=True).eval().to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    return model_cache[repo_id]

def load_model_and_transform(repo_id: str) -> tuple[VisionTransformer, T.Compose]:
    global transform_cache, model_cache
    if model_cache.get(repo_id, None) is None:
        model_cache[repo_id] = timm.create_model("hf-hub:" + repo_id, pretrained=True).eval()
    model = model_cache[repo_id]
    if transform_cache.get(repo_id, None) is None:
        transforms = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
        transform_cache[repo_id] = T.Compose(transforms.transforms + [RGBtoBGR()])
    return model, transform_cache[repo_id]

def get_tags(
    probs: Tensor,
    labels: LabelData,
    gen_threshold: float,
    char_threshold: float,
):
    probs = list(zip(labels.names, probs.numpy()))
    rating_labels = dict([probs[i] for i in labels.rating])
    gen_labels = [probs[i] for i in labels.general]
    gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold])
    gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))
    char_labels = [probs[i] for i in labels.character]
    char_labels = dict([x for x in char_labels if x[1] > char_threshold])
    char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True))
    combined_names = [x for x in gen_labels]
    combined_names.extend([x for x in char_labels])
    caption = ", ".join(combined_names).replace("(", "\\(").replace(")", "\\)")
    booru = caption.replace("_", " ")
    return caption, booru, rating_labels, char_labels, gen_labels

@torch.no_grad()
def render_heatmap(
    image: Tensor,
    gradients: Tensor,
    image_feats: Tensor,
    image_probs: Tensor,
    image_labels: list[str],
    cmap: LinearSegmentedColormap = cc.m_linear_bmy_10_95_c71,
    pos_embed_dim: int = 784,
    image_size: tuple[int, int] = (448, 448),
    font_args: dict = {
        "fontFace": cv2.FONT_HERSHEY_SIMPLEX,
        "fontScale": 1,
        "color": (255, 255, 255),
        "thickness": 2,
        "lineType": cv2.LINE_AA,
    },
    partial_rows: bool = True,
) -> tuple[list[Heatmap], Image.Image]:
    image_hmaps = gradients.mean(2, keepdim=True).mul(image_feats.unsqueeze(0)).squeeze()
    hmap_dim = int(math.sqrt(image_hmaps.mean(-1).numel() / len(image_labels)))
    image_hmaps = image_hmaps.mean(-1).reshape(len(image_labels), -1)
    image_hmaps = image_hmaps[..., -hmap_dim**2:]
    image_hmaps = image_hmaps.reshape(len(image_labels), hmap_dim, hmap_dim)
    image_hmaps = image_hmaps.max(torch.zeros_like(image_hmaps))
    image_hmaps /= image_hmaps.reshape(image_hmaps.shape[0], -1).max(-1)[0].unsqueeze(-1).unsqueeze(-1)
    image_hmaps = torch.stack([(x - x.min()) / (x.max() - x.min()) for x in image_hmaps]).unsqueeze(1)
    image_hmaps = F.interpolate(image_hmaps, size=image_size, mode="bilinear").squeeze(1)
    hmap_imgs = []
    for tag, hmap, score in zip(image_labels, image_hmaps, image_probs.cpu()):
        image_pixels = image.add(1).mul(127.5).squeeze().permute(1, 2, 0).cpu().numpy().astype(np.uint8)
        hmap_pixels = cmap(hmap.cpu().numpy(), bytes=True)[:, :, :3]
        hmap_cv2 = cv2.cvtColor(hmap_pixels, cv2.COLOR_RGB2BGR)
        hmap_image = cv2.addWeighted(image_pixels, 0.5, hmap_cv2, 0.5, 0)
        if tag is not None:
            cv2.putText(hmap_image, tag, (10, 30), **font_args)
            cv2.putText(hmap_image, f"{score:.3f}", (10, 60), **font_args)
        hmap_pil = Image.fromarray(cv2.cvtColor(hmap_image, cv2.COLOR_BGR2RGB))
        hmap_imgs.append(Heatmap(tag, score.item(), hmap_pil))
    hmap_imgs = sorted(hmap_imgs, key=lambda x: x.score, reverse=True)
    hmap_grid = pil_make_grid([x.image for x in hmap_imgs], partial_rows=partial_rows)
    return hmap_imgs, hmap_grid

def process_heatmap(
    model: VisionTransformer,
    image: Tensor,
    labels: LabelData,
    threshold: float = 0.5,
    partial_rows: bool = True,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    with torch.set_grad_enabled(True):
        features = model.forward_features(image.to(device))
        probs = model.forward_head(features)
        probs = torch.sigmoid(probs).squeeze(0)
        probs_mask = probs > threshold
        heatmap_probs = probs[probs_mask]
        label_indices = torch.nonzero(probs_mask, as_tuple=False).squeeze(1)
        image_labels = [labels.names[label_indices[i]] for i in range(len(label_indices))]
        eye = torch.eye(heatmap_probs.shape[0], device=device)
        grads = torch.autograd.grad(
            outputs=heatmap_probs,
            inputs=features,
            grad_outputs=eye,
            is_grads_batched=True,
            retain_graph=True,
        )
        grads = grads[0].detach().requires_grad_(False)[:, 0, :, :].unsqueeze(1)
    with torch.set_grad_enabled(False):
        hmap_imgs, hmap_grid = render_heatmap(
            image=image,
            gradients=grads,
            image_feats=features,
            image_probs=heatmap_probs,
            image_labels=image_labels,
            partial_rows=partial_rows,
        )
        caption, booru, ratings, character, general = get_tags(
            probs=probs.cpu(),
            labels=labels,
            gen_threshold=threshold,
            char_threshold=threshold,
        )
        image_labels_res = ImageLabels(caption, booru, ratings, general, character)
    return hmap_imgs, hmap_grid, image_labels_res

# -----------------------
# END: model.py code
# -----------------------

# -----------------------
# BEGIN: app.py code
# -----------------------
from os import getenv

TITLE = "WD Tagger Heatmap For More Models"
DESCRIPTION = """WD Tagger v3 Heatmap Generator."""
HF_TOKEN = getenv("HF_TOKEN", None)

AVAILABLE_MODEL_REPOS = [
    'SmilingWolf/wd-convnext-tagger-v3',
    'SmilingWolf/wd-swinv2-tagger-v3',
    'SmilingWolf/wd-vit-tagger-v3',
    'SmilingWolf/wd-vit-large-tagger-v3',
    "SmilingWolf/wd-eva02-large-tagger-v3",
]
MODEL_REPO = "SmilingWolf/wd-vit-tagger-v3"
WORK_DIR = Path(".").resolve()
IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"]

# Example images (can add paths to example images)
example_images = []

def predict(image: Image.Image, model_repo: str, threshold: float = 0.5):
    model, transform = load_model_and_transform(model_repo)
    labels: LabelData = load_labels_hf(model_repo)
    image = preprocess_image(image, (448, 448))
    image = transform(image).unsqueeze(0)
    heatmaps, heatmap_grid, image_labels = process_heatmap(model, image, labels, threshold)
    heatmap_images = [(x.image, x.label) for x in heatmaps]
    return (
        heatmap_images,
        heatmap_grid,
        image_labels.caption,
        image_labels.booru,
        image_labels.rating,
        image_labels.character,
        image_labels.general,
    )

# Enhanced CSS with scrollable functionality for Tags and Caption
css = """
/* Styling for Caption Textbox */
#caption_box {
    max-height: 300px;             /* Sets a maximum height */
    overflow-y: auto;              /* Adds vertical scroll if content exceeds max-height */
    padding: 10px;                 /* Adds padding for better readability */
    background-color: #3C3C3C;     /* Dark background for contrast */
    color: #FFFFFF;                /* White text for readability */
    border-radius: 1px;            /* Rounded corners */
}

/* Styling for Tags Textbox */
#tags_box {
    max-height: 300px;             /* Sets a maximum height */
    overflow-y: auto;              /* Adds vertical scroll if content exceeds max-height */
    padding: 10px;                 /* Adds padding for better readability */
    background-color: #3C3C3C;     /* Dark background for contrast */
    color: #FFFFFF;                /* White text for readability */
    border-radius: 1px;            /* Rounded corners */
}

/* Optional: Customize Scrollbar Appearance */
#caption_box::-webkit-scrollbar,
#tags_box::-webkit-scrollbar {
    width: 20px;                     /* Width of the scrollbar */
}

#caption_box::-webkit-scrollbar-track,
#tags_box::-webkit-scrollbar-track {
    background: #2D2D2D;            /* Track color */
    border-radius: 20px;
}

#caption_box::-webkit-scrollbar-thumb,
#tags_box::-webkit-scrollbar-thumb {
    background-color: #4CAF50;      /* Scrollbar thumb color */
    border-radius: 4px;
    border: 3px solid #2D2D2D;      /* Adds a border around the thumb */
}

/* Responsive Design Adjustments */
@media (max-width: 768px) {
    #caption_box, #tags_box {
        max-height: 100px;             /* Adjust max-height for smaller screens */
    }
}
"""

with gr.Blocks(theme="default", analytics_enabled=False, title=TITLE, css=css) as demo:
    with gr.Row(equal_height=False):
        with gr.Column(min_width=720):
            with gr.Group():
                img_input = gr.Image(
                    label="Input",
                    type="pil",
                    image_mode="RGB",
                    sources=["upload", "clipboard"],
                )
            with gr.Group():
                with gr.Row():
                    threshold = gr.Slider(
                        minimum=0.0,
                        maximum=1.0,
                        value=0.35,
                        step=0.01,
                        label="Tag Threshold",
                        scale=5,
                        elem_id="threshold",
                    )
                    model_to_use = gr.Dropdown(
                        choices=AVAILABLE_MODEL_REPOS,
                        value=MODEL_REPO,
                        label="Model Repository"
                    )
            with gr.Row():
                clear = gr.ClearButton(
                    components=[img_input, heatmap_gallery, heatmap_grid, caption, tags, rating, character, general],
                    variant="secondary",
                    size="lg"
                )
                submit = gr.Button(value="Submit", variant="primary", size="lg")

        with gr.Column(min_width=720):
            with gr.Tab(label="Heatmaps"):
                heatmap_gallery = gr.Gallery(columns=3, show_label=False)
            with gr.Tab(label="Grid"):
                heatmap_grid = gr.Image(show_label=False)
            with gr.Tab(label="Tags"):
                with gr.Group():
                    caption = gr.Textbox(
                        label="Caption",
                        show_copy_button=True,
                        lines=20,                # Sets the number of visible lines
                        elem_id="caption_box"  # Assigned for CSS targeting
                    )
                    tags = gr.Textbox(
                        label="Tags",
                        show_copy_button=True,
                        lines=20,              # Sets the number of visible lines
                        elem_id="tags_box"   # Assigned for CSS targeting
                    )
                with gr.Group():
                    rating = gr.Label(label="Rating")
                with gr.Group():
                    character = gr.Label(label="Character")
                with gr.Group():
                    general = gr.Label(label="General")

    with gr.Row():
        # Якщо у вас є приклади зображень, вкажіть їх у масиві example_images
        example_inputs = [[img, MODEL_REPO, 0.35] for img in example_images]
        examples = gr.Examples(
            examples=example_inputs,
            inputs=[img_input, model_to_use, threshold],
        )

    # Define clear button functionality
    clear.click(
        lambda: None,
        outputs=[img_input, heatmap_gallery, heatmap_grid, caption, tags, rating, character, general],
    )

    # Define submit button functionality
    submit.click(
        predict,
        inputs=[img_input, model_to_use, threshold],
        outputs=[heatmap_gallery, heatmap_grid, caption, tags, rating, character, general],
        api_name="predict",
    )

# Launch the app
demo.queue(max_size=10)
demo.launch(server_name="0.0.0.0", server_port=1200, debug=False)

5.9.1
Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://10e43b84272cd74416.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


