Skip to content

v0.3.4

Choose a tag to compare

@github-actions github-actions released this 14 Mar 20:08
· 642 commits to main since this release

[0.3.4] - 2026-03-14

Added

  • Mixture-of-Depths (MoD) for Llama 4: Proper implementation per Raposo et al. (2024) — lightweight router with argpartition_axis top-k, gather-before-compute on sub-batch, scatter-after, BCE auxiliary loss. Configurable capacity factor and per-layer selection
  • Llama 4 RoPE: Real RoPE implementation via pmetal_mlx::kernels::rope::apply_rope (Metal-accelerated), replacing the placeholder stub. Correctly wired into iRoPE layer dispatch — RoPE layers get rotary embeddings, NoPE layers skip them
  • Llama 4 temperature scaling: Per Meta's formula log(floor((pos+1)/floor_scale) + 1) * attn_scale + 1.0, applied to Q states in NoPE layers before QK matmul for long-context attention stabilization
  • Llama 4 GQA: KV-head broadcast expansion for grouped-query attention — enables Scout (40 Q / 8 KV) and Maverick configs
  • MoE top-k > 1: Llama4Router uses argpartition_axis for O(n) expert selection with L1-normalized weights and per-slot dispatch loop, replacing hardcoded argmax
  • ANE fused kernels: gen_dynamic_sdpa_fwd (single-kernel attention: RMSNorm + QKV + SDPA + Wo) and gen_dynamic_ffn_w13 (single-kernel FFN: RMSNorm + W1 + W3 + SiLU), replacing 6+ separate ANE evaluations per layer
  • ANE fused backward: gen_dynamic_ffn_bwd_w2t and gen_dynamic_ffn_bwd_w13t for fused FFN backward pass
  • Metal dequantization kernels: Q4_0 and IQ4_XS Metal compute shaders, verified correct per GGML spec. Bridge methods in MlxMetalBridge for GPU-accelerated dequantization
  • Cancellation safety infrastructure: CompletionToken::Drop guard in AsyncScheduler waits for in-flight GPU commands; retain_resource() / as_retained() for Metal buffer lifetime extension
  • IoSurface helpers: write_f32_strided_at, write_f32_at_col_offset, zero_channel_range_f32 for fused backward kernel IO
  • CloudBridge: Complete training state export (weights, optimizer state, RNG, dataloader position, metadata) with working Python bootstrap scripts for FSDP/DeepSpeed cluster resumption and Rust-side loader functions
  • Formal verification: cargo-kani proofs for ring all-reduce chunk arithmetic (95 checks) and k-ary tree topology consistency (607 checks), with justfile recipes
  • Reasoning templates: MathReasoningTemplate (GRPO + accuracy/format rewards) and CodeReasoningTemplate (structural code fence + test case matching)
  • Reasoning dataset auto-detection: pmetal dataset prepare automatically detects problem/thinking/solution columns and formats them as <think> tagged ChatML conversations
  • --columns flag: General column remapping for dataset prepare (e.g., --columns "instruction=question,output=answer")
  • adapter_config.json: Saved alongside LoRA weights during training (r, alpha, target_modules, use_rslora). Loaded automatically at inference and fuse time — eliminates config guesswork
  • Supply chain: cargo-vet initialized with Mozilla, Google, and Bytecode Alliance audit imports; 17 workspace crates covered; 5 transitive dependency exemptions with exact lockfile versions
  • Tracing spans: 6 info_span! markers in Python trainer for phase-level observability (model_resolve, load_tokenizer, load_dataset, load_model, training_loop, save_weights)

Fixed

  • LoRA inference garbage output: Merged LoRA weights into base model at inference time (W += scale*B@A), matching mlx-lm's pattern. The separate-forward path had dtype mismatch issues (BF16 base × F32 LoRA)
  • Auto-chat mode regression: Removed heuristic that forced chat template on base models just because their tokenizer has <|im_end|>. Chat mode now requires explicit --chat or an instruction-tuned model
  • Missing EOS in training data: Training sequences now end with the model's actual EOS token (e.g., <|endoftext|> for Qwen). Previously only had turn delimiter (<|im_end|>) — model never learned to stop generating
  • Fuse command wrong alpha/rank: pmetal fuse now reads adapter_config.json for correct alpha and rank instead of defaulting to scale=1.0. Also filters MLP LoRA weights (rank=0) when auto-detecting rank from shapes
  • ANE x2norm backward bug: FFN weight gradients (dW1, dW3) were computed against the wrong pre-norm tensor (xnorm from attention block instead of x2norm from FFN block). Restored x2norm field and CPU RMSNorm recomputation for gradient correctness
  • ANE sdpa_bwd surface dtype: Backward SDPA output surfaces were allocated as fp32 but ANE kernels produce fp16 — stride mismatch corrupted dV/dQ/dK gradients. Fixed to IoSurface::for_tensor() (fp16)
  • MoD argpartition sign: Router negated weights before argpartition_axis, selecting bottom-k (least important) tokens instead of top-k. Removed negation
  • MLX bridge copy_as_f32 regression: Renamed methods dropped auto dtype conversion — callers passing wrong dtype would panic. Restored copy_as_f32 / copy_as_f16 with auto-conversion
  • MLX bridge view_f32 eval: Removed .eval() call before accessing data pointer — unevaluated arrays returned null. Restored defensive eval
  • Python API surface: Restored ProgressCallback, LoggingCallback(log_every=10), __version__, and PythonCallbackBridge that were deleted during PyO3 migration
  • TUI training completion: Reads final metrics from JSONL file on disk (immune to polling lag). Shows actual loss and step count instead of 0.0000 / sample count
  • TUI Steps/min overflow: Guards against divide-by-zero when total_ms=0 — shows instead of 60000
  • Dataset prepare panic: Empty results no longer crash with index-out-of-bounds. Shows diagnostic message with format hints

Changed

  • LoRA inference uses merge: merge_lora() is called before generation, producing a single merged weight matrix per layer. This is equivalent to the fuse command but happens in-memory without saving
  • PyO3 0.23 → 0.28: allow_threadsdetach, with_gilattach, from_py_object on all pyclass types, Bound<'py, PyDict> return types
  • tokio 1.49 → 1.50
  • unsafe_code lint: Escalated from warn to deny workspace-wide

Full Changelog: v0.3.3...v0.3.4