<a href="https://colab.research.google.com/github/koya-jp/AA-google-colab-kohya/blob/master/Diffusers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Driveに接続 { display-mode: "form" }
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#@title ライブラリの追加, LoRAの読み込み { display-mode: "form" }

# !pip install diffusers==0.12.1
# diffusers[torch] 以外の のインストール
!pip install --upgrade diffusers==0.17.1 transformers accelerate scipy ftfy safetensors >/dev/null 2>&1

import torch
from safetensors.torch import load_file


def load_safetensors_lora(pipeline, checkpoint_path, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.75):
    # load LoRA weight from .safetensors
    state_dict = load_file(checkpoint_path)

    visited = []

    # directly update weight in diffusers model
    for key in state_dict:
        # it is suggested to print out the key, it usually will be something like below
        # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"

        # as we have set the alpha beforehand, so just skip
        if ".alpha" in key or key in visited:
            continue

        if "text" in key:
            layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
            curr_layer = pipeline.text_encoder
        else:
            layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
            curr_layer = pipeline.unet

        # find the target layer
        temp_name = layer_infos.pop(0)
        while len(layer_infos) > -1:
            try:
                curr_layer = curr_layer.__getattr__(temp_name)
                if len(layer_infos) > 0:
                    temp_name = layer_infos.pop(0)
                elif len(layer_infos) == 0:
                    break
            except Exception:
                if len(temp_name) > 0:
                    temp_name += "_" + layer_infos.pop(0)
                else:
                    temp_name = layer_infos.pop(0)

        pair_keys = []
        if "lora_down" in key:
            pair_keys.append(key.replace("lora_down", "lora_up"))
            pair_keys.append(key)
        else:
            pair_keys.append(key)
            pair_keys.append(key.replace("lora_up", "lora_down"))

        # update weight
        if len(state_dict[pair_keys[0]].shape) == 4:
            weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
            weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
            curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
        else:
            weight_up = state_dict[pair_keys[0]].to(torch.float32)
            weight_down = state_dict[pair_keys[1]].to(torch.float32)
            curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)

        # update visited list
        for item in pair_keys:
            visited.append(item)

    return pipeline


In [None]:
#@title LoRAを設定 { display-mode: "both" }

# !pip install safetensors --quiet

from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
from diffusers.models import AutoencoderKL
import torch

#画像生成に使うモデルデータ
# emilianJR/majicMIX_realistic_v6 -> 若干アジア人寄りの画像（３Dチック） 作り込みはすごい。
# sinkinai/Beautiful-Realistic-Asians-v5 -> アジア人に特化したモデル(よりリアル。)　作り込みは足りない。
# runwayml/stable-diffusion-v1-5　→ リアルな動物の画像を生成できる。背景は外。
model_id = "runwayml/stable-diffusion-v1-5"
#画像生成に使うVAE
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema")
#画像生成に使うスケジューラー
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")

#パイプラインの作成
pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, vae=vae, custom_pipeline="lpw_stable_diffusion")

#LoRAを読み込む
LoRA_USE = False #@param {type:"boolean"}
if LoRA_USE == True:
  LoRA="/content/drive/MyDrive/Lora/add_detail.safetensors"#@param {type:"string"}
  LoRA_alpha = 0.7 #@param {type:"number"}
  pipe = load_safetensors_lora(pipe, LoRA, alpha=LoRA_alpha)

#LoRA_2を読み込む flat2 -1, LickingOralLoRA 0.5, koreanDollLikeness 0.8, DDpovbj_1ot
LoRA_USE_2= False #@param {type:"boolean"} Ture / False
if LoRA_USE_2== True:
  LoRA_2="/content/drive/MyDrive/Lora/flat2.safetensors"#@param {type:"string"}
  LoRA_alpha_2 = -1 #@param {type:"number"}
  pipe = load_safetensors_lora(pipe, LoRA_2, alpha=LoRA_alpha_2)

