Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fast_llm/engine/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def preprocess_batch(
phase: PhaseType,
iteration: int,
metrics: dict | None = None,
setup_activation_storage: bool = False,
) -> list[tuple[torch.Tensor, dict]]:
# TODO Move batch splitting elsewhere, align interface with LayerBase
pass
Expand Down
1 change: 0 additions & 1 deletion fast_llm/engine/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,6 @@ def _validate(self) -> None:
# TODO: Add support.
Assert.eq(self.model.distributed.pipeline_parallel, 1)
# TODO: Check if these work.
Assert.eq(self.model.distributed.tensor_parallel, 1)
Assert.eq(self.model.distributed.sequence_data_parallel, 1)
if self.run.experiment_dir is None:
assert not self.training.checkpoint.enabled()
Expand Down
3 changes: 3 additions & 0 deletions fast_llm/layers/block/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class BlockKwargs:
sequence_lengths = "sequence_lengths"
# TODO: Belongs elsewhere?
grad_output = "grad_output"
activation_distillation_storage = "activation_distillation_storage"
activation_distillation_targets = "activation_distillation_targets"
activation_distillation_total = "activation_distillation_total"


@config_class(registry=True)
Expand Down
60 changes: 59 additions & 1 deletion fast_llm/layers/decoder/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.layers.block.block import Block
from fast_llm.layers.block.config import BlockKwargs
from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.layers.decoder.config import BlockWithBiasConfig, DecoderBlockConfig
from fast_llm.layers.language_model.head import _format_name
from fast_llm.tensor import TensorMeta
from fast_llm.utils import Assert

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -136,6 +139,9 @@ def forward(
if self._debug.enabled:
self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs)
hidden_states, bias = self.mixer(hidden_states, kwargs)

hidden_states, bias = self.activation_distillation_loss(hidden_states, bias, kwargs, losses)

