# 01 Stable Diffusion v1.5 model exploration

In [13]:
from diffusers import StableDiffusionPipeline
import torch
from matplotlib import pyplot as plt
import cv2
import numpy as np
import torchinfo
import torchview

In [2]:
model_id = "sd-legacy/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("mps")

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

`torch_dtype` is deprecated! Use `dtype` instead!


In [None]:
prompt = "a photo of an astronaut riding a horse on mars"
image_pil = pipe(prompt).images[0]

In [None]:
plt.imshow(image_pil)
plt.axis('off')
plt.show()

## Exploring the infrastructure

In [12]:
torchinfo.summary(pipe.unet, depth=5)

Layer (type:depth-idx)                                            Param #
UNet2DConditionModel                                              --
├─Conv2d: 1-1                                                     11,840
├─Timesteps: 1-2                                                  --
├─TimestepEmbedding: 1-3                                          --
│    └─Linear: 2-1                                                410,880
│    └─SiLU: 2-2                                                  --
│    └─Linear: 2-3                                                1,639,680
├─ModuleList: 1-4                                                 --
│    └─CrossAttnDownBlock2D: 2-4                                  --
│    │    └─ModuleList: 3-1                                       --
│    │    │    └─Transformer2DModel: 4-1                          --
│    │    │    │    └─GroupNorm: 5-1                              640
│    │    │    │    └─Conv2d: 5-2                                 102,720
│    │ 

Down blocks have cross-attention layers as well

In [20]:
import torch
from torchview import draw_graph

# 1. Get the correct configuration
batch_size = 1
height = pipe.unet.config.sample_size
width = pipe.unet.config.sample_size
latents_shape = (batch_size, pipe.unet.config.in_channels, height, width)
prompt_shape = (batch_size, 77, pipe.unet.config.cross_attention_dim)

# 2. Create Dummy Inputs with the CORRECT DTYPE
# We add `dtype=pipe.unet.dtype` to match the model (float16 or float32)
dummy_latents = torch.randn(latents_shape, dtype=pipe.unet.dtype, device=pipe.device)
dummy_context = torch.randn(prompt_shape, dtype=pipe.unet.dtype, device=pipe.device)

# Timestep remains a standard tensor (usually int/long, converted internally)
dummy_timestep = torch.tensor([1], device=pipe.device)

# 3. Draw Graph
graph = draw_graph(
    pipe.unet,
    input_data=(dummy_latents, dummy_timestep, dummy_context),
    expand_nested=True,
    depth=2,
    save_graph=True,
    filename="unet_architecture"
)

graph.visual_graph.render(format="png")
print("Graph rendered successfully.")

Graph rendered successfully.


In [21]:
from graphviz_anywidget import graphviz_widget
graphviz_widget(graph.visual_graph.source)

