<a href="https://colab.research.google.com/github/koya-jp/AA-google-colab-kohya/blob/master/Diffusers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Diffusers ライブラリを用いて、画像を生成するスクリプト。**

In [27]:
#@title Driveに接続 { display-mode: "form" }
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [28]:
#@title ライブラリの追加, LoRAの読み込み { display-mode: "form" }

# !pip install diffusers==0.12.1
# diffusers[torch] 以外の のインストール
!pip install --upgrade diffusers==0.17.1 transformers accelerate scipy ftfy safetensors >/dev/null 2>&1

import torch
from safetensors.torch import load_file


def load_safetensors_lora(pipeline, checkpoint_path, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.75):
    # load LoRA weight from .safetensors
    state_dict = load_file(checkpoint_path)

    visited = []

    # directly update weight in diffusers model
    for key in state_dict:
        # it is suggested to print out the key, it usually will be something like below
        # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"

        # as we have set the alpha beforehand, so just skip
        if ".alpha" in key or key in visited:
            continue

        if "text" in key:
            layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
            curr_layer = pipeline.text_encoder
        else:
            layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
            curr_layer = pipeline.unet

        # find the target layer
        temp_name = layer_infos.pop(0)
        while len(layer_infos) > -1:
            try:
                curr_layer = curr_layer.__getattr__(temp_name)
                if len(layer_infos) > 0:
                    temp_name = layer_infos.pop(0)
                elif len(layer_infos) == 0:
                    break
            except Exception:
                if len(temp_name) > 0:
                    temp_name += "_" + layer_infos.pop(0)
                else:
                    temp_name = layer_infos.pop(0)

        pair_keys = []
        if "lora_down" in key:
            pair_keys.append(key.replace("lora_down", "lora_up"))
            pair_keys.append(key)
        else:
            pair_keys.append(key)
            pair_keys.append(key.replace("lora_up", "lora_down"))

        # update weight
        if len(state_dict[pair_keys[0]].shape) == 4:
            weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
            weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
            curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
        else:
            weight_up = state_dict[pair_keys[0]].to(torch.float32)
            weight_down = state_dict[pair_keys[1]].to(torch.float32)
            curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)

        # update visited list
        for item in pair_keys:
            visited.append(item)

    return pipeline


In [29]:
#@title LoRAを設定 　★ memo:　majicMIX_realistic_v6（アジア美女：リアル）,　stable-diffusion-v1-5（猫：リアル） { display-mode: "form" }

from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
from diffusers.models import AutoencoderKL
import torch

#画像生成に使うモデルデータ
model_id = "runwayml/stable-diffusion-v1-5" #@param ["runwayml/stable-diffusion-v1-5", "emilianJR/majicMIX_realistic_v6"]
#画像生成に使うVAE
vae_id = "stabilityai/sd-vae-ft-ema" #@param {type:"string"}
vae = AutoencoderKL.from_pretrained(vae_id)
#画像生成に使うスケジューラー
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")

#パイプラインの作成
pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, vae=vae, custom_pipeline="lpw_stable_diffusion")

# Ture / False

#LoRAを読み込む
LoRA_USE = False #@param {type:"boolean"}
if LoRA_USE == True:
  LoRA="/content/drive/MyDrive/Lora/add_detail.safetensors" #@param ["/content/drive/MyDrive/Lora/flat2.safetensors", "/content/drive/MyDrive/Lora/EkunePOVFellatioV2.safetensors", "/content/drive/MyDrive/Lora/pretty-cat-rum-sama.safetensors", "/content/drive/MyDrive/Lora/koreanDollLikeness.safetensors", "/content/drive/MyDrive/Lora/add_detail.safetensors"]
  LoRA_alpha = 0.7 #@param {type:"number"}
  pipe = load_safetensors_lora(pipe, LoRA, alpha=LoRA_alpha)

