Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
73948a9
Support transformers 4.x and 5.x simultaneously
jlamypoirier Apr 1, 2026
1c2ceb0
Add transformers 5.x support to external models while preserving 4.x …
jlamypoirier Apr 2, 2026
9265577
Fix MTP Llama converter bugs and hidden state collection
jlamypoirier Apr 9, 2026
284b787
Switch testing tokenizer from santacoder to gpt2
jlamypoirier Apr 9, 2026
4fd678a
Simplify transformers v5 compat; rename _TRANSFORMERS_V5 → _TRANSFORM…
jlamypoirier Apr 9, 2026
93b3485
Replace W-object path chaining with explicit W() calls in plan.py
jlamypoirier Apr 9, 2026
912b71c
Restore loop structure in plan.py; use prefix tuples only at layer init
jlamypoirier Apr 9, 2026
563ed77
fix
jlamypoirier Apr 9, 2026
c51812c
Fix mtp llama test
jlamypoirier Apr 9, 2026
41273d4
Fix transformers v5 compat in apriel2 and mtp_llama external models
jlamypoirier Apr 22, 2026
a091518
misc
jlamypoirier Apr 22, 2026
94ecb21
Revert "misc"
jlamypoirier Apr 22, 2026
1b91a67
Fix transformers v5 compat in apriel2 external model tests
jlamypoirier Apr 23, 2026
7e2cc2d
Delete intermediate roundtrip checkpoints eagerly to reduce peak disk…
jlamypoirier Apr 23, 2026
58dfa6f
Add HF roundtrip tests; fix Mixtral/Qwen2/Apriel2 converter bugs
jlamypoirier Apr 24, 2026
5e449ba
Fix LLaVA plan.py: checkpoint keys are same for transformers 4.x and 5.x
jlamypoirier Apr 24, 2026
ccba6f5
Fix TestGDNEquivalence for both transformers 4.x and 5.x
jlamypoirier Apr 24, 2026
995aa59
Pin transformers to v5 in Dockerfile
jlamypoirier Apr 24, 2026
5933638
Fix preparator test expected values broken by bad merge with main
jlamypoirier Apr 24, 2026
6630192
Merge branch 'main' into jlp_transformers_v5
jlamypoirier Apr 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/
COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/

# Install dependencies within the virtual environment.
RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,STREAMING,DEV]" triton==3.5.1
RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,STREAMING,DEV]" triton==3.5.1 "transformers>=5.0.0,<6.0.0"

# Copy the remaining source code with universal write permissions.
COPY --chmod=777 ./Megatron-LM Megatron-LM
Expand Down
35 changes: 24 additions & 11 deletions fast_llm/engine/inference/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import dataclasses
import logging
import os
import pathlib
Expand All @@ -12,20 +13,32 @@

logger = logging.getLogger(__name__)

_TRANSFORMERS_V4 = not dataclasses.is_dataclass(transformers.PretrainedConfig)


class HuggingfaceModelConfig(transformers.PretrainedConfig):
model_type = "fast_llm"
model_config_class: typing.ClassVar[type[FastLLMModelConfig]] = FastLLMModelConfig
fast_llm_config: FastLLMModelConfig | None = None
use_cache: bool = True

def __init__(self, fast_llm_config: FastLLMModelConfig | None = None, **kwargs):
def __post_init__(self, **kwargs):
# Needed for `to_diff_dict` (`__repr__`)
if fast_llm_config is None:
fast_llm_config = self.model_config_class()
self.fast_llm_config = fast_llm_config
self.use_cache = kwargs.pop("use_cache", True)
super().__init__(**kwargs)
if self.torch_dtype is not None:
assert self.torch_dtype == self.fast_llm_config.distributed.compute_dtype.torch
if self.fast_llm_config is None:
self.fast_llm_config = self.model_config_class()
super().__post_init__(**kwargs)
if self.dtype is not None:
assert self.dtype == self.fast_llm_config.distributed.compute_dtype.torch

if _TRANSFORMERS_V4:

def __init__(self, fast_llm_config: FastLLMModelConfig | None = None, **kwargs):
# Needed for `to_diff_dict` (`__repr__`)
self.fast_llm_config = fast_llm_config if fast_llm_config is not None else self.model_config_class()
self.use_cache = kwargs.pop("use_cache", True)
super().__init__(**kwargs)
if self.torch_dtype is not None:
assert self.torch_dtype == self.fast_llm_config.distributed.compute_dtype.torch

