In [1]:
!pip install pillow>=9.0.0 onnxruntime>=1.12.0 huggingface-hub

In [3]:


#@markdown SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"

#@markdown CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"

#@markdown VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"

#@markdown

#@markdown MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"

#@markdown SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"

#@markdown CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"

#@markdown CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"

#@markdown VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"

import argparse
import os
import huggingface_hub
import numpy as np
import onnxruntime as rt
import pandas as pd
from PIL import Image
from datetime import datetime

HF_TOKEN = os.environ.get("HF_TOKEN")

# Dataset v3 series of models:
SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"

# Dataset v2 series of models:
MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"

# Files to download from the repos
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"

kaomojis = [
    "0_0",
    "(o)_(o)",
    "+_+",
    "+_-",
    "._.",
    "<o>_<o>",
    "<|>_<|>",
    "=_=",
    ">_<",
    "3_3",
    "6_9",
    ">_o",
    "@_@",
    "^_^",
    "o_o",
    "u_u",
    "x_x",
    "|_|",
    "||_||",
]

def load_labels(dataframe) -> list[str]:
    name_series = dataframe["name"]
    name_series = name_series.map(
        lambda x: x.replace("_", " ") if x not in kaomojis else x
    )
    tag_names = name_series.tolist()

    rating_indexes = list(np.where(dataframe["category"] == 9)[0])
    general_indexes = list(np.where(dataframe["category"] == 0)[0])
    character_indexes = list(np.where(dataframe["category"] == 4)[0])
    return tag_names, rating_indexes, general_indexes, character_indexes

class Predictor:
    def __init__(self):
        self.model_target_size = None
        self.last_loaded_repo = None

    def download_model(self, model_repo):
        csv_path = huggingface_hub.hf_hub_download(
            model_repo,
            LABEL_FILENAME,
            use_auth_token=HF_TOKEN,
        )
        model_path = huggingface_hub.hf_hub_download(
            model_repo,
            MODEL_FILENAME,
            use_auth_token=HF_TOKEN,
        )
        return csv_path, model_path

    def load_model(self, model_repo):
        if model_repo == self.last_loaded_repo:
            return

        csv_path, model_path = self.download_model(model_repo)

        tags_df = pd.read_csv(csv_path)
        sep_tags = load_labels(tags_df)

        self.tag_names = sep_tags[0]
        self.rating_indexes = sep_tags[1]
        self.general_indexes = sep_tags[2]
        self.character_indexes = sep_tags[3]

        model = rt.InferenceSession(model_path)
        _, height, width, _ = model.get_inputs()[0].shape
        self.model_target_size = height

        self.last_loaded_repo = model_repo
        self.model = model

    def prepare_image(self, image):
        target_size = self.model_target_size

        # Convert the image to RGB format
        image = image.convert("RGB")

        # Resize the image to the target size
        image = image.resize((target_size, target_size), Image.BICUBIC)

        image_array = np.asarray(image, dtype=np.float32)
        image_array = image_array[:, :, ::-1]

        return np.expand_dims(image_array, axis=0)


    def predict(
        self,
        image_path,
        model_repo,
        general_thresh,
        general_mcut_enabled,
        character_thresh,
        character_mcut_enabled,
    ):
        self.load_model(model_repo)

        image = Image.open(image_path)
        image = self.prepare_image(image)

        input_name = self.model.get_inputs()[0].name
        label_name = self.model.get_outputs()[0].name
        preds = self.model.run([label_name], {input_name: image})[0]

        labels = list(zip(self.tag_names, preds[0].astype(float)))

        ratings_names = [labels[i] for i in self.rating_indexes]
        rating = dict(ratings_names)

        general_names = [labels[i] for i in self.general_indexes]

        if general_mcut_enabled:
            general_probs = np.array([x[1] for x in general_names])
            general_thresh = self.mcut_threshold(general_probs)

        general_res = [x for x in general_names if x[1] > general_thresh]
        general_res = dict(general_res)

        character_names = [labels[i] for i in self.character_indexes]

        if character_mcut_enabled:
            character_probs = np.array([x[1] for x in character_names])
            character_thresh = self.mcut_threshold(character_probs)
            character_thresh = max(0.15, character_thresh)

        character_res = [x for x in character_names if x[1] > character_thresh]
        character_res = dict(character_res)

        sorted_general_strings = sorted(
            general_res.items(),
            key=lambda x: x[1],
            reverse=True,
        )
        sorted_general_strings = [x[0] for x in sorted_general_strings]
        sorted_general_strings = (
            ", ".join(sorted_general_strings).replace("(", "\(").replace(")", "\)")
        )

        return sorted_general_strings, rating, character_res, general_res

    def mcut_threshold(self, probs):
        sorted_probs = probs[probs.argsort()[::-1]]
        difs = sorted_probs[:-1] - sorted_probs[1:]
        t = difs.argmax()
        thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
        return thresh
