From 9fa4c46d367bbe63ebb1b641813a2f0e56fc8021 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 12 Nov 2025 21:15:33 +0000 Subject: [PATCH 01/14] activation distillation: first draft --- fast_llm/layers/block/config.py | 5 +++ fast_llm/layers/decoder/block.py | 33 +++++++++++++++---- fast_llm/layers/language_model/config.py | 8 +++++ fast_llm/layers/language_model/head.py | 33 +++++++++++++++++-- .../layers/language_model/language_model.py | 3 ++ fast_llm/models/gpt/model.py | 23 +++++++++++-- 6 files changed, 95 insertions(+), 10 deletions(-) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index f3e93ede..2fb2fedf 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -37,6 +37,11 @@ class BlockKwargs: sequence_lengths = "sequence_lengths" # TODO: Belongs elsewhere? grad_output = "grad_output" + root = "_root_kwargs" + activation_distillation_storage = "activation_distillation_storage" + activation_distillation_targets = "activation_distillation_targets" + activation_distillation_total = "activation_distillation_total" + activation_distillation_count = "activation_distillation_count" @config_class(registry=True) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 8b19db66..5081457b 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -3,6 +3,7 @@ import typing import torch +import torch.nn.functional as F from fast_llm.core.distributed import set_generator from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig @@ -14,6 +15,7 @@ from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import BlockWithBiasConfig, DecoderBlockConfig from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -136,13 +138,32 @@ def forward( if self._debug.enabled: self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) hidden_states, bias = self.mixer(hidden_states, kwargs) - if self._debug.enabled: - self._debug( - hidden_states if bias is None else hidden_states + bias, - "mixer output", - kwargs[BlockKwargs.hidden_dims], - kwargs, + mixer_output = hidden_states if bias is None else hidden_states + bias + root_kwargs = kwargs.get(BlockKwargs.root, kwargs) + # Teacher populates mixer activations for distillation. + activation_storage = root_kwargs.get(BlockKwargs.activation_distillation_storage) + if activation_storage is not None: + activation_storage[self.module_name] = mixer_output.detach() + # Student gets teacher activations and computes the activation-level loss. + activation_targets = root_kwargs.get(BlockKwargs.activation_distillation_targets) + if ( + activation_targets is not None + and self.training + and (teacher_output := activation_targets.pop(self.module_name, None)) is not None + ): + # Compare student mixer output with the teacher’s stored activation and accumulate the loss. + teacher_tensor = teacher_output.detach().to(device=mixer_output.device, dtype=mixer_output.dtype) + Assert.eq(teacher_tensor.shape, mixer_output.shape) + activation_loss = F.mse_loss(mixer_output, teacher_tensor) + activation_total = root_kwargs.get(BlockKwargs.activation_distillation_total) + root_kwargs[BlockKwargs.activation_distillation_total] = ( + activation_loss if activation_total is None else activation_total + activation_loss + ) + root_kwargs[BlockKwargs.activation_distillation_count] = ( + root_kwargs.get(BlockKwargs.activation_distillation_count, 0) + 1 ) + if self._debug.enabled: + self._debug(mixer_output, "mixer output", kwargs[BlockKwargs.hidden_dims], kwargs) with set_generator(generator): input_ = self._bias_dropout_add(hidden_states, bias, input_) if self._debug.enabled: diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 25fa2d91..6ede28b9 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -162,6 +162,12 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Factor to scale the distillation loss by when using distillation.", hint=FieldHint.feature, ) + activation_distillation_factor: float = Field( + default=0.0, + desc="Factor to scale the activation-level distillation loss by when using distillation.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) logits_scale_factor: float = Field( default=1.0, desc="Multiply output logits by scale factor.", @@ -233,6 +239,8 @@ def _validate(self) -> None: self.language_model_loss_factor = 0.0 super()._validate() assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both + if self.activation_distillation_factor > 0.0 and self.distillation_model is None: + raise ValueError("Activation distillation requires a distillation_model to be configured.") @property def max_prediction_distance(self) -> int: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 4b0e3d10..a6ca0545 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -18,7 +18,7 @@ from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.block.block import Block -from fast_llm.layers.block.config import BlockDimNames +from fast_llm.layers.block.config import BlockDimNames, BlockKwargs from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import ( @@ -419,6 +419,19 @@ def _logits_cross_entropy_forward_backward( else: distillation_loss, distillation_grad = None, None + activation_loss = None + root_kwargs = kwargs.get(BlockKwargs.root, kwargs) + activation_total = root_kwargs.get(BlockKwargs.activation_distillation_total) + activation_count = root_kwargs.get(BlockKwargs.activation_distillation_count, 0) + if activation_total is not None and activation_count and self._config.activation_distillation_factor > 0.0: + activation_loss = (activation_total / activation_count) * self._config.activation_distillation_factor + if losses is not None and self._activation_distillation_loss_name in losses: + losses[self._activation_distillation_loss_name].append(activation_loss.detach()) + # Activation targets are no longer needed past this point. + root_kwargs.pop(BlockKwargs.activation_distillation_targets, None) + root_kwargs.pop(BlockKwargs.activation_distillation_total, None) + root_kwargs.pop(BlockKwargs.activation_distillation_count, None) + # TODO: de-allocate earlier. del logits @@ -426,7 +439,7 @@ def _logits_cross_entropy_forward_backward( grad = _add_tensors(dpo_grad, lm_grad, distillation_grad) # TODO: Return individual losses? - loss = _add_tensors(dpo_loss, lm_loss, distillation_loss) + loss = _add_tensors(dpo_loss, lm_loss, distillation_loss, activation_loss) if self.training and losses is not None: if dpo_loss is not None: losses[self._dpo_loss_name].append(dpo_loss.detach()) @@ -472,6 +485,13 @@ def _distillation_loss_name(self) -> str: name = f"{name}_{self._prediction_distance}" return name + @functools.cached_property + def _activation_distillation_loss_name(self) -> str: + name = "activation_distillation_loss" + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: loss_defs = [LossDef(name=self._loss_name, formatted_name=_format_name(self._loss_name), count=count)] if self._config.logit_z_loss: @@ -500,6 +520,15 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: ) ) + if self._config.activation_distillation_factor > 0.0 and self._config.distillation_model is not None: + loss_defs.append( + LossDef( + name=self._activation_distillation_loss_name, + formatted_name=_format_name(self._activation_distillation_loss_name), + count=count, + ) + ) + return loss_defs diff --git a/fast_llm/layers/language_model/language_model.py b/fast_llm/layers/language_model/language_model.py index 2e46bb57..4bfe34e4 100644 --- a/fast_llm/layers/language_model/language_model.py +++ b/fast_llm/layers/language_model/language_model.py @@ -8,6 +8,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.block.block import BlockBase +from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import LanguageModelConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding @@ -59,6 +60,8 @@ def get_layers(self) -> list[Layer]: return self.embeddings.get_layers() + self.decoder.get_layers() + self.head.get_layers() def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + # Seed a shared root pointer so nested layers (including namespaced ones) can exchange activation distillation state. + kwargs.setdefault(BlockKwargs.root, kwargs) # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? self.embeddings.preprocess(batch, kwargs) self.decoder.preprocess(batch, kwargs) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index efa348ec..8df83c2a 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -10,7 +10,7 @@ from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.block.config import BlockDimNames +from fast_llm.layers.block.config import BlockDimNames, BlockKwargs from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.language_model import LanguageModel from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig @@ -176,6 +176,8 @@ def preprocess_batch( ) reference_logits = [{} for _ in preprocessed_meta] + distillation_model = getattr(self._config.head, "distillation_model", None) + activation_distillation_factor = getattr(self._config.head, "activation_distillation_factor", 0.0) for name, reference_model in self._reference_models.items(): reference_preprocessed_meta = [ (tokens_meta, kwargs_meta["reference_models"][name]) for tokens_meta, kwargs_meta in preprocessed_meta @@ -188,8 +190,19 @@ def preprocess_batch( # TODO: Do things work with >1? Assert.eq(len(reference_batch), len(preprocessed_meta), 1) for i, (reference_tokens, reference_kwargs) in enumerate(reference_batch): + if ( + phase != PhaseType.inference + and name == distillation_model + and activation_distillation_factor > 0.0 + ): + reference_kwargs[BlockKwargs.activation_distillation_storage] = {} reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"] + if BlockKwargs.activation_distillation_storage in reference_kwargs: + reference_logits[i][f"{name}_activations"] = reference_kwargs[ + BlockKwargs.activation_distillation_storage + ] + del reference_kwargs[BlockKwargs.activation_distillation_storage] token_ids = batch.token_ids if sequence_first: @@ -255,7 +268,13 @@ def preprocess_batch( kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) kwargs[LanguageModelKwargs.labels] = labels - kwargs.update(reference_logits[i]) + reference_payload = reference_logits[i] + kwargs.update(reference_payload) + + if distillation_model is not None and activation_distillation_factor > 0.0: + teacher_key = f"{distillation_model}_activations" + if teacher_key in reference_payload: + kwargs[BlockKwargs.activation_distillation_targets] = reference_payload.pop(teacher_key) if batch.chosen_spans is not None: chosen_valid_spans = [] From 11708ff5713bda914cb32b6cb458a02b895cbc41 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 12 Nov 2025 23:38:48 +0000 Subject: [PATCH 02/14] fix kwargs --- fast_llm/engine/base_model/base_model.py | 1 + .../layers/language_model/language_model.py | 21 ++++++++++-- fast_llm/models/gpt/model.py | 34 ++++++++++--------- 3 files changed, 38 insertions(+), 18 deletions(-) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 5df59d4c..106bea21 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -145,6 +145,7 @@ def preprocess_batch( phase: PhaseType, iteration: int, metrics: dict | None = None, + setup_activation_storage: bool = False, ) -> list[tuple[torch.Tensor, dict]]: # TODO Move batch splitting elsewhere, align interface with LayerBase pass diff --git a/fast_llm/layers/language_model/language_model.py b/fast_llm/layers/language_model/language_model.py index 4bfe34e4..579eb531 100644 --- a/fast_llm/layers/language_model/language_model.py +++ b/fast_llm/layers/language_model/language_model.py @@ -60,8 +60,25 @@ def get_layers(self) -> list[Layer]: return self.embeddings.get_layers() + self.decoder.get_layers() + self.head.get_layers() def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: - # Seed a shared root pointer so nested layers (including namespaced ones) can exchange activation distillation state. - kwargs.setdefault(BlockKwargs.root, kwargs) + # TODO: remove root_kwargs + activation_factor = getattr(self.head._config, "activation_distillation_factor", 0.0) + if ( + activation_factor > 0.0 + or BlockKwargs.activation_distillation_targets in kwargs + or BlockKwargs.activation_distillation_storage in kwargs + ): + root_state = kwargs.get(BlockKwargs.root) + if root_state is None or root_state is kwargs: + root_state = {} + kwargs[BlockKwargs.root] = root_state + if BlockKwargs.activation_distillation_targets in kwargs: + root_state[BlockKwargs.activation_distillation_targets] = kwargs[ + BlockKwargs.activation_distillation_targets + ] + if BlockKwargs.activation_distillation_storage in kwargs: + root_state[BlockKwargs.activation_distillation_storage] = kwargs[ + BlockKwargs.activation_distillation_storage + ] # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? self.embeddings.preprocess(batch, kwargs) self.decoder.preprocess(batch, kwargs) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 8df83c2a..25affc49 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -157,6 +157,7 @@ def preprocess_batch( phase: PhaseType, iteration: int, metrics: dict | None = None, + setup_activation_storage: bool = False, ) -> list[tuple[torch.Tensor, dict]]: # TODO Move batch splitting elsewhere, align interface with LayerBase assert self._is_setup @@ -175,27 +176,26 @@ def preprocess_batch( non_blocking=True, ) - reference_logits = [{} for _ in preprocessed_meta] distillation_model = getattr(self._config.head, "distillation_model", None) - activation_distillation_factor = getattr(self._config.head, "activation_distillation_factor", 0.0) + activation_factor = getattr(self._config.head, "activation_distillation_factor", 0.0) + reference_logits: list[dict[str, typing.Any]] | None = None + reference_logits = [{} for _ in preprocessed_meta] for name, reference_model in self._reference_models.items(): reference_preprocessed_meta = [ (tokens_meta, kwargs_meta["reference_models"][name]) for tokens_meta, kwargs_meta in preprocessed_meta ] reference_batch = reference_model.fast_llm_model.base_model.preprocess_batch( - batch, reference_preprocessed_meta, phase=PhaseType.inference, iteration=iteration + batch, + reference_preprocessed_meta, + phase=PhaseType.inference, + iteration=iteration, + setup_activation_storage=activation_factor > 0.0, ) # TODO: Do things work with >1? Assert.eq(len(reference_batch), len(preprocessed_meta), 1) for i, (reference_tokens, reference_kwargs) in enumerate(reference_batch): - if ( - phase != PhaseType.inference - and name == distillation_model - and activation_distillation_factor > 0.0 - ): - reference_kwargs[BlockKwargs.activation_distillation_storage] = {} reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"] if BlockKwargs.activation_distillation_storage in reference_kwargs: @@ -268,13 +268,13 @@ def preprocess_batch( kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) kwargs[LanguageModelKwargs.labels] = labels - reference_payload = reference_logits[i] - kwargs.update(reference_payload) - - if distillation_model is not None and activation_distillation_factor > 0.0: - teacher_key = f"{distillation_model}_activations" - if teacher_key in reference_payload: - kwargs[BlockKwargs.activation_distillation_targets] = reference_payload.pop(teacher_key) + if reference_logits is not None: + reference_payload = reference_logits[i] + kwargs.update(reference_payload) + if distillation_model is not None and activation_factor > 0.0: + teacher_key = f"{distillation_model}_activations" + if teacher_key in reference_payload: + kwargs[BlockKwargs.activation_distillation_targets] = reference_payload.pop(teacher_key) if batch.chosen_spans is not None: chosen_valid_spans = [] @@ -307,6 +307,8 @@ def preprocess_batch( rejected_valid_spans.append(valid_spans) kwargs[LanguageModelKwargs.rejected_spans] = rejected_valid_spans + if setup_activation_storage: + kwargs.setdefault(BlockKwargs.activation_distillation_storage, {}) self.preprocess(tokens, kwargs) preprocessed.append((tokens, kwargs)) From 943731090bb670406123c97a980ea4f27225e3ad Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 13 Nov 2025 19:33:22 +0000 Subject: [PATCH 03/14] remove count, add auxiliaryLoss hook --- fast_llm/layers/block/config.py | 1 - fast_llm/layers/decoder/block.py | 15 ++++++++------- fast_llm/layers/language_model/head.py | 6 ++---- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 2fb2fedf..45dbe495 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -41,7 +41,6 @@ class BlockKwargs: activation_distillation_storage = "activation_distillation_storage" activation_distillation_targets = "activation_distillation_targets" activation_distillation_total = "activation_distillation_total" - activation_distillation_count = "activation_distillation_count" @config_class(registry=True) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 5081457b..d5d87cdb 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -3,7 +3,6 @@ import typing import torch -import torch.nn.functional as F from fast_llm.core.distributed import set_generator from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig @@ -12,6 +11,7 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import BlockWithBiasConfig, DecoderBlockConfig from fast_llm.tensor import TensorMeta @@ -139,6 +139,7 @@ def forward( self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) hidden_states, bias = self.mixer(hidden_states, kwargs) mixer_output = hidden_states if bias is None else hidden_states + bias + # TODO: wrap in method: activation_distillation root_kwargs = kwargs.get(BlockKwargs.root, kwargs) # Teacher populates mixer activations for distillation. activation_storage = root_kwargs.get(BlockKwargs.activation_distillation_storage) @@ -154,14 +155,14 @@ def forward( # Compare student mixer output with the teacher’s stored activation and accumulate the loss. teacher_tensor = teacher_output.detach().to(device=mixer_output.device, dtype=mixer_output.dtype) Assert.eq(teacher_tensor.shape, mixer_output.shape) - activation_loss = F.mse_loss(mixer_output, teacher_tensor) + # TODO: handle sequence-first? + activation_loss = torch.mean(torch.norm(mixer_output - teacher_tensor, p=2, dim=(1, 2))) + mixer_output = AuxiliaryLoss.apply(mixer_output, activation_loss, 1.0) activation_total = root_kwargs.get(BlockKwargs.activation_distillation_total) - root_kwargs[BlockKwargs.activation_distillation_total] = ( - activation_loss if activation_total is None else activation_total + activation_loss - ) - root_kwargs[BlockKwargs.activation_distillation_count] = ( - root_kwargs.get(BlockKwargs.activation_distillation_count, 0) + 1 + activation_total = ( + activation_loss.detach() if activation_total is None else activation_total + activation_loss.detach() ) + root_kwargs[BlockKwargs.activation_distillation_total] = activation_total if self._debug.enabled: self._debug(mixer_output, "mixer output", kwargs[BlockKwargs.hidden_dims], kwargs) with set_generator(generator): diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index a6ca0545..290f5014 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -422,15 +422,13 @@ def _logits_cross_entropy_forward_backward( activation_loss = None root_kwargs = kwargs.get(BlockKwargs.root, kwargs) activation_total = root_kwargs.get(BlockKwargs.activation_distillation_total) - activation_count = root_kwargs.get(BlockKwargs.activation_distillation_count, 0) - if activation_total is not None and activation_count and self._config.activation_distillation_factor > 0.0: - activation_loss = (activation_total / activation_count) * self._config.activation_distillation_factor + if activation_total is not None and self._config.activation_distillation_factor > 0.0: + activation_loss = activation_total * self._config.activation_distillation_factor if losses is not None and self._activation_distillation_loss_name in losses: losses[self._activation_distillation_loss_name].append(activation_loss.detach()) # Activation targets are no longer needed past this point. root_kwargs.pop(BlockKwargs.activation_distillation_targets, None) root_kwargs.pop(BlockKwargs.activation_distillation_total, None) - root_kwargs.pop(BlockKwargs.activation_distillation_count, None) # TODO: de-allocate earlier. del logits From d3ac9646e75ea5754c9b1d8328275b28e966c203 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 13 Nov 2025 19:40:33 +0000 Subject: [PATCH 04/14] fix auxiliary loss --- fast_llm/layers/decoder/block.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index d5d87cdb..92f63872 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -157,7 +157,10 @@ def forward( Assert.eq(teacher_tensor.shape, mixer_output.shape) # TODO: handle sequence-first? activation_loss = torch.mean(torch.norm(mixer_output - teacher_tensor, p=2, dim=(1, 2))) - mixer_output = AuxiliaryLoss.apply(mixer_output, activation_loss, 1.0) + # Backward hooks + hidden_states = AuxiliaryLoss.apply(hidden_states, activation_loss, 1.0) + bias = AuxiliaryLoss.apply(bias, activation_loss, 1.0) if bias is not None else None + # Logging activation_total = root_kwargs.get(BlockKwargs.activation_distillation_total) activation_total = ( activation_loss.detach() if activation_total is None else activation_total + activation_loss.detach() From 56fc8db9d3caac25e497ea8ef3e4b03ebac598b7 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 13 Nov 2025 19:45:26 +0000 Subject: [PATCH 05/14] wrap in method --- fast_llm/layers/decoder/block.py | 58 +++++++++++++++++++------------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 92f63872..6fe00006 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -138,6 +138,40 @@ def forward( if self._debug.enabled: self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) hidden_states, bias = self.mixer(hidden_states, kwargs) + + self.activation_distillation_loss(hidden_states, bias, kwargs) + + if self._debug.enabled: + self._debug( + hidden_states if bias is None else hidden_states + bias, + "mixer output", + kwargs[BlockKwargs.hidden_dims], + kwargs, + ) + with set_generator(generator): + input_ = self._bias_dropout_add(hidden_states, bias, input_) + if self._debug.enabled: + self._debug(input_, "mixer residual", kwargs[BlockKwargs.hidden_dims], kwargs) + hidden_states = self.norm_2(input_) + if self._debug.enabled: + self._debug(hidden_states, "norm 2", kwargs[BlockKwargs.hidden_dims], kwargs) + hidden_states, bias = self.mlp(hidden_states, kwargs, losses, metrics) + if self._debug.enabled: + self._debug( + hidden_states if bias is None else hidden_states + bias, + "MLP output", + kwargs[BlockKwargs.hidden_dims], + kwargs, + ) + with set_generator(generator): + hidden_states = self._bias_dropout_add(hidden_states, bias, input_) + if self._debug.enabled: + self._debug(None, "MLP residual", kwargs[BlockKwargs.hidden_dims], kwargs) + if self._return_input: + hidden_states = torch.stack((fw_input, hidden_states), dim=0) + return hidden_states + + def activation_distillation_loss(self, hidden_states, bias, kwargs): mixer_output = hidden_states if bias is None else hidden_states + bias # TODO: wrap in method: activation_distillation root_kwargs = kwargs.get(BlockKwargs.root, kwargs) @@ -166,30 +200,6 @@ def forward( activation_loss.detach() if activation_total is None else activation_total + activation_loss.detach() ) root_kwargs[BlockKwargs.activation_distillation_total] = activation_total - if self._debug.enabled: - self._debug(mixer_output, "mixer output", kwargs[BlockKwargs.hidden_dims], kwargs) - with set_generator(generator): - input_ = self._bias_dropout_add(hidden_states, bias, input_) - if self._debug.enabled: - self._debug(input_, "mixer residual", kwargs[BlockKwargs.hidden_dims], kwargs) - hidden_states = self.norm_2(input_) - if self._debug.enabled: - self._debug(hidden_states, "norm 2", kwargs[BlockKwargs.hidden_dims], kwargs) - hidden_states, bias = self.mlp(hidden_states, kwargs, losses, metrics) - if self._debug.enabled: - self._debug( - hidden_states if bias is None else hidden_states + bias, - "MLP output", - kwargs[BlockKwargs.hidden_dims], - kwargs, - ) - with set_generator(generator): - hidden_states = self._bias_dropout_add(hidden_states, bias, input_) - if self._debug.enabled: - self._debug(None, "MLP residual", kwargs[BlockKwargs.hidden_dims], kwargs) - if self._return_input: - hidden_states = torch.stack((fw_input, hidden_states), dim=0) - return hidden_states def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: # TODO: Add marginal compute? (normalization, bias_dropout_add) From 5d75f01469bce67da86f42b87695ee2365da6ce3 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 13 Nov 2025 19:47:33 +0000 Subject: [PATCH 06/14] fixes --- fast_llm/layers/decoder/block.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 6fe00006..7b058f5f 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -139,7 +139,7 @@ def forward( self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) hidden_states, bias = self.mixer(hidden_states, kwargs) - self.activation_distillation_loss(hidden_states, bias, kwargs) + hidden_states, bias = self.activation_distillation_loss(hidden_states, bias, kwargs) if self._debug.enabled: self._debug( @@ -172,8 +172,10 @@ def forward( return hidden_states def activation_distillation_loss(self, hidden_states, bias, kwargs): + """ + Maybe apply activation distillation loss and setup backward hooks + """ mixer_output = hidden_states if bias is None else hidden_states + bias - # TODO: wrap in method: activation_distillation root_kwargs = kwargs.get(BlockKwargs.root, kwargs) # Teacher populates mixer activations for distillation. activation_storage = root_kwargs.get(BlockKwargs.activation_distillation_storage) @@ -200,6 +202,7 @@ def activation_distillation_loss(self, hidden_states, bias, kwargs): activation_loss.detach() if activation_total is None else activation_total + activation_loss.detach() ) root_kwargs[BlockKwargs.activation_distillation_total] = activation_total + return hidden_states, bias def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: # TODO: Add marginal compute? (normalization, bias_dropout_add) From f1bfca967743ca53a6530f918bab4bd57a6a1b0c Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 14 Nov 2025 21:04:57 +0000 Subject: [PATCH 07/14] move activation distillation loss reporting to decoder block --- fast_llm/layers/decoder/block.py | 33 ++++++++++++++++--- fast_llm/layers/decoder/config.py | 16 +++++++++ fast_llm/layers/language_model/config.py | 8 ----- fast_llm/layers/language_model/head.py | 31 ++--------------- .../layers/language_model/language_model.py | 2 +- 5 files changed, 48 insertions(+), 42 deletions(-) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 7b058f5f..86f1e8c0 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -14,6 +14,7 @@ from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import BlockWithBiasConfig, DecoderBlockConfig +from fast_llm.layers.language_model.head import _format_name from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert @@ -139,7 +140,7 @@ def forward( self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) hidden_states, bias = self.mixer(hidden_states, kwargs) - hidden_states, bias = self.activation_distillation_loss(hidden_states, bias, kwargs) + hidden_states, bias = self.activation_distillation_loss(hidden_states, bias, kwargs, losses) if self._debug.enabled: self._debug( @@ -171,7 +172,7 @@ def forward( hidden_states = torch.stack((fw_input, hidden_states), dim=0) return hidden_states - def activation_distillation_loss(self, hidden_states, bias, kwargs): + def activation_distillation_loss(self, hidden_states, bias, kwargs, losses): """ Maybe apply activation distillation loss and setup backward hooks """ @@ -192,7 +193,12 @@ def activation_distillation_loss(self, hidden_states, bias, kwargs): teacher_tensor = teacher_output.detach().to(device=mixer_output.device, dtype=mixer_output.dtype) Assert.eq(teacher_tensor.shape, mixer_output.shape) # TODO: handle sequence-first? - activation_loss = torch.mean(torch.norm(mixer_output - teacher_tensor, p=2, dim=(1, 2))) + # TODO: un-scaled loss for reporting? Average loss over layers? + # L2 loss + activation_loss_factor = self._config.activation_distillation_factor + activation_loss = activation_loss_factor * torch.mean( + torch.norm(mixer_output - teacher_tensor, p=2, dim=(2)) + ) # Backward hooks hidden_states = AuxiliaryLoss.apply(hidden_states, activation_loss, 1.0) bias = AuxiliaryLoss.apply(bias, activation_loss, 1.0) if bias is not None else None @@ -202,6 +208,9 @@ def activation_distillation_loss(self, hidden_states, bias, kwargs): activation_loss.detach() if activation_total is None else activation_total + activation_loss.detach() ) root_kwargs[BlockKwargs.activation_distillation_total] = activation_total + + if losses is not None and self._activation_distillation_loss_name in losses: + losses[self._activation_distillation_loss_name].append(activation_total.detach()) return hidden_states, bias def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: @@ -217,5 +226,21 @@ def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None self.mixer.preprocess(batch, kwargs) self.mlp.preprocess(batch, kwargs) + # TODO: add layer_index + _activation_distillation_loss_name = "activation_distillation_loss" + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return self.mixer.get_loss_definitions(count=count) + self.mlp.get_loss_definitions(count=count) + loss_definitions = [] + if self._config.activation_distillation_factor > 0.0 and self._config.distillation_model is not None: + loss_definitions.append( + LossDef( + name=self._activation_distillation_loss_name, + formatted_name=_format_name(self._activation_distillation_loss_name), + count=count, + ) + ) + return ( + loss_definitions + + self.mixer.get_loss_definitions(count=count) + + self.mlp.get_loss_definitions(count=count) + ) diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 403b204c..99331ee7 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -88,6 +88,22 @@ class DecoderBlockConfig(BlockConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) + distillation_model: str | None = Field( + default=None, + desc="Name of the reference model to use for activation-level distillation.", + hint=FieldHint.feature, + ) + activation_distillation_factor: float = Field( + default=0.0, + desc="Factor to scale the activation-level distillation loss by.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + + def _validate(self) -> None: + super()._validate() + if self.activation_distillation_factor > 0.0 and self.distillation_model is None: + raise ValueError("Activation distillation requires a distillation_model.") @property def layer_class(self) -> "type[DecoderBlock]": diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 6ede28b9..25fa2d91 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -162,12 +162,6 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Factor to scale the distillation loss by when using distillation.", hint=FieldHint.feature, ) - activation_distillation_factor: float = Field( - default=0.0, - desc="Factor to scale the activation-level distillation loss by when using distillation.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) logits_scale_factor: float = Field( default=1.0, desc="Multiply output logits by scale factor.", @@ -239,8 +233,6 @@ def _validate(self) -> None: self.language_model_loss_factor = 0.0 super()._validate() assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both - if self.activation_distillation_factor > 0.0 and self.distillation_model is None: - raise ValueError("Activation distillation requires a distillation_model to be configured.") @property def max_prediction_distance(self) -> int: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 290f5014..4b0e3d10 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -18,7 +18,7 @@ from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.block.block import Block -from fast_llm.layers.block.config import BlockDimNames, BlockKwargs +from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import ( @@ -419,17 +419,6 @@ def _logits_cross_entropy_forward_backward( else: distillation_loss, distillation_grad = None, None - activation_loss = None - root_kwargs = kwargs.get(BlockKwargs.root, kwargs) - activation_total = root_kwargs.get(BlockKwargs.activation_distillation_total) - if activation_total is not None and self._config.activation_distillation_factor > 0.0: - activation_loss = activation_total * self._config.activation_distillation_factor - if losses is not None and self._activation_distillation_loss_name in losses: - losses[self._activation_distillation_loss_name].append(activation_loss.detach()) - # Activation targets are no longer needed past this point. - root_kwargs.pop(BlockKwargs.activation_distillation_targets, None) - root_kwargs.pop(BlockKwargs.activation_distillation_total, None) - # TODO: de-allocate earlier. del logits @@ -437,7 +426,7 @@ def _logits_cross_entropy_forward_backward( grad = _add_tensors(dpo_grad, lm_grad, distillation_grad) # TODO: Return individual losses? - loss = _add_tensors(dpo_loss, lm_loss, distillation_loss, activation_loss) + loss = _add_tensors(dpo_loss, lm_loss, distillation_loss) if self.training and losses is not None: if dpo_loss is not None: losses[self._dpo_loss_name].append(dpo_loss.detach()) @@ -483,13 +472,6 @@ def _distillation_loss_name(self) -> str: name = f"{name}_{self._prediction_distance}" return name - @functools.cached_property - def _activation_distillation_loss_name(self) -> str: - name = "activation_distillation_loss" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: loss_defs = [LossDef(name=self._loss_name, formatted_name=_format_name(self._loss_name), count=count)] if self._config.logit_z_loss: @@ -518,15 +500,6 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: ) ) - if self._config.activation_distillation_factor > 0.0 and self._config.distillation_model is not None: - loss_defs.append( - LossDef( - name=self._activation_distillation_loss_name, - formatted_name=_format_name(self._activation_distillation_loss_name), - count=count, - ) - ) - return loss_defs diff --git a/fast_llm/layers/language_model/language_model.py b/fast_llm/layers/language_model/language_model.py index 579eb531..0b8157d7 100644 --- a/fast_llm/layers/language_model/language_model.py +++ b/fast_llm/layers/language_model/language_model.py @@ -61,7 +61,7 @@ def get_layers(self) -> list[Layer]: def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: # TODO: remove root_kwargs - activation_factor = getattr(self.head._config, "activation_distillation_factor", 0.0) + activation_factor = getattr(self.decoder.config.block, "activation_distillation_factor", 0.0) if ( activation_factor > 0.0 or BlockKwargs.activation_distillation_targets in kwargs From 8b1675203dbcbbf7a0a9e6b2141ec99669166fd0 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 14 Nov 2025 21:25:02 +0000 Subject: [PATCH 08/14] fix logging --- fast_llm/layers/decoder/block.py | 10 +++------- fast_llm/models/gpt/model.py | 7 ++++--- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 86f1e8c0..9867019c 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -196,6 +196,8 @@ def activation_distillation_loss(self, hidden_states, bias, kwargs, losses): # TODO: un-scaled loss for reporting? Average loss over layers? # L2 loss activation_loss_factor = self._config.activation_distillation_factor + # (batch, sequence, hidden). Take the norm over hidden dim. + # TODO: handle possible padding? activation_loss = activation_loss_factor * torch.mean( torch.norm(mixer_output - teacher_tensor, p=2, dim=(2)) ) @@ -203,14 +205,8 @@ def activation_distillation_loss(self, hidden_states, bias, kwargs, losses): hidden_states = AuxiliaryLoss.apply(hidden_states, activation_loss, 1.0) bias = AuxiliaryLoss.apply(bias, activation_loss, 1.0) if bias is not None else None # Logging - activation_total = root_kwargs.get(BlockKwargs.activation_distillation_total) - activation_total = ( - activation_loss.detach() if activation_total is None else activation_total + activation_loss.detach() - ) - root_kwargs[BlockKwargs.activation_distillation_total] = activation_total - if losses is not None and self._activation_distillation_loss_name in losses: - losses[self._activation_distillation_loss_name].append(activation_total.detach()) + losses[self._activation_distillation_loss_name].append(activation_loss.detach()) return hidden_states, bias def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 25affc49..17187d0b 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -176,8 +176,9 @@ def preprocess_batch( non_blocking=True, ) - distillation_model = getattr(self._config.head, "distillation_model", None) - activation_factor = getattr(self._config.head, "activation_distillation_factor", 0.0) + # TODO: decoder doesn't necessarily have a `block` attribute + distillation_model = self._config.decoder.block.distillation_model + activation_factor = self._config.decoder.block.activation_distillation_factor reference_logits: list[dict[str, typing.Any]] | None = None reference_logits = [{} for _ in preprocessed_meta] for name, reference_model in self._reference_models.items(): @@ -190,7 +191,7 @@ def preprocess_batch( reference_preprocessed_meta, phase=PhaseType.inference, iteration=iteration, - setup_activation_storage=activation_factor > 0.0, + setup_activation_storage=activation_factor > 0.0 and distillation_model == name, ) # TODO: Do things work with >1? From efa8cf0d613d865b13c8b11412ca9ae04d1ada5f Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 14 Nov 2025 22:01:07 +0000 Subject: [PATCH 09/14] remove root kwargs --- fast_llm/layers/block/config.py | 1 - fast_llm/layers/decoder/block.py | 5 ++--- .../layers/language_model/language_model.py | 20 ------------------- 3 files changed, 2 insertions(+), 24 deletions(-) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 45dbe495..dfc80a47 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -37,7 +37,6 @@ class BlockKwargs: sequence_lengths = "sequence_lengths" # TODO: Belongs elsewhere? grad_output = "grad_output" - root = "_root_kwargs" activation_distillation_storage = "activation_distillation_storage" activation_distillation_targets = "activation_distillation_targets" activation_distillation_total = "activation_distillation_total" diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 9867019c..05df6e67 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -177,13 +177,12 @@ def activation_distillation_loss(self, hidden_states, bias, kwargs, losses): Maybe apply activation distillation loss and setup backward hooks """ mixer_output = hidden_states if bias is None else hidden_states + bias - root_kwargs = kwargs.get(BlockKwargs.root, kwargs) # Teacher populates mixer activations for distillation. - activation_storage = root_kwargs.get(BlockKwargs.activation_distillation_storage) + activation_storage = kwargs.get(BlockKwargs.activation_distillation_storage) if activation_storage is not None: activation_storage[self.module_name] = mixer_output.detach() # Student gets teacher activations and computes the activation-level loss. - activation_targets = root_kwargs.get(BlockKwargs.activation_distillation_targets) + activation_targets = kwargs.get(BlockKwargs.activation_distillation_targets) if ( activation_targets is not None and self.training diff --git a/fast_llm/layers/language_model/language_model.py b/fast_llm/layers/language_model/language_model.py index 0b8157d7..2e46bb57 100644 --- a/fast_llm/layers/language_model/language_model.py +++ b/fast_llm/layers/language_model/language_model.py @@ -8,7 +8,6 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.block.block import BlockBase -from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import LanguageModelConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding @@ -60,25 +59,6 @@ def get_layers(self) -> list[Layer]: return self.embeddings.get_layers() + self.decoder.get_layers() + self.head.get_layers() def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: - # TODO: remove root_kwargs - activation_factor = getattr(self.decoder.config.block, "activation_distillation_factor", 0.0) - if ( - activation_factor > 0.0 - or BlockKwargs.activation_distillation_targets in kwargs - or BlockKwargs.activation_distillation_storage in kwargs - ): - root_state = kwargs.get(BlockKwargs.root) - if root_state is None or root_state is kwargs: - root_state = {} - kwargs[BlockKwargs.root] = root_state - if BlockKwargs.activation_distillation_targets in kwargs: - root_state[BlockKwargs.activation_distillation_targets] = kwargs[ - BlockKwargs.activation_distillation_targets - ] - if BlockKwargs.activation_distillation_storage in kwargs: - root_state[BlockKwargs.activation_distillation_storage] = kwargs[ - BlockKwargs.activation_distillation_storage - ] # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? self.embeddings.preprocess(batch, kwargs) self.decoder.preprocess(batch, kwargs) From 4cda56d69af6f226f091480584eccaa4e851e638 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 17 Nov 2025 20:43:48 +0000 Subject: [PATCH 10/14] fix mistral mlp conversion --- fast_llm/models/gpt/conversion/mistral.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/fast_llm/models/gpt/conversion/mistral.py b/fast_llm/models/gpt/conversion/mistral.py index b5db3fa0..28941bc8 100644 --- a/fast_llm/models/gpt/conversion/mistral.py +++ b/fast_llm/models/gpt/conversion/mistral.py @@ -2,6 +2,7 @@ from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.models.gpt.conversion.config import MistralCheckpointFormat from fast_llm.models.gpt.conversion.llama import ( LlamaAttentionConverter, @@ -10,6 +11,7 @@ LlamaDecoderConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, + LlamaMLPConverter, ) from fast_llm.utils import safe_merge_dicts @@ -38,8 +40,26 @@ def _check_config(cls, config: AttentionConfig) -> None: assert not config.add_linear_biases +class MistrallMLPConverter(LlamaMLPConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + config["mlp_bias"] = False + return super().import_config(config) + + @classmethod + def export_config(cls, config: MLPConfig) -> dict: + out = super().export_config(config) + del out["mlp_bias"] + return out + + @classmethod + def _check_config(cls, config: MLPConfig) -> None: + assert not config.add_linear_biases + + class MistralBlockConverter(LlamaBlockConverter): mixer_converter_class: typing.ClassVar[type[MistralAttentionConverter]] = MistralAttentionConverter + mlp_converter_class: typing.ClassVar[type[MistrallMLPConverter]] = MistrallMLPConverter class MistralDecoderConverter(LlamaDecoderConverter): From 41692e9fa2c8890231730abed464cad6e171e3bf Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 17 Nov 2025 20:49:05 +0000 Subject: [PATCH 11/14] remove duplicate from apriel conversion --- fast_llm/models/gpt/conversion/apriel.py | 25 +----------------------- 1 file changed, 1 insertion(+), 24 deletions(-) diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index 7550df04..ffd2522c 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -8,18 +8,12 @@ from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.block.config import BlockSequenceConfig, FixedBlockSequenceConfig, PatternBlockSequenceConfig from fast_llm.layers.decoder.config import DecoderBlockConfig -from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.layers.ssm.config import DiscreteMamba2Config, Mamba2Config from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import AprielHybridSSMCheckpointFormat -from fast_llm.models.gpt.conversion.llama import ( - LlamaMLPConverter, - get_parameter_converter, - get_weight_and_bias_converters, -) +from fast_llm.models.gpt.conversion.llama import get_parameter_converter, get_weight_and_bias_converters from fast_llm.models.gpt.conversion.mistral import ( MistralBaseModelConverter, - MistralBlockConverter, MistralDecoderConverter, MistralHeadConverter, MistralHuggingfaceCheckpointHandler, @@ -229,23 +223,6 @@ def get_converters( ] -class AprielMLPConverter(LlamaMLPConverter): - @classmethod - def import_config(cls, config: dict) -> dict: - config["mlp_bias"] = False - return super().import_config(config) - - @classmethod - def export_config(cls, config: MLPConfig) -> dict: - out = super().export_config(config) - del out["mlp_bias"] - return out - - -class AprielBlockConverterBase(MistralBlockConverter): - mlp_converter_class: typing.ClassVar[type[AprielMLPConverter]] = AprielMLPConverter - - class AprielDiscreteMamba2BlockConverter(AprielBlockConverterBase): mixer_converter_class: typing.ClassVar[type[AprielDiscreteMamba2Converter]] = AprielDiscreteMamba2Converter hf_mixer_name: typing.ClassVar[str] = "mixer" From 99c42c06d8af13df9bfeaece1e9425de4f9932fb Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 17 Nov 2025 20:56:35 +0000 Subject: [PATCH 12/14] fix --- fast_llm/models/gpt/conversion/apriel.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index ffd2522c..e16eac4d 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -14,6 +14,7 @@ from fast_llm.models.gpt.conversion.llama import get_parameter_converter, get_weight_and_bias_converters from fast_llm.models.gpt.conversion.mistral import ( MistralBaseModelConverter, + MistralBlockConverter, MistralDecoderConverter, MistralHeadConverter, MistralHuggingfaceCheckpointHandler, @@ -223,12 +224,12 @@ def get_converters( ] -class AprielDiscreteMamba2BlockConverter(AprielBlockConverterBase): +class AprielDiscreteMamba2BlockConverter(MistralBlockConverter): mixer_converter_class: typing.ClassVar[type[AprielDiscreteMamba2Converter]] = AprielDiscreteMamba2Converter hf_mixer_name: typing.ClassVar[str] = "mixer" -class AprielMamba2BlockConverter(AprielBlockConverterBase): +class AprielMamba2BlockConverter(MistralBlockConverter): mixer_converter_class: typing.ClassVar[type[AprielMamba2Converter]] = AprielMamba2Converter hf_mixer_name: typing.ClassVar[str] = "mixer" @@ -240,7 +241,7 @@ class AprielBlockConverter: DiscreteMamba2Config: "m2d", } _converter_classes = { - AttentionConfig: AprielBlockConverterBase, + AttentionConfig: MistralBlockConverter, Mamba2Config: AprielMamba2BlockConverter, DiscreteMamba2Config: AprielDiscreteMamba2BlockConverter, } From d3df7a567e5a8a02400d79bc4aee735ff10a0942 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 17 Nov 2025 21:12:37 +0000 Subject: [PATCH 13/14] move assert --- fast_llm/models/gpt/conversion/mistral.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/fast_llm/models/gpt/conversion/mistral.py b/fast_llm/models/gpt/conversion/mistral.py index 28941bc8..a9a0909e 100644 --- a/fast_llm/models/gpt/conversion/mistral.py +++ b/fast_llm/models/gpt/conversion/mistral.py @@ -48,14 +48,11 @@ def import_config(cls, config: dict) -> dict: @classmethod def export_config(cls, config: MLPConfig) -> dict: + assert not config.add_linear_biases out = super().export_config(config) del out["mlp_bias"] return out - @classmethod - def _check_config(cls, config: MLPConfig) -> None: - assert not config.add_linear_biases - class MistralBlockConverter(LlamaBlockConverter): mixer_converter_class: typing.ClassVar[type[MistralAttentionConverter]] = MistralAttentionConverter From 8e04abaaa32e3a84387738e211e25bc96824502e Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 17 Nov 2025 21:58:00 +0000 Subject: [PATCH 14/14] remove tp-1 check for reference models --- fast_llm/engine/training/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 531bc206..867cca98 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -361,7 +361,6 @@ def _validate(self) -> None: # TODO: Add support. Assert.eq(self.model.distributed.pipeline_parallel, 1) # TODO: Check if these work. - Assert.eq(self.model.distributed.tensor_parallel, 1) Assert.eq(self.model.distributed.sequence_data_parallel, 1) if self.run.experiment_dir is None: assert not self.training.checkpoint.enabled()