<a href="https://colab.research.google.com/github/Vladus-CPU/IQ/blob/main/work2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ----------------------------------------------------------------------------
# app.py (Extended version with artist style detection from "Laxhar/noob-wiki")
# ----------------------------------------------------------------------------
!pip install --quiet colorcet diffusers gradio hf-transfer huggingface-hub==0.14.0 matplotlib \
    numpy==1.23.5 opencv-contrib-python-headless pandas Pillow rich git+https://github.com/huggingface/pytorch-image-models.git@main#egg=timm
!pip install tokenizers torch torchvision transformers
!pip install numpy pandas torch torchvision timm datasets Pillow huggingface_hub opencv-python colorcet gradio
!pip install numpy pandas torch torchvision timm gradio matplotlib colorcet pillow opencv-python datasets huggingface-hub
!pip install pandas
!pip install datasets huggingface_hub
!pip install scikit-learn seaborn
!pip install --upgrade pip
import math
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Optional

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch import nn, Tensor

import gradio as gr
import cv2
import colorcet as cc
from matplotlib.colors import LinearSegmentedColormap

import timm
from timm.data import create_transform, resolve_data_config
from timm.models import VisionTransformer
from torchvision import transforms as T

from PIL import Image
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HfHubHTTPError

# Additional dataset import for artist style detection
import datasets

# ----------------------------------------------------------------------------
# BEGIN: Extended data classes and label loading
# ----------------------------------------------------------------------------

@dataclass
class LabelData:
    """
    Original label structure
    """
    names: list[str]
    rating: list[np.int64]
    general: list[np.int64]
    character: list[np.int64]


@dataclass
class LabelDataExtended:
    """
    Extended label data to include artist styles
    """
    names: list[str]
    rating: list[np.int64]
    general: list[np.int64]
    character: list[np.int64]
    artist_style: list[np.int64]


@dataclass
class Heatmap:
    label: str
    score: float
    image: Image.Image


@dataclass
class ImageLabels:
    caption: str
    booru: str
    rating: dict[str, float]
    general: dict[str, float]
    character: dict[str, float]
    # For convenience, you can add the artist labels here if needed:
    # artist_style: dict[str, float]


@lru_cache(maxsize=5)
def load_labels_hf(repo_id: str, revision: Optional[str] = None, token: Optional[str] = None) -> LabelData:
    """
    Loads the original label data from the Hugging Face repository (e.g. WD Tagger's CSV file).
    """
    try:
        csv_path = hf_hub_download(
            repo_id=repo_id,
            filename="selected_tags.csv",
            revision=revision,
            token=token
        )
        csv_path = Path(csv_path).resolve()
    except HfHubHTTPError as e:
        raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e

    df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
    tag_data = LabelData(
        names=df["name"].tolist(),
        rating=list(np.where(df["category"] == 9)[0]),
        general=list(np.where(df["category"] == 0)[0]),
        character=list(np.where(df["category"] == 4)[0]),
    )
    return tag_data


@lru_cache(maxsize=5)
def load_labels_ext(repo_id: str, revision: Optional[str] = None, token: Optional[str] = None) -> LabelDataExtended:
    """
    Loads the extended label data with artist styles.
    It contains original WD labels plus an extra category for "artist style"
    from the "Laxhar/noob-wiki" dataset (as an example).
    """
    # Load the base tag data from the WD Tagger
    base_data = load_labels_hf(repo_id=repo_id, revision=revision, token=token)

    # Example: load the additional dataset for artist style
    # This is an illustrative example. Adjust to your dataset's actual format.
    artist_ds = datasets.load_dataset("Laxhar/noob-wiki", split="train")

    # Suppose the dataset has "name" and "category" columns,
    # and "category == 5" indicates artist style.
    # Adjust this logic to your actual dataset structure.
    artist_style_indices = []
    # In real usage, you'd probably have a separate set of names+indices
    # or a method to map them onto your existing label set.
    # Here we just demonstrate a possible approach:
    for i, row in enumerate(artist_ds):
        if "category" in row and row["category"] == 5:
            artist_style_indices.append(i)

    # Create an extended data structure
    return LabelDataExtended(
        names=base_data.names,              # existing WD names
        rating=base_data.rating,            # rating indices
        general=base_data.general,          # general indices
        character=base_data.character,      # character indices
        artist_style=np.array(artist_style_indices, dtype=np.int64)  # new field
    )

