In [1]:
import torch

from diffusers import StableVideoDiffusionPipeline
from diffusers.utils import load_image, export_to_video

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def custom_repr(self):
    return f'{{Tensor:{tuple(self.shape)}}} {original_repr(self)}'

torch.set_printoptions(sci_mode=False)
original_repr = torch.Tensor.__repr__
torch.Tensor.__repr__ = custom_repr

In [10]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda:3" if torch.cuda.is_available() else "cpu")

pipe = StableVideoDiffusionPipeline.from_pretrained(
    "stabilityai/stable-video-diffusion-img2vid-xt" #, torch_dtype=torch.float16, variant="fp16"
).to(device)

Loading pipeline components...: 100%|██████████| 5/5 [00:00<00:00,  6.66it/s]


In [11]:
pipe.enable_model_cpu_offload()

In [13]:
pipe

StableVideoDiffusionPipeline {
  "_class_name": "StableVideoDiffusionPipeline",
  "_diffusers_version": "0.31.0.dev0",
  "_name_or_path": "stabilityai/stable-video-diffusion-img2vid-xt",
  "feature_extractor": [
    "transformers",
    "CLIPImageProcessor"
  ],
  "image_encoder": [
    "transformers",
    "CLIPVisionModelWithProjection"
  ],
  "scheduler": [
    "diffusers",
    "EulerDiscreteScheduler"
  ],
  "unet": [
    "diffusers",
    "UNetSpatioTemporalConditionModel"
  ],
  "vae": [
    "diffusers",
    "AutoencoderKLTemporalDecoder"
  ]
}

In [14]:

# Load the conditioning image
image = load_image("../images/fire-car-11-700.png")
image = image.resize((1024, 576))
image = image.resize((576, 320))    # convience for previewing
print(1024/576, 512/288, 384/256, 576/320)

1.7777777777777777 1.7777777777777777 1.5 1.8


In [15]:
generator = torch.manual_seed(42)
frames = pipe(image, decode_chunk_size=8, generator=generator, height=320, width=576).frames[0]

  0%|          | 0/25 [00:00<?, ?it/s]