#LoRA_3を読み込む
LoRA_USE_3= False #@param {type:"boolean"}
if LoRA_USE_3== True:
  LoRA_3="/content/drive/MyDrive/Lora/cohf.safetensors"#@param {type:"string"}
  LoRA_alpha_3 = 0.6 #@param {type:"number"}
  pipe = load_safetensors_lora(pipe, LoRA_3, alpha=LoRA_alpha_3)

#LoRA_4を読み込む
LoRA_USE_4= True #@param {type:"boolean"}
if LoRA_USE_4== True:
  LoRA_4="/content/drive/MyDrive/Lora/pretty-cut-rum-sama.safetensors"#@param {type:"string"}
  LoRA_alpha_4 = 0.8 #@param {type:"number"}
  pipe = load_safetensors_lora(pipe, LoRA_4, alpha=LoRA_alpha_4)


pipe = pipe.to("cuda")

#NSFW規制を無効化する
if pipe.safety_checker is not None:
  pipe.safety_checker = lambda images, **kwargs: (images, False)


Downloading (…)lve/main/config.json:   0%|          | 0.00/547 [00:00<?, ?B/s]

Downloading (…)ch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

Downloading (…)cheduler_config.json:   0%|          | 0.00/308 [00:00<?, ?B/s]

Downloading (…)ain/model_index.json:   0%|          | 0.00/541 [00:00<?, ?B/s]



Downloading (…)_stable_diffusion.py:   0%|          | 0.00/13.9k [00:00<?, ?B/s]

Fetching 12 files:   0%|          | 0/12 [00:00<?, ?it/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/806 [00:00<?, ?B/s]

Downloading (…)_checker/config.json:   0%|          | 0.00/4.72k [00:00<?, ?B/s]

Downloading (…)_encoder/config.json:   0%|          | 0.00/617 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

Downloading (…)tokenizer/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/492M [00:00<?, ?B/s]

Downloading (…)7f0/unet/config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]

Downloading (…)ch_model.safetensors:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

Downloading (…)tokenizer/vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.


In [None]:
#@title 画像を生成 { display-mode: "both" }
import datetime
import os

# txt2img出力画像の保存先（日付ごと）
today = datetime.date.today()
output_dir = f"/content/drive/MyDrive/txt2img_output/{today.strftime('%Y%m%d')}"
os.makedirs(output_dir, exist_ok=True)

# ファイル名に使う日付と時刻のフォーマットを定義する
file_format = "%Y%m%d_%H%M%S"

# ポジティブプロンプト
prompt = "cute cat, realistic, high quality, masterpiece, HD,(adorable, white-and-gray fur), round-faced friendly"

# ネガティブプロンプト
n_prompt = "(worst quality, low quality:1.4), (zombie, sketch, interlocked fingers,comic), nsfw"

# 生成枚数
num_images = 5

# seed値 ex) 11897334222
seed = -1

# 画像を生成して保存する関数
def generate_and_save_image(prompt, n_prompt, seed, output_dir, file_format):
  # seed固定
  # generator = torch.Generator(device='cuda').manual_seed(seed)
  # image = pipe(prompt, negative_prompt=n_prompt, width=768, height=512, generator=generator, guidance_scale=7, num_inference_steps=20).images[0]

  # seed=-1（ランダム）の場合  width=512, height=768, width=768, height=1152
  image = pipe(prompt, negative_prompt=n_prompt, width=512, height=768, guidance_scale=7.5, num_inference_steps=20).images[0]

  # 出力する画像の名前を生成する
  image_name = datetime.datetime.now().strftime(file_format) + ".png"

  # 画像を保存する
  save_location = os.path.join(output_dir, image_name)
  image.save(save_location)

# num_images分だけ画像を生成して保存する
for i in range(num_images):
  generate_and_save_image(prompt, n_prompt, seed, output_dir, file_format)


In [None]:
# google.colabライブラリのインポート
import google.colab

# ランタイムの接続を解除して削除
google.colab.runtime.unassign()
