<a href="https://colab.research.google.com/github/Stability-AI/model-demo-notebooks/blob/main/japanese_stable_diffusion_xl.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Japanese Stable Diffusion XL Demo
This is a demo for [Japanese Stable Diffusion XL](https://huggingface.co/stabilityai/japanese-stable-diffusion-xl) from [Stability AI](https://stability.ai/).

- Blog: https://ja.stability.ai/blog/japanese-stable-diffusion-xl
- Twitter: https://twitter.com/StabilityAI_JP
- Discord: https://discord.com/invite/StableJP


In [None]:
#@title Setup
!nvidia-smi
!pip install 'diffusers>=0.23.0' transformers sentencepiece gradio accelerate

In [None]:
# @title Login HuggingFace
!huggingface-cli login

In [None]:
#@title Load JSDXL

# copied from https://huggingface.co/stabilityai/japanese-stable-diffusion-xl/blob/main/modeling_clipnull.py
from dataclasses import dataclass

import torch
from torch import nn

from diffusers.configuration_utils import register_to_config, ConfigMixin
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils import BaseOutput


@dataclass
class CLIPNullTextOutput(BaseOutput):
    text_embeds: torch.FloatTensor
    last_hidden_state: torch.FloatTensor


class CLIPNullTextModel(ModelMixin, ConfigMixin):
    @register_to_config
    def __init__(
        self,
        hidden_size: int = 1280,
        always_return_pooled: bool = True,
        max_length: int = 77,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.return_pooled = always_return_pooled
        self.max_length = max_length
        self.register_parameter(
            "z", nn.Parameter(torch.zeros((1, self.max_length, self.hidden_size)))
        )
        self.register_parameter(
            "pooled",
            nn.Parameter(
                torch.zeros((1, self.hidden_size)) if always_return_pooled else None
            ),
        )

    def forward(self, bsz: int = 1, return_dict: bool = True):
        z = self.z.expand(bsz, -1, -1)
        pooled = None
        if self.return_pooled:
            pooled = self.pooled.expand(bsz, -1)

        if not return_dict:
            return (pooled, z)

        return CLIPNullTextOutput(
            text_embeds=pooled,
            last_hidden_state=z,
        )

# img2img pipeline
from typing import Optional

import torch
from diffusers import StableDiffusionXLImg2ImgPipeline
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import (
    StableDiffusionXLLoraLoaderMixin,
    TextualInversionLoaderMixin,
)
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
    USE_PEFT_BACKEND,
    is_invisible_watermark_available,
    logging,
    scale_lora_layers,
    unscale_lora_layers,
)
from transformers import AutoTokenizer, CLIPTextModelWithProjection
# from .modeling_clipnull import CLIPNullTextModel

if is_invisible_watermark_available():
    from diffusers.pipelines.stable_diffusion_xl.watermark import (
        StableDiffusionXLWatermarker,
    )


class JapaneseStableDiffusionXLImg2ImgPipeline(StableDiffusionXLImg2ImgPipeline):
    model_cpu_offload_seq = "text_encoder->null_encoder->unet->vae"
    _optional_components = ["tokenizer", "text_encoder", "null_encoder"]

    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModelWithProjection,
        tokenizer: AutoTokenizer,
        unet: UNet2DConditionModel,
        scheduler: KarrasDiffusionSchedulers,
        null_encoder: Optional[CLIPNullTextModel] = None,
        requires_aesthetics_score: bool = False,
        force_zeros_for_empty_prompt: bool = True,
        add_watermarker: Optional[bool] = None,
    ):
        if null_encoder is None:
            null_encoder = CLIPNullTextModel(
                hidden_size=1280,
                max_length=tokenizer.model_max_length,
                always_return_pooled=True,
            )
        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            null_encoder=null_encoder,
        )
        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
        self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
        self.default_sample_size = self.unet.config.sample_size

        add_watermarker = (
            add_watermarker
            if add_watermarker is not None
            else is_invisible_watermark_available()
        )

        if add_watermarker:
            self.watermark = StableDiffusionXLWatermarker()
        else:
            self.watermark = None
        self.text_encoder_2 = None

    def encode_prompt(
        self,
        prompt: str,
        device: Optional[torch.device] = None,
        num_images_per_prompt: int = 1,
        do_classifier_free_guidance: bool = True,
        negative_prompt: Optional[str] = None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
        lora_scale: Optional[float] = None,
        clip_skip: Optional[int] = None,
        **kwargs,
    ):
        device = device or self._execution_device

        # set lora scale so that monkey patched LoRA
        # function of text encoder can correctly access it
        if lora_scale is not None and isinstance(
            self, StableDiffusionXLLoraLoaderMixin
        ):
            self._lora_scale = lora_scale

            # dynamically adjust the LoRA scale
            if self.text_encoder is not None:
                if not USE_PEFT_BACKEND:
                    adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
                else:
                    scale_lora_layers(self.text_encoder, lora_scale)

        prompt = [prompt] if isinstance(prompt, str) else prompt

        if prompt is not None:
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        if prompt_embeds is None:
            # textual inversion: procecss multi-vector tokens if necessary
            prompt_embeds_list = []
            if isinstance(self, TextualInversionLoaderMixin):
                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)

            text_inputs = self.tokenizer(
                prompt,
                padding="max_length",
                max_length=self.tokenizer.model_max_length,
                truncation=True,
                return_tensors="pt",
            )

            text_input_ids = text_inputs.input_ids
            untruncated_ids = self.tokenizer(
                prompt, padding="longest", return_tensors="pt"
            ).input_ids

            if untruncated_ids.shape[-1] >= text_input_ids.shape[
                -1
            ] and not torch.equal(text_input_ids, untruncated_ids):
                removed_text = self.tokenizer.batch_decode(
                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
                )
                logger.warning(
                    "The following part of your input was truncated because CLIP can only handle sequences up to"
                    f" {self.tokenizer.model_max_length} tokens: {removed_text}"
                )

            prompt_embeds = self.text_encoder(
                text_input_ids.to(device), output_hidden_states=True
            )
            if clip_skip is None:
                prompt_embeds = prompt_embeds.hidden_states[-2]
            else:
                # "2" because SDXL always indexes from the penultimate layer.
                prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
            prompt_embeds_list.append(prompt_embeds)
            bsz = prompt_embeds.size(0)

            pooled_prompt_embeds, prompt_embeds_2 = self.null_encoder(
                bsz, return_dict=False
            )
            prompt_embeds_2 = prompt_embeds_2.to(
                prompt_embeds.device, prompt_embeds_2.dtype
            )
            pooled_prompt_embeds = pooled_prompt_embeds.to(
                prompt_embeds.device, prompt_embeds.dtype
            )
            prompt_embeds_list.append(prompt_embeds_2)
            prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)

        # get unconditional embeddings for classifier free guidance
        zero_out_negative_prompt = (
            negative_prompt is None and self.config.force_zeros_for_empty_prompt
        )
        if (
            do_classifier_free_guidance
            and negative_prompt_embeds is None
            and zero_out_negative_prompt
        ):
            negative_prompt_embeds = torch.zeros_like(prompt_embeds)
            negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
        elif do_classifier_free_guidance and negative_prompt_embeds is None:
            negative_prompt = negative_prompt or ""

            # normalize str to list
            negative_prompt = (
                batch_size * [negative_prompt]
                if isinstance(negative_prompt, str)
                else negative_prompt
            )

            if prompt is not None and type(prompt) is not type(negative_prompt):
                raise TypeError(
                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
                    f" {type(prompt)}."
                )
            elif batch_size != len(negative_prompt):
                raise ValueError(
                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
                    " the batch size of `prompt`."
                )

            negative_prompt_embeds_list = []
            if isinstance(self, TextualInversionLoaderMixin):
                negative_prompt = self.maybe_convert_prompt(
                    negative_prompt, self.tokenizer
                )

            max_length = prompt_embeds.shape[1]
            uncond_input = self.tokenizer(
                negative_prompt,
                padding="max_length",
                max_length=max_length,
                truncation=True,
                return_tensors="pt",
            )

            negative_prompt_embeds = self.text_encoder(
                uncond_input.input_ids.to(device),
                output_hidden_states=True,
            )
            negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
            negative_prompt_embeds_list.append(negative_prompt_embeds)

            bsz = negative_prompt_embeds.size(0)
            negative_prompt_embeds_2 = torch.zeros(
                (bsz, negative_prompt_embeds.size(1), self.null_encoder.hidden_size),
                device=negative_prompt_embeds.device,
            )
            negative_pooled_prompt_embeds = torch.zeros(
                (bsz, self.null_encoder.hidden_size), device=self.device
            )
            negative_prompt_embeds_list.append(negative_prompt_embeds_2)

            negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)

        prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)

        bs_embed, seq_len, _ = prompt_embeds.shape
        # duplicate text embeddings for each generation per prompt, using mps friendly method
        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
        prompt_embeds = prompt_embeds.view(
            bs_embed * num_images_per_prompt, seq_len, -1
        )

        if do_classifier_free_guidance:
            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
            seq_len = negative_prompt_embeds.shape[1]

            negative_prompt_embeds = negative_prompt_embeds.to(
                dtype=self.unet.dtype, device=device
            )

            negative_prompt_embeds = negative_prompt_embeds.repeat(
                1, num_images_per_prompt, 1
            )
            negative_prompt_embeds = negative_prompt_embeds.view(
                batch_size * num_images_per_prompt, seq_len, -1
            )

        pooled_prompt_embeds = pooled_prompt_embeds.repeat(
            1, num_images_per_prompt
        ).view(bs_embed * num_images_per_prompt, -1)
        if do_classifier_free_guidance:
            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(
                1, num_images_per_prompt
            ).view(bs_embed * num_images_per_prompt, -1)

        if self.text_encoder is not None:
            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
                # Retrieve the original scale by scaling back the LoRA layers
                unscale_lora_layers(self.text_encoder, lora_scale)
        return (
            prompt_embeds,
            negative_prompt_embeds,
            pooled_prompt_embeds,
            negative_pooled_prompt_embeds,
        )

    def check_inputs(
        self,
        prompt,
        prompt_2,
        strength,
        num_inference_steps,
        callback_steps,
        negative_prompt=None,
        negative_prompt_2=None,
        prompt_embeds=None,
        negative_prompt_embeds=None,
        callback_on_step_end_tensor_inputs=None,
    ):
        assert (
            prompt_2 is None
        ), "Japanese Stable Diffusion XL doesn't support `prompt_2` because there's only one text encoder."
        assert (
            negative_prompt_2 is None
        ), "Japanese Stable Diffusion XL doesn't support `prompt_2` because there's only one text encoder."
        return super().check_inputs(
            prompt,
            None,
            strength,
            num_inference_steps,
            callback_steps,
            negative_prompt,
            None,
            prompt_embeds,
            negative_prompt_embeds,
            callback_on_step_end_tensor_inputs,
        )


