Skip to content

Draft: Artifixer integration#71

Draft
riccardodelutio wants to merge 22 commits into
NVIDIA:mainfrom
riccardodelutio:rdelutio/artifixer-recipe
Draft

Draft: Artifixer integration#71
riccardodelutio wants to merge 22 commits into
NVIDIA:mainfrom
riccardodelutio:rdelutio/artifixer-recipe

Conversation

@riccardodelutio
Copy link
Copy Markdown
Collaborator

@riccardodelutio riccardodelutio commented May 13, 2026

Summary

Adds the artifixer-dmd-wan2.1-t2v-1.3b recipe under integrations/artifixer/ — a reconstruction-enhanced T2V model built on Wan 2.1 1.3B with four extensions over vanilla Wan:

  • per-block opacity + Plucker-camera-ray MLPs;
  • a third KV bank for neighbor cross-attention with PRoPE;
  • opacity-weighted latent mixing of noise with VAE-encoded reconstruction-rendered frames;
  • 4-step DMD distillation (FlowMatchScheduler(shift=5)).

A consuming driver calls ArtifixerInferencePipeline.initialize_cache with pre-encoded UMT5 prompts + VAE-encoded condition / neighbor latents and drives the AR rollout chunk by chunk via generate / finalize.

Validation

  • Static check (scripts/static_check_artifixer.sh): green (py_compile clean, all flashdreams.* imports resolve, entry-points map to artifixer.config).
  • PRoPE parity (tests/test_prope_parity.py::test_prope_matches_reference): bit-identical at fp32 vs the ArtiFixer reference (gated on ARTIFIXER_REFERENCE_REPO_ROOT).
  • State-dict transform tests: all 1095 keys load strictly into ArtifixerDiTNetwork with no missing / unexpected entries.
  • Cross-backend final_video PSNR vs the ArtiFixer reference's ArtifixerKvCachePipeline: 51.34 dB on a captured DL3DV-ours scene; per-block per-call PSNR stays >50 dB across all 30 layers after the fp32 AdaLN/norm/residual promotion and the no-op finalize_kv_cache override.

Diffs outside integrations/artifixer/

  • pyproject.toml: +1 line in [tool.pyright].extraPaths and +1 line in [tool.ty].extra-paths (standard for every integrations/<name>/).
  • .gitignore: +1 line (slurm-logs/).
  • scripts/: 4 dev helpers (smoke_test_artifixer.sh, static_check_artifixer.sh, run_prope_parity.sh, import_flashdreams_sqsh.sh).

Mirrors the integrations/self_forcing layout: pyproject.toml workspace
member registering a single ``artifixer-dmd-wan2.1-t2v-1.3b`` entry point,
artifixer/config.py with the AR / scheduler knobs that match the dreamfix
stage-3 DMD training config (frames_per_block=7 -> len_t, local_attn_size=21
-> window_size_t, sink_size=7 -> sink_size_t, num_inference_steps=4,
timestep_shift=5 -> FlowMatchSchedulerConfig.shift), and a stripped clone of
SelfForcingT2VRunner that takes a single text prompt.

Phase 1 only: this loads vanilla Wan 2.1 1.3B base weights from HuggingFace
so the recipe end-to-end is exercisable but the output is plain T2V. Later
commits add the ArtiFixer architecture extensions (opacity/camera MLPs,
neighbor cross-attention with PRoPE, opacity-weighted latent mixing) and
the state_dict_transform that loads the merged DMD safetensors produced by
dreamfix/scripts/merge_dcp_to_safetensors.py.

Also adds three helper scripts under scripts/:
  - import_flashdreams_sqsh.sh: one-shot pull of the FlashDreams container
    from ghcr.io into a persistent sqsh under /lustre so later jobs avoid
    re-pulling and avoid the pyxis tmp-extraction failures.
  - smoke_test_artifixer.sh: slurm wrapper that runs the plugin smoke
    tests via ``uv sync`` + pytest inside the container.
  - static_check_artifixer.sh: sub-second static-only sanity check
    (py_compile + entry-point sanity + flashdreams import resolution)
    without needing torch/uv/GPU. Designed for fast Phase 1 iteration.

Workspace pyproject.toml: add integrations/artifixer to the pyright /
ty extra-paths lists.

Verified: scripts/static_check_artifixer.sh passes (py_compile clean, 11
flashdreams imports all resolve to classes in flashdreams/, entry-point
maps to RUNNER_ARTIFIXER_DMD_T2V_1PT3B in artifixer/config.py).
Phase 2.1 of the dreamfix port: each transformer block now carries the
per-block opacity and Plucker-camera-ray projection heads from
``ArtifixerTransformerBlock`` (dreamfix/model_training/net/transformer.py
L617-L767). The MLP outputs are added to the AdaLN-normed hidden states
before self-attention. Both heads are zero-initialized so the block
behaves identically to the underlying ``Block`` when ``opacity_extra`` /
``camera_extra`` are not passed (Phase 1 text-only path).

