Draft: Artifixer integration#71
Draft
riccardodelutio wants to merge 22 commits into
Draft
Conversation
Mirrors the integrations/self_forcing layout: pyproject.toml workspace
member registering a single ``artifixer-dmd-wan2.1-t2v-1.3b`` entry point,
artifixer/config.py with the AR / scheduler knobs that match the dreamfix
stage-3 DMD training config (frames_per_block=7 -> len_t, local_attn_size=21
-> window_size_t, sink_size=7 -> sink_size_t, num_inference_steps=4,
timestep_shift=5 -> FlowMatchSchedulerConfig.shift), and a stripped clone of
SelfForcingT2VRunner that takes a single text prompt.
Phase 1 only: this loads vanilla Wan 2.1 1.3B base weights from HuggingFace
so the recipe end-to-end is exercisable but the output is plain T2V. Later
commits add the ArtiFixer architecture extensions (opacity/camera MLPs,
neighbor cross-attention with PRoPE, opacity-weighted latent mixing) and
the state_dict_transform that loads the merged DMD safetensors produced by
dreamfix/scripts/merge_dcp_to_safetensors.py.
Also adds three helper scripts under scripts/:
- import_flashdreams_sqsh.sh: one-shot pull of the FlashDreams container
from ghcr.io into a persistent sqsh under /lustre so later jobs avoid
re-pulling and avoid the pyxis tmp-extraction failures.
- smoke_test_artifixer.sh: slurm wrapper that runs the plugin smoke
tests via ``uv sync`` + pytest inside the container.
- static_check_artifixer.sh: sub-second static-only sanity check
(py_compile + entry-point sanity + flashdreams import resolution)
without needing torch/uv/GPU. Designed for fast Phase 1 iteration.
Workspace pyproject.toml: add integrations/artifixer to the pyright /
ty extra-paths lists.
Verified: scripts/static_check_artifixer.sh passes (py_compile clean, 11
flashdreams imports all resolve to classes in flashdreams/, entry-point
maps to RUNNER_ARTIFIXER_DMD_T2V_1PT3B in artifixer/config.py).
Phase 2.1 of the dreamfix port: each transformer block now carries the
per-block opacity and Plucker-camera-ray projection heads from
``ArtifixerTransformerBlock`` (dreamfix/model_training/net/transformer.py
L617-L767). The MLP outputs are added to the AdaLN-normed hidden states
before self-attention. Both heads are zero-initialized so the block
behaves identically to the underlying ``Block`` when ``opacity_extra`` /
``camera_extra`` are not passed (Phase 1 text-only path).
New modules under integrations/artifixer/artifixer/:
- network/block.py ArtifixerBlock(Block) with opacity/camera MLPs
- network/dit.py ArtifixerDiTNetwork(WanDiTNetwork) +
ArtifixerDiTNetwork1pt3BConfig, plus
artifixer_embedding_dims() helper deriving
opacity_dim=1024 and camera_dim=1536 from the
Wan VAE strides + patch_size=(1,2,2)
- checkpoint.py zero_pad_artifixer_keys() state_dict transform
that zero-fills the 4 new keys per block when
loading a vanilla Wan checkpoint so strict-mode
load_state_dict still succeeds
config.py wiring:
- Replaces WanDiTNetwork1pt3BConfig with ArtifixerDiTNetwork1pt3BConfig
- Sets state_dict_transform=zero_pad_artifixer_keys(...) for the
BASE_WAN_T2V_1PT3B_CHECKPOINT_PATH source
Smoke tests extended:
- ArtifixerBlock instantiates with correct linear shapes and
zero-init invariants
- Recipe wires up ArtifixerDiTNetwork1pt3BConfig (not vanilla)
- zero_pad_artifixer_keys() fills missing keys with bf16 zeros
Static-check (scripts/static_check_artifixer.sh) now scans every .py file
under integrations/artifixer/artifixer/ and all 9 flashdreams imports
across the recipe resolve cleanly.
Phase 2.2: ArtifixerCrossAttention(CrossAttention) adds the parameters
that drive the neighbor-frame KV bank in dreamfix
(model_training/net/transformer.py L670-L699). Names are kept identical
to the dreamfix layout so the merged ArtiFixer DMD safetensors only need
the diffusers ``attn2 -> cross_attn`` regex remap (added in Phase 5) at
Phase 5, not a per-key rename:
cross_attn.add_k_proj.{weight,bias} Linear(inner_dim, inner_dim)
cross_attn.add_v_proj.{weight,bias} Linear(inner_dim, inner_dim) (zero-init)
cross_attn.norm_added_k.weight RMSNorm(inner_dim)
Plus a separate ``attn_op_neighbor`` RingAttention op so the neighbor
branch can later carry PRoPE-transformed q/k without entangling the
text branch.
ArtifixerBlock now swaps the inherited ``self.cross_attn`` for the
ArtifixerCrossAttention variant after super().__init__(). The forward
path is unchanged in this commit (inherited CrossAttention.forward
ignores the neighbor branch when no neighbor context is supplied), so
behavior is identical to Phase 2.1 until the pipeline-level neighbor
wiring lands in Phase 3 and PRoPE in Phase 2.3.
zero_pad_artifixer_keys extended to fill all 9 ArtiFixer-only keys per
block (4 from Phase 2.1 + 5 from Phase 2.2), matching the 270 extras
identified by dreamfix/scripts/dump_artifixer_param_names.py.
Smoke tests now also assert the cross_attn neighbor projections have
the correct shape and add_v_proj is zero-initialized.
Phase 2.3: verbatim port of dreamfix
``model_training/net/prope.py`` (paper "Cameras as Relative Positional
Encoding", arXiv 2507.10496). The module is self-contained — the only
dependency, ``model_training/utils/pose_utils.invert_SE3``, is inlined as
the private helper ``_invert_SE3`` so there are no cross-repo runtime
imports.
PRoPE applies block-diagonal SE(3) projection-matrix + 2D RoPE transforms
to q/k/v/o on cross-attention. ``apply_to_o`` is the inverse of the
RoPE legs in ``apply_to_q`` / ``apply_to_kv``. Used twice per
ArtifixerBlock in dreamfix:
- prope_cross_attn_src on target camera (w2cs / Ks)
- prope_cross_attn_tgt on neighbor cameras (neighbor_w2cs / neighbor_Ks)
This commit ports the module; the wiring inside
ArtifixerCrossAttention.forward (calling ``_apply_to_q``,
``_apply_to_kv``, ``_apply_to_o`` around the RingAttention SDPA call)
lands in Phase 2.4 with the network forward changes.
Tests:
- tests/test_prope_parity.py:test_prope_internal_consistency exercises
the identity-camera invariant
``apply_to_o(apply_to_q(x)) == x`` exactly at fp64.
- tests/test_prope_parity.py:test_prope_apply_to_q_changes_input_for_nontrivial_cameras
sanity-checks that real cameras do produce a non-identity transform.
- tests/test_prope_parity.py:test_prope_matches_dreamfix_reference is
a numerical-parity test against dreamfix at fp64. Skipped when
DREAMFIX_REPO_ROOT is not set / dreamfix is not on PYTHONPATH. The
intent is to gate the architecture changes on bit-identical math
the moment a dual env is available.
PRoPE module is exported from artifixer.network for convenience.
Static check passes (12 flashdreams imports resolve, py_compile clean).
…tests
Two fixes:
1. ``artifixer/network/__init__.py`` was eagerly importing every
submodule, which transitively pulled in
``flashdreams.core.attention.ring`` and the rest of the flashdreams
tree just to use PRoPE. The PRoPE unit test cannot run without the
full flashdreams env in place because of this. Made the __init__
empty (mirroring dreamfix's ``model_training/net/__init__.py``) so
``from artifixer.network.prope import ...`` only needs torch. The
recipe and tests now import submodules explicitly.
2. ``tests/test_prope_parity.py``:
- Default fixture dtype switched from fp64 -> fp32. PRoPE's
``_rope_precompute_coeffs`` does an implicit int64 -> fp32
promotion (``torch.arange(...) / num_freqs``) so the RoPE
coefficients are always at fp32 precision regardless of input
dtype. dreamfix's ``_lift_K`` also hard-codes float32 output
(no ``dtype=``), so any cross-implementation comparison must
use fp32 inputs to avoid an einsum dtype crash on the reference.
- Tolerances relaxed to fp32 (atol/rtol=1e-6 for parity, 5e-6 for
the round-trip identity check). Our port already carries the
``dtype=Ks.dtype`` fix in ``_lift_K``, so it is slightly more
dtype-robust than upstream at fp64, but bit-identical at fp32.
Verified: 3/3 tests pass in the artifixer-cuda12 container with
DREAMFIX_REPO_ROOT pointing at the dreamfix checkout:
- test_prope_internal_consistency PASSED
- test_prope_apply_to_q_changes_input_... PASSED
- test_prope_matches_dreamfix_reference PASSED
… + block + network
Phase 2.4: complete the architecture. ArtifixerCrossAttention,
ArtifixerBlock, and ArtifixerDiTNetwork now plumb the per-block opacity
+ camera-ray MLPs, the neighbor-frame KV bank, and PRoPE q/k/v/o
transforms end-to-end. With every extra ``None`` / ``False`` the path
is a no-op extension of vanilla Wan, matching Phase 2.1 / 2.2 behavior.
artifixer/network/cross_attn.py:
- Added ``initialize_neighbor_cache(context)`` which calls
``compute_kv_neighbor`` and stores a per-module
``neighbor_kv_cache: BlockKVCache | None``. ``context=None`` clears
it so subsequent forward passes skip the neighbor branch.
- Added ``forward(x, kv_cache, *, prope_src=None, prope_tgt=None,
ignore_neighbors=False)``. When the neighbor cache is set and PRoPE
modules are supplied, adds the PRoPE-modulated neighbor branch on
top of the text (+ optional I2V image) branch (matches dreamfix
transformer.py L833-L940). PRoPE math runs at fp32 with a
``.float() / .to(query.dtype)`` round trip.
- ``compute_kv_neighbor`` now returns a ``BlockKVCache`` (was a
``(k, v)`` tuple) -- consistent with the I2V ``compute_kv_image``
pattern and ready to be reused across denoise steps.
artifixer/network/block.py: ``ArtifixerBlock.forward`` accepts
``prope_src`` / ``prope_tgt`` / ``ignore_neighbors`` kwargs and forwards
them to ``cross_attn``. The neighbor KV cache itself lives on the
cross_attn module rather than in ``block_extra_kwargs`` so that one
``block_extra_kwargs`` dict can be passed identically to every block
by ``WanDiTNetwork.forward`` (while each block still has its own cache).
artifixer/network/dit.py: ``ArtifixerDiTNetwork.initialize_neighbor_kv_caches(context)``
walks every block and propagates the (optional) neighbor latent context
to each ``cross_attn``. Called once per rollout by the pipeline (Phase 3).
tests/test_smoke.py: heavy imports (mediapy, full flashdreams) deferred
into the test functions that need them; lighter tests
(``test_compute_kv_neighbor_and_cache_init``, the PRoPE parity ones)
now collect in torch-only environments without pulling boto3 / mediapy.
Added ``test_compute_kv_neighbor_and_cache_init`` covering the cache
toggle.
Verified inside the artifixer container with dreamfix on PYTHONPATH:
test_prope_internal_consistency PASSED
test_prope_apply_to_q_changes_input_for_... PASSED
test_prope_matches_dreamfix_reference PASSED (fp32 atol 1e-6)
test_compute_kv_neighbor_and_cache_init PASSED
scripts/run_prope_parity.sh also installs boto3 + tyro now and includes
the new cache test in the run.
…modules Phase 3.1: add a Wan21Transformer subclass that owns the two PropeDotProductAttention modules (target/source camera + neighbor camera) used by ArtifixerCrossAttention. The modules are constructed once at transformer-init time at the correct ``head_dim`` and dtype; their per-rollout ``_precompute_and_cache_apply_fns(...)`` calls will be driven by the pipeline (Phase 3.2 / 3.3). Wiring: ``artifixer.config.PIPELINE_ARTIFIXER_DMD_T2V_1PT3B`` now uses ``ArtifixerWanTransformerConfig`` in place of vanilla ``Wan21TransformerConfig``. No behavioral change yet (the per-rollout state plumbing lands next); the PRoPE modules sit dormant. Static check + linter clean.
Phase 3.2: port the patchification logic from dreamfix
``model_training/net/transformer.py`` L309-L341 as two stand-alone
functions:
* ``patchify_opacity`` turns ``(B, T, H, W)`` per-pixel alpha into
``(B, L, opacity_embedding_dim)`` per-token features. On the first AR
chunk (``frame_offset == 0``) the first input frame is left-padded by
3 copies so the rearrange treats every latent frame uniformly under
Wan VAE's ``1 + 4`` temporal layout.
* ``patchify_camera_rays`` turns ``(B, T, H, W, 6)`` Plucker rays into
``(B, L, camera_embedding_dim)``. Branches on whether the temporal
axis is already at the post-patch latent rate
(``hidden_post_patch_t`` == camera_rays.shape[1]) or at the input
rate, matching dreamfix exactly. The input-rate branch carries an
extra ``vae_t`` multiplier in the per-token feature and is unused at
inference (``kv_cache_pipeline.py`` L213 slices at the latent rate);
we port both for parity with the dreamfix forward.
tests/test_patches.py covers shape correctness on both branches plus a
constant-field invariance check. 9/9 tests pass (4 PRoPE + 5 patches)
inside the artifixer container with dreamfix on PYTHONPATH.
scripts/run_prope_parity.sh now also runs the patches tests.
Phase 3.3: define :class:`ArtifixerCtrl` as a per-AR-chunk conditioning payload carrying patchified opacity / camera features and the ``ignore_neighbors`` flag, and override :meth:`ArtifixerWanTransformer.predict_flow` to repackage those into ``network_extra_kwargs`` so they reach every transformer block via ``WanDiTNetwork.forward(... **block_extra_kwargs)``. The override also forwards ``self.prope_cross_attn_src`` and ``self.prope_cross_attn_tgt`` so ``ArtifixerCrossAttention.forward`` gets the PRoPE modules uniformly. The pipeline (Phase 3.5) owns the per-chunk slicing, patchification, and the per-chunk ``prope_cross_attn_src._precompute_and_cache_apply_fns(...)`` update; this commit just plumbs the kwargs. ``ArtifixerCtrl`` sets ``_is_patchified=True`` so the ``DiffusionModel.generate`` patchify-on-input dispatch (base.py L163-164) treats it as a no-op for our tensors. When the input is not an ``ArtifixerCtrl`` (e.g. pre-Phase-3 T2V smoke) the override falls through to the base ``predict_flow`` unchanged -- every conditioning kwarg in :class:`ArtifixerBlock` is optional. Static check + linter clean. No behavioral test yet -- exercising ``predict_flow`` end-to-end requires a GPU + the pipeline state from Phase 3.4 / 3.5.
Phase 3.4: port the ``ArtifixerPipelineBase.prepare_latents`` mix from
dreamfix (pipeline_base.py L98-L122) as a stand-alone, side-effect-free
helper ``opacity_weighted_latent_mix``.
The mix replaces the base ``DiffusionModel.generate`` initial-noise draw
with a per-AR-chunk weighted blend::
latents = condition * opacity_lat + noise * (1 - opacity_lat)
where ``opacity_lat`` is the per-pixel alpha max-pooled from the input
resolution to the VAE latent grid. The pipeline (Phase 3.5) will pull
the chunk's slice of the full-rollout opacity, sample noise via the
scheduler's RNG (so CP-broadcast / determinism stay intact), and feed
both to this helper.
The first AR chunk left-pads the input opacity by 3 copies of the first
frame to absorb the Wan VAE's ``1 + 4`` temporal layout. Later chunks
already arrive at ``vae_t * t_lat`` input frames.
tests/test_latent_mix.py: 5 CPU-only invariant tests (opacity=1 -> drop
noise, opacity=0 -> drop condition, t_lat=1 edge case, max-pool semantics,
non-first-chunk shape). All 14/14 tests now pass on the slurm node:
PRoPE (3) + patches (5) + latent_mix (5) + cache init (1).
Phase 3.5a: ``ArtifixerInferencePipeline`` extends ``WanInferencePipeline``
with the ArtiFixer conditioning surface. The cache adds full-rollout
fields for the per-AR-chunk slicing in Phase 3.5b
(``condition_latent``, ``opacity``, ``camera_rays``, ``w2cs``, ``Ks``,
optional ``neighbor_w2cs`` / ``neighbor_Ks``).
``initialize_cache`` takes pre-VAE-encoded latents from the caller
(the dreamfix-side Phase 4 driver owns VAE encoding) and:
1. Runs ``WanInferencePipeline.initialize_cache(text, image=None,
height, width)`` for text embeddings + base text K/V build.
2. Projects neighbor latents through ``patch_embedding`` to per-token
``neighbor_context`` (mirrors dreamfix transformer.py L361-362) and
pushes it into every block's
``ArtifixerCrossAttention.neighbor_kv_cache`` via the
``ArtifixerDiTNetwork.initialize_neighbor_kv_caches`` hook.
3. Updates the per-grid ``patches_x`` / ``patches_y`` RoPE coefficients
on both PRoPE modules.
4. Precomputes the neighbor-side PRoPE ``apply_fns`` once per rollout
(neighbor cameras are static across AR steps; the source-side
``apply_fns`` get updated per AR chunk in Phase 3.5b).
Wires ``ArtifixerInferencePipelineConfig`` into the shipped recipe in
``config.py``. Phase 3.5b adds ``generate``: per-AR-chunk PRoPE-src
precompute, opacity-weighted latent mix, manual denoise loop with
``prepare_latents`` renoise, and decode.
Static check + linter clean.
Phase 3.5b: ``ArtifixerInferencePipeline.generate`` runs the per-AR-chunk
denoise loop and decodes a video chunk. Bypasses
``DiffusionModel.generate`` so we can renoise each step toward a fresh
``opacity_weighted_latent_mix`` of the condition + new noise (matching
``ArtifixerKvCachePipeline.generate_samples_from_batch`` L211-L264 in
dreamfix), instead of ``FlowMatchScheduler.sample``'s default renoise
toward the predicted clean.
Per-AR-step orchestration:
1. ``_chunk_frame_ranges`` translates the AR index into latent-rate
and input-rate slice boundaries, accounting for the Wan VAE's
``1 + 4`` temporal layout on the first chunk.
2. Slice the chunk's ``condition_latent``, ``opacity``, ``w2cs``,
``Ks``, ``camera_rays`` (auto-detects latent vs input rate) out of
the full-rollout cache.
3. ``transformer.prope_cross_attn_src._precompute_and_cache_apply_fns
(chunk_w2cs, chunk_Ks)`` -- the neighbor-side cameras are static
and were already precomputed in ``initialize_cache``.
4. ``_build_ctrl`` patchifies the chunk's opacity / camera_rays via
:func:`patchify_opacity` / :func:`patchify_camera_rays` into an
:class:`ArtifixerCtrl` payload threaded through ``predict_flow``.
5. ``transformer_cache.start(ar_idx)``.
6. Initial latent = ``opacity_weighted_latent_mix(condition, opacity,
randn)`` (unpatchified) -> patchify.
7. For each of the ``num_inference_steps`` denoise steps:
- ``flow = transformer.predict_flow(latent, t, cache, input=ctrl)``
- ``clean = latent - sigma * flow``
- If not the last step: sample fresh noise, re-mix, patchify,
``latent = (1 - sigma_next) * clean + sigma_next * fresh_mix``.
8. ``postprocess_clean_latent`` (no-op for non-I2V).
9. Stash a ``DiffusionModel.FinalState`` on the cache so the
inherited ``finalize`` path closes the AR cache.
10. ``unpatchify_and_maybe_gather_cp`` + decode -> return.
Static check + linter clean. End-to-end testability requires the
merged DMD safetensors (Phase 5) and the dreamfix-side driver that
prepares condition_latent + neighbor_latent (Phase 4).
Phase 5a: ``artifixer_dmd_state_dict_transform`` remaps the merged
checkpoint produced by ``dreamfix/scripts/merge_dcp_to_safetensors.py``
from HF diffusers naming to the flashdreams ``WanDiTNetwork`` /
``ArtifixerDiTNetwork`` naming.
The regex mapping is a superset of
``integrations/fastvideo_causal_wan22/.../config.CHECKPOINT_KEY_MAPPING``
(Wan 2.1 / 2.2 share that layout). Three ArtiFixer-specific extra
substitutions land the 270 ArtiFixer-only keys on the right modules:
blocks.X.attn2.add_k_proj.* -> blocks.X.cross_attn.add_k_proj.*
blocks.X.attn2.add_v_proj.* -> blocks.X.cross_attn.add_v_proj.*
blocks.X.attn2.norm_added_k.* -> blocks.X.cross_attn.norm_added_k.*
The other 60 ArtiFixer-only keys (``blocks.X.opacity_embedding.*``,
``blocks.X.camera_embedding.*``) have no ``attn2`` prefix and pass
through unchanged -- they already match the
:class:`ArtifixerBlock` attribute names registered at __init__ time.
tests/test_state_dict_transform.py: 6 CPU-only key-naming tests that
load the ``param_audit.json`` produced by Phase 0 and assert:
- Diffusers-named keys are renamed correctly (spot check).
- ``opacity_embedding`` / ``camera_embedding`` pass through unchanged.
- The full 1095-key audit transforms to 1095 unique keys (no collisions).
- Every block (0..29) has the expected 21 attributes
(self_attn / cross_attn / norms / ffn / modulation /
add_k_proj / add_v_proj / norm_added_k / opacity_embedding /
camera_embedding).
- Network-level globals (``head.*``, embeddings) all present.
- All regex backreferences are well-formed.
20/20 tests pass (3 PRoPE + 5 patches + 5 latent_mix + 6 state_dict +
1 cache_init).
This commit only adds the transform helper; switching the recipe's
checkpoint_path + state_dict_transform to use it (and dropping
``zero_pad_artifixer_keys``) is Phase 5b once the merged safetensors
are at a portable path.
Phase 5b: ``PIPELINE_ARTIFIXER_DMD_T2V_1PT3B`` now defaults to the
merged ArtiFixer DMD safetensors produced by
``dreamfix/scripts/merge_dcp_to_safetensors.py``, paired with
``artifixer_dmd_state_dict_transform`` (added in Phase 5a) for the
diffusers -> ``WanDiTNetwork`` regex remap and the ArtiFixer-only
``attn2.add_k_proj`` / ``attn2.add_v_proj`` / ``attn2.norm_added_k`` ->
``cross_attn.*`` substitutions.
Two env vars cover the deployment surface:
* ``ARTIFIXER_DMD_CHECKPOINT_PATH`` -- path to the merged safetensors;
defaults to the dreamfix repo's ``merged_checkpoints/`` layout
(``/lustre/fsw/.../artifixer_dmd_1p3b_s3_2000/model.safetensors``).
* ``ARTIFIXER_USE_BASE_WAN_WEIGHTS=1`` -- fall back to vanilla
Wan 2.1 1.3B HuggingFace weights + ``zero_pad_artifixer_keys``;
useful for wiring smoke-tests when the merged safetensors are
unavailable.
The 20 existing CPU-only tests still pass; live load of the merged
safetensors into the network is exercised by Phase 4's end-to-end run
(driver in dreamfix).
… prompts Phase 4a (flashdreams side): ``ArtifixerInferencePipeline.initialize_cache`` now takes either ``text: list[str]`` (raw prompts, encoded internally by UMT5) or ``text_embeddings: Tensor`` (pre-encoded UMT5 output). The two are mutually exclusive. The dreamfix-side driver (Phase 4b) feeds the pre-encoded prompt tensor that dreamfix's data pipeline produces (each eval item carries an ``encoded_prompt`` field), avoiding a second UMT5 forward when we already have the embeddings. The ``text_embeddings`` path bypasses the UMT5-only call site in :meth:`WanInferencePipeline.initialize_cache` and slots the embeddings directly into the transformer cache via ``StreamInferencePipeline.initialize_cache(transformer_context=...)``. CFG is intentionally disabled in this branch (no ``negative_text_embeddings`` even when ``guidance_scale > 1``) to match the dreamfix ``ArtifixerKvCachePipeline`` contract -- the KV-cache pipeline ignores ``negative_prompt`` (see ``kv_cache_pipeline.py`` L104-L109). Static check + linter clean.
…npacked
Live-run fix: ``WanDiTNetwork.forward`` accepts a SINGLE
``block_extra_kwargs`` parameter (a dict), not arbitrary kwargs. The base
``Wan21Transformer._predict_flow`` does ``self._select_network(...)
(**network_extra_kwargs)`` which unpacks ``network_extra_kwargs`` into
the network forward call. Our override was flattening the extras into
``network_extra_kwargs`` directly:
network_extra_kwargs = {
"opacity_extra": ...,
"camera_extra": ...,
...
}
…which after unpack became
``network(opacity_extra=..., camera_extra=...)`` and crashed with
``TypeError: WanDiTNetwork.forward() got an unexpected keyword
argument 'opacity_extra'``.
Wrap them in a single ``block_extra_kwargs`` dict so the unpack lands
on ``network(block_extra_kwargs={"opacity_extra": ..., ...})``, which
``WanDiTNetwork.forward`` then forwards to every block as
``block(... **block_extra_kwargs)`` (network.py L438-449). The
:class:`ArtifixerBlock.forward` signature accepts each extra as a
keyword arg.
Bug surfaced on the first end-to-end run with
``--inference_backend flashdreams`` on a DL3DV scene; before the
crash, the FlashDreams pipeline ``.setup()`` loaded all 1095 keys of
the merged DMD safetensors cleanly (validating Phase 5's
state_dict_transform on the real network).
…C,H,W) before patchify
Second live-run bug fix.
dreamfix's ``encode_video_frames`` returns latents in diffusers
convention ``(B, in_dim, T_lat, Hl, Wl)`` (the same ordering ``cache.
condition_latent`` and ``opacity_weighted_latent_mix`` use), but
FlashDreams' ``WanDiTNetwork.patchify_and_maybe_split_cp`` uses the
einops pattern:
"... (t kt) c (h kh) (w kw) -> ... (t h w) (c kt kh kw)"
i.e. it expects ``(B, T, C, H, W)`` with C *after* T. Feeding the
diffusers-shaped latent in put C=16 into the t-axis position, so
the per-token feature dim came out as
``T_chunk * kt * kh * kw = 7 * 1 * 2 * 2 = 28`` instead of
``C * kt * kh * kw = 16 * 1 * 2 * 2 = 64``, and the post-patchify
linear (which fuses the Conv3d patch_embedding) failed with::
a and b must have same reduction dim, but got [.., 28] X [64, 1536]
Permute ``(B, C, T, H, W) -> (B, T, C, H, W)`` for both the initial
mix and the per-step ``fresh_mix`` before calling
``patchify_and_maybe_split_cp``. The unpatchify side already returns
``(B, T, C, H, W)`` (the FlashDreams convention) which is what the
``WanVAEDecoder`` expects, so no symmetric permute is needed at the
decode call site.
Static check + linter clean. Bug surfaced on the second
``--inference_backend flashdreams`` end-to-end run (the previous
predict_flow kwarg fix unblocked us to reach the patchify).
…peline The base StreamInferencePipeline.generate installs cache.event_profiler when this flag is True, but our custom ArtifixerInferencePipeline.generate bypasses that path -- so finalize() asserts on a None profiler. Turn the flag off until our generate() grows the matching EventProfiler setup; the inline comment documents the contract so the next person knows what to wire up before flipping it back on. Unblocks the live DL3DV rollout end-to-end through the FlashDreams backend.
The dreamfix-native ``ArtifixerKvCachePipeline.generate_samples_from_batch``
advances its KV cache *in-place* during the regular denoise forwards
and never runs an extra forward at AR-chunk boundaries. The FlashDreams
default ``Transformer.finalize_kv_cache`` runs one more ``predict_flow``
at the context-noise timestep to advance the cache (the extra forward
discarded by ``_ = ...``). For the artifixer recipe this extra forward
writes a *different* KV state than dreamfix's last in-loop forward,
which empirically opens a ~7-9 dB cross-backend PSNR gap at the start
of every AR chunk past chunk 0:
``scripts/parity_harness.py`` ``transformer_call_*`` diff,
before this change:
chunk 0 step 0: 43.55 dB
...
chunk 1 step 0: 7.53 dB <- cliff
chunk 1 step 1: 18.15 dB
chunk 2 step 0: 7.57 dB <- cliff
after this change:
chunk 0 step 0: 43.55 dB (unchanged, no AR boundary yet)
chunk 1 step 0: 37.37 dB
chunk 1 step 1: 37.20 dB
chunk 2 step 0: 37.44 dB
End-to-end ``final_video`` PSNR jumps from 30.49 dB -> 47.72 dB on the
parity-harness item. ``cache.finalize(autoregressive_index)`` (the
bookkeeping that increments AR index) is still invoked by
``DiffusionModel.finalize`` after this no-op, so the rollout state
advances correctly; only the redundant predict is suppressed.
Mirrors the pattern in
``flashdreams/recipes/alpadreams/transformer/__init__.py::finalize_kv_cache``
which conditionally skips the same call.
…eamfix) The dreamfix ``ArtifixerTransformerBlock.forward`` (``model_training/net/transformer.py`` L725-816) runs every per-block AdaLN modulation, RMSNorm, and residual add in fp32, then casts the result back to the input dtype. The base FlashDreams ``Block.forward`` keeps everything in ``x.dtype`` (typically bf16). With 30 blocks each contributing ~1 dB of bf16 noise, the cross-backend PSNR drifts from 51 dB at block 0 down to 28 dB at block 29 -- a layer-by-layer accumulation visible in ``scripts/parity_harness.py --capture_blocks``. This change mirrors the six dreamfix promotions: - modulation chunking: ``(modulation.float() + e.float()).chunk(6)`` - self-attn pre-norm + AdaLN - self-attn residual (x.float() + y * gate) - cross-attn pre-norm - FFN pre-norm + AdaLN - FFN residual (x.float() + ff.float() * gate) ``norm1`` / ``norm2`` are ``elementwise_affine=False`` so calling them with an fp32 input is a free upcast (no weight/bias dtype check). ``norm3`` is ``elementwise_affine=True`` with bf16 weight/bias after ``.to(bf16)``; passing fp32 input through ``nn.LayerNorm`` raises ``expected scalar type Float but found BFloat16``, so we route through a ``_layer_norm_fp32`` helper that promotes weight + bias to fp32 too. This matches diffusers' ``FP32LayerNorm`` which the dreamfix WanTransformerBlock norms use under the hood. End-to-end ``final_video`` cross-backend PSNR jumps 47.72 dB -> 51.34 dB on the parity-harness item. Block-by-block (call 0): block 0 went from 51.08 dB -> 55.35 dB; L1_max outliers in the middle of the stack (blocks 17, 21) went from 60+ -> 28-45. The remaining residual is attention-impl op-ordering noise inside the bf16 SDPA / FFN paths, which is irreducible without forcing the same attention backend on both implementations. Performance cost: a handful of extra casts per block. The fp32-promoted norms / residuals run on small tensors compared to the attention QKV / FFN GEMMs, so wall-clock impact is negligible.
…, README) Three small cleanups in preparation for opening the upstream MR: 1. ``ARTIFIXER_DMD_CHECKPOINT_PATH`` is now required (no committed ``/lustre/.../rdelutio/...`` default). Either set the env var or set ``ARTIFIXER_USE_BASE_WAN_WEIGHTS=1`` -- otherwise ``config.py`` raises a clear ``RuntimeError`` at import time pointing the user to ``dreamfix/scripts/merge_dcp_to_safetensors.py``. The same env-var pattern is applied to ``tests/test_state_dict_transform.py``'s ``ARTIFIXER_PARAM_AUDIT_PATH``; the existing ``pytest.skip`` path keeps CI green without the audit JSON. 2. Drop ``Phase X.Y (this commit)`` narrative wording from docstrings and comments across ``block.py`` / ``transformer.py`` / ``pipeline.py`` / ``config.py`` / ``runner.py`` / ``cross_attn.py`` / ``latent_mix.py`` / ``checkpoint.py`` / ``tests/test_smoke.py``. The Phase numbering was commit-narrative bleed-through; now reads as plain past-tense prose without dating the code to its development order. 3. Rework ``integrations/artifixer/README.md``: drop the ``| Phase | Scope | Status |`` table (Phase 4 was still ``upcoming`` despite the dreamfix-side adapter closing it), replace with a plain ``| Component | Description |`` table that describes the five recipe pieces. Add a cross-backend parity dB note so reviewers see the validation summary up front. No behavior change. Phase references in pre-existing ``.pyc`` files under ``__pycache__/`` will refresh on next collect.
…c refs
Public-friendly cleanup for the upstream MR:
* "dreamfix" rewritten to "the ArtiFixer reference" / "the
reference" throughout module docstrings, comments, and tests. The
name "dreamfix" is an internal project handle that has no place in
a public flashdreams plugin; the model name "ArtiFixer" is what
readers actually need.
* Dropped every "L###-L###" line-number citation pointing into the
reference repo's source -- those rot the moment the reference
moves a line, and they were diagnostic rather than substantive.
* Removed two stale references to the deleted ``scripts/parity_harness.py``
(in block.py and transformer.py).
* Renamed the test-side env var ``DREAMFIX_REPO_ROOT`` ->
``ARTIFIXER_REFERENCE_REPO_ROOT`` and the test function
``test_prope_matches_dreamfix_reference`` ->
``test_prope_matches_reference``. The companion
``scripts/run_prope_parity.sh`` is updated to match.
* README.md drops the internal gitlab-master URL and reframes the
"components" table without referring to the reference repo by name.
No behaviour change. Static check (``scripts/static_check_artifixer.sh``)
still passes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds the
artifixer-dmd-wan2.1-t2v-1.3brecipe underintegrations/artifixer/— a reconstruction-enhanced T2V model built on Wan 2.1 1.3B with four extensions over vanilla Wan:FlowMatchScheduler(shift=5)).A consuming driver calls
ArtifixerInferencePipeline.initialize_cachewith pre-encoded UMT5 prompts + VAE-encoded condition / neighbor latents and drives the AR rollout chunk by chunk viagenerate/finalize.Validation
scripts/static_check_artifixer.sh): green (py_compileclean, allflashdreams.*imports resolve, entry-points map toartifixer.config).tests/test_prope_parity.py::test_prope_matches_reference): bit-identical at fp32 vs the ArtiFixer reference (gated onARTIFIXER_REFERENCE_REPO_ROOT).ArtifixerDiTNetworkwith no missing / unexpected entries.final_videoPSNR vs the ArtiFixer reference'sArtifixerKvCachePipeline: 51.34 dB on a captured DL3DV-ours scene; per-block per-call PSNR stays>50 dBacross all 30 layers after the fp32 AdaLN/norm/residual promotion and the no-opfinalize_kv_cacheoverride.Diffs outside
integrations/artifixer/pyproject.toml: +1 line in[tool.pyright].extraPathsand +1 line in[tool.ty].extra-paths(standard for everyintegrations/<name>/)..gitignore: +1 line (slurm-logs/).scripts/: 4 dev helpers (smoke_test_artifixer.sh,static_check_artifixer.sh,run_prope_parity.sh,import_flashdreams_sqsh.sh).