debug: unet.forward(): sample: torch.Size([2, 25, 8, 40, 72])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 640, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 320, 20, 36])
hidden_states.shape=torch.Size([50, 640, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 3

  4%|▍         | 1/25 [00:05<02:17,  5.72s/it]

debug: unet.forward(): sample: torch.Size([2, 25, 8, 40, 72])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 640, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 320, 20, 36])
hidden_states.shape=torch.Size([50, 640, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 3

  8%|▊         | 2/25 [00:08<01:33,  4.06s/it]

debug: unet.forward(): sample: torch.Size([2, 25, 8, 40, 72])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 640, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 320, 20, 36])
hidden_states.shape=torch.Size([50, 640, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 3

 12%|█▏        | 3/25 [00:11<01:14,  3.38s/it]

debug: unet.forward(): sample: torch.Size([2, 25, 8, 40, 72])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 640, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 320, 20, 36])
hidden_states.shape=torch.Size([50, 640, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 3

 16%|█▌        | 4/25 [00:14<01:07,  3.19s/it]

debug: unet.forward(): sample: torch.Size([2, 25, 8, 40, 72])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 640, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 320, 20, 36])
hidden_states.shape=torch.Size([50, 640, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 3

 20%|██        | 5/25 [00:16<00:59,  2.97s/it]

debug: unet.forward(): sample: torch.Size([2, 25, 8, 40, 72])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 640, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 320, 20, 36])
hidden_states.shape=torch.Size([50, 640, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 3

 24%|██▍       | 6/25 [00:19<00:55,  2.90s/it]

debug: unet.forward(): sample: torch.Size([2, 25, 8, 40, 72])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 640, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 320, 20, 36])
hidden_states.shape=torch.Size([50, 640, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 3

 28%|██▊       | 7/25 [00:21<00:50,  2.79s/it]

debug: unet.forward(): sample: torch.Size([2, 25, 8, 40, 72])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 640, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 320, 20, 36])
hidden_states.shape=torch.Size([50, 640, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 3

 32%|███▏      | 8/25 [00:24<00:47,  2.77s/it]

debug: unet.forward(): sample: torch.Size([2, 25, 8, 40, 72])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 640, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 320, 20, 36])
hidden_states.shape=torch.Size([50, 640, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 3

 36%|███▌      | 9/25 [00:27<00:43,  2.71s/it]

debug: unet.forward(): sample: torch.Size([2, 25, 8, 40, 72])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 640, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 320, 20, 36])
hidden_states.shape=torch.Size([50, 640, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 3

 40%|████      | 10/25 [00:29<00:40,  2.71s/it]

debug: unet.forward(): sample: torch.Size([2, 25, 8, 40, 72])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 640, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 320, 20, 36])
hidden_states.shape=torch.Size([50, 640, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 3

 44%|████▍     | 11/25 [00:32<00:37,  2.67s/it]

debug: unet.forward(): sample: torch.Size([2, 25, 8, 40, 72])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 640, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 320, 20, 36])
hidden_states.shape=torch.Size([50, 640, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 3

 48%|████▊     | 12/25 [00:35<00:34,  2.68s/it]

debug: unet.forward(): sample: torch.Size([2, 25, 8, 40, 72])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 640, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 320, 20, 36])
hidden_states.shape=torch.Size([50, 640, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 3

 52%|█████▏    | 13/25 [00:37<00:31,  2.65s/it]

debug: unet.forward(): sample: torch.Size([2, 25, 8, 40, 72])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 640, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 320, 20, 36])
hidden_states.shape=torch.Size([50, 640, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 3

 56%|█████▌    | 14/25 [00:40<00:29,  2.67s/it]

debug: unet.forward(): sample: torch.Size([2, 25, 8, 40, 72])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 640, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 320, 20, 36])
hidden_states.shape=torch.Size([50, 640, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 3

 60%|██████    | 15/25 [00:43<00:26,  2.64s/it]

debug: unet.forward(): sample: torch.Size([2, 25, 8, 40, 72])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 640, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 320, 20, 36])
hidden_states.shape=torch.Size([50, 640, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 3

 64%|██████▍   | 16/25 [00:45<00:23,  2.66s/it]

debug: unet.forward(): sample: torch.Size([2, 25, 8, 40, 72])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 640, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 320, 20, 36])
hidden_states.shape=torch.Size([50, 640, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 3

 68%|██████▊   | 17/25 [00:48<00:21,  2.64s/it]

debug: unet.forward(): sample: torch.Size([2, 25, 8, 40, 72])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 640, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 320, 20, 36])
hidden_states.shape=torch.Size([50, 640, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 3

 72%|███████▏  | 18/25 [00:51<00:18,  2.66s/it]

debug: unet.forward(): sample: torch.Size([2, 25, 8, 40, 72])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 640, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 320, 20, 36])
hidden_states.shape=torch.Size([50, 640, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 3

 76%|███████▌  | 19/25 [00:53<00:15,  2.64s/it]

debug: unet.forward(): sample: torch.Size([2, 25, 8, 40, 72])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 640, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 320, 20, 36])
hidden_states.shape=torch.Size([50, 640, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 3

 80%|████████  | 20/25 [00:56<00:13,  2.67s/it]

debug: unet.forward(): sample: torch.Size([2, 25, 8, 40, 72])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 640, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 320, 20, 36])
hidden_states.shape=torch.Size([50, 640, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 3

 84%|████████▍ | 21/25 [00:59<00:10,  2.65s/it]

debug: unet.forward(): sample: torch.Size([2, 25, 8, 40, 72])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 640, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 320, 20, 36])
hidden_states.shape=torch.Size([50, 640, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 3

 88%|████████▊ | 22/25 [01:01<00:07,  2.66s/it]

debug: unet.forward(): sample: torch.Size([2, 25, 8, 40, 72])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 640, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 320, 20, 36])
hidden_states.shape=torch.Size([50, 640, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 3

 92%|█████████▏| 23/25 [01:04<00:05,  2.64s/it]

debug: unet.forward(): sample: torch.Size([2, 25, 8, 40, 72])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 640, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 320, 20, 36])
hidden_states.shape=torch.Size([50, 640, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 3

 96%|█████████▌| 24/25 [01:07<00:02,  2.64s/it]

debug: unet.forward(): sample: torch.Size([2, 25, 8, 40, 72])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 1280, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 10, 18]), res_hidden_states.shape=torch.Size([50, 640, 10, 18])
hidden_states.shape=torch.Size([50, 1280, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 640, 20, 36])
hidden_states.shape=torch.Size([50, 640, 20, 36]), res_hidden_states.shape=torch.Size([50, 320, 20, 36])
hidden_states.shape=torch.Size([50, 640, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 320, 40, 72])
hidden_states.shape=torch.Size([50, 320, 40, 72]), res_hidden_states.shape=torch.Size([50, 3

100%|██████████| 25/25 [01:09<00:00,  2.78s/it]


In [16]:
print(len(frames), frames)

25 [<PIL.Image.Image image mode=RGB size=576x320 at 0x7F96B831B220>, <PIL.Image.Image image mode=RGB size=576x320 at 0x7F96B831ADD0>, <PIL.Image.Image image mode=RGB size=576x320 at 0x7F96B831B490>, <PIL.Image.Image image mode=RGB size=576x320 at 0x7F96B831B340>, <PIL.Image.Image image mode=RGB size=576x320 at 0x7F96B831B0A0>, <PIL.Image.Image image mode=RGB size=576x320 at 0x7F96B831B010>, <PIL.Image.Image image mode=RGB size=576x320 at 0x7F96B831AC80>, <PIL.Image.Image image mode=RGB size=576x320 at 0x7F96B831B6A0>, <PIL.Image.Image image mode=RGB size=576x320 at 0x7F96B831B3A0>, <PIL.Image.Image image mode=RGB size=576x320 at 0x7F96B831AE30>, <PIL.Image.Image image mode=RGB size=576x320 at 0x7F96B831B4C0>, <PIL.Image.Image image mode=RGB size=576x320 at 0x7F96B831AFE0>, <PIL.Image.Image image mode=RGB size=576x320 at 0x7F96B831B460>, <PIL.Image.Image image mode=RGB size=576x320 at 0x7F96B831B610>, <PIL.Image.Image image mode=RGB size=576x320 at 0x7F96B831B520>, <PIL.Image.Image imag

In [17]:
export_to_video(frames, "../results/fire-car-11-700-32bit-small.mp4", fps=7)

'../results/fire-car-11-700-32bit-small.mp4'