if self._debug.enabled:
self._debug(
hidden_states if bias is None else hidden_states + bias,
Expand Down Expand Up @@ -166,6 +172,42 @@ def forward(
hidden_states = torch.stack((fw_input, hidden_states), dim=0)
return hidden_states

def activation_distillation_loss(self, hidden_states, bias, kwargs, losses):
"""
Maybe apply activation distillation loss and setup backward hooks
"""
mixer_output = hidden_states if bias is None else hidden_states + bias
# Teacher populates mixer activations for distillation.
activation_storage = kwargs.get(BlockKwargs.activation_distillation_storage)
if activation_storage is not None:
activation_storage[self.module_name] = mixer_output.detach()
# Student gets teacher activations and computes the activation-level loss.
activation_targets = kwargs.get(BlockKwargs.activation_distillation_targets)
if (
activation_targets is not None
and self.training
and (teacher_output := activation_targets.pop(self.module_name, None)) is not None
):
# Compare student mixer output with the teacher’s stored activation and accumulate the loss.
teacher_tensor = teacher_output.detach().to(device=mixer_output.device, dtype=mixer_output.dtype)
Assert.eq(teacher_tensor.shape, mixer_output.shape)
# TODO: handle sequence-first?
# TODO: un-scaled loss for reporting? Average loss over layers?
# L2 loss
activation_loss_factor = self._config.activation_distillation_factor
# (batch, sequence, hidden). Take the norm over hidden dim.
# TODO: handle possible padding?
activation_loss = activation_loss_factor * torch.mean(
torch.norm(mixer_output - teacher_tensor, p=2, dim=(2))
)
# Backward hooks
hidden_states = AuxiliaryLoss.apply(hidden_states, activation_loss, 1.0)
bias = AuxiliaryLoss.apply(bias, activation_loss, 1.0) if bias is not None else None
# Logging
if losses is not None and self._activation_distillation_loss_name in losses:
losses[self._activation_distillation_loss_name].append(activation_loss.detach())
return hidden_states, bias

def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int:
# TODO: Add marginal compute? (normalization, bias_dropout_add)
return sum(
Expand All @@ -179,5 +221,21 @@ def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None
self.mixer.preprocess(batch, kwargs)
self.mlp.preprocess(batch, kwargs)

# TODO: add layer_index
_activation_distillation_loss_name = "activation_distillation_loss"

def get_loss_definitions(self, count: int = 1) -> list[LossDef]:
return self.mixer.get_loss_definitions(count=count) + self.mlp.get_loss_definitions(count=count)
loss_definitions = []
if self._config.activation_distillation_factor > 0.0 and self._config.distillation_model is not None:
loss_definitions.append(
LossDef(
name=self._activation_distillation_loss_name,
formatted_name=_format_name(self._activation_distillation_loss_name),
count=count,
)
)
return (
loss_definitions
+ self.mixer.get_loss_definitions(count=count)
+ self.mlp.get_loss_definitions(count=count)
)
16 changes: 16 additions & 0 deletions fast_llm/layers/decoder/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,22 @@ class DecoderBlockConfig(BlockConfig):
hint=FieldHint.feature,
valid=check_field(Assert.geq, 0),
)
distillation_model: str | None = Field(
default=None,
desc="Name of the reference model to use for activation-level distillation.",
hint=FieldHint.feature,
)
activation_distillation_factor: float = Field(
default=0.0,
desc="Factor to scale the activation-level distillation loss by.",
hint=FieldHint.feature,
valid=check_field(Assert.geq, 0),
)

def _validate(self) -> None:
super()._validate()
if self.activation_distillation_factor > 0.0 and self.distillation_model is None:
raise ValueError("Activation distillation requires a distillation_model.")

@property
def layer_class(self) -> "type[DecoderBlock]":
Expand Down
30 changes: 4 additions & 26 deletions fast_llm/models/gpt/conversion/apriel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,10 @@
from fast_llm.layers.attention.config import AttentionConfig
from fast_llm.layers.block.config import BlockSequenceConfig, FixedBlockSequenceConfig, PatternBlockSequenceConfig
from fast_llm.layers.decoder.config import DecoderBlockConfig
from fast_llm.layers.decoder.mlp.config import MLPConfig
from fast_llm.layers.ssm.config import DiscreteMamba2Config, Mamba2Config
from fast_llm.models.gpt.config import GPTModelConfig
from fast_llm.models.gpt.conversion.config import AprielHybridSSMCheckpointFormat
from fast_llm.models.gpt.conversion.llama import (
LlamaMLPConverter,
get_parameter_converter,
get_weight_and_bias_converters,
)
from fast_llm.models.gpt.conversion.llama import get_parameter_converter, get_weight_and_bias_converters
from fast_llm.models.gpt.conversion.mistral import (
MistralBaseModelConverter,
MistralBlockConverter,
Expand Down Expand Up @@ -229,29 +224,12 @@ def get_converters(
]


class AprielMLPConverter(LlamaMLPConverter):
@classmethod
def import_config(cls, config: dict) -> dict:
config["mlp_bias"] = False
return super().import_config(config)

@classmethod
def export_config(cls, config: MLPConfig) -> dict:
out = super().export_config(config)
del out["mlp_bias"]
return out


class AprielBlockConverterBase(MistralBlockConverter):
mlp_converter_class: typing.ClassVar[type[AprielMLPConverter]] = AprielMLPConverter


class AprielDiscreteMamba2BlockConverter(AprielBlockConverterBase):
class AprielDiscreteMamba2BlockConverter(MistralBlockConverter):
mixer_converter_class: typing.ClassVar[type[AprielDiscreteMamba2Converter]] = AprielDiscreteMamba2Converter
hf_mixer_name: typing.ClassVar[str] = "mixer"


class AprielMamba2BlockConverter(AprielBlockConverterBase):
class AprielMamba2BlockConverter(MistralBlockConverter):
mixer_converter_class: typing.ClassVar[type[AprielMamba2Converter]] = AprielMamba2Converter
hf_mixer_name: typing.ClassVar[str] = "mixer"

Expand All @@ -263,7 +241,7 @@ class AprielBlockConverter:
DiscreteMamba2Config: "m2d",
}
_converter_classes = {
AttentionConfig: AprielBlockConverterBase,
AttentionConfig: MistralBlockConverter,
Mamba2Config: AprielMamba2BlockConverter,
DiscreteMamba2Config: AprielDiscreteMamba2BlockConverter,
}
Expand Down
17 changes: 17 additions & 0 deletions fast_llm/models/gpt/conversion/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from fast_llm.engine.checkpoint.config import CheckpointFormat
from fast_llm.layers.attention.config import AttentionConfig
from fast_llm.layers.decoder.mlp.config import MLPConfig
from fast_llm.models.gpt.conversion.config import MistralCheckpointFormat
from fast_llm.models.gpt.conversion.llama import (
LlamaAttentionConverter,
Expand All @@ -10,6 +11,7 @@
LlamaDecoderConverter,
LlamaHeadConverter,
LlamaHuggingfaceCheckpointHandler,
LlamaMLPConverter,
)
from fast_llm.utils import safe_merge_dicts

Expand Down Expand Up @@ -38,8 +40,23 @@ def _check_config(cls, config: AttentionConfig) -> None:
assert not config.add_linear_biases


class MistrallMLPConverter(LlamaMLPConverter):
@classmethod
def import_config(cls, config: dict) -> dict:
config["mlp_bias"] = False
return super().import_config(config)

@classmethod
def export_config(cls, config: MLPConfig) -> dict:
assert not config.add_linear_biases
out = super().export_config(config)
del out["mlp_bias"]
return out


class MistralBlockConverter(LlamaBlockConverter):
mixer_converter_class: typing.ClassVar[type[MistralAttentionConverter]] = MistralAttentionConverter
mlp_converter_class: typing.ClassVar[type[MistrallMLPConverter]] = MistrallMLPConverter


class MistralDecoderConverter(LlamaDecoderConverter):
Expand Down
28 changes: 25 additions & 3 deletions fast_llm/models/gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from fast_llm.engine.inference.runner import InferenceRunner
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
from fast_llm.layers.attention.config import AttentionKwargs
from fast_llm.layers.block.config import BlockDimNames
from fast_llm.layers.block.config import BlockDimNames, BlockKwargs
from fast_llm.layers.language_model.config import LanguageModelKwargs
from fast_llm.layers.language_model.language_model import LanguageModel
from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig
Expand Down Expand Up @@ -157,6 +157,7 @@ def preprocess_batch(
phase: PhaseType,
iteration: int,
metrics: dict | None = None,
setup_activation_storage: bool = False,
) -> list[tuple[torch.Tensor, dict]]:
# TODO Move batch splitting elsewhere, align interface with LayerBase
assert self._is_setup
Expand All @@ -175,21 +176,34 @@ def preprocess_batch(
non_blocking=True,
)

# TODO: decoder doesn't necessarily have a `block` attribute
distillation_model = self._config.decoder.block.distillation_model
activation_factor = self._config.decoder.block.activation_distillation_factor
reference_logits: list[dict[str, typing.Any]] | None = None
reference_logits = [{} for _ in preprocessed_meta]
for name, reference_model in self._reference_models.items():
reference_preprocessed_meta = [
(tokens_meta, kwargs_meta["reference_models"][name]) for tokens_meta, kwargs_meta in preprocessed_meta
]

reference_batch = reference_model.fast_llm_model.base_model.preprocess_batch(
batch, reference_preprocessed_meta, phase=PhaseType.inference, iteration=iteration
batch,
reference_preprocessed_meta,
phase=PhaseType.inference,
iteration=iteration,
setup_activation_storage=activation_factor > 0.0 and distillation_model == name,
)

# TODO: Do things work with >1?
Assert.eq(len(reference_batch), len(preprocessed_meta), 1)
for i, (reference_tokens, reference_kwargs) in enumerate(reference_batch):
reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration)
reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"]
if BlockKwargs.activation_distillation_storage in reference_kwargs:
reference_logits[i][f"{name}_activations"] = reference_kwargs[
BlockKwargs.activation_distillation_storage
]
del reference_kwargs[BlockKwargs.activation_distillation_storage]

token_ids = batch.token_ids
if sequence_first:
Expand Down Expand Up @@ -255,7 +269,13 @@ def preprocess_batch(
kwargs[LanguageModelKwargs.loss_mask] = loss_mask
labels = torch.where(loss_mask, labels, -100)
kwargs[LanguageModelKwargs.labels] = labels
kwargs.update(reference_logits[i])
if reference_logits is not None:
reference_payload = reference_logits[i]
kwargs.update(reference_payload)
if distillation_model is not None and activation_factor > 0.0:
teacher_key = f"{distillation_model}_activations"
if teacher_key in reference_payload:
kwargs[BlockKwargs.activation_distillation_targets] = reference_payload.pop(teacher_key)

if batch.chosen_spans is not None:
chosen_valid_spans = []
Expand Down Expand Up @@ -288,6 +308,8 @@ def preprocess_batch(
rejected_valid_spans.append(valid_spans)
kwargs[LanguageModelKwargs.rejected_spans] = rejected_valid_spans

if setup_activation_storage:
kwargs.setdefault(BlockKwargs.activation_distillation_storage, {})
self.preprocess(tokens, kwargs)
preprocessed.append((tokens, kwargs))

Expand Down