In [1]:
from pathlib import Path
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch

In [2]:
def load_latent(folder: Path, step: int, pattern: str) -> torch.Tensor:
    path = folder / pattern.format(step=step)
    if not path.exists():
        raise FileNotFoundError(f"Missing latent file: {path}")
    latent = torch.load(path, map_location='cpu').float()
    if latent.dim() != 4:
        raise ValueError(
            f"Latent tensor must have shape [C, T, H, W], got {tuple(latent.shape)} for {path}"
        )
    return latent


def build_mask_from_two(
    z_a: torch.Tensor,
    z_b: torch.Tensor,
    percentile: float,
    mode: str = "global"  # global or framewise
):
    diff = (z_a - z_b).abs().mean(dim=0)  # [T, H, W]

    if mode == "global":
        flat = diff.flatten().numpy()
        thresh = np.percentile(flat, percentile)
        mask = (diff.numpy() > thresh).astype(np.uint8)

    elif mode == "framewise":
        diff_np = diff.numpy()
        T = diff_np.shape[0]
        mask = np.zeros_like(diff_np, dtype=np.uint8)
        thresh = np.zeros(T, dtype=np.float32)
        for t in range(T):
            flat_t = diff_np[t].flatten()
            thresh_t = np.percentile(flat_t, percentile)
            mask[t] = (diff_np[t] > thresh_t).astype(np.uint8)
            thresh[t] = thresh_t
    else:
        raise ValueError(f"Unknown mode: {mode}. Choose 'global' or 'framewise'.")

    return diff, mask, thresh


def save_quicklooks(step: int, diff: torch.Tensor, mask: np.ndarray, outdir: Path):
    outdir.mkdir(parents=True, exist_ok=True)

    plt.figure(figsize=(6, 4))
    plt.hist(diff.flatten().numpy(), bins=100, log=True)
    plt.title(f"Latent abs diff histogram (step {step})")
    plt.xlabel("abs diff"); plt.ylabel("count (log)")
    plt.tight_layout()
    plt.savefig(outdir / f"hist_step{step:03d}.png")
    plt.close()

    heatmap = diff.mean(dim=0).numpy()
    plt.figure(figsize=(6, 6))
    plt.imshow(heatmap, cmap="viridis")
    plt.colorbar(label="mean abs diff")
    plt.title(f"Mean abs diff heatmap (step {step})")
    plt.tight_layout()
    plt.savefig(outdir / f"heatmap_step{step:03d}.png")
    plt.close()

    tmid = diff.shape[0] // 2
    plt.figure(figsize=(6, 6))
    plt.imshow(mask[tmid], cmap="hot")
    plt.colorbar(label="mask (1=changed)")
    plt.title(f"Mask (step {step}, frame {tmid})")
    plt.tight_layout()
    plt.savefig(outdir / f"mask_step{step:03d}.png")
    plt.close()

In [6]:
latent_dir_a = Path("/data02/henry/wan_cache/latents")  # Prompt A latents
latent_dir_b = Path("/data02/henry/wan_cache/baseline/latents")  # Prompt B latents
output_dir = Path("/data02/henry/wan_cache/masks/")
steps = [10, 15, 20, 25, 30, 35, 40, 45]
percentile = 70.0  # mask = diff > Pth percentile (higher = smaller masked area)
filename_pattern = "latent_step{step:03d}.pt"

In [9]:
output_dir.mkdir(parents=True, exist_ok=True)
print(f"[INFO] Saving outputs to {output_dir}")

for step in steps:
    print(f"[BUILD] step {step:03d}")
    z_a = load_latent(latent_dir_a, step, filename_pattern)
    z_b = load_latent(latent_dir_b, step, filename_pattern)
    diff, mask, thresh = build_mask_from_two(z_a, z_b, percentile, mode="framewise")

    masked_frac = float(mask.mean())
    stats_path = output_dir / f"stats_step{step:03d}.txt"
    with open(stats_path, "w") as fh:
        fh.write(f"mean_abs_diff={float(diff.mean()):.6f}\n")
        fh.write(f"max_abs_diff={float(diff.max()):.6f}\n")
        fh.write(f"masked_frac={masked_frac:.6f}\n")
        fh.write(f"unmasked_frac={float(1.0 - masked_frac):.6f}\n")
        # fh.write(f"threshold_percentile={thresh:.6f}\n")

    torch.save(torch.from_numpy(mask.astype(np.uint8)), output_dir / f"mask_step{step:03d}.pt")
    torch.save(diff, output_dir / f"diff_step{step:03d}.pt")
    save_quicklooks(step, diff, mask, output_dir)

print("[DONE] Mask generation complete.")

[INFO] Saving outputs to /data02/henry/wan_cache/masks
[BUILD] step 010


  latent = torch.load(path, map_location='cpu').float()


[BUILD] step 015
[BUILD] step 020
[BUILD] step 025
[BUILD] step 030
[BUILD] step 035
[BUILD] step 040
[BUILD] step 045
[DONE] Mask generation complete.


In [10]:
mask_fhw = torch.load("/data02/henry/wan_cache/masks/mask_step045.pt")
print("Mask shape:", mask_fhw.shape)

if isinstance(mask_fhw, torch.Tensor):
    mask_fhw = mask_fhw.cpu().numpy() 

mask_fhw_mp4 = mask_fhw.astype(np.uint8) * 255

F, H, W = mask_fhw_mp4.shape
print(f"{F=}, {H=}, {W=}")

out_path = "/data02/henry/wan_cache/mask_fhw_compare.mp4"
fps = 30  # 你可以改，比如 16
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
writer = cv2.VideoWriter(out_path, fourcc, fps, (W, H), isColor=False)

# 6. 写入每帧
for i in range(F):
    writer.write(mask_fhw_mp4[i])

writer.release()
print(f"✅ 黑白mask视频已保存到: {out_path}")

Mask shape: torch.Size([21, 104, 60])
F=21, H=104, W=60
✅ 黑白mask视频已保存到: /data02/henry/wan_cache/mask_fhw_compare.mp4


  mask_fhw = torch.load("/data02/henry/wan_cache/masks/mask_step045.pt")
