v0.3.4
[0.3.4] - 2026-03-14
Added
- Mixture-of-Depths (MoD) for Llama 4: Proper implementation per Raposo et al. (2024) — lightweight router with
argpartition_axistop-k, gather-before-compute on sub-batch, scatter-after, BCE auxiliary loss. Configurable capacity factor and per-layer selection - Llama 4 RoPE: Real RoPE implementation via
pmetal_mlx::kernels::rope::apply_rope(Metal-accelerated), replacing the placeholder stub. Correctly wired into iRoPE layer dispatch — RoPE layers get rotary embeddings, NoPE layers skip them - Llama 4 temperature scaling: Per Meta's formula
log(floor((pos+1)/floor_scale) + 1) * attn_scale + 1.0, applied to Q states in NoPE layers before QK matmul for long-context attention stabilization - Llama 4 GQA: KV-head broadcast expansion for grouped-query attention — enables Scout (40 Q / 8 KV) and Maverick configs
- MoE top-k > 1:
Llama4Routerusesargpartition_axisfor O(n) expert selection with L1-normalized weights and per-slot dispatch loop, replacing hardcoded argmax - ANE fused kernels:
gen_dynamic_sdpa_fwd(single-kernel attention: RMSNorm + QKV + SDPA + Wo) andgen_dynamic_ffn_w13(single-kernel FFN: RMSNorm + W1 + W3 + SiLU), replacing 6+ separate ANE evaluations per layer - ANE fused backward:
gen_dynamic_ffn_bwd_w2tandgen_dynamic_ffn_bwd_w13tfor fused FFN backward pass - Metal dequantization kernels: Q4_0 and IQ4_XS Metal compute shaders, verified correct per GGML spec. Bridge methods in
MlxMetalBridgefor GPU-accelerated dequantization - Cancellation safety infrastructure:
CompletionToken::Dropguard inAsyncSchedulerwaits for in-flight GPU commands;retain_resource()/as_retained()for Metal buffer lifetime extension - IoSurface helpers:
write_f32_strided_at,write_f32_at_col_offset,zero_channel_range_f32for fused backward kernel IO - CloudBridge: Complete training state export (weights, optimizer state, RNG, dataloader position, metadata) with working Python bootstrap scripts for FSDP/DeepSpeed cluster resumption and Rust-side loader functions
- Formal verification:
cargo-kaniproofs for ring all-reduce chunk arithmetic (95 checks) and k-ary tree topology consistency (607 checks), with justfile recipes - Reasoning templates:
MathReasoningTemplate(GRPO + accuracy/format rewards) andCodeReasoningTemplate(structural code fence + test case matching) - Reasoning dataset auto-detection:
pmetal dataset prepareautomatically detectsproblem/thinking/solutioncolumns and formats them as<think>tagged ChatML conversations --columnsflag: General column remapping fordataset prepare(e.g.,--columns "instruction=question,output=answer")adapter_config.json: Saved alongside LoRA weights during training (r, alpha, target_modules, use_rslora). Loaded automatically at inference and fuse time — eliminates config guesswork- Supply chain:
cargo-vetinitialized with Mozilla, Google, and Bytecode Alliance audit imports; 17 workspace crates covered; 5 transitive dependency exemptions with exact lockfile versions - Tracing spans: 6
info_span!markers in Python trainer for phase-level observability (model_resolve, load_tokenizer, load_dataset, load_model, training_loop, save_weights)
Fixed
- LoRA inference garbage output: Merged LoRA weights into base model at inference time (
W += scale*B@A), matching mlx-lm's pattern. The separate-forward path had dtype mismatch issues (BF16 base × F32 LoRA) - Auto-chat mode regression: Removed heuristic that forced chat template on base models just because their tokenizer has
<|im_end|>. Chat mode now requires explicit--chator an instruction-tuned model - Missing EOS in training data: Training sequences now end with the model's actual EOS token (e.g.,
<|endoftext|>for Qwen). Previously only had turn delimiter (<|im_end|>) — model never learned to stop generating - Fuse command wrong alpha/rank:
pmetal fusenow readsadapter_config.jsonfor correct alpha and rank instead of defaulting toscale=1.0. Also filters MLP LoRA weights (rank=0) when auto-detecting rank from shapes - ANE
x2normbackward bug: FFN weight gradients (dW1,dW3) were computed against the wrong pre-norm tensor (xnormfrom attention block instead ofx2normfrom FFN block). Restoredx2normfield and CPU RMSNorm recomputation for gradient correctness - ANE
sdpa_bwdsurface dtype: Backward SDPA output surfaces were allocated as fp32 but ANE kernels produce fp16 — stride mismatch corrupted dV/dQ/dK gradients. Fixed toIoSurface::for_tensor()(fp16) - MoD argpartition sign: Router negated weights before
argpartition_axis, selecting bottom-k (least important) tokens instead of top-k. Removed negation - MLX bridge
copy_as_f32regression: Renamed methods dropped auto dtype conversion — callers passing wrong dtype would panic. Restoredcopy_as_f32/copy_as_f16with auto-conversion - MLX bridge
view_f32eval: Removed.eval()call before accessing data pointer — unevaluated arrays returned null. Restored defensive eval - Python API surface: Restored
ProgressCallback,LoggingCallback(log_every=10),__version__, andPythonCallbackBridgethat were deleted during PyO3 migration - TUI training completion: Reads final metrics from JSONL file on disk (immune to polling lag). Shows actual loss and step count instead of
0.0000/ sample count - TUI Steps/min overflow: Guards against divide-by-zero when
total_ms=0— shows—instead of60000 - Dataset prepare panic: Empty results no longer crash with index-out-of-bounds. Shows diagnostic message with format hints
Changed
- LoRA inference uses merge:
merge_lora()is called before generation, producing a single merged weight matrix per layer. This is equivalent to the fuse command but happens in-memory without saving - PyO3 0.23 → 0.28:
allow_threads→detach,with_gil→attach,from_py_objecton all pyclass types,Bound<'py, PyDict>return types - tokio 1.49 → 1.50
unsafe_codelint: Escalated fromwarntodenyworkspace-wide
Full Changelog: v0.3.3...v0.3.4