# **InfiniteTalk & MultiTalk with WAN2.1**
-  This notebook is based on ComfyUI_WanVideoWrapper by Kijai : https://github.com/kijai/ComfyUI-WanVideoWrapper
- You can use this notebook for audio driven image to video generation on the L4 or A100 GPU. Videos are generated at 25 fps. This means that a 5-second video requires 125 frames. The A100 is recommended for generations beyond 5 seconds.  
- If using the L4, then I recommend you use a Q4_K_M GGUF Wan 2.1 Model or a lower quant and reduce the image resolution below 480p to avoid OOM errors. Generating a 17-second (425 frames) 432x768 video with the Q4_K_M GGUF model almost crashed the A100. You can enable 'use_block_swap' to manage VRAM for longer generations at the cost of increased generation time.
- You can find models in these huggingface repos: (1) https://huggingface.co/Kijai/WanVideo_comfy/tree/main (2) https://huggingface.co/Comfy-Org/models (3) https://huggingface.co/MeiGen-AI/InfiniteTalk/tree/main (4) https://huggingface.co/city96/models
- Make sure the dimensions of your image are divisible by 16.
- Note that while multitalk uses one model for both single-person and multiple-people generation, infiniteTalk uses different models.
- Github projects: InfiniteTalk -> https://github.com/MeiGen-AI/InfiniteTalk MultiTalk -> https://github.com/MeiGen-AI/MultiTalk
- Notebook source: https://github.com/Isi-dev/Google-Colab_Notebooks
- Premium notebooks I highly recommend: https://isinse.gumroad.com/






In [None]:
# Default links
# wan21_model_download_url:
# https://huggingface.co/city96/Wan2.1-I2V-14B-480P-gguf/resolve/main/wan2.1-i2v-14b-480p-Q4_K_M.gguf
# speed_LoRA_download_url:
# https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank64_bf16.safetensors
# infiniteTalk_url:
# https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/InfiniteTalk/Wan2_1-InfiniTetalk-Single_fp16.safetensors
# multiTalk_url:
# https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/WanVideo_2_1_Multitalk_14B_fp8_e4m3fn.safetensors
# infiniteTalk_url (for multiple speakers):
# https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/InfiniteTalk/Wan2_1-InfiniteTalk-Multi_fp16.safetensors

























# @markdown # 💥1. Prepare Environment
# !pip install --upgrade --quiet torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# !pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0
!pip install torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0
%cd /content
from IPython.display import clear_output
# !pip install -q torchsde einops diffusers accelerate xformers==0.0.29.post2 triton==3.2.0 sageattention
!pip install -q torchsde einops diffusers accelerate xformers==0.0.32.post1 triton==3.4 sageattention
clear_output()
!pip install av spandrel albumentations onnx opencv-python color-matcher segment_anything ultralytics onnxruntime
clear_output()
!pip install onnxruntime-gpu -y
clear_output()
!git clone --branch ComfyUI_v0.3.47 https://github.com/Isi-dev/ComfyUI
clear_output()
%cd /content/ComfyUI/custom_nodes
# !git clone --branch forHidream  https://github.com/Isi-dev/ComfyUI_GGUF.git
# clear_output()
# !git clone --branch kjnv1.1.3 https://github.com/Isi-dev/ComfyUI_KJNodes.git
# clear_output()
!git clone https://github.com/Isi-dev/ComfyUI_WanVideoWrapper
!git clone https://github.com/Isi-dev/audio_separation_nodes_comfyui
# %cd /content/ComfyUI/custom_nodes/ComfyUI_GGUF
# !pip install -r requirements.txt
clear_output()
# %cd /content/ComfyUI/custom_nodes/ComfyUI_KJNodes
# !pip install -r requirements.txt
%cd /content/ComfyUI/custom_nodes/ComfyUI_WanVideoWrapper
!pip install -r requirements.txt
%cd /content/ComfyUI/custom_nodes/audio_separation_nodes_comfyui
!pip install -r requirements.txt
clear_output()


%cd /content/ComfyUI
!apt -y install -qq aria2 ffmpeg
clear_output()


import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

from pathlib import Path
import torch
import torchaudio
import numpy as np
import cv2
from PIL import Image
import gc
import sys
import random
import imageio
import subprocess
import shutil
from google.colab import files
from IPython.display import display, HTML, Image as IPImage
sys.path.insert(0, '/content/ComfyUI')

from comfy import model_management

from nodes import (
    CheckpointLoaderSimple,
    CLIPLoader,
    CLIPTextEncode,
    VAEDecode,
    VAELoader,
    KSampler,
    KSamplerAdvanced,
    UNETLoader,
    LoadImage,
    SaveImage,
    CLIPVisionLoader,
    CLIPVisionEncode,
    LoraLoaderModelOnly,
    ImageScale
)

# from custom_nodes.ComfyUI_GGUF.nodes import UnetLoaderGGUF
# from custom_nodes.ComfyUI_KJNodes.nodes.model_optimization_nodes import (
#     WanVideoTeaCacheKJ,
#     PathchSageAttentionKJ,
#     WanVideoNAG,
#     SkipLayerGuidanceWanVideo
# )

from custom_nodes.ComfyUI_WanVideoWrapper.multitalk.nodes import (
    MultiTalkModelLoader,
    MultiTalkWav2VecEmbeds,
    WanVideoImageToVideoMultiTalk
)

from custom_nodes.ComfyUI_WanVideoWrapper.nodes import (
    WanVideoSampler,
    WanVideoContextOptions,
    WanVideoTextEmbedBridge,
    WanVideoDecode,
    WanVideoClipVisionEncode
)

from custom_nodes.ComfyUI_WanVideoWrapper.nodes_model_loading import (
    WanVideoModelLoader,
    WanVideoVAELoader,
    WanVideoLoraSelect,
    WanVideoBlockSwap
)



from custom_nodes.ComfyUI_WanVideoWrapper.fantasytalking.nodes import DownloadAndLoadWav2VecModel

from comfy_extras.nodes_audio import LoadAudio
from custom_nodes.audio_separation_nodes_comfyui.src.separation import AudioSeparation
from custom_nodes.audio_separation_nodes_comfyui.src.crop import AudioCrop

# from comfy_extras.nodes_model_advanced import ModelSamplingSD3
from comfy_extras.nodes_images import SaveAnimatedWEBP
from comfy_extras.nodes_video import SaveWEBM
# from comfy_extras.nodes_wan import WanImageToVideo
# from comfy_extras.nodes_wan import WanFirstLastFrameToVideo
# from comfy_extras.nodes_upscale_model import UpscaleModelLoader


def download_with_aria2c(link, folder="/content/ComfyUI/models/loras"):
    import os

    filename = link.split("/")[-1]
    command = f"aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {link} -d {folder} -o {filename}"

    print("Executing download command:")
    print(command)

    os.makedirs(folder, exist_ok=True)
    get_ipython().system(command)

    return filename



def download_civitai_model(civitai_link, civitai_token, folder="/content/ComfyUI/models/loras"):
    import os
    import time

    os.makedirs(folder, exist_ok=True)

    try:
        model_id = civitai_link.split("/models/")[1].split("?")[0]
    except IndexError:
        raise ValueError("Invalid Civitai URL format. Please use a link like: https://civitai.com/api/download/models/1523247?...")

    civitai_url = f"https://civitai.com/api/download/models/{model_id}?type=Model&format=SafeTensor"
    if civitai_token:
        civitai_url += f"&token={civitai_token}"

    timestamp = time.strftime("%Y%m%d_%H%M%S")
    filename = f"model_{timestamp}.safetensors"

    full_path = os.path.join(folder, filename)

    download_command = f"wget --max-redirect=10 --show-progress \"{civitai_url}\" -O \"{full_path}\""
    print("Downloading from Civitai...")

    os.system(download_command)

    local_path = os.path.join(folder, filename)
    if os.path.exists(local_path) and os.path.getsize(local_path) > 0:
        print(f"LoRA downloaded successfully: {local_path}")
    else:
        print(f"❌ LoRA download failed or file is empty: {local_path}")

    return filename