# start loading pipeline
from diffusers import DiffusionPipeline


pipeline_type = "txt2img" # @param ["txt2img", "img2img"]
pipeline_id = "stabilityai/japanese-stable-diffusion-xl"

pipe = DiffusionPipeline.from_pretrained(
    "stabilityai/japanese-stable-diffusion-xl",
    trust_remote_code=True,
    torch_dtype=torch.float16
)

if pipeline_type == "img2img":
  pipe = JapaneseStableDiffusionXLImg2ImgPipeline(
      vae=pipe.vae,
      text_encoder=pipe.text_encoder,
      tokenizer=pipe.tokenizer,
      unet=pipe.unet,
      scheduler=pipe.scheduler,
      null_encoder=pipe.null_encoder
  )

# if using torch < 2.0
# pipeline.enable_xformers_memory_efficient_attention()
pipe.to("cuda")

In [None]:
# @title Launch the demo
import random
import gc
import gradio as gr
from diffusers.utils import make_image_grid


def infer_func(
    prompt,
    scale=7.5,
    steps=40,
    W=1024,
    H=1024,
    n_samples=1,
    seed="random",
    negative_prompt="",
    image=None,
    strength=0.7,
):
    scale = float(scale)
    steps = int(steps)
    W = int(W)
    H = int(H)
    n_samples = int(n_samples)
    if seed == "random":
        seed = random.randint(0, 2**32)
    seed = int(seed)
    kwargs = {}
    if pipeline_type == "img2img":
        kwargs["image"] = image
        kwargs["strength"] = strength
    else:
        kwargs["height"] = H
        kwargs["width"] = W
    images = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt if len(negative_prompt) > 0 else None,
        guidance_scale=scale,
        generator=torch.Generator(device="cuda").manual_seed(seed),
        num_images_per_prompt=n_samples,
        num_inference_steps=steps,
        **kwargs
    ).images
    grid = make_image_grid(images, 1, len(images))
    gc.collect()
    torch.cuda.empty_cache()
    return grid, images, {"seed": seed}


