Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pipeline/causal_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
self.args = args
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
self.local_attn_size = args.model_kwargs.local_attn_size
self.global_sink = getattr(args, "global_sink", False)

# Normalize to list if sequence-like (e.g., OmegaConf ListConfig)

Expand Down Expand Up @@ -260,7 +261,8 @@ def _initialize_kv_cache(self, batch_size, dtype, device, kv_cache_size_override
"k": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
"v": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
"local_end_index": torch.tensor([0], dtype=torch.long, device=device),
"global_sink": self.global_sink,
})

self.kv_cache1 = kv_cache1 # always store the clean cache
Expand Down
13 changes: 1 addition & 12 deletions pipeline/interactive_causal_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,9 @@ def __init__(
vae: WanVAEWrapper | None = None,
):
super().__init__(args, device, generator=generator, text_encoder=text_encoder, vae=vae)
self.global_sink = getattr(args, "global_sink", False)

# Internal helpers
def _recache_after_switch(self, output, current_start_frame, new_conditional_dict):
if not self.global_sink:
# reset kv cache
for block_idx in range(self.num_transformer_blocks):
cache = self.kv_cache1[block_idx]
cache["k"].zero_()
cache["v"].zero_()
# cache["global_end_index"].zero_()
# cache["local_end_index"].zero_()

# reset cross-attention cache
for blk in self.crossattn_cache:
blk["k"].zero_()
Expand Down Expand Up @@ -75,8 +65,6 @@ def _recache_after_switch(self, output, current_start_frame, new_conditional_dic
context_timestep = torch.ones([batch_size, num_recache_frames],
device=device, dtype=torch.int64) * self.args.context_noise

self.generator.model.block_mask = block_mask

# recache
with torch.no_grad():
self.generator(
Expand All @@ -86,6 +74,7 @@ def _recache_after_switch(self, output, current_start_frame, new_conditional_dic
kv_cache=self.kv_cache1,
crossattn_cache=self.crossattn_cache,
current_start=recache_start_frame * self.frame_seq_length,
block_mask=block_mask,
)

# reset cross-attention cache
Expand Down
47 changes: 17 additions & 30 deletions pipeline/streaming_switch_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,6 @@ class StreamingSwitchTrainingPipeline(StreamingTrainingPipeline):
remaining frames.
"""

def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.global_sink = getattr(args, "global_sink", False)

def generate_chunk_with_cache(
self,
noise: torch.Tensor,
Expand Down Expand Up @@ -242,42 +234,37 @@ def generate_chunk_with_cache(
return output, denoised_timestep_from, denoised_timestep_to

def _recache_after_switch(self, output, current_start_frame, new_conditional_dict, local_start_frame=None, switch_recache_frames=None):
if not self.global_sink:
# reset kv cache
for block_idx in range(self.num_transformer_blocks):
cache = self.kv_cache1[block_idx]
cache["k"].zero_()
cache["v"].zero_()
# cache["global_end_index"].zero_()
# cache["local_end_index"].zero_()

# reset cross-attention cache
for blk in self.crossattn_cache:
blk["k"].zero_()
blk["v"].zero_()
blk["is_init"] = False

if current_start_frame == 0:
return
assert current_start_frame > 0, "recache should happen after frames after generated"

if switch_recache_frames is not None:
frames_to_recache = torch.cat([switch_recache_frames, output], dim=1)[:, -21:, ...]
frames_to_recache = torch.cat([switch_recache_frames, output], dim=1)[:, -self.local_attn_size:, ...]
num_recache_frames = frames_to_recache.shape[1]
if DEBUG and (not dist.is_initialized() or dist.get_rank() == 0):
print(f"[SeqTrain-DMDSwitch] Using external switch_recache_frames (previous_frames): {frames_to_recache.shape}")
else:
# Determine how to fetch frames based on whether local_start_frame is provided
if local_start_frame is not None:
# Chunk mode: output is the current chunk's output; use relative coordinates
num_recache_frames = min(local_start_frame, 21)
num_recache_frames = min(local_start_frame, self.local_attn_size)
frames_to_recache = output[:, -num_recache_frames:]
else:
# Full sequence mode: output is the complete sequence; use absolute coordinates
num_recache_frames = min(current_start_frame, 21)
num_recache_frames = min(current_start_frame, self.local_attn_size)
frames_to_recache = output[:, -num_recache_frames:]

batch_size, num_recache_frames, c, h, w = frames_to_recache.shape

for block_idx in range(self.num_transformer_blocks):
cache = self.kv_cache1[block_idx]
# update local end index pointer so that we rebuild the cache from the beginning
cache["local_end_index"].fill_((num_recache_frames - current_start_frame) * self.frame_seq_length + cache["global_end_index"].item())

# reset cross-attention cache
for blk in self.crossattn_cache:
blk["k"].zero_()
blk["v"].zero_()
blk["is_init"] = False

if (not dist.is_initialized() or dist.get_rank() == 0) and DEBUG:
print(f"num_recache_frames: {num_recache_frames}, current_start_frame: {current_start_frame}, local_start_frame: {local_start_frame}")

Expand All @@ -290,15 +277,14 @@ def _recache_after_switch(self, output, current_start_frame, new_conditional_dic
num_frames=num_recache_frames,
frame_seqlen=self.frame_seq_length,
num_frame_per_block=self.num_frame_per_block,
local_attn_size=21
local_attn_size=self.local_attn_size,
)

# Prepare time steps
context_timestep = torch.ones([batch_size, num_recache_frames],
device=device, dtype=torch.int64) * self.context_noise

# Set the new block_mask
self.generator.model.block_mask = block_mask
if DEBUG and (not dist.is_initialized() or dist.get_rank() == 0):
print(f"current_start_frame: {current_start_frame}, num_recache_frames: {num_recache_frames}")
with torch.no_grad():
Expand All @@ -309,6 +295,7 @@ def _recache_after_switch(self, output, current_start_frame, new_conditional_dic
kv_cache=self.kv_cache1,
crossattn_cache=self.crossattn_cache,
current_start=(current_start_frame - num_recache_frames) * self.frame_seq_length,
block_mask=block_mask,
)

# reset cross-attention cache
Expand Down
4 changes: 3 additions & 1 deletion pipeline/streaming_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self,
self.last_step_only = last_step_only

self.local_attn_size = kwargs.get("local_attn_size", -1)
self.global_sink = kwargs.get("global_sink", False)

slice_last_frames: int = int(kwargs.get("slice_last_frames", 21))
self.kv_cache_size = (self.local_attn_size + slice_last_frames) * self.frame_seq_length
Expand Down Expand Up @@ -268,7 +269,8 @@ def _initialize_kv_cache(self, batch_size, dtype, device):
"k": torch.zeros([batch_size, self.kv_cache_size, 12, 128], dtype=dtype, device=device),
"v": torch.zeros([batch_size, self.kv_cache_size, 12, 128], dtype=dtype, device=device),
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
"local_end_index": torch.tensor([0], dtype=torch.long, device=device),
"global_sink": self.global_sink,
})

self.kv_cache1 = kv_cache1 # always store the clean cache
Expand Down
23 changes: 6 additions & 17 deletions pipeline/switch_causal_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,9 @@ def __init__(
vae: WanVAEWrapper | None = None,
):
super().__init__(args, device, generator=generator, text_encoder=text_encoder, vae=vae)
self.global_sink = getattr(args, "global_sink", False)

# Internal helpers
def _recache_after_switch(self, output, current_start_frame, new_conditional_dict):
if not self.global_sink:
# reset kv cache
for block_idx in range(self.num_transformer_blocks):
cache = self.kv_cache1[block_idx]
cache["k"].zero_()
cache["v"].zero_()
# cache["global_end_index"].zero_()
# cache["local_end_index"].zero_()

# reset cross-attention cache
for blk in self.crossattn_cache:
blk["k"].zero_()
Expand Down Expand Up @@ -73,19 +63,18 @@ def _recache_after_switch(self, output, current_start_frame, new_conditional_dic
local_attn_size=self.local_attn_size
)

context_timestep = torch.ones([batch_size, recompute_frames],
context_timestep = torch.ones([batch_size, num_recache_frames],
device=device, dtype=torch.int64) * self.args.context_noise

self.generator.model.block_mask = block_mask


with torch.no_grad():
self.generator(
noisy_image_or_video=frames_to_recompute,
noisy_image_or_video=frames_to_recache,
conditional_dict=new_conditional_dict,
timestep=context_timestep,
kv_cache=self.kv_cache1,
crossattn_cache=self.crossattn_cache,
current_start=recompute_start_frame * self.frame_seq_length,
current_start=recache_start_frame * self.frame_seq_length,
block_mask=block_mask,
)

# reset cross-attention cache
Expand Down Expand Up @@ -177,7 +166,7 @@ def inference(
else:
cond_in_use = cond_second if using_second else cond_first

noisy_input = noise[:, current_start_frame - num_input_frames : current_start_frame + current_num_frames - num_input_frames]
noisy_input = noise[:, current_start_frame - (1 if initial_latent is not None else 0) : current_start_frame + current_num_frames - (1 if initial_latent is not None else 0)]

# Spatial denoising loop (same as parent but uses cond_in_use)
for index, current_timestep in enumerate(self.denoising_step_list):
Expand Down
7 changes: 5 additions & 2 deletions utils/wan_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List, Optional
import torch
from torch import nn
from torch.nn.attention.flex_attention import BlockMask

from utils.scheduler import SchedulerInterface, FlowMatchScheduler
from wan.modules.tokenizers import HuggingfaceTokenizer
Expand Down Expand Up @@ -231,7 +232,8 @@ def forward(
concat_time_embeddings: Optional[bool] = False,
clean_x: Optional[torch.Tensor] = None,
aug_t: Optional[torch.Tensor] = None,
cache_start: Optional[int] = None
cache_start: Optional[int] = None,
block_mask: Optional[BlockMask] = None
) -> torch.Tensor:
prompt_embeds = conditional_dict["prompt_embeds"]

Expand All @@ -251,7 +253,8 @@ def forward(
kv_cache=kv_cache,
crossattn_cache=crossattn_cache,
current_start=current_start,
cache_start=cache_start
cache_start=cache_start,
block_mask=block_mask
).permute(0, 2, 1, 3, 4)
else:
if clean_x is not None:
Expand Down
67 changes: 41 additions & 26 deletions wan/modules/causal_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Adopted from https://github.com/guandeh17/Self-Forcing
# SPDX-License-Identifier: CC-BY-NC-SA-4.0
from typing import Optional
from wan.modules.attention import attention
from wan.modules.model import (
WanRMSNorm,
Expand Down Expand Up @@ -103,7 +104,7 @@ def forward(
block_mask,
kv_cache=None,
current_start=0,
cache_start=None
cache_start=None,
):
r"""
Args:
Expand Down Expand Up @@ -256,8 +257,8 @@ def qkv_fn(x):
temp_v[:, sink_tokens + num_evicted_tokens:sink_tokens + num_evicted_tokens + num_rolled_tokens].clone()

# Insert new key/value into the temporary cache
# Protect sink_tokens only during recomputation; regular forward generation allows writing into the initial sink region
write_start_index = max(local_start_index, sink_tokens) if is_recompute else local_start_index
# Protect sink_tokens only during recaching; regular forward generation allows writing into the initial sink region
write_start_index = max(local_start_index, sink_tokens) if ((block_mask is not None) and kv_cache.get("global_sink", False)) else local_start_index
roped_offset = max(0, write_start_index - local_start_index)
write_len = max(0, local_end_index - write_start_index)
if write_len > 0:
Expand Down Expand Up @@ -290,8 +291,8 @@ def qkv_fn(x):
# Construct full k, v for attention computation (without modifying the original cache)
temp_k = kv_cache["k"].clone()
temp_v = kv_cache["v"].clone()
# Protect sink_tokens only during recomputation; regular forward generation allows writing into the initial sink region
write_start_index = max(local_start_index, sink_tokens) if is_recompute else local_start_index
# Protect sink_tokens only during recaching; regular forward generation allows writing into the initial sink region
write_start_index = max(local_start_index, sink_tokens) if ((block_mask is not None) and kv_cache.get("global_sink", False)) else local_start_index
roped_offset = max(0, write_start_index - local_start_index)
write_len = max(0, local_end_index - write_start_index)
if write_len > 0:
Expand Down Expand Up @@ -331,18 +332,34 @@ def qkv_fn(x):
else:
k_cat = k_sink
v_cat = v_sink
x = attention(
roped_query,
k_cat,
v_cat
)
if block_mask is not None:
x = flex_attention(
query=roped_query.transpose(2, 1),
key=k_cat.transpose(2, 1),
value=v_cat.transpose(2, 1),
block_mask=block_mask
).transpose(2, 1)
else:
x = attention(
roped_query,
k_cat,
v_cat
)
else:
window_start = max(0, local_end_index - self.max_attention_size)
x = attention(
roped_query,
temp_k[:, window_start:local_end_index],
temp_v[:, window_start:local_end_index]
)
if block_mask is not None:
x = flex_attention(
query=roped_query.transpose(2, 1),
key=temp_k[:, window_start:local_end_index].transpose(2, 1),
value=temp_v[:, window_start:local_end_index].transpose(2, 1),
block_mask=block_mask
).transpose(2, 1)
else:
x = attention(
roped_query,
temp_k[:, window_start:local_end_index],
temp_v[:, window_start:local_end_index]
)

# output
x = x.flatten(2)
Expand Down Expand Up @@ -638,18 +655,15 @@ def _prepare_blockwise_causal_attn_mask(
[1 latent frame] [1 latent frame] ... [1 latent frame]
We use flexattention to construct the attention mask
"""
total_length = num_frames * frame_seqlen
total_q_length = num_frames * frame_seqlen
total_kv_length = (local_attn_size * frame_seqlen) if local_attn_size != -1 else total_q_length

# we do right padding to get to a multiple of 128
padded_length = math.ceil(total_length / 128) * 128 - total_length

ends = torch.zeros(total_length + padded_length,
device=device, dtype=torch.long)
ends = torch.zeros(total_q_length, device=device, dtype=torch.long)

# Block-wise causal mask will attend to all elements that are before the end of the current chunk
frame_indices = torch.arange(
start=0,
end=total_length,
end=total_q_length,
step=frame_seqlen * num_frame_per_block,
device=device
)
Expand All @@ -665,8 +679,8 @@ def attention_mask(b, h, q_idx, kv_idx):
return ((kv_idx < ends[q_idx]) & (kv_idx >= (ends[q_idx] - local_attn_size * frame_seqlen))) | (q_idx == kv_idx)
# return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask

block_mask = create_block_mask(attention_mask, B=None, H=None, Q_LEN=total_length + padded_length,
KV_LEN=total_length + padded_length, _compile=False, device=device)
block_mask = create_block_mask(attention_mask, B=None, H=None, Q_LEN=total_q_length,
KV_LEN=total_kv_length, _compile=False, device=device)

import torch.distributed as dist
if (not dist.is_initialized() or dist.get_rank() == 0) and DEBUG:
Expand Down Expand Up @@ -895,7 +909,8 @@ def _forward_inference(
kv_cache: dict = None,
crossattn_cache: dict = None,
current_start: int = 0,
cache_start: int = 0
cache_start: int = 0,
block_mask: Optional[BlockMask] = None
):
r"""
Run the diffusion model with kv caching.
Expand Down Expand Up @@ -979,7 +994,7 @@ def _forward_inference(
freqs=self.freqs,
context=context,
context_lens=context_lens,
block_mask=self.block_mask
block_mask=block_mask
)
# print("kwargs done")
def create_custom_forward(module):
Expand Down