def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs) -> None:
# Hack the method to save at the right place.
Expand Down Expand Up @@ -88,9 +101,9 @@ def _get_config_dict(
)
metadata = cls.model_config_class.load_metadata(pretrained)
updates = {}
torch_dtype = kwargs.pop("torch_dtype", None)
if torch_dtype is not None:
updates[("distributed", "compute_dtype")] = torch_dtype
dtype = kwargs.pop("dtype", None) or kwargs.pop("torch_dtype", None) # torch_dtype: transformers v4
if dtype is not None:
updates[("distributed", "compute_dtype")] = dtype
fast_llm_config = cls.model_config_class.from_metadata(
pretrained, metadata, default=kwargs.pop("fast_llm_config", None), updates=updates
)
Expand Down
12 changes: 9 additions & 3 deletions fast_llm/engine/inference/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,19 @@
from fast_llm.core.distributed import broadcast, broadcast_object, safe_barrier
from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.engine.inference.config import HuggingfaceModelConfig
from fast_llm.engine.inference.config import _TRANSFORMERS_V4, HuggingfaceModelConfig
from fast_llm.engine.inference.runner import InferenceRunner
from fast_llm.engine.multi_stage.config import StageMode
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
from fast_llm.engine.schedule.runner import ScheduleRunner
from fast_llm.utils import Assert

if _TRANSFORMERS_V4:
from transformers.modeling_utils import no_init_weights as transformers_no_init_weights
else:
from transformers.initialization import no_init_weights as transformers_no_init_weights


logger = logging.getLogger(__name__)


Expand All @@ -38,7 +44,7 @@ def __init__(
**kwargs,
):
if config is None:
config = self.config_class(fast_llm_model.config)
config = self.config_class(fast_llm_config=fast_llm_model.config)

assert self.runner_class.model_class.config_class is config.model_config_class
assert config.fast_llm_config is fast_llm_model.config
Expand Down Expand Up @@ -70,7 +76,7 @@ def __init__(
# Transformers needs to be able to inspect the base model.
self.fast_llm_base_model = fast_llm_model.base_model

with transformers.modeling_utils.no_init_weights():
with transformers_no_init_weights():
self.post_init()

if fast_llm_model.config.multi_stage.zero_stage == 3:
Expand Down
17 changes: 11 additions & 6 deletions fast_llm/models/gpt/conversion/apriel2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from fast_llm.engine.checkpoint.external import WeightConverter
from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler
from fast_llm.layers.attention.config import AttentionConfig
from fast_llm.layers.decoder.config import DecoderBlockConfig, StochasticMixerConfig
from fast_llm.layers.decoder.config import DecoderBlockConfig, StochasticMixerConfig, StochasticMixerSamplingStrategy
from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, MambaConfig
from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig
from fast_llm.models.gpt.conversion.config import Apriel2TextCheckpointFormat
Expand Down Expand Up @@ -79,8 +79,9 @@ def export_config(cls, config: AttentionConfig) -> dict:
"type": rotary_type,
"theta": config.rotary.theta,
},
"window_size": config.window_size,
}
if config.window_size is not None:
result["window_size"] = config.window_size
# Export per-layer bias configuration
# Only include if explicitly set (not None)
if config.query_layer.bias.enabled is not None:
Expand All @@ -91,8 +92,10 @@ def export_config(cls, config: AttentionConfig) -> dict:
result["value_layer"] = {"bias": {"enabled": config.value_layer.bias.enabled}}
if config.dense_layer.bias.enabled is not None:
result["dense_layer"] = {"bias": {"enabled": config.dense_layer.bias.enabled}}
# add_linear_biases as fallback default
result["add_linear_biases"] = config.add_linear_biases
# add_linear_biases as fallback default; omit when True (the Fast-LLM default) to avoid
# round-trip inflation on configs that don't set it explicitly.
if not config.add_linear_biases:
result["add_linear_biases"] = config.add_linear_biases
return result

@classmethod
Expand Down Expand Up @@ -491,12 +494,14 @@ def export_config(cls, config: StochasticMixerConfig) -> dict:
else:
raise ValueError(f"Unknown sub-mixer type: {mixer_type}")

return {
result = {
"type": "stochastic",
"mixers": mixers,
"main_mixer_name": config.main_mixer_name,
"sampling_strategy": config.sampling_strategy.value,
}
if config.sampling_strategy != StochasticMixerSamplingStrategy.uniform:
result["sampling_strategy"] = config.sampling_strategy.value
return result

@classmethod
def get_converters(
Expand Down
98 changes: 58 additions & 40 deletions fast_llm/models/gpt/conversion/llama.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import dataclasses
import logging
import typing

import torch
import transformers

from fast_llm.engine.checkpoint.config import CheckpointFormat
from fast_llm.engine.checkpoint.external import (
Expand Down Expand Up @@ -30,6 +32,8 @@
from fast_llm.tensor import SafeTensorSlice
from fast_llm.utils import Assert, div, safe_merge_dicts

_TRANSFORMERS_V4 = not dataclasses.is_dataclass(transformers.PretrainedConfig)

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -188,32 +192,37 @@ def import_weight(
class LlamaAttentionConverter:
@classmethod
def import_config(cls, config: dict) -> dict:
try:
rope_type = config["rope_scaling"]["rope_type"]
except (KeyError, TypeError):
rope_type = "default"
rotary_config = {
"type": rope_type,
"theta": config["rope_theta"],
}
# Normalize rope params to a single dict before dispatching on rope_type.
# transformers v5 consolidates rope_theta + rope_scaling into rope_parameters.
# transformers v4: rope_theta at top level, rope_scaling dict for non-default types.
# Note: detection is on checkpoint format, not transformers version — old checkpoints
# remain loadable with v5 transformers.
if "rope_parameters" in config: # transformers v5
rope_params = config["rope_parameters"]
rope_theta = rope_params["rope_theta"]
else: # transformers v4
rope_params = config.get("rope_scaling") or {}
rope_theta = config["rope_theta"]
rope_type = rope_params.get("rope_type", "default")
rotary_config = {"type": rope_type, "theta": rope_theta}
if rope_type == "default":
pass
elif rope_type == "llama3":
rotary_config.update(
{
"scale_factor": config["rope_scaling"]["factor"],
"low_frequency_factor": config["rope_scaling"]["low_freq_factor"],
"high_frequency_factor": config["rope_scaling"]["high_freq_factor"],
"original_context_length": config["rope_scaling"]["original_max_position_embeddings"],
"scale_factor": rope_params["factor"],
"low_frequency_factor": rope_params["low_freq_factor"],
"high_frequency_factor": rope_params["high_freq_factor"],
"original_context_length": rope_params["original_max_position_embeddings"],
}
)
elif rope_type == "yarn":
rotary_config.update(
{
"attention_factor": config["rope_scaling"]["attention_factor"],
"beta_fast": config["rope_scaling"]["beta_fast"],
"beta_slow": config["rope_scaling"]["beta_slow"],
"original_context_length": config["rope_scaling"]["original_max_position_embeddings"],
"attention_factor": rope_params["attention_factor"],
"beta_fast": rope_params["beta_fast"],
"beta_slow": rope_params["beta_slow"],
"original_context_length": rope_params["original_max_position_embeddings"],
}
)
else:
Expand All @@ -235,36 +244,45 @@ def import_config(cls, config: dict) -> dict:
def export_config(cls, config: AttentionConfig) -> dict:
cls._check_config(config)
Assert.eq(config.softmax_scale_power, 0.5)
out = {
"num_attention_heads": config.heads,
"num_key_value_heads": config.head_groups,
"head_dim": config.head_size,
"attention_bias": config.add_linear_biases,
"attention_dropout": config.dropout,
"rope_theta": config.rotary.theta,
}
rope_parameters = {"rope_theta": config.rotary.theta}
if type(config.rotary) is DefaultRotaryConfig:
pass
rope_parameters["rope_type"] = "default"
elif type(config.rotary) is Llama3RotaryConfig:
out["rope_scaling"] = {
"rope_type": "llama3",
"factor": config.rotary.scale_factor,
"low_freq_factor": config.rotary.low_frequency_factor,
"high_freq_factor": config.rotary.high_frequency_factor,
"original_max_position_embeddings": config.rotary.original_context_length,
}
rope_parameters.update(
{
"rope_type": "llama3",
"factor": config.rotary.scale_factor,
"low_freq_factor": config.rotary.low_frequency_factor,
"high_freq_factor": config.rotary.high_frequency_factor,
"original_max_position_embeddings": config.rotary.original_context_length,
}
)
elif type(config.rotary) is YarnRotaryConfig:
out["rope_scaling"] = {
"rope_type": "yarn",
"attention_factor": config.rotary.attention_factor,
"beta_fast": config.rotary.beta_fast,
"beta_slow": config.rotary.beta_slow,
"original_max_position_embeddings": config.rotary.original_context_length,
}
rope_parameters.update(
{
"rope_type": "yarn",
"attention_factor": config.rotary.attention_factor,
"beta_fast": config.rotary.beta_fast,
"beta_slow": config.rotary.beta_slow,
"original_max_position_embeddings": config.rotary.original_context_length,
}
)
else:
raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}")

return out
common = {
"num_attention_heads": config.heads,
"num_key_value_heads": config.head_groups,
"head_dim": config.head_size,
"attention_bias": config.add_linear_biases,
"attention_dropout": config.dropout,
}
if _TRANSFORMERS_V4:
out = {**common, "rope_theta": rope_parameters["rope_theta"]}
if type(config.rotary) is not DefaultRotaryConfig:
out["rope_scaling"] = {k: v for k, v in rope_parameters.items() if k != "rope_theta"}
return out
return {**common, "rope_parameters": rope_parameters}

@classmethod
def _check_config(cls, config: AttentionConfig) -> None:
Expand Down
5 changes: 4 additions & 1 deletion fast_llm/models/gpt/conversion/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
class MixtralMLPConverter(LlamaMLPConverter):
@classmethod
def import_config(cls, config: dict) -> dict:
config["mlp_bias"] = False
return safe_merge_dicts(
super().import_config(config),
{
Expand All @@ -31,8 +32,10 @@ def import_config(cls, config: dict) -> dict:
def export_config(cls, config: MoEMLPConfig) -> dict:
Assert.custom(isinstance, config, MoEMLPConfig)
assert not config.add_linear_biases
out = super().export_config(config)
del out["mlp_bias"]
return safe_merge_dicts(
super().export_config(config),
out,
{
"num_local_experts": config.experts,
"num_experts_per_tok": config.experts_per_token,
Expand Down
6 changes: 3 additions & 3 deletions fast_llm/models/gpt/conversion/mtp_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_converters(
converters += cls.block_converter_class.get_converters(
config.decoder.last_block_config,
f"multi_token_prediction.blocks.{prediction_distance-2}",
f"model.mtp_heads.{prediction_distance - 1}",
f"model.mtp_heads.{prediction_distance - 2}",
)
converters += cls.normalization_converter_class.get_converters(
config.head.normalization,
Expand All @@ -73,7 +73,7 @@ class MTPLlamaDecoderConverter(LlamaDecoderConverter):
def import_config(cls, config: dict) -> dict:
return {
"block": cls.block_converter_class.import_config(config),
"num_blocks": config["num_hidden_layers"] - 1,
"num_blocks": config["num_hidden_layers"],
}

@classmethod
Expand All @@ -82,7 +82,7 @@ def export_config(cls, config: FixedBlockSequenceConfig) -> dict:
Assert.custom(isinstance, config, FixedBlockSequenceConfig)
return safe_merge_dicts(
cls.block_converter_class.export_config(config.block),
{"num_hidden_layers": config.num_blocks + 1},
{"num_hidden_layers": config.num_blocks},
)


Expand Down
25 changes: 25 additions & 0 deletions fast_llm/models/gpt/conversion/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from fast_llm.engine.checkpoint.config import CheckpointFormat
from fast_llm.engine.checkpoint.external import WeightConverter
from fast_llm.layers.attention.config import AttentionConfig
from fast_llm.layers.block.config import FixedBlockSequenceConfig
from fast_llm.layers.decoder.mlp.config import MLPConfig
from fast_llm.models.gpt.config import GPTBaseModelConfig
from fast_llm.models.gpt.conversion.config import Qwen2CheckpointFormat
from fast_llm.models.gpt.conversion.llama import (
KeyValueWeightConverter,
Expand Down Expand Up @@ -37,6 +39,9 @@ def import_config(cls, config: dict) -> dict:
def export_config(cls, config: AttentionConfig) -> dict:
out = super().export_config(config)
del out["attention_bias"]
# Qwen2Config does not have head_dim as a standard field; it is always
# derivable as hidden_size // num_attention_heads.
del out["head_dim"]
return out

@classmethod
Expand Down Expand Up @@ -118,6 +123,26 @@ class Qwen2BaseModelConverter(LlamaBaseModelConverter):
decoder_converter_class: typing.ClassVar[type[Qwen2DecoderConverter]] = Qwen2DecoderConverter
head_converter_class: typing.ClassVar[type[Qwen2HeadConverter]] = Qwen2HeadConverter

@classmethod
def import_config(cls, config: dict) -> dict:
assert config.get("use_mrope") is not True, "MRoPE (use_mrope=True) is not supported by the Qwen2 converter"
return super().import_config(config)

@classmethod
def export_config(cls, config: GPTBaseModelConfig) -> dict:
block = (
config.decoder.block
if isinstance(config.decoder, FixedBlockSequenceConfig)
else next(iter(config.decoder.blocks.values()))
)
if isinstance(block.mixer, AttentionConfig):
Assert.eq(
block.mixer.heads * block.mixer.head_size,
config.hidden_size,
msg="Qwen2 format omits head_dim; requires heads * head_size == hidden_size",
)
return super().export_config(config)


class Qwen2HuggingfaceCheckpointHandler(LlamaHuggingfaceCheckpointHandler):
format: typing.ClassVar[type[CheckpointFormat]] = Qwen2CheckpointFormat
Expand Down
Loading
Loading