New modules under integrations/artifixer/artifixer/:
  - network/block.py      ArtifixerBlock(Block) with opacity/camera MLPs
  - network/dit.py        ArtifixerDiTNetwork(WanDiTNetwork) +
                          ArtifixerDiTNetwork1pt3BConfig, plus
                          artifixer_embedding_dims() helper deriving
                          opacity_dim=1024 and camera_dim=1536 from the
                          Wan VAE strides + patch_size=(1,2,2)
  - checkpoint.py         zero_pad_artifixer_keys() state_dict transform
                          that zero-fills the 4 new keys per block when
                          loading a vanilla Wan checkpoint so strict-mode
                          load_state_dict still succeeds

config.py wiring:
  - Replaces WanDiTNetwork1pt3BConfig with ArtifixerDiTNetwork1pt3BConfig
  - Sets state_dict_transform=zero_pad_artifixer_keys(...) for the
    BASE_WAN_T2V_1PT3B_CHECKPOINT_PATH source

Smoke tests extended:
  - ArtifixerBlock instantiates with correct linear shapes and
    zero-init invariants
  - Recipe wires up ArtifixerDiTNetwork1pt3BConfig (not vanilla)
  - zero_pad_artifixer_keys() fills missing keys with bf16 zeros

Static-check (scripts/static_check_artifixer.sh) now scans every .py file
under integrations/artifixer/artifixer/ and all 9 flashdreams imports
across the recipe resolve cleanly.
Phase 2.2: ArtifixerCrossAttention(CrossAttention) adds the parameters
that drive the neighbor-frame KV bank in dreamfix
(model_training/net/transformer.py L670-L699). Names are kept identical
to the dreamfix layout so the merged ArtiFixer DMD safetensors only need
the diffusers ``attn2 -> cross_attn`` regex remap (added in Phase 5) at
Phase 5, not a per-key rename:

  cross_attn.add_k_proj.{weight,bias}    Linear(inner_dim, inner_dim)
  cross_attn.add_v_proj.{weight,bias}    Linear(inner_dim, inner_dim)   (zero-init)
  cross_attn.norm_added_k.weight         RMSNorm(inner_dim)

Plus a separate ``attn_op_neighbor`` RingAttention op so the neighbor
branch can later carry PRoPE-transformed q/k without entangling the
text branch.

ArtifixerBlock now swaps the inherited ``self.cross_attn`` for the
ArtifixerCrossAttention variant after super().__init__(). The forward
path is unchanged in this commit (inherited CrossAttention.forward
ignores the neighbor branch when no neighbor context is supplied), so
behavior is identical to Phase 2.1 until the pipeline-level neighbor
wiring lands in Phase 3 and PRoPE in Phase 2.3.

zero_pad_artifixer_keys extended to fill all 9 ArtiFixer-only keys per
block (4 from Phase 2.1 + 5 from Phase 2.2), matching the 270 extras
identified by dreamfix/scripts/dump_artifixer_param_names.py.

Smoke tests now also assert the cross_attn neighbor projections have
the correct shape and add_v_proj is zero-initialized.
Phase 2.3: verbatim port of dreamfix
``model_training/net/prope.py`` (paper "Cameras as Relative Positional
Encoding", arXiv 2507.10496). The module is self-contained — the only
dependency, ``model_training/utils/pose_utils.invert_SE3``, is inlined as
the private helper ``_invert_SE3`` so there are no cross-repo runtime
imports.

PRoPE applies block-diagonal SE(3) projection-matrix + 2D RoPE transforms
to q/k/v/o on cross-attention. ``apply_to_o`` is the inverse of the
RoPE legs in ``apply_to_q`` / ``apply_to_kv``. Used twice per
ArtifixerBlock in dreamfix:

  - prope_cross_attn_src on target camera (w2cs / Ks)
  - prope_cross_attn_tgt on neighbor cameras (neighbor_w2cs / neighbor_Ks)

This commit ports the module; the wiring inside
ArtifixerCrossAttention.forward (calling ``_apply_to_q``,
``_apply_to_kv``, ``_apply_to_o`` around the RingAttention SDPA call)
lands in Phase 2.4 with the network forward changes.

Tests:
  - tests/test_prope_parity.py:test_prope_internal_consistency exercises
    the identity-camera invariant
    ``apply_to_o(apply_to_q(x)) == x`` exactly at fp64.
  - tests/test_prope_parity.py:test_prope_apply_to_q_changes_input_for_nontrivial_cameras
    sanity-checks that real cameras do produce a non-identity transform.
  - tests/test_prope_parity.py:test_prope_matches_dreamfix_reference is
    a numerical-parity test against dreamfix at fp64. Skipped when
    DREAMFIX_REPO_ROOT is not set / dreamfix is not on PYTHONPATH. The
    intent is to gate the architecture changes on bit-identical math
    the moment a dual env is available.

PRoPE module is exported from artifixer.network for convenience.
Static check passes (12 flashdreams imports resolve, py_compile clean).
…tests

Two fixes:

1. ``artifixer/network/__init__.py`` was eagerly importing every
   submodule, which transitively pulled in
   ``flashdreams.core.attention.ring`` and the rest of the flashdreams
   tree just to use PRoPE. The PRoPE unit test cannot run without the
   full flashdreams env in place because of this. Made the __init__
   empty (mirroring dreamfix's ``model_training/net/__init__.py``) so
   ``from artifixer.network.prope import ...`` only needs torch. The
   recipe and tests now import submodules explicitly.

2. ``tests/test_prope_parity.py``:
   - Default fixture dtype switched from fp64 -> fp32. PRoPE's
     ``_rope_precompute_coeffs`` does an implicit int64 -> fp32
     promotion (``torch.arange(...) / num_freqs``) so the RoPE
     coefficients are always at fp32 precision regardless of input
     dtype. dreamfix's ``_lift_K`` also hard-codes float32 output
     (no ``dtype=``), so any cross-implementation comparison must
     use fp32 inputs to avoid an einsum dtype crash on the reference.
   - Tolerances relaxed to fp32 (atol/rtol=1e-6 for parity, 5e-6 for
     the round-trip identity check). Our port already carries the
     ``dtype=Ks.dtype`` fix in ``_lift_K``, so it is slightly more
     dtype-robust than upstream at fp64, but bit-identical at fp32.

Verified: 3/3 tests pass in the artifixer-cuda12 container with
DREAMFIX_REPO_ROOT pointing at the dreamfix checkout:
  - test_prope_internal_consistency           PASSED
  - test_prope_apply_to_q_changes_input_...   PASSED
  - test_prope_matches_dreamfix_reference     PASSED
… + block + network

Phase 2.4: complete the architecture. ArtifixerCrossAttention,
ArtifixerBlock, and ArtifixerDiTNetwork now plumb the per-block opacity
+ camera-ray MLPs, the neighbor-frame KV bank, and PRoPE q/k/v/o
transforms end-to-end. With every extra ``None`` / ``False`` the path
is a no-op extension of vanilla Wan, matching Phase 2.1 / 2.2 behavior.

artifixer/network/cross_attn.py:
  - Added ``initialize_neighbor_cache(context)`` which calls
    ``compute_kv_neighbor`` and stores a per-module
    ``neighbor_kv_cache: BlockKVCache | None``. ``context=None`` clears
    it so subsequent forward passes skip the neighbor branch.
  - Added ``forward(x, kv_cache, *, prope_src=None, prope_tgt=None,
    ignore_neighbors=False)``. When the neighbor cache is set and PRoPE
    modules are supplied, adds the PRoPE-modulated neighbor branch on
    top of the text (+ optional I2V image) branch (matches dreamfix
    transformer.py L833-L940). PRoPE math runs at fp32 with a
    ``.float() / .to(query.dtype)`` round trip.
  - ``compute_kv_neighbor`` now returns a ``BlockKVCache`` (was a
    ``(k, v)`` tuple) -- consistent with the I2V ``compute_kv_image``
    pattern and ready to be reused across denoise steps.

artifixer/network/block.py: ``ArtifixerBlock.forward`` accepts
``prope_src`` / ``prope_tgt`` / ``ignore_neighbors`` kwargs and forwards
them to ``cross_attn``. The neighbor KV cache itself lives on the
cross_attn module rather than in ``block_extra_kwargs`` so that one
``block_extra_kwargs`` dict can be passed identically to every block
by ``WanDiTNetwork.forward`` (while each block still has its own cache).

artifixer/network/dit.py: ``ArtifixerDiTNetwork.initialize_neighbor_kv_caches(context)``
walks every block and propagates the (optional) neighbor latent context
to each ``cross_attn``. Called once per rollout by the pipeline (Phase 3).

tests/test_smoke.py: heavy imports (mediapy, full flashdreams) deferred
into the test functions that need them; lighter tests
(``test_compute_kv_neighbor_and_cache_init``, the PRoPE parity ones)
now collect in torch-only environments without pulling boto3 / mediapy.
Added ``test_compute_kv_neighbor_and_cache_init`` covering the cache
toggle.

Verified inside the artifixer container with dreamfix on PYTHONPATH:
  test_prope_internal_consistency               PASSED
  test_prope_apply_to_q_changes_input_for_...   PASSED
  test_prope_matches_dreamfix_reference         PASSED  (fp32 atol 1e-6)
  test_compute_kv_neighbor_and_cache_init       PASSED

scripts/run_prope_parity.sh also installs boto3 + tyro now and includes
the new cache test in the run.
…modules

Phase 3.1: add a Wan21Transformer subclass that owns the two
PropeDotProductAttention modules (target/source camera + neighbor
camera) used by ArtifixerCrossAttention. The modules are constructed
once at transformer-init time at the correct ``head_dim`` and dtype;
their per-rollout ``_precompute_and_cache_apply_fns(...)`` calls will
be driven by the pipeline (Phase 3.2 / 3.3).

Wiring: ``artifixer.config.PIPELINE_ARTIFIXER_DMD_T2V_1PT3B`` now uses
``ArtifixerWanTransformerConfig`` in place of vanilla
``Wan21TransformerConfig``. No behavioral change yet (the per-rollout
state plumbing lands next); the PRoPE modules sit dormant.

Static check + linter clean.
Phase 3.2: port the patchification logic from dreamfix
``model_training/net/transformer.py`` L309-L341 as two stand-alone
functions:

  * ``patchify_opacity`` turns ``(B, T, H, W)`` per-pixel alpha into
    ``(B, L, opacity_embedding_dim)`` per-token features. On the first AR
    chunk (``frame_offset == 0``) the first input frame is left-padded by
    3 copies so the rearrange treats every latent frame uniformly under
    Wan VAE's ``1 + 4`` temporal layout.
  * ``patchify_camera_rays`` turns ``(B, T, H, W, 6)`` Plucker rays into
    ``(B, L, camera_embedding_dim)``. Branches on whether the temporal
    axis is already at the post-patch latent rate
    (``hidden_post_patch_t`` == camera_rays.shape[1]) or at the input
    rate, matching dreamfix exactly. The input-rate branch carries an
    extra ``vae_t`` multiplier in the per-token feature and is unused at
    inference (``kv_cache_pipeline.py`` L213 slices at the latent rate);
    we port both for parity with the dreamfix forward.

tests/test_patches.py covers shape correctness on both branches plus a
constant-field invariance check. 9/9 tests pass (4 PRoPE + 5 patches)
inside the artifixer container with dreamfix on PYTHONPATH.

scripts/run_prope_parity.sh now also runs the patches tests.
Phase 3.3: define :class:`ArtifixerCtrl` as a per-AR-chunk conditioning
payload carrying patchified opacity / camera features and the
``ignore_neighbors`` flag, and override
:meth:`ArtifixerWanTransformer.predict_flow` to repackage those into
``network_extra_kwargs`` so they reach every transformer block via
``WanDiTNetwork.forward(... **block_extra_kwargs)``.

The override also forwards ``self.prope_cross_attn_src`` and
``self.prope_cross_attn_tgt`` so ``ArtifixerCrossAttention.forward``
gets the PRoPE modules uniformly. The pipeline (Phase 3.5) owns the
per-chunk slicing, patchification, and the per-chunk
``prope_cross_attn_src._precompute_and_cache_apply_fns(...)`` update;
this commit just plumbs the kwargs.

``ArtifixerCtrl`` sets ``_is_patchified=True`` so the
``DiffusionModel.generate`` patchify-on-input dispatch (base.py L163-164)
treats it as a no-op for our tensors.

When the input is not an ``ArtifixerCtrl`` (e.g. pre-Phase-3 T2V smoke)
the override falls through to the base ``predict_flow`` unchanged --
every conditioning kwarg in :class:`ArtifixerBlock` is optional.

Static check + linter clean. No behavioral test yet -- exercising
``predict_flow`` end-to-end requires a GPU + the pipeline state from
Phase 3.4 / 3.5.
Phase 3.4: port the ``ArtifixerPipelineBase.prepare_latents`` mix from
dreamfix (pipeline_base.py L98-L122) as a stand-alone, side-effect-free
helper ``opacity_weighted_latent_mix``.

The mix replaces the base ``DiffusionModel.generate`` initial-noise draw
with a per-AR-chunk weighted blend::

    latents = condition * opacity_lat + noise * (1 - opacity_lat)

where ``opacity_lat`` is the per-pixel alpha max-pooled from the input
resolution to the VAE latent grid. The pipeline (Phase 3.5) will pull
the chunk's slice of the full-rollout opacity, sample noise via the
scheduler's RNG (so CP-broadcast / determinism stay intact), and feed
both to this helper.

The first AR chunk left-pads the input opacity by 3 copies of the first
frame to absorb the Wan VAE's ``1 + 4`` temporal layout. Later chunks
already arrive at ``vae_t * t_lat`` input frames.

tests/test_latent_mix.py: 5 CPU-only invariant tests (opacity=1 -> drop
noise, opacity=0 -> drop condition, t_lat=1 edge case, max-pool semantics,
non-first-chunk shape). All 14/14 tests now pass on the slurm node:
PRoPE (3) + patches (5) + latent_mix (5) + cache init (1).
Phase 3.5a: ``ArtifixerInferencePipeline`` extends ``WanInferencePipeline``
with the ArtiFixer conditioning surface. The cache adds full-rollout
fields for the per-AR-chunk slicing in Phase 3.5b
(``condition_latent``, ``opacity``, ``camera_rays``, ``w2cs``, ``Ks``,
optional ``neighbor_w2cs`` / ``neighbor_Ks``).

``initialize_cache`` takes pre-VAE-encoded latents from the caller
(the dreamfix-side Phase 4 driver owns VAE encoding) and:

  1. Runs ``WanInferencePipeline.initialize_cache(text, image=None,
     height, width)`` for text embeddings + base text K/V build.
  2. Projects neighbor latents through ``patch_embedding`` to per-token
     ``neighbor_context`` (mirrors dreamfix transformer.py L361-362) and
     pushes it into every block's
     ``ArtifixerCrossAttention.neighbor_kv_cache`` via the
     ``ArtifixerDiTNetwork.initialize_neighbor_kv_caches`` hook.
  3. Updates the per-grid ``patches_x`` / ``patches_y`` RoPE coefficients
     on both PRoPE modules.
  4. Precomputes the neighbor-side PRoPE ``apply_fns`` once per rollout
     (neighbor cameras are static across AR steps; the source-side
     ``apply_fns`` get updated per AR chunk in Phase 3.5b).

Wires ``ArtifixerInferencePipelineConfig`` into the shipped recipe in
``config.py``. Phase 3.5b adds ``generate``: per-AR-chunk PRoPE-src
precompute, opacity-weighted latent mix, manual denoise loop with
``prepare_latents`` renoise, and decode.

Static check + linter clean.
Phase 3.5b: ``ArtifixerInferencePipeline.generate`` runs the per-AR-chunk
denoise loop and decodes a video chunk. Bypasses
``DiffusionModel.generate`` so we can renoise each step toward a fresh
``opacity_weighted_latent_mix`` of the condition + new noise (matching
``ArtifixerKvCachePipeline.generate_samples_from_batch`` L211-L264 in
dreamfix), instead of ``FlowMatchScheduler.sample``'s default renoise
toward the predicted clean.

Per-AR-step orchestration:

  1. ``_chunk_frame_ranges`` translates the AR index into latent-rate
     and input-rate slice boundaries, accounting for the Wan VAE's
     ``1 + 4`` temporal layout on the first chunk.
  2. Slice the chunk's ``condition_latent``, ``opacity``, ``w2cs``,
     ``Ks``, ``camera_rays`` (auto-detects latent vs input rate) out of
     the full-rollout cache.
  3. ``transformer.prope_cross_attn_src._precompute_and_cache_apply_fns
     (chunk_w2cs, chunk_Ks)`` -- the neighbor-side cameras are static
     and were already precomputed in ``initialize_cache``.
  4. ``_build_ctrl`` patchifies the chunk's opacity / camera_rays via
     :func:`patchify_opacity` / :func:`patchify_camera_rays` into an
     :class:`ArtifixerCtrl` payload threaded through ``predict_flow``.
  5. ``transformer_cache.start(ar_idx)``.
  6. Initial latent = ``opacity_weighted_latent_mix(condition, opacity,
     randn)`` (unpatchified) -> patchify.
  7. For each of the ``num_inference_steps`` denoise steps:
     - ``flow = transformer.predict_flow(latent, t, cache, input=ctrl)``
     - ``clean = latent - sigma * flow``
     - If not the last step: sample fresh noise, re-mix, patchify,
       ``latent = (1 - sigma_next) * clean + sigma_next * fresh_mix``.
  8. ``postprocess_clean_latent`` (no-op for non-I2V).
  9. Stash a ``DiffusionModel.FinalState`` on the cache so the
     inherited ``finalize`` path closes the AR cache.
  10. ``unpatchify_and_maybe_gather_cp`` + decode -> return.

Static check + linter clean. End-to-end testability requires the
merged DMD safetensors (Phase 5) and the dreamfix-side driver that
prepares condition_latent + neighbor_latent (Phase 4).
Phase 5a: ``artifixer_dmd_state_dict_transform`` remaps the merged
checkpoint produced by ``dreamfix/scripts/merge_dcp_to_safetensors.py``
from HF diffusers naming to the flashdreams ``WanDiTNetwork`` /
``ArtifixerDiTNetwork`` naming.

The regex mapping is a superset of
``integrations/fastvideo_causal_wan22/.../config.CHECKPOINT_KEY_MAPPING``
(Wan 2.1 / 2.2 share that layout). Three ArtiFixer-specific extra
substitutions land the 270 ArtiFixer-only keys on the right modules:

  blocks.X.attn2.add_k_proj.*   -> blocks.X.cross_attn.add_k_proj.*
  blocks.X.attn2.add_v_proj.*   -> blocks.X.cross_attn.add_v_proj.*
  blocks.X.attn2.norm_added_k.* -> blocks.X.cross_attn.norm_added_k.*

The other 60 ArtiFixer-only keys (``blocks.X.opacity_embedding.*``,
``blocks.X.camera_embedding.*``) have no ``attn2`` prefix and pass
through unchanged -- they already match the
:class:`ArtifixerBlock` attribute names registered at __init__ time.

tests/test_state_dict_transform.py: 6 CPU-only key-naming tests that
load the ``param_audit.json`` produced by Phase 0 and assert:

  - Diffusers-named keys are renamed correctly (spot check).
  - ``opacity_embedding`` / ``camera_embedding`` pass through unchanged.
  - The full 1095-key audit transforms to 1095 unique keys (no collisions).
  - Every block (0..29) has the expected 21 attributes
    (self_attn / cross_attn / norms / ffn / modulation /
    add_k_proj / add_v_proj / norm_added_k / opacity_embedding /
    camera_embedding).
  - Network-level globals (``head.*``, embeddings) all present.
  - All regex backreferences are well-formed.

20/20 tests pass (3 PRoPE + 5 patches + 5 latent_mix + 6 state_dict +
1 cache_init).

This commit only adds the transform helper; switching the recipe's
checkpoint_path + state_dict_transform to use it (and dropping
``zero_pad_artifixer_keys``) is Phase 5b once the merged safetensors
are at a portable path.
Phase 5b: ``PIPELINE_ARTIFIXER_DMD_T2V_1PT3B`` now defaults to the
merged ArtiFixer DMD safetensors produced by
``dreamfix/scripts/merge_dcp_to_safetensors.py``, paired with
``artifixer_dmd_state_dict_transform`` (added in Phase 5a) for the
diffusers -> ``WanDiTNetwork`` regex remap and the ArtiFixer-only
``attn2.add_k_proj`` / ``attn2.add_v_proj`` / ``attn2.norm_added_k`` ->
``cross_attn.*`` substitutions.

Two env vars cover the deployment surface:

  * ``ARTIFIXER_DMD_CHECKPOINT_PATH`` -- path to the merged safetensors;
    defaults to the dreamfix repo's ``merged_checkpoints/`` layout
    (``/lustre/fsw/.../artifixer_dmd_1p3b_s3_2000/model.safetensors``).
  * ``ARTIFIXER_USE_BASE_WAN_WEIGHTS=1`` -- fall back to vanilla
    Wan 2.1 1.3B HuggingFace weights + ``zero_pad_artifixer_keys``;
    useful for wiring smoke-tests when the merged safetensors are
    unavailable.

The 20 existing CPU-only tests still pass; live load of the merged
safetensors into the network is exercised by Phase 4's end-to-end run
(driver in dreamfix).
… prompts

Phase 4a (flashdreams side): ``ArtifixerInferencePipeline.initialize_cache``
now takes either ``text: list[str]`` (raw prompts, encoded internally by
UMT5) or ``text_embeddings: Tensor`` (pre-encoded UMT5 output). The two
are mutually exclusive.

The dreamfix-side driver (Phase 4b) feeds the pre-encoded prompt tensor
that dreamfix's data pipeline produces (each eval item carries an
``encoded_prompt`` field), avoiding a second UMT5 forward when we
already have the embeddings. The ``text_embeddings`` path bypasses
the UMT5-only call site in
:meth:`WanInferencePipeline.initialize_cache` and slots the embeddings
directly into the transformer cache via
``StreamInferencePipeline.initialize_cache(transformer_context=...)``.

CFG is intentionally disabled in this branch (no
``negative_text_embeddings`` even when ``guidance_scale > 1``) to
match the dreamfix ``ArtifixerKvCachePipeline`` contract -- the
KV-cache pipeline ignores ``negative_prompt`` (see
``kv_cache_pipeline.py`` L104-L109).

Static check + linter clean.
…npacked

Live-run fix: ``WanDiTNetwork.forward`` accepts a SINGLE
``block_extra_kwargs`` parameter (a dict), not arbitrary kwargs. The base
``Wan21Transformer._predict_flow`` does ``self._select_network(...)
(**network_extra_kwargs)`` which unpacks ``network_extra_kwargs`` into
the network forward call. Our override was flattening the extras into
``network_extra_kwargs`` directly:

    network_extra_kwargs = {
        "opacity_extra": ...,
        "camera_extra": ...,
        ...
    }

…which after unpack became
``network(opacity_extra=..., camera_extra=...)`` and crashed with
``TypeError: WanDiTNetwork.forward() got an unexpected keyword
argument 'opacity_extra'``.

Wrap them in a single ``block_extra_kwargs`` dict so the unpack lands
on ``network(block_extra_kwargs={"opacity_extra": ..., ...})``, which
``WanDiTNetwork.forward`` then forwards to every block as
``block(... **block_extra_kwargs)`` (network.py L438-449). The
:class:`ArtifixerBlock.forward` signature accepts each extra as a
keyword arg.

Bug surfaced on the first end-to-end run with
``--inference_backend flashdreams`` on a DL3DV scene; before the
crash, the FlashDreams pipeline ``.setup()`` loaded all 1095 keys of
the merged DMD safetensors cleanly (validating Phase 5's
state_dict_transform on the real network).
…C,H,W) before patchify

Second live-run bug fix.

dreamfix's ``encode_video_frames`` returns latents in diffusers
convention ``(B, in_dim, T_lat, Hl, Wl)`` (the same ordering ``cache.
condition_latent`` and ``opacity_weighted_latent_mix`` use), but
FlashDreams' ``WanDiTNetwork.patchify_and_maybe_split_cp`` uses the
einops pattern:

    "... (t kt) c (h kh) (w kw) -> ... (t h w) (c kt kh kw)"

i.e. it expects ``(B, T, C, H, W)`` with C *after* T. Feeding the
diffusers-shaped latent in put C=16 into the t-axis position, so
the per-token feature dim came out as
``T_chunk * kt * kh * kw = 7 * 1 * 2 * 2 = 28`` instead of
``C * kt * kh * kw = 16 * 1 * 2 * 2 = 64``, and the post-patchify
linear (which fuses the Conv3d patch_embedding) failed with::

    a and b must have same reduction dim, but got [.., 28] X [64, 1536]

Permute ``(B, C, T, H, W) -> (B, T, C, H, W)`` for both the initial
mix and the per-step ``fresh_mix`` before calling
``patchify_and_maybe_split_cp``. The unpatchify side already returns
``(B, T, C, H, W)`` (the FlashDreams convention) which is what the
``WanVAEDecoder`` expects, so no symmetric permute is needed at the
decode call site.

