In [1]:
%load_ext autoreload
%autoreload 2

# load pipeline
from diffusers.models import ControlNetModel
from diffusers.schedulers import DPMSolverMultistepScheduler
import torch

from text3d2video.generative_rendering.generative_rendering_pipeline import (
    GenerativeRenderingPipeline,
)

module_paths = [
    "down_blocks.0.attentions.0.transformer_blocks.0.attn1",
    "down_blocks.0.attentions.1.transformer_blocks.0.attn1",
    "down_blocks.1.attentions.0.transformer_blocks.0.attn1",
    "down_blocks.1.attentions.1.transformer_blocks.0.attn1",
    "down_blocks.2.attentions.0.transformer_blocks.0.attn1",
    "down_blocks.2.attentions.1.transformer_blocks.0.attn1",
    "mid_block.attentions.0.transformer_blocks.0.attn1",
    "up_blocks.1.attentions.0.transformer_blocks.0.attn1",
    "up_blocks.1.attentions.1.transformer_blocks.0.attn1",
    "up_blocks.1.attentions.2.transformer_blocks.0.attn1",
    "up_blocks.2.attentions.0.transformer_blocks.0.attn1",
    "up_blocks.2.attentions.1.transformer_blocks.0.attn1",
    "up_blocks.2.attentions.2.transformer_blocks.0.attn1",
    "up_blocks.3.attentions.0.transformer_blocks.0.attn1",
    "up_blocks.3.attentions.1.transformer_blocks.0.attn1",
    "up_blocks.3.attentions.2.transformer_blocks.0.attn1",
]

device = torch.device("cuda")
dtype = torch.float16

sd_repo = "runwayml/stable-diffusion-v1-5"
controlnet_repo = "lllyasviel/control_v11f1p_sd15_depth"

controlnet = ControlNetModel.from_pretrained(controlnet_repo, torch_dtype=dtype).to(
    device
)

pipe: GenerativeRenderingPipeline = GenerativeRenderingPipeline.from_pretrained(
    sd_repo, controlnet=controlnet, torch_dtype=dtype
).to(device)

pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]



In [2]:
from text3d2video.sd_feature_extraction import AttnLayerId

module = [m for m in module_paths if m.startswith("down")][0]

mod_id = AttnLayerId.parse_module_path(module)
print(module)
print(mod_id.module_path())

down_blocks.0.attentions.0.transformer_blocks.0.attn1
down_blocks.0.attentions.0.transformer_blocks.0.attn1


In [3]:
for module in module_paths:
    layer = AttnLayerId.parse_module_path(module)
    level_idx = layer.level_idx(pipe.unet)

    print(
        layer.block_type.value,
        layer.unet_absolute_index(),
        layer.unet_path_index(),
    )


down_blocks 0 0
down_blocks 1 1
down_blocks 2 2
down_blocks 3 3
down_blocks 4 4
down_blocks 5 5
mid_block 6 0
up_blocks 7 0
up_blocks 8 1
up_blocks 9 2
up_blocks 10 3
up_blocks 11 4
up_blocks 12 5
up_blocks 13 6
up_blocks 14 7
up_blocks 15 8