def download_lora(link, folder="/content/ComfyUI/models/loras", civitai_token=None):
    """
    Download a model file, automatically detecting if it's a Civitai link or huggingface download.

    Args:
        link: The download URL (either huggingface or Civitai)
        folder: Destination folder for the download
        civitai_token: Optional token for Civitai downloads (required if link is from Civitai)

    Returns:
        The filename of the downloaded model
    """
    if "civitai.com" in link.lower():
        if not civitai_token:
            raise ValueError("Civitai token is required for Civitai downloads")
        return download_civitai_model(link, civitai_token, folder)
    else:
        return download_with_aria2c(link, folder)



def model_download(url: str, dest_dir: str, filename: str = None, silent: bool = True) -> bool:
    """
    Colab-optimized download with aria2c

    Args:
        url: Download URL
        dest_dir: Target directory (will be created if needed)
        filename: Optional output filename (defaults to URL filename)
        silent: If True, suppresses all output (except errors)

    Returns:
        bool: True if successful, False if failed
    """
    try:
        # Create destination directory
        Path(dest_dir).mkdir(parents=True, exist_ok=True)

        # Set filename if not specified
        if filename is None:
            filename = url.split('/')[-1].split('?')[0]  # Remove URL parameters

        # Build command
        cmd = [
            'aria2c',
            '--console-log-level=error',
            '-c', '-x', '16', '-s', '16', '-k', '1M',
            '-d', dest_dir,
            '-o', filename,
            url
        ]

        # Add silent flags if requested
        if silent:
            cmd.extend(['--summary-interval=0', '--quiet'])
            print(f"Downloading {filename}...", end=' ', flush=True)

        # Run download
        result = subprocess.run(cmd, check=True, capture_output=True, text=True)

        if silent:
            print("Done!")
        else:
            print(f"Downloaded {filename} to {dest_dir}")
        return filename

    except subprocess.CalledProcessError as e:
        error = e.stderr.strip() or "Unknown error"
        print(f"\nError downloading {filename}: {error}")
        return False
    except Exception as e:
        print(f"\nError: {str(e)}")
        return False


# model_quant = "Q4_K_M" # @param ["Q4_K_M", "Q5_K_M", "Q6_K", "Q8_0"]
# lightx2v_rank = "128" # @param ["32", "64", "128"]
lightx2v_rank = "32"

# use_preferred_wanModels = True # @param {type:"boolean"}
use_preferred_wanModels = False
# high_noise_model_download_url = "https://huggingface.co/bullerwins/Wan2.2-I2V-A14B-GGUF/resolve/main/wan2.2_i2v_high_noise_14B_Q4_K_M.gguf"# @param {"type":"string"}
# low_noise_model_download_url = "https://huggingface.co/bullerwins/Wan2.2-I2V-A14B-GGUF/resolve/main/wan2.2_i2v_low_noise_14B_Q4_K_M.gguf"# @param {"type":"string"}

# low_noise_model_download_url = "https://huggingface.co/Isi99999/Wan2.1BasedModels/resolve/main/wan2.1-i2v-14b-480p-Q4_K_M.gguf"

# use_preferred_speedup_LoRAs = True # @param {type:"boolean"}
use_preferred_speedup_LoRAs = False
# high_noise_speed_LoRA_download_url = "https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors"# @param {"type":"string"}
# low_noise_speed_LoRA_download_url = "https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors"# @param {"type":"string"}

wan21_model_download_url = "https://huggingface.co/city96/Wan2.1-I2V-14B-480P-gguf/resolve/main/wan2.1-i2v-14b-480p-Q4_K_M.gguf"# @param {"type":"string"}

speed_LoRA_download_url = "https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank64_bf16.safetensors" # @param {"type":"string"}

infiniteTalk_url = "https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/InfiniteTalk/Wan2_1-InfiniteTalk-Multi_fp16.safetensors" # @param {"type":"string"}

use_multiTalk_instead = True # @param {type:"boolean"}

multiTalk_url = "https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/WanVideo_2_1_Multitalk_14B_fp8_e4m3fn.safetensors" # @param {"type":"string"}

download_loRA_1 = False
lora_1_download_url = "Put your loRA here"
download_loRA_2 = False
lora_2_download_url = "Put your loRA here"

download_loRA_3 = False
lora_3_download_url = "https://huggingface.co/Remade-AI/Rotate/resolve/main/rotate_20_epochs.safetensors"

token_if_civitai_url = "Put your civitai token here"

lora_1 = None
if download_loRA_1:
    lora_1 = download_lora(lora_1_download_url, civitai_token=token_if_civitai_url)
# Validate loRA file extension
valid_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft'}
if lora_1:
    if not any(lora_1.lower().endswith(ext) for ext in valid_extensions):
        print(f"❌ Invalid LoRA format: {lora_1}")
        lora_1 = None
    else:
        clear_output()
        print("loRA 1 downloaded succesfully!")

lora_2 = None
if download_loRA_2:
    lora_2 = download_lora(lora_2_download_url, civitai_token=token_if_civitai_url)
if lora_2:
    if not any(lora_2.lower().endswith(ext) for ext in valid_extensions):
        print(f"❌ Invalid LoRA format: {lora_2}")
        lora_2 = None
    else:
        clear_output()
        print("loRA 2 downloaded succesfully!")

lora_3 = None
if download_loRA_3:
    lora_3 = download_lora(lora_3_download_url, civitai_token=token_if_civitai_url)
if lora_3:
    if not any(lora_3.lower().endswith(ext) for ext in valid_extensions):
        print(f"❌ Invalid LoRA format: {lora_3}")
        lora_3 = None
    else:
        clear_output()
        print("loRA 3 downloaded succesfully!")


# if use_preferred_wanModels:
#     dit_model = model_download(high_noise_model_download_url, "/content/ComfyUI/models/diffusion_models")
#     dit_model2 = model_download(low_noise_model_download_url, "/content/ComfyUI/models/diffusion_models")

# else:
#     # if model_quant == "Q4_K_M":
#     #     dit_model = model_download("https://huggingface.co/Isi99999/Wan2.2BasedModels/resolve/main/wan2.2_i2v_high_noise_14B_Q4_K_M.gguf", "/content/ComfyUI/models/diffusion_models")
#     #     dit_model2 = model_download("https://huggingface.co/Isi99999/Wan2.2BasedModels/resolve/main/wan2.2_i2v_low_noise_14B_Q4_K_M.gguf", "/content/ComfyUI/models/diffusion_models")
#     # elif model_quant == "Q5_K_M":
#     #     dit_model = model_download("https://huggingface.co/Isi99999/Wan2.2BasedModels/resolve/main/wan2.2_i2v_high_noise_14B_Q5_K_M.gguf", "/content/ComfyUI/models/diffusion_models")
#     #     dit_model2 = model_download("https://huggingface.co/Isi99999/Wan2.2BasedModels/resolve/main/wan2.2_i2v_low_noise_14B_Q5_K_M.gguf", "/content/ComfyUI/models/diffusion_models")
#     # elif model_quant == "Q6_K":
#     #     dit_model = model_download("https://huggingface.co/Isi99999/Wan2.2BasedModels/resolve/main/wan2.2_i2v_high_noise_14B_Q6_K.gguf", "/content/ComfyUI/models/diffusion_models")
#     #     dit_model2 = model_download("https://huggingface.co/Isi99999/Wan2.2BasedModels/resolve/main/wan2.2_i2v_low_noise_14B_Q6_K.gguf", "/content/ComfyUI/models/diffusion_models")
#     # else:
#     #     dit_model = model_download("https://huggingface.co/Isi99999/Wan2.2BasedModels/resolve/main/wan2.2_i2v_high_noise_14B_Q8_0.gguf", "/content/ComfyUI/models/diffusion_models")
#     #     dit_model2 = model_download("https://huggingface.co/Isi99999/Wan2.2BasedModels/resolve/main/wan2.2_i2v_low_noise_14B_Q8_0.gguf", "/content/ComfyUI/models/diffusion_models")