# ----------------------------------------------------------------------------
# BEGIN: Utility functions
# ----------------------------------------------------------------------------

def pil_ensure_rgb(image: Image.Image) -> Image.Image:
    """
    Ensures the Pillow image is in a consistent RGB or RGBA mode.
    """
    if image.mode not in ["RGB", "RGBA"]:
        image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
    if image.mode == "RGBA":
        canvas = Image.new("RGBA", image.size, (255, 255, 255))
        canvas.alpha_composite(image)
        image = canvas.convert("RGB")
    return image


def pil_pad_square(image: Image.Image, fill: tuple[int, int, int] = (255, 255, 255)) -> Image.Image:
    """
    Pads an image to a square with a given fill color.
    """
    w, h = image.size
    px = max(image.size)
    canvas = Image.new("RGB", (px, px), fill)
    canvas.paste(image, ((px - w) // 2, (px - h) // 2))
    return canvas


def preprocess_image(image: Image.Image, size_px: int | tuple[int, int], upscale: bool = True) -> Image.Image:
    """
    Preprocesses an image by converting to RGB, padding to a square,
    and resizing/thumbnails to `size_px`.
    """
    if isinstance(size_px, int):
        size_px = (size_px, size_px)
    image = pil_ensure_rgb(image)
    image = pil_pad_square(image)
    if image.size[0] < size_px[0] or image.size[1] < size_px[1]:
        if not upscale:
            raise ValueError("Image is smaller than target size, and upscaling is disabled")
        image = image.resize(size_px, Image.LANCZOS)
    if image.size[0] > size_px[0] or image.size[1] > size_px[1]:
        image.thumbnail(size_px, Image.BICUBIC)
    return image


def pil_make_grid(
    images: list[Image.Image],
    max_cols: int = 8,
    padding: int = 4,
    bg_color: tuple[int, int, int] = (40, 42, 54),
    partial_rows: bool = True,
) -> Image.Image:
    """
    Creates a grid of images (like a contact sheet).
    """
    n_cols = min(math.floor(math.sqrt(len(images))), max_cols)
    n_rows = math.ceil(len(images) / n_cols)
    if n_cols * n_rows > len(images) and not partial_rows:
        n_rows -= 1
    image_width, image_height = images[0].size
    canvas_width = ((image_width + padding) * n_cols) + padding
    canvas_height = ((image_height + padding) * n_rows) + padding
    canvas = Image.new("RGB", (canvas_width, canvas_height), bg_color)
    for i, img in enumerate(images):
        x = (i % n_cols) * (image_width + padding) + padding
        y = (i // n_cols) * (image_height + padding) + padding
        canvas.paste(img, (x, y))
    return canvas


def mcut_threshold(probs: np.ndarray) -> float:
    """
    Maximum Cut Thresholding (MCut)
    """
    probs = probs[probs.argsort()[::-1]]
    diffs = probs[:-1] - probs[1:]
    idx = diffs.argmax()
    thresh = (probs[idx] + probs[idx + 1]) / 2
    return float(thresh)

# ----------------------------------------------------------------------------
# BEGIN: Model creation and caching
# ----------------------------------------------------------------------------

class RGBtoBGR(nn.Module):
    """
    Transforms an image from RGB to BGR channel order.
    """
    def forward(self, x: Tensor) -> Tensor:
        if x.ndim == 4:
            return x[:, [2, 1, 0], :, :]
        return x[[2, 1, 0], :, :]


model_cache: dict[str, VisionTransformer] = {}
transform_cache: dict[str, T.Compose] = {}


def model_device(model: nn.Module) -> torch.device:
    return next(model.parameters()).device


def load_model(repo_id: str) -> VisionTransformer:
    """
    Loads or retrieves the model from the timm HF Hub reference.
    """
    global model_cache
    if model_cache.get(repo_id, None) is None:
        model_cache[repo_id] = (
            timm.create_model("hf-hub:" + repo_id, pretrained=True)
            .eval()
            .to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
        )
    return model_cache[repo_id]


def load_model_and_transform(repo_id: str) -> tuple[VisionTransformer, T.Compose]:
    """
    Loads the model and its transforms (with optional RGB->BGR).
    """
    global transform_cache, model_cache
    if model_cache.get(repo_id, None) is None:
        model_cache[repo_id] = timm.create_model("hf-hub:" + repo_id, pretrained=True).eval()
    model = model_cache[repo_id]
    if transform_cache.get(repo_id, None) is None:
        transforms = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
        transform_cache[repo_id] = T.Compose(transforms.transforms + [RGBtoBGR()])
    return model, transform_cache[repo_id]

# ----------------------------------------------------------------------------
# BEGIN: Tag extraction with extended functionality
# ----------------------------------------------------------------------------

def get_tags_extended(
    probs: Tensor,
    labels: LabelDataExtended,
    gen_threshold: float,
    char_threshold: float,
    artist_threshold: float = 0.5,
):
    """
    Extended function to retrieve rating, general, character, and artist styles.
    """
    probs_list = list(zip(labels.names, probs.numpy()))

    # Original logic
    rating_labels = dict([probs_list[i] for i in labels.rating])
    gen_list = [probs_list[i] for i in labels.general]
    gen_labels = dict([x for x in gen_list if x[1] > gen_threshold])
    gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))

    char_list = [probs_list[i] for i in labels.character]
    char_labels = dict([x for x in char_list if x[1] > char_threshold])
    char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True))

    # New logic for artist style:
    style_list = [probs_list[i] for i in labels.artist_style if i < len(probs_list)]
    style_labels = dict([x for x in style_list if x[1] > artist_threshold])
    style_labels = dict(sorted(style_labels.items(), key=lambda item: item[1], reverse=True))

    # Combine names for caption
    combined_names = list(gen_labels.keys()) + list(char_labels.keys()) + list(style_labels.keys())

    caption = ", ".join(combined_names).replace("(", "\(").replace(")", "\)")
    booru = caption.replace("_", " ")

    return caption, booru, rating_labels, char_labels, gen_labels, style_labels


