<a href="https://colab.research.google.com/github/mkshing/notebooks/blob/main/stable_cascade.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
# Stable Cascade Demo
This notebook is the demo for Stable Cascade aka Würstchen v3, from Stability AI on Colab free plan.


This was made by [mkshing](https://twitter.com/mk1stats).

Visit the following links for the details of the model.
- Blog: https://stability.ai/news/introducing-stable-cascade
- Paper: https://openreview.net/forum?id=gU58d5QeGv
- Code: https://github.com/Stability-AI/StableCascade
- HF: https://huggingface.co/stabilityai/stable-cascade
- License: [STABILITY AI NON-COMMERCIAL RESEARCH COMMUNITY LICENSE AGREEMEN](https://github.com/Stability-AI/StableCascade/blob/master/LICENSE)


*Please remeber that this model was released under non-commercial license.*


## Updates
### 2024.2.14
* Model release. Congratulations on the release to the team! (https://twitter.com/dome_271/status/1757427041563967512)

In [None]:
#@title Setup
!nvidia-smi
!pip install git+https://github.com/kashif/diffusers.git@wuerstchen-v3
!pip install -U accelerate torch torchvision
!pip install gradio==4.17.0

In [None]:
#@title Load models
import torch
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
device = torch.device("cuda")
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16)
# prior.enable_model_cpu_offload()
prior.to(device)
decoder.enable_model_cpu_offload()
# torch compile
# prior.prior = torch.compile(prior.prior, mode="reduce-overhead", fullgraph=True)
# decoder.decoder = torch.compile(decoder.decoder, mode="max-autotune", fullgraph=True)


In [None]:
#@title Run!
# original code: https://huggingface.co/spaces/multimodalart/stable-cascade
import random
import gc
import numpy as np
import gradio as gr

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1536


def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    return seed

def generate_prior(prompt, negative_prompt, generator, width, height, num_inference_steps, guidance_scale, num_images_per_prompt):
  # prior.to(device=device)
  prior_output = prior(
      prompt=prompt,
      height=height,
      width=width,
      negative_prompt=negative_prompt,
      guidance_scale=guidance_scale,
      num_images_per_prompt=num_images_per_prompt,
      num_inference_steps=num_inference_steps
  )
  # prior.to(device="cpu")
  torch.cuda.empty_cache()
  gc.collect()
  return prior_output.image_embeddings


def generate_decoder(prior_embeds, prompt, negative_prompt, generator, num_inference_steps, guidance_scale):

  # decoder.to(device=device)
  decoder_output = decoder(
      image_embeddings=prior_embeds.to(device=device, dtype=decoder.dtype),
      prompt=prompt,
      negative_prompt=negative_prompt,
      guidance_scale=guidance_scale,
      output_type="pil",
      num_inference_steps=num_inference_steps,
      generator=generator
  ).images
  # decoder.to(device="cpu")
  torch.cuda.empty_cache()
  gc.collect()
  return decoder_output


@torch.inference_mode()
def generate(
    prompt: str,
    negative_prompt: str = "",
    seed: int = 0,
    randomize_seed: bool = True,
    width: int = 1024,
    height: int = 1024,
    prior_num_inference_steps: int = 30,
    prior_guidance_scale: float = 4.0,
    decoder_num_inference_steps: int = 12,
    decoder_guidance_scale: float = 0.0,
    num_images_per_prompt: int = 2,
):
    """Generate images using Stable Cascade."""
    seed = randomize_seed_fn(seed, randomize_seed)
    print("seed:", seed)
    generator = torch.Generator(device=device).manual_seed(seed)
    prior_embeds = generate_prior(
        prompt=prompt,
        negative_prompt=negative_prompt,
        generator=generator,
        width=width,
        height=height,
        num_inference_steps=prior_num_inference_steps,
        guidance_scale=prior_guidance_scale,
        num_images_per_prompt=num_images_per_prompt,

    )

    decoder_output = generate_decoder(
        prior_embeds=prior_embeds,
        prompt=prompt,
        negative_prompt=negative_prompt,
        generator=generator,
        num_inference_steps=decoder_num_inference_steps,
        guidance_scale=decoder_guidance_scale,
    )

    return decoder_output


examples = [
    "An astronaut riding a green horse",
    "A mecha robot in a favela by Tarsila do Amaral",
    "The sprirt of a Tamagotchi wandering in the city of Los Angeles",
    "A delicious feijoada ramen dish"
]

with gr.Blocks() as demo:
  with gr.Column():

    prompt = gr.Text(
        label="Prompt",
        show_label=False,
        placeholder="Enter your prompt",
    )
    run_button = gr.Button("Run")
    with gr.Accordion("Advanced options", open=False):
        negative_prompt = gr.Text(
            label="Negative prompt",
            max_lines=1,
            placeholder="Enter a Negative Prompt",
        )

        seed = gr.Slider(
            label="Seed",
            minimum=0,
            maximum=MAX_SEED,
            step=1,
            value=0,
        )
        randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
        width = gr.Slider(
            label="Width",
            minimum=1024,
            maximum=MAX_IMAGE_SIZE,
            step=512,
            value=1024,
        )
        height = gr.Slider(
            label="Height",
            minimum=1024,
            maximum=MAX_IMAGE_SIZE,
            step=512,
            value=1024,
        )
        num_images_per_prompt = gr.Slider(
            label="Number of Images",
            minimum=1,
            maximum=2,
            step=1,
            value=1,
        )
        prior_guidance_scale = gr.Slider(
            label="Prior Guidance Scale",
            minimum=0,
            maximum=20,
            step=0.1,
            value=4.0,
        )
        prior_num_inference_steps = gr.Slider(
            label="Prior Inference Steps",
            minimum=10,
            maximum=30,
            step=1,
            value=20,
        )

        decoder_guidance_scale = gr.Slider(
            label="Decoder Guidance Scale",
            minimum=0,
            maximum=0,
            step=0.1,
            value=0.0,
        )
        decoder_num_inference_steps = gr.Slider(
            label="Decoder Inference Steps",
            minimum=4,
            maximum=12,
            step=1,
            value=10,
        )
  with gr.Column():
    result = gr.Gallery(label="Result", show_label=False)

  gr.Examples(
      examples=examples,
      inputs=prompt,
      outputs=result,
      fn=generate,
  )

  inputs = [
          prompt,
          negative_prompt,
          seed,
          randomize_seed,
          width,
          height,
          prior_num_inference_steps,
          prior_guidance_scale,
          decoder_num_inference_steps,
          decoder_guidance_scale,
          num_images_per_prompt,
  ]
  prompt.submit(
      fn=generate,
      inputs=inputs,
      outputs=result,
  )
  negative_prompt.submit(
      fn=generate,
      inputs=inputs,
      outputs=result,
  )
  run_button.click(
      fn=generate,
      inputs=inputs,
      outputs=result,
  )

  demo.launch(share=True, debug=True, show_error=True)