#     if model_quant == "Q4_K_M":
#         dit_model = model_download("https://huggingface.co/Isi99999/Wan2.1BasedModels/resolve/main/wan2.1-i2v-14b-480p-Q4_K_M.gguf", "/content/ComfyUI/models/diffusion_models")
#     elif model_quant == "Q5_K_M":
#         dit_model = model_download("https://huggingface.co/Isi99999/Wan2.1BasedModels/resolve/main/wan2.1-i2v-14b-480p-Q5_K_M.gguf", "/content/ComfyUI/models/diffusion_models")
#     elif model_quant == "Q6_K":
#         dit_model = model_download("https://huggingface.co/Isi99999/Wan2.1BasedModels/resolve/main/wan2.1-i2v-14b-480p-Q6_K.gguf", "/content/ComfyUI/models/diffusion_models")
#     elif model_quant == "Q4_0":
#         dit_model = model_download("https://huggingface.co/city96/Wan2.1-I2V-14B-480P-gguf/resolve/main/wan2.1-i2v-14b-480p-Q4_0.gguf", "/content/ComfyUI/models/diffusion_models")
#     else:
#         dit_model = model_download("https://huggingface.co/Isi99999/Wan2.1BasedModels/resolve/main/wan2.1-i2v-14b-480p-Q8_0.gguf", "/content/ComfyUI/models/diffusion_models")



clear_output()
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/Isi99999/Wan_Extras/resolve/main/umt5_xxl_fp8_e4m3fn_scaled.safetensors -d /content/ComfyUI/models/text_encoders -o umt5_xxl_fp8_e4m3fn_scaled.safetensors
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/Isi99999/Wan_Extras/resolve/main/wan_2.1_vae.safetensors -d /content/ComfyUI/models/vae -o wan_2.1_vae.safetensors
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/Isi99999/Wan_Extras/resolve/main/clip_vision_h.safetensors -d /content/ComfyUI/models/clip_vision -o clip_vision_h.safetensors
clear_output()

dit_model = model_download(wan21_model_download_url, "/content/ComfyUI/models/diffusion_models")

lightx2v_lora_lowNoise = model_download(speed_LoRA_download_url, "/content/ComfyUI/models/loras")

# if use_preferred_speedup_LoRAs:
#     lightx2v_lora = model_download(high_noise_speed_LoRA_download_url, "/content/ComfyUI/models/loras")
#     lightx2v_lora_lowNoise = model_download(low_noise_speed_LoRA_download_url, "/content/ComfyUI/models/loras")
# else:
#     if lightx2v_rank == "32":
#         lightx2v_lora_lowNoise = model_download("https://huggingface.co/Isi99999/Wan2.1BasedModels/resolve/main/lightx2v_I2V_14B_480p_cfg_step_distill_rank32_bf16.safetensors", "/content/ComfyUI/models/loras")
#         # lightx2v_lora = model_download("https://huggingface.co/Isi99999/Wan2.1BasedModels/resolve/main/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank32_bf16.safetensors", "/content/ComfyUI/models/loras")
#     elif lightx2v_rank == "64":
#         lightx2v_lora_lowNoise = model_download("https://huggingface.co/Isi99999/Wan2.1BasedModels/resolve/main/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank64_bf16.safetensors", "/content/ComfyUI/models/loras")
#     else:
#         lightx2v_lora_lowNoise = model_download("https://huggingface.co/Isi99999/Wan2.1BasedModels/resolve/main/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank128_bf16.safetensors", "/content/ComfyUI/models/loras")


if use_multiTalk_instead:
    multitalkModel  = model_download(multiTalk_url, "/content/ComfyUI/models/diffusion_models")
else:
    multitalkModel  = model_download(infiniteTalk_url, "/content/ComfyUI/models/diffusion_models")

w2v_model = model_download("https://huggingface.co/TencentGameMate/chinese-wav2vec2-base/resolve/main/chinese-wav2vec2-base-fairseq-ckpt.pt", "/content/ComfyUI/models/transformers")
# pyt_model = model_download("https://huggingface.co/TencentGameMate/chinese-wav2vec2-base/resolve/main/pytorch_model.bin", "/content/ComfyUI/models/transformers")

def upload_file():
    """Handle file upload (image or video) and return paths."""
    os.makedirs('/content/ComfyUI/input', exist_ok=True)
    uploaded = files.upload()

    paths = []
    for filename in uploaded.keys():
        src_path = f'/content/ComfyUI/{filename}'
        dest_path = f'/content/ComfyUI/input/{filename}'
        shutil.move(src_path, dest_path)
        paths.append(dest_path)
        print(f"File saved to: {dest_path}")

    return paths[0] if paths else None


def upload_fileAny(target_dir: str = '/content/ComfyUI/input', file_type: str = 'any') -> str:
    """
    Handle file uploads in Colab and store in specified directory

    Args:
        target_dir: Where to store uploaded files
        file_type: Filter for specific file types ('image', 'audio', or 'any')

    Returns:
        str: Path to the uploaded file, or None if failed
    """
    from google.colab import files
    import os
    import shutil

    # Create target directory if needed
    os.makedirs(target_dir, exist_ok=True)

    # Upload file
    uploaded = files.upload()

    if not uploaded:
        print("No file was uploaded")
        return None

    # Get the first uploaded file (we'll handle one file at a time)
    filename = next(iter(uploaded.keys()))
    src_path = os.path.join('/content/ComfyUI', filename)
    dest_path = os.path.join(target_dir, filename)

    # Verify file type if requested
    if file_type.lower() != 'any':
        ext = os.path.splitext(filename)[1].lower()
        if file_type.lower() == 'image' and ext not in ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp']:
            print(f"Error: {filename} is not an image file")
            return None
        elif file_type.lower() == 'audio' and ext not in ['.mp3', '.wav', '.ogg', '.flac', '.aac', '.m4a']:
            print(f"Error: {filename} is not an audio file")
            return None

    try:
        shutil.move(src_path, dest_path)
        print(f"File saved to: {dest_path}")
        return dest_path
    except Exception as e:
        print(f"Error moving file: {str(e)}")
        return None


def upload_fileInt():
    """Handle file upload (image or video) and return paths."""
    os.makedirs('/content/ComfyUI/output', exist_ok=True)
    uploaded = files.upload()

    paths = []
    for filename in uploaded.keys():
        src_path = f'/content/ComfyUI/{filename}'
        dest_path = f'/content/ComfyUI/output/{filename}'
        shutil.move(src_path, dest_path)
        paths.append(dest_path)
        print(f"File saved to: {dest_path}")

    return paths[0] if paths else None

def extract_frames(video_path, max_frames=None):
    """Extract frames from video and return as a batch tensor."""
    vidcap = cv2.VideoCapture(video_path)
    fps = vidcap.get(cv2.CAP_PROP_FPS)
    frames = []

    while True:
        success, frame = vidcap.read()
        if not success or (max_frames and len(frames) >= max_frames):
            break

        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = torch.from_numpy(frame).float() / 255.0
        frames.append(frame)

    if not frames:
        return None, fps

    # Stack frames into a batch tensor: (N, H, W, 3)
    batch = torch.stack(frames, dim=0)
    # print(f"Extracted {len(frames)} frames (shape: {batch.shape})")
    return batch, fps