@torch.no_grad()
def render_heatmap(
    image: Tensor,
    gradients: Tensor,
    image_feats: Tensor,
    image_probs: Tensor,
    image_labels: list[str],
    cmap: LinearSegmentedColormap = cc.m_linear_bmy_10_95_c71,
    image_size: tuple[int, int] = (448, 448),
    font_args: dict = {
        "fontFace": cv2.FONT_HERSHEY_SIMPLEX,
        "fontScale": 1,
        "color": (255, 255, 255),
        "thickness": 2,
        "lineType": cv2.LINE_AA,
    },
    partial_rows: bool = True,
) -> tuple[list[Heatmap], Image.Image]:
    """
    Renders heatmaps by applying gradients to the image and creating
    color overlays for each label discovered above threshold.
    """
    # Reduce along tokens dimension
    image_hmaps = gradients.mean(2, keepdim=True).mul(image_feats.unsqueeze(0)).squeeze()
    hmap_dim = int(math.sqrt(image_hmaps.mean(-1).numel() / len(image_labels)))
    image_hmaps = image_hmaps.mean(-1).reshape(len(image_labels), -1)
    image_hmaps = image_hmaps[..., -hmap_dim**2:]
    image_hmaps = image_hmaps.reshape(len(image_labels), hmap_dim, hmap_dim)
    image_hmaps = image_hmaps.max(torch.zeros_like(image_hmaps))
    image_hmaps /= image_hmaps.reshape(image_hmaps.shape[0], -1).max(-1)[0].unsqueeze(-1).unsqueeze(-1)
    image_hmaps = torch.stack([(x - x.min()) / (x.max() - x.min()) for x in image_hmaps]).unsqueeze(1)
    image_hmaps = F.interpolate(image_hmaps, size=image_size, mode="bilinear").squeeze(1)

    hmap_imgs = []
    for tag, hmap, score in zip(image_labels, image_hmaps, image_probs.cpu()):
        # Convert to 0..255 pixel range
        image_pixels = image.add(1).mul(127.5).squeeze().permute(1, 2, 0).cpu().numpy().astype(np.uint8)
        hmap_pixels = cmap(hmap.cpu().numpy(), bytes=True)[:, :, :3]
        hmap_cv2 = cv2.cvtColor(hmap_pixels, cv2.COLOR_RGB2BGR)
        hmap_image = cv2.addWeighted(image_pixels, 0.5, hmap_cv2, 0.5, 0)
        if tag is not None:
            cv2.putText(hmap_image, tag, (10, 30), **font_args)
            cv2.putText(hmap_image, f"{score:.3f}", (10, 60), **font_args)
        hmap_pil = Image.fromarray(cv2.cvtColor(hmap_image, cv2.COLOR_BGR2RGB))
        hmap_imgs.append(Heatmap(tag, score.item(), hmap_pil))

    hmap_imgs = sorted(hmap_imgs, key=lambda x: x.score, reverse=True)
    hmap_grid = pil_make_grid([x.image for x in hmap_imgs], partial_rows=partial_rows)
    return hmap_imgs, hmap_grid


