Skip to content

Ecosystem compatibility: GradientCheckpointingLayer, FlexAttn, TP/PP, hub kernel dispatch#2

Open
t-timms wants to merge 1 commit into
Zyphra:zaya1from
t-timms:feat/ecosystem-compatibility-patches
Open

Ecosystem compatibility: GradientCheckpointingLayer, FlexAttn, TP/PP, hub kernel dispatch#2
t-timms wants to merge 1 commit into
Zyphra:zaya1from
t-timms:feat/ecosystem-compatibility-patches

Conversation

@t-timms
Copy link
Copy Markdown

@t-timms t-timms commented May 11, 2026

Motivation

These changes were identified while fine-tuning ZAYA1-8B with TRL SFTTrainer + GRPOTrainer for agentic multi-turn tool calling. The missing metadata flags prevent standard HuggingFace ecosystem features from activating, even though the model architecture already supports them. Every other major MoE model in transformers main (DeepSeek-V3, Qwen3-MoE) already declares these flags.

What changed

  1. GradientCheckpointingLayer base classZayaDecoderATTLayer now extends GradientCheckpointingLayer instead of nn.Module. Enables automatic gradient checkpointing via model.gradient_checkpointing_enable(), reducing activation memory ~40-60% during training.

  2. Fix _no_split_modules typoZayaDecoderLayerZayaDecoderATTLayer. The old name did not match any actual module class, so _no_split_modules had no effect on gradient checkpointing.

  3. _supports_flex_attn = True — Enables PyTorch FlexAttention backend (attn_implementation=flex_attention). FlexAttention supports the custom attention mask patterns CCA uses.

  4. _can_record_outputs metadata — Enables TRL trainers to capture hidden states and attentions during training. Required for MoE aux loss computation and attention pattern debugging.

  5. Hub-loaded RoPE kernel@use_kernel_func_from_hub(rotary_pos_emb) on apply_rotary_pos_emb. Dispatches to Triton-optimized kernel when Liger Kernel is installed. PyTorch fallback is functionally identical.

  6. Hub-loaded RMSNorm@use_kernel_forward_from_hub(RMSNorm) on ZayaRMSNorm. Same pattern — Triton kernel dispatch when available.

  7. _tp_plan / _pp_plan — Enables tensor parallelism and pipeline parallelism for multi-GPU inference (transformers >= 5.4.0).

  8. router_aux_loss_coef config — Adds standard MoE load balancing coefficient (default=0.001) for future aux loss implementation.

What was deliberately NOT changed

CCA attention, MOD skip expert, EDA routing, fused bias+SwiGLU, FP32 residual accumulation, dual time-stream values, and ZayaDynamicCache are untouched. These are ZAYA-specific architectural innovations and should not be modified.

After merging

Run make fixup from the repo root to regenerate modeling_zaya.py from the modular file.

Testing

All 8 changes are class-attribute assignments or base-class changes. No forward-pass logic modified. Outputs are bit-identical with existing model weights.

… TP/PP, hub kernels, aux loss config

- ZayaDecoderATTLayer: extend GradientCheckpointingLayer instead of nn.Module
- Fix _no_split_modules typo (ZayaDecoderLayer -> ZayaDecoderATTLayer)
- Add _supports_flex_attn=True, _can_record_outputs metadata
- Add @use_kernel_func_from_hub to apply_rotary_pos_emb
- Add @use_kernel_forward_from_hub to ZayaRMSNorm
- Add _tp_plan/_pp_plan to ZayaForCausalLM
- Add router_aux_loss_coef to ZayaConfig
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant