Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 4 additions & 26 deletions fast_llm/models/gpt/conversion/apriel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,10 @@
from fast_llm.layers.attention.config import AttentionConfig
from fast_llm.layers.block.config import BlockSequenceConfig, FixedBlockSequenceConfig, PatternBlockSequenceConfig
from fast_llm.layers.decoder.config import DecoderBlockConfig
from fast_llm.layers.decoder.mlp.config import MLPConfig
from fast_llm.layers.ssm.config import DiscreteMamba2Config, Mamba2Config
from fast_llm.models.gpt.config import GPTModelConfig
from fast_llm.models.gpt.conversion.config import AprielHybridSSMCheckpointFormat
from fast_llm.models.gpt.conversion.llama import (
LlamaMLPConverter,
get_parameter_converter,
get_weight_and_bias_converters,
)
from fast_llm.models.gpt.conversion.llama import get_parameter_converter, get_weight_and_bias_converters
from fast_llm.models.gpt.conversion.mistral import (
MistralBaseModelConverter,
MistralBlockConverter,
Expand Down Expand Up @@ -229,29 +224,12 @@ def get_converters(
]


class AprielMLPConverter(LlamaMLPConverter):
@classmethod
def import_config(cls, config: dict) -> dict:
config["mlp_bias"] = False
return super().import_config(config)

@classmethod
def export_config(cls, config: MLPConfig) -> dict:
out = super().export_config(config)
del out["mlp_bias"]
return out


class AprielBlockConverterBase(MistralBlockConverter):
mlp_converter_class: typing.ClassVar[type[AprielMLPConverter]] = AprielMLPConverter


class AprielDiscreteMamba2BlockConverter(AprielBlockConverterBase):
class AprielDiscreteMamba2BlockConverter(MistralBlockConverter):
mixer_converter_class: typing.ClassVar[type[AprielDiscreteMamba2Converter]] = AprielDiscreteMamba2Converter
hf_mixer_name: typing.ClassVar[str] = "mixer"


class AprielMamba2BlockConverter(AprielBlockConverterBase):
class AprielMamba2BlockConverter(MistralBlockConverter):
mixer_converter_class: typing.ClassVar[type[AprielMamba2Converter]] = AprielMamba2Converter
hf_mixer_name: typing.ClassVar[str] = "mixer"

Expand All @@ -263,7 +241,7 @@ class AprielBlockConverter:
DiscreteMamba2Config: "m2d",
}
_converter_classes = {
AttentionConfig: AprielBlockConverterBase,
AttentionConfig: MistralBlockConverter,
Mamba2Config: AprielMamba2BlockConverter,
DiscreteMamba2Config: AprielDiscreteMamba2BlockConverter,
}
Expand Down
17 changes: 17 additions & 0 deletions fast_llm/models/gpt/conversion/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from fast_llm.engine.checkpoint.config import CheckpointFormat
from fast_llm.layers.attention.config import AttentionConfig
from fast_llm.layers.decoder.mlp.config import MLPConfig
from fast_llm.models.gpt.conversion.config import MistralCheckpointFormat
from fast_llm.models.gpt.conversion.llama import (
LlamaAttentionConverter,
Expand All @@ -10,6 +11,7 @@
LlamaDecoderConverter,
LlamaHeadConverter,
LlamaHuggingfaceCheckpointHandler,
LlamaMLPConverter,
)
from fast_llm.utils import safe_merge_dicts

Expand Down Expand Up @@ -38,8 +40,23 @@ def _check_config(cls, config: AttentionConfig) -> None:
assert not config.add_linear_biases


class MistrallMLPConverter(LlamaMLPConverter):
@classmethod
def import_config(cls, config: dict) -> dict:
config["mlp_bias"] = False
return super().import_config(config)

@classmethod
def export_config(cls, config: MLPConfig) -> dict:
assert not config.add_linear_biases
out = super().export_config(config)
del out["mlp_bias"]
return out


class MistralBlockConverter(LlamaBlockConverter):
mixer_converter_class: typing.ClassVar[type[MistralAttentionConverter]] = MistralAttentionConverter
mlp_converter_class: typing.ClassVar[type[MistrallMLPConverter]] = MistrallMLPConverter


class MistralDecoderConverter(LlamaDecoderConverter):
Expand Down