diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index 7550df044..e16eac4de 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -8,15 +8,10 @@ from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.block.config import BlockSequenceConfig, FixedBlockSequenceConfig, PatternBlockSequenceConfig from fast_llm.layers.decoder.config import DecoderBlockConfig -from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.layers.ssm.config import DiscreteMamba2Config, Mamba2Config from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import AprielHybridSSMCheckpointFormat -from fast_llm.models.gpt.conversion.llama import ( - LlamaMLPConverter, - get_parameter_converter, - get_weight_and_bias_converters, -) +from fast_llm.models.gpt.conversion.llama import get_parameter_converter, get_weight_and_bias_converters from fast_llm.models.gpt.conversion.mistral import ( MistralBaseModelConverter, MistralBlockConverter, @@ -229,29 +224,12 @@ def get_converters( ] -class AprielMLPConverter(LlamaMLPConverter): - @classmethod - def import_config(cls, config: dict) -> dict: - config["mlp_bias"] = False - return super().import_config(config) - - @classmethod - def export_config(cls, config: MLPConfig) -> dict: - out = super().export_config(config) - del out["mlp_bias"] - return out - - -class AprielBlockConverterBase(MistralBlockConverter): - mlp_converter_class: typing.ClassVar[type[AprielMLPConverter]] = AprielMLPConverter - - -class AprielDiscreteMamba2BlockConverter(AprielBlockConverterBase): +class AprielDiscreteMamba2BlockConverter(MistralBlockConverter): mixer_converter_class: typing.ClassVar[type[AprielDiscreteMamba2Converter]] = AprielDiscreteMamba2Converter hf_mixer_name: typing.ClassVar[str] = "mixer" -class AprielMamba2BlockConverter(AprielBlockConverterBase): +class AprielMamba2BlockConverter(MistralBlockConverter): mixer_converter_class: typing.ClassVar[type[AprielMamba2Converter]] = AprielMamba2Converter hf_mixer_name: typing.ClassVar[str] = "mixer" @@ -263,7 +241,7 @@ class AprielBlockConverter: DiscreteMamba2Config: "m2d", } _converter_classes = { - AttentionConfig: AprielBlockConverterBase, + AttentionConfig: MistralBlockConverter, Mamba2Config: AprielMamba2BlockConverter, DiscreteMamba2Config: AprielDiscreteMamba2BlockConverter, } diff --git a/fast_llm/models/gpt/conversion/mistral.py b/fast_llm/models/gpt/conversion/mistral.py index b5db3fa06..a9a0909ec 100644 --- a/fast_llm/models/gpt/conversion/mistral.py +++ b/fast_llm/models/gpt/conversion/mistral.py @@ -2,6 +2,7 @@ from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.models.gpt.conversion.config import MistralCheckpointFormat from fast_llm.models.gpt.conversion.llama import ( LlamaAttentionConverter, @@ -10,6 +11,7 @@ LlamaDecoderConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, + LlamaMLPConverter, ) from fast_llm.utils import safe_merge_dicts @@ -38,8 +40,23 @@ def _check_config(cls, config: AttentionConfig) -> None: assert not config.add_linear_biases +class MistrallMLPConverter(LlamaMLPConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + config["mlp_bias"] = False + return super().import_config(config) + + @classmethod + def export_config(cls, config: MLPConfig) -> dict: + assert not config.add_linear_biases + out = super().export_config(config) + del out["mlp_bias"] + return out + + class MistralBlockConverter(LlamaBlockConverter): mixer_converter_class: typing.ClassVar[type[MistralAttentionConverter]] = MistralAttentionConverter + mlp_converter_class: typing.ClassVar[type[MistrallMLPConverter]] = MistrallMLPConverter class MistralDecoderConverter(LlamaDecoderConverter):