def process_heatmap_extended(
    model: VisionTransformer,
    image: Tensor,
    labels: LabelDataExtended,
    threshold: float = 0.5,
    partial_rows: bool = True,
):
    """
    Core processing with the extended label data which includes artist style detection.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    with torch.set_grad_enabled(True):
        features = model.forward_features(image.to(device))
        probs = model.forward_head(features)
        probs = torch.sigmoid(probs).squeeze(0)

        # For demonstration, let's pick all tags over threshold:
        probs_mask = probs > threshold
        heatmap_probs = probs[probs_mask]
        label_indices = torch.nonzero(probs_mask, as_tuple=False).squeeze(1)
        image_labels = [labels.names[idx] for idx in label_indices if idx < len(labels.names)]

        # Calculate gradients
        eye = torch.eye(heatmap_probs.shape[0], device=device)
        grads = torch.autograd.grad(
            outputs=heatmap_probs,
            inputs=features,
            grad_outputs=eye,
            is_grads_batched=True,
            retain_graph=True,
        )[0]
        grads = grads.detach().requires_grad_(False)[:, 0, :, :].unsqueeze(1)

    with torch.no_grad():
        # Render heatmaps
        hmap_imgs, hmap_grid = render_heatmap(
            image=image,
            gradients=grads,
            image_feats=features,
            image_probs=heatmap_probs,
            image_labels=image_labels,
            partial_rows=partial_rows,
        )

        # Get extended tags (including artist styles)
        caption, booru, ratings, character, general, style_labels = get_tags_extended(
            probs=probs.cpu(),
            labels=labels,
            gen_threshold=threshold,
            char_threshold=threshold,
            artist_threshold=threshold  # reuse same threshold or set a new one
        )

        # Create a result object for simpler return
        image_labels_res = ImageLabels(
            caption=caption,
            booru=booru,
            rating=ratings,
            general=general,
            character=character
        )

    # Optionally return style_labels if you want to display them separately
    return hmap_imgs, hmap_grid, image_labels_res, style_labels

# ----------------------------------------------------------------------------
# BEGIN: Gradio UI with extended functionality
# ----------------------------------------------------------------------------

from os import getenv

TITLE = "WD Tagger Heatmap + Artist Style Enhanced"
DESCRIPTION = """Example Gradio app: WD Tagger with heatmap and additional detection for artist style."""
HF_TOKEN = getenv("HF_TOKEN", None)

AVAILABLE_MODEL_REPOS = [
    'SmilingWolf/wd-convnext-tagger-v3',
    'SmilingWolf/wd-swinv2-tagger-v3',
    'SmilingWolf/wd-vit-tagger-v3',
    'SmilingWolf/wd-vit-large-tagger-v3',
    "SmilingWolf/wd-eva02-large-tagger-v3",
]
MODEL_REPO = "SmilingWolf/wd-vit-tagger-v3"

WORK_DIR = Path(".").resolve()
IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"]

# If you have example images, you can list them here
example_images = []  # e.g., ["examples/image1.png", "examples/image2.jpg"]

def predict_extended(image: Image.Image, model_repo: str, threshold: float = 0.5):
    """
    Main inference function for Gradio UI that uses extended label data
    for artist style detection.
    """
    # Load the model and transforms
    model, transform = load_model_and_transform(model_repo)
    # Load extended label data
    labels_ext = load_labels_ext(model_repo)
    # Preprocess the input image
    image = preprocess_image(image, (448, 448))
    image_tensor = transform(image).unsqueeze(0)

    # Process with extended logic
    heatmaps, heatmap_grid, image_labels, style_labels = process_heatmap_extended(
        model=model,
        image=image_tensor,
        labels=labels_ext,
        threshold=threshold
    )

    heatmap_images = [(x.image, x.label) for x in heatmaps]

    # Combine style labels into a separate dictionary or text if you want to present them
    style_str = ", ".join([f"{s} ({style_labels[s]:.3f})" for s in style_labels])

    return (
        heatmap_images,
        heatmap_grid,
        image_labels.caption,
        image_labels.booru,
        image_labels.rating,
        image_labels.character,
        image_labels.general,
        style_str
    )

css = """
#use_mcut, #char_mcut {
    padding-top: var(--scale-3);
}
#threshold.dimmed {
    filter: brightness(75%);
}
"""

with gr.Blocks(theme="default", analytics_enabled=False, title=TITLE, css=css) as demo:
    with gr.Row(equal_height=False):
        with gr.Column(min_width=720):
            with gr.Group():
                img_input = gr.Image(
                    label="Input",
                    type="pil",
                    image_mode="RGB",
                    sources=["upload", "clipboard"],
                )
            with gr.Group():
                with gr.Row():
                    threshold = gr.Slider(
                        minimum=0.0,
                        maximum=1.0,
                        value=0.35,
                        step=0.01,
                        label="Tag Threshold",
                        scale=5,
                        elem_id="threshold",
                    )
                    model_to_use = gr.Dropdown(choices=AVAILABLE_MODEL_REPOS, value=MODEL_REPO)
            with gr.Row():
                clear = gr.ClearButton(components=[], variant="secondary", size="lg")
                submit = gr.Button(value="Submit", variant="primary", size="lg")

        with gr.Column(min_width=720):
            with gr.Tab(label="Heatmaps"):
                heatmap_gallery = gr.Gallery(columns=3, show_label=False)
            with gr.Tab(label="Grid"):
                heatmap_grid = gr.Image(show_label=False)
            with gr.Tab(label="Tags"):
                with gr.Group():
                    caption = gr.Textbox(label="Caption", show_copy_button=True)
                    tags = gr.Textbox(label="Tags", show_copy_button=True, lines=2)
                with gr.Group():
                    rating = gr.Label(label="Rating")
                    character = gr.Label(label="Character")
                    general = gr.Label(label="General")
                    artist_style_str = gr.Textbox(label="Artist Style Detection")

    with gr.Row():
        # If you have example images, set them up for demonstration
        example_inputs = [[img, MODEL_REPO, 0.35] for img in example_images]
        examples = gr.Examples(
            examples=example_inputs,
            inputs=[img_input, model_to_use, threshold],
        )

    clear.add([img_input, heatmap_gallery, heatmap_grid, caption, tags, rating, character, general, artist_style_str])

    submit.click(
        predict_extended,
        inputs=[img_input, model_to_use, threshold],
        outputs=[
            heatmap_gallery,
            heatmap_grid,
            caption,
            tags,
            rating,
            character,
            general,
            artist_style_str
        ],
        api_name="predict_extended",
    )

demo.queue(max_size=10)
demo.launch(server_name="0.0.0.0", server_port=7871, debug=False)

Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on public URL: https://00b68b32c48fc84c3e.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