Static check + linter clean. Bug surfaced on the second
``--inference_backend flashdreams`` end-to-end run (the previous
predict_flow kwarg fix unblocked us to reach the patchify).
…peline

The base StreamInferencePipeline.generate installs cache.event_profiler
when this flag is True, but our custom ArtifixerInferencePipeline.generate
bypasses that path -- so finalize() asserts on a None profiler. Turn the
flag off until our generate() grows the matching EventProfiler setup; the
inline comment documents the contract so the next person knows what to
wire up before flipping it back on.

Unblocks the live DL3DV rollout end-to-end through the FlashDreams backend.
The dreamfix-native ``ArtifixerKvCachePipeline.generate_samples_from_batch``
advances its KV cache *in-place* during the regular denoise forwards
and never runs an extra forward at AR-chunk boundaries. The FlashDreams
default ``Transformer.finalize_kv_cache`` runs one more ``predict_flow``
at the context-noise timestep to advance the cache (the extra forward
discarded by ``_ = ...``). For the artifixer recipe this extra forward
writes a *different* KV state than dreamfix's last in-loop forward,
which empirically opens a ~7-9 dB cross-backend PSNR gap at the start
of every AR chunk past chunk 0:

  ``scripts/parity_harness.py`` ``transformer_call_*`` diff,
  before this change:
    chunk 0 step 0: 43.55 dB
    ...
    chunk 1 step 0: 7.53 dB  <- cliff
    chunk 1 step 1: 18.15 dB
    chunk 2 step 0: 7.57 dB  <- cliff
  after this change:
    chunk 0 step 0: 43.55 dB (unchanged, no AR boundary yet)
    chunk 1 step 0: 37.37 dB
    chunk 1 step 1: 37.20 dB
    chunk 2 step 0: 37.44 dB

