<a href="https://colab.research.google.com/github/WSH032/wd-v1-4-tagger-feature-extractor-tutorials/blob/main/candidate_labels.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# 安装依赖

!pip install torch huggingface_hub transformers toml numpy pandas


In [1]:
# 下载tag文件

import os

from huggingface_hub import hf_hub_download


DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
FILES = ["selected_tags.csv"]
CSV_FILE = FILES[-1]

download_dir = "wd14_tagger_model" # 模型保存目录

def download_keras_model(repo_id: str, download_dir: str) -> None:
    print(f"downloading wd14 tagger model from hf_hub. id: {repo_id}")
    for file in FILES:
        hf_hub_download(repo_id, file, force_download=True, local_dir=download_dir)


download_keras_model(DEFAULT_WD14_TAGGER_REPO, download_dir)


downloading wd14 tagger model from hf_hub. id: SmilingWolf/wd-v1-4-moat-tagger-v2


Downloading (…)in/selected_tags.csv:   0%|          | 0.00/254k [00:00<?, ?B/s]

In [2]:
# 读取selected_tags.csv中内容

import csv


selected_tags_csv_path = os.path.join(download_dir, CSV_FILE)
with open(selected_tags_csv_path, 'r') as f:
    reader = csv.reader(f)
    selected_tags_csv = [row for row in reader]

rating_tags = [row[1] for row in selected_tags_csv if row[2] == "9"]
general_tags = [row[1] for row in selected_tags_csv if row[2] == "0"]
character_tags = [row[1] for row in selected_tags_csv if row[2] == "4"]

print("rating_tags", len(rating_tags))
print("general_tags", len(general_tags))
print("character_tags", len(character_tags))

rating_tags 4
general_tags 6947
character_tags 2132


In [3]:
# 将用于对general_tag进行zero-shot分类的候选标签

candidate_labels_dict = {
    "items": ["items", "food", "flowers", "furniture"],
    "image_composition": ["image composition", "background", "color", "perspective", "style", "art style", "image style"],
    "environment": ["environment", "buildings", "cities", "indoors", "outdoors", "scene", "sky", "nature", "animals"],
    "characters": [
        "characters",
        "role",
        "human type",

        "body",
        "body parts",
        "body ornament",
        "body decoration",
        "breasts",
        "ears",
        "head ornament"
        "face",
        "facial features",
        "hair",
        "hair color",
        "hair style",
        "hair ornament",
        "neck",
        "eyes",
        "eyes color",
        "foot",
        "lowerbody",
        "upperbody",

        "ornament",
        "decorations",
        "clothes",
        "shoes",
        "legwear",

        "actions",
        "expressions",
        "mood",
        "character pose",

    ],
    "quality": ["picture quality"],
}

In [4]:
# 转为列表，便于输入transformers

candidate_labels_list = []
for value in candidate_labels_dict.values():
    if isinstance(value, list):
        candidate_labels_list.extend(value)
    else:
        candidate_labels_list.append(value)
print(f"共有{len(candidate_labels_list)}个候选标签")

共有51个候选标签


In [5]:
# 载入zero-shot模型

import torch
from transformers import pipeline


zero_shot_model_rep = "sileod/deberta-v3-base-tasksource-nli"
device = 0 if torch.cuda.is_available() else -1  # 有显卡就用显卡


classifier = pipeline(
    "zero-shot-classification",
    model=zero_shot_model_rep,
    device=device,
    framework="pt"
)


In [6]:
# 进行分类

batch_size = 2048  # 6g显存只能256，这里使用的是colab的T4显卡

general_classifier_result = classifier(
    [tag.replace("_", " ")for tag in general_tags],  # 一定要把"_"替换为" " ； 不然影响预测
    candidate_labels_list,
    batch_size=batch_size,
    multi_label=True,  # 注意要开启这个
)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [8]:
import numpy as np


# 声明空间
res_np = np.zeros((51, 6947))

# 将general_res的结果转换为numpy数组
for i, res in enumerate(general_classifier_result):
    # 每个label和各自对应的score绑定
    temp = list( zip( res["labels"], res["scores"] ) )
    # 统一按candidate_labels_list内顺序排序
    temp.sort(key = lambda x: candidate_labels_list.index( x[0] ) )
    # 将得分变成一维矩阵
    temp_array = np.array( [t[1] for t in temp] )
    # 储存到res_np中
    res_np[:,i] = temp_array


In [12]:
candidate_labels_np = np.array(candidate_labels_list)

np_savez_name = "candidate_labels_scores_safetensors.npz"

np.savez(
    np_savez_name,
    candidate_labels = candidate_labels_np,
    scores = res_np,  # 默认好像是64位精度
)


In [13]:
import pandas as pd

# 以candidate_labels_list为行名，general_tags为列名
df = pd.DataFrame(
    res_np,
    index=candidate_labels_np,
    columns=general_tags,
    )

# 显示df每一列中最大值那一行的行名
print( df.idxmax(axis=0) )

df.head()


1girl                   human type
solo                    human type
long_hair                     hair
breasts                    breasts
looking_at_viewer      perspective
                         ...      
yellow_socks               legwear
animal_on_hand             animals
red_mittens                  color
rabbit_on_head       hair ornament
qingxin_flower             flowers
Length: 6947, dtype: object


Unnamed: 0,1girl,solo,long_hair,breasts,looking_at_viewer,blush,smile,short_hair,open_mouth,bangs,...,stuffed_dog,four-leaf_clover_hair_ornament,year_of_the_rooster,person_on_head,lifebuoy_ornament,yellow_socks,animal_on_hand,red_mittens,rabbit_on_head,qingxin_flower
items,0.130697,0.077979,0.061493,0.158774,0.122784,0.123849,0.101619,0.104703,0.052109,0.393595,...,0.499052,0.340415,0.058474,0.025486,0.258079,0.712743,0.041068,0.833878,0.10621,0.161397
food,0.127122,0.075766,0.003528,0.591244,0.124879,0.061938,0.072735,0.006649,0.613572,0.029793,...,0.25562,0.029732,0.152246,0.02108,0.01387,0.008024,0.393119,0.010228,0.367474,0.234317
flowers,0.060849,0.023409,0.008003,0.00709,0.113476,0.099993,0.052043,0.012484,0.010939,0.00847,...,0.002113,0.115879,0.018505,0.010001,0.020184,0.015922,0.009708,0.007625,0.009366,0.935317
furniture,0.051603,0.053115,0.00495,0.011781,0.081355,0.024549,0.06487,0.010401,0.018,0.023697,...,0.32261,0.002883,0.022658,0.012093,0.014367,0.00706,0.011791,0.006851,0.054742,0.005539
image composition,0.561482,0.201916,0.366101,0.292319,0.745837,0.418577,0.423961,0.522924,0.301614,0.309978,...,0.274987,0.597231,0.460572,0.528594,0.388638,0.528926,0.234076,0.668987,0.691157,0.403057


In [14]:
# 用toml来储存tags和候选标签

import toml


wd14_tags_toml_name = "wd14_tags.toml"


tags = [
    {
        "name": "rating",
        "tags": rating_tags,
    },
    {
        "name": "general",
        "tags": general_tags,
    },
    {
        "name": "character",
        "tags": character_tags,
    },

]

candidate_labels = candidate_labels_dict,


toml_dict = {
    "tags": tags,
    "candidate_labels": candidate_labels,
}

with open(wd14_tags_toml_name, "w") as f:
    toml.dump(toml_dict, f)

In [17]:
# 上传文件

repo_id = "your_repo_id"
token = "your_token"
upload_files_list = [wd14_tags_toml_name, np_savez_name]

from huggingface_hub import HfApi
api = HfApi()

for f in upload_files_list:
    api.upload_file(
        path_or_fileobj=f,
        path_in_repo=f,
        repo_id=repo_id,
        token=token,
    )

candidate_labels_scores_safetensors.npz:   0%|          | 0.00/2.84M [00:00<?, ?B/s]

## 换成pt进行预测

In [20]:
from huggingface_hub import snapshot_download


download_dir = "deberta-v3-base-tasksource-nli" # 模型保存目录


snapshot_download(
    repo_id = "sileod/deberta-v3-base-tasksource-nli",
    local_dir = download_dir,
    ignore_patterns = ["model.safetensors"],  # 使用torch模型而不是safetensors
)


Fetching 10 files:   0%|          | 0/10 [00:00<?, ?it/s]

'/content/deberta-v3-base-tasksource-nli'

In [21]:
# 载入zero-shot模型


device = 0 if torch.cuda.is_available() else -1  # 有显卡就用显卡


classifier = pipeline(
    "zero-shot-classification",
    model=download_dir,  # 载入pt模型
    device=device,
    framework="pt"
)


In [22]:
# 进行分类

batch_size = 2048  # 6g显存只能256，这里使用的是colab的T4显卡

general_classifier_result = classifier(
    [tag.replace("_", " ")for tag in general_tags],  # 一定要把"_"替换为" " ； 不然影响预测
    candidate_labels_list,
    batch_size=batch_size,
    multi_label=True,  # 注意要开启这个
)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [23]:
import numpy as np


# 声明空间
res_np = np.zeros((51, 6947))

# 将general_res的结果转换为numpy数组
for i, res in enumerate(general_classifier_result):
    # 每个label和各自对应的score绑定
    temp = list( zip( res["labels"], res["scores"] ) )
    # 统一按candidate_labels_list内顺序排序
    temp.sort(key = lambda x: candidate_labels_list.index( x[0] ) )
    # 将得分变成一维矩阵
    temp_array = np.array( [t[1] for t in temp] )
    # 储存到res_np中
    res_np[:,i] = temp_array


In [24]:
candidate_labels_np = np.array(candidate_labels_list)

np_savez_name = "candidate_labels_scores_pt.npz"  # 把名字换成pt

np.savez(
    np_savez_name,
    candidate_labels = candidate_labels_np,
    scores = res_np,  # 默认好像是64位精度
)


In [25]:
import pandas as pd

# 以candidate_labels_list为行名，general_tags为列名
df = pd.DataFrame(
    res_np,
    index=candidate_labels_np,
    columns=general_tags,
    )

# 显示df每一列中最大值那一行的行名
print( df.idxmax(axis=0) )

df.head()


1girl                human type
solo                   outdoors
long_hair                  hair
breasts                 breasts
looking_at_viewer          eyes
                        ...    
yellow_socks            clothes
animal_on_hand          animals
red_mittens             clothes
rabbit_on_head          animals
qingxin_flower           nature
Length: 6947, dtype: object


Unnamed: 0,1girl,solo,long_hair,breasts,looking_at_viewer,blush,smile,short_hair,open_mouth,bangs,...,stuffed_dog,four-leaf_clover_hair_ornament,year_of_the_rooster,person_on_head,lifebuoy_ornament,yellow_socks,animal_on_hand,red_mittens,rabbit_on_head,qingxin_flower
items,0.266413,0.402745,0.581564,0.788755,0.232187,0.627783,0.366624,0.564226,0.467037,0.813162,...,0.912031,0.841842,0.242005,0.424968,0.863428,0.907297,0.650708,0.91563,0.463008,0.688104
food,0.034859,0.064432,0.022804,0.910414,0.021792,0.066484,0.040706,0.026217,0.650144,0.063456,...,0.941469,0.463964,0.288845,0.060912,0.032233,0.015502,0.599672,0.019072,0.623993,0.673381
flowers,0.079797,0.180416,0.044189,0.091722,0.09888,0.658574,0.109768,0.046611,0.167661,0.071966,...,0.041132,0.920508,0.210883,0.088215,0.388897,0.416009,0.057904,0.116387,0.074662,0.95034
furniture,0.046291,0.252386,0.025946,0.171819,0.09004,0.055395,0.082526,0.040002,0.287637,0.037053,...,0.67625,0.193443,0.060459,0.167433,0.399497,0.031865,0.032507,0.056228,0.044833,0.027556
image composition,0.497671,0.702992,0.460509,0.589578,0.652932,0.654274,0.601347,0.482562,0.676661,0.487945,...,0.572065,0.615834,0.458377,0.545359,0.781091,0.476714,0.501315,0.62097,0.583186,0.576092


In [26]:
# 上传文件

repo_id = "your_repo_id"
token = "your_token"
upload_files_list = [np_savez_name]

from huggingface_hub import HfApi
api = HfApi()

for f in upload_files_list:
    api.upload_file(
        path_or_fileobj=f,
        path_in_repo=f,
        repo_id=repo_id,
        token=token,
    )

candidate_labels_scores_pt.npz:   0%|          | 0.00/2.84M [00:00<?, ?B/s]