In [45]:
%load_ext autoreload
%autoreload 2

# load pipeline
from diffusers.models import ControlNetModel
from diffusers.schedulers import DPMSolverMultistepScheduler
import torch

from text3d2video.generative_rendering.configs import GenerativeRenderingConfig
from text3d2video.generative_rendering.generative_rendering_pipeline import (
    GenerativeRenderingPipeline,
)
from text3d2video.generative_rendering.generative_rendering_attn import (
    GenerativeRenderingAttn,
)

device = torch.device("cuda")
dtype = torch.float16

sd_repo = "runwayml/stable-diffusion-v1-5"
controlnet_repo = "lllyasviel/control_v11f1p_sd15_depth"

controlnet = ControlNetModel.from_pretrained(controlnet_repo, torch_dtype=dtype).to(
    device
)

module_paths = [
      "down_blocks.0.attentions.0.transformer_blocks.0.attn1",
      "down_blocks.0.attentions.1.transformer_blocks.0.attn1",
      "down_blocks.1.attentions.0.transformer_blocks.0.attn1",
      "down_blocks.1.attentions.1.transformer_blocks.0.attn1",
      "down_blocks.2.attentions.0.transformer_blocks.0.attn1",
      "down_blocks.2.attentions.1.transformer_blocks.0.attn1",
      "mid_block.attentions.0.transformer_blocks.0.attn1",
      "up_blocks.1.attentions.0.transformer_blocks.0.attn1",
      "up_blocks.1.attentions.1.transformer_blocks.0.attn1",
      "up_blocks.1.attentions.2.transformer_blocks.0.attn1",
      "up_blocks.2.attentions.0.transformer_blocks.0.attn1",
      "up_blocks.2.attentions.1.transformer_blocks.0.attn1",
      "up_blocks.2.attentions.2.transformer_blocks.0.attn1",
      "up_blocks.3.attentions.0.transformer_blocks.0.attn1",
      "up_blocks.3.attentions.1.transformer_blocks.0.attn1",
      "up_blocks.3.attentions.2.transformer_blocks.0.attn1",
]

gr_config = GenerativeRenderingConfig(
    seed= 0,
    resolution=512,
    do_pre_attn_injection= True,
    do_post_attn_injection= True,
    feature_blend_alpha= 0.5,
    attend_to_self_kv= False,
    mean_features_weight= 0.5,
    chunk_size = 5,
    num_keyframes = 4,
    num_inference_steps = 10,
    guidance_scale = 7.5,
    controlnet_conditioning_scale = 1.0,
    module_paths = module_paths
)

pipe: GenerativeRenderingPipeline = GenerativeRenderingPipeline.from_pretrained(
    sd_repo, controlnet=controlnet, torch_dtype=dtype
).to(device)

pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

In [47]:
from text3d2video.sd_feature_extraction import get_module_from_path

module_path = module_paths[0]
attn = get_module_from_path(pipe.unet, module_path)

In [48]:
# feature extraction, enabled layer

from text3d2video.artifacts.gr_data import GrDataArtifact
from text3d2video.generative_rendering.configs import GrSaveConfig
from text3d2video.generative_rendering.generative_rendering_attn import GrAttnMode

# create input feature map
H = 64
d = attn.to_q.in_features
T = H * H
B = 2
F = 5

x = torch.randn(B * F, T, d).cuda().to(dtype)

gr_save_config = GrSaveConfig(
    enabled=True,
    save_latents=False,
    save_q=True,
    save_k=False,
    save_v=False,
    save_features=False,
    save_features_3d=False,
    n_frames=5,
    n_timesteps=5,
    out_artifact="test",
    module_paths=[module_path],
)

# create GrData object
gr_data = GrDataArtifact.init_from_cfg(gr_save_config)
gr_data.begin_recording()

# Create GR processor
processor = GenerativeRenderingAttn(
    pipe.unet,
    gr_config,
)
processor.gr_config = gr_config

# set mode to feature extraction
processor.mode = GrAttnMode.FEATURE_EXTRACTION
attn.processor = processor

# forward pass
y = attn(x)

gr_data.end_recording()

In [49]:
gr_data.print_datasets()

frame_indices (0,)
timesteps (0,)
