<a href="https://colab.research.google.com/github/ailab-nda/ML/blob/main/Diffusers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#stable diffusionのインストール
!pip install --upgrade diffusers[torch] transformers

In [None]:
#from google.colab import drive
#drive.mount('/content/drive')

In [None]:
!pip install safetensors

import torch
from safetensors.torch import load_file


def load_safetensors_lora(pipeline, checkpoint_path, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.75):
    # load LoRA weight from .safetensors
    state_dict = load_file(checkpoint_path)

    visited = []

    # directly update weight in diffusers model
    for key in state_dict:
        # it is suggested to print out the key, it usually will be something like below
        # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"

        # as we have set the alpha beforehand, so just skip
        if ".alpha" in key or key in visited:
            continue

        if "text" in key:
            layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
            curr_layer = pipeline.text_encoder
        else:
            layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
            curr_layer = pipeline.unet

        # find the target layer
        temp_name = layer_infos.pop(0)
        while len(layer_infos) > -1:
            try:
                curr_layer = curr_layer.__getattr__(temp_name)
                if len(layer_infos) > 0:
                    temp_name = layer_infos.pop(0)
                elif len(layer_infos) == 0:
                    break
            except Exception:
                if len(temp_name) > 0:
                    temp_name += "_" + layer_infos.pop(0)
                else:
                    temp_name = layer_infos.pop(0)

        pair_keys = []
        if "lora_down" in key:
            pair_keys.append(key.replace("lora_down", "lora_up"))
            pair_keys.append(key)
        else:
            pair_keys.append(key)
            pair_keys.append(key.replace("lora_up", "lora_down"))

        # update weight
        if len(state_dict[pair_keys[0]].shape) == 4:
            weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
            weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
            curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
        else:
            weight_up = state_dict[pair_keys[0]].to(torch.float32)
            weight_down = state_dict[pair_keys[1]].to(torch.float32)
            curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)

        # update visited list
        for item in pair_keys:
            visited.append(item)

    return pipeline

In [None]:
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
from diffusers.models import AutoencoderKL
import torch

#画像生成に使うモデルデータ
model_id = "sinkinai/Beautiful-Realistic-Asians-v5"#@param {type:"string"}
#画像生成に使うVAE
vae = "stabilityai/sd-vae-ft-ema"#@param {type:"string"}
vae = AutoencoderKL.from_pretrained(vae)

#画像生成に使うスケジューラー
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")

#パイプラインの作成
pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, vae=vae, custom_pipeline="lpw_stable_diffusion")

#LoRAを読み込む
#LoRA_USE = False #@param {type:"boolean"}
#if LoRA_USE == True:
#  LoRA="/content/drive/MyDrive/StableDiffusion/Lora/flat2.safetensors"#@param {type:"string"}
#  LoRA_alpha = -1#@param {type:"number"}
#  pipe = load_safetensors_lora(pipe, LoRA, alpha=LoRA_alpha)

pipe = pipe.to("cuda")

#NSFW規制を無効化する
if pipe.safety_checker is not None:
  pipe.safety_checker = lambda images, **kwargs: (images, False)

In [None]:
import datetime
import os
import random
import IPython
from IPython.display import Image

#txt2img出力画像の保存先
!mkdir -p /content/txt2img_output

# ファイル名に使う日付と時刻のフォーマットを定義する
file_format = "%Y%m%d_%H%M%S"
i=0

#ポジティブプロンプト
prompt = "japanese, girl, Tokyo street, night, cityscape, city lights, upper body, close-up, 8k, RAW photo, best quality, masterpiece, realistic, photo-realistic"#@param {type:"string"}
#ネガティブプロンプト
n_prompt = "EasyNegativeV2,negative_hand-neg, (worst quality:2), (low quality:2), (normal quality:2), lowres, bad anatomy, bad hands, normal quality, ((monochrome)), ((grayscale))"#@param {type:"string"}
#CFG Scale
CFG_scale = 7#@param {type:"number"}
#ステップ数
Steps = 20#@param {type:"number"}
#seed値
seed=-1#@param {type:"number"}
if seed is None or seed == -1:
  inputSeed = random.randint(0, 2147483647)
else:
  valueSeed = seed

#生成枚数
num_images = 3#@param {type:"number"}
#出力画像の横幅
width = 768#@param {type:"number"}
#出力画像の高さ
height = 512#@param {type:"number"}
#出力画像を保存するフォルダ
save_path = "/content/txt2img_output"#@param {type:"string"}

while i < int(num_images):
  #generator
  if seed is None or seed == -1:valueSeed = inputSeed + i
  generator = torch.Generator(device="cuda").manual_seed(valueSeed)

  #画像を生成
  image = pipe(prompt, negative_prompt=n_prompt, width=width, height=height, generator=generator, guidance_scale=CFG_scale, num_inference_steps=Steps).images[0]

  # 現在の日本時間を取得
  jst_now = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=9)))
  #出力する画像の名前を生成する
  file_name = (jst_now.strftime(file_format)+ "_" + str(valueSeed))
  image_name = file_name + f".png"

  #画像を保存する
  save_location = os.path.join(save_path, image_name)
  image.save(save_location)
  IPython.display.display(IPython.display.Image(save_location))
  i = i + 1

In [None]:
# ファイルをDL
from google.colab import files
import shutil
# フォルダをzip圧縮
shutil.make_archive("txt2img_output", "zip", "/content/txt2img_output")
# 圧縮ファイルをダウンロード
files.download("txt2img_output.zip")