Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/maxdiffusion/configs/base_wan_i2v_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,10 @@ num_frames: 81
guidance_scale: 5.0
flow_shift: 5.0

# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
# Diffusion CFG cache (FasterCache-style)
use_cfg_cache: False
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208)
use_sen_cache: False
use_magcache: False
magcache_thresh: 0.12
magcache_K: 2
Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/configs/base_wan_i2v_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ boundary_ratio: 0.875

# Diffusion CFG cache (FasterCache-style)
use_cfg_cache: False
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208)
use_sen_cache: False

# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
guidance_scale_low=config.guidance_scale_low,
guidance_scale_high=config.guidance_scale_high,
use_cfg_cache=config.use_cfg_cache,
use_sen_cache=config.use_sen_cache,
)
else:
raise ValueError(f"Unsupported model_name for I2V in config: {model_key}")
Expand Down
129 changes: 129 additions & 0 deletions src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,25 @@ def __call__(
output_type: Optional[str] = "np",
rng: Optional[jax.Array] = None,
use_cfg_cache: bool = False,
use_sen_cache: bool = False,
):
if use_cfg_cache and use_sen_cache:
raise ValueError("use_cfg_cache and use_sen_cache are mutually exclusive. Enable only one.")

if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0):
raise ValueError(
f"use_cfg_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
f"(got {guidance_scale_low}, {guidance_scale_high}). "
"CFG cache accelerates classifier-free guidance, which must be enabled for both transformer phases."
)

if use_sen_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0):
raise ValueError(
f"use_sen_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
f"(got {guidance_scale_low}, {guidance_scale_high}). "
"SenCache requires classifier-free guidance to be enabled for both transformer phases."
)

height = height or self.config.height
width = width or self.config.width
num_frames = num_frames or self.config.num_frames
Expand Down Expand Up @@ -264,6 +275,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt):
scheduler=self.scheduler,
image_embeds=image_embeds,
use_cfg_cache=use_cfg_cache,
use_sen_cache=use_sen_cache,
height=height,
)

Expand Down Expand Up @@ -308,11 +320,128 @@ def run_inference_2_2_i2v(
scheduler: FlaxUniPCMultistepScheduler,
scheduler_state,
use_cfg_cache: bool = False,
use_sen_cache: bool = False,
height: int = 480,
):
do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0
bsz = latents.shape[0]

# ── SenCache path (arXiv:2602.24208) ──
if use_sen_cache and do_classifier_free_guidance:
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)
step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)]

# SenCache hyperparameters
sen_epsilon = 0.1
max_reuse = 3
warmup_steps = 1
nocache_start_ratio = 0.3
nocache_end_ratio = 0.1
alpha_x, alpha_t = 1.0, 1.0

nocache_start = int(num_inference_steps * nocache_start_ratio)
nocache_end_begin = int(num_inference_steps * (1.0 - nocache_end_ratio))
num_train_timesteps = float(scheduler.config.num_train_timesteps)

prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
if image_embeds is not None:
image_embeds_combined = jnp.concatenate([image_embeds, image_embeds], axis=0)
else:
image_embeds_combined = None
condition_doubled = jnp.concatenate([condition] * 2)

# SenCache state
ref_noise_pred = None
ref_latent = None
ref_timestep = 0.0
accum_dx = 0.0
accum_dt = 0.0
reuse_count = 0
cache_count = 0

for step in range(num_inference_steps):
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
t_float = float(timesteps_np[step]) / num_train_timesteps

if step_uses_high[step]:
graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest
guidance_scale = guidance_scale_high
else:
graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest
guidance_scale = guidance_scale_low

is_boundary = step > 0 and step_uses_high[step] != step_uses_high[step - 1]
force_compute = (
step < warmup_steps or step < nocache_start or step >= nocache_end_begin or is_boundary or ref_noise_pred is None
)

if force_compute:
latents_doubled = jnp.concatenate([latents, latents], axis=0)
latent_model_input = jnp.concatenate([latents_doubled, condition_doubled], axis=-1)
latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3))
timestep = jnp.broadcast_to(t, bsz * 2)
noise_pred, _, _ = transformer_forward_pass_full_cfg(
graphdef,
state,
rest,
latent_model_input,
timestep,
prompt_embeds_combined,
guidance_scale=guidance_scale,
encoder_hidden_states_image=image_embeds_combined,
)
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
ref_noise_pred = noise_pred
ref_latent = latents
ref_timestep = t_float
accum_dx = 0.0
accum_dt = 0.0
reuse_count = 0
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
continue

dx_norm = float(jnp.sqrt(jnp.mean((latents - ref_latent) ** 2)))
dt = abs(t_float - ref_timestep)
accum_dx += dx_norm
accum_dt += dt

score = alpha_x * accum_dx + alpha_t * accum_dt

if score <= sen_epsilon and reuse_count < max_reuse:
noise_pred = ref_noise_pred
reuse_count += 1
cache_count += 1
else:
latents_doubled = jnp.concatenate([latents, latents], axis=0)
latent_model_input = jnp.concatenate([latents_doubled, condition_doubled], axis=-1)
latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3))
timestep = jnp.broadcast_to(t, bsz * 2)
noise_pred, _, _ = transformer_forward_pass_full_cfg(
graphdef,
state,
rest,
latent_model_input,
timestep,
prompt_embeds_combined,
guidance_scale=guidance_scale,
encoder_hidden_states_image=image_embeds_combined,
)
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
ref_noise_pred = noise_pred
ref_latent = latents
ref_timestep = t_float
accum_dx = 0.0
accum_dt = 0.0
reuse_count = 0

latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()

print(
f"[SenCache] Cached {cache_count}/{num_inference_steps} steps "
f"({100*cache_count/num_inference_steps:.1f}% cache ratio)"
)
return latents

# ── CFG cache path ──
if use_cfg_cache and do_classifier_free_guidance:
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)
Expand Down
Loading
Loading