Ecosystem compatibility: GradientCheckpointingLayer, FlexAttn, TP/PP, hub kernel dispatch#2
Open
t-timms wants to merge 1 commit into
Open
Conversation
… TP/PP, hub kernels, aux loss config - ZayaDecoderATTLayer: extend GradientCheckpointingLayer instead of nn.Module - Fix _no_split_modules typo (ZayaDecoderLayer -> ZayaDecoderATTLayer) - Add _supports_flex_attn=True, _can_record_outputs metadata - Add @use_kernel_func_from_hub to apply_rotary_pos_emb - Add @use_kernel_forward_from_hub to ZayaRMSNorm - Add _tp_plan/_pp_plan to ZayaForCausalLM - Add router_aux_loss_coef to ZayaConfig
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
These changes were identified while fine-tuning ZAYA1-8B with TRL SFTTrainer + GRPOTrainer for agentic multi-turn tool calling. The missing metadata flags prevent standard HuggingFace ecosystem features from activating, even though the model architecture already supports them. Every other major MoE model in transformers main (DeepSeek-V3, Qwen3-MoE) already declares these flags.
What changed
GradientCheckpointingLayer base class —
ZayaDecoderATTLayernow extendsGradientCheckpointingLayerinstead ofnn.Module. Enables automatic gradient checkpointing viamodel.gradient_checkpointing_enable(), reducing activation memory ~40-60% during training.Fix _no_split_modules typo —
ZayaDecoderLayer→ZayaDecoderATTLayer. The old name did not match any actual module class, so_no_split_moduleshad no effect on gradient checkpointing._supports_flex_attn = True — Enables PyTorch FlexAttention backend (
attn_implementation=flex_attention). FlexAttention supports the custom attention mask patterns CCA uses._can_record_outputs metadata — Enables TRL trainers to capture hidden states and attentions during training. Required for MoE aux loss computation and attention pattern debugging.
Hub-loaded RoPE kernel —
@use_kernel_func_from_hub(rotary_pos_emb)onapply_rotary_pos_emb. Dispatches to Triton-optimized kernel when Liger Kernel is installed. PyTorch fallback is functionally identical.Hub-loaded RMSNorm —
@use_kernel_forward_from_hub(RMSNorm)onZayaRMSNorm. Same pattern — Triton kernel dispatch when available._tp_plan / _pp_plan — Enables tensor parallelism and pipeline parallelism for multi-GPU inference (transformers >= 5.4.0).
router_aux_loss_coef config — Adds standard MoE load balancing coefficient (
default=0.001) for future aux loss implementation.What was deliberately NOT changed
CCA attention, MOD skip expert, EDA routing, fused bias+SwiGLU, FP32 residual accumulation, dual time-stream values, and ZayaDynamicCache are untouched. These are ZAYA-specific architectural innovations and should not be modified.
After merging
Run
make fixupfrom the repo root to regeneratemodeling_zaya.pyfrom the modular file.Testing
All 8 changes are class-attribute assignments or base-class changes. No forward-pass logic modified. Outputs are bit-identical with existing model weights.