def select_every_n_frame_tensor(
    frames_tensor: torch.Tensor,
    fps: float,
    n: int,
    skip_first: int = 0,
    max_output_frames: int = 0
):
    if frames_tensor is None or frames_tensor.ndim != 4:
        raise ValueError("frames_tensor must be a 4D tensor of shape (N, H, W, C)")
    if n < 1:
        raise ValueError("n must be >= 1")

    total_frames = frames_tensor.shape[0]

    if skip_first >= total_frames:
        print("No frames available after skipping.")
        return None, 0.0

    frames_to_use = frames_tensor[skip_first:]

    # Select every nth frame
    selected_frames = frames_to_use[::n]

    # Cap output if needed
    if max_output_frames > 0 and selected_frames.shape[0] > max_output_frames:
        selected_frames = selected_frames[:max_output_frames]

    adjusted_fps = fps / n

    if max_output_frames:
        print(f"Frame cap: {max_output_frames} -> Final output: {selected_frames.shape[0]} frames")
    # print(f"Adjusted FPS: {adjusted_fps:.2f}  -> Final output: {selected_frames.shape[0]} frames")

    return selected_frames, adjusted_fps

def swapT(pa, f, s):
    if pa == f:
        pa = s
    return pa



def image_width_height(image):
    if image.ndim == 4:
        _, height, width, _ = image.shape
    elif image.ndim == 3:
        height, width, _ = image.shape
    else:
        raise ValueError(f"Unsupported image shape: {image.shape}")
    return width, height

def clear_memory():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    for obj in list(globals().values()):
        if torch.is_tensor(obj) or (hasattr(obj, "data") and torch.is_tensor(obj.data)):
            del obj
    gc.collect()

# def save_as_mp4(images, filename_prefix, fps, output_dir="/content/ComfyUI/output"):
#     os.makedirs(output_dir, exist_ok=True)
#     output_path = f"{output_dir}/{filename_prefix}.mp4"

#     frames = [(img.cpu().numpy() * 255).astype(np.uint8) for img in images]

#     with imageio.get_writer(output_path, fps=fps) as writer:
#         for frame in frames:
#             writer.append_data(frame)

#     return output_path

def save_as_mp4(images, filename_prefix, fps=25, audio_path=None, output_dir="/content/ComfyUI/output"):
    """
    Save images as MP4 video with optional audio

    Args:
        images: List of image tensors or numpy arrays
        filename_prefix: Output filename without extension
        fps: Frames per second
        audio_path: Path to audio file (optional)
        output_dir: Output directory

    Returns:
        str: Path to the generated MP4 file
    """
    os.makedirs(output_dir, exist_ok=True)

    # Convert images to uint8 numpy arrays
    frames = [(img.cpu().numpy() * 255).astype(np.uint8) if hasattr(img, 'cpu')
             else (img * 255).astype(np.uint8) for img in images]

    # Temporary video path without audio
    temp_video_path = f"{output_dir}/{filename_prefix}_temp.mp4"
    final_video_path = f"{output_dir}/{filename_prefix}.mp4"

    # Save video without audio first
    with imageio.get_writer(temp_video_path, fps=fps) as writer:
        for frame in frames:
            writer.append_data(frame)

    # If audio path is provided, merge with video
    if audio_path and os.path.exists(audio_path):
        try:
            # Use ffmpeg to merge audio and video
            cmd = [
                'ffmpeg',
                '-y',  # Overwrite without asking
                '-i', temp_video_path,
                '-i', audio_path,
                '-c:v', 'copy',  # Copy video stream without re-encoding
                '-c:a', 'aac',   # Encode audio to AAC
                '-shortest',     # Match duration of the shorter input
                final_video_path
            ]
            subprocess.run(cmd, check=True, capture_output=True)

            # Remove temporary file
            os.remove(temp_video_path)

            print(f"Video with audio saved to: {final_video_path}")
            return final_video_path

        except subprocess.CalledProcessError as e:
            print(f"Error adding audio: {e.stderr.decode()}")
            os.rename(temp_video_path, final_video_path)
            return final_video_path
    else:
        os.rename(temp_video_path, final_video_path)
        print(f"Video saved to: {final_video_path}")
        return final_video_path

def save_as_mp4U(images, filename_prefix, fps, output_dir="/content/ComfyUI/output"):
    os.makedirs(output_dir, exist_ok=True)
    output_path = f"{output_dir}/{filename_prefix}.mp4"

    frames = []
    for i, img in enumerate(images):
        try:

            if isinstance(img, torch.Tensor):
                img = img.cpu().numpy()

            # print(f"Frame {i} initial shape: {img.shape}, dtype: {img.dtype}, max: {img.max()}")  # Debug


            if img.max() <= 1.0:
                img = (img * 255).astype(np.uint8)
            else:
                img = img.astype(np.uint8)


            if len(img.shape) == 4:  # Batch dimension? (N, C, H, W)
                img = img[0]  # Take first image in batch

            if len(img.shape) == 3:
                if img.shape[0] in (1, 3, 4):  # CHW format
                    img = np.transpose(img, (1, 2, 0))
                elif img.shape[2] > 4:  # Too many channels
                    img = img[:, :, :3]
            elif len(img.shape) == 2:
                img = np.expand_dims(img, axis=-1)

            # print(f"Frame {i} processed shape: {img.shape}")  # Debug

            # Final validation
            if len(img.shape) != 3 or img.shape[2] not in (1, 3, 4):
                raise ValueError(f"Invalid frame shape after processing: {img.shape}")

            frames.append(img)
        except Exception as e:
            print(f"Error processing frame {i}: {str(e)}")
            raise

    try:
        with imageio.get_writer(output_path, fps=fps) as writer:
            for i, frame in enumerate(frames):
                # print(f"Writing frame {i} with shape: {frame.shape}")  # Debug
                writer.append_data(frame)
    except Exception as e:
        print(f"Error writing video: {str(e)}")
        raise

    return output_path

def save_as_webp(images, filename_prefix, fps, quality=90, lossless=False, method=4, output_dir="/content/ComfyUI/output"):
    """Save images as animated WEBP using imageio."""
    os.makedirs(output_dir, exist_ok=True)
    output_path = f"{output_dir}/{filename_prefix}.webp"


    frames = [(img.cpu().numpy() * 255).astype(np.uint8) for img in images]


    kwargs = {
        'fps': int(fps),
        'quality': int(quality),
        'lossless': bool(lossless),
        'method': int(method)
    }

    with imageio.get_writer(
        output_path,
        format='WEBP',
        mode='I',
        **kwargs
    ) as writer:
        for frame in frames:
            writer.append_data(frame)

    return output_path

def save_as_webm(images, filename_prefix, fps, codec="vp9", quality=32, output_dir="/content/ComfyUI/output"):
    """Save images as WEBM using imageio."""
    os.makedirs(output_dir, exist_ok=True)
    output_path = f"{output_dir}/{filename_prefix}.webm"


    frames = [(img.cpu().numpy() * 255).astype(np.uint8) for img in images]


    kwargs = {
        'fps': int(fps),
        'quality': int(quality),
        'codec': str(codec),
        'output_params': ['-crf', str(int(quality))]
    }

    with imageio.get_writer(
        output_path,
        format='FFMPEG',
        mode='I',
        **kwargs
    ) as writer:
        for frame in frames:
            writer.append_data(frame)

    return output_path

def save_as_image(image, filename_prefix, output_dir="/content/ComfyUI/output"):
    """Save single frame as PNG image."""
    os.makedirs(output_dir, exist_ok=True)
    output_path = f"{output_dir}/{filename_prefix}.png"

    frame = (image.cpu().numpy() * 255).astype(np.uint8)

    Image.fromarray(frame).save(output_path)

    return output_path

def save_as_image2(image, filename_prefix, output_dir="/content/ComfyUI/output"):
    """Save single frame as PNG image."""
    os.makedirs(output_dir, exist_ok=True)
    output_path = f"{output_dir}/{filename_prefix}.png"

    if isinstance(image, torch.Tensor):
        image = image.cpu().numpy()
    if image.ndim == 4:  # Batch dimension
        image = image[0]
    if image.shape[0] == 3:  # CHW to HWC
        image = np.transpose(image, (1, 2, 0))
    image = (image * 255).astype(np.uint8)

    Image.fromarray(image).save(output_path)
    return output_path


def upload_image():
    """Handle image upload in Colab and store in /content/ComfyUI/input/"""
    from google.colab import files
    import os
    import shutil

    os.makedirs('/content/ComfyUI/input', exist_ok=True)

    uploaded = files.upload()

    # Move each uploaded file to ComfyUI input directory
    for filename in uploaded.keys():
        src_path = f'/content/ComfyUI/{filename}'
        dest_path = f'/content/ComfyUI/input/{filename}'

        shutil.move(src_path, dest_path)
        # print(f"Image saved to: {dest_path}")
        return dest_path

    return None



output_path =""

# file_uploaded = None
# file_uploaded2 = None

def generate_video(
    image_path: str = None,
    image_path2: str = None,
    audio_path: str = None,
    audio_path2: str = None,
    audio_path3: str = None,
    audio_path4: str = None,
    LoRA_Strength: float = 1.00,
    rel_l1_thresh: float = 0.275,
    start_percent: float = 0.1,
    end_percent: float = 1.0,
    positive_prompt: str = "a cute anime girl with massive fennec ears and a big fluffy tail wearing a maid outfit turning around",
    prompt_assist: str = "walking to viewers",
    negative_prompt: str = "色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走",
    width: int = 832,
    height: int = 480,
    custom_audio_duration: int = 30,
    seed: int = 82628696717253,
    steps: int = 20,
    cfg_scale: float = 1.0,
    sampler_name: str = "uni_pc",
    scheduler: str = "simple",
    # frames: int = 33,
    fps: int = 16,
    output_format: str = "mp4",
    overwrite: bool = False,
    use_lora: bool = True,
    use_lora2: bool = True,
    LoRA_Strength2: float = 1.00,
    use_lora3: bool = True,
    LoRA_Strength3: float = 1.00,
    use_lightx2v: bool = False,
    lightx2v_Strength: float = 0.80,
    lightx2v_steps: int = 4,
    use_pusa: bool = False,
    pusa_Strength: float = 1.2,
    pusa_steps: int = 6,
    use_sage_attention: bool = True,
    enable_flow_shift: bool = True,
    shift: float = 8.0,
    enable_flow_shift2: bool = True,
    shift2: float = 8.0,
    end_step1: int = 10,
    use_one_model: bool = True,
    use_block_swap: bool = True,
    blocks_to_swap: int = 20

):

    with torch.inference_mode():

        # Initialize nodes
        # unet_loader = UnetLoaderGGUF()
        load_audio = LoadAudio()
        audio_separation = AudioSeparation()
        audio_crop = AudioCrop()
        load_wav2vec = DownloadAndLoadWav2VecModel()
        wan_model_loader = WanVideoModelLoader()
        wan_lora_select = WanVideoLoraSelect()
        wan_vae_loader = WanVideoVAELoader()
        wan_vae_decoder = WanVideoDecode()
        wan_text_embed_bridge = WanVideoTextEmbedBridge()
        wan_clip_vision = WanVideoClipVisionEncode()
        multitalk_loader = MultiTalkModelLoader()
        multitalk_wav2vec = MultiTalkWav2VecEmbeds()
        multitalk_img2vid = WanVideoImageToVideoMultiTalk()
        wan_sampler = WanVideoSampler()
        wan_context_options = WanVideoContextOptions()
        block_swapper = WanVideoBlockSwap()

        # pathch_sage_attention = PathchSageAttentionKJ()
        # wan_video_nag = WanVideoNAG()
        # teacache = WanVideoTeaCacheKJ()
        # model_sampling = ModelSamplingSD3()
        clip_loader = CLIPLoader()
        clip_encode_positive = CLIPTextEncode()
        clip_encode_negative = CLIPTextEncode()
        # vae_loader = VAELoader()
        clip_vision_loader = CLIPVisionLoader()
        # clip_vision_encode = CLIPVisionEncode()
        load_image = LoadImage()
        # wan_image_to_video = WanFirstLastFrameToVideo()
        # ksampler = KSamplerAdvanced()
        # vae_decode = VAEDecode()
        save_webp = SaveAnimatedWEBP()
        save_webm = SaveWEBM()
        # pAssLora = LoraLoaderModelOnly()
        # load_lora = LoraLoaderModelOnly()
        # load_lora2 = LoraLoaderModelOnly()
        # load_lora3 = LoraLoaderModelOnly()
        load_lightx2v_lora = LoraLoaderModelOnly()
        load_pusa_lora = LoraLoaderModelOnly()
        image_scaler = ImageScale()

        end_step1in = end_step1

        print("Loading Text_Encoder...")
        clip = clip_loader.load_clip("umt5_xxl_fp8_e4m3fn_scaled.safetensors", "wan", "default")[0]

        positive = clip_encode_positive.encode(clip, positive_prompt)[0]
        negative = clip_encode_negative.encode(clip, negative_prompt)[0]

        del clip
        torch.cuda.empty_cache()
        gc.collect()

        text_embeds = wan_text_embed_bridge.process(positive, negative)[0]

        print("Loading vision_Encoder...")
        clip_vision = clip_vision_loader.load_clip("clip_vision_h.safetensors")[0]

        if image_path is None:
            print("Please upload an image file to avoid errors:")
            image_path = upload_image()
        if image_path is None:
            raise ValueError("No image uploaded!")

        if image_path.lower().endswith(('.png', '.jpg', '.jpeg')):
            loaded_image = load_image.load_image(image_path)[0]

            width_int, height_int = image_width_height(loaded_image)

            if height == 0:
                height = int(width * height_int / width_int)

            print(f"Image resolution is {width_int}x{height_int}")
            print(f"Scaling image to {width}x{height}...")
            loaded_image = image_scaler.upscale(
                loaded_image,
                "lanczos",
                width,
                height,
                "disabled"
            )[0]

        else:
            framesVid, fpsVid = extract_frames(image_path)
            width_int, height_int = image_width_height(framesVid[0])
            print(f"Video resolution is {width_int}x{height_int}")
            print(f"Scaling video to {width}x{height}...")
            loaded_image = image_scaler.upscale(
                framesVid,
                "lanczos",
                width,
                height,
                "disabled"
            )[0]


        print("Processing with clip vision...")
        clip_vision_output = wan_clip_vision.process(
            clip_vision=clip_vision,
            image_1=loaded_image,
            strength_1=1.0,
            strength_2=1.0,
            force_offload=False,
            crop="disabled",
            combine_embeds="average",
            image_2=None,
            negative_image=None,
            tiles=0,
            ratio=0.5
        )[0]


        if audio_path is None:
            print("Please upload an audio file:")
            audio_path = upload_fileAny(file_type='audio')
        if audio_path is None:
            raise ValueError("No audio uploaded!")

        # loaded_audios["audio1"] = load_audio.load(audio_path)[0]
        # audio_duration = loaded_audios["audio1"]["waveform"].shape[-1] / loaded_audios["audio1"]["sample_rate"]


        # for i in range(2, 5):
        #     audio_path = globals().get(f"audio_path{i}")
        #     if audio_path and isinstance(audio_path, str):
        #         loaded_audios[f"audio{i}"] = load_audio.load(audio_path)[0]
        #         audio_duration += loaded_audios[f"audio{i}"]["waveform"].shape[-1] / loaded_audios[f"audio{i}"]["sample_rate"]
        audio_paths = [audio_path, audio_path2, audio_path3, audio_path4]
        loaded_audios = []
        for path in audio_paths:
            if path is not None:
                loaded_audios.append(load_audio.load(path)[0])
                # print("Added audio for combination")
            else:
                loaded_audios.append(None)
        non_none_audios = [a for a in loaded_audios if a is not None]
        if len(non_none_audios) > 2:
            # print("Combining remaining audios")
            combined_waveform = torch.cat([a["waveform"] for a in non_none_audios], dim=-1)
            sample_rate = non_none_audios[0]["sample_rate"]
            combined_audio = {"waveform": combined_waveform, "sample_rate": sample_rate}
        else:
            combined_audio = non_none_audios[0]

        audio_duration = combined_audio["waveform"].shape[-1] / combined_audio["sample_rate"]

        print(f"Input audio duration is {audio_duration} seconds, chosen audio duration is: {custom_audio_duration} seconds.")

        if custom_audio_duration < audio_duration:
            print("Reducing audio length...")
            audio_duration = int(custom_audio_duration)
            combined_audio = audio_crop.main(combined_audio, "00", f"{audio_duration}")[0]

        # output_audio_path = "/content/cropped_audio.wav"
        # waveform = combined_audio["waveform"].squeeze(0)  # [channels, samples]
        # torchaudio.save(output_audio_path, waveform, combined_audio["sample_rate"])

        if image_path2 is None:
            loaded_image2 = None
            clip_vision_output2 = None
        else:
            loaded_image2 = load_image.load_image(image_path2)[0]

            width_int, height_int = image_width_height(loaded_image2)

            if height == 0:
                height = int(width * height_int / width_int)

            print(f"Second Image resolution is {width_int}x{height_int}")
            print(f"Scaling Second image to {width}x{height}...")
            loaded_image2 = image_scaler.upscale(
                loaded_image2,
                "lanczos",
                width,
                height,
                "disabled"
            )[0]
            # clip_vision_output2 = clip_vision_encode.encode(clip_vision, loaded_image2, "none")[0]

            # clip_vision_output2 = wan_clip_vision.process(
            #     clip_vision=clip_vision,
            #     image_1=loaded_image2,
            #     strength_1=1.0,
            #     strength_2=1.0,
            #     force_offload=False,
            #     crop="disabled",
            #     combine_embeds="average",
            #     image_2=None,
            #     negative_image=None,
            #     tiles=0,
            #     ratio=0.5
            # )[0]

        # del clip_vision
        # torch.cuda.empty_cache()
        # gc.collect()


        # if custom_audio_duration < audio_duration:

        #     print("reducing audio length...")
        #     audio_duration = int(custom_audio_duration)
        #     loaded_audio = audio_crop.main(loaded_audio, "00", f"{audio_duration}")[0]

        #     output_audio_path = "/content/cropped_audio.wav"
        #     waveform = loaded_audio["waveform"].squeeze(0)  # now [channels, samples]
        #     torchaudio.save(output_audio_path, waveform, loaded_audio["sample_rate"])

        # else:
        #     output_audio_path = audio_path

        frames = max(1*fps, int(audio_duration * fps))
        if frames % 2 == 0:
            frames += 1

        print(f"Audio duration is now {audio_duration} seconds, and frames is: {frames}.")

        # print("separating vocals...")
        # separated_vocals = audio_separation.main(loaded_audio,"linear", 10.0, 0.1)[3]

        print("Loading wav2vec...")
        wav2vec_model = load_wav2vec.loadmodel("TencentGameMate/chinese-wav2vec2-base", "fp16", "main_device")[0]

        # print(f"Loaded audios: {loaded_audios}")
        print("Embedding wav2vec...")
        wav2vec_embeds, combined_audio, actual_num_frames = multitalk_wav2vec.process(
            wav2vec_model,
            True,
            fps,
            frames,
            loaded_audios[0],
            1,
            1,
            "para",
            loaded_audios[1],
            loaded_audios[2],
            loaded_audios[3]
        )

        output_audio_path = "/content/cropped_audio.wav"
        waveform = combined_audio["waveform"].squeeze(0)  # [channels, samples]
        torchaudio.save(output_audio_path, waveform, combined_audio["sample_rate"])

        del wav2vec_model
        torch.cuda.empty_cache()
        gc.collect()



        print("Loading VAE...")
        vae = wan_vae_loader.loadmodel("wan_2.1_vae.safetensors", "fp16")[0]

        # print("embedding Image...")
        image_embeds = multitalk_img2vid.process(vae, width, height, frames, fps, False, 'mkl', loaded_image, False, clip_vision_output)[0]


        # positive_out, negative_out, latent = wan_image_to_video.encode(
        #     positive, negative, vae, width, height, frames, 1, loaded_image, loaded_image2, clip_vision_output, clip_vision_output2
        # )

        usedSteps = steps

        if use_multiTalk_instead:
            print("Loading multiTalk Model...")
        else:
            print("Loading infiniteTalk Model...")

        multitalk_model = multitalk_loader.loadmodel(multitalkModel, "fp16")[0]

        if use_lightx2v:
            print("Loading speed LoRA...")
            wan_speed_lora = wan_lora_select.getlorapath(lightx2v_lora_lowNoise, pusa_Strength, None, {}, None, False, False)[0]
            # model = load_pusa_lora.load_lora_model_only(model, lightx2v_lora_lowNoise, pusa_Strength)[0]
            usedSteps=lightx2v_steps
        end_step1in = -1

        clear_output()

        print("Loading Model...")

        if use_block_swap:
            block_swap_args = block_swapper.setargs(
                blocks_to_swap=blocks_to_swap,
                offload_img_emb=False,
                offload_txt_emb=False,
                use_non_blocking=True,
                vace_blocks_to_swap=0,
                prefetch_blocks=0,
                block_swap_debug=False
            )[0]
        else:
            block_swap_args=None

        model = wan_model_loader.loadmodel(
            model=dit_model,
            base_precision="fp16",
            load_device="offload_device",
            quantization="disabled",
            compile_args=None,
            attention_mode="sageattn",
            block_swap_args=block_swap_args,
            lora=wan_speed_lora,
            vram_management_args=None,
            vace_model=None,
            fantasytalking_model=None,
            multitalk_model=multitalk_model,
            fantasyportrait_model=None
        )[0]

        del multitalk_model
        torch.cuda.empty_cache()
        gc.collect()


        # model = unet_loader.load_unet(dit_model)[0]

        # model = wan_video_nag.patch(model, negative, 11.0, 0.25, 2.5)[0]

        # if enable_flow_shift:
        #     model = model_sampling.patch(model, shift)[0]

        # if prompt_assist != "none":
        #     if prompt_assist == "walking to viewers":
        #         print("Loading walking to camera LoRA...")
        #         model = pAssLora.load_lora_model_only(model, walkingToViewersL, 1)[0]
        #     if prompt_assist == "walking from behind":
        #         print("Loading walking from camera LoRA...")
        #         model = pAssLora.load_lora_model_only(model, walkingFromBehindL, 1)[0]
        #     if prompt_assist == "b3ll13-d8nc3r":
        #         print("Loading dancing LoRA...")
        #         model = pAssLora.load_lora_model_only(model, dancingL, 1)[0]

        # if use_lora and lora_1 is not None:
        #     print("Loading LoRA...")
        #     model = load_lora.load_lora_model_only(model, lora_1, LoRA_Strength)[0]

        # if use_lora2 and lora_2 is not None:
        #     print("Loading LoRA 2...")
        #     model = load_lora2.load_lora_model_only(model, lora_2, LoRA_Strength2)[0]

        # if use_lora3 and lora_3 is not None:
        #     print("Loading LoRA 3...")
        #     model = load_lora3.load_lora_model_only(model, lora_3, LoRA_Strength3)[0]

        # if use_causvid:
        #     print("Loading causvid LoRA...")
        #     model = load_causvid_lora.load_lora_model_only(model, causvid_lora, causvid_Strength)[0]
        #     usedSteps=causvid_steps

        # if use_lightx2v:
        #     if use_one_model:
        #         print("Loading speed LoRA...")
        #         model = load_pusa_lora.load_lora_model_only(model, lightx2v_lora_lowNoise, pusa_Strength)[0]
        #         usedSteps=lightx2v_steps
        #         end_step1in = -1

        #     else:
        #         print("Loading high noise LoRA...")
        #         model = load_lightx2v_lora.load_lora_model_only(model, lightx2v_lora, lightx2v_Strength)[0]
        #         usedSteps=lightx2v_steps

        # if use_pusa:
        #     print("Loading pusav1 LoRA...")
        #     model = load_pusa_lora.load_lora_model_only(model, pusa_lora, pusa_Strength)[0]
        #     usedSteps=pusa_steps

        # if use_sage_attention:
        #     model = pathch_sage_attention.patch(model, "auto")[0]

        # if rel_l1_thresh > 0:
        #     print("Setting Teacache...")
        #     model = teacache.patch_teacache(model, rel_l1_thresh, start_percent, end_percent, "main_device", "14B")[0]

        clear_output()

        if use_one_model:
            print("Generating video...")
        else:
            print("Generating video with high noise model...")
        # sampled = ksampler.sample(
        #     model=model,
        #     add_noise="enable",
        #     noise_seed=seed,
        #     steps=usedSteps,
        #     cfg=cfg_scale,
        #     sampler_name=sampler_name,
        #     scheduler=scheduler,
        #     positive=positive_out,
        #     negative=negative_out,
        #     latent_image=latent,
        #     start_at_step=0,
        #     end_at_step=end_step1,
        #     return_with_leftover_noise="enable"
        # )[0]

        # context_options = wan_context_options.process(
        #     context_schedule="uniform_standard",
        #     context_frames=frames,
        #     context_stride=4,
        #     context_overlap=16,
        #     freenoise=True,
        #     verbose=False,
        #     image_cond_start_step=6,
        #     image_cond_window_count=2,
        #     vae=None,
        #     fuse_method="linear",
        #     reference_latent=None
        # )[0]


        sampled = wan_sampler.process(
            model=model,
            image_embeds=image_embeds,
            shift=8,
            steps=usedSteps,
            cfg=cfg_scale,
            seed=seed,
            scheduler=sampler_name,
            riflex_freq_index=0,
            text_embeds=text_embeds,
            force_offload=True,
            samples=None,
            feta_args=None,
            denoise_strength=1.0,
            context_options=None,
            cache_args=None,
            teacache_args=None,
            flowedit_args=None,
            batched_cfg=False,
            slg_args=None,
            rope_function="default",
            loop_args=None,
            experimental_args=None,
            sigmas=None,
            unianimate_poses=None,
            fantasytalking_embeds=None,
            uni3c_embeds=None,
            multitalk_embeds=wav2vec_embeds,
            freeinit_args=None,
            start_step=0,
            end_step=end_step1in,
            add_noise_to_samples=False
        )[0]

        del model
        torch.cuda.empty_cache()
        gc.collect()




        if use_one_model is False:
            multitalk_model = multitalk_loader.loadmodel(multitalkModel, "fp16")[0]

            print("Loading low noise Model...")
            model = wan_model_loader.loadmodel(
                model=dit_model2,
                base_precision="fp16",
                load_device="main_device",
                quantization="disabled",
                compile_args=None,
                attention_mode="sageattn",
                block_swap_args=None,
                lora=None,
                vram_management_args=None,
                vace_model=None,
                fantasytalking_model=None,
                multitalk_model=multitalk_model,
                fantasyportrait_model=None
            )[0]

            del multitalk_model
            torch.cuda.empty_cache()
            gc.collect()

            # model = unet_loader.load_unet(dit_model2)[0]



            # model = wan_video_nag.patch(model, negative, 11.0, 0.25, 2.5)[0]

            # if enable_flow_shift2:
            #     model = model_sampling.patch(model, shift2)[0]

            # if prompt_assist != "none":
            #     if prompt_assist == "walking to viewers":
            #         print("Loading walking to camera LoRA...")
            #         model = pAssLora.load_lora_model_only(model, walkingToViewersL, 1)[0]
            #     if prompt_assist == "walking from behind":
            #         print("Loading walking from camera LoRA...")
            #         model = pAssLora.load_lora_model_only(model, walkingFromBehindL, 1)[0]
            #     if prompt_assist == "b3ll13-d8nc3r":
            #         print("Loading dancing LoRA...")
            #         model = pAssLora.load_lora_model_only(model, dancingL, 1)[0]

            # if use_lora and lora_1 is not None:
            #     print("Loading LoRA...")
            #     model = load_lora.load_lora_model_only(model, lora_1, LoRA_Strength)[0]

            # if use_lora2 and lora_2 is not None:
            #     print("Loading LoRA 2...")
            #     model = load_lora2.load_lora_model_only(model, lora_2, LoRA_Strength2)[0]

            # if use_lora3 and lora_3 is not None:
            #     print("Loading LoRA 3...")
            #     model = load_lora3.load_lora_model_only(model, lora_3, LoRA_Strength3)[0]

            # if use_causvid:
            #     print("Loading causvid LoRA...")
            #     model = load_causvid_lora.load_lora_model_only(model, causvid_lora, causvid_Strength)[0]
            #     usedSteps=causvid_steps

            # if use_lightx2v:
            #     print("Loading lightx2v LoRA...")
            #     model = load_lightx2v_lora.load_lora_model_only(model, lightx2v_lora, lightx2v_Strength)[0]
            #     usedSteps=lightx2v_steps

            if use_pusa:
                print("Loading low noise LoRA...")
                model = load_pusa_lora.load_lora_model_only(model, lightx2v_lora_lowNoise, pusa_Strength)[0]
                usedSteps=lightx2v_steps

            # if use_sage_attention:
            #     model = pathch_sage_attention.patch(model, "auto")[0]

            # if rel_l1_thresh > 0:
            #     print("Setting Teacache...")
            #     model = teacache.patch_teacache(model, rel_l1_thresh, start_percent, end_percent, "main_device", "14B")[0]

            # clear_output()

            print("Generating video with low noise model...")
            # sampled = ksampler.sample(
            #     model=model,
            #     add_noise="disable",
            #     noise_seed=seed,
            #     steps=usedSteps,
            #     cfg=cfg_scale,
            #     sampler_name=sampler_name,
            #     scheduler=scheduler,
            #     positive=positive_out,
            #     negative=negative_out,
            #     latent_image=sampled,
            #     start_at_step=end_step1,
            #     end_at_step=10000,
            #     return_with_leftover_noise="disable"
            # )[0]

            print(sampled.keys())

            sampled = wan_sampler.process(
                model=model,
                image_embeds=image_embeds,
                shift=8,
                steps=usedSteps,
                cfg=cfg_scale,
                seed=seed,
                scheduler=sampler_name,
                riflex_freq_index=0,
                text_embeds=text_embeds,
                force_offload=True,
                samples=sampled,
                feta_args=None,
                denoise_strength=1.0,
                context_options=None,
                cache_args=None,
                teacache_args=None,
                flowedit_args=None,
                batched_cfg=False,
                slg_args=None,
                rope_function="default",
                loop_args=None,
                experimental_args=None,
                sigmas=None,
                unianimate_poses=None,
                fantasytalking_embeds=None,
                uni3c_embeds=None,
                multitalk_embeds=wav2vec_embeds,
                freeinit_args=None,
                start_step=end_step1,
                end_step=-1,
                add_noise_to_samples=False
            )[0]

            del model
            torch.cuda.empty_cache()
            gc.collect()


        try:
            print("Decoding latents...")
            decoded = wan_vae_decoder.decode(
                vae=vae,
                samples=sampled,
                enable_vae_tiling=False,
                tile_x=272,
                tile_y=272,
                tile_stride_x=144,
                tile_stride_y=128,
                normalization="default")[0]

            # decoded = vae_decode.decode(vae, sampled)[0]

            del vae
            torch.cuda.empty_cache()
            gc.collect()

            global output_path
            import datetime
            base_name = "ComfyUI"
            if not overwrite:
                timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
                base_name += f"_{timestamp}"
            if frames == 1:
                print("Single frame detected - saving as PNG image...")
                output_path = save_as_image(decoded[0], "ComfyUI")
                # print(f"Image saved as PNG: {output_path}")

                display(IPImage(filename=output_path))
            else:
                if output_format.lower() == "webm":
                    print("Saving as WEBM...")
                    output_path = save_as_webm(
                        decoded,
                        base_name,
                        fps=fps,
                        codec="vp9",
                        quality=10
                    )
                elif output_format.lower() == "mp4":
                    print("Saving as MP4...")
                    # output_path1 = save_as_mp4(decoded, base_name, fps)
                    output_path = save_as_mp4(
                        images=decoded,
                        filename_prefix=base_name,
                        fps=fps,
                        audio_path=output_audio_path
                    )

                    output_path2 = save_as_mp4(decoded, "ComfyUI", fps)
                else:
                    raise ValueError(f"Unsupported output format: {output_format}")

                # print(f"Video saved as {output_format.upper()}: {output_path}")

                display_video(output_path)

        except Exception as e:
            print(f"Error during decoding/saving: {str(e)}")
            raise
        finally:
            clear_memory()



