In [1]:
import torch
import time
import traceback
from diffusers import DiffusionPipeline
from diffusers import EulerAncestralDiscreteScheduler

In [2]:
torch.cuda.get_device_name(torch.cuda.current_device()), torch.version.cuda, torch.backends.cudnn.version()

('NVIDIA L4', '11.8', 8700)

In [3]:
torch.nn.functional.scaled_dot_product_attention

<function torch._C._nn.scaled_dot_product_attention>

In [4]:
pipe = DiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",    
    torch_dtype=torch.float16,
)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
# pipe.enable_attention_slicing()

In [5]:
pipe.enable_xformers_memory_efficient_attention()

In [14]:
help(pipe)

Help on StableDiffusionPipeline in module diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion object:

class StableDiffusionPipeline(diffusers.pipelines.pipeline_utils.DiffusionPipeline, diffusers.loaders.TextualInversionLoaderMixin, diffusers.loaders.LoraLoaderMixin, diffusers.loaders.FromCkptMixin)
 |  StableDiffusionPipeline(vae: diffusers.models.autoencoder_kl.AutoencoderKL, text_encoder: transformers.models.clip.modeling_clip.CLIPTextModel, tokenizer: transformers.models.clip.tokenization_clip.CLIPTokenizer, unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, scheduler: diffusers.schedulers.scheduling_utils.KarrasDiffusionSchedulers, safety_checker: diffusers.pipelines.stable_diffusion.safety_checker.StableDiffusionSafetyChecker, feature_extractor: transformers.models.clip.image_processing_clip.CLIPImageProcessor, requires_safety_checker: bool = True)
 |  
 |  Pipeline for text-to-image generation using Stable Diffusion.
 |  
 |  This model inherits from [`Di

In [20]:
batch_size_list = [2 ** x for x in range(0, 8)]
steps = 50
cfg_scale = 15
prompt = "postapocalyptic steampunk city, exploration, cinematic, realistic, hyper detailed, photorealistic maximum detail, volumetric light, (((focus))), wide-angle, (((brightly lit))), (((vegetation))), lightning, vines, destruction, devastation, wartorn, ruins"
# prompt = "detailed portrait beautiful Neon Operator Girl, cyberpunk futuristic neon, reflective puffy coat, decorated with traditional Japanese ornaments by Ismail inceoglu dragan bibin hans thoma greg rutkowski Alexandros Pyromallis Nekro Rene Maritte Illustrated, Perfect face, fine details, realistic shaded, fine-face, pretty face"
negative_prompt = "(((blurry))), ((foggy)), (((dark))), ((monochrome)), sun, (((depth of field)))"

In [9]:
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)

In [16]:
# prewarm
batch_size = 4
t0 = time.time()
_ = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    num_inference_steps=50,
    num_images_per_prompt=batch_size,
    guidance_scale=cfg_scale,
    height=512,
    width=512,
    ).images
t1 = time.time()
its = steps * batch_size / (t1 - t0)
print("batch_size {}, it/s: {}, time: {}".format(batch_size, round(its, 2), round((t1 - t0), 2)))

  0%|          | 0/50 [00:00<?, ?it/s]

batch_size 4, it/s: 18.83, time: 10.62


In [23]:
batch_size_list = [1, 1, 2, 2, 4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 128, 128]

In [24]:
result = []
for batch_size in batch_size_list:
    try:
        t0 = time.time()
        images = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            num_inference_steps=steps,
            num_images_per_prompt=batch_size,
            guidance_scale=cfg_scale,
            height=512,
            width=512,
            ).images
        t1 = time.time()
        its = steps * batch_size / (t1 - t0)
        print("batch_size {}, it/s: {}, time: {}".format(batch_size, round(its, 2), round((t1 - t0), 2)))
    except torch.cuda.OutOfMemoryError as e:
        print("batch_size {}, OOM".format(batch_size))
        its = 0
    except Exception:
        print(traceback.print_exc())
    result.append(round(its, 2))
result_jit = result[1::2]

  0%|          | 0/50 [00:00<?, ?it/s]

batch_size 1, it/s: 15.61, time: 3.2


  0%|          | 0/50 [00:00<?, ?it/s]

batch_size 1, it/s: 15.63, time: 3.2


  0%|          | 0/50 [00:00<?, ?it/s]

batch_size 2, it/s: 0.78, time: 128.48


  0%|          | 0/50 [00:00<?, ?it/s]

batch_size 2, it/s: 17.4, time: 5.75


  0%|          | 0/50 [00:00<?, ?it/s]

batch_size 4, it/s: 17.65, time: 11.33


  0%|          | 0/50 [00:00<?, ?it/s]

batch_size 4, it/s: 17.61, time: 11.36


  0%|          | 0/50 [00:00<?, ?it/s]

batch_size 8, it/s: 3.97, time: 100.66


  0%|          | 0/50 [00:00<?, ?it/s]

batch_size 8, it/s: 17.26, time: 23.17


  0%|          | 0/50 [00:00<?, ?it/s]

batch_size 16, it/s: 6.98, time: 114.54


  0%|          | 0/50 [00:00<?, ?it/s]

batch_size 16, it/s: 17.18, time: 46.56


  0%|          | 0/50 [00:00<?, ?it/s]

batch_size 32, it/s: 10.69, time: 149.61


  0%|          | 0/50 [00:00<?, ?it/s]

batch_size 32, it/s: 15.53, time: 103.0


  0%|          | 0/50 [00:00<?, ?it/s]

batch_size 64, OOM


  0%|          | 0/50 [00:00<?, ?it/s]

batch_size 64, OOM


  0%|          | 0/50 [00:00<?, ?it/s]

batch_size 128, OOM


  0%|          | 0/50 [00:00<?, ?it/s]

batch_size 128, OOM


In [25]:
result_jit

[15.63, 17.4, 17.61, 17.26, 17.18, 15.53, 0, 0]

In [None]:
images[0]