diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 5df59d4c..106bea21 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -145,6 +145,7 @@ def preprocess_batch( phase: PhaseType, iteration: int, metrics: dict | None = None, + setup_activation_storage: bool = False, ) -> list[tuple[torch.Tensor, dict]]: # TODO Move batch splitting elsewhere, align interface with LayerBase pass diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 531bc206..867cca98 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -361,7 +361,6 @@ def _validate(self) -> None: # TODO: Add support. Assert.eq(self.model.distributed.pipeline_parallel, 1) # TODO: Check if these work. - Assert.eq(self.model.distributed.tensor_parallel, 1) Assert.eq(self.model.distributed.sequence_data_parallel, 1) if self.run.experiment_dir is None: assert not self.training.checkpoint.enabled() diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index f3e93ede..dfc80a47 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -37,6 +37,9 @@ class BlockKwargs: sequence_lengths = "sequence_lengths" # TODO: Belongs elsewhere? grad_output = "grad_output" + activation_distillation_storage = "activation_distillation_storage" + activation_distillation_targets = "activation_distillation_targets" + activation_distillation_total = "activation_distillation_total" @config_class(registry=True) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 8b19db66..05df6e67 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -11,9 +11,12 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import BlockWithBiasConfig, DecoderBlockConfig +from fast_llm.layers.language_model.head import _format_name from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -136,6 +139,9 @@ def forward( if self._debug.enabled: self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) hidden_states, bias = self.mixer(hidden_states, kwargs) + + hidden_states, bias = self.activation_distillation_loss(hidden_states, bias, kwargs, losses) + if self._debug.enabled: self._debug( hidden_states if bias is None else hidden_states + bias, @@ -166,6 +172,42 @@ def forward( hidden_states = torch.stack((fw_input, hidden_states), dim=0) return hidden_states + def activation_distillation_loss(self, hidden_states, bias, kwargs, losses): + """ + Maybe apply activation distillation loss and setup backward hooks + """ + mixer_output = hidden_states if bias is None else hidden_states + bias + # Teacher populates mixer activations for distillation. + activation_storage = kwargs.get(BlockKwargs.activation_distillation_storage) + if activation_storage is not None: + activation_storage[self.module_name] = mixer_output.detach() + # Student gets teacher activations and computes the activation-level loss. + activation_targets = kwargs.get(BlockKwargs.activation_distillation_targets) + if ( + activation_targets is not None + and self.training + and (teacher_output := activation_targets.pop(self.module_name, None)) is not None + ): + # Compare student mixer output with the teacher’s stored activation and accumulate the loss. + teacher_tensor = teacher_output.detach().to(device=mixer_output.device, dtype=mixer_output.dtype) + Assert.eq(teacher_tensor.shape, mixer_output.shape) + # TODO: handle sequence-first? + # TODO: un-scaled loss for reporting? Average loss over layers? + # L2 loss + activation_loss_factor = self._config.activation_distillation_factor + # (batch, sequence, hidden). Take the norm over hidden dim. + # TODO: handle possible padding? + activation_loss = activation_loss_factor * torch.mean( + torch.norm(mixer_output - teacher_tensor, p=2, dim=(2)) + ) + # Backward hooks + hidden_states = AuxiliaryLoss.apply(hidden_states, activation_loss, 1.0) + bias = AuxiliaryLoss.apply(bias, activation_loss, 1.0) if bias is not None else None + # Logging + if losses is not None and self._activation_distillation_loss_name in losses: + losses[self._activation_distillation_loss_name].append(activation_loss.detach()) + return hidden_states, bias + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: # TODO: Add marginal compute? (normalization, bias_dropout_add) return sum( @@ -179,5 +221,21 @@ def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None self.mixer.preprocess(batch, kwargs) self.mlp.preprocess(batch, kwargs) + # TODO: add layer_index + _activation_distillation_loss_name = "activation_distillation_loss" + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return self.mixer.get_loss_definitions(count=count) + self.mlp.get_loss_definitions(count=count) + loss_definitions = [] + if self._config.activation_distillation_factor > 0.0 and self._config.distillation_model is not None: + loss_definitions.append( + LossDef( + name=self._activation_distillation_loss_name, + formatted_name=_format_name(self._activation_distillation_loss_name), + count=count, + ) + ) + return ( + loss_definitions + + self.mixer.get_loss_definitions(count=count) + + self.mlp.get_loss_definitions(count=count) + ) diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 403b204c..99331ee7 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -88,6 +88,22 @@ class DecoderBlockConfig(BlockConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) + distillation_model: str | None = Field( + default=None, + desc="Name of the reference model to use for activation-level distillation.", + hint=FieldHint.feature, + ) + activation_distillation_factor: float = Field( + default=0.0, + desc="Factor to scale the activation-level distillation loss by.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + + def _validate(self) -> None: + super()._validate() + if self.activation_distillation_factor > 0.0 and self.distillation_model is None: + raise ValueError("Activation distillation requires a distillation_model.") @property def layer_class(self) -> "type[DecoderBlock]": diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index 7550df04..e16eac4d 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -8,15 +8,10 @@ from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.block.config import BlockSequenceConfig, FixedBlockSequenceConfig, PatternBlockSequenceConfig from fast_llm.layers.decoder.config import DecoderBlockConfig -from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.layers.ssm.config import DiscreteMamba2Config, Mamba2Config from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import AprielHybridSSMCheckpointFormat -from fast_llm.models.gpt.conversion.llama import ( - LlamaMLPConverter, - get_parameter_converter, - get_weight_and_bias_converters, -) +from fast_llm.models.gpt.conversion.llama import get_parameter_converter, get_weight_and_bias_converters from fast_llm.models.gpt.conversion.mistral import ( MistralBaseModelConverter, MistralBlockConverter, @@ -229,29 +224,12 @@ def get_converters( ] -class AprielMLPConverter(LlamaMLPConverter): - @classmethod - def import_config(cls, config: dict) -> dict: - config["mlp_bias"] = False - return super().import_config(config) - - @classmethod - def export_config(cls, config: MLPConfig) -> dict: - out = super().export_config(config) - del out["mlp_bias"] - return out - - -class AprielBlockConverterBase(MistralBlockConverter): - mlp_converter_class: typing.ClassVar[type[AprielMLPConverter]] = AprielMLPConverter - - -class AprielDiscreteMamba2BlockConverter(AprielBlockConverterBase): +class AprielDiscreteMamba2BlockConverter(MistralBlockConverter): mixer_converter_class: typing.ClassVar[type[AprielDiscreteMamba2Converter]] = AprielDiscreteMamba2Converter hf_mixer_name: typing.ClassVar[str] = "mixer" -class AprielMamba2BlockConverter(AprielBlockConverterBase): +class AprielMamba2BlockConverter(MistralBlockConverter): mixer_converter_class: typing.ClassVar[type[AprielMamba2Converter]] = AprielMamba2Converter hf_mixer_name: typing.ClassVar[str] = "mixer" @@ -263,7 +241,7 @@ class AprielBlockConverter: DiscreteMamba2Config: "m2d", } _converter_classes = { - AttentionConfig: AprielBlockConverterBase, + AttentionConfig: MistralBlockConverter, Mamba2Config: AprielMamba2BlockConverter, DiscreteMamba2Config: AprielDiscreteMamba2BlockConverter, } diff --git a/fast_llm/models/gpt/conversion/mistral.py b/fast_llm/models/gpt/conversion/mistral.py index b5db3fa0..a9a0909e 100644 --- a/fast_llm/models/gpt/conversion/mistral.py +++ b/fast_llm/models/gpt/conversion/mistral.py @@ -2,6 +2,7 @@ from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.models.gpt.conversion.config import MistralCheckpointFormat from fast_llm.models.gpt.conversion.llama import ( LlamaAttentionConverter, @@ -10,6 +11,7 @@ LlamaDecoderConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, + LlamaMLPConverter, ) from fast_llm.utils import safe_merge_dicts @@ -38,8 +40,23 @@ def _check_config(cls, config: AttentionConfig) -> None: assert not config.add_linear_biases +class MistrallMLPConverter(LlamaMLPConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + config["mlp_bias"] = False + return super().import_config(config) + + @classmethod + def export_config(cls, config: MLPConfig) -> dict: + assert not config.add_linear_biases + out = super().export_config(config) + del out["mlp_bias"] + return out + + class MistralBlockConverter(LlamaBlockConverter): mixer_converter_class: typing.ClassVar[type[MistralAttentionConverter]] = MistralAttentionConverter + mlp_converter_class: typing.ClassVar[type[MistrallMLPConverter]] = MistrallMLPConverter class MistralDecoderConverter(LlamaDecoderConverter): diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index efa348ec..17187d0b 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -10,7 +10,7 @@ from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.block.config import BlockDimNames +from fast_llm.layers.block.config import BlockDimNames, BlockKwargs from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.language_model import LanguageModel from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig @@ -157,6 +157,7 @@ def preprocess_batch( phase: PhaseType, iteration: int, metrics: dict | None = None, + setup_activation_storage: bool = False, ) -> list[tuple[torch.Tensor, dict]]: # TODO Move batch splitting elsewhere, align interface with LayerBase assert self._is_setup @@ -175,6 +176,10 @@ def preprocess_batch( non_blocking=True, ) + # TODO: decoder doesn't necessarily have a `block` attribute + distillation_model = self._config.decoder.block.distillation_model + activation_factor = self._config.decoder.block.activation_distillation_factor + reference_logits: list[dict[str, typing.Any]] | None = None reference_logits = [{} for _ in preprocessed_meta] for name, reference_model in self._reference_models.items(): reference_preprocessed_meta = [ @@ -182,7 +187,11 @@ def preprocess_batch( ] reference_batch = reference_model.fast_llm_model.base_model.preprocess_batch( - batch, reference_preprocessed_meta, phase=PhaseType.inference, iteration=iteration + batch, + reference_preprocessed_meta, + phase=PhaseType.inference, + iteration=iteration, + setup_activation_storage=activation_factor > 0.0 and distillation_model == name, ) # TODO: Do things work with >1? @@ -190,6 +199,11 @@ def preprocess_batch( for i, (reference_tokens, reference_kwargs) in enumerate(reference_batch): reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"] + if BlockKwargs.activation_distillation_storage in reference_kwargs: + reference_logits[i][f"{name}_activations"] = reference_kwargs[ + BlockKwargs.activation_distillation_storage + ] + del reference_kwargs[BlockKwargs.activation_distillation_storage] token_ids = batch.token_ids if sequence_first: @@ -255,7 +269,13 @@ def preprocess_batch( kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) kwargs[LanguageModelKwargs.labels] = labels - kwargs.update(reference_logits[i]) + if reference_logits is not None: + reference_payload = reference_logits[i] + kwargs.update(reference_payload) + if distillation_model is not None and activation_factor > 0.0: + teacher_key = f"{distillation_model}_activations" + if teacher_key in reference_payload: + kwargs[BlockKwargs.activation_distillation_targets] = reference_payload.pop(teacher_key) if batch.chosen_spans is not None: chosen_valid_spans = [] @@ -288,6 +308,8 @@ def preprocess_batch( rejected_valid_spans.append(valid_spans) kwargs[LanguageModelKwargs.rejected_spans] = rejected_valid_spans + if setup_activation_storage: + kwargs.setdefault(BlockKwargs.activation_distillation_storage, {}) self.preprocess(tokens, kwargs) preprocessed.append((tokens, kwargs))