with gr.Blocks() as demo:
    gr.Markdown("# Japanese Stable Diffusion XL Demo")
    gr.Markdown(
        """[Japanese Stable Diffusion XL](https://huggingface.co/stabilityai/japanese-stable-diffusion-xl) is a Japanese-specific SDXL by [Stability AI](https://ja.stability.ai/).
                - Blog: https://ja.stability.ai/blog/japanese-stable-diffusion-xl
                - Twitter: https://twitter.com/StabilityAI_JP
                - Discord: https://discord.com/invite/StableJP"""
    )
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="prompt", max_lines=1, value="カラフルなペンギン、アート")
            image = gr.Image(label="input_image", visible=pipeline_type=="img2img", value=None, type="pil")
            scale = gr.Number(value=7.5, label="cfg_scale")
            strength = gr.Number(value=0.5, label="strength")
            steps = gr.Number(value=40, label="steps")
            width = gr.Number(value=1024, label="width", visible=pipeline_type!="img2img")
            height = gr.Number(value=1024, label="height", visible=pipeline_type!="img2img")
            n_samples = gr.Number(value=1, label="n_samples", precision=0, maximum=5)
            seed = gr.Text(value="42", label="seed (integer or 'random')")
            negative_prompt = gr.Textbox(label="negative prompt", value="")
            btn = gr.Button("Run")
        with gr.Column():
            out = gr.Image(label="grid")
            gallery = gr.Gallery(label="Generated images", show_label=False)
            info = gr.JSON(label="sampling_info")
    inputs = [
        prompt,
        scale,
        steps,
        width,
        height,
        n_samples,
        seed,
        negative_prompt,
        image,
        strength,
    ]
    prompt.submit(infer_func, inputs=inputs, outputs=[out, gallery, info])
    btn.click(infer_func, inputs=inputs, outputs=[out, gallery, info])

demo.launch(debug=True, share=True, show_error=True)