End-to-end ``final_video`` PSNR jumps from 30.49 dB -> 47.72 dB on the
parity-harness item. ``cache.finalize(autoregressive_index)`` (the
bookkeeping that increments AR index) is still invoked by
``DiffusionModel.finalize`` after this no-op, so the rollout state
advances correctly; only the redundant predict is suppressed.

Mirrors the pattern in
``flashdreams/recipes/alpadreams/transformer/__init__.py::finalize_kv_cache``
which conditionally skips the same call.
…eamfix)

The dreamfix ``ArtifixerTransformerBlock.forward``
(``model_training/net/transformer.py`` L725-816) runs every per-block
AdaLN modulation, RMSNorm, and residual add in fp32, then casts the
result back to the input dtype. The base FlashDreams ``Block.forward``
keeps everything in ``x.dtype`` (typically bf16). With 30 blocks each
contributing ~1 dB of bf16 noise, the cross-backend PSNR drifts from
51 dB at block 0 down to 28 dB at block 29 -- a layer-by-layer
accumulation visible in ``scripts/parity_harness.py --capture_blocks``.

This change mirrors the six dreamfix promotions:
  - modulation chunking: ``(modulation.float() + e.float()).chunk(6)``
  - self-attn pre-norm + AdaLN
  - self-attn residual (x.float() + y * gate)
  - cross-attn pre-norm
  - FFN pre-norm + AdaLN
  - FFN residual (x.float() + ff.float() * gate)