#LoRA_2を読み込む flat2 -1, LickingOralLoRA 0.5, koreanDollLikeness 0.8, DDpovbj_1ot
LoRA_USE_2= False #@param {type:"boolean"}
if LoRA_USE_2== True:
  LoRA_2="/content/drive/MyDrive/Lora/flat2.safetensors" #@param ["/content/drive/MyDrive/Lora/flat2.safetensors", "/content/drive/MyDrive/Lora/EkunePOVFellatioV2.safetensors", "/content/drive/MyDrive/Lora/pretty-cat-rum-sama.safetensors", "/content/drive/MyDrive/Lora/koreanDollLikeness.safetensors", "/content/drive/MyDrive/Lora/add_detail.safetensors"]
  LoRA_alpha_2 = -1 #@param {type:"number"}
  pipe = load_safetensors_lora(pipe, LoRA_2, alpha=LoRA_alpha_2)

#LoRA_3を読み込む
LoRA_USE_3= False #@param {type:"boolean"}
if LoRA_USE_3== True:
  LoRA_3="/content/drive/MyDrive/Lora/koreanDollLikeness.safetensors" #@param ["/content/drive/MyDrive/Lora/flat2.safetensors", "/content/drive/MyDrive/Lora/EkunePOVFellatioV2.safetensors", "/content/drive/MyDrive/Lora/pretty-cat-rum-sama.safetensors", "/content/drive/MyDrive/Lora/koreanDollLikeness.safetensors", "/content/drive/MyDrive/Lora/add_detail.safetensors"]
  LoRA_alpha_3 = 0.6 #@param {type:"number"}
  pipe = load_safetensors_lora(pipe, LoRA_3, alpha=LoRA_alpha_3)

#LoRA_4を読み込む
LoRA_USE_4= True #@param {type:"boolean"}
if LoRA_USE_4== True:
  LoRA_4="/content/drive/MyDrive/Lora/pretty-cat-rum-sama.safetensors" #@param (string)
  LoRA_alpha_4 = 0.8 #@param {type:"number"}
  pipe = load_safetensors_lora(pipe, LoRA_4, alpha=LoRA_alpha_4)


pipe = pipe.to("cuda")

#NSFW規制を無効化する
if pipe.safety_checker is not None:
  pipe.safety_checker = lambda images, **kwargs: (images, False)


`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.


In [None]:
#@title 画像を生成 { display-mode: "form" }
import datetime
import os

# txt2img出力画像の保存先（日付ごと）
today = datetime.date.today()
output_dir = f"/content/drive/MyDrive/txt2img_output/{today.strftime('%Y%m%d')}"
os.makedirs(output_dir, exist_ok=True)

# ファイル名に使う日付と時刻のフォーマットを定義する
file_format = "%Y%m%d_%H%M%S"

# ポジティブプロンプト
prompt = "cat, realistic, high quality, masterpiece, HD, looking at viewer, full body, no humans, realistic, animal focus, black eyes, solo, white fur, gray fur, in room" #@param {type:"string"}

# ネガティブプロンプト
n_prompt = "(worst quality, low quality:1.4), (sketch, interlocked fingers,comic)" #@param {type:"string"}

# 生成枚数
num_images = 3 #@param {type:"integer"}

# seed値 ex) 11897334222
seed = -1 #@param {type:"integer"}

# 画像を生成して保存する関数
def generate_and_save_image(prompt, n_prompt, seed, output_dir, file_format):
  # seed固定
  # generator = torch.Generator(device='cuda').manual_seed(seed)
  # image = pipe(prompt, negative_prompt=n_prompt, width=768, height=512, generator=generator, guidance_scale=7, num_inference_steps=20).images[0]

  # seed=-1（ランダム）の場合  width=512, height=768, width=768, height=1152
  width = 512 # @param [512, 768]
  height = 768 # @param [512, 768, 1152]
  guidance_scale = 7.5 #@param {type:"number"}
  num_inference_steps = 20 #@param {type:"integer"}
  image = pipe(prompt, negative_prompt=n_prompt, width=width, height=height, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps).images[0]

  # 出力する画像の名前を生成する
  image_name = datetime.datetime.now().strftime(file_format) + ".png"

  # 画像を保存する
  save_location = os.path.join(output_dir, image_name)
  image.save(save_location)

# num_images分だけ画像を生成して保存する
for i in range(num_images):
  generate_and_save_image(prompt, n_prompt, seed, output_dir, file_format)


  0%|          | 0/20 [00:00<?, ?it/s]

In [None]:
# #@title ランタイムの接続を解除して削除 { display-mode: "form" }

# # google.colabライブラリのインポート
# import google.colab

# # ランタイムの接続を解除して削除
# google.colab.runtime.unassign()
