In [1]:
%load_ext autoreload
%autoreload 2

# load pipeline
from diffusers.models import ControlNetModel
from diffusers.schedulers import DPMSolverMultistepScheduler
import torch

from text3d2video.generative_rendering.configs import GenerativeRenderingConfig
from dataclasses import replace
from text3d2video.generative_rendering.generative_rendering_pipeline import (
    GenerativeRenderingPipeline,
)
from text3d2video.generative_rendering.generative_rendering_attn import (
    GenerativeRenderingAttn,
)
from text3d2video.generative_rendering.generative_rendering_attn import GrAttentionMode

device = torch.device("cuda")
dtype = torch.float16

sd_repo = "runwayml/stable-diffusion-v1-5"
controlnet_repo = "lllyasviel/control_v11f1p_sd15_depth"

controlnet = ControlNetModel.from_pretrained(controlnet_repo, torch_dtype=dtype).to(
    device
)

module_paths = [
      "down_blocks.0.attentions.0.transformer_blocks.0.attn1",
      "down_blocks.0.attentions.1.transformer_blocks.0.attn1",
      "down_blocks.1.attentions.0.transformer_blocks.0.attn1",
      "down_blocks.1.attentions.1.transformer_blocks.0.attn1",
      "down_blocks.2.attentions.0.transformer_blocks.0.attn1",
      "down_blocks.2.attentions.1.transformer_blocks.0.attn1",
      "mid_block.attentions.0.transformer_blocks.0.attn1",
      "up_blocks.1.attentions.0.transformer_blocks.0.attn1",
      "up_blocks.1.attentions.1.transformer_blocks.0.attn1",
      "up_blocks.1.attentions.2.transformer_blocks.0.attn1",
      "up_blocks.2.attentions.0.transformer_blocks.0.attn1",
      "up_blocks.2.attentions.1.transformer_blocks.0.attn1",
      "up_blocks.2.attentions.2.transformer_blocks.0.attn1",
      "up_blocks.3.attentions.0.transformer_blocks.0.attn1",
      "up_blocks.3.attentions.1.transformer_blocks.0.attn1",
      "up_blocks.3.attentions.2.transformer_blocks.0.attn1",
]

gr_config = GenerativeRenderingConfig(
    seed= 0,
    resolution=512,
    do_pre_attn_injection= True,
    do_post_attn_injection= False,
    feature_blend_alpha= 0.5,
    attend_to_self_kv= False,
    mean_features_weight= 0.5,
    chunk_size = 5,
    num_keyframes = 4,
    num_inference_steps = 10,
    guidance_scale = 7.5,
    controlnet_conditioning_scale = 1.0,
    module_paths = module_paths
)

pipe: GenerativeRenderingPipeline = GenerativeRenderingPipeline.from_pretrained(
    sd_repo, controlnet=controlnet, torch_dtype=dtype, gr_config=gr_config
).to(device)

pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

In [2]:
attn = pipe.unet.down_blocks[0].attentions[0].transformer_blocks[0].attn1

In [14]:
# feature extraction, enabled layer

H = 64
d = attn.to_q.in_features
T = H * H
B = 2
F = 5

processor = GenerativeRenderingAttn(
    pipe.unet,
    gr_config,
)
processor.gr_config = gr_config

x = torch.randn(B * F, T, d).cuda().to(dtype)

processor.gr_attn_mode = GrAttentionMode.FEATURE_EXTRACTION
attn.processor = processor
y = attn(x)

processor.pre_attn_features

pre_attn_features = list(processor.pre_attn_features.values())[0]
post_attn_features = list(processor.post_attn_features.values())[0]

assert len(processor.pre_attn_features) == 1
assert len(processor.post_attn_features) == 1

assert pre_attn_features.shape == (B, F * T, d)
assert post_attn_features.shape == (B, F, d, H, H)

In [34]:
# feature extraction, disabled layer

test_config = replace(gr_config)

# remove module paths from config
test_config.module_paths = []

processor = GenerativeRenderingAttn(
    pipe.unet,
    test_config,
)

processor.gr_attn_mode = GrAttentionMode.FEATURE_EXTRACTION
attn.processor = processor
y = attn(x)


print(processor.pre_attn_features)
print(processor.post_attn_features)

assert(processor.pre_attn_features == {})
assert(processor.post_attn_features == {})

{}
{}


In [37]:
# feature injection
test_config = replace(gr_config)

processor = GenerativeRenderingAttn(
    pipe.unet,
    test_config,
)

processor.gr_attn_mode = GrAttentionMode.FEATURE_INJECTION

attn.processor = processor
y = attn(x)

y.shape

torch.Size([10, 4096, 320])

True