<a href="https://colab.research.google.com/github/DmitryPodyachev/bremen/blob/main/3xvideo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# @title Default title text
%cd /content
!git clone -b dev https://github.com/camenduru/generative-models
!apt install mc aria2
!export PYTORCH_ALLOC_CONF=expandable_segments:True
!pip install einops fairscale fire fsspec invisible-watermark kornia ninja omegaconf open-clip-torch opencv-python pandas pillow pytorch-lightning pyyaml scipy timm tokenizers torch torchaudio torchdata torchmetrics torchvision tqdm transformers triton xformers urllib3 gradio
!pip install -q -e generative-models
!pip install -q -e git+https://github.com/Stability-AI/datapipelines@main#egg=sdata
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/vdo/stable-video-diffusion-img2vid-xt/resolve/main/svd_xt.safetensors?download=true -d /content/checkpoints -o svd_xt.safetensors
!mkdir -p /content/scripts/util/detection
!ln -s /content/generative-models/scripts/util/detection/p_head_v1.npz /content/scripts/util/detection/p_head_v1.npz
!ln -s /content/generative-models/scripts/util/detection/w_head_v1.npz /content/scripts/util/detection/w_head_v1.npz


/content
Cloning into 'generative-models'...
remote: Enumerating objects: 850, done.[K
remote: Counting objects: 100% (361/361), done.[K
remote: Compressing objects: 100% (89/89), done.[K
remote: Total 850 (delta 286), reused 272 (delta 272), pack-reused 489 (from 1)[K
Receiving objects: 100% (850/850), 42.65 MiB | 25.86 MiB/s, done.
Resolving deltas: 100% (437/437), done.
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  libaria2-0 libc-ares2 mc-data
Suggested packages:
  arj catdvi | texlive-binaries dbview djvulibre-bin epub-utils genisoimage gv
  imagemagick libaspell-dev links | w3m | lynx odt2txt poppler-utils python
  python-boto python-tz unar wimtools xpdf | pdf-viewer
The following NEW packages will be installed:
  aria2 libaria2-0 libc-ares2 mc mc-data
0 upgraded, 5 newly installed, 0 to remove and 37 not upgraded.
Need to get 3,487 kB of archives.
After this operation, 

In [9]:
!nvidia-smi

Sun Feb 22 22:07:26 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.82.07              Driver Version: 580.82.07      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   50C    P0             28W /   70W |    3205MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

In [1]:
import sys
sys.path.append("generative-models")

import os, math, torch, cv2
from omegaconf import OmegaConf
from glob import glob
from pathlib import Path
from typing import Optional, List
import numpy as np
from einops import rearrange, repeat

from PIL import Image
from torchvision.transforms import ToTensor
from torchvision.transforms import functional as TF
from sgm.util import instantiate_from_config

def load_model(config: str, device: str, num_frames: int, num_steps: int):
    config = OmegaConf.load(config)
    # Set init_device for open_clip_embedding_config to 'cpu' to avoid early GPU allocation
    config.model.params.conditioner_config.params.emb_models[0].params.open_clip_embedding_config.params.init_device = "cpu"
    config.model.params.sampler_config.params.num_steps = num_steps
    config.model.params.sampler_config.params.guider_config.params.num_frames = (num_frames)

    # Instantiate the model on CPU first
    model = instantiate_from_config(config.model).eval().requires_grad_(False)

    # Move model.model (the UNet) to device in float16
    model.model.to(device=device, dtype=torch.float16)

    # Keep conditioner and first_stage_model on CPU initially, as they are moved to CPU later in the sample function
    model.conditioner.to("cpu")
    model.first_stage_model.to("cpu")

    return model

num_frames = 16 # Changed from 25 to 16 to fix einops error
num_steps = 10
model_config = "generative-models/scripts/sampling/configs/svd_xt.yaml"
device = "cuda" if torch.cuda.is_available() else "cpu"

# Call torch.cuda.empty_cache() before loading the model to ensure maximum available memory
torch.cuda.empty_cache()

model = load_model(model_config, device, num_frames, num_steps)

# These lines are redundant now as conditioner and first_stage_model are already on CPU
# model.conditioner.cpu()
# model.first_stage_model.cpu()

# This line is handled inside load_model now
# model.model.to(dtype=torch.float16)

torch.cuda.empty_cache()
model = model.requires_grad_(False)

def get_unique_embedder_keys_from_conditioner(conditioner):
    return list(set([x.input_key for x in conditioner.embedders]))

def get_batch(keys, value_dict, N, T, device, dtype=None):
    batch = {}
    batch_uc = {}
    for key in keys:
        if key == "fps_id":
            batch[key] = (
                torch.tensor([value_dict["fps_id"]])
                .to(device, dtype=dtype)
                .repeat(int(math.prod(N)))
            )
        elif key == "motion_bucket_id":
            batch[key] = (
                torch.tensor([value_dict["motion_bucket_id"]])
                .to(device, dtype=dtype)
                .repeat(int(math.prod(N)))
            )
        elif key == "cond_aug":
            batch[key] = repeat(
                torch.tensor([value_dict["cond_aug"]]).to(device, dtype=dtype),
                "1 -> b",
                b=math.prod(N),
            )
        elif key == "cond_frames":
            batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
        elif key == "cond_frames_without_noise":
            batch[key] = repeat(
                value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
            )
        else:
            batch[key] = value_dict[key]
    if T is not None:
        batch["num_video_frames"] = T
    for key in batch.keys():
        if key not in batch_uc and isinstance(batch[key], torch.Tensor):
            batch_uc[key] = torch.clone(batch[key])
    return batch, batch_uc

def sample(
    input_path: str = "/content/test_image.png",
    resize_image: bool = False,
    num_frames: Optional[int] = None,
    num_steps: Optional[int] = None,
    fps_id: int = 6,
    motion_bucket_id: int = 127,
    cond_aug: float = 0.02,
    seed: int = 23,
    decoding_t: int = 14,  # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
    device: str = "cuda",
    output_folder: Optional[str] = "/content/outputs",
):
    """
    Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
    image file in folder `input_path`.
    """
    torch.manual_seed(seed)

    path = Path(input_path)
    all_img_paths = []
    # Check if input_path is a URL
    if input_path.startswith("http://") or input_path.startswith("https://"):
        import requests
        from io import BytesIO
        response = requests.get(input_path)
        image = Image.open(BytesIO(response.content))
        # Save the image temporarily if the sample function expects a filepath, or pass the PIL Image object
        # For now, let's assume `Image.open` can handle BytesIO and subsequent processing handles PIL Image.
        # If it needs a path, we'd save it to a temp file here.
        # The current `sample` function's internal logic expects input_path to be a path to open, so we need to save it.
        temp_image_path = "/tmp/temp_image.png"
        image.save(temp_image_path)
        all_img_paths = [temp_image_path]
    elif path.is_file():
        if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]):
            all_img_paths = [input_path]
        else:
            raise ValueError("Path is not valid image file.")
    elif path.is_dir():
        all_img_paths = sorted(
            [
                f
                for f in path.iterdir()
                if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
            ]
        )
        if len(all_img_paths) == 0:
            raise ValueError("Folder does not contain any images.")
    else:
        raise ValueError

    all_out_paths = []
    for input_img_path in all_img_paths:
        with Image.open(input_img_path) as image:
            if image.mode == "RGBA":
                image = image.convert("RGB")
            if resize_image and image.size != (1024, 576):
                print(f"Resizing {image.size} to (1024, 576)")
                image = TF.resize(TF.resize(image, 1024), (576, 1024))
            w, h = image.size
            if h % 64 != 0 or w % 64 != 0:
                width, height = map(lambda x: x - x % 64, (w, h))
                image = image.resize((width, height))
                print(
                    f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
                )
            image = ToTensor()(image)
            image = image * 2.0 - 1.0

        image = image.unsqueeze(0).to(device)
        H, W = image.shape[2:]
        assert image.shape[1] == 3
        F = 8
        C = 4
        shape = (num_frames, C, H // F, W // F)
        if (H, W) != (576, 1024):
            print(
                "WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
            )
        if motion_bucket_id > 255:
            print(
                "WARNING: High motion bucket! This may lead to suboptimal performance."
            )
        if fps_id < 5:
            print("WARNING: Small fps value! This may lead to suboptimal performance.")
        if fps_id > 30:
            print("WARNING: Large fps value! This may lead to suboptimal performance.")

        value_dict = {}
        value_dict["motion_bucket_id"] = motion_bucket_id
        value_dict["fps_id"] = fps_id
        value_dict["cond_aug"] = cond_aug
        value_dict["cond_frames_without_noise"] = image
        value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
        value_dict["cond_aug"] = cond_aug
        # low vram mode
        model.conditioner.cpu()
        model.first_stage_model.cpu()
        torch.cuda.empty_cache()
        model.sampler.verbose = True

        with torch.no_grad():
            with torch.autocast(device):
                model.conditioner.to(device)
                batch, batch_uc = get_batch(
                    get_unique_embedder_keys_from_conditioner(model.conditioner),
                    value_dict,
                    [1, num_frames],
                    T=num_frames,
                    device=device,
                )
                c, uc = model.conditioner.get_unconditional_conditioning(
                    batch,
                    batch_uc=batch_uc,
                    force_uc_zero_embeddings=[
                        "cond_frames",
                        "cond_frames_without_noise",
                    ],
                )
                model.conditioner.cpu()
                torch.cuda.empty_cache()

                # from here, dtype is fp16
                for k in ["crossattn", "concat"]:
                    uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
                    uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
                    c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
                    c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
                for k in uc.keys():
                    uc[k] = uc[k].to(dtype=torch.float16)
                    c[k] = c[k].to(dtype=torch.float16)

                randn = torch.randn(shape, device=device, dtype=torch.float16)
                additional_model_inputs = {}
                additional_model_inputs["image_only_indicator"] = torch.zeros(2, num_frames).to(device)
                additional_model_inputs["num_video_frames"] = batch["num_video_frames"]

                for k in additional_model_inputs:
                    if isinstance(additional_model_inputs[k], torch.Tensor):
                        additional_model_inputs[k] = additional_model_inputs[k].to(dtype=torch.float16)

                def denoiser(input, sigma, c):
                    return model.denoiser(model.model, input, sigma, c, **additional_model_inputs)

                samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
                samples_z.to(dtype=model.first_stage_model.dtype)
                model.en_and_decode_n_samples_a_time = decoding_t
                model.first_stage_model.to(device)
                samples_x = model.decode_first_stage(samples_z)
                samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
                model.first_stage_model.cpu()
                torch.cuda.empty_cache()

                os.makedirs(output_folder, exist_ok=True)
                base_count = len(glob(os.path.join(output_folder, "*.mp4")))
                video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
                writer = cv2.VideoWriter(
                    video_path,
                    cv2.VideoWriter_fourcc(*"MP4V"),
                    fps_id + 1,
                    (samples.shape[-1], samples.shape[-2]),
                )
                vid = (
                    (rearrange(samples, "t c h w -> t h w c") * 255)
                    .cpu()
                    .numpy()
                    .astype(np.uint8)
                )
                for frame in vid:
                    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                    writer.write(frame)
                writer.release()
                all_out_paths.append(video_path)
    return all_out_paths

import gradio as gr
import random

def infer(input_path: str, resize_image: bool, n_steps: int, decoding_t: int) -> List[str]:
  generated_videos = []
  for i in range(2): # Generate 2 videos
    seed = random.randint(0, 2**32) # Use a new random seed for each generation
    print(f"Generating video {i+1} with seed: {seed}")
    output_paths = sample(
      input_path=input_path,
      resize_image=resize_image,
      num_frames=num_frames, # Use global num_frames
      num_steps=n_steps,
      fps_id=6,
      motion_bucket_id=127,
      cond_aug=0.02,
      seed=seed,
      decoding_t=decoding_t,  # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
      device=device,
    )
    generated_videos.append(output_paths[0])
  return generated_videos

with gr.Blocks() as demo:
  with gr.Column():
    image = gr.Image(label="input image", type="filepath")
    resize_image = gr.Checkbox(label="resize to optimal size", value=True)
    btn = gr.Button("Run")
    with gr.Accordion(label="Advanced options", open=False):
      # Removed n_frames input as it cannot be changed dynamically
      n_steps = gr.Number(precision=0, label="number of steps", value=num_steps)
      # Removed seed input as it's now handled internally for multiple generations
      decoding_t = gr.Number(precision=0, label="number of frames decoded at a time", value=2)
  with gr.Column():
    gr.Markdown("### Generated Videos:")
    video_out1 = gr.Video(label="Generated Video 1")
    video_out2 = gr.Video(label="Generated Video 2")

  examples = [["https://user-images.githubusercontent.com/33302880/284758167-367a25d8-8d7b-42d3-8391-6d82813c7b0f.png"]]
  inputs = [image, resize_image, n_steps, decoding_t]
  outputs = [video_out1, video_out2]
  btn.click(infer, inputs=inputs, outputs=outputs)
  gr.Examples(examples=examples, inputs=inputs, outputs=outputs, fn=infer)
  demo.queue().launch(debug=True, share=True, inline=False, show_error=True)

VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
Initialized embedder #0: FrozenOpenCLIPImagePredictionEmbedder with 683800065 params. Trainable: False
Initialized embedder #1: ConcatTimestepEmbedderND with 0 params. Trainable: False
Initialized embedder #2: ConcatTimestepEmbedderND with 0 params. Trainable: False
Initialized e

NameError: name 'video_out3' is not defined

Чтобы очистить память GPU от всех загруженных моделей и тензоров, нужно явно удалить переменные, которые на них ссылаются, а затем вызвать `torch.cuda.empty_cache()`.

Следующий код поможет вам это сделать. Он ищет в глобальном пространстве имен все объекты, которые являются экземплярами `torch.nn.Module` (т.е. ваши модели) или `torch.Tensor`, и если они находятся на CUDA, удаляет их.

**Важно:** Убедитесь, что вы действительно больше не нуждаетесь в этих моделях/тензорах, прежде чем запускать этот код, так как он безвозвратно удалит их из памяти.

In [8]:
import gc
import torch

def clear_all_gpu_memory():
    print("Attempting to clear all GPU memory...")
    # Удалить все ссылки на модели и тензоры из глобального пространства имен
    for name in dir():
        if not name.startswith('_'): # Избегаем внутренних переменных Python
            var = globals()[name]
            if isinstance(var, (torch.nn.Module, torch.Tensor)):
                if hasattr(var, 'is_cuda') and var.is_cuda:
                    print(f"Deleting GPU object: {name} (Type: {type(var)}) ")
                    del globals()[name]
            elif isinstance(var, (list, tuple, dict)):
                # Для коллекций, попытаться очистить их содержимое, если оно на GPU
                # Это упрощенный подход, для глубокой очистки нужен рекурсивный обход
                pass

    # Запустить сборщик мусора Python
    gc.collect()
    # Очистить кэш памяти GPU
    torch.cuda.empty_cache()
    print("GPU memory cleared. Free memory: ", torch.cuda.memory_reserved() / 1024**3, "GB")

# Вызов функции для очистки памяти
clear_all_gpu_memory()

# Пример: если 'model' была определена как глобальная переменная, она будет удалена.
# Если у вас есть другие глобальные переменные, такие как 'samples_z', 'c', 'uc', 'randn' из функции sample,
# вы можете явно удалить и их:
# del samples_z
# del c
# del uc
# del randn
# gc.collect()
# torch.cuda.empty_cache()

# Если вы хотите очистить модель 'model' из предыдущей ячейки:
if 'model' in globals():
    print("Deleting global 'model' variable.")
    del model
gc.collect()
torch.cuda.empty_cache()


Attempting to clear all GPU memory...
GPU memory cleared. Free memory:  2.99609375 GB


In [7]:
import sys
sys.path.append("generative-models")

import os, math, torch, cv2
from omegaconf import OmegaConf
from glob import glob
from pathlib import Path
from typing import Optional, List
import numpy as np
from einops import rearrange, repeat

from PIL import Image
from torchvision.transforms import ToTensor
from torchvision.transforms import functional as TF
from sgm.util import instantiate_from_config

def load_model(config: str, device: str, num_frames: int, num_steps: int):
    config = OmegaConf.load(config)
    # Set init_device for open_clip_embedding_config to 'cpu' to avoid early GPU allocation
    config.model.params.conditioner_config.params.emb_models[0].params.open_clip_embedding_config.params.init_device = "cpu"
    config.model.params.sampler_config.params.num_steps = num_steps
    config.model.params.sampler_config.params.guider_config.params.num_frames = (num_frames)

    # Instantiate the model on CPU first
    model = instantiate_from_config(config.model).eval().requires_grad_(False)

    # Move model.model (the UNet) to device in float16
    model.model.to(device=device, dtype=torch.float16)

    # Keep conditioner and first_stage_model on CPU initially, as they are moved to CPU later in the sample function
    model.conditioner.to("cpu")
    model.first_stage_model.to("cpu")

    return model

num_frames = 48 # Changed from 25 to 16 to fix einops error
num_steps = 30
model_config = "generative-models/scripts/sampling/configs/svd_xt.yaml"
device = "cuda" if torch.cuda.is_available() else "cpu"

# Call torch.cuda.empty_cache() before loading the model to ensure maximum available memory
torch.cuda.empty_cache()

model = load_model(model_config, device, num_frames, num_steps)

# These lines are redundant now as conditioner and first_stage_model are already on CPU
# model.conditioner.cpu()
# model.first_stage_model.cpu()

# This line is handled inside load_model now
# model.model.to(dtype=torch.float16)

torch.cuda.empty_cache()
model = model.requires_grad_(False)

def get_unique_embedder_keys_from_conditioner(conditioner):
    return list(set([x.input_key for x in conditioner.embedders]))

def get_batch(keys, value_dict, N, T, device, dtype=None):
    batch = {}
    batch_uc = {}
    for key in keys:
        if key == "fps_id":
            batch[key] = (
                torch.tensor([value_dict["fps_id"]])
                .to(device, dtype=dtype)
                .repeat(int(math.prod(N)))
            )
        elif key == "motion_bucket_id":
            batch[key] = (
                torch.tensor([value_dict["motion_bucket_id"]])
                .to(device, dtype=dtype)
                .repeat(int(math.prod(N)))
            )
        elif key == "cond_aug":
            batch[key] = repeat(
                torch.tensor([value_dict["cond_aug"]]).to(device, dtype=dtype),
                "1 -> b",
                b=math.prod(N),
            )
        elif key == "cond_frames":
            batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
        elif key == "cond_frames_without_noise":
            batch[key] = repeat(
                value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
            )
        else:
            batch[key] = value_dict[key]
    if T is not None:
        batch["num_video_frames"] = T
    for key in batch.keys():
        if key not in batch_uc and isinstance(batch[key], torch.Tensor):
            batch_uc[key] = torch.clone(batch[key])
    return batch, batch_uc

def sample(
    input_path: str = "/content/test_image.png",
    resize_image: bool = False,
    num_frames: Optional[int] = None,
    num_steps: Optional[int] = None,
    fps_id: int = 6,
    motion_bucket_id: int = 127,
    cond_aug: float = 0.02,
    seed: int = 23,
    decoding_t: int = 14,  # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
    device: str = "cuda",
    output_folder: Optional[str] = "/content/outputs",
):
    """
    Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
    image file in folder `input_path`.
    """
    torch.manual_seed(seed)

    path = Path(input_path)
    all_img_paths = []
    # Check if input_path is a URL
    if input_path.startswith("http://") or input_path.startswith("https://"):
        import requests
        from io import BytesIO
        response = requests.get(input_path)
        image = Image.open(BytesIO(response.content))
        # Save the image temporarily if the sample function expects a filepath, or pass the PIL Image object
        # For now, let's assume `Image.open` can handle BytesIO and subsequent processing handles PIL Image.
        # If it needs a path, we'd save it to a temp file here.
        # The current `sample` function's internal logic expects input_path to be a path to open, so we need to save it.
        temp_image_path = "/tmp/temp_image.png"
        image.save(temp_image_path)
        all_img_paths = [temp_image_path]
    elif path.is_file():
        if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]):
            all_img_paths = [input_path]
        else:
            raise ValueError("Path is not valid image file.")
    elif path.is_dir():
        all_img_paths = sorted(
            [
                f
                for f in path.iterdir()
                if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
            ]
        )
        if len(all_img_paths) == 0:
            raise ValueError("Folder does not contain any images.")
    else:
        raise ValueError

    all_out_paths = []
    for input_img_path in all_img_paths:
        with Image.open(input_img_path) as image:
            if image.mode == "RGBA":
                image = image.convert("RGB")
            if resize_image and image.size != (1024, 576):
                print(f"Resizing {image.size} to (1024, 576)")
                image = TF.resize(TF.resize(image, 1024), (576, 1024))
            w, h = image.size
            if h % 64 != 0 or w % 64 != 0:
                width, height = map(lambda x: x - x % 64, (w, h))
                image = image.resize((width, height))
                print(
                    f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
                )
            image = ToTensor()(image)
            image = image * 2.0 - 1.0

        image = image.unsqueeze(0).to(device)
        H, W = image.shape[2:]
        assert image.shape[1] == 3
        F = 8
        C = 4
        shape = (num_frames, C, H // F, W // F)
        if (H, W) != (576, 1024):
            print(
                "WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
            )
        if motion_bucket_id > 255:
            print(
                "WARNING: High motion bucket! This may lead to suboptimal performance."
            )
        if fps_id < 5:
            print("WARNING: Small fps value! This may lead to suboptimal performance.")
        if fps_id > 30:
            print("WARNING: Large fps value! This may lead to suboptimal performance.")

        value_dict = {}
        value_dict["motion_bucket_id"] = motion_bucket_id
        value_dict["fps_id"] = fps_id
        value_dict["cond_aug"] = cond_aug
        value_dict["cond_frames_without_noise"] = image
        value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
        value_dict["cond_aug"] = cond_aug
        # low vram mode
        model.conditioner.cpu()
        model.first_stage_model.cpu()
        torch.cuda.empty_cache()
        model.sampler.verbose = True

        with torch.no_grad():
            with torch.autocast(device):
                model.conditioner.to(device)
                batch, batch_uc = get_batch(
                    get_unique_embedder_keys_from_conditioner(model.conditioner),
                    value_dict,
                    [1, num_frames],
                    T=num_frames,
                    device=device,
                )
                c, uc = model.conditioner.get_unconditional_conditioning(
                    batch,
                    batch_uc=batch_uc,
                    force_uc_zero_embeddings=[
                        "cond_frames",
                        "cond_frames_without_noise",
                    ],
                )
                model.conditioner.cpu()
                torch.cuda.empty_cache()

                # from here, dtype is fp16
                for k in ["crossattn", "concat"]:
                    uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
                    uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
                    c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
                    c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
                for k in uc.keys():
                    uc[k] = uc[k].to(dtype=torch.float16)
                    c[k] = c[k].to(dtype=torch.float16)

                randn = torch.randn(shape, device=device, dtype=torch.float16)
                additional_model_inputs = {}
                additional_model_inputs["image_only_indicator"] = torch.zeros(2, num_frames).to(device)
                additional_model_inputs["num_video_frames"] = batch["num_video_frames"]

                for k in additional_model_inputs:
                    if isinstance(additional_model_inputs[k], torch.Tensor):
                        additional_model_inputs[k] = additional_model_inputs[k].to(dtype=torch.float16)

                def denoiser(input, sigma, c):
                    return model.denoiser(model.model, input, sigma, c, **additional_model_inputs)

                samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
                samples_z.to(dtype=model.first_stage_model.dtype)
                model.en_and_decode_n_samples_a_time = decoding_t
                model.first_stage_model.to(device)
                samples_x = model.decode_first_stage(samples_z)
                samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
                model.first_stage_model.cpu()
                torch.cuda.empty_cache()

                os.makedirs(output_folder, exist_ok=True)
                base_count = len(glob(os.path.join(output_folder, "*.mp4")))
                video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
                writer = cv2.VideoWriter(
                    video_path,
                    cv2.VideoWriter_fourcc(*"MP4V"),
                    fps_id + 1,
                    (samples.shape[-1], samples.shape[-2]),
                )
                vid = (
                    (rearrange(samples, "t c h w -> t h w c") * 255)
                    .cpu()
                    .numpy()
                    .astype(np.uint8)
                )
                for frame in vid:
                    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                    writer.write(frame)
                writer.release()
                all_out_paths.append(video_path)
    return all_out_paths

import gradio as gr
import random

def infer(input_path: str, resize_image: bool, n_steps: int, decoding_t: int) -> List[str]:
  generated_videos = []
  for i in range(3): # Generate 3 videos
    seed = random.randint(0, 2**32) # Use a new random seed for each generation
    print(f"Generating video {i+1} with seed: {seed}")
    output_paths = sample(
      input_path=input_path,
      resize_image=resize_image,
      num_frames=num_frames, # Use global num_frames
      num_steps=n_steps,
      fps_id=6,
      motion_bucket_id=127,
      cond_aug=0.02,
      seed=seed,
      decoding_t=decoding_t,  # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
      device=device,
    )
    generated_videos.append(output_paths[0])
  return generated_videos

with gr.Blocks() as demo:
  with gr.Column():
    image = gr.Image(label="input image", type="filepath")
    resize_image = gr.Checkbox(label="resize to optimal size", value=True)
    btn = gr.Button("Run")
    with gr.Accordion(label="Advanced options", open=False):
      # Removed n_frames input as it cannot be changed dynamically
      n_steps = gr.Number(precision=0, label="number of steps", value=num_steps)
      # Removed seed input as it's now handled internally for multiple generations
      decoding_t = gr.Number(precision=0, label="number of frames decoded at a time", value=2)
  with gr.Column():
    gr.Markdown("### Generated Videos:")
    video_out1 = gr.Video(label="Generated Video 1")
    video_out2 = gr.Video(label="Generated Video 2")


  examples = [["https://user-images.githubusercontent.com/33302880/284758167-367a25d8-8d7b-42d3-8391-6d82813c7b0f.png"]]
  inputs = [image, resize_image, n_steps, decoding_t]
  outputs = [video_out1, video_out2]
  btn.click(infer, inputs=inputs, outputs=outputs)
  gr.Examples(examples=examples, inputs=inputs, outputs=outputs, fn=infer)
  demo.queue().launch(debug=True, share=True, inline=False, show_error=True)

VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing


KeyboardInterrupt: 