def display_video(video_path):
    from IPython.display import HTML
    from base64 import b64encode

    video_data = open(video_path,'rb').read()

    # Determine MIME type based on file extension
    if video_path.lower().endswith('.mp4'):
        mime_type = "video/mp4"
    elif video_path.lower().endswith('.webm'):
        mime_type = "video/webm"
    elif video_path.lower().endswith('.webp'):
        mime_type = "image/webp"
    else:
        mime_type = "video/mp4"  # default

    data_url = f"data:{mime_type};base64," + b64encode(video_data).decode()

    display(HTML(f"""
    <video width=512 controls autoplay loop>
        <source src="{data_url}" type="{mime_type}">
    </video>
    """))

clear_output()

print("✅ Environment Setup Complete!")


# @markdown ---


In [None]:

# @markdown # 💥2. Upload Image (png, jpg, jpeg)/Video
file_uploaded = upload_file()
display_upload = False # @param {type:"boolean"}
if display_upload:
    if file_uploaded.lower().endswith(('.png', '.jpg', '.jpeg')):
        display(IPImage(filename=file_uploaded))
    else:
        print("If Image, then Image format cannnot be displayed.")
        display_video(file_uploaded)
# @markdown ---

In [None]:

# @markdown # 💥3. Upload Audio
audio_uploaded = upload_fileAny(file_type='audio')


In [None]:
# @markdown # Upload Audio 2 (Optional)
audio_uploaded2 = upload_fileAny(file_type='audio')
# @markdown ---

In [None]:

# @markdown # 💥4. Generate Video
import time
start_time = time.time()
# @markdown ### Video Settings
# use_image1_as_first_last = False # @param {type:"boolean"}
# disable_image2 = False # @param {type:"boolean"}
positive_prompt = "The woman and man take turns talking to each other" # @param {"type":"string"}
prompt_assist = "none"
# prompt_assist = swapT(prompt_assist, "walking to camera", "walking to viewers")
# prompt_assist = swapT(prompt_assist, "walking from camera", "walking from behind")
# prompt_assist = swapT(prompt_assist, "swaying", "b3ll13-d8nc3r")
# positive_prompt = f"{positive_prompt} {prompt_assist}." if prompt_assist != "none" else positive_prompt
# positive_prompt = f"{positive_prompt} Turn this image into {prompt_assist} style." if prompt_assist != "none" else positive_prompt

negative_prompt = "bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" # @param {"type":"string"}
width = 400 # @param {"type":"number"}
height = 704 # @param {"type":"number"}
reduce_audio_duration_to = 20 # @param {"type":"number"}
seed = 2335434353 # @param {"type":"integer"}
high_noise_steps = 2
steps = 4 # @param {"type":"integer", "min":1, "max":50}
cfg_scale = 1 # @param {"type":"number", "min":1, "max":20}
scheduler = "flowmatch_distill" # @param ["flowmatch_distill","uni_pc", "uni_pc_bh2", "ddim","euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral","lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu","dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm","ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp","gradient_estimation", "er_sde", "seeds_2", "seeds_3"]
sampler_name = "simple"

frames = 81
# fps = 16 # @param {"type":"integer", "min":1, "max":60}
fps = 25
# output_format = "mp4" # @param ["mp4", "webm"]
output_format = "mp4"
overwrite_previous_video = False # @param {type:"boolean"}

# @markdown ### Model Configuration
use_block_swap = False # @param {type:"boolean"}
blocks_to_swap = 20 # @param {"type":"number"}
use_sage_attention = True # @param {type:"boolean"}
# use_sage_attention = True
use_flow_shift = True
flow_shift = 11
flow_shift2 = 11
use_one_model = True

# use_causvid = False # @param {type:"boolean"}
# causvid_Strength = 0.8 # @param {"type":"slider","min":-100,"max":100,"step":0.01}
# causvid_steps = 4 # @param {"type":"integer", "min":1, "max":20}
use_high_noise_speed_LoRA = True
high_noise_speed_LoRA_Strength = 1
# lightx2v_steps = 4 # @param {"type":"integer", "min":1, "max":20}
use_speed_LoRA = True # @param {type:"boolean"}
speed_LoRA_Strength = 1 # @param {"type":"slider","min":-10,"max":10,"step":0.01}
# pusav1_steps = 6 # @param {"type":"integer", "min":1, "max":20}

use_lora = False
LoRA_Strength = 1.5
use_lora2 = False
LoRA_Strength2 = 1.0
use_lora3 = False #
LoRA_Strength3 = 1.0

rel_l1_thresh = 0
start_percent = 0.2
end_percent = 1.0

# @markdown ---

import random
seed = seed if seed != 0 else random.randint(0, 2**32 - 1)
print(f"Using seed: {seed}")

second_image = None

audio_driver  = globals().get("audio_uploaded", None)
audio_driver2 = globals().get("audio_uploaded2", None)
audio_driver3 = globals().get("audio_uploaded3", None)
audio_driver4 = globals().get("audio_uploaded4", None)

# with torch.inference_mode():
generate_video(
    image_path=file_uploaded,
    image_path2=second_image,
    audio_path = audio_driver,
    audio_path2 = audio_driver2,
    audio_path3 = audio_driver3,
    audio_path4 = audio_driver4,
    LoRA_Strength=LoRA_Strength,
    rel_l1_thresh=rel_l1_thresh,
    start_percent=start_percent,
    end_percent = end_percent,
    positive_prompt=positive_prompt,
    prompt_assist=prompt_assist,
    negative_prompt=negative_prompt,
    width=width,
    height=height,
    custom_audio_duration=reduce_audio_duration_to,
    seed=seed,
    steps=steps,
    cfg_scale=cfg_scale,
    sampler_name=scheduler,
    scheduler=sampler_name,
    # frames=frames,
    fps=fps,
    output_format=output_format,
    overwrite=overwrite_previous_video,
    use_lora = use_lora,
    use_lora2=use_lora2,
    LoRA_Strength2=LoRA_Strength2,
    use_lora3=use_lora3,
    LoRA_Strength3=LoRA_Strength3,
    # use_causvid=use_causvid,
    # causvid_Strength=causvid_Strength,
    # causvid_steps=causvid_steps,
    use_lightx2v=use_high_noise_speed_LoRA,
    lightx2v_Strength=high_noise_speed_LoRA_Strength,
    lightx2v_steps=steps,
    use_pusa=use_speed_LoRA,
    pusa_Strength=speed_LoRA_Strength,
    pusa_steps=steps,
    use_sage_attention = use_sage_attention,
    enable_flow_shift = use_flow_shift,
    shift = flow_shift,
    enable_flow_shift2 = use_flow_shift,
    shift2 = flow_shift2,
    end_step1 = high_noise_steps,
    use_one_model = use_one_model,
    use_block_swap = use_block_swap,
    blocks_to_swap = blocks_to_swap
)

end_time = time.time()
duration = end_time - start_time
mins, secs = divmod(duration, 60)
print(f"Seed: {seed}")
# print(f"prompt: {positive_prompt}")
print(f"✅ Generation completed in {int(mins)} min {secs:.2f} sec")

clear_memory()