In [4]:
%load_ext autoreload
%autoreload 2
import torch
import wandb

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from misc.cross_frame_attn import CrossFrameAttnProcessor

device = torch.device('cuda')
torch_dtype = torch.float16
sd_repo = "runwayml/stable-diffusion-v1-5"
controlnet_repo = "lllyasviel/sd-controlnet-depth"

controlnet = ControlNetModel.from_pretrained(
    controlnet_repo,
    torch_dtype=torch_dtype,
    device=device
)

pipe = StableDiffusionControlNetPipeline.from_pretrained(
    sd_repo,
    controlnet=controlnet,
    device=device,
    torch_dtype=torch_dtype,
    safety_checker=None
)
pipe.enable_model_cpu_offload(gpu_id=0)

cross_frame_processor = CrossFrameAttnProcessor(unet_chunk_size=2)
pipe.unet.set_attn_processor(processor=cross_frame_processor)
pipe.controlnet.set_attn_processor(processor=cross_frame_processor)

  from .autonotebook import tqdm as notebook_tqdm
  warn(
Keyword arguments {'device': device(type='cuda')} are not expected by StableDiffusionControlNetPipeline and will be ignored.
Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00,  6.48it/s]
You have disabled the safety checker for <class 'diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


In [3]:
from text3d2video.artifacts.animation_artifact import AnimationArtifact
from text3d2video.util import front_camera
import wandb

api = wandb.Api()
animation = 'backflip:latest'
animation = api.artifact(f'romeu/diffusion-3D-features/{animation}')
animation = AnimationArtifact.from_wandb_artifact(animation)

batch_size = 7
frames = animation.load_frames(range(1, batch_size + 1))
camera = front_camera()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Downloading large artifact backflip:latest, 121.09MB. 61 files... 
[34m[1mwandb[0m:   61 of 61 files downloaded.  
Done. 0:0:0.4


In [4]:
from text3d2video.rendering import make_rasterizer, normalize_depth_map
import torchvision.transforms.functional as TF

rasterizer = make_rasterizer(camera, 512)
fragments = rasterizer(frames)
depth_map_normalized = normalize_depth_map(fragments.zbuf)
depth_ims = [TF.to_pil_image(depth_map_normalized[i, :, :, 0].cpu()) for i in range(batch_size)]

In [34]:
from scripts.generate_video import generate_video
from text3d2video.artifacts.video_artifact import VideoArtifact
from text3d2video.ipython_utils import display_ims

wandb.init(project='diffusion-3D-features', job_type='generate_video')

name = 'spiderman'
prompt = "Spiderman doing a backflip"

wandb.log({'prompt': prompt})
wandb.use_artifact(animation.wandb_artifact)

ims = generate_video(pipe, prompt, depth_ims)

video = VideoArtifact.create_wandb_artifact(f'{name}-cross-frame', frames=ims, fps=30)
wandb.log_artifact(video)

0,1
prompt,Stormtrooper doing a...


100%|██████████| 50/50 [00:31<00:00,  1.58it/s]


Moviepy - Building video /tmp/tmpa71hikgk/video.mp4.
Moviepy - Writing video /tmp/tmpa71hikgk/video.mp4



[34m[1mwandb[0m: Adding directory to artifact (/tmp/tmpa71hikgk)... Done. 0.0s


Moviepy - Done !
Moviepy - video ready /tmp/tmpa71hikgk/video.mp4


<Artifact spiderman-cross-frame>

In [24]:
from text3d2video.artifacts.video_artifact import VideoArtifact
from text3d2video.wandb_util import api_artifact

artifact_tag = 'stormtrooper-cross-frame:latest'

video = api_artifact(artifact_tag)
video = VideoArtifact.from_wandb_artifact(video)
video.ipy_display()

[34m[1mwandb[0m:   1 of 1 files downloaded.  
