# 初始化模型

In [1]:
import torch
from typing import Tuple, Optional, List, Union
import os
import math
import traceback
import einops
import numpy as np
from PIL import Image
from diffusers import WanPipeline
from transformers import AutoTokenizer, UMT5EncoderModel
from diffusers.models.transformers.transformer_wan import WanTransformer3DModel,register_to_config
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers_helper.utils import save_bcthw_as_mp4, crop_or_pad_yield_mask, soft_append_bcthw, resize_and_center_crop, generate_timestamp
from diffusers_helper.memory import cpu, gpu, get_cuda_free_memory_gb, move_model_to_device_with_memory_preservation, offload_model_from_device_for_memory_preservation, fake_diffusers_current_device, unload_complete_models, load_model_as_complete
from diffusers_helper.bucket_tools import find_nearest_bucket
from diffusers.schedulers import UniPCMultistepScheduler

import torch.utils.checkpoint

from diffusers.utils.accelerate_utils import apply_forward_hook
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
from diffusers.models.autoencoders.autoencoder_kl_wan import WanCausalConv3d, WanEncoder3d, WanDecoder3d, AutoencoderKLWan as OriginalAutoencoderKLWan



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_id = 'Wan-AI/Wan2.1-T2V-1.3B-Diffusers'
# model_id = '/home/tippy/.cache/huggingface/models--Wan-AI--Wan2.1-T2V-1.3B-Diffusers/'

In [3]:

class AutoencoderKLWan(OriginalAutoencoderKLWan):
    r"""
    A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
    Introduced in [Wan 2.1].

    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
    for all models (such as downloading or saving).
    """

    _supports_gradient_checkpointing = False

    @register_to_config
    def __init__(
        self,
        base_dim: int = 96,
        z_dim: int = 16,
        dim_mult: Tuple[int] = [1, 2, 4, 4],
        num_res_blocks: int = 2,
        attn_scales: List[float] = [],
        temperal_downsample: List[bool] = [False, True, True],
        dropout: float = 0.0,
        latents_mean: List[float] = [
            -0.7571,
            -0.7089,
            -0.9113,
            0.1075,
            -0.1745,
            0.9653,
            -0.1517,
            1.5508,
            0.4134,
            -0.0715,
            0.5517,
            -0.3632,
            -0.1922,
            -0.9497,
            0.2503,
            -0.2921,
        ],
        latents_std: List[float] = [
            2.8184,
            1.4541,
            2.3275,
            2.6558,
            1.2196,
            1.7708,
            2.6052,
            2.0743,
            3.2687,
            2.1526,
            2.8652,
            1.5579,
            1.6382,
            1.1253,
            2.8251,
            1.9160,
        ],
    ) -> None:
        super().__init__()

        self.z_dim = z_dim
        self.temperal_downsample = temperal_downsample
        self.temperal_upsample = temperal_downsample[::-1]

        self.encoder = WanEncoder3d(
            base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
        )
        self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
        self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)

        self.decoder = WanDecoder3d(
            base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
        )

        self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)

        # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
        # to perform decoding of a single video latent at a time.
        self.use_slicing = False

        # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
        # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
        # intermediate tiles together, the memory requirement can be lowered.
        self.use_tiling = False

        # The minimal tile height and width for spatial tiling to be used
        self.tile_sample_min_height = 256
        self.tile_sample_min_width = 256

        # The minimal distance between two spatial tiles
        self.tile_sample_stride_height = 192
        self.tile_sample_stride_width = 192

    def enable_tiling(
        self,
        tile_sample_min_height: Optional[int] = None,
        tile_sample_min_width: Optional[int] = None,
        tile_sample_stride_height: Optional[float] = None,
        tile_sample_stride_width: Optional[float] = None,
    ) -> None:
        r"""
        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
        processing larger images.

        Args:
            tile_sample_min_height (`int`, *optional*):
                The minimum height required for a sample to be separated into tiles across the height dimension.
            tile_sample_min_width (`int`, *optional*):
                The minimum width required for a sample to be separated into tiles across the width dimension.
            tile_sample_stride_height (`int`, *optional*):
                The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
                no tiling artifacts produced across the height dimension.
            tile_sample_stride_width (`int`, *optional*):
                The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
                artifacts produced across the width dimension.
        """
        print("Enabling tiled VAE decoding. This will split the input tensor into tiles to compute decoding in several steps.")
        self.use_tiling = True
        self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
        self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
        self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
        self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width

    def disable_tiling(self) -> None:
        r"""
        Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
        decoding in one step.
        """
        self.use_tiling = False

    def enable_slicing(self) -> None:
        r"""
        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
        """
        self.use_slicing = True

    def disable_slicing(self) -> None:
        r"""
        Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
        decoding in one step.
        """
        self.use_slicing = False

    def _encode(self, x: torch.Tensor):
        _, _, num_frame, height, width = x.shape

        if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
            return self.tiled_encode(x)

        self.clear_cache()
        iter_ = 1 + (num_frame - 1) // 4
        for i in range(iter_):
            self._enc_conv_idx = [0]
            if i == 0:
                out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
            else:
                out_ = self.encoder(
                    x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
                    feat_cache=self._enc_feat_map,
                    feat_idx=self._enc_conv_idx,
                )
                out = torch.cat([out, out_], 2)

        enc = self.quant_conv(out)
        self.clear_cache()
        return enc

    @apply_forward_hook
    def encode(
        self, x: torch.Tensor, return_dict: bool = True
    ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
        r"""
        Encode a batch of images into latents.

        Args:
            x (`torch.Tensor`): Input batch of images.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.

        Returns:
                The latent representations of the encoded videos. If `return_dict` is True, a
                [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
        """
        if self.use_slicing and x.shape[0] > 1:
            encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
            h = torch.cat(encoded_slices)
        else:
            h = self._encode(x)
        posterior = DiagonalGaussianDistribution(h)

        if not return_dict:
            return (posterior,)
        return AutoencoderKLOutput(latent_dist=posterior)

    def _decode(self, z: torch.Tensor, return_dict: bool = True):
        _, _, num_frame, height, width = z.shape
        tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
        tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio

        if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
            return self.tiled_decode(z, return_dict=return_dict)

        self.clear_cache()
        x = self.post_quant_conv(z)
        for i in range(num_frame):
            self._conv_idx = [0]
            if i == 0:
                out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
            else:
                out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
                out = torch.cat([out, out_], 2)

        out = torch.clamp(out, min=-1.0, max=1.0)
        self.clear_cache()
        if not return_dict:
            return (out,)

        return DecoderOutput(sample=out)

    @apply_forward_hook
    def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
        r"""
        Decode a batch of images.

        Args:
            z (`torch.Tensor`): Input batch of latent vectors.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.

        Returns:
            [`~models.vae.DecoderOutput`] or `tuple`:
                If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
                returned.
        """
        if self.use_slicing and z.shape[0] > 1:
            decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
            decoded = torch.cat(decoded_slices)
        else:
            decoded = self._decode(z).sample

        if not return_dict:
            return (decoded,)
        return DecoderOutput(sample=decoded)

    def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
        blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
        for y in range(blend_extent):
            b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
                y / blend_extent
            )
        return b

    def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
        blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
        for x in range(blend_extent):
            b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
                x / blend_extent
            )
        return b

    def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
        r"""Encode a batch of images using a tiled encoder.

        Args:
            x (`torch.Tensor`): Input batch of videos.

        Returns:
            `torch.Tensor`:
                The latent representation of the encoded videos.
        """
        _, _, num_frames, height, width = x.shape
        latent_height = height // self.spatial_compression_ratio
        latent_width = width // self.spatial_compression_ratio

        tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
        tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
        tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
        tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio

        blend_height = tile_latent_min_height - tile_latent_stride_height
        blend_width = tile_latent_min_width - tile_latent_stride_width

        # Split x into overlapping tiles and encode them separately.
        # The tiles have an overlap to avoid seams between tiles.
        rows = []
        for i in range(0, height, self.tile_sample_stride_height):
            row = []
            for j in range(0, width, self.tile_sample_stride_width):
                self.clear_cache()
                time = []
                frame_range = 1 + (num_frames - 1) // 4
                for k in range(frame_range):
                    self._enc_conv_idx = [0]
                    if k == 0:
                        tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
                    else:
                        tile = x[
                            :,
                            :,
                            1 + 4 * (k - 1) : 1 + 4 * k,
                            i : i + self.tile_sample_min_height,
                            j : j + self.tile_sample_min_width,
                        ]
                    tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
                    tile = self.quant_conv(tile)
                    time.append(tile)
                row.append(torch.cat(time, dim=2))
            rows.append(row)
        self.clear_cache()

        result_rows = []
        for i, row in enumerate(rows):
            result_row = []
            for j, tile in enumerate(row):
                # blend the above tile and the left tile
                # to the current tile and add the current tile to the result row
                if i > 0:
                    tile = self.blend_v(rows[i - 1][j], tile, blend_height)
                if j > 0:
                    tile = self.blend_h(row[j - 1], tile, blend_width)
                result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
            result_rows.append(torch.cat(result_row, dim=-1))

        enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
        return enc

    def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
        r"""
        Decode a batch of images using a tiled decoder.

        Args:
            z (`torch.Tensor`): Input batch of latent vectors.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.

        Returns:
            [`~models.vae.DecoderOutput`] or `tuple`:
                If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
                returned.
        """
        _, _, num_frames, height, width = z.shape
        sample_height = height * self.spatial_compression_ratio
        sample_width = width * self.spatial_compression_ratio

        tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
        tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
        tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
        tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio

        blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
        blend_width = self.tile_sample_min_width - self.tile_sample_stride_width

        # Split z into overlapping tiles and decode them separately.
        # The tiles have an overlap to avoid seams between tiles.
        rows = []
        for i in range(0, height, tile_latent_stride_height):
            row = []
            for j in range(0, width, tile_latent_stride_width):
                self.clear_cache()
                time = []
                for k in range(num_frames):
                    self._conv_idx = [0]
                    tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
                    tile = self.post_quant_conv(tile)
                    decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
                    time.append(decoded)
                row.append(torch.cat(time, dim=2))
            rows.append(row)
        self.clear_cache()

        result_rows = []
        for i, row in enumerate(rows):
            result_row = []
            for j, tile in enumerate(row):
                # blend the above tile and the left tile
                # to the current tile and add the current tile to the result row
                if i > 0:
                    tile = self.blend_v(rows[i - 1][j], tile, blend_height)
                if j > 0:
                    tile = self.blend_h(row[j - 1], tile, blend_width)
                result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
            result_rows.append(torch.cat(result_row, dim=-1))

        dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]

        if not return_dict:
            return (dec,)
        return DecoderOutput(sample=dec)

## 多尺度压缩块

In [4]:
from torch import nn
import einops