``norm1`` / ``norm2`` are ``elementwise_affine=False`` so calling them
with an fp32 input is a free upcast (no weight/bias dtype check).
``norm3`` is ``elementwise_affine=True`` with bf16 weight/bias after
``.to(bf16)``; passing fp32 input through ``nn.LayerNorm`` raises
``expected scalar type Float but found BFloat16``, so we route through
a ``_layer_norm_fp32`` helper that promotes weight + bias to fp32 too.
This matches diffusers' ``FP32LayerNorm`` which the dreamfix
WanTransformerBlock norms use under the hood.

End-to-end ``final_video`` cross-backend PSNR jumps 47.72 dB -> 51.34 dB
on the parity-harness item. Block-by-block (call 0): block 0 went from
51.08 dB -> 55.35 dB; L1_max outliers in the middle of the stack
(blocks 17, 21) went from 60+ -> 28-45. The remaining residual is
attention-impl op-ordering noise inside the bf16 SDPA / FFN paths,
which is irreducible without forcing the same attention backend on
both implementations.

Performance cost: a handful of extra casts per block. The
fp32-promoted norms / residuals run on small tensors compared to the
attention QKV / FFN GEMMs, so wall-clock impact is negligible.
…, README)

Three small cleanups in preparation for opening the upstream MR:

1. ``ARTIFIXER_DMD_CHECKPOINT_PATH`` is now required (no committed
   ``/lustre/.../rdelutio/...`` default). Either set the env var or set
   ``ARTIFIXER_USE_BASE_WAN_WEIGHTS=1`` -- otherwise ``config.py``
   raises a clear ``RuntimeError`` at import time pointing the user to
   ``dreamfix/scripts/merge_dcp_to_safetensors.py``. The same env-var
   pattern is applied to ``tests/test_state_dict_transform.py``'s
   ``ARTIFIXER_PARAM_AUDIT_PATH``; the existing ``pytest.skip`` path
   keeps CI green without the audit JSON.

2. Drop ``Phase X.Y (this commit)`` narrative wording from docstrings
   and comments across ``block.py`` / ``transformer.py`` / ``pipeline.py``
   / ``config.py`` / ``runner.py`` / ``cross_attn.py`` / ``latent_mix.py``
   / ``checkpoint.py`` / ``tests/test_smoke.py``. The Phase numbering
   was commit-narrative bleed-through; now reads as plain past-tense
   prose without dating the code to its development order.

3. Rework ``integrations/artifixer/README.md``: drop the
   ``| Phase | Scope | Status |`` table (Phase 4 was still ``upcoming``
   despite the dreamfix-side adapter closing it), replace with a
   plain ``| Component | Description |`` table that describes the
   five recipe pieces. Add a cross-backend parity dB note so reviewers
   see the validation summary up front.

No behavior change. Phase references in pre-existing ``.pyc`` files
under ``__pycache__/`` will refresh on next collect.
…c refs

Public-friendly cleanup for the upstream MR:

  * "dreamfix" rewritten to "the ArtiFixer reference" / "the
    reference" throughout module docstrings, comments, and tests. The
    name "dreamfix" is an internal project handle that has no place in
    a public flashdreams plugin; the model name "ArtiFixer" is what
    readers actually need.
  * Dropped every "L###-L###" line-number citation pointing into the
    reference repo's source -- those rot the moment the reference
    moves a line, and they were diagnostic rather than substantive.
  * Removed two stale references to the deleted ``scripts/parity_harness.py``
    (in block.py and transformer.py).
  * Renamed the test-side env var ``DREAMFIX_REPO_ROOT`` ->
    ``ARTIFIXER_REFERENCE_REPO_ROOT`` and the test function
    ``test_prope_matches_dreamfix_reference`` ->
    ``test_prope_matches_reference``. The companion
    ``scripts/run_prope_parity.sh`` is updated to match.
  * README.md drops the internal gitlab-master URL and reframes the
    "components" table without referring to the reference repo by name.

No behaviour change. Static check (``scripts/static_check_artifixer.sh``)
still passes.
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 13, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant