From 0d96f280d9a9cdbb9417970b12c20b56c7e1083f Mon Sep 17 00:00:00 2001 From: James Huang Date: Mon, 30 Mar 2026 21:05:55 +0000 Subject: [PATCH] Implement SenCache for WAN 2.2 I2V pipeline Signed-off-by: James Huang --- src/maxdiffusion/configs/base_wan_i2v_14b.yml | 4 +- src/maxdiffusion/configs/base_wan_i2v_27b.yml | 2 + src/maxdiffusion/generate_wan.py | 1 + .../pipelines/wan/wan_pipeline_i2v_2p2.py | 129 ++++++++ src/maxdiffusion/tests/wan_sen_cache_test.py | 276 ++++++++++++++++++ 5 files changed, 411 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index e2170293..395ed0f8 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -286,8 +286,10 @@ num_frames: 81 guidance_scale: 5.0 flow_shift: 5.0 -# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only) +# Diffusion CFG cache (FasterCache-style) use_cfg_cache: False +# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) +use_sen_cache: False use_magcache: False magcache_thresh: 0.12 magcache_K: 2 diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index b7f89304..b1f6901c 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -300,6 +300,8 @@ boundary_ratio: 0.875 # Diffusion CFG cache (FasterCache-style) use_cfg_cache: False +# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) +use_sen_cache: False # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 3cbfb60e..9c973146 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -117,6 +117,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt): guidance_scale_low=config.guidance_scale_low, guidance_scale_high=config.guidance_scale_high, use_cfg_cache=config.use_cfg_cache, + use_sen_cache=config.use_sen_cache, ) else: raise ValueError(f"Unsupported model_name for I2V in config: {model_key}") diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py index 65e78674..ffbe1496 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py @@ -167,7 +167,11 @@ def __call__( output_type: Optional[str] = "np", rng: Optional[jax.Array] = None, use_cfg_cache: bool = False, + use_sen_cache: bool = False, ): + if use_cfg_cache and use_sen_cache: + raise ValueError("use_cfg_cache and use_sen_cache are mutually exclusive. Enable only one.") + if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0): raise ValueError( f"use_cfg_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 " @@ -175,6 +179,13 @@ def __call__( "CFG cache accelerates classifier-free guidance, which must be enabled for both transformer phases." ) + if use_sen_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0): + raise ValueError( + f"use_sen_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 " + f"(got {guidance_scale_low}, {guidance_scale_high}). " + "SenCache requires classifier-free guidance to be enabled for both transformer phases." + ) + height = height or self.config.height width = width or self.config.width num_frames = num_frames or self.config.num_frames @@ -264,6 +275,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): scheduler=self.scheduler, image_embeds=image_embeds, use_cfg_cache=use_cfg_cache, + use_sen_cache=use_sen_cache, height=height, ) @@ -308,11 +320,128 @@ def run_inference_2_2_i2v( scheduler: FlaxUniPCMultistepScheduler, scheduler_state, use_cfg_cache: bool = False, + use_sen_cache: bool = False, height: int = 480, ): do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 bsz = latents.shape[0] + # ── SenCache path (arXiv:2602.24208) ── + if use_sen_cache and do_classifier_free_guidance: + timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32) + step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)] + + # SenCache hyperparameters + sen_epsilon = 0.1 + max_reuse = 3 + warmup_steps = 1 + nocache_start_ratio = 0.3 + nocache_end_ratio = 0.1 + alpha_x, alpha_t = 1.0, 1.0 + + nocache_start = int(num_inference_steps * nocache_start_ratio) + nocache_end_begin = int(num_inference_steps * (1.0 - nocache_end_ratio)) + num_train_timesteps = float(scheduler.config.num_train_timesteps) + + prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) + if image_embeds is not None: + image_embeds_combined = jnp.concatenate([image_embeds, image_embeds], axis=0) + else: + image_embeds_combined = None + condition_doubled = jnp.concatenate([condition] * 2) + + # SenCache state + ref_noise_pred = None + ref_latent = None + ref_timestep = 0.0 + accum_dx = 0.0 + accum_dt = 0.0 + reuse_count = 0 + cache_count = 0 + + for step in range(num_inference_steps): + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + t_float = float(timesteps_np[step]) / num_train_timesteps + + if step_uses_high[step]: + graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest + guidance_scale = guidance_scale_high + else: + graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest + guidance_scale = guidance_scale_low + + is_boundary = step > 0 and step_uses_high[step] != step_uses_high[step - 1] + force_compute = ( + step < warmup_steps or step < nocache_start or step >= nocache_end_begin or is_boundary or ref_noise_pred is None + ) + + if force_compute: + latents_doubled = jnp.concatenate([latents, latents], axis=0) + latent_model_input = jnp.concatenate([latents_doubled, condition_doubled], axis=-1) + latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3)) + timestep = jnp.broadcast_to(t, bsz * 2) + noise_pred, _, _ = transformer_forward_pass_full_cfg( + graphdef, + state, + rest, + latent_model_input, + timestep, + prompt_embeds_combined, + guidance_scale=guidance_scale, + encoder_hidden_states_image=image_embeds_combined, + ) + noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) + ref_noise_pred = noise_pred + ref_latent = latents + ref_timestep = t_float + accum_dx = 0.0 + accum_dt = 0.0 + reuse_count = 0 + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + continue + + dx_norm = float(jnp.sqrt(jnp.mean((latents - ref_latent) ** 2))) + dt = abs(t_float - ref_timestep) + accum_dx += dx_norm + accum_dt += dt + + score = alpha_x * accum_dx + alpha_t * accum_dt + + if score <= sen_epsilon and reuse_count < max_reuse: + noise_pred = ref_noise_pred + reuse_count += 1 + cache_count += 1 + else: + latents_doubled = jnp.concatenate([latents, latents], axis=0) + latent_model_input = jnp.concatenate([latents_doubled, condition_doubled], axis=-1) + latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3)) + timestep = jnp.broadcast_to(t, bsz * 2) + noise_pred, _, _ = transformer_forward_pass_full_cfg( + graphdef, + state, + rest, + latent_model_input, + timestep, + prompt_embeds_combined, + guidance_scale=guidance_scale, + encoder_hidden_states_image=image_embeds_combined, + ) + noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) + ref_noise_pred = noise_pred + ref_latent = latents + ref_timestep = t_float + accum_dx = 0.0 + accum_dt = 0.0 + reuse_count = 0 + + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + + print( + f"[SenCache] Cached {cache_count}/{num_inference_steps} steps " + f"({100*cache_count/num_inference_steps:.1f}% cache ratio)" + ) + return latents + # ── CFG cache path ── if use_cfg_cache and do_classifier_free_guidance: timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32) diff --git a/src/maxdiffusion/tests/wan_sen_cache_test.py b/src/maxdiffusion/tests/wan_sen_cache_test.py index 1d2fe76c..b82d4122 100644 --- a/src/maxdiffusion/tests/wan_sen_cache_test.py +++ b/src/maxdiffusion/tests/wan_sen_cache_test.py @@ -23,6 +23,7 @@ from absl.testing import absltest from maxdiffusion.pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2 +from maxdiffusion.pipelines.wan.wan_pipeline_i2v_2p2 import WanPipelineI2V_2_2 IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" THIS_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -350,5 +351,280 @@ def test_sen_cache_speedup_and_fidelity(self): self.assertGreaterEqual(mean_ssim, 0.95, f"Mean SSIM={mean_ssim:.4f} < 0.95") +class Wan22I2VSenCacheValidationTest(unittest.TestCase): + """Tests that use_sen_cache validation raises correct errors for Wan 2.2 I2V.""" + + def _make_pipeline(self): + pipeline = WanPipelineI2V_2_2.__new__(WanPipelineI2V_2_2) + return pipeline + + def test_sen_cache_with_both_scales_low_raises(self): + pipeline = self._make_pipeline() + with self.assertRaises(ValueError) as ctx: + pipeline( + prompt=["test"], + image=None, + guidance_scale_low=1.0, + guidance_scale_high=1.0, + use_sen_cache=True, + ) + self.assertIn("use_sen_cache", str(ctx.exception)) + + def test_sen_cache_with_low_scale_low_raises(self): + pipeline = self._make_pipeline() + with self.assertRaises(ValueError) as ctx: + pipeline( + prompt=["test"], + image=None, + guidance_scale_low=0.5, + guidance_scale_high=4.0, + use_sen_cache=True, + ) + self.assertIn("use_sen_cache", str(ctx.exception)) + + def test_sen_cache_with_high_scale_low_raises(self): + pipeline = self._make_pipeline() + with self.assertRaises(ValueError) as ctx: + pipeline( + prompt=["test"], + image=None, + guidance_scale_low=3.0, + guidance_scale_high=1.0, + use_sen_cache=True, + ) + self.assertIn("use_sen_cache", str(ctx.exception)) + + def test_sen_cache_mutually_exclusive_with_cfg_cache(self): + pipeline = self._make_pipeline() + with self.assertRaises(ValueError) as ctx: + pipeline( + prompt=["test"], + image=None, + guidance_scale_low=3.0, + guidance_scale_high=4.0, + use_cfg_cache=True, + use_sen_cache=True, + ) + self.assertIn("mutually exclusive", str(ctx.exception)) + + def test_sen_cache_with_valid_scales_no_validation_error(self): + """Both guidance_scales > 1.0 should pass validation (may fail later without model).""" + pipeline = self._make_pipeline() + try: + pipeline( + prompt=["test"], + image=None, + guidance_scale_low=3.0, + guidance_scale_high=4.0, + use_sen_cache=True, + ) + except ValueError as e: + if "use_sen_cache" in str(e): + self.fail(f"Unexpected validation error: {e}") + except Exception: + pass + + def test_no_sen_cache_with_low_scales_no_error(self): + """use_sen_cache=False should never raise our ValueError.""" + pipeline = self._make_pipeline() + try: + pipeline( + prompt=["test"], + image=None, + guidance_scale_low=0.5, + guidance_scale_high=0.5, + use_sen_cache=False, + ) + except ValueError as e: + if "use_sen_cache" in str(e): + self.fail(f"Unexpected validation error: {e}") + except Exception: + pass + + +class Wan22I2VSenCacheScheduleTest(unittest.TestCase): + """Tests the SenCache schedule logic for Wan 2.2 I2V. + + The schedule logic is identical to T2V — validates force_compute zones + and sensitivity gating constraints. + """ + + def _get_force_compute_schedule(self, num_inference_steps, boundary_ratio=0.875, num_train_timesteps=1000): + """Extract which steps are forced to compute — mirrors run_inference_2_2_i2v's SenCache logic.""" + boundary = boundary_ratio * num_train_timesteps + timesteps = np.linspace(num_train_timesteps - 1, 0, num_inference_steps, dtype=np.int32) + step_uses_high = [bool(timesteps[s] >= boundary) for s in range(num_inference_steps)] + + warmup_steps = 1 + nocache_start_ratio = 0.3 + nocache_end_ratio = 0.1 + + nocache_start = int(num_inference_steps * nocache_start_ratio) + nocache_end_begin = int(num_inference_steps * (1.0 - nocache_end_ratio)) + + force_compute = [] + for s in range(num_inference_steps): + is_boundary = s > 0 and step_uses_high[s] != step_uses_high[s - 1] + forced = s < warmup_steps or s < nocache_start or s >= nocache_end_begin or is_boundary or s == 0 + force_compute.append(forced) + + return force_compute, step_uses_high + + def test_first_step_always_forced(self): + force_compute, _ = self._get_force_compute_schedule(50) + self.assertTrue(force_compute[0]) + + def test_first_30_percent_always_forced(self): + force_compute, _ = self._get_force_compute_schedule(50) + nocache_start = int(50 * 0.3) + self.assertTrue(all(force_compute[:nocache_start])) + + def test_last_10_percent_always_forced(self): + force_compute, _ = self._get_force_compute_schedule(50) + nocache_end_begin = int(50 * 0.9) + self.assertTrue(all(force_compute[nocache_end_begin:])) + + def test_boundary_transition_forced(self): + force_compute, step_uses_high = self._get_force_compute_schedule(50) + for s in range(1, 50): + if step_uses_high[s] != step_uses_high[s - 1]: + self.assertTrue(force_compute[s], f"Boundary step {s} should be forced") + + def test_cacheable_window_exists(self): + force_compute, _ = self._get_force_compute_schedule(50) + nocache_start = int(50 * 0.3) + nocache_end_begin = int(50 * 0.9) + cacheable = [not force_compute[s] for s in range(nocache_start, nocache_end_begin)] + self.assertGreater(sum(cacheable), 0, "Should have cacheable steps in the middle window") + + def test_schedule_matches_t2v(self): + """I2V SenCache schedule should be identical to T2V SenCache schedule.""" + t2v_test = WanSenCacheScheduleTest() + for n_steps in [20, 50, 100]: + fc_i2v, high_i2v = self._get_force_compute_schedule(n_steps) + fc_t2v, high_t2v = t2v_test._get_force_compute_schedule(n_steps) + self.assertEqual(fc_i2v, fc_t2v, f"I2V and T2V schedules should match for {n_steps} steps") + self.assertEqual(high_i2v, high_t2v, f"I2V and T2V high-noise schedules should match for {n_steps} steps") + + +@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Requires TPU v7-8 and model weights") +class Wan22I2VSenCacheSmokeTest(unittest.TestCase): + """End-to-end smoke test: SenCache for Wan 2.2 I2V should be faster with SSIM >= 0.95. + + Runs on TPU v7-8 (8 chips, context_parallelism=8) with WAN 2.2 I2V 27B, 720p. + Skipped in CI (GitHub Actions) — run locally with: + python -m pytest src/maxdiffusion/tests/wan_sen_cache_test.py::Wan22I2VSenCacheSmokeTest -v + """ + + @classmethod + def setUpClass(cls): + from maxdiffusion import pyconfig + from maxdiffusion.checkpointing.wan_checkpointer_i2v_2p2 import WanCheckpointerI2V_2_2 + from maxdiffusion.utils.loading_utils import load_image + + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_i2v_27b.yml"), + "num_inference_steps=50", + "height=720", + "width=1280", + "num_frames=81", + "fps=24", + "guidance_scale_low=3.0", + "guidance_scale_high=4.0", + "boundary_ratio=0.875", + "flow_shift=3.0", + "seed=11234567893", + "attention=flash", + "remat_policy=FULL", + "allow_split_physical_axes=True", + "skip_jax_distributed_system=True", + "weights_dtype=bfloat16", + "activations_dtype=bfloat16", + "per_device_batch_size=0.125", + "ici_data_parallelism=1", + "ici_fsdp_parallelism=1", + "ici_context_parallelism=8", + "ici_tensor_parallelism=1", + "flash_min_seq_length=0", + 'flash_block_sizes={"block_q": 2048, "block_kv_compute": 1024, "block_kv": 2048, "block_q_dkv": 2048, "block_kv_dkv": 2048, "block_kv_dkv_compute": 2048, "use_fused_bwd_kernel": true}', + ], + unittest=True, + ) + cls.config = pyconfig.config + checkpoint_loader = WanCheckpointerI2V_2_2(config=cls.config) + cls.pipeline, _, _ = checkpoint_loader.load_checkpoint() + + cls.prompt = [cls.config.prompt] * cls.config.global_batch_size_to_train_on + cls.negative_prompt = [cls.config.negative_prompt] * cls.config.global_batch_size_to_train_on + cls.image = load_image(cls.config.image_url) + + # Warmup both XLA code paths + for use_cache in [False, True]: + cls.pipeline( + prompt=cls.prompt, + image=cls.image, + negative_prompt=cls.negative_prompt, + height=cls.config.height, + width=cls.config.width, + num_frames=cls.config.num_frames, + num_inference_steps=cls.config.num_inference_steps, + guidance_scale_low=cls.config.guidance_scale_low, + guidance_scale_high=cls.config.guidance_scale_high, + use_sen_cache=use_cache, + ) + + def _run_pipeline(self, use_sen_cache): + t0 = time.perf_counter() + videos = self.pipeline( + prompt=self.prompt, + image=self.image, + negative_prompt=self.negative_prompt, + height=self.config.height, + width=self.config.width, + num_frames=self.config.num_frames, + num_inference_steps=self.config.num_inference_steps, + guidance_scale_low=self.config.guidance_scale_low, + guidance_scale_high=self.config.guidance_scale_high, + use_sen_cache=use_sen_cache, + ) + return videos, time.perf_counter() - t0 + + def test_sen_cache_speedup_and_fidelity(self): + """I2V SenCache must be faster than baseline with PSNR >= 30 dB and SSIM >= 0.95.""" + videos_baseline, t_baseline = self._run_pipeline(use_sen_cache=False) + videos_cached, t_cached = self._run_pipeline(use_sen_cache=True) + + # Speed check + speedup = t_baseline / t_cached + print(f"I2V Baseline: {t_baseline:.2f}s, SenCache: {t_cached:.2f}s, Speedup: {speedup:.3f}x") + self.assertGreater(speedup, 1.0, f"SenCache should be faster. Speedup={speedup:.3f}x") + + # Fidelity checks + v1 = np.array(videos_baseline[0], dtype=np.float64) + v2 = np.array(videos_cached[0], dtype=np.float64) + + # PSNR + mse = np.mean((v1 - v2) ** 2) + psnr = 10.0 * np.log10(1.0 / mse) if mse > 0 else float("inf") + print(f"I2V PSNR: {psnr:.2f} dB") + self.assertGreaterEqual(psnr, 30.0, f"PSNR={psnr:.2f} dB < 30 dB") + + # SSIM (per-frame) + C1, C2 = 0.01**2, 0.03**2 + ssim_scores = [] + for f in range(v1.shape[0]): + mu1, mu2 = np.mean(v1[f]), np.mean(v2[f]) + sigma1_sq, sigma2_sq = np.var(v1[f]), np.var(v2[f]) + sigma12 = np.mean((v1[f] - mu1) * (v2[f] - mu2)) + ssim = ((2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)) / ((mu1**2 + mu2**2 + C1) * (sigma1_sq + sigma2_sq + C2)) + ssim_scores.append(float(ssim)) + + mean_ssim = np.mean(ssim_scores) + print(f"I2V SSIM: mean={mean_ssim:.4f}, min={np.min(ssim_scores):.4f}") + self.assertGreaterEqual(mean_ssim, 0.95, f"Mean SSIM={mean_ssim:.4f} < 0.95") + + if __name__ == "__main__": absltest.main()