class WanPatchEmbedForCleanLatents(nn.Module):
    def __init__(self, inner_dim, in_chans=16):
        super().__init__()
        
        # 1x压缩， 用于处理最近的关键帧（起始帧、结束帧等需要保留的关键帧）
        self.proj = nn.Conv3d(in_chans, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
        # 2x压缩， 用于处理中等时间尺度的上下文信息（最近生成的帧序列）
        self.proj_2x = nn.Conv3d(in_chans, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
        # 4x压缩， 用于处理长时间尺度的全局上下文信息（历史帧序列的全局信息）
        self.proj_4x = nn.Conv3d(in_chans, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))

    @torch.no_grad()
    def initialize_weight_from_another_conv3d(self, another_layer):
        # 可能是训练或加载权重时使用
        weight = another_layer.weight.detach().clone()
        bias = another_layer.bias.detach().clone()

        sd = {
            # 1x压缩使用的权重
            'proj.weight': weight.clone(),
            'proj.bias': bias.clone(),
            # 2x压缩使用的权重，将权重重复两倍，然后归一化（有三个重复了两倍，所以要除以8）
            'proj_2x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=2, hk=2, wk=2) / 8.0,
            'proj_2x.bias': bias.clone(),
            # 4x压缩使用的权重，将权重重复四倍，然后归一化（有三个重复了四倍，所以要除以64）
            'proj_4x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=4, hk=4, wk=4) / 64.0,
            'proj_4x.bias': bias.clone(),
        }

        sd = {k: v.clone() for k, v in sd.items()}

        self.load_state_dict(sd)
        return

### 为什么使用压缩后需要重复对应的参数？

如果直接用随机权重初始化，不同压缩比例会产生完全不同的特征表示，破坏语义一致性
推理代码中应该没有直接使用，可能在训练时使用


In [5]:

from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
from diffusers.models.transformers.transformer_wan import WanRotaryPosEmbed
from diffusers.models.transformers.transformer_wan import WanTimeTextImageEmbedding
from torch import nn
from typing import Dict, Tuple, Optional, Union

class WanTransformer3DModelPacked(WanTransformer3DModel):
    def __init__(
        self,
        patch_size: Tuple[int] = (1, 2, 2),
        num_attention_heads: int = 40,
        attention_head_dim: int = 128,
        in_channels: int = 16,
        out_channels: int = 16,
        text_dim: int = 4096,
        freq_dim: int = 256,
        ffn_dim: int = 13824,
        num_layers: int = 40,
        cross_attn_norm: bool = True,
        qk_norm: Optional[str] = "rms_norm_across_heads",
        eps: float = 1e-6,
        image_dim: Optional[int] = None,
        added_kv_proj_dim: Optional[int] = None,
        rope_max_seq_len: int = 1024,
        
        has_clean_embedding: bool=False
    ) -> None:
        super().__init__(patch_size, num_attention_heads, attention_head_dim, in_channels, out_channels, text_dim, freq_dim, ffn_dim, num_layers, cross_attn_norm, qk_norm, eps, image_dim, added_kv_proj_dim, rope_max_seq_len)
        
        
        self.inner_dim = num_attention_heads * attention_head_dim
        self.in_channels = in_channels
        
        # 设置多尺度压缩层
        self.clean_embedding = None
        if has_clean_embedding:
            self.install_clean_embedding()
            
        # print(f"WanTransformer3DModelPacked initialized with {self.inner_dim} inner dimensions")
            
        # 1. Patch & position embedding
        
        # 2. Condition embeddings
        
        # 3. Transformer blocks
        
        # 4. Output norm & projection
    

    def install_clean_embedding(self):
        """
        应用多尺度压缩层，将输入的latents压缩到text_dim维度
        """
        self.clean_embedding = WanPatchEmbedForCleanLatents(self.inner_dim, self.in_channels)
        self.config['has_clean_embedding'] = True
        
    def forward(self, hidden_states: torch.Tensor, timestep: torch.LongTensor, encoder_hidden_states: torch.Tensor, encoder_hidden_states_image: Optional[torch.Tensor] = None,
                latent_indices=None,
            clean_latents=None, clean_latent_indices=None,
            clean_latents_2x=None, clean_latent_2x_indices=None,
            clean_latents_4x=None, clean_latent_4x_indices=None,
            return_dict=True, attention_kwargs=None)->Union[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        TODO: 使用多尺度压缩层，将输入的latents压缩
        """
        # print(f"hidden_states.shape = {hidden_states.shape}, device = {hidden_states.device}, dtype = {hidden_states.dtype}" if hidden_states is not None else "hidden_states is None")
        # print(f"timestep = {timestep}, device = {timestep.device}, dtype = {timestep.dtype}" if timestep is not None else "timestep is None")
        # print(f"encoder_hidden_states.shape = {encoder_hidden_states.shape}, device = {encoder_hidden_states.device}, dtype = {encoder_hidden_states.dtype}" if encoder_hidden_states is not None else "encoder_hidden_states is None")
        # print(f"latent_indices.shape = {latent_indices.shape}, device = {latent_indices.device}, dtype = {latent_indices.dtype}" if latent_indices is not None else "latent_indices is None")
        # print(f"clean_latents.shape = {clean_latents.shape}, device = {clean_latents.device}, dtype = {clean_latents.dtype}" if clean_latents is not None else "clean_latents is None")
        # print(f"clean_latent_indices.shape = {clean_latent_indices.shape}, device = {clean_latent_indices.device}, dtype = {clean_latent_indices.dtype}" if clean_latent_indices is not None else "clean_latent_indices is None")
        # print(f"clean_latents_2x.shape = {clean_latents_2x.shape}, device = {clean_latents_2x.device}, dtype = {clean_latents_2x.dtype}" if clean_latents_2x is not None else "clean_latents_2x is None")
        # print(f"clean_latent_2x_indices.shape = {clean_latent_2x_indices.shape}, device = {clean_latent_2x_indices.device}, dtype = {clean_latent_2x_indices.dtype}" if clean_latent_2x_indices is not None else "clean_latent_2x_indices is None")
        # print(f"clean_latents_4x.shape = {clean_latents_4x.shape}, device = {clean_latents_4x.device}, dtype = {clean_latents_4x.dtype}" if clean_latents_4x is not None else "clean_latents_4x is None")
        # print(f"clean_latent_4x_indices.shape = {clean_latent_4x_indices.shape}, device = {clean_latent_4x_indices.device}, dtype = {clean_latent_4x_indices.dtype}" if clean_latent_4x_indices is not None else "clean_latent_4x_indices is None")
        # print(f"transformer device: {next(self.parameters()).device}")
        # print(f"transformer dtype: {next(self.parameters()).dtype}")
        
        return super().forward(hidden_states, timestep, encoder_hidden_states, encoder_hidden_states_image, return_dict, attention_kwargs)

In [6]:
from diffusers.utils import is_torch_xla_available
if is_torch_xla_available():
    import torch_xla.core.xla_model as xm

    XLA_AVAILABLE = True
else:
    XLA_AVAILABLE = False

XLA_AVAILABLE

False

In [7]:
# 完整的WanFramePackPipeline实现
from typing import Any, Callable, List, Union, Optional, Dict
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
import torch.nn.functional as F


class WanFramePackPipelineComplete(WanPipeline):
    """
    完整支持FramePack功能的Wan Pipeline
    """
    def __init__(self, tokenizer, text_encoder, transformer, vae, scheduler, high_vram=False, latent_window_size=9, cfg=1.0, rs=0.0, target_device=None):
        self.high_vram = high_vram
        self.latent_window_size = latent_window_size
        self.cfg = cfg
        self.rs = rs
        self.total_latent_sections = None
        self.prompt_embeds = None
        self.negative_prompt_embeds = None
        
        self.target_device = target_device
        if self.target_device is None:
            self.target_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

        super().__init__(tokenizer, text_encoder, transformer, vae, scheduler)
        
        if not self.high_vram:
            from diffusers_helper.memory import DynamicSwapInstaller
            DynamicSwapInstaller.install_model(self.transformer, device=self._execution_device)
            DynamicSwapInstaller.install_model(self.text_encoder, device=self._execution_device)
        
    def encode_prompt(self, 
                      prompt: Union[str, List[str]],
                      negative_prompt: Optional[Union[str, List[str]]] = None,
                      do_classifier_free_guidance: bool = True,
                      num_videos_per_prompt: int = 1,
                      prompt_embeds: Optional[torch.Tensor] = None,
                      negative_prompt_embeds: Optional[torch.Tensor] = None,
                      max_sequence_length: int = 226,
                      device: Optional[torch.device] = None,
                      dtype: Optional[torch.dtype] = None):
        """FramePack专用的文本编码方法，会缓存结果"""
        if not self.high_vram:
            unload_complete_models()
            fake_diffusers_current_device(self.text_encoder, device)
            load_model_as_complete(self.text_encoder, target_device=device)

        prompt_embeds, negative_prompt_embeds = super().encode_prompt(
            prompt, negative_prompt, do_classifier_free_guidance, 
            num_videos_per_prompt, prompt_embeds, negative_prompt_embeds, 
            max_sequence_length, device=device, dtype=dtype)
        
        # 缓存编码结果
        self.prompt_embeds = prompt_embeds
        self.negative_prompt_embeds = negative_prompt_embeds
        
        # if not self.high_vram:
        #     unload_complete_models()
        
        return prompt_embeds, negative_prompt_embeds
    
    def sample_framepack_section(self, 
                                latents: torch.Tensor,
                                clean_latents: torch.Tensor,
                                clean_latent_indices: torch.Tensor,
                                clean_latents_2x: torch.Tensor,
                                clean_latent_2x_indices: torch.Tensor,
                                clean_latents_4x: torch.Tensor,
                                clean_latent_4x_indices: torch.Tensor,
                                prompt_embeds: torch.Tensor,
                                negative_prompt_embeds: Optional[torch.Tensor],
                                num_inference_steps: int,
                                guidance_scale: float,
                                generator: Optional[torch.Generator] = None,
                                attention_kwargs: Optional[Dict[str, Any]] = None,
                                callback_on_step_end: Optional[Callable] = None,
                                callback_on_step_end_tensor_inputs: List[str] = ["latents"]):
        """
        FramePack单段采样函数
        """
        device = self._execution_device
        
        # 设置时间步
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps
        
        # 确保latents在正确设备和数据类型
        latents = latents.to(device=device, dtype=self.transformer.dtype)
        
        # 准备FramePack参数
        framepack_kwargs = {
            'clean_latents': clean_latents.to(device=device, dtype=self.transformer.dtype),
            'clean_latent_indices': clean_latent_indices.to(device=device),
            'clean_latents_2x': clean_latents_2x.to(device=device, dtype=self.transformer.dtype),
            'clean_latent_2x_indices': clean_latent_2x_indices.to(device=device),
            'clean_latents_4x': clean_latents_4x.to(device=device, dtype=self.transformer.dtype),
            'clean_latent_4x_indices': clean_latent_4x_indices.to(device=device),
        }
        
        # 去噪循环
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        transformer_dtype = self.transformer.dtype
        # with self.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            if self.interrupt:
                continue
                
            self._current_timestep = t
            latent_model_input = latents.to(transformer_dtype).to(device)
            timestep = t.expand(latents.shape[0])
            
            if not self.high_vram:
                unload_complete_models()
                # fake_diffusers_current_device(self.transformer, device)
                # load_model_as_complete(self.transformer, target_device=device)
                move_model_to_device_with_memory_preservation(self.transformer, target_device=device, preserved_memory_gb=self.gpu_memory_preservation)
            
            # 正向预测（带FramePack参数）
            noise_pred = self.transformer(
                hidden_states=latent_model_input,
                timestep=timestep,
                encoder_hidden_states=prompt_embeds.to(device),
                attention_kwargs=attention_kwargs,
                return_dict=False,
                **framepack_kwargs  # 传递FramePack参数
            )[0]
            # print(f"[WanFramePackPipelineComplete.sample_framepack_section] noise_pred.shape = {noise_pred.shape}")
            # 分类器自由引导
            if self.do_classifier_free_guidance:
                noise_uncond = self.transformer(
                    hidden_states=latent_model_input,
                    timestep=timestep,
                    encoder_hidden_states=negative_prompt_embeds.to(device) if negative_prompt_embeds is not None else None,
                    attention_kwargs=attention_kwargs,
                    return_dict=False,
                    **framepack_kwargs  # 传递FramePack参数
                )[0]
                noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
            
            # 计算前一个噪声样本 x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
            
            # 回调处理
            if callback_on_step_end is not None:
                callback_kwargs = {}
                for k in callback_on_step_end_tensor_inputs:
                    callback_kwargs[k] = locals()[k]
                callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
                
                latents = callback_outputs.pop("latents", latents)
                prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
                negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
            
            # if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
            #     progress_bar.update()
                
            if XLA_AVAILABLE:
                xm.mark_step()
        
        self._current_timestep = None
        return latents
    
    @property
    def _execution_device(self):
        """
        FramePack 重写 _execution_device 属性，确保返回正确的 GPU 设备(暂时设置为cuda:0)
        """
        # 如果使用高显存模式，直接返回 GPU
        if self.high_vram:
            return self.target_device
        
        # 对于低显存模式，检查是否有模型在 GPU 上
        for name, component in self.components.items():
            if isinstance(component, torch.nn.Module) and hasattr(component, 'device'):
                if component.device.type == 'cuda':
                    return component.device
        
        # 如果没有模型在 GPU 上，返回默认的 GPU 设备
        # 这是 FramePack 的关键：即使模型在 CPU，执行设备仍然是 GPU
        return self.target_device
    
    def prepare_history_latents(self, 
                                batch_size: int,
                                num_channels_latents: int,
                                height: int,
                                width: int,
                                dtype: torch.dtype,
                                device: torch.device,
                                num_frames: int = None,
                                num_latent_frames: int = None
    ):
        # 优先使用num_latent_frames, 如果num_latent_frames为None, 则根据num_frames计算
        assert num_frames is not None or num_latent_frames is not None, "num_frames or num_latent_frames must be provided"
        if not num_latent_frames:
            num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
        
        shape = (
            batch_size, num_channels_latents, num_latent_frames, int(height) // self.vae_scale_factor_spatial, int(width) // self.vae_scale_factor_spatial
        )            
        history_latents = torch.zeros(
            size=shape,
            dtype=dtype,
            device=device
        )
        return history_latents
    
    def __call__(self, 
                prompt: Union[str, List[str]] = None,
                negative_prompt: Union[str, List[str]] = None,
                height: int = 480,
                width: int = 832,
                num_inference_steps: int = 50,
                guidance_scale: float = 5.0,
                num_videos_per_prompt: Optional[int] = 1,
                generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
                latents: Optional[torch.Tensor] = None,
                prompt_embeds: Optional[torch.Tensor] = None,
                negative_prompt_embeds: Optional[torch.Tensor] = None,
                output_type: Optional[str] = "np",
                return_dict: bool = True,
                attention_kwargs: Optional[Dict[str, Any]] = None,
                callback_on_step_end: Optional[
                    Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
                ] = None,
                callback_on_step_end_tensor_inputs: List[str] = ["latents"],
                max_sequence_length: int = 512,
                seed: int = 42,
                use_teacache: bool = False,
                total_second_length: float = 5.0,
                fps: float = 30.0,
                gpu_memory_preservation: float = 11.0):
        """
        FramePack视频生成主函数
        """
        device = self._execution_device
        # print(f"[WanFramePackPipelineComplete.__call__] device = {device}")
        
        # 1. 输入检查
        self.check_inputs(
            prompt, negative_prompt, height, width,
            prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs
        )
        
        height, width = find_nearest_bucket(height, width, resolution=640)
        
        # 2. 设置参数
        self._guidance_scale = guidance_scale
        self._attention_kwargs = attention_kwargs
        self._current_timestep = None
        self._interrupt = False
        self.gpu_memory_preservation = gpu_memory_preservation
        
        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
        
        # 3. 文本编码
        transformer_dtype = self.transformer.dtype
        if self.prompt_embeds is None or self.negative_prompt_embeds is None:
            prompt_embeds, negative_prompt_embeds = self.encode_prompt(
                prompt=prompt,
                negative_prompt=negative_prompt,
                do_classifier_free_guidance=self.do_classifier_free_guidance,
                num_videos_per_prompt=num_videos_per_prompt,
                prompt_embeds=prompt_embeds,
                negative_prompt_embeds=negative_prompt_embeds,
                max_sequence_length=max_sequence_length,
                device=device,
                dtype=transformer_dtype
            )
        else:
            prompt_embeds = self.prompt_embeds
            negative_prompt_embeds = self.negative_prompt_embeds
            
            prompt_embeds = prompt_embeds.to(transformer_dtype)
            if negative_prompt_embeds is not None:
                negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
        # print(f"[WanFramePackPipelineComplete.__call__] prompt_embeds.shape = {self.prompt_embeds.shape}, negative_prompt_embeds.shape = {self.negative_prompt_embeds.shape}")
        
        
        # 4. 准备时间步
        # print(f"[WanFramePackPipelineComplete.__call__] num_inference_steps = {num_inference_steps}")
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps
        
        
        # 5. 创建起始latent
        # print(f"[WanFramePackPipelineComplete.__call__] start prepare start latent variables")
        num_channels_latents = self.transformer.config.in_channels
        if generator is None:
            generator = torch.Generator(device=cpu).manual_seed(seed)

        # 生成起始latent
        # prepare_latents 会压缩width和height
        # 如果start_latent存在，prepare_latents会直接返回start_latent
        if not self.high_vram:
            # print(f"[WanFramePackPipelineComplete.__call__] load vae model")
            load_model_as_complete(self.vae, target_device=device)
            # print(f"[WanFramePackPipelineComplete.__call__] load vae model done")

        # print(f"[WanFramePackPipelineComplete.__call__] start prepare start latent variables")
        start_latent = self.prepare_latents(
            batch_size=1,
            num_channels_latents=num_channels_latents,
            height=height,
            width=width,
            num_frames=1,
            dtype=self.vae.dtype,   # 官方直接使用torch.float32
            device=cpu,
            generator=generator,
            latents=latents
        )
        # print(f"[WanFramePackPipelineComplete.__call__] start_latent.shape = {start_latent.shape}")
        
        # 6. FramePack核心逻辑, 默认窗口大小为9, num_frames = 33, 每个section的latent数量为33 * 16 = 528
        # 这里的4应该是指时间压缩倍率，3是指下面history_latents压缩后占用的latent数
        num_frames = self.latent_window_size * 4 - 3
        total_latent_sections = int(max(round((total_second_length * fps) / (self.latent_window_size * 4)), 1))
        # print(f"[WanFramePackPipelineComplete.__call__] total_latent_sections = {total_latent_sections}")
        
        # 初始化历史latents缓冲区
        # 1，2，16分别指1x, 2x, 4x压缩（前）的latents，后面经过不同卷积核后变成3个latent
        history_latents = self.prepare_history_latents(
            batch_size=1,
            num_channels_latents=num_channels_latents,
            num_latent_frames= 1 + 2 + 16,
            height=height,
            width=width,
            dtype=self.vae.dtype,
            device=cpu
        )
        
        history_pixels = None
        total_generated_latent_frames = 0
        
        # 设置latent_paddings序列
        if total_latent_sections <= 4:
            # 理论上，`latent_paddings` 应该遵循上述顺序，但当 `total_latent_sections` 大于4时，复制某些元素似乎比扩展它的效果更好。
            # 可以尝试去掉下面这个技巧，只使用 `latent_paddings = list(reversed(range(total_latent_sections)))` 来进行比较。 
            latent_paddings = list(reversed(range(total_latent_sections)))
        else:
            latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
        
        # 7. 分段生成循环
        for section_idx, latent_padding in enumerate(latent_paddings):
            is_last_section = latent_padding == 0
            latent_padding_size = latent_padding * self.latent_window_size
            
            print(f'Section {section_idx+1}/{len(latent_paddings)}: '
                  f'latent_padding_size={latent_padding_size}, is_last_section={is_last_section}')
            
            # 构建FramePack的多尺度索引
            total_indices = sum([1, latent_padding_size, self.latent_window_size, 1, 2, 16])
            indices = torch.arange(0, total_indices).unsqueeze(0)
            
            split_sizes = [1, latent_padding_size, self.latent_window_size, 1, 2, 16]
            (clean_latent_indices_pre, blank_indices, latent_indices, 
             clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices) = indices.split(split_sizes, dim=1)
            
            clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
            
            # 构建多尺度clean_latents
            clean_latents_pre = start_latent.to(history_latents.device, history_latents.dtype)
            clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, :1 + 2 + 16, :, :].split([1, 2, 16], dim=2)
            clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
            
            # 准备当前段的噪声latents
            current_latents = self.prepare_latents(
                batch_size=1,
                num_channels_latents=num_channels_latents,
                height=height,
                width=width,
                num_frames=num_frames,
                dtype=self.vae.dtype,
                device=cpu,
                generator=generator,
                latents=None
            )
            
            # 模型加载到GPU
            if not self.high_vram:
                unload_complete_models()
                move_model_to_device_with_memory_preservation(
                    self.transformer, target_device=device, 
                    preserved_memory_gb=gpu_memory_preservation
                )
            
            # TeaCache支持（如果需要）
            if use_teacache:
                print("TeaCache not implemented yet")
            
            # 执行FramePack采样
            generated_latents = self.sample_framepack_section(
                latents=current_latents,
                clean_latents=clean_latents,
                clean_latent_indices=clean_latent_indices,
                clean_latents_2x=clean_latents_2x,
                clean_latent_2x_indices=clean_latent_2x_indices,
                clean_latents_4x=clean_latents_4x,
                clean_latent_4x_indices=clean_latent_4x_indices,
                prompt_embeds=prompt_embeds,
                negative_prompt_embeds=negative_prompt_embeds,
                num_inference_steps=num_inference_steps,
                guidance_scale=guidance_scale,
                generator=generator,
                attention_kwargs=attention_kwargs,
                callback_on_step_end=callback_on_step_end,
                callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs
            )
            
            # 处理最后一段
            if is_last_section:
                generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2)
            
            # 更新历史缓冲区
            total_generated_latent_frames += int(generated_latents.shape[2])
            history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
            
            # 内存管理
            if not self.high_vram:
                offload_model_from_device_for_memory_preservation(
                    self.transformer, target_device=device, preserved_memory_gb=self.gpu_memory_preservation
                )
                load_model_as_complete(self.vae, target_device=device)
            
            # VAE解码（如果需要输出视频）
            if output_type != "latent":
                real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]
                
                if history_pixels is None:
                    # 首次解码
                    history_pixels = self.vae.decode(
                        real_history_latents.to(self.vae.dtype).to(self.vae.device)
                    ).sample.cpu()
                else:
                    # 增量解码
                    section_latent_frames = (self.latent_window_size * 2 + 1) if is_last_section else (self.latent_window_size * 2)
                    overlapped_frames = self.latent_window_size * 4 - 3
                    
                    current_pixels = self.vae.decode(
                        real_history_latents[:, :, :section_latent_frames].to(self.vae.dtype).to(self.vae.device)
                    ).sample.cpu()
                    
                    history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
            
            # 内存清理
            if not self.high_vram:
                unload_complete_models()
            
            print(f'Section {section_idx+1} completed. Generated frames: {total_generated_latent_frames}')
            
            if is_last_section:
                break
        
        # 8. 最终输出处理
        if output_type == "latent":
            video = history_latents[:, :, :total_generated_latent_frames, :, :]
        else:
            video = self.video_processor.postprocess_video(history_pixels, output_type=output_type)
        
        # 清理
        self.maybe_free_model_hooks()
        
        if not return_dict:
            return (video,)
        
        return WanPipelineOutput(frames=video)


In [8]:
transformer = WanTransformer3DModelPacked.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16, has_clean_x_embedder=True)

Fetching 2 files: 100%|██████████| 2/2 [00:00<00:00, 11169.92it/s]
Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.18s/it]


In [9]:
# 获取显存信息
free_mem_gb = get_cuda_free_memory_gb(gpu)
high_vram = free_mem_gb > 60

print(f'Free VRAM {free_mem_gb} GB')
print(f'High-VRAM Mode: {high_vram}')

# 加载模型组件
print("Loading Wan2.1-1.3B model components...")

# 加载tokenizer和text_encoder
tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder='tokenizer')
text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder='text_encoder', torch_dtype=torch.bfloat16).cpu()

# 加载VAE
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float16).cpu()

# 加载调度器
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")

# 使用我们的WanTransformer3DModelPacked
print("Using WanTransformer3DModelPacked with FramePack features...")

# 设置模型为评估模式
vae.eval()
text_encoder.eval()
transformer.eval()

if not high_vram:
    vae.enable_slicing()
    vae.enable_tiling()

# 设置高精度输出
transformer.high_quality_fp32_output_for_inference = True
print('transformer.high_quality_fp32_output_for_inference = True')

# 设置数据类型
transformer.to(dtype=torch.bfloat16)
vae.to(dtype=torch.float16)
text_encoder.to(dtype=torch.bfloat16)

# 冻结梯度
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
transformer.requires_grad_(False)

# 内存管理
# if not high_vram:
#     from diffusers_helper.memory import DynamicSwapInstaller
#     DynamicSwapInstaller.install_model(transformer, device=gpu)
#     DynamicSwapInstaller.install_model(text_encoder, device=gpu)
# else:
#     text_encoder.to(gpu)
#     vae.to(gpu)
#     transformer.to(gpu)


Free VRAM 10.7939453125 GB
High-VRAM Mode: False
Loading Wan2.1-1.3B model components...


Downloading shards: 100%|██████████| 5/5 [00:00<00:00, 14065.41it/s]


Loading checkpoint shards: 100%|██████████| 5/5 [01:22<00:00, 16.41s/it]


Using WanTransformer3DModelPacked with FramePack features...
Enabling tiled VAE decoding. This will split the input tensor into tiles to compute decoding in several steps.
transformer.high_quality_fp32_output_for_inference = True


WanTransformer3DModelPacked(
  (rope): WanRotaryPosEmbed()
  (patch_embedding): Conv3d(16, 1536, kernel_size=(1, 2, 2), stride=(1, 2, 2))
  (condition_embedder): WanTimeTextImageEmbedding(
    (timesteps_proj): Timesteps()
    (time_embedder): TimestepEmbedding(
      (linear_1): Linear(in_features=256, out_features=1536, bias=True)
      (act): SiLU()
      (linear_2): Linear(in_features=1536, out_features=1536, bias=True)
    )
    (act_fn): SiLU()
    (time_proj): Linear(in_features=1536, out_features=9216, bias=True)
    (text_embedder): PixArtAlphaTextProjection(
      (linear_1): Linear(in_features=4096, out_features=1536, bias=True)
      (act_1): GELU(approximate='tanh')
      (linear_2): Linear(in_features=1536, out_features=1536, bias=True)
    )
  )
  (blocks): ModuleList(
    (0-29): 30 x WanTransformerBlock(
      (norm1): FP32LayerNorm((1536,), eps=1e-06, elementwise_affine=False)
      (attn1): Attention(
        (norm_q): RMSNorm()
        (norm_k): RMSNorm()
        (t

In [10]:
pipeline = WanFramePackPipelineComplete(
    tokenizer=tokenizer,
    text_encoder=text_encoder, 
    transformer=transformer,
    vae=vae,
    scheduler=scheduler,
    high_vram=high_vram,
    target_device=gpu
)

# 暂时使用标准pipeline进行测试
# pipeline = WanPipeline.from_pretrained(
#     model_id,
#     transformer=transformer,  # 使用我们的packed transformer
#     torch_dtype=torch.bfloat16
# )

### Test Start

In [11]:
from diffusers.utils import export_to_video

prompt = "A cat walks on the grass, realistic"
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
output = pipeline(
            prompt=prompt,
            negative_prompt=negative_prompt,
            height=480,
            width=832,
            guidance_scale=5.0,
            total_second_length=2,
            gpu_memory_preservation=11.0,
            fps=30,
        ).frames[0]

# 设置输出目录
outputs_folder = "/home/tippy/FramePack/outputs"
os.makedirs(outputs_folder, exist_ok=True)
output_path =  os.path.join(outputs_folder, "test_output.mp4")
export_to_video(output, output_path, fps=30)
output.shape

Section 1/2: latent_padding_size=9, is_last_section=False
Section 1 completed. Generated frames: 9
Section 2/2: latent_padding_size=0, is_last_section=True
Section 2 completed. Generated frames: 19


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


(73, 480, 832, 3)

In [13]:
from IPython.display import VimeoVideo

VimeoVideo(output_path)

In [None]:
start_latent = torch.randn(
                (1, 16, 1, 1080 // 8, 1920 // 8),
                generator=torch.Generator(gpu).manual_seed(42),
                dtype=vae.dtype,
                device=gpu
            )

In [None]:
history_latents = torch.zeros(1, 16, 1+2+16, 480, 480)
history_latents.shape

torch.Size([1, 16, 19, 480, 480])

In [None]:
history_latents = history_latents.split([1, 2, 16], dim=2)
history_latents[0].shape
history_latents[1].shape
history_latents[2].shape

torch.Size([1, 16, 16, 480, 480])

In [24]:
transformer.__dict__.keys()

dict_keys(['_internal_dict', 'training', '_parameters', '_buffers', '_non_persistent_buffers_set', '_backward_pre_hooks', '_backward_hooks', '_is_full_backward_hook', '_forward_hooks', '_forward_hooks_with_kwargs', '_forward_hooks_always_called', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_state_dict_hooks', '_state_dict_pre_hooks', '_load_state_dict_pre_hooks', '_load_state_dict_post_hooks', '_modules', '_gradient_checkpointing_func', 'gradient_checkpointing', 'inner_dim', 'in_channels'])

In [30]:
transformer.__dict__

{'_internal_dict': FrozenDict([('patch_size', [1, 2, 2]),
             ('num_attention_heads', 12),
             ('attention_head_dim', 128),
             ('in_channels', 16),
             ('out_channels', 16),
             ('text_dim', 4096),
             ('freq_dim', 256),
             ('ffn_dim', 8960),
             ('num_layers', 30),
             ('cross_attn_norm', True),
             ('qk_norm', 'rms_norm_across_heads'),
             ('eps', 1e-06),
             ('image_dim', None),
             ('added_kv_proj_dim', None),
             ('rope_max_seq_len', 1024),
             ('_use_default_values',
              ['image_dim',
               'num_layers',
               'qk_norm',
               'cross_attn_norm',
               'num_attention_heads',
               'patch_size',
               'text_dim',
               'ffn_dim',
               'freq_dim',
               'added_kv_proj_dim',
               'attention_head_dim',
               'eps',
               'out_channe

In [31]:
from diffusers_helper.memory import DynamicSwapInstaller

DynamicSwapInstaller.install_model(transformer, device=gpu)

In [32]:
transformer.__dict__

{'_internal_dict': FrozenDict([('patch_size', [1, 2, 2]),
             ('num_attention_heads', 12),
             ('attention_head_dim', 128),
             ('in_channels', 16),
             ('out_channels', 16),
             ('text_dim', 4096),
             ('freq_dim', 256),
             ('ffn_dim', 8960),
             ('num_layers', 30),
             ('cross_attn_norm', True),
             ('qk_norm', 'rms_norm_across_heads'),
             ('eps', 1e-06),
             ('image_dim', None),
             ('added_kv_proj_dim', None),
             ('rope_max_seq_len', 1024),
             ('_use_default_values',
              ['image_dim',
               'num_layers',
               'qk_norm',
               'cross_attn_norm',
               'num_attention_heads',
               'patch_size',
               'text_dim',
               'ffn_dim',
               'freq_dim',
               'added_kv_proj_dim',
               'attention_head_dim',
               'eps',
               'out_channe

In [None]:
from diffusers.models.transformers.transformer_hunyuan_video import HunyuanVideoTransformer3DModelPacked