Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions modelopt/torch/puzzletron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
anymodel,
block_config,
build_library_and_stats,
bypass_distillation,
dataset,
entrypoint,
mip,
Expand Down
13 changes: 13 additions & 0 deletions modelopt/torch/puzzletron/anymodel/model_descriptor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,19 @@ def uses_autocast() -> bool:
"""
return True

@staticmethod
def pruning_mixins() -> Dict[str, Any]:
"""Return available pruning mixins for bypass distillation.

Override in subclasses to provide model-specific pruning mixins, e.g.
``{"kv_heads": KVHeadsPruningMixIn(...), "experts_removal": ExpertRemovalPruningMixIn(...)}``.

Returns an empty dict by default so that descriptors that do not need
model-specific weight-slicing (e.g. Llama with standard FFN truncation)
can rely on the generic ``create_child_state_dict`` fallback path.
"""
return {}

@staticmethod
def get_language_model_config(config):
"""Get the language model config from a PretrainedConfig.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ExpertRemovalLayerDescriptor,
ExpertRemovalPruningMixIn,
)
from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn

# Expert removal is supported for unquantized models (test models).
# Production models use MXFP4 quantized MoE with combined tensors
Expand All @@ -37,7 +38,11 @@
from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory
from ...puzzformer.no_op import MatchingZeros, Same, return_tuple_of_size

__all__ = ["GptOssModelDescriptor", "GptOssExpertRemovalLayerDescriptor"]
__all__ = [
"GptOssExpertRemovalLayerDescriptor",
"GptOssKVHeadsLayerDescriptor",
"GptOssModelDescriptor",
]


@ModelDescriptorFactory.register_decorator("gpt_oss")
Expand Down Expand Up @@ -173,7 +178,29 @@ def pruning_mixins() -> Dict[str, PruningMixIn]:
Note: Expert removal works for unquantized models (test models).
Production models use MXFP4 quantization which is not yet supported.
"""
return {"expert_removal": ExpertRemovalPruningMixIn(GptOssExpertRemovalLayerDescriptor())}
# Single instance shared between the canonical key and the legacy alias
# so resolve_pruning_mixin returns the same object regardless of which
# name a caller uses.
expert_mixin = ExpertRemovalPruningMixIn(GptOssExpertRemovalLayerDescriptor())
return {
"experts_removal": expert_mixin,
# Backward-compat alias: this key was "expert_removal" before the
# bypass branch standardised on "experts_removal" (matching the
# NemotronH descriptor). Kept so external scripts that still call
# `resolve_pruning_mixin("expert_removal", GptOssModelDescriptor)`
# continue to work. Remove after a deprecation cycle.
"expert_removal": expert_mixin,
"kv_heads": KVHeadsPruningMixIn(GptOssKVHeadsLayerDescriptor()),
}


@dataclass
class GptOssKVHeadsLayerDescriptor(KVHeadsLayerDescriptor):
o_proj_name: str = "self_attn.o_proj"
attn_prefix_name: str = "model.layers.{layer_idx}.self_attn"
qkvo_weight_names: List[str] = field(
default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"]
)


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,16 @@
ExpertRemovalLayerDescriptor,
ExpertRemovalPruningMixIn,
)
from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn
from ....pruning.pruning_mixin import PruningMixIn
from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory
from ...puzzformer.no_op import MatchingZeros, Same

__all__ = ["NemotronHExpertRemovalLayerDescriptor", "NemotronHModelDescriptor"]
__all__ = [
"NemotronHExpertRemovalLayerDescriptor",
"NemotronHKVHeadsLayerDescriptor",
"NemotronHModelDescriptor",
]


def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]:
Expand All @@ -51,6 +56,15 @@ def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]:
return matches


@dataclass
class NemotronHKVHeadsLayerDescriptor(KVHeadsLayerDescriptor):
o_proj_name: str = "mixer.o_proj"
attn_prefix_name: str = "backbone.layers.{layer_idx}.mixer"
qkvo_weight_names: List[str] = field(
default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"]
)


@dataclass
class NemotronHExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor):
target_name: str = "mixer.gate"
Expand Down Expand Up @@ -251,4 +265,5 @@ def build_attention_predicates() -> Dict[str, re.Pattern]:
def pruning_mixins() -> Dict[str, PruningMixIn]:
return {
"experts_removal": ExpertRemovalPruningMixIn(NemotronHExpertRemovalLayerDescriptor()),
"kv_heads": KVHeadsPruningMixIn(NemotronHKVHeadsLayerDescriptor()),
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,16 @@
FFNIntermediateLayerDescriptor,
FFNIntermediatePruningMixIn,
)
from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn
from ....pruning.pruning_mixin import PruningMixIn
from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory
from ...puzzformer.no_op import MatchingZeros, Same

__all__ = ["NemotronHV2FFNIntermediateLayerDescriptor", "NemotronHV2ModelDescriptor"]
__all__ = [
"NemotronHV2FFNIntermediateLayerDescriptor",
"NemotronHV2KVHeadsLayerDescriptor",
"NemotronHV2ModelDescriptor",
]


def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]:
Expand Down Expand Up @@ -69,6 +74,15 @@ class NemotronHV2FFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor):
linear_weight_names: List[str] = field(default_factory=lambda: ["down_proj", "up_proj"])


@dataclass
class NemotronHV2KVHeadsLayerDescriptor(KVHeadsLayerDescriptor):
o_proj_name: str = "mixer.o_proj"
attn_prefix_name: str = "backbone.layers.{layer_idx}.mixer"
qkvo_weight_names: List[str] = field(
default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"]
)


@ModelDescriptorFactory.register_decorator("nemotron_h_v2")
class NemotronHV2ModelDescriptor(ModelDescriptor):
_DECODER_LAYER_CLS: Type[nn.Module] = None
Expand Down Expand Up @@ -251,5 +265,6 @@ def pruning_mixins() -> Dict[str, PruningMixIn]:
"ffn_intermediate": FFNIntermediatePruningMixIn(
NemotronHV2FFNIntermediateLayerDescriptor()
),
"kv_heads": KVHeadsPruningMixIn(NemotronHV2KVHeadsLayerDescriptor()),
# TODO: Add expert removal support when ExpertRemovalPruningMixIn is migrated
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@
)

from ....block_config import BlockConfig
from ....pruning.expert_removal_pruning_mixin import ExpertRemovalLayerDescriptor
from ....pruning.expert_removal_pruning_mixin import (
ExpertRemovalLayerDescriptor,
ExpertRemovalPruningMixIn,
)
from ....pruning.ffn_intermediate_pruning_mixin import FFNIntermediateLayerDescriptor
from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor
from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn
from ....pruning.pruning_mixin import PruningMixIn
from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory
from ...puzzformer.no_op import MatchingZeros, Same, return_tuple_of_size

Expand Down Expand Up @@ -56,6 +60,13 @@ def get_language_model_config(config):
"""Qwen3-VL has nested text_config for language model parameters."""
return config.text_config if hasattr(config, "text_config") else config

@staticmethod
def pruning_mixins() -> Dict[str, PruningMixIn]:
return {
"experts_removal": ExpertRemovalPruningMixIn(Qwen3VLExpertRemovalLayerDescriptor()),
"kv_heads": KVHeadsPruningMixIn(Qwen3VLKVHeadsLayerDescriptor()),
}

@staticmethod
def decoder_layer_cls():
return Qwen3VLMoeTextDecoderLayer
Expand Down
24 changes: 24 additions & 0 deletions modelopt/torch/puzzletron/bypass_distillation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Bypass distillation (blockwise local distillation) for the PUZZLE framework.

This module implements Stage 1 of the PUZZLE pipeline: training alternative transformer
block configurations using per-block knowledge distillation from a teacher model.
"""

from .training_loop import launch_bypass_distillation

__all__ = ["launch_bypass_distillation"]
Loading
Loading