In [None]:
from libs import stable_diffusion_service
from libs.stable_diffusion_service import GenerateImage, StableDiffusion
from PIL import Image
from IPython.display import display
from tqdm import tqdm
from tqdm.asyncio import tqdm as async_tqdm
import pandas as pd
import asyncio
import time
import os
import json

# STABLE_DIFFUSION_BASE_URL = "http://100.119.72.3:7860" # 腾讯云IDE
# STABLE_DIFFUSION_BASE_URL = "http://192.168.10.101:7860"  # 笔记本
STABLE_DIFFUSION_BASE_URL = "http://localhost:6006"  # AutoDL

SD = StableDiffusion(STABLE_DIFFUSION_BASE_URL, concurrency_limit=10)
TEST = False

In [2]:
# 读取 csv 文件并构建 prompt 列表
sd_style_list: pd.DataFrame = (
    pd.read_csv("csv/300_NAI_Styles_Table.csv")
    if not TEST
    else pd.read_csv("csv/300_NAI_Styles_Table-test.csv")
)
common_prompts: pd.DataFrame = pd.read_csv("csv/common_prompts.csv")
prompt_list: list[str] = []

for _, sd_style in sd_style_list.iterrows():
    for _, common_prompt in common_prompts.iterrows():
        prompt = (
            f"{common_prompt['Gender tags']}"
            f"{common_prompt['Character(s) tags']}"
            f"{common_prompt['Series tags']}"
            f"{common_prompt['Rating tags']}"
            f"{sd_style['Artists']}"
            f"{common_prompt['General tags']}"
            f"{common_prompt['Qulity tags']}"
        )
        prompt_list.append(prompt)

# 查看生成的 prompt 列表的前几个
# print(f"Prompt list length: {len(prompt_list)}")
# for idx, prompt in enumerate(prompt_list[:6]):
#     print(f"Prompt {idx}: {prompt}")

In [3]:
# 构建生成所用参数
SD_generate_Parameter_list: list[dict] = []
for prompt in prompt_list:
    SD_generate_Parameter_list.append(
        {
            "prompt": prompt,
            "negative_prompt": r"text,watermark,bad anatomy,bad proportions,extra limbs,extra digit,extra legs,extra legs and arms,disfigured,missing arms,too many fingers,fused fingers,missing fingers,unclear eyes,watermark,username,logo,artist logo,patreon logo,weibo logo,arknights logo,",
            "width": 832,
            "height": 1216,
            "cfg_scale": 5,
            "steps": 28,
            "sampler_name": "Euler a",
            "seed": -1,
            "batch_size": 1,
        }
    )

In [4]:
# 实际生成图片
tasks = [
    SD.aio_generate_images(params, i)
    for i, params in enumerate(SD_generate_Parameter_list)
]


images: list[Image.Image] = (
    [Image.Image] * len(tasks) * SD_generate_Parameter_list[0]["batch_size"]
)
image_info: list[dict] = (
    [{"parameters", "info"}] * len(tasks) * SD_generate_Parameter_list[0]["batch_size"]
)

for coro in async_tqdm(
    asyncio.as_completed(tasks), total=len(tasks), desc="Generating"
):
    result, index = await coro
    for i, img in enumerate(result):
        images[index * SD_generate_Parameter_list[0]["batch_size"] + i] = img.image
        image_info[index * SD_generate_Parameter_list[0]["batch_size"] + i] = {
            "parameters": img.parameters,
            "info": img.info,
        }

Generating: 100%|██████████| 9/9 [00:29<00:00,  3.30s/it]


In [5]:
# 保存图片、数据和生成的表格
# 初始化变量
sd_style_table: pd.DataFrame = pd.DataFrame(
    index=sd_style_list["Artists"], columns=common_prompts["General tags"]
)
image_data: list[dict] = []
current_time: str = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
tmp_dir = os.path.join("tmp", current_time)
if not os.path.exists(tmp_dir):
    os.makedirs(tmp_dir)
image_index: int = 0



for _, sd_style in sd_style_list.iterrows():
    for _, common_prompt in common_prompts.iterrows():
        # 保存图片
        img_storage_path: str = os.path.join(tmp_dir, f"{image_index}.webp")
        images[image_index].save(img_storage_path, format="webp")

        # 更新 DataFrame
        sd_style_table.at[sd_style["Artists"], common_prompt["General tags"]] = image_index
        # 更新 image_data
        image_data.append({
            "index": image_index,
            "img_storage_path": img_storage_path,
            "parameters": image_info[image_index]["parameters"],
            "info": image_info[image_index]["info"],
        })

        image_index += 1

# 保存 DataFrame 到 Parquet 文件
sd_style_table.to_csv(
    os.path.join(tmp_dir, "sd_style_table.csv"), index=True
)
# 保存图片数据到 JSON 文件
with open(os.path.join(tmp_dir, "image_data.json"), "w", encoding="utf-8") as f:
    json.dump(image_data, f, ensure_ascii=False, indent=4)

In [6]:
sd_style_table.at[sd_style_list["Artists"][0], common_prompts["General tags"][0]]

0