# def main():
#     predictor = Predictor()

#     input_folder = "/content/images"  # Provide the path to the folder containing input images
#     output_folder = "output_folder"  # Provide the path to the folder where output text files will be saved

#     model_repo = SWINV2_MODEL_DSV3_REPO  # You can change the model repository if needed

#     general_thresh = 0.35  # Set your desired general threshold
#     general_mcut_enabled = False  # Set to True if you want to use MCut threshold for general tags
#     character_thresh = 0.85  # Set your desired character threshold
#     character_mcut_enabled = False  # Set to True if you want to use MCut threshold for character tags

#     # Iterate through each file in the input folder
#     for filename in os.listdir(input_folder):
#         if filename.endswith(('.jpg', '.jpeg', '.png')):  # Check if file is an image
#             image_path = os.path.join(input_folder, filename)
#             output_file = os.path.splitext(filename)[0] + ".txt"
#             output_path = os.path.join(output_folder, output_file)

#             # Predict tags for the current image
#             sorted_general_tags, rating, character_res, general_res = predictor.predict(
#                 image_path,
#                 model_repo,
#                 general_thresh,
#                 general_mcut_enabled,
#                 character_thresh,
#                 character_mcut_enabled,
#             )

#             # Write the sorted general tags to a text file
#             with open(output_path, 'w') as f:
#                 f.write("Sorted General Tags: {}\n".format(sorted_general_tags))
#                 f.write("Ratings: {}\n".format(rating))
#                 f.write("Character Results: {}\n".format(character_res))
#                 f.write("General Results: {}\n".format(general_res))

# if __name__ == "__main__":
#     main()


def main():
    predictor = Predictor()

    input_folder = "/content/images" # @param {type:"string"}
    output_folder = "/content/output_txt" # @param {type:"string"}
    model_repo = SWINV2_MODEL_DSV3_REPO # @param ["SWINV2_MODEL_DSV3_REPO","CONV_MODEL_DSV3_REPO","VIT_MODEL_DSV3_REPO","MOAT_MODEL_DSV2_REPO","SWIN_MODEL_DSV2_REPO","CONV_MODEL_DSV2_REPO","CONV2_MODEL_DSV2_REPO","VIT_MODEL_DSV2_REPO"] {type:"raw"}
    general_thresh = 0.35 #@param
    general_mcut_enabled = False #@param
    character_thresh = 0.85 #@param
    character_mcut_enabled = False #@param

    for filename in os.listdir(input_folder):
        if filename.endswith(('.jpg', '.jpeg', '.png')):
            image_path = os.path.join(input_folder, filename)
            output_file = os.path.splitext(filename)[0] + ".txt"
            output_path = os.path.join(output_folder, output_file)

            # Predict tags for the current image
            sorted_general_tags, rating, character_res, general_res = predictor.predict(
                image_path,
                model_repo,
                general_thresh,
                general_mcut_enabled,
                character_thresh,
                character_mcut_enabled,
            )

            # Extract sorted general tags
            sorted_general_tags = sorted(general_res.keys())

            # Write the sorted general tags to a text file
            with open(output_path, 'w') as f:
                f.write(", ".join(sorted_general_tags))

if __name__ == "__main__":
    main()



The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


selected_tags.csv:   0%|          | 0.00/308k [00:00<?, ?B/s]

model.onnx:   0%|          | 0.00/467M [00:00<?, ?B/s]