From 82eed2b44c30c891ef2e07c2c80c4f5fcfa1e7f1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 21 Jul 2025 17:17:26 -0400 Subject: [PATCH 01/82] TP mamba --- fast_llm/layers/common/config.py | 6 +- fast_llm/layers/ssm/config.py | 214 +++++++++---- fast_llm/layers/ssm/discrete_mamba2.py | 39 ++- fast_llm/layers/ssm/llamba_block.py | 18 +- fast_llm/layers/ssm/mamba2.py | 302 +++++++----------- fast_llm/layers/ssm/mamba_layer.py | 159 ++++----- fast_llm/layers/transformer/attention.py | 3 +- fast_llm/layers/transformer/transformer.py | 27 +- fast_llm/models/custom/model.py | 4 +- fast_llm/models/gpt/model.py | 8 +- fast_llm/models/ssm/config.py | 42 +-- .../external/llamba/modeling_mtp_llamba.py | 10 +- fast_llm/models/ssm/model.py | 34 +- fast_llm/tensor.py | 8 +- setup.cfg | 2 +- tests/test_multi_stage.py | 4 +- 16 files changed, 407 insertions(+), 473 deletions(-) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 9f32ac689..07dadbc22 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -99,7 +99,7 @@ class LayerNormalizationBaseConfig(NormalizationConfig): ) def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": - from fast_llm.tensor import init_uniform_ + from fast_llm.tensor import init_uniform_centered_ kwargs = { "hidden_dim": hidden_dim, @@ -110,9 +110,7 @@ def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> " } if self.initialization_range: mean = 0 if self.zero_centered else 1 - kwargs["weight_init_method"] = init_uniform_( - mean - self.initialization_range, mean + self.initialization_range - ) + kwargs["weight_init_method"] = init_uniform_centered_(self.initialization_range, mean=mean) return self.module_class(**kwargs) @property diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index c69ada389..f4c8067dd 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,28 +1,35 @@ import enum from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace +from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig -from fast_llm.utils import Assert +from fast_llm.utils import Assert, div class SSMDimNames: - model_dim = "model_dim" # Model dimension (D) - state_dim = "state_dim" # State dimension (N) - conv_dim = "conv_dim" # Dimension of the conv1d input in mamba layers - inner_dim = "inner_dim" # Inner dimension after expansion - dt_rank = "dt_rank" # Rank of Δ - inner_proj_mamba = "inner_proj_mamba" # Inner projection dimension for mamba - inner_proj_discrete_mamba2 = "inner_proj_discrete_mamba2" # Inner projection dimension for discrete mamba2 - inner_proj_mamba2 = "inner_proj_mamba2" # Inner projection dimension for mamba2 - x_proj_dim = "x_proj_dim" # X projection dimension - head_dim = "head_dim" # Dimension of the mamba2 head (P) - conv_kernel_size = "conv_kernel_size" # Kernel size of the conv1d in mamba layers - qk_heads = "qk_heads" # Number of QK heads - v_heads = "v_heads" # Number of V heads + # TODO: Use separate tensor space for different mixers so there is no risk of name conflict. + state = "ssm_state" # State dimension (N), aka head size / num channels + + head_groups = "ssm_head_groups" + group_heads = "ssm_group_heads" + + composite_heads = "ssm_composite_heads" + composite_heads_and_state = "ssm_composite_heads_and_state" + composite_head_groups_and_state = "ssm_composite_head_groups_and_state" + + # Inner projection total dimension. + inner_projection = "ssm_inner_projection" + composite_inner_projection = "ssm_composite_inner_projection" + + # Convolution shape in discrete mamba 2. TODO: Remove (dim too complex) + conv_dim = "ssm_conv_dim" + + dt_rank = "ssm_dt_rank" - # Mamba 2 - x_proj_dim_2 = "x_proj_dim" # d_xb + x_proj_dim = "x_proj_dim" # X projection dimension + conv_kernel = "conv_kernel" # Kernel size of the conv1d in mamba layers class SSMBlockType(enum.StrEnum): @@ -36,6 +43,16 @@ class SSMBlockType(enum.StrEnum): transformer = "t" +class DTInitType(enum.StrEnum): + constant = "constant" + random = "random" + + def get_init_method(self, scale: float): + from fast_llm.tensor import init_fill_, init_uniform_centered_ + + return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) + + @config_class() class SSMConfig(LLMBlockConfig): _abstract = False @@ -45,79 +62,87 @@ class SSMConfig(LLMBlockConfig): desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, ) + + # Model dimensions + # TODO: Remove (redundant default) expansion_factor: int = Field( default=2, desc="Expansion factor for Mamba blocks.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + # head_size [MambaLayer, Mamba2, DiscreteMamba2] state_size: int = Field( default=16, desc="State size for Mamba blocks.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + # [MambaLayer, Mamba2, DiscreteMamba2] conv_kernel_dimension: int = Field( default=4, desc="Conv kernel dimension for Mamba blocks.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - # Layer parameters - add_bias_linear: bool = Field( - default=False, - desc="Whether to use bias in SSM layers", - hint=FieldHint.architecture, - ) - + # [MambaLayer, Mamba2] dt_rank: None | int = Field( default=None, desc="Rank of the Δ projection matrix. If 'None', will be set to ceil(hidden_size/16)", hint=FieldHint.architecture, ) - chunk_size: int = Field( - default=256, - desc="Chunk size for Mamba2 blocks.", - hint=FieldHint.architecture, - ) + # head_groups [DiscreteMamba2] n_qk_heads: int = Field( default=32, desc="Number of QK heads for Mamba2 blocks.", hint=FieldHint.architecture, ) + # heads [DiscreteMamba2]# TODO: Remove? (redundant) n_v_heads: int = Field( default=32, desc="Number of V heads for Mamba2 blocks.", hint=FieldHint.architecture, ) - activation_type: ActivationType = Field( + # c_size [MambaLayer, Mamba2, DiscreteMamba2]? + d_inner: None | int = Field( + default=None, + desc="Inner dimension for Mamba2 blocks.", + hint=FieldHint.core, + ) + # xb_size [Mamba2] + d_xb: int = Field( default=None, - desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", + desc="Dimension of the xB in Mamba2 blocks.", hint=FieldHint.architecture, ) - debug_ssm: bool = Field( + + # Model options + # add_bias_linear [Mamba2, DiscreteMamba2] [hard-coded to False in MambaLayer] + add_bias_linear: bool = Field( default=False, - desc="debug_ssm", - hint=FieldHint.optional, + desc="Whether to use bias in SSM layers", + hint=FieldHint.architecture, ) - dt_min: float = Field( - default=0.001, - desc="Minimum step size for discretization", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), + # activation_type [DiscreteMamba2] [hard-coded to silu in MambaLayer, Mamba2] + activation_type: ActivationType = Field( + default=None, + hint=FieldHint.architecture, ) - dt_init_floor: float = Field( - default=1e-4, - desc="Minimum value for initializing dt", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), + # repeat_xb_before_conv [Mamba2] + repeat_kv_before_conv: bool = Field( + default=True, + desc="Whether to repeat x and B before (True) or after (False) the conv1d in Mamba2 blocks.", + hint=FieldHint.architecture, ) - - d_inner: None | int = Field( - default=None, - desc="Inner dimension for Mamba2 blocks.", - hint=FieldHint.core, + # chunk_size [DiscreteMamba2] + chunk_size: int = Field( + default=256, + desc="Chunk size for Mamba2 blocks.", + hint=FieldHint.architecture, ) + + # Learning rate + # lr_scale [MambaLayer, Mamba2, DiscreteMamba2] mamba_lr_scale: float | None = Field( default=None, desc="Learning rate scale for Mamba blocks.", @@ -125,43 +150,38 @@ class SSMConfig(LLMBlockConfig): valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - # Mamba 2 - repeat_kv_before_conv: bool = Field( - default=True, - desc="Whether to repeat the KV before the conv1d in Mamba2 blocks.", - hint=FieldHint.architecture, - ) - d_xb: int = Field( - default=None, - desc="Dimension of the xB in Mamba2 blocks.", - hint=FieldHint.architecture, - ) - dt_init: str = Field( + # Initialization + # dt_weight_initialization_method [Mamba2] + dt_init: DTInitType = Field( default="random", desc="Initialization method for dt", hint=FieldHint.core, ) - dt_max: float = Field( - default=0.1, - desc="Maximum step size for discretization", + # dt_weight_initialization_scale [Mamba2] + dt_scale: float = Field( + default=1.0, + desc="Scale for dt", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) + # dt_bias_initialization_min [MambaLayer, Mamba2] dt_min: float = Field( default=0.001, desc="Minimum step size for discretization", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - dt_init_floor: float = Field( - default=1e-4, - desc="Minimum value for initializing dt", + # dt_bias_initialization_max [MambaLayer, Mamba2] + dt_max: float = Field( + default=0.1, + desc="Maximum step size for discretization", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - dt_scale: float = Field( - default=1.0, - desc="Scale for dt", + # dt_bias_initialization_floor [MambaLayer, Mamba2] + dt_init_floor: float = Field( + default=1e-4, + desc="Minimum value for initializing dt", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) @@ -172,3 +192,59 @@ def _validate(self) -> None: self.activation_type = ActivationType.silu super()._validate() Assert.geq(self.dt_max, self.dt_min) + + def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType) -> None: + tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) + + num_heads = div(self.d_inner, self.state_size) + # Head groups are configured differently depending on the block type. + if block_type == SSMBlockType.mamba: + num_head_groups = num_heads + # (head_groups, 2 * group_heads * state_dim) + inner_projection_size = self.d_inner * 2 + elif block_type == SSMBlockType.mamba2: + num_head_groups = div(self.d_xb, self.state_size) + # (head_groups, 2 * group_heads + 2, state_dim) + (dt,) + inner_projection_size: int = 2 * self.d_inner + 2 * num_head_groups * self.state_size + self.dt_rank + elif block_type == SSMBlockType.mamba2_discrete: + Assert.eq(num_heads, self.n_v_heads) + num_head_groups = self.n_qk_heads + # (head_groups, (2 * group_heads + 2) * state_dim + group_heads) + inner_projection_size = 2 * self.d_inner + 2 * num_head_groups * self.state_size + num_heads + else: + raise NotImplementedError(block_type) + + tensor_space.add_tensor_dim(state_dim := TensorDim(SSMDimNames.state, self.state_size)) + tensor_space.add_tensor_dim(head_groups := TensorDim(SSMDimNames.head_groups, num_head_groups, tensor)) + tensor_space.add_tensor_dim( + group_heads := TensorDim(SSMDimNames.group_heads, num_group_heads := div(num_heads, num_head_groups)) + ) + tensor_space.add_tensor_dim(CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads))) + tensor_space.add_tensor_dim( + CompositeTensorDim(SSMDimNames.composite_heads_and_state, (head_groups, group_heads, state_dim)) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(SSMDimNames.composite_head_groups_and_state, (head_groups, state_dim)) + ) + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel, self.conv_kernel_dimension)) + + # DT projection + if block_type in (SSMBlockType.mamba, SSMBlockType.mamba2): + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.dt_rank, self.dt_rank)) + + if block_type == SSMBlockType.mamba: + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim, self.dt_rank + self.state_size * 2)) + inner_projection_size = 2 * num_group_heads * self.state_size + elif block_type == SSMBlockType.mamba2: + inner_projection_size = 2 * (num_group_heads + 1) * self.state_size + elif block_type == SSMBlockType.mamba2_discrete: + inner_projection_size = 2 * (num_group_heads + 1) * self.state_size + num_group_heads + # TODO: (head_groups, group_heads + 2, state_size) + tensor_space.add_tensor_dim( + TensorDim(SSMDimNames.conv_dim, self.d_inner + 2 * self.n_qk_heads * self.state_size) + ) + + tensor_space.add_tensor_dim(inner_projection := TensorDim(SSMDimNames.inner_projection, inner_projection_size)) + tensor_space.add_tensor_dim( + CompositeTensorDim(SSMDimNames.composite_inner_projection, (head_groups, inner_projection)) + ) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 934cd2b5d..d06b47965 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -1,5 +1,6 @@ import logging import math +import typing import einops import torch @@ -7,8 +8,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ from fast_llm.utils import get_lr_scale logger = logging.getLogger(__name__) @@ -33,7 +34,7 @@ def bias_init_method(conv_weight): fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(conv_weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - return init_uniform_(-bound, bound) + return init_uniform_centered_(bound) class DiscreteMamba2(torch.nn.Module): @@ -53,21 +54,20 @@ def __init__( # factory_kwargs = {"device": "meta"} # , "dtype": torch.bfloat16} super().__init__() self.config: SSMConfig = config - bias = config.add_bias_linear self.layer_idx = layer_idx self._return_input = return_input layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) logger.info(f"Setting lr_scale for layer {layer_idx} of type {type(self)}: {mamba_layer_lr_scale}") - td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) - td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) + td_inner = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) + td_state = tensor_space.get_tensor_dim(SSMDimNames.state) + td_model = tensor_space.get_tensor_dim(TransformerDimNames.hidden) td_conv = tensor_space.get_tensor_dim(SSMDimNames.conv_dim) - td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.qk_heads) - td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.v_heads) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) - td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.inner_proj_discrete_mamba2) + td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.head_groups) + td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.composite_heads) + td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel) + td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.composite_inner_projection) self.d_model = td_model.size self.d_inner = td_inner.size @@ -85,8 +85,8 @@ def __init__( self.in_proj = Linear( td_model, td_inner_proj, - bias=bias, - weight_init_method=kaiming_init_(td_model.size), + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(td_model.size), lr_scale=mamba_layer_lr_scale, ) self.z_bias = ( @@ -96,15 +96,13 @@ def __init__( init_method=init_zeros_, lr_scale=mamba_layer_lr_scale, ) - if not bias + if not config.add_bias_linear else 0.0 ) self.conv1d_weight = ParameterMeta.from_dims( (td_conv, TensorDim("1", 1), td_conv_kernel), - init_method=init_uniform_( - 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) - ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 + init_method=init_uniform_centered_((td_conv.size * td_conv_kernel.size) ** -0.5), lr_scale=mamba_layer_lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( @@ -123,12 +121,12 @@ def __init__( self.out_proj = Linear( td_inner, td_model, - bias=bias, - weight_init_method=kaiming_init_(td_inner.size), + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(td_inner.size), lr_scale=mamba_layer_lr_scale, ) - def forward(self, hidden_states, kwargs): + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: """ ON variable names and pep8: keeping some variable names as in the original code for clarity. @@ -144,7 +142,6 @@ def forward(self, hidden_states, kwargs): raise NotImplementedError(f"Sequence-first not supported for SSMs.") assert _mamba_available - input_ = hidden_states outputs = {} # assert state is None batch, seqlen, dim = input_.shape diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/llamba_block.py index ee222d6d2..e877ff9c2 100644 --- a/fast_llm/layers/ssm/llamba_block.py +++ b/fast_llm/layers/ssm/llamba_block.py @@ -1,6 +1,6 @@ import typing -from fast_llm.layers.transformer.transformer import BaseBlock +from fast_llm.layers.transformer.transformer import BaseBlock, Mixer if typing.TYPE_CHECKING: from fast_llm.engine.config_utils.tensor_space import TensorSpace @@ -14,21 +14,19 @@ class LlambaBlock(BaseBlock): """ _name = "Llamba block" - _mixer_module_name = "mixer" def __init__( self, - config_transformer: "TransformerConfig", - config_ssm: "SSMConfig", + transformer_config: "TransformerConfig", + ssm_config: "SSMConfig", tensor_space: "TensorSpace", - mixer_cls, + mixer_cls: type[Mixer], layer_index: int, return_input: bool = False, ): - self.mixer_cls = mixer_cls - self._config_ssm = config_ssm self._debug_mode = self._config_ssm.debug_ssm - super().__init__(config_transformer, tensor_space, layer_index, return_input) + super().__init__(transformer_config, tensor_space, layer_index, return_input) + self.mixer = mixer_cls(ssm_config, layer_idx=self._layer_index, tensor_space=self._tensor_space) - def _create_mixer(self): - self.mixer = self.mixer_cls(self._config_ssm, layer_idx=self._layer_index, tensor_space=self._tensor_space) + def get_mixer(self) -> Mixer: + return self.mixer diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index a03509abb..011889d04 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -1,14 +1,15 @@ -import math -import typing - -import einops import torch from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.common.linear import Linear +from fast_llm.functional.config import ActivationType +from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.tensor import ParameterMeta, init_fill_, init_ones_, init_uniform_, kaiming_init_ -from fast_llm.utils import get_lr_scale +from fast_llm.layers.ssm.discrete_mamba2 import bias_init_method +from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.transformer import Mixer +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_ +from fast_llm.utils import Assert, div, get_lr_scale try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa @@ -25,25 +26,7 @@ _causal_conv1d_available = False -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def bias_init_method(conv_weight): - fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(conv_weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - return init_uniform_(-bound, bound) - - -class Mamba2(torch.nn.Module): +class Mamba2(Mixer): """ This code is adapted from https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py """ @@ -53,207 +36,138 @@ def __init__( config: SSMConfig, layer_idx: int, tensor_space: TensorSpace, - return_input: bool = False, ): super().__init__() - self.config: SSMConfig = config - bias: bool = config.add_bias_linear - self.layer_idx = layer_idx - self._return_input = return_input + self._config: SSMConfig = config + Assert.eq(self._config.activation_type, ActivationType.silu) layer_lr_scale: float | None = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None - mamba_layer_lr_scale: float | tuple[float | None, ...] | None = get_lr_scale( - self.config.mamba_lr_scale, layer_lr_scale - ) - - td_inner: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.inner_dim) - td_state: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.state_dim) - td_model: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.model_dim) - tdt_rank: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) - td_xb: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.x_proj_dim_2) - td_inner_proj: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.inner_proj_mamba2) - td_conv_kernel: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel_size) + lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - self.repeat_kv_before_conv = config.repeat_kv_before_conv + inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_state) + hidden_dim: TensorDim = tensor_space.get_tensor_dim(name=TransformerDimNames.hidden) - self.d_state = td_state.size - self.d_model = td_model.size - self.d_xb = td_xb.size - self.d_inner = td_inner.size - self.dt_rank = tdt_rank.size - - if self.repeat_kv_before_conv: - self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, TensorDim("1", 1), td_conv_kernel), - init_method=init_uniform_( - 1 / math.sqrt(td_inner.size * td_conv_kernel.size), - 1 / math.sqrt(td_inner.size * td_conv_kernel.size), - ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 - lr_scale=mamba_layer_lr_scale, - ) + self._head_groups = div(self._config.d_xb, self._config.state_size) + self._heads = div(self._config.d_inner, self._config.state_size) + self._group_heads = div(self._heads, self._head_groups) - self.conv1d_bias = ParameterMeta.from_dims( - (td_inner,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale - ) - else: - self.conv1d_weight = ParameterMeta.from_dims( - (td_xb, TensorDim("1", 1), td_conv_kernel), - init_method=init_uniform_( - 1 / math.sqrt(td_xb.size * td_conv_kernel.size), - 1 / math.sqrt(td_xb.size * td_conv_kernel.size), - ), - ) - self.conv1d_bias = ParameterMeta.from_dims( - (td_xb,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale - ) - - self.activation = "silu" - - self.num_xb_head = td_xb.size // td_state.size - self.num_C_head = td_inner.size // td_state.size - self.repeat_group = self.num_C_head // self.num_xb_head - - self.in_proj = Linear( - td_model, - td_inner_proj, - bias=bias, - weight_init_method=kaiming_init_(td_model.size), - lr_scale=mamba_layer_lr_scale, + conv1d_dim = ( + inner_dim + if self._config.repeat_kv_before_conv + else tensor_space.get_tensor_dim(name=SSMDimNames.composite_head_groups_and_state) ) - - # Initialize special dt projection to preserve variance at initialization - dt_scale = config.dt_scale # 1.0 - dt_init_std = self.dt_rank**-0.5 * dt_scale - if config.dt_init == "constant": - dt_init = init_fill_(dt_init_std) - elif config.dt_init == "random": - dt_init = init_uniform_(-dt_init_std, dt_init_std) - else: - raise NotImplementedError - - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - dt_max = config.dt_max # or 0.1 - dt_min = config.dt_min # or 0.001 - dt_init_floor = config.dt_init_floor # or 1e-4 - dt = torch.exp(torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)).clamp( - min=dt_init_floor + self.conv1d_weight = ParameterMeta.from_dims( + (conv1d_dim, tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel)), + init_method=init_uniform_centered_((conv1d_dim.size * self._config.conv_kernel_dimension) ** -0.5), + lr_scale=lr_scale, + ) + self.conv1d_bias = ParameterMeta.from_dims( + (conv1d_dim,), init_method=bias_init_method(self._config.conv_kernel_dimension**-0.5), lr_scale=lr_scale + ) + self.in_proj = OutputParallelLinear( + hidden_dim, + tensor_space.get_tensor_dim(name=SSMDimNames.composite_inner_projection), + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(hidden_dim.size), + lr_scale=lr_scale, ) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - - def init_from_tensor_( - value: torch.Tensor, - ) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - return tensor.copy_(value) - - return init_ - self.dt_proj = Linear( - tdt_rank, - td_inner, + tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank), + inner_dim, bias=False, - weight_init_method=dt_init, - lr_scale=mamba_layer_lr_scale, + # Initialize special dt projection to preserve variance at initialization + weight_init_method=self._config.dt_init.get_init_method( + self._config.dt_rank**-0.5 * self._config.dt_scale + ), + lr_scale=lr_scale, ) # define bias outside the linear layer since its also used in the selective_scan_fn self.dt_proj_bias = ParameterMeta.from_dims( - (td_inner,), init_method=init_from_tensor_(inv_dt), lr_scale=mamba_layer_lr_scale + (inner_dim,), + init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), + lr_scale=lr_scale, ) - - A = einops.repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_log = torch.log(A).flatten() # Keep A_log in fp32 self.A_log = ParameterMeta.from_dims( - (td_inner, td_state), - init_method=init_from_tensor_(A_log), - lr_scale=mamba_layer_lr_scale, + (inner_dim, tensor_space.get_tensor_dim(name=SSMDimNames.state)), + init_method=init_A(self._config.state_size, self._config.d_inner), + lr_scale=lr_scale, weight_decay=False, ) - self.D = ParameterMeta.from_dims( - (td_inner,), + (inner_dim,), weight_decay=False, init_method=init_ones_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) - - self.out_proj = Linear( - td_inner, - td_model, - bias=bias, - weight_init_method=kaiming_init_(td_inner.size), + self.out_proj = InputParallelLinear( + inner_dim, + hidden_dim, + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(self._config.d_inner), ) def forward(self, hidden_states, kwargs): - """ - hidden_states: (B, L, D) - Returns: same shape as hidden_states - """ assert _mamba_available - batch, seqlen, dim = hidden_states.shape - outputs = {} - - conv_state, ssm_state = None, None - - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - - zxbcdt = self.in_proj(hidden_states) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) - - x = einops.rearrange(x, "b l d -> b d l") - z = einops.rearrange(z, "b l d -> b d l") - - B = einops.rearrange(B, "b l (n_group dstate) -> b n_group l dstate", dstate=self.d_state) - B = repeat_kv(B, self.repeat_group) # B, n_group, L, H - B = einops.rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() - C = einops.rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() - - dt = self.dt_proj(dt) + self.dt_proj_bias # B, L, d_inner - dt = einops.rearrange(dt, "b l d -> b d l") # B, d_inner, L + assert _causal_conv1d_available + + inner_projection = self.in_proj(hidden_states) + # Standardize to (batch, sequence, inner_projection) + if kwargs[TransformerKwargs.sequence_first]: + inner_projection = inner_projection.transpose(0, 1) + sequence_length = hidden_states.size(1) + + z, x, b, c, dt = torch.split( + inner_projection, + [self._config.d_inner, self._config.d_xb, self._config.d_xb, self._config.d_inner, self._config.dt_rank], + dim=2, + ) + # z: (batch, sequence, heads * state) -> (batch, heads * state, sequence) + z = z.transpose(1, 2) + + # x: (batch, sequence, head_groups * state) -> (batch, heads * state, sequence) + x = x.transpose(1, 2) + if self._config.repeat_kv_before_conv: + x = ( + x.unflatten(1, (self._head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._heads) + .flatten(1, 2) + ) + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias, activation="silu") + else: + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias, activation="silu") + x = ( + x.unflatten(1, (self._head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._heads) + .flatten(1, 2) + ) - if self.repeat_kv_before_conv: - assert self.repeat_group > 0 - x = einops.rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state) - x = repeat_kv(x, self.repeat_group) - x = einops.rearrange(x, "b n_group l dstate -> b (n_group dstate) l") + # b: (batch, sequence, head_groups * state) -> (batch, heads, state, sequence) + b = ( + b.transpose(1, 2) + .unflatten(1, (self._head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._heads) + ) - assert self.activation in ["silu", "swish"] - if _causal_conv1d_available: - x = _causal_conv1d_fn( - x=x, - weight=einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), - bias=self.conv1d_bias, - activation=self.activation, - ) # B, L, D - else: - raise RuntimeError("Causal conv1d is not available. Please install causal_conv1d.") + # c: (batch, sequence, heads * state) -> (batch, heads, state, sequence) + c = c.transpose(1, 2).unflatten(1, (self._heads, self._config.state_size)) - if not self.repeat_kv_before_conv: - x = einops.rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state) - x = repeat_kv(x, self.repeat_group) - x = einops.rearrange(x, "b n_group l dstate -> b (n_group dstate) l") + # dt: (batch, sequence, dt_rank) -> (batch, heads * state, sequence) + dt = (self.dt_proj(dt) + self.dt_proj_bias).transpose(1, 2) y = selective_scan_fn( x, dt, - A, - B, - C, + -torch.exp(self.A_log.float()), + b, + c, self.D.float(), - z=z, - delta_bias=self.dt_proj_bias.float(), # self.dt_proj.bias.float(), + z, + delta_bias=self.dt_proj_bias.float(), delta_softplus=True, - return_last_state=False, ) - if ssm_state is not None: - y, last_state = y - ssm_state.copy_(einops.rearrange(last_state, "b (h d) n -> b h d n", h=self.num_C_head)) - - y = einops.rearrange(y, "b d l -> b l d") - out = self.out_proj(y) - outputs["hidden_states"] = out[:, :seqlen, :].contiguous() - return outputs["hidden_states"], None + # y: (batch, heads * state, sequence) -> out: (batch, sequence, hidden) + out = self.out_proj(y.transpose(1, 2))[:, :sequence_length] + if kwargs[TransformerKwargs.sequence_first]: + out = out.transpose(0, 1) + # TODO: Is contiguous needed? + return out.contiguous(), None diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 7c824d235..fa2789b1e 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -1,14 +1,18 @@ +import logging import math +import typing from typing import Callable -import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_ -from fast_llm.utils import get_lr_scale +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.transformer import Mixer +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_ +from fast_llm.utils import Assert, get_lr_scale try: from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn as _mamba_inner_fn # noqa @@ -17,6 +21,8 @@ except (ImportError, RuntimeError): _mamba_available = False +logger = logging.getLogger(__name__) + """ Note: this is mostly adapted from https://github.com/Zyphra/Zamba2, similar code is also in https://github.com/state-spaces/mamba. For now it only supports training and not inference. @@ -26,169 +32,126 @@ def init_A(d_state, d_inner) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - # S4D real initialization # TODO: adopt this initialization to work for tensor parallel setting! - A = einops.repeat(torch.arange(1, d_state + 1, dtype=torch.float32), "n -> d n", d=d_inner).contiguous() - A_log = torch.log(A) # Keep A_log in fp32 - if tensor.shape != A_log.shape: - if tensor.numel() == A_log.numel(): - tensor_view = tensor.view(d_inner, d_state) - tensor_view.copy_(A_log) - else: - raise ValueError(f"Tensor size {tensor.numel()} doesn't match expected size {A_log.numel()}") - else: - tensor.copy_(A_log) - return tensor + if tensor.numel() != d_state * d_inner: + raise ValueError(f"_init_A requires not supported for tensor slices.") + return torch.log(torch.arange(1, d_state + 1, dtype=torch.float32).repeat(d_inner), out=tensor) return init_ def init_dtprojbias( - d_inner: int, dt_max: float, dt_min: float, dt_init_floor: float, factory_kwargs: dict + dt_max: float, dt_min: float, dt_init_floor: float ) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - dt = torch.exp( - torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) - ).clamp(min=dt_init_floor) + tensor = tensor.uniform_(math.log(dt_min), math.log(dt_max)).exp_().clamp_min(dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - tensor.copy_(inv_dt) - return tensor + return tensor.add_(torch.log(-torch.expm1(-tensor))) return init_ -class MambaLayer(torch.nn.Module): +class MambaLayer(Mixer): def __init__( self, config: SSMConfig, layer_idx: int, tensor_space: TensorSpace, - return_input: bool = False, ): - factory_kwargs = {} super().__init__() - self.config: SSMConfig = config - self.layer_idx = layer_idx - - self._debug_mode = config.debug_ssm + assert tensor_space.distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" + self._config = config + # TODO: It's not silu? + Assert.eq(self._config.activation_type, ActivationType.silu) # Tensor dims: - td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) - td_inner_proj = tensor_space.get_tensor_dim( - SSMDimNames.inner_proj_mamba - ) # TensorDim("D_inner_2", self.d_inner * 2) - tdt_rank = tensor_space.get_tensor_dim(SSMDimNames.dt_rank) - td_x_proj = tensor_space.get_tensor_dim(SSMDimNames.x_proj_dim) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) - td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) - self.d_conv = td_conv_kernel.size - self.d_inner = td_inner.size - self.d_state = td_state.size - self.d_model = td_model.size - self.dt_rank = tdt_rank.size + inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) + hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None - mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) + lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - self.in_proj_weight = ParameterMeta.from_dims( - (td_inner_proj, td_model), - init_method=kaiming_init_(td_model.size), + # TODO: Backward compatibility? + # TODO: lr_scale? + self.in_proj = Linear( + hidden_dim, + tensor_space.get_tensor_dim(SSMDimNames.composite_inner_projection), + bias=False, + weight_init_method=init_kaiming_(hidden_dim.size), ) self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, TensorDim("D_inner_2", self.d_inner // self.d_inner), td_conv_kernel), - init_method=kaiming_init_(td_inner.size), - lr_scale=mamba_layer_lr_scale, + (inner_dim, tensor_space.get_tensor_dim(SSMDimNames.conv_kernel)), + init_method=init_kaiming_(inner_dim.size), + lr_scale=lr_scale, ) - self.conv1d_bias = None - - self.activation = "silu" - self.act = torch.nn.SiLU() - self.x_proj = Linear( - td_inner, - td_x_proj, - weight_init_method=kaiming_init_(td_inner.size), + inner_dim, + tensor_space.get_tensor_dim(SSMDimNames.x_proj_dim), + weight_init_method=init_kaiming_(inner_dim.size), bias=False, - lr_scale=mamba_layer_lr_scale, - **factory_kwargs, + lr_scale=lr_scale, ) self.x_proj.weight.auto_grad_accumulation = True # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 self.dt_proj_weight = ParameterMeta.from_dims( - (td_inner, tdt_rank), - init_method=kaiming_init_(tdt_rank.size), - lr_scale=mamba_layer_lr_scale, + (inner_dim, tensor_space.get_tensor_dim(SSMDimNames.dt_rank)), + init_method=init_kaiming_(self._config.dt_rank), + lr_scale=lr_scale, ) self.dt_proj_bias = ParameterMeta.from_dims( - (td_inner,), - init_method=init_dtprojbias( - self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor, factory_kwargs - ), - lr_scale=mamba_layer_lr_scale, + (inner_dim,), + init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), + lr_scale=lr_scale, ) self.A_log = ParameterMeta.from_dims( - (td_inner, td_state), + (inner_dim, tensor_space.get_tensor_dim(SSMDimNames.state)), weight_decay=False, - init_method=init_A(self.d_state, self.d_inner), - lr_scale=mamba_layer_lr_scale, + init_method=init_A(self._config.state_size, inner_dim.size), + lr_scale=lr_scale, ) # D "skip" parameter self.D = ParameterMeta.from_dims( - (td_inner,), + (inner_dim,), weight_decay=False, init_method=init_ones_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) self.out_proj = Linear( - td_inner, - td_model, + inner_dim, + hidden_dim, bias=False, # TODO: note, if bias is used there is a problem in the MambaInnerFn.backward for the bias grads. I think this bias is not used in other mamba repos. - weight_init_method=kaiming_init_(td_model.size), - lr_scale=mamba_layer_lr_scale, - **factory_kwargs, + weight_init_method=init_kaiming_(hidden_dim.size), + lr_scale=lr_scale, ) self.out_proj.weight.auto_grad_accumulation = True - self._return_input = return_input - def forward(self, hidden_states, kwargs): + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available - batch, seqlen, dim = hidden_states.shape - - # We do matmul and transpose BLH -> HBL at the same time - xz = einops.rearrange( - self.in_proj_weight @ einops.rearrange(hidden_states, "b l d -> d (b l)"), - "d (b l) -> b d l", - l=seqlen, - ) - if self._debug_mode: - print("XZ: ", xz.shape) + in_proj = self.in_proj(input_).permute((1, 2, 0) if kwargs[TransformerKwargs.sequence_first] else (0, 2, 1)) - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # In the backward pass we write dx and dz next to each other to avoid torch.cat # not, if we wanbt to support inference, we would need to imp.lement slow path here, see https://github.com/Zyphra/Zamba2/blob/1b182f40f2257f822cc06dd785df53d67d691a15/mamba_layer.py#L172s out = _mamba_inner_fn( - xz, - self.conv1d_weight, - self.conv1d_bias, + in_proj, + self.conv1d_weight.unsqueeze(1), + None, self.x_proj.weight, self.dt_proj_weight, self.out_proj.weight, self.out_proj.bias, # is None here - A, + -torch.exp(self.A_log.float()), None, # input-dependent B None, # input-dependent C self.D.float(), delta_bias=self.dt_proj_bias.float(), delta_softplus=True, ) - if self._return_input: - out = torch.stack((hidden_states, out), dim=0) + if kwargs[TransformerKwargs.sequence_first]: + out = out.transpose(0, 1) return out, None diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 3351c9906..76b8ed1ca 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -13,6 +13,7 @@ TransformerKwargs, TransformerSubLayerName, ) +from fast_llm.layers.transformer.transformer import Mixer from fast_llm.logging import log_distributed_grad, log_distributed_tensor from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ from fast_llm.utils import Assert, get_lr_scale @@ -50,7 +51,7 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention(torch.nn.Module): +class Attention(Mixer): """ A self-attention layer. """ diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 147452073..f80e903f0 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -18,13 +18,24 @@ logger = logging.getLogger(__name__) +class Mixer(torch.nn.Module, abc.ABC): + """ + Base class for mixer modules. + """ + + @abc.abstractmethod + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Mixer module forward. Returns the output hidden states and an optional bias, + in case its addition can be made more efficient in `_bias_dropout_add`. + """ + + class BaseBlock(Layer, abc.ABC): """ A transformer-like decoder base block with abstract mixer. """ - _mixer_module_name = "self_attn" - def __init__( self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False ): @@ -54,7 +65,7 @@ def __init__( self.norm_2 = self._config.peft.apply_other(self.norm_2) @abc.abstractmethod - def _create_mixer(self): + def get_mixer(self) -> Mixer: pass @torch.compile @@ -115,7 +126,7 @@ def forward( hidden_states = self.norm_1(input_) if self._debug_mode: self._debug_log(hidden_states, "Norm 1", kwargs) - hidden_states, bias = getattr(self, self._mixer_module_name)(hidden_states, kwargs) + hidden_states, bias = self.get_mixer()(hidden_states, kwargs) if self._debug_mode: self._debug_log(hidden_states, f"{self._mixer_module_name} output", kwargs, bias=bias) with set_generator(generator): @@ -137,14 +148,14 @@ def forward( return hidden_states -class TransformerLayer(BaseBlock): +class TransformerBlock(BaseBlock): _name = "Transformer layer" - _mixer_module_name = "self_attn" def __init__( self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False ): super().__init__(config, tensor_space, layer_index, return_input) - - def _create_mixer(self): self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) + + def get_mixer(self) -> Mixer: + return self.self_attn diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index c206ef406..a9cf3bb8c 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -7,7 +7,7 @@ from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.schedule.config import BatchConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.custom.config import CustomBaseModelConfig, CustomModelConfig from fast_llm.models.custom.head import CustomHead from fast_llm.models.gpt.config import GPTBaseModelConfig @@ -31,7 +31,7 @@ def get_layers(self) -> list[Layer]: return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, layer_index=i + 1, diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 444ad72b2..a3a68e0a6 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -21,7 +21,7 @@ TransformerLossNames, ) from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -68,7 +68,7 @@ def get_output_layers(self) -> list[Layer]: for i in range(self._config.prediction_heads): if i > 0: layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, # TODO MTP: which index? @@ -91,7 +91,7 @@ def get_layers(self) -> list[Layer]: return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, layer_index=i + 1, @@ -336,7 +336,7 @@ def embedding(self) -> LanguageModelEmbedding: return self.layers[0] @property - def transformer_layers(self) -> list[TransformerLayer]: + def transformer_layers(self) -> list[TransformerBlock]: return self.layers[1:-1] @property diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index cc83f11be..c294fe528 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -6,12 +6,11 @@ from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler from fast_llm.engine.config_utils.runnable import RunnableConfig -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.language_model.config import LanguageModelBaseConfig -from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig, SSMDimNames -from fast_llm.models.gpt.config import GPTBatchConfig, PretrainedGPTModelConfig +from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, PretrainedGPTModelConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -24,7 +23,7 @@ @config_class() -class HybridSSMBaseModelConfig(LanguageModelBaseConfig): +class HybridSSMBaseModelConfig(GPTBaseModelConfig): _abstract = False ssm: SSMConfig = Field( @@ -51,38 +50,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: Some of these can be setup directly in the layer config, but keeping them here for clarity. """ super().setup_tensor_space(tensor_space) - d_inner: int = self.ssm.d_inner - - # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.model_dim, self.transformer.hidden_size)) - # Mamba-specific dimensions - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_dim, d_inner)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.state_dim, self.ssm.state_size)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.dt_rank, self.ssm.dt_rank)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim, self.ssm.dt_rank + self.ssm.state_size * 2)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel_size, self.ssm.conv_kernel_dimension)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba, d_inner * 2)) - - if SSMBlockType.mamba2_discrete.value in self.hybrid_block_layout: - # Mamba2 specific dimensions - # as per https://github.com/cartesia-ai/edge/blob/a0e121ebed3d2324c6d762b0e211a08d62583681/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py#L66C3-L66C4 - headdim = d_inner // self.ssm.n_v_heads - Assert.eq(self.ssm.n_v_heads, d_inner // headdim) - Assert.eq(d_inner % headdim, 0) - Assert.eq(self.ssm.n_v_heads % self.ssm.n_qk_heads, 0) - - conv_dim = d_inner + 2 * self.ssm.n_qk_heads * self.ssm.state_size - inner_proj_dim = 2 * d_inner + 2 * self.ssm.n_qk_heads * self.ssm.state_size + self.ssm.n_v_heads - - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.head_dim, headdim)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.qk_heads, self.ssm.n_qk_heads)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.v_heads, self.ssm.n_v_heads)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_discrete_mamba2, inner_proj_dim)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_dim, conv_dim)) - elif SSMBlockType.mamba2.value in self.hybrid_block_layout: - inner_proj_dim: int = 2 * self.ssm.d_xb + 2 * d_inner + self.ssm.dt_rank - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba2, inner_proj_dim)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim_2, self.ssm.d_xb)) + self.ssm.setup_tensor_space(tensor_space) def _validate(self): with self._set_implicit_default(None): diff --git a/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py b/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py index 6d9746db1..8f49ded40 100644 --- a/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py +++ b/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py @@ -322,19 +322,21 @@ def __init__(self, config, factory_kwargs, layer_idx, **kwargs): # Mixer self.mixer = DiscreteMamba2( - d_model=self.config.d_model, + d_model=self.config._hidden_size, layer_idx=layer_idx, **config.ssm_cfg, **factory_kwargs, ) # Other components - self.input_layernorm = LlamaRMSNorm(hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs) + self.input_layernorm = LlamaRMSNorm( + hidden_size=self.config._hidden_size, eps=1e-5, factory_kwargs=factory_kwargs + ) self.post_attention_layernorm = LlamaRMSNorm( - hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs + hidden_size=self.config._hidden_size, eps=1e-5, factory_kwargs=factory_kwargs ) self.mlp = LlamaMLP( - hidden_size=self.config.d_model, + hidden_size=self.config._hidden_size, **config.mlp_cfg, factory_kwargs=factory_kwargs, ) diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 02a5ac239..3e57689b6 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -9,7 +9,7 @@ from fast_llm.layers.ssm.llamba_block import LlambaBlock from fast_llm.layers.ssm.mamba2 import Mamba2 from fast_llm.layers.ssm.mamba_layer import MambaLayer -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.model import GPTBaseModel, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType @@ -39,14 +39,14 @@ def get_output_layers(self) -> list[Layer]: Get the output layers of the model. This includes the language model head and any additional heads specified in the configuration. """ - layers = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] + layers: list[Layer] = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] if self._config.prediction_heads > 1: block_type = self._config.default_mtp_type or self._config.hybrid_block_layout[-1] for i in range(1, self._config.prediction_heads): if block_type == SSMBlockType.transformer: layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, layer_index=len(self._config.hybrid_block_layout), @@ -55,8 +55,8 @@ def get_output_layers(self) -> list[Layer]: ) elif block_type == SSMBlockType.mamba2_discrete: mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, mixer_cls=DiscreteMamba2, layer_index=len(self._config.hybrid_block_layout), tensor_space=self._tensor_space, @@ -65,8 +65,8 @@ def get_output_layers(self) -> list[Layer]: layers.append(mamba_block) elif block_type == SSMBlockType.mamba: mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, mixer_cls=MambaLayer, layer_index=len(self._config.hybrid_block_layout), tensor_space=self._tensor_space, @@ -75,8 +75,8 @@ def get_output_layers(self) -> list[Layer]: layers.append(mamba_block) elif block_type == SSMBlockType.mamba2: mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, mixer_cls=Mamba2, layer_index=len(self._config.hybrid_block_layout), tensor_space=self._tensor_space, @@ -94,14 +94,14 @@ def get_layers(self) -> list[Layer]: Create a list of layers for the model, interleaving Transformer and Mamba blocks according to the block pattern. """ - layers = [LanguageModelEmbedding(self._config, self._tensor_space)] + layers: list[Layer] = [LanguageModelEmbedding(self._config, self._tensor_space)] # Create blocks according to pattern for i, block_type in enumerate(self._config.hybrid_block_layout): if block_type == SSMBlockType.transformer: # Transformer block layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, layer_index=i + 1, @@ -112,8 +112,8 @@ def get_layers(self) -> list[Layer]: ) elif block_type == SSMBlockType.mamba2_discrete: mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, mixer_cls=DiscreteMamba2, layer_index=i + 1, tensor_space=self._tensor_space, @@ -126,8 +126,8 @@ def get_layers(self) -> list[Layer]: elif block_type == SSMBlockType.mamba: # Create Mamba block mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, mixer_cls=MambaLayer, layer_index=i + 1, tensor_space=self._tensor_space, @@ -139,8 +139,8 @@ def get_layers(self) -> list[Layer]: elif block_type == SSMBlockType.mamba2: mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, mixer_cls=Mamba2, layer_index=i + 1, tensor_space=self._tensor_space, diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index d780e4d6d..b474fe87f 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -354,7 +354,7 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) return init_ -def kaiming_init_(d_in): +def init_kaiming_(d_in): return init_normal_(0.0, math.sqrt(2.0 / d_in)) @@ -369,3 +369,9 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) return tensor return init_ + + +def init_uniform_centered_( + high, max_val=None, mean=0.0 +) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: + return init_uniform_(mean - high, mean + high, min_val=mean - max_val, max_val=mean + max_val) diff --git a/setup.cfg b/setup.cfg index 2f69b8e06..c086af7d0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,7 +48,7 @@ HUGGINGFACE = # Required to run SSMs # To install on cpu environment (ex. for IDE support): -# MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation +# MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[SSM]" --no-build-isolation SSM = mamba_ssm[causal-conv1d]==2.2.4 diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index c530a170c..e5fbc7d69 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -4,7 +4,7 @@ from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer from fast_llm.layers.ssm.llamba_block import LlambaBlock -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.utils import Assert from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -39,7 +39,7 @@ def test_frozen_weights(model_testing_config): model_frozen._num_stages, ) frozen_parameter_counts = [ - sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerLayer, LlambaBlock)) else 0 + sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerBlock, LlambaBlock)) else 0 for layer in model_ref.base_model.layers ] for weight_buffer_ref, weight_buffer_frozen in zip( From 4e310c74634a70c4d8117cc025f18a040ffbd098 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 22 Jul 2025 13:04:54 -0400 Subject: [PATCH 02/82] TP mamba --- fast_llm/engine/config_utils/tensor_space.py | 174 ++++++++++++------- fast_llm/layers/common/linear.py | 8 +- fast_llm/layers/common/normalization.py | 4 +- fast_llm/layers/common/peft.py | 4 +- fast_llm/layers/ssm/config.py | 45 +++-- fast_llm/layers/ssm/discrete_mamba2.py | 2 +- fast_llm/layers/ssm/mamba2.py | 22 ++- fast_llm/layers/ssm/mamba_layer.py | 2 +- fast_llm/tensor.py | 31 ++-- 9 files changed, 184 insertions(+), 108 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 99c1bcf70..dceeb7da4 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -5,6 +5,8 @@ from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: + import torch + from fast_llm.core.distributed import ProcessGroup from fast_llm.engine.distributed.distributed import Distributed @@ -23,7 +25,7 @@ def __repr__(self) -> str: f"name={self._name}," f" size={self._size}," f" global_size={self._global_size}," - f" parallel_dim={None if self.parallel_dim is None else self._parallel_dim}" + f" parallel_dim={self._parallel_dim}" f")" ) @@ -38,83 +40,134 @@ def name(self) -> str: def size(self) -> int: return self._size - @property - def expanded_shape(self) -> tuple[int, ...]: - return (self._size,) - - @property - def ndim(self) -> int: - return 1 - @property def global_size(self) -> int: return self._global_size @property - def global_expanded_shape(self) -> tuple[int, ...]: - return (self._size if self._parallel_dim is None else self._size * self._parallel_dim.size,) + def is_parallel(self) -> bool: + return self._parallel_dim is not None and self._parallel_dim.size > 1 @property def parallel_dim(self) -> DistributedDim | None: + # TODO: Make more flexible for derived classes? return self._parallel_dim - @property - def parallel_dim_index(self) -> int | None: - return None if self._parallel_dim is None else 0 - @property def parallel_group(self) -> "ProcessGroup|None": + # TODO: Make more flexible for derived classes? return None if self._parallel_dim is None else self._parallel_dim.group def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - assert self.parallel_dim is not None + assert self.is_parallel return TensorDim(self.name, self.size * distributed_dim.size, distributed_dim) + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + if self.parallel_group is not None: + from fast_llm.core.ops import gather_op + + return gather_op(tensor, self.parallel_group, dim) + else: + return tensor + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> torch.Tensor: + return ( + tensor.chunk(self.parallel_dim.size, dim)[self.parallel_dim.rank] + if self.parallel_dim is not None and self.parallel_dim.size > 1 + else tensor + ) + class CompositeTensorDim(TensorDim): - def __init__(self, name: str, dims: tuple[TensorDim, ...]): - # TODO: Recursive composition?? - parallel_dims = [(i, dim.parallel_dim) for i, dim in enumerate(dims) if dim.parallel_dim] - Assert.leq(len(parallel_dims), 1) + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): + parallel_dim = None + for dim, tensor_dim in enumerate(tensor_dims): + if tensor_dim.is_parallel: + # TODO: Allow more than one parallel subdim? + assert parallel_dim is None + parallel_dim = tensor_dim.parallel_dim + self._parallel_dim_index = dim super().__init__( name=name, - global_size=math.prod(dim.global_size for dim in dims), - parallel_dim=parallel_dims[0][1] if parallel_dims else None, - ) - self._dims = dims - self._parallel_dim_index = ( - sum(dim.ndim for dim in self._dims[: parallel_dims[0][0]]) - + self._dims[parallel_dims[0][0]].parallel_dim_index - if parallel_dims - else None + global_size=math.prod(dim.global_size for dim in tensor_dims), + parallel_dim=parallel_dim, ) + self._tensor_dims = tensor_dims - @property - def dims(self) -> tuple[TensorDim, ...]: - return self._dims + def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + assert self._parallel_dim_index is not None + dims = list(self._tensor_dims) + dims[self._parallel_dim_index] = dims[self._parallel_dim_index].replace_parallel_dim(distributed_dim) + return CompositeTensorDim(self.name, tuple(dims)) - @property - def ndim(self) -> int: - return sum(dim.ndim for dim in self._dims) + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in enumerate(self._tensor_dims): + tensor = tensor_dim.local_to_global(tensor, dim + i) - @property - def expanded_shape(self) -> tuple[int, ...]: - return sum((dim.expanded_shape for dim in self._dims), ()) + return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) - @property - def global_expanded_shape(self) -> tuple[int, ...]: - return sum((dim.global_expanded_shape for dim in self._dims), ()) + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> torch.Tensor: + tensor = tensor.unflatten(dim, [tensor_dim.global_size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in reversed(list(enumerate(self._tensor_dims))): + tensor = tensor_dim.global_to_local(tensor, dim + i) + return tensor if expand else tensor.flatten(dim, dim + len(self._tensor_dims) - 1) - @property - def parallel_dim_index(self) -> int | None: - return self._parallel_dim_index + +class ConcatenatedTensorDim(TensorDim): + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): + parallel_dim = tensor_dims[0].parallel_dim + for dim, tensor_dim in enumerate(tensor_dims[1:]): + # TODO: Allow more flexibility? + Assert.is_(tensor_dim.parallel_dim, parallel_dim) + + super().__init__( + name=name, + global_size=sum(dim.global_size for dim in tensor_dims), + parallel_dim=parallel_dim, + ) + self._tensor_dims = tensor_dims def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - assert self.parallel_dim_index is not None - dims = list(self.dims) - dims[self.parallel_dim_index] = dims[self.parallel_dim_index].replace_parallel_dim(distributed_dim) - return CompositeTensorDim(self.name, tuple(dims)) + # TODO: Implement + raise NotImplementedError() + + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + return ( + torch.concatenate( + [ + tensor_dim.local_to_global(tensor_, dim)[0] + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> torch.Tensor: + if self.is_parallel and expand: + raise NotImplementedError() + return ( + torch.concatenate( + [ + tensor_dim.global_to_local(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.global_size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) class DefaultDimNames: @@ -147,21 +200,22 @@ def distributed(self) -> "Distributed": assert self._is_setup return self._distributed - def add_tensor_dim(self, dim: TensorDim) -> None: - if isinstance(dim, CompositeTensorDim): - for dim_ in dim.dims: - Assert.incl(dim_.name, self._tensor_dims) - Assert.eq(dim_, self._tensor_dims[dim_.name]) - if dim.name in self._tensor_dims: - Assert.eq(dim, self._tensor_dims[dim.name]) + def add_tensor_dim(self, tensor_dim: TensorDim) -> None: + if tensor_dim.name in self._tensor_dims: + Assert.eq(tensor_dim, self._tensor_dims[tensor_dim.name]) else: - if dim.parallel_dim is not None: - assert dim.parallel_dim.name in self._distributed_config.distributed_dims, dim.parallel_dim.name + if tensor_dim.parallel_dim is not None: + assert ( + tensor_dim.parallel_dim.name in self._distributed_config.distributed_dims + ), tensor_dim.parallel_dim.name Assert.eq( - dim.parallel_dim.__dict__, - self._distributed_config.distributed_dims[dim.parallel_dim.name].__dict__, + tensor_dim.parallel_dim.__dict__, + self._distributed_config.distributed_dims[tensor_dim.parallel_dim.name].__dict__, ) - self._tensor_dims[dim.name] = dim + self._tensor_dims[tensor_dim.name] = tensor_dim def get_tensor_dim(self, name: str) -> TensorDim: return self._tensor_dims[name] + + # TODO: Replace uses + __getitem__ = get_tensor_dim diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear.py index cd19a47a5..7249ef569 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear.py @@ -94,8 +94,8 @@ def __init__( transposed_weight: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert in_dim.parallel_dim is None - assert out_dim.parallel_dim is None + assert not in_dim.is_parallel + assert not out_dim.is_parallel super().__init__( in_dim, out_dim, @@ -132,7 +132,7 @@ def __init__( sequence_parallel: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert in_dim.parallel_dim is None + assert not in_dim.is_parallel self._group_size = 1 if out_dim.parallel_dim is None else out_dim.parallel_dim.size self._sequence_parallel = sequence_parallel and self._group_size > 1 super().__init__( @@ -176,7 +176,7 @@ def __init__( transposed_weight: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert out_dim.parallel_dim is None + assert not out_dim.is_parallel self._group_size = 1 if in_dim.parallel_dim is None else in_dim.parallel_dim.size self._sequence_parallel = sequence_parallel and self._group_size > 1 super().__init__( diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index 5f30beaef..bccc1d627 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -158,7 +158,7 @@ def __init__( lr_scale: float | None = None, ): super().__init__() - assert hidden_dim.parallel_dim is None + assert not hidden_dim.is_parallel self._eps = eps self._zero_centered = zero_centered if implementation == NormalizationImplementation.auto: @@ -242,7 +242,7 @@ def __init__( lr_scale: float | None = None, ): super().__init__() - assert hidden_dim.parallel_dim is None + assert not hidden_dim.is_parallel self._eps = eps self._zero_centered = zero_centered if implementation == NormalizationImplementation.auto: diff --git a/fast_llm/layers/common/peft.py b/fast_llm/layers/common/peft.py index 3a1966e51..08f3e535b 100644 --- a/fast_llm/layers/common/peft.py +++ b/fast_llm/layers/common/peft.py @@ -19,12 +19,12 @@ def lora_linear( ): layer.weight.requires_grad = False in_dim = layer._in_dim + assert not in_dim.is_parallel, "LoRA not supported with tensor parallelism." if in_dim.parallel_dim is not None: - assert in_dim.parallel_dim.size == 1, "LoRA not supported with tensor parallelism." in_dim = TensorDim(in_dim.name, in_dim.global_size) out_dim = layer._out_dim + assert not out_dim.is_parallel, "LoRA not supported with tensor parallelism." if out_dim.parallel_dim is not None: - assert out_dim.parallel_dim.size == 1, "LoRA not supported with tensor parallelism." out_dim = TensorDim(out_dim.name, out_dim.global_size) if out_channel_begin is not None or out_channel_end is not None: if out_channel_begin is None: diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index f4c8067dd..ce37a9804 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,7 +1,7 @@ import enum from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig @@ -20,8 +20,7 @@ class SSMDimNames: composite_head_groups_and_state = "ssm_composite_head_groups_and_state" # Inner projection total dimension. - inner_projection = "ssm_inner_projection" - composite_inner_projection = "ssm_composite_inner_projection" + concatenated_inner_projection = "ssm_concatenated_inner_projection" # Convolution shape in discrete mamba 2. TODO: Remove (dim too complex) conv_dim = "ssm_conv_dim" @@ -210,7 +209,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType Assert.eq(num_heads, self.n_v_heads) num_head_groups = self.n_qk_heads # (head_groups, (2 * group_heads + 2) * state_dim + group_heads) - inner_projection_size = 2 * self.d_inner + 2 * num_head_groups * self.state_size + num_heads + 2 * self.d_inner + 2 * num_head_groups * self.state_size + num_heads else: raise NotImplementedError(block_type) @@ -219,12 +218,18 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType tensor_space.add_tensor_dim( group_heads := TensorDim(SSMDimNames.group_heads, num_group_heads := div(num_heads, num_head_groups)) ) - tensor_space.add_tensor_dim(CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads))) tensor_space.add_tensor_dim( - CompositeTensorDim(SSMDimNames.composite_heads_and_state, (head_groups, group_heads, state_dim)) + heads := CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(SSMDimNames.composite_head_groups_and_state, (head_groups, state_dim)) + heads_and_state := CompositeTensorDim( + SSMDimNames.composite_heads_and_state, (head_groups, group_heads, state_dim) + ) + ) + tensor_space.add_tensor_dim( + head_groups_and_state := CompositeTensorDim( + SSMDimNames.composite_head_groups_and_state, (head_groups, state_dim) + ) ) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel, self.conv_kernel_dimension)) @@ -234,17 +239,27 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType if block_type == SSMBlockType.mamba: tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim, self.dt_rank + self.state_size * 2)) - inner_projection_size = 2 * num_group_heads * self.state_size + # TODO: Use composition instead + tensor_space.add_tensor_dim( + ConcatenatedTensorDim(SSMDimNames.concatenated_inner_projection, (heads_and_state, heads_and_state)) + ) elif block_type == SSMBlockType.mamba2: - inner_projection_size = 2 * (num_group_heads + 1) * self.state_size + # TODO: Factor out state? + tensor_space.add_tensor_dim( + ConcatenatedTensorDim( + SSMDimNames.concatenated_inner_projection, + (heads_and_state, head_groups_and_state, head_groups_and_state, heads_and_state), + ) + ) elif block_type == SSMBlockType.mamba2_discrete: - inner_projection_size = 2 * (num_group_heads + 1) * self.state_size + num_group_heads + # TODO: Factor as (head_groups, (group_heads + 2) * state_size + group_heads)? + tensor_space.add_tensor_dim( + ConcatenatedTensorDim( + SSMDimNames.concatenated_inner_projection, + (heads_and_state, head_groups_and_state, head_groups_and_state, heads_and_state, heads), + ) + ) # TODO: (head_groups, group_heads + 2, state_size) tensor_space.add_tensor_dim( TensorDim(SSMDimNames.conv_dim, self.d_inner + 2 * self.n_qk_heads * self.state_size) ) - - tensor_space.add_tensor_dim(inner_projection := TensorDim(SSMDimNames.inner_projection, inner_projection_size)) - tensor_space.add_tensor_dim( - CompositeTensorDim(SSMDimNames.composite_inner_projection, (head_groups, inner_projection)) - ) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index d06b47965..988a09504 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -67,7 +67,7 @@ def __init__( td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.head_groups) td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.composite_heads) td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel) - td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.composite_inner_projection) + td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.concatenated_inner_projection) self.d_model = td_model.size self.d_inner = td_inner.size diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 011889d04..dff1356e6 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -45,6 +45,7 @@ def __init__( inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_state) hidden_dim: TensorDim = tensor_space.get_tensor_dim(name=TransformerDimNames.hidden) + dt_rank_dim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) self._head_groups = div(self._config.d_xb, self._config.state_size) self._heads = div(self._config.d_inner, self._config.state_size) @@ -65,13 +66,21 @@ def __init__( ) self.in_proj = OutputParallelLinear( hidden_dim, - tensor_space.get_tensor_dim(name=SSMDimNames.composite_inner_projection), + tensor_space.get_tensor_dim(name=SSMDimNames.concatenated_inner_projection), bias=config.add_bias_linear, weight_init_method=init_kaiming_(hidden_dim.size), lr_scale=lr_scale, ) - self.dt_proj = Linear( - tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank), + + self.dt_in_proj = Linear( + hidden_dim, + dt_rank_dim, + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(hidden_dim.size), + lr_scale=lr_scale, + ) + self.dt_proj = OutputParallelLinear( + dt_rank_dim, inner_dim, bias=False, # Initialize special dt projection to preserve variance at initialization @@ -110,16 +119,19 @@ def forward(self, hidden_states, kwargs): assert _causal_conv1d_available inner_projection = self.in_proj(hidden_states) + dt = self.dt_in_proj(hidden_states) # Standardize to (batch, sequence, inner_projection) if kwargs[TransformerKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) + dt = dt.transpose(0, 1) sequence_length = hidden_states.size(1) - z, x, b, c, dt = torch.split( + z, x, b, c = torch.split( inner_projection, - [self._config.d_inner, self._config.d_xb, self._config.d_xb, self._config.d_inner, self._config.dt_rank], + [self._config.d_inner, self._config.d_xb, self._config.d_xb, self._config.d_inner], dim=2, ) + # z: (batch, sequence, heads * state) -> (batch, heads * state, sequence) z = z.transpose(1, 2) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index fa2789b1e..0cdcb5242 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -74,7 +74,7 @@ def __init__( # TODO: lr_scale? self.in_proj = Linear( hidden_dim, - tensor_space.get_tensor_dim(SSMDimNames.composite_inner_projection), + tensor_space.get_tensor_dim(SSMDimNames.concatenated_inner_projection), bias=False, weight_init_method=init_kaiming_(hidden_dim.size), ) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index b474fe87f..f312f1962 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -5,7 +5,7 @@ import torch from fast_llm.core.distributed import ReduceOp -from fast_llm.core.ops import gather_op, reduce_op +from fast_llm.core.ops import reduce_op from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed @@ -166,14 +166,13 @@ def local_to_global( ) -> tuple[torch.Tensor, ...]: # Tensors are always either split or duplicated in the tensor-parallel direction. # TODO: Avoid hard-coded assumptions on duplication - is_first_rank = distributed.config.tensor_rank == 0 - modified = False - for i, dim in enumerate(self.dims): - if dim.parallel_group is not None: - tensor = gather_op( - tensor.unflatten(i, dim.expanded_shape), dim.parallel_group, i + dim.parallel_dim_index - ).flatten(i, i + len(dim.expanded_shape) - 1) - is_first_rank, modified = is_first_rank and dim.parallel_group.rank() == 0, True + is_first_rank, modified = distributed.config.tensor_rank == 0, False + + for dim, tensor_dim in enumerate(self.dims): + if tensor_dim.is_parallel: + tensor = tensor_dim.local_to_global(tensor, dim) + is_first_rank &= tensor_dim.parallel_dim.rank == 0 + modified = True for distributed_dim, op in self._reductions: if distributed_dim.group is not None: @@ -187,23 +186,19 @@ def local_to_global( def global_to_local( self, tensor: torch.Tensor | SafeTensorSlice, - # Return an expanded tensor, avoiding `flatten` which copies the data. + # Return an expanded tensor, avoiding `flatten` which copies the data. TODO: Rework. expand: bool = False, ) -> torch.Tensor: """ Recover the tensor-parallel slice of a tensor. Support lazy-loaded safetensor slices. """ # Take a trivial slice to convert safetensor slices. - tensor_ = tensor[:] + tensor = tensor[:] assert not self._reductions - for i, dim in reversed(list(enumerate(self.dims))): - if dim.parallel_dim is not None and dim.parallel_dim.size > 1: - tensor_ = tensor_.unflatten(i, dim.global_expanded_shape).chunk( - dim.parallel_dim.size, i + dim.parallel_dim_index - )[dim.parallel_dim.rank] - - return tensor_ if expand else tensor_.reshape(self.shape) + for dim, tensor_dim in reversed(list(enumerate(self.dims))): + tensor = tensor_dim.global_to_local(tensor, dim, expand) + return tensor @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): From 3cc41182a71d28e02918d76cd882978ca8384f73 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 22 Jul 2025 16:57:38 -0400 Subject: [PATCH 03/82] fix --- fast_llm/engine/config_utils/tensor_space.py | 6 +- fast_llm/layers/ssm/config.py | 24 +++-- fast_llm/layers/ssm/discrete_mamba2.py | 2 + fast_llm/layers/ssm/llamba_block.py | 10 +- fast_llm/layers/ssm/mamba_layer.py | 13 ++- fast_llm/layers/transformer/transformer.py | 20 ++-- fast_llm/models/ssm/config.py | 41 +++----- fast_llm/models/ssm/model.py | 99 +++++--------------- fast_llm/tensor.py | 7 +- tests/data/test_blending.py | 1 + tests/data/test_concatenate.py | 1 + tests/data/test_fim.py | 2 + tests/test_multi_stage.py | 6 +- tests/utils/model_configs.py | 43 +++++---- 14 files changed, 127 insertions(+), 148 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index dceeb7da4..d927f2e71 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -70,7 +70,7 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor else: return tensor - def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> torch.Tensor: + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": return ( tensor.chunk(self.parallel_dim.size, dim)[self.parallel_dim.rank] if self.parallel_dim is not None and self.parallel_dim.size > 1 @@ -108,7 +108,7 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) - def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> torch.Tensor: + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": tensor = tensor.unflatten(dim, [tensor_dim.global_size for tensor_dim in self._tensor_dims]) for i, tensor_dim in reversed(list(enumerate(self._tensor_dims))): tensor = tensor_dim.global_to_local(tensor, dim + i) @@ -150,7 +150,7 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor else tensor ) - def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> torch.Tensor: + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": if self.is_parallel and expand: raise NotImplementedError() return ( diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index ce37a9804..aa011f75f 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -41,6 +41,22 @@ class SSMBlockType(enum.StrEnum): mamba2 = "m2" transformer = "t" + def get_mixer_class(self): + if self == SSMBlockType.mamba: + from fast_llm.layers.ssm.mamba_layer import MambaLayer + + return MambaLayer + elif self == SSMBlockType.mamba2: + from fast_llm.layers.ssm.mamba2 import Mamba2 + + return Mamba2 + elif self == SSMBlockType.mamba2_discrete: + from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 + + return DiscreteMamba2 + else: + raise NotImplementedError(self) + class DTInitType(enum.StrEnum): constant = "constant" @@ -199,17 +215,13 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType # Head groups are configured differently depending on the block type. if block_type == SSMBlockType.mamba: num_head_groups = num_heads - # (head_groups, 2 * group_heads * state_dim) - inner_projection_size = self.d_inner * 2 elif block_type == SSMBlockType.mamba2: num_head_groups = div(self.d_xb, self.state_size) - # (head_groups, 2 * group_heads + 2, state_dim) + (dt,) - inner_projection_size: int = 2 * self.d_inner + 2 * num_head_groups * self.state_size + self.dt_rank elif block_type == SSMBlockType.mamba2_discrete: Assert.eq(num_heads, self.n_v_heads) + # TODO: Fix (Du einsum crashes) + Assert.eq(self.n_qk_heads, self.n_v_heads) num_head_groups = self.n_qk_heads - # (head_groups, (2 * group_heads + 2) * state_dim + group_heads) - 2 * self.d_inner + 2 * num_head_groups * self.state_size + num_heads else: raise NotImplementedError(block_type) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 988a09504..14fb8aaed 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -216,6 +216,8 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ else: y = result + print("AHNFIUWEGIUWEI", self.D.shape, x.shape) + # TODO: h different for D and x (qk_heads, v_heads) Du = torch.einsum("h,blhp->blhp", self.D, x) y = einops.rearrange(y + Du, "b l h p -> b l (h p)") diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/llamba_block.py index e877ff9c2..774ee7303 100644 --- a/fast_llm/layers/ssm/llamba_block.py +++ b/fast_llm/layers/ssm/llamba_block.py @@ -8,7 +8,7 @@ from fast_llm.layers.transformer.config import TransformerConfig -class LlambaBlock(BaseBlock): +class SSMBlock(BaseBlock): """ A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458 """ @@ -24,9 +24,9 @@ def __init__( layer_index: int, return_input: bool = False, ): - self._debug_mode = self._config_ssm.debug_ssm + self._ssm_config = ssm_config + self._mixer_cls = mixer_cls super().__init__(transformer_config, tensor_space, layer_index, return_input) - self.mixer = mixer_cls(ssm_config, layer_idx=self._layer_index, tensor_space=self._tensor_space) - def get_mixer(self) -> Mixer: - return self.mixer + def _create_mixer(self) -> Mixer: + return self._mixer_cls(self._ssm_config, layer_idx=self._layer_index, tensor_space=self._tensor_space) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 0cdcb5242..8235f4f1a 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -1,7 +1,6 @@ import logging import math import typing -from typing import Callable import torch @@ -30,21 +29,25 @@ """ -def init_A(d_state, d_inner) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: +def init_A(d_state, d_inner) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa # TODO: adopt this initialization to work for tensor parallel setting! if tensor.numel() != d_state * d_inner: raise ValueError(f"_init_A requires not supported for tensor slices.") - return torch.log(torch.arange(1, d_state + 1, dtype=torch.float32).repeat(d_inner), out=tensor) + return torch.log( + torch.arange(1, d_state + 1, dtype=torch.float32, device=tensor.device).repeat(d_inner), out=tensor + ) return init_ def init_dtprojbias( dt_max: float, dt_min: float, dt_init_floor: float -) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: +) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - tensor = tensor.uniform_(math.log(dt_min), math.log(dt_max)).exp_().clamp_min(dt_init_floor) + tensor = ( + tensor.uniform_(math.log(dt_min), math.log(dt_max), generator=generator).exp_().clamp_min(dt_init_floor) + ) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 return tensor.add_(torch.log(-torch.expm1(-tensor))) diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index f80e903f0..a0611cd29 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -8,7 +8,6 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.transformer.attention import Attention from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP @@ -36,6 +35,9 @@ class BaseBlock(Layer, abc.ABC): A transformer-like decoder base block with abstract mixer. """ + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "mixer" + def __init__( self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False ): @@ -54,7 +56,8 @@ def __init__( self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) - self._create_mixer() + # The mixer needs to be created here for backward-compatible weight ordering. + setattr(self, self._mixer_module_name, self._create_mixer()) self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( self._config, self._tensor_space, f"{self.name} mlp", layer_index=layer_index @@ -65,7 +68,7 @@ def __init__( self.norm_2 = self._config.peft.apply_other(self.norm_2) @abc.abstractmethod - def get_mixer(self) -> Mixer: + def _create_mixer(self) -> Mixer: pass @torch.compile @@ -126,7 +129,7 @@ def forward( hidden_states = self.norm_1(input_) if self._debug_mode: self._debug_log(hidden_states, "Norm 1", kwargs) - hidden_states, bias = self.get_mixer()(hidden_states, kwargs) + hidden_states, bias = getattr(self, self._mixer_module_name)(hidden_states, kwargs) if self._debug_mode: self._debug_log(hidden_states, f"{self._mixer_module_name} output", kwargs, bias=bias) with set_generator(generator): @@ -150,12 +153,15 @@ def forward( class TransformerBlock(BaseBlock): _name = "Transformer layer" + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "self_attn" def __init__( self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False ): super().__init__(config, tensor_space, layer_index, return_input) - self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) - def get_mixer(self) -> Mixer: - return self.self_attn + def _create_mixer(self) -> Mixer: + from fast_llm.layers.transformer.attention import Attention + + return Attention(self._config, self._tensor_space, self._layer_index) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index c294fe528..6b9e28584 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -30,7 +30,7 @@ class HybridSSMBaseModelConfig(GPTBaseModelConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) - hybrid_block_layout: list[str] | None = Field( + hybrid_block_layout: list[SSMBlockType] | None = Field( default=None, desc=f"Pattern of blocks to use in the model. Available types: {SSMBlockType.__members__.values()}", hint=FieldHint.core, @@ -43,14 +43,16 @@ class HybridSSMBaseModelConfig(GPTBaseModelConfig): use_megatron_initialization: bool = Field( default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing ) # TODO: is this needed? + # TODO: Support combination of different SSM block types. + ssm_block_type: SSMBlockType | None = Field(init=False) def setup_tensor_space(self, tensor_space: TensorSpace) -> None: """ Setup the tensor space for the model. - Some of these can be setup directly in the layer config, but keeping them here for clarity. """ super().setup_tensor_space(tensor_space) - self.ssm.setup_tensor_space(tensor_space) + if self.ssm_block_type is not None: + self.ssm.setup_tensor_space(tensor_space, self.ssm_block_type) def _validate(self): with self._set_implicit_default(None): @@ -64,30 +66,21 @@ def _validate(self): if self.hybrid_block_layout is None: with self._set_implicit_default(): - self.hybrid_block_layout = [SSMBlockType.mamba2_discrete.value] + self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_layers if len(self.hybrid_block_layout) != self.transformer.num_layers: + message = f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: - raise ValueError( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" - ) - num_repeats = int(self.transformer.num_layers // len(self.hybrid_block_layout)) - logger.warning( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" - ) + raise ValueError(message) + num_repeats = self.transformer.num_layers // len(self.hybrid_block_layout) + logger.warning(f"{message}, will repeat {self.hybrid_block_layout} {num_repeats} times.") self.hybrid_block_layout = self.hybrid_block_layout * num_repeats - Assert.eq(len(self.hybrid_block_layout), self.transformer.num_layers) - Assert.custom( - lambda _: all(block_type in SSMBlockType.__members__.values() for block_type in self.hybrid_block_layout), - f"Invalid block type: {self.hybrid_block_layout}. Must be one of {SSMBlockType.__members__.values()}", - ) - Assert.custom( - lambda _: self.default_mtp_type in SSMBlockType.__members__.values() or self.default_mtp_type is None, - f"Invalid MTP type: {self.default_mtp_type}. Must be one of {SSMBlockType.__members__.values()} or None", - ) - super()._validate() + ssm_block_types = set(self.hybrid_block_layout) - {SSMBlockType.transformer} + # TODO: Support combination of different SSM block types. + Assert.leq(len(ssm_block_types), 1) + self.ssm_block_type = ssm_block_types.pop() if ssm_block_types else None class LLambaHuggingfaceCheckpointFormat(CheckpointFormat): @@ -162,12 +155,6 @@ def _validate(self): logger.warning( "HybridSSMModelConfig is being instantiated. This model is experimental and may not work as expected." ) - if ( - self.base_model.sequence_first - or self.distributed.sequence_data_parallel > 1 - or self.distributed.sequence_tensor_parallel - ): - raise NotImplementedError(f"Sequence-first not supported for SSMs.") super()._validate() diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 3e57689b6..4a95891a7 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -5,10 +5,7 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead -from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 -from fast_llm.layers.ssm.llamba_block import LlambaBlock -from fast_llm.layers.ssm.mamba2 import Mamba2 -from fast_llm.layers.ssm.mamba_layer import MambaLayer +from fast_llm.layers.ssm.llamba_block import SSMBlock from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.model import GPTBaseModel, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType @@ -31,7 +28,6 @@ def __init__( config: HybridSSMBaseModelConfig, distributed_config: DistributedConfig, ): - self.SSM_BLOCK_CLS = LlambaBlock # TODO: extend to other block types if needed super().__init__(config, distributed_config) def get_output_layers(self) -> list[Layer]: @@ -53,38 +49,17 @@ def get_output_layers(self) -> list[Layer]: return_input=i != self._config.prediction_heads - 1, ) ) - elif block_type == SSMBlockType.mamba2_discrete: - mamba_block = self.SSM_BLOCK_CLS( - transformer_config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - elif block_type == SSMBlockType.mamba: - mamba_block = self.SSM_BLOCK_CLS( - transformer_config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - elif block_type == SSMBlockType.mamba2: - mamba_block = self.SSM_BLOCK_CLS( - transformer_config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=Mamba2, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") + layers.append( + SSMBlock( + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, + mixer_cls=self._config.ssm_block_type.get_mixer_class(), + layer_index=len(self._config.hybrid_block_layout), + tensor_space=self._tensor_space, + return_input=i != self._config.prediction_heads - 1, + ) + ) layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=i)) return layers @@ -110,47 +85,19 @@ def get_layers(self) -> list[Layer]: ), ) ) - elif block_type == SSMBlockType.mamba2_discrete: - mamba_block = self.SSM_BLOCK_CLS( - transformer_config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - - elif block_type == SSMBlockType.mamba: - # Create Mamba block - mamba_block = self.SSM_BLOCK_CLS( - transformer_config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - - elif block_type == SSMBlockType.mamba2: - mamba_block = self.SSM_BLOCK_CLS( - transformer_config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=Mamba2, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") + layers.append( + SSMBlock( + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, + mixer_cls=self._config.ssm_block_type.get_mixer_class(), + layer_index=i + 1, + tensor_space=self._tensor_space, + return_input=( + i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 + ), + ) + ) # Add the output layers layers += self.get_output_layers() diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index f312f1962..1111fd044 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -369,4 +369,9 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) def init_uniform_centered_( high, max_val=None, mean=0.0 ) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - return init_uniform_(mean - high, mean + high, min_val=mean - max_val, max_val=mean + max_val) + return init_uniform_( + mean - high, + mean + high, + min_val=None if max_val is None else mean - max_val, + max_val=None if max_val is None else mean + max_val, + ) diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 438782dfe..3e6c37632 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -193,6 +193,7 @@ def test_gpt_blended_mixed(): def test_gpt_blended_mixed_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index e951cc2b1..4f36cdf89 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -39,6 +39,7 @@ def test_gpt_concatenate(): def test_gpt_concatenate_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 7472f1958..004b96289 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -58,6 +58,7 @@ def test_gpt_fim(): def test_gpt_fim_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { @@ -81,6 +82,7 @@ def test_gpt_fim_data(): def test_gpt_fim_data_legacy(): + get_test_dataset() get_test_data_and_compare_samples( { "format": "list", diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index e5fbc7d69..2f125717e 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -3,9 +3,10 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer -from fast_llm.layers.ssm.llamba_block import LlambaBlock +from fast_llm.layers.ssm.llamba_block import SSMBlock from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.utils import Assert +from tests.utils.dataset import get_model_test_dataset from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -23,6 +24,7 @@ def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: @requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_frozen_weights(model_testing_config): + get_model_test_dataset() args = model_testing_config.config_args + ["run.tensor_logs.save=False"] model_ref = _get_trainer_from_args(args, model_testing_config.model_type)._multi_stage model_frozen = _get_trainer_from_args( @@ -39,7 +41,7 @@ def test_frozen_weights(model_testing_config): model_frozen._num_stages, ) frozen_parameter_counts = [ - sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerBlock, LlambaBlock)) else 0 + sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerBlock, SSMBlock)) else 0 for layer in model_ref.base_model.layers ] for weight_buffer_ref, weight_buffer_frozen in zip( diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index b96a8963b..b834ed4d1 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -451,16 +451,14 @@ def _update_and_add_testing_config( ) _update_and_add_testing_config( - # Tests hybrid ssm, llamba converter. + # Tests hybrid Mamba, llamba converter. "llama", "llamba", model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m']", - "model.base_model.ssm.state_size=8", - "model.base_model.ssm.chunk_size=32", - "model.base_model.ssm.n_qk_heads=8", - "model.base_model.ssm.n_v_heads=8", + "model.base_model.ssm.d_inner=512", + "model.base_model.ssm.state_size=16", ], megatron_args=None, checkpoint_format=LLambaHuggingfaceCheckpointFormat, @@ -468,26 +466,31 @@ def _update_and_add_testing_config( groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.broken, # TODO: Fix and bring back to `testing_groups` + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - # TODO: Fix and bring back to `testing_groups` - ModelTestingGroup.distributed: ModelTestingGroupAction.broken, + ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, }, compare_factor=2.0, - # SSMs don't support sequence-first configurations. - skip_tests=("sf", "sdp", "stp", "ms"), + # Micro-sequence split not supported. + skip_tests=("sdp", "ms"), ) _update_and_add_testing_config( - # Tests hybrid ssm, llamba converter. - "llamba", + # Tests hybrid discrete Mamba 2. + "llama", "hybrid_discrete_mamba2", model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2d']", + "model.base_model.ssm.d_inner=512", + "model.base_model.ssm.state_size=8", + # TODO: Set to 16 once fixed. + "model.base_model.ssm.n_qk_heads=32", + "model.base_model.ssm.n_v_heads=32", + "model.base_model.ssm.chunk_size=32", ], megatron_args=None, checkpoint_format=None, @@ -497,17 +500,23 @@ def _update_and_add_testing_config( ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + # TODO: Implement + ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, }, + # Micro-sequence split and sequence-first not supported. + skip_tests=("sf", "stp", "sdp", "ms"), ) _update_and_add_testing_config( - # Tests hybrid ssm, llamba converter. - "llamba", + # Tests hybrid Mamba 2. + "llama", "hybrid_mamba2", model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2']", + "model.base_model.ssm.d_inner=512", + "model.base_model.ssm.state_size=16", + "model.base_model.ssm.d_xb=256", ], megatron_args=None, checkpoint_format=None, @@ -517,8 +526,10 @@ def _update_and_add_testing_config( ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, + # Micro-sequence split not supported. + skip_tests=("sdp", "ms"), ) From 9f7f75c72f1fff36a781773c8c772441d7fa9067 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 22 Jul 2025 19:56:35 -0400 Subject: [PATCH 04/82] fix --- fast_llm/engine/config_utils/tensor_space.py | 6 +++++- fast_llm/layers/ssm/config.py | 2 -- fast_llm/layers/ssm/discrete_mamba2.py | 4 +--- fast_llm/layers/ssm/mamba2.py | 19 +++++++++++-------- fast_llm/layers/ssm/mamba_layer.py | 5 ++++- fast_llm/tensor.py | 6 ++++++ tests/utils/model_configs.py | 9 +++++---- 7 files changed, 32 insertions(+), 19 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index d927f2e71..2ca7e3e9a 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -21,7 +21,7 @@ def __init__(self, name: str, global_size: int | None, parallel_dim: Distributed def __repr__(self) -> str: return ( - f"TensorDim(" + f"{type(self).__name__}(" f"name={self._name}," f" size={self._size}," f" global_size={self._global_size}," @@ -134,6 +134,8 @@ def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: raise NotImplementedError() def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + import torch + return ( torch.concatenate( [ @@ -153,6 +155,8 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": if self.is_parallel and expand: raise NotImplementedError() + import torch + return ( torch.concatenate( [ diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index aa011f75f..7da4283ba 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -219,8 +219,6 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType num_head_groups = div(self.d_xb, self.state_size) elif block_type == SSMBlockType.mamba2_discrete: Assert.eq(num_heads, self.n_v_heads) - # TODO: Fix (Du einsum crashes) - Assert.eq(self.n_qk_heads, self.n_v_heads) num_head_groups = self.n_qk_heads else: raise NotImplementedError(block_type) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 14fb8aaed..102accb85 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -111,7 +111,7 @@ def __init__( # D "skip" parameter self.D = ParameterMeta.from_dims( - (td_n_qk_heads,), + (td_n_v_heads,), weight_decay=False, init_method=init_ones_, lr_scale=mamba_layer_lr_scale, @@ -216,8 +216,6 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ else: y = result - print("AHNFIUWEGIUWEI", self.D.shape, x.shape) - # TODO: h different for D and x (qk_heads, v_heads) Du = torch.einsum("h,blhp->blhp", self.D, x) y = einops.rearrange(y + Du, "b l h p -> b l (h p)") diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index dff1356e6..11ab91e40 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -4,7 +4,6 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.ssm.discrete_mamba2 import bias_init_method from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer @@ -62,7 +61,9 @@ def __init__( lr_scale=lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( - (conv1d_dim,), init_method=bias_init_method(self._config.conv_kernel_dimension**-0.5), lr_scale=lr_scale + (conv1d_dim,), + init_method=init_uniform_centered_(self._config.conv_kernel_dimension**-0.5), + lr_scale=lr_scale, ) self.in_proj = OutputParallelLinear( hidden_dim, @@ -124,7 +125,7 @@ def forward(self, hidden_states, kwargs): if kwargs[TransformerKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) dt = dt.transpose(0, 1) - sequence_length = hidden_states.size(1) + sequence_length = inner_projection.size(1) z, x, b, c = torch.split( inner_projection, @@ -177,9 +178,11 @@ def forward(self, hidden_states, kwargs): delta_softplus=True, ) - # y: (batch, heads * state, sequence) -> out: (batch, sequence, hidden) - out = self.out_proj(y.transpose(1, 2))[:, :sequence_length] + # y: (batch, heads * state, sequence) -> (batch, sequence, heads * state) + y = y.transpose(1, 2)[:, :sequence_length] if kwargs[TransformerKwargs.sequence_first]: - out = out.transpose(0, 1) - # TODO: Is contiguous needed? - return out.contiguous(), None + # TODO: Is contiguous needed? + y = y.transpose(0, 1).contiguous() + a, b = self.out_proj(y) + Assert.eq(a.shape, hidden_states.shape) + return a, b diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 8235f4f1a..49b9e45b7 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -35,7 +35,10 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) if tensor.numel() != d_state * d_inner: raise ValueError(f"_init_A requires not supported for tensor slices.") return torch.log( - torch.arange(1, d_state + 1, dtype=torch.float32, device=tensor.device).repeat(d_inner), out=tensor + torch.arange(1, d_state + 1, dtype=torch.float32, device=tensor.device) + .unsqueeze(0) + .expand(d_inner, d_state), + out=tensor, ) return init_ diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 1111fd044..25ae49a31 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -164,6 +164,9 @@ def local_to_global( *, distributed: Distributed, ) -> tuple[torch.Tensor, ...]: + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.shape) # Tensors are always either split or duplicated in the tensor-parallel direction. # TODO: Avoid hard-coded assumptions on duplication is_first_rank, modified = distributed.config.tensor_rank == 0, False @@ -195,6 +198,9 @@ def global_to_local( # Take a trivial slice to convert safetensor slices. tensor = tensor[:] assert not self._reductions + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.global_shape) for dim, tensor_dim in reversed(list(enumerate(self.dims))): tensor = tensor_dim.global_to_local(tensor, dim, expand) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index b834ed4d1..47314263b 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -487,9 +487,8 @@ def _update_and_add_testing_config( "model.base_model.hybrid_block_layout=['t','m2d']", "model.base_model.ssm.d_inner=512", "model.base_model.ssm.state_size=8", - # TODO: Set to 16 once fixed. - "model.base_model.ssm.n_qk_heads=32", - "model.base_model.ssm.n_v_heads=32", + "model.base_model.ssm.n_qk_heads=8", + "model.base_model.ssm.n_v_heads=16", "model.base_model.ssm.chunk_size=32", ], megatron_args=None, @@ -503,6 +502,7 @@ def _update_and_add_testing_config( # TODO: Implement ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, }, + compare_factor=2.0, # Micro-sequence split and sequence-first not supported. skip_tests=("sf", "stp", "sdp", "ms"), ) @@ -515,7 +515,7 @@ def _update_and_add_testing_config( extra_args=[ "model.base_model.hybrid_block_layout=['t','m2']", "model.base_model.ssm.d_inner=512", - "model.base_model.ssm.state_size=16", + "model.base_model.ssm.state_size=8", "model.base_model.ssm.d_xb=256", ], megatron_args=None, @@ -528,6 +528,7 @@ def _update_and_add_testing_config( ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, + compare_factor=2.0, # Micro-sequence split not supported. skip_tests=("sdp", "ms"), ) From 4054e047d7318c2dfd6e37712f3b6b94d3beca5b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 23 Jul 2025 15:22:24 -0400 Subject: [PATCH 05/82] fixes --- fast_llm/engine/config_utils/tensor_space.py | 11 ++++-- fast_llm/engine/multi_stage/stage_base.py | 2 + fast_llm/layers/ssm/mamba2.py | 41 +++++++++++--------- fast_llm/tensor.py | 2 + 4 files changed, 34 insertions(+), 22 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 2ca7e3e9a..0d971a88a 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -1,3 +1,4 @@ +import logging import math import typing @@ -10,6 +11,8 @@ from fast_llm.core.distributed import ProcessGroup from fast_llm.engine.distributed.distributed import Distributed +logger = logging.getLogger(__name__) + class TensorDim: def __init__(self, name: str, global_size: int | None, parallel_dim: DistributedDim | None = None): @@ -130,8 +133,10 @@ def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): self._tensor_dims = tensor_dims def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - # TODO: Implement - raise NotImplementedError() + assert self.is_parallel + return ConcatenatedTensorDim( + self.name, tuple(tensor_dim.replace_parallel_dim(distributed_dim) for tensor_dim in self._tensor_dims) + ) def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": import torch @@ -139,7 +144,7 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor return ( torch.concatenate( [ - tensor_dim.local_to_global(tensor_, dim)[0] + tensor_dim.local_to_global(tensor_, dim) for tensor_, tensor_dim in zip( tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), self._tensor_dims, diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 2f18f1360..9a8ce2092 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -191,6 +191,8 @@ def initialize_weights(self) -> None: # Initialize all global weights on every gpu, then select the appropriate slice if applicable. global_param = parameter.new_empty(global_shape, device=self._distributed.device) meta.init_parameter(global_param, distributed=self._distributed) + # It happens. + Assert.eq(global_param.shape, global_shape) if self._mode.on_device: parameter.copy_(fsdp.parameter_global_to_shard(global_param, meta.tensor_name)) elif self._mode.on_device: diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 11ab91e40..a285711c6 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -1,3 +1,5 @@ +import logging + import torch from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace @@ -24,6 +26,8 @@ except (ImportError, RuntimeError): _causal_conv1d_available = False +logger = logging.getLogger(__name__) + class Mamba2(Mixer): """ @@ -43,21 +47,20 @@ def __init__( lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_state) + xb_dim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_head_groups_and_state) hidden_dim: TensorDim = tensor_space.get_tensor_dim(name=TransformerDimNames.hidden) dt_rank_dim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) - self._head_groups = div(self._config.d_xb, self._config.state_size) - self._heads = div(self._config.d_inner, self._config.state_size) - self._group_heads = div(self._heads, self._head_groups) + self._local_heads = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads).size + self._local_head_groups = tensor_space.get_tensor_dim(name=SSMDimNames.head_groups).size + self._group_heads = div(self._local_heads, self._local_head_groups) + self._local_inner_size = inner_dim.size + self._local_xb_size = xb_dim.size - conv1d_dim = ( - inner_dim - if self._config.repeat_kv_before_conv - else tensor_space.get_tensor_dim(name=SSMDimNames.composite_head_groups_and_state) - ) + conv1d_dim = inner_dim if self._config.repeat_kv_before_conv else xb_dim self.conv1d_weight = ParameterMeta.from_dims( (conv1d_dim, tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel)), - init_method=init_uniform_centered_((conv1d_dim.size * self._config.conv_kernel_dimension) ** -0.5), + init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( @@ -69,7 +72,7 @@ def __init__( hidden_dim, tensor_space.get_tensor_dim(name=SSMDimNames.concatenated_inner_projection), bias=config.add_bias_linear, - weight_init_method=init_kaiming_(hidden_dim.size), + weight_init_method=init_kaiming_(hidden_dim.global_size), lr_scale=lr_scale, ) @@ -77,7 +80,7 @@ def __init__( hidden_dim, dt_rank_dim, bias=config.add_bias_linear, - weight_init_method=init_kaiming_(hidden_dim.size), + weight_init_method=init_kaiming_(hidden_dim.global_size), lr_scale=lr_scale, ) self.dt_proj = OutputParallelLinear( @@ -129,7 +132,7 @@ def forward(self, hidden_states, kwargs): z, x, b, c = torch.split( inner_projection, - [self._config.d_inner, self._config.d_xb, self._config.d_xb, self._config.d_inner], + [self._local_inner_size, self._local_xb_size, self._local_xb_size, self._local_inner_size], dim=2, ) @@ -140,28 +143,28 @@ def forward(self, hidden_states, kwargs): x = x.transpose(1, 2) if self._config.repeat_kv_before_conv: x = ( - x.unflatten(1, (self._head_groups, self._config.state_size)) - .repeat_interleave(self._group_heads, 1, output_size=self._heads) + x.unflatten(1, (self._local_head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) ) x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias, activation="silu") else: x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias, activation="silu") x = ( - x.unflatten(1, (self._head_groups, self._config.state_size)) - .repeat_interleave(self._group_heads, 1, output_size=self._heads) + x.unflatten(1, (self._local_head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) ) # b: (batch, sequence, head_groups * state) -> (batch, heads, state, sequence) b = ( b.transpose(1, 2) - .unflatten(1, (self._head_groups, self._config.state_size)) - .repeat_interleave(self._group_heads, 1, output_size=self._heads) + .unflatten(1, (self._local_head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) ) # c: (batch, sequence, heads * state) -> (batch, heads, state, sequence) - c = c.transpose(1, 2).unflatten(1, (self._heads, self._config.state_size)) + c = c.transpose(1, 2).unflatten(1, (self._local_heads, self._config.state_size)) # dt: (batch, sequence, dt_rank) -> (batch, heads * state, sequence) dt = (self.dt_proj(dt) + self.dt_proj_bias).transpose(1, 2) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 25ae49a31..6995e9e94 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -184,6 +184,7 @@ def local_to_global( tensor = tensor.clone() tensor = reduce_op(tensor, distributed_dim.group, op=op) is_first_rank, modified = is_first_rank and distributed_dim.group.rank() == 0, True + Assert.eq(tensor.shape, self.global_shape) return tensor, is_first_rank def global_to_local( @@ -204,6 +205,7 @@ def global_to_local( for dim, tensor_dim in reversed(list(enumerate(self.dims))): tensor = tensor_dim.global_to_local(tensor, dim, expand) + Assert.eq(tensor.shape, self.shape) return tensor @classmethod From 0014cc6b3f79138e53610dc86cb654a5eaba90a0 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 23 Jul 2025 18:02:43 -0400 Subject: [PATCH 06/82] fix --- fast_llm/layers/ssm/discrete_mamba2.py | 27 +++----- fast_llm/layers/ssm/llamba_block.py | 11 ++- fast_llm/layers/ssm/mamba2.py | 53 ++++++++++---- fast_llm/layers/ssm/mamba_layer.py | 11 +-- fast_llm/layers/transformer/attention.py | 69 ++++--------------- .../layers/transformer/mixture_of_experts.py | 6 +- fast_llm/layers/transformer/mlp.py | 10 +-- fast_llm/layers/transformer/transformer.py | 63 ++++++++++++++--- fast_llm/models/custom/model.py | 2 +- fast_llm/models/gpt/model.py | 4 +- fast_llm/models/ssm/model.py | 8 +-- tests/utils/model_configs.py | 6 +- 12 files changed, 154 insertions(+), 116 deletions(-) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 102accb85..b95ff76da 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -8,7 +8,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ from fast_llm.utils import get_lr_scale @@ -37,28 +38,23 @@ def bias_init_method(conv_weight): return init_uniform_centered_(bound) -class DiscreteMamba2(torch.nn.Module): +class DiscreteMamba2(Mixer): """DiscreteMamba2 (This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py).""" + _mixer_name: typing.ClassVar[str] = "discrete_mamba_2" + def __init__( self, config: SSMConfig, - layer_idx: int, + block_index: int, tensor_space: TensorSpace, - return_input: bool = False, + transformer_config: TransformerConfig, ): - """ - See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. - Other options are all experimental and should not need to be configured. - """ - # factory_kwargs = {"device": "meta"} # , "dtype": torch.bfloat16} - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) self.config: SSMConfig = config - self.layer_idx = layer_idx - self._return_input = return_input - layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) - logger.info(f"Setting lr_scale for layer {layer_idx} of type {type(self)}: {mamba_layer_lr_scale}") + logger.info(f"Setting lr_scale for layer {block_index} of type {type(self)}: {mamba_layer_lr_scale}") td_inner = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) td_state = tensor_space.get_tensor_dim(SSMDimNames.state) @@ -223,9 +219,6 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ out = self.out_proj(y * torch.nn.functional.silu(z + self.z_bias)) outputs["hidden_states"] = out[:, :seqlen, :].contiguous() - if self._return_input: - return torch.stack([input_, outputs["hidden_states"]], dim=0) - # TODO: since we do not support inference for now, we only return the hidden states for now. return outputs["hidden_states"], None diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/llamba_block.py index 774ee7303..986606634 100644 --- a/fast_llm/layers/ssm/llamba_block.py +++ b/fast_llm/layers/ssm/llamba_block.py @@ -21,12 +21,17 @@ def __init__( ssm_config: "SSMConfig", tensor_space: "TensorSpace", mixer_cls: type[Mixer], - layer_index: int, + block_index: int, return_input: bool = False, ): self._ssm_config = ssm_config self._mixer_cls = mixer_cls - super().__init__(transformer_config, tensor_space, layer_index, return_input) + super().__init__(transformer_config, tensor_space, block_index, return_input) def _create_mixer(self) -> Mixer: - return self._mixer_cls(self._ssm_config, layer_idx=self._layer_index, tensor_space=self._tensor_space) + return self._mixer_cls( + self._ssm_config, + tensor_space=self._tensor_space, + block_index=self._block_index, + transformer_config=self._config, + ) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index a285711c6..88fe4abc0 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -1,4 +1,5 @@ import logging +import typing import torch @@ -7,7 +8,7 @@ from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_ from fast_llm.utils import Assert, div, get_lr_scale @@ -34,16 +35,31 @@ class Mamba2(Mixer): This code is adapted from https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py """ + _mixer_name: typing.ClassVar[str] = "mamba_2" + + _XZ_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.composite_heads_and_state, + TransformerDimNames.sequence_q, + ) + _BC_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.composite_heads, + SSMDimNames.state, + TransformerDimNames.sequence_q, + ) + def __init__( self, config: SSMConfig, - layer_idx: int, tensor_space: TensorSpace, + block_index: int, + transformer_config: TransformerConfig, ): - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) self._config: SSMConfig = config Assert.eq(self._config.activation_type, ActivationType.silu) - layer_lr_scale: float | None = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_state) @@ -72,7 +88,8 @@ def __init__( hidden_dim, tensor_space.get_tensor_dim(name=SSMDimNames.concatenated_inner_projection), bias=config.add_bias_linear, - weight_init_method=init_kaiming_(hidden_dim.global_size), + weight_init_method=init_kaiming_(transformer_config.hidden_size), + sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) @@ -80,7 +97,7 @@ def __init__( hidden_dim, dt_rank_dim, bias=config.add_bias_linear, - weight_init_method=init_kaiming_(hidden_dim.global_size), + weight_init_method=init_kaiming_(transformer_config.hidden_size), lr_scale=lr_scale, ) self.dt_proj = OutputParallelLinear( @@ -91,6 +108,7 @@ def __init__( weight_init_method=self._config.dt_init.get_init_method( self._config.dt_rank**-0.5 * self._config.dt_scale ), + sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) # define bias outside the linear layer since its also used in the selective_scan_fn @@ -116,6 +134,8 @@ def __init__( hidden_dim, bias=config.add_bias_linear, weight_init_method=init_kaiming_(self._config.d_inner), + sequence_parallel=self._sequence_parallel, + # TODO: lr_scale? ) def forward(self, hidden_states, kwargs): @@ -123,11 +143,12 @@ def forward(self, hidden_states, kwargs): assert _causal_conv1d_available inner_projection = self.in_proj(hidden_states) - dt = self.dt_in_proj(hidden_states) + dt = self.dt_proj(self.dt_in_proj(hidden_states)) + self.dt_proj_bias # Standardize to (batch, sequence, inner_projection) if kwargs[TransformerKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) dt = dt.transpose(0, 1) + sequence_length = inner_projection.size(1) z, x, b, c = torch.split( @@ -166,8 +187,15 @@ def forward(self, hidden_states, kwargs): # c: (batch, sequence, heads * state) -> (batch, heads, state, sequence) c = c.transpose(1, 2).unflatten(1, (self._local_heads, self._config.state_size)) - # dt: (batch, sequence, dt_rank) -> (batch, heads * state, sequence) - dt = (self.dt_proj(dt) + self.dt_proj_bias).transpose(1, 2) + # dt: (batch, sequence, heads * state) -> (batch, heads * state, sequence) + dt = dt.transpose(1, 2) + + if self._debug_level: + self._debug_log(z, "z", self._XZ_DIMS, kwargs) + self._debug_log(x, "x", self._XZ_DIMS, kwargs) + self._debug_log(b, "b", self._BC_DIMS, kwargs) + self._debug_log(c, "c", self._BC_DIMS, kwargs) + self._debug_log(dt, "dt", self._XZ_DIMS, kwargs) y = selective_scan_fn( x, @@ -181,11 +209,12 @@ def forward(self, hidden_states, kwargs): delta_softplus=True, ) + if self._debug_level: + self._debug_log(y, "y", self._XZ_DIMS, kwargs) + # y: (batch, heads * state, sequence) -> (batch, sequence, heads * state) y = y.transpose(1, 2)[:, :sequence_length] if kwargs[TransformerKwargs.sequence_first]: # TODO: Is contiguous needed? y = y.transpose(0, 1).contiguous() - a, b = self.out_proj(y) - Assert.eq(a.shape, hidden_states.shape) - return a, b + return self.out_proj(y) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 49b9e45b7..49afa910e 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -8,7 +8,7 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_ from fast_llm.utils import Assert, get_lr_scale @@ -58,13 +58,16 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) class MambaLayer(Mixer): + _mixer_name: typing.ClassVar[str] = "mamba" + def __init__( self, config: SSMConfig, - layer_idx: int, + block_index: int, tensor_space: TensorSpace, + transformer_config: TransformerConfig, ): - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) assert tensor_space.distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" self._config = config # TODO: It's not silu? @@ -73,7 +76,7 @@ def __init__( # Tensor dims: inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) # TODO: Backward compatibility? diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 76b8ed1ca..174e19588 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -14,9 +14,8 @@ TransformerSubLayerName, ) from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.logging import log_distributed_grad, log_distributed_tensor -from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.tensor import init_normal_, init_zeros_ +from fast_llm.utils import get_lr_scale try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -56,6 +55,8 @@ class Attention(Mixer): A self-attention layer. """ + _mixer_name: typing.ClassVar[str] = "attn" + _QUERY_DIMS = ( TransformerDimNames.batch, TransformerDimNames.sequence_q, @@ -65,7 +66,7 @@ class Attention(Mixer): _KV_DIMS = ( TransformerDimNames.batch, TransformerDimNames.sequence_q, - TransformerDimNames.group_heads, + TransformerDimNames.head_groups, TransformerDimNames.kv_channels, ) _CONTEXT_DIMS = ( @@ -74,19 +75,9 @@ class Attention(Mixer): TransformerDimNames.composite_dense, ) - def __init__( - self, - config: TransformerConfig, - tensor_space: TensorSpace, - layer_index, - ): - super().__init__() + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int): + super().__init__(tensor_space, block_index, config.debug_transformer) self._config = config - self._tensor_space = tensor_space - # Assert.in_range_incl(layer_index, 1, max(self._config.num_layers, 1)) - self._layer_index = layer_index - self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel - self._debug_transformer = self._config.debug_transformer self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) init_method_qkv = init_normal_( @@ -109,7 +100,7 @@ def __init__( hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) @@ -179,10 +170,10 @@ def _attn_fused( query, key, beta=0, - alpha=self._softmax_scale / self._layer_index, + alpha=self._softmax_scale / self._block_index, ).view(b, self._local_head_groups, sq, self._local_heads_per_group, sk) - attn_weights = attn_weights.to(torch.float32) * self._layer_index + attn_weights = attn_weights.to(torch.float32) * self._block_index attn_weights = torch.where(mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) @@ -201,40 +192,6 @@ def _attn_fused( .flatten(2) ) - def _get_meta( - self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> TensorMeta: - hidden_dims = {dim.name: dim for dim in kwargs[TransformerKwargs.hidden_dims]} - return TensorMeta.from_dims( - tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) - for dim_name in dim_names - ), - tensor_name=f"transformer layer {self._layer_index} attn {name}", - dtype=input_.dtype, - ) - - def _debug_log( - self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> None: - # TODO: Local vs global - Assert.gt(self._debug_transformer, 0) - log_distributed_tensor( - "", - tensor, - level=self._debug_transformer, - meta=self._get_meta(tensor, name, dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=self._debug_transformer, - meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - def _query_key_value_forward( self, input_: torch.Tensor, sequence_first: bool ) -> tuple[torch.Tensor, torch.Tensor, dict[str, typing.Any]]: @@ -301,7 +258,7 @@ def _decide_window_size(self) -> int | None: # https://github.com/huggingface/transformers/blob/5e2183f344911aa82aba0b83778a4f196cff378e/src/transformers/models/qwen2/modular_qwen2.py#L71 # TODO: make universal per layer config window_size = self._config.window_size - if self._config.max_window_layers is not None and self._layer_index < self._config.max_window_layers: + if self._config.max_window_layers is not None and self._block_index < self._config.max_window_layers: window_size = None return window_size @@ -342,7 +299,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels) value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels) - if self._debug_transformer: + if self._debug_level: self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs) self._debug_log( key, @@ -396,7 +353,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ kwargs[TransformerKwargs.attention_mask_value], ) - if self._debug_transformer: + if self._debug_level: self._debug_log(query, "query", self._QUERY_DIMS, kwargs) self._debug_log( key, diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index a46af1387..73f83ccf5 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -40,11 +40,11 @@ class MixtureOfExpertMLP(MLPBase): _group: ProcessGroup - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, tensor_space, name, layer_index) + super().__init__(config, tensor_space, name, block_index) self._config = config self._tensor_space = tensor_space self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory @@ -59,7 +59,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._z_loss_factor = config.expert_z_loss_coefficient self._moe_jitter_eps = config.moe_jitter_eps - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) self.router = Linear( diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index b01eb2aa5..efe0c5cc5 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -14,10 +14,10 @@ class MLPBase(Layer, ABC): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): super().__init__() self._name = name - self._layer_index = layer_index + self._block_index = block_index init_method_1 = init_normal_( std=config.init_method_std_mlp_1, @@ -39,7 +39,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._activation_type = config.activation_type self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale lr_scale = get_lr_scale(lr_scale, layer_lr_scale) @@ -69,9 +69,9 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s class MLP(MLPBase): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.eq(config.num_experts, 1) - super().__init__(config, tensor_space, name, layer_index) + super().__init__(config, tensor_space, name, block_index) def forward( self, diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index a0611cd29..d08db9a94 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -13,6 +13,7 @@ from fast_llm.layers.transformer.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -22,6 +23,15 @@ class Mixer(torch.nn.Module, abc.ABC): Base class for mixer modules. """ + _mixer_name: typing.ClassVar[str] + + def __init__(self, tensor_space: TensorSpace, block_index: int, debug_level: int = 0): + super().__init__() + self._tensor_space = tensor_space + self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel + self._block_index = block_index + self._debug_level = debug_level + @abc.abstractmethod def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: """ @@ -29,6 +39,43 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ in case its addition can be made more efficient in `_bias_dropout_add`. """ + def _get_meta( + self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> TensorMeta: + hidden_dims = { + dim.name: dim + for dim in kwargs[TransformerKwargs.hidden_dims] + (kwargs[TransformerKwargs.sequence_q_dim],) + } + return TensorMeta.from_dims( + tuple( + hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) + for dim_name in dim_names + ), + tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", + dtype=input_.dtype, + ) + + def _debug_log( + self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> None: + # TODO: Local vs global + Assert.gt(self._debug_level, 0) + log_distributed_tensor( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name, dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + if tensor.requires_grad: + log_distributed_grad( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + class BaseBlock(Layer, abc.ABC): """ @@ -39,7 +86,7 @@ class BaseBlock(Layer, abc.ABC): _mixer_module_name: typing.ClassVar[str] = "mixer" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False ): super().__init__() self._config: TransformerConfig = config @@ -48,11 +95,11 @@ def __init__( # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input - self._layer_index = layer_index + self._block_index = block_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) # Note, layer_lr_scale does not impact the norms - # TODO: add a seperate norm_lr_scale + # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) @@ -60,7 +107,7 @@ def __init__( setattr(self, self._mixer_module_name, self._create_mixer()) self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._tensor_space, f"{self.name} mlp", layer_index=layer_index + self._config, self._tensor_space, f"{self.name} mlp", block_index=block_index ) # PEFT. @@ -81,7 +128,7 @@ def _bias_dropout_add( @property def name(self) -> str: - return f"{self._name} {self._layer_index}" + return f"{self._name} {self._block_index}" def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[TransformerKwargs.hidden_dims] @@ -157,11 +204,11 @@ class TransformerBlock(BaseBlock): _mixer_module_name: typing.ClassVar[str] = "self_attn" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False ): - super().__init__(config, tensor_space, layer_index, return_input) + super().__init__(config, tensor_space, block_index, return_input) def _create_mixer(self) -> Mixer: from fast_llm.layers.transformer.attention import Attention - return Attention(self._config, self._tensor_space, self._layer_index) + return Attention(self._config, self._tensor_space, self._block_index) diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index a9cf3bb8c..534d813ff 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -34,7 +34,7 @@ def get_layers(self) -> list[Layer]: TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, ) for i in range(self._config.transformer.num_layers) ], diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index a3a68e0a6..4c1eab46f 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -72,7 +72,7 @@ def get_output_layers(self) -> list[Layer]: self._config.transformer, self._tensor_space, # TODO MTP: which index? - layer_index=max(self._config.transformer.num_layers + i, 1), + block_index=max(self._config.transformer.num_layers + i, 1), # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=i < self._config.prediction_heads - 1, @@ -94,7 +94,7 @@ def get_layers(self) -> list[Layer]: TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 4a95891a7..89f0cd4aa 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -45,7 +45,7 @@ def get_output_layers(self) -> list[Layer]: TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=len(self._config.hybrid_block_layout), + block_index=len(self._config.hybrid_block_layout), return_input=i != self._config.prediction_heads - 1, ) ) @@ -55,7 +55,7 @@ def get_output_layers(self) -> list[Layer]: transformer_config=self._config.transformer, ssm_config=self._config.ssm, mixer_cls=self._config.ssm_block_type.get_mixer_class(), - layer_index=len(self._config.hybrid_block_layout), + block_index=len(self._config.hybrid_block_layout), tensor_space=self._tensor_space, return_input=i != self._config.prediction_heads - 1, ) @@ -79,7 +79,7 @@ def get_layers(self) -> list[Layer]: TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, return_input=( i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 ), @@ -91,7 +91,7 @@ def get_layers(self) -> list[Layer]: transformer_config=self._config.transformer, ssm_config=self._config.ssm, mixer_cls=self._config.ssm_block_type.get_mixer_class(), - layer_index=i + 1, + block_index=i + 1, tensor_space=self._tensor_space, return_input=( i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 47314263b..4090e5a38 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -517,6 +517,7 @@ def _update_and_add_testing_config( "model.base_model.ssm.d_inner=512", "model.base_model.ssm.state_size=8", "model.base_model.ssm.d_xb=256", + # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}" ], megatron_args=None, checkpoint_format=None, @@ -530,7 +531,10 @@ def _update_and_add_testing_config( }, compare_factor=2.0, # Micro-sequence split not supported. - skip_tests=("sdp", "ms"), + skip_tests=( + "sdp", + "ms", + ), # "pp","dp", "ce","16", "bf", "df", "stp"), ) From 47ad5485454236d557570a32771c5888bbb3658e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 23 Jul 2025 19:03:01 -0400 Subject: [PATCH 07/82] fixes --- Megatron-LM | 2 +- fast_llm/layers/language_model/head.py | 16 ++++++++++------ fast_llm/logging.py | 2 ++ fast_llm/tensor.py | 3 ++- tests/test_attention.py | 4 ++-- tests/utils/model_configs.py | 2 +- 6 files changed, 18 insertions(+), 11 deletions(-) diff --git a/Megatron-LM b/Megatron-LM index 511e8f5cb..75b0d9787 160000 --- a/Megatron-LM +++ b/Megatron-LM @@ -1 +1 @@ -Subproject commit 511e8f5cbe3ab8291953ac64e5beceb727a1b814 +Subproject commit 75b0d97876006c4b6b23fce302100d18dbf7db37 diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 25fc2b28d..21bf3bbd0 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -125,12 +125,16 @@ def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: if isinstance(input_, TensorMeta): - return TensorMeta.from_tensor_space( - (DefaultDimNames.scalar,), - self._tensor_space, - tensor_name="Loss", - reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa - ) + if self._is_last_head: + return TensorMeta.from_tensor_space( + (DefaultDimNames.scalar,), + self._tensor_space, + tensor_name="Loss", + reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa + ) + else: + return TensorMeta.from_dims(input_.dims[1:], tensor_name="Shared hidden") + if not self._is_last_head: # MTP: split the stacked input shared_hidden, input_ = torch.unbind(input_, dim=0) diff --git a/fast_llm/logging.py b/fast_llm/logging.py index e8334de6e..6d555a0bb 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -138,6 +138,8 @@ def log_tensor[ if level < 1: return tensor = tensor.detach() + if tensor.ndim == 0: + tensor = tensor[None] save_stats = TensorLogs.config.save shape = tuple(tensor.shape) _, dtype = str(tensor.dtype).split("torch.") diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 6995e9e94..899e70005 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -205,7 +205,8 @@ def global_to_local( for dim, tensor_dim in reversed(list(enumerate(self.dims))): tensor = tensor_dim.global_to_local(tensor, dim, expand) - Assert.eq(tensor.shape, self.shape) + if not expand: + Assert.eq(tensor.shape, self.shape) return tensor @classmethod diff --git a/tests/test_attention.py b/tests/test_attention.py index 87b0d3e59..dd36b840a 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -17,12 +17,12 @@ def test_decide_window_size(): # Arrange - Case 1: window_size is returned (layer_index >= max_window_layers) attention._config = TransformerConfig(window_size=512, max_window_layers=2) - attention._layer_index = 2 + attention._block_index = 2 assert attention._decide_window_size() == 512 # Arrange - Case 2: window_size is None (layer_index < max_window_layers) attention._config = TransformerConfig(window_size=512, max_window_layers=2) - attention._layer_index = 1 + attention._block_index = 1 assert attention._decide_window_size() is None # Arrange - Case 3: max_window_layers is None (always return window_size) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 4090e5a38..18db0d401 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -467,7 +467,7 @@ def _update_and_add_testing_config( ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, # TODO: Fix and bring back to `testing_groups` - ModelTestingGroup.convert: ModelTestingGroupAction.broken, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, From 6a074fa3c72bbe16c617a11cff690c543e4c5e86 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 23 Jul 2025 19:50:05 -0400 Subject: [PATCH 08/82] fixes --- fast_llm/layers/ssm/config.py | 2 +- fast_llm/models/ssm/conversion.py | 18 ++++++---- tests/utils/model_configs.py | 55 ++++++++++++++++--------------- 3 files changed, 41 insertions(+), 34 deletions(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 7da4283ba..15a6a8210 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -168,7 +168,7 @@ class SSMConfig(LLMBlockConfig): # Initialization # dt_weight_initialization_method [Mamba2] dt_init: DTInitType = Field( - default="random", + default=DTInitType.random, desc="Initialization method for dt", hint=FieldHint.core, ) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index d57300252..43e3c67e5 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -3,6 +3,7 @@ import pathlib import typing +from fast_llm.config import MISSING from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( ConstantExportParamConverter, @@ -19,7 +20,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import RMSNormalizationConfig -from fast_llm.layers.ssm.config import SSMBlockType +from fast_llm.layers.ssm.config import DTInitType, SSMBlockType from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter from fast_llm.models.ssm.config import ( AprielSSMHHybridHuggingfaceCheckpointFormat, @@ -42,11 +43,11 @@ class HybridModelCheckpointHandler(HuggingfaceStateDictCheckpointHandler): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - block_converter = RenameParamConverter( + block_converter = MappedConfigParamConverter( fast_llm_names=(("hybrid_block_layout",),), export_names=(("hybrid_block_layout",),), - ignore_missing=True, - default_value=[cls._default_block_type], + fast_llm_value=lambda x: [cls._default_block_type] if x == MISSING else x, + export_value=lambda x: [x_.value for x_ in x], ) return super()._create_config_converters() + [block_converter] @@ -202,7 +203,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ignore_missing=True, default_value=4, ), - RenameParamConverter( + MappedConfigParamConverter( fast_llm_names=(("ssm", "dt_init"),), export_names=( ( @@ -210,8 +211,8 @@ def _create_config_converters(cls) -> list[ParamConverter]: "dt_init", ), ), - ignore_missing=True, - default_value="random", + fast_llm_value=lambda x: DTInitType.random if x == MISSING else DTInitType(x), + export_value=lambda x: x.value, ), ] @@ -258,6 +259,9 @@ def _create_weight_converters(self) -> list[WeightConverter]: ) # ================================================ # Mamba2 specific parameters + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.mixer.dt_in_proj", f"model.layers.{i}.mixer.dt_in_proj", ssm_bias + ) converters += self._get_weight_and_bias_converters( f"layers.{i+1}.mixer.dt_proj", f"model.layers.{i}.mixer.dt_proj", False ) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 18db0d401..3ffc3281b 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -19,7 +19,10 @@ Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) -from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat +from fast_llm.models.ssm.config import ( + AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, + LLambaHuggingfaceCheckpointFormat, +) from tests.utils.dataset import MODEL_DATASET_PREFIX, MODEL_TEST_VOCAB_SIZE from tests.utils.distributed_configs import DistributedTestingConfig @@ -467,7 +470,7 @@ def _update_and_add_testing_config( ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, # TODO: Fix and bring back to `testing_groups` - ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, @@ -477,47 +480,49 @@ def _update_and_add_testing_config( skip_tests=("sdp", "ms"), ) - _update_and_add_testing_config( - # Tests hybrid discrete Mamba 2. + # Tests hybrid Mamba 2. "llama", - "hybrid_discrete_mamba2", + "hybrid_mamba2", model_type="hybrid_ssm", extra_args=[ - "model.base_model.hybrid_block_layout=['t','m2d']", + "model.base_model.hybrid_block_layout=['t','m2']", "model.base_model.ssm.d_inner=512", "model.base_model.ssm.state_size=8", - "model.base_model.ssm.n_qk_heads=8", - "model.base_model.ssm.n_v_heads=16", - "model.base_model.ssm.chunk_size=32", + "model.base_model.ssm.d_xb=256", + # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}" ], megatron_args=None, - checkpoint_format=None, + checkpoint_format=AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - # TODO: Implement - ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, compare_factor=2.0, - # Micro-sequence split and sequence-first not supported. - skip_tests=("sf", "stp", "sdp", "ms"), + # Micro-sequence split not supported. + skip_tests=( + "sdp", + "ms", + ), # "pp","dp", "ce","16", "bf", "df", "stp"), ) + _update_and_add_testing_config( - # Tests hybrid Mamba 2. + # Tests hybrid discrete Mamba 2. "llama", - "hybrid_mamba2", + "hybrid_discrete_mamba2", model_type="hybrid_ssm", extra_args=[ - "model.base_model.hybrid_block_layout=['t','m2']", + "model.base_model.hybrid_block_layout=['t','m2d']", "model.base_model.ssm.d_inner=512", "model.base_model.ssm.state_size=8", - "model.base_model.ssm.d_xb=256", - # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}" + "model.base_model.ssm.n_qk_heads=8", + "model.base_model.ssm.n_v_heads=16", + "model.base_model.ssm.chunk_size=32", ], megatron_args=None, checkpoint_format=None, @@ -527,14 +532,12 @@ def _update_and_add_testing_config( ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.normal, + # TODO: Implement + ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, }, compare_factor=2.0, - # Micro-sequence split not supported. - skip_tests=( - "sdp", - "ms", - ), # "pp","dp", "ce","16", "bf", "df", "stp"), + # Micro-sequence split and sequence-first not supported. + skip_tests=("sf", "stp", "sdp", "ms"), ) From d66651f5433392794d1b45560282d9237824881d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 23 Jul 2025 19:56:19 -0400 Subject: [PATCH 09/82] Update external --- .../modeling_ssm_hybrid_apriel15b.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index f8f6a0520..4fde72458 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -843,9 +843,8 @@ def __init__( self.num_C_head = self.d_inner // self.d_state self.repeat_group = self.num_C_head // self.num_xb_head - self.in_proj = nn.Linear( - self.d_model, 2 * self.d_xb + 2 * self.d_inner + self.dt_rank, bias=bias, **factory_kwargs - ) + self.in_proj = nn.Linear(self.d_model, 2 * self.d_xb + 2 * self.d_inner, bias=bias, **factory_kwargs) + self.dt_in_proj = nn.Linear(self.d_model, self.dt_rank, bias=bias, **factory_kwargs) self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=dt_proj_bias, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization @@ -933,8 +932,17 @@ def forward( outputs = {} A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - zxbcdt = self.in_proj(hidden_states) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) + zxbc = self.in_proj(hidden_states) + z, x, B, C = torch.split( + zxbc, + [ + self.d_inner, + self.d_xb, + self.d_xb, + self.d_inner, + ], + dim=-1, + ) x = rearrange(x, "b l d -> b d l") z = rearrange(z, "b l d -> b d l") @@ -944,7 +952,7 @@ def forward( B = rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() C = rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() - dt = self.dt_proj(dt) # B, L, d_inner + dt = self.dt_proj(self.dt_in_proj(hidden_states)) # B, L, d_inner dt = rearrange(dt, "b l d -> b d l") # B, d_inner, L if self.repeat_kv_before_conv: From 50083ba88a0bfa58747d2bc8307814b62af1a79a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 15:14:13 -0400 Subject: [PATCH 10/82] SSM debugging --- Megatron-LM | 2 +- fast_llm/engine/multi_stage/stage_base.py | 2 + fast_llm/layers/language_model/head.py | 16 ++- fast_llm/layers/ssm/config.py | 34 +++--- fast_llm/layers/ssm/discrete_mamba2.py | 23 ++-- fast_llm/layers/ssm/llamba_block.py | 29 +++-- fast_llm/layers/ssm/mamba2.py | 38 ++++-- fast_llm/layers/ssm/mamba_layer.py | 36 +++--- fast_llm/layers/transformer/attention.py | 72 +++-------- .../layers/transformer/mixture_of_experts.py | 6 +- fast_llm/layers/transformer/mlp.py | 10 +- fast_llm/layers/transformer/transformer.py | 94 ++++++++++++--- fast_llm/logging.py | 2 + fast_llm/models/gpt/model.py | 12 +- fast_llm/models/ssm/config.py | 40 +++---- fast_llm/models/ssm/model.py | 113 +++++------------- setup.cfg | 7 +- tests/data/test_blending.py | 1 + tests/data/test_concatenate.py | 1 + tests/data/test_fim.py | 2 + tests/test_attention.py | 4 +- tests/test_multi_stage.py | 8 +- tests/utils/model_configs.py | 1 + 23 files changed, 271 insertions(+), 282 deletions(-) diff --git a/Megatron-LM b/Megatron-LM index 511e8f5cb..75b0d9787 160000 --- a/Megatron-LM +++ b/Megatron-LM @@ -1 +1 @@ -Subproject commit 511e8f5cbe3ab8291953ac64e5beceb727a1b814 +Subproject commit 75b0d97876006c4b6b23fce302100d18dbf7db37 diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 2f18f1360..9a8ce2092 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -191,6 +191,8 @@ def initialize_weights(self) -> None: # Initialize all global weights on every gpu, then select the appropriate slice if applicable. global_param = parameter.new_empty(global_shape, device=self._distributed.device) meta.init_parameter(global_param, distributed=self._distributed) + # It happens. + Assert.eq(global_param.shape, global_shape) if self._mode.on_device: parameter.copy_(fsdp.parameter_global_to_shard(global_param, meta.tensor_name)) elif self._mode.on_device: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 25fc2b28d..21bf3bbd0 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -125,12 +125,16 @@ def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: if isinstance(input_, TensorMeta): - return TensorMeta.from_tensor_space( - (DefaultDimNames.scalar,), - self._tensor_space, - tensor_name="Loss", - reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa - ) + if self._is_last_head: + return TensorMeta.from_tensor_space( + (DefaultDimNames.scalar,), + self._tensor_space, + tensor_name="Loss", + reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa + ) + else: + return TensorMeta.from_dims(input_.dims[1:], tensor_name="Shared hidden") + if not self._is_last_head: # MTP: split the stacked input shared_hidden, input_ = torch.unbind(input_, dim=0) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 46d629aa8..a1f357de9 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -23,6 +23,7 @@ class SSMDimNames: # Mamba 2 x_proj_dim_2 = "x_proj_dim_2" # d_xb + c_heads = "c_heads" class SSMBlockType(enum.StrEnum): @@ -35,6 +36,22 @@ class SSMBlockType(enum.StrEnum): mamba2 = "m2" transformer = "t" + def get_mixer_class(self): + if self == SSMBlockType.mamba: + from fast_llm.layers.ssm.mamba_layer import MambaLayer + + return MambaLayer + elif self == SSMBlockType.mamba2: + from fast_llm.layers.ssm.mamba2 import Mamba2 + + return Mamba2 + elif self == SSMBlockType.mamba2_discrete: + from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 + + return DiscreteMamba2 + else: + raise NotImplementedError(self) + @config_class() class SSMConfig(LLMBlockConfig): @@ -95,11 +112,6 @@ class SSMConfig(LLMBlockConfig): desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", hint=FieldHint.architecture, ) - debug_ssm: bool = Field( - default=False, - desc="debug_ssm", - hint=FieldHint.optional, - ) dt_min: float = Field( default=0.001, desc="Minimum step size for discretization", @@ -147,18 +159,6 @@ class SSMConfig(LLMBlockConfig): hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - dt_min: float = Field( - default=0.001, - desc="Minimum step size for discretization", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), - ) - dt_init_floor: float = Field( - default=1e-4, - desc="Minimum value for initializing dt", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), - ) dt_scale: float = Field( default=1.0, desc="Scale for dt", diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 934cd2b5d..734e35b21 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -1,5 +1,6 @@ import logging import math +import typing import einops import torch @@ -7,7 +8,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs +from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ from fast_llm.utils import get_lr_scale @@ -36,29 +38,29 @@ def bias_init_method(conv_weight): return init_uniform_(-bound, bound) -class DiscreteMamba2(torch.nn.Module): +class DiscreteMamba2(Mixer): """DiscreteMamba2 (This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py).""" + _mixer_name: typing.ClassVar[str] = "discrete_mamba_2" + def __init__( self, config: SSMConfig, - layer_idx: int, + block_index: int, tensor_space: TensorSpace, - return_input: bool = False, + transformer_config: TransformerConfig, ): """ See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. Other options are all experimental and should not need to be configured. """ # factory_kwargs = {"device": "meta"} # , "dtype": torch.bfloat16} - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) self.config: SSMConfig = config bias = config.add_bias_linear - self.layer_idx = layer_idx - self._return_input = return_input - layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) - logger.info(f"Setting lr_scale for layer {layer_idx} of type {type(self)}: {mamba_layer_lr_scale}") + logger.info(f"Setting lr_scale for layer {block_index} of type {type(self)}: {mamba_layer_lr_scale}") td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) @@ -226,9 +228,6 @@ def forward(self, hidden_states, kwargs): out = self.out_proj(y * torch.nn.functional.silu(z + self.z_bias)) outputs["hidden_states"] = out[:, :seqlen, :].contiguous() - if self._return_input: - return torch.stack([input_, outputs["hidden_states"]], dim=0) - # TODO: since we do not support inference for now, we only return the hidden states for now. return outputs["hidden_states"], None diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/llamba_block.py index ee222d6d2..986606634 100644 --- a/fast_llm/layers/ssm/llamba_block.py +++ b/fast_llm/layers/ssm/llamba_block.py @@ -1,6 +1,6 @@ import typing -from fast_llm.layers.transformer.transformer import BaseBlock +from fast_llm.layers.transformer.transformer import BaseBlock, Mixer if typing.TYPE_CHECKING: from fast_llm.engine.config_utils.tensor_space import TensorSpace @@ -8,27 +8,30 @@ from fast_llm.layers.transformer.config import TransformerConfig -class LlambaBlock(BaseBlock): +class SSMBlock(BaseBlock): """ A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458 """ _name = "Llamba block" - _mixer_module_name = "mixer" def __init__( self, - config_transformer: "TransformerConfig", - config_ssm: "SSMConfig", + transformer_config: "TransformerConfig", + ssm_config: "SSMConfig", tensor_space: "TensorSpace", - mixer_cls, - layer_index: int, + mixer_cls: type[Mixer], + block_index: int, return_input: bool = False, ): - self.mixer_cls = mixer_cls - self._config_ssm = config_ssm - self._debug_mode = self._config_ssm.debug_ssm - super().__init__(config_transformer, tensor_space, layer_index, return_input) + self._ssm_config = ssm_config + self._mixer_cls = mixer_cls + super().__init__(transformer_config, tensor_space, block_index, return_input) - def _create_mixer(self): - self.mixer = self.mixer_cls(self._config_ssm, layer_idx=self._layer_index, tensor_space=self._tensor_space) + def _create_mixer(self) -> Mixer: + return self._mixer_cls( + self._ssm_config, + tensor_space=self._tensor_space, + block_index=self._block_index, + transformer_config=self._config, + ) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index a03509abb..ead32fa2a 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -7,6 +7,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames +from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_fill_, init_ones_, init_uniform_, kaiming_init_ from fast_llm.utils import get_lr_scale @@ -43,24 +45,36 @@ def bias_init_method(conv_weight): return init_uniform_(-bound, bound) -class Mamba2(torch.nn.Module): +class Mamba2(Mixer): """ This code is adapted from https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py """ + _mixer_name: typing.ClassVar[str] = "mamba_2" + + _XZ_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.inner_dim, + TransformerDimNames.sequence_q, + ) + _BC_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.c_heads, + SSMDimNames.state_dim, + TransformerDimNames.sequence_q, + ) + def __init__( self, config: SSMConfig, - layer_idx: int, tensor_space: TensorSpace, - return_input: bool = False, + block_index: int, + transformer_config: TransformerConfig, ): - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) self.config: SSMConfig = config bias: bool = config.add_bias_linear - self.layer_idx = layer_idx - self._return_input = return_input - layer_lr_scale: float | None = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None mamba_layer_lr_scale: float | tuple[float | None, ...] | None = get_lr_scale( self.config.mamba_lr_scale, layer_lr_scale ) @@ -236,6 +250,13 @@ def forward(self, hidden_states, kwargs): x = repeat_kv(x, self.repeat_group) x = einops.rearrange(x, "b n_group l dstate -> b (n_group dstate) l") + if self._debug_level: + self._debug_log(z, "z", self._XZ_DIMS, kwargs) + self._debug_log(x, "x", self._XZ_DIMS, kwargs) + self._debug_log(B, "b", self._BC_DIMS, kwargs) + self._debug_log(C, "c", self._BC_DIMS, kwargs) + self._debug_log(dt, "dt", self._XZ_DIMS, kwargs) + y = selective_scan_fn( x, dt, @@ -249,6 +270,9 @@ def forward(self, hidden_states, kwargs): return_last_state=False, ) + if self._debug_level: + self._debug_log(y, "y", self._XZ_DIMS, kwargs) + if ssm_state is not None: y, last_state = y ssm_state.copy_(einops.rearrange(last_state, "b (h d) n -> b h d n", h=self.num_C_head)) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 7c824d235..a95e94c03 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -1,4 +1,5 @@ import math +import typing from typing import Callable import einops @@ -7,6 +8,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_ from fast_llm.utils import get_lr_scale @@ -44,12 +47,12 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) def init_dtprojbias( - d_inner: int, dt_max: float, dt_min: float, dt_init_floor: float, factory_kwargs: dict + d_inner: int, dt_max: float, dt_min: float, dt_init_floor: float ) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - dt = torch.exp( - torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) - ).clamp(min=dt_init_floor) + dt = torch.exp(torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)).clamp( + min=dt_init_floor + ) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) tensor.copy_(inv_dt) @@ -58,20 +61,18 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) return init_ -class MambaLayer(torch.nn.Module): +class MambaLayer(Mixer): + _mixer_name: typing.ClassVar[str] = "mamba" + def __init__( self, config: SSMConfig, - layer_idx: int, + block_index: int, tensor_space: TensorSpace, - return_input: bool = False, + transformer_config: TransformerConfig, ): - factory_kwargs = {} - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) self.config: SSMConfig = config - self.layer_idx = layer_idx - - self._debug_mode = config.debug_ssm # Tensor dims: td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) @@ -88,7 +89,7 @@ def __init__( self.d_state = td_state.size self.d_model = td_model.size self.dt_rank = tdt_rank.size - layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) self.in_proj_weight = ParameterMeta.from_dims( @@ -113,7 +114,6 @@ def __init__( weight_init_method=kaiming_init_(td_inner.size), bias=False, lr_scale=mamba_layer_lr_scale, - **factory_kwargs, ) self.x_proj.weight.auto_grad_accumulation = True @@ -127,7 +127,7 @@ def __init__( self.dt_proj_bias = ParameterMeta.from_dims( (td_inner,), init_method=init_dtprojbias( - self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor, factory_kwargs + self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor ), lr_scale=mamba_layer_lr_scale, ) @@ -153,10 +153,8 @@ def __init__( bias=False, # TODO: note, if bias is used there is a problem in the MambaInnerFn.backward for the bias grads. I think this bias is not used in other mamba repos. weight_init_method=kaiming_init_(td_model.size), lr_scale=mamba_layer_lr_scale, - **factory_kwargs, ) self.out_proj.weight.auto_grad_accumulation = True - self._return_input = return_input def forward(self, hidden_states, kwargs): assert _mamba_available @@ -168,8 +166,6 @@ def forward(self, hidden_states, kwargs): "d (b l) -> b d l", l=seqlen, ) - if self._debug_mode: - print("XZ: ", xz.shape) A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # In the backward pass we write dx and dz next to each other to avoid torch.cat @@ -189,6 +185,4 @@ def forward(self, hidden_states, kwargs): delta_bias=self.dt_proj_bias.float(), delta_softplus=True, ) - if self._return_input: - out = torch.stack((hidden_states, out), dim=0) return out, None diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 3351c9906..174e19588 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -13,9 +13,9 @@ TransformerKwargs, TransformerSubLayerName, ) -from fast_llm.logging import log_distributed_grad, log_distributed_tensor -from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.layers.transformer.transformer import Mixer +from fast_llm.tensor import init_normal_, init_zeros_ +from fast_llm.utils import get_lr_scale try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -50,11 +50,13 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention(torch.nn.Module): +class Attention(Mixer): """ A self-attention layer. """ + _mixer_name: typing.ClassVar[str] = "attn" + _QUERY_DIMS = ( TransformerDimNames.batch, TransformerDimNames.sequence_q, @@ -64,7 +66,7 @@ class Attention(torch.nn.Module): _KV_DIMS = ( TransformerDimNames.batch, TransformerDimNames.sequence_q, - TransformerDimNames.group_heads, + TransformerDimNames.head_groups, TransformerDimNames.kv_channels, ) _CONTEXT_DIMS = ( @@ -73,19 +75,9 @@ class Attention(torch.nn.Module): TransformerDimNames.composite_dense, ) - def __init__( - self, - config: TransformerConfig, - tensor_space: TensorSpace, - layer_index, - ): - super().__init__() + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int): + super().__init__(tensor_space, block_index, config.debug_transformer) self._config = config - self._tensor_space = tensor_space - # Assert.in_range_incl(layer_index, 1, max(self._config.num_layers, 1)) - self._layer_index = layer_index - self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel - self._debug_transformer = self._config.debug_transformer self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) init_method_qkv = init_normal_( @@ -108,7 +100,7 @@ def __init__( hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) @@ -178,10 +170,10 @@ def _attn_fused( query, key, beta=0, - alpha=self._softmax_scale / self._layer_index, + alpha=self._softmax_scale / self._block_index, ).view(b, self._local_head_groups, sq, self._local_heads_per_group, sk) - attn_weights = attn_weights.to(torch.float32) * self._layer_index + attn_weights = attn_weights.to(torch.float32) * self._block_index attn_weights = torch.where(mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) @@ -200,40 +192,6 @@ def _attn_fused( .flatten(2) ) - def _get_meta( - self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> TensorMeta: - hidden_dims = {dim.name: dim for dim in kwargs[TransformerKwargs.hidden_dims]} - return TensorMeta.from_dims( - tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) - for dim_name in dim_names - ), - tensor_name=f"transformer layer {self._layer_index} attn {name}", - dtype=input_.dtype, - ) - - def _debug_log( - self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> None: - # TODO: Local vs global - Assert.gt(self._debug_transformer, 0) - log_distributed_tensor( - "", - tensor, - level=self._debug_transformer, - meta=self._get_meta(tensor, name, dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=self._debug_transformer, - meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - def _query_key_value_forward( self, input_: torch.Tensor, sequence_first: bool ) -> tuple[torch.Tensor, torch.Tensor, dict[str, typing.Any]]: @@ -300,7 +258,7 @@ def _decide_window_size(self) -> int | None: # https://github.com/huggingface/transformers/blob/5e2183f344911aa82aba0b83778a4f196cff378e/src/transformers/models/qwen2/modular_qwen2.py#L71 # TODO: make universal per layer config window_size = self._config.window_size - if self._config.max_window_layers is not None and self._layer_index < self._config.max_window_layers: + if self._config.max_window_layers is not None and self._block_index < self._config.max_window_layers: window_size = None return window_size @@ -341,7 +299,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels) value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels) - if self._debug_transformer: + if self._debug_level: self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs) self._debug_log( key, @@ -395,7 +353,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ kwargs[TransformerKwargs.attention_mask_value], ) - if self._debug_transformer: + if self._debug_level: self._debug_log(query, "query", self._QUERY_DIMS, kwargs) self._debug_log( key, diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index a46af1387..73f83ccf5 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -40,11 +40,11 @@ class MixtureOfExpertMLP(MLPBase): _group: ProcessGroup - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, tensor_space, name, layer_index) + super().__init__(config, tensor_space, name, block_index) self._config = config self._tensor_space = tensor_space self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory @@ -59,7 +59,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._z_loss_factor = config.expert_z_loss_coefficient self._moe_jitter_eps = config.moe_jitter_eps - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) self.router = Linear( diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index b01eb2aa5..efe0c5cc5 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -14,10 +14,10 @@ class MLPBase(Layer, ABC): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): super().__init__() self._name = name - self._layer_index = layer_index + self._block_index = block_index init_method_1 = init_normal_( std=config.init_method_std_mlp_1, @@ -39,7 +39,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._activation_type = config.activation_type self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale lr_scale = get_lr_scale(lr_scale, layer_lr_scale) @@ -69,9 +69,9 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s class MLP(MLPBase): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.eq(config.num_experts, 1) - super().__init__(config, tensor_space, name, layer_index) + super().__init__(config, tensor_space, name, block_index) def forward( self, diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 147452073..d08db9a94 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -8,25 +8,85 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.transformer.attention import Attention from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert logger = logging.getLogger(__name__) +class Mixer(torch.nn.Module, abc.ABC): + """ + Base class for mixer modules. + """ + + _mixer_name: typing.ClassVar[str] + + def __init__(self, tensor_space: TensorSpace, block_index: int, debug_level: int = 0): + super().__init__() + self._tensor_space = tensor_space + self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel + self._block_index = block_index + self._debug_level = debug_level + + @abc.abstractmethod + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Mixer module forward. Returns the output hidden states and an optional bias, + in case its addition can be made more efficient in `_bias_dropout_add`. + """ + + def _get_meta( + self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> TensorMeta: + hidden_dims = { + dim.name: dim + for dim in kwargs[TransformerKwargs.hidden_dims] + (kwargs[TransformerKwargs.sequence_q_dim],) + } + return TensorMeta.from_dims( + tuple( + hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) + for dim_name in dim_names + ), + tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", + dtype=input_.dtype, + ) + + def _debug_log( + self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> None: + # TODO: Local vs global + Assert.gt(self._debug_level, 0) + log_distributed_tensor( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name, dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + if tensor.requires_grad: + log_distributed_grad( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + + class BaseBlock(Layer, abc.ABC): """ A transformer-like decoder base block with abstract mixer. """ - _mixer_module_name = "self_attn" + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "mixer" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False ): super().__init__() self._config: TransformerConfig = config @@ -35,18 +95,19 @@ def __init__( # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input - self._layer_index = layer_index + self._block_index = block_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) # Note, layer_lr_scale does not impact the norms - # TODO: add a seperate norm_lr_scale + # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) - self._create_mixer() + # The mixer needs to be created here for backward-compatible weight ordering. + setattr(self, self._mixer_module_name, self._create_mixer()) self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._tensor_space, f"{self.name} mlp", layer_index=layer_index + self._config, self._tensor_space, f"{self.name} mlp", block_index=block_index ) # PEFT. @@ -54,7 +115,7 @@ def __init__( self.norm_2 = self._config.peft.apply_other(self.norm_2) @abc.abstractmethod - def _create_mixer(self): + def _create_mixer(self) -> Mixer: pass @torch.compile @@ -67,7 +128,7 @@ def _bias_dropout_add( @property def name(self) -> str: - return f"{self._name} {self._layer_index}" + return f"{self._name} {self._block_index}" def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[TransformerKwargs.hidden_dims] @@ -137,14 +198,17 @@ def forward( return hidden_states -class TransformerLayer(BaseBlock): +class TransformerBlock(BaseBlock): _name = "Transformer layer" - _mixer_module_name = "self_attn" + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "self_attn" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False ): - super().__init__(config, tensor_space, layer_index, return_input) + super().__init__(config, tensor_space, block_index, return_input) + + def _create_mixer(self) -> Mixer: + from fast_llm.layers.transformer.attention import Attention - def _create_mixer(self): - self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) + return Attention(self._config, self._tensor_space, self._block_index) diff --git a/fast_llm/logging.py b/fast_llm/logging.py index e8334de6e..6d555a0bb 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -138,6 +138,8 @@ def log_tensor[ if level < 1: return tensor = tensor.detach() + if tensor.ndim == 0: + tensor = tensor[None] save_stats = TensorLogs.config.save shape = tuple(tensor.shape) _, dtype = str(tensor.dtype).split("torch.") diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 444ad72b2..4c1eab46f 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -21,7 +21,7 @@ TransformerLossNames, ) from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -68,11 +68,11 @@ def get_output_layers(self) -> list[Layer]: for i in range(self._config.prediction_heads): if i > 0: layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, # TODO MTP: which index? - layer_index=max(self._config.transformer.num_layers + i, 1), + block_index=max(self._config.transformer.num_layers + i, 1), # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=i < self._config.prediction_heads - 1, @@ -91,10 +91,10 @@ def get_layers(self) -> list[Layer]: return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, @@ -336,7 +336,7 @@ def embedding(self) -> LanguageModelEmbedding: return self.layers[0] @property - def transformer_layers(self) -> list[TransformerLayer]: + def transformer_layers(self) -> list[TransformerBlock]: return self.layers[1:-1] @property diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index cc83f11be..9ca0123b2 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -9,9 +9,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.language_model.config import LanguageModelBaseConfig from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig, SSMDimNames -from fast_llm.models.gpt.config import GPTBatchConfig, PretrainedGPTModelConfig +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, PretrainedGPTModelConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -24,14 +23,14 @@ @config_class() -class HybridSSMBaseModelConfig(LanguageModelBaseConfig): +class HybridSSMBaseModelConfig(GPTBaseModelConfig): _abstract = False ssm: SSMConfig = Field( desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) - hybrid_block_layout: list[str] | None = Field( + hybrid_block_layout: list[SSMBlockType] | None = Field( default=None, desc=f"Pattern of blocks to use in the model. Available types: {SSMBlockType.__members__.values()}", hint=FieldHint.core, @@ -41,9 +40,8 @@ class HybridSSMBaseModelConfig(LanguageModelBaseConfig): desc="Multi-token prediction mixer to use in the model. If None, will use the last block type in `hybrid_block_layout`.", hint=FieldHint.optional, ) - use_megatron_initialization: bool = Field( - default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing - ) # TODO: is this needed? + # TODO: Support combination of different SSM block types. + ssm_block_type: SSMBlockType | None = Field(init=False) def setup_tensor_space(self, tensor_space: TensorSpace) -> None: """ @@ -83,6 +81,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: inner_proj_dim: int = 2 * self.ssm.d_xb + 2 * d_inner + self.ssm.dt_rank tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba2, inner_proj_dim)) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim_2, self.ssm.d_xb)) + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.c_heads, d_inner // self.ssm.state_size)) def _validate(self): with self._set_implicit_default(None): @@ -96,30 +95,21 @@ def _validate(self): if self.hybrid_block_layout is None: with self._set_implicit_default(): - self.hybrid_block_layout = [SSMBlockType.mamba2_discrete.value] + self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_layers if len(self.hybrid_block_layout) != self.transformer.num_layers: + message = f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: - raise ValueError( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" - ) - num_repeats = int(self.transformer.num_layers // len(self.hybrid_block_layout)) - logger.warning( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" - ) + raise ValueError(message) + num_repeats = self.transformer.num_layers // len(self.hybrid_block_layout) + logger.warning(f"{message}, will repeat {self.hybrid_block_layout} {num_repeats} times.") self.hybrid_block_layout = self.hybrid_block_layout * num_repeats - Assert.eq(len(self.hybrid_block_layout), self.transformer.num_layers) - Assert.custom( - lambda _: all(block_type in SSMBlockType.__members__.values() for block_type in self.hybrid_block_layout), - f"Invalid block type: {self.hybrid_block_layout}. Must be one of {SSMBlockType.__members__.values()}", - ) - Assert.custom( - lambda _: self.default_mtp_type in SSMBlockType.__members__.values() or self.default_mtp_type is None, - f"Invalid MTP type: {self.default_mtp_type}. Must be one of {SSMBlockType.__members__.values()} or None", - ) - super()._validate() + ssm_block_types = set(self.hybrid_block_layout) - {SSMBlockType.transformer} + # TODO: Support combination of different SSM block types. + Assert.leq(len(ssm_block_types), 1) + self.ssm_block_type = ssm_block_types.pop() if ssm_block_types else None class LLambaHuggingfaceCheckpointFormat(CheckpointFormat): diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 02a5ac239..89f0cd4aa 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -5,11 +5,8 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead -from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 -from fast_llm.layers.ssm.llamba_block import LlambaBlock -from fast_llm.layers.ssm.mamba2 import Mamba2 -from fast_llm.layers.ssm.mamba_layer import MambaLayer -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.ssm.llamba_block import SSMBlock +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.model import GPTBaseModel, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType @@ -31,7 +28,6 @@ def __init__( config: HybridSSMBaseModelConfig, distributed_config: DistributedConfig, ): - self.SSM_BLOCK_CLS = LlambaBlock # TODO: extend to other block types if needed super().__init__(config, distributed_config) def get_output_layers(self) -> list[Layer]: @@ -39,52 +35,31 @@ def get_output_layers(self) -> list[Layer]: Get the output layers of the model. This includes the language model head and any additional heads specified in the configuration. """ - layers = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] + layers: list[Layer] = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] if self._config.prediction_heads > 1: block_type = self._config.default_mtp_type or self._config.hybrid_block_layout[-1] for i in range(1, self._config.prediction_heads): if block_type == SSMBlockType.transformer: layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=len(self._config.hybrid_block_layout), + block_index=len(self._config.hybrid_block_layout), return_input=i != self._config.prediction_heads - 1, ) ) - elif block_type == SSMBlockType.mamba2_discrete: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - elif block_type == SSMBlockType.mamba: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - elif block_type == SSMBlockType.mamba2: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=Mamba2, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") + layers.append( + SSMBlock( + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, + mixer_cls=self._config.ssm_block_type.get_mixer_class(), + block_index=len(self._config.hybrid_block_layout), + tensor_space=self._tensor_space, + return_input=i != self._config.prediction_heads - 1, + ) + ) layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=i)) return layers @@ -94,63 +69,35 @@ def get_layers(self) -> list[Layer]: Create a list of layers for the model, interleaving Transformer and Mamba blocks according to the block pattern. """ - layers = [LanguageModelEmbedding(self._config, self._tensor_space)] + layers: list[Layer] = [LanguageModelEmbedding(self._config, self._tensor_space)] # Create blocks according to pattern for i, block_type in enumerate(self._config.hybrid_block_layout): if block_type == SSMBlockType.transformer: # Transformer block layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, return_input=( i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 ), ) ) - elif block_type == SSMBlockType.mamba2_discrete: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - - elif block_type == SSMBlockType.mamba: - # Create Mamba block - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - - elif block_type == SSMBlockType.mamba2: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=Mamba2, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") + layers.append( + SSMBlock( + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, + mixer_cls=self._config.ssm_block_type.get_mixer_class(), + block_index=i + 1, + tensor_space=self._tensor_space, + return_input=( + i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 + ), + ) + ) # Add the output layers layers += self.get_output_layers() diff --git a/setup.cfg b/setup.cfg index 843aa15ca..c086af7d0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,14 +48,9 @@ HUGGINGFACE = # Required to run SSMs # To install on cpu environment (ex. for IDE support): -# MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation +# MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[SSM]" --no-build-isolation SSM = mamba_ssm[causal-conv1d]==2.2.4 - cartesia_pytorch>=0.0.2 - -GENERATION = - lm_eval>=0.4.9 - DEV = # Pre-commit git hook diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 438782dfe..3e6c37632 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -193,6 +193,7 @@ def test_gpt_blended_mixed(): def test_gpt_blended_mixed_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index e951cc2b1..4f36cdf89 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -39,6 +39,7 @@ def test_gpt_concatenate(): def test_gpt_concatenate_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 7472f1958..004b96289 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -58,6 +58,7 @@ def test_gpt_fim(): def test_gpt_fim_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { @@ -81,6 +82,7 @@ def test_gpt_fim_data(): def test_gpt_fim_data_legacy(): + get_test_dataset() get_test_data_and_compare_samples( { "format": "list", diff --git a/tests/test_attention.py b/tests/test_attention.py index 87b0d3e59..dd36b840a 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -17,12 +17,12 @@ def test_decide_window_size(): # Arrange - Case 1: window_size is returned (layer_index >= max_window_layers) attention._config = TransformerConfig(window_size=512, max_window_layers=2) - attention._layer_index = 2 + attention._block_index = 2 assert attention._decide_window_size() == 512 # Arrange - Case 2: window_size is None (layer_index < max_window_layers) attention._config = TransformerConfig(window_size=512, max_window_layers=2) - attention._layer_index = 1 + attention._block_index = 1 assert attention._decide_window_size() is None # Arrange - Case 3: max_window_layers is None (always return window_size) diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index c530a170c..2f125717e 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -3,9 +3,10 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer -from fast_llm.layers.ssm.llamba_block import LlambaBlock -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.ssm.llamba_block import SSMBlock +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.utils import Assert +from tests.utils.dataset import get_model_test_dataset from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -23,6 +24,7 @@ def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: @requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_frozen_weights(model_testing_config): + get_model_test_dataset() args = model_testing_config.config_args + ["run.tensor_logs.save=False"] model_ref = _get_trainer_from_args(args, model_testing_config.model_type)._multi_stage model_frozen = _get_trainer_from_args( @@ -39,7 +41,7 @@ def test_frozen_weights(model_testing_config): model_frozen._num_stages, ) frozen_parameter_counts = [ - sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerLayer, LlambaBlock)) else 0 + sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerBlock, SSMBlock)) else 0 for layer in model_ref.base_model.layers ] for weight_buffer_ref, weight_buffer_frozen in zip( diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 1eee3675d..42252c620 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -523,6 +523,7 @@ def _update_and_add_testing_config( model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2']", + f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}", ], megatron_args=None, checkpoint_format=None, From 7b32699be7c1a1fb29cc7386eb33280b0bc19a5c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 17:28:56 -0400 Subject: [PATCH 11/82] stuff --- fast_llm/layers/ssm/mamba2.py | 57 ++++++++++++++--------------------- fast_llm/models/ssm/config.py | 2 +- tests/utils/model_configs.py | 2 +- 3 files changed, 24 insertions(+), 37 deletions(-) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index ead32fa2a..b936ccf14 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -7,6 +7,7 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_fill_, init_ones_, init_uniform_, kaiming_init_ @@ -97,9 +98,9 @@ def __init__( if self.repeat_kv_before_conv: self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, TensorDim("1", 1), td_conv_kernel), + (td_inner, td_conv_kernel), init_method=init_uniform_( - 1 / math.sqrt(td_inner.size * td_conv_kernel.size), + -1 / math.sqrt(td_inner.size * td_conv_kernel.size), 1 / math.sqrt(td_inner.size * td_conv_kernel.size), ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 lr_scale=mamba_layer_lr_scale, @@ -110,9 +111,9 @@ def __init__( ) else: self.conv1d_weight = ParameterMeta.from_dims( - (td_xb, TensorDim("1", 1), td_conv_kernel), + (td_xb, td_conv_kernel), init_method=init_uniform_( - 1 / math.sqrt(td_xb.size * td_conv_kernel.size), + -1 / math.sqrt(td_xb.size * td_conv_kernel.size), 1 / math.sqrt(td_xb.size * td_conv_kernel.size), ), ) @@ -133,7 +134,13 @@ def __init__( weight_init_method=kaiming_init_(td_model.size), lr_scale=mamba_layer_lr_scale, ) - + self.dt_in_proj = Linear( + td_model, + tdt_rank, + bias=config.add_bias_linear, + weight_init_method=kaiming_init_(transformer_config.hidden_size), + lr_scale=mamba_layer_lr_scale, + ) # Initialize special dt projection to preserve variance at initialization dt_scale = config.dt_scale # 1.0 dt_init_std = self.dt_rank**-0.5 * dt_scale @@ -144,24 +151,6 @@ def __init__( else: raise NotImplementedError - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - dt_max = config.dt_max # or 0.1 - dt_min = config.dt_min # or 0.001 - dt_init_floor = config.dt_init_floor # or 1e-4 - dt = torch.exp(torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)).clamp( - min=dt_init_floor - ) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - - def init_from_tensor_( - value: torch.Tensor, - ) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - return tensor.copy_(value) - - return init_ - self.dt_proj = Linear( tdt_rank, td_inner, @@ -171,18 +160,16 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) ) # define bias outside the linear layer since its also used in the selective_scan_fn self.dt_proj_bias = ParameterMeta.from_dims( - (td_inner,), init_method=init_from_tensor_(inv_dt), lr_scale=mamba_layer_lr_scale + (td_inner,), + init_method=init_dtprojbias( + self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor + ), + lr_scale=mamba_layer_lr_scale, ) - A = einops.repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_log = torch.log(A).flatten() # Keep A_log in fp32 self.A_log = ParameterMeta.from_dims( (td_inner, td_state), - init_method=init_from_tensor_(A_log), + init_method=init_A(self.config.state_size, self.config.d_inner), lr_scale=mamba_layer_lr_scale, weight_decay=False, ) @@ -214,8 +201,8 @@ def forward(self, hidden_states, kwargs): A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - zxbcdt = self.in_proj(hidden_states) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) + zxbc = self.in_proj(hidden_states) + z, x, B, C = torch.split(zxbc, [self.d_inner, self.d_xb, self.d_xb, self.d_inner], dim=-1) x = einops.rearrange(x, "b l d -> b d l") z = einops.rearrange(z, "b l d -> b d l") @@ -225,7 +212,7 @@ def forward(self, hidden_states, kwargs): B = einops.rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() C = einops.rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() - dt = self.dt_proj(dt) + self.dt_proj_bias # B, L, d_inner + dt = self.dt_proj(self.dt_in_proj(hidden_states)) + self.dt_proj_bias # B, L, d_inner dt = einops.rearrange(dt, "b l d -> b d l") # B, d_inner, L if self.repeat_kv_before_conv: @@ -238,7 +225,7 @@ def forward(self, hidden_states, kwargs): if _causal_conv1d_available: x = _causal_conv1d_fn( x=x, - weight=einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), + weight=self.conv1d_weight, bias=self.conv1d_bias, activation=self.activation, ) # B, L, D diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 9ca0123b2..b04b1f210 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -78,7 +78,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_discrete_mamba2, inner_proj_dim)) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_dim, conv_dim)) elif SSMBlockType.mamba2.value in self.hybrid_block_layout: - inner_proj_dim: int = 2 * self.ssm.d_xb + 2 * d_inner + self.ssm.dt_rank + inner_proj_dim: int = 2 * self.ssm.d_xb + 2 * d_inner # + self.ssm.dt_rank tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba2, inner_proj_dim)) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim_2, self.ssm.d_xb)) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.c_heads, d_inner // self.ssm.state_size)) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 42252c620..4976ad2b1 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -523,7 +523,7 @@ def _update_and_add_testing_config( model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2']", - f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}", + # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}", ], megatron_args=None, checkpoint_format=None, From 1feccc866c1dea2da66567476fc911a37a855038 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 17:48:23 -0400 Subject: [PATCH 12/82] stuff --- fast_llm/layers/ssm/mamba2.py | 2 +- fast_llm/layers/ssm/mamba_layer.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 88fe4abc0..fdba10beb 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -111,7 +111,7 @@ def __init__( sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) - # define bias outside the linear layer since its also used in the selective_scan_fn + # define bias outside the linear layer since it's also used in the selective_scan_fn self.dt_proj_bias = ParameterMeta.from_dims( (inner_dim,), init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 49afa910e..11db37910 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -48,9 +48,7 @@ def init_dtprojbias( dt_max: float, dt_min: float, dt_init_floor: float ) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - tensor = ( - tensor.uniform_(math.log(dt_min), math.log(dt_max), generator=generator).exp_().clamp_min(dt_init_floor) - ) + tensor.uniform_(math.log(dt_min), math.log(dt_max), generator=generator).exp_().clamp_min_(dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 return tensor.add_(torch.log(-torch.expm1(-tensor))) From e528b50ba5c5e2ea726876779db010f83fccd8ef Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 18:00:20 -0400 Subject: [PATCH 13/82] misc --- fast_llm/layers/ssm/discrete_mamba2.py | 4 ++-- fast_llm/layers/ssm/mamba2.py | 12 ++++++++---- fast_llm/layers/ssm/mamba_layer.py | 10 +++++++--- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index b95ff76da..fdce9bf63 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -5,7 +5,7 @@ import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs @@ -97,7 +97,7 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( - (td_conv, TensorDim("1", 1), td_conv_kernel), + (td_conv, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=init_uniform_centered_((td_conv.size * td_conv_kernel.size) ** -0.5), lr_scale=mamba_layer_lr_scale, ) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index fdba10beb..8be9dcb9b 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -3,7 +3,7 @@ import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames @@ -75,7 +75,11 @@ def __init__( conv1d_dim = inner_dim if self._config.repeat_kv_before_conv else xb_dim self.conv1d_weight = ParameterMeta.from_dims( - (conv1d_dim, tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel)), + ( + conv1d_dim, + tensor_space.get_tensor_dim(DefaultDimNames.scalar), + tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel), + ), init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, ) @@ -168,9 +172,9 @@ def forward(self, hidden_states, kwargs): .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) ) - x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias, activation="silu") + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias.squeeze(1), activation="silu") else: - x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias, activation="silu") + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias.squeeze(1), activation="silu") x = ( x.unflatten(1, (self._local_head_groups, self._config.state_size)) .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 11db37910..07eec38e6 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -4,7 +4,7 @@ import torch -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames @@ -87,7 +87,11 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( - (inner_dim, tensor_space.get_tensor_dim(SSMDimNames.conv_kernel)), + ( + inner_dim, + tensor_space.get_tensor_dim(DefaultDimNames.scalar), + tensor_space.get_tensor_dim(SSMDimNames.conv_kernel), + ), init_method=init_kaiming_(inner_dim.size), lr_scale=lr_scale, ) @@ -146,7 +150,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ # not, if we wanbt to support inference, we would need to imp.lement slow path here, see https://github.com/Zyphra/Zamba2/blob/1b182f40f2257f822cc06dd785df53d67d691a15/mamba_layer.py#L172s out = _mamba_inner_fn( in_proj, - self.conv1d_weight.unsqueeze(1), + self.conv1d_weight, None, self.x_proj.weight, self.dt_proj_weight, From b49c42febac4f32dc1be83655b242d6199a385bc Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 18:16:42 -0400 Subject: [PATCH 14/82] misc --- fast_llm/layers/ssm/discrete_mamba2.py | 4 ++-- fast_llm/layers/ssm/mamba2.py | 8 ++++---- fast_llm/layers/ssm/mamba_layer.py | 4 ++-- .../modeling_ssm_hybrid_apriel15b.py | 20 +++++++++++++------ tests/utils/model_configs.py | 1 - 5 files changed, 22 insertions(+), 15 deletions(-) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 734e35b21..c0ae7e781 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -5,7 +5,7 @@ import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs @@ -103,7 +103,7 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( - (td_conv, TensorDim("1", 1), td_conv_kernel), + (td_conv, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=init_uniform_( 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index b936ccf14..74c212add 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -4,7 +4,7 @@ import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias @@ -98,7 +98,7 @@ def __init__( if self.repeat_kv_before_conv: self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, td_conv_kernel), + (td_inner, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=init_uniform_( -1 / math.sqrt(td_inner.size * td_conv_kernel.size), 1 / math.sqrt(td_inner.size * td_conv_kernel.size), @@ -111,7 +111,7 @@ def __init__( ) else: self.conv1d_weight = ParameterMeta.from_dims( - (td_xb, td_conv_kernel), + (td_xb, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=init_uniform_( -1 / math.sqrt(td_xb.size * td_conv_kernel.size), 1 / math.sqrt(td_xb.size * td_conv_kernel.size), @@ -225,7 +225,7 @@ def forward(self, hidden_states, kwargs): if _causal_conv1d_available: x = _causal_conv1d_fn( x=x, - weight=self.conv1d_weight, + weight=einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), bias=self.conv1d_bias, activation=self.activation, ) # B, L, D diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index a95e94c03..4493332ce 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -5,7 +5,7 @@ import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig @@ -98,7 +98,7 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, TensorDim("D_inner_2", self.d_inner // self.d_inner), td_conv_kernel), + (td_inner, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=kaiming_init_(td_inner.size), lr_scale=mamba_layer_lr_scale, ) diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index f8f6a0520..4fde72458 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -843,9 +843,8 @@ def __init__( self.num_C_head = self.d_inner // self.d_state self.repeat_group = self.num_C_head // self.num_xb_head - self.in_proj = nn.Linear( - self.d_model, 2 * self.d_xb + 2 * self.d_inner + self.dt_rank, bias=bias, **factory_kwargs - ) + self.in_proj = nn.Linear(self.d_model, 2 * self.d_xb + 2 * self.d_inner, bias=bias, **factory_kwargs) + self.dt_in_proj = nn.Linear(self.d_model, self.dt_rank, bias=bias, **factory_kwargs) self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=dt_proj_bias, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization @@ -933,8 +932,17 @@ def forward( outputs = {} A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - zxbcdt = self.in_proj(hidden_states) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) + zxbc = self.in_proj(hidden_states) + z, x, B, C = torch.split( + zxbc, + [ + self.d_inner, + self.d_xb, + self.d_xb, + self.d_inner, + ], + dim=-1, + ) x = rearrange(x, "b l d -> b d l") z = rearrange(z, "b l d -> b d l") @@ -944,7 +952,7 @@ def forward( B = rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() C = rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() - dt = self.dt_proj(dt) # B, L, d_inner + dt = self.dt_proj(self.dt_in_proj(hidden_states)) # B, L, d_inner dt = rearrange(dt, "b l d -> b d l") # B, d_inner, L if self.repeat_kv_before_conv: diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 4976ad2b1..1eee3675d 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -523,7 +523,6 @@ def _update_and_add_testing_config( model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2']", - # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}", ], megatron_args=None, checkpoint_format=None, From c1b7f44a10ff379a067b10b76df296f3bee4cac1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 18:19:08 -0400 Subject: [PATCH 15/82] misc --- .../models/ssm/external/llamba/modeling_mtp_llamba.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py b/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py index 8f49ded40..6d9746db1 100644 --- a/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py +++ b/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py @@ -322,21 +322,19 @@ def __init__(self, config, factory_kwargs, layer_idx, **kwargs): # Mixer self.mixer = DiscreteMamba2( - d_model=self.config._hidden_size, + d_model=self.config.d_model, layer_idx=layer_idx, **config.ssm_cfg, **factory_kwargs, ) # Other components - self.input_layernorm = LlamaRMSNorm( - hidden_size=self.config._hidden_size, eps=1e-5, factory_kwargs=factory_kwargs - ) + self.input_layernorm = LlamaRMSNorm(hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs) self.post_attention_layernorm = LlamaRMSNorm( - hidden_size=self.config._hidden_size, eps=1e-5, factory_kwargs=factory_kwargs + hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs ) self.mlp = LlamaMLP( - hidden_size=self.config._hidden_size, + hidden_size=self.config.d_model, **config.mlp_cfg, factory_kwargs=factory_kwargs, ) From 31f5d415ef0c7eeca54a26d415076cbf3ba33cfd Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 18:20:26 -0400 Subject: [PATCH 16/82] misc --- fast_llm/models/ssm/conversion.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index d57300252..43e3c67e5 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -3,6 +3,7 @@ import pathlib import typing +from fast_llm.config import MISSING from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( ConstantExportParamConverter, @@ -19,7 +20,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import RMSNormalizationConfig -from fast_llm.layers.ssm.config import SSMBlockType +from fast_llm.layers.ssm.config import DTInitType, SSMBlockType from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter from fast_llm.models.ssm.config import ( AprielSSMHHybridHuggingfaceCheckpointFormat, @@ -42,11 +43,11 @@ class HybridModelCheckpointHandler(HuggingfaceStateDictCheckpointHandler): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - block_converter = RenameParamConverter( + block_converter = MappedConfigParamConverter( fast_llm_names=(("hybrid_block_layout",),), export_names=(("hybrid_block_layout",),), - ignore_missing=True, - default_value=[cls._default_block_type], + fast_llm_value=lambda x: [cls._default_block_type] if x == MISSING else x, + export_value=lambda x: [x_.value for x_ in x], ) return super()._create_config_converters() + [block_converter] @@ -202,7 +203,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ignore_missing=True, default_value=4, ), - RenameParamConverter( + MappedConfigParamConverter( fast_llm_names=(("ssm", "dt_init"),), export_names=( ( @@ -210,8 +211,8 @@ def _create_config_converters(cls) -> list[ParamConverter]: "dt_init", ), ), - ignore_missing=True, - default_value="random", + fast_llm_value=lambda x: DTInitType.random if x == MISSING else DTInitType(x), + export_value=lambda x: x.value, ), ] @@ -258,6 +259,9 @@ def _create_weight_converters(self) -> list[WeightConverter]: ) # ================================================ # Mamba2 specific parameters + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.mixer.dt_in_proj", f"model.layers.{i}.mixer.dt_in_proj", ssm_bias + ) converters += self._get_weight_and_bias_converters( f"layers.{i+1}.mixer.dt_proj", f"model.layers.{i}.mixer.dt_proj", False ) From 0a9ff25f6e0a699caef881dfcaeef0b19f825764 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 18:22:24 -0400 Subject: [PATCH 17/82] misc --- fast_llm/models/ssm/config.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 6b9e28584..d2a69303c 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -40,9 +40,6 @@ class HybridSSMBaseModelConfig(GPTBaseModelConfig): desc="Multi-token prediction mixer to use in the model. If None, will use the last block type in `hybrid_block_layout`.", hint=FieldHint.optional, ) - use_megatron_initialization: bool = Field( - default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing - ) # TODO: is this needed? # TODO: Support combination of different SSM block types. ssm_block_type: SSMBlockType | None = Field(init=False) From e7d9636819ab83df7204cc2b021fd4565188e946 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 19:55:53 -0400 Subject: [PATCH 18/82] Parallel discrete mamba 2 --- fast_llm/layers/ssm/config.py | 12 +- fast_llm/layers/ssm/discrete_mamba2.py | 212 ++++++++++--------------- fast_llm/layers/ssm/mamba2.py | 6 +- 3 files changed, 95 insertions(+), 135 deletions(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 15a6a8210..7f0b3cf61 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -211,23 +211,25 @@ def _validate(self) -> None: def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType) -> None: tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - num_heads = div(self.d_inner, self.state_size) # Head groups are configured differently depending on the block type. if block_type == SSMBlockType.mamba: + num_heads = div(self.d_inner, self.state_size) num_head_groups = num_heads elif block_type == SSMBlockType.mamba2: + num_heads = div(self.d_inner, self.state_size) num_head_groups = div(self.d_xb, self.state_size) elif block_type == SSMBlockType.mamba2_discrete: - Assert.eq(num_heads, self.n_v_heads) + # TODO: Use different variables? + num_heads = self.n_v_heads num_head_groups = self.n_qk_heads + # v_heads have size `headdim` that may be different from `state_size`. + Assert.multiple(self.d_inner, num_heads) else: raise NotImplementedError(block_type) tensor_space.add_tensor_dim(state_dim := TensorDim(SSMDimNames.state, self.state_size)) tensor_space.add_tensor_dim(head_groups := TensorDim(SSMDimNames.head_groups, num_head_groups, tensor)) - tensor_space.add_tensor_dim( - group_heads := TensorDim(SSMDimNames.group_heads, num_group_heads := div(num_heads, num_head_groups)) - ) + tensor_space.add_tensor_dim(group_heads := TensorDim(SSMDimNames.group_heads, div(num_heads, num_head_groups))) tensor_space.add_tensor_dim( heads := CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads)) ) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index fdce9bf63..ac4fb87cc 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -1,12 +1,12 @@ import logging -import math import typing import einops import torch from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace -from fast_llm.layers.common.linear import Linear +from fast_llm.functional.config import ActivationType +from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer @@ -32,12 +32,6 @@ _causal_conv1d_available = False -def bias_init_method(conv_weight): - fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(conv_weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - return init_uniform_centered_(bound) - - class DiscreteMamba2(Mixer): """DiscreteMamba2 (This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py).""" @@ -51,198 +45,162 @@ def __init__( transformer_config: TransformerConfig, ): super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) - self.config: SSMConfig = config + self._config: SSMConfig = config layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) - logger.info(f"Setting lr_scale for layer {block_index} of type {type(self)}: {mamba_layer_lr_scale}") - - td_inner = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state) - td_model = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - td_conv = tensor_space.get_tensor_dim(SSMDimNames.conv_dim) - td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.head_groups) - td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.composite_heads) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel) - td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.concatenated_inner_projection) - - self.d_model = td_model.size - self.d_inner = td_inner.size - self.d_state = td_state.size - self.chunk_size = config.chunk_size - self.n_qk_heads = td_n_qk_heads.size - self.n_v_heads = td_n_v_heads.size - self.conv_kernel_size = td_conv_kernel.size - - self.act = config.activation_type.activation_fn - self.activation_name = config.activation_type.name + lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + + inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) + hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) + conv1d_dim = tensor_space.get_tensor_dim(SSMDimNames.conv_dim) + heads_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads) + + self._local_heads = heads_dim.size + self._local_head_groups = tensor_space.get_tensor_dim(SSMDimNames.head_groups).size + self._local_inner_size = inner_dim.size + self._local_bc_size = tensor_space.get_tensor_dim(SSMDimNames.composite_head_groups_and_state).size # TODO: double check initializations # Projections - self.in_proj = Linear( - td_model, - td_inner_proj, + self.in_proj = OutputParallelLinear( + hidden_dim, + tensor_space.get_tensor_dim(name=SSMDimNames.concatenated_inner_projection), bias=config.add_bias_linear, - weight_init_method=init_kaiming_(td_model.size), - lr_scale=mamba_layer_lr_scale, + weight_init_method=init_kaiming_(transformer_config.hidden_size), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, ) - self.z_bias = ( - ParameterMeta.from_dims( - (td_inner,), + if not config.add_bias_linear: + self.z_bias = ParameterMeta.from_dims( + (inner_dim,), weight_decay=False, init_method=init_zeros_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) - if not config.add_bias_linear - else 0.0 - ) - self.conv1d_weight = ParameterMeta.from_dims( - (td_conv, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), - init_method=init_uniform_centered_((td_conv.size * td_conv_kernel.size) ** -0.5), - lr_scale=mamba_layer_lr_scale, + ( + conv1d_dim, + tensor_space.get_tensor_dim(DefaultDimNames.scalar), + tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel), + ), + init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), + lr_scale=lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( - (td_conv,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale + (conv1d_dim,), + init_method=init_uniform_centered_(self._config.conv_kernel_dimension**-0.5), + lr_scale=lr_scale, ) - # D "skip" parameter self.D = ParameterMeta.from_dims( - (td_n_v_heads,), + (heads_dim,), weight_decay=False, init_method=init_ones_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) - - # out_proj - self.out_proj = Linear( - td_inner, - td_model, + self.out_proj = InputParallelLinear( + inner_dim, + hidden_dim, bias=config.add_bias_linear, - weight_init_method=init_kaiming_(td_inner.size), - lr_scale=mamba_layer_lr_scale, + weight_init_method=init_kaiming_(self._config.d_inner), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, ) def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - ON variable names and pep8: keeping some variable names as in the original code for clarity. - - Args: - u: (B, L, D), - - Returns: - outputs: dict. - outputs["hidden_states"]: (B, L, D). - outputs["state"]: inference cache. - """ if kwargs[TransformerKwargs.sequence_first]: raise NotImplementedError(f"Sequence-first not supported for SSMs.") assert _mamba_available - outputs = {} - # assert state is None - batch, seqlen, dim = input_.shape - - state = None - # Hacky way to initialize state during inference - chunk_size = self.chunk_size if state is None else seqlen + sequence_length = input_.size(0 if kwargs[TransformerKwargs.sequence_first] else 1) # Pad input to nearest multiple of chunklen - padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size - u = torch.nn.functional.pad(input_, (0, 0, 0, padded_len - seqlen)) + padded_length = (1 + (sequence_length - 1) // self._config.chunk_size) * self._config.chunk_size + if padded_length != sequence_length: + assert not kwargs[TransformerKwargs.sequence_first] and not self._sequence_parallel + input_ = torch.nn.functional.pad(input_, (0, 0, 0, padded_length - sequence_length)) - # Project input - xBCzA_log = self.in_proj(u) + inner_projection = self.in_proj(input_) + # Standardize to (batch, sequence, inner_projection) + if kwargs[TransformerKwargs.sequence_first]: + inner_projection = inner_projection.transpose(0, 1) - ( - xBC, - z, - A_log, - ) = torch.split( - xBCzA_log, + xBC, z, A_log = torch.split( + inner_projection, [ - self.d_inner + 2 * self.n_qk_heads * self.d_state, - self.d_inner, - self.n_v_heads, + self._local_inner_size + 2 * self._local_bc_size, + self._local_inner_size, + self._local_heads, ], dim=-1, ) - if state is not None: - # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead torch.nn.functional.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - xBC_t = einops.rearrange(xBC[:, :seqlen, :], "b l d -> b d l") - state["conv"].copy_( - torch.nn.functional.pad(xBC_t, (self.conv_kernel_size - xBC_t.shape[-1], 0)) - ) # Update state (B D W) - # Convolutional layer - xBC = self.convolutional_forward(xBC, padded_len) + xBC = self.convolutional_forward(xBC, sequence_length) x, B, C = torch.split( xBC, [ - self.d_inner, - self.n_qk_heads * self.d_state, - self.n_qk_heads * self.d_state, + self._local_inner_size, + self._local_bc_size, + self._local_bc_size, ], dim=-1, ) - x = einops.rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) - B = einops.rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) - C = einops.rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + x = einops.rearrange(x, "b l (h n) -> b l h n", h=self._local_heads) + B = einops.rearrange(B, "b l (h n) -> b l h n", h=self._local_head_groups) + C = einops.rearrange(C, "b l (h n) -> b l h n", h=self._local_head_groups) # SSM forward - result = _mamba_chunk_scan_combined( + y = _mamba_chunk_scan_combined( x=x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1), dt=A_log, dt_softplus=True, - A=-torch.ones(self.n_v_heads, device=A_log.device), + A=-torch.ones(self._local_heads, device=A_log.device), B=B, C=C, - chunk_size=chunk_size, - # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation - return_final_states=(state is not None), + chunk_size=self._config.chunk_size, + return_final_states=False, ) - - if state is not None: - y, ssm_state = result - state["ssm"].copy_(ssm_state) - else: - y = result - Du = torch.einsum("h,blhp->blhp", self.D, x) - y = einops.rearrange(y + Du, "b l h p -> b l (h p)") # Norm and gate - out = self.out_proj(y * torch.nn.functional.silu(z + self.z_bias)) - outputs["hidden_states"] = out[:, :seqlen, :].contiguous() + if not self._config.add_bias_linear: + z = z + self.z_bias - # TODO: since we do not support inference for now, we only return the hidden states for now. - return outputs["hidden_states"], None + # y: (batch, sequence, heads, state) -> (batch, sequence, heads * state) + y = ((y + Du).flatten(2, 3) * torch.nn.functional.silu(z))[:, :sequence_length] + if kwargs[TransformerKwargs.sequence_first]: + # TODO: Is contiguous needed? + y = y.transpose(0, 1).contiguous() + return self.out_proj(y) def convolutional_forward(self, xBC, padded_len): """Convolutional layer forward pass for the full sequence.""" - if _causal_conv1d_available and self.activation_name in ( - "silu", + if _causal_conv1d_available and self._config.activation_type in ( + ActivationType.silu, "swish", - "identity", + ActivationType.identity, ): xBC = _causal_conv1d_fn( xBC.transpose(1, 2), einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), self.conv1d_bias, - activation=None if self.activation_name == "identity" else self.activation_name, + activation=( + None + if self._config.activation_type == ActivationType.identity + else self._config.activation_type.value + ), ).transpose(1, 2) else: - xBC = self.act( + xBC = self._config.activation_type.activation_fn( torch.nn.functional.conv1d( xBC.transpose(1, 2), self.conv1d_weight, bias=self.conv1d_bias, groups=self.conv1d_weight.shape[0], - padding=self.conv_kernel_size - 1, + padding=self._config.conv_kernel_dimension - 1, )[..., :padded_len].transpose(1, 2) ) return xBC diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 8be9dcb9b..cba28f8b8 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -142,12 +142,12 @@ def __init__( # TODO: lr_scale? ) - def forward(self, hidden_states, kwargs): + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available assert _causal_conv1d_available - inner_projection = self.in_proj(hidden_states) - dt = self.dt_proj(self.dt_in_proj(hidden_states)) + self.dt_proj_bias + inner_projection = self.in_proj(input_) + dt = self.dt_proj(self.dt_in_proj(input_)) + self.dt_proj_bias # Standardize to (batch, sequence, inner_projection) if kwargs[TransformerKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) From c14b7643ae3f840f8da23404922f9482ff507284 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 25 Jul 2025 17:14:17 -0400 Subject: [PATCH 19/82] Mamba 2, misc --- fast_llm/engine/multi_stage/stage_base.py | 5 +- fast_llm/layers/ssm/config.py | 62 ++++++++++--------- fast_llm/layers/ssm/discrete_mamba2.py | 50 ++++++++++----- fast_llm/layers/ssm/mamba2.py | 22 ++++--- fast_llm/layers/ssm/mamba_layer.py | 27 ++++----- fast_llm/tensor.py | 74 +++++++++++++++-------- tests/models/test_checkpoint.py | 11 +++- tests/utils/model_configs.py | 9 +-- 8 files changed, 160 insertions(+), 100 deletions(-) diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 9a8ce2092..3218a1963 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -185,8 +185,9 @@ def initialize_weights(self) -> None: # Multi-gpu init may be different because of TP or FSDP (different shape), or PP (not on device) global_shape = meta.global_shape - if self._distributed_config.reproducible_init and ( - global_shape.numel() != parameter.numel() or not self._mode.on_device + if meta.requires_global_initialization or ( + self._distributed_config.reproducible_init + and (global_shape.numel() != parameter.numel() or not self._mode.on_device) ): # Initialize all global weights on every gpu, then select the appropriate slice if applicable. global_param = parameter.new_empty(global_shape, device=self._distributed.device) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 7f0b3cf61..c06d85148 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -5,31 +5,31 @@ from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig +from fast_llm.tensor import Initializer from fast_llm.utils import Assert, div class SSMDimNames: # TODO: Use separate tensor space for different mixers so there is no risk of name conflict. state = "ssm_state" # State dimension (N), aka head size / num channels - + head_dim = "ssm_head_dim" head_groups = "ssm_head_groups" group_heads = "ssm_group_heads" + convolution_kernel = "ssm_convolution_kernel" # Kernel dimension of the conv1d in mamba layers + + dt_rank = "ssm_dt_rank" + + # Composite dimensions composite_heads = "ssm_composite_heads" - composite_heads_and_state = "ssm_composite_heads_and_state" + composite_heads_and_head_dim = "ssm_composite_heads_and_head_dim" composite_head_groups_and_state = "ssm_composite_head_groups_and_state" - # Inner projection total dimension. + # Concatenated dimensions + concatenated_convolution = "ssm_concatenated_convolution" + concatenated_x_projection = "ssm_x_concatenated_x_projection" concatenated_inner_projection = "ssm_concatenated_inner_projection" - # Convolution shape in discrete mamba 2. TODO: Remove (dim too complex) - conv_dim = "ssm_conv_dim" - - dt_rank = "ssm_dt_rank" - - x_proj_dim = "x_proj_dim" # X projection dimension - conv_kernel = "conv_kernel" # Kernel size of the conv1d in mamba layers - class SSMBlockType(enum.StrEnum): """ @@ -62,7 +62,7 @@ class DTInitType(enum.StrEnum): constant = "constant" random = "random" - def get_init_method(self, scale: float): + def get_init_method(self, scale: float) -> Initializer: from fast_llm.tensor import init_fill_, init_uniform_centered_ return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) @@ -222,56 +222,64 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType # TODO: Use different variables? num_heads = self.n_v_heads num_head_groups = self.n_qk_heads - # v_heads have size `headdim` that may be different from `state_size`. - Assert.multiple(self.d_inner, num_heads) else: raise NotImplementedError(block_type) - tensor_space.add_tensor_dim(state_dim := TensorDim(SSMDimNames.state, self.state_size)) + tensor_space.add_tensor_dim(state := TensorDim(SSMDimNames.state, self.state_size)) + if block_type == SSMBlockType.mamba2_discrete: + tensor_space.add_tensor_dim(head_dim := TensorDim(SSMDimNames.head_dim, div(self.d_inner, num_heads))) + else: + head_dim = state + tensor_space.add_tensor_dim(head_groups := TensorDim(SSMDimNames.head_groups, num_head_groups, tensor)) tensor_space.add_tensor_dim(group_heads := TensorDim(SSMDimNames.group_heads, div(num_heads, num_head_groups))) tensor_space.add_tensor_dim( heads := CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads)) ) tensor_space.add_tensor_dim( - heads_and_state := CompositeTensorDim( - SSMDimNames.composite_heads_and_state, (head_groups, group_heads, state_dim) + heads_and_head_dim := CompositeTensorDim( + SSMDimNames.composite_heads_and_head_dim, (head_groups, group_heads, head_dim) ) ) tensor_space.add_tensor_dim( head_groups_and_state := CompositeTensorDim( - SSMDimNames.composite_head_groups_and_state, (head_groups, state_dim) + SSMDimNames.composite_head_groups_and_state, (head_groups, state) ) ) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel, self.conv_kernel_dimension)) + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.convolution_kernel, self.conv_kernel_dimension)) # DT projection if block_type in (SSMBlockType.mamba, SSMBlockType.mamba2): - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.dt_rank, self.dt_rank)) + tensor_space.add_tensor_dim(dt_rank := TensorDim(SSMDimNames.dt_rank, self.dt_rank)) if block_type == SSMBlockType.mamba: - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim, self.dt_rank + self.state_size * 2)) + tensor_space.add_tensor_dim( + ConcatenatedTensorDim(SSMDimNames.concatenated_x_projection, (dt_rank, state, state)) + ) # TODO: Use composition instead tensor_space.add_tensor_dim( - ConcatenatedTensorDim(SSMDimNames.concatenated_inner_projection, (heads_and_state, heads_and_state)) + ConcatenatedTensorDim( + SSMDimNames.concatenated_inner_projection, (heads_and_head_dim, heads_and_head_dim) + ) ) elif block_type == SSMBlockType.mamba2: # TODO: Factor out state? tensor_space.add_tensor_dim( ConcatenatedTensorDim( SSMDimNames.concatenated_inner_projection, - (heads_and_state, head_groups_and_state, head_groups_and_state, heads_and_state), + (heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim), ) ) elif block_type == SSMBlockType.mamba2_discrete: - # TODO: Factor as (head_groups, (group_heads + 2) * state_size + group_heads)? tensor_space.add_tensor_dim( ConcatenatedTensorDim( SSMDimNames.concatenated_inner_projection, - (heads_and_state, head_groups_and_state, head_groups_and_state, heads_and_state, heads), + (heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim, heads), ) ) - # TODO: (head_groups, group_heads + 2, state_size) tensor_space.add_tensor_dim( - TensorDim(SSMDimNames.conv_dim, self.d_inner + 2 * self.n_qk_heads * self.state_size) + ConcatenatedTensorDim( + SSMDimNames.concatenated_convolution, + (heads_and_head_dim, head_groups_and_state, head_groups_and_state), + ) ) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index ac4fb87cc..64377b93c 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -49,14 +49,18 @@ def __init__( layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) + inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_head_dim) hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - conv1d_dim = tensor_space.get_tensor_dim(SSMDimNames.conv_dim) + conv1d_dim = tensor_space.get_tensor_dim(SSMDimNames.concatenated_convolution) heads_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads) - self._local_heads = heads_dim.size + # local_head_groups = head_groups / TP self._local_head_groups = tensor_space.get_tensor_dim(SSMDimNames.head_groups).size + # local_heads = local_head_groups * group_heads + self._local_heads = heads_dim.size + # local_inner_size = local_heads * head_size self._local_inner_size = inner_dim.size + # local_bc_size = local_head_groups * state self._local_bc_size = tensor_space.get_tensor_dim(SSMDimNames.composite_head_groups_and_state).size # TODO: double check initializations @@ -80,7 +84,7 @@ def __init__( ( conv1d_dim, tensor_space.get_tensor_dim(DefaultDimNames.scalar), - tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel), + tensor_space.get_tensor_dim(name=SSMDimNames.convolution_kernel), ), init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, @@ -107,24 +111,25 @@ def __init__( ) def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: - if kwargs[TransformerKwargs.sequence_first]: - raise NotImplementedError(f"Sequence-first not supported for SSMs.") - assert _mamba_available - sequence_length = input_.size(0 if kwargs[TransformerKwargs.sequence_first] else 1) + sequence_length = kwargs[TransformerKwargs.sequence_q_dim].global_size # Pad input to nearest multiple of chunklen padded_length = (1 + (sequence_length - 1) // self._config.chunk_size) * self._config.chunk_size if padded_length != sequence_length: - assert not kwargs[TransformerKwargs.sequence_first] and not self._sequence_parallel + assert not kwargs[TransformerKwargs.sequence_first] and input_.size(1) == sequence_length input_ = torch.nn.functional.pad(input_, (0, 0, 0, padded_length - sequence_length)) + # inner_projection : (batch/local_or_padded_sequence, local_sequence/batch, hidden) + # -> (batch/local_or_padded_sequence, local_sequence/batch, inner_projection) + # inner_projection: (batch, local_or_padded_sequence, hidden) -> (batch, padded_sequence, local_inner_size) inner_projection = self.in_proj(input_) - # Standardize to (batch, sequence, inner_projection) + # Standardize to (batch, padded_sequence, inner_projection) if kwargs[TransformerKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) + print("QAIKOFNMJOWENM inner_projection", inner_projection.shape) xBC, z, A_log = torch.split( inner_projection, [ @@ -134,9 +139,13 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ ], dim=-1, ) + print("QAIKOFNMJOWENM xBC", xBC.shape, self._local_inner_size, self._local_bc_size) + print("QAIKOFNMJOWENM z", z.shape) + print("QAIKOFNMJOWENM A_log", A_log.shape) # Convolutional layer - xBC = self.convolutional_forward(xBC, sequence_length) + # xbc: (batch, padded_sequence, local_heads * head_size + 2 * local_head_groups * state) + xBC = self.convolutional_forward(xBC, padded_length) x, B, C = torch.split( xBC, @@ -148,13 +157,16 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ dim=-1, ) + # x: (batch, padded_sequence, local_heads * head_size) -> (batch, padded_sequence, local_heads, head_size) x = einops.rearrange(x, "b l (h n) -> b l h n", h=self._local_heads) + + # b,c: (batch, padded_sequence, local_head_groups * state) -> (batch, padded_sequence, local_head_groups, state) B = einops.rearrange(B, "b l (h n) -> b l h n", h=self._local_head_groups) C = einops.rearrange(C, "b l (h n) -> b l h n", h=self._local_head_groups) # SSM forward y = _mamba_chunk_scan_combined( - x=x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1), + x=self._apply_a_log(x, A_log), dt=A_log, dt_softplus=True, A=-torch.ones(self._local_heads, device=A_log.device), @@ -169,23 +181,31 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if not self._config.add_bias_linear: z = z + self.z_bias - # y: (batch, sequence, heads, state) -> (batch, sequence, heads * state) + # y: (batch, padded_sequence, local_heads, head_size) -> (batch, sequence, local_heads * head_size) y = ((y + Du).flatten(2, 3) * torch.nn.functional.silu(z))[:, :sequence_length] if kwargs[TransformerKwargs.sequence_first]: # TODO: Is contiguous needed? y = y.transpose(0, 1).contiguous() + # out_proj: (batch/sequence, sequence/batch, local_heads * head_size) + # -> (batch/local_sequence, local_sequence/batch, hidden) + a, b = self.out_proj(y) + logger.info(f"EKFBN y {y.shape}") + logger.info(f"EKFBN a {a.shape}") return self.out_proj(y) + @torch.compile + def _apply_a_log(self, x: torch.Tensor, A_log: torch.Tensor) -> torch.Tensor: + return x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1) + def convolutional_forward(self, xBC, padded_len): """Convolutional layer forward pass for the full sequence.""" if _causal_conv1d_available and self._config.activation_type in ( ActivationType.silu, - "swish", ActivationType.identity, ): xBC = _causal_conv1d_fn( xBC.transpose(1, 2), - einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), + self.conv1d_weight.squeeze(1), self.conv1d_bias, activation=( None diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index cba28f8b8..1ae25e44c 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -39,7 +39,7 @@ class Mamba2(Mixer): _XZ_DIMS = ( TransformerDimNames.batch, - SSMDimNames.composite_heads_and_state, + SSMDimNames.composite_heads_and_head_dim, TransformerDimNames.sequence_q, ) _BC_DIMS = ( @@ -62,7 +62,7 @@ def __init__( layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_state) + inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_head_dim) xb_dim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_head_groups_and_state) hidden_dim: TensorDim = tensor_space.get_tensor_dim(name=TransformerDimNames.hidden) dt_rank_dim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) @@ -78,7 +78,7 @@ def __init__( ( conv1d_dim, tensor_space.get_tensor_dim(DefaultDimNames.scalar), - tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel), + tensor_space.get_tensor_dim(name=SSMDimNames.convolution_kernel), ), init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, @@ -146,6 +146,8 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ assert _mamba_available assert _causal_conv1d_available + # inner_projection : (batch/local_sequence, local_sequence/batch, hidden) + # -> (batch/sequence, sequence/batch, inner_projection) inner_projection = self.in_proj(input_) dt = self.dt_proj(self.dt_in_proj(input_)) + self.dt_proj_bias # Standardize to (batch, sequence, inner_projection) @@ -161,10 +163,10 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ dim=2, ) - # z: (batch, sequence, heads * state) -> (batch, heads * state, sequence) + # z: (batch, sequence, local_heads * state) -> (batch, local_heads * state, sequence) z = z.transpose(1, 2) - # x: (batch, sequence, head_groups * state) -> (batch, heads * state, sequence) + # x: (batch, sequence, local_head_groups * state) -> (batch, local_heads * state, sequence) x = x.transpose(1, 2) if self._config.repeat_kv_before_conv: x = ( @@ -172,16 +174,16 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) ) - x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias.squeeze(1), activation="silu") + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight.squeeze(1), bias=self.conv1d_bias, activation="silu") else: - x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias.squeeze(1), activation="silu") + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight.squeeze(1), bias=self.conv1d_bias, activation="silu") x = ( x.unflatten(1, (self._local_head_groups, self._config.state_size)) .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) ) - # b: (batch, sequence, head_groups * state) -> (batch, heads, state, sequence) + # b: (batch, sequence, local_head_groups * state) -> (batch, local_heads, state, sequence) b = ( b.transpose(1, 2) .unflatten(1, (self._local_head_groups, self._config.state_size)) @@ -216,9 +218,11 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if self._debug_level: self._debug_log(y, "y", self._XZ_DIMS, kwargs) - # y: (batch, heads * state, sequence) -> (batch, sequence, heads * state) + # y: (batch, local_heads * state, sequence) -> (batch, sequence, local_heads * state) y = y.transpose(1, 2)[:, :sequence_length] if kwargs[TransformerKwargs.sequence_first]: # TODO: Is contiguous needed? y = y.transpose(0, 1).contiguous() + # (batch/sequence, sequence/batch, local_heads * state) + # -> (batch/local_sequence, local_sequence/batch, hidden) return self.out_proj(y) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 07eec38e6..64c8227fc 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -10,7 +10,7 @@ from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_ +from fast_llm.tensor import LambdaInitializer, ParameterMeta, init_kaiming_, init_ones_ from fast_llm.utils import Assert, get_lr_scale try: @@ -29,30 +29,27 @@ """ -def init_A(d_state, d_inner) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - # TODO: adopt this initialization to work for tensor parallel setting! +def init_A(d_state, d_inner) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa if tensor.numel() != d_state * d_inner: - raise ValueError(f"_init_A requires not supported for tensor slices.") - return torch.log( + raise ValueError("_init_A requires not supported for tensor slices.") + torch.log( torch.arange(1, d_state + 1, dtype=torch.float32, device=tensor.device) .unsqueeze(0) .expand(d_inner, d_state), out=tensor, ) - return init_ + return LambdaInitializer(init_, requires_global_initialization=True) -def init_dtprojbias( - dt_max: float, dt_min: float, dt_init_floor: float -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: +def init_dtprojbias(dt_max: float, dt_min: float, dt_init_floor: float) -> LambdaInitializer: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa tensor.uniform_(math.log(dt_min), math.log(dt_max), generator=generator).exp_().clamp_min_(dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - return tensor.add_(torch.log(-torch.expm1(-tensor))) + tensor.add_(torch.log(-torch.expm1(-tensor))) - return init_ + return LambdaInitializer(init_) class MambaLayer(Mixer): @@ -72,7 +69,7 @@ def __init__( Assert.eq(self._config.activation_type, ActivationType.silu) # Tensor dims: - inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) + inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_head_dim) hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) @@ -90,7 +87,7 @@ def __init__( ( inner_dim, tensor_space.get_tensor_dim(DefaultDimNames.scalar), - tensor_space.get_tensor_dim(SSMDimNames.conv_kernel), + tensor_space.get_tensor_dim(SSMDimNames.convolution_kernel), ), init_method=init_kaiming_(inner_dim.size), lr_scale=lr_scale, @@ -98,7 +95,7 @@ def __init__( self.x_proj = Linear( inner_dim, - tensor_space.get_tensor_dim(SSMDimNames.x_proj_dim), + tensor_space.get_tensor_dim(SSMDimNames.concatenated_x_projection), weight_init_method=init_kaiming_(inner_dim.size), bias=False, lr_scale=lr_scale, diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 899e70005..b89ed4a04 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -1,3 +1,4 @@ +import abc import functools import math import typing @@ -241,7 +242,7 @@ def __init__( *, tensor_name: str = "", dims: tuple[TensorDim, ...], - init_method: typing.Callable[["ParameterMeta", torch.Tensor, torch.Generator], torch.Tensor] | None = None, + init_method: "Initializer | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None" = None, weight_decay: bool = True, # Pass a list to split the parameter in contiguous (dim=0) chunks of equal size for optimization. lr_scale: float | None | tuple[float | None, ...] = None, @@ -251,7 +252,11 @@ def __init__( allow_no_grad: bool = False, ): super().__init__(data, tensor_name=tensor_name, dims=dims) - self.param_init_method = init_method + if init_method is not None and not isinstance(init_method, Initializer): + # Support non-wrapped callables for convenience. + assert callable(init_method) + init_method = LambdaInitializer(init_method) + self.param_init_method: Initializer | None = init_method self.param_weight_decay = weight_decay self._is_param = True self.param_grad_is_zero = False @@ -276,7 +281,7 @@ def __new__( *, tensor_name: str = "", dims: tuple[TensorDim, ...], - init_method: typing.Callable, + init_method: "Initializer | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None", weight_decay: bool = True, lr_scale: float | None | tuple[float | None, ...] = None, allow_sequence_tensor_parallel: bool = True, @@ -303,6 +308,10 @@ def init_parameter(self, tensor: torch.Tensor, distributed: Distributed) -> None generator = distributed.tp_init_generator if self.is_tensor_parallel else distributed.pp_init_generator self.param_init_method(self, tensor, generator) + @property + def requires_global_initialization(self) -> bool: + return self.param_init_method.requires_global_initialization + def save(self) -> dict[str, typing.Any]: return { "name": self.tensor_name, @@ -334,11 +343,32 @@ def accumulate_gradient(param: torch.Tensor, grad: torch.Tensor) -> None: triton_add(grad, param.grad_buffer, out=param.grad_buffer) # noqa -def init_fill_(value) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - return tensor.fill_(value) +class Initializer(abc.ABC): + @abc.abstractmethod + def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: + pass + + requires_global_initialization = False - return init_ + +class LambdaInitializer(Initializer): + def __init__( + self, + init_method: typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None], + requires_global_initialization: bool = False, + ) -> None: + self._init_method = init_method + self.requires_global_initialization = requires_global_initialization + + def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: + return self._init_method(meta, tensor, generator) + + +def init_fill_(value: float) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa + tensor.fill_(value) + + return LambdaInitializer(init_) init_zeros_ = init_fill_(0.0) @@ -346,38 +376,32 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) def init_normal_( - mean=0.0, std=1.0, min_val=None, max_val=None -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa + mean: float = 0.0, std: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa tensor = tensor.normal_(mean, std, generator=generator) if min_val is not None or max_val is not None: - return tensor.clamp_(min=min_val, max=max_val) # noqa - else: - return tensor + tensor.clamp_(min=min_val, max=max_val) - return init_ + return LambdaInitializer(init_) -def init_kaiming_(d_in): +def init_kaiming_(d_in: float) -> LambdaInitializer: return init_normal_(0.0, math.sqrt(2.0 / d_in)) def init_uniform_( - low=0.0, high=1.0, min_val=None, max_val=None -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa + low: float = 0.0, high: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa tensor = tensor.uniform_(low, high, generator=generator) if min_val is not None or max_val is not None: - return tensor.clamp_(min=min_val, max=max_val) # noqa - else: - return tensor + tensor.clamp_(min=min_val, max=max_val) - return init_ + return LambdaInitializer(init_) -def init_uniform_centered_( - high, max_val=None, mean=0.0 -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: +def init_uniform_centered_(high: float, max_val: float | None = None, mean: float = 0.0) -> LambdaInitializer: return init_uniform_( mean - high, mean + high, diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 05acf23dc..4bda5512c 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -284,10 +284,15 @@ def test_load_pretrained( @pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_huggingface_model(model_testing_config, get_convert_path): # Test that Fast-LLM's Hugging Face wrapper produces the same results as the converted Hugging Face model. + # TODO: Stress the importance of this test as the main correctness test for most models. # TODO: Review test. Move to test_generate? fast_llm_path = get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat) hf_path = get_convert_path(model_testing_config.checkpoint_format, DistributedCheckpointFormat) - model_ref = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained( + try: + hf_class = model_testing_config.huggingface_model_for_causal_lm_class + except NotImplementedError: + pytest.skip(f"Hugging Face wrapper not implemented for {model_testing_config.name}.") + model_ref = hf_class.from_pretrained( CheckpointLoadConfig( path=get_convert_path(), format=DistributedCheckpointFormat, @@ -298,8 +303,8 @@ def test_huggingface_model(model_testing_config, get_convert_path): 0, model_ref.config.fast_llm_config.base_model.vocab_size, size=(4, 100), dtype=torch.int64, device="cuda" ) output_ref = model_ref(test_input) - model_from_fast_llm = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained(fast_llm_path) - model_from_hf = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained( + model_from_fast_llm = hf_class.from_pretrained(fast_llm_path) + model_from_hf = hf_class.from_pretrained( CheckpointLoadConfig( path=hf_path, format=model_testing_config.checkpoint_format, diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 038b53c26..722d8d63a 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -20,6 +20,7 @@ Starcoder2GPTHuggingfaceCheckpointFormat, ) from fast_llm.models.ssm.config import ( + AprielSSMHHybridHuggingfaceCheckpointFormat, AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, LLambaHuggingfaceCheckpointFormat, ) @@ -540,19 +541,19 @@ def _update_and_add_testing_config( "model.base_model.ssm.chunk_size=32", ], megatron_args=None, - checkpoint_format=None, + checkpoint_format=AprielSSMHHybridHuggingfaceCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, # TODO: Implement - ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, compare_factor=2.0, # Micro-sequence split and sequence-first not supported. - skip_tests=("sf", "stp", "sdp", "ms"), + skip_tests=("sdp", "ms"), ) From b605bd29bcdd85379a2c43124f07a4c215f53e71 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 25 Jul 2025 18:24:53 -0400 Subject: [PATCH 20/82] doc --- docs/contributing/contributing.md | 4 ++-- docs/contributing/testing.md | 37 ++++++++++++++++++++++++++----- mkdocs.yaml | 1 + 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/docs/contributing/contributing.md b/docs/contributing/contributing.md index 6185b63fe..938fe925f 100644 --- a/docs/contributing/contributing.md +++ b/docs/contributing/contributing.md @@ -40,7 +40,7 @@ Before diving into code, [open an issue](https://github.com/ServiceNow/Fast-LLM/ Here are some tips to ensure your pull request gets reviewed and merged promptly: - **Follow our coding standards**: Stick to our [style guide and conventions](https://servicenow.github.io/Fast-LLM/developers/style-guide) to keep the code clean and consistent. -- **Write tests**: Verify your changes with unit tests for new features or bug fixes. +- **Write tests**: Verify your changes with unit tests for new features or bug fixes. See our [testing guide](https://servicenow.github.io/Fast-LLM/contributing/testing) for tips and recommendations on testing. - **Test on GPUs and real-world workloads**: Since Fast-LLM is all about training large language models, make sure your changes work smoothly in GPU environments and on typical training setups. - **Run benchmarks and performance tests**: Make sure your changes don't slow things down. If there's any impact on performance, provide benchmark results to back it up. - **Avoid introducing new issues**: Check that there are no new runtime warnings, type checker errors, linting problems, or unhandled edge cases. @@ -48,7 +48,7 @@ Here are some tips to ensure your pull request gets reviewed and merged promptly - **Keep sensitive data out**: Make sure your code or commit messages don't expose private or proprietary information. - **Use a clear and descriptive title**: The PR title should summarize the key change or feature introduced. Avoid vague titles like "Fix bug" or "Update code." Start with a keyword like `[feat]`, `[fix]`, `[docs]`, etc. to categorize the change. Reference the issue number if applicable (e.g., `[fix] resolve #123 memory leak in training loop`). This title will become the commit message for the squashed merge. - **Use the [PR template](https://github.com/ServiceNow/Fast-LLM/blob/main/.github/PULL_REQUEST_TEMPLATE.md)**: Complete the checklist to make sure everything is in order before hitting submit. -- **Make sure all tests pass before merging**: Run the tests with `pytest tests/ -v -ra -n 10`, and fix any failure before merging. If possible, please run the test in an environment with at least 4 GPUs. +- **Make sure all tests pass before merging**: Run the tests with `pytest tests/ -v -ra -n 10`, and fix any failure before merging. If possible, please run the test in an environment with at least 4 GPUs. See our [testing guide](https://servicenow.github.io/Fast-LLM/contributing/testing) for more details on testing and debugging. ## 🆘 Seeking Help or Clarification diff --git a/docs/contributing/testing.md b/docs/contributing/testing.md index 8df93f9d0..9cce78e3c 100644 --- a/docs/contributing/testing.md +++ b/docs/contributing/testing.md @@ -1,13 +1,43 @@ --- -title: Writing tests +title: Writing and running tests --- +## Debugging with tests + +### Selecting tests + +When debugging, it is often practical to target specific tests that will run quickly. While Pytest supports targeting specific directory, files or tests, the complex parameterization and dependencies of our tests often makes explicit targeting tedious and/or impractical. We provide several options for selecting tests: + +* `--skip-slow`: This will run a subset of "fast" tests that cover the majority of our codebase. This is useful for quickly checking that changes did not break Fast-LLM too badly before running the full test suite. Note that parallel testing (`-n`) is not needed (and may be counter-productive) with this argument. +* `--run-extra-slow`: Some tests are disabled by default because they take too long to run (ex. complex integration tests) and/or are not particularly important. This argument re-enables them. +* `--models MODEL0 MODEL1 ...`: This allows targeting one or more specific models from the model tests (see below), and is particularly useful when debugging a model. For example, `pytest tests/models/test_models/test_checkpoint.py -v -ra --models llama` will test checkpoints specifically for the llama model. (Note that `-n` may not be needed here as model tests for a given model are only partly distributed dure to dependency constraints.) + +### Monitoring distributed tests + +`--no-distributed-capture` + +### Other options + +* `--show-gpu-memory N`: Our testing suite monitors GPU memory usage and reports the highest users. Use this option to adjust the number of reported tests (10 by default). Note that this option is mainly intended to make sure tests don't use too much memory (which could cause crashes with lots of parallel tests) and may not be an accurate measurement. +* `--show-skipped`: Many tests skipped for obvious reasons (ex. marked as slow or extra slow, skipped model testing groups (see below)) are removed entirely from the report to reduce clutter. This option may be used to show them explicitly. + +## Best practices + ## Testing models [Model integration tests](https://github.com/ServiceNow/Fast-LLM/blob/main/tests/models) are the most important part of our testing suite, ensuring that Fast-LLM works and yields consistent results for a variety of models, training configurations, optimizations, etc. For each tested model, we run a series of tests divided into several groups. Much of these tests consist of running a short Fast-LLM training run, then comparing intermediate tensors (ex. parameter initialization, layer outputs and gradients, parameter gradients) against a baseline. +### What is being tested + +Coming soon. + +!!! warning "Don't forget about unit tests!" + + While adding a model is a quick and efficient way to increase coverage, it is **not a replacement for unit tests**. + The model testing suite performs intensive consistency checks, but does little to make sure those results are correct to begin with. See [functional tests](https://github.com/ServiceNow/Fast-LLM/blob/main/tests/functional) and [test_lm_head](https://github.com/ServiceNow/Fast-LLM/blob/main/tests/layers/test_lm_head.py) for good examples of unit tests for individual components and an entire layer. + ### Adding a model When adding support for a new model that comes with additional features, the simplest option to increase coverage is to add an example configuration to the [tested modelsl](https://github.com/ServiceNow/Fast-LLM/blob/main/tests/utils/model_configs.py). @@ -41,11 +71,6 @@ _update_and_add_testing_config( ) ``` -!!! warning "Don't forget about unit tests!" - - While adding a model is a quick and efficient way to increase coverage, it is **not a replacement for unit tests**. - The model testing suite performs intensive consistency checks, but does little to make sure those results are correct to begin with. See [functional tests](https://github.com/ServiceNow/Fast-LLM/blob/main/tests/functional) and [test_lm_head](https://github.com/ServiceNow/Fast-LLM/blob/main/tests/layers/test_lm_head.py) for good examples of unit tests for individual components and an entire layer. - #### Reference for groups Fast-LLM currently supports the following testing groups: diff --git a/mkdocs.yaml b/mkdocs.yaml index 85fd4bff0..00e52a011 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -189,5 +189,6 @@ nav: - Contribution Guide: contributing/contributing.md - Style Guide: contributing/style-guide.md - Development Practices: contributing/dev-practices.md + - Testing: contributing/testing.md - About Us: about-us.md - Join Us: join-us.md From 5eea938403a74bcf8ee7f0c504e3d8bb6fe118f7 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 28 Jul 2025 11:52:24 -0400 Subject: [PATCH 21/82] fix --- fast_llm/models/custom/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index c206ef406..534d813ff 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -7,7 +7,7 @@ from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.schedule.config import BatchConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.custom.config import CustomBaseModelConfig, CustomModelConfig from fast_llm.models.custom.head import CustomHead from fast_llm.models.gpt.config import GPTBaseModelConfig @@ -31,10 +31,10 @@ def get_layers(self) -> list[Layer]: return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, ) for i in range(self._config.transformer.num_layers) ], From 2e6d082e4b2d7fc3f043365664339a5b823713e6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 28 Jul 2025 15:46:06 -0400 Subject: [PATCH 22/82] fixes --- fast_llm/engine/config_utils/tensor_space.py | 42 +++++++++++++- fast_llm/engine/multi_stage/fsdp.py | 32 +++-------- fast_llm/layers/ssm/config.py | 7 ++- fast_llm/models/gpt/megatron.py | 29 +++++----- fast_llm/tensor.py | 58 ++++++++++++++------ 5 files changed, 109 insertions(+), 59 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 0d971a88a..55d87e271 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -66,13 +66,23 @@ def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: return TensorDim(self.name, self.size * distributed_dim.size, distributed_dim) def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": - if self.parallel_group is not None: + if self.is_parallel: from fast_llm.core.ops import gather_op return gather_op(tensor, self.parallel_group, dim) else: return tensor + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + if self.is_parallel: + output = tensor.new_full((*tensor.shape[:dim], self.parallel_dim.size, *tensor.shape[dim:]), fill_value) + output.narrow(dim, self.parallel_dim.rank, 1).copy_(tensor.unsqueeze(dim)).squeeze(dim) + return output.flatten(dim, dim + 1) + else: + return tensor + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": return ( tensor.chunk(self.parallel_dim.size, dim)[self.parallel_dim.rank] @@ -111,6 +121,15 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in enumerate(self._tensor_dims): + tensor = tensor_dim.local_to_global_partial(tensor, dim + i) + + return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": tensor = tensor.unflatten(dim, [tensor_dim.global_size for tensor_dim in self._tensor_dims]) for i, tensor_dim in reversed(list(enumerate(self._tensor_dims))): @@ -157,6 +176,27 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor else tensor ) + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + import torch + + return ( + torch.concatenate( + [ + tensor_dim.local_to_global_partial(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": if self.is_parallel and expand: raise NotImplementedError() diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index 5b44bf14b..be15cd37a 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -441,39 +441,21 @@ def _get_parameter_shard_indices_in_full_weight( where it is located in the shard if it exists, or -1 if it's not in the shard. Used to determine the location of each entry in a different distributed configuration. """ - - # Create an empty index for the global parameter. - index = torch.full( - parameter_meta.global_shape, - -1, - dtype=torch.int64, - device=device, - ) # Set the shard slice of the global parameter to corresponding indices of the parameter slice of the shard begin, end = self._get_parameter_range_in_shard(parameter_name) - buffer_index = parameter_meta.global_to_local(index, expand=True) - # Copying directly into `buffer_index` requires a view of the tensor, which may not be feasible. - # In that case, we work with a separate tensor to be copied back into `buffer_index`. - try: - buffer_index_flat = buffer_index.view(-1) - is_view = True - except RuntimeError: - buffer_index_flat = buffer_index.new_full((buffer_index.numel(),), -1) - is_view = False - - # Copy the shard indices at their respective positions in the flat buffer index. - buffer_index_flat[ + # Create an empty local index to hold the local shard indices. + buffer_index = torch.full_like(parameter_meta, -1, dtype=torch.int64, device=device) + + # Copy the shard indices at their respective positions in the buffer index. + buffer_index.flatten()[ self._index_buffer_to_param( self._fsdp_dim.rank * self._shard_size, parameter_name ) : self._index_buffer_to_param((self._fsdp_dim.rank + 1) * self._shard_size, parameter_name) ].copy_(torch.arange(begin, end, dtype=torch.int64, device=device)) - # If needed, copy the flat buffer index back into the index. - if not is_view: - buffer_index.copy_(buffer_index_flat.view_as(buffer_index)) - - return index + # Create a global index from the local one. + return parameter_meta.local_to_global_partial(buffer_index, -1) def copy_shard_overlaps( self, diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index c06d85148..9b0949d55 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,13 +1,16 @@ import enum +import typing from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig -from fast_llm.tensor import Initializer from fast_llm.utils import Assert, div +if typing.TYPE_CHECKING: + from fast_llm.tensor import Initializer + class SSMDimNames: # TODO: Use separate tensor space for different mixers so there is no risk of name conflict. @@ -62,7 +65,7 @@ class DTInitType(enum.StrEnum): constant = "constant" random = "random" - def get_init_method(self, scale: float) -> Initializer: + def get_init_method(self, scale: float) -> "Initializer": from fast_llm.tensor import init_fill_, init_uniform_centered_ return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) diff --git a/fast_llm/models/gpt/megatron.py b/fast_llm/models/gpt/megatron.py index e7379e61e..20ed8e828 100644 --- a/fast_llm/models/gpt/megatron.py +++ b/fast_llm/models/gpt/megatron.py @@ -14,8 +14,8 @@ def get_init_megatron( meta: "ParameterMeta", config: TransformerConfig -) -> typing.Callable[["torch.Tensor", "Distributed"], "torch.Tensor"]: - def init_megatron(tensor: "torch.Tensor", distributed: "Distributed"): +) -> typing.Callable[["torch.Tensor", "Distributed"], None]: + def init_megatron(tensor: "torch.Tensor", distributed: "Distributed") -> None: Assert.eq(distributed.config.world_size, 1) if "bias" in meta.tensor_name: # Generator unused. @@ -29,11 +29,11 @@ def init_megatron(tensor: "torch.Tensor", distributed: "Distributed"): elif config.num_experts > 1 and "mlp.layer_" in meta.tensor_name: tensor_ = _init_moe_mlp_megatron(config, meta, tensor, distributed) elif "mlp.layer_2" in meta.tensor_name: - tensor_ = _init_transposed_mlp_weight_megatron(config, meta, tensor, distributed) + tensor_ = _init_transposed_mlp_weight_megatron(meta, tensor, distributed) else: # Word embedding (override generator), layer norm (generator unused), other mlp weights. return meta.param_init_method(meta, tensor, distributed.tp_init_generator) - return tensor.copy_(tensor_.reshape_as(tensor)) + tensor.copy_(tensor_.reshape_as(tensor)) return init_megatron @@ -58,9 +58,9 @@ def _init_attention_megatron( generator = distributed.tp_init_generator state = generator.get_state() # Initialize a mock dense layer to advance the random state - dense_tensor_ = meta.param_init_method( + meta.param_init_method( meta, - tensor.new_empty( + dense_tensor_ := tensor.new_empty( config.kv_channels * config.num_attention_heads, config.hidden_size, ), @@ -68,9 +68,9 @@ def _init_attention_megatron( ) # QKV is split differently. (Assuming no tensor-parallel.) heads_per_group = div(config.num_attention_heads, config.head_groups) - qkv_tensor_ = meta.param_init_method( + meta.param_init_method( meta, - tensor.new_empty( + qkv_tensor_ := tensor.new_empty( config.head_groups, heads_per_group + 2, config.kv_channels, @@ -110,18 +110,19 @@ def _init_position_embeddings_megatron( # Megatron initializes the position embeddings on cpu twice. assert meta.param_init_method is not None generator = distributed.default_cpu_generator - tensor_ = meta.param_init_method(meta, torch.empty(tensor.shape, dtype=tensor.dtype), generator) - return meta.param_init_method(meta, tensor_, generator) + meta.param_init_method(meta, tensor_ := torch.empty(tensor.shape, dtype=tensor.dtype), generator) + meta.param_init_method(meta, tensor_, generator) + return tensor_ def _init_transposed_mlp_weight_megatron( - config: TransformerConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" + meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" ) -> "torch.Tensor": import torch # Megatron never transposes the mlp layer 2 weight. assert meta.param_init_method is not None - tensor_ = meta.param_init_method(meta, torch.empty_like(tensor), distributed.tp_init_generator) + meta.param_init_method(meta, tensor_ := torch.empty_like(tensor), distributed.tp_init_generator) return tensor_.view(meta.size(1), meta.size(0)).t() @@ -132,8 +133,8 @@ def _init_moe_router_megatron( # Megatron initializes the router on cpu. assert meta.param_init_method is not None - tensor_ = meta.param_init_method( - meta, torch.empty(tensor.shape, dtype=tensor.dtype), distributed.default_cpu_generator + meta.param_init_method( + meta, tensor_ := torch.empty(tensor.shape, dtype=tensor.dtype), distributed.default_cpu_generator ) return tensor_ diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index b89ed4a04..0637931ee 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -1,5 +1,6 @@ import abc import functools +import logging import math import typing @@ -13,6 +14,8 @@ from fast_llm.functional.triton.pointwise import triton_add, triton_copy from fast_llm.utils import Assert +logger = logging.getLogger(__name__) + class _SafeTensorSliceMeta(type): def __instancecheck__(self, instance) -> bool: @@ -159,12 +162,11 @@ def from_tensor_space( def global_shape(self) -> torch.Size: return torch.Size([dim.global_size for dim in self.dims]) - def local_to_global( - self, - tensor: torch.Tensor, - *, - distributed: Distributed, - ) -> tuple[torch.Tensor, ...]: + def local_to_global(self, tensor: torch.Tensor, *, distributed: Distributed) -> tuple[torch.Tensor, ...]: + """ + Reconstruct a global tensor from its distributed slices. Support lazy-loaded safetensor slices. + Returns a view of the input tensor (or the input tensor itself) when possible. + """ if tensor.ndim == 0: tensor = tensor[None] Assert.eq(tensor.shape, self.shape) @@ -188,14 +190,32 @@ def local_to_global( Assert.eq(tensor.shape, self.global_shape) return tensor, is_first_rank - def global_to_local( - self, - tensor: torch.Tensor | SafeTensorSlice, - # Return an expanded tensor, avoiding `flatten` which copies the data. TODO: Rework. - expand: bool = False, - ) -> torch.Tensor: + def local_to_global_partial(self, tensor: torch.Tensor, fill_value: float | int = -1) -> torch.Tensor: """ - Recover the tensor-parallel slice of a tensor. Support lazy-loaded safetensor slices. + Construct a tensor of shape `self.global_shape` that contains its local slice at the appropriate location, + i.e. for which `self.global_to_local(self.local_to_global_partial(tensor)) == tensor`. + Other entries are filled with `fill_value`. + Returns a view of the input tensor (or the input tensor itself) when possible. + """ + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.shape) + assert not self._reductions + logger.info(f"AAAA {self.tensor_name} {self.shape} {self.global_shape} {tensor.shape}") + for dim, tensor_dim in enumerate(self.dims): + if tensor_dim.is_parallel: + tensor = tensor_dim.local_to_global_partial(tensor, dim, fill_value) + logger.info( + f"BBBB {self.tensor_name} {self.shape} {self.global_shape} {tensor.shape} {tensor_dim.is_parallel}" + ) + + Assert.eq(tensor.shape, self.global_shape) + return tensor + + def global_to_local(self, tensor: torch.Tensor | SafeTensorSlice) -> torch.Tensor: + """ + Select the local slice of a global tensor. Support lazy-loaded safetensor slices. + Returns a view of the input tensor (or the input tensor itself) when possible. """ # Take a trivial slice to convert safetensor slices. tensor = tensor[:] @@ -205,9 +225,9 @@ def global_to_local( Assert.eq(tensor.shape, self.global_shape) for dim, tensor_dim in reversed(list(enumerate(self.dims))): - tensor = tensor_dim.global_to_local(tensor, dim, expand) - if not expand: - Assert.eq(tensor.shape, self.shape) + tensor = tensor_dim.global_to_local(tensor, dim) + + Assert.eq(tensor.shape, self.shape) return tensor @classmethod @@ -302,7 +322,11 @@ def __repr__(self, *, tensor_contents=()) -> str: def init_parameter(self, tensor: torch.Tensor, distributed: Distributed) -> None: assert self.param_init_method is not None - if distributed.config.tensor_parallel == 1 or distributed.config.reproducible_init: + if ( + distributed.config.tensor_parallel == 1 + or distributed.config.reproducible_init + or self.param_init_method.requires_global_initialization + ): generator = distributed.pp_init_generator else: generator = distributed.tp_init_generator if self.is_tensor_parallel else distributed.pp_init_generator From b6c86138bbdbf19099b799475f17e8d3dcca34b6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 28 Jul 2025 16:08:58 -0400 Subject: [PATCH 23/82] misc --- fast_llm/engine/config_utils/tensor_space.py | 5 +---- fast_llm/layers/language_model/embedding.py | 8 ++++---- fast_llm/layers/language_model/head.py | 10 +++++----- .../layers/language_model/preprocessing.py | 4 ++-- fast_llm/layers/ssm/discrete_mamba2.py | 18 ++++++++--------- fast_llm/layers/ssm/mamba2.py | 20 +++++++++---------- fast_llm/layers/ssm/mamba_layer.py | 16 +++++++-------- fast_llm/layers/transformer/attention.py | 16 +++++++-------- .../layers/transformer/mixture_of_experts.py | 6 +++--- fast_llm/layers/transformer/mlp.py | 6 +++--- fast_llm/layers/transformer/preprocessing.py | 2 +- .../transformer/rotary/preprocessing.py | 4 ++-- fast_llm/layers/transformer/rotary/rotary.py | 4 ++-- fast_llm/layers/transformer/transformer.py | 4 ++-- fast_llm/models/gpt/model.py | 2 +- fast_llm/tensor.py | 2 +- 16 files changed, 62 insertions(+), 65 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 55d87e271..cf2974a99 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -263,8 +263,5 @@ def add_tensor_dim(self, tensor_dim: TensorDim) -> None: ) self._tensor_dims[tensor_dim.name] = tensor_dim - def get_tensor_dim(self, name: str) -> TensorDim: + def __getitem__(self, name: str) -> TensorDim: return self._tensor_dims[name] - - # TODO: Replace uses - __getitem__ = get_tensor_dim diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 7036a1e97..f6f43d199 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -46,10 +46,10 @@ def __init__( self._dropout_p = config.transformer.hidden_dropout self._use_absolute_position_embeddings = config.use_absolute_position_embeddings - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - vocab_dim = tensor_space.get_tensor_dim( + hidden_dim = tensor_space[TransformerDimNames.hidden] + vocab_dim = tensor_space[ LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ) + ] if self._parallel_embeddings: self._vocab_start_index = self._distributed_config.tensor_rank * vocab_dim.size @@ -66,7 +66,7 @@ def __init__( ) if self._use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( - (tensor_space.get_tensor_dim(LanguageModelDimNames.position_embed), hidden_dim), + (tensor_space[LanguageModelDimNames.position_embed], hidden_dim), init_method=init_normal_( std=config.init_method_std_embed, min_val=config.init_method_min_embed, diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 21bf3bbd0..210cad644 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -61,7 +61,7 @@ def __init__( if self._cross_entropy_splits is not None and self._sequence_parallel: assert not self._parallel_embeddings - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] self._loss_coefficient = ( config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0 @@ -108,9 +108,9 @@ def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: if self._tie_word_embeddings or self._prediction_distance > 0: return # untie embedding weights - vocab_dim = self._tensor_space.get_tensor_dim( + vocab_dim = self._tensor_space[ LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ) + ] self.output_weights = ParameterMeta.from_dims( (vocab_dim, hidden_dim), init_method=init_normal_( @@ -338,9 +338,9 @@ def _logits_cross_entropy_forward_backward( logits_scale_factor=self._logits_scale_factor, ) if self._debug_transformer and self._cross_entropy_splits is None: - vocab_dim = self._tensor_space.get_tensor_dim( + vocab_dim = self._tensor_space[ LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp - ) + ] dims = [*kwargs[TransformerKwargs.hidden_dims][:-1], vocab_dim] sequence_index = 1 - int(kwargs[TransformerKwargs.sequence_first]) dims[sequence_index] = ( diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index d719bef3d..c8d53a789 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -28,7 +28,7 @@ def __init__( assert config.use_absolute_position_embeddings self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def _create_tensors(self, sequence_length: int) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: @@ -76,7 +76,7 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: return diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 64377b93c..c9d555de9 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -49,25 +49,25 @@ def __init__( layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_head_dim) - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - conv1d_dim = tensor_space.get_tensor_dim(SSMDimNames.concatenated_convolution) - heads_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads) + inner_dim = tensor_space[SSMDimNames.composite_heads_and_head_dim] + hidden_dim = tensor_space[TransformerDimNames.hidden] + conv1d_dim = tensor_space[SSMDimNames.concatenated_convolution] + heads_dim = tensor_space[SSMDimNames.composite_heads] # local_head_groups = head_groups / TP - self._local_head_groups = tensor_space.get_tensor_dim(SSMDimNames.head_groups).size + self._local_head_groups = tensor_space[SSMDimNames.head_groups].size # local_heads = local_head_groups * group_heads self._local_heads = heads_dim.size # local_inner_size = local_heads * head_size self._local_inner_size = inner_dim.size # local_bc_size = local_head_groups * state - self._local_bc_size = tensor_space.get_tensor_dim(SSMDimNames.composite_head_groups_and_state).size + self._local_bc_size = tensor_space[SSMDimNames.composite_head_groups_and_state].size # TODO: double check initializations # Projections self.in_proj = OutputParallelLinear( hidden_dim, - tensor_space.get_tensor_dim(name=SSMDimNames.concatenated_inner_projection), + tensor_space[SSMDimNames.concatenated_inner_projection], bias=config.add_bias_linear, weight_init_method=init_kaiming_(transformer_config.hidden_size), sequence_parallel=self._sequence_parallel, @@ -83,8 +83,8 @@ def __init__( self.conv1d_weight = ParameterMeta.from_dims( ( conv1d_dim, - tensor_space.get_tensor_dim(DefaultDimNames.scalar), - tensor_space.get_tensor_dim(name=SSMDimNames.convolution_kernel), + tensor_space[DefaultDimNames.scalar], + tensor_space[SSMDimNames.convolution_kernel], ), init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 1ae25e44c..77c1b3869 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -62,13 +62,13 @@ def __init__( layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_head_dim) - xb_dim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_head_groups_and_state) - hidden_dim: TensorDim = tensor_space.get_tensor_dim(name=TransformerDimNames.hidden) - dt_rank_dim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) + inner_dim: TensorDim = tensor_space[SSMDimNames.composite_heads_and_head_dim] + xb_dim = tensor_space[SSMDimNames.composite_head_groups_and_state] + hidden_dim: TensorDim = tensor_space[TransformerDimNames.hidden] + dt_rank_dim = tensor_space[SSMDimNames.dt_rank] - self._local_heads = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads).size - self._local_head_groups = tensor_space.get_tensor_dim(name=SSMDimNames.head_groups).size + self._local_heads = tensor_space[SSMDimNames.composite_heads].size + self._local_head_groups = tensor_space[SSMDimNames.head_groups].size self._group_heads = div(self._local_heads, self._local_head_groups) self._local_inner_size = inner_dim.size self._local_xb_size = xb_dim.size @@ -77,8 +77,8 @@ def __init__( self.conv1d_weight = ParameterMeta.from_dims( ( conv1d_dim, - tensor_space.get_tensor_dim(DefaultDimNames.scalar), - tensor_space.get_tensor_dim(name=SSMDimNames.convolution_kernel), + tensor_space[DefaultDimNames.scalar], + tensor_space[SSMDimNames.convolution_kernel], ), init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, @@ -90,7 +90,7 @@ def __init__( ) self.in_proj = OutputParallelLinear( hidden_dim, - tensor_space.get_tensor_dim(name=SSMDimNames.concatenated_inner_projection), + tensor_space[SSMDimNames.concatenated_inner_projection], bias=config.add_bias_linear, weight_init_method=init_kaiming_(transformer_config.hidden_size), sequence_parallel=self._sequence_parallel, @@ -122,7 +122,7 @@ def __init__( lr_scale=lr_scale, ) self.A_log = ParameterMeta.from_dims( - (inner_dim, tensor_space.get_tensor_dim(name=SSMDimNames.state)), + (inner_dim, tensor_space[SSMDimNames.state]), init_method=init_A(self._config.state_size, self._config.d_inner), lr_scale=lr_scale, weight_decay=False, diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 64c8227fc..9343ef1b8 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -69,8 +69,8 @@ def __init__( Assert.eq(self._config.activation_type, ActivationType.silu) # Tensor dims: - inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_head_dim) - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) + inner_dim = tensor_space[SSMDimNames.composite_heads_and_head_dim] + hidden_dim = tensor_space[TransformerDimNames.hidden] layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) @@ -78,7 +78,7 @@ def __init__( # TODO: lr_scale? self.in_proj = Linear( hidden_dim, - tensor_space.get_tensor_dim(SSMDimNames.concatenated_inner_projection), + tensor_space[SSMDimNames.concatenated_inner_projection], bias=False, weight_init_method=init_kaiming_(hidden_dim.size), ) @@ -86,8 +86,8 @@ def __init__( self.conv1d_weight = ParameterMeta.from_dims( ( inner_dim, - tensor_space.get_tensor_dim(DefaultDimNames.scalar), - tensor_space.get_tensor_dim(SSMDimNames.convolution_kernel), + tensor_space[DefaultDimNames.scalar], + tensor_space[SSMDimNames.convolution_kernel], ), init_method=init_kaiming_(inner_dim.size), lr_scale=lr_scale, @@ -95,7 +95,7 @@ def __init__( self.x_proj = Linear( inner_dim, - tensor_space.get_tensor_dim(SSMDimNames.concatenated_x_projection), + tensor_space[SSMDimNames.concatenated_x_projection], weight_init_method=init_kaiming_(inner_dim.size), bias=False, lr_scale=lr_scale, @@ -104,7 +104,7 @@ def __init__( # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 self.dt_proj_weight = ParameterMeta.from_dims( - (inner_dim, tensor_space.get_tensor_dim(SSMDimNames.dt_rank)), + (inner_dim, tensor_space[SSMDimNames.dt_rank]), init_method=init_kaiming_(self._config.dt_rank), lr_scale=lr_scale, ) @@ -116,7 +116,7 @@ def __init__( ) self.A_log = ParameterMeta.from_dims( - (inner_dim, tensor_space.get_tensor_dim(SSMDimNames.state)), + (inner_dim, tensor_space[SSMDimNames.state]), weight_decay=False, init_method=init_A(self._config.state_size, inner_dim.size), lr_scale=lr_scale, diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 174e19588..c59b191af 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -91,14 +91,14 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i max_val=self._config.init_method_max_attn_proj, ) - self._kv_channels = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels).size - self._head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).global_size - self._local_head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).size - self._local_heads_per_group = self._tensor_space.get_tensor_dim(TransformerDimNames.group_heads).size + self._kv_channels = self._tensor_space[TransformerDimNames.kv_channels].size + self._head_groups = self._tensor_space[TransformerDimNames.head_groups].global_size + self._local_head_groups = self._tensor_space[TransformerDimNames.head_groups].size + self._local_heads_per_group = self._tensor_space[TransformerDimNames.group_heads].size self._local_heads = self._local_head_groups * self._local_heads_per_group self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) @@ -106,7 +106,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_query), + self._tensor_space[TransformerDimNames.composite_query], bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -115,7 +115,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i ) self.key_value = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_key_value), + self._tensor_space[TransformerDimNames.composite_key_value], bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -129,7 +129,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i # Output. self.dense = InputParallelLinear( - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_dense), + self._tensor_space[TransformerDimNames.composite_dense], hidden_dim, bias=self._config.add_attn_dense_bias, weight_init_method=init_method_std_attn_proj, diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index 73f83ccf5..4fd2844d5 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -63,8 +63,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) self.router = Linear( - tensor_space.get_tensor_dim(TransformerDimNames.hidden), - tensor_space.get_tensor_dim(TransformerDimNames.unshared_experts), + tensor_space[TransformerDimNames.hidden], + tensor_space[TransformerDimNames.unshared_experts], bias=False, weight_init_method=init_normal_( std=config.init_method_std, min_val=config.init_method_min, max_val=config.init_method_max @@ -255,7 +255,7 @@ def _debug_log( def _get_meta(self, tensor: torch.Tensor, name: str, dim_name: str, kwargs: dict[str, typing.Any]) -> TensorMeta: return TensorMeta.from_dims( - kwargs[TransformerKwargs.hidden_dims][:-1] + (self._tensor_space.get_tensor_dim(dim_name),), + kwargs[TransformerKwargs.hidden_dims][:-1] + (self._tensor_space[dim_name],), tensor_name=f"{self._name} {name}", dtype=tensor.dtype, ) diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index efe0c5cc5..101d97ef3 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -30,8 +30,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s max_val=config.init_method_max_mlp_2, ) - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - self._intermediate_dim = tensor_space.get_tensor_dim(TransformerDimNames.composite_expert_mlp) + hidden_dim = tensor_space[TransformerDimNames.hidden] + self._intermediate_dim = tensor_space[TransformerDimNames.composite_expert_mlp] self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel self._recompute_level = config.mlp_recompute_level @@ -46,7 +46,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - tensor_space.get_tensor_dim(TransformerDimNames.composite_gated_expert_mlp), + tensor_space[TransformerDimNames.composite_gated_expert_mlp], bias=config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index dc3ddeb52..3f0e14eb7 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -28,7 +28,7 @@ def __init__( self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config assert not self._config.do_use_flash_attention(self._distributed_config) - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def _create_tensors(self, sequence_length: int) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: diff --git a/fast_llm/layers/transformer/rotary/preprocessing.py b/fast_llm/layers/transformer/rotary/preprocessing.py index cc83dae02..c357411b6 100644 --- a/fast_llm/layers/transformer/rotary/preprocessing.py +++ b/fast_llm/layers/transformer/rotary/preprocessing.py @@ -25,8 +25,8 @@ def __init__( self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) - self._kv_channels_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] + self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: self._create_tensors(kwargs[TransformerKwargs.sequence_length]) diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index 056b9aa4c..17b18a1ca 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -82,8 +82,8 @@ def __init__( super().__init__(config, tensor_space) self._tensor_space = tensor_space if self._tensor_space is not None: - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) - self._kv_channels_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] + self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: assert self._tensor_space is not None diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index d08db9a94..75d06f268 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -48,7 +48,7 @@ def _get_meta( } return TensorMeta.from_dims( tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) + hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space[dim_name] for dim_name in dim_names ), tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", @@ -97,7 +97,7 @@ def __init__( self._block_index = block_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] # Note, layer_lr_scale does not impact the norms # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 4c1eab46f..49a5dcbd3 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -155,7 +155,7 @@ def preprocess_meta( sequence_first = self._config.sequence_first assert not (need_sequence_first and not sequence_first) - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] hidden_dims = ( (hidden_sequence_q_dim, batch_dim, hidden_dim) if sequence_first diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 0637931ee..b3795b740 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -150,7 +150,7 @@ def from_tensor_space( reductions: tuple[tuple[str, ReduceOp], ...] = (), **kwargs: typing.Any, ) -> typing.Self: - dims = tuple(tensor_space.get_tensor_dim(dim_name) for dim_name in dim_names) + dims = tuple(tensor_space[dim_name] for dim_name in dim_names) if reductions: # kwarg not available for ParameterMeta, so we only provide if necessary. kwargs["reductions"] = tuple( From e536af9d935fe789b98683777e3e320eaf5d7e62 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 28 Jul 2025 16:15:17 -0400 Subject: [PATCH 24/82] Concatenated dim --- fast_llm/engine/config_utils/tensor_space.py | 224 +++++++++++++----- fast_llm/engine/multi_stage/fsdp.py | 32 +-- fast_llm/engine/multi_stage/stage_base.py | 5 +- fast_llm/layers/common/config.py | 6 +- fast_llm/layers/common/linear.py | 8 +- fast_llm/layers/common/normalization.py | 4 +- fast_llm/layers/common/peft.py | 4 +- fast_llm/layers/language_model/embedding.py | 8 +- fast_llm/layers/language_model/head.py | 10 +- .../layers/language_model/preprocessing.py | 4 +- fast_llm/layers/transformer/attention.py | 16 +- .../layers/transformer/mixture_of_experts.py | 6 +- fast_llm/layers/transformer/mlp.py | 6 +- fast_llm/layers/transformer/preprocessing.py | 2 +- .../transformer/rotary/preprocessing.py | 4 +- fast_llm/layers/transformer/rotary/rotary.py | 4 +- fast_llm/layers/transformer/transformer.py | 4 +- fast_llm/models/gpt/megatron.py | 29 +-- fast_llm/models/gpt/model.py | 2 +- fast_llm/tensor.py | 169 ++++++++----- 20 files changed, 346 insertions(+), 201 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 99c1bcf70..cf2974a99 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -1,3 +1,4 @@ +import logging import math import typing @@ -5,9 +6,13 @@ from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: + import torch + from fast_llm.core.distributed import ProcessGroup from fast_llm.engine.distributed.distributed import Distributed +logger = logging.getLogger(__name__) + class TensorDim: def __init__(self, name: str, global_size: int | None, parallel_dim: DistributedDim | None = None): @@ -19,11 +24,11 @@ def __init__(self, name: str, global_size: int | None, parallel_dim: Distributed def __repr__(self) -> str: return ( - f"TensorDim(" + f"{type(self).__name__}(" f"name={self._name}," f" size={self._size}," f" global_size={self._global_size}," - f" parallel_dim={None if self.parallel_dim is None else self._parallel_dim}" + f" parallel_dim={self._parallel_dim}" f")" ) @@ -38,83 +43,180 @@ def name(self) -> str: def size(self) -> int: return self._size - @property - def expanded_shape(self) -> tuple[int, ...]: - return (self._size,) - - @property - def ndim(self) -> int: - return 1 - @property def global_size(self) -> int: return self._global_size @property - def global_expanded_shape(self) -> tuple[int, ...]: - return (self._size if self._parallel_dim is None else self._size * self._parallel_dim.size,) + def is_parallel(self) -> bool: + return self._parallel_dim is not None and self._parallel_dim.size > 1 @property def parallel_dim(self) -> DistributedDim | None: + # TODO: Make more flexible for derived classes? return self._parallel_dim - @property - def parallel_dim_index(self) -> int | None: - return None if self._parallel_dim is None else 0 - @property def parallel_group(self) -> "ProcessGroup|None": + # TODO: Make more flexible for derived classes? return None if self._parallel_dim is None else self._parallel_dim.group def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - assert self.parallel_dim is not None + assert self.is_parallel return TensorDim(self.name, self.size * distributed_dim.size, distributed_dim) + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + if self.is_parallel: + from fast_llm.core.ops import gather_op + + return gather_op(tensor, self.parallel_group, dim) + else: + return tensor + + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + if self.is_parallel: + output = tensor.new_full((*tensor.shape[:dim], self.parallel_dim.size, *tensor.shape[dim:]), fill_value) + output.narrow(dim, self.parallel_dim.rank, 1).copy_(tensor.unsqueeze(dim)).squeeze(dim) + return output.flatten(dim, dim + 1) + else: + return tensor + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + return ( + tensor.chunk(self.parallel_dim.size, dim)[self.parallel_dim.rank] + if self.parallel_dim is not None and self.parallel_dim.size > 1 + else tensor + ) + class CompositeTensorDim(TensorDim): - def __init__(self, name: str, dims: tuple[TensorDim, ...]): - # TODO: Recursive composition?? - parallel_dims = [(i, dim.parallel_dim) for i, dim in enumerate(dims) if dim.parallel_dim] - Assert.leq(len(parallel_dims), 1) + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): + parallel_dim = None + for dim, tensor_dim in enumerate(tensor_dims): + if tensor_dim.is_parallel: + # TODO: Allow more than one parallel subdim? + assert parallel_dim is None + parallel_dim = tensor_dim.parallel_dim + self._parallel_dim_index = dim super().__init__( name=name, - global_size=math.prod(dim.global_size for dim in dims), - parallel_dim=parallel_dims[0][1] if parallel_dims else None, - ) - self._dims = dims - self._parallel_dim_index = ( - sum(dim.ndim for dim in self._dims[: parallel_dims[0][0]]) - + self._dims[parallel_dims[0][0]].parallel_dim_index - if parallel_dims - else None + global_size=math.prod(dim.global_size for dim in tensor_dims), + parallel_dim=parallel_dim, ) + self._tensor_dims = tensor_dims - @property - def dims(self) -> tuple[TensorDim, ...]: - return self._dims + def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + assert self._parallel_dim_index is not None + dims = list(self._tensor_dims) + dims[self._parallel_dim_index] = dims[self._parallel_dim_index].replace_parallel_dim(distributed_dim) + return CompositeTensorDim(self.name, tuple(dims)) - @property - def ndim(self) -> int: - return sum(dim.ndim for dim in self._dims) + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in enumerate(self._tensor_dims): + tensor = tensor_dim.local_to_global(tensor, dim + i) - @property - def expanded_shape(self) -> tuple[int, ...]: - return sum((dim.expanded_shape for dim in self._dims), ()) + return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) - @property - def global_expanded_shape(self) -> tuple[int, ...]: - return sum((dim.global_expanded_shape for dim in self._dims), ()) + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in enumerate(self._tensor_dims): + tensor = tensor_dim.local_to_global_partial(tensor, dim + i) + + return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + tensor = tensor.unflatten(dim, [tensor_dim.global_size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in reversed(list(enumerate(self._tensor_dims))): + tensor = tensor_dim.global_to_local(tensor, dim + i) + return tensor if expand else tensor.flatten(dim, dim + len(self._tensor_dims) - 1) - @property - def parallel_dim_index(self) -> int | None: - return self._parallel_dim_index + +class ConcatenatedTensorDim(TensorDim): + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): + parallel_dim = tensor_dims[0].parallel_dim + for dim, tensor_dim in enumerate(tensor_dims[1:]): + # TODO: Allow more flexibility? + Assert.is_(tensor_dim.parallel_dim, parallel_dim) + + super().__init__( + name=name, + global_size=sum(dim.global_size for dim in tensor_dims), + parallel_dim=parallel_dim, + ) + self._tensor_dims = tensor_dims def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - assert self.parallel_dim_index is not None - dims = list(self.dims) - dims[self.parallel_dim_index] = dims[self.parallel_dim_index].replace_parallel_dim(distributed_dim) - return CompositeTensorDim(self.name, tuple(dims)) + assert self.is_parallel + return ConcatenatedTensorDim( + self.name, tuple(tensor_dim.replace_parallel_dim(distributed_dim) for tensor_dim in self._tensor_dims) + ) + + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + import torch + + return ( + torch.concatenate( + [ + tensor_dim.local_to_global(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) + + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + import torch + + return ( + torch.concatenate( + [ + tensor_dim.local_to_global_partial(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + if self.is_parallel and expand: + raise NotImplementedError() + import torch + + return ( + torch.concatenate( + [ + tensor_dim.global_to_local(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.global_size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) class DefaultDimNames: @@ -147,21 +249,19 @@ def distributed(self) -> "Distributed": assert self._is_setup return self._distributed - def add_tensor_dim(self, dim: TensorDim) -> None: - if isinstance(dim, CompositeTensorDim): - for dim_ in dim.dims: - Assert.incl(dim_.name, self._tensor_dims) - Assert.eq(dim_, self._tensor_dims[dim_.name]) - if dim.name in self._tensor_dims: - Assert.eq(dim, self._tensor_dims[dim.name]) + def add_tensor_dim(self, tensor_dim: TensorDim) -> None: + if tensor_dim.name in self._tensor_dims: + Assert.eq(tensor_dim, self._tensor_dims[tensor_dim.name]) else: - if dim.parallel_dim is not None: - assert dim.parallel_dim.name in self._distributed_config.distributed_dims, dim.parallel_dim.name + if tensor_dim.parallel_dim is not None: + assert ( + tensor_dim.parallel_dim.name in self._distributed_config.distributed_dims + ), tensor_dim.parallel_dim.name Assert.eq( - dim.parallel_dim.__dict__, - self._distributed_config.distributed_dims[dim.parallel_dim.name].__dict__, + tensor_dim.parallel_dim.__dict__, + self._distributed_config.distributed_dims[tensor_dim.parallel_dim.name].__dict__, ) - self._tensor_dims[dim.name] = dim + self._tensor_dims[tensor_dim.name] = tensor_dim - def get_tensor_dim(self, name: str) -> TensorDim: + def __getitem__(self, name: str) -> TensorDim: return self._tensor_dims[name] diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index 5b44bf14b..be15cd37a 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -441,39 +441,21 @@ def _get_parameter_shard_indices_in_full_weight( where it is located in the shard if it exists, or -1 if it's not in the shard. Used to determine the location of each entry in a different distributed configuration. """ - - # Create an empty index for the global parameter. - index = torch.full( - parameter_meta.global_shape, - -1, - dtype=torch.int64, - device=device, - ) # Set the shard slice of the global parameter to corresponding indices of the parameter slice of the shard begin, end = self._get_parameter_range_in_shard(parameter_name) - buffer_index = parameter_meta.global_to_local(index, expand=True) - # Copying directly into `buffer_index` requires a view of the tensor, which may not be feasible. - # In that case, we work with a separate tensor to be copied back into `buffer_index`. - try: - buffer_index_flat = buffer_index.view(-1) - is_view = True - except RuntimeError: - buffer_index_flat = buffer_index.new_full((buffer_index.numel(),), -1) - is_view = False - - # Copy the shard indices at their respective positions in the flat buffer index. - buffer_index_flat[ + # Create an empty local index to hold the local shard indices. + buffer_index = torch.full_like(parameter_meta, -1, dtype=torch.int64, device=device) + + # Copy the shard indices at their respective positions in the buffer index. + buffer_index.flatten()[ self._index_buffer_to_param( self._fsdp_dim.rank * self._shard_size, parameter_name ) : self._index_buffer_to_param((self._fsdp_dim.rank + 1) * self._shard_size, parameter_name) ].copy_(torch.arange(begin, end, dtype=torch.int64, device=device)) - # If needed, copy the flat buffer index back into the index. - if not is_view: - buffer_index.copy_(buffer_index_flat.view_as(buffer_index)) - - return index + # Create a global index from the local one. + return parameter_meta.local_to_global_partial(buffer_index, -1) def copy_shard_overlaps( self, diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 9a8ce2092..3218a1963 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -185,8 +185,9 @@ def initialize_weights(self) -> None: # Multi-gpu init may be different because of TP or FSDP (different shape), or PP (not on device) global_shape = meta.global_shape - if self._distributed_config.reproducible_init and ( - global_shape.numel() != parameter.numel() or not self._mode.on_device + if meta.requires_global_initialization or ( + self._distributed_config.reproducible_init + and (global_shape.numel() != parameter.numel() or not self._mode.on_device) ): # Initialize all global weights on every gpu, then select the appropriate slice if applicable. global_param = parameter.new_empty(global_shape, device=self._distributed.device) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 9f32ac689..07dadbc22 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -99,7 +99,7 @@ class LayerNormalizationBaseConfig(NormalizationConfig): ) def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": - from fast_llm.tensor import init_uniform_ + from fast_llm.tensor import init_uniform_centered_ kwargs = { "hidden_dim": hidden_dim, @@ -110,9 +110,7 @@ def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> " } if self.initialization_range: mean = 0 if self.zero_centered else 1 - kwargs["weight_init_method"] = init_uniform_( - mean - self.initialization_range, mean + self.initialization_range - ) + kwargs["weight_init_method"] = init_uniform_centered_(self.initialization_range, mean=mean) return self.module_class(**kwargs) @property diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear.py index cd19a47a5..7249ef569 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear.py @@ -94,8 +94,8 @@ def __init__( transposed_weight: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert in_dim.parallel_dim is None - assert out_dim.parallel_dim is None + assert not in_dim.is_parallel + assert not out_dim.is_parallel super().__init__( in_dim, out_dim, @@ -132,7 +132,7 @@ def __init__( sequence_parallel: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert in_dim.parallel_dim is None + assert not in_dim.is_parallel self._group_size = 1 if out_dim.parallel_dim is None else out_dim.parallel_dim.size self._sequence_parallel = sequence_parallel and self._group_size > 1 super().__init__( @@ -176,7 +176,7 @@ def __init__( transposed_weight: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert out_dim.parallel_dim is None + assert not out_dim.is_parallel self._group_size = 1 if in_dim.parallel_dim is None else in_dim.parallel_dim.size self._sequence_parallel = sequence_parallel and self._group_size > 1 super().__init__( diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index 5f30beaef..bccc1d627 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -158,7 +158,7 @@ def __init__( lr_scale: float | None = None, ): super().__init__() - assert hidden_dim.parallel_dim is None + assert not hidden_dim.is_parallel self._eps = eps self._zero_centered = zero_centered if implementation == NormalizationImplementation.auto: @@ -242,7 +242,7 @@ def __init__( lr_scale: float | None = None, ): super().__init__() - assert hidden_dim.parallel_dim is None + assert not hidden_dim.is_parallel self._eps = eps self._zero_centered = zero_centered if implementation == NormalizationImplementation.auto: diff --git a/fast_llm/layers/common/peft.py b/fast_llm/layers/common/peft.py index 3a1966e51..08f3e535b 100644 --- a/fast_llm/layers/common/peft.py +++ b/fast_llm/layers/common/peft.py @@ -19,12 +19,12 @@ def lora_linear( ): layer.weight.requires_grad = False in_dim = layer._in_dim + assert not in_dim.is_parallel, "LoRA not supported with tensor parallelism." if in_dim.parallel_dim is not None: - assert in_dim.parallel_dim.size == 1, "LoRA not supported with tensor parallelism." in_dim = TensorDim(in_dim.name, in_dim.global_size) out_dim = layer._out_dim + assert not out_dim.is_parallel, "LoRA not supported with tensor parallelism." if out_dim.parallel_dim is not None: - assert out_dim.parallel_dim.size == 1, "LoRA not supported with tensor parallelism." out_dim = TensorDim(out_dim.name, out_dim.global_size) if out_channel_begin is not None or out_channel_end is not None: if out_channel_begin is None: diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 7036a1e97..f6f43d199 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -46,10 +46,10 @@ def __init__( self._dropout_p = config.transformer.hidden_dropout self._use_absolute_position_embeddings = config.use_absolute_position_embeddings - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - vocab_dim = tensor_space.get_tensor_dim( + hidden_dim = tensor_space[TransformerDimNames.hidden] + vocab_dim = tensor_space[ LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ) + ] if self._parallel_embeddings: self._vocab_start_index = self._distributed_config.tensor_rank * vocab_dim.size @@ -66,7 +66,7 @@ def __init__( ) if self._use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( - (tensor_space.get_tensor_dim(LanguageModelDimNames.position_embed), hidden_dim), + (tensor_space[LanguageModelDimNames.position_embed], hidden_dim), init_method=init_normal_( std=config.init_method_std_embed, min_val=config.init_method_min_embed, diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 21bf3bbd0..210cad644 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -61,7 +61,7 @@ def __init__( if self._cross_entropy_splits is not None and self._sequence_parallel: assert not self._parallel_embeddings - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] self._loss_coefficient = ( config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0 @@ -108,9 +108,9 @@ def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: if self._tie_word_embeddings or self._prediction_distance > 0: return # untie embedding weights - vocab_dim = self._tensor_space.get_tensor_dim( + vocab_dim = self._tensor_space[ LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ) + ] self.output_weights = ParameterMeta.from_dims( (vocab_dim, hidden_dim), init_method=init_normal_( @@ -338,9 +338,9 @@ def _logits_cross_entropy_forward_backward( logits_scale_factor=self._logits_scale_factor, ) if self._debug_transformer and self._cross_entropy_splits is None: - vocab_dim = self._tensor_space.get_tensor_dim( + vocab_dim = self._tensor_space[ LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp - ) + ] dims = [*kwargs[TransformerKwargs.hidden_dims][:-1], vocab_dim] sequence_index = 1 - int(kwargs[TransformerKwargs.sequence_first]) dims[sequence_index] = ( diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index d719bef3d..c8d53a789 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -28,7 +28,7 @@ def __init__( assert config.use_absolute_position_embeddings self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def _create_tensors(self, sequence_length: int) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: @@ -76,7 +76,7 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: return diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 174e19588..c59b191af 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -91,14 +91,14 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i max_val=self._config.init_method_max_attn_proj, ) - self._kv_channels = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels).size - self._head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).global_size - self._local_head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).size - self._local_heads_per_group = self._tensor_space.get_tensor_dim(TransformerDimNames.group_heads).size + self._kv_channels = self._tensor_space[TransformerDimNames.kv_channels].size + self._head_groups = self._tensor_space[TransformerDimNames.head_groups].global_size + self._local_head_groups = self._tensor_space[TransformerDimNames.head_groups].size + self._local_heads_per_group = self._tensor_space[TransformerDimNames.group_heads].size self._local_heads = self._local_head_groups * self._local_heads_per_group self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) @@ -106,7 +106,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_query), + self._tensor_space[TransformerDimNames.composite_query], bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -115,7 +115,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i ) self.key_value = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_key_value), + self._tensor_space[TransformerDimNames.composite_key_value], bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -129,7 +129,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i # Output. self.dense = InputParallelLinear( - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_dense), + self._tensor_space[TransformerDimNames.composite_dense], hidden_dim, bias=self._config.add_attn_dense_bias, weight_init_method=init_method_std_attn_proj, diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index 73f83ccf5..4fd2844d5 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -63,8 +63,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) self.router = Linear( - tensor_space.get_tensor_dim(TransformerDimNames.hidden), - tensor_space.get_tensor_dim(TransformerDimNames.unshared_experts), + tensor_space[TransformerDimNames.hidden], + tensor_space[TransformerDimNames.unshared_experts], bias=False, weight_init_method=init_normal_( std=config.init_method_std, min_val=config.init_method_min, max_val=config.init_method_max @@ -255,7 +255,7 @@ def _debug_log( def _get_meta(self, tensor: torch.Tensor, name: str, dim_name: str, kwargs: dict[str, typing.Any]) -> TensorMeta: return TensorMeta.from_dims( - kwargs[TransformerKwargs.hidden_dims][:-1] + (self._tensor_space.get_tensor_dim(dim_name),), + kwargs[TransformerKwargs.hidden_dims][:-1] + (self._tensor_space[dim_name],), tensor_name=f"{self._name} {name}", dtype=tensor.dtype, ) diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index efe0c5cc5..101d97ef3 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -30,8 +30,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s max_val=config.init_method_max_mlp_2, ) - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - self._intermediate_dim = tensor_space.get_tensor_dim(TransformerDimNames.composite_expert_mlp) + hidden_dim = tensor_space[TransformerDimNames.hidden] + self._intermediate_dim = tensor_space[TransformerDimNames.composite_expert_mlp] self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel self._recompute_level = config.mlp_recompute_level @@ -46,7 +46,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - tensor_space.get_tensor_dim(TransformerDimNames.composite_gated_expert_mlp), + tensor_space[TransformerDimNames.composite_gated_expert_mlp], bias=config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index dc3ddeb52..3f0e14eb7 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -28,7 +28,7 @@ def __init__( self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config assert not self._config.do_use_flash_attention(self._distributed_config) - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def _create_tensors(self, sequence_length: int) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: diff --git a/fast_llm/layers/transformer/rotary/preprocessing.py b/fast_llm/layers/transformer/rotary/preprocessing.py index cc83dae02..c357411b6 100644 --- a/fast_llm/layers/transformer/rotary/preprocessing.py +++ b/fast_llm/layers/transformer/rotary/preprocessing.py @@ -25,8 +25,8 @@ def __init__( self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) - self._kv_channels_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] + self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: self._create_tensors(kwargs[TransformerKwargs.sequence_length]) diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index 056b9aa4c..17b18a1ca 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -82,8 +82,8 @@ def __init__( super().__init__(config, tensor_space) self._tensor_space = tensor_space if self._tensor_space is not None: - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) - self._kv_channels_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] + self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: assert self._tensor_space is not None diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index d08db9a94..75d06f268 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -48,7 +48,7 @@ def _get_meta( } return TensorMeta.from_dims( tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) + hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space[dim_name] for dim_name in dim_names ), tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", @@ -97,7 +97,7 @@ def __init__( self._block_index = block_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] # Note, layer_lr_scale does not impact the norms # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) diff --git a/fast_llm/models/gpt/megatron.py b/fast_llm/models/gpt/megatron.py index e7379e61e..20ed8e828 100644 --- a/fast_llm/models/gpt/megatron.py +++ b/fast_llm/models/gpt/megatron.py @@ -14,8 +14,8 @@ def get_init_megatron( meta: "ParameterMeta", config: TransformerConfig -) -> typing.Callable[["torch.Tensor", "Distributed"], "torch.Tensor"]: - def init_megatron(tensor: "torch.Tensor", distributed: "Distributed"): +) -> typing.Callable[["torch.Tensor", "Distributed"], None]: + def init_megatron(tensor: "torch.Tensor", distributed: "Distributed") -> None: Assert.eq(distributed.config.world_size, 1) if "bias" in meta.tensor_name: # Generator unused. @@ -29,11 +29,11 @@ def init_megatron(tensor: "torch.Tensor", distributed: "Distributed"): elif config.num_experts > 1 and "mlp.layer_" in meta.tensor_name: tensor_ = _init_moe_mlp_megatron(config, meta, tensor, distributed) elif "mlp.layer_2" in meta.tensor_name: - tensor_ = _init_transposed_mlp_weight_megatron(config, meta, tensor, distributed) + tensor_ = _init_transposed_mlp_weight_megatron(meta, tensor, distributed) else: # Word embedding (override generator), layer norm (generator unused), other mlp weights. return meta.param_init_method(meta, tensor, distributed.tp_init_generator) - return tensor.copy_(tensor_.reshape_as(tensor)) + tensor.copy_(tensor_.reshape_as(tensor)) return init_megatron @@ -58,9 +58,9 @@ def _init_attention_megatron( generator = distributed.tp_init_generator state = generator.get_state() # Initialize a mock dense layer to advance the random state - dense_tensor_ = meta.param_init_method( + meta.param_init_method( meta, - tensor.new_empty( + dense_tensor_ := tensor.new_empty( config.kv_channels * config.num_attention_heads, config.hidden_size, ), @@ -68,9 +68,9 @@ def _init_attention_megatron( ) # QKV is split differently. (Assuming no tensor-parallel.) heads_per_group = div(config.num_attention_heads, config.head_groups) - qkv_tensor_ = meta.param_init_method( + meta.param_init_method( meta, - tensor.new_empty( + qkv_tensor_ := tensor.new_empty( config.head_groups, heads_per_group + 2, config.kv_channels, @@ -110,18 +110,19 @@ def _init_position_embeddings_megatron( # Megatron initializes the position embeddings on cpu twice. assert meta.param_init_method is not None generator = distributed.default_cpu_generator - tensor_ = meta.param_init_method(meta, torch.empty(tensor.shape, dtype=tensor.dtype), generator) - return meta.param_init_method(meta, tensor_, generator) + meta.param_init_method(meta, tensor_ := torch.empty(tensor.shape, dtype=tensor.dtype), generator) + meta.param_init_method(meta, tensor_, generator) + return tensor_ def _init_transposed_mlp_weight_megatron( - config: TransformerConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" + meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" ) -> "torch.Tensor": import torch # Megatron never transposes the mlp layer 2 weight. assert meta.param_init_method is not None - tensor_ = meta.param_init_method(meta, torch.empty_like(tensor), distributed.tp_init_generator) + meta.param_init_method(meta, tensor_ := torch.empty_like(tensor), distributed.tp_init_generator) return tensor_.view(meta.size(1), meta.size(0)).t() @@ -132,8 +133,8 @@ def _init_moe_router_megatron( # Megatron initializes the router on cpu. assert meta.param_init_method is not None - tensor_ = meta.param_init_method( - meta, torch.empty(tensor.shape, dtype=tensor.dtype), distributed.default_cpu_generator + meta.param_init_method( + meta, tensor_ := torch.empty(tensor.shape, dtype=tensor.dtype), distributed.default_cpu_generator ) return tensor_ diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 4c1eab46f..49a5dcbd3 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -155,7 +155,7 @@ def preprocess_meta( sequence_first = self._config.sequence_first assert not (need_sequence_first and not sequence_first) - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] hidden_dims = ( (hidden_sequence_q_dim, batch_dim, hidden_dim) if sequence_first diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index d780e4d6d..b3795b740 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -1,17 +1,21 @@ +import abc import functools +import logging import math import typing import torch from fast_llm.core.distributed import ReduceOp -from fast_llm.core.ops import gather_op, reduce_op +from fast_llm.core.ops import reduce_op from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed from fast_llm.functional.triton.pointwise import triton_add, triton_copy from fast_llm.utils import Assert +logger = logging.getLogger(__name__) + class _SafeTensorSliceMeta(type): def __instancecheck__(self, instance) -> bool: @@ -146,7 +150,7 @@ def from_tensor_space( reductions: tuple[tuple[str, ReduceOp], ...] = (), **kwargs: typing.Any, ) -> typing.Self: - dims = tuple(tensor_space.get_tensor_dim(dim_name) for dim_name in dim_names) + dims = tuple(tensor_space[dim_name] for dim_name in dim_names) if reductions: # kwarg not available for ParameterMeta, so we only provide if necessary. kwargs["reductions"] = tuple( @@ -158,22 +162,23 @@ def from_tensor_space( def global_shape(self) -> torch.Size: return torch.Size([dim.global_size for dim in self.dims]) - def local_to_global( - self, - tensor: torch.Tensor, - *, - distributed: Distributed, - ) -> tuple[torch.Tensor, ...]: + def local_to_global(self, tensor: torch.Tensor, *, distributed: Distributed) -> tuple[torch.Tensor, ...]: + """ + Reconstruct a global tensor from its distributed slices. Support lazy-loaded safetensor slices. + Returns a view of the input tensor (or the input tensor itself) when possible. + """ + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.shape) # Tensors are always either split or duplicated in the tensor-parallel direction. # TODO: Avoid hard-coded assumptions on duplication - is_first_rank = distributed.config.tensor_rank == 0 - modified = False - for i, dim in enumerate(self.dims): - if dim.parallel_group is not None: - tensor = gather_op( - tensor.unflatten(i, dim.expanded_shape), dim.parallel_group, i + dim.parallel_dim_index - ).flatten(i, i + len(dim.expanded_shape) - 1) - is_first_rank, modified = is_first_rank and dim.parallel_group.rank() == 0, True + is_first_rank, modified = distributed.config.tensor_rank == 0, False + + for dim, tensor_dim in enumerate(self.dims): + if tensor_dim.is_parallel: + tensor = tensor_dim.local_to_global(tensor, dim) + is_first_rank &= tensor_dim.parallel_dim.rank == 0 + modified = True for distributed_dim, op in self._reductions: if distributed_dim.group is not None: @@ -182,28 +187,48 @@ def local_to_global( tensor = tensor.clone() tensor = reduce_op(tensor, distributed_dim.group, op=op) is_first_rank, modified = is_first_rank and distributed_dim.group.rank() == 0, True + Assert.eq(tensor.shape, self.global_shape) return tensor, is_first_rank - def global_to_local( - self, - tensor: torch.Tensor | SafeTensorSlice, - # Return an expanded tensor, avoiding `flatten` which copies the data. - expand: bool = False, - ) -> torch.Tensor: + def local_to_global_partial(self, tensor: torch.Tensor, fill_value: float | int = -1) -> torch.Tensor: """ - Recover the tensor-parallel slice of a tensor. Support lazy-loaded safetensor slices. + Construct a tensor of shape `self.global_shape` that contains its local slice at the appropriate location, + i.e. for which `self.global_to_local(self.local_to_global_partial(tensor)) == tensor`. + Other entries are filled with `fill_value`. + Returns a view of the input tensor (or the input tensor itself) when possible. + """ + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.shape) + assert not self._reductions + logger.info(f"AAAA {self.tensor_name} {self.shape} {self.global_shape} {tensor.shape}") + for dim, tensor_dim in enumerate(self.dims): + if tensor_dim.is_parallel: + tensor = tensor_dim.local_to_global_partial(tensor, dim, fill_value) + logger.info( + f"BBBB {self.tensor_name} {self.shape} {self.global_shape} {tensor.shape} {tensor_dim.is_parallel}" + ) + + Assert.eq(tensor.shape, self.global_shape) + return tensor + + def global_to_local(self, tensor: torch.Tensor | SafeTensorSlice) -> torch.Tensor: + """ + Select the local slice of a global tensor. Support lazy-loaded safetensor slices. + Returns a view of the input tensor (or the input tensor itself) when possible. """ # Take a trivial slice to convert safetensor slices. - tensor_ = tensor[:] + tensor = tensor[:] assert not self._reductions + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.global_shape) - for i, dim in reversed(list(enumerate(self.dims))): - if dim.parallel_dim is not None and dim.parallel_dim.size > 1: - tensor_ = tensor_.unflatten(i, dim.global_expanded_shape).chunk( - dim.parallel_dim.size, i + dim.parallel_dim_index - )[dim.parallel_dim.rank] + for dim, tensor_dim in reversed(list(enumerate(self.dims))): + tensor = tensor_dim.global_to_local(tensor, dim) - return tensor_ if expand else tensor_.reshape(self.shape) + Assert.eq(tensor.shape, self.shape) + return tensor @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -237,7 +262,7 @@ def __init__( *, tensor_name: str = "", dims: tuple[TensorDim, ...], - init_method: typing.Callable[["ParameterMeta", torch.Tensor, torch.Generator], torch.Tensor] | None = None, + init_method: "Initializer | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None" = None, weight_decay: bool = True, # Pass a list to split the parameter in contiguous (dim=0) chunks of equal size for optimization. lr_scale: float | None | tuple[float | None, ...] = None, @@ -247,7 +272,11 @@ def __init__( allow_no_grad: bool = False, ): super().__init__(data, tensor_name=tensor_name, dims=dims) - self.param_init_method = init_method + if init_method is not None and not isinstance(init_method, Initializer): + # Support non-wrapped callables for convenience. + assert callable(init_method) + init_method = LambdaInitializer(init_method) + self.param_init_method: Initializer | None = init_method self.param_weight_decay = weight_decay self._is_param = True self.param_grad_is_zero = False @@ -272,7 +301,7 @@ def __new__( *, tensor_name: str = "", dims: tuple[TensorDim, ...], - init_method: typing.Callable, + init_method: "Initializer | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None", weight_decay: bool = True, lr_scale: float | None | tuple[float | None, ...] = None, allow_sequence_tensor_parallel: bool = True, @@ -293,12 +322,20 @@ def __repr__(self, *, tensor_contents=()) -> str: def init_parameter(self, tensor: torch.Tensor, distributed: Distributed) -> None: assert self.param_init_method is not None - if distributed.config.tensor_parallel == 1 or distributed.config.reproducible_init: + if ( + distributed.config.tensor_parallel == 1 + or distributed.config.reproducible_init + or self.param_init_method.requires_global_initialization + ): generator = distributed.pp_init_generator else: generator = distributed.tp_init_generator if self.is_tensor_parallel else distributed.pp_init_generator self.param_init_method(self, tensor, generator) + @property + def requires_global_initialization(self) -> bool: + return self.param_init_method.requires_global_initialization + def save(self) -> dict[str, typing.Any]: return { "name": self.tensor_name, @@ -330,11 +367,32 @@ def accumulate_gradient(param: torch.Tensor, grad: torch.Tensor) -> None: triton_add(grad, param.grad_buffer, out=param.grad_buffer) # noqa -def init_fill_(value) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - return tensor.fill_(value) +class Initializer(abc.ABC): + @abc.abstractmethod + def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: + pass + + requires_global_initialization = False + + +class LambdaInitializer(Initializer): + def __init__( + self, + init_method: typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None], + requires_global_initialization: bool = False, + ) -> None: + self._init_method = init_method + self.requires_global_initialization = requires_global_initialization + + def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: + return self._init_method(meta, tensor, generator) + - return init_ +def init_fill_(value: float) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa + tensor.fill_(value) + + return LambdaInitializer(init_) init_zeros_ = init_fill_(0.0) @@ -342,30 +400,35 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) def init_normal_( - mean=0.0, std=1.0, min_val=None, max_val=None -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa + mean: float = 0.0, std: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa tensor = tensor.normal_(mean, std, generator=generator) if min_val is not None or max_val is not None: - return tensor.clamp_(min=min_val, max=max_val) # noqa - else: - return tensor + tensor.clamp_(min=min_val, max=max_val) - return init_ + return LambdaInitializer(init_) -def kaiming_init_(d_in): +def init_kaiming_(d_in: float) -> LambdaInitializer: return init_normal_(0.0, math.sqrt(2.0 / d_in)) def init_uniform_( - low=0.0, high=1.0, min_val=None, max_val=None -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa + low: float = 0.0, high: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa tensor = tensor.uniform_(low, high, generator=generator) if min_val is not None or max_val is not None: - return tensor.clamp_(min=min_val, max=max_val) # noqa - else: - return tensor + tensor.clamp_(min=min_val, max=max_val) + + return LambdaInitializer(init_) + - return init_ +def init_uniform_centered_(high: float, max_val: float | None = None, mean: float = 0.0) -> LambdaInitializer: + return init_uniform_( + mean - high, + mean + high, + min_val=None if max_val is None else mean - max_val, + max_val=None if max_val is None else mean + max_val, + ) From 017f5cc5a021d9a2ef58e5d1903f60c4917f311c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 28 Jul 2025 18:09:53 -0400 Subject: [PATCH 25/82] fixes --- fast_llm/layers/ssm/discrete_mamba2.py | 24 ++++++++++----------- fast_llm/layers/ssm/mamba2.py | 26 +++++++++++----------- fast_llm/layers/ssm/mamba_layer.py | 30 ++++++++++++-------------- 3 files changed, 39 insertions(+), 41 deletions(-) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index c0ae7e781..6012f74a7 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -10,7 +10,7 @@ from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_, init_zeros_ from fast_llm.utils import get_lr_scale logger = logging.getLogger(__name__) @@ -62,14 +62,14 @@ def __init__( mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) logger.info(f"Setting lr_scale for layer {block_index} of type {type(self)}: {mamba_layer_lr_scale}") - td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) - td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) - td_conv = tensor_space.get_tensor_dim(SSMDimNames.conv_dim) - td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.qk_heads) - td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.v_heads) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) - td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.inner_proj_discrete_mamba2) + td_inner = tensor_space[SSMDimNames.inner_dim] + td_state = tensor_space[SSMDimNames.state_dim] + td_model = tensor_space[SSMDimNames.model_dim] + td_conv = tensor_space[SSMDimNames.conv_dim] + td_n_qk_heads = tensor_space[SSMDimNames.qk_heads] + td_n_v_heads = tensor_space[SSMDimNames.v_heads] + td_conv_kernel = tensor_space[SSMDimNames.conv_kernel_size] + td_inner_proj = tensor_space[SSMDimNames.inner_proj_discrete_mamba2] self.d_model = td_model.size self.d_inner = td_inner.size @@ -88,7 +88,7 @@ def __init__( td_model, td_inner_proj, bias=bias, - weight_init_method=kaiming_init_(td_model.size), + weight_init_method=init_kaiming_(td_model.size), lr_scale=mamba_layer_lr_scale, ) self.z_bias = ( @@ -103,7 +103,7 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( - (td_conv, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), + (td_conv, tensor_space[DefaultDimNames.scalar], td_conv_kernel), init_method=init_uniform_( 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 @@ -126,7 +126,7 @@ def __init__( td_inner, td_model, bias=bias, - weight_init_method=kaiming_init_(td_inner.size), + weight_init_method=init_kaiming_(td_inner.size), lr_scale=mamba_layer_lr_scale, ) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 74c212add..9dfad8462 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -10,7 +10,7 @@ from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.tensor import ParameterMeta, init_fill_, init_ones_, init_uniform_, kaiming_init_ +from fast_llm.tensor import ParameterMeta, init_fill_, init_kaiming_, init_ones_, init_uniform_ from fast_llm.utils import get_lr_scale try: @@ -80,13 +80,13 @@ def __init__( self.config.mamba_lr_scale, layer_lr_scale ) - td_inner: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.inner_dim) - td_state: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.state_dim) - td_model: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.model_dim) - tdt_rank: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) - td_xb: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.x_proj_dim_2) - td_inner_proj: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.inner_proj_mamba2) - td_conv_kernel: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel_size) + td_inner: TensorDim = tensor_space[SSMDimNames.inner_dim] + td_state: TensorDim = tensor_space[SSMDimNames.state_dim] + td_model: TensorDim = tensor_space[SSMDimNames.model_dim] + tdt_rank: TensorDim = tensor_space[SSMDimNames.dt_rank] + td_xb: TensorDim = tensor_space[SSMDimNames.x_proj_dim_2] + td_inner_proj: TensorDim = tensor_space[SSMDimNames.inner_proj_mamba2] + td_conv_kernel: TensorDim = tensor_space[SSMDimNames.conv_kernel_size] self.repeat_kv_before_conv = config.repeat_kv_before_conv @@ -98,7 +98,7 @@ def __init__( if self.repeat_kv_before_conv: self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), + (td_inner, tensor_space[DefaultDimNames.scalar], td_conv_kernel), init_method=init_uniform_( -1 / math.sqrt(td_inner.size * td_conv_kernel.size), 1 / math.sqrt(td_inner.size * td_conv_kernel.size), @@ -111,7 +111,7 @@ def __init__( ) else: self.conv1d_weight = ParameterMeta.from_dims( - (td_xb, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), + (td_xb, tensor_space[DefaultDimNames.scalar], td_conv_kernel), init_method=init_uniform_( -1 / math.sqrt(td_xb.size * td_conv_kernel.size), 1 / math.sqrt(td_xb.size * td_conv_kernel.size), @@ -131,14 +131,14 @@ def __init__( td_model, td_inner_proj, bias=bias, - weight_init_method=kaiming_init_(td_model.size), + weight_init_method=init_kaiming_(td_model.size), lr_scale=mamba_layer_lr_scale, ) self.dt_in_proj = Linear( td_model, tdt_rank, bias=config.add_bias_linear, - weight_init_method=kaiming_init_(transformer_config.hidden_size), + weight_init_method=init_kaiming_(transformer_config.hidden_size), lr_scale=mamba_layer_lr_scale, ) # Initialize special dt projection to preserve variance at initialization @@ -185,7 +185,7 @@ def __init__( td_inner, td_model, bias=bias, - weight_init_method=kaiming_init_(td_inner.size), + weight_init_method=init_kaiming_(td_inner.size), ) def forward(self, hidden_states, kwargs): diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 4493332ce..5e0ae786e 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -10,7 +10,7 @@ from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_ +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_ from fast_llm.utils import get_lr_scale try: @@ -75,15 +75,13 @@ def __init__( self.config: SSMConfig = config # Tensor dims: - td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) - td_inner_proj = tensor_space.get_tensor_dim( - SSMDimNames.inner_proj_mamba - ) # TensorDim("D_inner_2", self.d_inner * 2) - tdt_rank = tensor_space.get_tensor_dim(SSMDimNames.dt_rank) - td_x_proj = tensor_space.get_tensor_dim(SSMDimNames.x_proj_dim) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) - td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) + td_inner = tensor_space[SSMDimNames.inner_dim] + td_inner_proj = tensor_space[SSMDimNames.inner_proj_mamba] # TensorDim("D_inner_2", self.d_inner * 2) + tdt_rank = tensor_space[SSMDimNames.dt_rank] + td_x_proj = tensor_space[SSMDimNames.x_proj_dim] + td_state = tensor_space[SSMDimNames.state_dim] + td_model = tensor_space[SSMDimNames.model_dim] + td_conv_kernel = tensor_space[SSMDimNames.conv_kernel_size] self.d_conv = td_conv_kernel.size self.d_inner = td_inner.size self.d_state = td_state.size @@ -94,12 +92,12 @@ def __init__( self.in_proj_weight = ParameterMeta.from_dims( (td_inner_proj, td_model), - init_method=kaiming_init_(td_model.size), + init_method=init_kaiming_(td_model.size), ) self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), - init_method=kaiming_init_(td_inner.size), + (td_inner, tensor_space[DefaultDimNames.scalar], td_conv_kernel), + init_method=init_kaiming_(td_inner.size), lr_scale=mamba_layer_lr_scale, ) @@ -111,7 +109,7 @@ def __init__( self.x_proj = Linear( td_inner, td_x_proj, - weight_init_method=kaiming_init_(td_inner.size), + weight_init_method=init_kaiming_(td_inner.size), bias=False, lr_scale=mamba_layer_lr_scale, ) @@ -120,7 +118,7 @@ def __init__( # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 self.dt_proj_weight = ParameterMeta.from_dims( (td_inner, tdt_rank), - init_method=kaiming_init_(tdt_rank.size), + init_method=init_kaiming_(tdt_rank.size), lr_scale=mamba_layer_lr_scale, ) @@ -151,7 +149,7 @@ def __init__( td_inner, td_model, bias=False, # TODO: note, if bias is used there is a problem in the MambaInnerFn.backward for the bias grads. I think this bias is not used in other mamba repos. - weight_init_method=kaiming_init_(td_model.size), + weight_init_method=init_kaiming_(td_model.size), lr_scale=mamba_layer_lr_scale, ) self.out_proj.weight.auto_grad_accumulation = True From c41efc21ae2f8c1a87d35834a28ae3ad852f22d4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 28 Jul 2025 19:24:46 -0400 Subject: [PATCH 26/82] doc --- docs/contributing/testing.md | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/docs/contributing/testing.md b/docs/contributing/testing.md index 9cce78e3c..e04cf1b37 100644 --- a/docs/contributing/testing.md +++ b/docs/contributing/testing.md @@ -6,20 +6,22 @@ title: Writing and running tests ### Selecting tests -When debugging, it is often practical to target specific tests that will run quickly. While Pytest supports targeting specific directory, files or tests, the complex parameterization and dependencies of our tests often makes explicit targeting tedious and/or impractical. We provide several options for selecting tests: +When debugging, it is often advisable to target specific tests that can be executed efficiently. Although Pytest allows targeting specific tests or files, complex parameterization and dependencies in our suite often make explicit selection difficult. To address this, several options for test selection are available: -* `--skip-slow`: This will run a subset of "fast" tests that cover the majority of our codebase. This is useful for quickly checking that changes did not break Fast-LLM too badly before running the full test suite. Note that parallel testing (`-n`) is not needed (and may be counter-productive) with this argument. -* `--run-extra-slow`: Some tests are disabled by default because they take too long to run (ex. complex integration tests) and/or are not particularly important. This argument re-enables them. -* `--models MODEL0 MODEL1 ...`: This allows targeting one or more specific models from the model tests (see below), and is particularly useful when debugging a model. For example, `pytest tests/models/test_models/test_checkpoint.py -v -ra --models llama` will test checkpoints specifically for the llama model. (Note that `-n` may not be needed here as model tests for a given model are only partly distributed dure to dependency constraints.) +* `--skip-slow`: Executes a subset of expedited tests that encompass much of the codebase. This option is effective for quickly checking for major regressions prior to executing the comprehensive test suite. Please note, parallel testing (`-n`) is typically unnecessary—and may even be counterproductive—when using this argument. +* `--run-extra-slow`: Certain tests are disabled by default due to their lengthy execution times (e.g., complex integration tests) or limited criticality. Use this flag to re-enable them. +* `--models MODEL0 MODEL1 ...`: Enables targeting of one or more specific models within the model testing suite. This feature is particularly useful during model-specific debugging efforts. For instance, running `pytest tests/models/test_models/test_checkpoint.py -v -ra --models llama` will specifically test checkpointing functionality for the llama model. Note that parallelization (`-n`) may be unnecessary in this context, as model tests for a given model are only partially distributed due to dependency constraints. ### Monitoring distributed tests -`--no-distributed-capture` +Distributed tests are generally the slowest due to the overhead associated with starting processes and process groups. To mitigate this, Fast-LLM incorporates several bundled tests that execute multiple subtests within a single subprocess call. As bundled calls can generate substantial output and potentially reduce report readability, Fast-LLM captures the output from each subtest and forwards it to an associated test. If necessary, this output capture can be disabled using `--no-distributed-capture`—for instance, if a severe crash hinders output capture or to disable pytest capture entirely (`-s`). Captured logs are stored in the testing cache directory; please consult individual tests for specific locations. + +For example, `test_run_model_distributed[llama]` tries various distributed configurations for the `llama` model, each reported under an associated test such as `test_model_distributed[llama-distributed]`. Should a distributed subtest, say `tp2` (tensor-parallel), encounter a failure, `test_run_model_distributed` will log the issue, continue executing remaining subtests, and ultimately raise an error to designate the bundled test as failed. The associated test, `test_model_distributed[llama-tp2]`, will also fail and display the captured output (retrieved from `/tmp/fast_llm_tests/models/llama/tp2/`), separated by type (stdout, stderr and traceback) as would happen for a normal test (minus some advanced formating), but also by rank. ### Other options -* `--show-gpu-memory N`: Our testing suite monitors GPU memory usage and reports the highest users. Use this option to adjust the number of reported tests (10 by default). Note that this option is mainly intended to make sure tests don't use too much memory (which could cause crashes with lots of parallel tests) and may not be an accurate measurement. -* `--show-skipped`: Many tests skipped for obvious reasons (ex. marked as slow or extra slow, skipped model testing groups (see below)) are removed entirely from the report to reduce clutter. This option may be used to show them explicitly. +* `--show-gpu-memory N`: Monitors GPU memory use and reports the top N tests (default 10). Mainly helps ensure tests don't exceed memory limits, but results may not be precise. +* `--show-skipped`: Many tests skipped for obvious reasons (ex. marked as slow or extra slow, skipped model testing groups (see below)) are removed entirely from the report to reduce clutter. Use this flag to display them. ## Best practices @@ -29,15 +31,6 @@ When debugging, it is often practical to target specific tests that will run qui For each tested model, we run a series of tests divided into several groups. Much of these tests consist of running a short Fast-LLM training run, then comparing intermediate tensors (ex. parameter initialization, layer outputs and gradients, parameter gradients) against a baseline. -### What is being tested - -Coming soon. - -!!! warning "Don't forget about unit tests!" - - While adding a model is a quick and efficient way to increase coverage, it is **not a replacement for unit tests**. - The model testing suite performs intensive consistency checks, but does little to make sure those results are correct to begin with. See [functional tests](https://github.com/ServiceNow/Fast-LLM/blob/main/tests/functional) and [test_lm_head](https://github.com/ServiceNow/Fast-LLM/blob/main/tests/layers/test_lm_head.py) for good examples of unit tests for individual components and an entire layer. - ### Adding a model When adding support for a new model that comes with additional features, the simplest option to increase coverage is to add an example configuration to the [tested modelsl](https://github.com/ServiceNow/Fast-LLM/blob/main/tests/utils/model_configs.py). @@ -71,6 +64,11 @@ _update_and_add_testing_config( ) ``` +!!! warning "Don't forget about unit tests!" + + While adding a model is a quick and efficient way to increase coverage, it is **not a replacement for unit tests**. + The model testing suite performs intensive consistency checks, but does little to make sure those results are correct to begin with. See [functional tests](https://github.com/ServiceNow/Fast-LLM/blob/main/tests/functional) and [test_lm_head](https://github.com/ServiceNow/Fast-LLM/blob/main/tests/layers/test_lm_head.py) for good examples of unit tests for individual components and an entire layer. + #### Reference for groups Fast-LLM currently supports the following testing groups: From 0b8bd5dc7a09d73adc2fe08a1aa2924052bd01b5 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 28 Jul 2025 19:38:26 -0400 Subject: [PATCH 27/82] cleanup --- docs/contributing/contributing.md | 4 ++-- docs/contributing/testing.md | 25 +------------------------ mkdocs.yaml | 1 - setup.cfg | 2 +- 4 files changed, 4 insertions(+), 28 deletions(-) diff --git a/docs/contributing/contributing.md b/docs/contributing/contributing.md index 938fe925f..6185b63fe 100644 --- a/docs/contributing/contributing.md +++ b/docs/contributing/contributing.md @@ -40,7 +40,7 @@ Before diving into code, [open an issue](https://github.com/ServiceNow/Fast-LLM/ Here are some tips to ensure your pull request gets reviewed and merged promptly: - **Follow our coding standards**: Stick to our [style guide and conventions](https://servicenow.github.io/Fast-LLM/developers/style-guide) to keep the code clean and consistent. -- **Write tests**: Verify your changes with unit tests for new features or bug fixes. See our [testing guide](https://servicenow.github.io/Fast-LLM/contributing/testing) for tips and recommendations on testing. +- **Write tests**: Verify your changes with unit tests for new features or bug fixes. - **Test on GPUs and real-world workloads**: Since Fast-LLM is all about training large language models, make sure your changes work smoothly in GPU environments and on typical training setups. - **Run benchmarks and performance tests**: Make sure your changes don't slow things down. If there's any impact on performance, provide benchmark results to back it up. - **Avoid introducing new issues**: Check that there are no new runtime warnings, type checker errors, linting problems, or unhandled edge cases. @@ -48,7 +48,7 @@ Here are some tips to ensure your pull request gets reviewed and merged promptly - **Keep sensitive data out**: Make sure your code or commit messages don't expose private or proprietary information. - **Use a clear and descriptive title**: The PR title should summarize the key change or feature introduced. Avoid vague titles like "Fix bug" or "Update code." Start with a keyword like `[feat]`, `[fix]`, `[docs]`, etc. to categorize the change. Reference the issue number if applicable (e.g., `[fix] resolve #123 memory leak in training loop`). This title will become the commit message for the squashed merge. - **Use the [PR template](https://github.com/ServiceNow/Fast-LLM/blob/main/.github/PULL_REQUEST_TEMPLATE.md)**: Complete the checklist to make sure everything is in order before hitting submit. -- **Make sure all tests pass before merging**: Run the tests with `pytest tests/ -v -ra -n 10`, and fix any failure before merging. If possible, please run the test in an environment with at least 4 GPUs. See our [testing guide](https://servicenow.github.io/Fast-LLM/contributing/testing) for more details on testing and debugging. +- **Make sure all tests pass before merging**: Run the tests with `pytest tests/ -v -ra -n 10`, and fix any failure before merging. If possible, please run the test in an environment with at least 4 GPUs. ## 🆘 Seeking Help or Clarification diff --git a/docs/contributing/testing.md b/docs/contributing/testing.md index e04cf1b37..8df93f9d0 100644 --- a/docs/contributing/testing.md +++ b/docs/contributing/testing.md @@ -1,30 +1,7 @@ --- -title: Writing and running tests +title: Writing tests --- -## Debugging with tests - -### Selecting tests - -When debugging, it is often advisable to target specific tests that can be executed efficiently. Although Pytest allows targeting specific tests or files, complex parameterization and dependencies in our suite often make explicit selection difficult. To address this, several options for test selection are available: - -* `--skip-slow`: Executes a subset of expedited tests that encompass much of the codebase. This option is effective for quickly checking for major regressions prior to executing the comprehensive test suite. Please note, parallel testing (`-n`) is typically unnecessary—and may even be counterproductive—when using this argument. -* `--run-extra-slow`: Certain tests are disabled by default due to their lengthy execution times (e.g., complex integration tests) or limited criticality. Use this flag to re-enable them. -* `--models MODEL0 MODEL1 ...`: Enables targeting of one or more specific models within the model testing suite. This feature is particularly useful during model-specific debugging efforts. For instance, running `pytest tests/models/test_models/test_checkpoint.py -v -ra --models llama` will specifically test checkpointing functionality for the llama model. Note that parallelization (`-n`) may be unnecessary in this context, as model tests for a given model are only partially distributed due to dependency constraints. - -### Monitoring distributed tests - -Distributed tests are generally the slowest due to the overhead associated with starting processes and process groups. To mitigate this, Fast-LLM incorporates several bundled tests that execute multiple subtests within a single subprocess call. As bundled calls can generate substantial output and potentially reduce report readability, Fast-LLM captures the output from each subtest and forwards it to an associated test. If necessary, this output capture can be disabled using `--no-distributed-capture`—for instance, if a severe crash hinders output capture or to disable pytest capture entirely (`-s`). Captured logs are stored in the testing cache directory; please consult individual tests for specific locations. - -For example, `test_run_model_distributed[llama]` tries various distributed configurations for the `llama` model, each reported under an associated test such as `test_model_distributed[llama-distributed]`. Should a distributed subtest, say `tp2` (tensor-parallel), encounter a failure, `test_run_model_distributed` will log the issue, continue executing remaining subtests, and ultimately raise an error to designate the bundled test as failed. The associated test, `test_model_distributed[llama-tp2]`, will also fail and display the captured output (retrieved from `/tmp/fast_llm_tests/models/llama/tp2/`), separated by type (stdout, stderr and traceback) as would happen for a normal test (minus some advanced formating), but also by rank. - -### Other options - -* `--show-gpu-memory N`: Monitors GPU memory use and reports the top N tests (default 10). Mainly helps ensure tests don't exceed memory limits, but results may not be precise. -* `--show-skipped`: Many tests skipped for obvious reasons (ex. marked as slow or extra slow, skipped model testing groups (see below)) are removed entirely from the report to reduce clutter. Use this flag to display them. - -## Best practices - ## Testing models [Model integration tests](https://github.com/ServiceNow/Fast-LLM/blob/main/tests/models) are the most important part of our testing suite, ensuring that Fast-LLM works and yields consistent results for a variety of models, training configurations, optimizations, etc. diff --git a/mkdocs.yaml b/mkdocs.yaml index 00e52a011..85fd4bff0 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -189,6 +189,5 @@ nav: - Contribution Guide: contributing/contributing.md - Style Guide: contributing/style-guide.md - Development Practices: contributing/dev-practices.md - - Testing: contributing/testing.md - About Us: about-us.md - Join Us: join-us.md diff --git a/setup.cfg b/setup.cfg index dc6d0c445..843aa15ca 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,7 +48,7 @@ HUGGINGFACE = # Required to run SSMs # To install on cpu environment (ex. for IDE support): -# MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[SSM]" --no-build-isolation +# MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation SSM = mamba_ssm[causal-conv1d]==2.2.4 cartesia_pytorch>=0.0.2 From 02f8af5e5ce9189ded97a83ed5c90b84d18a5ec3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 29 Jul 2025 13:39:47 -0400 Subject: [PATCH 28/82] Block interface --- fast_llm/layers/block/__init__.py | 0 .../transformer.py => block/block.py} | 99 +--- fast_llm/layers/block/config.py | 120 ++++ fast_llm/layers/block/mixer.py | 68 +++ fast_llm/layers/block/mlp/__init__.py | 0 fast_llm/layers/block/mlp/config.py | 171 ++++++ .../mlp}/mixture_of_experts.py | 46 +- .../layers/{transformer => block/mlp}/mlp.py | 32 +- fast_llm/layers/block/peft.py | 128 +++++ fast_llm/layers/common/config.py | 12 - fast_llm/layers/language_model/config.py | 6 +- fast_llm/layers/language_model/embedding.py | 5 +- fast_llm/layers/language_model/head.py | 37 +- .../layers/language_model/preprocessing.py | 21 +- .../layers/ssm/{llamba_block.py => block.py} | 24 +- fast_llm/layers/ssm/config.py | 14 +- fast_llm/layers/ssm/discrete_mamba2.py | 29 +- fast_llm/layers/ssm/mamba2.py | 30 +- fast_llm/layers/ssm/mamba_layer.py | 16 +- fast_llm/layers/transformer/attention.py | 10 +- fast_llm/layers/transformer/block.py | 23 + fast_llm/layers/transformer/config.py | 542 +++--------------- fast_llm/models/gpt/conversion.py | 3 +- fast_llm/models/gpt/model.py | 14 +- tests/test_mlp.py | 14 +- 25 files changed, 749 insertions(+), 715 deletions(-) create mode 100644 fast_llm/layers/block/__init__.py rename fast_llm/layers/{transformer/transformer.py => block/block.py} (60%) create mode 100644 fast_llm/layers/block/config.py create mode 100644 fast_llm/layers/block/mixer.py create mode 100644 fast_llm/layers/block/mlp/__init__.py create mode 100644 fast_llm/layers/block/mlp/config.py rename fast_llm/layers/{transformer => block/mlp}/mixture_of_experts.py (89%) rename fast_llm/layers/{transformer => block/mlp}/mlp.py (77%) create mode 100644 fast_llm/layers/block/peft.py rename fast_llm/layers/ssm/{llamba_block.py => block.py} (52%) create mode 100644 fast_llm/layers/transformer/block.py diff --git a/fast_llm/layers/block/__init__.py b/fast_llm/layers/block/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/block/block.py similarity index 60% rename from fast_llm/layers/transformer/transformer.py rename to fast_llm/layers/block/block.py index 75d06f268..85da61c01 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/block/block.py @@ -1,83 +1,22 @@ import abc -import logging import typing import torch +from fast_llm.config import Configurable from fast_llm.core.distributed import set_generator from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs -from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP -from fast_llm.layers.transformer.mlp import MLP +from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs +from fast_llm.layers.block.mixer import Mixer +from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP +from fast_llm.layers.block.mlp.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta -from fast_llm.utils import Assert -logger = logging.getLogger(__name__) - -class Mixer(torch.nn.Module, abc.ABC): - """ - Base class for mixer modules. - """ - - _mixer_name: typing.ClassVar[str] - - def __init__(self, tensor_space: TensorSpace, block_index: int, debug_level: int = 0): - super().__init__() - self._tensor_space = tensor_space - self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel - self._block_index = block_index - self._debug_level = debug_level - - @abc.abstractmethod - def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Mixer module forward. Returns the output hidden states and an optional bias, - in case its addition can be made more efficient in `_bias_dropout_add`. - """ - - def _get_meta( - self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> TensorMeta: - hidden_dims = { - dim.name: dim - for dim in kwargs[TransformerKwargs.hidden_dims] + (kwargs[TransformerKwargs.sequence_q_dim],) - } - return TensorMeta.from_dims( - tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space[dim_name] - for dim_name in dim_names - ), - tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", - dtype=input_.dtype, - ) - - def _debug_log( - self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> None: - # TODO: Local vs global - Assert.gt(self._debug_level, 0) - log_distributed_tensor( - "", - tensor, - level=self._debug_level, - meta=self._get_meta(tensor, name, dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=self._debug_level, - meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - - -class BaseBlock(Layer, abc.ABC): +class Block[ConfigType: BlockConfig](Layer, Configurable[ConfigType]): """ A transformer-like decoder base block with abstract mixer. """ @@ -85,11 +24,9 @@ class BaseBlock(Layer, abc.ABC): # TODO: Standardize to `mixer` _mixer_module_name: typing.ClassVar[str] = "mixer" - def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False - ): + def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, return_input: bool = False): super().__init__() - self._config: TransformerConfig = config + self._config = config self._tensor_space: TensorSpace = tensor_space self._dropout_p: float = self._config.hidden_dropout # For multi-token prediction, return a stack of shared_hidden and transformer_output. @@ -97,7 +34,7 @@ def __init__( self._block_index = block_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory - hidden_dim = self._tensor_space[TransformerDimNames.hidden] + hidden_dim = self._tensor_space[BlockDimNames.hidden] # Note, layer_lr_scale does not impact the norms # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) @@ -131,7 +68,7 @@ def name(self) -> str: return f"{self._name} {self._block_index}" def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): - dims = kwargs[TransformerKwargs.hidden_dims] + dims = kwargs[BlockKwargs.hidden_dims] if self._return_input: dims = (TensorDim("stacked_input_output", 2),) + dims return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) @@ -196,19 +133,3 @@ def forward( if self._return_input: hidden_states = torch.stack((fw_input, hidden_states), dim=0) return hidden_states - - -class TransformerBlock(BaseBlock): - _name = "Transformer layer" - # TODO: Standardize to `mixer` - _mixer_module_name: typing.ClassVar[str] = "self_attn" - - def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False - ): - super().__init__(config, tensor_space, block_index, return_input) - - def _create_mixer(self) -> Mixer: - from fast_llm.layers.transformer.attention import Attention - - return Attention(self._config, self._tensor_space, self._block_index) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py new file mode 100644 index 000000000..2f26d8d79 --- /dev/null +++ b/fast_llm/layers/block/config.py @@ -0,0 +1,120 @@ +import enum + +from fast_llm.config import Field, FieldHint, check_field, config_class +from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.layers.block.mlp.config import MLPConfig +from fast_llm.layers.block.peft_config import TransformerPeftConfig +from fast_llm.layers.common.config import NormalizationConfig +from fast_llm.utils import Assert + + +class BlockDimNames: + # A set of common tensor dim names packed into a namespace. + # Input dimensions (variable) + # TODO: Does batch belong here? + batch = "batch" + # TODO: Distinguish micro-sequence? + sequence_q = "sequence_q" + sequence_q_tp = "sequence_q_tp" + sequence_k = "sequence_k" + hidden = "hidden" + + +class BlockKwargs: + sequence_first = "sequence_first" + hidden_dims = "hidden_dims" + sequence_q_dim = "sequence_q_dim" + sequence_k_dim = "sequence_k_dim" + sequence_length = "sequence_length" + # TODO: Belongs elsewhere? + grad_output = "grad_output" + + +@config_class() +# TODO: Use composition for MLP config +class BlockConfig(MLPConfig, BaseModelConfig): + + # TODO: Review names + normalization: NormalizationConfig = Field( + desc="Configuration for the normalization layers architecture.", + hint=FieldHint.architecture, + ) + peft: TransformerPeftConfig = Field( + desc="Configuration for the parameter-efficient fine tuning.", + hint=FieldHint.architecture, + ) + hidden_dropout: float = Field( + default=0.0, + desc="Dropout applied to the residual connections.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + full_precision_residual: bool = Field( + default=False, + desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", + hint=FieldHint.stability, + ) + debug_transformer: int = Field( + default=0, + desc="Log the output of each operation in a transformer layer.", + hint=FieldHint.logging, + valid=check_field(Assert.geq, 0), + ) + debug_transformer_memory: bool = Field( + default=False, + desc="Log the memory usage after each operation in a transformer layer..", + hint=FieldHint.logging, + ) + add_linear_biases: bool | AddLinearBiasChoices = Field( + default=True, + desc="Add biases to all, none or Q, K, V layers. Accepted values: True, False, or AddLinearBiasChoices.", + hint=FieldHint.architecture, + ) + + # TODO: Move these, not specific to a single block. + num_layers: int = Field( + default=12, + desc="Number of layers in the transformer.", + hint=FieldHint.architecture, + valid=check_field(Assert.geq, 0), + ) + hidden_size: int = Field( + default=1024, + desc="Size of the transformer's main hidden dimension, e.g., for its input and output layers.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + per_layer_lr_scale: list[float] | None = Field( + default=None, + desc="Custom learning rate scale for each layer.", + doc="May be used to freeze some layers by setting their scale to zero.", + hint=FieldHint.feature, + ) + + def _validate(self) -> None: + with self._set_implicit_default(): + if self.ffn_hidden_size is None: + self.ffn_hidden_size = 4 * self.hidden_size + + super()._validate() + + @property + def add_mlp_bias(self) -> bool: + if isinstance(self.add_linear_biases, bool): + return self.add_linear_biases + if self.add_linear_biases == AddLinearBiasChoices.everywhere: + return True + return False + + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + super().setup_tensor_space(tensor_space) + + # Hidden dimension + tensor_space.add_tensor_dim(TensorDim(BlockDimNames.hidden, self.hidden_size)) + + +class AddLinearBiasChoices(str, enum.Enum): + nowhere = "nowhere" + everywhere = "everywhere" + only_attn_qkv = "only_attn_qkv" diff --git a/fast_llm/layers/block/mixer.py b/fast_llm/layers/block/mixer.py new file mode 100644 index 000000000..5c811e330 --- /dev/null +++ b/fast_llm/layers/block/mixer.py @@ -0,0 +1,68 @@ +import abc +import typing + +import torch + +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.logging import log_distributed_grad, log_distributed_tensor +from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert + + +class Mixer(torch.nn.Module, abc.ABC): + """ + Base class for mixer modules. + """ + + _mixer_name: typing.ClassVar[str] + + def __init__(self, tensor_space: TensorSpace, block_index: int, debug_level: int = 0): + super().__init__() + self._tensor_space = tensor_space + self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel + self._block_index = block_index + self._debug_level = debug_level + + @abc.abstractmethod + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Mixer module forward. Returns the output hidden states and an optional bias, + in case its addition can be made more efficient in `_bias_dropout_add`. + """ + + def _get_meta( + self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> TensorMeta: + hidden_dims = { + dim.name: dim for dim in kwargs[BlockKwargs.hidden_dims] + (kwargs[BlockKwargs.sequence_q_dim],) + } + return TensorMeta.from_dims( + tuple( + hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space[dim_name] + for dim_name in dim_names + ), + tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", + dtype=input_.dtype, + ) + + def _debug_log( + self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> None: + # TODO: Local vs global + Assert.gt(self._debug_level, 0) + log_distributed_tensor( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name, dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + if tensor.requires_grad: + log_distributed_grad( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) diff --git a/fast_llm/layers/block/mlp/__init__.py b/fast_llm/layers/block/mlp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py new file mode 100644 index 000000000..63e31219a --- /dev/null +++ b/fast_llm/layers/block/mlp/config.py @@ -0,0 +1,171 @@ +import enum + +from fast_llm.config import Config, Field, FieldHint, check_field, skip_valid_if_none +from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace +from fast_llm.engine.distributed.config import DistributedDimNames +from fast_llm.functional.config import ActivationType, MLPRecomputeLevel +from fast_llm.utils import Assert + + +class MLPDimNames: + # MLP dimensions + mlp = "mlp" + gate_and_up = "gate_and_up" + composite_gated_mlp = "composite_gated_mlp" + experts = "experts" + top_experts = "top_experts" + shared_experts = "shared_experts" + unshared_experts = "unshared_experts" + composite_expert_mlp = "composite_expert_mlp" + composite_gated_expert_mlp = "composite_gated_expert_mlp" + composite_shared_expert_mlp = "composite_shared_expert_mlp" + composite_gated_shared_expert_mlp = "composite_gated_shared_expert_mlp" + + +class MLPLossNames: + load_balancing_loss = "load_balancing_loss" + router_z_loss = "router_z_loss" + + +class RoutingType(str, enum.Enum): + topk = "aux_loss" + sinkhorn = "sinkhorn" + + +class MLPConfig(Config): + # TODO: Review names + _abstract = False + ffn_hidden_size: int = Field( + default=None, + desc="Hidden dimension of the MLP intermediate state. Default: 4 * hidden_size.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + num_experts: int = Field( + default=1, + desc="Number of MLP experts in a Mixture of Expert (MoE) model", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + num_shared_experts: int = Field( + default=0, + desc="Number of MLP experts that are shared between all tokens, i.e., always enabled.", + hint=FieldHint.architecture, + valid=check_field(Assert.geq, 0), + ) + num_unshared_experts: int = Field( + init=False, + desc="Number of MLP experts excluding shared ones", + hint=FieldHint.architecture, + valid=check_field(Assert.geq, 0), + ) + num_experts_per_token: int = Field( + default=1, + desc="Active experts for each token in a MoE model.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + expert_routing_type: RoutingType = Field( + default=RoutingType.topk, + desc="The routing method, i.e., the method used to assign experts to tokens.", + hint=FieldHint.architecture, + ) + gated: bool = Field(default=False, desc="Enable gated MLP.", hint=FieldHint.architecture) + # Default: hidden_size**-0.5 + # TODO: Allow custom initialization (InitializationConfig?) + activation_type: ActivationType = Field( + default=None, + desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", + hint=FieldHint.core, + ) + # normalization_implementation: NormalizationImplementation = NormalizationImplementation.auto + mlp_recompute_level: MLPRecomputeLevel = Field( + default=MLPRecomputeLevel.none, + desc="Set which of the MLP intermediate activations will be recomputed during the backward passes. This provides a trade-off between memory and speed.", + hint=FieldHint.performance, + ) + expert_auxiliary_loss_coefficient: float = Field( + default=0.01, + desc="Scale of the load balancing auxiliary loss for topk routing.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + expert_z_loss_coefficient: float = Field( + default=0.0, + desc="Regularize the router during training by applying Z-loss to the logits.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + moe_jitter_eps: float = Field( + default=0.0, + desc="Regularize the router during training by applying a random multiplicative noise `uniform(1-eps, 1+eps)` to the logits.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + mlp_lr_scale: float | None | list[float | None] = Field( + default=None, + desc="Custom learning rate scale for each expert.", + doc="May be used to freeze some experts by setting their scale to zero.", + hint=FieldHint.feature, + ) + router_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate for the MoE router weight.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + dropless_moe: bool = Field( + default=True, desc="Evaluate all the experts at once using dropless MoE.", hint=FieldHint.expert + ) + dropless_dynamic_shape: bool = Field( + default=False, + desc="Use a dynamic shape for dropless MLP instead of the worst-case value." + " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", + hint=FieldHint.expert, + ) + + def _validate(self) -> None: + with self._set_implicit_default(): + if self.activation_type is None: + self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu + self.num_unshared_experts = self.num_experts - self.num_shared_experts + + super()._validate() + + Assert.leq(self.num_shared_experts, self.num_experts) + Assert.leq(self.num_shared_experts + self.num_experts_per_token, self.num_experts) + + if isinstance(self.mlp_lr_scale, list): + Assert.eq(len(self.mlp_lr_scale), self.num_experts) + for scale in self.mlp_lr_scale: + if scale is not None: + Assert.geq(scale, 0) + elif self.mlp_lr_scale is not None: + Assert.geq(self.mlp_lr_scale, 0) + + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) + + # MLP dimensions + tensor_space.add_tensor_dim(mlp := TensorDim(MLPDimNames.mlp, self.ffn_hidden_size, tensor)) + tensor_space.add_tensor_dim(gate_and_up := TensorDim(MLPDimNames.gate_and_up, 2 if self.gated else 1)) + tensor_space.add_tensor_dim(CompositeTensorDim(MLPDimNames.composite_gated_mlp, (gate_and_up, mlp))) + tensor_space.add_tensor_dim(experts := TensorDim(MLPDimNames.experts, self.num_experts)) + tensor_space.add_tensor_dim(CompositeTensorDim(MLPDimNames.composite_expert_mlp, (experts, mlp))) + tensor_space.add_tensor_dim( + CompositeTensorDim(MLPDimNames.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) + ) + tensor_space.add_tensor_dim(TensorDim(MLPDimNames.top_experts, self.num_experts_per_token)) + tensor_space.add_tensor_dim(TensorDim(MLPDimNames.unshared_experts, self.num_unshared_experts)) + + # shared_experts + if self.num_shared_experts: + tensor_space.add_tensor_dim( + shared_experts := TensorDim(MLPDimNames.shared_experts, self.num_shared_experts) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(MLPDimNames.composite_shared_expert_mlp, (shared_experts, mlp)) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(MLPDimNames.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp)) + ) diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py similarity index 89% rename from fast_llm/layers/transformer/mixture_of_experts.py rename to fast_llm/layers/block/mlp/mixture_of_experts.py index 4fd2844d5..8d092b6dc 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -9,16 +9,11 @@ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map +from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs +from fast_llm.layers.block.mlp.config import MLPDimNames, MLPLossNames, RoutingType +from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.linear import Linear -from fast_llm.layers.transformer.config import ( - RoutingType, - TransformerConfig, - TransformerDimNames, - TransformerKwargs, - TransformerLossNames, -) -from fast_llm.layers.transformer.mlp import MLPBase from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta, init_normal_ from fast_llm.utils import Assert, get_lr_scale @@ -26,7 +21,7 @@ logger = logging.getLogger(__name__) -class MixtureOfExpertMLP(MLPBase): +class MixtureOfExpertMLP[ConfigType: BlockConfig](MLPBase[ConfigType]): """ MoeLayer following implementation from https://github.com/NVIDIA/Megatron-LM/blob/46ebc0e4202c980d98900000d455f754a7ff9d4b/megatron/model/transformer.py#L346 @@ -40,12 +35,11 @@ class MixtureOfExpertMLP(MLPBase): _group: ProcessGroup - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): + def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." super().__init__(config, tensor_space, name, block_index) - self._config = config self._tensor_space = tensor_space self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory @@ -63,8 +57,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) self.router = Linear( - tensor_space[TransformerDimNames.hidden], - tensor_space[TransformerDimNames.unshared_experts], + tensor_space[BlockDimNames.hidden], + tensor_space[MLPDimNames.unshared_experts], bias=False, weight_init_method=init_normal_( std=config.init_method_std, min_val=config.init_method_min, max_val=config.init_method_max @@ -86,7 +80,7 @@ def forward( hidden_states = input_.flatten(0, -2) logits = self.router(hidden_states) if self._debug_mode: - self._debug_log(logits, "Router logits", TransformerDimNames.experts, kwargs) + self._debug_log(logits, "Router logits", MLPDimNames.experts, kwargs) # Apply z_loss if applicable if self._z_loss_factor > 0.0: @@ -96,7 +90,7 @@ def forward( self.training, grad_scale=kwargs.get("grad_output"), losses=losses, - loss_name=TransformerLossNames.router_z_loss, + loss_name=MLPLossNames.router_z_loss, ) # Apply input_jitter if applicable: @@ -106,7 +100,7 @@ def forward( # Routing if self._routing_type == RoutingType.topk: - scores, top_experts = self._topk_routing(logits, kwargs.get(TransformerKwargs.grad_output), losses) + scores, top_experts = self._topk_routing(logits, kwargs.get(BlockKwargs.grad_output), losses) if self._num_shared_experts > 0: scores, top_experts = self._add_shared_experts(top_experts, scores) elif self._routing_type == RoutingType.sinkhorn: @@ -116,8 +110,8 @@ def forward( if self._debug_mode: # To log all ranks set `global_=False` - self._debug_log(scores, "Router scores", TransformerDimNames.top_experts, kwargs) - self._debug_log(top_experts, "Router top experts", TransformerDimNames.top_experts, kwargs) + self._debug_log(scores, "Router scores", MLPDimNames.top_experts, kwargs) + self._debug_log(top_experts, "Router top experts", MLPDimNames.top_experts, kwargs) return self._mlp_forward(hidden_states, scores, top_experts).view_as(input_), None # noqa @@ -135,12 +129,12 @@ def _forward_dropless( None, self.layer_2.weight, None, - gated=self._gated, - activation_type=self._activation_type, + gated=self._config.gated, + activation_type=self._config.activation_type, group=self._intermediate_dim.parallel_group, sequence_parallel=self._sequence_parallel, training=self.training, - recompute_level=self._recompute_level, + recompute_level=self._config.mlp_recompute_level, transposed_layer_2_weight=True, sparse_map=sparse_map, ) @@ -155,12 +149,12 @@ def _forward_looped( self.layer_1.weight, self.layer_2.weight, self._num_experts, - self._gated, - self._activation_type, + self._config.gated, + self._config.activation_type, self._intermediate_dim.parallel_group, self._sequence_parallel, self.training, - self._recompute_level, + self._config.mlp_recompute_level, ) @torch.compile @@ -185,7 +179,7 @@ def _topk_routing( probs.flatten(0, -2).mean(dim=0) * mask.flatten(0, -2).mean(dim=0, dtype=torch.float32) ) if losses is not None: - losses[TransformerLossNames.load_balancing_loss].append(aux_loss.detach()) + losses[MLPLossNames.load_balancing_loss].append(aux_loss.detach()) if self.training and grad_scale is not None: scores = AuxiliaryLoss.apply( scores, aux_loss, self._num_unshared_experts * self._load_balancing_factor * grad_scale @@ -255,7 +249,7 @@ def _debug_log( def _get_meta(self, tensor: torch.Tensor, name: str, dim_name: str, kwargs: dict[str, typing.Any]) -> TensorMeta: return TensorMeta.from_dims( - kwargs[TransformerKwargs.hidden_dims][:-1] + (self._tensor_space[dim_name],), + kwargs[BlockKwargs.hidden_dims][:-1] + (self._tensor_space[dim_name],), tensor_name=f"{self._name} {name}", dtype=tensor.dtype, ) diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/block/mlp/mlp.py similarity index 77% rename from fast_llm/layers/transformer/mlp.py rename to fast_llm/layers/block/mlp/mlp.py index 101d97ef3..04b8506a4 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -1,21 +1,23 @@ import typing -from abc import ABC import torch +from fast_llm.config import Configurable from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd +from fast_llm.layers.block.config import BlockConfig, BlockDimNames +from fast_llm.layers.block.mlp.config import MLPDimNames +from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import LinearBase -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerSubLayerName from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import Assert, get_lr_scale -class MLPBase(Layer, ABC): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): - super().__init__() +class MLPBase[ConfigType: BlockConfig](Layer, Configurable[ConfigType]): + def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): + super().__init__(config) self._name = name self._block_index = block_index @@ -30,13 +32,9 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s max_val=config.init_method_max_mlp_2, ) - hidden_dim = tensor_space[TransformerDimNames.hidden] - self._intermediate_dim = tensor_space[TransformerDimNames.composite_expert_mlp] + hidden_dim = tensor_space[BlockDimNames.hidden] + self._intermediate_dim = tensor_space[MLPDimNames.composite_expert_mlp] self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel - self._recompute_level = config.mlp_recompute_level - - self._gated = config.gated - self._activation_type = config.activation_type self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None @@ -46,7 +44,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - tensor_space[TransformerDimNames.composite_gated_expert_mlp], + tensor_space[MLPDimNames.composite_gated_expert_mlp], bias=config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, @@ -68,8 +66,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self.layer_2 = config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) -class MLP(MLPBase): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): +class MLP[ConfigType: BlockConfig](MLPBase[ConfigType]): + def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.eq(config.num_experts, 1) super().__init__(config, tensor_space, name, block_index) @@ -89,12 +87,12 @@ def forward( self.layer_1.bias, self.layer_2.weight, None if parallel_group else self.layer_2.bias, - gated=self._gated, - activation_type=self._activation_type, + gated=self._config.gated, + activation_type=self._config.activation_type, group=parallel_group, sequence_parallel=self._sequence_parallel, training=self.training, - recompute_level=self._recompute_level, + recompute_level=self._config.mlp_recompute_level, transposed_layer_2_weight=self.layer_2.transposed_weight, ), self.layer_2.bias if parallel_group else None, diff --git a/fast_llm/layers/block/peft.py b/fast_llm/layers/block/peft.py new file mode 100644 index 000000000..269ed0aac --- /dev/null +++ b/fast_llm/layers/block/peft.py @@ -0,0 +1,128 @@ +""" +TODO: Generalize beyond transformers. +""" + +import abc +import enum +import typing + +from fast_llm.config import Field, FieldHint, config_class +from fast_llm.layers.common.config import LoRAConfig, NoPeftConfig, PeftConfig +from fast_llm.tensor import ParameterMeta +from fast_llm.utils import div + +if typing.TYPE_CHECKING: + import torch + + from fast_llm.layers.common.linear import LinearBase, LinearLike + + +class TransformerSubLayerName(str, enum.Enum): + # TODO: Use this to replace AddLinearBiasChoices. + query = "query" + key = "key" + value_ = "value" + key_value = "key_value" + dense = "dense" + mlp_1 = "mlp_1" + mlp_2 = "mlp_2" + + +@config_class(registry=True) +class TransformerPeftConfig(PeftConfig): + @abc.abstractmethod + def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": + pass + + @abc.abstractmethod + def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": + pass + + @abc.abstractmethod + def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": + pass + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is TransformerPeftConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return TransformerNoPeftConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + + +@config_class(dynamic_type={TransformerPeftConfig: "none"}) +class TransformerNoPeftConfig(NoPeftConfig, TransformerPeftConfig): + _abstract = False + + def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": + return super().apply_linear(linear) + + def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": + return module + + def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": + return parameter + + +@config_class(dynamic_type={TransformerPeftConfig: "lora"}) +class TransformerLoRAConfig(LoRAConfig, TransformerPeftConfig): + layers: list[TransformerSubLayerName] = Field( + default=(TransformerSubLayerName.query, TransformerSubLayerName.value_), + desc="The layers on which to apply LoRA.", + hint=FieldHint.feature, + ) + freeze_others: bool = Field( + default=True, + desc="Whether to freeze other layers during training.", + ) + + def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": + if layer_type is None or self.layers is None or layer_type in self.layers: + if layer_type == TransformerSubLayerName.key: + return super().apply_linear(linear, out_channel_end=div(linear._out_dim.global_size, 2)) + elif layer_type == TransformerSubLayerName.value_: + return super().apply_linear(linear, out_channel_begin=div(linear._out_dim.global_size, 2)) + else: + return super().apply_linear(linear) + elif self.freeze_others: + linear.weight.requires_grad = False + return linear + + def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": + if self.freeze_others: + for parameter in module.parameters(): + parameter.requires_grad = False + return module + + def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": + if self.freeze_others: + parameter.requires_grad = False + return parameter + + def _validate(self) -> None: + super()._validate() + if TransformerSubLayerName.mlp_1 in self.layers or TransformerSubLayerName.mlp_2 in self.layers: + # TODO: Add MLP support. + raise NotImplementedError("LoRA not supported for MLP.") + if TransformerSubLayerName.dense in self.layers: + # TODO: Support InputParallelLinear (different output format). + raise NotImplementedError("LoRA not supported for attention dense layer.") + if ( + sum( + name in self.layers + for name in ( + TransformerSubLayerName.key_value, + TransformerSubLayerName.key, + TransformerSubLayerName.value_, + ) + ) + > 1 + ): + raise ValueError( + f"{TransformerSubLayerName.key_value.value}, {TransformerSubLayerName.key.value} and {TransformerSubLayerName.value_.value} are mutually exclusive." + ) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 07dadbc22..9d5ce3f3b 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -14,18 +14,6 @@ from fast_llm.layers.common.normalization import LayerNorm, RMSNorm -@config_class() -class LLMBlockConfig(BaseModelConfig): - _abstract = False - - per_layer_lr_scale: list[float] | None = Field( - default=None, - desc="Custom learning rate scale for each layer.", - doc="May be used to freeze some layers by setting their scale to zero.", - hint=FieldHint.feature, - ) - - class NormalizationImplementation(str, enum.Enum): """ An enum for the available implementations of layer norm. diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 8e2e97f1a..b667e5318 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -5,12 +5,13 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl +from fast_llm.layers.block.config import BlockDimNames, BlockKwargs from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.layers.transformer.rotary.config import NoRotaryConfig from fast_llm.utils import Assert -class LanguageModelDimNames: +class LanguageModelDimNames(BlockDimNames): # Embedding dimensions position_embed = "position_embed" vocab = "vocab" @@ -33,7 +34,7 @@ def multi_token_prediction_loss(index: int) -> str: return f"language_model_loss_{index}" -class LanguageModelKwargs: +class LanguageModelKwargs(BlockKwargs): position_ids = "position_ids" # TODO: These are generic labels = "labels" @@ -46,6 +47,7 @@ class LanguageModelKwargs: @config_class() class LanguageModelBaseConfig(BaseModelConfig): + # TODO: block transformer: TransformerConfig = Field( desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index f6f43d199..05678a700 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -8,7 +8,6 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelDimNames, LanguageModelKwargs -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ from fast_llm.utils import Assert @@ -46,7 +45,7 @@ def __init__( self._dropout_p = config.transformer.hidden_dropout self._use_absolute_position_embeddings = config.use_absolute_position_embeddings - hidden_dim = tensor_space[TransformerDimNames.hidden] + hidden_dim = tensor_space[LanguageModelDimNames.hidden] vocab_dim = tensor_space[ LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab ] @@ -129,7 +128,7 @@ def forward( ) -> torch.Tensor: if isinstance(input_, TensorMeta): return TensorMeta.from_dims( - kwargs[TransformerKwargs.hidden_dims], + kwargs[LanguageModelKwargs.hidden_dims], tensor_name="Embedding output", dtype=self._residual_dtype, ) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 210cad644..bc672725c 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -23,7 +23,6 @@ LanguageModelLossNames, ) from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs from fast_llm.logging import log_distributed_tensor from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ from fast_llm.utils import Assert, div, get_unique @@ -61,7 +60,7 @@ def __init__( if self._cross_entropy_splits is not None and self._sequence_parallel: assert not self._parallel_embeddings - hidden_dim = self._tensor_space[TransformerDimNames.hidden] + hidden_dim = self._tensor_space[LanguageModelDimNames.hidden] self._loss_coefficient = ( config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0 @@ -168,20 +167,22 @@ def _forward_backward( if "output_hidden_states" in kwargs and kwargs["output_hidden_states"]: # The last hidden layer output is returned normalized in the HF Transformers-style output, at least for LLama style models. # So, if needed, we gather the data after normalization and set it as the output of the previous layer. - dims = list(kwargs[TransformerKwargs.hidden_dims]) - sequence_index = 1 - int(kwargs[TransformerKwargs.sequence_first]) + dims = list(kwargs[LanguageModelKwargs.hidden_dims]) + sequence_index = 1 - int(kwargs[LanguageModelKwargs.sequence_first]) dims[sequence_index] = ( TensorDim( - TransformerDimNames.sequence_q_tp, dims[sequence_index].global_size, DistributedDimNames.tensor + LanguageModelDimNames.sequence_q_tp, + dims[sequence_index].global_size, + DistributedDimNames.tensor, ) if self._sequence_parallel_logits - else TensorDim(TransformerDimNames.sequence_q, dims[sequence_index].global_size) + else TensorDim(LanguageModelDimNames.sequence_q, dims[sequence_index].global_size) ) meta = TensorMeta.from_dims(tuple(dims), tensor_name="transformer hidden_state", dtype=ln_output.dtype) hidden_state, _ = meta.local_to_global(ln_output.detach(), distributed=self._tensor_space.distributed) kwargs["hidden_states"][len(kwargs["hidden_states"]) - 1]["tensor"] = hidden_state - grad_output = kwargs[TransformerKwargs.grad_output] / ( + grad_output = kwargs[LanguageModelKwargs.grad_output] / ( self._group_size if self._sequence_parallel_logits else 1 ) @@ -221,18 +222,18 @@ def _get_targets( if lm_target is not None: # MTP: Shift the labels lm_target_sequence_length = ( - lm_target.size(1 - kwargs[TransformerKwargs.sequence_first]) + lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - self._config.prediction_heads ) - if TransformerKwargs.sequence_q_dim in kwargs: - Assert.eq(lm_target_sequence_length, kwargs[TransformerKwargs.sequence_q_dim].size) + if LanguageModelKwargs.sequence_q_dim in kwargs: + Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) lm_target_slice = slice( self._prediction_distance, self._prediction_distance + lm_target_sequence_length ) lm_target = ( lm_target[lm_target_slice] - if kwargs[TransformerKwargs.sequence_first] + if kwargs[LanguageModelKwargs.sequence_first] else lm_target[:, lm_target_slice] ).flatten() else: @@ -341,23 +342,23 @@ def _logits_cross_entropy_forward_backward( vocab_dim = self._tensor_space[ LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp ] - dims = [*kwargs[TransformerKwargs.hidden_dims][:-1], vocab_dim] - sequence_index = 1 - int(kwargs[TransformerKwargs.sequence_first]) + dims = [*kwargs[LanguageModelKwargs.hidden_dims][:-1], vocab_dim] + sequence_index = 1 - int(kwargs[LanguageModelKwargs.sequence_first]) dims[sequence_index] = ( TensorDim( - TransformerDimNames.sequence_q_tp, dims[sequence_index].global_size, DistributedDimNames.tensor + LanguageModelDimNames.sequence_q_tp, dims[sequence_index].global_size, DistributedDimNames.tensor ) if self._sequence_parallel_logits - else TensorDim(TransformerDimNames.sequence_q, dims[sequence_index].global_size) + else TensorDim(LanguageModelDimNames.sequence_q, dims[sequence_index].global_size) ) dim_names = ( - [TransformerDimNames.sequence_q_tp, LanguageModelDimNames.vocab] + [LanguageModelDimNames.sequence_q_tp, LanguageModelDimNames.vocab] if self._sequence_parallel_logits - else [TransformerDimNames.sequence_q, LanguageModelDimNames.vocab_tp] + else [LanguageModelDimNames.sequence_q, LanguageModelDimNames.vocab_tp] ) - dim_names.insert(int(kwargs[TransformerKwargs.sequence_first]), TransformerDimNames.batch) + dim_names.insert(int(kwargs[LanguageModelKwargs.sequence_first]), LanguageModelDimNames.batch) log_distributed_tensor( "", logits, diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index c8d53a789..f5d915855 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -6,7 +6,6 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs -from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert @@ -41,29 +40,29 @@ def _create_tensors(self, sequence_length: int) -> None: ) def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[TransformerKwargs.sequence_length]) - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size - if (sequence_lengths := kwargs.get(TransformerKwargs.sequence_lengths)) is not None: + self._create_tensors(kwargs[LanguageModelKwargs.sequence_length]) + sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size + sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size + if (sequence_lengths := kwargs.get(LanguageModelKwargs.sequence_lengths)) is not None: position_ids = torch.stack( [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] ).to(self._tensor_space.distributed.device, dtype=torch.int64) position_ids = position_ids[:, sequence_k - sequence_q : sequence_k] - if kwargs[TransformerKwargs.sequence_first]: + if kwargs[LanguageModelKwargs.sequence_first]: position_ids = position_ids.transpose(0, 1) kwargs[LanguageModelKwargs.position_ids] = position_ids else: kwargs[LanguageModelKwargs.position_ids] = self._position_ids[ sequence_k - sequence_q : sequence_k - ].unsqueeze(int(kwargs[TransformerKwargs.sequence_first])) + ].unsqueeze(int(kwargs[LanguageModelKwargs.sequence_first])) def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: # Position embeddings will be broadcast. - sequence_q_dim = kwargs[TransformerKwargs.sequence_q_dim] + sequence_q_dim = kwargs[LanguageModelKwargs.sequence_q_dim] kwargs[LanguageModelKwargs.position_ids] = TensorMeta.from_dims( ( (sequence_q_dim, self._scalar_dim) - if kwargs[TransformerKwargs.sequence_first] + if kwargs[LanguageModelKwargs.sequence_first] else (self._scalar_dim, sequence_q_dim) ), tensor_name=LanguageModelKwargs.position_ids, @@ -82,8 +81,8 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: return def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size + sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size + sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels if LanguageModelKwargs.chosen_spans not in kwargs or LanguageModelKwargs.rejected_spans not in kwargs: diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/block.py similarity index 52% rename from fast_llm/layers/ssm/llamba_block.py rename to fast_llm/layers/ssm/block.py index 986606634..4854900a3 100644 --- a/fast_llm/layers/ssm/llamba_block.py +++ b/fast_llm/layers/ssm/block.py @@ -1,14 +1,12 @@ -import typing +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.block.block import Block +from fast_llm.layers.block.config import BlockConfig +from fast_llm.layers.block.mixer import Mixer +from fast_llm.layers.ssm.config import SSMConfig -from fast_llm.layers.transformer.transformer import BaseBlock, Mixer -if typing.TYPE_CHECKING: - from fast_llm.engine.config_utils.tensor_space import TensorSpace - from fast_llm.layers.ssm.config import SSMConfig - from fast_llm.layers.transformer.config import TransformerConfig - - -class SSMBlock(BaseBlock): +# TODO: Sort out configs. +class SSMBlock[ConfigType: BlockConfig](Block[BlockConfig]): """ A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458 """ @@ -17,16 +15,16 @@ class SSMBlock(BaseBlock): def __init__( self, - transformer_config: "TransformerConfig", - ssm_config: "SSMConfig", - tensor_space: "TensorSpace", + config: BlockConfig, + ssm_config: SSMConfig, + tensor_space: TensorSpace, mixer_cls: type[Mixer], block_index: int, return_input: bool = False, ): self._ssm_config = ssm_config self._mixer_cls = mixer_cls - super().__init__(transformer_config, tensor_space, block_index, return_input) + super().__init__(config, tensor_space, block_index, return_input) def _create_mixer(self) -> Mixer: return self._mixer_cls( diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 9b0949d55..efcf2d873 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,18 +1,18 @@ import enum import typing -from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType -from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig +from fast_llm.layers.block.config import BlockDimNames from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: from fast_llm.tensor import Initializer -class SSMDimNames: +class SSMDimNames(BlockDimNames): # TODO: Use separate tensor space for different mixers so there is no risk of name conflict. state = "ssm_state" # State dimension (N), aka head size / num channels head_dim = "ssm_head_dim" @@ -72,15 +72,9 @@ def get_init_method(self, scale: float) -> "Initializer": @config_class() -class SSMConfig(LLMBlockConfig): +class SSMConfig(Config): _abstract = False - # Normalization - normalization: NormalizationConfig = Field( - desc="Configuration for the normalization layers architecture.", - hint=FieldHint.architecture, - ) - # Model dimensions # TODO: Remove (redundant default) expansion_factor: int = Field( diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index c9d555de9..550c44d0f 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -6,10 +6,10 @@ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType +from fast_llm.layers.block.config import BlockConfig, BlockKwargs +from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs -from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ from fast_llm.utils import get_lr_scale @@ -42,15 +42,15 @@ def __init__( config: SSMConfig, block_index: int, tensor_space: TensorSpace, - transformer_config: TransformerConfig, + block_config: BlockConfig, ): - super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) + super().__init__(tensor_space, block_index, debug_level=block_config.debug_transformer) self._config: SSMConfig = config - layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None + layer_lr_scale = block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) inner_dim = tensor_space[SSMDimNames.composite_heads_and_head_dim] - hidden_dim = tensor_space[TransformerDimNames.hidden] + hidden_dim = tensor_space[SSMDimNames.hidden] conv1d_dim = tensor_space[SSMDimNames.concatenated_convolution] heads_dim = tensor_space[SSMDimNames.composite_heads] @@ -69,7 +69,7 @@ def __init__( hidden_dim, tensor_space[SSMDimNames.concatenated_inner_projection], bias=config.add_bias_linear, - weight_init_method=init_kaiming_(transformer_config.hidden_size), + weight_init_method=init_kaiming_(block_config.hidden_size), sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) @@ -113,12 +113,12 @@ def __init__( def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available - sequence_length = kwargs[TransformerKwargs.sequence_q_dim].global_size + sequence_length = kwargs[BlockKwargs.sequence_q_dim].global_size # Pad input to nearest multiple of chunklen padded_length = (1 + (sequence_length - 1) // self._config.chunk_size) * self._config.chunk_size if padded_length != sequence_length: - assert not kwargs[TransformerKwargs.sequence_first] and input_.size(1) == sequence_length + assert not kwargs[BlockKwargs.sequence_first] and input_.size(1) == sequence_length input_ = torch.nn.functional.pad(input_, (0, 0, 0, padded_length - sequence_length)) # inner_projection : (batch/local_or_padded_sequence, local_sequence/batch, hidden) @@ -126,10 +126,9 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ # inner_projection: (batch, local_or_padded_sequence, hidden) -> (batch, padded_sequence, local_inner_size) inner_projection = self.in_proj(input_) # Standardize to (batch, padded_sequence, inner_projection) - if kwargs[TransformerKwargs.sequence_first]: + if kwargs[BlockKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) - print("QAIKOFNMJOWENM inner_projection", inner_projection.shape) xBC, z, A_log = torch.split( inner_projection, [ @@ -139,10 +138,6 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ ], dim=-1, ) - print("QAIKOFNMJOWENM xBC", xBC.shape, self._local_inner_size, self._local_bc_size) - print("QAIKOFNMJOWENM z", z.shape) - print("QAIKOFNMJOWENM A_log", A_log.shape) - # Convolutional layer # xbc: (batch, padded_sequence, local_heads * head_size + 2 * local_head_groups * state) xBC = self.convolutional_forward(xBC, padded_length) @@ -183,14 +178,12 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ # y: (batch, padded_sequence, local_heads, head_size) -> (batch, sequence, local_heads * head_size) y = ((y + Du).flatten(2, 3) * torch.nn.functional.silu(z))[:, :sequence_length] - if kwargs[TransformerKwargs.sequence_first]: + if kwargs[BlockKwargs.sequence_first]: # TODO: Is contiguous needed? y = y.transpose(0, 1).contiguous() # out_proj: (batch/sequence, sequence/batch, local_heads * head_size) # -> (batch/local_sequence, local_sequence/batch, hidden) a, b = self.out_proj(y) - logger.info(f"EKFBN y {y.shape}") - logger.info(f"EKFBN a {a.shape}") return self.out_proj(y) @torch.compile diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 77c1b3869..712c420ee 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -5,11 +5,11 @@ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.config import ActivationType +from fast_llm.layers.block.config import BlockConfig +from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs -from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_ from fast_llm.utils import Assert, div, get_lr_scale @@ -38,15 +38,15 @@ class Mamba2(Mixer): _mixer_name: typing.ClassVar[str] = "mamba_2" _XZ_DIMS = ( - TransformerDimNames.batch, + SSMDimNames.batch, SSMDimNames.composite_heads_and_head_dim, - TransformerDimNames.sequence_q, + SSMDimNames.sequence_q, ) _BC_DIMS = ( - TransformerDimNames.batch, + SSMDimNames.batch, SSMDimNames.composite_heads, SSMDimNames.state, - TransformerDimNames.sequence_q, + SSMDimNames.sequence_q, ) def __init__( @@ -54,17 +54,19 @@ def __init__( config: SSMConfig, tensor_space: TensorSpace, block_index: int, - transformer_config: TransformerConfig, + block_config: BlockConfig, ): - super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) + super().__init__(tensor_space, block_index, debug_level=block_config.debug_transformer) self._config: SSMConfig = config Assert.eq(self._config.activation_type, ActivationType.silu) - layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None + layer_lr_scale: float | None = ( + block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None + ) lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) inner_dim: TensorDim = tensor_space[SSMDimNames.composite_heads_and_head_dim] xb_dim = tensor_space[SSMDimNames.composite_head_groups_and_state] - hidden_dim: TensorDim = tensor_space[TransformerDimNames.hidden] + hidden_dim: TensorDim = tensor_space[SSMDimNames.hidden] dt_rank_dim = tensor_space[SSMDimNames.dt_rank] self._local_heads = tensor_space[SSMDimNames.composite_heads].size @@ -92,7 +94,7 @@ def __init__( hidden_dim, tensor_space[SSMDimNames.concatenated_inner_projection], bias=config.add_bias_linear, - weight_init_method=init_kaiming_(transformer_config.hidden_size), + weight_init_method=init_kaiming_(block_config.hidden_size), sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) @@ -101,7 +103,7 @@ def __init__( hidden_dim, dt_rank_dim, bias=config.add_bias_linear, - weight_init_method=init_kaiming_(transformer_config.hidden_size), + weight_init_method=init_kaiming_(block_config.hidden_size), lr_scale=lr_scale, ) self.dt_proj = OutputParallelLinear( @@ -151,7 +153,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ inner_projection = self.in_proj(input_) dt = self.dt_proj(self.dt_in_proj(input_)) + self.dt_proj_bias # Standardize to (batch, sequence, inner_projection) - if kwargs[TransformerKwargs.sequence_first]: + if kwargs[BlockKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) dt = dt.transpose(0, 1) @@ -220,7 +222,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ # y: (batch, local_heads * state, sequence) -> (batch, sequence, local_heads * state) y = y.transpose(1, 2)[:, :sequence_length] - if kwargs[TransformerKwargs.sequence_first]: + if kwargs[BlockKwargs.sequence_first]: # TODO: Is contiguous needed? y = y.transpose(0, 1).contiguous() # (batch/sequence, sequence/batch, local_heads * state) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 9343ef1b8..f5b0139cf 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -6,10 +6,10 @@ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType +from fast_llm.layers.block.config import BlockConfig, BlockKwargs +from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs -from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import LambdaInitializer, ParameterMeta, init_kaiming_, init_ones_ from fast_llm.utils import Assert, get_lr_scale @@ -60,9 +60,9 @@ def __init__( config: SSMConfig, block_index: int, tensor_space: TensorSpace, - transformer_config: TransformerConfig, + block_config: BlockConfig, ): - super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) + super().__init__(tensor_space, block_index, debug_level=block_config.debug_transformer) assert tensor_space.distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" self._config = config # TODO: It's not silu? @@ -70,8 +70,8 @@ def __init__( # Tensor dims: inner_dim = tensor_space[SSMDimNames.composite_heads_and_head_dim] - hidden_dim = tensor_space[TransformerDimNames.hidden] - layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None + hidden_dim = tensor_space[SSMDimNames.hidden] + layer_lr_scale = block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) # TODO: Backward compatibility? @@ -141,7 +141,7 @@ def __init__( def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available - in_proj = self.in_proj(input_).permute((1, 2, 0) if kwargs[TransformerKwargs.sequence_first] else (0, 2, 1)) + in_proj = self.in_proj(input_).permute((1, 2, 0) if kwargs[BlockKwargs.sequence_first] else (0, 2, 1)) # In the backward pass we write dx and dz next to each other to avoid torch.cat # not, if we wanbt to support inference, we would need to imp.lement slow path here, see https://github.com/Zyphra/Zamba2/blob/1b182f40f2257f822cc06dd785df53d67d691a15/mamba_layer.py#L172s @@ -160,6 +160,6 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ delta_bias=self.dt_proj_bias.float(), delta_softplus=True, ) - if kwargs[TransformerKwargs.sequence_first]: + if kwargs[BlockKwargs.sequence_first]: out = out.transpose(0, 1) return out, None diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index c59b191af..b1de792e3 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -6,14 +6,10 @@ from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.autograd import wrap_forward_backward +from fast_llm.layers.block.mixer import Mixer +from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import ( - TransformerConfig, - TransformerDimNames, - TransformerKwargs, - TransformerSubLayerName, -) -from fast_llm.layers.transformer.transformer import Mixer +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import get_lr_scale diff --git a/fast_llm/layers/transformer/block.py b/fast_llm/layers/transformer/block.py new file mode 100644 index 000000000..4a0e818f0 --- /dev/null +++ b/fast_llm/layers/transformer/block.py @@ -0,0 +1,23 @@ +import logging +import typing + +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.block.block import Block +from fast_llm.layers.block.mixer import Mixer +from fast_llm.layers.transformer.attention import Attention +from fast_llm.layers.transformer.config import TransformerConfig + +logger = logging.getLogger(__name__) + + +class TransformerBlock[ConfigType: TransformerConfig](Block[ConfigType]): + _name = "Transformer layer" + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "self_attn" + _config: TransformerConfig + + def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, return_input: bool = False): + super().__init__(config, tensor_space, block_index, return_input) + + def _create_mixer(self) -> Mixer: + return Attention(self._config, self._tensor_space, self._block_index) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index f6eaf5890..1c10753a8 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -1,44 +1,25 @@ -import abc -import enum import functools import logging -import math import typing import warnings -from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames -from fast_llm.functional.config import ActivationType, MLPRecomputeLevel, TritonConfig -from fast_llm.layers.common.config import LLMBlockConfig, LoRAConfig, NoPeftConfig, NormalizationConfig, PeftConfig +from fast_llm.functional.config import TritonConfig +from fast_llm.layers.block.config import AddLinearBiasChoices, BlockConfig, BlockDimNames, BlockKwargs from fast_llm.layers.transformer.rotary.config import RotaryConfig from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: - import torch - - from fast_llm.layers.common.linear import LinearBase, LinearLike - from fast_llm.tensor import ParameterMeta + pass logger = logging.getLogger(__name__) -class RoutingType(str, enum.Enum): - topk = "aux_loss" - sinkhorn = "sinkhorn" - - -class TransformerDimNames: +class TransformerDimNames(BlockDimNames): # A set of common tensor dim names packed into a namespace. - # Input dimensions (variable) - # TODO: Does batch belong here? - batch = "batch" - # TODO: Distinguish micro-sequence? - sequence_q = "sequence_q" - sequence_q_tp = "sequence_q_tp" - sequence_k = "sequence_k" - hidden = "hidden" # Self-attention dimensions head_groups = "head_groups" group_heads = "group_heads" @@ -48,21 +29,9 @@ class TransformerDimNames: composite_query = "composite_query" composite_key_value = "composite_key_value" composite_dense = "composite_dense" - # MLP dimensions - mlp = "mlp" - gate_and_up = "gate_and_up" - composite_gated_mlp = "composite_gated_mlp" - experts = "experts" - top_experts = "top_experts" - shared_experts = "shared_experts" - unshared_experts = "unshared_experts" - composite_expert_mlp = "composite_expert_mlp" - composite_gated_expert_mlp = "composite_gated_expert_mlp" - composite_shared_expert_mlp = "composite_shared_expert_mlp" - composite_gated_shared_expert_mlp = "composite_gated_shared_expert_mlp" -class TransformerKwargs: +class TransformerKwargs(BlockKwargs): rotary_freq_q = "rotary_freq_q" rotary_freq_k = "rotary_freq_k" attention_mask = "attention_mask" @@ -75,164 +44,17 @@ class TransformerKwargs: # TODO: Review these presents = "presents" past_key_values = "past_key_values" - sequence_first = "sequence_first" - hidden_dims = "hidden_dims" - sequence_q_dim = "sequence_q_dim" - sequence_k_dim = "sequence_k_dim" - sequence_length = "sequence_length" - # TODO: Move - grad_output = "grad_output" - - -class TransformerLossNames: - load_balancing_loss = "load_balancing_loss" - router_z_loss = "router_z_loss" - - -class AddLinearBiasChoices(str, enum.Enum): - nowhere = "nowhere" - everywhere = "everywhere" - only_attn_qkv = "only_attn_qkv" - - -class TransformerSubLayerName(str, enum.Enum): - # TODO: Use this to replace AddLinearBiasChoices. - query = "query" - key = "key" - value_ = "value" - key_value = "key_value" - dense = "dense" - mlp_1 = "mlp_1" - mlp_2 = "mlp_2" -@config_class(registry=True) -class TransformerPeftConfig(PeftConfig): - @abc.abstractmethod - def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": - pass - - @abc.abstractmethod - def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": - pass - - @abc.abstractmethod - def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": - pass - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - if cls is TransformerPeftConfig and cls.get_subclass(default.get("type")) is None: - # Default subclass. - return TransformerNoPeftConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) - - -@config_class(dynamic_type={TransformerPeftConfig: "none"}) -class TransformerNoPeftConfig(NoPeftConfig, TransformerPeftConfig): +class AttentionConfig(Config): + # TODO: Make mixer class dynamic. _abstract = False - def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": - return super().apply_linear(linear) - - def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": - return module - - def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": - return parameter - - -@config_class(dynamic_type={TransformerPeftConfig: "lora"}) -class TransformerLoRAConfig(LoRAConfig, TransformerPeftConfig): - layers: list[TransformerSubLayerName] = Field( - default=(TransformerSubLayerName.query, TransformerSubLayerName.value_), - desc="The layers on which to apply LoRA.", - hint=FieldHint.feature, - ) - freeze_others: bool = Field( - default=True, - desc="Whether to freeze other layers during training.", - ) - - def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": - if layer_type is None or self.layers is None or layer_type in self.layers: - if layer_type == TransformerSubLayerName.key: - return super().apply_linear(linear, out_channel_end=div(linear._out_dim.global_size, 2)) - elif layer_type == TransformerSubLayerName.value_: - return super().apply_linear(linear, out_channel_begin=div(linear._out_dim.global_size, 2)) - else: - return super().apply_linear(linear) - elif self.freeze_others: - linear.weight.requires_grad = False - return linear - - def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": - if self.freeze_others: - for parameter in module.parameters(): - parameter.requires_grad = False - return module - - def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": - if self.freeze_others: - parameter.requires_grad = False - return parameter - - def _validate(self) -> None: - super()._validate() - if TransformerSubLayerName.mlp_1 in self.layers or TransformerSubLayerName.mlp_2 in self.layers: - # TODO: Add MLP support. - raise NotImplementedError("LoRA not supported for MLP.") - if TransformerSubLayerName.dense in self.layers: - # TODO: Support InputParallelLinear (different output format). - raise NotImplementedError("LoRA not supported for attention dense layer.") - if ( - sum( - name in self.layers - for name in ( - TransformerSubLayerName.key_value, - TransformerSubLayerName.key, - TransformerSubLayerName.value_, - ) - ) - > 1 - ): - raise ValueError( - f"{TransformerSubLayerName.key_value.value}, {TransformerSubLayerName.key.value} and {TransformerSubLayerName.value_.value} are mutually exclusive." - ) - - -@config_class() -class TransformerConfig(LLMBlockConfig): - _abstract = False - normalization: NormalizationConfig = Field( - desc="Configuration for the normalization layers architecture.", - hint=FieldHint.architecture, - ) + # TODO: Review names rotary: RotaryConfig = Field( desc="Configuration for the rotary positional embeddings.", hint=FieldHint.architecture, ) - peft: TransformerPeftConfig = Field( - desc="Configuration for the parameter-efficient fine tuning.", - hint=FieldHint.architecture, - ) - num_layers: int = Field( - default=12, - desc="Number of layers in the transformer.", - hint=FieldHint.architecture, - valid=check_field(Assert.geq, 0), - ) - hidden_size: int = Field( - default=1024, - desc="Size of the transformer's main hidden dimension, e.g., for its input and output layers.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) num_attention_heads: int = Field(default=8, desc="Number of attention heads.", hint=FieldHint.architecture) head_groups: int = Field( default=1, @@ -241,60 +63,104 @@ class TransformerConfig(LLMBlockConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - add_linear_biases: bool | AddLinearBiasChoices = Field( - default=True, - desc="Add biases to all, none or Q, K, V layers. Accepted values: True, False, or AddLinearBiasChoices.", - hint=FieldHint.architecture, - ) - ffn_hidden_size: int = Field( - default=None, - desc="Hidden dimension of the MLP intermediate state. Default: 4 * hidden_size.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) kv_channels: int = Field( default=None, desc="Number of key and value channels, i.e., hidden dimension of each attention head. Default: hidden_size // num_attention_heads", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - gated: bool = Field(default=False, desc="Enable gated MLP.", hint=FieldHint.architecture) - num_experts: int = Field( - default=1, - desc="Number of MLP experts in a Mixture of Expert (MoE) model", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - num_shared_experts: int = Field( - default=0, - desc="Number of MLP experts that are shared between all tokens, i.e., always enabled.", - hint=FieldHint.architecture, + attention_dropout: float = Field( + default=0.0, + desc="Dropout applied to the attention intermediate states.", + hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - num_unshared_experts: int = Field( - init=False, - desc="Number of MLP experts excluding shared ones", - hint=FieldHint.architecture, - valid=check_field(Assert.geq, 0), + # Use flash attention if possible (fp16 or bf16) + use_flash_attention: bool = Field( + default=True, desc="Enable Flash Attention if possible.", hint=FieldHint.optional ) - num_experts_per_token: int = Field( - default=1, - desc="Active experts for each token in a MoE model.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), + window_size: int | None = Field( + default=None, + desc="Size of the attention sliding window. Warning: this parameter is not part of the architecture and must be redefined when loading a pretrained model.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - expert_routing_type: RoutingType = Field( - default=RoutingType.topk, - desc="The routing method, i.e., the method used to assign experts to tokens.", - hint=FieldHint.architecture, + max_window_layers: int | None = Field( + default=None, + desc="The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.", + hint=FieldHint.optional, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - activation_type: ActivationType = Field( + attention_lr_scale: float | None = Field( default=None, - desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", - hint=FieldHint.core, + desc="Custom learning rate scale for the Attention projection weights.", + doc="Can be used in muP to scale the Attention learning rate by 1/width_factor", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + attention_softmax_scale_power: float = Field( + default=0.5, + desc="The scaling power to apply to kv_channel in the attention calculation. " + " Under Standard Parameterization (SP): default to 0.5. " + " Under muP (if scaling kv_channels size): use 1. " + " Under muP (if scaling number of heads instead of kv_channels): use 0.5.", + valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - # Default: hidden_size**-0.5 - # TODO: Allow custom initialization (InitializationConfig?) + + def _validate(self) -> None: + super()._validate() + + if not TritonConfig.TRITON_ENABLED: + warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") + + Assert.multiple(self.num_attention_heads, self.head_groups) + + @functools.cached_property + def projection_size(self): + assert self._validated + return self.num_attention_heads * self.kv_channels + + def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: + return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) + + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) + # Needed for multiple inheritance. + + tensor_space.add_tensor_dim( + head_groups := TensorDim( + TransformerDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None + ) + ) + tensor_space.add_tensor_dim( + group_heads := TensorDim( + TransformerDimNames.group_heads, + div(self.num_attention_heads, self.head_groups), + None if self.head_groups > 1 else tensor, + ) + ) + tensor_space.add_tensor_dim(key_and_value := TensorDim(TransformerDimNames.key_and_value, 2)) + tensor_space.add_tensor_dim(kv_channels := TensorDim(TransformerDimNames.kv_channels, self.kv_channels)) + tensor_space.add_tensor_dim( + CompositeTensorDim(TransformerDimNames.composite_heads, (head_groups, group_heads)) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(TransformerDimNames.composite_query, (head_groups, group_heads, kv_channels)) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(TransformerDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(TransformerDimNames.composite_dense, (head_groups, group_heads, kv_channels)) + ) + + +@config_class() +# TODO: Use composition for attention config +class TransformerConfig(AttentionConfig, BlockConfig): + _abstract = False + + # TODO: Review names init_method_std: float = Field( default=None, desc="Default scale for weight initialization. Default: hidden_size**-0.5", @@ -375,125 +241,17 @@ class TransformerConfig(LLMBlockConfig): desc="Min value for clamping initialized weights for MLP second layer. Default: -float('inf')", hint=FieldHint.optional, ) - attention_dropout: float = Field( - default=0.0, - desc="Dropout applied to the attention intermediate states.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - hidden_dropout: float = Field( - default=0.0, - desc="Dropout applied to the residual connections.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - full_precision_residual: bool = Field( - default=False, - desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", - hint=FieldHint.stability, - ) - # Use flash attention if possible (fp16 or bf16) - use_flash_attention: bool = Field( - default=True, desc="Enable Flash Attention if possible.", hint=FieldHint.optional - ) - window_size: int | None = Field( - default=None, - desc="Size of the attention sliding window. Warning: this parameter is not part of the architecture and must be redefined when loading a pretrained model.", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - max_window_layers: int | None = Field( - default=None, - desc="The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.", - hint=FieldHint.optional, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - # normalization_implementation: NormalizationImplementation = NormalizationImplementation.auto - mlp_recompute_level: MLPRecomputeLevel = Field( - default=MLPRecomputeLevel.none, - desc="Set which of the MLP intermediate activations will be recomputed during the backward passes. This provides a trade-off between memory and speed.", - hint=FieldHint.performance, - ) - debug_transformer: int = Field( - default=0, - desc="Log the output of each operation in a transformer layer.", - hint=FieldHint.logging, - valid=check_field(Assert.geq, 0), - ) - debug_transformer_memory: bool = Field( - default=False, - desc="Log the memory usage after each operation in a transformer layer..", - hint=FieldHint.logging, - ) # Use random inits instead of constant values, useful for debugging. random_bias_init: bool = Field( default=False, desc="Initialize the biases using the initialization method of their respective weights instead of setting them to zero. Used to test for issues that may not be visible when the biases are zero.", hint=FieldHint.testing, ) - expert_auxiliary_loss_coefficient: float = Field( - default=0.01, - desc="Scale of the load balancing auxiliary loss for topk routing.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - expert_z_loss_coefficient: float = Field( - default=0.0, - desc="Regularize the router during training by applying Z-loss to the logits.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - moe_jitter_eps: float = Field( - default=0.0, - desc="Regularize the router during training by applying a random multiplicative noise `uniform(1-eps, 1+eps)` to the logits.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - mlp_lr_scale: float | None | list[float | None] = Field( - default=None, - desc="Custom learning rate scale for each expert.", - doc="May be used to freeze some experts by setting their scale to zero.", - hint=FieldHint.feature, - ) - router_lr_scale: float | None = Field( - default=None, - desc="Custom learning rate for the MoE router weight.", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - attention_lr_scale: float | None = Field( - default=None, - desc="Custom learning rate scale for the Attention projection weights.", - doc="Can be used in muP to scale the Attention learning rate by 1/width_factor", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - attention_softmax_scale_power: float = Field( - default=0.5, - desc="The scaling power to apply to kv_channel in the attention calculation. " - " Under Standard Parameterization (SP): default to 0.5. " - " Under muP (if scaling kv_channels size): use 1. " - " Under muP (if scaling number of heads instead of kv_channels): use 0.5.", - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - dropless_moe: bool = Field( - default=True, desc="Evaluate all the experts at once using dropless MoE.", hint=FieldHint.expert - ) - dropless_dynamic_shape: bool = Field( - default=False, - desc="Use a dynamic shape for dropless MLP instead of the worst-case value." - " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", - hint=FieldHint.expert, - ) def _validate(self) -> None: with self._set_implicit_default(): - if self.ffn_hidden_size is None: - self.ffn_hidden_size = 4 * self.hidden_size if self.kv_channels is None: self.kv_channels = div(self.hidden_size, self.num_attention_heads) - if self.activation_type is None: - self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu if self.init_method_std is None: self.init_method_std = self.hidden_size**-0.5 if self.init_method_std_qkv is None: @@ -532,40 +290,9 @@ def _validate(self) -> None: Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) - self.num_unshared_experts = self.num_experts - self.num_shared_experts super()._validate() - if not TritonConfig.TRITON_ENABLED: - warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") - - Assert.leq(self.num_shared_experts, self.num_experts) - Assert.leq(self.num_shared_experts + self.num_experts_per_token, self.num_experts) - Assert.multiple(self.num_attention_heads, self.head_groups) - Assert.geq(self.attention_dropout, 0) - Assert.geq(self.hidden_dropout, 0) - - if isinstance(self.mlp_lr_scale, list): - Assert.eq(len(self.mlp_lr_scale), self.num_experts) - for scale in self.mlp_lr_scale: - if scale is not None: - Assert.geq(scale, 0) - elif self.mlp_lr_scale is not None: - Assert.geq(self.mlp_lr_scale, 0) - - @functools.cached_property - def projection_size(self): - assert self._validated - return self.num_attention_heads * self.kv_channels - - @property - def add_mlp_bias(self) -> bool: - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.everywhere: - return True - return False - @property def add_attn_qkv_bias(self) -> bool: if isinstance(self.add_linear_biases, bool): @@ -581,84 +308,3 @@ def add_attn_dense_bias(self) -> bool: if self.add_linear_biases == AddLinearBiasChoices.everywhere: return True return False - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - # TODO v0.x: Remove backward compatibility. - cls._handle_renamed_field( - default, - "use_rotary_embeddings", - ("rotary", "type"), - lambda x: "default" if x else "none", - ) - cls._handle_renamed_field(default, "rotary_embedding_scale", ("rotary", "theta"), lambda x: math.exp(-x)) - cls._handle_renamed_field(default, "triton_rotary", ("rotary", "triton")) - return super()._from_dict(default, strict, flat) - - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - - # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.hidden, self.hidden_size)) - - # Self-attention dimensions - tensor_space.add_tensor_dim( - head_groups := TensorDim( - TransformerDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None - ) - ) - tensor_space.add_tensor_dim( - group_heads := TensorDim( - TransformerDimNames.group_heads, - div(self.num_attention_heads, self.head_groups), - None if self.head_groups > 1 else tensor, - ) - ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(TransformerDimNames.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(TransformerDimNames.kv_channels, self.kv_channels)) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_heads, (head_groups, group_heads)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_query, (head_groups, group_heads, kv_channels)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_dense, (head_groups, group_heads, kv_channels)) - ) - - # MLP dimensions - tensor_space.add_tensor_dim(mlp := TensorDim(TransformerDimNames.mlp, self.ffn_hidden_size, tensor)) - tensor_space.add_tensor_dim(gate_and_up := TensorDim(TransformerDimNames.gate_and_up, 2 if self.gated else 1)) - tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_gated_mlp, (gate_and_up, mlp))) - tensor_space.add_tensor_dim(experts := TensorDim(TransformerDimNames.experts, self.num_experts)) - tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_expert_mlp, (experts, mlp))) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) - ) - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.top_experts, self.num_experts_per_token)) - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.unshared_experts, self.num_unshared_experts)) - - # shared_experts - if self.num_shared_experts: - tensor_space.add_tensor_dim( - shared_experts := TensorDim(TransformerDimNames.shared_experts, self.num_shared_experts) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_shared_expert_mlp, (shared_experts, mlp)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim( - TransformerDimNames.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) - ) - ) - - def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: - return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index d8425786d..2dbef77f3 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -24,8 +24,9 @@ from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType +from fast_llm.layers.block.mlp.config import RoutingType from fast_llm.layers.common.config import LayerNormalizationConfig -from fast_llm.layers.transformer.config import RoutingType, TransformerConfig +from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig from fast_llm.layers.transformer.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex from fast_llm.models.gpt.config import ( diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 49a5dcbd3..da647de57 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -10,18 +10,14 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.layers.block.mlp.config import MLPLossNames, RoutingType from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor -from fast_llm.layers.transformer.config import ( - RoutingType, - TransformerDimNames, - TransformerKwargs, - TransformerLossNames, -) +from fast_llm.layers.transformer.block import TransformerBlock +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor -from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -374,7 +370,7 @@ def loss_defs(self) -> list[LossDef]: ): loss_defs.append( LossDef( - name=TransformerLossNames.load_balancing_loss, + name=MLPLossNames.load_balancing_loss, formatted_name="load balancing loss", count=self._config.transformer.num_layers, ) @@ -382,7 +378,7 @@ def loss_defs(self) -> list[LossDef]: if self._config.transformer.expert_z_loss_coefficient: loss_defs.append( LossDef( - name=TransformerLossNames.router_z_loss, + name=MLPLossNames.router_z_loss, formatted_name="router z loss", count=self._config.transformer.num_layers, ) diff --git a/tests/test_mlp.py b/tests/test_mlp.py index bcfbaf693..4cf1ac458 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -1,8 +1,8 @@ -from fast_llm.layers.transformer.mlp import MLP -from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.block.mlp import MLP +from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP +from fast_llm.layers.transformer.config import TransformerConfig def test_mlp_constructor(): @@ -20,11 +20,7 @@ def test_mlp_constructor(): def test_moe_mlp_constructor(): transformer_conf = TransformerConfig( - num_layers=2, - num_attention_heads=2, - hidden_size=16, - num_experts=2, - add_linear_biases=False + num_layers=2, num_attention_heads=2, hidden_size=16, num_experts=2, add_linear_biases=False ) distributed_config = DistributedConfig() tensor_space = TensorSpace(distributed_config=distributed_config) From 6bf06d6aecb9a2a0de67ad7a42690db071a812f4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 29 Jul 2025 15:51:13 -0400 Subject: [PATCH 29/82] fix --- fast_llm/tensor.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index b3795b740..d080e6a1e 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -201,13 +201,9 @@ def local_to_global_partial(self, tensor: torch.Tensor, fill_value: float | int tensor = tensor[None] Assert.eq(tensor.shape, self.shape) assert not self._reductions - logger.info(f"AAAA {self.tensor_name} {self.shape} {self.global_shape} {tensor.shape}") for dim, tensor_dim in enumerate(self.dims): if tensor_dim.is_parallel: tensor = tensor_dim.local_to_global_partial(tensor, dim, fill_value) - logger.info( - f"BBBB {self.tensor_name} {self.shape} {self.global_shape} {tensor.shape} {tensor_dim.is_parallel}" - ) Assert.eq(tensor.shape, self.global_shape) return tensor From 2ddc3a748817ee98785344e03809cfd67590e954 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 29 Jul 2025 16:15:10 -0400 Subject: [PATCH 30/82] fix --- fast_llm/engine/config_utils/tensor_space.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index cf2974a99..6c4b95b20 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -95,7 +95,7 @@ class CompositeTensorDim(TensorDim): def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): parallel_dim = None for dim, tensor_dim in enumerate(tensor_dims): - if tensor_dim.is_parallel: + if tensor_dim.parallel_dim is not None: # TODO: Allow more than one parallel subdim? assert parallel_dim is None parallel_dim = tensor_dim.parallel_dim From ce70b169e55dea29383eb3f6a488125b309487ce Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 29 Jul 2025 18:13:29 -0400 Subject: [PATCH 31/82] fixes --- fast_llm/layers/block/config.py | 16 +++++++++------- fast_llm/layers/block/mlp/config.py | 3 ++- fast_llm/layers/block/mlp/mlp.py | 2 +- fast_llm/layers/block/peft.py | 2 +- fast_llm/layers/ssm/mamba2.py | 2 +- fast_llm/layers/transformer/config.py | 3 ++- fast_llm/models/custom/model.py | 2 +- fast_llm/models/ssm/model.py | 8 ++++---- tests/test_mlp.py | 2 +- tests/test_multi_stage.py | 4 ++-- 10 files changed, 24 insertions(+), 20 deletions(-) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 2f26d8d79..5a999fa6d 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -4,7 +4,7 @@ from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.block.mlp.config import MLPConfig -from fast_llm.layers.block.peft_config import TransformerPeftConfig +from fast_llm.layers.block.peft import TransformerPeftConfig from fast_llm.layers.common.config import NormalizationConfig from fast_llm.utils import Assert @@ -26,11 +26,19 @@ class BlockKwargs: hidden_dims = "hidden_dims" sequence_q_dim = "sequence_q_dim" sequence_k_dim = "sequence_k_dim" + # TODO: These are confusing sequence_length = "sequence_length" + sequence_lengths = "sequence_lengths" # TODO: Belongs elsewhere? grad_output = "grad_output" +class AddLinearBiasChoices(str, enum.Enum): + nowhere = "nowhere" + everywhere = "everywhere" + only_attn_qkv = "only_attn_qkv" + + @config_class() # TODO: Use composition for MLP config class BlockConfig(MLPConfig, BaseModelConfig): @@ -112,9 +120,3 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: # Hidden dimension tensor_space.add_tensor_dim(TensorDim(BlockDimNames.hidden, self.hidden_size)) - - -class AddLinearBiasChoices(str, enum.Enum): - nowhere = "nowhere" - everywhere = "everywhere" - only_attn_qkv = "only_attn_qkv" diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 63e31219a..1d125c4f7 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -1,6 +1,6 @@ import enum -from fast_llm.config import Config, Field, FieldHint, check_field, skip_valid_if_none +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType, MLPRecomputeLevel @@ -32,6 +32,7 @@ class RoutingType(str, enum.Enum): sinkhorn = "sinkhorn" +@config_class() class MLPConfig(Config): # TODO: Review names _abstract = False diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index 04b8506a4..19349671e 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -15,7 +15,7 @@ from fast_llm.utils import Assert, get_lr_scale -class MLPBase[ConfigType: BlockConfig](Layer, Configurable[ConfigType]): +class MLPBase[ConfigType: BlockConfig](Configurable[ConfigType], Layer): def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): super().__init__(config) self._name = name diff --git a/fast_llm/layers/block/peft.py b/fast_llm/layers/block/peft.py index 269ed0aac..66bc675ed 100644 --- a/fast_llm/layers/block/peft.py +++ b/fast_llm/layers/block/peft.py @@ -8,13 +8,13 @@ from fast_llm.config import Field, FieldHint, config_class from fast_llm.layers.common.config import LoRAConfig, NoPeftConfig, PeftConfig -from fast_llm.tensor import ParameterMeta from fast_llm.utils import div if typing.TYPE_CHECKING: import torch from fast_llm.layers.common.linear import LinearBase, LinearLike + from fast_llm.tensor import ParameterMeta class TransformerSubLayerName(str, enum.Enum): diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 712c420ee..1c319f490 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -5,7 +5,7 @@ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.config import ActivationType -from fast_llm.layers.block.config import BlockConfig +from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 1c10753a8..ebb976e63 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -36,7 +36,6 @@ class TransformerKwargs(BlockKwargs): rotary_freq_k = "rotary_freq_k" attention_mask = "attention_mask" attention_mask_value = "attention_mask_value" - sequence_lengths = "sequence_lengths" cu_seqlens_q = "cu_seqlens_q" cu_seqlens_k = "cu_seqlens_k" max_seqlen_q = "max_seqlen_q" @@ -46,6 +45,7 @@ class TransformerKwargs(BlockKwargs): past_key_values = "past_key_values" +@config_class() class AttentionConfig(Config): # TODO: Make mixer class dynamic. _abstract = False @@ -126,6 +126,7 @@ def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Needed for multiple inheritance. + super().setup_tensor_space(tensor_space) # Noqa tensor_space.add_tensor_dim( head_groups := TensorDim( diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index 534d813ff..3c0ad8ab4 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -7,7 +7,7 @@ from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.schedule.config import BatchConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding -from fast_llm.layers.transformer.transformer import TransformerBlock +from fast_llm.layers.transformer.block import TransformerBlock from fast_llm.models.custom.config import CustomBaseModelConfig, CustomModelConfig from fast_llm.models.custom.head import CustomHead from fast_llm.models.gpt.config import GPTBaseModelConfig diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 3ba6b1a62..ca840911f 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -5,8 +5,8 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead -from fast_llm.layers.ssm.llamba_block import SSMBlock -from fast_llm.layers.transformer.transformer import TransformerBlock +from fast_llm.layers.ssm.block import SSMBlock +from fast_llm.layers.transformer.block import TransformerBlock from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType @@ -52,7 +52,7 @@ def get_output_layers(self) -> list[Layer]: else: layers.append( SSMBlock( - transformer_config=self._config.transformer, + config=self._config.transformer, ssm_config=self._config.ssm, mixer_cls=self._config.ssm_block_type.get_mixer_class(), block_index=len(self._config.hybrid_block_layout), @@ -88,7 +88,7 @@ def get_layers(self) -> list[Layer]: else: layers.append( SSMBlock( - transformer_config=self._config.transformer, + config=self._config.transformer, ssm_config=self._config.ssm, mixer_cls=self._config.ssm_block_type.get_mixer_class(), block_index=i + 1, diff --git a/tests/test_mlp.py b/tests/test_mlp.py index 4cf1ac458..5875822ff 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -1,7 +1,7 @@ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.block.mlp import MLP from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP +from fast_llm.layers.block.mlp.mlp import MLP from fast_llm.layers.transformer.config import TransformerConfig diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index 2f125717e..0639ec7ed 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -3,8 +3,8 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer -from fast_llm.layers.ssm.llamba_block import SSMBlock -from fast_llm.layers.transformer.transformer import TransformerBlock +from fast_llm.layers.ssm.block import SSMBlock +from fast_llm.layers.transformer.block import TransformerBlock from fast_llm.utils import Assert from tests.utils.dataset import get_model_test_dataset from tests.utils.model_configs import ModelTestingGroup From a9f733d121e47df360b997097abb8bf2d5ac49d1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 29 Jul 2025 18:33:05 -0400 Subject: [PATCH 32/82] fix --- fast_llm/layers/ssm/block.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/layers/ssm/block.py b/fast_llm/layers/ssm/block.py index 4854900a3..0bfa266ac 100644 --- a/fast_llm/layers/ssm/block.py +++ b/fast_llm/layers/ssm/block.py @@ -31,5 +31,5 @@ def _create_mixer(self) -> Mixer: self._ssm_config, tensor_space=self._tensor_space, block_index=self._block_index, - transformer_config=self._config, + block_config=self._config, ) From cef7c155ebe08c40a20b61a1c9f930ee223007f7 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 30 Jul 2025 12:20:46 -0400 Subject: [PATCH 33/82] fix --- fast_llm/models/ssm/config.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 9427f69be..866de962f 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -4,13 +4,18 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, config_class from fast_llm.data.data.gpt.config import GPTDataConfig -from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler +from fast_llm.engine.checkpoint.config import CheckpointHandler from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig -from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, PretrainedGPTModelConfig +from fast_llm.models.gpt.config import ( + GPTBaseModelConfig, + GPTBatchConfig, + GPTHuggingfaceCheckpointFormat, + PretrainedGPTModelConfig, +) from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -79,8 +84,7 @@ def _validate(self): self.ssm_block_type = ssm_block_types.pop() if ssm_block_types else None -class LLambaHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False +class LLambaHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "llamba" @classmethod @@ -90,8 +94,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return LLambaHuggingfaceCheckpointHandler -class AprielSSMHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False +class AprielSSMHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "apriel_ssm" @classmethod @@ -101,8 +104,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return AprielSSMHuggingfaceCheckpointHandler -class AprielSSMHHybridHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False +class AprielSSMHHybridHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "apriel_ssm_hybrid" @classmethod @@ -112,8 +114,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return AprielSSMHHybridHuggingfaceCheckpointHandler -class AprielThinkerSSMHHybridHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False +class AprielThinkerSSMHHybridHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "apriel_ssm_thinker_hybrid" @classmethod From a5eb0767e99038e18c1bd07f7f78718634296c4c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 31 Jul 2025 15:14:13 -0400 Subject: [PATCH 34/82] stuff --- docs/developer_guide/conversion.md | 30 ++- .../engine/config_utils/initialization.py | 178 ++++++++++++ fast_llm/layers/block/block.py | 132 +++++++-- fast_llm/layers/block/config.py | 160 +++++++++-- fast_llm/layers/block/mixer.py | 68 ----- fast_llm/layers/block/mlp/config.py | 79 +++++- .../layers/block/mlp/mixture_of_experts.py | 134 ++++------ fast_llm/layers/block/mlp/mlp.py | 72 ++--- fast_llm/layers/common/config.py | 2 +- fast_llm/layers/common/linear.py | 3 +- fast_llm/layers/common/normalization.py | 3 +- fast_llm/layers/language_model/config.py | 122 ++++----- fast_llm/layers/language_model/embedding.py | 48 ++-- fast_llm/layers/language_model/head.py | 109 ++++---- .../layers/language_model/preprocessing.py | 10 +- fast_llm/layers/ssm/config.py | 4 +- fast_llm/layers/ssm/discrete_mamba2.py | 4 +- fast_llm/layers/ssm/mamba2.py | 5 +- fast_llm/layers/ssm/mamba_layer.py | 7 +- fast_llm/layers/transformer/attention.py | 142 +++++----- fast_llm/layers/transformer/config.py | 253 ++++++------------ fast_llm/layers/transformer/preprocessing.py | 52 ++-- .../transformer/rotary/preprocessing.py | 26 +- fast_llm/layers/transformer/rotary/rotary.py | 30 +-- fast_llm/models/custom/model.py | 2 +- fast_llm/models/gpt/config.py | 9 +- fast_llm/models/gpt/conversion.py | 6 +- fast_llm/models/gpt/huggingface.py | 10 +- fast_llm/models/gpt/model.py | 54 ++-- fast_llm/models/ssm/config.py | 10 +- fast_llm/models/ssm/conversion.py | 6 +- fast_llm/tensor.py | 70 +---- tests/layers/test_lm_head.py | 6 +- tests/models/test_generate.py | 2 +- tests/test_attention.py | 16 +- tests/test_ssms.py | 6 +- tests/utils/model_configs.py | 2 + 37 files changed, 1015 insertions(+), 857 deletions(-) create mode 100644 fast_llm/engine/config_utils/initialization.py delete mode 100644 fast_llm/layers/block/mixer.py diff --git a/docs/developer_guide/conversion.md b/docs/developer_guide/conversion.md index 0620beaea..719757df1 100644 --- a/docs/developer_guide/conversion.md +++ b/docs/developer_guide/conversion.md @@ -230,21 +230,23 @@ Continuing our `AwesomeModel` handler example, we define: ```python def _create_weight_converters(self) -> list[WeightConverter]: + + converters = [] - # The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`. - num_layers = self._model.config.base_model.transformer.num_layers - - # A simple renaming example, for the word embeddings. - converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) - - # We usually want to loop dynamically over layers - for i in range(num_layers): - # A `SplitWeightConverter` example, splitting a weight in two. - converters.append(SplitWeightConverter( - f"layers.{i + 1}.weight", - (f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"), - )) - return converters +# The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`. +num_layers = self._model.config.base_model.transformer.num_blocks + +# A simple renaming example, for the word embeddings. +converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) + +# We usually want to loop dynamically over layers +for i in range(num_layers): + # A `SplitWeightConverter` example, splitting a weight in two. + converters.append(SplitWeightConverter( + f"layers.{i + 1}.weight", + (f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"), + )) +return converters ``` And that's it! We're ready to use the new checkpoint format in Fast-LLM. diff --git a/fast_llm/engine/config_utils/initialization.py b/fast_llm/engine/config_utils/initialization.py new file mode 100644 index 000000000..d35c2220c --- /dev/null +++ b/fast_llm/engine/config_utils/initialization.py @@ -0,0 +1,178 @@ +import abc +import typing + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + import torch + + from fast_llm.tensor import ParameterMeta + + +@config_class(registry=True) +class InitializationConfig(Config): + _abstract = True + has_initialization: typing.ClassVar[bool] = True + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is InitializationConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return DefaultInitializationConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + + def get_initializer(self) -> "Initializer": + raise NotImplementedError() + + +@config_class(dynamic_type={InitializationConfig: "default"}) +class DefaultInitializationConfig(InitializationConfig): + # A placeholder indicating that the class default should be used instead. + _abstract = False + has_initialization = False + + +@config_class(dynamic_type={InitializationConfig: "fill"}) +class NormalInitializationConfig(InitializationConfig): + """ + Normal initialization: normal(mean, std).clamp(min,max) + """ + + _abstract = False + + value: float = Field( + default=1, + desc="Initialization value.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + + def get_initializer(self): + return init_fill_(self.value) + + +@config_class(dynamic_type={InitializationConfig: "zeros"}) +class ZeroInitializationConfig(InitializationConfig): + def get_initializer(self): + return init_zeros_ + + +@config_class(dynamic_type={InitializationConfig: "ones"}) +class ZeroInitializationConfig(InitializationConfig): + def get_initializer(self): + return init_ones_ + + +@config_class(dynamic_type={InitializationConfig: "normal"}) +class NormalInitializationConfig(InitializationConfig): + """ + Normal initialization: normal(mean, std).clamp(min,max) + """ + + _abstract = False + + std: float = Field( + default=1, + desc="Standard deviation for normal initialization.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + mean: float = Field( + default=0, + desc="Mean for normal initialization.", + hint=FieldHint.optional, + ) + min: float | None = Field( + default=None, + desc="Min value for initialization clamping.", + hint=FieldHint.optional, + ) + max: float | None = Field( + default=None, + desc="Min value for initialization clamping.", + hint=FieldHint.optional, + ) + + def get_initializer(self): + return init_normal_(self.mean, self.std, self.min, self.max) + + +@config_class(dynamic_type={InitializationConfig: "uniform"}) +class UniformInitializationConfig(InitializationConfig): + """ + Uniform initialization: uniform(mean - scale, mean + scale).clamp(min,max) + """ + + _abstract = False + + scale: float = Field( + default=None, + desc="Initialization scale.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + mean: float = Field( + default=None, + desc="Initialization mean.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + + def get_initializer(self) -> "Initializer": + return init_uniform_centered_(self.scale, self.mean) + + +class Initializer(abc.ABC): + @abc.abstractmethod + def __call__(self, meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: + pass + + requires_global_initialization = False + + +class LambdaInitializer(Initializer): + def __init__( + self, + init_method: typing.Callable[["ParameterMeta", "torch.Tensor", "torch.Generator"], None], + requires_global_initialization: bool = False, + ) -> None: + self._init_method = init_method + self.requires_global_initialization = requires_global_initialization + + def __call__(self, meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: + return self._init_method(meta, tensor, generator) + + +def init_fill_(value: float) -> LambdaInitializer: + def init_(meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: # noqa + tensor.fill_(value) + + return LambdaInitializer(init_) + + +init_zeros_ = init_fill_(0.0) +init_ones_ = init_fill_(1.0) + + +def init_normal_( + mean: float = 0.0, std: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: # noqa + tensor = tensor.normal_(mean, std, generator=generator) + if min_val is not None or max_val is not None: + tensor.clamp_(min=min_val, max=max_val) + + return LambdaInitializer(init_) + + +def init_uniform_centered_(scale: float, mean: float = 0.0) -> LambdaInitializer: + def init_(meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: # noqa + tensor.uniform_(mean - scale, mean + scale, generator=generator) + + return LambdaInitializer(init_) diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 85da61c01..d13b09807 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -1,4 +1,5 @@ import abc +import functools import typing import torch @@ -8,23 +9,118 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs -from fast_llm.layers.block.mixer import Mixer -from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP -from fast_llm.layers.block.mlp.mlp import MLP +from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs, BlockLayerConfig from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta -class Block[ConfigType: BlockConfig](Layer, Configurable[ConfigType]): +class DebugLayer: + # TODO: Move elsewhere? + def __init__(self, tensor_space: TensorSpace, name: str, debug_level: int = 0, debug_memory: bool = False): + self._tensor_space = tensor_space + self._name = name + self._debug_level = debug_level + self._debug_memory = debug_memory + + def _get_meta( + self, tensor: torch.Tensor, name: str, dims: tuple[TensorDim | str, ...], kwargs: dict[str, typing.Any] + ) -> TensorMeta: + hidden_dims = { + dim.name: dim for dim in kwargs[BlockKwargs.hidden_dims] + (kwargs[BlockKwargs.sequence_q_dim],) + } + return TensorMeta.from_dims( + tuple( + ( + dim + if isinstance(dim, TensorDim) + else hidden_dims[dim] if dim in hidden_dims else self._tensor_space[dim] + ) + for dim in dims + ), + tensor_name=f"{self._name} {name}", + dtype=tensor.dtype, + ) + + @functools.cached_property + def enabled(self) -> bool: + return self._debug_level > 0 or self._debug_memory + + def __call__( + self, + tensor: torch.Tensor, + name: str, + dims: tuple[TensorDim | str, ...], + kwargs: dict[str, typing.Any], + scale: float = 1.0, + global_: bool = True, + log_fn: type[BaseException] | typing.Callable[[str], T] | None = logger.info, + ) -> None: + # TODO: Local vs global? + if self._debug_memory: + log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"{self._name} {name}", str)) + if self._debug_level > 0: + log_distributed_tensor( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name, dims, kwargs), + distributed=self._tensor_space.distributed, + global_=global_, + log_fn=log_fn, + scale=scale, + ) + if tensor.requires_grad: + log_distributed_grad( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name + " grad", dims, kwargs), + distributed=self._tensor_space.distributed, + global_=global_, + log_fn=log_fn, + scale=scale, + ) + + +class BlockLayer[ConfigType: BlockLayerConfig](Configurable[ConfigType], torch.nn.Module): + """ + Base class for mixer and MLP modules. + """ + + def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): + super().__init__(config) + self._tensor_space = tensor_space + self._block_index = block_index + self._name = name + self._sequence_parallel: bool = self._tensor_space.distributed_config.sequence_tensor_parallel + self._debug = DebugLayer( + tensor_space, + f"Block {self._block_index} {self._name}", + self.config.block.debug_transformer, + self._config.block.debug_transformer_memory, + ) + + @abc.abstractmethod + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + pass + + +class Block[ConfigType: BlockConfig](Configurable[ConfigType], Layer): """ A transformer-like decoder base block with abstract mixer. """ # TODO: Standardize to `mixer` - _mixer_module_name: typing.ClassVar[str] = "mixer" - def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, return_input: bool = False): + def __init__( + self, config: ConfigType, tensor_space: TensorSpace, block_index: int = 0, return_input: bool = False + ): super().__init__() self._config = config self._tensor_space: TensorSpace = tensor_space @@ -40,21 +136,19 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: i self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) - # The mixer needs to be created here for backward-compatible weight ordering. - setattr(self, self._mixer_module_name, self._create_mixer()) - - self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._tensor_space, f"{self.name} mlp", block_index=block_index + # Attribute should be mixer, but Attention uses a different name for backward compatibility. TODO: Fix. + setattr( + self, + self._config.mixer.module_name, + self._config.mixer.get_layer(self._tensor_space, block_index, f"{self.name} mixer"), ) + self.mlp = self._config.mlp.get_layer(self._tensor_space, block_index, f"{self.name} mlp") + # PEFT. self.norm_1 = self._config.peft.apply_other(self.norm_1) self.norm_2 = self._config.peft.apply_other(self.norm_2) - @abc.abstractmethod - def _create_mixer(self) -> Mixer: - pass - @torch.compile def _bias_dropout_add( self, input_: torch.Tensor, bias: torch.Tensor | None, residual: torch.Tensor @@ -113,13 +207,13 @@ def forward( hidden_states = self.norm_1(input_) if self._debug_mode: self._debug_log(hidden_states, "Norm 1", kwargs) - hidden_states, bias = getattr(self, self._mixer_module_name)(hidden_states, kwargs) + hidden_states, bias = getattr(self, self._config.mixer.module_name)(hidden_states, kwargs) if self._debug_mode: - self._debug_log(hidden_states, f"{self._mixer_module_name} output", kwargs, bias=bias) + self._debug_log(hidden_states, f"{self._config.mixer.module_name} output", kwargs, bias=bias) with set_generator(generator): input_ = self._bias_dropout_add(hidden_states, bias, input_) if self._debug_mode: - self._debug_log(input_, f"{self._mixer_module_name} residual", kwargs) + self._debug_log(input_, f"{self._config.mixer.module_name} residual", kwargs) hidden_states = self.norm_2(input_) if self._debug_mode: self._debug_log(hidden_states, "Norm 2", kwargs) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 5a999fa6d..87bd6d249 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -1,13 +1,21 @@ +import abc import enum +import functools +import typing from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.block.mlp.config import MLPConfig from fast_llm.layers.block.peft import TransformerPeftConfig from fast_llm.layers.common.config import NormalizationConfig from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.layers.block.block import Block, BlockLayer + + +# TODO: Generalize these beyond language models? (Ex. vision) + class BlockDimNames: # A set of common tensor dim names packed into a namespace. @@ -39,10 +47,76 @@ class AddLinearBiasChoices(str, enum.Enum): only_attn_qkv = "only_attn_qkv" +@config_class(registry=True) +class BlockLayerConfig(BaseModelConfig): + _abstract = True + block: "BlockConfig" = Field(init=False) + + def _validate(self) -> None: + assert hasattr(self, "block") + Assert.is_(self.block.mlp, self) + super()._validate() + + @property + def layer_class(self) -> "type[BlockLayer]": + raise NotImplementedError() + + def get_layer(self, tensor_space: TensorSpace, block_index: int, name: str) -> "BlockLayer": + return self.layer_class(self, tensor_space, block_index, name) + + +@config_class() +class MixerConfig(BlockLayerConfig): + _abstract = True + + # Needed for backward compatibility. + module_name: typing.ClassVar[str] = "mixer" + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is MixerConfig and cls.get_subclass(default.get("type")) is None: + from fast_llm.layers.transformer.config import AttentionConfig + + # Default subclass. + return AttentionConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + + @config_class() -# TODO: Use composition for MLP config -class BlockConfig(MLPConfig, BaseModelConfig): +class MLPBaseConfig(BlockLayerConfig): + _abstract = True + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is MLPBaseConfig and cls.get_subclass(default.get("type")) is None: + from fast_llm.layers.block.mlp.config import MLPConfig + # Default subclass. + return MLPConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + + +@config_class() +class BlockConfig(BaseModelConfig): + _abstract = False + mixer: MixerConfig = Field( + desc="Configuration for the mixer.", + hint=FieldHint.architecture, + ) + mlp: MLPBaseConfig = Field( + desc="Configuration for the MLP.", + hint=FieldHint.architecture, + ) # TODO: Review names normalization: NormalizationConfig = Field( desc="Configuration for the normalization layers architecture.", @@ -58,11 +132,6 @@ class BlockConfig(MLPConfig, BaseModelConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - full_precision_residual: bool = Field( - default=False, - desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", - hint=FieldHint.stability, - ) debug_transformer: int = Field( default=0, desc="Log the output of each operation in a transformer layer.", @@ -80,8 +149,45 @@ class BlockConfig(MLPConfig, BaseModelConfig): hint=FieldHint.architecture, ) + block_sequence: "BlockSequenceConfig" = Field(init=False) + + def _validate(self) -> None: + assert hasattr(self, "block_sequence") + Assert.incl(self, self.block_sequence.blocks.values()) + self.mixer.block = self + self.mlp.block = self + super()._validate() + + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + self.mlp.setup_tensor_space(tensor_space) + self.mixer.setup_tensor_space(tensor_space) + + # Hidden dimension + tensor_space.add_tensor_dim(TensorDim(BlockDimNames.hidden, self.block_sequence.hidden_size)) + + @abc.abstractmethod + def get_block(self) -> "Block": + pass + + +@config_class() +class BlockSequenceConfig(BaseModelConfig): + _abstract = True + + blocks: dict[str, BlockConfig] = Field() + block_pattern: tuple[str, ...] = Field( + default=None, + desc="The pattern of blocks (referred by name) to use. The sequence is repeated until reaching `num_blocks`." + " Default: cycle over `blocks` in the order they are defined.", + ) + default_block: str = Field( + default=None, + desc="The default block configuration to use when referring to the model." + " Used to set some defaults in the language model.", + ) + # TODO: Move these, not specific to a single block. - num_layers: int = Field( + num_blocks: int = Field( default=12, desc="Number of layers in the transformer.", hint=FieldHint.architecture, @@ -93,30 +199,28 @@ class BlockConfig(MLPConfig, BaseModelConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - per_layer_lr_scale: list[float] | None = Field( - default=None, - desc="Custom learning rate scale for each layer.", - doc="May be used to freeze some layers by setting their scale to zero.", - hint=FieldHint.feature, + full_precision_residual: bool = Field( + default=False, + desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", + hint=FieldHint.stability, ) def _validate(self) -> None: - with self._set_implicit_default(): - if self.ffn_hidden_size is None: - self.ffn_hidden_size = 4 * self.hidden_size - + for block in self.blocks.values(): + block.validate() + if self.block_pattern is None: + self.block_pattern = tuple(self.blocks) + if self.default_block is None: + self.default_block = self.block_pattern[0] super()._validate() - @property - def add_mlp_bias(self) -> bool: - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.everywhere: - return True - return False + def get_block_config(self, block_index: int) -> BlockConfig: + return self.blocks[self.block_pattern[block_index % len(self.block_pattern)]] def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - super().setup_tensor_space(tensor_space) + for block in self.blocks.values(): + block.setup_tensor_space(tensor_space) - # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(BlockDimNames.hidden, self.hidden_size)) + @functools.cached_property + def default_block_config(self) -> BlockConfig: + return self.blocks[self.default_block] diff --git a/fast_llm/layers/block/mixer.py b/fast_llm/layers/block/mixer.py deleted file mode 100644 index 5c811e330..000000000 --- a/fast_llm/layers/block/mixer.py +++ /dev/null @@ -1,68 +0,0 @@ -import abc -import typing - -import torch - -from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.block.config import BlockKwargs -from fast_llm.logging import log_distributed_grad, log_distributed_tensor -from fast_llm.tensor import TensorMeta -from fast_llm.utils import Assert - - -class Mixer(torch.nn.Module, abc.ABC): - """ - Base class for mixer modules. - """ - - _mixer_name: typing.ClassVar[str] - - def __init__(self, tensor_space: TensorSpace, block_index: int, debug_level: int = 0): - super().__init__() - self._tensor_space = tensor_space - self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel - self._block_index = block_index - self._debug_level = debug_level - - @abc.abstractmethod - def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Mixer module forward. Returns the output hidden states and an optional bias, - in case its addition can be made more efficient in `_bias_dropout_add`. - """ - - def _get_meta( - self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> TensorMeta: - hidden_dims = { - dim.name: dim for dim in kwargs[BlockKwargs.hidden_dims] + (kwargs[BlockKwargs.sequence_q_dim],) - } - return TensorMeta.from_dims( - tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space[dim_name] - for dim_name in dim_names - ), - tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", - dtype=input_.dtype, - ) - - def _debug_log( - self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> None: - # TODO: Local vs global - Assert.gt(self._debug_level, 0) - log_distributed_tensor( - "", - tensor, - level=self._debug_level, - meta=self._get_meta(tensor, name, dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=self._debug_level, - meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 1d125c4f7..526c513db 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -1,11 +1,18 @@ import enum +import functools +import typing -from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_zeros_ from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType, MLPRecomputeLevel from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.layers.block.config import AddLinearBiasChoices, BlockLayerConfig + from fast_llm.layers.block.mlp.mlp import MLPBase + class MLPDimNames: # MLP dimensions @@ -32,9 +39,10 @@ class RoutingType(str, enum.Enum): sinkhorn = "sinkhorn" -@config_class() -class MLPConfig(Config): +@config_class(dynamic_type={BlockLayerConfig: "mlp"}) +class MLPConfig(BlockLayerConfig): # TODO: Review names + # TODO: Separate MoE? _abstract = False ffn_hidden_size: int = Field( default=None, @@ -124,11 +132,52 @@ class MLPConfig(Config): " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", hint=FieldHint.expert, ) + layer_1_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for the first mlp layer weights. Default: hidden_size**-0.5", + hint=FieldHint.feature, + ) + layer_1_bias_initialization: InitializationConfig = Field( + desc="Initialization configuration for the first mlp layer biases. Default: fill with zeros.", + hint=FieldHint.feature, + ) + layer_2_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for the second mlp layer weights." + " Default: (2 * num_blocks * hidden_size)**-0.5", + hint=FieldHint.feature, + ) + layer_2_bias_initialization: InitializationConfig = Field( + desc="Initialization configuration for the second mlp layer biases. Default: fill with zeros.", + hint=FieldHint.feature, + ) + + @property + def layer_class(self) -> "type[MLPBase]": + if self.num_experts > 1: + from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP + + return MixtureOfExpertMLP + else: + from fast_llm.layers.block.mlp.mlp import MLP + + return MLP + + @property + def add_bias(self) -> bool: + if isinstance(self.block.add_linear_biases, bool): + return self.block.add_linear_biases + if self.block.add_linear_biases == AddLinearBiasChoices.everywhere: + return True + return False def _validate(self) -> None: + assert hasattr(self, "block") + with self._set_implicit_default(): if self.activation_type is None: self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu + if self.ffn_hidden_size is None: + # TODO: hidden_size not yet validated. + self.ffn_hidden_size = 4 * self.block.block_sequence.hidden_size self.num_unshared_experts = self.num_experts - self.num_shared_experts super()._validate() @@ -144,6 +193,30 @@ def _validate(self) -> None: elif self.mlp_lr_scale is not None: Assert.geq(self.mlp_lr_scale, 0) + @functools.cached_property + def layer_1_weight_initialization_method(self) -> Initializer: + if not self.layer_1_weight_initialization.has_initialization: + return self.layer_1_weight_initialization.get_initializer() + return self.block.block_sequence.hidden_size**-0.5 + + @functools.cached_property + def layer_1_bias_initialization_method(self) -> Initializer: + if not self.layer_1_bias_initialization.has_initialization: + return self.layer_1_bias_initialization.get_initializer() + return init_zeros_ + + @functools.cached_property + def layer_2_weight_initialization_method(self) -> Initializer: + if self.layer_2_weight_initialization.has_initialization: + return self.layer_2_weight_initialization.get_initializer() + return self.block.block_sequence.hidden_size**-0.5 / max(2 * self.block.block_sequence.num_blocks, 1) + + @functools.cached_property + def layer_2_bias_initialization_method(self) -> Initializer: + if self.layer_2_bias_initialization.has_initialization: + return self.layer_2_bias_initialization.get_initializer() + return init_zeros_ + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index 8d092b6dc..332d3109f 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -1,27 +1,24 @@ import logging -import typing import warnings import torch from fast_llm.core.distributed import ProcessGroup, set_generator -from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank +from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map -from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs -from fast_llm.layers.block.mlp.config import MLPDimNames, MLPLossNames, RoutingType +from fast_llm.layers.block.config import BlockDimNames, BlockKwargs +from fast_llm.layers.block.mlp.config import MLPConfig, MLPDimNames, MLPLossNames, RoutingType from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.linear import Linear -from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage -from fast_llm.tensor import TensorMeta, init_normal_ -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.utils import get_lr_scale logger = logging.getLogger(__name__) -class MixtureOfExpertMLP[ConfigType: BlockConfig](MLPBase[ConfigType]): +class MixtureOfExpertMLP[ConfigType: MLPConfig](MLPBase[ConfigType]): """ MoeLayer following implementation from https://github.com/NVIDIA/Megatron-LM/blob/46ebc0e4202c980d98900000d455f754a7ff9d4b/megatron/model/transformer.py#L346 @@ -35,23 +32,10 @@ class MixtureOfExpertMLP[ConfigType: BlockConfig](MLPBase[ConfigType]): _group: ProcessGroup - def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): - Assert.gt(config.num_experts, 1) + def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): + super().__init__(config, tensor_space, block_index, name) # TODO: Implement? - assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, tensor_space, name, block_index) - self._tensor_space = tensor_space - self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory - - self._num_experts = config.num_experts - self._experts_per_token = config.num_experts_per_token - self._num_shared_experts = config.num_shared_experts - self._num_unshared_experts = config.num_unshared_experts - - self._routing_type = config.expert_routing_type - self._load_balancing_factor = config.expert_auxiliary_loss_coefficient - self._z_loss_factor = config.expert_z_loss_coefficient - self._moe_jitter_eps = config.moe_jitter_eps + assert not self._config.add_linear_biases, "Biases not supported for MoE." layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) @@ -72,21 +56,20 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = " ) dropless_moe = False self._mlp_forward = self._forward_dropless if dropless_moe else self._forward_looped - self._dynamic_shape = config.dropless_dynamic_shape def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: hidden_states = input_.flatten(0, -2) logits = self.router(hidden_states) - if self._debug_mode: - self._debug_log(logits, "Router logits", MLPDimNames.experts, kwargs) + if self._debug.enabled: + self._debug(logits, "Router logits", kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.experts,), kwargs) # Apply z_loss if applicable - if self._z_loss_factor > 0.0: + if self._config.expert_z_loss_coefficient > 0.0: logits = z_loss( logits, - self._z_loss_factor, + self._config.expert_z_loss_coefficient, self.training, grad_scale=kwargs.get("grad_output"), losses=losses, @@ -94,24 +77,31 @@ def forward( ) # Apply input_jitter if applicable: - if self.training and self._moe_jitter_eps > 0.0: + if self.training and self._config.moe_jitter_eps > 0.0: with set_generator(self._tensor_space.distributed.pp_generator): logits = self._apply_input_jitter(logits) # Routing - if self._routing_type == RoutingType.topk: + if self._config.expert_routing_type == RoutingType.topk: scores, top_experts = self._topk_routing(logits, kwargs.get(BlockKwargs.grad_output), losses) - if self._num_shared_experts > 0: + if self._config.num_shared_experts > 0: scores, top_experts = self._add_shared_experts(top_experts, scores) - elif self._routing_type == RoutingType.sinkhorn: + elif self._config.expert_routing_type == RoutingType.sinkhorn: scores, top_experts = self._sinkhorn_routing(logits) else: - raise NotImplementedError(self._routing_type) + raise NotImplementedError(self._config.expert_routing_type) - if self._debug_mode: + if self._debug.enabled: # To log all ranks set `global_=False` - self._debug_log(scores, "Router scores", MLPDimNames.top_experts, kwargs) - self._debug_log(top_experts, "Router top experts", MLPDimNames.top_experts, kwargs) + self._debug( + scores, "Router scores", kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.top_experts,), kwargs + ) + self._debug( + top_experts, + "Router top experts", + kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.top_experts,), + kwargs, + ) return self._mlp_forward(hidden_states, scores, top_experts).view_as(input_), None # noqa @@ -119,7 +109,9 @@ def _forward_dropless( self, hidden_states: torch.Tensor, scores: torch.Tensor, top_experts: torch.Tensor ) -> torch.Tensor: # Compute token counts and the sparse mapping (dense_row, top_index) -> sparse_row. - sparse_map = get_sparse_map(top_experts, self._num_experts, dynamic_shape=self._dynamic_shape) + sparse_map = get_sparse_map( + top_experts, self._config.num_experts, dynamic_shape=self._config.dropless_dynamic_shape + ) # Sparse MLP return mlp_autograd( @@ -148,7 +140,7 @@ def _forward_looped( top_experts, self.layer_1.weight, self.layer_2.weight, - self._num_experts, + self._config.num_experts, self._config.gated, self._config.activation_type, self._intermediate_dim.parallel_group, @@ -159,7 +151,9 @@ def _forward_looped( @torch.compile def _apply_input_jitter(self, logits: torch.Tensor) -> torch.Tensor: - return logits * torch.empty_like(logits).uniform_(1.0 - self._moe_jitter_eps, 1.0 + self._moe_jitter_eps) + return logits * torch.empty_like(logits).uniform_( + 1.0 - self._config.moe_jitter_eps, 1.0 + self._config.moe_jitter_eps + ) def _topk_routing( self, @@ -167,11 +161,11 @@ def _topk_routing( grad_scale: float | None = None, losses: dict | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - top_logits, top_experts = torch.topk(logits, k=self._experts_per_token, dim=-1) + top_logits, top_experts = torch.topk(logits, k=self._config.num_experts_per_token, dim=-1) scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32) if losses is not None or (self.training and grad_scale is not None): probs = torch.softmax(logits, dim=-1, dtype=torch.float32) - mask = torch.nn.functional.one_hot(top_experts, num_classes=self._num_unshared_experts).sum(dim=1) + mask = torch.nn.functional.one_hot(top_experts, num_classes=self._config.num_unshared_experts).sum(dim=1) # Auxiliary loss, corresponding to the sum of probabilities for the top experts. # In the optimal case (uniform distribution), loss = experts_per_token / num_experts. # In the worst case (whole distribution in the top experts), loss = 1. @@ -182,7 +176,9 @@ def _topk_routing( losses[MLPLossNames.load_balancing_loss].append(aux_loss.detach()) if self.training and grad_scale is not None: scores = AuxiliaryLoss.apply( - scores, aux_loss, self._num_unshared_experts * self._load_balancing_factor * grad_scale + scores, + aux_loss, + self._config.num_unshared_experts * self._config.expert_auxiliary_loss_coefficient * grad_scale, ) return scores, top_experts @@ -191,69 +187,33 @@ def _add_shared_experts( ) -> tuple[torch.Tensor, torch.Tensor]: # Add the shared experts (last ones) to the top experts. shared_experts = torch.arange( - self._num_unshared_experts, self._num_experts, device=top_experts.device, dtype=top_experts.dtype + self._config.num_unshared_experts, + self._config.num_experts, + device=top_experts.device, + dtype=top_experts.dtype, )[None].repeat(top_experts.size(0), 1) top_experts = torch.cat((shared_experts, top_experts), dim=1) # Add scores of 1 to scores for shared experts. - scores = torch.cat((scores.new_ones(scores.size(0), self._num_shared_experts), scores), dim=1) + scores = torch.cat((scores.new_ones(scores.size(0), self._config.num_shared_experts), scores), dim=1) return scores, top_experts def _sinkhorn_routing(self, logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: if self.training: - _, top_experts = torch.topk(sinkhorn(logits), k=self._experts_per_token, dim=-1) + _, top_experts = torch.topk(sinkhorn(logits), k=self._config.num_experts_per_token, dim=-1) logits = self._sinkhorn_activation(logits) scores = torch.gather(logits, -1, top_experts) else: logits = self._sinkhorn_activation(logits) - scores, top_experts = torch.topk(logits, k=self._experts_per_token, dim=-1) + scores, top_experts = torch.topk(logits, k=self._config.num_experts_per_token, dim=-1) return scores, top_experts def _sinkhorn_activation(self, logits: torch.Tensor) -> torch.Tensor: return ( torch.sigmoid(logits) - if self._experts_per_token == 1 + if self._config.num_experts_per_token == 1 else torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) ) - def _debug_log( - self, - tensor: torch.Tensor | None, - name: str, - dim_name: str, - kwargs: dict[str, typing.Any], - global_: bool = True, - ) -> None: - if self._config.debug_transformer_memory: - log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"{self._name} {name}", str)) - if self._config.debug_transformer and tensor is not None: - # TODO: Local vs global - meta = self._get_meta(tensor, name, dim_name, kwargs) - log_distributed_tensor( - "", - tensor.view_as(meta), - level=self._config.debug_transformer, - meta=meta, - distributed=self._tensor_space.distributed, - global_=global_, - ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=self._config.debug_transformer, - meta=self._get_meta(tensor, name + " grad", dim_name, kwargs), - distributed=self._tensor_space.distributed, - grad_fn=lambda tensor_: tensor_.view_as(meta), - global_=global_, - ) - - def _get_meta(self, tensor: torch.Tensor, name: str, dim_name: str, kwargs: dict[str, typing.Any]) -> TensorMeta: - return TensorMeta.from_dims( - kwargs[BlockKwargs.hidden_dims][:-1] + (self._tensor_space[dim_name],), - tensor_name=f"{self._name} {name}", - dtype=tensor.dtype, - ) - def sinkhorn(cost: torch.Tensor, tolerance: float = 1e-5, eps=1e-9) -> torch.Tensor: """Sinkhorn based MoE routing function""" diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index 19349671e..aba5639b5 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -2,75 +2,77 @@ import torch -from fast_llm.config import Configurable -from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd -from fast_llm.layers.block.config import BlockConfig, BlockDimNames -from fast_llm.layers.block.mlp.config import MLPDimNames +from fast_llm.layers.block.block import BlockLayer +from fast_llm.layers.block.config import BlockDimNames +from fast_llm.layers.block.mlp.config import MLPConfig, MLPDimNames from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import LinearBase -from fast_llm.tensor import init_normal_, init_zeros_ -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.utils import get_lr_scale -class MLPBase[ConfigType: BlockConfig](Configurable[ConfigType], Layer): - def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): - super().__init__(config) - self._name = name - self._block_index = block_index +class MLPBase[ConfigType: MLPConfig](BlockLayer[ConfigType]): + _name: typing.ClassVar[str] = "mlp" + + def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): + super().__init__(config, tensor_space, block_index, name) init_method_1 = init_normal_( - std=config.init_method_std_mlp_1, - min_val=config.init_method_min_mlp_1, - max_val=config.init_method_max_mlp_1, + std=self._config.init_method_std_mlp_1, + min_val=self._config.init_method_min_mlp_1, + max_val=self._config.init_method_max_mlp_1, ) init_method_2 = init_normal_( - std=config.init_method_std_mlp_2, - min_val=config.init_method_min_mlp_2, - max_val=config.init_method_max_mlp_2, + std=self._config.init_method_std_mlp_2, + min_val=self._config.init_method_min_mlp_2, + max_val=self._config.init_method_max_mlp_2, ) - hidden_dim = tensor_space[BlockDimNames.hidden] - self._intermediate_dim = tensor_space[MLPDimNames.composite_expert_mlp] - self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel + hidden_dim = self._tensor_space[BlockDimNames.hidden] + self._intermediate_dim = self._tensor_space[MLPDimNames.composite_expert_mlp] self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - lr_scale = tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale + layer_lr_scale = ( + self._config.block.block_sequence.per_layer_lr_scale[self._block_index] + if self._config.block.block_sequence.per_layer_lr_scale + else None + ) + lr_scale = ( + tuple(self._config.mlp_lr_scale) + if isinstance(self._config.mlp_lr_scale, list) + else self._config.mlp_lr_scale + ) lr_scale = get_lr_scale(lr_scale, layer_lr_scale) # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - tensor_space[MLPDimNames.composite_gated_expert_mlp], - bias=config.add_mlp_bias, + self._tensor_space[MLPDimNames.composite_gated_expert_mlp], + bias=self._config.add_bias, weight_init_method=init_method_1, - bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, + bias_init_method=init_method_1 if self._config.random_bias_init else init_zeros_, lr_scale=lr_scale, ) self.layer_2 = LinearBase( self._intermediate_dim, hidden_dim, - bias=config.add_mlp_bias, + bias=self._config.add_bias, weight_init_method=init_method_2, - bias_init_method=init_method_2 if config.random_bias_init else init_zeros_, - auto_bias_grad_accumulation=tensor_space.distributed_config.tensor_parallel > 1, + bias_init_method=init_method_2 if self._config.random_bias_init else init_zeros_, + auto_bias_grad_accumulation=self._tensor_space.distributed_config.tensor_parallel > 1, transposed_weight=True, lr_scale=lr_scale, ) # PEFT. - self.layer_1 = config.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1) - self.layer_2 = config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) - + self.layer_1 = self._config.block.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1) + self.layer_2 = self._config.block.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) -class MLP[ConfigType: BlockConfig](MLPBase[ConfigType]): - def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): - Assert.eq(config.num_experts, 1) - super().__init__(config, tensor_space, name, block_index) +class MLP[ConfigType: MLPConfig](MLPBase[ConfigType]): def forward( self, input_: torch.Tensor, diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 9d5ce3f3b..2f45fdf9f 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -87,7 +87,7 @@ class LayerNormalizationBaseConfig(NormalizationConfig): ) def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": - from fast_llm.tensor import init_uniform_centered_ + from fast_llm.engine.config_utils.initialization import init_uniform_centered_ kwargs = { "hidden_dim": hidden_dim, diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear.py index 7249ef569..740b4847c 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear.py @@ -3,6 +3,7 @@ import torch +from fast_llm.engine.config_utils.initialization import init_zeros_ from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.linear import ( @@ -14,7 +15,7 @@ output_parallel_linear_backward, output_parallel_linear_forward, ) -from fast_llm.tensor import ParameterMeta, init_zeros_ +from fast_llm.tensor import ParameterMeta logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index bccc1d627..d44be3297 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -1,11 +1,12 @@ import torch +from fast_llm.engine.config_utils.initialization import init_ones_, init_zeros_ from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.normalization import triton_normalization_autograd from fast_llm.layers.common.config import NormalizationImplementation -from fast_llm.tensor import ParameterMeta, accumulate_gradient, init_ones_, init_zeros_ +from fast_llm.tensor import ParameterMeta, accumulate_gradient from fast_llm.utils import Assert try: diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index b667e5318..2e7d71963 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,13 +1,11 @@ -import typing +import functools from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl -from fast_llm.layers.block.config import BlockDimNames, BlockKwargs -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.layers.transformer.rotary.config import NoRotaryConfig +from fast_llm.layers.block.config import BlockDimNames, BlockKwargs, BlockSequenceConfig from fast_llm.utils import Assert @@ -46,27 +44,27 @@ class LanguageModelKwargs(BlockKwargs): @config_class() -class LanguageModelBaseConfig(BaseModelConfig): - # TODO: block - transformer: TransformerConfig = Field( - desc="Configuration for the transformer architecture.", +class LanguageModelConfig(BlockSequenceConfig): + decoder: BlockSequenceConfig = Field( hint=FieldHint.architecture, ) - max_position_embeddings: int = Field( - default=2048, - desc="Number of absolute position embeddings, if applicable.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) vocab_size: int = Field( default=49152, desc="Size of the vocabulary, i.e., number of vocabulary embeddings and logits.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - use_position_embeddings: bool = Field( + embedding_dropout: float = Field( + # TODO: backward compatibility? + default=0.0, + desc="Dropout applied to the embedding layer.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + absolute_position_embeddings: int | None = Field( + # TODO: backward compatibility? default=None, - desc="Enable absolute position embeddings. Default: Enable unless using rotary embeddings.", + desc="Number of absolute position embeddings, if applicable.", hint=FieldHint.architecture, ) tie_word_embeddings: bool = Field( @@ -80,22 +78,6 @@ class LanguageModelBaseConfig(BaseModelConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - init_method_std_embed: float = Field( - default=None, - desc="Initialization scale for the vocabulary embedding and output weights (logits).", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - init_method_max_embed: float | None = Field( - default=None, - desc="Max value for clamping initialized weights of the vocabulary embedding and output (logits).", - hint=FieldHint.feature, - ) - init_method_min_embed: float | None = Field( - default=None, - desc="Min value for clamping initialized weights of the vocabulary embedding and output (logits).", - hint=FieldHint.feature, - ) enable_dpo: bool | None = Field( default=False, desc="Whether to enable DPO loss", @@ -203,26 +185,27 @@ class LanguageModelBaseConfig(BaseModelConfig): doc="If not provided, all heads are equally weighted.", hint=FieldHint.feature, ) + word_embedding_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for word embeddings. Default: hidden_size**-0.5", + hint=FieldHint.feature, + ) + position_embedding_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for position embeddings. Default: hidden_size**-0.5", + hint=FieldHint.feature, + ) + output_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for untied output weights. Default: hidden_size**-0.5", + hint=FieldHint.feature, + ) def _validate(self) -> None: - self.transformer.validate() with self._set_implicit_default(): if self.language_model_loss_factor is None: if self.distillation_model is None: self.language_model_loss_factor = 1.0 else: self.language_model_loss_factor = 0.0 - if self.use_position_embeddings is None: - self.use_position_embeddings = isinstance(self.transformer.rotary, NoRotaryConfig) - if self.init_method_std_embed is None: - self.init_method_std_embed = self.transformer.init_method_std - if self.init_method_max_embed is None: - self.init_method_max_embed = self.transformer.init_method_max - if self.init_method_min_embed is None: - self.init_method_min_embed = self.transformer.init_method_min super()._validate() - if self.init_method_max_embed is not None and self.init_method_min_embed is not None: - Assert.leq(self.init_method_min_embed, self.init_method_max_embed) if self.distillation_model is not None: if self.prediction_heads > 1: raise NotImplementedError("Multi-token prediction not supported with distillation.") @@ -230,43 +213,40 @@ def _validate(self) -> None: Assert.eq(len(self.prediction_loss_coefficient), self.prediction_heads) for coeff in self.prediction_loss_coefficient: Assert.geq(coeff, 0) - if self.transformer.per_layer_lr_scale is not None: - # -1 because the first prediction head's transformer layer is accounted for in num_layers - # +1 because the layer index starts at 1 - Assert.eq( - len(self.transformer.per_layer_lr_scale), self.transformer.num_layers + self.prediction_heads - 1 + 1 - ) + + if self.output_weight_initialization.has_initialization: + assert self.use_absolute_position_embeddings + if self.output_weight_initialization.has_initialization: + assert not self.tie_word_embeddings def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - self.transformer.setup_tensor_space(tensor_space) + super().setup_tensor_space(tensor_space) tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Embedding dimensions - tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.position_embed, self.max_position_embeddings)) + if self.use_absolute_position_embeddings: + tensor_space.add_tensor_dim( + TensorDim(LanguageModelDimNames.position_embed, self.absolute_position_embeddings) + ) # TODO: Need both? tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab, self.vocab_size)) tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab_tp, self.vocab_size, tensor)) - @property - def num_absolute_position_embeddings(self) -> int: - # TODO: Rename from max embeddings. - return self.max_position_embeddings if self.use_absolute_position_embeddings else None + @functools.cached_property + def word_embedding_weight_initialization_method(self) -> Initializer: + if self.word_embedding_weight_initialization.has_initialization: + return self.word_embedding_weight_initialization.get_initializer() + else: + return self.hidden_size**-0.5 @property def use_absolute_position_embeddings(self) -> int: # TODO: Set through num embeddings instead instead. - return self.use_position_embeddings - - @classmethod - def from_flat_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - ) -> typing.Self: - # The backward compatibility fix in `NormalizationArchitectureConfig` - # won't work for older checkpoints saved with a flat config. - # TODO v0.3: Remove flat format - cls._handle_renamed_field(default, "normalization_type", "type") - cls._handle_renamed_field(default, "layer_norm_eps", "epsilon") - cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") - return super().from_flat_dict(default, strict) + return self.absolute_position_embeddings is not None + + @functools.cached_property + def output_weight_initialization_method(self) -> Initializer: + if self.output_weight_initialization.has_initialization: + return self.output_weight_initialization.get_initializer() + else: + return self.hidden_size**-0.5 diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 05678a700..b49fef7ba 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -7,28 +7,28 @@ from fast_llm.core.ops import reduce_forward, split from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelDimNames, LanguageModelKwargs -from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ +from fast_llm.layers.language_model.config import LanguageModelConfig, LanguageModelDimNames, LanguageModelKwargs +from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import Assert WORD_EMBEDDINGS_WEIGHT = "word_embeddings_weight" -class LanguageModelEmbedding[ConfigType: LanguageModelBaseConfig](Configurable[LanguageModelBaseConfig], Layer): +class LanguageModelEmbedding[ConfigType: LanguageModelConfig](Configurable[ConfigType], Layer): """ A language model embedding layer. Consists of word embeddings (tensor-parallel or sequence-tensor-parallel), together with optional absolute position embeddings and dropout. """ - config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = LanguageModelBaseConfig + config_class: typing.ClassVar[type[LanguageModelConfig]] = LanguageModelConfig # Ensure the layer is on its own stage. layer_count: float = 1000.0 def __init__( self, - config: LanguageModelBaseConfig, + config: LanguageModelConfig, tensor_space: TensorSpace, ): super().__init__(config) @@ -36,14 +36,14 @@ def __init__( self._tensor_space = tensor_space self._residual_dtype = ( self._distributed_config.optimization_dtype - if config.transformer.full_precision_residual + if self._config.full_precision_residual else self._distributed_config.training_dtype ).torch self._group_size = self._distributed_config.tensor_parallel self._sequence_parallel = self._distributed_config.sequence_tensor_parallel - self._parallel_embeddings = tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings - self._dropout_p = config.transformer.hidden_dropout - self._use_absolute_position_embeddings = config.use_absolute_position_embeddings + self._parallel_embeddings = ( + tensor_space.distributed_config.tensor_parallel > 1 and self._config.parallel_embeddings + ) hidden_dim = tensor_space[LanguageModelDimNames.hidden] vocab_dim = tensor_space[ @@ -56,23 +56,15 @@ def __init__( self.word_embeddings_weight = ParameterMeta.from_dims( (vocab_dim, hidden_dim), - init_method=init_normal_( - std=config.init_method_std_embed, - min_val=config.init_method_min_embed, - max_val=config.init_method_max_embed, - ), - lr_scale=config.embeddings_lr_scale, + init_method=self._config.word_embedding_weight_initialization_method, + lr_scale=self._config.embeddings_lr_scale, ) - if self._use_absolute_position_embeddings: + if self._config.use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( (tensor_space[LanguageModelDimNames.position_embed], hidden_dim), - init_method=init_normal_( - std=config.init_method_std_embed, - min_val=config.init_method_min_embed, - max_val=config.init_method_max_embed, - ), - allow_sequence_tensor_parallel=not config.parallel_embeddings, - lr_scale=config.embeddings_lr_scale, + init_method=self._config.position_embedding_weight_initialization_method, + allow_sequence_tensor_parallel=not self._config.parallel_embeddings, + lr_scale=self._config.embeddings_lr_scale, ) # PEFT. @@ -84,21 +76,21 @@ def __init__( @torch.compile def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask_inputs: bool) -> torch.Tensor: - Assert.eq(position_ids is not None, self._use_absolute_position_embeddings) + Assert.eq(position_ids is not None, self._config.use_absolute_position_embeddings) group = self._tensor_space.distributed.tensor_group if self._parallel_embeddings: input_mask = (input_ >= self._vocab_start_index) * (input_ < self._vocab_end_index) masked_input = (input_ - self._vocab_start_index) * input_mask embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * input_mask.unsqueeze(2) # noqa embeddings = reduce_forward(embeddings, group) - if self._use_absolute_position_embeddings: + if self._config.use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) else: if self._sequence_parallel: input_ = split(input_, group=group, dim=0) - if self._use_absolute_position_embeddings: + if self._config.use_absolute_position_embeddings: position_ids = split(position_ids, group=group, dim=0) # handle masked tokens if mask_inputs: @@ -107,7 +99,7 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask embeddings = torch.embedding(self.word_embeddings_weight, masked_input) else: embeddings = torch.embedding(self.word_embeddings_weight, input_) - if self._use_absolute_position_embeddings: + if self._config.use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if mask_inputs: embeddings = embeddings * input_mask.unsqueeze(2) @@ -116,7 +108,7 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask if self._sequence_parallel else self._tensor_space.distributed.pp_generator ): - embeddings = torch.dropout(embeddings, self._dropout_p, self.training) + embeddings = torch.dropout(embeddings, self._config.embedding_dropout, self.training) return embeddings.to(dtype=self._residual_dtype) def forward( diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index bc672725c..098b2463b 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -15,16 +15,16 @@ from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward +from fast_llm.layers.block.block import DebugLayer from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.language_model.config import ( - LanguageModelBaseConfig, + LanguageModelConfig, LanguageModelDimNames, LanguageModelKwargs, LanguageModelLossNames, ) from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT -from fast_llm.logging import log_distributed_tensor -from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ +from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import Assert, div, get_unique logger = logging.getLogger(__name__) @@ -32,61 +32,67 @@ OUTPUT_WEIGHTS = "output_weights" -class LanguageModelHead[ConfigType: LanguageModelBaseConfig](Configurable[LanguageModelBaseConfig], Layer): +class LanguageModelHead[ConfigType: LanguageModelConfig](Configurable[ConfigType], Layer): """ A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). """ - config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = LanguageModelBaseConfig + config_class: typing.ClassVar[type[LanguageModelConfig]] = LanguageModelConfig def __init__( self, - config: LanguageModelBaseConfig, + config: LanguageModelConfig, tensor_space: TensorSpace, prediction_distance: int, ): super().__init__(config) - self._debug_transformer = config.transformer.debug_transformer - self._tie_word_embeddings = config.tie_word_embeddings + # TODO: Avoid default_block_config? + self._debug = DebugLayer( + tensor_space, + f"Block {self._block_index} {self._name}", + self._config.default_block_config.debug_transformer, + self._config.default_block_config.debug_transformer_memory, + ) self._tensor_space = tensor_space self._group_size = tensor_space.distributed_config.tensor_parallel self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel - self._parallel_embeddings = tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings + self._parallel_embeddings = ( + tensor_space.distributed_config.tensor_parallel > 1 and self._config.parallel_embeddings + ) self._sequence_parallel_logits = ( - tensor_space.distributed_config.sequence_tensor_parallel and not config.parallel_embeddings + tensor_space.distributed_config.sequence_tensor_parallel and not self._config.parallel_embeddings ) - self._cross_entropy_splits = config.cross_entropy_splits + self._cross_entropy_splits = self._config.cross_entropy_splits if self._cross_entropy_splits is not None and self._sequence_parallel: assert not self._parallel_embeddings hidden_dim = self._tensor_space[LanguageModelDimNames.hidden] self._loss_coefficient = ( - config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0 + self._config.prediction_loss_coefficient[prediction_distance] + if self._config.prediction_loss_coefficient + else 1.0 ) self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance) - self.final_norm = config.transformer.normalization.get_layer(hidden_dim) - self._logits_scale_factor = config.logits_scale_factor - self._language_model_loss_factor = config.language_model_loss_factor - self._distillation_loss_factor = config.distillation_loss_factor - self._z_loss_factor = config.logit_z_loss + # TODO: Avoid default_block_config? + self.final_norm = self._config.default_block_config.normalization.get_layer(hidden_dim) + self._logits_scale_factor = self._config.logits_scale_factor + self._language_model_loss_factor = self._config.language_model_loss_factor + self._distillation_loss_factor = self._config.distillation_loss_factor + self._z_loss_factor = self._config.logit_z_loss # Distance of the target token prediction # 0: next-token prediction # >0: multi-token prediction (MTP) Assert.geq(prediction_distance, 0) self._prediction_distance = prediction_distance - self._is_last_head = self._prediction_distance == config.prediction_heads - 1 + self._is_last_head = self._prediction_distance == self._config.prediction_heads - 1 - self._init_output_weights(hidden_dim, config) + self._init_output_weights(hidden_dim, self._config) - self._use_dpo_loss = config.enable_dpo - if self._use_dpo_loss: - self.dpo_beta = config.dpo_beta - else: - self._cross_entropy_impl = config.cross_entropy_impl - self._distillation_loss_implementation = config.distillation_loss_implementation + if not self._config.enable_dpo: + self._cross_entropy_impl = self._config.cross_entropy_impl if self._cross_entropy_impl == CrossEntropyImpl.auto: if self._parallel_embeddings: self._cross_entropy_impl = CrossEntropyImpl.fused @@ -104,7 +110,7 @@ def __init__( def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: # Only the first head defines the output weights - if self._tie_word_embeddings or self._prediction_distance > 0: + if self._config.tie_word_embeddings or self._prediction_distance > 0: return # untie embedding weights vocab_dim = self._tensor_space[ @@ -112,11 +118,7 @@ def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: ] self.output_weights = ParameterMeta.from_dims( (vocab_dim, hidden_dim), - init_method=init_normal_( - std=config.init_method_std_embed, - min_val=config.init_method_min_embed, - max_val=config.init_method_max_embed, - ), + init_method=self._config.output_weight_initialization_method, lr_scale=config.output_lr_scale, ) @@ -201,7 +203,7 @@ def _get_targets( self, kwargs: dict ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | None: # Loss mask for distillation. (Labels are already masked.) - if self._use_dpo_loss: + if self._config.enable_dpo: dpo_target = kwargs.get(LanguageModelKwargs.labels) lm_target = None distillation_target = None @@ -251,7 +253,7 @@ def _get_targets( return targets def _get_output_weights(self, kwargs: dict) -> torch.Tensor: - if self._tie_word_embeddings: + if self._config.tie_word_embeddings: return kwargs[WORD_EMBEDDINGS_WEIGHT] if self._prediction_distance > 0: return kwargs[OUTPUT_WEIGHTS] @@ -338,35 +340,22 @@ def _logits_cross_entropy_forward_backward( LanguageModelLossNames.z_loss, logits_scale_factor=self._logits_scale_factor, ) - if self._debug_transformer and self._cross_entropy_splits is None: - vocab_dim = self._tensor_space[ + if self._debug.enabled and self._cross_entropy_splits is None: + vocab_dim = ( LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp - ] - dims = [*kwargs[LanguageModelKwargs.hidden_dims][:-1], vocab_dim] - sequence_index = 1 - int(kwargs[LanguageModelKwargs.sequence_first]) - dims[sequence_index] = ( - TensorDim( - LanguageModelDimNames.sequence_q_tp, dims[sequence_index].global_size, DistributedDimNames.tensor - ) - if self._sequence_parallel_logits - else TensorDim(LanguageModelDimNames.sequence_q, dims[sequence_index].global_size) ) - - dim_names = ( - [LanguageModelDimNames.sequence_q_tp, LanguageModelDimNames.vocab] + sequence_dim = ( + LanguageModelDimNames.sequence_q_tp if self._sequence_parallel_logits - else [LanguageModelDimNames.sequence_q, LanguageModelDimNames.vocab_tp] + else LanguageModelDimNames.sequence_q ) - - dim_names.insert(int(kwargs[LanguageModelKwargs.sequence_first]), LanguageModelDimNames.batch) - log_distributed_tensor( - "", - logits, - level=self._debug_transformer, - meta=TensorMeta.from_dims(tuple(dims), tensor_name="transformer logits", dtype=logits.dtype), - distributed=self._tensor_space.distributed, - scale=self._logits_scale_factor, + batch_dim = kwargs[LanguageModelKwargs.hidden_dims][1 if kwargs[LanguageModelKwargs.sequence_first] else 0] + dims = ( + (sequence_dim, batch_dim, vocab_dim) + if kwargs[LanguageModelKwargs.sequence_first] + else (batch_dim, sequence_dim, vocab_dim) ) + self._debug(logits, "Language model logits", dims, kwargs, scale=self._logits_scale_factor) if targets is None: return logits * self._logits_scale_factor, None @@ -379,7 +368,7 @@ def _logits_cross_entropy_forward_backward( kwargs.get(f"{self._config.dpo_reference_model}_logits"), kwargs[LanguageModelKwargs.chosen_spans], kwargs[LanguageModelKwargs.rejected_spans], - self.dpo_beta, + self._config.dpo_beta, grad_output * self._loss_coefficient, ) else: @@ -401,7 +390,7 @@ def _logits_cross_entropy_forward_backward( lm_loss, lm_grad = None, None if distillation_target is not None and self._distillation_loss_factor > 0.0: - if self._distillation_loss_implementation == DistillationLossImpl.reverse_kl: + if self._config.distillation_loss_implementation == DistillationLossImpl.reverse_kl: distillation_loss, distillation_grad = reverse_kl_forward_backward( logits.flatten(0, -2), distillation_target, @@ -414,7 +403,7 @@ def _logits_cross_entropy_forward_backward( TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits ), ) - elif self._distillation_loss_implementation == DistillationLossImpl.cross_entropy: + elif self._config.distillation_loss_implementation == DistillationLossImpl.cross_entropy: distillation_loss, distillation_grad = cross_entropy_forward_backward( logits.flatten(0, -2), distillation_target, diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index f5d915855..3c9f18c8d 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -5,7 +5,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs +from fast_llm.layers.language_model.config import LanguageModelConfig, LanguageModelKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert @@ -20,11 +20,11 @@ class PositionEmbeddingPreprocessor(Preprocessor): def __init__( self, - config: LanguageModelBaseConfig, + config: LanguageModelConfig, tensor_space: TensorSpace, ): self._config = config - assert config.use_absolute_position_embeddings + assert config.absolute_position_embeddings is not None self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] @@ -34,7 +34,7 @@ def _create_tensors(self, sequence_length: int) -> None: return self._tensor_cache_max_sequence_length = sequence_length - Assert.leq(sequence_length, self._config.num_absolute_position_embeddings) + Assert.leq(sequence_length, self._config.absolute_position_embeddings) self._position_ids = torch.arange( 0, sequence_length, device=self._tensor_space.distributed.device, dtype=torch.int64 ) @@ -71,7 +71,7 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: class PreferenceSpanPreprocessor(Preprocessor): - def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): + def __init__(self, config: LanguageModelConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index efcf2d873..00c709814 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -9,7 +9,7 @@ from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: - from fast_llm.tensor import Initializer + from fast_llm.engine.config_utils.initialization import Initializer, init_fill_, init_uniform_centered_ class SSMDimNames(BlockDimNames): @@ -66,8 +66,6 @@ class DTInitType(enum.StrEnum): random = "random" def get_init_method(self, scale: float) -> "Initializer": - from fast_llm.tensor import init_fill_, init_uniform_centered_ - return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 550c44d0f..04b27af47 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -4,13 +4,15 @@ import einops import torch +from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_, init_zeros_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ +from fast_llm.layers.ssm.mamba_layer import init_kaiming_ +from fast_llm.tensor import ParameterMeta from fast_llm.utils import get_lr_scale logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 1c319f490..b02fbd401 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -3,14 +3,15 @@ import torch +from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias -from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_ +from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias, init_kaiming_ +from fast_llm.tensor import ParameterMeta from fast_llm.utils import Assert, div, get_lr_scale try: diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index f5b0139cf..e22852fe6 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -4,13 +4,14 @@ import torch +from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.tensor import LambdaInitializer, ParameterMeta, init_kaiming_, init_ones_ +from fast_llm.tensor import ParameterMeta from fast_llm.utils import Assert, get_lr_scale try: @@ -163,3 +164,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if kwargs[BlockKwargs.sequence_first]: out = out.transpose(0, 1) return out, None + + +def init_kaiming_(d_in: float) -> LambdaInitializer: + return init_normal_(0.0, math.sqrt(2.0 / d_in)) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index b1de792e3..2db7b0ac8 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -6,11 +6,10 @@ from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.autograd import wrap_forward_backward -from fast_llm.layers.block.mixer import Mixer +from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs -from fast_llm.tensor import init_normal_, init_zeros_ +from fast_llm.layers.transformer.config import AttentionConfig, AttentionDimNames, AttentionKwargs from fast_llm.utils import get_lr_scale try: @@ -46,55 +45,52 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention(Mixer): +class Attention[ConfigType: AttentionConfig](BlockLayer[ConfigType]): """ A self-attention layer. """ - _mixer_name: typing.ClassVar[str] = "attn" - _QUERY_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.composite_heads, - TransformerDimNames.kv_channels, + AttentionDimNames.batch, + AttentionDimNames.sequence_q, + AttentionDimNames.composite_heads, + AttentionDimNames.kv_channels, ) _KV_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.head_groups, - TransformerDimNames.kv_channels, + AttentionDimNames.batch, + AttentionDimNames.sequence_q, + AttentionDimNames.head_groups, + AttentionDimNames.kv_channels, ) _CONTEXT_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.composite_dense, + AttentionDimNames.batch, + AttentionDimNames.sequence_q, + AttentionDimNames.composite_dense, ) - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int): - super().__init__(tensor_space, block_index, config.debug_transformer) + def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): + super().__init__(config, tensor_space, block_index, name) self._config = config self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) - init_method_qkv = init_normal_( - std=self._config.init_method_std_qkv, - min_val=self._config.init_method_min_qkv, - max_val=self._config.init_method_max_qkv, - ) - init_method_std_attn_proj = init_normal_( - std=self._config.init_method_std_attn_proj, - min_val=self._config.init_method_min_attn_proj, - max_val=self._config.init_method_max_attn_proj, - ) - - self._kv_channels = self._tensor_space[TransformerDimNames.kv_channels].size - self._head_groups = self._tensor_space[TransformerDimNames.head_groups].global_size - self._local_head_groups = self._tensor_space[TransformerDimNames.head_groups].size - self._local_heads_per_group = self._tensor_space[TransformerDimNames.group_heads].size + # init_method_qkv = init_normal_( + # std=self._config.init_method_std_qkv, + # min_val=self._config.init_method_min_qkv, + # max_val=self._config.init_method_max_qkv, + # ) + # init_method_std_attn_proj = init_normal_( + # std=self._config.init_method_std_attn_proj, + # min_val=self._config.init_method_min_attn_proj, + # max_val=self._config.init_method_max_attn_proj, + # ) + self._kv_channels = self._tensor_space[AttentionDimNames.kv_channels].size + self._head_groups = self._tensor_space[AttentionDimNames.head_groups].global_size + self._local_head_groups = self._tensor_space[AttentionDimNames.head_groups].size + self._local_heads_per_group = self._tensor_space[AttentionDimNames.group_heads].size self._local_heads = self._local_head_groups * self._local_heads_per_group - self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) + self._softmax_scale: float = self._kv_channels ** (-self._config.attention_softmax_scale_power) - hidden_dim = self._tensor_space[TransformerDimNames.hidden] + hidden_dim = self._tensor_space[AttentionDimNames.hidden] layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) @@ -102,19 +98,19 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, - self._tensor_space[TransformerDimNames.composite_query], + self._tensor_space[AttentionDimNames.composite_query], bias=self._config.add_attn_qkv_bias, - weight_init_method=init_method_qkv, - bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, + weight_init_method=self._config.qkv_weight_initialization_method, + bias_init_method=self._config.qkv_bias_initialization_method, sequence_parallel=self._sequence_parallel, lr_scale=attention_lr_scale, ) self.key_value = OutputParallelLinear( hidden_dim, - self._tensor_space[TransformerDimNames.composite_key_value], + self._tensor_space[AttentionDimNames.composite_key_value], bias=self._config.add_attn_qkv_bias, - weight_init_method=init_method_qkv, - bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, + weight_init_method=self._config.qkv_weight_initialization_method, + bias_init_method=self._config.qkv_bias_initialization_method, sequence_parallel=self._sequence_parallel, lr_scale=attention_lr_scale, ) @@ -125,11 +121,11 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i # Output. self.dense = InputParallelLinear( - self._tensor_space[TransformerDimNames.composite_dense], + self._tensor_space[AttentionDimNames.composite_dense], hidden_dim, bias=self._config.add_attn_dense_bias, - weight_init_method=init_method_std_attn_proj, - bias_init_method=init_method_std_attn_proj if self._config.random_bias_init else init_zeros_, + weight_init_method=self._config.dense_weight_initialization_method, + bias_init_method=self._config.dense_bias_initialization_method, sequence_parallel=self._sequence_parallel, lr_scale=attention_lr_scale, ) @@ -259,18 +255,24 @@ def _decide_window_size(self) -> int | None: return window_size - def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: - sequence_first = kwargs[TransformerKwargs.sequence_first] + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + sequence_first = kwargs[AttentionKwargs.sequence_first] query, key_value = self._query_key_value(input_, sequence_first) # TODO: Move the rest to function. - if (past_key_values := kwargs.get(TransformerKwargs.past_key_values)) is not None: + if (past_key_values := kwargs.get(AttentionKwargs.past_key_values)) is not None: assert sequence_first # Clear the lists so tensors can be de-allocated key_value = torch.cat((past_key_values.pop(0), key_value), dim=0) - if (presents := kwargs.get(TransformerKwargs.presents)) is not None: + if (presents := kwargs.get(AttentionKwargs.presents)) is not None: # Return the presents as a leaf tensors so the gradients from later micro-sequences # don't propagate to this one. presents.append(present := key_value.detach().requires_grad_()) @@ -279,9 +281,9 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if self._tensor_space.distributed.sequence_data_group: key_value = ( - key_value[: kwargs[TransformerKwargs.sequence_k_dim].size] + key_value[: kwargs[AttentionKwargs.sequence_k_dim].size] if sequence_first - else key_value[:, : kwargs[TransformerKwargs.sequence_k_dim].size] + else key_value[:, : kwargs[AttentionKwargs.sequence_k_dim].size] ) if sequence_first: @@ -295,9 +297,9 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels) value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels) - if self._debug_level: - self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs) - self._debug_log( + if self._debug.enabled: + self._debug(query, "query_rotary_input", self._QUERY_DIMS, kwargs) + self._debug( key, "key_rotary_input", self._KV_DIMS, @@ -310,7 +312,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if self._use_flash_attention: assert _flash_available with set_generator(self._tensor_space.distributed.tp_generator): - if (cu_seqlens_q := kwargs.get(TransformerKwargs.cu_seqlens_q, None)) is not None: + if (cu_seqlens_q := kwargs.get(AttentionKwargs.cu_seqlens_q, None)) is not None: out_dims = query.size() query = query.view(-1, query.size(-2), query.size(-1)) key = key.view(-1, key.size(-2), key.size(-1)) @@ -320,9 +322,9 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key, value, cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=kwargs.get(TransformerKwargs.cu_seqlens_k), - max_seqlen_q=kwargs.get(TransformerKwargs.max_seqlen_q), - max_seqlen_k=kwargs.get(TransformerKwargs.max_seqlen_k), + cu_seqlens_k=kwargs.get(AttentionKwargs.cu_seqlens_k), + max_seqlen_q=kwargs.get(AttentionKwargs.max_seqlen_q), + max_seqlen_k=kwargs.get(AttentionKwargs.max_seqlen_k), dropout_p=self._config.attention_dropout if self.training else 0.0, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), causal=True, @@ -345,25 +347,15 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ query.flatten(-2), key.flatten(-2), value.flatten(-2), - kwargs[TransformerKwargs.attention_mask], - kwargs[TransformerKwargs.attention_mask_value], + kwargs[AttentionKwargs.attention_mask], + kwargs[AttentionKwargs.attention_mask_value], ) - if self._debug_level: - self._debug_log(query, "query", self._QUERY_DIMS, kwargs) - self._debug_log( - key, - "key", - self._KV_DIMS, - kwargs, - ) - self._debug_log( - value, - "value", - self._KV_DIMS, - kwargs, - ) - self._debug_log(input_, "context", self._CONTEXT_DIMS, kwargs) + if self._debug.enabled: + self._debug(query, "query", self._QUERY_DIMS, kwargs) + self._debug(key, "key", self._KV_DIMS, kwargs) + self._debug(value, "value", self._KV_DIMS, kwargs) + self._debug(input_, "context", self._CONTEXT_DIMS, kwargs) if sequence_first: # TODO: Optimize (is contiguous avoidable? Transpose dense output?) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index ebb976e63..bd72bd305 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -3,22 +3,29 @@ import typing import warnings -from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_zeros_ from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import TritonConfig -from fast_llm.layers.block.config import AddLinearBiasChoices, BlockConfig, BlockDimNames, BlockKwargs +from fast_llm.layers.block.config import ( + AddLinearBiasChoices, + BlockDimNames, + BlockKwargs, + BlockLayerConfig, + MixerConfig, +) from fast_llm.layers.transformer.rotary.config import RotaryConfig from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: - pass + from fast_llm.layers.transformer.attention import Attention logger = logging.getLogger(__name__) -class TransformerDimNames(BlockDimNames): +class AttentionDimNames(BlockDimNames): # A set of common tensor dim names packed into a namespace. # Self-attention dimensions head_groups = "head_groups" @@ -31,7 +38,7 @@ class TransformerDimNames(BlockDimNames): composite_dense = "composite_dense" -class TransformerKwargs(BlockKwargs): +class AttentionKwargs(BlockKwargs): rotary_freq_q = "rotary_freq_q" rotary_freq_k = "rotary_freq_k" attention_mask = "attention_mask" @@ -45,9 +52,8 @@ class TransformerKwargs(BlockKwargs): past_key_values = "past_key_values" -@config_class() -class AttentionConfig(Config): - # TODO: Make mixer class dynamic. +@config_class(dynamic_type={BlockLayerConfig: "attention"}) +class AttentionConfig(MixerConfig): _abstract = False # TODO: Review names @@ -107,7 +113,30 @@ class AttentionConfig(Config): valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) + qkv_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for the query, key and value layer weights. Default: hidden_size**-0.5", + hint=FieldHint.feature, + ) + qkv_bias_initialization: InitializationConfig = Field( + desc="Initialization configuration for the query, key and value layer biases. Default: fill with zeros.", + hint=FieldHint.feature, + ) + dense_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for the dense layer weight. Default: (2 * num_blocks * hidden_size)**-0.5", + hint=FieldHint.feature, + ) + dense_bias_initialization: InitializationConfig = Field( + desc="Initialization configuration for the dense layer biases. Default: fill with zeros.", + hint=FieldHint.feature, + ) + def _validate(self) -> None: + + with self._set_implicit_default(): + if self.kv_channels is None: + # TODO: hidden_size not yet validated. + self.kv_channels = div(self.block.block_sequence.hidden_size, self.num_attention_heads) + super()._validate() if not TritonConfig.TRITON_ENABLED: @@ -130,182 +159,74 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor_space.add_tensor_dim( head_groups := TensorDim( - TransformerDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None + AttentionDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None ) ) tensor_space.add_tensor_dim( group_heads := TensorDim( - TransformerDimNames.group_heads, + AttentionDimNames.group_heads, div(self.num_attention_heads, self.head_groups), None if self.head_groups > 1 else tensor, ) ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(TransformerDimNames.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(TransformerDimNames.kv_channels, self.kv_channels)) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_heads, (head_groups, group_heads)) - ) + tensor_space.add_tensor_dim(key_and_value := TensorDim(AttentionDimNames.key_and_value, 2)) + tensor_space.add_tensor_dim(kv_channels := TensorDim(AttentionDimNames.kv_channels, self.kv_channels)) + tensor_space.add_tensor_dim(CompositeTensorDim(AttentionDimNames.composite_heads, (head_groups, group_heads))) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_query, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(AttentionDimNames.composite_query, (head_groups, group_heads, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) + CompositeTensorDim(AttentionDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_dense, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(AttentionDimNames.composite_dense, (head_groups, group_heads, kv_channels)) ) + def get_block(self) -> "Attention": + pass -@config_class() -# TODO: Use composition for attention config -class TransformerConfig(AttentionConfig, BlockConfig): - _abstract = False - - # TODO: Review names - init_method_std: float = Field( - default=None, - desc="Default scale for weight initialization. Default: hidden_size**-0.5", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max: float | None = Field( - default=None, - desc="Max value for clamping initialized weights. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min: float | None = Field( - default=None, - desc="Min value for clamping initialized weights. Default: -float('inf')", - hint=FieldHint.optional, - ) - init_method_std_qkv: float = Field( - default=None, - desc="Scale for the query, key and value weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_qkv: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for query, key and value matrices. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_qkv: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for query, key and value matrices. Default: -float('inf')", - hint=FieldHint.optional, - ) - init_method_std_attn_proj: float = Field( - default=None, - desc="Scale for the attention projection weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_attn_proj: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for attention projection. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_attn_proj: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for attention projection. Default: -float('inf')", - hint=FieldHint.optional, - ) - init_method_std_mlp_1: float = Field( - default=None, - desc="Scale for the MLP first layer weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_mlp_1: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for MLP first layer. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_mlp_1: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for MLP first layer. Default: -float('inf')", - hint=FieldHint.optional, - ) - init_method_std_mlp_2: float = Field( - default=None, - desc="Scale for the MLP second layer weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_mlp_2: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for MLP second layer. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_mlp_2: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for MLP second layer. Default: -float('inf')", - hint=FieldHint.optional, - ) - # Use random inits instead of constant values, useful for debugging. - random_bias_init: bool = Field( - default=False, - desc="Initialize the biases using the initialization method of their respective weights instead of setting them to zero. Used to test for issues that may not be visible when the biases are zero.", - hint=FieldHint.testing, - ) - - def _validate(self) -> None: - with self._set_implicit_default(): - if self.kv_channels is None: - self.kv_channels = div(self.hidden_size, self.num_attention_heads) - if self.init_method_std is None: - self.init_method_std = self.hidden_size**-0.5 - if self.init_method_std_qkv is None: - self.init_method_std_qkv = self.init_method_std - if self.init_method_std_attn_proj is None: - self.init_method_std_attn_proj = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 - if self.init_method_std_mlp_1 is None: - self.init_method_std_mlp_1 = self.init_method_std - if self.init_method_std_mlp_2 is None: - self.init_method_std_mlp_2 = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 - if self.init_method_max_qkv is None: - self.init_method_max_qkv = self.init_method_max - if self.init_method_min_qkv is None: - self.init_method_min_qkv = self.init_method_min - if self.init_method_max_attn_proj is None: - self.init_method_max_attn_proj = self.init_method_max - if self.init_method_min_attn_proj is None: - self.init_method_min_attn_proj = self.init_method_min - if self.init_method_max_mlp_1 is None: - self.init_method_max_mlp_1 = self.init_method_max - if self.init_method_min_mlp_1 is None: - self.init_method_min_mlp_1 = self.init_method_min - if self.init_method_max_mlp_2 is None: - self.init_method_max_mlp_2 = self.init_method_max - if self.init_method_min_mlp_2 is None: - self.init_method_min_mlp_2 = self.init_method_min - if self.init_method_min is not None and self.init_method_max is not None: - Assert.leq(self.init_method_min, self.init_method_max) - if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: - Assert.leq(self.init_method_min, self.init_method_max) - if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: - Assert.leq(self.init_method_min_qkv, self.init_method_max_qkv) - if self.init_method_min_attn_proj is not None and self.init_method_max_attn_proj is not None: - Assert.leq(self.init_method_min_attn_proj, self.init_method_max_attn_proj) - if self.init_method_min_mlp_1 is not None and self.init_method_max_mlp_1 is not None: - Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) - if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: - Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) - - super()._validate() - - @property - def add_attn_qkv_bias(self) -> bool: - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.nowhere: + @functools.cached_property + def add_qkv_bias(self) -> bool: + if isinstance(self.block.add_linear_biases, bool): + return self.block.add_linear_biases + if self.block.add_linear_biases == AddLinearBiasChoices.nowhere: return False return True - @property - def add_attn_dense_bias(self) -> bool: - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.everywhere: + @functools.cached_property + def add_dense_bias(self) -> bool: + if isinstance(self.block.add_linear_biases, bool): + return self.block.add_linear_biases + if self.block.add_linear_biases == AddLinearBiasChoices.everywhere: return True return False + + @functools.cached_property + def qkv_weight_initialization_method(self) -> Initializer: + if self.qkv_weight_initialization.has_initialization: + return self.qkv_weight_initialization.get_initializer() + else: + return self.block.block_sequence.hidden_size**-0.5 + + @functools.cached_property + def qkv_bias_initialization_method(self) -> Initializer: + if self.qkv_bias_initialization.has_initialization: + assert self.add_qkv_bias + return self.qkv_bias_initialization.get_initializer() + else: + return init_zeros_ + + @functools.cached_property + def dense_weight_initialization_method(self) -> Initializer: + if self.dense_weight_initialization.has_initialization: + return self.dense_weight_initialization.get_initializer() + else: + return self.block.block_sequence.hidden_size**-0.5 / max(2 * self.block.block_sequence.num_blocks, 1) + + @functools.cached_property + def dense_bias_initialization_method(self) -> Initializer: + if self.dense_bias_initialization.has_initialization: + assert self.add_dense_bias + return self.dense_bias_initialization.get_initializer() + else: + return init_zeros_ diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 3f0e14eb7..16e5811e6 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -5,7 +5,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs +from fast_llm.layers.transformer.config import AttentionConfig, AttentionKwargs from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) @@ -21,7 +21,7 @@ class BackupAttentionPreprocessor(Preprocessor): def __init__( self, - config: TransformerConfig, + config: AttentionConfig, tensor_space: TensorSpace, ): self._config = config @@ -51,13 +51,13 @@ def _create_tensors(self, sequence_length: int) -> None: ) def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[TransformerKwargs.sequence_length]) - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size - kwargs[TransformerKwargs.attention_mask] = self._mask[ + self._create_tensors(kwargs[AttentionKwargs.sequence_length]) + sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size + sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size + kwargs[AttentionKwargs.attention_mask] = self._mask[ None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k ] - if (sequence_lengths := kwargs.get(TransformerKwargs.sequence_lengths, None)) is not None: + if (sequence_lengths := kwargs.get(AttentionKwargs.sequence_lengths, None)) is not None: seq_ids = torch.stack( [ torch.cat([torch.full((x,), i) for i, x in enumerate(sample_lens)]) @@ -65,33 +65,33 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: ] ) document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(self._tensor_space.distributed.device) - kwargs[TransformerKwargs.attention_mask] = ( - kwargs[TransformerKwargs.attention_mask] + kwargs[AttentionKwargs.attention_mask] = ( + kwargs[AttentionKwargs.attention_mask] & document_mask[:, None, sequence_k - sequence_q : sequence_k, None, :sequence_k] ) - kwargs[TransformerKwargs.attention_mask_value] = self._mask_value + kwargs[AttentionKwargs.attention_mask_value] = self._mask_value def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[TransformerKwargs.attention_mask] = TensorMeta.from_dims( + kwargs[AttentionKwargs.attention_mask] = TensorMeta.from_dims( ( self._scalar_dim, self._scalar_dim, - kwargs[TransformerKwargs.sequence_q_dim], + kwargs[AttentionKwargs.sequence_q_dim], self._scalar_dim, - kwargs[TransformerKwargs.sequence_k_dim], + kwargs[AttentionKwargs.sequence_k_dim], ), - tensor_name=TransformerKwargs.attention_mask, + tensor_name=AttentionKwargs.attention_mask, dtype=torch.bool, ) - kwargs[TransformerKwargs.attention_mask_value] = TensorMeta.from_dims( + kwargs[AttentionKwargs.attention_mask_value] = TensorMeta.from_dims( (self._scalar_dim,), - tensor_name=TransformerKwargs.attention_mask_value, + tensor_name=AttentionKwargs.attention_mask_value, dtype=self._tensor_space.distributed_config.training_dtype.torch, ) class FlashAttnVarlenPreprocessor(Preprocessor): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace): + def __init__(self, config: AttentionConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config @@ -107,12 +107,12 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: also contain previous tokens from the first document in micro-sequence. We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths. """ - if TransformerKwargs.sequence_lengths not in kwargs: + if AttentionKwargs.sequence_lengths not in kwargs: return - sequence_lengths = kwargs[TransformerKwargs.sequence_lengths] - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size - if sequence_q < kwargs[TransformerKwargs.sequence_length]: + sequence_lengths = kwargs[AttentionKwargs.sequence_lengths] + sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size + sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size + if sequence_q < kwargs[AttentionKwargs.sequence_length]: cumsums = [torch.cumsum(x, dim=0) for x in sequence_lengths] # The first and last documents in a microsequence need to be handled separately. Include all tokens from other documents # in the microsequence. We need to consider all keys computed so far from the first sample. We also store the offsets @@ -146,17 +146,17 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: else: seqlens_q = torch.cat(sequence_lengths) seqlens_k = torch.cat(sequence_lengths) - kwargs[TransformerKwargs.cu_seqlens_q] = torch.cat( + kwargs[AttentionKwargs.cu_seqlens_q] = torch.cat( ( torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), torch.cumsum(seqlens_q, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), ) ) - kwargs[TransformerKwargs.cu_seqlens_k] = torch.cat( + kwargs[AttentionKwargs.cu_seqlens_k] = torch.cat( ( torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), torch.cumsum(seqlens_k, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), ) ) - kwargs[TransformerKwargs.max_seqlen_q] = seqlens_q.max() - kwargs[TransformerKwargs.max_seqlen_k] = seqlens_k.max() + kwargs[AttentionKwargs.max_seqlen_q] = seqlens_q.max() + kwargs[AttentionKwargs.max_seqlen_k] = seqlens_k.max() diff --git a/fast_llm/layers/transformer/rotary/preprocessing.py b/fast_llm/layers/transformer/rotary/preprocessing.py index c357411b6..9f8732f85 100644 --- a/fast_llm/layers/transformer/rotary/preprocessing.py +++ b/fast_llm/layers/transformer/rotary/preprocessing.py @@ -4,7 +4,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig from fast_llm.tensor import TensorMeta @@ -26,34 +26,34 @@ def __init__( self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] - self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] + self._kv_channels_dim = self._tensor_space[AttentionDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[TransformerKwargs.sequence_length]) - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - kwargs[TransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ - :, sequence_k - kwargs[TransformerKwargs.sequence_q_dim].size : sequence_k + self._create_tensors(kwargs[AttentionKwargs.sequence_length]) + sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size + kwargs[AttentionKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ + :, sequence_k - kwargs[AttentionKwargs.sequence_q_dim].size : sequence_k ] - kwargs[TransformerKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] + kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[TransformerKwargs.rotary_freq_q] = TensorMeta.from_dims( + kwargs[AttentionKwargs.rotary_freq_q] = TensorMeta.from_dims( ( self._scalar_dim, - kwargs[TransformerKwargs.sequence_q_dim], + kwargs[AttentionKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=TransformerKwargs.rotary_freq_q, + tensor_name=AttentionKwargs.rotary_freq_q, ) - kwargs[TransformerKwargs.rotary_freq_k] = TensorMeta.from_dims( + kwargs[AttentionKwargs.rotary_freq_k] = TensorMeta.from_dims( ( self._scalar_dim, - kwargs[TransformerKwargs.sequence_q_dim], + kwargs[AttentionKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=TransformerKwargs.rotary_freq_k, + tensor_name=AttentionKwargs.rotary_freq_k, ) def _create_tensors(self, sequence_length: int) -> None: diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index 17b18a1ca..ebb629aa1 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -8,7 +8,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.triton.rotary import triton_rotary_autograd_ -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs from fast_llm.layers.transformer.rotary.config import ( DefaultRotaryConfig, Llama3RotaryConfig, @@ -83,44 +83,44 @@ def __init__( self._tensor_space = tensor_space if self._tensor_space is not None: self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] - self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] + self._kv_channels_dim = self._tensor_space[AttentionDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: assert self._tensor_space is not None - self._create_tensors(kwargs[TransformerKwargs.sequence_length]) - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - kwargs[TransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ - :, sequence_k - kwargs[TransformerKwargs.sequence_q_dim].size : sequence_k + self._create_tensors(kwargs[AttentionKwargs.sequence_length]) + sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size + kwargs[AttentionKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ + :, sequence_k - kwargs[AttentionKwargs.sequence_q_dim].size : sequence_k ] - kwargs[TransformerKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] + kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: assert self._tensor_space is not None - kwargs[TransformerKwargs.rotary_freq_q] = TensorMeta.from_dims( + kwargs[AttentionKwargs.rotary_freq_q] = TensorMeta.from_dims( ( self._scalar_dim, - kwargs[TransformerKwargs.sequence_q_dim], + kwargs[AttentionKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=TransformerKwargs.rotary_freq_q, + tensor_name=AttentionKwargs.rotary_freq_q, ) - kwargs[TransformerKwargs.rotary_freq_k] = TensorMeta.from_dims( + kwargs[AttentionKwargs.rotary_freq_k] = TensorMeta.from_dims( ( self._scalar_dim, - kwargs[TransformerKwargs.sequence_q_dim], + kwargs[AttentionKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=TransformerKwargs.rotary_freq_k, + tensor_name=AttentionKwargs.rotary_freq_k, ) def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: rotary_fn = triton_rotary_autograd_ if self._config.triton else apply_rotary_embeddings - query = rotary_fn(query, kwargs[TransformerKwargs.rotary_freq_q]) - key = rotary_fn(key, kwargs[TransformerKwargs.rotary_freq_k]) + query = rotary_fn(query, kwargs[AttentionKwargs.rotary_freq_q]) + key = rotary_fn(key, kwargs[AttentionKwargs.rotary_freq_k]) return query, key def _create_tensors(self, sequence_length: int) -> None: diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index 3c0ad8ab4..eb24ef183 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -36,7 +36,7 @@ def get_layers(self) -> list[Layer]: self._tensor_space, block_index=i + 1, ) - for i in range(self._config.transformer.num_layers) + for i in range(self._config.transformer.num_blocks) ], CustomHead(self._config, self._tensor_space), ] diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 0da16428e..a7fcad82d 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -9,7 +9,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.language_model.config import LanguageModelBaseConfig +from fast_llm.layers.language_model.config import LanguageModelConfig from fast_llm.models.gpt.megatron import set_megatron_distributed_seeds from fast_llm.utils import Assert, div @@ -119,7 +119,7 @@ def micro_batch_splits(self) -> int: @config_class() -class GPTBaseModelConfig(LanguageModelBaseConfig): +class GPTBaseModelConfig(LanguageModelConfig): _abstract = False # Debug, to get an exact match with megatron init. @@ -192,15 +192,12 @@ class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate() def _validate(self) -> None: - if self.batch.sequence_length is None: - # TODO: Drop this. - self.batch.sequence_length = self.model.base_model.max_position_embeddings if self.model.base_model.use_megatron_initialization: set_megatron_distributed_seeds(self.model.distributed) super()._validate() if self.model.base_model.use_absolute_position_embeddings: - Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) + Assert.geq(self.model.base_model.absolute_position_embeddings, self.batch.sequence_length) distillation_model = self.model.base_model.distillation_model dpo_reference_model = self.model.base_model.dpo_reference_model diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 2dbef77f3..f3e57fe13 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -176,7 +176,7 @@ def _create_weight_converters( self, ) -> list[WeightConverter]: converters = [] - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.transformer.num_blocks # Embeddings converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) @@ -256,7 +256,7 @@ def _create_transformer_layer_converters( return converters def _create_lm_head_converters(self) -> list[WeightConverter]: - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.transformer.num_blocks prediction_heads = self._model.config.base_model.prediction_heads norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) converters = [] @@ -654,7 +654,7 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig # Override base method to handle the MTP heads def _create_lm_head_converters(self) -> list[WeightConverter]: - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.transformer.num_blocks prediction_heads = self._model.config.base_model.prediction_heads norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) converters = [] diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index cf7da3872..4e3f258fc 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -9,7 +9,7 @@ from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM -from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.transformer.config import AttentionKwargs from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner @@ -86,12 +86,12 @@ def forward( if past_key_values is not None: # The transformers will use the past keys and values to this list. - kwargs[TransformerKwargs.past_key_values] = past_key_values + kwargs[AttentionKwargs.past_key_values] = past_key_values # TODO: preprocess needs to know about the past. raise NotImplementedError() if use_cache: # The transformers will save the present keys and values to this list. - kwargs[TransformerKwargs.presents] = [] + kwargs[AttentionKwargs.presents] = [] if output_hidden_states: kwargs["output_hidden_states"] = True @@ -117,11 +117,11 @@ def forward( outputs = (logits,) if use_cache: - outputs += (kwargs[TransformerKwargs.presents],) + outputs += (kwargs[AttentionKwargs.presents],) return outputs return transformers.modeling_outputs.CausalLMOutputWithPast( logits=logits, hidden_states=hidden_states, - past_key_values=kwargs[TransformerKwargs.presents], + past_key_values=kwargs[AttentionKwargs.presents], ) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index da647de57..30842597d 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -16,7 +16,7 @@ from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor from fast_llm.layers.transformer.block import TransformerBlock -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron @@ -68,7 +68,7 @@ def get_output_layers(self) -> list[Layer]: self._config.transformer, self._tensor_space, # TODO MTP: which index? - block_index=max(self._config.transformer.num_layers + i, 1), + block_index=max(self._config.transformer.num_blocks + i, 1), # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=i < self._config.prediction_heads - 1, @@ -93,9 +93,9 @@ def get_layers(self) -> list[Layer]: block_index=i + 1, # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. - return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, + return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_blocks - 1, ) - for i in range(self._config.transformer.num_layers) + for i in range(self._config.transformer.num_blocks) ], *self.get_output_layers(), ] @@ -119,7 +119,7 @@ def preprocess_meta( truncate_documents = True batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) - batch_dim = TensorDim(TransformerDimNames.batch, micro_batch_size * batch_data.size, batch_data) + batch_dim = TensorDim(AttentionDimNames.batch, micro_batch_size * batch_data.size, batch_data) if micro_sequence_length is None: micro_sequence_length = sequence_length @@ -128,13 +128,13 @@ def preprocess_meta( # TODO: Calculate hidden dims elsewhere? sequence_q_dim = TensorDim( - TransformerDimNames.sequence_q, + AttentionDimNames.sequence_q, micro_sequence_length, self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.sequence_data), ) hidden_sequence_q_dim = ( TensorDim( - TransformerDimNames.sequence_q_tp, + AttentionDimNames.sequence_q_tp, micro_sequence_length, self._tensor_space.distributed_config.get_distributed_dim( DistributedDimNames.tensor_and_sequence_data @@ -151,7 +151,7 @@ def preprocess_meta( sequence_first = self._config.sequence_first assert not (need_sequence_first and not sequence_first) - hidden_dim = self._tensor_space[TransformerDimNames.hidden] + hidden_dim = self._tensor_space[AttentionDimNames.hidden] hidden_dims = ( (hidden_sequence_q_dim, batch_dim, hidden_dim) if sequence_first @@ -160,10 +160,10 @@ def preprocess_meta( common_kwargs = { LanguageModelKwargs.phase: phase, - TransformerKwargs.sequence_first: sequence_first, - TransformerKwargs.hidden_dims: hidden_dims, - TransformerKwargs.sequence_length: sequence_length, - TransformerKwargs.sequence_q_dim: sequence_q_dim, + AttentionKwargs.sequence_first: sequence_first, + AttentionKwargs.hidden_dims: hidden_dims, + AttentionKwargs.sequence_length: sequence_length, + AttentionKwargs.sequence_q_dim: sequence_q_dim, LanguageModelKwargs.mask_inputs: not truncate_documents, } @@ -182,7 +182,7 @@ def preprocess_meta( preprocessed_meta = [] for i, sequence_k_past in enumerate(sequence_k_pasts): sequence_k = sequence_k_past + sequence_q_dim.size - sequence_k_dim = TensorDim(TransformerDimNames.sequence_k, sequence_k) + sequence_k_dim = TensorDim(AttentionDimNames.sequence_k, sequence_k) tokens = TensorMeta.from_dims( hidden_dims[:2], tensor_name=f"tokens_{sequence_k_past}_to_{sequence_k-1}", dtype=torch.int64 @@ -190,7 +190,7 @@ def preprocess_meta( kwargs = { **common_kwargs, - TransformerKwargs.sequence_k_dim: sequence_k_dim, + AttentionKwargs.sequence_k_dim: sequence_k_dim, } if phase != PhaseType.inference: kwargs[LanguageModelKwargs.labels] = TensorMeta.from_dims( @@ -202,10 +202,10 @@ def preprocess_meta( for name, reference_preprocessed_meta in reference_preprocessed_metas.items(): reference_tokens, reference_kwargs_ = reference_preprocessed_meta[i] for key in ( - TransformerKwargs.sequence_first, - TransformerKwargs.sequence_length, - TransformerKwargs.sequence_q_dim, - TransformerKwargs.sequence_k_dim, + AttentionKwargs.sequence_first, + AttentionKwargs.sequence_length, + AttentionKwargs.sequence_q_dim, + AttentionKwargs.sequence_k_dim, ): Assert.eq(reference_kwargs_[key], kwargs[key]) reference_kwargs[name] = reference_kwargs_ @@ -231,8 +231,8 @@ def preprocess( preprocessed_meta = self.preprocess_meta(batch.token_ids, phase) _, common_kwargs = preprocessed_meta[0] - sequence_q = common_kwargs[TransformerKwargs.sequence_q_dim].size - sequence_first = common_kwargs[TransformerKwargs.sequence_first] + sequence_q = common_kwargs[AttentionKwargs.sequence_q_dim].size + sequence_first = common_kwargs[AttentionKwargs.sequence_first] prediction_heads: int = self._config.prediction_heads batch.token_ids = batch.token_ids.to( @@ -264,14 +264,14 @@ def preprocess( preprocessed = [] presents = None for i, (_, kwargs_meta) in enumerate(preprocessed_meta): - sequence_k = kwargs_meta[TransformerKwargs.sequence_k_dim].size + sequence_k = kwargs_meta[AttentionKwargs.sequence_k_dim].size if sequence_first: tokens = batch.token_ids[sequence_k - sequence_q : sequence_k] else: # TODO: Avoid multiple contiguous calls? tokens = batch.token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() if batch.sequence_lengths is not None: - kwargs_meta[TransformerKwargs.sequence_lengths] = batch.sequence_lengths + kwargs_meta[AttentionKwargs.sequence_lengths] = batch.sequence_lengths if batch.chosen_spans is not None: kwargs_meta[LanguageModelKwargs.chosen_spans] = batch.chosen_spans if batch.rejected_spans is not None: @@ -283,8 +283,8 @@ def preprocess( presents = None if i == len(preprocessed_meta) - 1 else [] kwargs = { **kwargs_meta, - TransformerKwargs.past_key_values: pasts, - TransformerKwargs.presents: presents, + AttentionKwargs.past_key_values: pasts, + AttentionKwargs.presents: presents, } if phase != PhaseType.inference: sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels @@ -372,7 +372,7 @@ def loss_defs(self) -> list[LossDef]: LossDef( name=MLPLossNames.load_balancing_loss, formatted_name="load balancing loss", - count=self._config.transformer.num_layers, + count=self._config.transformer.num_blocks, ) ) if self._config.transformer.expert_z_loss_coefficient: @@ -380,7 +380,7 @@ def loss_defs(self) -> list[LossDef]: LossDef( name=MLPLossNames.router_z_loss, formatted_name="router z loss", - count=self._config.transformer.num_layers, + count=self._config.transformer.num_blocks, ) ) if self._config.logit_z_loss: @@ -421,7 +421,7 @@ def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration, batch_size, s consumed_tokens_per_iteration = sequence_length * batch_size - num_transformer_layers = transformer_config.num_layers + self._config.base_model.prediction_heads - 1 + num_transformer_layers = transformer_config.num_blocks + self._config.base_model.prediction_heads - 1 transformer_flops_base = ( 2 * checkpoint_activations_factor * consumed_tokens_per_iteration * num_transformer_layers ) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 9427f69be..a351522ca 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -62,13 +62,13 @@ def _validate(self): if self.hybrid_block_layout is None: with self._set_implicit_default(): - self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_layers + self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_blocks - if len(self.hybrid_block_layout) != self.transformer.num_layers: - message = f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" - if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: + if len(self.hybrid_block_layout) != self.transformer.num_blocks: + message = f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_blocks}" + if self.transformer.num_blocks % len(self.hybrid_block_layout) != 0: raise ValueError(message) - num_repeats = self.transformer.num_layers // len(self.hybrid_block_layout) + num_repeats = self.transformer.num_blocks // len(self.hybrid_block_layout) logger.warning(f"{message}, will repeat {self.hybrid_block_layout} {num_repeats} times.") self.hybrid_block_layout = self.hybrid_block_layout * num_repeats diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 43e3c67e5..fb24c1aec 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -219,7 +219,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: def _create_weight_converters(self) -> list[WeightConverter]: converters = super()._create_weight_converters() or [] - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.transformer.num_blocks ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear for i in range(num_layers): @@ -383,7 +383,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: def _create_weight_converters(self) -> list[WeightConverter]: # not using super() because LLamba model is called backbone in the checkpoints converters = [] - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.transformer.num_blocks norm_bias: bool = False ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear @@ -572,7 +572,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: def _create_weight_converters(self) -> list[WeightConverter]: converters = super()._create_weight_converters() - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.transformer.num_blocks norm_bias: bool = False # Embedding and output diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index d080e6a1e..b12d12072 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -1,13 +1,12 @@ -import abc import functools import logging -import math import typing import torch from fast_llm.core.distributed import ReduceOp from fast_llm.core.ops import reduce_op +from fast_llm.engine.config_utils.initialization import Initializer, LambdaInitializer from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed @@ -361,70 +360,3 @@ def accumulate_gradient(param: torch.Tensor, grad: torch.Tensor) -> None: triton_copy(grad, param.grad_buffer) # noqa else: triton_add(grad, param.grad_buffer, out=param.grad_buffer) # noqa - - -class Initializer(abc.ABC): - @abc.abstractmethod - def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: - pass - - requires_global_initialization = False - - -class LambdaInitializer(Initializer): - def __init__( - self, - init_method: typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None], - requires_global_initialization: bool = False, - ) -> None: - self._init_method = init_method - self.requires_global_initialization = requires_global_initialization - - def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: - return self._init_method(meta, tensor, generator) - - -def init_fill_(value: float) -> LambdaInitializer: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa - tensor.fill_(value) - - return LambdaInitializer(init_) - - -init_zeros_ = init_fill_(0.0) -init_ones_ = init_fill_(1.0) - - -def init_normal_( - mean: float = 0.0, std: float = 1.0, min_val: float | None = None, max_val: float | None = None -) -> LambdaInitializer: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa - tensor = tensor.normal_(mean, std, generator=generator) - if min_val is not None or max_val is not None: - tensor.clamp_(min=min_val, max=max_val) - - return LambdaInitializer(init_) - - -def init_kaiming_(d_in: float) -> LambdaInitializer: - return init_normal_(0.0, math.sqrt(2.0 / d_in)) - - -def init_uniform_( - low: float = 0.0, high: float = 1.0, min_val: float | None = None, max_val: float | None = None -) -> LambdaInitializer: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa - tensor = tensor.uniform_(low, high, generator=generator) - if min_val is not None or max_val is not None: - tensor.clamp_(min=min_val, max=max_val) - - return LambdaInitializer(init_) - - -def init_uniform_centered_(high: float, max_val: float | None = None, mean: float = 0.0) -> LambdaInitializer: - return init_uniform_( - mean - high, - mean + high, - min_val=None if max_val is None else mean - max_val, - max_val=None if max_val is None else mean + max_val, - ) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 9a878c494..8c33aed4d 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -9,7 +9,7 @@ from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead -from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.transformer.config import AttentionKwargs from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda @@ -198,8 +198,8 @@ def test_lm_head( else: loss_mask = None kwargs = { - TransformerKwargs.sequence_first: sequence_first, - TransformerKwargs.grad_output: 1.0, + AttentionKwargs.sequence_first: sequence_first, + AttentionKwargs.grad_output: 1.0, } if config.distillation_model is None: target = torch.randint( diff --git a/tests/models/test_generate.py b/tests/models/test_generate.py index 7f0b902f8..cb9c69ccb 100644 --- a/tests/models/test_generate.py +++ b/tests/models/test_generate.py @@ -354,7 +354,7 @@ def _test_forward_return_hidden_states( # hidden_states include embeddings layer assert ( - len(res_fast_llm.hidden_states) - 1 == fast_llm_model.config.fast_llm_config.base_model.transformer.num_layers + len(res_fast_llm.hidden_states) - 1 == fast_llm_model.config.fast_llm_config.base_model.transformer.num_blocks ) diff --git a/tests/test_attention.py b/tests/test_attention.py index dd36b840a..534e3800e 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -6,7 +6,7 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs, TransformerConfig from fast_llm.layers.transformer.preprocessing import FlashAttnVarlenPreprocessor from fast_llm.utils import Assert @@ -77,13 +77,13 @@ def test_varlen_preprocessor(): varlen_preprocessor = FlashAttnVarlenPreprocessor(transformer_cfg, tensor_space=tensor_space) for micro_seq_idx in range(int(sequence_length / micro_sequence_length)): kwargs = { - TransformerKwargs.sequence_q_dim: TensorDim(TransformerDimNames.sequence_k, micro_sequence_length), - TransformerKwargs.sequence_k_dim: TensorDim( - TransformerDimNames.sequence_k, (micro_seq_idx + 1) * micro_sequence_length + AttentionKwargs.sequence_q_dim: TensorDim(AttentionDimNames.sequence_k, micro_sequence_length), + AttentionKwargs.sequence_k_dim: TensorDim( + AttentionDimNames.sequence_k, (micro_seq_idx + 1) * micro_sequence_length ), - TransformerKwargs.sequence_length: sequence_length, - TransformerKwargs.sequence_lengths: sequence_lengths, + AttentionKwargs.sequence_length: sequence_length, + AttentionKwargs.sequence_lengths: sequence_lengths, } varlen_preprocessor.preprocess(None, kwargs) - Assert.all_equal(kwargs[TransformerKwargs.cu_seqlens_q], cumulative_sequences_q[micro_seq_idx]) - Assert.all_equal(kwargs[TransformerKwargs.cu_seqlens_k], cumulative_sequences_k[micro_seq_idx]) + Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_q], cumulative_sequences_q[micro_seq_idx]) + Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_k], cumulative_sequences_k[micro_seq_idx]) diff --git a/tests/test_ssms.py b/tests/test_ssms.py index 694faa55b..6c4c7f0cb 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -9,7 +9,7 @@ from fast_llm.engine.schedule.config import ScheduleConfig from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.engine.schedule.schedule import Schedule -from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.transformer.config import AttentionKwargs from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat from fast_llm.models.ssm.model import HybridSSMModel @@ -71,8 +71,8 @@ def test_load_from_llamba_checkpoint(): schedule_runner.setup(model.distributed, optimizer=None) common_kwargs = { - TransformerKwargs.sequence_first: True, - TransformerKwargs.grad_output: False, + AttentionKwargs.sequence_first: True, + AttentionKwargs.grad_output: False, } input_data = [(x, common_kwargs)] diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 722d8d63a..4705ebb79 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -162,6 +162,7 @@ def _update_and_add_testing_config( "model.base_model.transformer.num_attention_heads=8", "model.base_model.transformer.head_groups=8", "model.base_model.transformer.init_method_std=0.022", + "model.base_model.transformer.use_position_embeddings=True", f"model.base_model.vocab_size={MODEL_TEST_VOCAB_SIZE}", f"model.multi_stage.debug_param_init={_LOG_LEVEL}", f"model.multi_stage.debug_layer_outputs={_LOG_LEVEL}", @@ -258,6 +259,7 @@ def _update_and_add_testing_config( extra_args=[ "model.base_model.transformer.head_groups=4", "model.base_model.transformer.rotary.type=default", + "model.base_model.transformer.use_position_embeddings=False", # Unused, but prevents issues with conversion tests. "model.base_model.max_position_embeddings=2048", ], From ab484ac94555915bd2d808279d5909de42541550 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 31 Jul 2025 15:29:57 -0400 Subject: [PATCH 35/82] Revert "stuff" This reverts commit a5eb0767e99038e18c1bd07f7f78718634296c4c. --- docs/developer_guide/conversion.md | 30 +-- .../engine/config_utils/initialization.py | 178 ------------ fast_llm/layers/block/block.py | 132 ++------- fast_llm/layers/block/config.py | 160 ++--------- fast_llm/layers/block/mixer.py | 68 +++++ fast_llm/layers/block/mlp/config.py | 79 +----- .../layers/block/mlp/mixture_of_experts.py | 134 ++++++---- fast_llm/layers/block/mlp/mlp.py | 72 +++-- fast_llm/layers/common/config.py | 2 +- fast_llm/layers/common/linear.py | 3 +- fast_llm/layers/common/normalization.py | 3 +- fast_llm/layers/language_model/config.py | 122 +++++---- fast_llm/layers/language_model/embedding.py | 48 ++-- fast_llm/layers/language_model/head.py | 109 ++++---- .../layers/language_model/preprocessing.py | 10 +- fast_llm/layers/ssm/config.py | 4 +- fast_llm/layers/ssm/discrete_mamba2.py | 4 +- fast_llm/layers/ssm/mamba2.py | 5 +- fast_llm/layers/ssm/mamba_layer.py | 7 +- fast_llm/layers/transformer/attention.py | 142 +++++----- fast_llm/layers/transformer/config.py | 253 ++++++++++++------ fast_llm/layers/transformer/preprocessing.py | 52 ++-- .../transformer/rotary/preprocessing.py | 26 +- fast_llm/layers/transformer/rotary/rotary.py | 30 +-- fast_llm/models/custom/model.py | 2 +- fast_llm/models/gpt/config.py | 9 +- fast_llm/models/gpt/conversion.py | 6 +- fast_llm/models/gpt/huggingface.py | 10 +- fast_llm/models/gpt/model.py | 54 ++-- fast_llm/models/ssm/config.py | 10 +- fast_llm/models/ssm/conversion.py | 6 +- fast_llm/tensor.py | 70 ++++- tests/layers/test_lm_head.py | 6 +- tests/models/test_generate.py | 2 +- tests/test_attention.py | 16 +- tests/test_ssms.py | 6 +- tests/utils/model_configs.py | 2 - 37 files changed, 857 insertions(+), 1015 deletions(-) delete mode 100644 fast_llm/engine/config_utils/initialization.py create mode 100644 fast_llm/layers/block/mixer.py diff --git a/docs/developer_guide/conversion.md b/docs/developer_guide/conversion.md index 719757df1..0620beaea 100644 --- a/docs/developer_guide/conversion.md +++ b/docs/developer_guide/conversion.md @@ -230,23 +230,21 @@ Continuing our `AwesomeModel` handler example, we define: ```python def _create_weight_converters(self) -> list[WeightConverter]: - - converters = [] -# The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`. -num_layers = self._model.config.base_model.transformer.num_blocks - -# A simple renaming example, for the word embeddings. -converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) - -# We usually want to loop dynamically over layers -for i in range(num_layers): - # A `SplitWeightConverter` example, splitting a weight in two. - converters.append(SplitWeightConverter( - f"layers.{i + 1}.weight", - (f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"), - )) -return converters + # The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`. + num_layers = self._model.config.base_model.transformer.num_layers + + # A simple renaming example, for the word embeddings. + converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) + + # We usually want to loop dynamically over layers + for i in range(num_layers): + # A `SplitWeightConverter` example, splitting a weight in two. + converters.append(SplitWeightConverter( + f"layers.{i + 1}.weight", + (f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"), + )) + return converters ``` And that's it! We're ready to use the new checkpoint format in Fast-LLM. diff --git a/fast_llm/engine/config_utils/initialization.py b/fast_llm/engine/config_utils/initialization.py deleted file mode 100644 index d35c2220c..000000000 --- a/fast_llm/engine/config_utils/initialization.py +++ /dev/null @@ -1,178 +0,0 @@ -import abc -import typing - -from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.utils import Assert - -if typing.TYPE_CHECKING: - import torch - - from fast_llm.tensor import ParameterMeta - - -@config_class(registry=True) -class InitializationConfig(Config): - _abstract = True - has_initialization: typing.ClassVar[bool] = True - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - if cls is InitializationConfig and cls.get_subclass(default.get("type")) is None: - # Default subclass. - return DefaultInitializationConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) - - def get_initializer(self) -> "Initializer": - raise NotImplementedError() - - -@config_class(dynamic_type={InitializationConfig: "default"}) -class DefaultInitializationConfig(InitializationConfig): - # A placeholder indicating that the class default should be used instead. - _abstract = False - has_initialization = False - - -@config_class(dynamic_type={InitializationConfig: "fill"}) -class NormalInitializationConfig(InitializationConfig): - """ - Normal initialization: normal(mean, std).clamp(min,max) - """ - - _abstract = False - - value: float = Field( - default=1, - desc="Initialization value.", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - - def get_initializer(self): - return init_fill_(self.value) - - -@config_class(dynamic_type={InitializationConfig: "zeros"}) -class ZeroInitializationConfig(InitializationConfig): - def get_initializer(self): - return init_zeros_ - - -@config_class(dynamic_type={InitializationConfig: "ones"}) -class ZeroInitializationConfig(InitializationConfig): - def get_initializer(self): - return init_ones_ - - -@config_class(dynamic_type={InitializationConfig: "normal"}) -class NormalInitializationConfig(InitializationConfig): - """ - Normal initialization: normal(mean, std).clamp(min,max) - """ - - _abstract = False - - std: float = Field( - default=1, - desc="Standard deviation for normal initialization.", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - mean: float = Field( - default=0, - desc="Mean for normal initialization.", - hint=FieldHint.optional, - ) - min: float | None = Field( - default=None, - desc="Min value for initialization clamping.", - hint=FieldHint.optional, - ) - max: float | None = Field( - default=None, - desc="Min value for initialization clamping.", - hint=FieldHint.optional, - ) - - def get_initializer(self): - return init_normal_(self.mean, self.std, self.min, self.max) - - -@config_class(dynamic_type={InitializationConfig: "uniform"}) -class UniformInitializationConfig(InitializationConfig): - """ - Uniform initialization: uniform(mean - scale, mean + scale).clamp(min,max) - """ - - _abstract = False - - scale: float = Field( - default=None, - desc="Initialization scale.", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - mean: float = Field( - default=None, - desc="Initialization mean.", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - - def get_initializer(self) -> "Initializer": - return init_uniform_centered_(self.scale, self.mean) - - -class Initializer(abc.ABC): - @abc.abstractmethod - def __call__(self, meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: - pass - - requires_global_initialization = False - - -class LambdaInitializer(Initializer): - def __init__( - self, - init_method: typing.Callable[["ParameterMeta", "torch.Tensor", "torch.Generator"], None], - requires_global_initialization: bool = False, - ) -> None: - self._init_method = init_method - self.requires_global_initialization = requires_global_initialization - - def __call__(self, meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: - return self._init_method(meta, tensor, generator) - - -def init_fill_(value: float) -> LambdaInitializer: - def init_(meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: # noqa - tensor.fill_(value) - - return LambdaInitializer(init_) - - -init_zeros_ = init_fill_(0.0) -init_ones_ = init_fill_(1.0) - - -def init_normal_( - mean: float = 0.0, std: float = 1.0, min_val: float | None = None, max_val: float | None = None -) -> LambdaInitializer: - def init_(meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: # noqa - tensor = tensor.normal_(mean, std, generator=generator) - if min_val is not None or max_val is not None: - tensor.clamp_(min=min_val, max=max_val) - - return LambdaInitializer(init_) - - -def init_uniform_centered_(scale: float, mean: float = 0.0) -> LambdaInitializer: - def init_(meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: # noqa - tensor.uniform_(mean - scale, mean + scale, generator=generator) - - return LambdaInitializer(init_) diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index d13b09807..85da61c01 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -1,5 +1,4 @@ import abc -import functools import typing import torch @@ -9,118 +8,23 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs, BlockLayerConfig +from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs +from fast_llm.layers.block.mixer import Mixer +from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP +from fast_llm.layers.block.mlp.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta -class DebugLayer: - # TODO: Move elsewhere? - def __init__(self, tensor_space: TensorSpace, name: str, debug_level: int = 0, debug_memory: bool = False): - self._tensor_space = tensor_space - self._name = name - self._debug_level = debug_level - self._debug_memory = debug_memory - - def _get_meta( - self, tensor: torch.Tensor, name: str, dims: tuple[TensorDim | str, ...], kwargs: dict[str, typing.Any] - ) -> TensorMeta: - hidden_dims = { - dim.name: dim for dim in kwargs[BlockKwargs.hidden_dims] + (kwargs[BlockKwargs.sequence_q_dim],) - } - return TensorMeta.from_dims( - tuple( - ( - dim - if isinstance(dim, TensorDim) - else hidden_dims[dim] if dim in hidden_dims else self._tensor_space[dim] - ) - for dim in dims - ), - tensor_name=f"{self._name} {name}", - dtype=tensor.dtype, - ) - - @functools.cached_property - def enabled(self) -> bool: - return self._debug_level > 0 or self._debug_memory - - def __call__( - self, - tensor: torch.Tensor, - name: str, - dims: tuple[TensorDim | str, ...], - kwargs: dict[str, typing.Any], - scale: float = 1.0, - global_: bool = True, - log_fn: type[BaseException] | typing.Callable[[str], T] | None = logger.info, - ) -> None: - # TODO: Local vs global? - if self._debug_memory: - log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"{self._name} {name}", str)) - if self._debug_level > 0: - log_distributed_tensor( - "", - tensor, - level=self._debug_level, - meta=self._get_meta(tensor, name, dims, kwargs), - distributed=self._tensor_space.distributed, - global_=global_, - log_fn=log_fn, - scale=scale, - ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=self._debug_level, - meta=self._get_meta(tensor, name + " grad", dims, kwargs), - distributed=self._tensor_space.distributed, - global_=global_, - log_fn=log_fn, - scale=scale, - ) - - -class BlockLayer[ConfigType: BlockLayerConfig](Configurable[ConfigType], torch.nn.Module): - """ - Base class for mixer and MLP modules. - """ - - def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): - super().__init__(config) - self._tensor_space = tensor_space - self._block_index = block_index - self._name = name - self._sequence_parallel: bool = self._tensor_space.distributed_config.sequence_tensor_parallel - self._debug = DebugLayer( - tensor_space, - f"Block {self._block_index} {self._name}", - self.config.block.debug_transformer, - self._config.block.debug_transformer_memory, - ) - - @abc.abstractmethod - def forward( - self, - input_: torch.Tensor, - kwargs: dict[str, typing.Any], - losses: dict[str, typing.Any] | None = None, - metrics: dict[str, typing.Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - pass - - -class Block[ConfigType: BlockConfig](Configurable[ConfigType], Layer): +class Block[ConfigType: BlockConfig](Layer, Configurable[ConfigType]): """ A transformer-like decoder base block with abstract mixer. """ # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "mixer" - def __init__( - self, config: ConfigType, tensor_space: TensorSpace, block_index: int = 0, return_input: bool = False - ): + def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, return_input: bool = False): super().__init__() self._config = config self._tensor_space: TensorSpace = tensor_space @@ -136,19 +40,21 @@ def __init__( self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) - # Attribute should be mixer, but Attention uses a different name for backward compatibility. TODO: Fix. - setattr( - self, - self._config.mixer.module_name, - self._config.mixer.get_layer(self._tensor_space, block_index, f"{self.name} mixer"), - ) + # The mixer needs to be created here for backward-compatible weight ordering. + setattr(self, self._mixer_module_name, self._create_mixer()) - self.mlp = self._config.mlp.get_layer(self._tensor_space, block_index, f"{self.name} mlp") + self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( + self._config, self._tensor_space, f"{self.name} mlp", block_index=block_index + ) # PEFT. self.norm_1 = self._config.peft.apply_other(self.norm_1) self.norm_2 = self._config.peft.apply_other(self.norm_2) + @abc.abstractmethod + def _create_mixer(self) -> Mixer: + pass + @torch.compile def _bias_dropout_add( self, input_: torch.Tensor, bias: torch.Tensor | None, residual: torch.Tensor @@ -207,13 +113,13 @@ def forward( hidden_states = self.norm_1(input_) if self._debug_mode: self._debug_log(hidden_states, "Norm 1", kwargs) - hidden_states, bias = getattr(self, self._config.mixer.module_name)(hidden_states, kwargs) + hidden_states, bias = getattr(self, self._mixer_module_name)(hidden_states, kwargs) if self._debug_mode: - self._debug_log(hidden_states, f"{self._config.mixer.module_name} output", kwargs, bias=bias) + self._debug_log(hidden_states, f"{self._mixer_module_name} output", kwargs, bias=bias) with set_generator(generator): input_ = self._bias_dropout_add(hidden_states, bias, input_) if self._debug_mode: - self._debug_log(input_, f"{self._config.mixer.module_name} residual", kwargs) + self._debug_log(input_, f"{self._mixer_module_name} residual", kwargs) hidden_states = self.norm_2(input_) if self._debug_mode: self._debug_log(hidden_states, "Norm 2", kwargs) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 87bd6d249..5a999fa6d 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -1,21 +1,13 @@ -import abc import enum -import functools -import typing from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.layers.block.mlp.config import MLPConfig from fast_llm.layers.block.peft import TransformerPeftConfig from fast_llm.layers.common.config import NormalizationConfig from fast_llm.utils import Assert -if typing.TYPE_CHECKING: - from fast_llm.layers.block.block import Block, BlockLayer - - -# TODO: Generalize these beyond language models? (Ex. vision) - class BlockDimNames: # A set of common tensor dim names packed into a namespace. @@ -47,76 +39,10 @@ class AddLinearBiasChoices(str, enum.Enum): only_attn_qkv = "only_attn_qkv" -@config_class(registry=True) -class BlockLayerConfig(BaseModelConfig): - _abstract = True - block: "BlockConfig" = Field(init=False) - - def _validate(self) -> None: - assert hasattr(self, "block") - Assert.is_(self.block.mlp, self) - super()._validate() - - @property - def layer_class(self) -> "type[BlockLayer]": - raise NotImplementedError() - - def get_layer(self, tensor_space: TensorSpace, block_index: int, name: str) -> "BlockLayer": - return self.layer_class(self, tensor_space, block_index, name) - - -@config_class() -class MixerConfig(BlockLayerConfig): - _abstract = True - - # Needed for backward compatibility. - module_name: typing.ClassVar[str] = "mixer" - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - if cls is MixerConfig and cls.get_subclass(default.get("type")) is None: - from fast_llm.layers.transformer.config import AttentionConfig - - # Default subclass. - return AttentionConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) - - @config_class() -class MLPBaseConfig(BlockLayerConfig): - _abstract = True - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - if cls is MLPBaseConfig and cls.get_subclass(default.get("type")) is None: - from fast_llm.layers.block.mlp.config import MLPConfig +# TODO: Use composition for MLP config +class BlockConfig(MLPConfig, BaseModelConfig): - # Default subclass. - return MLPConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) - - -@config_class() -class BlockConfig(BaseModelConfig): - _abstract = False - mixer: MixerConfig = Field( - desc="Configuration for the mixer.", - hint=FieldHint.architecture, - ) - mlp: MLPBaseConfig = Field( - desc="Configuration for the MLP.", - hint=FieldHint.architecture, - ) # TODO: Review names normalization: NormalizationConfig = Field( desc="Configuration for the normalization layers architecture.", @@ -132,6 +58,11 @@ class BlockConfig(BaseModelConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) + full_precision_residual: bool = Field( + default=False, + desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", + hint=FieldHint.stability, + ) debug_transformer: int = Field( default=0, desc="Log the output of each operation in a transformer layer.", @@ -149,45 +80,8 @@ class BlockConfig(BaseModelConfig): hint=FieldHint.architecture, ) - block_sequence: "BlockSequenceConfig" = Field(init=False) - - def _validate(self) -> None: - assert hasattr(self, "block_sequence") - Assert.incl(self, self.block_sequence.blocks.values()) - self.mixer.block = self - self.mlp.block = self - super()._validate() - - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - self.mlp.setup_tensor_space(tensor_space) - self.mixer.setup_tensor_space(tensor_space) - - # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(BlockDimNames.hidden, self.block_sequence.hidden_size)) - - @abc.abstractmethod - def get_block(self) -> "Block": - pass - - -@config_class() -class BlockSequenceConfig(BaseModelConfig): - _abstract = True - - blocks: dict[str, BlockConfig] = Field() - block_pattern: tuple[str, ...] = Field( - default=None, - desc="The pattern of blocks (referred by name) to use. The sequence is repeated until reaching `num_blocks`." - " Default: cycle over `blocks` in the order they are defined.", - ) - default_block: str = Field( - default=None, - desc="The default block configuration to use when referring to the model." - " Used to set some defaults in the language model.", - ) - # TODO: Move these, not specific to a single block. - num_blocks: int = Field( + num_layers: int = Field( default=12, desc="Number of layers in the transformer.", hint=FieldHint.architecture, @@ -199,28 +93,30 @@ class BlockSequenceConfig(BaseModelConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - full_precision_residual: bool = Field( - default=False, - desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", - hint=FieldHint.stability, + per_layer_lr_scale: list[float] | None = Field( + default=None, + desc="Custom learning rate scale for each layer.", + doc="May be used to freeze some layers by setting their scale to zero.", + hint=FieldHint.feature, ) def _validate(self) -> None: - for block in self.blocks.values(): - block.validate() - if self.block_pattern is None: - self.block_pattern = tuple(self.blocks) - if self.default_block is None: - self.default_block = self.block_pattern[0] + with self._set_implicit_default(): + if self.ffn_hidden_size is None: + self.ffn_hidden_size = 4 * self.hidden_size + super()._validate() - def get_block_config(self, block_index: int) -> BlockConfig: - return self.blocks[self.block_pattern[block_index % len(self.block_pattern)]] + @property + def add_mlp_bias(self) -> bool: + if isinstance(self.add_linear_biases, bool): + return self.add_linear_biases + if self.add_linear_biases == AddLinearBiasChoices.everywhere: + return True + return False def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - for block in self.blocks.values(): - block.setup_tensor_space(tensor_space) + super().setup_tensor_space(tensor_space) - @functools.cached_property - def default_block_config(self) -> BlockConfig: - return self.blocks[self.default_block] + # Hidden dimension + tensor_space.add_tensor_dim(TensorDim(BlockDimNames.hidden, self.hidden_size)) diff --git a/fast_llm/layers/block/mixer.py b/fast_llm/layers/block/mixer.py new file mode 100644 index 000000000..5c811e330 --- /dev/null +++ b/fast_llm/layers/block/mixer.py @@ -0,0 +1,68 @@ +import abc +import typing + +import torch + +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.logging import log_distributed_grad, log_distributed_tensor +from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert + + +class Mixer(torch.nn.Module, abc.ABC): + """ + Base class for mixer modules. + """ + + _mixer_name: typing.ClassVar[str] + + def __init__(self, tensor_space: TensorSpace, block_index: int, debug_level: int = 0): + super().__init__() + self._tensor_space = tensor_space + self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel + self._block_index = block_index + self._debug_level = debug_level + + @abc.abstractmethod + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Mixer module forward. Returns the output hidden states and an optional bias, + in case its addition can be made more efficient in `_bias_dropout_add`. + """ + + def _get_meta( + self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> TensorMeta: + hidden_dims = { + dim.name: dim for dim in kwargs[BlockKwargs.hidden_dims] + (kwargs[BlockKwargs.sequence_q_dim],) + } + return TensorMeta.from_dims( + tuple( + hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space[dim_name] + for dim_name in dim_names + ), + tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", + dtype=input_.dtype, + ) + + def _debug_log( + self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> None: + # TODO: Local vs global + Assert.gt(self._debug_level, 0) + log_distributed_tensor( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name, dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + if tensor.requires_grad: + log_distributed_grad( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 526c513db..1d125c4f7 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -1,18 +1,11 @@ import enum -import functools -import typing -from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_zeros_ +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType, MLPRecomputeLevel from fast_llm.utils import Assert -if typing.TYPE_CHECKING: - from fast_llm.layers.block.config import AddLinearBiasChoices, BlockLayerConfig - from fast_llm.layers.block.mlp.mlp import MLPBase - class MLPDimNames: # MLP dimensions @@ -39,10 +32,9 @@ class RoutingType(str, enum.Enum): sinkhorn = "sinkhorn" -@config_class(dynamic_type={BlockLayerConfig: "mlp"}) -class MLPConfig(BlockLayerConfig): +@config_class() +class MLPConfig(Config): # TODO: Review names - # TODO: Separate MoE? _abstract = False ffn_hidden_size: int = Field( default=None, @@ -132,52 +124,11 @@ class MLPConfig(BlockLayerConfig): " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", hint=FieldHint.expert, ) - layer_1_weight_initialization: InitializationConfig = Field( - desc="Initialization configuration for the first mlp layer weights. Default: hidden_size**-0.5", - hint=FieldHint.feature, - ) - layer_1_bias_initialization: InitializationConfig = Field( - desc="Initialization configuration for the first mlp layer biases. Default: fill with zeros.", - hint=FieldHint.feature, - ) - layer_2_weight_initialization: InitializationConfig = Field( - desc="Initialization configuration for the second mlp layer weights." - " Default: (2 * num_blocks * hidden_size)**-0.5", - hint=FieldHint.feature, - ) - layer_2_bias_initialization: InitializationConfig = Field( - desc="Initialization configuration for the second mlp layer biases. Default: fill with zeros.", - hint=FieldHint.feature, - ) - - @property - def layer_class(self) -> "type[MLPBase]": - if self.num_experts > 1: - from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP - - return MixtureOfExpertMLP - else: - from fast_llm.layers.block.mlp.mlp import MLP - - return MLP - - @property - def add_bias(self) -> bool: - if isinstance(self.block.add_linear_biases, bool): - return self.block.add_linear_biases - if self.block.add_linear_biases == AddLinearBiasChoices.everywhere: - return True - return False def _validate(self) -> None: - assert hasattr(self, "block") - with self._set_implicit_default(): if self.activation_type is None: self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu - if self.ffn_hidden_size is None: - # TODO: hidden_size not yet validated. - self.ffn_hidden_size = 4 * self.block.block_sequence.hidden_size self.num_unshared_experts = self.num_experts - self.num_shared_experts super()._validate() @@ -193,30 +144,6 @@ def _validate(self) -> None: elif self.mlp_lr_scale is not None: Assert.geq(self.mlp_lr_scale, 0) - @functools.cached_property - def layer_1_weight_initialization_method(self) -> Initializer: - if not self.layer_1_weight_initialization.has_initialization: - return self.layer_1_weight_initialization.get_initializer() - return self.block.block_sequence.hidden_size**-0.5 - - @functools.cached_property - def layer_1_bias_initialization_method(self) -> Initializer: - if not self.layer_1_bias_initialization.has_initialization: - return self.layer_1_bias_initialization.get_initializer() - return init_zeros_ - - @functools.cached_property - def layer_2_weight_initialization_method(self) -> Initializer: - if self.layer_2_weight_initialization.has_initialization: - return self.layer_2_weight_initialization.get_initializer() - return self.block.block_sequence.hidden_size**-0.5 / max(2 * self.block.block_sequence.num_blocks, 1) - - @functools.cached_property - def layer_2_bias_initialization_method(self) -> Initializer: - if self.layer_2_bias_initialization.has_initialization: - return self.layer_2_bias_initialization.get_initializer() - return init_zeros_ - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index 332d3109f..8d092b6dc 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -1,24 +1,27 @@ import logging +import typing import warnings import torch from fast_llm.core.distributed import ProcessGroup, set_generator -from fast_llm.engine.config_utils.initialization import init_normal_ +from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map -from fast_llm.layers.block.config import BlockDimNames, BlockKwargs -from fast_llm.layers.block.mlp.config import MLPConfig, MLPDimNames, MLPLossNames, RoutingType +from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs +from fast_llm.layers.block.mlp.config import MLPDimNames, MLPLossNames, RoutingType from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.linear import Linear -from fast_llm.utils import get_lr_scale +from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage +from fast_llm.tensor import TensorMeta, init_normal_ +from fast_llm.utils import Assert, get_lr_scale logger = logging.getLogger(__name__) -class MixtureOfExpertMLP[ConfigType: MLPConfig](MLPBase[ConfigType]): +class MixtureOfExpertMLP[ConfigType: BlockConfig](MLPBase[ConfigType]): """ MoeLayer following implementation from https://github.com/NVIDIA/Megatron-LM/blob/46ebc0e4202c980d98900000d455f754a7ff9d4b/megatron/model/transformer.py#L346 @@ -32,10 +35,23 @@ class MixtureOfExpertMLP[ConfigType: MLPConfig](MLPBase[ConfigType]): _group: ProcessGroup - def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): - super().__init__(config, tensor_space, block_index, name) + def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): + Assert.gt(config.num_experts, 1) # TODO: Implement? - assert not self._config.add_linear_biases, "Biases not supported for MoE." + assert not config.add_linear_biases, "Biases not supported for MoE." + super().__init__(config, tensor_space, name, block_index) + self._tensor_space = tensor_space + self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory + + self._num_experts = config.num_experts + self._experts_per_token = config.num_experts_per_token + self._num_shared_experts = config.num_shared_experts + self._num_unshared_experts = config.num_unshared_experts + + self._routing_type = config.expert_routing_type + self._load_balancing_factor = config.expert_auxiliary_loss_coefficient + self._z_loss_factor = config.expert_z_loss_coefficient + self._moe_jitter_eps = config.moe_jitter_eps layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) @@ -56,20 +72,21 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: i ) dropless_moe = False self._mlp_forward = self._forward_dropless if dropless_moe else self._forward_looped + self._dynamic_shape = config.dropless_dynamic_shape def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: hidden_states = input_.flatten(0, -2) logits = self.router(hidden_states) - if self._debug.enabled: - self._debug(logits, "Router logits", kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.experts,), kwargs) + if self._debug_mode: + self._debug_log(logits, "Router logits", MLPDimNames.experts, kwargs) # Apply z_loss if applicable - if self._config.expert_z_loss_coefficient > 0.0: + if self._z_loss_factor > 0.0: logits = z_loss( logits, - self._config.expert_z_loss_coefficient, + self._z_loss_factor, self.training, grad_scale=kwargs.get("grad_output"), losses=losses, @@ -77,31 +94,24 @@ def forward( ) # Apply input_jitter if applicable: - if self.training and self._config.moe_jitter_eps > 0.0: + if self.training and self._moe_jitter_eps > 0.0: with set_generator(self._tensor_space.distributed.pp_generator): logits = self._apply_input_jitter(logits) # Routing - if self._config.expert_routing_type == RoutingType.topk: + if self._routing_type == RoutingType.topk: scores, top_experts = self._topk_routing(logits, kwargs.get(BlockKwargs.grad_output), losses) - if self._config.num_shared_experts > 0: + if self._num_shared_experts > 0: scores, top_experts = self._add_shared_experts(top_experts, scores) - elif self._config.expert_routing_type == RoutingType.sinkhorn: + elif self._routing_type == RoutingType.sinkhorn: scores, top_experts = self._sinkhorn_routing(logits) else: - raise NotImplementedError(self._config.expert_routing_type) + raise NotImplementedError(self._routing_type) - if self._debug.enabled: + if self._debug_mode: # To log all ranks set `global_=False` - self._debug( - scores, "Router scores", kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.top_experts,), kwargs - ) - self._debug( - top_experts, - "Router top experts", - kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.top_experts,), - kwargs, - ) + self._debug_log(scores, "Router scores", MLPDimNames.top_experts, kwargs) + self._debug_log(top_experts, "Router top experts", MLPDimNames.top_experts, kwargs) return self._mlp_forward(hidden_states, scores, top_experts).view_as(input_), None # noqa @@ -109,9 +119,7 @@ def _forward_dropless( self, hidden_states: torch.Tensor, scores: torch.Tensor, top_experts: torch.Tensor ) -> torch.Tensor: # Compute token counts and the sparse mapping (dense_row, top_index) -> sparse_row. - sparse_map = get_sparse_map( - top_experts, self._config.num_experts, dynamic_shape=self._config.dropless_dynamic_shape - ) + sparse_map = get_sparse_map(top_experts, self._num_experts, dynamic_shape=self._dynamic_shape) # Sparse MLP return mlp_autograd( @@ -140,7 +148,7 @@ def _forward_looped( top_experts, self.layer_1.weight, self.layer_2.weight, - self._config.num_experts, + self._num_experts, self._config.gated, self._config.activation_type, self._intermediate_dim.parallel_group, @@ -151,9 +159,7 @@ def _forward_looped( @torch.compile def _apply_input_jitter(self, logits: torch.Tensor) -> torch.Tensor: - return logits * torch.empty_like(logits).uniform_( - 1.0 - self._config.moe_jitter_eps, 1.0 + self._config.moe_jitter_eps - ) + return logits * torch.empty_like(logits).uniform_(1.0 - self._moe_jitter_eps, 1.0 + self._moe_jitter_eps) def _topk_routing( self, @@ -161,11 +167,11 @@ def _topk_routing( grad_scale: float | None = None, losses: dict | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - top_logits, top_experts = torch.topk(logits, k=self._config.num_experts_per_token, dim=-1) + top_logits, top_experts = torch.topk(logits, k=self._experts_per_token, dim=-1) scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32) if losses is not None or (self.training and grad_scale is not None): probs = torch.softmax(logits, dim=-1, dtype=torch.float32) - mask = torch.nn.functional.one_hot(top_experts, num_classes=self._config.num_unshared_experts).sum(dim=1) + mask = torch.nn.functional.one_hot(top_experts, num_classes=self._num_unshared_experts).sum(dim=1) # Auxiliary loss, corresponding to the sum of probabilities for the top experts. # In the optimal case (uniform distribution), loss = experts_per_token / num_experts. # In the worst case (whole distribution in the top experts), loss = 1. @@ -176,9 +182,7 @@ def _topk_routing( losses[MLPLossNames.load_balancing_loss].append(aux_loss.detach()) if self.training and grad_scale is not None: scores = AuxiliaryLoss.apply( - scores, - aux_loss, - self._config.num_unshared_experts * self._config.expert_auxiliary_loss_coefficient * grad_scale, + scores, aux_loss, self._num_unshared_experts * self._load_balancing_factor * grad_scale ) return scores, top_experts @@ -187,33 +191,69 @@ def _add_shared_experts( ) -> tuple[torch.Tensor, torch.Tensor]: # Add the shared experts (last ones) to the top experts. shared_experts = torch.arange( - self._config.num_unshared_experts, - self._config.num_experts, - device=top_experts.device, - dtype=top_experts.dtype, + self._num_unshared_experts, self._num_experts, device=top_experts.device, dtype=top_experts.dtype )[None].repeat(top_experts.size(0), 1) top_experts = torch.cat((shared_experts, top_experts), dim=1) # Add scores of 1 to scores for shared experts. - scores = torch.cat((scores.new_ones(scores.size(0), self._config.num_shared_experts), scores), dim=1) + scores = torch.cat((scores.new_ones(scores.size(0), self._num_shared_experts), scores), dim=1) return scores, top_experts def _sinkhorn_routing(self, logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: if self.training: - _, top_experts = torch.topk(sinkhorn(logits), k=self._config.num_experts_per_token, dim=-1) + _, top_experts = torch.topk(sinkhorn(logits), k=self._experts_per_token, dim=-1) logits = self._sinkhorn_activation(logits) scores = torch.gather(logits, -1, top_experts) else: logits = self._sinkhorn_activation(logits) - scores, top_experts = torch.topk(logits, k=self._config.num_experts_per_token, dim=-1) + scores, top_experts = torch.topk(logits, k=self._experts_per_token, dim=-1) return scores, top_experts def _sinkhorn_activation(self, logits: torch.Tensor) -> torch.Tensor: return ( torch.sigmoid(logits) - if self._config.num_experts_per_token == 1 + if self._experts_per_token == 1 else torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) ) + def _debug_log( + self, + tensor: torch.Tensor | None, + name: str, + dim_name: str, + kwargs: dict[str, typing.Any], + global_: bool = True, + ) -> None: + if self._config.debug_transformer_memory: + log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"{self._name} {name}", str)) + if self._config.debug_transformer and tensor is not None: + # TODO: Local vs global + meta = self._get_meta(tensor, name, dim_name, kwargs) + log_distributed_tensor( + "", + tensor.view_as(meta), + level=self._config.debug_transformer, + meta=meta, + distributed=self._tensor_space.distributed, + global_=global_, + ) + if tensor.requires_grad: + log_distributed_grad( + "", + tensor, + level=self._config.debug_transformer, + meta=self._get_meta(tensor, name + " grad", dim_name, kwargs), + distributed=self._tensor_space.distributed, + grad_fn=lambda tensor_: tensor_.view_as(meta), + global_=global_, + ) + + def _get_meta(self, tensor: torch.Tensor, name: str, dim_name: str, kwargs: dict[str, typing.Any]) -> TensorMeta: + return TensorMeta.from_dims( + kwargs[BlockKwargs.hidden_dims][:-1] + (self._tensor_space[dim_name],), + tensor_name=f"{self._name} {name}", + dtype=tensor.dtype, + ) + def sinkhorn(cost: torch.Tensor, tolerance: float = 1e-5, eps=1e-9) -> torch.Tensor: """Sinkhorn based MoE routing function""" diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index aba5639b5..19349671e 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -2,77 +2,75 @@ import torch -from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_ +from fast_llm.config import Configurable +from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd -from fast_llm.layers.block.block import BlockLayer -from fast_llm.layers.block.config import BlockDimNames -from fast_llm.layers.block.mlp.config import MLPConfig, MLPDimNames +from fast_llm.layers.block.config import BlockConfig, BlockDimNames +from fast_llm.layers.block.mlp.config import MLPDimNames from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import LinearBase -from fast_llm.utils import get_lr_scale +from fast_llm.tensor import init_normal_, init_zeros_ +from fast_llm.utils import Assert, get_lr_scale -class MLPBase[ConfigType: MLPConfig](BlockLayer[ConfigType]): - _name: typing.ClassVar[str] = "mlp" - - def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): - super().__init__(config, tensor_space, block_index, name) +class MLPBase[ConfigType: BlockConfig](Configurable[ConfigType], Layer): + def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): + super().__init__(config) + self._name = name + self._block_index = block_index init_method_1 = init_normal_( - std=self._config.init_method_std_mlp_1, - min_val=self._config.init_method_min_mlp_1, - max_val=self._config.init_method_max_mlp_1, + std=config.init_method_std_mlp_1, + min_val=config.init_method_min_mlp_1, + max_val=config.init_method_max_mlp_1, ) init_method_2 = init_normal_( - std=self._config.init_method_std_mlp_2, - min_val=self._config.init_method_min_mlp_2, - max_val=self._config.init_method_max_mlp_2, + std=config.init_method_std_mlp_2, + min_val=config.init_method_min_mlp_2, + max_val=config.init_method_max_mlp_2, ) - hidden_dim = self._tensor_space[BlockDimNames.hidden] - self._intermediate_dim = self._tensor_space[MLPDimNames.composite_expert_mlp] + hidden_dim = tensor_space[BlockDimNames.hidden] + self._intermediate_dim = tensor_space[MLPDimNames.composite_expert_mlp] + self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - layer_lr_scale = ( - self._config.block.block_sequence.per_layer_lr_scale[self._block_index] - if self._config.block.block_sequence.per_layer_lr_scale - else None - ) - lr_scale = ( - tuple(self._config.mlp_lr_scale) - if isinstance(self._config.mlp_lr_scale, list) - else self._config.mlp_lr_scale - ) + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None + lr_scale = tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale lr_scale = get_lr_scale(lr_scale, layer_lr_scale) # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - self._tensor_space[MLPDimNames.composite_gated_expert_mlp], - bias=self._config.add_bias, + tensor_space[MLPDimNames.composite_gated_expert_mlp], + bias=config.add_mlp_bias, weight_init_method=init_method_1, - bias_init_method=init_method_1 if self._config.random_bias_init else init_zeros_, + bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, lr_scale=lr_scale, ) self.layer_2 = LinearBase( self._intermediate_dim, hidden_dim, - bias=self._config.add_bias, + bias=config.add_mlp_bias, weight_init_method=init_method_2, - bias_init_method=init_method_2 if self._config.random_bias_init else init_zeros_, - auto_bias_grad_accumulation=self._tensor_space.distributed_config.tensor_parallel > 1, + bias_init_method=init_method_2 if config.random_bias_init else init_zeros_, + auto_bias_grad_accumulation=tensor_space.distributed_config.tensor_parallel > 1, transposed_weight=True, lr_scale=lr_scale, ) # PEFT. - self.layer_1 = self._config.block.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1) - self.layer_2 = self._config.block.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) + self.layer_1 = config.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1) + self.layer_2 = config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) + +class MLP[ConfigType: BlockConfig](MLPBase[ConfigType]): + def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): + Assert.eq(config.num_experts, 1) + super().__init__(config, tensor_space, name, block_index) -class MLP[ConfigType: MLPConfig](MLPBase[ConfigType]): def forward( self, input_: torch.Tensor, diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 2f45fdf9f..9d5ce3f3b 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -87,7 +87,7 @@ class LayerNormalizationBaseConfig(NormalizationConfig): ) def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": - from fast_llm.engine.config_utils.initialization import init_uniform_centered_ + from fast_llm.tensor import init_uniform_centered_ kwargs = { "hidden_dim": hidden_dim, diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear.py index 740b4847c..7249ef569 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear.py @@ -3,7 +3,6 @@ import torch -from fast_llm.engine.config_utils.initialization import init_zeros_ from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.linear import ( @@ -15,7 +14,7 @@ output_parallel_linear_backward, output_parallel_linear_forward, ) -from fast_llm.tensor import ParameterMeta +from fast_llm.tensor import ParameterMeta, init_zeros_ logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index d44be3297..bccc1d627 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -1,12 +1,11 @@ import torch -from fast_llm.engine.config_utils.initialization import init_ones_, init_zeros_ from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.normalization import triton_normalization_autograd from fast_llm.layers.common.config import NormalizationImplementation -from fast_llm.tensor import ParameterMeta, accumulate_gradient +from fast_llm.tensor import ParameterMeta, accumulate_gradient, init_ones_, init_zeros_ from fast_llm.utils import Assert try: diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 2e7d71963..b667e5318 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,11 +1,13 @@ -import functools +import typing from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer +from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl -from fast_llm.layers.block.config import BlockDimNames, BlockKwargs, BlockSequenceConfig +from fast_llm.layers.block.config import BlockDimNames, BlockKwargs +from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.layers.transformer.rotary.config import NoRotaryConfig from fast_llm.utils import Assert @@ -44,27 +46,27 @@ class LanguageModelKwargs(BlockKwargs): @config_class() -class LanguageModelConfig(BlockSequenceConfig): - decoder: BlockSequenceConfig = Field( +class LanguageModelBaseConfig(BaseModelConfig): + # TODO: block + transformer: TransformerConfig = Field( + desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) + max_position_embeddings: int = Field( + default=2048, + desc="Number of absolute position embeddings, if applicable.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) vocab_size: int = Field( default=49152, desc="Size of the vocabulary, i.e., number of vocabulary embeddings and logits.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - embedding_dropout: float = Field( - # TODO: backward compatibility? - default=0.0, - desc="Dropout applied to the embedding layer.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - absolute_position_embeddings: int | None = Field( - # TODO: backward compatibility? + use_position_embeddings: bool = Field( default=None, - desc="Number of absolute position embeddings, if applicable.", + desc="Enable absolute position embeddings. Default: Enable unless using rotary embeddings.", hint=FieldHint.architecture, ) tie_word_embeddings: bool = Field( @@ -78,6 +80,22 @@ class LanguageModelConfig(BlockSequenceConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + init_method_std_embed: float = Field( + default=None, + desc="Initialization scale for the vocabulary embedding and output weights (logits).", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + init_method_max_embed: float | None = Field( + default=None, + desc="Max value for clamping initialized weights of the vocabulary embedding and output (logits).", + hint=FieldHint.feature, + ) + init_method_min_embed: float | None = Field( + default=None, + desc="Min value for clamping initialized weights of the vocabulary embedding and output (logits).", + hint=FieldHint.feature, + ) enable_dpo: bool | None = Field( default=False, desc="Whether to enable DPO loss", @@ -185,27 +203,26 @@ class LanguageModelConfig(BlockSequenceConfig): doc="If not provided, all heads are equally weighted.", hint=FieldHint.feature, ) - word_embedding_weight_initialization: InitializationConfig = Field( - desc="Initialization configuration for word embeddings. Default: hidden_size**-0.5", - hint=FieldHint.feature, - ) - position_embedding_weight_initialization: InitializationConfig = Field( - desc="Initialization configuration for position embeddings. Default: hidden_size**-0.5", - hint=FieldHint.feature, - ) - output_weight_initialization: InitializationConfig = Field( - desc="Initialization configuration for untied output weights. Default: hidden_size**-0.5", - hint=FieldHint.feature, - ) def _validate(self) -> None: + self.transformer.validate() with self._set_implicit_default(): if self.language_model_loss_factor is None: if self.distillation_model is None: self.language_model_loss_factor = 1.0 else: self.language_model_loss_factor = 0.0 + if self.use_position_embeddings is None: + self.use_position_embeddings = isinstance(self.transformer.rotary, NoRotaryConfig) + if self.init_method_std_embed is None: + self.init_method_std_embed = self.transformer.init_method_std + if self.init_method_max_embed is None: + self.init_method_max_embed = self.transformer.init_method_max + if self.init_method_min_embed is None: + self.init_method_min_embed = self.transformer.init_method_min super()._validate() + if self.init_method_max_embed is not None and self.init_method_min_embed is not None: + Assert.leq(self.init_method_min_embed, self.init_method_max_embed) if self.distillation_model is not None: if self.prediction_heads > 1: raise NotImplementedError("Multi-token prediction not supported with distillation.") @@ -213,40 +230,43 @@ def _validate(self) -> None: Assert.eq(len(self.prediction_loss_coefficient), self.prediction_heads) for coeff in self.prediction_loss_coefficient: Assert.geq(coeff, 0) - - if self.output_weight_initialization.has_initialization: - assert self.use_absolute_position_embeddings - if self.output_weight_initialization.has_initialization: - assert not self.tie_word_embeddings + if self.transformer.per_layer_lr_scale is not None: + # -1 because the first prediction head's transformer layer is accounted for in num_layers + # +1 because the layer index starts at 1 + Assert.eq( + len(self.transformer.per_layer_lr_scale), self.transformer.num_layers + self.prediction_heads - 1 + 1 + ) def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - super().setup_tensor_space(tensor_space) + self.transformer.setup_tensor_space(tensor_space) tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Embedding dimensions - if self.use_absolute_position_embeddings: - tensor_space.add_tensor_dim( - TensorDim(LanguageModelDimNames.position_embed, self.absolute_position_embeddings) - ) + tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.position_embed, self.max_position_embeddings)) # TODO: Need both? tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab, self.vocab_size)) tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab_tp, self.vocab_size, tensor)) - @functools.cached_property - def word_embedding_weight_initialization_method(self) -> Initializer: - if self.word_embedding_weight_initialization.has_initialization: - return self.word_embedding_weight_initialization.get_initializer() - else: - return self.hidden_size**-0.5 + @property + def num_absolute_position_embeddings(self) -> int: + # TODO: Rename from max embeddings. + return self.max_position_embeddings if self.use_absolute_position_embeddings else None @property def use_absolute_position_embeddings(self) -> int: # TODO: Set through num embeddings instead instead. - return self.absolute_position_embeddings is not None - - @functools.cached_property - def output_weight_initialization_method(self) -> Initializer: - if self.output_weight_initialization.has_initialization: - return self.output_weight_initialization.get_initializer() - else: - return self.hidden_size**-0.5 + return self.use_position_embeddings + + @classmethod + def from_flat_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + ) -> typing.Self: + # The backward compatibility fix in `NormalizationArchitectureConfig` + # won't work for older checkpoints saved with a flat config. + # TODO v0.3: Remove flat format + cls._handle_renamed_field(default, "normalization_type", "type") + cls._handle_renamed_field(default, "layer_norm_eps", "epsilon") + cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") + return super().from_flat_dict(default, strict) diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index b49fef7ba..05678a700 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -7,28 +7,28 @@ from fast_llm.core.ops import reduce_forward, split from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.language_model.config import LanguageModelConfig, LanguageModelDimNames, LanguageModelKwargs -from fast_llm.tensor import ParameterMeta, TensorMeta +from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelDimNames, LanguageModelKwargs +from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ from fast_llm.utils import Assert WORD_EMBEDDINGS_WEIGHT = "word_embeddings_weight" -class LanguageModelEmbedding[ConfigType: LanguageModelConfig](Configurable[ConfigType], Layer): +class LanguageModelEmbedding[ConfigType: LanguageModelBaseConfig](Configurable[LanguageModelBaseConfig], Layer): """ A language model embedding layer. Consists of word embeddings (tensor-parallel or sequence-tensor-parallel), together with optional absolute position embeddings and dropout. """ - config_class: typing.ClassVar[type[LanguageModelConfig]] = LanguageModelConfig + config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = LanguageModelBaseConfig # Ensure the layer is on its own stage. layer_count: float = 1000.0 def __init__( self, - config: LanguageModelConfig, + config: LanguageModelBaseConfig, tensor_space: TensorSpace, ): super().__init__(config) @@ -36,14 +36,14 @@ def __init__( self._tensor_space = tensor_space self._residual_dtype = ( self._distributed_config.optimization_dtype - if self._config.full_precision_residual + if config.transformer.full_precision_residual else self._distributed_config.training_dtype ).torch self._group_size = self._distributed_config.tensor_parallel self._sequence_parallel = self._distributed_config.sequence_tensor_parallel - self._parallel_embeddings = ( - tensor_space.distributed_config.tensor_parallel > 1 and self._config.parallel_embeddings - ) + self._parallel_embeddings = tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings + self._dropout_p = config.transformer.hidden_dropout + self._use_absolute_position_embeddings = config.use_absolute_position_embeddings hidden_dim = tensor_space[LanguageModelDimNames.hidden] vocab_dim = tensor_space[ @@ -56,15 +56,23 @@ def __init__( self.word_embeddings_weight = ParameterMeta.from_dims( (vocab_dim, hidden_dim), - init_method=self._config.word_embedding_weight_initialization_method, - lr_scale=self._config.embeddings_lr_scale, + init_method=init_normal_( + std=config.init_method_std_embed, + min_val=config.init_method_min_embed, + max_val=config.init_method_max_embed, + ), + lr_scale=config.embeddings_lr_scale, ) - if self._config.use_absolute_position_embeddings: + if self._use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( (tensor_space[LanguageModelDimNames.position_embed], hidden_dim), - init_method=self._config.position_embedding_weight_initialization_method, - allow_sequence_tensor_parallel=not self._config.parallel_embeddings, - lr_scale=self._config.embeddings_lr_scale, + init_method=init_normal_( + std=config.init_method_std_embed, + min_val=config.init_method_min_embed, + max_val=config.init_method_max_embed, + ), + allow_sequence_tensor_parallel=not config.parallel_embeddings, + lr_scale=config.embeddings_lr_scale, ) # PEFT. @@ -76,21 +84,21 @@ def __init__( @torch.compile def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask_inputs: bool) -> torch.Tensor: - Assert.eq(position_ids is not None, self._config.use_absolute_position_embeddings) + Assert.eq(position_ids is not None, self._use_absolute_position_embeddings) group = self._tensor_space.distributed.tensor_group if self._parallel_embeddings: input_mask = (input_ >= self._vocab_start_index) * (input_ < self._vocab_end_index) masked_input = (input_ - self._vocab_start_index) * input_mask embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * input_mask.unsqueeze(2) # noqa embeddings = reduce_forward(embeddings, group) - if self._config.use_absolute_position_embeddings: + if self._use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) else: if self._sequence_parallel: input_ = split(input_, group=group, dim=0) - if self._config.use_absolute_position_embeddings: + if self._use_absolute_position_embeddings: position_ids = split(position_ids, group=group, dim=0) # handle masked tokens if mask_inputs: @@ -99,7 +107,7 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask embeddings = torch.embedding(self.word_embeddings_weight, masked_input) else: embeddings = torch.embedding(self.word_embeddings_weight, input_) - if self._config.use_absolute_position_embeddings: + if self._use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if mask_inputs: embeddings = embeddings * input_mask.unsqueeze(2) @@ -108,7 +116,7 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask if self._sequence_parallel else self._tensor_space.distributed.pp_generator ): - embeddings = torch.dropout(embeddings, self._config.embedding_dropout, self.training) + embeddings = torch.dropout(embeddings, self._dropout_p, self.training) return embeddings.to(dtype=self._residual_dtype) def forward( diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 098b2463b..bc672725c 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -15,16 +15,16 @@ from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward -from fast_llm.layers.block.block import DebugLayer from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.language_model.config import ( - LanguageModelConfig, + LanguageModelBaseConfig, LanguageModelDimNames, LanguageModelKwargs, LanguageModelLossNames, ) from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT -from fast_llm.tensor import ParameterMeta, TensorMeta +from fast_llm.logging import log_distributed_tensor +from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ from fast_llm.utils import Assert, div, get_unique logger = logging.getLogger(__name__) @@ -32,67 +32,61 @@ OUTPUT_WEIGHTS = "output_weights" -class LanguageModelHead[ConfigType: LanguageModelConfig](Configurable[ConfigType], Layer): +class LanguageModelHead[ConfigType: LanguageModelBaseConfig](Configurable[LanguageModelBaseConfig], Layer): """ A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). """ - config_class: typing.ClassVar[type[LanguageModelConfig]] = LanguageModelConfig + config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = LanguageModelBaseConfig def __init__( self, - config: LanguageModelConfig, + config: LanguageModelBaseConfig, tensor_space: TensorSpace, prediction_distance: int, ): super().__init__(config) - # TODO: Avoid default_block_config? - self._debug = DebugLayer( - tensor_space, - f"Block {self._block_index} {self._name}", - self._config.default_block_config.debug_transformer, - self._config.default_block_config.debug_transformer_memory, - ) + self._debug_transformer = config.transformer.debug_transformer + self._tie_word_embeddings = config.tie_word_embeddings self._tensor_space = tensor_space self._group_size = tensor_space.distributed_config.tensor_parallel self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel - self._parallel_embeddings = ( - tensor_space.distributed_config.tensor_parallel > 1 and self._config.parallel_embeddings - ) + self._parallel_embeddings = tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings self._sequence_parallel_logits = ( - tensor_space.distributed_config.sequence_tensor_parallel and not self._config.parallel_embeddings + tensor_space.distributed_config.sequence_tensor_parallel and not config.parallel_embeddings ) - self._cross_entropy_splits = self._config.cross_entropy_splits + self._cross_entropy_splits = config.cross_entropy_splits if self._cross_entropy_splits is not None and self._sequence_parallel: assert not self._parallel_embeddings hidden_dim = self._tensor_space[LanguageModelDimNames.hidden] self._loss_coefficient = ( - self._config.prediction_loss_coefficient[prediction_distance] - if self._config.prediction_loss_coefficient - else 1.0 + config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0 ) self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance) - # TODO: Avoid default_block_config? - self.final_norm = self._config.default_block_config.normalization.get_layer(hidden_dim) - self._logits_scale_factor = self._config.logits_scale_factor - self._language_model_loss_factor = self._config.language_model_loss_factor - self._distillation_loss_factor = self._config.distillation_loss_factor - self._z_loss_factor = self._config.logit_z_loss + self.final_norm = config.transformer.normalization.get_layer(hidden_dim) + self._logits_scale_factor = config.logits_scale_factor + self._language_model_loss_factor = config.language_model_loss_factor + self._distillation_loss_factor = config.distillation_loss_factor + self._z_loss_factor = config.logit_z_loss # Distance of the target token prediction # 0: next-token prediction # >0: multi-token prediction (MTP) Assert.geq(prediction_distance, 0) self._prediction_distance = prediction_distance - self._is_last_head = self._prediction_distance == self._config.prediction_heads - 1 + self._is_last_head = self._prediction_distance == config.prediction_heads - 1 - self._init_output_weights(hidden_dim, self._config) + self._init_output_weights(hidden_dim, config) - if not self._config.enable_dpo: - self._cross_entropy_impl = self._config.cross_entropy_impl + self._use_dpo_loss = config.enable_dpo + if self._use_dpo_loss: + self.dpo_beta = config.dpo_beta + else: + self._cross_entropy_impl = config.cross_entropy_impl + self._distillation_loss_implementation = config.distillation_loss_implementation if self._cross_entropy_impl == CrossEntropyImpl.auto: if self._parallel_embeddings: self._cross_entropy_impl = CrossEntropyImpl.fused @@ -110,7 +104,7 @@ def __init__( def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: # Only the first head defines the output weights - if self._config.tie_word_embeddings or self._prediction_distance > 0: + if self._tie_word_embeddings or self._prediction_distance > 0: return # untie embedding weights vocab_dim = self._tensor_space[ @@ -118,7 +112,11 @@ def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: ] self.output_weights = ParameterMeta.from_dims( (vocab_dim, hidden_dim), - init_method=self._config.output_weight_initialization_method, + init_method=init_normal_( + std=config.init_method_std_embed, + min_val=config.init_method_min_embed, + max_val=config.init_method_max_embed, + ), lr_scale=config.output_lr_scale, ) @@ -203,7 +201,7 @@ def _get_targets( self, kwargs: dict ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | None: # Loss mask for distillation. (Labels are already masked.) - if self._config.enable_dpo: + if self._use_dpo_loss: dpo_target = kwargs.get(LanguageModelKwargs.labels) lm_target = None distillation_target = None @@ -253,7 +251,7 @@ def _get_targets( return targets def _get_output_weights(self, kwargs: dict) -> torch.Tensor: - if self._config.tie_word_embeddings: + if self._tie_word_embeddings: return kwargs[WORD_EMBEDDINGS_WEIGHT] if self._prediction_distance > 0: return kwargs[OUTPUT_WEIGHTS] @@ -340,22 +338,35 @@ def _logits_cross_entropy_forward_backward( LanguageModelLossNames.z_loss, logits_scale_factor=self._logits_scale_factor, ) - if self._debug.enabled and self._cross_entropy_splits is None: - vocab_dim = ( + if self._debug_transformer and self._cross_entropy_splits is None: + vocab_dim = self._tensor_space[ LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp + ] + dims = [*kwargs[LanguageModelKwargs.hidden_dims][:-1], vocab_dim] + sequence_index = 1 - int(kwargs[LanguageModelKwargs.sequence_first]) + dims[sequence_index] = ( + TensorDim( + LanguageModelDimNames.sequence_q_tp, dims[sequence_index].global_size, DistributedDimNames.tensor + ) + if self._sequence_parallel_logits + else TensorDim(LanguageModelDimNames.sequence_q, dims[sequence_index].global_size) ) - sequence_dim = ( - LanguageModelDimNames.sequence_q_tp + + dim_names = ( + [LanguageModelDimNames.sequence_q_tp, LanguageModelDimNames.vocab] if self._sequence_parallel_logits - else LanguageModelDimNames.sequence_q + else [LanguageModelDimNames.sequence_q, LanguageModelDimNames.vocab_tp] ) - batch_dim = kwargs[LanguageModelKwargs.hidden_dims][1 if kwargs[LanguageModelKwargs.sequence_first] else 0] - dims = ( - (sequence_dim, batch_dim, vocab_dim) - if kwargs[LanguageModelKwargs.sequence_first] - else (batch_dim, sequence_dim, vocab_dim) + + dim_names.insert(int(kwargs[LanguageModelKwargs.sequence_first]), LanguageModelDimNames.batch) + log_distributed_tensor( + "", + logits, + level=self._debug_transformer, + meta=TensorMeta.from_dims(tuple(dims), tensor_name="transformer logits", dtype=logits.dtype), + distributed=self._tensor_space.distributed, + scale=self._logits_scale_factor, ) - self._debug(logits, "Language model logits", dims, kwargs, scale=self._logits_scale_factor) if targets is None: return logits * self._logits_scale_factor, None @@ -368,7 +379,7 @@ def _logits_cross_entropy_forward_backward( kwargs.get(f"{self._config.dpo_reference_model}_logits"), kwargs[LanguageModelKwargs.chosen_spans], kwargs[LanguageModelKwargs.rejected_spans], - self._config.dpo_beta, + self.dpo_beta, grad_output * self._loss_coefficient, ) else: @@ -390,7 +401,7 @@ def _logits_cross_entropy_forward_backward( lm_loss, lm_grad = None, None if distillation_target is not None and self._distillation_loss_factor > 0.0: - if self._config.distillation_loss_implementation == DistillationLossImpl.reverse_kl: + if self._distillation_loss_implementation == DistillationLossImpl.reverse_kl: distillation_loss, distillation_grad = reverse_kl_forward_backward( logits.flatten(0, -2), distillation_target, @@ -403,7 +414,7 @@ def _logits_cross_entropy_forward_backward( TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits ), ) - elif self._config.distillation_loss_implementation == DistillationLossImpl.cross_entropy: + elif self._distillation_loss_implementation == DistillationLossImpl.cross_entropy: distillation_loss, distillation_grad = cross_entropy_forward_backward( logits.flatten(0, -2), distillation_target, diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index 3c9f18c8d..f5d915855 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -5,7 +5,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.layers.language_model.config import LanguageModelConfig, LanguageModelKwargs +from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert @@ -20,11 +20,11 @@ class PositionEmbeddingPreprocessor(Preprocessor): def __init__( self, - config: LanguageModelConfig, + config: LanguageModelBaseConfig, tensor_space: TensorSpace, ): self._config = config - assert config.absolute_position_embeddings is not None + assert config.use_absolute_position_embeddings self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] @@ -34,7 +34,7 @@ def _create_tensors(self, sequence_length: int) -> None: return self._tensor_cache_max_sequence_length = sequence_length - Assert.leq(sequence_length, self._config.absolute_position_embeddings) + Assert.leq(sequence_length, self._config.num_absolute_position_embeddings) self._position_ids = torch.arange( 0, sequence_length, device=self._tensor_space.distributed.device, dtype=torch.int64 ) @@ -71,7 +71,7 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: class PreferenceSpanPreprocessor(Preprocessor): - def __init__(self, config: LanguageModelConfig, tensor_space: TensorSpace): + def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 00c709814..efcf2d873 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -9,7 +9,7 @@ from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: - from fast_llm.engine.config_utils.initialization import Initializer, init_fill_, init_uniform_centered_ + from fast_llm.tensor import Initializer class SSMDimNames(BlockDimNames): @@ -66,6 +66,8 @@ class DTInitType(enum.StrEnum): random = "random" def get_init_method(self, scale: float) -> "Initializer": + from fast_llm.tensor import init_fill_, init_uniform_centered_ + return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 04b27af47..550c44d0f 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -4,15 +4,13 @@ import einops import torch -from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_, init_zeros_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.ssm.mamba_layer import init_kaiming_ -from fast_llm.tensor import ParameterMeta +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ from fast_llm.utils import get_lr_scale logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index b02fbd401..1c319f490 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -3,15 +3,14 @@ import torch -from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias, init_kaiming_ -from fast_llm.tensor import ParameterMeta +from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_ from fast_llm.utils import Assert, div, get_lr_scale try: diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index e22852fe6..f5b0139cf 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -4,14 +4,13 @@ import torch -from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.tensor import ParameterMeta +from fast_llm.tensor import LambdaInitializer, ParameterMeta, init_kaiming_, init_ones_ from fast_llm.utils import Assert, get_lr_scale try: @@ -164,7 +163,3 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if kwargs[BlockKwargs.sequence_first]: out = out.transpose(0, 1) return out, None - - -def init_kaiming_(d_in: float) -> LambdaInitializer: - return init_normal_(0.0, math.sqrt(2.0 / d_in)) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 2db7b0ac8..b1de792e3 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -6,10 +6,11 @@ from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.autograd import wrap_forward_backward -from fast_llm.layers.block.block import BlockLayer +from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import AttentionConfig, AttentionDimNames, AttentionKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import get_lr_scale try: @@ -45,52 +46,55 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention[ConfigType: AttentionConfig](BlockLayer[ConfigType]): +class Attention(Mixer): """ A self-attention layer. """ + _mixer_name: typing.ClassVar[str] = "attn" + _QUERY_DIMS = ( - AttentionDimNames.batch, - AttentionDimNames.sequence_q, - AttentionDimNames.composite_heads, - AttentionDimNames.kv_channels, + TransformerDimNames.batch, + TransformerDimNames.sequence_q, + TransformerDimNames.composite_heads, + TransformerDimNames.kv_channels, ) _KV_DIMS = ( - AttentionDimNames.batch, - AttentionDimNames.sequence_q, - AttentionDimNames.head_groups, - AttentionDimNames.kv_channels, + TransformerDimNames.batch, + TransformerDimNames.sequence_q, + TransformerDimNames.head_groups, + TransformerDimNames.kv_channels, ) _CONTEXT_DIMS = ( - AttentionDimNames.batch, - AttentionDimNames.sequence_q, - AttentionDimNames.composite_dense, + TransformerDimNames.batch, + TransformerDimNames.sequence_q, + TransformerDimNames.composite_dense, ) - def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): - super().__init__(config, tensor_space, block_index, name) + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int): + super().__init__(tensor_space, block_index, config.debug_transformer) self._config = config self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) - # init_method_qkv = init_normal_( - # std=self._config.init_method_std_qkv, - # min_val=self._config.init_method_min_qkv, - # max_val=self._config.init_method_max_qkv, - # ) - # init_method_std_attn_proj = init_normal_( - # std=self._config.init_method_std_attn_proj, - # min_val=self._config.init_method_min_attn_proj, - # max_val=self._config.init_method_max_attn_proj, - # ) - self._kv_channels = self._tensor_space[AttentionDimNames.kv_channels].size - self._head_groups = self._tensor_space[AttentionDimNames.head_groups].global_size - self._local_head_groups = self._tensor_space[AttentionDimNames.head_groups].size - self._local_heads_per_group = self._tensor_space[AttentionDimNames.group_heads].size + init_method_qkv = init_normal_( + std=self._config.init_method_std_qkv, + min_val=self._config.init_method_min_qkv, + max_val=self._config.init_method_max_qkv, + ) + init_method_std_attn_proj = init_normal_( + std=self._config.init_method_std_attn_proj, + min_val=self._config.init_method_min_attn_proj, + max_val=self._config.init_method_max_attn_proj, + ) + + self._kv_channels = self._tensor_space[TransformerDimNames.kv_channels].size + self._head_groups = self._tensor_space[TransformerDimNames.head_groups].global_size + self._local_head_groups = self._tensor_space[TransformerDimNames.head_groups].size + self._local_heads_per_group = self._tensor_space[TransformerDimNames.group_heads].size self._local_heads = self._local_head_groups * self._local_heads_per_group - self._softmax_scale: float = self._kv_channels ** (-self._config.attention_softmax_scale_power) + self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) - hidden_dim = self._tensor_space[AttentionDimNames.hidden] + hidden_dim = self._tensor_space[TransformerDimNames.hidden] layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) @@ -98,19 +102,19 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: i # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, - self._tensor_space[AttentionDimNames.composite_query], + self._tensor_space[TransformerDimNames.composite_query], bias=self._config.add_attn_qkv_bias, - weight_init_method=self._config.qkv_weight_initialization_method, - bias_init_method=self._config.qkv_bias_initialization_method, + weight_init_method=init_method_qkv, + bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, lr_scale=attention_lr_scale, ) self.key_value = OutputParallelLinear( hidden_dim, - self._tensor_space[AttentionDimNames.composite_key_value], + self._tensor_space[TransformerDimNames.composite_key_value], bias=self._config.add_attn_qkv_bias, - weight_init_method=self._config.qkv_weight_initialization_method, - bias_init_method=self._config.qkv_bias_initialization_method, + weight_init_method=init_method_qkv, + bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, lr_scale=attention_lr_scale, ) @@ -121,11 +125,11 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: i # Output. self.dense = InputParallelLinear( - self._tensor_space[AttentionDimNames.composite_dense], + self._tensor_space[TransformerDimNames.composite_dense], hidden_dim, bias=self._config.add_attn_dense_bias, - weight_init_method=self._config.dense_weight_initialization_method, - bias_init_method=self._config.dense_bias_initialization_method, + weight_init_method=init_method_std_attn_proj, + bias_init_method=init_method_std_attn_proj if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, lr_scale=attention_lr_scale, ) @@ -255,24 +259,18 @@ def _decide_window_size(self) -> int | None: return window_size - def forward( - self, - input_: torch.Tensor, - kwargs: dict[str, typing.Any], - losses: dict[str, typing.Any] | None = None, - metrics: dict[str, typing.Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - sequence_first = kwargs[AttentionKwargs.sequence_first] + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + sequence_first = kwargs[TransformerKwargs.sequence_first] query, key_value = self._query_key_value(input_, sequence_first) # TODO: Move the rest to function. - if (past_key_values := kwargs.get(AttentionKwargs.past_key_values)) is not None: + if (past_key_values := kwargs.get(TransformerKwargs.past_key_values)) is not None: assert sequence_first # Clear the lists so tensors can be de-allocated key_value = torch.cat((past_key_values.pop(0), key_value), dim=0) - if (presents := kwargs.get(AttentionKwargs.presents)) is not None: + if (presents := kwargs.get(TransformerKwargs.presents)) is not None: # Return the presents as a leaf tensors so the gradients from later micro-sequences # don't propagate to this one. presents.append(present := key_value.detach().requires_grad_()) @@ -281,9 +279,9 @@ def forward( if self._tensor_space.distributed.sequence_data_group: key_value = ( - key_value[: kwargs[AttentionKwargs.sequence_k_dim].size] + key_value[: kwargs[TransformerKwargs.sequence_k_dim].size] if sequence_first - else key_value[:, : kwargs[AttentionKwargs.sequence_k_dim].size] + else key_value[:, : kwargs[TransformerKwargs.sequence_k_dim].size] ) if sequence_first: @@ -297,9 +295,9 @@ def forward( key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels) value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels) - if self._debug.enabled: - self._debug(query, "query_rotary_input", self._QUERY_DIMS, kwargs) - self._debug( + if self._debug_level: + self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs) + self._debug_log( key, "key_rotary_input", self._KV_DIMS, @@ -312,7 +310,7 @@ def forward( if self._use_flash_attention: assert _flash_available with set_generator(self._tensor_space.distributed.tp_generator): - if (cu_seqlens_q := kwargs.get(AttentionKwargs.cu_seqlens_q, None)) is not None: + if (cu_seqlens_q := kwargs.get(TransformerKwargs.cu_seqlens_q, None)) is not None: out_dims = query.size() query = query.view(-1, query.size(-2), query.size(-1)) key = key.view(-1, key.size(-2), key.size(-1)) @@ -322,9 +320,9 @@ def forward( key, value, cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=kwargs.get(AttentionKwargs.cu_seqlens_k), - max_seqlen_q=kwargs.get(AttentionKwargs.max_seqlen_q), - max_seqlen_k=kwargs.get(AttentionKwargs.max_seqlen_k), + cu_seqlens_k=kwargs.get(TransformerKwargs.cu_seqlens_k), + max_seqlen_q=kwargs.get(TransformerKwargs.max_seqlen_q), + max_seqlen_k=kwargs.get(TransformerKwargs.max_seqlen_k), dropout_p=self._config.attention_dropout if self.training else 0.0, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), causal=True, @@ -347,15 +345,25 @@ def forward( query.flatten(-2), key.flatten(-2), value.flatten(-2), - kwargs[AttentionKwargs.attention_mask], - kwargs[AttentionKwargs.attention_mask_value], + kwargs[TransformerKwargs.attention_mask], + kwargs[TransformerKwargs.attention_mask_value], ) - if self._debug.enabled: - self._debug(query, "query", self._QUERY_DIMS, kwargs) - self._debug(key, "key", self._KV_DIMS, kwargs) - self._debug(value, "value", self._KV_DIMS, kwargs) - self._debug(input_, "context", self._CONTEXT_DIMS, kwargs) + if self._debug_level: + self._debug_log(query, "query", self._QUERY_DIMS, kwargs) + self._debug_log( + key, + "key", + self._KV_DIMS, + kwargs, + ) + self._debug_log( + value, + "value", + self._KV_DIMS, + kwargs, + ) + self._debug_log(input_, "context", self._CONTEXT_DIMS, kwargs) if sequence_first: # TODO: Optimize (is contiguous avoidable? Transpose dense output?) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index bd72bd305..ebb976e63 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -3,29 +3,22 @@ import typing import warnings -from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_zeros_ from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import TritonConfig -from fast_llm.layers.block.config import ( - AddLinearBiasChoices, - BlockDimNames, - BlockKwargs, - BlockLayerConfig, - MixerConfig, -) +from fast_llm.layers.block.config import AddLinearBiasChoices, BlockConfig, BlockDimNames, BlockKwargs from fast_llm.layers.transformer.rotary.config import RotaryConfig from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: - from fast_llm.layers.transformer.attention import Attention + pass logger = logging.getLogger(__name__) -class AttentionDimNames(BlockDimNames): +class TransformerDimNames(BlockDimNames): # A set of common tensor dim names packed into a namespace. # Self-attention dimensions head_groups = "head_groups" @@ -38,7 +31,7 @@ class AttentionDimNames(BlockDimNames): composite_dense = "composite_dense" -class AttentionKwargs(BlockKwargs): +class TransformerKwargs(BlockKwargs): rotary_freq_q = "rotary_freq_q" rotary_freq_k = "rotary_freq_k" attention_mask = "attention_mask" @@ -52,8 +45,9 @@ class AttentionKwargs(BlockKwargs): past_key_values = "past_key_values" -@config_class(dynamic_type={BlockLayerConfig: "attention"}) -class AttentionConfig(MixerConfig): +@config_class() +class AttentionConfig(Config): + # TODO: Make mixer class dynamic. _abstract = False # TODO: Review names @@ -113,30 +107,7 @@ class AttentionConfig(MixerConfig): valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - qkv_weight_initialization: InitializationConfig = Field( - desc="Initialization configuration for the query, key and value layer weights. Default: hidden_size**-0.5", - hint=FieldHint.feature, - ) - qkv_bias_initialization: InitializationConfig = Field( - desc="Initialization configuration for the query, key and value layer biases. Default: fill with zeros.", - hint=FieldHint.feature, - ) - dense_weight_initialization: InitializationConfig = Field( - desc="Initialization configuration for the dense layer weight. Default: (2 * num_blocks * hidden_size)**-0.5", - hint=FieldHint.feature, - ) - dense_bias_initialization: InitializationConfig = Field( - desc="Initialization configuration for the dense layer biases. Default: fill with zeros.", - hint=FieldHint.feature, - ) - def _validate(self) -> None: - - with self._set_implicit_default(): - if self.kv_channels is None: - # TODO: hidden_size not yet validated. - self.kv_channels = div(self.block.block_sequence.hidden_size, self.num_attention_heads) - super()._validate() if not TritonConfig.TRITON_ENABLED: @@ -159,74 +130,182 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor_space.add_tensor_dim( head_groups := TensorDim( - AttentionDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None + TransformerDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None ) ) tensor_space.add_tensor_dim( group_heads := TensorDim( - AttentionDimNames.group_heads, + TransformerDimNames.group_heads, div(self.num_attention_heads, self.head_groups), None if self.head_groups > 1 else tensor, ) ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(AttentionDimNames.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(AttentionDimNames.kv_channels, self.kv_channels)) - tensor_space.add_tensor_dim(CompositeTensorDim(AttentionDimNames.composite_heads, (head_groups, group_heads))) + tensor_space.add_tensor_dim(key_and_value := TensorDim(TransformerDimNames.key_and_value, 2)) + tensor_space.add_tensor_dim(kv_channels := TensorDim(TransformerDimNames.kv_channels, self.kv_channels)) + tensor_space.add_tensor_dim( + CompositeTensorDim(TransformerDimNames.composite_heads, (head_groups, group_heads)) + ) tensor_space.add_tensor_dim( - CompositeTensorDim(AttentionDimNames.composite_query, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(TransformerDimNames.composite_query, (head_groups, group_heads, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(AttentionDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) + CompositeTensorDim(TransformerDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(AttentionDimNames.composite_dense, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(TransformerDimNames.composite_dense, (head_groups, group_heads, kv_channels)) ) - def get_block(self) -> "Attention": - pass - @functools.cached_property - def add_qkv_bias(self) -> bool: - if isinstance(self.block.add_linear_biases, bool): - return self.block.add_linear_biases - if self.block.add_linear_biases == AddLinearBiasChoices.nowhere: - return False - return True +@config_class() +# TODO: Use composition for attention config +class TransformerConfig(AttentionConfig, BlockConfig): + _abstract = False - @functools.cached_property - def add_dense_bias(self) -> bool: - if isinstance(self.block.add_linear_biases, bool): - return self.block.add_linear_biases - if self.block.add_linear_biases == AddLinearBiasChoices.everywhere: - return True - return False + # TODO: Review names + init_method_std: float = Field( + default=None, + desc="Default scale for weight initialization. Default: hidden_size**-0.5", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max: float | None = Field( + default=None, + desc="Max value for clamping initialized weights. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min: float | None = Field( + default=None, + desc="Min value for clamping initialized weights. Default: -float('inf')", + hint=FieldHint.optional, + ) + init_method_std_qkv: float = Field( + default=None, + desc="Scale for the query, key and value weight initialization. Default: init_method_std", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max_qkv: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for query, key and value matrices. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_qkv: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for query, key and value matrices. Default: -float('inf')", + hint=FieldHint.optional, + ) + init_method_std_attn_proj: float = Field( + default=None, + desc="Scale for the attention projection weight initialization. Default: init_method_std", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max_attn_proj: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for attention projection. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_attn_proj: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for attention projection. Default: -float('inf')", + hint=FieldHint.optional, + ) + init_method_std_mlp_1: float = Field( + default=None, + desc="Scale for the MLP first layer weight initialization. Default: init_method_std", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max_mlp_1: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for MLP first layer. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_mlp_1: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for MLP first layer. Default: -float('inf')", + hint=FieldHint.optional, + ) + init_method_std_mlp_2: float = Field( + default=None, + desc="Scale for the MLP second layer weight initialization. Default: init_method_std", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max_mlp_2: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for MLP second layer. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_mlp_2: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for MLP second layer. Default: -float('inf')", + hint=FieldHint.optional, + ) + # Use random inits instead of constant values, useful for debugging. + random_bias_init: bool = Field( + default=False, + desc="Initialize the biases using the initialization method of their respective weights instead of setting them to zero. Used to test for issues that may not be visible when the biases are zero.", + hint=FieldHint.testing, + ) - @functools.cached_property - def qkv_weight_initialization_method(self) -> Initializer: - if self.qkv_weight_initialization.has_initialization: - return self.qkv_weight_initialization.get_initializer() - else: - return self.block.block_sequence.hidden_size**-0.5 + def _validate(self) -> None: + with self._set_implicit_default(): + if self.kv_channels is None: + self.kv_channels = div(self.hidden_size, self.num_attention_heads) + if self.init_method_std is None: + self.init_method_std = self.hidden_size**-0.5 + if self.init_method_std_qkv is None: + self.init_method_std_qkv = self.init_method_std + if self.init_method_std_attn_proj is None: + self.init_method_std_attn_proj = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 + if self.init_method_std_mlp_1 is None: + self.init_method_std_mlp_1 = self.init_method_std + if self.init_method_std_mlp_2 is None: + self.init_method_std_mlp_2 = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 + if self.init_method_max_qkv is None: + self.init_method_max_qkv = self.init_method_max + if self.init_method_min_qkv is None: + self.init_method_min_qkv = self.init_method_min + if self.init_method_max_attn_proj is None: + self.init_method_max_attn_proj = self.init_method_max + if self.init_method_min_attn_proj is None: + self.init_method_min_attn_proj = self.init_method_min + if self.init_method_max_mlp_1 is None: + self.init_method_max_mlp_1 = self.init_method_max + if self.init_method_min_mlp_1 is None: + self.init_method_min_mlp_1 = self.init_method_min + if self.init_method_max_mlp_2 is None: + self.init_method_max_mlp_2 = self.init_method_max + if self.init_method_min_mlp_2 is None: + self.init_method_min_mlp_2 = self.init_method_min + if self.init_method_min is not None and self.init_method_max is not None: + Assert.leq(self.init_method_min, self.init_method_max) + if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: + Assert.leq(self.init_method_min, self.init_method_max) + if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: + Assert.leq(self.init_method_min_qkv, self.init_method_max_qkv) + if self.init_method_min_attn_proj is not None and self.init_method_max_attn_proj is not None: + Assert.leq(self.init_method_min_attn_proj, self.init_method_max_attn_proj) + if self.init_method_min_mlp_1 is not None and self.init_method_max_mlp_1 is not None: + Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) + if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: + Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) - @functools.cached_property - def qkv_bias_initialization_method(self) -> Initializer: - if self.qkv_bias_initialization.has_initialization: - assert self.add_qkv_bias - return self.qkv_bias_initialization.get_initializer() - else: - return init_zeros_ + super()._validate() - @functools.cached_property - def dense_weight_initialization_method(self) -> Initializer: - if self.dense_weight_initialization.has_initialization: - return self.dense_weight_initialization.get_initializer() - else: - return self.block.block_sequence.hidden_size**-0.5 / max(2 * self.block.block_sequence.num_blocks, 1) + @property + def add_attn_qkv_bias(self) -> bool: + if isinstance(self.add_linear_biases, bool): + return self.add_linear_biases + if self.add_linear_biases == AddLinearBiasChoices.nowhere: + return False + return True - @functools.cached_property - def dense_bias_initialization_method(self) -> Initializer: - if self.dense_bias_initialization.has_initialization: - assert self.add_dense_bias - return self.dense_bias_initialization.get_initializer() - else: - return init_zeros_ + @property + def add_attn_dense_bias(self) -> bool: + if isinstance(self.add_linear_biases, bool): + return self.add_linear_biases + if self.add_linear_biases == AddLinearBiasChoices.everywhere: + return True + return False diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 16e5811e6..3f0e14eb7 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -5,7 +5,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.layers.transformer.config import AttentionConfig, AttentionKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) @@ -21,7 +21,7 @@ class BackupAttentionPreprocessor(Preprocessor): def __init__( self, - config: AttentionConfig, + config: TransformerConfig, tensor_space: TensorSpace, ): self._config = config @@ -51,13 +51,13 @@ def _create_tensors(self, sequence_length: int) -> None: ) def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[AttentionKwargs.sequence_length]) - sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size - sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size - kwargs[AttentionKwargs.attention_mask] = self._mask[ + self._create_tensors(kwargs[TransformerKwargs.sequence_length]) + sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size + sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size + kwargs[TransformerKwargs.attention_mask] = self._mask[ None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k ] - if (sequence_lengths := kwargs.get(AttentionKwargs.sequence_lengths, None)) is not None: + if (sequence_lengths := kwargs.get(TransformerKwargs.sequence_lengths, None)) is not None: seq_ids = torch.stack( [ torch.cat([torch.full((x,), i) for i, x in enumerate(sample_lens)]) @@ -65,33 +65,33 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: ] ) document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(self._tensor_space.distributed.device) - kwargs[AttentionKwargs.attention_mask] = ( - kwargs[AttentionKwargs.attention_mask] + kwargs[TransformerKwargs.attention_mask] = ( + kwargs[TransformerKwargs.attention_mask] & document_mask[:, None, sequence_k - sequence_q : sequence_k, None, :sequence_k] ) - kwargs[AttentionKwargs.attention_mask_value] = self._mask_value + kwargs[TransformerKwargs.attention_mask_value] = self._mask_value def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[AttentionKwargs.attention_mask] = TensorMeta.from_dims( + kwargs[TransformerKwargs.attention_mask] = TensorMeta.from_dims( ( self._scalar_dim, self._scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], + kwargs[TransformerKwargs.sequence_q_dim], self._scalar_dim, - kwargs[AttentionKwargs.sequence_k_dim], + kwargs[TransformerKwargs.sequence_k_dim], ), - tensor_name=AttentionKwargs.attention_mask, + tensor_name=TransformerKwargs.attention_mask, dtype=torch.bool, ) - kwargs[AttentionKwargs.attention_mask_value] = TensorMeta.from_dims( + kwargs[TransformerKwargs.attention_mask_value] = TensorMeta.from_dims( (self._scalar_dim,), - tensor_name=AttentionKwargs.attention_mask_value, + tensor_name=TransformerKwargs.attention_mask_value, dtype=self._tensor_space.distributed_config.training_dtype.torch, ) class FlashAttnVarlenPreprocessor(Preprocessor): - def __init__(self, config: AttentionConfig, tensor_space: TensorSpace): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config @@ -107,12 +107,12 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: also contain previous tokens from the first document in micro-sequence. We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths. """ - if AttentionKwargs.sequence_lengths not in kwargs: + if TransformerKwargs.sequence_lengths not in kwargs: return - sequence_lengths = kwargs[AttentionKwargs.sequence_lengths] - sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size - sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size - if sequence_q < kwargs[AttentionKwargs.sequence_length]: + sequence_lengths = kwargs[TransformerKwargs.sequence_lengths] + sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size + sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size + if sequence_q < kwargs[TransformerKwargs.sequence_length]: cumsums = [torch.cumsum(x, dim=0) for x in sequence_lengths] # The first and last documents in a microsequence need to be handled separately. Include all tokens from other documents # in the microsequence. We need to consider all keys computed so far from the first sample. We also store the offsets @@ -146,17 +146,17 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: else: seqlens_q = torch.cat(sequence_lengths) seqlens_k = torch.cat(sequence_lengths) - kwargs[AttentionKwargs.cu_seqlens_q] = torch.cat( + kwargs[TransformerKwargs.cu_seqlens_q] = torch.cat( ( torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), torch.cumsum(seqlens_q, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), ) ) - kwargs[AttentionKwargs.cu_seqlens_k] = torch.cat( + kwargs[TransformerKwargs.cu_seqlens_k] = torch.cat( ( torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), torch.cumsum(seqlens_k, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), ) ) - kwargs[AttentionKwargs.max_seqlen_q] = seqlens_q.max() - kwargs[AttentionKwargs.max_seqlen_k] = seqlens_k.max() + kwargs[TransformerKwargs.max_seqlen_q] = seqlens_q.max() + kwargs[TransformerKwargs.max_seqlen_k] = seqlens_k.max() diff --git a/fast_llm/layers/transformer/rotary/preprocessing.py b/fast_llm/layers/transformer/rotary/preprocessing.py index 9f8732f85..c357411b6 100644 --- a/fast_llm/layers/transformer/rotary/preprocessing.py +++ b/fast_llm/layers/transformer/rotary/preprocessing.py @@ -4,7 +4,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig from fast_llm.tensor import TensorMeta @@ -26,34 +26,34 @@ def __init__( self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] - self._kv_channels_dim = self._tensor_space[AttentionDimNames.kv_channels] + self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[AttentionKwargs.sequence_length]) - sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size - kwargs[AttentionKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ - :, sequence_k - kwargs[AttentionKwargs.sequence_q_dim].size : sequence_k + self._create_tensors(kwargs[TransformerKwargs.sequence_length]) + sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size + kwargs[TransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ + :, sequence_k - kwargs[TransformerKwargs.sequence_q_dim].size : sequence_k ] - kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] + kwargs[TransformerKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[AttentionKwargs.rotary_freq_q] = TensorMeta.from_dims( + kwargs[TransformerKwargs.rotary_freq_q] = TensorMeta.from_dims( ( self._scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], + kwargs[TransformerKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=AttentionKwargs.rotary_freq_q, + tensor_name=TransformerKwargs.rotary_freq_q, ) - kwargs[AttentionKwargs.rotary_freq_k] = TensorMeta.from_dims( + kwargs[TransformerKwargs.rotary_freq_k] = TensorMeta.from_dims( ( self._scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], + kwargs[TransformerKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=AttentionKwargs.rotary_freq_k, + tensor_name=TransformerKwargs.rotary_freq_k, ) def _create_tensors(self, sequence_length: int) -> None: diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index ebb629aa1..17b18a1ca 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -8,7 +8,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.triton.rotary import triton_rotary_autograd_ -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.rotary.config import ( DefaultRotaryConfig, Llama3RotaryConfig, @@ -83,44 +83,44 @@ def __init__( self._tensor_space = tensor_space if self._tensor_space is not None: self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] - self._kv_channels_dim = self._tensor_space[AttentionDimNames.kv_channels] + self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: assert self._tensor_space is not None - self._create_tensors(kwargs[AttentionKwargs.sequence_length]) - sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size - kwargs[AttentionKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ - :, sequence_k - kwargs[AttentionKwargs.sequence_q_dim].size : sequence_k + self._create_tensors(kwargs[TransformerKwargs.sequence_length]) + sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size + kwargs[TransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ + :, sequence_k - kwargs[TransformerKwargs.sequence_q_dim].size : sequence_k ] - kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] + kwargs[TransformerKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: assert self._tensor_space is not None - kwargs[AttentionKwargs.rotary_freq_q] = TensorMeta.from_dims( + kwargs[TransformerKwargs.rotary_freq_q] = TensorMeta.from_dims( ( self._scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], + kwargs[TransformerKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=AttentionKwargs.rotary_freq_q, + tensor_name=TransformerKwargs.rotary_freq_q, ) - kwargs[AttentionKwargs.rotary_freq_k] = TensorMeta.from_dims( + kwargs[TransformerKwargs.rotary_freq_k] = TensorMeta.from_dims( ( self._scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], + kwargs[TransformerKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=AttentionKwargs.rotary_freq_k, + tensor_name=TransformerKwargs.rotary_freq_k, ) def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: rotary_fn = triton_rotary_autograd_ if self._config.triton else apply_rotary_embeddings - query = rotary_fn(query, kwargs[AttentionKwargs.rotary_freq_q]) - key = rotary_fn(key, kwargs[AttentionKwargs.rotary_freq_k]) + query = rotary_fn(query, kwargs[TransformerKwargs.rotary_freq_q]) + key = rotary_fn(key, kwargs[TransformerKwargs.rotary_freq_k]) return query, key def _create_tensors(self, sequence_length: int) -> None: diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index eb24ef183..3c0ad8ab4 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -36,7 +36,7 @@ def get_layers(self) -> list[Layer]: self._tensor_space, block_index=i + 1, ) - for i in range(self._config.transformer.num_blocks) + for i in range(self._config.transformer.num_layers) ], CustomHead(self._config, self._tensor_space), ] diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index a7fcad82d..0da16428e 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -9,7 +9,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.language_model.config import LanguageModelConfig +from fast_llm.layers.language_model.config import LanguageModelBaseConfig from fast_llm.models.gpt.megatron import set_megatron_distributed_seeds from fast_llm.utils import Assert, div @@ -119,7 +119,7 @@ def micro_batch_splits(self) -> int: @config_class() -class GPTBaseModelConfig(LanguageModelConfig): +class GPTBaseModelConfig(LanguageModelBaseConfig): _abstract = False # Debug, to get an exact match with megatron init. @@ -192,12 +192,15 @@ class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate() def _validate(self) -> None: + if self.batch.sequence_length is None: + # TODO: Drop this. + self.batch.sequence_length = self.model.base_model.max_position_embeddings if self.model.base_model.use_megatron_initialization: set_megatron_distributed_seeds(self.model.distributed) super()._validate() if self.model.base_model.use_absolute_position_embeddings: - Assert.geq(self.model.base_model.absolute_position_embeddings, self.batch.sequence_length) + Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) distillation_model = self.model.base_model.distillation_model dpo_reference_model = self.model.base_model.dpo_reference_model diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index f3e57fe13..2dbef77f3 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -176,7 +176,7 @@ def _create_weight_converters( self, ) -> list[WeightConverter]: converters = [] - num_layers = self._model.config.base_model.transformer.num_blocks + num_layers = self._model.config.base_model.transformer.num_layers # Embeddings converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) @@ -256,7 +256,7 @@ def _create_transformer_layer_converters( return converters def _create_lm_head_converters(self) -> list[WeightConverter]: - num_layers = self._model.config.base_model.transformer.num_blocks + num_layers = self._model.config.base_model.transformer.num_layers prediction_heads = self._model.config.base_model.prediction_heads norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) converters = [] @@ -654,7 +654,7 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig # Override base method to handle the MTP heads def _create_lm_head_converters(self) -> list[WeightConverter]: - num_layers = self._model.config.base_model.transformer.num_blocks + num_layers = self._model.config.base_model.transformer.num_layers prediction_heads = self._model.config.base_model.prediction_heads norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) converters = [] diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index 4e3f258fc..cf7da3872 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -9,7 +9,7 @@ from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM -from fast_llm.layers.transformer.config import AttentionKwargs +from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner @@ -86,12 +86,12 @@ def forward( if past_key_values is not None: # The transformers will use the past keys and values to this list. - kwargs[AttentionKwargs.past_key_values] = past_key_values + kwargs[TransformerKwargs.past_key_values] = past_key_values # TODO: preprocess needs to know about the past. raise NotImplementedError() if use_cache: # The transformers will save the present keys and values to this list. - kwargs[AttentionKwargs.presents] = [] + kwargs[TransformerKwargs.presents] = [] if output_hidden_states: kwargs["output_hidden_states"] = True @@ -117,11 +117,11 @@ def forward( outputs = (logits,) if use_cache: - outputs += (kwargs[AttentionKwargs.presents],) + outputs += (kwargs[TransformerKwargs.presents],) return outputs return transformers.modeling_outputs.CausalLMOutputWithPast( logits=logits, hidden_states=hidden_states, - past_key_values=kwargs[AttentionKwargs.presents], + past_key_values=kwargs[TransformerKwargs.presents], ) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 30842597d..da647de57 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -16,7 +16,7 @@ from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor from fast_llm.layers.transformer.block import TransformerBlock -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron @@ -68,7 +68,7 @@ def get_output_layers(self) -> list[Layer]: self._config.transformer, self._tensor_space, # TODO MTP: which index? - block_index=max(self._config.transformer.num_blocks + i, 1), + block_index=max(self._config.transformer.num_layers + i, 1), # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=i < self._config.prediction_heads - 1, @@ -93,9 +93,9 @@ def get_layers(self) -> list[Layer]: block_index=i + 1, # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. - return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_blocks - 1, + return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, ) - for i in range(self._config.transformer.num_blocks) + for i in range(self._config.transformer.num_layers) ], *self.get_output_layers(), ] @@ -119,7 +119,7 @@ def preprocess_meta( truncate_documents = True batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) - batch_dim = TensorDim(AttentionDimNames.batch, micro_batch_size * batch_data.size, batch_data) + batch_dim = TensorDim(TransformerDimNames.batch, micro_batch_size * batch_data.size, batch_data) if micro_sequence_length is None: micro_sequence_length = sequence_length @@ -128,13 +128,13 @@ def preprocess_meta( # TODO: Calculate hidden dims elsewhere? sequence_q_dim = TensorDim( - AttentionDimNames.sequence_q, + TransformerDimNames.sequence_q, micro_sequence_length, self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.sequence_data), ) hidden_sequence_q_dim = ( TensorDim( - AttentionDimNames.sequence_q_tp, + TransformerDimNames.sequence_q_tp, micro_sequence_length, self._tensor_space.distributed_config.get_distributed_dim( DistributedDimNames.tensor_and_sequence_data @@ -151,7 +151,7 @@ def preprocess_meta( sequence_first = self._config.sequence_first assert not (need_sequence_first and not sequence_first) - hidden_dim = self._tensor_space[AttentionDimNames.hidden] + hidden_dim = self._tensor_space[TransformerDimNames.hidden] hidden_dims = ( (hidden_sequence_q_dim, batch_dim, hidden_dim) if sequence_first @@ -160,10 +160,10 @@ def preprocess_meta( common_kwargs = { LanguageModelKwargs.phase: phase, - AttentionKwargs.sequence_first: sequence_first, - AttentionKwargs.hidden_dims: hidden_dims, - AttentionKwargs.sequence_length: sequence_length, - AttentionKwargs.sequence_q_dim: sequence_q_dim, + TransformerKwargs.sequence_first: sequence_first, + TransformerKwargs.hidden_dims: hidden_dims, + TransformerKwargs.sequence_length: sequence_length, + TransformerKwargs.sequence_q_dim: sequence_q_dim, LanguageModelKwargs.mask_inputs: not truncate_documents, } @@ -182,7 +182,7 @@ def preprocess_meta( preprocessed_meta = [] for i, sequence_k_past in enumerate(sequence_k_pasts): sequence_k = sequence_k_past + sequence_q_dim.size - sequence_k_dim = TensorDim(AttentionDimNames.sequence_k, sequence_k) + sequence_k_dim = TensorDim(TransformerDimNames.sequence_k, sequence_k) tokens = TensorMeta.from_dims( hidden_dims[:2], tensor_name=f"tokens_{sequence_k_past}_to_{sequence_k-1}", dtype=torch.int64 @@ -190,7 +190,7 @@ def preprocess_meta( kwargs = { **common_kwargs, - AttentionKwargs.sequence_k_dim: sequence_k_dim, + TransformerKwargs.sequence_k_dim: sequence_k_dim, } if phase != PhaseType.inference: kwargs[LanguageModelKwargs.labels] = TensorMeta.from_dims( @@ -202,10 +202,10 @@ def preprocess_meta( for name, reference_preprocessed_meta in reference_preprocessed_metas.items(): reference_tokens, reference_kwargs_ = reference_preprocessed_meta[i] for key in ( - AttentionKwargs.sequence_first, - AttentionKwargs.sequence_length, - AttentionKwargs.sequence_q_dim, - AttentionKwargs.sequence_k_dim, + TransformerKwargs.sequence_first, + TransformerKwargs.sequence_length, + TransformerKwargs.sequence_q_dim, + TransformerKwargs.sequence_k_dim, ): Assert.eq(reference_kwargs_[key], kwargs[key]) reference_kwargs[name] = reference_kwargs_ @@ -231,8 +231,8 @@ def preprocess( preprocessed_meta = self.preprocess_meta(batch.token_ids, phase) _, common_kwargs = preprocessed_meta[0] - sequence_q = common_kwargs[AttentionKwargs.sequence_q_dim].size - sequence_first = common_kwargs[AttentionKwargs.sequence_first] + sequence_q = common_kwargs[TransformerKwargs.sequence_q_dim].size + sequence_first = common_kwargs[TransformerKwargs.sequence_first] prediction_heads: int = self._config.prediction_heads batch.token_ids = batch.token_ids.to( @@ -264,14 +264,14 @@ def preprocess( preprocessed = [] presents = None for i, (_, kwargs_meta) in enumerate(preprocessed_meta): - sequence_k = kwargs_meta[AttentionKwargs.sequence_k_dim].size + sequence_k = kwargs_meta[TransformerKwargs.sequence_k_dim].size if sequence_first: tokens = batch.token_ids[sequence_k - sequence_q : sequence_k] else: # TODO: Avoid multiple contiguous calls? tokens = batch.token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() if batch.sequence_lengths is not None: - kwargs_meta[AttentionKwargs.sequence_lengths] = batch.sequence_lengths + kwargs_meta[TransformerKwargs.sequence_lengths] = batch.sequence_lengths if batch.chosen_spans is not None: kwargs_meta[LanguageModelKwargs.chosen_spans] = batch.chosen_spans if batch.rejected_spans is not None: @@ -283,8 +283,8 @@ def preprocess( presents = None if i == len(preprocessed_meta) - 1 else [] kwargs = { **kwargs_meta, - AttentionKwargs.past_key_values: pasts, - AttentionKwargs.presents: presents, + TransformerKwargs.past_key_values: pasts, + TransformerKwargs.presents: presents, } if phase != PhaseType.inference: sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels @@ -372,7 +372,7 @@ def loss_defs(self) -> list[LossDef]: LossDef( name=MLPLossNames.load_balancing_loss, formatted_name="load balancing loss", - count=self._config.transformer.num_blocks, + count=self._config.transformer.num_layers, ) ) if self._config.transformer.expert_z_loss_coefficient: @@ -380,7 +380,7 @@ def loss_defs(self) -> list[LossDef]: LossDef( name=MLPLossNames.router_z_loss, formatted_name="router z loss", - count=self._config.transformer.num_blocks, + count=self._config.transformer.num_layers, ) ) if self._config.logit_z_loss: @@ -421,7 +421,7 @@ def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration, batch_size, s consumed_tokens_per_iteration = sequence_length * batch_size - num_transformer_layers = transformer_config.num_blocks + self._config.base_model.prediction_heads - 1 + num_transformer_layers = transformer_config.num_layers + self._config.base_model.prediction_heads - 1 transformer_flops_base = ( 2 * checkpoint_activations_factor * consumed_tokens_per_iteration * num_transformer_layers ) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index a351522ca..9427f69be 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -62,13 +62,13 @@ def _validate(self): if self.hybrid_block_layout is None: with self._set_implicit_default(): - self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_blocks + self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_layers - if len(self.hybrid_block_layout) != self.transformer.num_blocks: - message = f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_blocks}" - if self.transformer.num_blocks % len(self.hybrid_block_layout) != 0: + if len(self.hybrid_block_layout) != self.transformer.num_layers: + message = f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" + if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: raise ValueError(message) - num_repeats = self.transformer.num_blocks // len(self.hybrid_block_layout) + num_repeats = self.transformer.num_layers // len(self.hybrid_block_layout) logger.warning(f"{message}, will repeat {self.hybrid_block_layout} {num_repeats} times.") self.hybrid_block_layout = self.hybrid_block_layout * num_repeats diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index fb24c1aec..43e3c67e5 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -219,7 +219,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: def _create_weight_converters(self) -> list[WeightConverter]: converters = super()._create_weight_converters() or [] - num_layers = self._model.config.base_model.transformer.num_blocks + num_layers = self._model.config.base_model.transformer.num_layers ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear for i in range(num_layers): @@ -383,7 +383,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: def _create_weight_converters(self) -> list[WeightConverter]: # not using super() because LLamba model is called backbone in the checkpoints converters = [] - num_layers = self._model.config.base_model.transformer.num_blocks + num_layers = self._model.config.base_model.transformer.num_layers norm_bias: bool = False ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear @@ -572,7 +572,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: def _create_weight_converters(self) -> list[WeightConverter]: converters = super()._create_weight_converters() - num_layers = self._model.config.base_model.transformer.num_blocks + num_layers = self._model.config.base_model.transformer.num_layers norm_bias: bool = False # Embedding and output diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index b12d12072..d080e6a1e 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -1,12 +1,13 @@ +import abc import functools import logging +import math import typing import torch from fast_llm.core.distributed import ReduceOp from fast_llm.core.ops import reduce_op -from fast_llm.engine.config_utils.initialization import Initializer, LambdaInitializer from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed @@ -360,3 +361,70 @@ def accumulate_gradient(param: torch.Tensor, grad: torch.Tensor) -> None: triton_copy(grad, param.grad_buffer) # noqa else: triton_add(grad, param.grad_buffer, out=param.grad_buffer) # noqa + + +class Initializer(abc.ABC): + @abc.abstractmethod + def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: + pass + + requires_global_initialization = False + + +class LambdaInitializer(Initializer): + def __init__( + self, + init_method: typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None], + requires_global_initialization: bool = False, + ) -> None: + self._init_method = init_method + self.requires_global_initialization = requires_global_initialization + + def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: + return self._init_method(meta, tensor, generator) + + +def init_fill_(value: float) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa + tensor.fill_(value) + + return LambdaInitializer(init_) + + +init_zeros_ = init_fill_(0.0) +init_ones_ = init_fill_(1.0) + + +def init_normal_( + mean: float = 0.0, std: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa + tensor = tensor.normal_(mean, std, generator=generator) + if min_val is not None or max_val is not None: + tensor.clamp_(min=min_val, max=max_val) + + return LambdaInitializer(init_) + + +def init_kaiming_(d_in: float) -> LambdaInitializer: + return init_normal_(0.0, math.sqrt(2.0 / d_in)) + + +def init_uniform_( + low: float = 0.0, high: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa + tensor = tensor.uniform_(low, high, generator=generator) + if min_val is not None or max_val is not None: + tensor.clamp_(min=min_val, max=max_val) + + return LambdaInitializer(init_) + + +def init_uniform_centered_(high: float, max_val: float | None = None, mean: float = 0.0) -> LambdaInitializer: + return init_uniform_( + mean - high, + mean + high, + min_val=None if max_val is None else mean - max_val, + max_val=None if max_val is None else mean + max_val, + ) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 8c33aed4d..9a878c494 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -9,7 +9,7 @@ from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead -from fast_llm.layers.transformer.config import AttentionKwargs +from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda @@ -198,8 +198,8 @@ def test_lm_head( else: loss_mask = None kwargs = { - AttentionKwargs.sequence_first: sequence_first, - AttentionKwargs.grad_output: 1.0, + TransformerKwargs.sequence_first: sequence_first, + TransformerKwargs.grad_output: 1.0, } if config.distillation_model is None: target = torch.randint( diff --git a/tests/models/test_generate.py b/tests/models/test_generate.py index cb9c69ccb..7f0b902f8 100644 --- a/tests/models/test_generate.py +++ b/tests/models/test_generate.py @@ -354,7 +354,7 @@ def _test_forward_return_hidden_states( # hidden_states include embeddings layer assert ( - len(res_fast_llm.hidden_states) - 1 == fast_llm_model.config.fast_llm_config.base_model.transformer.num_blocks + len(res_fast_llm.hidden_states) - 1 == fast_llm_model.config.fast_llm_config.base_model.transformer.num_layers ) diff --git a/tests/test_attention.py b/tests/test_attention.py index 534e3800e..dd36b840a 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -6,7 +6,7 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs, TransformerConfig +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.preprocessing import FlashAttnVarlenPreprocessor from fast_llm.utils import Assert @@ -77,13 +77,13 @@ def test_varlen_preprocessor(): varlen_preprocessor = FlashAttnVarlenPreprocessor(transformer_cfg, tensor_space=tensor_space) for micro_seq_idx in range(int(sequence_length / micro_sequence_length)): kwargs = { - AttentionKwargs.sequence_q_dim: TensorDim(AttentionDimNames.sequence_k, micro_sequence_length), - AttentionKwargs.sequence_k_dim: TensorDim( - AttentionDimNames.sequence_k, (micro_seq_idx + 1) * micro_sequence_length + TransformerKwargs.sequence_q_dim: TensorDim(TransformerDimNames.sequence_k, micro_sequence_length), + TransformerKwargs.sequence_k_dim: TensorDim( + TransformerDimNames.sequence_k, (micro_seq_idx + 1) * micro_sequence_length ), - AttentionKwargs.sequence_length: sequence_length, - AttentionKwargs.sequence_lengths: sequence_lengths, + TransformerKwargs.sequence_length: sequence_length, + TransformerKwargs.sequence_lengths: sequence_lengths, } varlen_preprocessor.preprocess(None, kwargs) - Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_q], cumulative_sequences_q[micro_seq_idx]) - Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_k], cumulative_sequences_k[micro_seq_idx]) + Assert.all_equal(kwargs[TransformerKwargs.cu_seqlens_q], cumulative_sequences_q[micro_seq_idx]) + Assert.all_equal(kwargs[TransformerKwargs.cu_seqlens_k], cumulative_sequences_k[micro_seq_idx]) diff --git a/tests/test_ssms.py b/tests/test_ssms.py index 6c4c7f0cb..694faa55b 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -9,7 +9,7 @@ from fast_llm.engine.schedule.config import ScheduleConfig from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.engine.schedule.schedule import Schedule -from fast_llm.layers.transformer.config import AttentionKwargs +from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat from fast_llm.models.ssm.model import HybridSSMModel @@ -71,8 +71,8 @@ def test_load_from_llamba_checkpoint(): schedule_runner.setup(model.distributed, optimizer=None) common_kwargs = { - AttentionKwargs.sequence_first: True, - AttentionKwargs.grad_output: False, + TransformerKwargs.sequence_first: True, + TransformerKwargs.grad_output: False, } input_data = [(x, common_kwargs)] diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 4705ebb79..722d8d63a 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -162,7 +162,6 @@ def _update_and_add_testing_config( "model.base_model.transformer.num_attention_heads=8", "model.base_model.transformer.head_groups=8", "model.base_model.transformer.init_method_std=0.022", - "model.base_model.transformer.use_position_embeddings=True", f"model.base_model.vocab_size={MODEL_TEST_VOCAB_SIZE}", f"model.multi_stage.debug_param_init={_LOG_LEVEL}", f"model.multi_stage.debug_layer_outputs={_LOG_LEVEL}", @@ -259,7 +258,6 @@ def _update_and_add_testing_config( extra_args=[ "model.base_model.transformer.head_groups=4", "model.base_model.transformer.rotary.type=default", - "model.base_model.transformer.use_position_embeddings=False", # Unused, but prevents issues with conversion tests. "model.base_model.max_position_embeddings=2048", ], From b68d36048852f6d78c2c8506f76f1708ada1f77e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 31 Jul 2025 15:58:34 -0400 Subject: [PATCH 36/82] stuff --- fast_llm/layers/block/config.py | 25 ++- fast_llm/layers/block/mlp/config.py | 54 ++++- fast_llm/layers/transformer/attention.py | 74 +++---- fast_llm/layers/transformer/config.py | 192 ++++++------------ fast_llm/layers/transformer/preprocessing.py | 48 ++--- .../transformer/rotary/preprocessing.py | 26 +-- fast_llm/layers/transformer/rotary/rotary.py | 30 +-- fast_llm/models/gpt/conversion.py | 6 +- fast_llm/models/gpt/huggingface.py | 10 +- fast_llm/models/gpt/model.py | 42 ++-- tests/layers/test_lm_head.py | 6 +- tests/test_attention.py | 16 +- tests/test_ssms.py | 6 +- 13 files changed, 266 insertions(+), 269 deletions(-) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 5a999fa6d..489cd4f3f 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -40,7 +40,7 @@ class AddLinearBiasChoices(str, enum.Enum): @config_class() -# TODO: Use composition for MLP config +# TODO: Use composition instead class BlockConfig(MLPConfig, BaseModelConfig): # TODO: Review names @@ -100,10 +100,33 @@ class BlockConfig(MLPConfig, BaseModelConfig): hint=FieldHint.feature, ) + # TODO: Review initialization + init_method_std: float = Field( + default=None, + desc="Default scale for weight initialization. Default: hidden_size**-0.5", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max: float | None = Field( + default=None, + desc="Max value for clamping initialized weights. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min: float | None = Field( + default=None, + desc="Min value for clamping initialized weights. Default: -float('inf')", + hint=FieldHint.optional, + ) + def _validate(self) -> None: with self._set_implicit_default(): if self.ffn_hidden_size is None: self.ffn_hidden_size = 4 * self.hidden_size + # TODO: Review initialization + if self.init_method_std is None: + self.init_method_std = self.hidden_size**-0.5 + if self.init_method_min is not None and self.init_method_max is not None: + Assert.leq(self.init_method_min, self.init_method_max) super()._validate() diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 1d125c4f7..64e234544 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -72,8 +72,6 @@ class MLPConfig(Config): hint=FieldHint.architecture, ) gated: bool = Field(default=False, desc="Enable gated MLP.", hint=FieldHint.architecture) - # Default: hidden_size**-0.5 - # TODO: Allow custom initialization (InitializationConfig?) activation_type: ActivationType = Field( default=None, desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", @@ -124,11 +122,63 @@ class MLPConfig(Config): " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", hint=FieldHint.expert, ) + # TODO: Review initialization + init_method_std_mlp_1: float = Field( + default=None, + desc="Scale for the MLP first layer weight initialization. Default: init_method_std", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max_mlp_1: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for MLP first layer. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_mlp_1: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for MLP first layer. Default: -float('inf')", + hint=FieldHint.optional, + ) + init_method_std_mlp_2: float = Field( + default=None, + desc="Scale for the MLP second layer weight initialization. Default: init_method_std", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max_mlp_2: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for MLP second layer. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_mlp_2: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for MLP second layer. Default: -float('inf')", + hint=FieldHint.optional, + ) def _validate(self) -> None: with self._set_implicit_default(): + # TODO: Make this work without inheritance. if self.activation_type is None: self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu + # TODO: Review initialization + if self.init_method_std_mlp_1 is None: + self.init_method_std_mlp_1 = self.init_method_std + if self.init_method_std_mlp_2 is None: + self.init_method_std_mlp_2 = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 + if self.init_method_max_mlp_1 is None: + self.init_method_max_mlp_1 = self.init_method_max + if self.init_method_min_mlp_1 is None: + self.init_method_min_mlp_1 = self.init_method_min + if self.init_method_max_mlp_2 is None: + self.init_method_max_mlp_2 = self.init_method_max + if self.init_method_min_mlp_2 is None: + self.init_method_min_mlp_2 = self.init_method_min + if self.init_method_min_mlp_1 is not None and self.init_method_max_mlp_1 is not None: + Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) + if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: + Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) + self.num_unshared_experts = self.num_experts - self.num_shared_experts super()._validate() diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index b1de792e3..e84e92a96 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -9,7 +9,7 @@ from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs, TransformerConfig from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import get_lr_scale @@ -54,21 +54,21 @@ class Attention(Mixer): _mixer_name: typing.ClassVar[str] = "attn" _QUERY_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.composite_heads, - TransformerDimNames.kv_channels, + AttentionDimNames.batch, + AttentionDimNames.sequence_q, + AttentionDimNames.composite_heads, + AttentionDimNames.kv_channels, ) _KV_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.head_groups, - TransformerDimNames.kv_channels, + AttentionDimNames.batch, + AttentionDimNames.sequence_q, + AttentionDimNames.head_groups, + AttentionDimNames.kv_channels, ) _CONTEXT_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.composite_dense, + AttentionDimNames.batch, + AttentionDimNames.sequence_q, + AttentionDimNames.composite_dense, ) def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int): @@ -87,14 +87,14 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i max_val=self._config.init_method_max_attn_proj, ) - self._kv_channels = self._tensor_space[TransformerDimNames.kv_channels].size - self._head_groups = self._tensor_space[TransformerDimNames.head_groups].global_size - self._local_head_groups = self._tensor_space[TransformerDimNames.head_groups].size - self._local_heads_per_group = self._tensor_space[TransformerDimNames.group_heads].size + self._kv_channels = self._tensor_space[AttentionDimNames.kv_channels].size + self._head_groups = self._tensor_space[AttentionDimNames.head_groups].global_size + self._local_head_groups = self._tensor_space[AttentionDimNames.head_groups].size + self._local_heads_per_group = self._tensor_space[AttentionDimNames.group_heads].size self._local_heads = self._local_head_groups * self._local_heads_per_group self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) - hidden_dim = self._tensor_space[TransformerDimNames.hidden] + hidden_dim = self._tensor_space[AttentionDimNames.hidden] layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) @@ -102,19 +102,19 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, - self._tensor_space[TransformerDimNames.composite_query], - bias=self._config.add_attn_qkv_bias, + self._tensor_space[AttentionDimNames.composite_query], + bias=self._config.add_qkv_bias, weight_init_method=init_method_qkv, - bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, + bias_init_method=init_zeros_, sequence_parallel=self._sequence_parallel, lr_scale=attention_lr_scale, ) self.key_value = OutputParallelLinear( hidden_dim, - self._tensor_space[TransformerDimNames.composite_key_value], - bias=self._config.add_attn_qkv_bias, + self._tensor_space[AttentionDimNames.composite_key_value], + bias=self._config.add_qkv_bias, weight_init_method=init_method_qkv, - bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, + bias_init_method=init_zeros_, sequence_parallel=self._sequence_parallel, lr_scale=attention_lr_scale, ) @@ -125,11 +125,11 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i # Output. self.dense = InputParallelLinear( - self._tensor_space[TransformerDimNames.composite_dense], + self._tensor_space[AttentionDimNames.composite_dense], hidden_dim, - bias=self._config.add_attn_dense_bias, + bias=self._config.add_dense_bias, weight_init_method=init_method_std_attn_proj, - bias_init_method=init_method_std_attn_proj if self._config.random_bias_init else init_zeros_, + bias_init_method=init_zeros_, sequence_parallel=self._sequence_parallel, lr_scale=attention_lr_scale, ) @@ -260,17 +260,17 @@ def _decide_window_size(self) -> int | None: return window_size def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: - sequence_first = kwargs[TransformerKwargs.sequence_first] + sequence_first = kwargs[AttentionKwargs.sequence_first] query, key_value = self._query_key_value(input_, sequence_first) # TODO: Move the rest to function. - if (past_key_values := kwargs.get(TransformerKwargs.past_key_values)) is not None: + if (past_key_values := kwargs.get(AttentionKwargs.past_key_values)) is not None: assert sequence_first # Clear the lists so tensors can be de-allocated key_value = torch.cat((past_key_values.pop(0), key_value), dim=0) - if (presents := kwargs.get(TransformerKwargs.presents)) is not None: + if (presents := kwargs.get(AttentionKwargs.presents)) is not None: # Return the presents as a leaf tensors so the gradients from later micro-sequences # don't propagate to this one. presents.append(present := key_value.detach().requires_grad_()) @@ -279,9 +279,9 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if self._tensor_space.distributed.sequence_data_group: key_value = ( - key_value[: kwargs[TransformerKwargs.sequence_k_dim].size] + key_value[: kwargs[AttentionKwargs.sequence_k_dim].size] if sequence_first - else key_value[:, : kwargs[TransformerKwargs.sequence_k_dim].size] + else key_value[:, : kwargs[AttentionKwargs.sequence_k_dim].size] ) if sequence_first: @@ -310,7 +310,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if self._use_flash_attention: assert _flash_available with set_generator(self._tensor_space.distributed.tp_generator): - if (cu_seqlens_q := kwargs.get(TransformerKwargs.cu_seqlens_q, None)) is not None: + if (cu_seqlens_q := kwargs.get(AttentionKwargs.cu_seqlens_q, None)) is not None: out_dims = query.size() query = query.view(-1, query.size(-2), query.size(-1)) key = key.view(-1, key.size(-2), key.size(-1)) @@ -320,9 +320,9 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key, value, cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=kwargs.get(TransformerKwargs.cu_seqlens_k), - max_seqlen_q=kwargs.get(TransformerKwargs.max_seqlen_q), - max_seqlen_k=kwargs.get(TransformerKwargs.max_seqlen_k), + cu_seqlens_k=kwargs.get(AttentionKwargs.cu_seqlens_k), + max_seqlen_q=kwargs.get(AttentionKwargs.max_seqlen_q), + max_seqlen_k=kwargs.get(AttentionKwargs.max_seqlen_k), dropout_p=self._config.attention_dropout if self.training else 0.0, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), causal=True, @@ -345,8 +345,8 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ query.flatten(-2), key.flatten(-2), value.flatten(-2), - kwargs[TransformerKwargs.attention_mask], - kwargs[TransformerKwargs.attention_mask_value], + kwargs[AttentionKwargs.attention_mask], + kwargs[AttentionKwargs.attention_mask_value], ) if self._debug_level: diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index ebb976e63..a8245f7da 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -18,7 +18,7 @@ logger = logging.getLogger(__name__) -class TransformerDimNames(BlockDimNames): +class AttentionDimNames(BlockDimNames): # A set of common tensor dim names packed into a namespace. # Self-attention dimensions head_groups = "head_groups" @@ -31,7 +31,7 @@ class TransformerDimNames(BlockDimNames): composite_dense = "composite_dense" -class TransformerKwargs(BlockKwargs): +class AttentionKwargs(BlockKwargs): rotary_freq_q = "rotary_freq_q" rotary_freq_k = "rotary_freq_k" attention_mask = "attention_mask" @@ -106,78 +106,7 @@ class AttentionConfig(Config): " Under muP (if scaling number of heads instead of kv_channels): use 0.5.", valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - - def _validate(self) -> None: - super()._validate() - - if not TritonConfig.TRITON_ENABLED: - warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") - - Assert.multiple(self.num_attention_heads, self.head_groups) - - @functools.cached_property - def projection_size(self): - assert self._validated - return self.num_attention_heads * self.kv_channels - - def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: - return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) - - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - # Needed for multiple inheritance. - super().setup_tensor_space(tensor_space) # Noqa - - tensor_space.add_tensor_dim( - head_groups := TensorDim( - TransformerDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None - ) - ) - tensor_space.add_tensor_dim( - group_heads := TensorDim( - TransformerDimNames.group_heads, - div(self.num_attention_heads, self.head_groups), - None if self.head_groups > 1 else tensor, - ) - ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(TransformerDimNames.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(TransformerDimNames.kv_channels, self.kv_channels)) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_heads, (head_groups, group_heads)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_query, (head_groups, group_heads, kv_channels)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_dense, (head_groups, group_heads, kv_channels)) - ) - - -@config_class() -# TODO: Use composition for attention config -class TransformerConfig(AttentionConfig, BlockConfig): - _abstract = False - - # TODO: Review names - init_method_std: float = Field( - default=None, - desc="Default scale for weight initialization. Default: hidden_size**-0.5", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max: float | None = Field( - default=None, - desc="Max value for clamping initialized weights. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min: float | None = Field( - default=None, - desc="Min value for clamping initialized weights. Default: -float('inf')", - hint=FieldHint.optional, - ) + # TODO: Review initialization init_method_std_qkv: float = Field( default=None, desc="Scale for the query, key and value weight initialization. Default: init_method_std", @@ -210,59 +139,17 @@ class TransformerConfig(AttentionConfig, BlockConfig): desc="Min value for clamping initialized weights for attention projection. Default: -float('inf')", hint=FieldHint.optional, ) - init_method_std_mlp_1: float = Field( - default=None, - desc="Scale for the MLP first layer weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_mlp_1: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for MLP first layer. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_mlp_1: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for MLP first layer. Default: -float('inf')", - hint=FieldHint.optional, - ) - init_method_std_mlp_2: float = Field( - default=None, - desc="Scale for the MLP second layer weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_mlp_2: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for MLP second layer. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_mlp_2: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for MLP second layer. Default: -float('inf')", - hint=FieldHint.optional, - ) - # Use random inits instead of constant values, useful for debugging. - random_bias_init: bool = Field( - default=False, - desc="Initialize the biases using the initialization method of their respective weights instead of setting them to zero. Used to test for issues that may not be visible when the biases are zero.", - hint=FieldHint.testing, - ) def _validate(self) -> None: with self._set_implicit_default(): + # TODO: Make this work without inheritance. if self.kv_channels is None: self.kv_channels = div(self.hidden_size, self.num_attention_heads) - if self.init_method_std is None: - self.init_method_std = self.hidden_size**-0.5 + # TODO: Review initialization if self.init_method_std_qkv is None: self.init_method_std_qkv = self.init_method_std if self.init_method_std_attn_proj is None: self.init_method_std_attn_proj = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 - if self.init_method_std_mlp_1 is None: - self.init_method_std_mlp_1 = self.init_method_std - if self.init_method_std_mlp_2 is None: - self.init_method_std_mlp_2 = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 if self.init_method_max_qkv is None: self.init_method_max_qkv = self.init_method_max if self.init_method_min_qkv is None: @@ -271,31 +158,61 @@ def _validate(self) -> None: self.init_method_max_attn_proj = self.init_method_max if self.init_method_min_attn_proj is None: self.init_method_min_attn_proj = self.init_method_min - if self.init_method_max_mlp_1 is None: - self.init_method_max_mlp_1 = self.init_method_max - if self.init_method_min_mlp_1 is None: - self.init_method_min_mlp_1 = self.init_method_min - if self.init_method_max_mlp_2 is None: - self.init_method_max_mlp_2 = self.init_method_max - if self.init_method_min_mlp_2 is None: - self.init_method_min_mlp_2 = self.init_method_min - if self.init_method_min is not None and self.init_method_max is not None: - Assert.leq(self.init_method_min, self.init_method_max) if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: Assert.leq(self.init_method_min, self.init_method_max) if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: Assert.leq(self.init_method_min_qkv, self.init_method_max_qkv) if self.init_method_min_attn_proj is not None and self.init_method_max_attn_proj is not None: Assert.leq(self.init_method_min_attn_proj, self.init_method_max_attn_proj) - if self.init_method_min_mlp_1 is not None and self.init_method_max_mlp_1 is not None: - Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) - if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: - Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) super()._validate() + if not TritonConfig.TRITON_ENABLED: + warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") + + Assert.multiple(self.num_attention_heads, self.head_groups) + + @functools.cached_property + def projection_size(self): + assert self._validated + return self.num_attention_heads * self.kv_channels + + def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: + return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) + + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) + # Needed for multiple inheritance. + super().setup_tensor_space(tensor_space) # Noqa + + tensor_space.add_tensor_dim( + head_groups := TensorDim( + AttentionDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None + ) + ) + tensor_space.add_tensor_dim( + group_heads := TensorDim( + AttentionDimNames.group_heads, + div(self.num_attention_heads, self.head_groups), + None if self.head_groups > 1 else tensor, + ) + ) + tensor_space.add_tensor_dim(key_and_value := TensorDim(AttentionDimNames.key_and_value, 2)) + tensor_space.add_tensor_dim(kv_channels := TensorDim(AttentionDimNames.kv_channels, self.kv_channels)) + tensor_space.add_tensor_dim(CompositeTensorDim(AttentionDimNames.composite_heads, (head_groups, group_heads))) + tensor_space.add_tensor_dim( + CompositeTensorDim(AttentionDimNames.composite_query, (head_groups, group_heads, kv_channels)) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(AttentionDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(AttentionDimNames.composite_dense, (head_groups, group_heads, kv_channels)) + ) + @property - def add_attn_qkv_bias(self) -> bool: + def add_qkv_bias(self) -> bool: + # TODO: Make this work without inheritance. if isinstance(self.add_linear_biases, bool): return self.add_linear_biases if self.add_linear_biases == AddLinearBiasChoices.nowhere: @@ -303,9 +220,16 @@ def add_attn_qkv_bias(self) -> bool: return True @property - def add_attn_dense_bias(self) -> bool: + def add_dense_bias(self) -> bool: + # TODO: Make this work without inheritance. if isinstance(self.add_linear_biases, bool): return self.add_linear_biases if self.add_linear_biases == AddLinearBiasChoices.everywhere: return True return False + + +@config_class() +# TODO: Use composition instead +class TransformerConfig(AttentionConfig, BlockConfig): + _abstract = False diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 3f0e14eb7..d8fa14a6d 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -5,7 +5,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs +from fast_llm.layers.transformer.config import AttentionKwargs, TransformerConfig from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) @@ -51,13 +51,13 @@ def _create_tensors(self, sequence_length: int) -> None: ) def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[TransformerKwargs.sequence_length]) - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size - kwargs[TransformerKwargs.attention_mask] = self._mask[ + self._create_tensors(kwargs[AttentionKwargs.sequence_length]) + sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size + sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size + kwargs[AttentionKwargs.attention_mask] = self._mask[ None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k ] - if (sequence_lengths := kwargs.get(TransformerKwargs.sequence_lengths, None)) is not None: + if (sequence_lengths := kwargs.get(AttentionKwargs.sequence_lengths, None)) is not None: seq_ids = torch.stack( [ torch.cat([torch.full((x,), i) for i, x in enumerate(sample_lens)]) @@ -65,27 +65,27 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: ] ) document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(self._tensor_space.distributed.device) - kwargs[TransformerKwargs.attention_mask] = ( - kwargs[TransformerKwargs.attention_mask] + kwargs[AttentionKwargs.attention_mask] = ( + kwargs[AttentionKwargs.attention_mask] & document_mask[:, None, sequence_k - sequence_q : sequence_k, None, :sequence_k] ) - kwargs[TransformerKwargs.attention_mask_value] = self._mask_value + kwargs[AttentionKwargs.attention_mask_value] = self._mask_value def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[TransformerKwargs.attention_mask] = TensorMeta.from_dims( + kwargs[AttentionKwargs.attention_mask] = TensorMeta.from_dims( ( self._scalar_dim, self._scalar_dim, - kwargs[TransformerKwargs.sequence_q_dim], + kwargs[AttentionKwargs.sequence_q_dim], self._scalar_dim, - kwargs[TransformerKwargs.sequence_k_dim], + kwargs[AttentionKwargs.sequence_k_dim], ), - tensor_name=TransformerKwargs.attention_mask, + tensor_name=AttentionKwargs.attention_mask, dtype=torch.bool, ) - kwargs[TransformerKwargs.attention_mask_value] = TensorMeta.from_dims( + kwargs[AttentionKwargs.attention_mask_value] = TensorMeta.from_dims( (self._scalar_dim,), - tensor_name=TransformerKwargs.attention_mask_value, + tensor_name=AttentionKwargs.attention_mask_value, dtype=self._tensor_space.distributed_config.training_dtype.torch, ) @@ -107,12 +107,12 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: also contain previous tokens from the first document in micro-sequence. We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths. """ - if TransformerKwargs.sequence_lengths not in kwargs: + if AttentionKwargs.sequence_lengths not in kwargs: return - sequence_lengths = kwargs[TransformerKwargs.sequence_lengths] - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size - if sequence_q < kwargs[TransformerKwargs.sequence_length]: + sequence_lengths = kwargs[AttentionKwargs.sequence_lengths] + sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size + sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size + if sequence_q < kwargs[AttentionKwargs.sequence_length]: cumsums = [torch.cumsum(x, dim=0) for x in sequence_lengths] # The first and last documents in a microsequence need to be handled separately. Include all tokens from other documents # in the microsequence. We need to consider all keys computed so far from the first sample. We also store the offsets @@ -146,17 +146,17 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: else: seqlens_q = torch.cat(sequence_lengths) seqlens_k = torch.cat(sequence_lengths) - kwargs[TransformerKwargs.cu_seqlens_q] = torch.cat( + kwargs[AttentionKwargs.cu_seqlens_q] = torch.cat( ( torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), torch.cumsum(seqlens_q, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), ) ) - kwargs[TransformerKwargs.cu_seqlens_k] = torch.cat( + kwargs[AttentionKwargs.cu_seqlens_k] = torch.cat( ( torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), torch.cumsum(seqlens_k, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), ) ) - kwargs[TransformerKwargs.max_seqlen_q] = seqlens_q.max() - kwargs[TransformerKwargs.max_seqlen_k] = seqlens_k.max() + kwargs[AttentionKwargs.max_seqlen_q] = seqlens_q.max() + kwargs[AttentionKwargs.max_seqlen_k] = seqlens_k.max() diff --git a/fast_llm/layers/transformer/rotary/preprocessing.py b/fast_llm/layers/transformer/rotary/preprocessing.py index c357411b6..9f8732f85 100644 --- a/fast_llm/layers/transformer/rotary/preprocessing.py +++ b/fast_llm/layers/transformer/rotary/preprocessing.py @@ -4,7 +4,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig from fast_llm.tensor import TensorMeta @@ -26,34 +26,34 @@ def __init__( self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] - self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] + self._kv_channels_dim = self._tensor_space[AttentionDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[TransformerKwargs.sequence_length]) - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - kwargs[TransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ - :, sequence_k - kwargs[TransformerKwargs.sequence_q_dim].size : sequence_k + self._create_tensors(kwargs[AttentionKwargs.sequence_length]) + sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size + kwargs[AttentionKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ + :, sequence_k - kwargs[AttentionKwargs.sequence_q_dim].size : sequence_k ] - kwargs[TransformerKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] + kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[TransformerKwargs.rotary_freq_q] = TensorMeta.from_dims( + kwargs[AttentionKwargs.rotary_freq_q] = TensorMeta.from_dims( ( self._scalar_dim, - kwargs[TransformerKwargs.sequence_q_dim], + kwargs[AttentionKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=TransformerKwargs.rotary_freq_q, + tensor_name=AttentionKwargs.rotary_freq_q, ) - kwargs[TransformerKwargs.rotary_freq_k] = TensorMeta.from_dims( + kwargs[AttentionKwargs.rotary_freq_k] = TensorMeta.from_dims( ( self._scalar_dim, - kwargs[TransformerKwargs.sequence_q_dim], + kwargs[AttentionKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=TransformerKwargs.rotary_freq_k, + tensor_name=AttentionKwargs.rotary_freq_k, ) def _create_tensors(self, sequence_length: int) -> None: diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index 17b18a1ca..ebb629aa1 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -8,7 +8,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.triton.rotary import triton_rotary_autograd_ -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs from fast_llm.layers.transformer.rotary.config import ( DefaultRotaryConfig, Llama3RotaryConfig, @@ -83,44 +83,44 @@ def __init__( self._tensor_space = tensor_space if self._tensor_space is not None: self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] - self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] + self._kv_channels_dim = self._tensor_space[AttentionDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: assert self._tensor_space is not None - self._create_tensors(kwargs[TransformerKwargs.sequence_length]) - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - kwargs[TransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ - :, sequence_k - kwargs[TransformerKwargs.sequence_q_dim].size : sequence_k + self._create_tensors(kwargs[AttentionKwargs.sequence_length]) + sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size + kwargs[AttentionKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ + :, sequence_k - kwargs[AttentionKwargs.sequence_q_dim].size : sequence_k ] - kwargs[TransformerKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] + kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: assert self._tensor_space is not None - kwargs[TransformerKwargs.rotary_freq_q] = TensorMeta.from_dims( + kwargs[AttentionKwargs.rotary_freq_q] = TensorMeta.from_dims( ( self._scalar_dim, - kwargs[TransformerKwargs.sequence_q_dim], + kwargs[AttentionKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=TransformerKwargs.rotary_freq_q, + tensor_name=AttentionKwargs.rotary_freq_q, ) - kwargs[TransformerKwargs.rotary_freq_k] = TensorMeta.from_dims( + kwargs[AttentionKwargs.rotary_freq_k] = TensorMeta.from_dims( ( self._scalar_dim, - kwargs[TransformerKwargs.sequence_q_dim], + kwargs[AttentionKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=TransformerKwargs.rotary_freq_k, + tensor_name=AttentionKwargs.rotary_freq_k, ) def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: rotary_fn = triton_rotary_autograd_ if self._config.triton else apply_rotary_embeddings - query = rotary_fn(query, kwargs[TransformerKwargs.rotary_freq_q]) - key = rotary_fn(key, kwargs[TransformerKwargs.rotary_freq_k]) + query = rotary_fn(query, kwargs[AttentionKwargs.rotary_freq_q]) + key = rotary_fn(key, kwargs[AttentionKwargs.rotary_freq_k]) return query, key def _create_tensors(self, sequence_length: int) -> None: diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 2dbef77f3..6e79388b0 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -199,19 +199,19 @@ def _create_transformer_layer_converters( ( f"{fast_llm_layer_name}.self_attn.query", f"{hf_layer_name}.self_attn.q_proj", - transformer_config.add_attn_qkv_bias, + transformer_config.add_qkv_bias, QueryWeightConverter, ), ( f"{fast_llm_layer_name}.self_attn.key_value", (f"{hf_layer_name}.self_attn.k_proj", f"{hf_layer_name}.self_attn.v_proj"), - transformer_config.add_attn_qkv_bias, + transformer_config.add_qkv_bias, KeyValueWeightConverter, ), ( f"{fast_llm_layer_name}.self_attn.dense", f"{hf_layer_name}.self_attn.o_proj", - transformer_config.add_attn_dense_bias, + transformer_config.add_dense_bias, WeightConverter, ), # Norm diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index cf7da3872..4e3f258fc 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -9,7 +9,7 @@ from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM -from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.transformer.config import AttentionKwargs from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner @@ -86,12 +86,12 @@ def forward( if past_key_values is not None: # The transformers will use the past keys and values to this list. - kwargs[TransformerKwargs.past_key_values] = past_key_values + kwargs[AttentionKwargs.past_key_values] = past_key_values # TODO: preprocess needs to know about the past. raise NotImplementedError() if use_cache: # The transformers will save the present keys and values to this list. - kwargs[TransformerKwargs.presents] = [] + kwargs[AttentionKwargs.presents] = [] if output_hidden_states: kwargs["output_hidden_states"] = True @@ -117,11 +117,11 @@ def forward( outputs = (logits,) if use_cache: - outputs += (kwargs[TransformerKwargs.presents],) + outputs += (kwargs[AttentionKwargs.presents],) return outputs return transformers.modeling_outputs.CausalLMOutputWithPast( logits=logits, hidden_states=hidden_states, - past_key_values=kwargs[TransformerKwargs.presents], + past_key_values=kwargs[AttentionKwargs.presents], ) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index da647de57..187ca618d 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -16,7 +16,7 @@ from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor from fast_llm.layers.transformer.block import TransformerBlock -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron @@ -119,7 +119,7 @@ def preprocess_meta( truncate_documents = True batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) - batch_dim = TensorDim(TransformerDimNames.batch, micro_batch_size * batch_data.size, batch_data) + batch_dim = TensorDim(AttentionDimNames.batch, micro_batch_size * batch_data.size, batch_data) if micro_sequence_length is None: micro_sequence_length = sequence_length @@ -128,13 +128,13 @@ def preprocess_meta( # TODO: Calculate hidden dims elsewhere? sequence_q_dim = TensorDim( - TransformerDimNames.sequence_q, + AttentionDimNames.sequence_q, micro_sequence_length, self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.sequence_data), ) hidden_sequence_q_dim = ( TensorDim( - TransformerDimNames.sequence_q_tp, + AttentionDimNames.sequence_q_tp, micro_sequence_length, self._tensor_space.distributed_config.get_distributed_dim( DistributedDimNames.tensor_and_sequence_data @@ -151,7 +151,7 @@ def preprocess_meta( sequence_first = self._config.sequence_first assert not (need_sequence_first and not sequence_first) - hidden_dim = self._tensor_space[TransformerDimNames.hidden] + hidden_dim = self._tensor_space[AttentionDimNames.hidden] hidden_dims = ( (hidden_sequence_q_dim, batch_dim, hidden_dim) if sequence_first @@ -160,10 +160,10 @@ def preprocess_meta( common_kwargs = { LanguageModelKwargs.phase: phase, - TransformerKwargs.sequence_first: sequence_first, - TransformerKwargs.hidden_dims: hidden_dims, - TransformerKwargs.sequence_length: sequence_length, - TransformerKwargs.sequence_q_dim: sequence_q_dim, + AttentionKwargs.sequence_first: sequence_first, + AttentionKwargs.hidden_dims: hidden_dims, + AttentionKwargs.sequence_length: sequence_length, + AttentionKwargs.sequence_q_dim: sequence_q_dim, LanguageModelKwargs.mask_inputs: not truncate_documents, } @@ -182,7 +182,7 @@ def preprocess_meta( preprocessed_meta = [] for i, sequence_k_past in enumerate(sequence_k_pasts): sequence_k = sequence_k_past + sequence_q_dim.size - sequence_k_dim = TensorDim(TransformerDimNames.sequence_k, sequence_k) + sequence_k_dim = TensorDim(AttentionDimNames.sequence_k, sequence_k) tokens = TensorMeta.from_dims( hidden_dims[:2], tensor_name=f"tokens_{sequence_k_past}_to_{sequence_k-1}", dtype=torch.int64 @@ -190,7 +190,7 @@ def preprocess_meta( kwargs = { **common_kwargs, - TransformerKwargs.sequence_k_dim: sequence_k_dim, + AttentionKwargs.sequence_k_dim: sequence_k_dim, } if phase != PhaseType.inference: kwargs[LanguageModelKwargs.labels] = TensorMeta.from_dims( @@ -202,10 +202,10 @@ def preprocess_meta( for name, reference_preprocessed_meta in reference_preprocessed_metas.items(): reference_tokens, reference_kwargs_ = reference_preprocessed_meta[i] for key in ( - TransformerKwargs.sequence_first, - TransformerKwargs.sequence_length, - TransformerKwargs.sequence_q_dim, - TransformerKwargs.sequence_k_dim, + AttentionKwargs.sequence_first, + AttentionKwargs.sequence_length, + AttentionKwargs.sequence_q_dim, + AttentionKwargs.sequence_k_dim, ): Assert.eq(reference_kwargs_[key], kwargs[key]) reference_kwargs[name] = reference_kwargs_ @@ -231,8 +231,8 @@ def preprocess( preprocessed_meta = self.preprocess_meta(batch.token_ids, phase) _, common_kwargs = preprocessed_meta[0] - sequence_q = common_kwargs[TransformerKwargs.sequence_q_dim].size - sequence_first = common_kwargs[TransformerKwargs.sequence_first] + sequence_q = common_kwargs[AttentionKwargs.sequence_q_dim].size + sequence_first = common_kwargs[AttentionKwargs.sequence_first] prediction_heads: int = self._config.prediction_heads batch.token_ids = batch.token_ids.to( @@ -264,14 +264,14 @@ def preprocess( preprocessed = [] presents = None for i, (_, kwargs_meta) in enumerate(preprocessed_meta): - sequence_k = kwargs_meta[TransformerKwargs.sequence_k_dim].size + sequence_k = kwargs_meta[AttentionKwargs.sequence_k_dim].size if sequence_first: tokens = batch.token_ids[sequence_k - sequence_q : sequence_k] else: # TODO: Avoid multiple contiguous calls? tokens = batch.token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() if batch.sequence_lengths is not None: - kwargs_meta[TransformerKwargs.sequence_lengths] = batch.sequence_lengths + kwargs_meta[AttentionKwargs.sequence_lengths] = batch.sequence_lengths if batch.chosen_spans is not None: kwargs_meta[LanguageModelKwargs.chosen_spans] = batch.chosen_spans if batch.rejected_spans is not None: @@ -283,8 +283,8 @@ def preprocess( presents = None if i == len(preprocessed_meta) - 1 else [] kwargs = { **kwargs_meta, - TransformerKwargs.past_key_values: pasts, - TransformerKwargs.presents: presents, + AttentionKwargs.past_key_values: pasts, + AttentionKwargs.presents: presents, } if phase != PhaseType.inference: sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 9a878c494..8c33aed4d 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -9,7 +9,7 @@ from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead -from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.transformer.config import AttentionKwargs from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda @@ -198,8 +198,8 @@ def test_lm_head( else: loss_mask = None kwargs = { - TransformerKwargs.sequence_first: sequence_first, - TransformerKwargs.grad_output: 1.0, + AttentionKwargs.sequence_first: sequence_first, + AttentionKwargs.grad_output: 1.0, } if config.distillation_model is None: target = torch.randint( diff --git a/tests/test_attention.py b/tests/test_attention.py index dd36b840a..534e3800e 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -6,7 +6,7 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs, TransformerConfig from fast_llm.layers.transformer.preprocessing import FlashAttnVarlenPreprocessor from fast_llm.utils import Assert @@ -77,13 +77,13 @@ def test_varlen_preprocessor(): varlen_preprocessor = FlashAttnVarlenPreprocessor(transformer_cfg, tensor_space=tensor_space) for micro_seq_idx in range(int(sequence_length / micro_sequence_length)): kwargs = { - TransformerKwargs.sequence_q_dim: TensorDim(TransformerDimNames.sequence_k, micro_sequence_length), - TransformerKwargs.sequence_k_dim: TensorDim( - TransformerDimNames.sequence_k, (micro_seq_idx + 1) * micro_sequence_length + AttentionKwargs.sequence_q_dim: TensorDim(AttentionDimNames.sequence_k, micro_sequence_length), + AttentionKwargs.sequence_k_dim: TensorDim( + AttentionDimNames.sequence_k, (micro_seq_idx + 1) * micro_sequence_length ), - TransformerKwargs.sequence_length: sequence_length, - TransformerKwargs.sequence_lengths: sequence_lengths, + AttentionKwargs.sequence_length: sequence_length, + AttentionKwargs.sequence_lengths: sequence_lengths, } varlen_preprocessor.preprocess(None, kwargs) - Assert.all_equal(kwargs[TransformerKwargs.cu_seqlens_q], cumulative_sequences_q[micro_seq_idx]) - Assert.all_equal(kwargs[TransformerKwargs.cu_seqlens_k], cumulative_sequences_k[micro_seq_idx]) + Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_q], cumulative_sequences_q[micro_seq_idx]) + Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_k], cumulative_sequences_k[micro_seq_idx]) diff --git a/tests/test_ssms.py b/tests/test_ssms.py index 694faa55b..6c4c7f0cb 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -9,7 +9,7 @@ from fast_llm.engine.schedule.config import ScheduleConfig from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.engine.schedule.schedule import Schedule -from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.transformer.config import AttentionKwargs from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat from fast_llm.models.ssm.model import HybridSSMModel @@ -71,8 +71,8 @@ def test_load_from_llamba_checkpoint(): schedule_runner.setup(model.distributed, optimizer=None) common_kwargs = { - TransformerKwargs.sequence_first: True, - TransformerKwargs.grad_output: False, + AttentionKwargs.sequence_first: True, + AttentionKwargs.grad_output: False, } input_data = [(x, common_kwargs)] From 82c9dbd2d4270ea0bbe30afbe79520be4ebc7e68 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 31 Jul 2025 16:29:55 -0400 Subject: [PATCH 37/82] misc --- fast_llm/layers/block/block.py | 104 ++++++++++++++++++++++- fast_llm/layers/block/mixer.py | 68 --------------- fast_llm/layers/language_model/head.py | 88 ++++++++++--------- fast_llm/layers/ssm/block.py | 7 +- fast_llm/layers/ssm/discrete_mamba2.py | 12 ++- fast_llm/layers/ssm/mamba2.py | 36 +++++--- fast_llm/layers/ssm/mamba_layer.py | 12 ++- fast_llm/layers/transformer/attention.py | 51 ++++++----- fast_llm/layers/transformer/block.py | 5 +- 9 files changed, 219 insertions(+), 164 deletions(-) delete mode 100644 fast_llm/layers/block/mixer.py diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 85da61c01..87a8f81cf 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -1,4 +1,5 @@ import abc +import functools import typing import torch @@ -9,13 +10,112 @@ from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs -from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.block.mlp.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta +class DebugLayer: + # TODO: Move elsewhere? + def __init__(self, tensor_space: TensorSpace, name: str, debug_level: int = 0, debug_memory: bool = False): + self._tensor_space = tensor_space + self._name = name + self._debug_level = debug_level + self._debug_memory = debug_memory + + def _get_meta( + self, tensor: torch.Tensor, name: str, dims: tuple[TensorDim | str, ...], kwargs: dict[str, typing.Any] + ) -> TensorMeta: + hidden_dims = { + dim.name: dim for dim in kwargs[BlockKwargs.hidden_dims] + (kwargs[BlockKwargs.sequence_q_dim],) + } + return TensorMeta.from_dims( + tuple( + ( + dim + if isinstance(dim, TensorDim) + else hidden_dims[dim] if dim in hidden_dims else self._tensor_space[dim] + ) + for dim in dims + ), + tensor_name=f"{self._name} {name}", + dtype=tensor.dtype, + ) + + @functools.cached_property + def enabled(self) -> bool: + return self._debug_level > 0 or self._debug_memory + + def __call__( + self, + tensor: torch.Tensor, + name: str, + dims: tuple[TensorDim | str, ...], + kwargs: dict[str, typing.Any], + scale: float = 1.0, + global_: bool = True, + log_fn: type[BaseException] | typing.Callable[[str], T] | None = logger.info, + ) -> None: + # TODO: Local vs global? + if self._debug_memory: + log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"{self._name} {name}", str)) + if self._debug_level > 0: + log_distributed_tensor( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name, dims, kwargs), + distributed=self._tensor_space.distributed, + global_=global_, + log_fn=log_fn, + scale=scale, + ) + if tensor.requires_grad: + log_distributed_grad( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name + " grad", dims, kwargs), + distributed=self._tensor_space.distributed, + global_=global_, + log_fn=log_fn, + scale=scale, + ) + + +class BlockLayer(torch.nn.Module, abc.ABC): + """ + Base class for mixer and MLP modules. + """ + + def __init__(self, tensor_space: TensorSpace, block_index: int, name: str, debug_level: int, debug_memory: bool): + super().__init__() + self._tensor_space = tensor_space + self._block_index = block_index + self._name = name + self._sequence_parallel: bool = self._tensor_space.distributed_config.sequence_tensor_parallel + self._debug = DebugLayer( + tensor_space, + f"Block {self._block_index} {self._name}", + debug_level, + debug_memory, + ) + + @abc.abstractmethod + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + pass + + def _debug_log(self, tensor: torch.Tensor) -> None: + pass + + class Block[ConfigType: BlockConfig](Layer, Configurable[ConfigType]): """ A transformer-like decoder base block with abstract mixer. @@ -52,7 +152,7 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: i self.norm_2 = self._config.peft.apply_other(self.norm_2) @abc.abstractmethod - def _create_mixer(self) -> Mixer: + def _create_mixer(self) -> BlockLayer: pass @torch.compile diff --git a/fast_llm/layers/block/mixer.py b/fast_llm/layers/block/mixer.py deleted file mode 100644 index 5c811e330..000000000 --- a/fast_llm/layers/block/mixer.py +++ /dev/null @@ -1,68 +0,0 @@ -import abc -import typing - -import torch - -from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.block.config import BlockKwargs -from fast_llm.logging import log_distributed_grad, log_distributed_tensor -from fast_llm.tensor import TensorMeta -from fast_llm.utils import Assert - - -class Mixer(torch.nn.Module, abc.ABC): - """ - Base class for mixer modules. - """ - - _mixer_name: typing.ClassVar[str] - - def __init__(self, tensor_space: TensorSpace, block_index: int, debug_level: int = 0): - super().__init__() - self._tensor_space = tensor_space - self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel - self._block_index = block_index - self._debug_level = debug_level - - @abc.abstractmethod - def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Mixer module forward. Returns the output hidden states and an optional bias, - in case its addition can be made more efficient in `_bias_dropout_add`. - """ - - def _get_meta( - self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> TensorMeta: - hidden_dims = { - dim.name: dim for dim in kwargs[BlockKwargs.hidden_dims] + (kwargs[BlockKwargs.sequence_q_dim],) - } - return TensorMeta.from_dims( - tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space[dim_name] - for dim_name in dim_names - ), - tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", - dtype=input_.dtype, - ) - - def _debug_log( - self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> None: - # TODO: Local vs global - Assert.gt(self._debug_level, 0) - log_distributed_tensor( - "", - tensor, - level=self._debug_level, - meta=self._get_meta(tensor, name, dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=self._debug_level, - meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index bc672725c..0623ac201 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -15,6 +15,7 @@ from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward +from fast_llm.layers.block.block import DebugLayer from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.language_model.config import ( LanguageModelBaseConfig, @@ -46,47 +47,66 @@ def __init__( prediction_distance: int, ): super().__init__(config) - self._debug_transformer = config.transformer.debug_transformer - self._tie_word_embeddings = config.tie_word_embeddings + self._debug = DebugLayer( + tensor_space, + f"Language model head", + self._config.transformer.debug_transformer, + self._config.transformer.debug_transformer_memory, + ) self._tensor_space = tensor_space self._group_size = tensor_space.distributed_config.tensor_parallel self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel - self._parallel_embeddings = tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings + self._parallel_embeddings = ( + tensor_space.distributed_config.tensor_parallel > 1 and self._config.parallel_embeddings + ) self._sequence_parallel_logits = ( - tensor_space.distributed_config.sequence_tensor_parallel and not config.parallel_embeddings + tensor_space.distributed_config.sequence_tensor_parallel and not self._config.parallel_embeddings ) - self._cross_entropy_splits = config.cross_entropy_splits + self._cross_entropy_splits = self._config.cross_entropy_splits if self._cross_entropy_splits is not None and self._sequence_parallel: assert not self._parallel_embeddings hidden_dim = self._tensor_space[LanguageModelDimNames.hidden] self._loss_coefficient = ( - config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0 + self._config.prediction_loss_coefficient[prediction_distance] + if self._config.prediction_loss_coefficient + else 1.0 ) self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance) - self.final_norm = config.transformer.normalization.get_layer(hidden_dim) - self._logits_scale_factor = config.logits_scale_factor - self._language_model_loss_factor = config.language_model_loss_factor - self._distillation_loss_factor = config.distillation_loss_factor - self._z_loss_factor = config.logit_z_loss + self.final_norm = self._config.transformer.normalization.get_layer(hidden_dim) + self._logits_scale_factor = self._config.logits_scale_factor + self._language_model_loss_factor = self._config.language_model_loss_factor + self._distillation_loss_factor = self._config.distillation_loss_factor + self._z_loss_factor = self._config.logit_z_loss # Distance of the target token prediction # 0: next-token prediction # >0: multi-token prediction (MTP) Assert.geq(prediction_distance, 0) self._prediction_distance = prediction_distance - self._is_last_head = self._prediction_distance == config.prediction_heads - 1 + self._is_last_head = self._prediction_distance == self._config.prediction_heads - 1 - self._init_output_weights(hidden_dim, config) + # Only the first head defines the output weights + if self._prediction_distance == 0 and not self._config.tie_word_embeddings: + # untie embedding weights + vocab_dim = self._tensor_space[ + LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab + ] + self.output_weights = ParameterMeta.from_dims( + (vocab_dim, hidden_dim), + init_method=init_normal_( + std=self._config.init_method_std_embed, + min_val=self._config.init_method_min_embed, + max_val=self._config.init_method_max_embed, + ), + lr_scale=self._config.output_lr_scale, + ) - self._use_dpo_loss = config.enable_dpo - if self._use_dpo_loss: - self.dpo_beta = config.dpo_beta - else: - self._cross_entropy_impl = config.cross_entropy_impl - self._distillation_loss_implementation = config.distillation_loss_implementation + self._use_dpo_loss = self._config.enable_dpo + if not self._use_dpo_loss: + self._cross_entropy_impl = self._config.cross_entropy_impl if self._cross_entropy_impl == CrossEntropyImpl.auto: if self._parallel_embeddings: self._cross_entropy_impl = CrossEntropyImpl.fused @@ -102,24 +122,6 @@ def __init__( if hasattr(self, "output_weights"): self.output_weights = self._config.transformer.peft.apply_weight(self.output_weights) - def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: - # Only the first head defines the output weights - if self._tie_word_embeddings or self._prediction_distance > 0: - return - # untie embedding weights - vocab_dim = self._tensor_space[ - LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ] - self.output_weights = ParameterMeta.from_dims( - (vocab_dim, hidden_dim), - init_method=init_normal_( - std=config.init_method_std_embed, - min_val=config.init_method_min_embed, - max_val=config.init_method_max_embed, - ), - lr_scale=config.output_lr_scale, - ) - def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: @@ -251,7 +253,7 @@ def _get_targets( return targets def _get_output_weights(self, kwargs: dict) -> torch.Tensor: - if self._tie_word_embeddings: + if self._config.tie_word_embeddings: return kwargs[WORD_EMBEDDINGS_WEIGHT] if self._prediction_distance > 0: return kwargs[OUTPUT_WEIGHTS] @@ -379,7 +381,7 @@ def _logits_cross_entropy_forward_backward( kwargs.get(f"{self._config.dpo_reference_model}_logits"), kwargs[LanguageModelKwargs.chosen_spans], kwargs[LanguageModelKwargs.rejected_spans], - self.dpo_beta, + self._config.dpo_beta, grad_output * self._loss_coefficient, ) else: @@ -401,7 +403,7 @@ def _logits_cross_entropy_forward_backward( lm_loss, lm_grad = None, None if distillation_target is not None and self._distillation_loss_factor > 0.0: - if self._distillation_loss_implementation == DistillationLossImpl.reverse_kl: + if self._config.distillation_loss_implementation == DistillationLossImpl.reverse_kl: distillation_loss, distillation_grad = reverse_kl_forward_backward( logits.flatten(0, -2), distillation_target, @@ -414,7 +416,7 @@ def _logits_cross_entropy_forward_backward( TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits ), ) - elif self._distillation_loss_implementation == DistillationLossImpl.cross_entropy: + elif self._config.distillation_loss_implementation == DistillationLossImpl.cross_entropy: distillation_loss, distillation_grad = cross_entropy_forward_backward( logits.flatten(0, -2), distillation_target, @@ -426,7 +428,9 @@ def _logits_cross_entropy_forward_backward( target_format=TargetFormat.logits, ) else: - raise ValueError(f"Invalid distillation loss implementation: {self._distillation_loss_implementation}") + raise ValueError( + f"Invalid distillation loss implementation: {self._config.distillation_loss_implementation}" + ) distillation_loss = distillation_loss * self._distillation_loss_factor else: distillation_loss, distillation_grad = None, None diff --git a/fast_llm/layers/ssm/block.py b/fast_llm/layers/ssm/block.py index 0bfa266ac..987d5fa0d 100644 --- a/fast_llm/layers/ssm/block.py +++ b/fast_llm/layers/ssm/block.py @@ -1,7 +1,6 @@ from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.block.block import Block +from fast_llm.layers.block.block import Block, BlockLayer from fast_llm.layers.block.config import BlockConfig -from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.ssm.config import SSMConfig @@ -18,7 +17,7 @@ def __init__( config: BlockConfig, ssm_config: SSMConfig, tensor_space: TensorSpace, - mixer_cls: type[Mixer], + mixer_cls: type[BlockLayer], block_index: int, return_input: bool = False, ): @@ -26,7 +25,7 @@ def __init__( self._mixer_cls = mixer_cls super().__init__(config, tensor_space, block_index, return_input) - def _create_mixer(self) -> Mixer: + def _create_mixer(self) -> BlockLayer: return self._mixer_cls( self._ssm_config, tensor_space=self._tensor_space, diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 550c44d0f..e48636926 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -6,8 +6,8 @@ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType +from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs -from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ @@ -32,7 +32,7 @@ _causal_conv1d_available = False -class DiscreteMamba2(Mixer): +class DiscreteMamba2(BlockLayer): """DiscreteMamba2 (This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py).""" _mixer_name: typing.ClassVar[str] = "discrete_mamba_2" @@ -44,7 +44,13 @@ def __init__( tensor_space: TensorSpace, block_config: BlockConfig, ): - super().__init__(tensor_space, block_index, debug_level=block_config.debug_transformer) + super().__init__( + tensor_space, + block_index, + self._mixer_name, + debug_level=block_config.debug_transformer, + debug_memory=block_config.debug_transformer_memory, + ) self._config: SSMConfig = config layer_lr_scale = block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 1c319f490..4357c0e86 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -5,8 +5,8 @@ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.config import ActivationType +from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs -from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) -class Mamba2(Mixer): +class Mamba2(BlockLayer): """ This code is adapted from https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py """ @@ -56,7 +56,13 @@ def __init__( block_index: int, block_config: BlockConfig, ): - super().__init__(tensor_space, block_index, debug_level=block_config.debug_transformer) + super().__init__( + tensor_space, + block_index, + self._mixer_name, + debug_level=block_config.debug_transformer, + debug_memory=block_config.debug_transformer_memory, + ) self._config: SSMConfig = config Assert.eq(self._config.activation_type, ActivationType.silu) layer_lr_scale: float | None = ( @@ -144,7 +150,13 @@ def __init__( # TODO: lr_scale? ) - def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available assert _causal_conv1d_available @@ -198,12 +210,12 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ # dt: (batch, sequence, heads * state) -> (batch, heads * state, sequence) dt = dt.transpose(1, 2) - if self._debug_level: - self._debug_log(z, "z", self._XZ_DIMS, kwargs) - self._debug_log(x, "x", self._XZ_DIMS, kwargs) - self._debug_log(b, "b", self._BC_DIMS, kwargs) - self._debug_log(c, "c", self._BC_DIMS, kwargs) - self._debug_log(dt, "dt", self._XZ_DIMS, kwargs) + if self._debug.enabled: + self._debug(z, "z", self._XZ_DIMS, kwargs) + self._debug(x, "x", self._XZ_DIMS, kwargs) + self._debug(b, "b", self._BC_DIMS, kwargs) + self._debug(c, "c", self._BC_DIMS, kwargs) + self._debug(dt, "dt", self._XZ_DIMS, kwargs) y = selective_scan_fn( x, @@ -217,8 +229,8 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ delta_softplus=True, ) - if self._debug_level: - self._debug_log(y, "y", self._XZ_DIMS, kwargs) + if self._debug.enabled: + self._debug(y, "y", self._XZ_DIMS, kwargs) # y: (batch, local_heads * state, sequence) -> (batch, sequence, local_heads * state) y = y.transpose(1, 2)[:, :sequence_length] diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index f5b0139cf..590edf18c 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -6,8 +6,8 @@ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType +from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs -from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.tensor import LambdaInitializer, ParameterMeta, init_kaiming_, init_ones_ @@ -52,7 +52,7 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) return LambdaInitializer(init_) -class MambaLayer(Mixer): +class MambaLayer(BlockLayer): _mixer_name: typing.ClassVar[str] = "mamba" def __init__( @@ -62,7 +62,13 @@ def __init__( tensor_space: TensorSpace, block_config: BlockConfig, ): - super().__init__(tensor_space, block_index, debug_level=block_config.debug_transformer) + super().__init__( + tensor_space, + block_index, + self._mixer_name, + debug_level=block_config.debug_transformer, + debug_memory=block_config.debug_transformer_memory, + ) assert tensor_space.distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" self._config = config # TODO: It's not silu? diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index e84e92a96..6598d3a29 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -6,7 +6,7 @@ from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.autograd import wrap_forward_backward -from fast_llm.layers.block.mixer import Mixer +from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs, TransformerConfig @@ -46,7 +46,7 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention(Mixer): +class Attention(BlockLayer): """ A self-attention layer. """ @@ -72,7 +72,13 @@ class Attention(Mixer): ) def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int): - super().__init__(tensor_space, block_index, config.debug_transformer) + super().__init__( + tensor_space, + block_index, + self._mixer_name, + debug_level=config.debug_transformer, + debug_memory=config.debug_transformer_memory, + ) self._config = config self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) @@ -259,7 +265,13 @@ def _decide_window_size(self) -> int | None: return window_size - def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: sequence_first = kwargs[AttentionKwargs.sequence_first] query, key_value = self._query_key_value(input_, sequence_first) @@ -295,14 +307,9 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels) value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels) - if self._debug_level: - self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs) - self._debug_log( - key, - "key_rotary_input", - self._KV_DIMS, - kwargs, - ) + if self._debug.enabled: + self._debug(query, "query_rotary_input", self._QUERY_DIMS, kwargs) + self._debug(key, "key_rotary_input", self._KV_DIMS, kwargs) query, key = self._rotary(query, key, kwargs) window_size = self._decide_window_size() @@ -349,21 +356,11 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ kwargs[AttentionKwargs.attention_mask_value], ) - if self._debug_level: - self._debug_log(query, "query", self._QUERY_DIMS, kwargs) - self._debug_log( - key, - "key", - self._KV_DIMS, - kwargs, - ) - self._debug_log( - value, - "value", - self._KV_DIMS, - kwargs, - ) - self._debug_log(input_, "context", self._CONTEXT_DIMS, kwargs) + if self._debug.enabled: + self._debug(query, "query", self._QUERY_DIMS, kwargs) + self._debug(key, "key", self._KV_DIMS, kwargs) + self._debug(value, "value", self._KV_DIMS, kwargs) + self._debug(input_, "context", self._CONTEXT_DIMS, kwargs) if sequence_first: # TODO: Optimize (is contiguous avoidable? Transpose dense output?) diff --git a/fast_llm/layers/transformer/block.py b/fast_llm/layers/transformer/block.py index 4a0e818f0..89d7a2e3b 100644 --- a/fast_llm/layers/transformer/block.py +++ b/fast_llm/layers/transformer/block.py @@ -2,8 +2,7 @@ import typing from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.block.block import Block -from fast_llm.layers.block.mixer import Mixer +from fast_llm.layers.block.block import Block, BlockLayer from fast_llm.layers.transformer.attention import Attention from fast_llm.layers.transformer.config import TransformerConfig @@ -19,5 +18,5 @@ class TransformerBlock[ConfigType: TransformerConfig](Block[ConfigType]): def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, return_input: bool = False): super().__init__(config, tensor_space, block_index, return_input) - def _create_mixer(self) -> Mixer: + def _create_mixer(self) -> BlockLayer: return Attention(self._config, self._tensor_space, self._block_index) From 9fbb9ff52081c7444a2a54547319ddb8ec05ad01 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 31 Jul 2025 17:18:16 -0400 Subject: [PATCH 38/82] misc --- docs/developer_guide/conversion.md | 30 ++--- fast_llm/layers/block/block.py | 106 +++++++++--------- .../layers/block/mlp/mixture_of_experts.py | 9 +- fast_llm/layers/block/mlp/mlp.py | 19 ++-- tests/test_mlp.py | 4 +- 5 files changed, 88 insertions(+), 80 deletions(-) diff --git a/docs/developer_guide/conversion.md b/docs/developer_guide/conversion.md index 0620beaea..35a324db0 100644 --- a/docs/developer_guide/conversion.md +++ b/docs/developer_guide/conversion.md @@ -230,21 +230,21 @@ Continuing our `AwesomeModel` handler example, we define: ```python def _create_weight_converters(self) -> list[WeightConverter]: - converters = [] - # The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`. - num_layers = self._model.config.base_model.transformer.num_layers - - # A simple renaming example, for the word embeddings. - converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) - - # We usually want to loop dynamically over layers - for i in range(num_layers): - # A `SplitWeightConverter` example, splitting a weight in two. - converters.append(SplitWeightConverter( - f"layers.{i + 1}.weight", - (f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"), - )) - return converters + converters = [] + # The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`. + num_layers = self._model.config.base_model.transformer.num_layers + + # A simple renaming example, for the word embeddings. + converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) + + # We usually want to loop dynamically over layers + for i in range(num_layers): + # A `SplitWeightConverter` example, splitting a weight in two. + converters.append(SplitWeightConverter( + f"layers.{i + 1}.weight", + (f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"), + )) + return converters ``` And that's it! We're ready to use the new checkpoint format in Fast-LLM. diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 87a8f81cf..84fb5f2d4 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -1,5 +1,6 @@ import abc import functools +import logging import typing import torch @@ -15,6 +16,8 @@ from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta +logger = logging.getLogger(__name__) + class DebugLayer: # TODO: Move elsewhere? @@ -47,9 +50,11 @@ def _get_meta( def enabled(self) -> bool: return self._debug_level > 0 or self._debug_memory - def __call__( + def __call__[ + T + ]( self, - tensor: torch.Tensor, + tensor: torch.Tensor | None, name: str, dims: tuple[TensorDim | str, ...], kwargs: dict[str, typing.Any], @@ -60,7 +65,7 @@ def __call__( # TODO: Local vs global? if self._debug_memory: log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"{self._name} {name}", str)) - if self._debug_level > 0: + if self._debug_level > 0 and tensor is not None: log_distributed_tensor( "", tensor, @@ -112,11 +117,8 @@ def forward( ) -> tuple[torch.Tensor, torch.Tensor | None]: pass - def _debug_log(self, tensor: torch.Tensor) -> None: - pass - -class Block[ConfigType: BlockConfig](Layer, Configurable[ConfigType]): +class Block[ConfigType: BlockConfig](Configurable[ConfigType], Layer): """ A transformer-like decoder base block with abstract mixer. """ @@ -125,10 +127,15 @@ class Block[ConfigType: BlockConfig](Layer, Configurable[ConfigType]): _mixer_module_name: typing.ClassVar[str] = "mixer" def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, return_input: bool = False): - super().__init__() - self._config = config + super().__init__(config) self._tensor_space: TensorSpace = tensor_space self._dropout_p: float = self._config.hidden_dropout + self._debug = DebugLayer( + tensor_space, + f"Block {self._block_index} {self._name}", + self._config.debug_transformer, + self._config.debug_transformer_memory, + ) # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input @@ -144,7 +151,9 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: i setattr(self, self._mixer_module_name, self._create_mixer()) self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._tensor_space, f"{self.name} mlp", block_index=block_index + self._config, + self._tensor_space, + self._block_index, ) # PEFT. @@ -163,35 +172,9 @@ def _bias_dropout_add( input_ = input_ + bias return residual + torch.dropout(input_, self._dropout_p, self.training) - @property - def name(self) -> str: - return f"{self._name} {self._block_index}" - - def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): - dims = kwargs[BlockKwargs.hidden_dims] - if self._return_input: - dims = (TensorDim("stacked_input_output", 2),) + dims - return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) - - def _debug_log(self, tensor: torch.Tensor | None, name: str, kwargs: dict[str, typing.Any], *, bias=None) -> None: - if self._config.debug_transformer_memory: - log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"{self.name} {name}", str)) - if self._config.debug_transformer and tensor is not None: - # TODO: Local vs global - log_distributed_tensor( - "", - tensor if bias is None else tensor + bias, - level=self._config.debug_transformer, - meta=self._get_meta(tensor, name, kwargs), - distributed=self._tensor_space.distributed, - ) - log_distributed_grad( - "", - tensor, - level=self._config.debug_transformer, - meta=self._get_meta(tensor, name + " grad", kwargs), - distributed=self._tensor_space.distributed, - ) + # @property + # def name(self) -> str: + # return f"{self._name} {self._block_index}" def forward( self, @@ -201,35 +184,50 @@ def forward( metrics: dict[str, typing.Any] | None = None, ) -> torch.Tensor: if isinstance(input_, TensorMeta): - return self._get_meta(input_, "output", kwargs) + dims = kwargs[BlockKwargs.hidden_dims] + if self._return_input: + dims = (TensorDim("stacked_input_output", 2),) + dims + return TensorMeta.from_dims( + dims, tensor_name=f"{self._name} {self._block_index} output", dtype=input_.dtype + ) generator = ( self._tensor_space.distributed.tp_generator if self._tensor_space.distributed_config.sequence_tensor_parallel else self._tensor_space.distributed.pp_generator ) - if self._debug_mode: - self._debug_log(None, "Begin", kwargs) + if self._debug.enabled: + self._debug(None, "Begin", kwargs[BlockKwargs.hidden_dims], kwargs) fw_input = input_ hidden_states = self.norm_1(input_) - if self._debug_mode: - self._debug_log(hidden_states, "Norm 1", kwargs) + if self._debug.enabled: + self._debug(hidden_states, "Norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) hidden_states, bias = getattr(self, self._mixer_module_name)(hidden_states, kwargs) - if self._debug_mode: - self._debug_log(hidden_states, f"{self._mixer_module_name} output", kwargs, bias=bias) + if self._debug.enabled: + self._debug( + hidden_states if bias is None else hidden_states + bias, + f"{self._mixer_module_name} output", + kwargs[BlockKwargs.hidden_dims], + kwargs, + ) with set_generator(generator): input_ = self._bias_dropout_add(hidden_states, bias, input_) - if self._debug_mode: - self._debug_log(input_, f"{self._mixer_module_name} residual", kwargs) + if self._debug.enabled: + self._debug(input_, f"{self._mixer_module_name} residual", kwargs[BlockKwargs.hidden_dims], kwargs) hidden_states = self.norm_2(input_) - if self._debug_mode: - self._debug_log(hidden_states, "Norm 2", kwargs) + if self._debug.enabled: + self._debug(hidden_states, "Norm 2", kwargs[BlockKwargs.hidden_dims], kwargs) hidden_states, bias = self.mlp(hidden_states, kwargs, losses, metrics) - if self._debug_mode: - self._debug_log(hidden_states, "MLP output", kwargs, bias=bias) + if self._debug.enabled: + self._debug( + hidden_states if bias is None else hidden_states + bias, + "MLP output", + kwargs[BlockKwargs.hidden_dims], + kwargs, + ) with set_generator(generator): hidden_states = self._bias_dropout_add(hidden_states, bias, input_) - if self._debug_mode: - self._debug_log(None, "MLP residual", kwargs, bias=bias) + if self._debug.enabled: + self._debug(None, "MLP residual", kwargs[BlockKwargs.hidden_dims], kwargs) if self._return_input: hidden_states = torch.stack((fw_input, hidden_states), dim=0) return hidden_states diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index 8d092b6dc..88d7ecf62 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -35,11 +35,16 @@ class MixtureOfExpertMLP[ConfigType: BlockConfig](MLPBase[ConfigType]): _group: ProcessGroup - def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): + def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: int = 0, name: str = "mlp"): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, tensor_space, name, block_index) + super().__init__( + config, + tensor_space, + block_index, + name, + ) self._tensor_space = tensor_space self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index 19349671e..a0980c39e 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -2,11 +2,10 @@ import torch -from fast_llm.config import Configurable -from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd +from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockDimNames from fast_llm.layers.block.mlp.config import MLPDimNames from fast_llm.layers.block.peft import TransformerSubLayerName @@ -15,9 +14,15 @@ from fast_llm.utils import Assert, get_lr_scale -class MLPBase[ConfigType: BlockConfig](Configurable[ConfigType], Layer): - def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): - super().__init__(config) +class MLPBase(BlockLayer): + def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: int = 0, name: str = "mlp"): + super().__init__( + tensor_space, + block_index, + name, + debug_level=config.debug_transformer, + debug_memory=config.debug_transformer_memory, + ) self._name = name self._block_index = block_index @@ -67,9 +72,9 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = " class MLP[ConfigType: BlockConfig](MLPBase[ConfigType]): - def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): + def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: int = 0, name: str = "mlp"): Assert.eq(config.num_experts, 1) - super().__init__(config, tensor_space, name, block_index) + super().__init__(config, tensor_space, block_index, name) def forward( self, diff --git a/tests/test_mlp.py b/tests/test_mlp.py index 5875822ff..802833eb2 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -15,7 +15,7 @@ def test_mlp_constructor(): tensor_space = TensorSpace(distributed_config=distributed_config) transformer_conf.setup_tensor_space(tensor_space) - MLP(transformer_conf, tensor_space, "name") + MLP(transformer_conf, tensor_space, 0, "name") def test_moe_mlp_constructor(): @@ -26,4 +26,4 @@ def test_moe_mlp_constructor(): tensor_space = TensorSpace(distributed_config=distributed_config) transformer_conf.setup_tensor_space(tensor_space) - MixtureOfExpertMLP(transformer_conf, tensor_space, "name") + MixtureOfExpertMLP(transformer_conf, tensor_space, 0, "name") From 44df195a207957254fb9bd50354c70cebe63766e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 31 Jul 2025 18:08:29 -0400 Subject: [PATCH 39/82] misc --- .../engine/config_utils/initialization.py | 57 +++++++++++++++ fast_llm/layers/block/block.py | 28 ++++---- .../layers/block/mlp/mixture_of_experts.py | 3 +- fast_llm/layers/block/mlp/mlp.py | 2 +- fast_llm/layers/common/config.py | 2 +- fast_llm/layers/common/linear.py | 3 +- fast_llm/layers/common/normalization.py | 3 +- fast_llm/layers/language_model/embedding.py | 3 +- fast_llm/layers/language_model/head.py | 3 +- fast_llm/layers/ssm/config.py | 4 +- fast_llm/layers/ssm/discrete_mamba2.py | 3 +- fast_llm/layers/ssm/mamba2.py | 3 +- fast_llm/layers/ssm/mamba_layer.py | 3 +- fast_llm/layers/transformer/attention.py | 2 +- fast_llm/tensor.py | 70 +------------------ 15 files changed, 91 insertions(+), 98 deletions(-) create mode 100644 fast_llm/engine/config_utils/initialization.py diff --git a/fast_llm/engine/config_utils/initialization.py b/fast_llm/engine/config_utils/initialization.py new file mode 100644 index 000000000..b60070562 --- /dev/null +++ b/fast_llm/engine/config_utils/initialization.py @@ -0,0 +1,57 @@ +import abc +import typing + +if typing.TYPE_CHECKING: + import torch + + from fast_llm.tensor import ParameterMeta + + +class Initializer(abc.ABC): + @abc.abstractmethod + def __call__(self, meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: + pass + + requires_global_initialization = False + + +class LambdaInitializer(Initializer): + def __init__( + self, + init_method: typing.Callable[["ParameterMeta", "torch.Tensor", "torch.Generator"], None], + requires_global_initialization: bool = False, + ) -> None: + self._init_method = init_method + self.requires_global_initialization = requires_global_initialization + + def __call__(self, meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: + return self._init_method(meta, tensor, generator) + + +def init_fill_(value: float) -> LambdaInitializer: + def init_(meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: # noqa + tensor.fill_(value) + + return LambdaInitializer(init_) + + +init_zeros_ = init_fill_(0.0) +init_ones_ = init_fill_(1.0) + + +def init_normal_( + mean: float = 0.0, std: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: # noqa + tensor = tensor.normal_(mean, std, generator=generator) + if min_val is not None or max_val is not None: + tensor.clamp_(min=min_val, max=max_val) + + return LambdaInitializer(init_) + + +def init_uniform_centered_(scale: float, mean: float = 0.0) -> LambdaInitializer: + def init_(meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: # noqa + tensor.uniform_(mean - scale, mean + scale, generator=generator) + + return LambdaInitializer(init_) diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 84fb5f2d4..292d2c9a4 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -102,7 +102,7 @@ def __init__(self, tensor_space: TensorSpace, block_index: int, name: str, debug self._sequence_parallel: bool = self._tensor_space.distributed_config.sequence_tensor_parallel self._debug = DebugLayer( tensor_space, - f"Block {self._block_index} {self._name}", + self._name, debug_level, debug_memory, ) @@ -128,19 +128,19 @@ class Block[ConfigType: BlockConfig](Configurable[ConfigType], Layer): def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, return_input: bool = False): super().__init__(config) + # TODO: Argument? + self._name = f"Block {self._block_index}" self._tensor_space: TensorSpace = tensor_space self._dropout_p: float = self._config.hidden_dropout + # For multi-token prediction, return a stack of shared_hidden and transformer_output. + self._return_input: bool = return_input + self._block_index = block_index self._debug = DebugLayer( tensor_space, - f"Block {self._block_index} {self._name}", + self._name, self._config.debug_transformer, self._config.debug_transformer_memory, ) - # For multi-token prediction, return a stack of shared_hidden and transformer_output. - self._return_input: bool = return_input - - self._block_index = block_index - self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory hidden_dim = self._tensor_space[BlockDimNames.hidden] # Note, layer_lr_scale does not impact the norms # TODO: add a separate norm_lr_scale @@ -187,35 +187,33 @@ def forward( dims = kwargs[BlockKwargs.hidden_dims] if self._return_input: dims = (TensorDim("stacked_input_output", 2),) + dims - return TensorMeta.from_dims( - dims, tensor_name=f"{self._name} {self._block_index} output", dtype=input_.dtype - ) + return TensorMeta.from_dims(dims, tensor_name=f"{self._name} output", dtype=input_.dtype) generator = ( self._tensor_space.distributed.tp_generator if self._tensor_space.distributed_config.sequence_tensor_parallel else self._tensor_space.distributed.pp_generator ) if self._debug.enabled: - self._debug(None, "Begin", kwargs[BlockKwargs.hidden_dims], kwargs) + self._debug(None, "begin", kwargs[BlockKwargs.hidden_dims], kwargs) fw_input = input_ hidden_states = self.norm_1(input_) if self._debug.enabled: - self._debug(hidden_states, "Norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) + self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) hidden_states, bias = getattr(self, self._mixer_module_name)(hidden_states, kwargs) if self._debug.enabled: self._debug( hidden_states if bias is None else hidden_states + bias, - f"{self._mixer_module_name} output", + "mixer output", kwargs[BlockKwargs.hidden_dims], kwargs, ) with set_generator(generator): input_ = self._bias_dropout_add(hidden_states, bias, input_) if self._debug.enabled: - self._debug(input_, f"{self._mixer_module_name} residual", kwargs[BlockKwargs.hidden_dims], kwargs) + self._debug(input_, "mixer residual", kwargs[BlockKwargs.hidden_dims], kwargs) hidden_states = self.norm_2(input_) if self._debug.enabled: - self._debug(hidden_states, "Norm 2", kwargs[BlockKwargs.hidden_dims], kwargs) + self._debug(hidden_states, "norm 2", kwargs[BlockKwargs.hidden_dims], kwargs) hidden_states, bias = self.mlp(hidden_states, kwargs, losses, metrics) if self._debug.enabled: self._debug( diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index 88d7ecf62..46005234c 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -5,6 +5,7 @@ import torch from fast_llm.core.distributed import ProcessGroup, set_generator +from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped @@ -15,7 +16,7 @@ from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.linear import Linear from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage -from fast_llm.tensor import TensorMeta, init_normal_ +from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert, get_lr_scale logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index a0980c39e..7d4643673 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -2,6 +2,7 @@ import torch +from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd @@ -10,7 +11,6 @@ from fast_llm.layers.block.mlp.config import MLPDimNames from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import LinearBase -from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import Assert, get_lr_scale diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 9d5ce3f3b..2f45fdf9f 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -87,7 +87,7 @@ class LayerNormalizationBaseConfig(NormalizationConfig): ) def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": - from fast_llm.tensor import init_uniform_centered_ + from fast_llm.engine.config_utils.initialization import init_uniform_centered_ kwargs = { "hidden_dim": hidden_dim, diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear.py index 7249ef569..740b4847c 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear.py @@ -3,6 +3,7 @@ import torch +from fast_llm.engine.config_utils.initialization import init_zeros_ from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.linear import ( @@ -14,7 +15,7 @@ output_parallel_linear_backward, output_parallel_linear_forward, ) -from fast_llm.tensor import ParameterMeta, init_zeros_ +from fast_llm.tensor import ParameterMeta logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index bccc1d627..d44be3297 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -1,11 +1,12 @@ import torch +from fast_llm.engine.config_utils.initialization import init_ones_, init_zeros_ from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.normalization import triton_normalization_autograd from fast_llm.layers.common.config import NormalizationImplementation -from fast_llm.tensor import ParameterMeta, accumulate_gradient, init_ones_, init_zeros_ +from fast_llm.tensor import ParameterMeta, accumulate_gradient from fast_llm.utils import Assert try: diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 05678a700..68aa4882b 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -6,9 +6,10 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import reduce_forward, split from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelDimNames, LanguageModelKwargs -from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ +from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import Assert WORD_EMBEDDINGS_WEIGHT = "word_embeddings_weight" diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 0623ac201..63d1a6b27 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -8,6 +8,7 @@ from fast_llm.config import Configurable from fast_llm.core.ops import split_op from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward @@ -25,7 +26,7 @@ ) from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.logging import log_distributed_tensor -from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ +from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import Assert, div, get_unique logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index efcf2d873..00c709814 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -9,7 +9,7 @@ from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: - from fast_llm.tensor import Initializer + from fast_llm.engine.config_utils.initialization import Initializer, init_fill_, init_uniform_centered_ class SSMDimNames(BlockDimNames): @@ -66,8 +66,6 @@ class DTInitType(enum.StrEnum): random = "random" def get_init_method(self, scale: float) -> "Initializer": - from fast_llm.tensor import init_fill_, init_uniform_centered_ - return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index e48636926..e967ab9d1 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -4,13 +4,14 @@ import einops import torch +from fast_llm.engine.config_utils.initialization import init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ +from fast_llm.tensor import ParameterMeta from fast_llm.utils import get_lr_scale logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 4357c0e86..5d62c144f 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -3,6 +3,7 @@ import torch +from fast_llm.engine.config_utils.initialization import init_kaiming_, init_ones_, init_uniform_centered_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer @@ -10,7 +11,7 @@ from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias -from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_ +from fast_llm.tensor import ParameterMeta from fast_llm.utils import Assert, div, get_lr_scale try: diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 590edf18c..0f3224f77 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -4,13 +4,14 @@ import torch +from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_kaiming_, init_ones_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.tensor import LambdaInitializer, ParameterMeta, init_kaiming_, init_ones_ +from fast_llm.tensor import ParameterMeta from fast_llm.utils import Assert, get_lr_scale try: diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 6598d3a29..ba7f2bb6e 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -4,13 +4,13 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim +from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs, TransformerConfig -from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import get_lr_scale try: diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index d080e6a1e..b12d12072 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -1,13 +1,12 @@ -import abc import functools import logging -import math import typing import torch from fast_llm.core.distributed import ReduceOp from fast_llm.core.ops import reduce_op +from fast_llm.engine.config_utils.initialization import Initializer, LambdaInitializer from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed @@ -361,70 +360,3 @@ def accumulate_gradient(param: torch.Tensor, grad: torch.Tensor) -> None: triton_copy(grad, param.grad_buffer) # noqa else: triton_add(grad, param.grad_buffer, out=param.grad_buffer) # noqa - - -class Initializer(abc.ABC): - @abc.abstractmethod - def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: - pass - - requires_global_initialization = False - - -class LambdaInitializer(Initializer): - def __init__( - self, - init_method: typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None], - requires_global_initialization: bool = False, - ) -> None: - self._init_method = init_method - self.requires_global_initialization = requires_global_initialization - - def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: - return self._init_method(meta, tensor, generator) - - -def init_fill_(value: float) -> LambdaInitializer: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa - tensor.fill_(value) - - return LambdaInitializer(init_) - - -init_zeros_ = init_fill_(0.0) -init_ones_ = init_fill_(1.0) - - -def init_normal_( - mean: float = 0.0, std: float = 1.0, min_val: float | None = None, max_val: float | None = None -) -> LambdaInitializer: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa - tensor = tensor.normal_(mean, std, generator=generator) - if min_val is not None or max_val is not None: - tensor.clamp_(min=min_val, max=max_val) - - return LambdaInitializer(init_) - - -def init_kaiming_(d_in: float) -> LambdaInitializer: - return init_normal_(0.0, math.sqrt(2.0 / d_in)) - - -def init_uniform_( - low: float = 0.0, high: float = 1.0, min_val: float | None = None, max_val: float | None = None -) -> LambdaInitializer: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa - tensor = tensor.uniform_(low, high, generator=generator) - if min_val is not None or max_val is not None: - tensor.clamp_(min=min_val, max=max_val) - - return LambdaInitializer(init_) - - -def init_uniform_centered_(high: float, max_val: float | None = None, mean: float = 0.0) -> LambdaInitializer: - return init_uniform_( - mean - high, - mean + high, - min_val=None if max_val is None else mean - max_val, - max_val=None if max_val is None else mean + max_val, - ) From 3bb03cb3cf4bc64ba286f4f9a5074d0ecff8c227 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 31 Jul 2025 19:15:14 -0400 Subject: [PATCH 40/82] misc --- fast_llm/layers/block/config.py | 2 - fast_llm/layers/block/mlp/config.py | 4 +- .../layers/block/mlp/mixture_of_experts.py | 127 ++++++------------ fast_llm/layers/block/mlp/mlp.py | 42 +++--- 4 files changed, 63 insertions(+), 112 deletions(-) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 489cd4f3f..6111c7e00 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -120,8 +120,6 @@ class BlockConfig(MLPConfig, BaseModelConfig): def _validate(self) -> None: with self._set_implicit_default(): - if self.ffn_hidden_size is None: - self.ffn_hidden_size = 4 * self.hidden_size # TODO: Review initialization if self.init_method_std is None: self.init_method_std = self.hidden_size**-0.5 diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 64e234544..92697de44 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -158,9 +158,11 @@ class MLPConfig(Config): def _validate(self) -> None: with self._set_implicit_default(): - # TODO: Make this work without inheritance. if self.activation_type is None: self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu + # TODO: Make this work without inheritance. + if self.ffn_hidden_size is None: + self.ffn_hidden_size = 4 * self.hidden_size # TODO: Review initialization if self.init_method_std_mlp_1 is None: self.init_method_std_mlp_1 = self.init_method_std diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index 46005234c..3a517db20 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -1,12 +1,10 @@ import logging -import typing import warnings import torch from fast_llm.core.distributed import ProcessGroup, set_generator from fast_llm.engine.config_utils.initialization import init_normal_ -from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map @@ -15,8 +13,6 @@ from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.linear import Linear -from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage -from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert, get_lr_scale logger = logging.getLogger(__name__) @@ -40,59 +36,44 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__( - config, - tensor_space, - block_index, - name, - ) - self._tensor_space = tensor_space - self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory - - self._num_experts = config.num_experts - self._experts_per_token = config.num_experts_per_token - self._num_shared_experts = config.num_shared_experts - self._num_unshared_experts = config.num_unshared_experts - - self._routing_type = config.expert_routing_type - self._load_balancing_factor = config.expert_auxiliary_loss_coefficient - self._z_loss_factor = config.expert_z_loss_coefficient - self._moe_jitter_eps = config.moe_jitter_eps + super().__init__(config, tensor_space, block_index, name) - layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) + layer_lr_scale = self._config.per_layer_lr_scale[block_index] if self._config.per_layer_lr_scale else None + router_lr_scale = get_lr_scale(self._config.router_lr_scale, layer_lr_scale) self.router = Linear( tensor_space[BlockDimNames.hidden], tensor_space[MLPDimNames.unshared_experts], bias=False, weight_init_method=init_normal_( - std=config.init_method_std, min_val=config.init_method_min, max_val=config.init_method_max + std=self._config.init_method_std, + min_val=self._config.init_method_min, + max_val=self._config.init_method_max, ), lr_scale=router_lr_scale, ) - dropless_moe = config.dropless_moe + dropless_moe = self._config.dropless_moe if dropless_moe and tensor_space.distributed_config.sequence_tensor_parallel: warnings.warn( "Dropless MoE not supported for sequence-tensor-parallel, falling back to looped implementation." ) dropless_moe = False self._mlp_forward = self._forward_dropless if dropless_moe else self._forward_looped - self._dynamic_shape = config.dropless_dynamic_shape + self._dynamic_shape = self._config.dropless_dynamic_shape def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: hidden_states = input_.flatten(0, -2) logits = self.router(hidden_states) - if self._debug_mode: - self._debug_log(logits, "Router logits", MLPDimNames.experts, kwargs) + if self._debug.enabled: + self._debug(logits, "Router logits", MLPDimNames.experts, kwargs) # Apply z_loss if applicable - if self._z_loss_factor > 0.0: + if self._config.expert_z_loss_coefficient > 0.0: logits = z_loss( logits, - self._z_loss_factor, + self._config.expert_z_loss_coefficient, self.training, grad_scale=kwargs.get("grad_output"), losses=losses, @@ -100,24 +81,24 @@ def forward( ) # Apply input_jitter if applicable: - if self.training and self._moe_jitter_eps > 0.0: + if self.training and self._config.moe_jitter_eps > 0.0: with set_generator(self._tensor_space.distributed.pp_generator): logits = self._apply_input_jitter(logits) # Routing - if self._routing_type == RoutingType.topk: + if self._config.expert_routing_type == RoutingType.topk: scores, top_experts = self._topk_routing(logits, kwargs.get(BlockKwargs.grad_output), losses) - if self._num_shared_experts > 0: + if self._config.num_shared_experts > 0: scores, top_experts = self._add_shared_experts(top_experts, scores) - elif self._routing_type == RoutingType.sinkhorn: + elif self._config.expert_routing_type == RoutingType.sinkhorn: scores, top_experts = self._sinkhorn_routing(logits) else: - raise NotImplementedError(self._routing_type) + raise NotImplementedError(self._config.expert_routing_type) - if self._debug_mode: + if self._debug.enabled: # To log all ranks set `global_=False` - self._debug_log(scores, "Router scores", MLPDimNames.top_experts, kwargs) - self._debug_log(top_experts, "Router top experts", MLPDimNames.top_experts, kwargs) + self._debug(scores, "Router scores", MLPDimNames.top_experts, kwargs) + self._debug(top_experts, "Router top experts", MLPDimNames.top_experts, kwargs) return self._mlp_forward(hidden_states, scores, top_experts).view_as(input_), None # noqa @@ -125,7 +106,7 @@ def _forward_dropless( self, hidden_states: torch.Tensor, scores: torch.Tensor, top_experts: torch.Tensor ) -> torch.Tensor: # Compute token counts and the sparse mapping (dense_row, top_index) -> sparse_row. - sparse_map = get_sparse_map(top_experts, self._num_experts, dynamic_shape=self._dynamic_shape) + sparse_map = get_sparse_map(top_experts, self._config.num_experts, dynamic_shape=self._dynamic_shape) # Sparse MLP return mlp_autograd( @@ -154,7 +135,7 @@ def _forward_looped( top_experts, self.layer_1.weight, self.layer_2.weight, - self._num_experts, + self._config.num_experts, self._config.gated, self._config.activation_type, self._intermediate_dim.parallel_group, @@ -165,7 +146,9 @@ def _forward_looped( @torch.compile def _apply_input_jitter(self, logits: torch.Tensor) -> torch.Tensor: - return logits * torch.empty_like(logits).uniform_(1.0 - self._moe_jitter_eps, 1.0 + self._moe_jitter_eps) + return logits * torch.empty_like(logits).uniform_( + 1.0 - self._config.moe_jitter_eps, 1.0 + self._config.moe_jitter_eps + ) def _topk_routing( self, @@ -173,11 +156,11 @@ def _topk_routing( grad_scale: float | None = None, losses: dict | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - top_logits, top_experts = torch.topk(logits, k=self._experts_per_token, dim=-1) + top_logits, top_experts = torch.topk(logits, k=self._config.num_experts_per_token, dim=-1) scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32) if losses is not None or (self.training and grad_scale is not None): probs = torch.softmax(logits, dim=-1, dtype=torch.float32) - mask = torch.nn.functional.one_hot(top_experts, num_classes=self._num_unshared_experts).sum(dim=1) + mask = torch.nn.functional.one_hot(top_experts, num_classes=self._config.num_unshared_experts).sum(dim=1) # Auxiliary loss, corresponding to the sum of probabilities for the top experts. # In the optimal case (uniform distribution), loss = experts_per_token / num_experts. # In the worst case (whole distribution in the top experts), loss = 1. @@ -188,7 +171,9 @@ def _topk_routing( losses[MLPLossNames.load_balancing_loss].append(aux_loss.detach()) if self.training and grad_scale is not None: scores = AuxiliaryLoss.apply( - scores, aux_loss, self._num_unshared_experts * self._load_balancing_factor * grad_scale + scores, + aux_loss, + self._config.num_unshared_experts * self._config.expert_auxiliary_loss_coefficient * grad_scale, ) return scores, top_experts @@ -197,69 +182,33 @@ def _add_shared_experts( ) -> tuple[torch.Tensor, torch.Tensor]: # Add the shared experts (last ones) to the top experts. shared_experts = torch.arange( - self._num_unshared_experts, self._num_experts, device=top_experts.device, dtype=top_experts.dtype + self._config.num_unshared_experts, + self._config.num_experts, + device=top_experts.device, + dtype=top_experts.dtype, )[None].repeat(top_experts.size(0), 1) top_experts = torch.cat((shared_experts, top_experts), dim=1) # Add scores of 1 to scores for shared experts. - scores = torch.cat((scores.new_ones(scores.size(0), self._num_shared_experts), scores), dim=1) + scores = torch.cat((scores.new_ones(scores.size(0), self._config.num_shared_experts), scores), dim=1) return scores, top_experts def _sinkhorn_routing(self, logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: if self.training: - _, top_experts = torch.topk(sinkhorn(logits), k=self._experts_per_token, dim=-1) + _, top_experts = torch.topk(sinkhorn(logits), k=self._config.num_experts_per_token, dim=-1) logits = self._sinkhorn_activation(logits) scores = torch.gather(logits, -1, top_experts) else: logits = self._sinkhorn_activation(logits) - scores, top_experts = torch.topk(logits, k=self._experts_per_token, dim=-1) + scores, top_experts = torch.topk(logits, k=self._config.num_experts_per_token, dim=-1) return scores, top_experts def _sinkhorn_activation(self, logits: torch.Tensor) -> torch.Tensor: return ( torch.sigmoid(logits) - if self._experts_per_token == 1 + if self._config.num_experts_per_token == 1 else torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) ) - def _debug_log( - self, - tensor: torch.Tensor | None, - name: str, - dim_name: str, - kwargs: dict[str, typing.Any], - global_: bool = True, - ) -> None: - if self._config.debug_transformer_memory: - log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"{self._name} {name}", str)) - if self._config.debug_transformer and tensor is not None: - # TODO: Local vs global - meta = self._get_meta(tensor, name, dim_name, kwargs) - log_distributed_tensor( - "", - tensor.view_as(meta), - level=self._config.debug_transformer, - meta=meta, - distributed=self._tensor_space.distributed, - global_=global_, - ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=self._config.debug_transformer, - meta=self._get_meta(tensor, name + " grad", dim_name, kwargs), - distributed=self._tensor_space.distributed, - grad_fn=lambda tensor_: tensor_.view_as(meta), - global_=global_, - ) - - def _get_meta(self, tensor: torch.Tensor, name: str, dim_name: str, kwargs: dict[str, typing.Any]) -> TensorMeta: - return TensorMeta.from_dims( - kwargs[BlockKwargs.hidden_dims][:-1] + (self._tensor_space[dim_name],), - tensor_name=f"{self._name} {name}", - dtype=tensor.dtype, - ) - def sinkhorn(cost: torch.Tensor, tolerance: float = 1e-5, eps=1e-9) -> torch.Tensor: """Sinkhorn based MoE routing function""" diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index 7d4643673..577986e3a 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -23,52 +23,54 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: debug_level=config.debug_transformer, debug_memory=config.debug_transformer_memory, ) - self._name = name - self._block_index = block_index + self._config = config init_method_1 = init_normal_( - std=config.init_method_std_mlp_1, - min_val=config.init_method_min_mlp_1, - max_val=config.init_method_max_mlp_1, + std=self._config.init_method_std_mlp_1, + min_val=self._config.init_method_min_mlp_1, + max_val=self._config.init_method_max_mlp_1, ) init_method_2 = init_normal_( - std=config.init_method_std_mlp_2, - min_val=config.init_method_min_mlp_2, - max_val=config.init_method_max_mlp_2, + std=self._config.init_method_std_mlp_2, + min_val=self._config.init_method_min_mlp_2, + max_val=self._config.init_method_max_mlp_2, ) - hidden_dim = tensor_space[BlockDimNames.hidden] - self._intermediate_dim = tensor_space[MLPDimNames.composite_expert_mlp] - self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel + hidden_dim = self._tensor_space[BlockDimNames.hidden] + self._intermediate_dim = self._tensor_space[MLPDimNames.composite_expert_mlp] self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - lr_scale = tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale + layer_lr_scale = self._config.per_layer_lr_scale[block_index] if self._config.per_layer_lr_scale else None + lr_scale = ( + tuple(self._config.mlp_lr_scale) + if isinstance(self._config.mlp_lr_scale, list) + else self._config.mlp_lr_scale + ) lr_scale = get_lr_scale(lr_scale, layer_lr_scale) # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - tensor_space[MLPDimNames.composite_gated_expert_mlp], - bias=config.add_mlp_bias, + self._tensor_space[MLPDimNames.composite_gated_expert_mlp], + bias=self._config.add_mlp_bias, weight_init_method=init_method_1, - bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, + bias_init_method=init_zeros_, lr_scale=lr_scale, ) self.layer_2 = LinearBase( self._intermediate_dim, hidden_dim, - bias=config.add_mlp_bias, + bias=self._config.add_mlp_bias, weight_init_method=init_method_2, - bias_init_method=init_method_2 if config.random_bias_init else init_zeros_, + bias_init_method=init_zeros_, auto_bias_grad_accumulation=tensor_space.distributed_config.tensor_parallel > 1, transposed_weight=True, lr_scale=lr_scale, ) # PEFT. - self.layer_1 = config.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1) - self.layer_2 = config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) + self.layer_1 = self._config.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1) + self.layer_2 = self._config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) class MLP[ConfigType: BlockConfig](MLPBase[ConfigType]): From 98bae95d7595d13077c4608b0adedc23cdce1297 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 31 Jul 2025 19:41:57 -0400 Subject: [PATCH 41/82] misc --- fast_llm/layers/block/config.py | 8 ---- fast_llm/layers/block/mlp/config.py | 10 ++++ .../layers/block/mlp/mixture_of_experts.py | 18 +++++-- fast_llm/layers/block/mlp/mlp.py | 2 +- fast_llm/layers/language_model/embedding.py | 16 ++++--- fast_llm/layers/language_model/head.py | 47 +++++++------------ fast_llm/layers/ssm/discrete_mamba2.py | 11 ++++- fast_llm/layers/ssm/mamba2.py | 4 +- fast_llm/layers/ssm/mamba_layer.py | 14 +++++- fast_llm/layers/transformer/preprocessing.py | 6 +-- 10 files changed, 75 insertions(+), 61 deletions(-) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 6111c7e00..756e54dac 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -128,14 +128,6 @@ def _validate(self) -> None: super()._validate() - @property - def add_mlp_bias(self) -> bool: - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.everywhere: - return True - return False - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: super().setup_tensor_space(tensor_space) diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 92697de44..70f05956a 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -4,6 +4,7 @@ from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType, MLPRecomputeLevel +from fast_llm.layers.block.config import AddLinearBiasChoices from fast_llm.utils import Assert @@ -156,6 +157,15 @@ class MLPConfig(Config): hint=FieldHint.optional, ) + @property + def add_mlp_bias(self) -> bool: + # TODO: Make this work without inheritance. + if isinstance(self.add_linear_biases, bool): + return self.add_linear_biases + if self.add_linear_biases == AddLinearBiasChoices.everywhere: + return True + return False + def _validate(self) -> None: with self._set_implicit_default(): if self.activation_type is None: diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index 3a517db20..60cee9847 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -59,7 +59,6 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: ) dropless_moe = False self._mlp_forward = self._forward_dropless if dropless_moe else self._forward_looped - self._dynamic_shape = self._config.dropless_dynamic_shape def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None @@ -67,7 +66,7 @@ def forward( hidden_states = input_.flatten(0, -2) logits = self.router(hidden_states) if self._debug.enabled: - self._debug(logits, "Router logits", MLPDimNames.experts, kwargs) + self._debug(logits, "Router logits", kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.experts,), kwargs) # Apply z_loss if applicable if self._config.expert_z_loss_coefficient > 0.0: @@ -97,8 +96,15 @@ def forward( if self._debug.enabled: # To log all ranks set `global_=False` - self._debug(scores, "Router scores", MLPDimNames.top_experts, kwargs) - self._debug(top_experts, "Router top experts", MLPDimNames.top_experts, kwargs) + self._debug( + scores, "Router scores", kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.top_experts,), kwargs + ) + self._debug( + top_experts, + "Router top experts", + kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.top_experts,), + kwargs, + ) return self._mlp_forward(hidden_states, scores, top_experts).view_as(input_), None # noqa @@ -106,7 +112,9 @@ def _forward_dropless( self, hidden_states: torch.Tensor, scores: torch.Tensor, top_experts: torch.Tensor ) -> torch.Tensor: # Compute token counts and the sparse mapping (dense_row, top_index) -> sparse_row. - sparse_map = get_sparse_map(top_experts, self._config.num_experts, dynamic_shape=self._dynamic_shape) + sparse_map = get_sparse_map( + top_experts, self._config.num_experts, dynamic_shape=self._config.dropless_dynamic_shape + ) # Sparse MLP return mlp_autograd( diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index 577986e3a..6243c17bd 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -63,7 +63,7 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: bias=self._config.add_mlp_bias, weight_init_method=init_method_2, bias_init_method=init_zeros_, - auto_bias_grad_accumulation=tensor_space.distributed_config.tensor_parallel > 1, + auto_bias_grad_accumulation=self._tensor_space.distributed_config.tensor_parallel > 1, transposed_weight=True, lr_scale=lr_scale, ) diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 68aa4882b..051044ef6 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -22,19 +22,19 @@ class LanguageModelEmbedding[ConfigType: LanguageModelBaseConfig](Configurable[L together with optional absolute position embeddings and dropout. """ - config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = LanguageModelBaseConfig + config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = ConfigType # Ensure the layer is on its own stage. layer_count: float = 1000.0 def __init__( self, - config: LanguageModelBaseConfig, + config: ConfigType, tensor_space: TensorSpace, ): super().__init__(config) - self._distributed_config = tensor_space.distributed_config self._tensor_space = tensor_space + self._distributed_config = self._tensor_space.distributed_config self._residual_dtype = ( self._distributed_config.optimization_dtype if config.transformer.full_precision_residual @@ -42,12 +42,14 @@ def __init__( ).torch self._group_size = self._distributed_config.tensor_parallel self._sequence_parallel = self._distributed_config.sequence_tensor_parallel - self._parallel_embeddings = tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings + self._parallel_embeddings = ( + self._tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings + ) self._dropout_p = config.transformer.hidden_dropout self._use_absolute_position_embeddings = config.use_absolute_position_embeddings - hidden_dim = tensor_space[LanguageModelDimNames.hidden] - vocab_dim = tensor_space[ + hidden_dim = self._tensor_space[LanguageModelDimNames.hidden] + vocab_dim = self._tensor_space[ LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab ] @@ -66,7 +68,7 @@ def __init__( ) if self._use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( - (tensor_space[LanguageModelDimNames.position_embed], hidden_dim), + (self._tensor_space[LanguageModelDimNames.position_embed], hidden_dim), init_method=init_normal_( std=config.init_method_std_embed, min_val=config.init_method_min_embed, diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 63d1a6b27..2fa0b0f06 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -25,7 +25,6 @@ LanguageModelLossNames, ) from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT -from fast_llm.logging import log_distributed_tensor from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import Assert, div, get_unique @@ -34,16 +33,16 @@ OUTPUT_WEIGHTS = "output_weights" -class LanguageModelHead[ConfigType: LanguageModelBaseConfig](Configurable[LanguageModelBaseConfig], Layer): +class LanguageModelHead[ConfigType: LanguageModelBaseConfig](Configurable[ConfigType], Layer): """ A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). """ - config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = LanguageModelBaseConfig + config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = ConfigType def __init__( self, - config: LanguageModelBaseConfig, + config: ConfigType, tensor_space: TensorSpace, prediction_distance: int, ): @@ -105,8 +104,7 @@ def __init__( lr_scale=self._config.output_lr_scale, ) - self._use_dpo_loss = self._config.enable_dpo - if not self._use_dpo_loss: + if not self._config.enable_dpo: self._cross_entropy_impl = self._config.cross_entropy_impl if self._cross_entropy_impl == CrossEntropyImpl.auto: if self._parallel_embeddings: @@ -204,7 +202,7 @@ def _get_targets( self, kwargs: dict ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | None: # Loss mask for distillation. (Labels are already masked.) - if self._use_dpo_loss: + if self._config.enable_dpo: dpo_target = kwargs.get(LanguageModelKwargs.labels) lm_target = None distillation_target = None @@ -341,35 +339,22 @@ def _logits_cross_entropy_forward_backward( LanguageModelLossNames.z_loss, logits_scale_factor=self._logits_scale_factor, ) - if self._debug_transformer and self._cross_entropy_splits is None: - vocab_dim = self._tensor_space[ + if self._debug.enabled and self._cross_entropy_splits is None: + vocab_dim = ( LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp - ] - dims = [*kwargs[LanguageModelKwargs.hidden_dims][:-1], vocab_dim] - sequence_index = 1 - int(kwargs[LanguageModelKwargs.sequence_first]) - dims[sequence_index] = ( - TensorDim( - LanguageModelDimNames.sequence_q_tp, dims[sequence_index].global_size, DistributedDimNames.tensor - ) - if self._sequence_parallel_logits - else TensorDim(LanguageModelDimNames.sequence_q, dims[sequence_index].global_size) ) - - dim_names = ( - [LanguageModelDimNames.sequence_q_tp, LanguageModelDimNames.vocab] + sequence_dim = ( + LanguageModelDimNames.sequence_q_tp if self._sequence_parallel_logits - else [LanguageModelDimNames.sequence_q, LanguageModelDimNames.vocab_tp] + else LanguageModelDimNames.sequence_q ) - - dim_names.insert(int(kwargs[LanguageModelKwargs.sequence_first]), LanguageModelDimNames.batch) - log_distributed_tensor( - "", - logits, - level=self._debug_transformer, - meta=TensorMeta.from_dims(tuple(dims), tensor_name="transformer logits", dtype=logits.dtype), - distributed=self._tensor_space.distributed, - scale=self._logits_scale_factor, + batch_dim = kwargs[LanguageModelKwargs.hidden_dims][1 if kwargs[LanguageModelKwargs.sequence_first] else 0] + dims = ( + (sequence_dim, batch_dim, vocab_dim) + if kwargs[LanguageModelKwargs.sequence_first] + else (batch_dim, sequence_dim, vocab_dim) ) + self._debug(logits, "Language model logits", dims, kwargs, scale=self._logits_scale_factor) if targets is None: return logits * self._logits_scale_factor, None diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index e967ab9d1..61291f845 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -4,13 +4,14 @@ import einops import torch -from fast_llm.engine.config_utils.initialization import init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ +from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_, init_zeros_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.ssm.mamba_layer import init_kaiming_ from fast_llm.tensor import ParameterMeta from fast_llm.utils import get_lr_scale @@ -117,7 +118,13 @@ def __init__( lr_scale=lr_scale, ) - def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available sequence_length = kwargs[BlockKwargs.sequence_q_dim].global_size diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 5d62c144f..b6626e893 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -3,14 +3,14 @@ import torch -from fast_llm.engine.config_utils.initialization import init_kaiming_, init_ones_, init_uniform_centered_ +from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias +from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias, init_kaiming_ from fast_llm.tensor import ParameterMeta from fast_llm.utils import Assert, div, get_lr_scale diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 0f3224f77..0dcc29f0b 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -4,7 +4,7 @@ import torch -from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_kaiming_, init_ones_ +from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer @@ -146,7 +146,13 @@ def __init__( ) self.out_proj.weight.auto_grad_accumulation = True - def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available in_proj = self.in_proj(input_).permute((1, 2, 0) if kwargs[BlockKwargs.sequence_first] else (0, 2, 1)) @@ -170,3 +176,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if kwargs[BlockKwargs.sequence_first]: out = out.transpose(0, 1) return out, None + + +def init_kaiming_(d_in: float) -> LambdaInitializer: + return init_normal_(0.0, math.sqrt(2.0 / d_in)) diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index d8fa14a6d..16e5811e6 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -5,7 +5,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.layers.transformer.config import AttentionKwargs, TransformerConfig +from fast_llm.layers.transformer.config import AttentionConfig, AttentionKwargs from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) @@ -21,7 +21,7 @@ class BackupAttentionPreprocessor(Preprocessor): def __init__( self, - config: TransformerConfig, + config: AttentionConfig, tensor_space: TensorSpace, ): self._config = config @@ -91,7 +91,7 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: class FlashAttnVarlenPreprocessor(Preprocessor): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace): + def __init__(self, config: AttentionConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config From fd731ef76ba1ac52610291cb17e38eee7107be71 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 31 Jul 2025 21:06:12 -0400 Subject: [PATCH 42/82] fixes --- fast_llm/data/data/abstract.py | 1 + fast_llm/data/data/gpt/data.py | 2 ++ fast_llm/engine/schedule/runner.py | 1 + fast_llm/layers/block/block.py | 9 ++++++--- fast_llm/layers/block/config.py | 10 ---------- fast_llm/layers/block/mlp/config.py | 3 ++- fast_llm/layers/block/mlp/mixture_of_experts.py | 2 +- fast_llm/layers/block/mlp/mlp.py | 2 +- fast_llm/layers/language_model/embedding.py | 2 +- fast_llm/layers/language_model/head.py | 2 +- fast_llm/layers/ssm/config.py | 4 +++- fast_llm/layers/transformer/config.py | 11 +++++++++++ fast_llm/layers/transformer/rotary/rotary.py | 9 +++++++++ 13 files changed, 39 insertions(+), 19 deletions(-) diff --git a/fast_llm/data/data/abstract.py b/fast_llm/data/data/abstract.py index e24d39985..04da64a9d 100644 --- a/fast_llm/data/data/abstract.py +++ b/fast_llm/data/data/abstract.py @@ -13,6 +13,7 @@ class Data[ConfigType: DataConfig](Configurable[ConfigType], abc.ABC): + config_class: typing.ClassVar[type[DataConfig]] = DataConfig _distributed: "Distributed" _sampling_parameters: dict[str, SamplingParameters] _cache_directory: pathlib.Path | None diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 6724afb59..37cfd9020 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -65,6 +65,8 @@ class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]): TODO: Separate generic and GPT classes. """ + config_class: typing.ClassVar[type[GPTDataConfig]] = GPTDataConfig + _datasets: dict[str, SampledDataset] _sampling_parameters: dict[str, GPTSamplingParameters] _tokenizer: Tokenizer | None diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 8eca4559d..7fdba1832 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -64,6 +64,7 @@ def __repr__(self): class ScheduleRunner[ConfigType: ScheduleConfig](Configurable[ScheduleConfig]): + config_class: typing.ClassVar[type[ScheduleConfig]] = ScheduleConfig _is_setup: bool = False _compute_stream: torch.cuda.Stream _data_stream: torch.cuda.Stream diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 292d2c9a4..03e0df928 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -11,8 +11,6 @@ from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs -from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP -from fast_llm.layers.block.mlp.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta @@ -123,18 +121,19 @@ class Block[ConfigType: BlockConfig](Configurable[ConfigType], Layer): A transformer-like decoder base block with abstract mixer. """ + config_class: typing.ClassVar[type[BlockConfig]] = BlockConfig # TODO: Standardize to `mixer` _mixer_module_name: typing.ClassVar[str] = "mixer" def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, return_input: bool = False): super().__init__(config) # TODO: Argument? + self._block_index = block_index self._name = f"Block {self._block_index}" self._tensor_space: TensorSpace = tensor_space self._dropout_p: float = self._config.hidden_dropout # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input - self._block_index = block_index self._debug = DebugLayer( tensor_space, self._name, @@ -150,6 +149,10 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: i # The mixer needs to be created here for backward-compatible weight ordering. setattr(self, self._mixer_module_name, self._create_mixer()) + # TODO: Use dynamic type. + from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP + from fast_llm.layers.block.mlp.mlp import MLP + self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( self._config, self._tensor_space, diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 756e54dac..919f95b3f 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -118,16 +118,6 @@ class BlockConfig(MLPConfig, BaseModelConfig): hint=FieldHint.optional, ) - def _validate(self) -> None: - with self._set_implicit_default(): - # TODO: Review initialization - if self.init_method_std is None: - self.init_method_std = self.hidden_size**-0.5 - if self.init_method_min is not None and self.init_method_max is not None: - Assert.leq(self.init_method_min, self.init_method_max) - - super()._validate() - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: super().setup_tensor_space(tensor_space) diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 70f05956a..a99debacc 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -4,7 +4,6 @@ from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType, MLPRecomputeLevel -from fast_llm.layers.block.config import AddLinearBiasChoices from fast_llm.utils import Assert @@ -159,6 +158,8 @@ class MLPConfig(Config): @property def add_mlp_bias(self) -> bool: + from fast_llm.layers.block.config import AddLinearBiasChoices + # TODO: Make this work without inheritance. if isinstance(self.add_linear_biases, bool): return self.add_linear_biases diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index 60cee9847..e53693460 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -18,7 +18,7 @@ logger = logging.getLogger(__name__) -class MixtureOfExpertMLP[ConfigType: BlockConfig](MLPBase[ConfigType]): +class MixtureOfExpertMLP(MLPBase): """ MoeLayer following implementation from https://github.com/NVIDIA/Megatron-LM/blob/46ebc0e4202c980d98900000d455f754a7ff9d4b/megatron/model/transformer.py#L346 diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index 6243c17bd..06850c8d0 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -73,7 +73,7 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: self.layer_2 = self._config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) -class MLP[ConfigType: BlockConfig](MLPBase[ConfigType]): +class MLP(MLPBase): def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: int = 0, name: str = "mlp"): Assert.eq(config.num_experts, 1) super().__init__(config, tensor_space, block_index, name) diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 051044ef6..1ecafb344 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -22,7 +22,7 @@ class LanguageModelEmbedding[ConfigType: LanguageModelBaseConfig](Configurable[L together with optional absolute position embeddings and dropout. """ - config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = ConfigType + config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = LanguageModelBaseConfig # Ensure the layer is on its own stage. layer_count: float = 1000.0 diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 2fa0b0f06..6d1fedd26 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -38,7 +38,7 @@ class LanguageModelHead[ConfigType: LanguageModelBaseConfig](Configurable[Config A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). """ - config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = ConfigType + config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = LanguageModelBaseConfig def __init__( self, diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 00c709814..dec0675b9 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -9,7 +9,7 @@ from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: - from fast_llm.engine.config_utils.initialization import Initializer, init_fill_, init_uniform_centered_ + from fast_llm.engine.config_utils.initialization import Initializer class SSMDimNames(BlockDimNames): @@ -66,6 +66,8 @@ class DTInitType(enum.StrEnum): random = "random" def get_init_method(self, scale: float) -> "Initializer": + from fast_llm.engine.config_utils.initialization import init_fill_, init_uniform_centered_ + return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index a8245f7da..f7c7fea9c 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -233,3 +233,14 @@ def add_dense_bias(self) -> bool: # TODO: Use composition instead class TransformerConfig(AttentionConfig, BlockConfig): _abstract = False + + def _validate(self) -> None: + with self._set_implicit_default(): + # Kept here for initialization order. + # TODO: Review initialization + if self.init_method_std is None: + self.init_method_std = self.hidden_size**-0.5 + if self.init_method_min is not None and self.init_method_max is not None: + Assert.leq(self.init_method_min, self.init_method_max) + + super()._validate() diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index ebb629aa1..207cff7d3 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -42,6 +42,8 @@ def apply_rotary_embeddings(tensor: torch.Tensor, rope_frequencies: torch.Tensor class Rotary[ConfigType: RotaryConfig](Configurable[RotaryConfig], torch.nn.Module, Preprocessor): + config_class: typing.ClassVar[type[RotaryConfig]] = RotaryConfig + def __init__( self, config: ConfigType, @@ -58,6 +60,8 @@ def forward( class NoRotary[ConfigType: NoRotaryConfig](Rotary[NoRotaryConfig]): + config_class: typing.ClassVar[type[NoRotaryConfig]] = NoRotaryConfig + def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: @@ -71,6 +75,7 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: class DefaultRotary[ConfigType: DefaultRotaryConfig](Rotary[DefaultRotaryConfig]): + config_class: typing.ClassVar[type[DefaultRotaryConfig]] = DefaultRotaryConfig _rotary_embedding_frequencies: torch.Tensor _tensor_cache_max_sequence_length: int = -1 @@ -154,6 +159,8 @@ def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: class Llama3Rotary[ConfigType: Llama3RotaryConfig](DefaultRotary[Llama3RotaryConfig]): + config_class: typing.ClassVar[type[Llama3RotaryConfig]] = Llama3RotaryConfig + def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: scales = super()._get_angle_scales(kv_channels, device) low_frequency_wavelength = self._config.original_context_length / self._config.low_frequency_factor @@ -180,6 +187,8 @@ class YarnRotary[ConfigType: YarnRotaryConfig](DefaultRotary[YarnRotaryConfig]): [original paper](https://arxiv.org/abs/2309.00071) """ + config_class: typing.ClassVar[type[YarnRotaryConfig]] = YarnRotaryConfig + def _get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda") -> torch.Tensor: return super()._get_frequencies(sequence_length, kv_channels, device) * self._config.attention_factor From f48332139d54a7e6cbe3171b480434832d7e5a8d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 31 Jul 2025 21:07:17 -0400 Subject: [PATCH 43/82] fixes --- tests/test_ssms.py | 82 ---------------------------------------------- 1 file changed, 82 deletions(-) delete mode 100644 tests/test_ssms.py diff --git a/tests/test_ssms.py b/tests/test_ssms.py deleted file mode 100644 index 6c4c7f0cb..000000000 --- a/tests/test_ssms.py +++ /dev/null @@ -1,82 +0,0 @@ -import pathlib - -import pytest -import torch - -from fast_llm.config import NoAutoValidate -from fast_llm.engine.checkpoint.config import CheckpointLoadConfig -from fast_llm.engine.distributed.config import DistributedConfig, PhaseType -from fast_llm.engine.schedule.config import ScheduleConfig -from fast_llm.engine.schedule.runner import ScheduleRunner -from fast_llm.engine.schedule.schedule import Schedule -from fast_llm.layers.transformer.config import AttentionKwargs -from fast_llm.models.gpt.config import GPTBatchConfig -from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat -from fast_llm.models.ssm.model import HybridSSMModel - - -@pytest.mark.skip("Disabled due to cartesia_pytorch installation issue") -@pytest.mark.slow -def test_load_from_llamba_checkpoint(): - """ - Test to check whether the of Fast-LLM and Huggingface checkpoint loading for Llamba-1B produce the same results. - """ - import cartesia_pytorch.Llamba.llamba - - vocab_size = 128256 # from https://huggingface.co/cartesia-ai/Llamba-1B/blob/main/config.json - batch_size = 2 - seq_length = 32 - - path = pathlib.Path("/mnt/checkpoints_fml/pretrained_models/Llamba-1B") - format = LLambaHuggingfaceCheckpointFormat - - x = torch.randint(0, vocab_size, (batch_size, seq_length), device="cuda") - - hf_model = cartesia_pytorch.Llamba.llamba.LMHeadModel.from_pretrained(path, strict=True).to("cuda") - parameter_sum_hf = sum(p.detach().sum().cpu().item() for p in hf_model.parameters()) - hf_logits = hf_model(x)["logits"].cpu() - del hf_model - torch.cuda.empty_cache() - - # Create checkpoint load config - checkpoint_config = CheckpointLoadConfig(path=path, format=format, model_weights=True, optimizer_state=False) - # Initialize model - model = HybridSSMModel.from_pretrained(checkpoint_config) - param_sum = 0 - for stage in model.stages: - for fsdp in stage.fsdps: - if hasattr(fsdp, "_weight_shard"): - param_sum += torch.sum(fsdp._weight_shard).item() - assert torch.abs(torch.tensor(param_sum) - parameter_sum_hf) < 1e-1 - - # model = GPTModel.from_pretrained(checkpoint_config) - assert model.config.base_model.vocab_size == vocab_size - schedule_config = ScheduleConfig() - with NoAutoValidate(): - batch_config = GPTBatchConfig(micro_batch_size=batch_size, sequence_length=seq_length) - batch_config.setup(DistributedConfig.from_dict({})) - batch_config.validate() - schedule_runner = ScheduleRunner( - config=schedule_config, - multi_stage=model, - distributed_config=model.distributed.config, - ) - schedule = Schedule( - multi_stage=model, - batch_config=batch_config, - schedule_config=schedule_config, - distributed_config=model.distributed.config, - phase=PhaseType.inference, - ) - schedule_runner.setup(model.distributed, optimizer=None) - - common_kwargs = { - AttentionKwargs.sequence_first: True, - AttentionKwargs.grad_output: False, - } - input_data = [(x, common_kwargs)] - - schedule_runner.run_step(iter([input_data]), schedule, iteration=0, return_metrics=True, preprocessed=True) - - logits = input_data[0][1]["logits"].cpu() - assert torch.allclose(logits, hf_logits, atol=1e-2) From 8abf2587028c36fa2fab29c12db57a189bfe3c0f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 8 Aug 2025 16:21:17 -0400 Subject: [PATCH 44/82] fixes --- fast_llm/models/gpt/config.py | 4 -- fast_llm/models/ssm/conversion.py | 61 ++++++++++++++++---------- tests/conftest.py | 20 ++------- tests/data/common.py | 2 +- tests/data/test_blending.py | 3 +- tests/data/test_concatenate.py | 3 +- tests/data/test_concatenated_memmap.py | 3 +- tests/data/test_dataset_from_file.py | 3 +- tests/data/test_fim.py | 3 +- tests/data/test_memmap.py | 3 +- tests/data/test_sampling.py | 3 +- tests/data/test_slice.py | 3 +- tests/models/test_checkpoint.py | 4 +- tests/models/test_lm_eval.py | 3 +- tests/models/test_match_megatron.py | 3 +- tests/utils/dataset.py | 26 +++++------ tests/utils/global_variables.py | 48 ++++++++++++++++++++ tests/utils/model_configs.py | 2 +- tests/utils/utils.py | 13 +----- 19 files changed, 125 insertions(+), 85 deletions(-) create mode 100644 tests/utils/global_variables.py diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 0da16428e..3ca2d71fa 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -23,7 +23,6 @@ class GPTHuggingfaceCheckpointFormat(CheckpointFormat): support_optimizer: typing.ClassVar[bool] = False - trust_remote_code: typing.ClassVar[bool] = False @classmethod def get_handler_class(cls) -> type[CheckpointHandler]: @@ -58,17 +57,14 @@ class MixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): class MTPLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "mtp_llama" - trust_remote_code: typing.ClassVar[bool] = True class DiffusionDreamGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "dream" - trust_remote_code: typing.ClassVar[bool] = True class DiffusionLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "diffusion_llama" - trust_remote_code: typing.ClassVar[bool] = True @config_class() diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 43e3c67e5..b5e77e0f0 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -3,6 +3,8 @@ import pathlib import typing +from transformers import PretrainedConfig + from fast_llm.config import MISSING from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( @@ -16,7 +18,7 @@ SplitWeightConverter, WeightConverter, ) -from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import RMSNormalizationConfig @@ -29,12 +31,14 @@ HybridSSMModelConfig, LLambaHuggingfaceCheckpointFormat, ) +from fast_llm.models.ssm.external.apriel_15b_hybrid import ( + configuration_ssm_hybrid_apriel15b, + modeling_ssm_hybrid_apriel15b, +) +from fast_llm.models.ssm.external.apriel_hybrid import configuration_ssm_hybrid_apriel, modeling_ssm_hybrid_apriel from fast_llm.models.ssm.model import HybridSSMModel from fast_llm.utils import Assert -if typing.TYPE_CHECKING: - pass - class HybridModelCheckpointHandler(HuggingfaceStateDictCheckpointHandler): _model: HybridSSMModel @@ -523,6 +527,11 @@ class AprielSSMHuggingfaceCheckpointHandler(CommonSSMHuggingfaceCheckpointHandle _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig format: typing.ClassVar[type[CheckpointFormat]] = AprielSSMHuggingfaceCheckpointFormat architecture: typing.ClassVar[str] = "AprielSSMForCausalLM" + modeling_file = modeling_ssm_hybrid_apriel15b.__file__ + configuration_file = configuration_ssm_hybrid_apriel15b.__file__ + configuration_cls: typing.ClassVar[type["PretrainedConfig"]] = ( + configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig + ) @classmethod def _create_config_converters(cls) -> list[ParamConverter]: @@ -635,6 +644,7 @@ def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.An class AprielSSMHHybridHuggingfaceCheckpointHandler( + CustomModelingExportMixin, HybridModelCheckpointHandler, # handles the block structure parameter CommonSSMHuggingfaceCheckpointHandler, # handles the SSM layers CommonLlamaHuggingfaceCheckpointHandler, # handles the LLama layers @@ -648,10 +658,21 @@ class AprielSSMHHybridHuggingfaceCheckpointHandler( format: typing.ClassVar[type[CheckpointFormat]] = AprielSSMHHybridHuggingfaceCheckpointFormat _default_block_type: str = SSMBlockType.mamba2_discrete.value architecture: typing.ClassVar[str] = "AprielSSMHybridForCausalLM" + modeling_file = modeling_ssm_hybrid_apriel.__file__ + configuration_file = configuration_ssm_hybrid_apriel.__file__ + configuration_cls: typing.ClassVar[type["PretrainedConfig"]] = modeling_ssm_hybrid_apriel.AprielSSMHybridConfig @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ + ConstantExportParamConverter( + export_names=(("auto_map",),), + export_value={ + "AutoConfig": "configuration_ssm_hybrid_apriel.AprielSSMHybridConfig", + "AutoModel": "modeling_ssm_hybrid_apriel.AprielSSMHybridModel", + "AutoModelForCausalLM": "modeling_ssm_hybrid_apriel.AprielSSMHybridForCausalLM", + }, + ), RenameParamConverter( fast_llm_names=(("ssm", "d_inner"),), export_names=(("ssm_cfg", "d_inner"),), @@ -693,6 +714,7 @@ def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.An class AprielThinkerSSMHHybridHuggingfaceCheckpointHandler( + CustomModelingExportMixin, HybridModelCheckpointHandler, # handles the block structure parameter CommonSSMHuggingfaceCheckpointHandler, # handles the SSM layers CommonLlamaHuggingfaceCheckpointHandler, # handles the LLama layers @@ -707,28 +729,23 @@ class AprielThinkerSSMHHybridHuggingfaceCheckpointHandler( _default_block_type: str = SSMBlockType.mamba2_discrete.value _hf_prefix: str = "model" architecture: typing.ClassVar[str] = "AprielThinkerSSMHybridForCausalLM" - - def _create_weight_converters(self) -> list[WeightConverter]: - converters = super()._create_weight_converters() - # num_layers = self._model.config.base_model.transformer.num_layers - # # Embedding and output - # if self._model.config.base_model.tie_word_embeddings: - # converters.append( - # WeightConverter("layers.0.word_embeddings_weight", f"{self._hf_prefix}.embedding.weight") - # ) - # converters.append(IgnoreImportWeightConverter((), f"{self._hf_prefix}.lm_head.weight")) - # else: - # converters.append( - # WeightConverter("layers.0.word_embeddings_weight", f"{self._hf_prefix}.embedding.weight") - # ) - # converters.append( - # WeightConverter(f"layers.{num_layers + 1}.output_weights", f"{self._hf_prefix}.lm_head.weight") - # ) - return converters + modeling_file = modeling_ssm_hybrid_apriel15b.__file__ + configuration_file = configuration_ssm_hybrid_apriel15b.__file__ + configuration_cls: typing.ClassVar[type["PretrainedConfig"]] = ( + configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig + ) @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ + ConstantExportParamConverter( + export_names=(("auto_map",),), + export_value={ + "AutoConfig": "configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig", + "AutoModel": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridModel", + "AutoModelForCausalLM": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridForCausalLM", + }, + ), RenameParamConverter( fast_llm_names=(("ssm", "d_inner"),), export_names=(("ssm_cfg", "d_inner"),), diff --git a/tests/conftest.py b/tests/conftest.py index 19bdfe5d9..86937326c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,27 +8,15 @@ import pytest import xdist.scheduler -from fast_llm.utils import get_and_reset_memory_usage_mib, set_global_variables +from fast_llm.utils import get_and_reset_memory_usage_mib from tests.utils.depends import DependencyManager +from tests.utils.global_variables import TEST_RESULTS_PATH, set_testing_global_variables # TODO: Is this early enough? -set_global_variables() # isort: skip - - -if worker_name := os.environ.get("PYTEST_XDIST_WORKER"): - if gpus := os.environ.get("CUDA_VISIBLE_DEVICES"): - # We set the device through "CUDA_VISIBLE_DEVICES", and this needs to happen before importing torch. - assert worker_name.startswith("gw") - worker_id = int(worker_name[2:]) - gpus = [int(i) for i in gpus.split(",")] - num_gpus = len(gpus) - gpus = [gpus[(i + worker_id) % num_gpus] for i in range(num_gpus)] - os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in gpus) - +set_testing_global_variables() # isort: skip import torch # isort: skip - from tests.utils.save_load_configs import ( # isort: skip distributed_save_load_config, distributed_save_load_config_non_pp, @@ -44,7 +32,7 @@ ) from tests.utils.model_configs import model_testing_config, ModelTestingConfig, testing_group_enabled # isort: skip -from tests.utils.utils import result_path, TEST_RESULTS_PATH, format_resource_report, report_subtest # isort: skip +from tests.utils.utils import result_path, format_resource_report, report_subtest # isort: skip logger = logging.getLogger(__name__) diff --git a/tests/data/common.py b/tests/data/common.py index 2bb90a6b4..6614accce 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -23,7 +23,7 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.utils import Assert, div -from tests.utils.dataset import TEST_VOCAB_SIZE +from tests.utils.global_variables import TEST_VOCAB_SIZE def get_sampling_data( diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 3e6c37632..312807aad 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -11,7 +11,8 @@ get_sampling_data, get_test_data_and_compare_samples, ) -from tests.utils.dataset import DATASET_CACHE, DATASET_PREFIX, get_test_dataset +from tests.utils.dataset import get_test_dataset +from tests.utils.global_variables import DATASET_CACHE, DATASET_PREFIX _DATASET_PREFIX_MIX_1 = DATASET_CACHE / "blended_mix_1" / "dataset" diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index 4f36cdf89..6cc5d639a 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -7,7 +7,8 @@ get_test_data_and_compare_samples, ) from tests.data.test_memmap import MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_TOKENS -from tests.utils.dataset import DATASET_PREFIX, get_test_dataset +from tests.utils.dataset import get_test_dataset +from tests.utils.global_variables import DATASET_PREFIX GPT_CONCATENATED_SAMPLES = [ [4709, 819, 79, 207, 277, 1790], diff --git a/tests/data/test_concatenated_memmap.py b/tests/data/test_concatenated_memmap.py index 1cc22250d..35d93d9d5 100644 --- a/tests/data/test_concatenated_memmap.py +++ b/tests/data/test_concatenated_memmap.py @@ -9,7 +9,8 @@ validate_indexed_dataset_sampling, ) from tests.data.test_memmap import MEMMAP_DATASET_SAMPLES -from tests.utils.dataset import DATASET_CACHE, get_test_concatenated_memmap_dataset +from tests.utils.dataset import get_test_concatenated_memmap_dataset +from tests.utils.global_variables import DATASET_CACHE _DATASET_PREFIX_MIX_CONCATENATED_MEMMAP = DATASET_CACHE / "concatenated_memmap" diff --git a/tests/data/test_dataset_from_file.py b/tests/data/test_dataset_from_file.py index 3f7d1a139..c149e1395 100644 --- a/tests/data/test_dataset_from_file.py +++ b/tests/data/test_dataset_from_file.py @@ -1,7 +1,8 @@ from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig from tests.data.common import compare_indexed_dataset, get_dataset_config from tests.data.test_memmap import MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_TOKENS -from tests.utils.dataset import DATASET_PREFIX, get_test_dataset +from tests.utils.dataset import get_test_dataset +from tests.utils.global_variables import DATASET_PREFIX def test_dataset_from_file(): diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 004b96289..551134fd2 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -7,7 +7,8 @@ get_sampling_data, get_test_data_and_compare_samples, ) -from tests.utils.dataset import DATASET_PREFIX, TOKENIZER_PATH, get_test_dataset +from tests.utils.dataset import get_test_dataset +from tests.utils.global_variables import DATASET_PREFIX, TOKENIZER_PATH GPT_FIM_SAMPLES = [ [4709, 819, 79, 207, 277, 1790], diff --git a/tests/data/test_memmap.py b/tests/data/test_memmap.py index fcd7756db..1286bddd7 100644 --- a/tests/data/test_memmap.py +++ b/tests/data/test_memmap.py @@ -4,7 +4,8 @@ from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig from tests.data.common import compare_indexed_dataset, get_dataset_config -from tests.utils.dataset import DATASET_CACHE, DATASET_PREFIX, DATASET_SAMPLING_CACHE, get_test_dataset +from tests.utils.dataset import get_test_dataset +from tests.utils.global_variables import DATASET_CACHE, DATASET_PREFIX, DATASET_SAMPLING_CACHE MEMMAP_DATASET_LENGTH = 6153 MEMMAP_DATASET_TOKENS = 508327 diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 32d76fa4c..a2996aa1c 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -13,7 +13,8 @@ get_test_data_and_compare_samples, validate_indexed_dataset_sampling, ) -from tests.utils.dataset import DATASET_PREFIX, get_test_dataset +from tests.utils.dataset import get_test_dataset +from tests.utils.global_variables import DATASET_PREFIX try: from fast_llm.csrc.data import build_padded_token_cumsum # noqa diff --git a/tests/data/test_slice.py b/tests/data/test_slice.py index f8eedc5bc..1440614cb 100644 --- a/tests/data/test_slice.py +++ b/tests/data/test_slice.py @@ -7,7 +7,8 @@ validate_indexed_dataset_sampling, ) from tests.data.test_memmap import MEMMAP_DATASET_SAMPLES -from tests.utils.dataset import DATASET_PREFIX, get_test_dataset +from tests.utils.dataset import get_test_dataset +from tests.utils.global_variables import DATASET_PREFIX GPT_SLICE_TRAINING_SAMPLES = [ [80, 268, 79, 260, 207, 3086], diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 4bda5512c..031ec6f97 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -317,9 +317,7 @@ def test_huggingface_model(model_testing_config, get_convert_path): if model_testing_config.name in ("diffusion_llama", "dream") else transformers.AutoModelForCausalLM ) - model_as_hf = auto_model.from_pretrained( - hf_path, trust_remote_code=model_testing_config.checkpoint_format.trust_remote_code - ).cuda() + model_as_hf = auto_model.from_pretrained(hf_path, trust_remote_code=True).cuda() for name, model in zip( ("From state dict", "From Huggingface", "Native Huggingface"), (model_from_fast_llm, model_from_hf, model_as_hf), diff --git a/tests/models/test_lm_eval.py b/tests/models/test_lm_eval.py index b9e2aa8c3..8011b5bbc 100644 --- a/tests/models/test_lm_eval.py +++ b/tests/models/test_lm_eval.py @@ -3,8 +3,9 @@ import pytest -from tests.utils.dataset import TOKENIZER_PATH, download_santacoder_tokenizer +from tests.utils.dataset import download_santacoder_tokenizer from tests.utils.distributed_configs import DistributedTestingConfig +from tests.utils.global_variables import TOKENIZER_PATH from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index 30667cd17..5ff998bfa 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -3,8 +3,9 @@ import pytest from tests.utils.compare_tensor_logs import CompareConfig -from tests.utils.dataset import MODEL_DATASET_PREFIX, get_model_test_dataset +from tests.utils.dataset import get_model_test_dataset from tests.utils.distributed_configs import DistributedTestingConfig +from tests.utils.global_variables import MODEL_DATASET_PREFIX from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index b770675d4..e4cce2935 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -1,27 +1,21 @@ import pathlib import random -import string import numpy as np import yaml from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample -from tests.utils.utils import SHARED_RESULT_PATH, TEST_RESULTS_PATH - -# TODO: Fixtures -TOKENIZER_PATH = SHARED_RESULT_PATH / "tokenizer" -TOKENIZER_FILE = TOKENIZER_PATH / "tokenizer.json" -DATASET_CACHE = SHARED_RESULT_PATH / "dataset" -DATASET_PREFIX = DATASET_CACHE / "common_dataset" -DATASET_SAMPLING_CACHE = TEST_RESULTS_PATH / "dataset_sampling_cache" -TEST_VOCAB_SIZE = 8192 -# Random lowercase: 80.7% (3.1% each); space: 18.6%; doc end: 0.6% -TEST_CHARACTERS = (string.ascii_lowercase) * 5 + " " * 30 + "\n" -TEST_DATASET_TOKENS = 1000000 - -MODEL_DATASET_PREFIX = DATASET_CACHE / "model_dataset" -MODEL_TEST_VOCAB_SIZE = 384 +from tests.utils.global_variables import ( + DATASET_PREFIX, + MODEL_DATASET_PREFIX, + MODEL_TEST_VOCAB_SIZE, + TEST_CHARACTERS, + TEST_DATASET_TOKENS, + TEST_VOCAB_SIZE, + TOKENIZER_FILE, + TOKENIZER_PATH, +) def download_santacoder_tokenizer(): diff --git a/tests/utils/global_variables.py b/tests/utils/global_variables.py new file mode 100644 index 000000000..80232bf53 --- /dev/null +++ b/tests/utils/global_variables.py @@ -0,0 +1,48 @@ +""" +This files holds global variables and settings that need to be defined before importing any third-party package. +They are kept in a separate file to prevent circular imports. +""" + +import os +import pathlib +import string + +from fast_llm.utils import set_global_variables + +# Directory for all test data and results. +# Cannot be a fixture because it's used outside testing environment (ex. distributed scripts). +TEST_RESULTS_PATH = pathlib.Path("/tmp/fast_llm_tests") + +WORKER_NAME = os.environ.get("PYTEST_XDIST_WORKER") +GPUS = os.environ.get("CUDA_VISIBLE_DEVICES") +SHARED_RESULT_PATH = TEST_RESULTS_PATH / (f"common_{WORKER_NAME}" if WORKER_NAME else "common") + + +def set_testing_global_variables(): + set_global_variables() # isort: skip + if WORKER_NAME: + if gpus := os.environ.get("CUDA_VISIBLE_DEVICES"): + # We set the device through "CUDA_VISIBLE_DEVICES", and this needs to happen before importing torch. + assert WORKER_NAME.startswith("gw") + worker_id = int(WORKER_NAME[2:]) + gpus = [int(i) for i in gpus.split(",")] + num_gpus = len(gpus) + gpus = [gpus[(i + worker_id) % num_gpus] for i in range(num_gpus)] + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in gpus) + os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(SHARED_RESULT_PATH / "torchinductor_cache") + os.environ["TRITON_CACHE_DIR"] = str(SHARED_RESULT_PATH / "triton_cache") + + +# TODO: Fixtures +TOKENIZER_PATH = SHARED_RESULT_PATH / "tokenizer" +TOKENIZER_FILE = TOKENIZER_PATH / "tokenizer.json" +DATASET_CACHE = SHARED_RESULT_PATH / "dataset" +DATASET_PREFIX = DATASET_CACHE / "common_dataset" +DATASET_SAMPLING_CACHE = TEST_RESULTS_PATH / "dataset_sampling_cache" +TEST_VOCAB_SIZE = 8192 +# Random lowercase: 80.7% (3.1% each); space: 18.6%; doc end: 0.6% +TEST_CHARACTERS = (string.ascii_lowercase) * 5 + " " * 30 + "\n" +TEST_DATASET_TOKENS = 1000000 + +MODEL_DATASET_PREFIX = DATASET_CACHE / "model_dataset" +MODEL_TEST_VOCAB_SIZE = 384 diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 722d8d63a..e9bdeba97 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -24,8 +24,8 @@ AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, LLambaHuggingfaceCheckpointFormat, ) -from tests.utils.dataset import MODEL_DATASET_PREFIX, MODEL_TEST_VOCAB_SIZE from tests.utils.distributed_configs import DistributedTestingConfig +from tests.utils.global_variables import MODEL_DATASET_PREFIX, MODEL_TEST_VOCAB_SIZE from fast_llm.engine.evaluation.evaluators import ( # isort:skip # needed for dynamic type registration EvaluatorsConfig, diff --git a/tests/utils/utils.py b/tests/utils/utils.py index 25d5221d8..88303a0f4 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -1,7 +1,6 @@ import json import logging import math -import os import pathlib import sys import time @@ -19,22 +18,12 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageConfig from fast_llm.engine.multi_stage.stage import Stage from fast_llm.utils import get_and_reset_memory_usage_mib, header +from tests.utils.global_variables import TEST_RESULTS_PATH logger = logging.getLogger(__name__) requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") -# Directory for all test data and results. -# Cannot be a fixture because it's used outside testing environment (ex. distributed scripts). -TEST_RESULTS_PATH = pathlib.Path("/tmp/fast_llm_tests") - -# Directory for data that is shared between independent tests and may not be parallel-safe, -# ex. generated dataset and downloaded files. -if worker_name := os.environ.get("PYTEST_XDIST_WORKER"): - SHARED_RESULT_PATH = TEST_RESULTS_PATH / f"common_{worker_name}" -else: - SHARED_RESULT_PATH = TEST_RESULTS_PATH / "common" - @pytest.fixture(scope="session") def result_path(): From 07c921182b31f2a1fff16da703fcd7e82b73a7fe Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 8 Aug 2025 16:33:45 -0400 Subject: [PATCH 45/82] stuff --- fast_llm/config.py | 6 ++++++ fast_llm/data/data/abstract.py | 1 - fast_llm/data/data/gpt/data.py | 2 -- fast_llm/data/preparator/config.py | 2 -- fast_llm/data/preparator/gpt_memmap/prepare.py | 2 -- fast_llm/engine/base_model/base_model.py | 1 - fast_llm/engine/distributed/distributed.py | 3 --- fast_llm/engine/evaluation/evaluator.py | 4 ---- fast_llm/engine/evaluation/lm_eval/evaluator.py | 2 -- fast_llm/engine/multi_stage/fast_llm_model.py | 1 - fast_llm/engine/multi_stage/multi_stage.py | 1 - fast_llm/engine/multi_stage/stage_base.py | 1 - fast_llm/engine/schedule/runner.py | 1 - fast_llm/engine/training/trainer.py | 3 --- fast_llm/layers/block/block.py | 1 - fast_llm/layers/language_model/embedding.py | 2 -- fast_llm/layers/language_model/head.py | 3 --- fast_llm/layers/transformer/rotary/rotary.py | 9 --------- fast_llm/models/custom/model.py | 6 +----- fast_llm/models/custom/trainer.py | 4 ---- fast_llm/models/gpt/model.py | 3 --- fast_llm/models/gpt/trainer.py | 2 -- fast_llm/models/ssm/model.py | 2 -- fast_llm/models/ssm/trainer.py | 1 - 24 files changed, 7 insertions(+), 56 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index c534b11f3..099670625 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -1028,6 +1028,12 @@ def __init__(self, config: ConfigType, *args, **kwargs): # Handle multiple inheritance. super().__init__(*args, **kwargs) + def __init_subclass__(cls): + # Automatically set `config_class` based on the bound type. + # Make sure `ConfigType` is bound and respects class hierarchy. + Assert.custom(issubclass, config_class := ConfigType.__bound__, cls.config_class) + cls.config_class = config_class + @property def config(self) -> ConfigType: return self._config diff --git a/fast_llm/data/data/abstract.py b/fast_llm/data/data/abstract.py index 04da64a9d..e24d39985 100644 --- a/fast_llm/data/data/abstract.py +++ b/fast_llm/data/data/abstract.py @@ -13,7 +13,6 @@ class Data[ConfigType: DataConfig](Configurable[ConfigType], abc.ABC): - config_class: typing.ClassVar[type[DataConfig]] = DataConfig _distributed: "Distributed" _sampling_parameters: dict[str, SamplingParameters] _cache_directory: pathlib.Path | None diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 37cfd9020..6724afb59 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -65,8 +65,6 @@ class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]): TODO: Separate generic and GPT classes. """ - config_class: typing.ClassVar[type[GPTDataConfig]] = GPTDataConfig - _datasets: dict[str, SampledDataset] _sampling_parameters: dict[str, GPTSamplingParameters] _tokenizer: Tokenizer | None diff --git a/fast_llm/data/preparator/config.py b/fast_llm/data/preparator/config.py index 7f6376c7d..160fccafc 100644 --- a/fast_llm/data/preparator/config.py +++ b/fast_llm/data/preparator/config.py @@ -19,8 +19,6 @@ def _get_runnable(self) -> typing.Callable[[], None]: class DatasetPreparator[ConfigType: DatasetPreparatorConfig](Configurable[ConfigType], abc.ABC): - config_class: typing.ClassVar[type[DatasetPreparatorConfig]] = DatasetPreparatorConfig - @abc.abstractmethod def run(self) -> None: raise NotImplementedError diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 427309a99..33c40bf8f 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -33,8 +33,6 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](DatasetPreparator[ConfigType]): - config_class: typing.ClassVar[type[GPTMemmapDatasetPreparatorConfig]] = GPTMemmapDatasetPreparatorConfig - _tokenizer: Tokenizer _data_type: DataType _text_column: str diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index df603a910..caaf94794 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -82,7 +82,6 @@ def get_layers(self) -> list[Layer]: class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], SequentialLayers, abc.ABC): - config_class: typing.ClassVar[type[BaseModelConfig]] = BaseModelConfig _is_setup: bool = False def __init__( diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index f17a8f452..dc41539c0 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -1,6 +1,5 @@ import datetime import logging -import typing import torch import torch.distributed @@ -146,8 +145,6 @@ class Distributed[ConfigType: DistributedConfig](Configurable[ConfigType]): TODO: Clarify cpu support. """ - config_class: typing.ClassVar[type[DistributedConfig]] = DistributedConfig - def __init__(self, config: DistributedConfig, use_cpu: bool = False): super().__init__(config) assert self._config.reference_config is None diff --git a/fast_llm/engine/evaluation/evaluator.py b/fast_llm/engine/evaluation/evaluator.py index 3bdc2407f..6b8f8db00 100644 --- a/fast_llm/engine/evaluation/evaluator.py +++ b/fast_llm/engine/evaluation/evaluator.py @@ -44,8 +44,6 @@ class EvaluatorSamplingParameters: class Evaluator[ConfigType: EvaluatorConfig](Configurable[ConfigType], abc.ABC): - config_class: typing.ClassVar[type[EvaluatorConfig]] = EvaluatorConfig - _is_setup: bool = False def __init__( @@ -96,8 +94,6 @@ def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: class LossEvaluator[ConfigType: LossEvaluatorConfig](Evaluator[ConfigType]): - config_class: typing.ClassVar[type[LossEvaluatorConfig]] = LossEvaluatorConfig - def setup( self, distributed: Distributed, diff --git a/fast_llm/engine/evaluation/lm_eval/evaluator.py b/fast_llm/engine/evaluation/lm_eval/evaluator.py index 162ceaf60..9040b11b4 100644 --- a/fast_llm/engine/evaluation/lm_eval/evaluator.py +++ b/fast_llm/engine/evaluation/lm_eval/evaluator.py @@ -25,8 +25,6 @@ class LmEvalEvaluator[ConfigType: LmEvalEvaluatorConfig](Evaluator[ConfigType]): - config_class: typing.ClassVar[type[LmEvalEvaluatorConfig]] = LmEvalEvaluatorConfig - _hf_model: "HuggingfaceBaseModelForCausalLM" = None _flm_wrapper: "FastLLMLmEvalWrapper" = None diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index 56bae90fe..09ee788e6 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -14,7 +14,6 @@ class FastLLMModel[ConfigType: FastLLMModelConfig](MultiStageModel[ConfigType]): - config_class: typing.ClassVar[type[FastLLMModelConfig]] = FastLLMModelConfig _is_loaded: bool = False def save_checkpoint( diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index 1f734268b..e17bc4ff8 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -26,7 +26,6 @@ class MultiStageModel[ConfigType: FastLLMModelConfig](Configurable[ConfigType]): - config_class: typing.ClassVar[type[FastLLMModelConfig]] = FastLLMModelConfig base_model_class: typing.ClassVar[type[BaseModel]] = BaseModel _is_setup: bool = False _flat_shard: torch.Tensor diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 3218a1963..387a53a03 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -21,7 +21,6 @@ class StageBase(Configurable[StageConfig]): - config_class: typing.ClassVar[type[StageConfig]] = StageConfig _distributed: Distributed _mode: StageMode diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 7fdba1832..8eca4559d 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -64,7 +64,6 @@ def __repr__(self): class ScheduleRunner[ConfigType: ScheduleConfig](Configurable[ScheduleConfig]): - config_class: typing.ClassVar[type[ScheduleConfig]] = ScheduleConfig _is_setup: bool = False _compute_stream: torch.cuda.Stream _data_stream: torch.cuda.Stream diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 5f5511a15..e5bd5a583 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -43,8 +43,6 @@ class TrainingEvaluator[ConfigType: TrainingEvaluatorConfig](Evaluator[ConfigType]): - config_class: typing.ClassVar[type[TrainingEvaluatorConfig]] = TrainingEvaluatorConfig - evaluator: Evaluator def __init__( @@ -114,7 +112,6 @@ def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: class Trainer[ConfigType: TrainerConfig](Configurable[ConfigType], abc.ABC): - config_class: typing.ClassVar[type[TrainerConfig]] = TrainerConfig # TODO: Generalize data, schedule, logging, etc. _is_setup: bool = False _distributed: Distributed diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 03e0df928..f06b2da45 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -121,7 +121,6 @@ class Block[ConfigType: BlockConfig](Configurable[ConfigType], Layer): A transformer-like decoder base block with abstract mixer. """ - config_class: typing.ClassVar[type[BlockConfig]] = BlockConfig # TODO: Standardize to `mixer` _mixer_module_name: typing.ClassVar[str] = "mixer" diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 1ecafb344..d90442e9f 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -22,8 +22,6 @@ class LanguageModelEmbedding[ConfigType: LanguageModelBaseConfig](Configurable[L together with optional absolute position embeddings and dropout. """ - config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = LanguageModelBaseConfig - # Ensure the layer is on its own stage. layer_count: float = 1000.0 diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 6d1fedd26..8624612d6 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -1,5 +1,4 @@ import logging -import typing import torch from torch._C._distributed_c10d import ReduceOp # noqa @@ -38,8 +37,6 @@ class LanguageModelHead[ConfigType: LanguageModelBaseConfig](Configurable[Config A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). """ - config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = LanguageModelBaseConfig - def __init__( self, config: ConfigType, diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index 207cff7d3..ebb629aa1 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -42,8 +42,6 @@ def apply_rotary_embeddings(tensor: torch.Tensor, rope_frequencies: torch.Tensor class Rotary[ConfigType: RotaryConfig](Configurable[RotaryConfig], torch.nn.Module, Preprocessor): - config_class: typing.ClassVar[type[RotaryConfig]] = RotaryConfig - def __init__( self, config: ConfigType, @@ -60,8 +58,6 @@ def forward( class NoRotary[ConfigType: NoRotaryConfig](Rotary[NoRotaryConfig]): - config_class: typing.ClassVar[type[NoRotaryConfig]] = NoRotaryConfig - def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: @@ -75,7 +71,6 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: class DefaultRotary[ConfigType: DefaultRotaryConfig](Rotary[DefaultRotaryConfig]): - config_class: typing.ClassVar[type[DefaultRotaryConfig]] = DefaultRotaryConfig _rotary_embedding_frequencies: torch.Tensor _tensor_cache_max_sequence_length: int = -1 @@ -159,8 +154,6 @@ def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: class Llama3Rotary[ConfigType: Llama3RotaryConfig](DefaultRotary[Llama3RotaryConfig]): - config_class: typing.ClassVar[type[Llama3RotaryConfig]] = Llama3RotaryConfig - def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: scales = super()._get_angle_scales(kv_channels, device) low_frequency_wavelength = self._config.original_context_length / self._config.low_frequency_factor @@ -187,8 +180,6 @@ class YarnRotary[ConfigType: YarnRotaryConfig](DefaultRotary[YarnRotaryConfig]): [original paper](https://arxiv.org/abs/2309.00071) """ - config_class: typing.ClassVar[type[YarnRotaryConfig]] = YarnRotaryConfig - def _get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda") -> torch.Tensor: return super()._get_frequencies(sequence_length, kv_channels, device) * self._config.attention_factor diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index 3c0ad8ab4..98937bdb1 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -8,16 +8,13 @@ from fast_llm.engine.schedule.config import BatchConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.transformer.block import TransformerBlock -from fast_llm.models.custom.config import CustomBaseModelConfig, CustomModelConfig +from fast_llm.models.custom.config import CustomBaseModelConfig from fast_llm.models.custom.head import CustomHead -from fast_llm.models.gpt.config import GPTBaseModelConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTModel from fast_llm.tensor import TensorMeta class CustomBaseModel[ConfigType: CustomBaseModelConfig](GPTBaseModel[ConfigType]): - config_class: typing.ClassVar[type[GPTBaseModelConfig]] = GPTBaseModelConfig - def __init__( self, config: CustomBaseModelConfig, @@ -66,5 +63,4 @@ def loss_defs(self) -> list[LossDef]: class CustomModel[ConfigType: CustomBaseModelConfig](GPTModel[ConfigType]): - config_class: typing.ClassVar[type[CustomModelConfig]] = CustomModelConfig base_model_class: typing.ClassVar[type[CustomBaseModel]] = CustomBaseModel diff --git a/fast_llm/models/custom/trainer.py b/fast_llm/models/custom/trainer.py index eba51235e..587adad3e 100644 --- a/fast_llm/models/custom/trainer.py +++ b/fast_llm/models/custom/trainer.py @@ -1,5 +1,3 @@ -import typing - from fast_llm.models.custom.config import CustomTrainerConfig from fast_llm.models.custom.data import CustomData from fast_llm.models.gpt.trainer import GPTTrainer @@ -7,8 +5,6 @@ class CustomTrainer[ConfigType: CustomTrainerConfig](GPTTrainer[ConfigType]): # TODO: Implement changes in the training loop (or tflops computation), if any (typically none). - config_class: typing.ClassVar[type[CustomTrainerConfig]] = CustomTrainerConfig - def _get_data(self): # TODO: Adjust signature if needed. return CustomData( diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 187ca618d..47df8ba1c 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -31,8 +31,6 @@ class GPTBaseModel[ConfigType: GPTBaseModelConfig](BaseModel[ConfigType]): A transformer-based language model generalizing the GPT model architecture. """ - config_class: typing.ClassVar[type[GPTBaseModelConfig]] = GPTBaseModelConfig - def __init__( self, config: GPTBaseModelConfig, @@ -410,7 +408,6 @@ def loss_defs(self) -> list[LossDef]: class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): - config_class: typing.ClassVar[type[GPTModelConfig]] = GPTModelConfig base_model_class: typing.ClassVar[type[GPTBaseModel]] = GPTBaseModel def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration, batch_size, sequence_length) -> tuple[int, int]: diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 54508e8e1..7f2e83ab4 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -10,8 +10,6 @@ class GPTTrainer[ConfigType: GPTTrainerConfig](Trainer[ConfigType]): - config_class: typing.ClassVar[type[GPTTrainerConfig]] = GPTTrainerConfig - def _get_data(self) -> GPTData: return GPTData( config=self._config.data, diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index ca840911f..32fbdad9b 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -20,7 +20,6 @@ class HybridSSMBaseModel[ConfigType: HybridSSMBaseModelConfig](GPTBaseModel[Conf As for the mixer, transformer uses MHA. For the LlambaBlock we support Mamba1 and discrete mamba2. """ - config_class: typing.ClassVar[type[HybridSSMBaseModelConfig]] = HybridSSMBaseModelConfig _is_setup: bool = False def __init__( @@ -110,7 +109,6 @@ class HybridSSMModel[ConfigType: HybridSSMModelConfig](GPTModel[ConfigType]): A hybrid model that combines Transformer and SSM blocks. """ - config_class: typing.ClassVar[type[HybridSSMModelConfig]] = HybridSSMModelConfig base_model_class: typing.ClassVar[type[HybridSSMBaseModel]] = HybridSSMBaseModel diff --git a/fast_llm/models/ssm/trainer.py b/fast_llm/models/ssm/trainer.py index efa7b704f..39f589384 100644 --- a/fast_llm/models/ssm/trainer.py +++ b/fast_llm/models/ssm/trainer.py @@ -6,5 +6,4 @@ class HybridSSMTrainer[ConfigType: HybridSSMTrainerConfig](GPTTrainer[ConfigType]): - config_class: typing.ClassVar[type[HybridSSMTrainerConfig]] = HybridSSMTrainerConfig model_class: typing.ClassVar[type[HybridSSMModel]] = HybridSSMModel From bd4ff0d03fd7f878c6b8d1551ffa682326f2d150 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 12 Aug 2025 14:21:51 -0400 Subject: [PATCH 46/82] doc --- fast_llm/engine/config_utils/tensor_space.py | 86 +++++++++++++++++++- fast_llm/tensor.py | 8 +- 2 files changed, 91 insertions(+), 3 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 6c4b95b20..66176ee0f 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -15,6 +15,16 @@ class TensorDim: + """ + Describes a simple, atomic dimension of a tensor and its size. + The dimension may be parallelized along a distributed dimension `parallel_dim`, + in which case its actual (local) `size` will differ from its `global_size`. + + TensorDim's are used to represent the metadata of tensors through `TensorMeta`. + + This class also serves as a base for more complex tensor dimensions. + """ + def __init__(self, name: str, global_size: int | None, parallel_dim: DistributedDim | None = None): # TODO: Handle None for unknown sizes? self._name = name @@ -62,10 +72,25 @@ def parallel_group(self) -> "ProcessGroup|None": return None if self._parallel_dim is None else self._parallel_dim.group def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + """ + Create a copy of the tensor dimension, where the parallel dimension is replaced by `distributed_dim`, + but the local size remains the same. + + Used in`TensorMeta.replace_tensor_parallel_dim`. + """ assert self.is_parallel return TensorDim(self.name, self.size * distributed_dim.size, distributed_dim) def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + """ + Partially reconstruct a global tensor from local `tensor` slices whose dimension `dim` is described by `self`. + If the dimension is parallelized, this amounts to gathering along dimension `dim` + and parallel dimension `parallel_dim`, otherwise return the input tensor. + The method needs to be called my all members of the parallel group using their appropriate local slice. + + Used in`TensorMeta.local_to_global`, + which iterates over the tensor dimensions to fully reconstruct the global tensor. + """ if self.is_parallel: from fast_llm.core.ops import gather_op @@ -76,6 +101,14 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor def local_to_global_partial( self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 ) -> "torch.Tensor": + """ + Partially reconstruct a global tensor from a local `tensor` whose dimension `dim` is described by `self`. + Unlike `local_to_global`, this method does not need to be called from a distributed setting. + Instead, entries from other ranks are populated with `fill_value`. + + Used in`TensorMeta.local_to_global_partial`, + which iterates over the tensor dimensions to fully reconstruct the global tensor. + """ if self.is_parallel: output = tensor.new_full((*tensor.shape[:dim], self.parallel_dim.size, *tensor.shape[dim:]), fill_value) output.narrow(dim, self.parallel_dim.rank, 1).copy_(tensor.unsqueeze(dim)).squeeze(dim) @@ -84,6 +117,14 @@ def local_to_global_partial( return tensor def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + """ + Partially recover a local tensor slice from a global `tensor` whose dimension `dim` is described by `self`. + If the dimension is parallel, this amounts to taking the `rank`th chunk of size `size` along dimension `dim` + and parallel dimension `self.parallel_dim`, otherwise return the input tensor. + + Used in`TensorMeta.local_to_global`, + which iterates over the tensor dimensions to fully reconstruct the local tensor. + """ return ( tensor.chunk(self.parallel_dim.size, dim)[self.parallel_dim.rank] if self.parallel_dim is not None and self.parallel_dim.size > 1 @@ -92,11 +133,20 @@ def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = F class CompositeTensorDim(TensorDim): + """ + A composite tensor dimension that represent multiple dimensions flattened into ones. + Typically happens for flattened view or higher-dimensional tensors, or tensors that can be expanded as such. + If one of the composed dimensions -- other than the first one -- is parallelized, + this is **not** equivalent to an atomic `TensorDim` of the same size, + as the relation between local and global tensors is different. + + At most one of the sub-dimensions may be parallelized. TODO: Allow for more than one? + """ + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): parallel_dim = None for dim, tensor_dim in enumerate(tensor_dims): if tensor_dim.parallel_dim is not None: - # TODO: Allow more than one parallel subdim? assert parallel_dim is None parallel_dim = tensor_dim.parallel_dim self._parallel_dim_index = dim @@ -109,12 +159,19 @@ def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): self._tensor_dims = tensor_dims def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + """ + Create a copy of the tensor dimension, where the parallel dimension is replaced by `distributed_dim`, + but the local size remains the same. + """ assert self._parallel_dim_index is not None dims = list(self._tensor_dims) dims[self._parallel_dim_index] = dims[self._parallel_dim_index].replace_parallel_dim(distributed_dim) return CompositeTensorDim(self.name, tuple(dims)) def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + """ + Partially reconstruct a global tensor from local `tensor` slices whose dimension `dim` is described by `self`. + """ tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) for i, tensor_dim in enumerate(self._tensor_dims): tensor = tensor_dim.local_to_global(tensor, dim + i) @@ -124,6 +181,10 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor def local_to_global_partial( self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 ) -> "torch.Tensor": + """ + Partially reconstruct a global tensor from a local `tensor` whose dimension `dim` is described by `self`, + populating other ranks with `fill_value`. + """ tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) for i, tensor_dim in enumerate(self._tensor_dims): tensor = tensor_dim.local_to_global_partial(tensor, dim + i) @@ -131,6 +192,9 @@ def local_to_global_partial( return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + """ + Partially recover a local tensor slice from a global `tensor` whose dimension `dim` is described by `self`. + """ tensor = tensor.unflatten(dim, [tensor_dim.global_size for tensor_dim in self._tensor_dims]) for i, tensor_dim in reversed(list(enumerate(self._tensor_dims))): tensor = tensor_dim.global_to_local(tensor, dim + i) @@ -138,6 +202,12 @@ def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = F class ConcatenatedTensorDim(TensorDim): + """ + A complex tensor dimension that results from concatenating tensors. + + All sub-dimensions should have the same `parallel_dim` (may be None). TODO: Allow for more complex scenarios? + """ + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): parallel_dim = tensor_dims[0].parallel_dim for dim, tensor_dim in enumerate(tensor_dims[1:]): @@ -152,12 +222,19 @@ def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): self._tensor_dims = tensor_dims def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + """ + Create a copy of the tensor dimension, where the parallel dimension is replaced by `distributed_dim`, + but the local size remains the same. + """ assert self.is_parallel return ConcatenatedTensorDim( self.name, tuple(tensor_dim.replace_parallel_dim(distributed_dim) for tensor_dim in self._tensor_dims) ) def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + """ + Partially reconstruct a global tensor from local `tensor` slices whose dimension `dim` is described by `self`. + """ import torch return ( @@ -179,6 +256,10 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor def local_to_global_partial( self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 ) -> "torch.Tensor": + """ + Partially reconstruct a global tensor from a local `tensor` whose dimension `dim` is described by `self`, + populating other ranks with `fill_value`. + """ import torch return ( @@ -198,6 +279,9 @@ def local_to_global_partial( ) def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + """ + Partially recover a local tensor slice from a global `tensor` whose dimension `dim` is described by `self`. + """ if self.is_parallel and expand: raise NotImplementedError() import torch diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index d080e6a1e..c17df9d0c 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -240,8 +240,12 @@ def validate(self, tensor: torch.Tensor, device: torch.device | None = None) -> return validate_tensor(tensor, self, device) def replace_tensor_parallel_dim(self, distributed_dim: DistributedDim) -> "TensorMeta": - # Replace the tensor-parallel `DistributedDim` in `meta`. - # Note: This will turn `ParameterMeta` into `TensorMeta` + """ + Replace the tensor-parallel `DistributedDim` in `meta`, preserving the local size. + Requires for advanced tensor manipulations, + ex. turn tensor-parallel slices of a tensor into slices of a different tensor-parallel size. + Note: This will turn `ParameterMeta` into `TensorMeta` + """ if not self.is_tensor_parallel: return self dims = list(self.dims) From 0e2e12402e162c4ad7a378b17522f03a24288ae6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 12 Aug 2025 17:17:07 -0400 Subject: [PATCH 47/82] stuff --- tests/utils/global_variables.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/utils/global_variables.py b/tests/utils/global_variables.py index 80232bf53..836b6b79d 100644 --- a/tests/utils/global_variables.py +++ b/tests/utils/global_variables.py @@ -29,8 +29,9 @@ def set_testing_global_variables(): num_gpus = len(gpus) gpus = [gpus[(i + worker_id) % num_gpus] for i in range(num_gpus)] os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in gpus) - os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(SHARED_RESULT_PATH / "torchinductor_cache") - os.environ["TRITON_CACHE_DIR"] = str(SHARED_RESULT_PATH / "triton_cache") + # TODO: This might help with some issues, but slows down testing significantly. + # os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(SHARED_RESULT_PATH / "torchinductor_cache") + # os.environ["TRITON_CACHE_DIR"] = str(SHARED_RESULT_PATH / "triton_cache") # TODO: Fixtures From 0a5e4584990165ad5ed69434fc3ca37f3e9ae856 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 Aug 2025 16:22:15 -0400 Subject: [PATCH 48/82] Remove tensor space, fixes --- Megatron-LM | 2 +- fast_llm/config.py | 18 +- fast_llm/engine/base_model/base_model.py | 57 +++--- fast_llm/engine/base_model/config.py | 7 +- .../{tensor_space.py => tensor_dim.py} | 50 +---- fast_llm/engine/multi_stage/fast_llm_model.py | 1 + fast_llm/engine/multi_stage/fsdp.py | 16 +- fast_llm/engine/multi_stage/multi_stage.py | 2 +- fast_llm/engine/multi_stage/stage.py | 13 +- fast_llm/engine/multi_stage/stage_base.py | 4 +- fast_llm/engine/schedule/runner.py | 2 +- fast_llm/layers/block/block.py | 107 +++++----- fast_llm/layers/block/config.py | 7 - fast_llm/layers/block/mlp/config.py | 44 ---- .../layers/block/mlp/mixture_of_experts.py | 51 +++-- fast_llm/layers/block/mlp/mlp.py | 65 ++++-- fast_llm/layers/common/config.py | 2 +- fast_llm/layers/common/linear.py | 2 +- fast_llm/layers/common/normalization.py | 2 +- fast_llm/layers/common/peft.py | 2 +- fast_llm/layers/language_model/config.py | 23 +-- fast_llm/layers/language_model/embedding.py | 63 +++--- fast_llm/layers/language_model/head.py | 188 ++++++++---------- .../layers/language_model/preprocessing.py | 36 ++-- fast_llm/layers/ssm/block.py | 23 ++- fast_llm/layers/ssm/config.py | 96 +-------- fast_llm/layers/ssm/discrete_mamba2.py | 78 +++++--- .../layers/ssm/{mamba_layer.py => mamba.py} | 52 +++-- fast_llm/layers/ssm/mamba2.py | 93 +++++---- fast_llm/layers/transformer/attention.py | 184 +++++++++-------- fast_llm/layers/transformer/block.py | 10 +- fast_llm/layers/transformer/config.py | 48 +---- fast_llm/layers/transformer/preprocessing.py | 51 ++--- fast_llm/layers/transformer/rotary/config.py | 6 +- .../transformer/rotary/preprocessing.py | 68 ------- fast_llm/layers/transformer/rotary/rotary.py | 57 ++---- fast_llm/logging.py | 6 +- fast_llm/models/custom/model.py | 29 +-- fast_llm/models/gpt/model.py | 105 ++++++---- fast_llm/models/ssm/config.py | 9 - fast_llm/models/ssm/model.py | 113 +++-------- fast_llm/tensor.py | 25 +-- tests/functional/test_triton_kernels.py | 4 +- tests/test_attention.py | 35 +--- tests/test_mlp.py | 29 --- tests/utils/global_variables.py | 4 +- tests/utils/utils.py | 7 +- 47 files changed, 791 insertions(+), 1105 deletions(-) rename fast_llm/engine/config_utils/{tensor_space.py => tensor_dim.py} (81%) rename fast_llm/layers/ssm/{mamba_layer.py => mamba.py} (79%) delete mode 100644 fast_llm/layers/transformer/rotary/preprocessing.py delete mode 100644 tests/test_mlp.py diff --git a/Megatron-LM b/Megatron-LM index 75b0d9787..f02b413f7 160000 --- a/Megatron-LM +++ b/Megatron-LM @@ -1 +1 @@ -Subproject commit 75b0d97876006c4b6b23fce302100d18dbf7db37 +Subproject commit f02b413f793af05ade3893bccd8aef6d644d3edf diff --git a/fast_llm/config.py b/fast_llm/config.py index 099670625..3352f3570 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -1031,7 +1031,23 @@ def __init__(self, config: ConfigType, *args, **kwargs): def __init_subclass__(cls): # Automatically set `config_class` based on the bound type. # Make sure `ConfigType` is bound and respects class hierarchy. - Assert.custom(issubclass, config_class := ConfigType.__bound__, cls.config_class) + try: + config_class = None + for base in types.get_original_bases(cls): + if hasattr(base, "__origin__") and issubclass(base.__origin__, Configurable): + for arg in base.__args__: + if arg.__name__ == "ConfigType": + if config_class is None: + config_class = arg.__bound__ + else: + assert arg.__bound__ is config_class + assert config_class is not None + except Exception as e: + raise TypeError( + f"Could not determine the configuration class for the configurable class {cls.__name__}: {e.args}. " + "Please make sure to declare in the format " + f"`class {cls.__name__}[ConfigType: ConfigClass](BaseConfigurable[ConfigType])`.] " + ) cls.config_class = config_class @property diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index caaf94794..832225803 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -7,7 +7,6 @@ from fast_llm.config import Configurable from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.tensor import ParameterMeta, TensorMeta @@ -20,11 +19,18 @@ class Module(torch.nn.Module, abc.ABC): """ """ - def forward(self, input_, kwargs): - """ - Run a forward pass for the module, with autograd support. - """ - raise NotImplementedError() + _is_setup: bool = False + _distributed: Distributed + + def __init__(self, distributed_config: DistributedConfig): + self._distributed_config = distributed_config + super().__init__() + + def setup(self, distributed: Distributed) -> None: + assert not self._is_setup + distributed.check_config(self._distributed_config) + self._distributed = distributed + self._is_setup = True class Layer(Module): @@ -39,9 +45,9 @@ def forward( class Sequential(Layer): - def __init__(self, layers: list[Layer]): - super().__init__() - self.layers = torch.nn.ModuleList(layers) + def __init__(self, distributed_config: DistributedConfig): + super().__init__(distributed_config) + self.layers = torch.nn.ModuleList(self.get_layers()) def __getitem__(self, item): return self.layers[item] @@ -59,6 +65,15 @@ def forward( input_ = layer(input_, kwargs, losses, metrics) return input_ + @abc.abstractmethod + def get_layers(self) -> list[Layer]: + pass + + def setup(self, distributed: Distributed) -> None: + super().setup(distributed) + for layer in self.layers: + layer.setup(distributed) + @dataclasses.dataclass() class LossDef: @@ -71,28 +86,14 @@ class LossDef: dtype: torch.dtype = torch.float32 -class SequentialLayers(Sequential, abc.ABC): - # Small class defined to fix the MRO of BaseModel.__init__ - def __init__(self): - super().__init__(self.get_layers()) - - @abc.abstractmethod - def get_layers(self) -> list[Layer]: - pass - - -class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], SequentialLayers, abc.ABC): - _is_setup: bool = False +class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], Sequential): def __init__( self, config: BaseModelConfig, distributed_config: DistributedConfig, ): - self._tensor_space: TensorSpace = TensorSpace(distributed_config) - config.setup_tensor_space(self._tensor_space) - - super().__init__(config) + super().__init__(config, distributed_config) for key, value in self.named_parameters(): Assert.custom(isinstance, value, ParameterMeta) @@ -103,12 +104,6 @@ def __init__( # TODO: Add basic handling (preprocessor) in this class. self._reference_models: dict[str, "InferenceRunner"] = {} - def setup(self, distributed: Distributed) -> None: - assert not self._is_setup - distributed.check_config(self._tensor_space.distributed_config) - self._tensor_space.setup(distributed) - self._is_setup = True - @abc.abstractmethod def get_layers(self) -> list[Layer]: pass diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 4be42e069..22abb021b 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -6,7 +6,7 @@ from fast_llm.utils import compare_nested, log if typing.TYPE_CHECKING: - from fast_llm.engine.config_utils.tensor_space import TensorSpace + import torch @config_class() @@ -18,9 +18,6 @@ class BaseModelConfig(Config): _abstract = True - def setup_tensor_space(self, tensor_space: "TensorSpace") -> None: - raise NotImplementedError() - def compare_architecture( self, model_config: typing.Self, @@ -64,5 +61,5 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: pass @abc.abstractmethod - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: pass diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_dim.py similarity index 81% rename from fast_llm/engine/config_utils/tensor_space.py rename to fast_llm/engine/config_utils/tensor_dim.py index 6c4b95b20..f67916a66 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_dim.py @@ -2,14 +2,13 @@ import math import typing -from fast_llm.engine.distributed.config import DistributedConfig, DistributedDim +from fast_llm.engine.distributed.config import DistributedDim from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: import torch from fast_llm.core.distributed import ProcessGroup - from fast_llm.engine.distributed.distributed import Distributed logger = logging.getLogger(__name__) @@ -219,49 +218,4 @@ def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = F ) -class DefaultDimNames: - # Scalar - scalar = "scalar" - - -class TensorSpace: - _is_setup: bool = False - _distributed: "Distributed" - - def __init__(self, distributed_config: DistributedConfig): - self._distributed_config = distributed_config - self._tensor_dims: dict[str, TensorDim] = {} - self.add_tensor_dim(TensorDim(DefaultDimNames.scalar, 1)) - - def setup(self, distributed: "Distributed") -> None: - assert not self._is_setup - if distributed.config is not self._distributed_config: - distributed.config.compare(self._distributed_config, ValueError) - self._is_setup = True - self._distributed = distributed - - @property - def distributed_config(self) -> DistributedConfig: - return self._distributed_config - - @property - def distributed(self) -> "Distributed": - assert self._is_setup - return self._distributed - - def add_tensor_dim(self, tensor_dim: TensorDim) -> None: - if tensor_dim.name in self._tensor_dims: - Assert.eq(tensor_dim, self._tensor_dims[tensor_dim.name]) - else: - if tensor_dim.parallel_dim is not None: - assert ( - tensor_dim.parallel_dim.name in self._distributed_config.distributed_dims - ), tensor_dim.parallel_dim.name - Assert.eq( - tensor_dim.parallel_dim.__dict__, - self._distributed_config.distributed_dims[tensor_dim.parallel_dim.name].__dict__, - ) - self._tensor_dims[tensor_dim.name] = tensor_dim - - def __getitem__(self, name: str) -> TensorDim: - return self._tensor_dims[name] +scalar_dim = TensorDim("scalar", 1) diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index 09ee788e6..da4fe527e 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -51,6 +51,7 @@ def from_pretrained( use_cpu: bool = False, stage_filter: set | None = None, ) -> typing.Self: + print("IUGRGHIOERIO", cls, cls.config_class) metadata = cls.config_class.load_metadata(pretrained_config) config = cls.config_class.from_dict(metadata.config, *updates, update_type=UpdateType.update) if mode.support_training: diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index be15cd37a..cb0a02a67 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -9,7 +9,7 @@ from fast_llm.core.distributed import ProcessGroup from fast_llm.core.ops import gather_op from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import SHARD_PAD_TO_MULTIPLE, ShardName, StageMode @@ -320,27 +320,31 @@ def import_state_tensor( return end - begin def export_shard( - self, shard: torch.Tensor, distributed: Distributed, data_type: DataType | None = None + self, shard: torch.Tensor, data_type: DataType | None = None ) -> typing.Generator[tuple[str, torch.Tensor], None, None]: if data_type is not None: shard = shard.to(dtype=data_type.torch) tensors = self.split_buffer(self.reconstruct_from_shard(shard)) for name, meta in self._parameter_metas.items(): - yield name, meta.local_to_global(tensors[name], distributed=distributed)[0] + yield name, meta.local_to_global(tensors[name])[0] def log_shard(self, name, shard, *, distributed: Distributed, level, global_: bool) -> None: # if global_ is None: # global_ = self._config.debug_global_tensors parameters = self.split_buffer(self.reconstruct_from_shard(shard)) if global_ else self.split_shard(shard) for parameter_name, parameter in parameters.items(): + meta = self.get_parameter_meta(parameter_name) log_distributed_tensor( name, parameter, level=level, - distributed=distributed, global_=global_, - duplicate_groups=(distributed.data_group,), - meta=self.get_parameter_meta(parameter_name), + # Assuming all tensors are either duplicated of parallel in the TP direction. + duplicate_groups=( + distributed.data_group, + distributed.tensor_group, + ), + meta=meta, ) def restore_parameters(self) -> None: diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index e17bc4ff8..d939bda2b 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -12,7 +12,7 @@ from fast_llm.engine.base_model.base_model import BaseModel from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.run import log_main_rank, log_model_parallel_main_rank -from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 87eac31c4..35547cd87 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -7,7 +7,7 @@ from fast_llm.core.distributed import check_parallel_match from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.multi_stage.config import StageMode +from fast_llm.engine.multi_stage.config import StageConfig, StageMode from fast_llm.engine.multi_stage.stage_base import StageBase from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage, log_tensor from fast_llm.tensor import ParameterMeta, TensorMeta, accumulate_gradient @@ -30,7 +30,7 @@ def hook(grad_inputs, grad_outputs): # noqa return hook -class Stage(StageBase): +class Stage[ConfigType: StageConfig](StageBase[ConfigType]): _is_restored: bool _training: bool | None = None # TODO: Handle all buffer sharing in multi_stage @@ -123,7 +123,7 @@ def forward( # Last layer does not provide output if output is not None: meta = self._meta_outputs[i] - output_global, _ = meta.local_to_global(output.detach(), distributed=self._distributed) + output_global, _ = meta.local_to_global(output.detach()) kwargs["hidden_states"][self._layer_range[i]] = { "layer_type": type(layer).__name__, "tensor": output_global, @@ -216,11 +216,13 @@ def _log_layer_forward(self, output: torch.Tensor, kwargs: dict[str, typing.Any] if (nms := kwargs.get("micro_batch_splits", 1)) > 1: name = f"{name}, ms={kwargs.get('micro_batch_split',0)}/{nms}" + # Assuming all tensors are either duplicated of parallel in the TP direction. log_distributed_tensor( name, output, level=self._config.debug_layer_outputs, - distributed=self._distributed, + # Assuming all tensors are either duplicated of parallel in the TP direction. + duplicate_groups=(self._distributed.tensor_group,), global_=self._config.debug_global_tensors, meta=self._meta_outputs[i], ) @@ -250,8 +252,9 @@ def _log_layer_backward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any name, input_, level=self._config.debug_layer_gradients, - distributed=self._distributed, grad_fn=lambda grad: grad / self._fsdp_size, + # Assuming all tensors are either duplicated of parallel in the TP direction. + duplicate_groups=(self._distributed.tensor_group,), global_=self._config.debug_global_tensors, meta=self._meta_inputs[i], ) diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 387a53a03..ded24e538 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) -class StageBase(Configurable[StageConfig]): +class StageBase[ConfigType: StageConfig](Configurable[ConfigType]): _distributed: Distributed _mode: StageMode @@ -314,7 +314,7 @@ def _export_shard( self, shards: tuple[torch.Tensor], data_type: DataType | None = None ) -> typing.Generator[tuple[str, torch.Tensor], None, None]: for fsdp, shard in zip(self._fsdps, shards, strict=True): - yield from fsdp.export_shard(shard, self._distributed, data_type) + yield from fsdp.export_shard(shard, data_type) def _get_parameter_metas(self) -> tuple[list[ParameterMeta], list[ParameterMeta]]: # Get all the stage parameters, diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 8eca4559d..21ecbe476 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -63,7 +63,7 @@ def __repr__(self): ) -class ScheduleRunner[ConfigType: ScheduleConfig](Configurable[ScheduleConfig]): +class ScheduleRunner[ConfigType: ScheduleConfig](Configurable[ConfigType]): _is_setup: bool = False _compute_stream: torch.cuda.Stream _data_stream: torch.cuda.Stream diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index f06b2da45..425731eb9 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -5,12 +5,14 @@ import torch -from fast_llm.config import Configurable +from fast_llm.config import Config, Configurable from fast_llm.core.distributed import set_generator -from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.base_model.base_model import Layer, Module from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta @@ -19,8 +21,7 @@ class DebugLayer: # TODO: Move elsewhere? - def __init__(self, tensor_space: TensorSpace, name: str, debug_level: int = 0, debug_memory: bool = False): - self._tensor_space = tensor_space + def __init__(self, name: str, debug_level: int = 0, debug_memory: bool = False): self._name = name self._debug_level = debug_level self._debug_memory = debug_memory @@ -36,9 +37,9 @@ def _get_meta( ( dim if isinstance(dim, TensorDim) - else hidden_dims[dim] if dim in hidden_dims else self._tensor_space[dim] + else hidden_dims[dim] if dim in hidden_dims else TensorDim(dim, tensor.size(i)) ) - for dim in dims + for i, dim in enumerate(dims) ), tensor_name=f"{self._name} {name}", dtype=tensor.dtype, @@ -69,7 +70,6 @@ def __call__[ tensor, level=self._debug_level, meta=self._get_meta(tensor, name, dims, kwargs), - distributed=self._tensor_space.distributed, global_=global_, log_fn=log_fn, scale=scale, @@ -80,31 +80,45 @@ def __call__[ tensor, level=self._debug_level, meta=self._get_meta(tensor, name + " grad", dims, kwargs), - distributed=self._tensor_space.distributed, global_=global_, log_fn=log_fn, scale=scale, ) -class BlockLayer(torch.nn.Module, abc.ABC): +class BlockLayerBase[ConfigType: Config](Configurable[ConfigType], Module): """ - Base class for mixer and MLP modules. + Base class for blocks, mixer and MLP modules. """ - def __init__(self, tensor_space: TensorSpace, block_index: int, name: str, debug_level: int, debug_memory: bool): - super().__init__() - self._tensor_space = tensor_space + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + # TODO: Review `hidden_dim` and `block_index` + hidden_dim: TensorDim, + block_index: int, + name: str, + debug_level: int, + debug_memory: bool, + ): + super().__init__(config, distributed_config) + self._hidden_dim = hidden_dim self._block_index = block_index self._name = name - self._sequence_parallel: bool = self._tensor_space.distributed_config.sequence_tensor_parallel + self._sequence_parallel: bool = self._distributed_config.sequence_tensor_parallel self._debug = DebugLayer( - tensor_space, self._name, debug_level, debug_memory, ) + +class BlockLayer[ConfigType: Config](BlockLayerBase[ConfigType]): + """ + Base class for mixer and MLP modules. + """ + @abc.abstractmethod def forward( self, @@ -116,7 +130,7 @@ def forward( pass -class Block[ConfigType: BlockConfig](Configurable[ConfigType], Layer): +class Block[ConfigType: BlockConfig](BlockLayerBase[ConfigType], Layer): """ A transformer-like decoder base block with abstract mixer. """ @@ -124,26 +138,30 @@ class Block[ConfigType: BlockConfig](Configurable[ConfigType], Layer): # TODO: Standardize to `mixer` _mixer_module_name: typing.ClassVar[str] = "mixer" - def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, return_input: bool = False): - super().__init__(config) - # TODO: Argument? - self._block_index = block_index - self._name = f"Block {self._block_index}" - self._tensor_space: TensorSpace = tensor_space - self._dropout_p: float = self._config.hidden_dropout + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + return_input: bool = False, + ): + super().__init__( + config, + distributed_config, + hidden_dim, + block_index, + name, + config.debug_transformer, + config.debug_transformer_memory, + ) # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input - self._debug = DebugLayer( - tensor_space, - self._name, - self._config.debug_transformer, - self._config.debug_transformer_memory, - ) - hidden_dim = self._tensor_space[BlockDimNames.hidden] # Note, layer_lr_scale does not impact the norms # TODO: add a separate norm_lr_scale - self.norm_1 = self._config.normalization.get_layer(hidden_dim) - self.norm_2 = self._config.normalization.get_layer(hidden_dim) + self.norm_1 = self._config.normalization.get_layer(self._hidden_dim) + self.norm_2 = self._config.normalization.get_layer(self._hidden_dim) # The mixer needs to be created here for backward-compatible weight ordering. setattr(self, self._mixer_module_name, self._create_mixer()) @@ -153,15 +171,18 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: i from fast_llm.layers.block.mlp.mlp import MLP self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, - self._tensor_space, - self._block_index, + self._config, self._distributed_config, self._hidden_dim, self._block_index, f"{self._name} MLP" ) # PEFT. self.norm_1 = self._config.peft.apply_other(self.norm_1) self.norm_2 = self._config.peft.apply_other(self.norm_2) + def setup(self, distributed: Distributed) -> None: + super().setup(distributed) + getattr(self, self._mixer_module_name).setup(distributed) + self.mlp.setup(distributed) + @abc.abstractmethod def _create_mixer(self) -> BlockLayer: pass @@ -172,11 +193,7 @@ def _bias_dropout_add( ) -> torch.Tensor: if bias is not None: input_ = input_ + bias - return residual + torch.dropout(input_, self._dropout_p, self.training) - - # @property - # def name(self) -> str: - # return f"{self._name} {self._block_index}" + return residual + torch.dropout(input_, self._config.hidden_dropout, self.training) def forward( self, @@ -190,11 +207,7 @@ def forward( if self._return_input: dims = (TensorDim("stacked_input_output", 2),) + dims return TensorMeta.from_dims(dims, tensor_name=f"{self._name} output", dtype=input_.dtype) - generator = ( - self._tensor_space.distributed.tp_generator - if self._tensor_space.distributed_config.sequence_tensor_parallel - else self._tensor_space.distributed.pp_generator - ) + generator = self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator if self._debug.enabled: self._debug(None, "begin", kwargs[BlockKwargs.hidden_dims], kwargs) fw_input = input_ diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 919f95b3f..0da7a0c99 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -2,7 +2,6 @@ from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.block.mlp.config import MLPConfig from fast_llm.layers.block.peft import TransformerPeftConfig from fast_llm.layers.common.config import NormalizationConfig @@ -117,9 +116,3 @@ class BlockConfig(MLPConfig, BaseModelConfig): desc="Min value for clamping initialized weights. Default: -float('inf')", hint=FieldHint.optional, ) - - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - super().setup_tensor_space(tensor_space) - - # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(BlockDimNames.hidden, self.hidden_size)) diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index a99debacc..57f7a9e03 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -1,27 +1,10 @@ import enum from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace -from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType, MLPRecomputeLevel from fast_llm.utils import Assert -class MLPDimNames: - # MLP dimensions - mlp = "mlp" - gate_and_up = "gate_and_up" - composite_gated_mlp = "composite_gated_mlp" - experts = "experts" - top_experts = "top_experts" - shared_experts = "shared_experts" - unshared_experts = "unshared_experts" - composite_expert_mlp = "composite_expert_mlp" - composite_gated_expert_mlp = "composite_gated_expert_mlp" - composite_shared_expert_mlp = "composite_shared_expert_mlp" - composite_gated_shared_expert_mlp = "composite_gated_shared_expert_mlp" - - class MLPLossNames: load_balancing_loss = "load_balancing_loss" router_z_loss = "router_z_loss" @@ -206,30 +189,3 @@ def _validate(self) -> None: Assert.geq(scale, 0) elif self.mlp_lr_scale is not None: Assert.geq(self.mlp_lr_scale, 0) - - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - - # MLP dimensions - tensor_space.add_tensor_dim(mlp := TensorDim(MLPDimNames.mlp, self.ffn_hidden_size, tensor)) - tensor_space.add_tensor_dim(gate_and_up := TensorDim(MLPDimNames.gate_and_up, 2 if self.gated else 1)) - tensor_space.add_tensor_dim(CompositeTensorDim(MLPDimNames.composite_gated_mlp, (gate_and_up, mlp))) - tensor_space.add_tensor_dim(experts := TensorDim(MLPDimNames.experts, self.num_experts)) - tensor_space.add_tensor_dim(CompositeTensorDim(MLPDimNames.composite_expert_mlp, (experts, mlp))) - tensor_space.add_tensor_dim( - CompositeTensorDim(MLPDimNames.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) - ) - tensor_space.add_tensor_dim(TensorDim(MLPDimNames.top_experts, self.num_experts_per_token)) - tensor_space.add_tensor_dim(TensorDim(MLPDimNames.unshared_experts, self.num_unshared_experts)) - - # shared_experts - if self.num_shared_experts: - tensor_space.add_tensor_dim( - shared_experts := TensorDim(MLPDimNames.shared_experts, self.num_shared_experts) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(MLPDimNames.composite_shared_expert_mlp, (shared_experts, mlp)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(MLPDimNames.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp)) - ) diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index e53693460..0bc531dad 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -5,11 +5,12 @@ from fast_llm.core.distributed import ProcessGroup, set_generator from fast_llm.engine.config_utils.initialization import init_normal_ -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map -from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs -from fast_llm.layers.block.mlp.config import MLPDimNames, MLPLossNames, RoutingType +from fast_llm.layers.block.config import BlockConfig, BlockKwargs +from fast_llm.layers.block.mlp.config import MLPLossNames, RoutingType from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.linear import Linear @@ -18,7 +19,7 @@ logger = logging.getLogger(__name__) -class MixtureOfExpertMLP(MLPBase): +class MixtureOfExpertMLP[ConfigType: BlockConfig](MLPBase[ConfigType]): """ MoeLayer following implementation from https://github.com/NVIDIA/Megatron-LM/blob/46ebc0e4202c980d98900000d455f754a7ff9d4b/megatron/model/transformer.py#L346 @@ -32,18 +33,25 @@ class MixtureOfExpertMLP(MLPBase): _group: ProcessGroup - def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: int = 0, name: str = "mlp"): + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + ): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, tensor_space, block_index, name) + super().__init__(config, distributed_config, hidden_dim, block_index, name) layer_lr_scale = self._config.per_layer_lr_scale[block_index] if self._config.per_layer_lr_scale else None router_lr_scale = get_lr_scale(self._config.router_lr_scale, layer_lr_scale) self.router = Linear( - tensor_space[BlockDimNames.hidden], - tensor_space[MLPDimNames.unshared_experts], + hidden_dim, + TensorDim("router_experts", self._config.num_unshared_experts), bias=False, weight_init_method=init_normal_( std=self._config.init_method_std, @@ -53,20 +61,33 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: lr_scale=router_lr_scale, ) dropless_moe = self._config.dropless_moe - if dropless_moe and tensor_space.distributed_config.sequence_tensor_parallel: + if dropless_moe and self._sequence_parallel: warnings.warn( "Dropless MoE not supported for sequence-tensor-parallel, falling back to looped implementation." ) dropless_moe = False self._mlp_forward = self._forward_dropless if dropless_moe else self._forward_looped + if self._debug.enabled: + self._top_expert_dim = TensorDim("top_experts", self._config.num_experts_per_token) + + def _get_intermediate_dims(self) -> tuple[TensorDim, TensorDim]: + intermediate_1_dim, intermediate_2_dim = super()._get_intermediate_dims() + experts_dim = TensorDim("experts", self._config.num_experts) + return ( + CompositeTensorDim("moe_intermediate_1", (experts_dim, intermediate_1_dim)), + CompositeTensorDim("moe_intermediate_2", (experts_dim, intermediate_2_dim)), + ) + def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: hidden_states = input_.flatten(0, -2) logits = self.router(hidden_states) if self._debug.enabled: - self._debug(logits, "Router logits", kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.experts,), kwargs) + self._debug( + logits, "Router logits", kwargs[BlockKwargs.hidden_dims][:-1] + (self._top_expert_dim,), kwargs + ) # Apply z_loss if applicable if self._config.expert_z_loss_coefficient > 0.0: @@ -81,7 +102,7 @@ def forward( # Apply input_jitter if applicable: if self.training and self._config.moe_jitter_eps > 0.0: - with set_generator(self._tensor_space.distributed.pp_generator): + with set_generator(self._distributed.pp_generator): logits = self._apply_input_jitter(logits) # Routing @@ -97,12 +118,12 @@ def forward( if self._debug.enabled: # To log all ranks set `global_=False` self._debug( - scores, "Router scores", kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.top_experts,), kwargs + scores, "Router scores", kwargs[BlockKwargs.hidden_dims][:-1] + (self._top_expert_dim,), kwargs ) self._debug( top_experts, "Router top experts", - kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.top_experts,), + kwargs[BlockKwargs.hidden_dims][:-1] + (self._top_expert_dim,), kwargs, ) @@ -126,7 +147,7 @@ def _forward_dropless( None, gated=self._config.gated, activation_type=self._config.activation_type, - group=self._intermediate_dim.parallel_group, + group=self._parallel_dim.group, sequence_parallel=self._sequence_parallel, training=self.training, recompute_level=self._config.mlp_recompute_level, @@ -146,7 +167,7 @@ def _forward_looped( self._config.num_experts, self._config.gated, self._config.activation_type, - self._intermediate_dim.parallel_group, + self._parallel_dim.group, self._sequence_parallel, self.training, self._config.mlp_recompute_level, diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index 06850c8d0..dc5178479 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -3,27 +3,37 @@ import torch from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_ -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_dim import ConcatenatedTensorDim, TensorDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd from fast_llm.layers.block.block import BlockLayer -from fast_llm.layers.block.config import BlockConfig, BlockDimNames -from fast_llm.layers.block.mlp.config import MLPDimNames +from fast_llm.layers.block.config import BlockConfig from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import LinearBase from fast_llm.utils import Assert, get_lr_scale -class MLPBase(BlockLayer): - def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: int = 0, name: str = "mlp"): +class MLPBase[ConfigType: BlockConfig](BlockLayer[ConfigType]): + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + ): super().__init__( - tensor_space, + config, + distributed_config, + hidden_dim, block_index, name, - debug_level=config.debug_transformer, - debug_memory=config.debug_transformer_memory, + config.debug_transformer, + config.debug_transformer_memory, ) - self._config = config + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + intermediate_1_dim, intermediate_2_dim = self._get_intermediate_dims() init_method_1 = init_normal_( std=self._config.init_method_std_mlp_1, @@ -36,8 +46,6 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: max_val=self._config.init_method_max_mlp_2, ) - hidden_dim = self._tensor_space[BlockDimNames.hidden] - self._intermediate_dim = self._tensor_space[MLPDimNames.composite_expert_mlp] self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation layer_lr_scale = self._config.per_layer_lr_scale[block_index] if self._config.per_layer_lr_scale else None @@ -51,19 +59,19 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - self._tensor_space[MLPDimNames.composite_gated_expert_mlp], + intermediate_1_dim, bias=self._config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_zeros_, lr_scale=lr_scale, ) self.layer_2 = LinearBase( - self._intermediate_dim, + intermediate_2_dim, hidden_dim, bias=self._config.add_mlp_bias, weight_init_method=init_method_2, bias_init_method=init_zeros_, - auto_bias_grad_accumulation=self._tensor_space.distributed_config.tensor_parallel > 1, + auto_bias_grad_accumulation=self._distributed_config.tensor_parallel > 1, transposed_weight=True, lr_scale=lr_scale, ) @@ -72,11 +80,27 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: self.layer_1 = self._config.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1) self.layer_2 = self._config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) + def _get_intermediate_dims(self): + intermediate_2_dim = TensorDim("intermediate", self._config.ffn_hidden_size, self._parallel_dim) + if self._config.gated: + TensorDim("gate_and_up", 2) + intermediate_1_dim = ConcatenatedTensorDim("gate_and_up", (intermediate_2_dim, intermediate_2_dim)) + else: + intermediate_1_dim = intermediate_2_dim + return intermediate_1_dim, intermediate_2_dim -class MLP(MLPBase): - def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: int = 0, name: str = "mlp"): + +class MLP[ConfigType: BlockConfig](MLPBase[ConfigType]): + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + ): Assert.eq(config.num_experts, 1) - super().__init__(config, tensor_space, block_index, name) + super().__init__(config, distributed_config, hidden_dim, block_index, name) def forward( self, @@ -85,7 +109,6 @@ def forward( losses: dict[str, typing.Any] | None = None, metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - parallel_group = self._intermediate_dim.parallel_group return ( mlp_autograd( input_, @@ -93,14 +116,14 @@ def forward( self.layer_1.weight, self.layer_1.bias, self.layer_2.weight, - None if parallel_group else self.layer_2.bias, + None if self._parallel_dim.group else self.layer_2.bias, gated=self._config.gated, activation_type=self._config.activation_type, - group=parallel_group, + group=self._parallel_dim.group, sequence_parallel=self._sequence_parallel, training=self.training, recompute_level=self._config.mlp_recompute_level, transposed_layer_2_weight=self.layer_2.transposed_weight, ), - self.layer_2.bias if parallel_group else None, + self.layer_2.bias if self._parallel_dim.group else None, ) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 2f45fdf9f..f56e2a2c1 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -9,7 +9,7 @@ if typing.TYPE_CHECKING: import torch - from fast_llm.engine.config_utils.tensor_space import TensorDim + from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.layers.common.linear import LinearBase, LinearLike from fast_llm.layers.common.normalization import LayerNorm, RMSNorm diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear.py index 740b4847c..ca807e67c 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear.py @@ -4,7 +4,7 @@ import torch from fast_llm.engine.config_utils.initialization import init_zeros_ -from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.linear import ( input_parallel_linear_autograd, diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index d44be3297..2b928eb38 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -2,7 +2,7 @@ from fast_llm.engine.config_utils.initialization import init_ones_, init_zeros_ from fast_llm.engine.config_utils.run import log_main_rank -from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.normalization import triton_normalization_autograd from fast_llm.layers.common.config import NormalizationImplementation diff --git a/fast_llm/layers/common/peft.py b/fast_llm/layers/common/peft.py index 08f3e535b..87991ef29 100644 --- a/fast_llm/layers/common/peft.py +++ b/fast_llm/layers/common/peft.py @@ -2,7 +2,7 @@ import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.common.linear import Linear, LinearBase diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index b667e5318..de3f9f196 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -2,24 +2,13 @@ from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl -from fast_llm.layers.block.config import BlockDimNames, BlockKwargs +from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.layers.transformer.rotary.config import NoRotaryConfig from fast_llm.utils import Assert -class LanguageModelDimNames(BlockDimNames): - # Embedding dimensions - position_embed = "position_embed" - vocab = "vocab" - vocab_tp = "vocab_tp" - # Misc - scalar = "scalar" - - class LanguageModelLossNames: language_model_loss = "language_model_loss" z_loss = "z_loss" @@ -237,16 +226,6 @@ def _validate(self) -> None: len(self.transformer.per_layer_lr_scale), self.transformer.num_layers + self.prediction_heads - 1 + 1 ) - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - self.transformer.setup_tensor_space(tensor_space) - tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - - # Embedding dimensions - tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.position_embed, self.max_position_embeddings)) - # TODO: Need both? - tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab, self.vocab_size)) - tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab_tp, self.vocab_size, tensor)) - @property def num_absolute_position_embeddings(self) -> int: # TODO: Rename from max embeddings. diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index d90442e9f..d1b912167 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -2,20 +2,21 @@ import torch -from fast_llm.config import Configurable from fast_llm.core.distributed import set_generator from fast_llm.core.ops import reduce_forward, split from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.initialization import init_normal_ -from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelDimNames, LanguageModelKwargs +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.layers.block.block import BlockLayerBase +from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import Assert WORD_EMBEDDINGS_WEIGHT = "word_embeddings_weight" -class LanguageModelEmbedding[ConfigType: LanguageModelBaseConfig](Configurable[LanguageModelBaseConfig], Layer): +class LanguageModelEmbedding[ConfigType: LanguageModelBaseConfig](BlockLayerBase[ConfigType], Layer): """ A language model embedding layer. Consists of word embeddings (tensor-parallel or sequence-tensor-parallel), @@ -28,35 +29,39 @@ class LanguageModelEmbedding[ConfigType: LanguageModelBaseConfig](Configurable[L def __init__( self, config: ConfigType, - tensor_space: TensorSpace, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + # TODO: Unnecessary? + block_index: int, + name: str, ): - super().__init__(config) - self._tensor_space = tensor_space - self._distributed_config = self._tensor_space.distributed_config + super().__init__( + config, + distributed_config, + hidden_dim, + block_index, + name, + config.transformer.debug_transformer, + config.transformer.debug_transformer_memory, + ) self._residual_dtype = ( self._distributed_config.optimization_dtype if config.transformer.full_precision_residual else self._distributed_config.training_dtype ).torch - self._group_size = self._distributed_config.tensor_parallel self._sequence_parallel = self._distributed_config.sequence_tensor_parallel - self._parallel_embeddings = ( - self._tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings + self._parallel_embeddings = self._distributed_config.tensor_parallel > 1 and config.parallel_embeddings + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + vocab_dim = TensorDim( + "vocab", self._config.vocab_size, self._parallel_dim if self._parallel_embeddings else None ) - self._dropout_p = config.transformer.hidden_dropout - self._use_absolute_position_embeddings = config.use_absolute_position_embeddings - - hidden_dim = self._tensor_space[LanguageModelDimNames.hidden] - vocab_dim = self._tensor_space[ - LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ] if self._parallel_embeddings: self._vocab_start_index = self._distributed_config.tensor_rank * vocab_dim.size self._vocab_end_index = (self._distributed_config.tensor_rank + 1) * vocab_dim.size self.word_embeddings_weight = ParameterMeta.from_dims( - (vocab_dim, hidden_dim), + (vocab_dim, self._hidden_dim), init_method=init_normal_( std=config.init_method_std_embed, min_val=config.init_method_min_embed, @@ -64,9 +69,9 @@ def __init__( ), lr_scale=config.embeddings_lr_scale, ) - if self._use_absolute_position_embeddings: + if self._config.use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( - (self._tensor_space[LanguageModelDimNames.position_embed], hidden_dim), + (TensorDim("position_embeddings", self._config.max_position_embeddings), self._hidden_dim), init_method=init_normal_( std=config.init_method_std_embed, min_val=config.init_method_min_embed, @@ -85,21 +90,21 @@ def __init__( @torch.compile def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask_inputs: bool) -> torch.Tensor: - Assert.eq(position_ids is not None, self._use_absolute_position_embeddings) - group = self._tensor_space.distributed.tensor_group + Assert.eq(position_ids is not None, self._config.use_absolute_position_embeddings) + group = self._parallel_dim.group if self._parallel_embeddings: input_mask = (input_ >= self._vocab_start_index) * (input_ < self._vocab_end_index) masked_input = (input_ - self._vocab_start_index) * input_mask embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * input_mask.unsqueeze(2) # noqa embeddings = reduce_forward(embeddings, group) - if self._use_absolute_position_embeddings: + if self._config.use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) else: if self._sequence_parallel: input_ = split(input_, group=group, dim=0) - if self._use_absolute_position_embeddings: + if self._config.use_absolute_position_embeddings: position_ids = split(position_ids, group=group, dim=0) # handle masked tokens if mask_inputs: @@ -108,16 +113,14 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask embeddings = torch.embedding(self.word_embeddings_weight, masked_input) else: embeddings = torch.embedding(self.word_embeddings_weight, input_) - if self._use_absolute_position_embeddings: + if self._config.use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if mask_inputs: embeddings = embeddings * input_mask.unsqueeze(2) with set_generator( - self._tensor_space.distributed.tp_generator - if self._sequence_parallel - else self._tensor_space.distributed.pp_generator + self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator ): - embeddings = torch.dropout(embeddings, self._dropout_p, self.training) + embeddings = torch.dropout(embeddings, self._config.transformer.hidden_dropout, self.training) return embeddings.to(dtype=self._residual_dtype) def forward( diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 8624612d6..cc6c69262 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -4,25 +4,20 @@ from torch._C._distributed_c10d import ReduceOp # noqa from torch.distributed import all_reduce -from fast_llm.config import Configurable from fast_llm.core.ops import split_op from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.initialization import init_normal_ -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.engine.distributed.config import DistributedDimNames +from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl, TargetFormat, TritonConfig from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward -from fast_llm.layers.block.block import DebugLayer +from fast_llm.layers.block.block import BlockLayerBase +from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss -from fast_llm.layers.language_model.config import ( - LanguageModelBaseConfig, - LanguageModelDimNames, - LanguageModelKwargs, - LanguageModelLossNames, -) +from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import Assert, div, get_unique @@ -32,7 +27,7 @@ OUTPUT_WEIGHTS = "output_weights" -class LanguageModelHead[ConfigType: LanguageModelBaseConfig](Configurable[ConfigType], Layer): +class LanguageModelHead[ConfigType: LanguageModelBaseConfig](BlockLayerBase[ConfigType], Layer): """ A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). """ @@ -40,31 +35,28 @@ class LanguageModelHead[ConfigType: LanguageModelBaseConfig](Configurable[Config def __init__( self, config: ConfigType, - tensor_space: TensorSpace, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + # TODO: Unnecessary? + block_index: int, + name: str, prediction_distance: int, ): - super().__init__(config) - self._debug = DebugLayer( - tensor_space, - f"Language model head", - self._config.transformer.debug_transformer, - self._config.transformer.debug_transformer_memory, + super().__init__( + config, + distributed_config, + hidden_dim, + block_index, + name, + config.transformer.debug_transformer, + config.transformer.debug_transformer_memory, ) - self._tensor_space = tensor_space + self._parallel_logits = self._distributed_config.tensor_parallel > 1 and config.parallel_embeddings + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) - self._group_size = tensor_space.distributed_config.tensor_parallel - self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel - self._parallel_embeddings = ( - tensor_space.distributed_config.tensor_parallel > 1 and self._config.parallel_embeddings - ) - self._sequence_parallel_logits = ( - tensor_space.distributed_config.sequence_tensor_parallel and not self._config.parallel_embeddings - ) - self._cross_entropy_splits = self._config.cross_entropy_splits - if self._cross_entropy_splits is not None and self._sequence_parallel: - assert not self._parallel_embeddings - - hidden_dim = self._tensor_space[LanguageModelDimNames.hidden] + self._sequence_parallel_logits = self._sequence_parallel and not self._config.parallel_embeddings + if self._config.cross_entropy_splits is not None and self._sequence_parallel: + assert not self._parallel_logits self._loss_coefficient = ( self._config.prediction_loss_coefficient[prediction_distance] @@ -72,11 +64,6 @@ def __init__( else 1.0 ) self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance) - self.final_norm = self._config.transformer.normalization.get_layer(hidden_dim) - self._logits_scale_factor = self._config.logits_scale_factor - self._language_model_loss_factor = self._config.language_model_loss_factor - self._distillation_loss_factor = self._config.distillation_loss_factor - self._z_loss_factor = self._config.logit_z_loss # Distance of the target token prediction # 0: next-token prediction @@ -85,14 +72,28 @@ def __init__( self._prediction_distance = prediction_distance self._is_last_head = self._prediction_distance == self._config.prediction_heads - 1 + if not self._config.enable_dpo: + self._cross_entropy_impl = self._config.cross_entropy_impl + if self._cross_entropy_impl == CrossEntropyImpl.auto: + if self._parallel_logits: + self._cross_entropy_impl = CrossEntropyImpl.fused + elif TritonConfig.TRITON_ENABLED: + self._cross_entropy_impl = CrossEntropyImpl.triton + else: + self._cross_entropy_impl = CrossEntropyImpl.fused + + self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) + + self.final_norm = self._config.transformer.normalization.get_layer(hidden_dim) + + self._vocab_dim = TensorDim( + "vocab", self._config.vocab_size, self._parallel_dim if self._parallel_logits else None + ) # Only the first head defines the output weights if self._prediction_distance == 0 and not self._config.tie_word_embeddings: # untie embedding weights - vocab_dim = self._tensor_space[ - LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ] self.output_weights = ParameterMeta.from_dims( - (vocab_dim, hidden_dim), + (self._vocab_dim, hidden_dim), init_method=init_normal_( std=self._config.init_method_std_embed, min_val=self._config.init_method_min_embed, @@ -101,18 +102,6 @@ def __init__( lr_scale=self._config.output_lr_scale, ) - if not self._config.enable_dpo: - self._cross_entropy_impl = self._config.cross_entropy_impl - if self._cross_entropy_impl == CrossEntropyImpl.auto: - if self._parallel_embeddings: - self._cross_entropy_impl = CrossEntropyImpl.fused - elif TritonConfig.TRITON_ENABLED: - self._cross_entropy_impl = CrossEntropyImpl.triton - else: - self._cross_entropy_impl = CrossEntropyImpl.fused - - self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) - # PEFT. self.final_norm = self._config.transformer.peft.apply_other(self.final_norm) if hasattr(self, "output_weights"): @@ -123,11 +112,12 @@ def forward( ) -> torch.Tensor: if isinstance(input_, TensorMeta): if self._is_last_head: - return TensorMeta.from_tensor_space( - (DefaultDimNames.scalar,), - self._tensor_space, + return TensorMeta.from_dims( + (scalar_dim,), tensor_name="Loss", - reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa + reductions=( + (self._distributed_config.get_distributed_dim(DistributedDimNames.data), ReduceOp.AVG), + ), # noqa ) else: return TensorMeta.from_dims(input_.dims[1:], tensor_name="Shared hidden") @@ -169,19 +159,19 @@ def _forward_backward( sequence_index = 1 - int(kwargs[LanguageModelKwargs.sequence_first]) dims[sequence_index] = ( TensorDim( - LanguageModelDimNames.sequence_q_tp, + BlockDimNames.sequence_q_tp, dims[sequence_index].global_size, DistributedDimNames.tensor, ) if self._sequence_parallel_logits - else TensorDim(LanguageModelDimNames.sequence_q, dims[sequence_index].global_size) + else TensorDim(BlockDimNames.sequence_q, dims[sequence_index].global_size) ) meta = TensorMeta.from_dims(tuple(dims), tensor_name="transformer hidden_state", dtype=ln_output.dtype) - hidden_state, _ = meta.local_to_global(ln_output.detach(), distributed=self._tensor_space.distributed) + hidden_state, _ = meta.local_to_global(ln_output.detach()) kwargs["hidden_states"][len(kwargs["hidden_states"]) - 1]["tensor"] = hidden_state grad_output = kwargs[LanguageModelKwargs.grad_output] / ( - self._group_size if self._sequence_parallel_logits else 1 + self._parallel_dim.size if self._sequence_parallel_logits else 1 ) output_weights = self._get_output_weights(kwargs) @@ -215,7 +205,7 @@ def _get_targets( if loss_mask is not None: loss_mask = loss_mask.flatten() - if self._config.distillation_model is None or self._language_model_loss_factor > 0.0: + if self._config.distillation_model is None or self._config.language_model_loss_factor > 0.0: lm_target = kwargs.get(LanguageModelKwargs.labels) if lm_target is not None: # MTP: Shift the labels @@ -239,10 +229,7 @@ def _get_targets( targets = (dpo_target, lm_target, distillation_target, loss_mask) if self._sequence_parallel_logits: - targets = [ - None if target is None else split_op(target, self._tensor_space.distributed.tensor_group, 0) - for target in targets - ] + targets = [None if target is None else split_op(target, self._parallel_dim.group, 0) for target in targets] if not any(target is not None for target in targets): # Simplify so we don't have to check every time. targets = None @@ -264,7 +251,7 @@ def _logits_cross_entropy_forward_backward_split( kwargs: dict, losses: dict | None = None, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: - if self._cross_entropy_splits is None or targets is None: + if self._config.cross_entropy_splits is None or targets is None: loss, logit_input_grad = self._logits_cross_entropy_forward_backward( input_, targets, weight, grad_output, kwargs, losses ) @@ -275,17 +262,18 @@ def _logits_cross_entropy_forward_backward_split( else: loss = None # TODO MTP: allow a _cross_entropy_splits that is not a divisor of the sequence length - grad_output /= self._cross_entropy_splits + grad_output /= self._config.cross_entropy_splits logit_input = input_.flatten(0, -2) if self.training: logit_input_grad = torch.empty_like(logit_input) else: logit_input_grad = None split_size = div( - get_unique(target.size(0) for target in targets if target is not None), self._cross_entropy_splits + get_unique(target.size(0) for target in targets if target is not None), + self._config.cross_entropy_splits, ) tensors_split = [ - [None] * self._cross_entropy_splits if tensor is None else tensor.split(split_size) + [None] * self._config.cross_entropy_splits if tensor is None else tensor.split(split_size) for tensor in [logit_input, *targets, logit_input_grad] ] for logit_input_, *targets_, logit_input_grad_ in zip(*tensors_split, strict=True): @@ -301,12 +289,14 @@ def _logits_cross_entropy_forward_backward_split( logit_input_grad_.copy_(grad_) loss = loss_ if loss is None else loss + loss_ del grad_, loss_ - loss_count = (self._cross_entropy_splits or 1) * (self._group_size if self._sequence_parallel_logits else 1) + loss_count = (self._config.cross_entropy_splits or 1) * ( + self._parallel_dim.size if self._sequence_parallel_logits else 1 + ) if loss_count != 1: loss.div_(loss_count) if self._sequence_parallel_logits: # TODO: Async - all_reduce(loss, group=self._tensor_space.distributed.tensor_group) + all_reduce(loss, group=self._parallel_dim.group) return loss, logit_input_grad.view_as(input_) if logit_input_grad is not None else None def _logits_cross_entropy_forward_backward( @@ -318,43 +308,37 @@ def _logits_cross_entropy_forward_backward( kwargs: dict, losses: dict | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: + group = self._parallel_dim.group if self._parallel_logits else None logits, context = output_parallel_linear_forward( input_=input_, weight=weight, bias=None, - group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, - sequence_parallel=self._sequence_parallel and self._parallel_embeddings, + group=group, + sequence_parallel=self._sequence_parallel and self._parallel_logits, ) - if self._z_loss_factor > 0.0: + if self._config.logit_z_loss > 0.0: logits = z_loss( logits, - self._z_loss_factor, + self._config.logit_z_loss, self.training, grad_output, losses, LanguageModelLossNames.z_loss, - logits_scale_factor=self._logits_scale_factor, - ) - if self._debug.enabled and self._cross_entropy_splits is None: - vocab_dim = ( - LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp - ) - sequence_dim = ( - LanguageModelDimNames.sequence_q_tp - if self._sequence_parallel_logits - else LanguageModelDimNames.sequence_q + logits_scale_factor=self._config.logits_scale_factor, ) + if self._debug.enabled and self._config.cross_entropy_splits is None: + sequence_dim = BlockDimNames.sequence_q_tp if self._sequence_parallel_logits else BlockDimNames.sequence_q batch_dim = kwargs[LanguageModelKwargs.hidden_dims][1 if kwargs[LanguageModelKwargs.sequence_first] else 0] dims = ( - (sequence_dim, batch_dim, vocab_dim) + (sequence_dim, batch_dim, self._vocab_dim) if kwargs[LanguageModelKwargs.sequence_first] - else (batch_dim, sequence_dim, vocab_dim) + else (batch_dim, sequence_dim, self._vocab_dim) ) - self._debug(logits, "Language model logits", dims, kwargs, scale=self._logits_scale_factor) + self._debug(logits, "Language model logits", dims, kwargs, scale=self._config.logits_scale_factor) if targets is None: - return logits * self._logits_scale_factor, None + return logits * self._config.logits_scale_factor, None dpo_target, lm_target, distillation_target, loss_mask = targets if dpo_target is not None: @@ -375,25 +359,25 @@ def _logits_cross_entropy_forward_backward( logits.flatten(0, -2), lm_target, None, - group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, - grad_output=grad_output * self._loss_coefficient * self._language_model_loss_factor, + group=group, + grad_output=grad_output * self._loss_coefficient * self._config.language_model_loss_factor, implementation=self._cross_entropy_impl, - logits_scale_factor=self._logits_scale_factor, + logits_scale_factor=self._config.logits_scale_factor, target_format=TargetFormat.labels, ) - lm_loss = lm_loss * self._language_model_loss_factor + lm_loss = lm_loss * self._config.language_model_loss_factor else: lm_loss, lm_grad = None, None - if distillation_target is not None and self._distillation_loss_factor > 0.0: + if distillation_target is not None and self._config.distillation_loss_factor > 0.0: if self._config.distillation_loss_implementation == DistillationLossImpl.reverse_kl: distillation_loss, distillation_grad = reverse_kl_forward_backward( logits.flatten(0, -2), distillation_target, loss_mask, - grad_output=grad_output * self._loss_coefficient * self._distillation_loss_factor, - group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, - logits_scale_factor=self._logits_scale_factor, + grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, + group=group, + logits_scale_factor=self._config.logits_scale_factor, teacher_softmax_temperature=self._config.teacher_softmax_temperature, target_format=( TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits @@ -404,17 +388,17 @@ def _logits_cross_entropy_forward_backward( logits.flatten(0, -2), distillation_target, loss_mask, - group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, - grad_output=grad_output * self._loss_coefficient * self._distillation_loss_factor, + group=group, + grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, implementation=self._cross_entropy_impl, - logits_scale_factor=self._logits_scale_factor, + logits_scale_factor=self._config.logits_scale_factor, target_format=TargetFormat.logits, ) else: raise ValueError( f"Invalid distillation loss implementation: {self._config.distillation_loss_implementation}" ) - distillation_loss = distillation_loss * self._distillation_loss_factor + distillation_loss = distillation_loss * self._config.distillation_loss_factor else: distillation_loss, distillation_grad = None, None diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index f5d915855..5ba31c0d0 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -4,7 +4,8 @@ import torch from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import scalar_dim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert @@ -13,40 +14,31 @@ class PositionEmbeddingPreprocessor(Preprocessor): - _scalar_dim: TensorDim _rotary_embedding_frequencies: torch.Tensor _position_ids: torch.Tensor _tensor_cache_max_sequence_length: int = -1 - def __init__( - self, - config: LanguageModelBaseConfig, - tensor_space: TensorSpace, - ): + def __init__(self, config: LanguageModelBaseConfig, distributed_config: DistributedConfig): self._config = config assert config.use_absolute_position_embeddings - self._tensor_space = tensor_space - self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] + self._distributed_config = distributed_config - def _create_tensors(self, sequence_length: int) -> None: + def _create_tensors(self, sequence_length: int, device: torch.device) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: return self._tensor_cache_max_sequence_length = sequence_length Assert.leq(sequence_length, self._config.num_absolute_position_embeddings) - self._position_ids = torch.arange( - 0, sequence_length, device=self._tensor_space.distributed.device, dtype=torch.int64 - ) + self._position_ids = torch.arange(0, sequence_length, device=device, dtype=torch.int64) - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[LanguageModelKwargs.sequence_length]) + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + self._create_tensors(kwargs[LanguageModelKwargs.sequence_length], batch.device) sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size if (sequence_lengths := kwargs.get(LanguageModelKwargs.sequence_lengths)) is not None: position_ids = torch.stack( [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] - ).to(self._tensor_space.distributed.device, dtype=torch.int64) + ).to(batch.device, dtype=torch.int64) position_ids = position_ids[:, sequence_k - sequence_q : sequence_k] if kwargs[LanguageModelKwargs.sequence_first]: position_ids = position_ids.transpose(0, 1) @@ -61,9 +53,9 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: sequence_q_dim = kwargs[LanguageModelKwargs.sequence_q_dim] kwargs[LanguageModelKwargs.position_ids] = TensorMeta.from_dims( ( - (sequence_q_dim, self._scalar_dim) + (sequence_q_dim, scalar_dim) if kwargs[LanguageModelKwargs.sequence_first] - else (self._scalar_dim, sequence_q_dim) + else (scalar_dim, sequence_q_dim) ), tensor_name=LanguageModelKwargs.position_ids, dtype=torch.int64, @@ -71,11 +63,9 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: class PreferenceSpanPreprocessor(Preprocessor): - def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): + def __init__(self, config: LanguageModelBaseConfig, distributed_config: DistributedConfig): self._config = config - self._tensor_space = tensor_space - self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] + self._distributed_config = distributed_config def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: return diff --git a/fast_llm/layers/ssm/block.py b/fast_llm/layers/ssm/block.py index 987d5fa0d..361fe9818 100644 --- a/fast_llm/layers/ssm/block.py +++ b/fast_llm/layers/ssm/block.py @@ -1,34 +1,37 @@ -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.block.block import Block, BlockLayer from fast_llm.layers.block.config import BlockConfig from fast_llm.layers.ssm.config import SSMConfig # TODO: Sort out configs. -class SSMBlock[ConfigType: BlockConfig](Block[BlockConfig]): +class SSMBlock[ConfigType: BlockConfig](Block[ConfigType]): """ A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458 """ - _name = "Llamba block" - def __init__( self, - config: BlockConfig, + config: ConfigType, ssm_config: SSMConfig, - tensor_space: TensorSpace, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, mixer_cls: type[BlockLayer], block_index: int, + name: str, return_input: bool = False, ): self._ssm_config = ssm_config self._mixer_cls = mixer_cls - super().__init__(config, tensor_space, block_index, return_input) + super().__init__(config, distributed_config, hidden_dim, block_index, name, return_input) def _create_mixer(self) -> BlockLayer: return self._mixer_cls( self._ssm_config, - tensor_space=self._tensor_space, - block_index=self._block_index, - block_config=self._config, + self._config, + self._distributed_config, + self._hidden_dim, + self._block_index, + f"{self._name} mixer", ) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index dec0675b9..2daad1186 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -2,11 +2,9 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, TensorSpace -from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.block.config import BlockDimNames -from fast_llm.utils import Assert, div +from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.engine.config_utils.initialization import Initializer @@ -46,9 +44,9 @@ class SSMBlockType(enum.StrEnum): def get_mixer_class(self): if self == SSMBlockType.mamba: - from fast_llm.layers.ssm.mamba_layer import MambaLayer + from fast_llm.layers.ssm.mamba import Mamba - return MambaLayer + return Mamba elif self == SSMBlockType.mamba2: from fast_llm.layers.ssm.mamba2 import Mamba2 @@ -79,21 +77,21 @@ class SSMConfig(Config): # TODO: Remove (redundant default) expansion_factor: int = Field( default=2, - desc="Expansion factor for Mamba blocks.", + desc="Expansion factor.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) # head_size [MambaLayer, Mamba2, DiscreteMamba2] state_size: int = Field( default=16, - desc="State size for Mamba blocks.", + desc="State size.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) # [MambaLayer, Mamba2, DiscreteMamba2] conv_kernel_dimension: int = Field( default=4, - desc="Conv kernel dimension for Mamba blocks.", + desc="Conv kernel dimension.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) @@ -106,19 +104,19 @@ class SSMConfig(Config): # head_groups [DiscreteMamba2] n_qk_heads: int = Field( default=32, - desc="Number of QK heads for Mamba2 blocks.", + desc="Number of QK heads.", hint=FieldHint.architecture, ) # heads [DiscreteMamba2]# TODO: Remove? (redundant) n_v_heads: int = Field( default=32, - desc="Number of V heads for Mamba2 blocks.", + desc="Number of V heads.", hint=FieldHint.architecture, ) # c_size [MambaLayer, Mamba2, DiscreteMamba2]? d_inner: None | int = Field( default=None, - desc="Inner dimension for Mamba2 blocks.", + desc="Inner dimension.", hint=FieldHint.core, ) # xb_size [Mamba2] @@ -204,79 +202,3 @@ def _validate(self) -> None: self.activation_type = ActivationType.silu super()._validate() Assert.geq(self.dt_max, self.dt_min) - - def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType) -> None: - tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - - # Head groups are configured differently depending on the block type. - if block_type == SSMBlockType.mamba: - num_heads = div(self.d_inner, self.state_size) - num_head_groups = num_heads - elif block_type == SSMBlockType.mamba2: - num_heads = div(self.d_inner, self.state_size) - num_head_groups = div(self.d_xb, self.state_size) - elif block_type == SSMBlockType.mamba2_discrete: - # TODO: Use different variables? - num_heads = self.n_v_heads - num_head_groups = self.n_qk_heads - else: - raise NotImplementedError(block_type) - - tensor_space.add_tensor_dim(state := TensorDim(SSMDimNames.state, self.state_size)) - if block_type == SSMBlockType.mamba2_discrete: - tensor_space.add_tensor_dim(head_dim := TensorDim(SSMDimNames.head_dim, div(self.d_inner, num_heads))) - else: - head_dim = state - - tensor_space.add_tensor_dim(head_groups := TensorDim(SSMDimNames.head_groups, num_head_groups, tensor)) - tensor_space.add_tensor_dim(group_heads := TensorDim(SSMDimNames.group_heads, div(num_heads, num_head_groups))) - tensor_space.add_tensor_dim( - heads := CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads)) - ) - tensor_space.add_tensor_dim( - heads_and_head_dim := CompositeTensorDim( - SSMDimNames.composite_heads_and_head_dim, (head_groups, group_heads, head_dim) - ) - ) - tensor_space.add_tensor_dim( - head_groups_and_state := CompositeTensorDim( - SSMDimNames.composite_head_groups_and_state, (head_groups, state) - ) - ) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.convolution_kernel, self.conv_kernel_dimension)) - - # DT projection - if block_type in (SSMBlockType.mamba, SSMBlockType.mamba2): - tensor_space.add_tensor_dim(dt_rank := TensorDim(SSMDimNames.dt_rank, self.dt_rank)) - - if block_type == SSMBlockType.mamba: - tensor_space.add_tensor_dim( - ConcatenatedTensorDim(SSMDimNames.concatenated_x_projection, (dt_rank, state, state)) - ) - # TODO: Use composition instead - tensor_space.add_tensor_dim( - ConcatenatedTensorDim( - SSMDimNames.concatenated_inner_projection, (heads_and_head_dim, heads_and_head_dim) - ) - ) - elif block_type == SSMBlockType.mamba2: - # TODO: Factor out state? - tensor_space.add_tensor_dim( - ConcatenatedTensorDim( - SSMDimNames.concatenated_inner_projection, - (heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim), - ) - ) - elif block_type == SSMBlockType.mamba2_discrete: - tensor_space.add_tensor_dim( - ConcatenatedTensorDim( - SSMDimNames.concatenated_inner_projection, - (heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim, heads), - ) - ) - tensor_space.add_tensor_dim( - ConcatenatedTensorDim( - SSMDimNames.concatenated_convolution, - (heads_and_head_dim, head_groups_and_state, head_groups_and_state), - ) - ) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 61291f845..7e445cca1 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -5,15 +5,16 @@ import torch from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_, init_zeros_ -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, scalar_dim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.ssm.mamba_layer import init_kaiming_ +from fast_llm.layers.ssm.mamba import init_kaiming_ from fast_llm.tensor import ParameterMeta -from fast_llm.utils import get_lr_scale +from fast_llm.utils import div, get_lr_scale logger = logging.getLogger(__name__) @@ -34,48 +35,69 @@ _causal_conv1d_available = False -class DiscreteMamba2(BlockLayer): - """DiscreteMamba2 (This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py).""" +class DiscreteMamba2[ConfigType: SSMConfig](BlockLayer[ConfigType]): + """ + This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py + """ _mixer_name: typing.ClassVar[str] = "discrete_mamba_2" + _config: SSMConfig def __init__( self, - config: SSMConfig, - block_index: int, - tensor_space: TensorSpace, + config: ConfigType, block_config: BlockConfig, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, ): super().__init__( - tensor_space, + config, + distributed_config, + hidden_dim, block_index, - self._mixer_name, - debug_level=block_config.debug_transformer, - debug_memory=block_config.debug_transformer_memory, + name, + block_config.debug_transformer, + block_config.debug_transformer_memory, ) - self._config: SSMConfig = config - layer_lr_scale = block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None - lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + state_dim = TensorDim("state", self._config.state_size) + v_head_size_dim = TensorDim(SSMDimNames.head_dim, div(self._config.d_inner, self._config.n_v_heads)) + + head_groups_dim = TensorDim( + SSMDimNames.head_groups, + self._config.n_qk_heads, + self._distributed_config.get_distributed_dim(DistributedDimNames.tensor), + ) + group_heads_dim = TensorDim(SSMDimNames.group_heads, div(self._config.n_v_heads, self._config.n_qk_heads)) + heads_dim = CompositeTensorDim(SSMDimNames.composite_heads, (head_groups_dim, group_heads_dim)) + inner_dim = CompositeTensorDim("inner", (head_groups_dim, group_heads_dim, v_head_size_dim)) + bc_dim = CompositeTensorDim("bc", (head_groups_dim, state_dim)) + convolution_kernel_dim = TensorDim("convolution_kernel", self._config.conv_kernel_dimension) - inner_dim = tensor_space[SSMDimNames.composite_heads_and_head_dim] - hidden_dim = tensor_space[SSMDimNames.hidden] - conv1d_dim = tensor_space[SSMDimNames.concatenated_convolution] - heads_dim = tensor_space[SSMDimNames.composite_heads] + inner_projection_dim = ConcatenatedTensorDim( + "inner_projection", + (inner_dim, bc_dim, bc_dim, inner_dim, heads_dim), + ) + convolution_dim = ConcatenatedTensorDim("convolution", (inner_dim, bc_dim, bc_dim)) # local_head_groups = head_groups / TP - self._local_head_groups = tensor_space[SSMDimNames.head_groups].size + self._local_head_groups = head_groups_dim.size # local_heads = local_head_groups * group_heads self._local_heads = heads_dim.size # local_inner_size = local_heads * head_size self._local_inner_size = inner_dim.size # local_bc_size = local_head_groups * state - self._local_bc_size = tensor_space[SSMDimNames.composite_head_groups_and_state].size + self._local_bc_size = bc_dim.size + + layer_lr_scale = block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None + lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) # TODO: double check initializations # Projections self.in_proj = OutputParallelLinear( hidden_dim, - tensor_space[SSMDimNames.concatenated_inner_projection], + inner_projection_dim, bias=config.add_bias_linear, weight_init_method=init_kaiming_(block_config.hidden_size), sequence_parallel=self._sequence_parallel, @@ -90,15 +112,17 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( ( - conv1d_dim, - tensor_space[DefaultDimNames.scalar], - tensor_space[SSMDimNames.convolution_kernel], + convolution_dim, + scalar_dim, + convolution_kernel_dim, + ), + init_method=init_uniform_centered_( + (convolution_dim.global_size * self._config.conv_kernel_dimension) ** -0.5 ), - init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( - (conv1d_dim,), + (convolution_dim,), init_method=init_uniform_centered_(self._config.conv_kernel_dimension**-0.5), lr_scale=lr_scale, ) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba.py similarity index 79% rename from fast_llm/layers/ssm/mamba_layer.py rename to fast_llm/layers/ssm/mamba.py index 0dcc29f0b..ac6576a87 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba.py @@ -5,14 +5,15 @@ import torch from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, scalar_dim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.common.linear import Linear -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.ssm.config import SSMConfig from fast_llm.tensor import ParameterMeta -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.utils import Assert, div, get_lr_scale try: from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn as _mamba_inner_fn # noqa @@ -53,31 +54,40 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) return LambdaInitializer(init_) -class MambaLayer(BlockLayer): +class Mamba[ConfigType: SSMConfig](BlockLayer[ConfigType]): _mixer_name: typing.ClassVar[str] = "mamba" def __init__( self, - config: SSMConfig, - block_index: int, - tensor_space: TensorSpace, + config: ConfigType, block_config: BlockConfig, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, ): super().__init__( - tensor_space, + config, + distributed_config, + hidden_dim, block_index, - self._mixer_name, - debug_level=block_config.debug_transformer, - debug_memory=block_config.debug_transformer_memory, + name, + block_config.debug_transformer, + block_config.debug_transformer_memory, ) - assert tensor_space.distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" - self._config = config + assert self._distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" # TODO: It's not silu? Assert.eq(self._config.activation_type, ActivationType.silu) # Tensor dims: - inner_dim = tensor_space[SSMDimNames.composite_heads_and_head_dim] - hidden_dim = tensor_space[SSMDimNames.hidden] + heads_dim = TensorDim("heads", div(self._config.d_inner, self._config.state_size)) + state_dim = TensorDim("state", self._config.state_size) + inner_dim = CompositeTensorDim("inner", (heads_dim, state_dim)) + convolution_kernel_dim = TensorDim("convolution_kernel", self._config.conv_kernel_dimension) + dt_rank_dim = TensorDim("dt_rank", self._config.dt_rank) + inner_projection_dim = ConcatenatedTensorDim("inner_projection", (inner_dim, inner_dim)) + x_projection_dim = ConcatenatedTensorDim("x_projection", (dt_rank_dim, state_dim, state_dim)) + layer_lr_scale = block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) @@ -85,7 +95,7 @@ def __init__( # TODO: lr_scale? self.in_proj = Linear( hidden_dim, - tensor_space[SSMDimNames.concatenated_inner_projection], + inner_projection_dim, bias=False, weight_init_method=init_kaiming_(hidden_dim.size), ) @@ -93,8 +103,8 @@ def __init__( self.conv1d_weight = ParameterMeta.from_dims( ( inner_dim, - tensor_space[DefaultDimNames.scalar], - tensor_space[SSMDimNames.convolution_kernel], + scalar_dim, + convolution_kernel_dim, ), init_method=init_kaiming_(inner_dim.size), lr_scale=lr_scale, @@ -102,7 +112,7 @@ def __init__( self.x_proj = Linear( inner_dim, - tensor_space[SSMDimNames.concatenated_x_projection], + x_projection_dim, weight_init_method=init_kaiming_(inner_dim.size), bias=False, lr_scale=lr_scale, @@ -111,7 +121,7 @@ def __init__( # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 self.dt_proj_weight = ParameterMeta.from_dims( - (inner_dim, tensor_space[SSMDimNames.dt_rank]), + (inner_dim, dt_rank_dim), init_method=init_kaiming_(self._config.dt_rank), lr_scale=lr_scale, ) @@ -123,7 +133,7 @@ def __init__( ) self.A_log = ParameterMeta.from_dims( - (inner_dim, tensor_space[SSMDimNames.state]), + (inner_dim, state_dim), weight_decay=False, init_method=init_A(self._config.state_size, inner_dim.size), lr_scale=lr_scale, diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index b6626e893..e6ca9ea12 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -4,13 +4,14 @@ import torch from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_ -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, scalar_dim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer -from fast_llm.layers.block.config import BlockConfig, BlockKwargs +from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias, init_kaiming_ +from fast_llm.layers.ssm.config import SSMConfig +from fast_llm.layers.ssm.mamba import init_A, init_dtprojbias, init_kaiming_ from fast_llm.tensor import ParameterMeta from fast_llm.utils import Assert, div, get_lr_scale @@ -31,38 +32,30 @@ logger = logging.getLogger(__name__) -class Mamba2(BlockLayer): +class Mamba2[ConfigType: SSMConfig](BlockLayer[ConfigType]): """ This code is adapted from https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py """ _mixer_name: typing.ClassVar[str] = "mamba_2" - _XZ_DIMS = ( - SSMDimNames.batch, - SSMDimNames.composite_heads_and_head_dim, - SSMDimNames.sequence_q, - ) - _BC_DIMS = ( - SSMDimNames.batch, - SSMDimNames.composite_heads, - SSMDimNames.state, - SSMDimNames.sequence_q, - ) - def __init__( self, - config: SSMConfig, - tensor_space: TensorSpace, - block_index: int, + config: ConfigType, block_config: BlockConfig, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, ): super().__init__( - tensor_space, + config, + distributed_config, + hidden_dim, block_index, - self._mixer_name, - debug_level=block_config.debug_transformer, - debug_memory=block_config.debug_transformer_memory, + name, + block_config.debug_transformer, + block_config.debug_transformer_memory, ) self._config: SSMConfig = config Assert.eq(self._config.activation_type, ActivationType.silu) @@ -71,13 +64,32 @@ def __init__( ) lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - inner_dim: TensorDim = tensor_space[SSMDimNames.composite_heads_and_head_dim] - xb_dim = tensor_space[SSMDimNames.composite_head_groups_and_state] - hidden_dim: TensorDim = tensor_space[SSMDimNames.hidden] - dt_rank_dim = tensor_space[SSMDimNames.dt_rank] + num_heads = div(self._config.d_inner, self._config.state_size) + num_head_groups = div(self._config.d_xb, self._config.state_size) + + state_dim = TensorDim("state", self._config.state_size) + + head_groups_dim = TensorDim( + "head_groups", num_head_groups, self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + ) + group_heads_dim = TensorDim("group_heads", div(num_heads, num_head_groups)) + + heads_dim = CompositeTensorDim("heads", (head_groups_dim, group_heads_dim)) - self._local_heads = tensor_space[SSMDimNames.composite_heads].size - self._local_head_groups = tensor_space[SSMDimNames.head_groups].size + inner_dim = CompositeTensorDim("inner", (head_groups_dim, group_heads_dim, state_dim)) + xb_dim = CompositeTensorDim("xb", (head_groups_dim, state_dim)) + convolution_kernel_dim = TensorDim("convolution_kernel", self._config.conv_kernel_dimension) + + # DT projection + dt_rank_dim = TensorDim("dt_rank", self._config.dt_rank) + + inner_projection_dim = ConcatenatedTensorDim( + "inner_projection", + (inner_dim, xb_dim, xb_dim, inner_dim), + ) + + self._local_heads = heads_dim.size + self._local_head_groups = head_groups_dim.size self._group_heads = div(self._local_heads, self._local_head_groups) self._local_inner_size = inner_dim.size self._local_xb_size = xb_dim.size @@ -86,8 +98,8 @@ def __init__( self.conv1d_weight = ParameterMeta.from_dims( ( conv1d_dim, - tensor_space[DefaultDimNames.scalar], - tensor_space[SSMDimNames.convolution_kernel], + scalar_dim, + convolution_kernel_dim, ), init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, @@ -99,7 +111,7 @@ def __init__( ) self.in_proj = OutputParallelLinear( hidden_dim, - tensor_space[SSMDimNames.concatenated_inner_projection], + inner_projection_dim, bias=config.add_bias_linear, weight_init_method=init_kaiming_(block_config.hidden_size), sequence_parallel=self._sequence_parallel, @@ -131,7 +143,7 @@ def __init__( lr_scale=lr_scale, ) self.A_log = ParameterMeta.from_dims( - (inner_dim, tensor_space[SSMDimNames.state]), + (inner_dim, state_dim), init_method=init_A(self._config.state_size, self._config.d_inner), lr_scale=lr_scale, weight_decay=False, @@ -151,6 +163,19 @@ def __init__( # TODO: lr_scale? ) + if self._debug.enabled: + _xz_dims = ( + BlockDimNames.batch, + inner_dim, + BlockDimNames.sequence_q, + ) + _bc_dims = ( + BlockDimNames.batch, + heads_dim, + state_dim, + BlockDimNames.sequence_q, + ) + def forward( self, input_: torch.Tensor, diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index ba7f2bb6e..d7a669295 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -5,13 +5,15 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_ -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.block.block import BlockLayer +from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs, TransformerConfig -from fast_llm.utils import get_lr_scale +from fast_llm.layers.transformer.config import AttentionKwargs, TransformerConfig +from fast_llm.utils import div, get_lr_scale try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -46,41 +48,58 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention(BlockLayer): +class Attention[ConfigType: TransformerConfig](BlockLayer[ConfigType]): """ A self-attention layer. """ - _mixer_name: typing.ClassVar[str] = "attn" - - _QUERY_DIMS = ( - AttentionDimNames.batch, - AttentionDimNames.sequence_q, - AttentionDimNames.composite_heads, - AttentionDimNames.kv_channels, - ) - _KV_DIMS = ( - AttentionDimNames.batch, - AttentionDimNames.sequence_q, - AttentionDimNames.head_groups, - AttentionDimNames.kv_channels, - ) - _CONTEXT_DIMS = ( - AttentionDimNames.batch, - AttentionDimNames.sequence_q, - AttentionDimNames.composite_dense, - ) - - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int): + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + ): super().__init__( - tensor_space, + config, + distributed_config, + hidden_dim, block_index, - self._mixer_name, - debug_level=config.debug_transformer, - debug_memory=config.debug_transformer_memory, + name, + config.debug_transformer, + config.debug_transformer_memory, + ) + self._use_flash_attention = self._config.do_use_flash_attention(self._distributed_config) + + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + self._sequence_data_parallel_dim = self._distributed_config.get_distributed_dim( + DistributedDimNames.sequence_data ) - self._config = config - self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) + head_group_dim = TensorDim( + "head_groups", self._config.head_groups, self._parallel_dim if self._config.head_groups > 1 else None + ) + group_heads_dim = TensorDim( + "group_heads", + div(self._config.num_attention_heads, self._config.head_groups), + None if self._config.head_groups > 1 else self._parallel_dim, + ) + self._local_head_groups = head_group_dim.size + self._local_heads_per_group = group_heads_dim.size + self._local_heads = self._local_head_groups * self._local_heads_per_group + + kv_channels_dim = TensorDim("kv_channels", self._config.kv_channels) + query_dim = CompositeTensorDim("query", (head_group_dim, group_heads_dim, kv_channels_dim)) + key_value_dim = ConcatenatedTensorDim( + "key_value", + ( + CompositeTensorDim("key", (head_group_dim, kv_channels_dim)), + CompositeTensorDim("value", (head_group_dim, kv_channels_dim)), + ), + ) + dense_dim = CompositeTensorDim("dense", (head_group_dim, group_heads_dim, kv_channels_dim)) + + self._softmax_scale = self._config.kv_channels ** (-self._config.attention_softmax_scale_power) init_method_qkv = init_normal_( std=self._config.init_method_std_qkv, @@ -93,22 +112,13 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i max_val=self._config.init_method_max_attn_proj, ) - self._kv_channels = self._tensor_space[AttentionDimNames.kv_channels].size - self._head_groups = self._tensor_space[AttentionDimNames.head_groups].global_size - self._local_head_groups = self._tensor_space[AttentionDimNames.head_groups].size - self._local_heads_per_group = self._tensor_space[AttentionDimNames.group_heads].size - self._local_heads = self._local_head_groups * self._local_heads_per_group - self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) - - hidden_dim = self._tensor_space[AttentionDimNames.hidden] - layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, - self._tensor_space[AttentionDimNames.composite_query], + query_dim, bias=self._config.add_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_zeros_, @@ -117,7 +127,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i ) self.key_value = OutputParallelLinear( hidden_dim, - self._tensor_space[AttentionDimNames.composite_key_value], + key_value_dim, bias=self._config.add_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_zeros_, @@ -127,11 +137,11 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward) # Rotary embeddings. - self._rotary = self._config.rotary.build() + self._rotary = self._config.rotary.build(kv_channels_dim) # Output. self.dense = InputParallelLinear( - self._tensor_space[AttentionDimNames.composite_dense], + dense_dim, hidden_dim, bias=self._config.add_dense_bias, weight_init_method=init_method_std_attn_proj, @@ -145,6 +155,25 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i self.key_value = self._config.peft.apply_linear(self.key_value, TransformerSubLayerName.key_value) self.dense = self._config.peft.apply_linear(self.dense, TransformerSubLayerName.dense) + if self._debug.enabled: + self._query_dims = ( + BlockDimNames.batch, + BlockDimNames.sequence_q, + CompositeTensorDim("heads", (head_group_dim, group_heads_dim)), + kv_channels_dim, + ) + self._kv_dims = ( + BlockDimNames.batch, + BlockDimNames.sequence_q, + head_group_dim, + kv_channels_dim, + ) + self._context_dims = ( + BlockDimNames.batch, + BlockDimNames.sequence_q, + dense_dim, + ) + def _attn_fused( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor ) -> torch.Tensor: @@ -153,16 +182,18 @@ def _attn_fused( sk = key.size(1) if self._local_head_groups == 1: - query = query.view(b, sq * self._local_heads, self._kv_channels) + query = query.view(b, sq * self._local_heads, self._config.kv_channels) key = key.transpose(-1, -2) else: query = ( - query.unflatten(-1, (self._local_head_groups, self._local_heads_per_group, self._kv_channels)) + query.unflatten(-1, (self._local_head_groups, self._local_heads_per_group, self._config.kv_channels)) .transpose(1, 2) - .reshape(b * self._local_head_groups, sq * self._local_heads_per_group, self._kv_channels) + .reshape(b * self._local_head_groups, sq * self._local_heads_per_group, self._config.kv_channels) + ) + key = key.unflatten(-1, (self._local_head_groups, self._config.kv_channels)).movedim(1, 3).flatten(0, 1) + value = ( + value.unflatten(-1, (self._local_head_groups, self._config.kv_channels)).transpose(1, 2).flatten(0, 1) ) - key = key.unflatten(-1, (self._local_head_groups, self._kv_channels)).movedim(1, 3).flatten(0, 1) - value = value.unflatten(-1, (self._local_head_groups, self._kv_channels)).transpose(1, 2).flatten(0, 1) attn_weights = torch.empty( (b * self._local_head_groups, sq * self._local_heads_per_group, sk), device=query.device, dtype=query.dtype @@ -179,7 +210,7 @@ def _attn_fused( attn_weights = torch.where(mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) - with set_generator(self._tensor_space.distributed.tp_generator): + with set_generator(self._distributed.tp_generator): attn_weights = torch.dropout(attn_weights, self._config.attention_dropout, self.training) attn_output = torch.bmm( attn_weights.view(b * self._local_head_groups, sq * self._local_heads_per_group, sk), value @@ -189,7 +220,7 @@ def _attn_fused( return attn_output.view(b, sq, -1) else: return ( - attn_output.view(b, self._local_head_groups, sq, self._local_heads_per_group, self._kv_channels) + attn_output.view(b, self._local_head_groups, sq, self._local_heads_per_group, self._config.kv_channels) .transpose(1, 2) .flatten(2) ) @@ -201,18 +232,16 @@ def _query_key_value_forward( handle = None - if self._head_groups == 1 and self._sequence_parallel: - key_value, handle = gather_op( - key_value, group=self._tensor_space.distributed.tensor_group, dim=0, async_op=True - ) + if self._config.head_groups == 1 and self._sequence_parallel: + key_value, handle = gather_op(key_value, group=self._parallel_dim.group, dim=0, async_op=True) - if self._tensor_space.distributed.sequence_data_group: + if self._sequence_data_parallel_dim.group: if handle: # TODO: This is probably unnecessary. handle.wait() # sequence dim may not be zero, but this needs to be handled after `handle.wait()` key_value, handle = gather_op( - key_value, group=self._tensor_space.distributed.sequence_data_group, dim=0, async_op=True + key_value, group=self._sequence_data_parallel_dim.group, dim=0, async_op=True ) query, query_context = self.query.forward_only(input_) @@ -220,8 +249,8 @@ def _query_key_value_forward( if handle: handle.wait() - if self._tensor_space.distributed.sequence_data_group and not sequence_first: - key_value = swap_mult_dim(key_value, self._tensor_space.distributed_config.sequence_data_parallel, 0, 1) + if self._sequence_data_parallel_dim.group and not sequence_first: + key_value = swap_mult_dim(key_value, self._sequence_parallel, 0, 1) context = {"query": query_context, "key_value": key_value_context, "sequence_first": sequence_first} return query, key_value, context @@ -230,15 +259,12 @@ def _query_key_value_backward( self, query_grad: torch.Tensor, key_value_grad: torch.Tensor, context: dict ) -> torch.Tensor: # TODO: De-allocate qkv grads quicker. - handle = None - - if self._tensor_space.distributed.sequence_data_group: - key_value_grad, handle = reduce_scatter_op( - key_value_grad, - group=self._tensor_space.distributed.sequence_data_group, - dim=1 - context["sequence_first"], - async_op=True, - ) + key_value_grad, handle = reduce_scatter_op( + key_value_grad, + group=self._sequence_data_parallel_dim.group, + dim=1 - context["sequence_first"], + async_op=True, + ) # TODO: Overlap with both. input_grad = self.query.backward(query_grad, context.pop("query")) @@ -246,7 +272,7 @@ def _query_key_value_backward( if handle: handle.wait() - if self._head_groups == 1 and (group := self._tensor_space.distributed.tensor_group): + if self._config.head_groups == 1 and (group := self._parallel_dim.group): if self._sequence_parallel: key_value_grad = reduce_scatter_op(key_value_grad, group=group, dim=0) else: @@ -289,7 +315,7 @@ def forward( # Manually add the gradients from later micro-sequences. key_value = AttachGrad.apply(key_value, present) - if self._tensor_space.distributed.sequence_data_group: + if self._sequence_data_parallel_dim.group: key_value = ( key_value[: kwargs[AttentionKwargs.sequence_k_dim].size] if sequence_first @@ -301,11 +327,11 @@ def forward( query = query.transpose(0, 1).contiguous() key_value = key_value.transpose(0, 1).contiguous() - key, value = key_value.split(self._local_head_groups * self._kv_channels, dim=-1) + key, value = key_value.split(self._local_head_groups * self._config.kv_channels, dim=-1) - query = query.view(*query.shape[:2], self._local_heads, self._kv_channels) - key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels) - value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels) + query = query.view(*query.shape[:2], self._local_heads, self._config.kv_channels) + key = key.view(*key.shape[:2], self._local_head_groups, self._config.kv_channels) + value = value.view(*value.shape[:2], self._local_head_groups, self._config.kv_channels) if self._debug.enabled: self._debug(query, "query_rotary_input", self._QUERY_DIMS, kwargs) @@ -316,7 +342,7 @@ def forward( if self._use_flash_attention: assert _flash_available - with set_generator(self._tensor_space.distributed.tp_generator): + with set_generator(self._distributed.tp_generator): if (cu_seqlens_q := kwargs.get(AttentionKwargs.cu_seqlens_q, None)) is not None: out_dims = query.size() query = query.view(-1, query.size(-2), query.size(-1)) @@ -357,10 +383,10 @@ def forward( ) if self._debug.enabled: - self._debug(query, "query", self._QUERY_DIMS, kwargs) - self._debug(key, "key", self._KV_DIMS, kwargs) - self._debug(value, "value", self._KV_DIMS, kwargs) - self._debug(input_, "context", self._CONTEXT_DIMS, kwargs) + self._debug(query, "query", self._query_dims, kwargs) + self._debug(key, "key", self._kv_dims, kwargs) + self._debug(value, "value", self._kv_dims, kwargs) + self._debug(input_, "context", self._context_dims, kwargs) if sequence_first: # TODO: Optimize (is contiguous avoidable? Transpose dense output?) diff --git a/fast_llm/layers/transformer/block.py b/fast_llm/layers/transformer/block.py index 89d7a2e3b..a5aad45a9 100644 --- a/fast_llm/layers/transformer/block.py +++ b/fast_llm/layers/transformer/block.py @@ -1,7 +1,6 @@ import logging import typing -from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.block.block import Block, BlockLayer from fast_llm.layers.transformer.attention import Attention from fast_llm.layers.transformer.config import TransformerConfig @@ -10,13 +9,10 @@ class TransformerBlock[ConfigType: TransformerConfig](Block[ConfigType]): - _name = "Transformer layer" # TODO: Standardize to `mixer` _mixer_module_name: typing.ClassVar[str] = "self_attn" - _config: TransformerConfig - - def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, return_input: bool = False): - super().__init__(config, tensor_space, block_index, return_input) def _create_mixer(self) -> BlockLayer: - return Attention(self._config, self._tensor_space, self._block_index) + return Attention( + self._config, self._distributed_config, self._hidden_dim, self._block_index, f"{self._name} attn" + ) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index f7c7fea9c..a40f676ca 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -5,10 +5,9 @@ from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace -from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig -from fast_llm.layers.block.config import AddLinearBiasChoices, BlockConfig, BlockDimNames, BlockKwargs +from fast_llm.layers.block.config import AddLinearBiasChoices, BlockConfig, BlockKwargs from fast_llm.layers.transformer.rotary.config import RotaryConfig from fast_llm.utils import Assert, div @@ -18,19 +17,6 @@ logger = logging.getLogger(__name__) -class AttentionDimNames(BlockDimNames): - # A set of common tensor dim names packed into a namespace. - # Self-attention dimensions - head_groups = "head_groups" - group_heads = "group_heads" - key_and_value = "key_value" - kv_channels = "kv_channels" - composite_heads = "composite_heads" - composite_query = "composite_query" - composite_key_value = "composite_key_value" - composite_dense = "composite_dense" - - class AttentionKwargs(BlockKwargs): rotary_freq_q = "rotary_freq_q" rotary_freq_k = "rotary_freq_k" @@ -180,36 +166,6 @@ def projection_size(self): def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - # Needed for multiple inheritance. - super().setup_tensor_space(tensor_space) # Noqa - - tensor_space.add_tensor_dim( - head_groups := TensorDim( - AttentionDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None - ) - ) - tensor_space.add_tensor_dim( - group_heads := TensorDim( - AttentionDimNames.group_heads, - div(self.num_attention_heads, self.head_groups), - None if self.head_groups > 1 else tensor, - ) - ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(AttentionDimNames.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(AttentionDimNames.kv_channels, self.kv_channels)) - tensor_space.add_tensor_dim(CompositeTensorDim(AttentionDimNames.composite_heads, (head_groups, group_heads))) - tensor_space.add_tensor_dim( - CompositeTensorDim(AttentionDimNames.composite_query, (head_groups, group_heads, kv_channels)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(AttentionDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(AttentionDimNames.composite_dense, (head_groups, group_heads, kv_channels)) - ) - @property def add_qkv_bias(self) -> bool: # TODO: Make this work without inheritance. diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 16e5811e6..769177668 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -4,7 +4,8 @@ import torch from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.transformer.config import AttentionConfig, AttentionKwargs from fast_llm.tensor import TensorMeta @@ -12,25 +13,18 @@ class BackupAttentionPreprocessor(Preprocessor): - _scalar_dim: TensorDim _kv_channels_dim: TensorDim _rotary_embedding_frequencies: torch.Tensor _mask: torch.Tensor _mask_value: torch.Tensor _tensor_cache_max_sequence_length: int = -1 - def __init__( - self, - config: AttentionConfig, - tensor_space: TensorSpace, - ): + def __init__(self, config: AttentionConfig, distributed_config: DistributedConfig): self._config = config - self._tensor_space = tensor_space - self._distributed_config = self._tensor_space.distributed_config + self._distributed_config = distributed_config assert not self._config.do_use_flash_attention(self._distributed_config) - self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] - def _create_tensors(self, sequence_length: int) -> None: + def _create_tensors(self, sequence_length: int, device: torch.device) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: return self._tensor_cache_max_sequence_length = sequence_length @@ -38,7 +32,7 @@ def _create_tensors(self, sequence_length: int) -> None: self._mask = torch.ones( (sequence_length, sequence_length), dtype=torch.bool, - device=self._tensor_space.distributed.device, + device=device, ).tril_() if self._config.window_size is not None: @@ -47,11 +41,11 @@ def _create_tensors(self, sequence_length: int) -> None: [], torch.finfo(self._distributed_config.training_dtype.torch).min, dtype=self._distributed_config.training_dtype.torch, - device=self._tensor_space.distributed.device, + device=device, ) - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[AttentionKwargs.sequence_length]) + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + self._create_tensors(kwargs[AttentionKwargs.sequence_length], batch.device) sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size kwargs[AttentionKwargs.attention_mask] = self._mask[ @@ -64,7 +58,7 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: for sample_lens in sequence_lengths ] ) - document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(self._tensor_space.distributed.device) + document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(batch.device) kwargs[AttentionKwargs.attention_mask] = ( kwargs[AttentionKwargs.attention_mask] & document_mask[:, None, sequence_k - sequence_q : sequence_k, None, :sequence_k] @@ -74,30 +68,29 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: kwargs[AttentionKwargs.attention_mask] = TensorMeta.from_dims( ( - self._scalar_dim, - self._scalar_dim, + scalar_dim, + scalar_dim, kwargs[AttentionKwargs.sequence_q_dim], - self._scalar_dim, + scalar_dim, kwargs[AttentionKwargs.sequence_k_dim], ), tensor_name=AttentionKwargs.attention_mask, dtype=torch.bool, ) kwargs[AttentionKwargs.attention_mask_value] = TensorMeta.from_dims( - (self._scalar_dim,), + (scalar_dim,), tensor_name=AttentionKwargs.attention_mask_value, - dtype=self._tensor_space.distributed_config.training_dtype.torch, + dtype=self._distributed_config.training_dtype.torch, ) class FlashAttnVarlenPreprocessor(Preprocessor): - def __init__(self, config: AttentionConfig, tensor_space: TensorSpace): + def __init__(self, config: AttentionConfig, distributed_config: DistributedConfig): self._config = config - self._tensor_space = tensor_space - self._distributed_config = self._tensor_space.distributed_config + self._distributed_config = distributed_config assert self._config.do_use_flash_attention(self._distributed_config) - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: """ Prepares cu_seqlens_q and cu_seqlens_k for flash_attn_varlen_func: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1375 @@ -148,14 +141,14 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: seqlens_k = torch.cat(sequence_lengths) kwargs[AttentionKwargs.cu_seqlens_q] = torch.cat( ( - torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), - torch.cumsum(seqlens_q, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), + torch.zeros(1, dtype=torch.int32, device=batch.device), + torch.cumsum(seqlens_q, dim=0, dtype=torch.int32).to(batch.device), ) ) kwargs[AttentionKwargs.cu_seqlens_k] = torch.cat( ( - torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), - torch.cumsum(seqlens_k, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), + torch.zeros(1, dtype=torch.int32, device=batch.device), + torch.cumsum(seqlens_k, dim=0, dtype=torch.int32).to(batch.device), ) ) kwargs[AttentionKwargs.max_seqlen_q] = seqlens_q.max() diff --git a/fast_llm/layers/transformer/rotary/config.py b/fast_llm/layers/transformer/rotary/config.py index 748f2af28..f0e0079c7 100644 --- a/fast_llm/layers/transformer/rotary/config.py +++ b/fast_llm/layers/transformer/rotary/config.py @@ -5,7 +5,7 @@ from fast_llm.config import Field, FieldHint, config_class from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.config import TritonConfig from fast_llm.utils import Assert @@ -29,8 +29,8 @@ def _from_dict( return NoRotaryConfig._from_dict(default, strict, flat) return super()._from_dict(default, strict=strict, flat=flat) - def build(self, tensor_space: TensorSpace | None = None) -> "Rotary": - return self._get_configurable_class()(self, tensor_space) + def build(self, kv_channels_dim: TensorDim) -> "Rotary": + return self._get_configurable_class()(self, kv_channels_dim) @classmethod @abc.abstractmethod diff --git a/fast_llm/layers/transformer/rotary/preprocessing.py b/fast_llm/layers/transformer/rotary/preprocessing.py deleted file mode 100644 index 9f8732f85..000000000 --- a/fast_llm/layers/transformer/rotary/preprocessing.py +++ /dev/null @@ -1,68 +0,0 @@ -import typing - -import torch - -from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs -from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig -from fast_llm.tensor import TensorMeta - - -class RotaryEmbeddingPreprocessor(Preprocessor): - _scalar_dim: TensorDim - _kv_channels_dim: TensorDim - _rotary_embedding_frequencies: torch.Tensor - _mask: torch.Tensor - _mask_value: torch.Tensor - _tensor_cache_max_sequence_length: int = -1 - - def __init__( - self, - config: DefaultRotaryConfig, - tensor_space: TensorSpace, - ): - self._config = config - self._tensor_space = tensor_space - self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] - self._kv_channels_dim = self._tensor_space[AttentionDimNames.kv_channels] - - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[AttentionKwargs.sequence_length]) - sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size - kwargs[AttentionKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ - :, sequence_k - kwargs[AttentionKwargs.sequence_q_dim].size : sequence_k - ] - kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] - - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[AttentionKwargs.rotary_freq_q] = TensorMeta.from_dims( - ( - self._scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], - self._scalar_dim, - self._kv_channels_dim, - ), - tensor_name=AttentionKwargs.rotary_freq_q, - ) - kwargs[AttentionKwargs.rotary_freq_k] = TensorMeta.from_dims( - ( - self._scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], - self._scalar_dim, - self._kv_channels_dim, - ), - tensor_name=AttentionKwargs.rotary_freq_k, - ) - - def _create_tensors(self, sequence_length: int) -> None: - if sequence_length <= self._tensor_cache_max_sequence_length: - return - self._tensor_cache_max_sequence_length = sequence_length - - self._rotary_embedding_frequencies = self._config.get_frequencies( - sequence_length, - self._kv_channels_dim.global_size, - device=self._tensor_space.distributed.device, - ) diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index ebb629aa1..bbf8b524a 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -6,9 +6,9 @@ from fast_llm.config import Configurable from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.functional.triton.rotary import triton_rotary_autograd_ -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs +from fast_llm.layers.transformer.config import AttentionKwargs from fast_llm.layers.transformer.rotary.config import ( DefaultRotaryConfig, Llama3RotaryConfig, @@ -41,14 +41,14 @@ def apply_rotary_embeddings(tensor: torch.Tensor, rope_frequencies: torch.Tensor return torch.view_as_real(complex_tensor * rope_frequencies).view_as(tensor).type_as(tensor) -class Rotary[ConfigType: RotaryConfig](Configurable[RotaryConfig], torch.nn.Module, Preprocessor): +class Rotary[ConfigType: RotaryConfig](Configurable[ConfigType], torch.nn.Module, Preprocessor): def __init__( self, config: ConfigType, - # The tensor space is only needed for preprocessing, so we make it optional. - tensor_space: TensorSpace | None = None, + kv_channels_dim: TensorDim, ): super().__init__(config) + self._kv_channels_dim = kv_channels_dim @abc.abstractmethod def forward( @@ -57,7 +57,7 @@ def forward( pass -class NoRotary[ConfigType: NoRotaryConfig](Rotary[NoRotaryConfig]): +class NoRotary[ConfigType: NoRotaryConfig](Rotary[ConfigType]): def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: @@ -70,24 +70,12 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: pass -class DefaultRotary[ConfigType: DefaultRotaryConfig](Rotary[DefaultRotaryConfig]): +class DefaultRotary[ConfigType: DefaultRotaryConfig](Rotary[ConfigType]): _rotary_embedding_frequencies: torch.Tensor _tensor_cache_max_sequence_length: int = -1 - def __init__( - self, - config: ConfigType, - tensor_space: TensorSpace | None = None, - ): - super().__init__(config, tensor_space) - self._tensor_space = tensor_space - if self._tensor_space is not None: - self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] - self._kv_channels_dim = self._tensor_space[AttentionDimNames.kv_channels] - - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - assert self._tensor_space is not None - self._create_tensors(kwargs[AttentionKwargs.sequence_length]) + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + self._create_tensors(kwargs[AttentionKwargs.sequence_length], batch.device) sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size kwargs[AttentionKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ :, sequence_k - kwargs[AttentionKwargs.sequence_q_dim].size : sequence_k @@ -95,21 +83,20 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - assert self._tensor_space is not None kwargs[AttentionKwargs.rotary_freq_q] = TensorMeta.from_dims( ( - self._scalar_dim, + scalar_dim, kwargs[AttentionKwargs.sequence_q_dim], - self._scalar_dim, + scalar_dim, self._kv_channels_dim, ), tensor_name=AttentionKwargs.rotary_freq_q, ) kwargs[AttentionKwargs.rotary_freq_k] = TensorMeta.from_dims( ( - self._scalar_dim, + scalar_dim, kwargs[AttentionKwargs.sequence_q_dim], - self._scalar_dim, + scalar_dim, self._kv_channels_dim, ), tensor_name=AttentionKwargs.rotary_freq_k, @@ -123,7 +110,7 @@ def forward( key = rotary_fn(key, kwargs[AttentionKwargs.rotary_freq_k]) return query, key - def _create_tensors(self, sequence_length: int) -> None: + def _create_tensors(self, sequence_length: int, device: torch.device) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: return self._tensor_cache_max_sequence_length = sequence_length @@ -131,10 +118,10 @@ def _create_tensors(self, sequence_length: int) -> None: self._rotary_embedding_frequencies = self._get_frequencies( sequence_length, self._kv_channels_dim.global_size, - device=self._tensor_space.distributed.device, + device=device, ) - def _get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda") -> torch.Tensor: + def _get_frequencies(self, sequence_length: int, kv_channels: int, device: torch.device) -> torch.Tensor: # Calculate the complex frequencies (https://blog.eleuther.ai/rotary-embeddings/) # `exp(i * n * a) = cos(n * a) + i sin(n * a)`, # `a = theta ** - (2 * (channel // 2) / kv_channels)`, @@ -149,12 +136,12 @@ def _get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda" ).contiguous() return frequencies - def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: + def _get_angle_scales(self, kv_channels: int, device: torch.device) -> torch.Tensor: return self._config.theta ** -torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64) -class Llama3Rotary[ConfigType: Llama3RotaryConfig](DefaultRotary[Llama3RotaryConfig]): - def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: +class Llama3Rotary[ConfigType: Llama3RotaryConfig](DefaultRotary[ConfigType]): + def _get_angle_scales(self, kv_channels: int, device: torch.device) -> torch.Tensor: scales = super()._get_angle_scales(kv_channels, device) low_frequency_wavelength = self._config.original_context_length / self._config.low_frequency_factor high_frequency_wavelength = self._config.original_context_length / self._config.high_frequency_factor @@ -173,17 +160,17 @@ def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: return torch.stack(new_scales) -class YarnRotary[ConfigType: YarnRotaryConfig](DefaultRotary[YarnRotaryConfig]): +class YarnRotary[ConfigType: YarnRotaryConfig](DefaultRotary[ConfigType]): """ Yarn scaling: https://github.com/huggingface/transformers/blob/006d9249ec0270ff6c4d3840979d23fe94bdc763/src/transformers/modeling_rope_utils.py#L163 [original paper](https://arxiv.org/abs/2309.00071) """ - def _get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda") -> torch.Tensor: + def _get_frequencies(self, sequence_length: int, kv_channels: int, device: torch.device) -> torch.Tensor: return super()._get_frequencies(sequence_length, kv_channels, device) * self._config.attention_factor - def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: + def _get_angle_scales(self, kv_channels: int, device: torch.device) -> torch.Tensor: scales = super()._get_angle_scales(kv_channels, device) # TODO: max_position_embeddings or original_context_length? # see https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L304 diff --git a/fast_llm/logging.py b/fast_llm/logging.py index 6d555a0bb..024d7d79c 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -14,7 +14,6 @@ if typing.TYPE_CHECKING: from fast_llm.core.distributed import ProcessGroup - from fast_llm.engine.distributed.distributed import Distributed logger = logging.getLogger(__name__) @@ -254,7 +253,6 @@ def log_distributed_tensor[ scale: float = 1.0, level: int = 2, storage: bool = False, - distributed: "Distributed", duplicate_groups: tuple[typing.Optional["ProcessGroup"], ...] = (), global_: bool = True, log_fn: type[BaseException] | typing.Callable[[str], T] | None = logger.info, @@ -263,7 +261,7 @@ def log_distributed_tensor[ if level <= 0: return if global_: - tensor, is_first_rank = meta.local_to_global(tensor, distributed=distributed) + tensor, is_first_rank = meta.local_to_global(tensor) storage = False is_first_rank = is_first_rank and all(group.rank() == 0 for group in duplicate_groups if group) if not is_first_rank: @@ -289,7 +287,6 @@ def log_distributed_grad[ scale: float = 1.0, level: int = 2, storage: bool = False, - distributed: "Distributed", duplicate_groups: tuple[typing.Optional["ProcessGroup"], ...] = (), grad_fn: typing.Callable[[torch.Tensor], torch.Tensor] | None = None, global_: bool = True, @@ -305,7 +302,6 @@ def log_distributed_grad[ scale=scale, level=level, storage=storage, - distributed=distributed, duplicate_groups=duplicate_groups, global_=global_, log_fn=log_fn, diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index 98937bdb1..3afd88ce1 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -3,11 +3,9 @@ import torch from fast_llm.data.data.gpt.data import GPTBatch -from fast_llm.engine.base_model.base_model import Layer, LossDef +from fast_llm.engine.base_model.base_model import LossDef from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.schedule.config import BatchConfig -from fast_llm.layers.language_model.embedding import LanguageModelEmbedding -from fast_llm.layers.transformer.block import TransformerBlock from fast_llm.models.custom.config import CustomBaseModelConfig from fast_llm.models.custom.head import CustomHead from fast_llm.models.gpt.model import GPTBaseModel, GPTModel @@ -17,26 +15,21 @@ class CustomBaseModel[ConfigType: CustomBaseModelConfig](GPTBaseModel[ConfigType]): def __init__( self, - config: CustomBaseModelConfig, + config: ConfigType, distributed_config: DistributedConfig, ): # TODO: Implement / update. super().__init__(config, distributed_config) - def get_layers(self) -> list[Layer]: - # TODO: Adjust as needed. - return [ - LanguageModelEmbedding(self._config, self._tensor_space), - *[ - TransformerBlock( - self._config.transformer, - self._tensor_space, - block_index=i + 1, - ) - for i in range(self._config.transformer.num_layers) - ], - CustomHead(self._config, self._tensor_space), - ] + def _get_head(self, prediction_distance): + return CustomHead( + self._config, + self._distributed_config, + self._hidden_dim, + max(self._config.transformer.num_layers + prediction_distance, 1), + f"Language model head {prediction_distance}", + prediction_distance=prediction_distance, + ) def preprocess_meta( self, batch_meta: BatchConfig | torch.Tensor, phase: PhaseType diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 47df8ba1c..41e0d607d 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -6,17 +6,18 @@ from fast_llm.data.data.gpt.data import GPTBatch from fast_llm.engine.base_model.base_model import BaseModel, Layer, LossDef from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.block.mlp.config import MLPLossNames, RoutingType from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor from fast_llm.layers.transformer.block import TransformerBlock -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs +from fast_llm.layers.transformer.config import AttentionKwargs from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron @@ -36,6 +37,7 @@ def __init__( config: GPTBaseModelConfig, distributed_config: DistributedConfig, ): + self._hidden_dim = TensorDim("hidden", config.transformer.hidden_size) super().__init__(config, distributed_config) self._use_flash_attention = self._config.transformer.do_use_flash_attention(distributed_config) if self._config.use_megatron_initialization: @@ -45,59 +47,81 @@ def __init__( # `self._reference_models` is not populated at this point, so we pass a mutable dict. self._preprocessors: list[Preprocessor] = [] if self._config.use_absolute_position_embeddings: - self._preprocessors.append(PositionEmbeddingPreprocessor(self._config, self._tensor_space)) + self._preprocessors.append(PositionEmbeddingPreprocessor(self._config, self._distributed_config)) # We have multiple identical rotary modules/preprocessors, so it's simpler to make a new one here. # TODO: Find a better solution. - self._preprocessors.append(self._config.transformer.rotary.build(self._tensor_space)) + self._preprocessors.append( + self._config.transformer.rotary.build(TensorDim("kv_channels", self._config.transformer.kv_channels)) + ) if self._use_flash_attention: - self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space)) + self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._distributed_config)) else: - self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) + self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._distributed_config)) if self._config.enable_dpo: # TODO better way to pass in? - self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._tensor_space)) + self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._distributed_config)) - def get_output_layers(self) -> list[Layer]: + def _get_output_layers(self) -> list[Layer]: layers = [] for i in range(self._config.prediction_heads): if i > 0: layers.append( - TransformerBlock( - self._config.transformer, - self._tensor_space, + self._get_block( # TODO MTP: which index? - block_index=max(self._config.transformer.num_layers + i, 1), + max(self._config.transformer.num_layers + i, 1), + f"MPT head {i} block", # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. - return_input=i < self._config.prediction_heads - 1, + i < self._config.prediction_heads - 1, ) ) - layers.append( - LanguageModelHead( - self._config, - self._tensor_space, - prediction_distance=i, - ) - ) + layers.append(self._get_head(i)) return layers def get_layers(self) -> list[Layer]: return [ - LanguageModelEmbedding(self._config, self._tensor_space), + self._get_embeddings(), *[ - TransformerBlock( - self._config.transformer, - self._tensor_space, - block_index=i + 1, + self._get_block( + i + 1, + f"Block {i + 1}", # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. - return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, + self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, ) for i in range(self._config.transformer.num_layers) ], - *self.get_output_layers(), + *self._get_output_layers(), ] + def _get_block( + self, + block_index: int, + name: str, + return_input: bool = False, + ): + return TransformerBlock( + self._config.transformer, + self._distributed_config, + self._hidden_dim, + block_index, + name, + return_input, + ) + + def _get_embeddings(self): + return LanguageModelEmbedding(self._config, self._distributed_config, self._hidden_dim, 0, "Embeddings") + + def _get_head(self, prediction_distance): + return LanguageModelHead( + self._config, + self._distributed_config, + self._hidden_dim, + max(self._config.transformer.num_layers + prediction_distance, 1), + f"Language model head {prediction_distance}", + prediction_distance=prediction_distance, + ) + def preprocess_meta( self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType ) -> list[tuple[TensorMeta, dict]]: @@ -116,8 +140,8 @@ def preprocess_meta( micro_sequence_length = sequence_length truncate_documents = True - batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) - batch_dim = TensorDim(AttentionDimNames.batch, micro_batch_size * batch_data.size, batch_data) + batch_data = self._distributed_config.get_distributed_dim(DistributedDimNames.batch_data) + batch_dim = TensorDim(BlockDimNames.batch, micro_batch_size * batch_data.size, batch_data) if micro_sequence_length is None: micro_sequence_length = sequence_length @@ -126,19 +150,17 @@ def preprocess_meta( # TODO: Calculate hidden dims elsewhere? sequence_q_dim = TensorDim( - AttentionDimNames.sequence_q, + BlockDimNames.sequence_q, micro_sequence_length, - self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.sequence_data), + self._distributed_config.get_distributed_dim(DistributedDimNames.sequence_data), ) hidden_sequence_q_dim = ( TensorDim( - AttentionDimNames.sequence_q_tp, + BlockDimNames.sequence_q_tp, micro_sequence_length, - self._tensor_space.distributed_config.get_distributed_dim( - DistributedDimNames.tensor_and_sequence_data - ), + self._distributed_config.get_distributed_dim(DistributedDimNames.tensor_and_sequence_data), ) - if self._tensor_space.distributed_config.sequence_tensor_parallel + if self._distributed_config.sequence_tensor_parallel else sequence_q_dim ) @@ -149,11 +171,10 @@ def preprocess_meta( sequence_first = self._config.sequence_first assert not (need_sequence_first and not sequence_first) - hidden_dim = self._tensor_space[AttentionDimNames.hidden] hidden_dims = ( - (hidden_sequence_q_dim, batch_dim, hidden_dim) + (hidden_sequence_q_dim, batch_dim, self._hidden_dim) if sequence_first - else (batch_dim, hidden_sequence_q_dim, hidden_dim) + else (batch_dim, hidden_sequence_q_dim, self._hidden_dim) ) common_kwargs = { @@ -166,7 +187,7 @@ def preprocess_meta( } sequence_k_pasts = range( - sequence_q_dim.size * self._tensor_space.distributed_config.sequence_data_rank, + sequence_q_dim.size * self._distributed_config.sequence_data_rank, sequence_length, micro_sequence_length, ) @@ -180,7 +201,7 @@ def preprocess_meta( preprocessed_meta = [] for i, sequence_k_past in enumerate(sequence_k_pasts): sequence_k = sequence_k_past + sequence_q_dim.size - sequence_k_dim = TensorDim(AttentionDimNames.sequence_k, sequence_k) + sequence_k_dim = TensorDim(BlockDimNames.sequence_k, sequence_k) tokens = TensorMeta.from_dims( hidden_dims[:2], tensor_name=f"tokens_{sequence_k_past}_to_{sequence_k-1}", dtype=torch.int64 @@ -234,7 +255,7 @@ def preprocess( prediction_heads: int = self._config.prediction_heads batch.token_ids = batch.token_ids.to( - device=self._tensor_space.distributed.device, + device=self._distributed.device, dtype=torch.int64, non_blocking=True, ) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 866de962f..9d54675be 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -6,7 +6,6 @@ from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointHandler from fast_llm.engine.config_utils.runnable import RunnableConfig -from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig @@ -47,14 +46,6 @@ class HybridSSMBaseModelConfig(GPTBaseModelConfig): # TODO: Support combination of different SSM block types. ssm_block_type: SSMBlockType | None = Field(init=False) - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - """ - Setup the tensor space for the model. - """ - super().setup_tensor_space(tensor_space) - if self.ssm_block_type is not None: - self.ssm.setup_tensor_space(tensor_space, self.ssm_block_type) - def _validate(self): with self._set_implicit_default(None): if self.ssm.dt_rank == "auto" or self.ssm.dt_rank is None: diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 32fbdad9b..7c67d7355 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -1,10 +1,6 @@ import logging import typing -from fast_llm.engine.base_model.base_model import Layer -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.language_model.embedding import LanguageModelEmbedding -from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.layers.ssm.block import SSMBlock from fast_llm.layers.transformer.block import TransformerBlock from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner, GPTModel @@ -20,88 +16,39 @@ class HybridSSMBaseModel[ConfigType: HybridSSMBaseModelConfig](GPTBaseModel[Conf As for the mixer, transformer uses MHA. For the LlambaBlock we support Mamba1 and discrete mamba2. """ - _is_setup: bool = False - - def __init__( + def _get_block( self, - config: HybridSSMBaseModelConfig, - distributed_config: DistributedConfig, + block_index: int, + name: str, + return_input: bool = False, ): - super().__init__(config, distributed_config) - - def get_output_layers(self) -> list[Layer]: - """ - Get the output layers of the model. - This includes the language model head and any additional heads specified in the configuration. - """ - layers: list[Layer] = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] - - if self._config.prediction_heads > 1: + if block_index > self._config.transformer.num_layers: + # MTP block block_type = self._config.default_mtp_type or self._config.hybrid_block_layout[-1] - for i in range(1, self._config.prediction_heads): - if block_type == SSMBlockType.transformer: - layers.append( - TransformerBlock( - self._config.transformer, - self._tensor_space, - block_index=len(self._config.hybrid_block_layout), - return_input=i != self._config.prediction_heads - 1, - ) - ) - else: - layers.append( - SSMBlock( - config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=self._config.ssm_block_type.get_mixer_class(), - block_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - ) - layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=i)) - - return layers - - def get_layers(self) -> list[Layer]: - """ - Create a list of layers for the model, interleaving Transformer and Mamba blocks - according to the block pattern. - """ - layers: list[Layer] = [LanguageModelEmbedding(self._config, self._tensor_space)] - - # Create blocks according to pattern - for i, block_type in enumerate(self._config.hybrid_block_layout): - if block_type == SSMBlockType.transformer: - # Transformer block - layers.append( - TransformerBlock( - self._config.transformer, - self._tensor_space, - block_index=i + 1, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - ) - else: - layers.append( - SSMBlock( - config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=self._config.ssm_block_type.get_mixer_class(), - block_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - ) - - # Add the output layers - layers += self.get_output_layers() - - return layers + else: + # Decoder block + block_type = self._config.hybrid_block_layout[block_index - 1] + + if block_type == SSMBlockType.transformer: + return TransformerBlock( + self._config.transformer, + self._distributed_config, + self._hidden_dim, + block_index, + name, + return_input, + ) + else: + return SSMBlock( + self._config.transformer, + self._config.ssm, + self._distributed_config, + self._hidden_dim, + self._config.ssm_block_type.get_mixer_class(), + block_index, + name, + return_input, + ) class HybridSSMModel[ConfigType: HybridSSMModelConfig](GPTModel[ConfigType]): diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index b12d12072..b6180c190 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -7,7 +7,7 @@ from fast_llm.core.distributed import ReduceOp from fast_llm.core.ops import reduce_op from fast_llm.engine.config_utils.initialization import Initializer, LambdaInitializer -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed from fast_llm.functional.triton.pointwise import triton_add, triton_copy @@ -138,30 +138,11 @@ def from_dims( **kwargs, ) - @classmethod - def from_tensor_space( - cls, - dim_names: tuple[str, ...], - tensor_space: TensorSpace, - *, - tensor_name: str = "", - dtype: torch.dtype = torch.float32, - reductions: tuple[tuple[str, ReduceOp], ...] = (), - **kwargs: typing.Any, - ) -> typing.Self: - dims = tuple(tensor_space[dim_name] for dim_name in dim_names) - if reductions: - # kwarg not available for ParameterMeta, so we only provide if necessary. - kwargs["reductions"] = tuple( - (tensor_space.distributed_config.get_distributed_dim(name), op) for name, op in reductions - ) - return cls.from_dims(dims, tensor_name=tensor_name, dtype=dtype, **kwargs) - @property def global_shape(self) -> torch.Size: return torch.Size([dim.global_size for dim in self.dims]) - def local_to_global(self, tensor: torch.Tensor, *, distributed: Distributed) -> tuple[torch.Tensor, ...]: + def local_to_global(self, tensor: torch.Tensor) -> tuple[torch.Tensor, ...]: """ Reconstruct a global tensor from its distributed slices. Support lazy-loaded safetensor slices. Returns a view of the input tensor (or the input tensor itself) when possible. @@ -171,7 +152,7 @@ def local_to_global(self, tensor: torch.Tensor, *, distributed: Distributed) -> Assert.eq(tensor.shape, self.shape) # Tensors are always either split or duplicated in the tensor-parallel direction. # TODO: Avoid hard-coded assumptions on duplication - is_first_rank, modified = distributed.config.tensor_rank == 0, False + is_first_rank, modified = True, False for dim, tensor_dim in enumerate(self.dims): if tensor_dim.is_parallel: diff --git a/tests/functional/test_triton_kernels.py b/tests/functional/test_triton_kernels.py index e61f72244..e4ad937b7 100644 --- a/tests/functional/test_triton_kernels.py +++ b/tests/functional/test_triton_kernels.py @@ -92,7 +92,7 @@ def test_triton_rotary(batch_size, sequence_length, num_heads, kv_channels): y1 = apply_rotary_embeddings( x, DefaultRotaryConfig(triton=False) - .build() + .build(None) ._get_frequencies( sequence_length, kv_channels, @@ -103,7 +103,7 @@ def test_triton_rotary(batch_size, sequence_length, num_heads, kv_channels): y2 = convert_rotary_real_to_complex( triton_rotary_( convert_rotary_complex_to_real(x, kv_channels, 3), - DefaultRotaryConfig(triton=True).build()._get_frequencies(sequence_length, kv_channels, device="cuda"), + DefaultRotaryConfig(triton=True).build(None)._get_frequencies(sequence_length, kv_channels, device="cuda"), ), kv_channels, 3, diff --git a/tests/test_attention.py b/tests/test_attention.py index 534e3800e..7d05e0a66 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -2,11 +2,12 @@ import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs, TransformerConfig +from fast_llm.layers.transformer.config import AttentionKwargs, TransformerConfig from fast_llm.layers.transformer.preprocessing import FlashAttnVarlenPreprocessor from fast_llm.utils import Assert @@ -30,19 +31,6 @@ def test_decide_window_size(): assert attention._decide_window_size() == 512 -def test_attention_constructor(): - transformer_conf = TransformerConfig( - num_layers=2, - num_attention_heads=2, - hidden_size=16, - ) - distributed_config = DistributedConfig() - tensor_space = TensorSpace(distributed_config=distributed_config) - transformer_conf.setup_tensor_space(tensor_space) - - Attention(transformer_conf, tensor_space, 1) - - def test_varlen_preprocessor(): sequence_lengths = [torch.tensor([8, 13, 4, 11], dtype=torch.int32), torch.tensor([11, 16, 9], dtype=torch.int32)] # First micro-sequence: @@ -63,27 +51,24 @@ def test_varlen_preprocessor(): ] micro_sequence_length = 12 sequence_length = 36 - transformer_cfg = TransformerConfig( + transformer_config = TransformerConfig( num_layers=2, num_attention_heads=2, hidden_size=16, use_flash_attention=True, ) - distributed_cfg = DistributedConfig(training_dtype="bfloat16") - distributed = Distributed(distributed_cfg, use_cpu=True) - tensor_space = TensorSpace(distributed_config=distributed_cfg) - tensor_space.setup(distributed) - transformer_cfg.setup_tensor_space(tensor_space) - varlen_preprocessor = FlashAttnVarlenPreprocessor(transformer_cfg, tensor_space=tensor_space) + distributed_config = DistributedConfig(training_dtype="bfloat16") + distributed = Distributed(distributed_config, use_cpu=True) + varlen_preprocessor = FlashAttnVarlenPreprocessor(transformer_config, distributed_config=distributed_config) for micro_seq_idx in range(int(sequence_length / micro_sequence_length)): kwargs = { - AttentionKwargs.sequence_q_dim: TensorDim(AttentionDimNames.sequence_k, micro_sequence_length), + AttentionKwargs.sequence_q_dim: TensorDim(BlockDimNames.sequence_k, micro_sequence_length), AttentionKwargs.sequence_k_dim: TensorDim( - AttentionDimNames.sequence_k, (micro_seq_idx + 1) * micro_sequence_length + BlockDimNames.sequence_k, (micro_seq_idx + 1) * micro_sequence_length ), AttentionKwargs.sequence_length: sequence_length, AttentionKwargs.sequence_lengths: sequence_lengths, } - varlen_preprocessor.preprocess(None, kwargs) + varlen_preprocessor.preprocess(torch.empty(1, device="cpu"), kwargs) Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_q], cumulative_sequences_q[micro_seq_idx]) Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_k], cumulative_sequences_k[micro_seq_idx]) diff --git a/tests/test_mlp.py b/tests/test_mlp.py deleted file mode 100644 index 802833eb2..000000000 --- a/tests/test_mlp.py +++ /dev/null @@ -1,29 +0,0 @@ -from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP -from fast_llm.layers.block.mlp.mlp import MLP -from fast_llm.layers.transformer.config import TransformerConfig - - -def test_mlp_constructor(): - transformer_conf = TransformerConfig( - num_layers=2, - num_attention_heads=2, - hidden_size=16, - ) - distributed_config = DistributedConfig() - tensor_space = TensorSpace(distributed_config=distributed_config) - transformer_conf.setup_tensor_space(tensor_space) - - MLP(transformer_conf, tensor_space, 0, "name") - - -def test_moe_mlp_constructor(): - transformer_conf = TransformerConfig( - num_layers=2, num_attention_heads=2, hidden_size=16, num_experts=2, add_linear_biases=False - ) - distributed_config = DistributedConfig() - tensor_space = TensorSpace(distributed_config=distributed_config) - transformer_conf.setup_tensor_space(tensor_space) - - MixtureOfExpertMLP(transformer_conf, tensor_space, 0, "name") diff --git a/tests/utils/global_variables.py b/tests/utils/global_variables.py index 80232bf53..42e588911 100644 --- a/tests/utils/global_variables.py +++ b/tests/utils/global_variables.py @@ -29,8 +29,8 @@ def set_testing_global_variables(): num_gpus = len(gpus) gpus = [gpus[(i + worker_id) % num_gpus] for i in range(num_gpus)] os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in gpus) - os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(SHARED_RESULT_PATH / "torchinductor_cache") - os.environ["TRITON_CACHE_DIR"] = str(SHARED_RESULT_PATH / "triton_cache") + # os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(SHARED_RESULT_PATH / "torchinductor_cache") + # os.environ["TRITON_CACHE_DIR"] = str(SHARED_RESULT_PATH / "triton_cache") # TODO: Fixtures diff --git a/tests/utils/utils.py b/tests/utils/utils.py index 88303a0f4..0dc3462eb 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -13,7 +13,6 @@ from fast_llm.core.distributed import ProcessGroup, allreduce_scalar, safe_barrier from fast_llm.engine.base_model.base_model import BaseModel, Layer from fast_llm.engine.config_utils.logging import configure_logging -from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageConfig from fast_llm.engine.multi_stage.stage import Stage @@ -33,12 +32,8 @@ def result_path(): def get_base_model(config: FastLLMModelConfig): # Create a base model (and distributed). # Using a full model config so we have the model type and distributed config in the same argument. - distributed = Distributed(config.distributed) - tensor_space = TensorSpace(config.distributed) - config.base_model.setup_tensor_space(tensor_space) - tensor_space.setup(distributed) base_model = config.get_model_class().base_model_class(config.base_model, config.distributed) - base_model.setup(distributed) + base_model.setup(distributed := Distributed(config.distributed)) return base_model, distributed From 797bd73befdb20c641b624b147a87971df266d55 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 Aug 2025 19:49:12 -0400 Subject: [PATCH 49/82] stuff --- fast_llm/engine/multi_stage/fast_llm_model.py | 1 - fast_llm/layers/block/block.py | 15 +- fast_llm/layers/block/config.py | 41 ++++- fast_llm/layers/block/mlp/config.py | 23 ++- .../layers/block/mlp/mixture_of_experts.py | 19 ++- fast_llm/layers/block/mlp/mlp.py | 25 ++- fast_llm/layers/common/config.py | 137 +---------------- .../layers/common/normalization/__init__.py | 0 .../layers/common/normalization/config.py | 142 ++++++++++++++++++ .../{ => normalization}/normalization.py | 2 +- fast_llm/models/gpt/conversion.py | 2 +- fast_llm/models/ssm/conversion.py | 2 +- 12 files changed, 235 insertions(+), 174 deletions(-) create mode 100644 fast_llm/layers/common/normalization/__init__.py create mode 100644 fast_llm/layers/common/normalization/config.py rename fast_llm/layers/common/{ => normalization}/normalization.py (99%) diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index da4fe527e..09ee788e6 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -51,7 +51,6 @@ def from_pretrained( use_cpu: bool = False, stage_filter: set | None = None, ) -> typing.Self: - print("IUGRGHIOERIO", cls, cls.config_class) metadata = cls.config_class.load_metadata(pretrained_config) config = cls.config_class.from_dict(metadata.config, *updates, update_type=UpdateType.update) if mode.support_training: diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 425731eb9..09370e3af 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -12,7 +12,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.layers.block.config import BlockConfig, BlockKwargs +from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockLayerConfig from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta @@ -94,27 +94,27 @@ class BlockLayerBase[ConfigType: Config](Configurable[ConfigType], Module): def __init__( self, config: ConfigType, + block_config: BlockConfig, distributed_config: DistributedConfig, # TODO: Review `hidden_dim` and `block_index` hidden_dim: TensorDim, block_index: int, name: str, - debug_level: int, - debug_memory: bool, ): super().__init__(config, distributed_config) + self._block_config = block_config self._hidden_dim = hidden_dim self._block_index = block_index self._name = name self._sequence_parallel: bool = self._distributed_config.sequence_tensor_parallel self._debug = DebugLayer( self._name, - debug_level, - debug_memory, + self._block_config.debug_transformer, + self._block_config.debug_transformer_memory, ) -class BlockLayer[ConfigType: Config](BlockLayerBase[ConfigType]): +class BlockLayer[ConfigType: BlockLayerConfig](BlockLayerBase[ConfigType]): """ Base class for mixer and MLP modules. """ @@ -148,13 +148,12 @@ def __init__( return_input: bool = False, ): super().__init__( + config, config, distributed_config, hidden_dim, block_index, name, - config.debug_transformer, - config.debug_transformer_memory, ) # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 0da7a0c99..3df82e24e 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -1,12 +1,20 @@ +import abc import enum +import functools +import typing from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.block.mlp.config import MLPConfig from fast_llm.layers.block.peft import TransformerPeftConfig -from fast_llm.layers.common.config import NormalizationConfig +from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.layers.block.block import BlockLayer + class BlockDimNames: # A set of common tensor dim names packed into a namespace. @@ -38,6 +46,37 @@ class AddLinearBiasChoices(str, enum.Enum): only_attn_qkv = "only_attn_qkv" +@config_class() +class BlockLayerConfig(BaseModelConfig): + """ + A common class for mixers and mlps, which have the exact same interface. + """ + + _abstract = True + + @functools.cached_property + @abc.abstractmethod + def layer_class(self) -> "type[BlockLayer]": + raise NotImplementedError() + + def get_layer( + self, + block_config: "BlockConfig", + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + ) -> "BlockLayer": + return self.layer_class( + self, + block_config, + distributed_config, + hidden_dim, + block_index, + name, + ) + + @config_class() # TODO: Use composition instead class BlockConfig(MLPConfig, BaseModelConfig): diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 57f7a9e03..83e45f002 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -1,9 +1,15 @@ import enum +import functools +import typing -from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.functional.config import ActivationType, MLPRecomputeLevel +from fast_llm.layers.block.config import BlockLayerConfig from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.layers.block.mlp.mlp import MLPBase + class MLPLossNames: load_balancing_loss = "load_balancing_loss" @@ -16,8 +22,8 @@ class RoutingType(str, enum.Enum): @config_class() -class MLPConfig(Config): - # TODO: Review names +class MLPConfig(BlockLayerConfig): + # TODO: Review names # TODO: Separate MoE? _abstract = False ffn_hidden_size: int = Field( default=None, @@ -150,6 +156,17 @@ def add_mlp_bias(self) -> bool: return True return False + @functools.cached_property + def layer_class(self) -> "type[MLPBase]": + if self.num_experts > 1: + from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP + + return MixtureOfExpertMLP + else: + from fast_llm.layers.block.mlp.mlp import MLP + + return MLP + def _validate(self) -> None: with self._set_implicit_default(): if self.activation_type is None: diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index 0bc531dad..2a234ca94 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -10,7 +10,7 @@ from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.layers.block.config import BlockConfig, BlockKwargs -from fast_llm.layers.block.mlp.config import MLPLossNames, RoutingType +from fast_llm.layers.block.mlp.config import MLPConfig, MLPLossNames, RoutingType from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.linear import Linear @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) -class MixtureOfExpertMLP[ConfigType: BlockConfig](MLPBase[ConfigType]): +class MixtureOfExpertMLP[ConfigType: MLPConfig](MLPBase[ConfigType]): """ MoeLayer following implementation from https://github.com/NVIDIA/Megatron-LM/blob/46ebc0e4202c980d98900000d455f754a7ff9d4b/megatron/model/transformer.py#L346 @@ -36,6 +36,7 @@ class MixtureOfExpertMLP[ConfigType: BlockConfig](MLPBase[ConfigType]): def __init__( self, config: ConfigType, + block_config: BlockConfig, distributed_config: DistributedConfig, hidden_dim: TensorDim, block_index: int, @@ -44,19 +45,21 @@ def __init__( Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, distributed_config, hidden_dim, block_index, name) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name) - layer_lr_scale = self._config.per_layer_lr_scale[block_index] if self._config.per_layer_lr_scale else None + layer_lr_scale = ( + self._block_config.per_layer_lr_scale[block_index] if self._block_config.per_layer_lr_scale else None + ) router_lr_scale = get_lr_scale(self._config.router_lr_scale, layer_lr_scale) self.router = Linear( - hidden_dim, + self._hidden_dim, TensorDim("router_experts", self._config.num_unshared_experts), bias=False, weight_init_method=init_normal_( - std=self._config.init_method_std, - min_val=self._config.init_method_min, - max_val=self._config.init_method_max, + std=self._block_config.init_method_std, + min_val=self._block_config.init_method_min, + max_val=self._block_config.init_method_max, ), lr_scale=router_lr_scale, ) diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index dc5178479..fd64713d1 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -9,29 +9,23 @@ from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig +from fast_llm.layers.block.mlp.config import MLPConfig from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import LinearBase from fast_llm.utils import Assert, get_lr_scale -class MLPBase[ConfigType: BlockConfig](BlockLayer[ConfigType]): +class MLPBase[ConfigType: MLPConfig](BlockLayer[ConfigType]): def __init__( self, config: ConfigType, + block_config: BlockConfig, distributed_config: DistributedConfig, hidden_dim: TensorDim, block_index: int, name: str, ): - super().__init__( - config, - distributed_config, - hidden_dim, - block_index, - name, - config.debug_transformer, - config.debug_transformer_memory, - ) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name) self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) intermediate_1_dim, intermediate_2_dim = self._get_intermediate_dims() @@ -48,7 +42,9 @@ def __init__( self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - layer_lr_scale = self._config.per_layer_lr_scale[block_index] if self._config.per_layer_lr_scale else None + layer_lr_scale = ( + self._block_config.per_layer_lr_scale[block_index] if self._block_config.per_layer_lr_scale else None + ) lr_scale = ( tuple(self._config.mlp_lr_scale) if isinstance(self._config.mlp_lr_scale, list) @@ -77,8 +73,8 @@ def __init__( ) # PEFT. - self.layer_1 = self._config.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1) - self.layer_2 = self._config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) + self.layer_1 = self._block_config.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1) + self.layer_2 = self._block_config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) def _get_intermediate_dims(self): intermediate_2_dim = TensorDim("intermediate", self._config.ffn_hidden_size, self._parallel_dim) @@ -94,13 +90,14 @@ class MLP[ConfigType: BlockConfig](MLPBase[ConfigType]): def __init__( self, config: ConfigType, + block_config: BlockConfig, distributed_config: DistributedConfig, hidden_dim: TensorDim, block_index: int, name: str, ): Assert.eq(config.num_experts, 1) - super().__init__(config, distributed_config, hidden_dim, block_index, name) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name) def forward( self, diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index f56e2a2c1..b09672961 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -1,146 +1,11 @@ import abc -import enum import typing -from fast_llm.config import Field, FieldHint, check_field, config_class +from fast_llm.config import Field, FieldHint, config_class from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.utils import Assert if typing.TYPE_CHECKING: - import torch - - from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.layers.common.linear import LinearBase, LinearLike - from fast_llm.layers.common.normalization import LayerNorm, RMSNorm - - -class NormalizationImplementation(str, enum.Enum): - """ - An enum for the available implementations of layer norm. - """ - - auto = "auto" - torch = "torch" - fused = "fused" - fast = "fast" - triton = "triton" - - -@config_class(registry=True) -class NormalizationConfig(BaseModelConfig): - pass - - @abc.abstractmethod - def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": - pass - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - if cls is NormalizationConfig and cls.get_subclass(default.get("type")) is None: - # Default subclass. - return LayerNormalizationConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) - - -@config_class(dynamic_type={NormalizationConfig: "none"}) -class NoNormalizationConfig(NormalizationConfig): - _abstract = False - - def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": - return torch.nn.Identity() - - -@config_class() -class LayerNormalizationBaseConfig(NormalizationConfig): - """ - Common configuration for layer norm and rms norm - """ - - # TODO: Rename to normalization_epsilon - epsilon: float = Field( - default=1e-5, - desc="Regularizer for the division.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - zero_centered: bool = Field( - default=False, - desc="Write the normalization weight as `w = 1 + w'`, to improve numerical accuracy when close to one.", - hint=FieldHint.architecture, - ) - implementation: NormalizationImplementation = Field( - default=NormalizationImplementation.auto, - desc="The implementation to use for the normalization layer.", - hint=FieldHint.performance, - ) - # TODO: Rename to normalization_init_range - initialization_range: float = Field( - default=0.0, - desc="Randomize the initialization with a uniform noise. Used to test for issues that may not be visible with the default initialization.", - hint=FieldHint.testing, - valid=check_field(Assert.geq, 0), - ) - - def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": - from fast_llm.engine.config_utils.initialization import init_uniform_centered_ - - kwargs = { - "hidden_dim": hidden_dim, - "eps": self.epsilon, - "implementation": self.implementation, - "zero_centered": self.zero_centered, - "lr_scale": lr_scale, - } - if self.initialization_range: - mean = 0 if self.zero_centered else 1 - kwargs["weight_init_method"] = init_uniform_centered_(self.initialization_range, mean=mean) - return self.module_class(**kwargs) - - @property - @abc.abstractmethod - def module_class(self): - pass - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - cls._handle_renamed_field(default, "normalization_type", "type") - cls._handle_renamed_field(default, "layer_norm_eps", "epsilon") - cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") - cls._handle_renamed_field(default, "normalization_implementation", "implementation") - cls._handle_renamed_field(default, "layer_norm_init_range", "initialization_range") - return super()._from_dict(default, strict, flat) - - -@config_class(dynamic_type={NormalizationConfig: "layer_norm"}) -class LayerNormalizationConfig(LayerNormalizationBaseConfig): - _abstract = False - - @property - def module_class(self): - from fast_llm.layers.common.normalization import LayerNorm - - return LayerNorm - - -@config_class(dynamic_type={NormalizationConfig: "rms_norm"}) -class RMSNormalizationConfig(LayerNormalizationBaseConfig): - _abstract = False - - @property - def module_class(self): - from fast_llm.layers.common.normalization import RMSNorm - - return RMSNorm @config_class() diff --git a/fast_llm/layers/common/normalization/__init__.py b/fast_llm/layers/common/normalization/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/layers/common/normalization/config.py b/fast_llm/layers/common/normalization/config.py new file mode 100644 index 000000000..658d00dfc --- /dev/null +++ b/fast_llm/layers/common/normalization/config.py @@ -0,0 +1,142 @@ +import abc +import enum +import typing + +from fast_llm.config import Field, FieldHint, check_field, config_class +from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + import torch + + from fast_llm.engine.config_utils.tensor_dim import TensorDim + from fast_llm.layers.common.normalization import LayerNorm, RMSNorm + + +class NormalizationImplementation(str, enum.Enum): + """ + An enum for the available implementations of layer norm. + """ + + auto = "auto" + torch = "torch" + fused = "fused" + fast = "fast" + triton = "triton" + + +@config_class(registry=True) +class NormalizationConfig(BaseModelConfig): + pass + + @abc.abstractmethod + def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": + pass + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is NormalizationConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return LayerNormalizationConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + + +@config_class(dynamic_type={NormalizationConfig: "none"}) +class NoNormalizationConfig(NormalizationConfig): + _abstract = False + + def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": + return torch.nn.Identity() + + +@config_class() +class LayerNormalizationBaseConfig(NormalizationConfig): + """ + Common configuration for layer norm and rms norm + """ + + # TODO: Rename to normalization_epsilon + epsilon: float = Field( + default=1e-5, + desc="Regularizer for the division.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + zero_centered: bool = Field( + default=False, + desc="Write the normalization weight as `w = 1 + w'`, to improve numerical accuracy when close to one.", + hint=FieldHint.architecture, + ) + implementation: NormalizationImplementation = Field( + default=NormalizationImplementation.auto, + desc="The implementation to use for the normalization layer.", + hint=FieldHint.performance, + ) + # TODO: Rename to normalization_init_range + initialization_range: float = Field( + default=0.0, + desc="Randomize the initialization with a uniform noise. Used to test for issues that may not be visible with the default initialization.", + hint=FieldHint.testing, + valid=check_field(Assert.geq, 0), + ) + + def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": + from fast_llm.engine.config_utils.initialization import init_uniform_centered_ + + kwargs = { + "hidden_dim": hidden_dim, + "eps": self.epsilon, + "implementation": self.implementation, + "zero_centered": self.zero_centered, + "lr_scale": lr_scale, + } + if self.initialization_range: + mean = 0 if self.zero_centered else 1 + kwargs["weight_init_method"] = init_uniform_centered_(self.initialization_range, mean=mean) + return self.module_class(**kwargs) + + @property + @abc.abstractmethod + def module_class(self): + pass + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + cls._handle_renamed_field(default, "normalization_type", "type") + cls._handle_renamed_field(default, "layer_norm_eps", "epsilon") + cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") + cls._handle_renamed_field(default, "normalization_implementation", "implementation") + cls._handle_renamed_field(default, "layer_norm_init_range", "initialization_range") + return super()._from_dict(default, strict, flat) + + +@config_class(dynamic_type={NormalizationConfig: "layer_norm"}) +class LayerNormalizationConfig(LayerNormalizationBaseConfig): + _abstract = False + + @property + def module_class(self): + from fast_llm.layers.common.normalization.normalization import LayerNorm + + return LayerNorm + + +@config_class(dynamic_type={NormalizationConfig: "rms_norm"}) +class RMSNormalizationConfig(LayerNormalizationBaseConfig): + _abstract = False + + @property + def module_class(self): + from fast_llm.layers.common.normalization.normalization import RMSNorm + + return RMSNorm diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization/normalization.py similarity index 99% rename from fast_llm/layers/common/normalization.py rename to fast_llm/layers/common/normalization/normalization.py index 2b928eb38..06ee11564 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -5,7 +5,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.normalization import triton_normalization_autograd -from fast_llm.layers.common.config import NormalizationImplementation +from fast_llm.layers.common.normalization.config import NormalizationImplementation from fast_llm.tensor import ParameterMeta, accumulate_gradient from fast_llm.utils import Assert diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 6e79388b0..e31c70a45 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -25,7 +25,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.block.mlp.config import RoutingType -from fast_llm.layers.common.config import LayerNormalizationConfig +from fast_llm.layers.common.normalization.config import LayerNormalizationConfig from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig from fast_llm.layers.transformer.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index b5e77e0f0..e9b18b848 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -21,7 +21,7 @@ from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType -from fast_llm.layers.common.config import RMSNormalizationConfig +from fast_llm.layers.common.normalization.config import RMSNormalizationConfig from fast_llm.layers.ssm.config import DTInitType, SSMBlockType from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter from fast_llm.models.ssm.config import ( From c0a37827488caca54a558aa5338e16b995dcf39c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 Aug 2025 20:15:02 -0400 Subject: [PATCH 50/82] stuff --- fast_llm/layers/transformer/attention.py | 2 +- fast_llm/layers/transformer/rotary/config.py | 2 +- fast_llm/models/gpt/model.py | 2 +- tests/functional/test_triton_kernels.py | 6 ++++-- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index d7a669295..9ad27534f 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -137,7 +137,7 @@ def __init__( self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward) # Rotary embeddings. - self._rotary = self._config.rotary.build(kv_channels_dim) + self._rotary = self._config.rotary.get_layer(kv_channels_dim) # Output. self.dense = InputParallelLinear( diff --git a/fast_llm/layers/transformer/rotary/config.py b/fast_llm/layers/transformer/rotary/config.py index f0e0079c7..6cc19fce8 100644 --- a/fast_llm/layers/transformer/rotary/config.py +++ b/fast_llm/layers/transformer/rotary/config.py @@ -29,7 +29,7 @@ def _from_dict( return NoRotaryConfig._from_dict(default, strict, flat) return super()._from_dict(default, strict=strict, flat=flat) - def build(self, kv_channels_dim: TensorDim) -> "Rotary": + def get_layer(self, kv_channels_dim: TensorDim) -> "Rotary": return self._get_configurable_class()(self, kv_channels_dim) @classmethod diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 41e0d607d..92f7b8173 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -51,7 +51,7 @@ def __init__( # We have multiple identical rotary modules/preprocessors, so it's simpler to make a new one here. # TODO: Find a better solution. self._preprocessors.append( - self._config.transformer.rotary.build(TensorDim("kv_channels", self._config.transformer.kv_channels)) + self._config.transformer.rotary.get_layer(TensorDim("kv_channels", self._config.transformer.kv_channels)) ) if self._use_flash_attention: self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._distributed_config)) diff --git a/tests/functional/test_triton_kernels.py b/tests/functional/test_triton_kernels.py index e4ad937b7..3f4446e4d 100644 --- a/tests/functional/test_triton_kernels.py +++ b/tests/functional/test_triton_kernels.py @@ -92,7 +92,7 @@ def test_triton_rotary(batch_size, sequence_length, num_heads, kv_channels): y1 = apply_rotary_embeddings( x, DefaultRotaryConfig(triton=False) - .build(None) + .get_layer(None) ._get_frequencies( sequence_length, kv_channels, @@ -103,7 +103,9 @@ def test_triton_rotary(batch_size, sequence_length, num_heads, kv_channels): y2 = convert_rotary_real_to_complex( triton_rotary_( convert_rotary_complex_to_real(x, kv_channels, 3), - DefaultRotaryConfig(triton=True).build(None)._get_frequencies(sequence_length, kv_channels, device="cuda"), + DefaultRotaryConfig(triton=True) + .get_layer(None) + ._get_frequencies(sequence_length, kv_channels, device="cuda"), ), kv_channels, 3, From e60ded4467c564bcb795d923758574d98e2f407b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 15 Aug 2025 13:30:30 -0400 Subject: [PATCH 51/82] stuff --- fast_llm/layers/block/block.py | 48 +++++++++++++++---- fast_llm/layers/block/config.py | 42 ++-------------- fast_llm/layers/block/mlp/config.py | 7 ++- .../layers/block/mlp/mixture_of_experts.py | 12 ++--- fast_llm/layers/block/mlp/mlp.py | 18 +++---- fast_llm/layers/ssm/block.py | 25 +++++----- fast_llm/layers/ssm/discrete_mamba2.py | 17 ++----- fast_llm/layers/ssm/mamba.py | 16 ++----- fast_llm/layers/ssm/mamba2.py | 20 ++------ fast_llm/layers/transformer/attention.py | 39 +++++++-------- fast_llm/layers/transformer/block.py | 16 ++++--- fast_llm/layers/transformer/config.py | 4 -- fast_llm/utils.py | 39 ++++++++------- 13 files changed, 133 insertions(+), 170 deletions(-) diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 09370e3af..64ba31626 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -12,7 +12,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockLayerConfig +from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta @@ -100,6 +100,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, ): super().__init__(config, distributed_config) self._block_config = block_config @@ -112,9 +113,10 @@ def __init__( self._block_config.debug_transformer, self._block_config.debug_transformer_memory, ) + self._lr_scale = lr_scale -class BlockLayer[ConfigType: BlockLayerConfig](BlockLayerBase[ConfigType]): +class BlockLayer[ConfigType: Config](BlockLayerBase[ConfigType]): """ Base class for mixer and MLP modules. """ @@ -145,6 +147,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, return_input: bool = False, ): super().__init__( @@ -154,28 +157,53 @@ def __init__( hidden_dim, block_index, name, + lr_scale, ) # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input # Note, layer_lr_scale does not impact the norms # TODO: add a separate norm_lr_scale - self.norm_1 = self._config.normalization.get_layer(self._hidden_dim) - self.norm_2 = self._config.normalization.get_layer(self._hidden_dim) + self.norm_1 = self._config.peft.apply_other(self._config.normalization.get_layer(self._hidden_dim)) + self.norm_2 = self._config.peft.apply_other(self._config.normalization.get_layer(self._hidden_dim)) - # The mixer needs to be created here for backward-compatible weight ordering. - setattr(self, self._mixer_module_name, self._create_mixer()) + # Attribute should be mixer, but Attention uses a different name for backward compatibility. TODO: Fix. + setattr( + self, + self._mixer_module_name, + self._mixer_class( + self._mixer_config, + self._config, + self._distributed_config, + self._hidden_dim, + self._block_index, + f"{self._name} mixer", + self._lr_scale, + ), + ) # TODO: Use dynamic type. from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.block.mlp.mlp import MLP self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._distributed_config, self._hidden_dim, self._block_index, f"{self._name} MLP" + self._config, + self._config, + self._distributed_config, + self._hidden_dim, + self._block_index, + f"{self._name} MLP", + lr_scale, ) - # PEFT. - self.norm_1 = self._config.peft.apply_other(self.norm_1) - self.norm_2 = self._config.peft.apply_other(self.norm_2) + @functools.cached_property + @abc.abstractmethod + def _mixer_class(self) -> type[BlockLayer]: + pass + + @property + @abc.abstractmethod + def _mixer_config(self) -> Config: + pass def setup(self, distributed: Distributed) -> None: super().setup(distributed) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 3df82e24e..63b58722b 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -1,19 +1,17 @@ -import abc import enum -import functools import typing from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.block.mlp.config import MLPConfig from fast_llm.layers.block.peft import TransformerPeftConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.layers.block.block import BlockLayer + pass + +# TODO: Generalize these beyond language models? (Ex. vision) class BlockDimNames: @@ -46,37 +44,6 @@ class AddLinearBiasChoices(str, enum.Enum): only_attn_qkv = "only_attn_qkv" -@config_class() -class BlockLayerConfig(BaseModelConfig): - """ - A common class for mixers and mlps, which have the exact same interface. - """ - - _abstract = True - - @functools.cached_property - @abc.abstractmethod - def layer_class(self) -> "type[BlockLayer]": - raise NotImplementedError() - - def get_layer( - self, - block_config: "BlockConfig", - distributed_config: DistributedConfig, - hidden_dim: TensorDim, - block_index: int, - name: str, - ) -> "BlockLayer": - return self.layer_class( - self, - block_config, - distributed_config, - hidden_dim, - block_index, - name, - ) - - @config_class() # TODO: Use composition instead class BlockConfig(MLPConfig, BaseModelConfig): @@ -90,6 +57,7 @@ class BlockConfig(MLPConfig, BaseModelConfig): desc="Configuration for the parameter-efficient fine tuning.", hint=FieldHint.architecture, ) + # TODO: Review names hidden_dropout: float = Field( default=0.0, desc="Dropout applied to the residual connections.", @@ -121,7 +89,7 @@ class BlockConfig(MLPConfig, BaseModelConfig): # TODO: Move these, not specific to a single block. num_layers: int = Field( default=12, - desc="Number of layers in the transformer.", + desc="Number of blocks in the model.", hint=FieldHint.architecture, valid=check_field(Assert.geq, 0), ) diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 83e45f002..89d423025 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -2,9 +2,8 @@ import functools import typing -from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.functional.config import ActivationType, MLPRecomputeLevel -from fast_llm.layers.block.config import BlockLayerConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -22,7 +21,7 @@ class RoutingType(str, enum.Enum): @config_class() -class MLPConfig(BlockLayerConfig): +class MLPConfig(Config): # TODO: Review names # TODO: Separate MoE? _abstract = False ffn_hidden_size: int = Field( @@ -90,7 +89,7 @@ class MLPConfig(BlockLayerConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - mlp_lr_scale: float | None | list[float | None] = Field( + mlp_lr_scale: float | None | tuple[float | None] = Field( default=None, desc="Custom learning rate scale for each expert.", doc="May be used to freeze some experts by setting their scale to zero.", diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index 2a234ca94..d52f5a429 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -14,7 +14,7 @@ from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.linear import Linear -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.utils import Assert, combine_lr_scales logger = logging.getLogger(__name__) @@ -41,16 +41,12 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, ): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name) - - layer_lr_scale = ( - self._block_config.per_layer_lr_scale[block_index] if self._block_config.per_layer_lr_scale else None - ) - router_lr_scale = get_lr_scale(self._config.router_lr_scale, layer_lr_scale) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) self.router = Linear( self._hidden_dim, @@ -61,7 +57,7 @@ def __init__( min_val=self._block_config.init_method_min, max_val=self._block_config.init_method_max, ), - lr_scale=router_lr_scale, + lr_scale=combine_lr_scales(self._config.router_lr_scale, self._lr_scale), ) dropless_moe = self._config.dropless_moe if dropless_moe and self._sequence_parallel: diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index fd64713d1..341ecf265 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -12,7 +12,7 @@ from fast_llm.layers.block.mlp.config import MLPConfig from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import LinearBase -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.utils import Assert, combine_lr_scales class MLPBase[ConfigType: MLPConfig](BlockLayer[ConfigType]): @@ -24,8 +24,9 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, ): - super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) intermediate_1_dim, intermediate_2_dim = self._get_intermediate_dims() @@ -42,15 +43,7 @@ def __init__( self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - layer_lr_scale = ( - self._block_config.per_layer_lr_scale[block_index] if self._block_config.per_layer_lr_scale else None - ) - lr_scale = ( - tuple(self._config.mlp_lr_scale) - if isinstance(self._config.mlp_lr_scale, list) - else self._config.mlp_lr_scale - ) - lr_scale = get_lr_scale(lr_scale, layer_lr_scale) + lr_scale = combine_lr_scales(self._lr_scale, self._config.mlp_lr_scale) # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( @@ -95,9 +88,10 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, ): Assert.eq(config.num_experts, 1) - super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) def forward( self, diff --git a/fast_llm/layers/ssm/block.py b/fast_llm/layers/ssm/block.py index 361fe9818..408f21041 100644 --- a/fast_llm/layers/ssm/block.py +++ b/fast_llm/layers/ssm/block.py @@ -1,3 +1,5 @@ +import functools + from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.block.block import Block, BlockLayer @@ -17,21 +19,20 @@ def __init__( ssm_config: SSMConfig, distributed_config: DistributedConfig, hidden_dim: TensorDim, - mixer_cls: type[BlockLayer], block_index: int, + lr_scale: float | list[float] | None, name: str, + mixer_class: type[BlockLayer], return_input: bool = False, ): self._ssm_config = ssm_config - self._mixer_cls = mixer_cls - super().__init__(config, distributed_config, hidden_dim, block_index, name, return_input) + self._mixer_class = mixer_class + super().__init__(config, distributed_config, hidden_dim, block_index, name, lr_scale, return_input) + + @functools.cached_property + def _mixer_class(self) -> type[BlockLayer]: + return self._mixer_class - def _create_mixer(self) -> BlockLayer: - return self._mixer_cls( - self._ssm_config, - self._config, - self._distributed_config, - self._hidden_dim, - self._block_index, - f"{self._name} mixer", - ) + @property + def _mixer_config(self) -> SSMConfig: + return self._config diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 7e445cca1..fb78f09c5 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -14,7 +14,7 @@ from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.ssm.mamba import init_kaiming_ from fast_llm.tensor import ParameterMeta -from fast_llm.utils import div, get_lr_scale +from fast_llm.utils import combine_lr_scales, div logger = logging.getLogger(__name__) @@ -41,7 +41,6 @@ class DiscreteMamba2[ConfigType: SSMConfig](BlockLayer[ConfigType]): """ _mixer_name: typing.ClassVar[str] = "discrete_mamba_2" - _config: SSMConfig def __init__( self, @@ -51,16 +50,9 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, ): - super().__init__( - config, - distributed_config, - hidden_dim, - block_index, - name, - block_config.debug_transformer, - block_config.debug_transformer_memory, - ) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) state_dim = TensorDim("state", self._config.state_size) v_head_size_dim = TensorDim(SSMDimNames.head_dim, div(self._config.d_inner, self._config.n_v_heads)) @@ -90,8 +82,7 @@ def __init__( # local_bc_size = local_head_groups * state self._local_bc_size = bc_dim.size - layer_lr_scale = block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None - lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + lr_scale = combine_lr_scales(self._lr_scale, self._config.mamba_lr_scale) # TODO: double check initializations # Projections diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index ac6576a87..37ac20ef1 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -13,7 +13,7 @@ from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig from fast_llm.tensor import ParameterMeta -from fast_llm.utils import Assert, div, get_lr_scale +from fast_llm.utils import Assert, combine_lr_scales, div try: from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn as _mamba_inner_fn # noqa @@ -65,16 +65,9 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, ): - super().__init__( - config, - distributed_config, - hidden_dim, - block_index, - name, - block_config.debug_transformer, - block_config.debug_transformer_memory, - ) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) assert self._distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" # TODO: It's not silu? Assert.eq(self._config.activation_type, ActivationType.silu) @@ -88,8 +81,7 @@ def __init__( inner_projection_dim = ConcatenatedTensorDim("inner_projection", (inner_dim, inner_dim)) x_projection_dim = ConcatenatedTensorDim("x_projection", (dt_rank_dim, state_dim, state_dim)) - layer_lr_scale = block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None - lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + lr_scale = combine_lr_scales(self._lr_scale, self._config.mamba_lr_scale) # TODO: Backward compatibility? # TODO: lr_scale? diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index e6ca9ea12..bc40658e6 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -13,7 +13,7 @@ from fast_llm.layers.ssm.config import SSMConfig from fast_llm.layers.ssm.mamba import init_A, init_dtprojbias, init_kaiming_ from fast_llm.tensor import ParameterMeta -from fast_llm.utils import Assert, div, get_lr_scale +from fast_llm.utils import Assert, combine_lr_scales, div try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa @@ -47,22 +47,10 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, ): - super().__init__( - config, - distributed_config, - hidden_dim, - block_index, - name, - block_config.debug_transformer, - block_config.debug_transformer_memory, - ) - self._config: SSMConfig = config + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) Assert.eq(self._config.activation_type, ActivationType.silu) - layer_lr_scale: float | None = ( - block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None - ) - lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) num_heads = div(self._config.d_inner, self._config.state_size) num_head_groups = div(self._config.d_xb, self._config.state_size) @@ -94,6 +82,8 @@ def __init__( self._local_inner_size = inner_dim.size self._local_xb_size = xb_dim.size + lr_scale = combine_lr_scales(self._lr_scale, self._config.mamba_lr_scale) + conv1d_dim = inner_dim if self._config.repeat_kv_before_conv else xb_dim self.conv1d_weight = ParameterMeta.from_dims( ( diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 9ad27534f..8abab1206 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -9,11 +9,11 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.block.block import BlockLayer -from fast_llm.layers.block.config import BlockDimNames +from fast_llm.layers.block.config import BlockConfig, BlockDimNames from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import AttentionKwargs, TransformerConfig -from fast_llm.utils import div, get_lr_scale +from fast_llm.layers.transformer.config import AttentionConfig, AttentionKwargs +from fast_llm.utils import combine_lr_scales, div try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -48,7 +48,7 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention[ConfigType: TransformerConfig](BlockLayer[ConfigType]): +class Attention[ConfigType: AttentionConfig](BlockLayer[ConfigType]): """ A self-attention layer. """ @@ -56,20 +56,15 @@ class Attention[ConfigType: TransformerConfig](BlockLayer[ConfigType]): def __init__( self, config: ConfigType, + block_config: BlockConfig, distributed_config: DistributedConfig, hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, ): - super().__init__( - config, - distributed_config, - hidden_dim, - block_index, - name, - config.debug_transformer, - config.debug_transformer_memory, - ) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) + self._use_flash_attention = self._config.do_use_flash_attention(self._distributed_config) self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) @@ -112,8 +107,10 @@ def __init__( max_val=self._config.init_method_max_attn_proj, ) - layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) + lr_scale = combine_lr_scales( + self._lr_scale, + self._config.attention_lr_scale, + ) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( @@ -123,7 +120,7 @@ def __init__( weight_init_method=init_method_qkv, bias_init_method=init_zeros_, sequence_parallel=self._sequence_parallel, - lr_scale=attention_lr_scale, + lr_scale=lr_scale, ) self.key_value = OutputParallelLinear( hidden_dim, @@ -132,7 +129,7 @@ def __init__( weight_init_method=init_method_qkv, bias_init_method=init_zeros_, sequence_parallel=self._sequence_parallel, - lr_scale=attention_lr_scale, + lr_scale=lr_scale, ) self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward) @@ -147,13 +144,13 @@ def __init__( weight_init_method=init_method_std_attn_proj, bias_init_method=init_zeros_, sequence_parallel=self._sequence_parallel, - lr_scale=attention_lr_scale, + lr_scale=lr_scale, ) # PEFT. - self.query = self._config.peft.apply_linear(self.query, TransformerSubLayerName.query) - self.key_value = self._config.peft.apply_linear(self.key_value, TransformerSubLayerName.key_value) - self.dense = self._config.peft.apply_linear(self.dense, TransformerSubLayerName.dense) + self.query = self._block_config.peft.apply_linear(self.query, TransformerSubLayerName.query) + self.key_value = self._block_config.peft.apply_linear(self.key_value, TransformerSubLayerName.key_value) + self.dense = self._block_config.peft.apply_linear(self.dense, TransformerSubLayerName.dense) if self._debug.enabled: self._query_dims = ( diff --git a/fast_llm/layers/transformer/block.py b/fast_llm/layers/transformer/block.py index a5aad45a9..ba593461b 100644 --- a/fast_llm/layers/transformer/block.py +++ b/fast_llm/layers/transformer/block.py @@ -1,9 +1,10 @@ +import functools import logging import typing -from fast_llm.layers.block.block import Block, BlockLayer +from fast_llm.layers.block.block import Block from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.layers.transformer.config import AttentionConfig, TransformerConfig logger = logging.getLogger(__name__) @@ -12,7 +13,10 @@ class TransformerBlock[ConfigType: TransformerConfig](Block[ConfigType]): # TODO: Standardize to `mixer` _mixer_module_name: typing.ClassVar[str] = "self_attn" - def _create_mixer(self) -> BlockLayer: - return Attention( - self._config, self._distributed_config, self._hidden_dim, self._block_index, f"{self._name} attn" - ) + @functools.cached_property + def _mixer_class(self) -> type[Attention]: + return Attention + + @property + def _mixer_config(self) -> AttentionConfig: + return self._config diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index a40f676ca..02b741723 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -1,6 +1,5 @@ import functools import logging -import typing import warnings from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none @@ -11,9 +10,6 @@ from fast_llm.layers.transformer.rotary.config import RotaryConfig from fast_llm.utils import Assert, div -if typing.TYPE_CHECKING: - pass - logger = logging.getLogger(__name__) diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 58285d408..f7f5e9663 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -348,22 +348,29 @@ def check_equal_nested(config_a, config_b): raise ValueError("\n".join(errors)) -def get_lr_scale( - lr_scale: float | None | tuple[float | None, ...], layer_lr_scale: float | None -) -> float | None | tuple[float | None, ...]: - """ - Combine module and layer lr_scale. - If one is None, return the other. - """ - if lr_scale is None: - return layer_lr_scale - if layer_lr_scale is None: - return lr_scale - if isinstance(lr_scale, float): - return lr_scale * layer_lr_scale - if isinstance(lr_scale, tuple): - return tuple(lrs * layer_lr_scale if lrs is not None else layer_lr_scale for lrs in lr_scale) - raise ValueError(f"Invalid lr_scale: {lr_scale} (type {type(lr_scale)})") +def combine_lr_scales(*lr_scales: float | None | tuple[float | None, ...]): + # Remove `None` entries. + lr_scales = [lr_scale for lr_scale in lr_scales if lr_scale is not None] + if not lr_scales: + # Everything is None + return None + tuple_length = None + # Check if we have tuples, and determine the length. + for lr_scale in lr_scales: + if isinstance(lr_scale, tuple): + if tuple_length is None: + tuple_length = len(lr_scale) + else: + assert len(lr_scale) == tuple_length + if tuple_length is None: + # No tuple: simple product. + return math.prod(lr_scales) + else: + # Tuple(s): use recursion. + return [ + combine_lr_scales(*[lr_scale[i] if isinstance(lr_scale, tuple) else lr_scale for lr_scale in lr_scales]) + for i in range(tuple_length) + ] class Interrupter: From 1483bcc7cbe7bf6fa763ab12b75573ed89015207 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 15 Aug 2025 13:57:56 -0400 Subject: [PATCH 52/82] stuff --- fast_llm/layers/language_model/embedding.py | 5 +++-- fast_llm/layers/language_model/head.py | 7 ++++--- fast_llm/layers/ssm/config.py | 23 --------------------- fast_llm/layers/ssm/discrete_mamba2.py | 10 ++++----- fast_llm/layers/ssm/mamba.py | 7 ------- fast_llm/layers/ssm/mamba2.py | 3 +-- 6 files changed, 13 insertions(+), 42 deletions(-) diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index d1b912167..fd4e8412e 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -37,12 +37,13 @@ def __init__( ): super().__init__( config, + config.transformer, distributed_config, hidden_dim, block_index, name, - config.transformer.debug_transformer, - config.transformer.debug_transformer_memory, + # TODO: Add lr scale? + None, ) self._residual_dtype = ( self._distributed_config.optimization_dtype diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index cc6c69262..7b1b5f6d8 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -44,12 +44,13 @@ def __init__( ): super().__init__( config, + config.transformer, distributed_config, hidden_dim, block_index, name, - config.transformer.debug_transformer, - config.transformer.debug_transformer_memory, + # TODO: Add lr scale? + None, ) self._parallel_logits = self._distributed_config.tensor_parallel > 1 and config.parallel_embeddings self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) @@ -161,7 +162,7 @@ def _forward_backward( TensorDim( BlockDimNames.sequence_q_tp, dims[sequence_index].global_size, - DistributedDimNames.tensor, + self._distributed_config.get_distributed_dim(DistributedDimNames.tensor), ) if self._sequence_parallel_logits else TensorDim(BlockDimNames.sequence_q, dims[sequence_index].global_size) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 2daad1186..8917feaf6 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -3,35 +3,12 @@ from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.functional.config import ActivationType -from fast_llm.layers.block.config import BlockDimNames from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.engine.config_utils.initialization import Initializer -class SSMDimNames(BlockDimNames): - # TODO: Use separate tensor space for different mixers so there is no risk of name conflict. - state = "ssm_state" # State dimension (N), aka head size / num channels - head_dim = "ssm_head_dim" - head_groups = "ssm_head_groups" - group_heads = "ssm_group_heads" - - convolution_kernel = "ssm_convolution_kernel" # Kernel dimension of the conv1d in mamba layers - - dt_rank = "ssm_dt_rank" - - # Composite dimensions - composite_heads = "ssm_composite_heads" - composite_heads_and_head_dim = "ssm_composite_heads_and_head_dim" - composite_head_groups_and_state = "ssm_composite_head_groups_and_state" - - # Concatenated dimensions - concatenated_convolution = "ssm_concatenated_convolution" - concatenated_x_projection = "ssm_x_concatenated_x_projection" - concatenated_inner_projection = "ssm_concatenated_inner_projection" - - class SSMBlockType(enum.StrEnum): """ An enum for the available mamba types for the MLP layer. diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index fb78f09c5..7fea3d480 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -11,7 +11,7 @@ from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.ssm.config import SSMConfig from fast_llm.layers.ssm.mamba import init_kaiming_ from fast_llm.tensor import ParameterMeta from fast_llm.utils import combine_lr_scales, div @@ -54,15 +54,15 @@ def __init__( ): super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) state_dim = TensorDim("state", self._config.state_size) - v_head_size_dim = TensorDim(SSMDimNames.head_dim, div(self._config.d_inner, self._config.n_v_heads)) + v_head_size_dim = TensorDim("v_head_size", div(self._config.d_inner, self._config.n_v_heads)) head_groups_dim = TensorDim( - SSMDimNames.head_groups, + "head_groups", self._config.n_qk_heads, self._distributed_config.get_distributed_dim(DistributedDimNames.tensor), ) - group_heads_dim = TensorDim(SSMDimNames.group_heads, div(self._config.n_v_heads, self._config.n_qk_heads)) - heads_dim = CompositeTensorDim(SSMDimNames.composite_heads, (head_groups_dim, group_heads_dim)) + group_heads_dim = TensorDim("group_heads", div(self._config.n_v_heads, self._config.n_qk_heads)) + heads_dim = CompositeTensorDim("heads", (head_groups_dim, group_heads_dim)) inner_dim = CompositeTensorDim("inner", (head_groups_dim, group_heads_dim, v_head_size_dim)) bc_dim = CompositeTensorDim("bc", (head_groups_dim, state_dim)) convolution_kernel_dim = TensorDim("convolution_kernel", self._config.conv_kernel_dimension) diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index 37ac20ef1..59fd03a1e 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -91,7 +91,6 @@ def __init__( bias=False, weight_init_method=init_kaiming_(hidden_dim.size), ) - self.conv1d_weight = ParameterMeta.from_dims( ( inner_dim, @@ -101,7 +100,6 @@ def __init__( init_method=init_kaiming_(inner_dim.size), lr_scale=lr_scale, ) - self.x_proj = Linear( inner_dim, x_projection_dim, @@ -110,27 +108,23 @@ def __init__( lr_scale=lr_scale, ) self.x_proj.weight.auto_grad_accumulation = True - # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 self.dt_proj_weight = ParameterMeta.from_dims( (inner_dim, dt_rank_dim), init_method=init_kaiming_(self._config.dt_rank), lr_scale=lr_scale, ) - self.dt_proj_bias = ParameterMeta.from_dims( (inner_dim,), init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), lr_scale=lr_scale, ) - self.A_log = ParameterMeta.from_dims( (inner_dim, state_dim), weight_decay=False, init_method=init_A(self._config.state_size, inner_dim.size), lr_scale=lr_scale, ) - # D "skip" parameter self.D = ParameterMeta.from_dims( (inner_dim,), @@ -138,7 +132,6 @@ def __init__( init_method=init_ones_, lr_scale=lr_scale, ) - self.out_proj = Linear( inner_dim, hidden_dim, diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index bc40658e6..bf9c30521 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -81,10 +81,10 @@ def __init__( self._group_heads = div(self._local_heads, self._local_head_groups) self._local_inner_size = inner_dim.size self._local_xb_size = xb_dim.size + conv1d_dim = inner_dim if self._config.repeat_kv_before_conv else xb_dim lr_scale = combine_lr_scales(self._lr_scale, self._config.mamba_lr_scale) - conv1d_dim = inner_dim if self._config.repeat_kv_before_conv else xb_dim self.conv1d_weight = ParameterMeta.from_dims( ( conv1d_dim, @@ -107,7 +107,6 @@ def __init__( sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) - self.dt_in_proj = Linear( hidden_dim, dt_rank_dim, From 4deb501748a1f725d96aff0ba88034166b4ae04f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 15 Aug 2025 14:12:57 -0400 Subject: [PATCH 53/82] misc --- fast_llm/layers/block/block.py | 4 ---- fast_llm/models/gpt/model.py | 6 ++++++ fast_llm/models/ssm/model.py | 10 ++++++++-- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 64ba31626..535ca12c5 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -210,10 +210,6 @@ def setup(self, distributed: Distributed) -> None: getattr(self, self._mixer_module_name).setup(distributed) self.mlp.setup(distributed) - @abc.abstractmethod - def _create_mixer(self) -> BlockLayer: - pass - @torch.compile def _bias_dropout_add( self, input_: torch.Tensor, bias: torch.Tensor | None, residual: torch.Tensor diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 92f7b8173..581429467 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -100,12 +100,18 @@ def _get_block( name: str, return_input: bool = False, ): + lr_scale = ( + None + if self._config.transformer.per_layer_lr_scale is None + else self._config.transformer.per_layer_lr_scale[block_index] + ) return TransformerBlock( self._config.transformer, self._distributed_config, self._hidden_dim, block_index, name, + lr_scale, return_input, ) diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 7c67d7355..9afd7dabb 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -29,6 +29,12 @@ def _get_block( # Decoder block block_type = self._config.hybrid_block_layout[block_index - 1] + lr_scale = ( + None + if self._config.transformer.per_layer_lr_scale is None + else self._config.transformer.per_layer_lr_scale[block_index] + ) + if block_type == SSMBlockType.transformer: return TransformerBlock( self._config.transformer, @@ -36,7 +42,7 @@ def _get_block( self._hidden_dim, block_index, name, - return_input, + lr_scale.return_input, ) else: return SSMBlock( @@ -44,9 +50,9 @@ def _get_block( self._config.ssm, self._distributed_config, self._hidden_dim, - self._config.ssm_block_type.get_mixer_class(), block_index, name, + lr_scale.self._config.ssm_block_type.get_mixer_class(), return_input, ) From fc809e0ed20747314ca98f08734d7eebd43cfb22 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 15 Aug 2025 15:37:38 -0400 Subject: [PATCH 54/82] Misc, tests pass --- fast_llm/layers/block/block.py | 4 ++-- fast_llm/layers/block/config.py | 2 +- fast_llm/layers/block/mlp/config.py | 18 +++--------------- .../layers/block/mlp/mixture_of_experts.py | 2 +- fast_llm/layers/block/mlp/mlp.py | 4 ++-- fast_llm/layers/ssm/block.py | 4 ++-- fast_llm/layers/ssm/discrete_mamba2.py | 2 +- fast_llm/layers/ssm/mamba.py | 2 +- fast_llm/layers/ssm/mamba2.py | 2 +- fast_llm/layers/transformer/attention.py | 2 +- fast_llm/models/ssm/model.py | 6 ++++-- fast_llm/utils.py | 6 +++--- 12 files changed, 22 insertions(+), 32 deletions(-) diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 535ca12c5..b8aad3903 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -100,7 +100,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, ): super().__init__(config, distributed_config) self._block_config = block_config @@ -147,7 +147,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, return_input: bool = False, ): super().__init__( diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 63b58722b..95bcb02af 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -99,7 +99,7 @@ class BlockConfig(MLPConfig, BaseModelConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - per_layer_lr_scale: list[float] | None = Field( + per_layer_lr_scale: list[float | None] | None = Field( default=None, desc="Custom learning rate scale for each layer.", doc="May be used to freeze some layers by setting their scale to zero.", diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 89d423025..88ce4af10 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -1,5 +1,4 @@ import enum -import functools import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none @@ -7,7 +6,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.layers.block.mlp.mlp import MLPBase + pass class MLPLossNames: @@ -89,7 +88,7 @@ class MLPConfig(Config): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - mlp_lr_scale: float | None | tuple[float | None] = Field( + mlp_lr_scale: float | None | tuple[float | None, ...] = Field( default=None, desc="Custom learning rate scale for each expert.", doc="May be used to freeze some experts by setting their scale to zero.", @@ -155,17 +154,6 @@ def add_mlp_bias(self) -> bool: return True return False - @functools.cached_property - def layer_class(self) -> "type[MLPBase]": - if self.num_experts > 1: - from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP - - return MixtureOfExpertMLP - else: - from fast_llm.layers.block.mlp.mlp import MLP - - return MLP - def _validate(self) -> None: with self._set_implicit_default(): if self.activation_type is None: @@ -198,7 +186,7 @@ def _validate(self) -> None: Assert.leq(self.num_shared_experts, self.num_experts) Assert.leq(self.num_shared_experts + self.num_experts_per_token, self.num_experts) - if isinstance(self.mlp_lr_scale, list): + if isinstance(self.mlp_lr_scale, tuple): Assert.eq(len(self.mlp_lr_scale), self.num_experts) for scale in self.mlp_lr_scale: if scale is not None: diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index d52f5a429..4f7cf2dc4 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -41,7 +41,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, ): Assert.gt(config.num_experts, 1) # TODO: Implement? diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index 341ecf265..c3a714a42 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -24,7 +24,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, ): super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) @@ -88,7 +88,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, ): Assert.eq(config.num_experts, 1) super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) diff --git a/fast_llm/layers/ssm/block.py b/fast_llm/layers/ssm/block.py index 408f21041..22d01a5cb 100644 --- a/fast_llm/layers/ssm/block.py +++ b/fast_llm/layers/ssm/block.py @@ -20,8 +20,8 @@ def __init__( distributed_config: DistributedConfig, hidden_dim: TensorDim, block_index: int, - lr_scale: float | list[float] | None, name: str, + lr_scale: float | None, mixer_class: type[BlockLayer], return_input: bool = False, ): @@ -35,4 +35,4 @@ def _mixer_class(self) -> type[BlockLayer]: @property def _mixer_config(self) -> SSMConfig: - return self._config + return self._ssm_config diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 7fea3d480..0d91fbaff 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -50,7 +50,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, ): super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) state_dim = TensorDim("state", self._config.state_size) diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index 59fd03a1e..79a0e5c8e 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -65,7 +65,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, ): super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) assert self._distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index bf9c30521..eec134a22 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -47,7 +47,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, ): super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) Assert.eq(self._config.activation_type, ActivationType.silu) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 8abab1206..41d509512 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -61,7 +61,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, ): super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 9afd7dabb..4b7785402 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -42,7 +42,8 @@ def _get_block( self._hidden_dim, block_index, name, - lr_scale.return_input, + lr_scale, + return_input, ) else: return SSMBlock( @@ -52,7 +53,8 @@ def _get_block( self._hidden_dim, block_index, name, - lr_scale.self._config.ssm_block_type.get_mixer_class(), + lr_scale, + self._config.ssm_block_type.get_mixer_class(), return_input, ) diff --git a/fast_llm/utils.py b/fast_llm/utils.py index f7f5e9663..51249c3fa 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -350,7 +350,7 @@ def check_equal_nested(config_a, config_b): def combine_lr_scales(*lr_scales: float | None | tuple[float | None, ...]): # Remove `None` entries. - lr_scales = [lr_scale for lr_scale in lr_scales if lr_scale is not None] + lr_scales = tuple(lr_scale for lr_scale in lr_scales if lr_scale is not None) if not lr_scales: # Everything is None return None @@ -367,10 +367,10 @@ def combine_lr_scales(*lr_scales: float | None | tuple[float | None, ...]): return math.prod(lr_scales) else: # Tuple(s): use recursion. - return [ + return tuple( combine_lr_scales(*[lr_scale[i] if isinstance(lr_scale, tuple) else lr_scale for lr_scale in lr_scales]) for i in range(tuple_length) - ] + ) class Interrupter: From cdb67105cc9f70c234132ca7d248f7db0cdfef89 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 20 Aug 2025 10:50:59 -0400 Subject: [PATCH 55/82] misc --- fast_llm/layers/block/config.py | 4 - .../layers/common/normalization/config.py | 44 +++--- .../common/normalization/normalization.py | 128 ++++++++++-------- fast_llm/layers/language_model/head.py | 4 +- fast_llm/layers/transformer/attention.py | 1 - 5 files changed, 94 insertions(+), 87 deletions(-) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 95bcb02af..29acaadf0 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -1,5 +1,4 @@ import enum -import typing from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig @@ -8,9 +7,6 @@ from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.utils import Assert -if typing.TYPE_CHECKING: - pass - # TODO: Generalize these beyond language models? (Ex. vision) diff --git a/fast_llm/layers/common/normalization/config.py b/fast_llm/layers/common/normalization/config.py index 658d00dfc..569d48b0e 100644 --- a/fast_llm/layers/common/normalization/config.py +++ b/fast_llm/layers/common/normalization/config.py @@ -7,10 +7,8 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - import torch - from fast_llm.engine.config_utils.tensor_dim import TensorDim - from fast_llm.layers.common.normalization import LayerNorm, RMSNorm + from fast_llm.layers.common.normalization.normalization import Normalization class NormalizationImplementation(str, enum.Enum): @@ -29,10 +27,18 @@ class NormalizationImplementation(str, enum.Enum): class NormalizationConfig(BaseModelConfig): pass + @property @abc.abstractmethod - def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": + def module_class(self) -> type["Normalization"]: pass + def get_layer( + self, + hidden_dim: "TensorDim", + lr_scale: float | None = None, + ) -> "Normalization": + return self.module_class(self, hidden_dim, lr_scale) + @classmethod def _from_dict( cls, @@ -50,8 +56,11 @@ def _from_dict( class NoNormalizationConfig(NormalizationConfig): _abstract = False - def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": - return torch.nn.Identity() + @property + def module_class(self) -> type["Normalization"]: + from fast_llm.layers.common.normalization.normalization import NoNormalization + + return NoNormalization @config_class() @@ -85,21 +94,6 @@ class LayerNormalizationBaseConfig(NormalizationConfig): valid=check_field(Assert.geq, 0), ) - def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": - from fast_llm.engine.config_utils.initialization import init_uniform_centered_ - - kwargs = { - "hidden_dim": hidden_dim, - "eps": self.epsilon, - "implementation": self.implementation, - "zero_centered": self.zero_centered, - "lr_scale": lr_scale, - } - if self.initialization_range: - mean = 0 if self.zero_centered else 1 - kwargs["weight_init_method"] = init_uniform_centered_(self.initialization_range, mean=mean) - return self.module_class(**kwargs) - @property @abc.abstractmethod def module_class(self): @@ -126,9 +120,9 @@ class LayerNormalizationConfig(LayerNormalizationBaseConfig): @property def module_class(self): - from fast_llm.layers.common.normalization.normalization import LayerNorm + from fast_llm.layers.common.normalization.normalization import LayerNormalization - return LayerNorm + return LayerNormalization @config_class(dynamic_type={NormalizationConfig: "rms_norm"}) @@ -137,6 +131,6 @@ class RMSNormalizationConfig(LayerNormalizationBaseConfig): @property def module_class(self): - from fast_llm.layers.common.normalization.normalization import RMSNorm + from fast_llm.layers.common.normalization.normalization import RMSNormalization - return RMSNorm + return RMSNormalization diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index 06ee11564..7f7d3eb65 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -1,11 +1,20 @@ +import abc + import torch -from fast_llm.engine.config_utils.initialization import init_ones_, init_zeros_ +from fast_llm.config import Configurable +from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_, init_zeros_ from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.normalization import triton_normalization_autograd -from fast_llm.layers.common.normalization.config import NormalizationImplementation +from fast_llm.layers.common.normalization.config import ( + LayerNormalizationConfig, + NoNormalizationConfig, + NormalizationConfig, + NormalizationImplementation, + RMSNormalizationConfig, +) from fast_llm.tensor import ParameterMeta, accumulate_gradient from fast_llm.utils import Assert @@ -139,7 +148,24 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None, return grad_input, None, None, None -class LayerNorm(torch.nn.Module): +class Normalization[ConfigType: NormalizationConfig](Configurable[ConfigType], torch.nn.Module): + def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | None = None): + super().__init__(config) + self._hidden_dim = hidden_dim + self._lr_scale = lr_scale + assert not self._hidden_dim.is_parallel + + @abc.abstractmethod + def forward(self, input_: torch.Tensor) -> torch.Tensor: + pass + + +class NoNormalization[ConfigType: NoNormalizationConfig](Normalization[ConfigType]): + def forward(self, input_: torch.Tensor) -> torch.Tensor: + return input_ + + +class LayerNormalization[ConfigType: LayerNormalizationConfig](Normalization[ConfigType]): """ A layer normalization layer, supporting multiple implementations. Note: Converting input automatically to training dtype to match Apex behaviour, @@ -147,25 +173,17 @@ class LayerNorm(torch.nn.Module): TODO: Review this? """ - def __init__( - self, - hidden_dim: TensorDim, - *, - eps=1e-5, - implementation: NormalizationImplementation = NormalizationImplementation.auto, - weight_init_method=None, - bias_init_method=init_zeros_, - zero_centered: bool = False, - lr_scale: float | None = None, - ): - super().__init__() - assert not hidden_dim.is_parallel - self._eps = eps - self._zero_centered = zero_centered + def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | None = None): + super().__init__(config, hidden_dim, lr_scale) + implementation = self._config.implementation if implementation == NormalizationImplementation.auto: - if _fast_normalization_available and hidden_dim.size in _PERSIST_LN_SIZES and not self._zero_centered: + if ( + _fast_normalization_available + and hidden_dim.size in _PERSIST_LN_SIZES + and not self._config.zero_centered + ): implementation = NormalizationImplementation.fast - elif TritonConfig.TRITON_ENABLED or self._zero_centered: + elif TritonConfig.TRITON_ENABLED or self._config.zero_centered: log_main_rank("Fast layer norm unavailable, using backup triton implementation.") implementation = NormalizationImplementation.triton elif _fused_normalization_available: @@ -174,7 +192,7 @@ def __init__( else: log_main_rank("Fast and fused layer norm unavailable, using backup pytorch implementation.") implementation = NormalizationImplementation.torch - if self._zero_centered: + if self._config.zero_centered: assert implementation == NormalizationImplementation.triton if implementation == NormalizationImplementation.triton: self._forward = self._forward_triton @@ -187,44 +205,49 @@ def __init__( else: raise NotImplementedError(implementation) - if weight_init_method is None: - weight_init_method = init_zeros_ if self._zero_centered else init_ones_ + if self.config.initialization_range: + mean = 0 if self.zero_centered else 1 + weight_init_method = init_uniform_centered_(self.config.initialization_range, mean=mean) + else: + weight_init_method = init_zeros_ if self._config.zero_centered else init_ones_ self.weight = ParameterMeta.from_dims( (hidden_dim,), init_method=weight_init_method, weight_decay=False, auto_grad_accumulation=implementation == NormalizationImplementation.torch, - lr_scale=lr_scale, + lr_scale=self._lr_scale, ) self.bias = ParameterMeta.from_dims( (hidden_dim,), - init_method=bias_init_method, + init_method=init_zeros_, weight_decay=False, auto_grad_accumulation=implementation == NormalizationImplementation.torch, - lr_scale=lr_scale, + lr_scale=self._lr_scale, ) - self.normalized_shape = self.weight.shape + self._normalized_shape = self.weight.shape def forward(self, input_: torch.Tensor) -> torch.Tensor: - return self._forward(input_.view(-1, *self.normalized_shape)).view_as(input_) + return self._forward(input_.view(-1, *self._normalized_shape)).view_as(input_) def _forward_triton(self, input_: torch.Tensor) -> torch.Tensor: return triton_normalization_autograd( - input_, self.weight, self.bias, self._eps, self.training, self._zero_centered + input_, self.weight, self.bias, self._config.epsilon, self.training, self._config.zero_centered ) def _forward_fast(self, input_: torch.Tensor) -> torch.Tensor: - return FastLayerNorm.apply(input_, self.normalized_shape, self.weight, self.bias, self._eps) + return FastLayerNorm.apply(input_, self._normalized_shape, self.weight, self.bias, self._config.epsilon) def _forward_fused(self, input_: torch.Tensor) -> torch.Tensor: - return FusedLayerNorm.apply(input_, self.normalized_shape, self.weight, self.bias, self._eps) + return FusedLayerNorm.apply(input_, self._normalized_shape, self.weight, self.bias, self._config.epsilon) def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: - return torch.layer_norm(input_.to(self.weight.dtype), self.normalized_shape, self.weight, self.bias, self._eps) + return torch.layer_norm( + input_.to(self.weight.dtype), self._normalized_shape, self.weight, self.bias, self._config.epsilon + ) -class RMSNorm(torch.nn.Module): +class RMSNormalization[ConfigType: RMSNormalizationConfig](Configurable[ConfigType], torch.nn.Module): """ A RMS normalization layer. Note: Converting input automatically to training dtype to match Apex behaviour, @@ -232,22 +255,12 @@ class RMSNorm(torch.nn.Module): TODO: Review this? """ - def __init__( - self, - hidden_dim: TensorDim, - *, - eps=1e-5, - implementation: NormalizationImplementation = NormalizationImplementation.auto, - weight_init_method=None, - zero_centered: bool = False, - lr_scale: float | None = None, - ): - super().__init__() + def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | None = None): + super().__init__(config, hidden_dim, lr_scale) assert not hidden_dim.is_parallel - self._eps = eps - self._zero_centered = zero_centered + implementation = self._config.implementation if implementation == NormalizationImplementation.auto: - if TritonConfig.TRITON_ENABLED or self._zero_centered: + if TritonConfig.TRITON_ENABLED or self._config.zero_centered: implementation = NormalizationImplementation.triton elif _fused_normalization_available: log_main_rank("Triton RMS norm unavailable, using fused implementation.") @@ -255,7 +268,7 @@ def __init__( else: log_main_rank("Fused RMS norm unavailable, using backup implementation.") implementation = NormalizationImplementation.torch - if self._zero_centered: + if self._config.zero_centered: assert implementation == NormalizationImplementation.triton if implementation == NormalizationImplementation.triton: self._forward = self._forward_triton @@ -266,8 +279,11 @@ def __init__( else: raise NotImplementedError(implementation) - if weight_init_method is None: - weight_init_method = init_zeros_ if self._zero_centered else init_ones_ + if self.config.initialization_range: + mean = 0 if self.zero_centered else 1 + weight_init_method = init_uniform_centered_(self.config.initialization_range, mean=mean) + else: + weight_init_method = init_zeros_ if self._config.zero_centered else init_ones_ self.weight = ParameterMeta.from_dims( (hidden_dim,), @@ -276,16 +292,18 @@ def __init__( auto_grad_accumulation=True, lr_scale=lr_scale, ) - self.normalized_shape = self.weight.shape + self._normalized_shape = self.weight.shape def forward(self, input_: torch.Tensor) -> torch.Tensor: - return self._forward(input_.view(-1, *self.normalized_shape)).view_as(input_) + return self._forward(input_.view(-1, *self._normalized_shape)).view_as(input_) def _forward_triton(self, input_: torch.Tensor) -> torch.Tensor: - return triton_normalization_autograd(input_, self.weight, None, self._eps, self.training, self._zero_centered) + return triton_normalization_autograd( + input_, self.weight, None, self._config.epsilon, self.training, self._config.zero_centered + ) def _forward_fused(self, input_: torch.Tensor) -> torch.Tensor: - return FusedRMSNorm.apply(input_, self.normalized_shape, self.weight, self._eps) + return FusedRMSNorm.apply(input_, self._normalized_shape, self.weight, self._config.epsilon) def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: - return torch.rms_norm(input_.to(self.weight.dtype), self.normalized_shape, self.weight, self._eps) + return torch.rms_norm(input_.to(self.weight.dtype), self._normalized_shape, self.weight, self._config.epsilon) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 7b1b5f6d8..d0c0eb8f9 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -118,7 +118,7 @@ def forward( tensor_name="Loss", reductions=( (self._distributed_config.get_distributed_dim(DistributedDimNames.data), ReduceOp.AVG), - ), # noqa + ), ) else: return TensorMeta.from_dims(input_.dims[1:], tensor_name="Shared hidden") @@ -262,7 +262,7 @@ def _logits_cross_entropy_forward_backward_split( return None, None else: loss = None - # TODO MTP: allow a _cross_entropy_splits that is not a divisor of the sequence length + # TODO MTP: allow a cross_entropy_splits that is not a divisor of the sequence length grad_output /= self._config.cross_entropy_splits logit_input = input_.flatten(0, -2) if self.training: diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 41d509512..91fca75b8 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -64,7 +64,6 @@ def __init__( lr_scale: float | None, ): super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) - self._use_flash_attention = self._config.do_use_flash_attention(self._distributed_config) self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) From 9ce72e04ead2857adefa2c13430c9cbcb373e506 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 20 Aug 2025 12:51:26 -0400 Subject: [PATCH 56/82] Move files --- fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py | 2 +- fast_llm/layers/{transformer => attention}/__init__.py | 0 .../layers/{transformer => attention}/attention.py | 2 +- fast_llm/layers/{transformer => attention}/block.py | 4 ++-- fast_llm/layers/{transformer => attention}/config.py | 2 +- .../layers/{transformer => attention}/preprocessing.py | 2 +- .../{transformer => attention}/rotary/__init__.py | 0 .../layers/{transformer => attention}/rotary/config.py | 10 +++++----- .../layers/{transformer => attention}/rotary/rotary.py | 4 ++-- fast_llm/layers/block/peft.py | 2 +- fast_llm/layers/common/normalization/normalization.py | 2 +- fast_llm/layers/common/peft/__init__.py | 0 fast_llm/layers/common/{ => peft}/config.py | 2 +- fast_llm/layers/common/{peft.py => peft/lora.py} | 0 fast_llm/layers/language_model/config.py | 4 ++-- fast_llm/models/gpt/conversion.py | 6 +++--- fast_llm/models/gpt/huggingface.py | 2 +- fast_llm/models/gpt/megatron.py | 6 +++--- fast_llm/models/gpt/model.py | 6 +++--- fast_llm/models/ssm/model.py | 2 +- tests/functional/test_triton_kernels.py | 4 ++-- tests/layers/test_lm_head.py | 2 +- tests/test_attention.py | 6 +++--- tests/test_multi_stage.py | 2 +- 24 files changed, 36 insertions(+), 36 deletions(-) rename fast_llm/layers/{transformer => attention}/__init__.py (100%) rename fast_llm/layers/{transformer => attention}/attention.py (99%) rename fast_llm/layers/{transformer => attention}/block.py (77%) rename fast_llm/layers/{transformer => attention}/config.py (99%) rename fast_llm/layers/{transformer => attention}/preprocessing.py (98%) rename fast_llm/layers/{transformer => attention}/rotary/__init__.py (100%) rename fast_llm/layers/{transformer => attention}/rotary/config.py (92%) rename fast_llm/layers/{transformer => attention}/rotary/rotary.py (98%) create mode 100644 fast_llm/layers/common/peft/__init__.py rename fast_llm/layers/common/{ => peft}/config.py (95%) rename fast_llm/layers/common/{peft.py => peft/lora.py} (100%) diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py index 8f4dffedf..439d1da2e 100644 --- a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -16,7 +16,7 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.evaluation.lm_eval.utils import prepare_lm_eval_simple_eval_params, process_lm_eval_results from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM -from fast_llm.layers.transformer.rotary.config import NoRotaryConfig +from fast_llm.layers.attention.rotary.config import NoRotaryConfig logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/transformer/__init__.py b/fast_llm/layers/attention/__init__.py similarity index 100% rename from fast_llm/layers/transformer/__init__.py rename to fast_llm/layers/attention/__init__.py diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/attention/attention.py similarity index 99% rename from fast_llm/layers/transformer/attention.py rename to fast_llm/layers/attention/attention.py index 91fca75b8..8a4c490c9 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -8,11 +8,11 @@ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import wrap_forward_backward +from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockDimNames from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import AttentionConfig, AttentionKwargs from fast_llm.utils import combine_lr_scales, div try: diff --git a/fast_llm/layers/transformer/block.py b/fast_llm/layers/attention/block.py similarity index 77% rename from fast_llm/layers/transformer/block.py rename to fast_llm/layers/attention/block.py index ba593461b..3396a2997 100644 --- a/fast_llm/layers/transformer/block.py +++ b/fast_llm/layers/attention/block.py @@ -2,9 +2,9 @@ import logging import typing +from fast_llm.layers.attention.attention import Attention +from fast_llm.layers.attention.config import AttentionConfig, TransformerConfig from fast_llm.layers.block.block import Block -from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import AttentionConfig, TransformerConfig logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/attention/config.py similarity index 99% rename from fast_llm/layers/transformer/config.py rename to fast_llm/layers/attention/config.py index 02b741723..e5c638adc 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/attention/config.py @@ -6,8 +6,8 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig +from fast_llm.layers.attention.rotary.config import RotaryConfig from fast_llm.layers.block.config import AddLinearBiasChoices, BlockConfig, BlockKwargs -from fast_llm.layers.transformer.rotary.config import RotaryConfig from fast_llm.utils import Assert, div logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/attention/preprocessing.py similarity index 98% rename from fast_llm/layers/transformer/preprocessing.py rename to fast_llm/layers/attention/preprocessing.py index 769177668..24ef3397c 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/attention/preprocessing.py @@ -6,7 +6,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.transformer.config import AttentionConfig, AttentionKwargs +from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/transformer/rotary/__init__.py b/fast_llm/layers/attention/rotary/__init__.py similarity index 100% rename from fast_llm/layers/transformer/rotary/__init__.py rename to fast_llm/layers/attention/rotary/__init__.py diff --git a/fast_llm/layers/transformer/rotary/config.py b/fast_llm/layers/attention/rotary/config.py similarity index 92% rename from fast_llm/layers/transformer/rotary/config.py rename to fast_llm/layers/attention/rotary/config.py index 6cc19fce8..4ebd6c5dc 100644 --- a/fast_llm/layers/transformer/rotary/config.py +++ b/fast_llm/layers/attention/rotary/config.py @@ -10,7 +10,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.layers.transformer.rotary.rotary import DefaultRotary, Llama3Rotary, NoRotary, Rotary, YarnRotary + from fast_llm.layers.attention.rotary.rotary import DefaultRotary, Llama3Rotary, NoRotary, Rotary, YarnRotary @config_class(registry=True) @@ -44,7 +44,7 @@ class NoRotaryConfig(RotaryConfig): @classmethod def _get_configurable_class(self) -> "type[NoRotary]": - from fast_llm.layers.transformer.rotary.rotary import NoRotary + from fast_llm.layers.attention.rotary.rotary import NoRotary return NoRotary @@ -75,7 +75,7 @@ def _validate(self) -> None: warnings.warn("Triton is disabled, but the triton rotary kernel will be used anyway.") def _get_configurable_class(self) -> "type[DefaultRotary]": - from fast_llm.layers.transformer.rotary.rotary import DefaultRotary + from fast_llm.layers.attention.rotary.rotary import DefaultRotary return DefaultRotary @@ -97,7 +97,7 @@ def _validate(self) -> None: Assert.gt(self.high_frequency_factor, self.low_frequency_factor) def _get_configurable_class(self) -> "type[Llama3Rotary]": - from fast_llm.layers.transformer.rotary.rotary import Llama3Rotary + from fast_llm.layers.attention.rotary.rotary import Llama3Rotary return Llama3Rotary @@ -137,6 +137,6 @@ def _validate(self) -> None: super()._validate() def _get_configurable_class(self) -> "type[YarnRotary]": - from fast_llm.layers.transformer.rotary.rotary import YarnRotary + from fast_llm.layers.attention.rotary.rotary import YarnRotary return YarnRotary diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/attention/rotary/rotary.py similarity index 98% rename from fast_llm/layers/transformer/rotary/rotary.py rename to fast_llm/layers/attention/rotary/rotary.py index bbf8b524a..53b24c9bb 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/attention/rotary/rotary.py @@ -8,8 +8,8 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.functional.triton.rotary import triton_rotary_autograd_ -from fast_llm.layers.transformer.config import AttentionKwargs -from fast_llm.layers.transformer.rotary.config import ( +from fast_llm.layers.attention.config import AttentionKwargs +from fast_llm.layers.attention.rotary.config import ( DefaultRotaryConfig, Llama3RotaryConfig, NoRotaryConfig, diff --git a/fast_llm/layers/block/peft.py b/fast_llm/layers/block/peft.py index 66bc675ed..2261a7ea1 100644 --- a/fast_llm/layers/block/peft.py +++ b/fast_llm/layers/block/peft.py @@ -7,7 +7,7 @@ import typing from fast_llm.config import Field, FieldHint, config_class -from fast_llm.layers.common.config import LoRAConfig, NoPeftConfig, PeftConfig +from fast_llm.layers.common.peft.config import LoRAConfig, NoPeftConfig, PeftConfig from fast_llm.utils import div if typing.TYPE_CHECKING: diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index 7f7d3eb65..a7eba72c8 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -247,7 +247,7 @@ def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: ) -class RMSNormalization[ConfigType: RMSNormalizationConfig](Configurable[ConfigType], torch.nn.Module): +class RMSNormalization[ConfigType: RMSNormalizationConfig](Normalization[ConfigType], torch.nn.Module): """ A RMS normalization layer. Note: Converting input automatically to training dtype to match Apex behaviour, diff --git a/fast_llm/layers/common/peft/__init__.py b/fast_llm/layers/common/peft/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/peft/config.py similarity index 95% rename from fast_llm/layers/common/config.py rename to fast_llm/layers/common/peft/config.py index b09672961..ae8ce3ba4 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/peft/config.py @@ -44,7 +44,7 @@ class LoRAConfig(PeftConfig): ) def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": - from fast_llm.layers.common.peft import lora_linear + from fast_llm.layers.common.peft.lora import lora_linear # TODO: Init method? return lora_linear( diff --git a/fast_llm/layers/common/peft.py b/fast_llm/layers/common/peft/lora.py similarity index 100% rename from fast_llm/layers/common/peft.py rename to fast_llm/layers/common/peft/lora.py diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index de3f9f196..df6969cfc 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -3,9 +3,9 @@ from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl +from fast_llm.layers.attention.config import TransformerConfig +from fast_llm.layers.attention.rotary.config import NoRotaryConfig from fast_llm.layers.block.config import BlockKwargs -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.layers.transformer.rotary.config import NoRotaryConfig from fast_llm.utils import Assert diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index e31c70a45..36975dea1 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -24,11 +24,11 @@ from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType +from fast_llm.layers.attention.config import TransformerConfig +from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig +from fast_llm.layers.attention.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex from fast_llm.layers.block.mlp.config import RoutingType from fast_llm.layers.common.normalization.config import LayerNormalizationConfig -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig -from fast_llm.layers.transformer.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex from fast_llm.models.gpt.config import ( DiffusionDreamGPTHuggingfaceCheckpointFormat, DiffusionLlamaGPTHuggingfaceCheckpointFormat, diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index 4e3f258fc..2f99ae4c3 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -9,7 +9,7 @@ from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM -from fast_llm.layers.transformer.config import AttentionKwargs +from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner diff --git a/fast_llm/models/gpt/megatron.py b/fast_llm/models/gpt/megatron.py index 20ed8e828..5d3130549 100644 --- a/fast_llm/models/gpt/megatron.py +++ b/fast_llm/models/gpt/megatron.py @@ -1,7 +1,7 @@ import typing -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig +from fast_llm.layers.attention.config import TransformerConfig +from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: @@ -94,7 +94,7 @@ def _init_attention_megatron( raise NotImplementedError(meta.tensor_name) if isinstance(config.rotary, DefaultRotaryConfig) and config.rotary.complex_format: - from fast_llm.layers.transformer.rotary.config import convert_rotary_real_to_complex + from fast_llm.layers.attention.rotary.config import convert_rotary_real_to_complex # Megatron uses (2, kv_channels/2) for the complex split; we use (kv_channels/2, 2). # TODO: Avoid unnecessarily changing the value and dense tensors. diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 581429467..b13c77724 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -10,15 +10,15 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.layers.attention.block import TransformerBlock +from fast_llm.layers.attention.config import AttentionKwargs +from fast_llm.layers.attention.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.block.mlp.config import MLPLossNames, RoutingType from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor -from fast_llm.layers.transformer.block import TransformerBlock -from fast_llm.layers.transformer.config import AttentionKwargs -from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 4b7785402..9b79e74a3 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -1,8 +1,8 @@ import logging import typing +from fast_llm.layers.attention.block import TransformerBlock from fast_llm.layers.ssm.block import SSMBlock -from fast_llm.layers.transformer.block import TransformerBlock from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType diff --git a/tests/functional/test_triton_kernels.py b/tests/functional/test_triton_kernels.py index 3f4446e4d..5a9065454 100644 --- a/tests/functional/test_triton_kernels.py +++ b/tests/functional/test_triton_kernels.py @@ -23,8 +23,8 @@ from fast_llm.functional.triton.pointwise import triton_add, triton_copy, triton_fill from fast_llm.functional.triton.rotary import triton_rotary_ from fast_llm.functional.triton.sparse_copy import get_sparse_map -from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig -from fast_llm.layers.transformer.rotary.rotary import ( +from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig +from fast_llm.layers.attention.rotary.rotary import ( apply_rotary_embeddings, convert_rotary_complex_to_real, convert_rotary_real_to_complex, diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 8c33aed4d..380ab0550 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -6,10 +6,10 @@ from fast_llm.config import UpdateType from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl +from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead -from fast_llm.layers.transformer.config import AttentionKwargs from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda diff --git a/tests/test_attention.py b/tests/test_attention.py index 7d05e0a66..9564a931f 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -5,10 +5,10 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.attention.attention import Attention +from fast_llm.layers.attention.config import AttentionKwargs, TransformerConfig +from fast_llm.layers.attention.preprocessing import FlashAttnVarlenPreprocessor from fast_llm.layers.block.config import BlockDimNames -from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import AttentionKwargs, TransformerConfig -from fast_llm.layers.transformer.preprocessing import FlashAttnVarlenPreprocessor from fast_llm.utils import Assert diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index 0639ec7ed..56356cf7a 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -3,8 +3,8 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer +from fast_llm.layers.attention.block import TransformerBlock from fast_llm.layers.ssm.block import SSMBlock -from fast_llm.layers.transformer.block import TransformerBlock from fast_llm.utils import Assert from tests.utils.dataset import get_model_test_dataset from tests.utils.model_configs import ModelTestingGroup From 065b34fac5a44d87281c439ff173f1170126564b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 20 Aug 2025 13:12:28 -0400 Subject: [PATCH 57/82] misc --- fast_llm/layers/block/peft.py | 57 +++-------------- fast_llm/layers/common/peft/config.py | 67 +++++++++++++++----- fast_llm/layers/common/peft/peft.py | 88 +++++++++++++++++++++++++++ 3 files changed, 147 insertions(+), 65 deletions(-) create mode 100644 fast_llm/layers/common/peft/peft.py diff --git a/fast_llm/layers/block/peft.py b/fast_llm/layers/block/peft.py index 2261a7ea1..b51d352bc 100644 --- a/fast_llm/layers/block/peft.py +++ b/fast_llm/layers/block/peft.py @@ -2,7 +2,6 @@ TODO: Generalize beyond transformers. """ -import abc import enum import typing @@ -11,14 +10,10 @@ from fast_llm.utils import div if typing.TYPE_CHECKING: - import torch - from fast_llm.layers.common.linear import LinearBase, LinearLike - from fast_llm.tensor import ParameterMeta class TransformerSubLayerName(str, enum.Enum): - # TODO: Use this to replace AddLinearBiasChoices. query = "query" key = "key" value_ = "value" @@ -30,18 +25,6 @@ class TransformerSubLayerName(str, enum.Enum): @config_class(registry=True) class TransformerPeftConfig(PeftConfig): - @abc.abstractmethod - def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": - pass - - @abc.abstractmethod - def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": - pass - - @abc.abstractmethod - def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": - pass - @classmethod def _from_dict( cls, @@ -57,16 +40,7 @@ def _from_dict( @config_class(dynamic_type={TransformerPeftConfig: "none"}) class TransformerNoPeftConfig(NoPeftConfig, TransformerPeftConfig): - _abstract = False - - def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": - return super().apply_linear(linear) - - def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": - return module - - def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": - return parameter + pass @config_class(dynamic_type={TransformerPeftConfig: "lora"}) @@ -76,33 +50,18 @@ class TransformerLoRAConfig(LoRAConfig, TransformerPeftConfig): desc="The layers on which to apply LoRA.", hint=FieldHint.feature, ) - freeze_others: bool = Field( - default=True, - desc="Whether to freeze other layers during training.", - ) def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": + out_channel_begin, out_channel_end = None, None if layer_type is None or self.layers is None or layer_type in self.layers: + enabled = True if layer_type == TransformerSubLayerName.key: - return super().apply_linear(linear, out_channel_end=div(linear._out_dim.global_size, 2)) + out_channel_end = div(linear._out_dim.global_size, 2) elif layer_type == TransformerSubLayerName.value_: - return super().apply_linear(linear, out_channel_begin=div(linear._out_dim.global_size, 2)) - else: - return super().apply_linear(linear) - elif self.freeze_others: - linear.weight.requires_grad = False - return linear - - def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": - if self.freeze_others: - for parameter in module.parameters(): - parameter.requires_grad = False - return module - - def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": - if self.freeze_others: - parameter.requires_grad = False - return parameter + out_channel_begin = div(linear._out_dim.global_size, 2) + else: + enabled = False + return super().apply_linear(linear, enabled, out_channel_begin, out_channel_end) def _validate(self) -> None: super()._validate() diff --git a/fast_llm/layers/common/peft/config.py b/fast_llm/layers/common/peft/config.py index ae8ce3ba4..4b06623ba 100644 --- a/fast_llm/layers/common/peft/config.py +++ b/fast_llm/layers/common/peft/config.py @@ -5,23 +5,41 @@ from fast_llm.engine.base_model.config import BaseModelConfig if typing.TYPE_CHECKING: + import torch + from fast_llm.layers.common.linear import LinearBase, LinearLike + from fast_llm.layers.common.normalization.normalization import Normalization + from fast_llm.tensor import ParameterMeta @config_class() class PeftConfig(BaseModelConfig): @abc.abstractmethod - def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": - pass + def apply_linear( + self, + module: "LinearBase", + enabled: bool, + out_channel_begin: int | None = None, + out_channel_end: int | None = None, + ) -> "LinearLike": + return self.apply_other(module) + + def apply_normalization(self, module: "Normalization") -> "Normalization": + return self.apply_other(module) + + def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": + for parameter in module.parameters(): + self.apply_weight(parameter) + return module + + def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": + return parameter @config_class() class NoPeftConfig(PeftConfig): _abstract = False - def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": - return linear - @config_class() class LoRAConfig(PeftConfig): @@ -42,17 +60,34 @@ class LoRAConfig(PeftConfig): desc="Dropout rate for LoRA.", hint=FieldHint.stability, ) + freeze_others: bool = Field( + default=True, + desc="Whether to freeze other layers during training.", + ) + + def apply_linear( + self, + module: "LinearBase", + enabled: bool, + out_channel_begin: int | None = None, + out_channel_end: int | None = None, + ) -> "LinearLike": + if not enabled: + return self.apply_other(module) - def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": - from fast_llm.layers.common.peft.lora import lora_linear + from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear + from fast_llm.layers.common.peft.peft import lora_linear + + if isinstance(module, InputParallelLinear): + # TODO: Support InputParallelLinear (different output format). + raise NotImplementedError("LoRA not supported for InputParallelLinear.") + elif isinstance(module, OutputParallelLinear): + assert out_channel_begin is None and out_channel_end is None # TODO: Init method? - return lora_linear( - linear, - linear.weight.param_init_method, - linear.weight.param_init_method, - self.rank, - self.alpha, - self.dropout, - **kwargs, - ) + return lora_linear(module, self.rank, self.alpha, self.dropout, out_channel_begin, out_channel_end) + + def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": + if self.freeze_others: + parameter.requires_grad = False + return parameter diff --git a/fast_llm/layers/common/peft/peft.py b/fast_llm/layers/common/peft/peft.py new file mode 100644 index 000000000..9e0ca0dd0 --- /dev/null +++ b/fast_llm/layers/common/peft/peft.py @@ -0,0 +1,88 @@ +import typing + +import torch + +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.functional.autograd import wrap_forward_backward +from fast_llm.layers.common.linear import Linear, LinearBase + + +def lora_linear( + module: LinearBase, + rank: int, + alpha: float, + dropout: float = 0.0, + out_channel_begin: int | None = None, + out_channel_end: int | None = None, +): + module.weight.requires_grad = False + in_dim = module._in_dim + assert not in_dim.is_parallel, "LoRA not supported with tensor parallelism." + if in_dim.parallel_dim is not None: + in_dim = TensorDim(in_dim.name, in_dim.global_size) + out_dim = module._out_dim + assert not out_dim.is_parallel, "LoRA not supported with tensor parallelism." + if out_dim.parallel_dim is not None: + out_dim = TensorDim(out_dim.name, out_dim.global_size) + if out_channel_begin is not None or out_channel_end is not None: + if out_channel_begin is None: + out_channel_begin = 0 + if out_channel_end is None: + out_channel_end = out_dim.global_size + # TODO: This won't work with TP. Use Composite dim structure for proper split? + out_dim = TensorDim(out_dim.name, out_channel_end - out_channel_begin) + + middle_dim = TensorDim("lora_middle", rank) + + module.lora_0 = Linear( + in_dim, + middle_dim, + bias=False, + weight_init_method=module.weight.param_init_method, + transposed_weight=module.transposed_weight, + lr_scale=module.weight.lr_scale, + ) + module.lora_1 = Linear( + middle_dim, + out_dim, + bias=False, + weight_init_method=module.weight.param_init_method, + transposed_weight=module.transposed_weight, + lr_scale=module.weight.lr_scale, + ) + # TODO: Implement proper backward pass. + module.lora_0.weight.auto_grad_accumulation = True + module.lora_1.weight.auto_grad_accumulation = True + + old_forward = module._forward + + def forward_only(input_: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + # TODO: torch compile? + input_ = input_.detach().requires_grad_() + with torch.enable_grad(): + output = old_forward(input_) + if isinstance(output, tuple): + layer_out, tp_bias = output[0] + assert tp_bias is None + lora_out = (alpha / rank) * module.lora_1( + module.lora_0(torch.dropout(input_, dropout, module.training) if dropout > 0.0 else input_) + ) + if out_channel_begin is None: + output = output + lora_out + else: + output.view(-1, layer_out.size(-1))[:, out_channel_begin:out_channel_end] += lora_out + return output.detach(), (input_, output) + + def backward( + grad_output: torch.Tensor, context: torch.Tensor + ) -> tuple[torch.Tensor, typing.Callable[[], None] | None]: + # TODO: Implement proper backward pass. + input_, output = context + output.backward(grad_output) + return input_.grad + + module._forward = wrap_forward_backward(forward_only, backward) + module.forward_only = forward_only + module.backward = backward + + return module From 4510b7b1a20aea4cb0348aeb233f704c2fcc30cf Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 20 Aug 2025 13:15:19 -0400 Subject: [PATCH 58/82] misc --- fast_llm/layers/common/peft/config.py | 2 +- fast_llm/layers/common/peft/lora.py | 44 +++++++------- fast_llm/layers/common/peft/peft.py | 88 --------------------------- 3 files changed, 22 insertions(+), 112 deletions(-) delete mode 100644 fast_llm/layers/common/peft/peft.py diff --git a/fast_llm/layers/common/peft/config.py b/fast_llm/layers/common/peft/config.py index 4b06623ba..12e1810ff 100644 --- a/fast_llm/layers/common/peft/config.py +++ b/fast_llm/layers/common/peft/config.py @@ -76,7 +76,7 @@ def apply_linear( return self.apply_other(module) from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear - from fast_llm.layers.common.peft.peft import lora_linear + from fast_llm.layers.common.peft.lora import lora_linear if isinstance(module, InputParallelLinear): # TODO: Support InputParallelLinear (different output format). diff --git a/fast_llm/layers/common/peft/lora.py b/fast_llm/layers/common/peft/lora.py index 87991ef29..9e0ca0dd0 100644 --- a/fast_llm/layers/common/peft/lora.py +++ b/fast_llm/layers/common/peft/lora.py @@ -8,21 +8,19 @@ def lora_linear( - layer: LinearBase, - init_method_0, - init_method_1, + module: LinearBase, rank: int, alpha: float, dropout: float = 0.0, out_channel_begin: int | None = None, out_channel_end: int | None = None, ): - layer.weight.requires_grad = False - in_dim = layer._in_dim + module.weight.requires_grad = False + in_dim = module._in_dim assert not in_dim.is_parallel, "LoRA not supported with tensor parallelism." if in_dim.parallel_dim is not None: in_dim = TensorDim(in_dim.name, in_dim.global_size) - out_dim = layer._out_dim + out_dim = module._out_dim assert not out_dim.is_parallel, "LoRA not supported with tensor parallelism." if out_dim.parallel_dim is not None: out_dim = TensorDim(out_dim.name, out_dim.global_size) @@ -36,27 +34,27 @@ def lora_linear( middle_dim = TensorDim("lora_middle", rank) - layer.lora_0 = Linear( + module.lora_0 = Linear( in_dim, middle_dim, bias=False, - weight_init_method=init_method_0, - transposed_weight=layer.transposed_weight, - lr_scale=layer.weight.lr_scale, + weight_init_method=module.weight.param_init_method, + transposed_weight=module.transposed_weight, + lr_scale=module.weight.lr_scale, ) - layer.lora_1 = Linear( + module.lora_1 = Linear( middle_dim, out_dim, bias=False, - weight_init_method=init_method_1, - transposed_weight=layer.transposed_weight, - lr_scale=layer.weight.lr_scale, + weight_init_method=module.weight.param_init_method, + transposed_weight=module.transposed_weight, + lr_scale=module.weight.lr_scale, ) # TODO: Implement proper backward pass. - layer.lora_0.weight.auto_grad_accumulation = True - layer.lora_1.weight.auto_grad_accumulation = True + module.lora_0.weight.auto_grad_accumulation = True + module.lora_1.weight.auto_grad_accumulation = True - old_forward = layer._forward + old_forward = module._forward def forward_only(input_: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: # TODO: torch compile? @@ -66,8 +64,8 @@ def forward_only(input_: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor if isinstance(output, tuple): layer_out, tp_bias = output[0] assert tp_bias is None - lora_out = (alpha / rank) * layer.lora_1( - layer.lora_0(torch.dropout(input_, dropout, layer.training) if dropout > 0.0 else input_) + lora_out = (alpha / rank) * module.lora_1( + module.lora_0(torch.dropout(input_, dropout, module.training) if dropout > 0.0 else input_) ) if out_channel_begin is None: output = output + lora_out @@ -83,8 +81,8 @@ def backward( output.backward(grad_output) return input_.grad - layer._forward = wrap_forward_backward(forward_only, backward) - layer.forward_only = forward_only - layer.backward = backward + module._forward = wrap_forward_backward(forward_only, backward) + module.forward_only = forward_only + module.backward = backward - return layer + return module diff --git a/fast_llm/layers/common/peft/peft.py b/fast_llm/layers/common/peft/peft.py deleted file mode 100644 index 9e0ca0dd0..000000000 --- a/fast_llm/layers/common/peft/peft.py +++ /dev/null @@ -1,88 +0,0 @@ -import typing - -import torch - -from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.functional.autograd import wrap_forward_backward -from fast_llm.layers.common.linear import Linear, LinearBase - - -def lora_linear( - module: LinearBase, - rank: int, - alpha: float, - dropout: float = 0.0, - out_channel_begin: int | None = None, - out_channel_end: int | None = None, -): - module.weight.requires_grad = False - in_dim = module._in_dim - assert not in_dim.is_parallel, "LoRA not supported with tensor parallelism." - if in_dim.parallel_dim is not None: - in_dim = TensorDim(in_dim.name, in_dim.global_size) - out_dim = module._out_dim - assert not out_dim.is_parallel, "LoRA not supported with tensor parallelism." - if out_dim.parallel_dim is not None: - out_dim = TensorDim(out_dim.name, out_dim.global_size) - if out_channel_begin is not None or out_channel_end is not None: - if out_channel_begin is None: - out_channel_begin = 0 - if out_channel_end is None: - out_channel_end = out_dim.global_size - # TODO: This won't work with TP. Use Composite dim structure for proper split? - out_dim = TensorDim(out_dim.name, out_channel_end - out_channel_begin) - - middle_dim = TensorDim("lora_middle", rank) - - module.lora_0 = Linear( - in_dim, - middle_dim, - bias=False, - weight_init_method=module.weight.param_init_method, - transposed_weight=module.transposed_weight, - lr_scale=module.weight.lr_scale, - ) - module.lora_1 = Linear( - middle_dim, - out_dim, - bias=False, - weight_init_method=module.weight.param_init_method, - transposed_weight=module.transposed_weight, - lr_scale=module.weight.lr_scale, - ) - # TODO: Implement proper backward pass. - module.lora_0.weight.auto_grad_accumulation = True - module.lora_1.weight.auto_grad_accumulation = True - - old_forward = module._forward - - def forward_only(input_: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - # TODO: torch compile? - input_ = input_.detach().requires_grad_() - with torch.enable_grad(): - output = old_forward(input_) - if isinstance(output, tuple): - layer_out, tp_bias = output[0] - assert tp_bias is None - lora_out = (alpha / rank) * module.lora_1( - module.lora_0(torch.dropout(input_, dropout, module.training) if dropout > 0.0 else input_) - ) - if out_channel_begin is None: - output = output + lora_out - else: - output.view(-1, layer_out.size(-1))[:, out_channel_begin:out_channel_end] += lora_out - return output.detach(), (input_, output) - - def backward( - grad_output: torch.Tensor, context: torch.Tensor - ) -> tuple[torch.Tensor, typing.Callable[[], None] | None]: - # TODO: Implement proper backward pass. - input_, output = context - output.backward(grad_output) - return input_.grad - - module._forward = wrap_forward_backward(forward_only, backward) - module.forward_only = forward_only - module.backward = backward - - return module From 9a2a7a27018f23f385d566bd9a94bf4affc02813 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 21 Aug 2025 14:41:55 -0400 Subject: [PATCH 59/82] Pr comments --- fast_llm/layers/ssm/discrete_mamba2.py | 62 ++++++++++++++-------- fast_llm/layers/ssm/mamba2.py | 51 +++++++++++++----- fast_llm/layers/ssm/mamba_layer.py | 36 ++++++++----- fast_llm/layers/transformer/transformer.py | 2 +- 4 files changed, 104 insertions(+), 47 deletions(-) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index c9d555de9..b895412b5 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -4,14 +4,21 @@ import einops import torch -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace +from fast_llm.engine.config_utils.tensor_space import ( + CompositeTensorDim, + ConcatenatedTensorDim, + DefaultDimNames, + TensorDim, + TensorSpace, +) +from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.ssm.config import SSMConfig from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ -from fast_llm.utils import get_lr_scale +from fast_llm.utils import div, get_lr_scale logger = logging.getLogger(__name__) @@ -49,25 +56,41 @@ def __init__( layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - inner_dim = tensor_space[SSMDimNames.composite_heads_and_head_dim] hidden_dim = tensor_space[TransformerDimNames.hidden] - conv1d_dim = tensor_space[SSMDimNames.concatenated_convolution] - heads_dim = tensor_space[SSMDimNames.composite_heads] + state_dim = TensorDim("state", self._config.state_size) + v_head_size_dim = TensorDim("v_head_size", div(self._config.d_inner, self._config.n_v_heads)) + + head_groups_dim = TensorDim( + "head_groups", + self._config.n_qk_heads, + self._distributed_config.get_distributed_dim(DistributedDimNames.tensor), + ) + group_heads_dim = TensorDim("group_heads", div(self._config.n_v_heads, self._config.n_qk_heads)) + heads_dim = CompositeTensorDim("heads", (head_groups_dim, group_heads_dim)) + inner_dim = CompositeTensorDim("inner", (head_groups_dim, group_heads_dim, v_head_size_dim)) + bc_dim = CompositeTensorDim("bc", (head_groups_dim, state_dim)) + convolution_kernel_dim = TensorDim("convolution_kernel", self._config.conv_kernel_dimension) + + inner_projection_dim = ConcatenatedTensorDim( + "inner_projection", + (inner_dim, bc_dim, bc_dim, inner_dim, heads_dim), + ) + convolution_dim = ConcatenatedTensorDim("convolution", (inner_dim, bc_dim, bc_dim)) # local_head_groups = head_groups / TP - self._local_head_groups = tensor_space[SSMDimNames.head_groups].size + self._local_head_groups = head_groups_dim.size # local_heads = local_head_groups * group_heads self._local_heads = heads_dim.size # local_inner_size = local_heads * head_size self._local_inner_size = inner_dim.size # local_bc_size = local_head_groups * state - self._local_bc_size = tensor_space[SSMDimNames.composite_head_groups_and_state].size + self._local_bc_size = bc_dim.size # TODO: double check initializations # Projections self.in_proj = OutputParallelLinear( hidden_dim, - tensor_space[SSMDimNames.concatenated_inner_projection], + inner_projection_dim, bias=config.add_bias_linear, weight_init_method=init_kaiming_(transformer_config.hidden_size), sequence_parallel=self._sequence_parallel, @@ -82,15 +105,17 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( ( - conv1d_dim, + convolution_dim, tensor_space[DefaultDimNames.scalar], - tensor_space[SSMDimNames.convolution_kernel], + convolution_kernel_dim, + ), + init_method=init_uniform_centered_( + (convolution_dim.global_size * self._config.conv_kernel_dimension) ** -0.5 ), - init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( - (conv1d_dim,), + (convolution_dim,), init_method=init_uniform_centered_(self._config.conv_kernel_dimension**-0.5), lr_scale=lr_scale, ) @@ -122,14 +147,12 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ input_ = torch.nn.functional.pad(input_, (0, 0, 0, padded_length - sequence_length)) # inner_projection : (batch/local_or_padded_sequence, local_sequence/batch, hidden) - # -> (batch/local_or_padded_sequence, local_sequence/batch, inner_projection) - # inner_projection: (batch, local_or_padded_sequence, hidden) -> (batch, padded_sequence, local_inner_size) + # -> (batch/padded_sequence, sequence/batch, local_inner_projection) inner_projection = self.in_proj(input_) - # Standardize to (batch, padded_sequence, inner_projection) + # Standardize to (batch, padded_sequence, local_inner_projection) if kwargs[TransformerKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) - print("QAIKOFNMJOWENM inner_projection", inner_projection.shape) xBC, z, A_log = torch.split( inner_projection, [ @@ -139,9 +162,6 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ ], dim=-1, ) - print("QAIKOFNMJOWENM xBC", xBC.shape, self._local_inner_size, self._local_bc_size) - print("QAIKOFNMJOWENM z", z.shape) - print("QAIKOFNMJOWENM A_log", A_log.shape) # Convolutional layer # xbc: (batch, padded_sequence, local_heads * head_size + 2 * local_head_groups * state) @@ -189,8 +209,6 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ # out_proj: (batch/sequence, sequence/batch, local_heads * head_size) # -> (batch/local_sequence, local_sequence/batch, hidden) a, b = self.out_proj(y) - logger.info(f"EKFBN y {y.shape}") - logger.info(f"EKFBN a {a.shape}") return self.out_proj(y) @torch.compile diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 77c1b3869..babbe6e05 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -3,7 +3,14 @@ import torch -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import ( + CompositeTensorDim, + ConcatenatedTensorDim, + DefaultDimNames, + TensorDim, + TensorSpace, +) +from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames @@ -62,13 +69,33 @@ def __init__( layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - inner_dim: TensorDim = tensor_space[SSMDimNames.composite_heads_and_head_dim] - xb_dim = tensor_space[SSMDimNames.composite_head_groups_and_state] + num_heads = div(self._config.d_inner, self._config.state_size) + num_head_groups = div(self._config.d_xb, self._config.state_size) + hidden_dim: TensorDim = tensor_space[TransformerDimNames.hidden] - dt_rank_dim = tensor_space[SSMDimNames.dt_rank] + state_dim = TensorDim("state", self._config.state_size) + + head_groups_dim = TensorDim( + "head_groups", num_head_groups, self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + ) + group_heads_dim = TensorDim("group_heads", div(num_heads, num_head_groups)) + + heads_dim = CompositeTensorDim("heads", (head_groups_dim, group_heads_dim)) + + inner_dim = CompositeTensorDim("inner", (head_groups_dim, group_heads_dim, state_dim)) + xb_dim = CompositeTensorDim("xb", (head_groups_dim, state_dim)) + convolution_kernel_dim = TensorDim("convolution_kernel", self._config.conv_kernel_dimension) - self._local_heads = tensor_space[SSMDimNames.composite_heads].size - self._local_head_groups = tensor_space[SSMDimNames.head_groups].size + # DT projection + dt_rank_dim = TensorDim("dt_rank", self._config.dt_rank) + + inner_projection_dim = ConcatenatedTensorDim( + "inner_projection", + (inner_dim, xb_dim, xb_dim, inner_dim), + ) + + self._local_heads = heads_dim.size + self._local_head_groups = head_groups_dim.size self._group_heads = div(self._local_heads, self._local_head_groups) self._local_inner_size = inner_dim.size self._local_xb_size = xb_dim.size @@ -78,7 +105,7 @@ def __init__( ( conv1d_dim, tensor_space[DefaultDimNames.scalar], - tensor_space[SSMDimNames.convolution_kernel], + convolution_kernel_dim, ), init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, @@ -90,7 +117,7 @@ def __init__( ) self.in_proj = OutputParallelLinear( hidden_dim, - tensor_space[SSMDimNames.concatenated_inner_projection], + convolution_kernel_dim, bias=config.add_bias_linear, weight_init_method=init_kaiming_(transformer_config.hidden_size), sequence_parallel=self._sequence_parallel, @@ -122,7 +149,7 @@ def __init__( lr_scale=lr_scale, ) self.A_log = ParameterMeta.from_dims( - (inner_dim, tensor_space[SSMDimNames.state]), + (inner_dim, convolution_kernel_dim), init_method=init_A(self._config.state_size, self._config.d_inner), lr_scale=lr_scale, weight_decay=False, @@ -139,7 +166,7 @@ def __init__( bias=config.add_bias_linear, weight_init_method=init_kaiming_(self._config.d_inner), sequence_parallel=self._sequence_parallel, - # TODO: lr_scale? + lr_scale=lr_scale, ) def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: @@ -147,10 +174,10 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ assert _causal_conv1d_available # inner_projection : (batch/local_sequence, local_sequence/batch, hidden) - # -> (batch/sequence, sequence/batch, inner_projection) + # -> (batch/sequence, sequence/batch, local_inner_projection) inner_projection = self.in_proj(input_) dt = self.dt_proj(self.dt_in_proj(input_)) + self.dt_proj_bias - # Standardize to (batch, sequence, inner_projection) + # Standardize to (batch, sequence, local_inner_projection) if kwargs[TransformerKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) dt = dt.transpose(0, 1) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 9343ef1b8..061921b3d 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -4,14 +4,20 @@ import torch -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace +from fast_llm.engine.config_utils.tensor_space import ( + CompositeTensorDim, + ConcatenatedTensorDim, + DefaultDimNames, + TensorDim, + TensorSpace, +) from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import Linear -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.ssm.config import SSMConfig from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import LambdaInitializer, ParameterMeta, init_kaiming_, init_ones_ -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.utils import Assert, div, get_lr_scale try: from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn as _mamba_inner_fn # noqa @@ -67,27 +73,33 @@ def __init__( self._config = config # TODO: It's not silu? Assert.eq(self._config.activation_type, ActivationType.silu) + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None + lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) # Tensor dims: - inner_dim = tensor_space[SSMDimNames.composite_heads_and_head_dim] hidden_dim = tensor_space[TransformerDimNames.hidden] - layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + heads_dim = TensorDim("heads", div(self._config.d_inner, self._config.state_size)) + state_dim = TensorDim("state", self._config.state_size) + inner_dim = CompositeTensorDim("inner", (heads_dim, state_dim)) + convolution_kernel_dim = TensorDim("convolution_kernel", self._config.conv_kernel_dimension) + dt_rank_dim = TensorDim("dt_rank", self._config.dt_rank) + inner_projection_dim = ConcatenatedTensorDim("inner_projection", (inner_dim, inner_dim)) + x_projection_dim = ConcatenatedTensorDim("x_projection", (dt_rank_dim, state_dim, state_dim)) # TODO: Backward compatibility? - # TODO: lr_scale? self.in_proj = Linear( hidden_dim, - tensor_space[SSMDimNames.concatenated_inner_projection], + inner_projection_dim, bias=False, weight_init_method=init_kaiming_(hidden_dim.size), + lr_scale=lr_scale, ) self.conv1d_weight = ParameterMeta.from_dims( ( inner_dim, tensor_space[DefaultDimNames.scalar], - tensor_space[SSMDimNames.convolution_kernel], + convolution_kernel_dim, ), init_method=init_kaiming_(inner_dim.size), lr_scale=lr_scale, @@ -95,7 +107,7 @@ def __init__( self.x_proj = Linear( inner_dim, - tensor_space[SSMDimNames.concatenated_x_projection], + x_projection_dim, weight_init_method=init_kaiming_(inner_dim.size), bias=False, lr_scale=lr_scale, @@ -104,7 +116,7 @@ def __init__( # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 self.dt_proj_weight = ParameterMeta.from_dims( - (inner_dim, tensor_space[SSMDimNames.dt_rank]), + (inner_dim, dt_rank_dim), init_method=init_kaiming_(self._config.dt_rank), lr_scale=lr_scale, ) @@ -116,7 +128,7 @@ def __init__( ) self.A_log = ParameterMeta.from_dims( - (inner_dim, tensor_space[SSMDimNames.state]), + (inner_dim, state_dim), weight_decay=False, init_method=init_A(self._config.state_size, inner_dim.size), lr_scale=lr_scale, diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 75d06f268..63f3aaab6 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -98,7 +98,7 @@ def __init__( self._block_index = block_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory hidden_dim = self._tensor_space[TransformerDimNames.hidden] - # Note, layer_lr_scale does not impact the norms + # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) From 8c382a902b91feff0300476be0164773dc47a807 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 21 Aug 2025 14:43:40 -0400 Subject: [PATCH 60/82] Cleanup --- fast_llm/layers/ssm/config.py | 81 ++--------------------------------- 1 file changed, 4 insertions(+), 77 deletions(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 9b0949d55..e6f87cf27 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -2,11 +2,10 @@ import typing from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, TensorSpace -from fast_llm.engine.distributed.config import DistributedDimNames +from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig -from fast_llm.utils import Assert, div +from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.tensor import Initializer @@ -212,77 +211,5 @@ def _validate(self) -> None: Assert.geq(self.dt_max, self.dt_min) def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType) -> None: - tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - - # Head groups are configured differently depending on the block type. - if block_type == SSMBlockType.mamba: - num_heads = div(self.d_inner, self.state_size) - num_head_groups = num_heads - elif block_type == SSMBlockType.mamba2: - num_heads = div(self.d_inner, self.state_size) - num_head_groups = div(self.d_xb, self.state_size) - elif block_type == SSMBlockType.mamba2_discrete: - # TODO: Use different variables? - num_heads = self.n_v_heads - num_head_groups = self.n_qk_heads - else: - raise NotImplementedError(block_type) - - tensor_space.add_tensor_dim(state := TensorDim(SSMDimNames.state, self.state_size)) - if block_type == SSMBlockType.mamba2_discrete: - tensor_space.add_tensor_dim(head_dim := TensorDim(SSMDimNames.head_dim, div(self.d_inner, num_heads))) - else: - head_dim = state - - tensor_space.add_tensor_dim(head_groups := TensorDim(SSMDimNames.head_groups, num_head_groups, tensor)) - tensor_space.add_tensor_dim(group_heads := TensorDim(SSMDimNames.group_heads, div(num_heads, num_head_groups))) - tensor_space.add_tensor_dim( - heads := CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads)) - ) - tensor_space.add_tensor_dim( - heads_and_head_dim := CompositeTensorDim( - SSMDimNames.composite_heads_and_head_dim, (head_groups, group_heads, head_dim) - ) - ) - tensor_space.add_tensor_dim( - head_groups_and_state := CompositeTensorDim( - SSMDimNames.composite_head_groups_and_state, (head_groups, state) - ) - ) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.convolution_kernel, self.conv_kernel_dimension)) - - # DT projection - if block_type in (SSMBlockType.mamba, SSMBlockType.mamba2): - tensor_space.add_tensor_dim(dt_rank := TensorDim(SSMDimNames.dt_rank, self.dt_rank)) - - if block_type == SSMBlockType.mamba: - tensor_space.add_tensor_dim( - ConcatenatedTensorDim(SSMDimNames.concatenated_x_projection, (dt_rank, state, state)) - ) - # TODO: Use composition instead - tensor_space.add_tensor_dim( - ConcatenatedTensorDim( - SSMDimNames.concatenated_inner_projection, (heads_and_head_dim, heads_and_head_dim) - ) - ) - elif block_type == SSMBlockType.mamba2: - # TODO: Factor out state? - tensor_space.add_tensor_dim( - ConcatenatedTensorDim( - SSMDimNames.concatenated_inner_projection, - (heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim), - ) - ) - elif block_type == SSMBlockType.mamba2_discrete: - tensor_space.add_tensor_dim( - ConcatenatedTensorDim( - SSMDimNames.concatenated_inner_projection, - (heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim, heads), - ) - ) - tensor_space.add_tensor_dim( - ConcatenatedTensorDim( - SSMDimNames.concatenated_convolution, - (heads_and_head_dim, head_groups_and_state, head_groups_and_state), - ) - ) + # Handled in the model. + pass From 019e43dc6e95a4b2901b1ff3bd8dfacb65af961f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 21 Aug 2025 14:50:13 -0400 Subject: [PATCH 61/82] Cleanup --- fast_llm/layers/ssm/mamba2.py | 38 +++++++++++++++++------------------ 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index babbe6e05..0408479e5 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -13,7 +13,7 @@ from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.ssm.config import SSMConfig from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer @@ -44,18 +44,6 @@ class Mamba2(Mixer): _mixer_name: typing.ClassVar[str] = "mamba_2" - _XZ_DIMS = ( - TransformerDimNames.batch, - SSMDimNames.composite_heads_and_head_dim, - TransformerDimNames.sequence_q, - ) - _BC_DIMS = ( - TransformerDimNames.batch, - SSMDimNames.composite_heads, - SSMDimNames.state, - TransformerDimNames.sequence_q, - ) - def __init__( self, config: SSMConfig, @@ -168,6 +156,18 @@ def __init__( sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) + if self._debug.enabled: + self._xz_dims = ( + TransformerDimNames.batch, + inner_dim, + TransformerDimNames.sequence_q, + ) + self._bc_dims = ( + TransformerDimNames.batch, + heads_dim, + state_dim, + TransformerDimNames.sequence_q, + ) def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available @@ -224,11 +224,11 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ dt = dt.transpose(1, 2) if self._debug_level: - self._debug_log(z, "z", self._XZ_DIMS, kwargs) - self._debug_log(x, "x", self._XZ_DIMS, kwargs) - self._debug_log(b, "b", self._BC_DIMS, kwargs) - self._debug_log(c, "c", self._BC_DIMS, kwargs) - self._debug_log(dt, "dt", self._XZ_DIMS, kwargs) + self._debug_log(z, "z", self._xz_dims, kwargs) + self._debug_log(x, "x", self._xz_dims, kwargs) + self._debug_log(b, "b", self._bc_dims, kwargs) + self._debug_log(c, "c", self._bc_dims, kwargs) + self._debug_log(dt, "dt", self._xz_dims, kwargs) y = selective_scan_fn( x, @@ -243,7 +243,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ ) if self._debug_level: - self._debug_log(y, "y", self._XZ_DIMS, kwargs) + self._debug_log(y, "y", self._xz_dims, kwargs) # y: (batch, local_heads * state, sequence) -> (batch, sequence, local_heads * state) y = y.transpose(1, 2)[:, :sequence_length] From 3e0f3e555ab92eca7a62378d6a5ad366f5118bda Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 21 Aug 2025 14:50:38 -0400 Subject: [PATCH 62/82] Cleanup --- fast_llm/layers/ssm/config.py | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index e6f87cf27..fb178e7d5 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -11,28 +11,6 @@ from fast_llm.tensor import Initializer -class SSMDimNames: - # TODO: Use separate tensor space for different mixers so there is no risk of name conflict. - state = "ssm_state" # State dimension (N), aka head size / num channels - head_dim = "ssm_head_dim" - head_groups = "ssm_head_groups" - group_heads = "ssm_group_heads" - - convolution_kernel = "ssm_convolution_kernel" # Kernel dimension of the conv1d in mamba layers - - dt_rank = "ssm_dt_rank" - - # Composite dimensions - composite_heads = "ssm_composite_heads" - composite_heads_and_head_dim = "ssm_composite_heads_and_head_dim" - composite_head_groups_and_state = "ssm_composite_head_groups_and_state" - - # Concatenated dimensions - concatenated_convolution = "ssm_concatenated_convolution" - concatenated_x_projection = "ssm_x_concatenated_x_projection" - concatenated_inner_projection = "ssm_concatenated_inner_projection" - - class SSMBlockType(enum.StrEnum): """ An enum for the available mamba types for the MLP layer. From 39960cee17e7a047ac4cc7f3e6d33b0fa631ae5f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 21 Aug 2025 15:01:37 -0400 Subject: [PATCH 63/82] Cleanup --- fast_llm/layers/ssm/mamba.py | 4 ++-- fast_llm/layers/ssm/mamba2.py | 22 +++++++++++----------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index 79a0e5c8e..453c14af6 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -68,7 +68,7 @@ def __init__( lr_scale: float | None, ): super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) - assert self._distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" + assert self._distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for Mamba" # TODO: It's not silu? Assert.eq(self._config.activation_type, ActivationType.silu) @@ -84,12 +84,12 @@ def __init__( lr_scale = combine_lr_scales(self._lr_scale, self._config.mamba_lr_scale) # TODO: Backward compatibility? - # TODO: lr_scale? self.in_proj = Linear( hidden_dim, inner_projection_dim, bias=False, weight_init_method=init_kaiming_(hidden_dim.size), + lr_scale=lr_scale, ) self.conv1d_weight = ParameterMeta.from_dims( ( diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index eec134a22..2659e415f 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -149,16 +149,16 @@ def __init__( bias=config.add_bias_linear, weight_init_method=init_kaiming_(self._config.d_inner), sequence_parallel=self._sequence_parallel, - # TODO: lr_scale? + lr_scale=lr_scale, ) if self._debug.enabled: - _xz_dims = ( + self._xz_dims = ( BlockDimNames.batch, inner_dim, BlockDimNames.sequence_q, ) - _bc_dims = ( + self._bc_dims = ( BlockDimNames.batch, heads_dim, state_dim, @@ -176,10 +176,10 @@ def forward( assert _causal_conv1d_available # inner_projection : (batch/local_sequence, local_sequence/batch, hidden) - # -> (batch/sequence, sequence/batch, inner_projection) + # -> (batch/sequence, sequence/batch, local_inner_projection) inner_projection = self.in_proj(input_) dt = self.dt_proj(self.dt_in_proj(input_)) + self.dt_proj_bias - # Standardize to (batch, sequence, inner_projection) + # Standardize to (batch, sequence, local_inner_projection) if kwargs[BlockKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) dt = dt.transpose(0, 1) @@ -226,11 +226,11 @@ def forward( dt = dt.transpose(1, 2) if self._debug.enabled: - self._debug(z, "z", self._XZ_DIMS, kwargs) - self._debug(x, "x", self._XZ_DIMS, kwargs) - self._debug(b, "b", self._BC_DIMS, kwargs) - self._debug(c, "c", self._BC_DIMS, kwargs) - self._debug(dt, "dt", self._XZ_DIMS, kwargs) + self._debug(z, "z", self._xz_dims, kwargs) + self._debug(x, "x", self._xz_dims, kwargs) + self._debug(b, "b", self._bc_dims, kwargs) + self._debug(c, "c", self._bc_dims, kwargs) + self._debug(dt, "dt", self._xz_dims, kwargs) y = selective_scan_fn( x, @@ -245,7 +245,7 @@ def forward( ) if self._debug.enabled: - self._debug(y, "y", self._XZ_DIMS, kwargs) + self._debug(y, "y", self._xz_dims, kwargs) # y: (batch, local_heads * state, sequence) -> (batch, sequence, local_heads * state) y = y.transpose(1, 2)[:, :sequence_length] From 1abdd19280f8cf7104236be375aa14ceeea235ee Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 21 Aug 2025 15:09:24 -0400 Subject: [PATCH 64/82] fixes --- fast_llm/layers/common/config.py | 4 ++-- fast_llm/layers/ssm/discrete_mamba2.py | 2 +- fast_llm/layers/ssm/mamba2.py | 6 ++++-- fast_llm/layers/transformer/transformer.py | 5 +++-- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 07dadbc22..710b2668f 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -43,7 +43,7 @@ class NormalizationConfig(BaseModelConfig): pass @abc.abstractmethod - def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": + def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None) -> "torch.nn.Module": pass @classmethod @@ -63,7 +63,7 @@ def _from_dict( class NoNormalizationConfig(NormalizationConfig): _abstract = False - def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": + def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None) -> "torch.nn.Module": return torch.nn.Identity() diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index b895412b5..47a94214a 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -63,7 +63,7 @@ def __init__( head_groups_dim = TensorDim( "head_groups", self._config.n_qk_heads, - self._distributed_config.get_distributed_dim(DistributedDimNames.tensor), + self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor), ) group_heads_dim = TensorDim("group_heads", div(self._config.n_v_heads, self._config.n_qk_heads)) heads_dim = CompositeTensorDim("heads", (head_groups_dim, group_heads_dim)) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 0408479e5..95febb1c6 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -64,7 +64,9 @@ def __init__( state_dim = TensorDim("state", self._config.state_size) head_groups_dim = TensorDim( - "head_groups", num_head_groups, self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + "head_groups", + num_head_groups, + self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor), ) group_heads_dim = TensorDim("group_heads", div(num_heads, num_head_groups)) @@ -105,7 +107,7 @@ def __init__( ) self.in_proj = OutputParallelLinear( hidden_dim, - convolution_kernel_dim, + inner_projection_dim, bias=config.add_bias_linear, weight_init_method=init_kaiming_(transformer_config.hidden_size), sequence_parallel=self._sequence_parallel, diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 63f3aaab6..c7becd948 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -100,8 +100,9 @@ def __init__( hidden_dim = self._tensor_space[TransformerDimNames.hidden] # TODO: add a separate norm_lr_scale - self.norm_1 = self._config.normalization.get_layer(hidden_dim) - self.norm_2 = self._config.normalization.get_layer(hidden_dim) + lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None + self.norm_1 = self._config.normalization.get_layer(hidden_dim, lr_scale) + self.norm_2 = self._config.normalization.get_layer(hidden_dim, lr_scale) # The mixer needs to be created here for backward-compatible weight ordering. setattr(self, self._mixer_module_name, self._create_mixer()) From 7c2429292e5b56a763d39d67096d2931e657098d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 21 Aug 2025 15:49:49 -0400 Subject: [PATCH 65/82] fixes --- fast_llm/layers/ssm/mamba2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 95febb1c6..802d757eb 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -158,7 +158,7 @@ def __init__( sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) - if self._debug.enabled: + if self._debug_level: self._xz_dims = ( TransformerDimNames.batch, inner_dim, From af2964bfee592db2a59789cde413deba3acd3d1d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 21 Aug 2025 15:59:44 -0400 Subject: [PATCH 66/82] fixes --- fast_llm/layers/ssm/mamba2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 802d757eb..7151da394 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -139,7 +139,7 @@ def __init__( lr_scale=lr_scale, ) self.A_log = ParameterMeta.from_dims( - (inner_dim, convolution_kernel_dim), + (inner_dim, state_dim), init_method=init_A(self._config.state_size, self._config.d_inner), lr_scale=lr_scale, weight_decay=False, From 654aeeb4be24eb64fba6f3885d72ebcf4992d532 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 21 Aug 2025 16:09:54 -0400 Subject: [PATCH 67/82] Fix merge --- fast_llm/layers/block/block.py | 10 +++++++--- fast_llm/layers/ssm/discrete_mamba2.py | 6 ++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index b8aad3903..f90fce698 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -163,8 +163,12 @@ def __init__( self._return_input: bool = return_input # Note, layer_lr_scale does not impact the norms # TODO: add a separate norm_lr_scale - self.norm_1 = self._config.peft.apply_other(self._config.normalization.get_layer(self._hidden_dim)) - self.norm_2 = self._config.peft.apply_other(self._config.normalization.get_layer(self._hidden_dim)) + self.norm_1 = self._config.peft.apply_other( + self._config.normalization.get_layer(self._hidden_dim, self._lr_scale) + ) + self.norm_2 = self._config.peft.apply_other( + self._config.normalization.get_layer(self._hidden_dim, self._lr_scale) + ) # Attribute should be mixer, but Attention uses a different name for backward compatibility. TODO: Fix. setattr( @@ -192,7 +196,7 @@ def __init__( self._hidden_dim, self._block_index, f"{self._name} MLP", - lr_scale, + self._lr_scale, ) @functools.cached_property diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 0d91fbaff..f9462a942 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -150,11 +150,9 @@ def forward( assert not kwargs[BlockKwargs.sequence_first] and input_.size(1) == sequence_length input_ = torch.nn.functional.pad(input_, (0, 0, 0, padded_length - sequence_length)) - # inner_projection : (batch/local_or_padded_sequence, local_sequence/batch, hidden) - # -> (batch/local_or_padded_sequence, local_sequence/batch, inner_projection) - # inner_projection: (batch, local_or_padded_sequence, hidden) -> (batch, padded_sequence, local_inner_size) + # -> (batch/padded_sequence, sequence/batch, local_inner_projection inner_projection = self.in_proj(input_) - # Standardize to (batch, padded_sequence, inner_projection) + # Standardize to (batch, padded_sequence, local_inner_projection) if kwargs[BlockKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) From 3f4a8ba8600adba99f791d490c0588f10221068e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 27 Aug 2025 15:52:35 -0400 Subject: [PATCH 68/82] fix --- fast_llm/layers/common/peft/config.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fast_llm/layers/common/peft/config.py b/fast_llm/layers/common/peft/config.py index 12e1810ff..64a2ca57a 100644 --- a/fast_llm/layers/common/peft/config.py +++ b/fast_llm/layers/common/peft/config.py @@ -1,4 +1,3 @@ -import abc import typing from fast_llm.config import Field, FieldHint, config_class @@ -14,7 +13,6 @@ @config_class() class PeftConfig(BaseModelConfig): - @abc.abstractmethod def apply_linear( self, module: "LinearBase", From 9741ba047eea2a8fef7e1bd227f3046d1d075a20 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 27 Aug 2025 16:39:03 -0400 Subject: [PATCH 69/82] stuff --- fast_llm/engine/config_utils/parameter.py | 67 +++++ fast_llm/engine/multi_stage/stage.py | 2 - fast_llm/layers/attention/attention.py | 25 +- fast_llm/layers/attention/config.py | 18 ++ fast_llm/layers/block/mlp/config.py | 19 +- .../layers/block/mlp/mixture_of_experts.py | 7 +- fast_llm/layers/block/mlp/mlp.py | 20 +- fast_llm/layers/block/peft.py | 2 +- fast_llm/layers/common/linear/__init__.py | 0 fast_llm/layers/common/linear/config.py | 179 ++++++++++++ fast_llm/layers/common/linear/convolution.py | 53 ++++ fast_llm/layers/common/{ => linear}/linear.py | 114 ++------ .../common/normalization/normalization.py | 3 - fast_llm/layers/common/peft/config.py | 8 +- fast_llm/layers/common/peft/lora.py | 5 +- fast_llm/layers/language_model/config.py | 17 +- fast_llm/layers/language_model/embedding.py | 14 +- fast_llm/layers/ssm/config.py | 262 +++++++++++++----- fast_llm/layers/ssm/discrete_mamba2.py | 91 ++---- fast_llm/layers/ssm/mamba.py | 88 +++--- fast_llm/layers/ssm/mamba2.py | 97 +++---- fast_llm/models/ssm/config.py | 15 +- fast_llm/tensor.py | 5 - tests/models/distributed_test_model.py | 1 + tests/models/test_checkpoint.py | 1 + tests/utils/model_configs.py | 3 + 26 files changed, 698 insertions(+), 418 deletions(-) create mode 100644 fast_llm/engine/config_utils/parameter.py create mode 100644 fast_llm/layers/common/linear/__init__.py create mode 100644 fast_llm/layers/common/linear/config.py create mode 100644 fast_llm/layers/common/linear/convolution.py rename fast_llm/layers/common/{ => linear}/linear.py (53%) diff --git a/fast_llm/engine/config_utils/parameter.py b/fast_llm/engine/config_utils/parameter.py new file mode 100644 index 000000000..aa84408d2 --- /dev/null +++ b/fast_llm/engine/config_utils/parameter.py @@ -0,0 +1,67 @@ +import typing + +from fast_llm.config import Config, Field, config_class +from fast_llm.engine.config_utils.initialization import Initializer +from fast_llm.engine.config_utils.tensor_dim import TensorDim + +if typing.TYPE_CHECKING: + from fast_llm.tensor import ParameterMeta + + +@config_class() +class ParameterConfig(Config): + # TODO: Initialization, lr_scale + + def _validate(self) -> None: + pass + + def get_parameter( + self, + dims: tuple[TensorDim, ...], + default_initializer: Initializer, + lr_scale: float | None, + weight_decay: bool = True, + allow_sequence_tensor_parallel: bool = True, + ) -> "ParameterMeta": + from fast_llm.tensor import ParameterMeta + + return ParameterMeta.from_dims( + dims, + init_method=default_initializer, + lr_scale=lr_scale, + weight_decay=weight_decay, + allow_sequence_tensor_parallel=allow_sequence_tensor_parallel, + ) + + +@config_class() +class OptionalParameterConfig(ParameterConfig): + enabled: bool | None = Field( + default=None, + ) + # TODO: Initialization, lr_scale + + def _validate(self) -> None: + pass + + def get_parameter( + self, + dims: tuple[TensorDim, ...], + default_initializer: Initializer, + lr_scale: float | None, + weight_decay: bool = True, + allow_sequence_tensor_parallel: bool = True, + default_enabled: bool = False, + ) -> "ParameterMeta|None": + from fast_llm.tensor import ParameterMeta + + if (self.enabled is None and default_enabled) or self.enabled: + return ParameterMeta.from_dims( + dims, + init_method=default_initializer, + lr_scale=lr_scale, + weight_decay=weight_decay, + allow_sequence_tensor_parallel=allow_sequence_tensor_parallel, + ) + else: + return None diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 35547cd87..40ef07f67 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -22,8 +22,6 @@ def _accumulate_grad_hook(buffer: torch.nn.Parameter, meta: ParameterMeta) -> typing.Callable[[tuple, tuple], None]: def hook(grad_inputs, grad_outputs): # noqa if buffer.grad is not None: - if not meta.auto_grad_accumulation: - raise RuntimeError(f"Unexpected grad for parameter {meta.tensor_name}") accumulate_gradient(buffer, buffer.grad) buffer.grad = None diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 8a4c490c9..74cfb6ed4 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -4,7 +4,7 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim -from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_ +from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import wrap_forward_backward @@ -12,7 +12,6 @@ from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockDimNames from fast_llm.layers.block.peft import TransformerSubLayerName -from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.utils import combine_lr_scales, div try: @@ -112,21 +111,20 @@ def __init__( ) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) - self.query = OutputParallelLinear( + self.query = self._config.query_layer.get_layer( hidden_dim, query_dim, - bias=self._config.add_qkv_bias, - weight_init_method=init_method_qkv, - bias_init_method=init_zeros_, + default_weight_initializer=init_method_qkv, + default_add_bias=self._config.add_qkv_bias, sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) - self.key_value = OutputParallelLinear( + # TODO: Use value config. + self.key_value = self._config.query_layer.get_layer( hidden_dim, key_value_dim, - bias=self._config.add_qkv_bias, - weight_init_method=init_method_qkv, - bias_init_method=init_zeros_, + default_weight_initializer=init_method_qkv, + default_add_bias=self._config.add_qkv_bias, sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) @@ -136,12 +134,11 @@ def __init__( self._rotary = self._config.rotary.get_layer(kv_channels_dim) # Output. - self.dense = InputParallelLinear( + self.dense = self._config.dense_layer.get_layer( dense_dim, hidden_dim, - bias=self._config.add_dense_bias, - weight_init_method=init_method_std_attn_proj, - bias_init_method=init_zeros_, + default_weight_initializer=init_method_std_attn_proj, + default_add_bias=self._config.add_dense_bias, sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index e5c638adc..8e4226270 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -8,6 +8,7 @@ from fast_llm.functional.config import TritonConfig from fast_llm.layers.attention.rotary.config import RotaryConfig from fast_llm.layers.block.config import AddLinearBiasChoices, BlockConfig, BlockKwargs +from fast_llm.layers.common.linear.config import AffineLinearConfig from fast_llm.utils import Assert, div logger = logging.getLogger(__name__) @@ -32,6 +33,23 @@ class AttentionConfig(Config): # TODO: Make mixer class dynamic. _abstract = False + query_layer: AffineLinearConfig = Field( + desc="Configuration for the query layer.", + hint=FieldHint.architecture, + ) + key_layer: AffineLinearConfig = Field( + desc="Configuration for the key layer.", + hint=FieldHint.architecture, + ) + # TODO: Use + value_layer: AffineLinearConfig = Field( + desc="Configuration for the value layer.", + hint=FieldHint.architecture, + ) + dense_layer: AffineLinearConfig = Field( + desc="Initialization configuration for the dense layer.", + hint=FieldHint.feature, + ) # TODO: Review names rotary: RotaryConfig = Field( desc="Configuration for the rotary positional embeddings.", diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 88ce4af10..2a4d8e81f 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -3,6 +3,7 @@ from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.functional.config import ActivationType, MLPRecomputeLevel +from fast_llm.layers.common.linear.config import AffineLinearConfig, LinearConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -21,8 +22,24 @@ class RoutingType(str, enum.Enum): @config_class() class MLPConfig(Config): - # TODO: Review names # TODO: Separate MoE? + # TODO: Review names + # TODO: Separate MoE? _abstract = False + # TODO: Configure experts, gate/up separately? + layer_1: AffineLinearConfig = Field( + desc="Configuration for the first MLP layer.", + hint=FieldHint.architecture, + ) + # TODO: Separate gate and up + layer_2: AffineLinearConfig = Field( + desc="Configuration for the second MLP layer.", + hint=FieldHint.architecture, + ) + router: LinearConfig = Field( + # TODO: Improve default? + desc="Configuration for the MoE router.", + hint=FieldHint.feature, + ) ffn_hidden_size: int = Field( default=None, desc="Hidden dimension of the MLP intermediate state. Default: 4 * hidden_size.", diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index 4f7cf2dc4..9298e872b 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -13,7 +13,6 @@ from fast_llm.layers.block.mlp.config import MLPConfig, MLPLossNames, RoutingType from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss -from fast_llm.layers.common.linear import Linear from fast_llm.utils import Assert, combine_lr_scales logger = logging.getLogger(__name__) @@ -47,12 +46,10 @@ def __init__( # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) - - self.router = Linear( + self.router = self._config.router.get_layer( self._hidden_dim, TensorDim("router_experts", self._config.num_unshared_experts), - bias=False, - weight_init_method=init_normal_( + default_weight_initializer=init_normal_( std=self._block_config.init_method_std, min_val=self._block_config.init_method_min, max_val=self._block_config.init_method_max, diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index c3a714a42..4c79cf9de 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -2,7 +2,7 @@ import torch -from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_ +from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import TritonConfig @@ -11,7 +11,6 @@ from fast_llm.layers.block.config import BlockConfig from fast_llm.layers.block.mlp.config import MLPConfig from fast_llm.layers.block.peft import TransformerSubLayerName -from fast_llm.layers.common.linear import LinearBase from fast_llm.utils import Assert, combine_lr_scales @@ -46,21 +45,20 @@ def __init__( lr_scale = combine_lr_scales(self._lr_scale, self._config.mlp_lr_scale) # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) - self.layer_1 = LinearBase( + self.layer_1 = self._config.layer_1.get_layer( hidden_dim, intermediate_1_dim, - bias=self._config.add_mlp_bias, - weight_init_method=init_method_1, - bias_init_method=init_zeros_, + default_weight_initializer=init_method_1, + default_add_bias=self._config.add_mlp_bias, + sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) - self.layer_2 = LinearBase( + self.layer_2 = self._config.layer_1.get_layer( intermediate_2_dim, hidden_dim, - bias=self._config.add_mlp_bias, - weight_init_method=init_method_2, - bias_init_method=init_zeros_, - auto_bias_grad_accumulation=self._distributed_config.tensor_parallel > 1, + default_weight_initializer=init_method_2, + default_add_bias=self._config.add_mlp_bias, + sequence_parallel=self._sequence_parallel, transposed_weight=True, lr_scale=lr_scale, ) diff --git a/fast_llm/layers/block/peft.py b/fast_llm/layers/block/peft.py index b51d352bc..ffa40a255 100644 --- a/fast_llm/layers/block/peft.py +++ b/fast_llm/layers/block/peft.py @@ -10,7 +10,7 @@ from fast_llm.utils import div if typing.TYPE_CHECKING: - from fast_llm.layers.common.linear import LinearBase, LinearLike + from fast_llm.layers.common.linear.linear import LinearBase, LinearLike class TransformerSubLayerName(str, enum.Enum): diff --git a/fast_llm/layers/common/linear/__init__.py b/fast_llm/layers/common/linear/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/layers/common/linear/config.py b/fast_llm/layers/common/linear/config.py new file mode 100644 index 000000000..776a11925 --- /dev/null +++ b/fast_llm/layers/common/linear/config.py @@ -0,0 +1,179 @@ +import typing + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.engine.config_utils.initialization import Initializer, init_uniform_centered_, init_zeros_ +from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig +from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim +from fast_llm.functional.config import ActivationType +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + from fast_llm.layers.common.linear.convolution import CausalConv1d + from fast_llm.layers.common.linear.linear import LinearBase + + +@config_class() +class LinearBaseConfig(Config): + """ + Configuration for a linear-like layer without bias. + """ + + weight: ParameterConfig = Field( + desc="Initialization configuration for the weight.", + hint=FieldHint.feature, + ) + + +@config_class() +class AffineLinearBaseConfig(LinearBaseConfig): + """ + Configuration for a linear-like layer with optional bias. + """ + + bias: OptionalParameterConfig = Field( + desc="Use bias.", + hint=FieldHint.architecture, + ) + + +@config_class() +class LinearConfig(LinearBaseConfig): + def get_layer( + self, + in_dim: TensorDim, + out_dim: TensorDim, + *, + default_weight_initializer: Initializer, + sequence_parallel: bool = False, + transposed_weight: bool = False, + lr_scale: float | None, + ) -> "LinearBase": + from fast_llm.layers.common.linear.linear import InputParallelLinear, Linear, OutputParallelLinear + + weight = self.weight.get_parameter( + (in_dim, out_dim) if transposed_weight else (out_dim, in_dim), + default_initializer=default_weight_initializer, + lr_scale=lr_scale, + ) + if in_dim.parallel_dim is not None: + assert out_dim.parallel_dim is None + return InputParallelLinear( + weight, + None, + transposed_weight=transposed_weight, + parallel_dim=in_dim.parallel_dim, + sequence_parallel=sequence_parallel, + ) + elif out_dim.parallel_dim is not None: + return OutputParallelLinear( + weight, + None, + transposed_weight=transposed_weight, + parallel_dim=out_dim.parallel_dim, + sequence_parallel=sequence_parallel, + ) + else: + assert not sequence_parallel + return Linear(weight, None, transposed_weight=transposed_weight) + + +@config_class() +class AffineLinearConfig(AffineLinearBaseConfig, LinearConfig): + def get_layer( + self, + in_dim: TensorDim, + out_dim: TensorDim, + *, + default_weight_initializer: Initializer, + default_bias_initializer: Initializer = init_zeros_, + default_add_bias: bool = True, + sequence_parallel: bool = False, + transposed_weight: bool = False, + lr_scale: float | None, + ) -> "LinearBase": + from fast_llm.layers.common.linear.linear import InputParallelLinear, Linear, OutputParallelLinear + + weight = self.weight.get_parameter( + (in_dim, out_dim) if transposed_weight else (out_dim, in_dim), + default_initializer=default_weight_initializer, + lr_scale=lr_scale, + ) + bias = self.bias.get_parameter( + (out_dim,), + default_initializer=default_bias_initializer, + lr_scale=lr_scale, + default_enabled=default_add_bias, + ) + if in_dim.parallel_dim is not None: + assert out_dim.parallel_dim is None + return InputParallelLinear( + weight, + bias, + transposed_weight=transposed_weight, + parallel_dim=in_dim.parallel_dim, + sequence_parallel=sequence_parallel, + ) + elif out_dim.parallel_dim is not None: + return OutputParallelLinear( + weight, + bias, + transposed_weight=transposed_weight, + parallel_dim=out_dim.parallel_dim, + sequence_parallel=sequence_parallel, + ) + else: + assert not sequence_parallel + return Linear(weight, bias, transposed_weight=transposed_weight) + + +@config_class() +class CausalConv1dConfig(AffineLinearBaseConfig): + """ + Configuration for a 1d causal convolution, as used in mamba layers. + """ + + kernel_size: int = Field( + default=4, + desc="Convolution kernel size.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + activation: ActivationType | None = Field( + default=None, + hint=FieldHint.architecture, + ) + + def get_layer( + self, + in_dim: TensorDim, + *, + default_weight_initializer: Initializer | None = None, + default_bias_initializer: Initializer | None = None, + default_add_bias: bool = True, + default_activation: ActivationType = ActivationType.identity, + lr_scale: float | None, + ) -> "CausalConv1d": + from fast_llm.layers.common.linear.convolution import CausalConv1d + + kernel_dim = TensorDim("convolution_kernel", self.kernel_size) + + if default_weight_initializer is None: + default_weight_initializer = init_uniform_centered_((in_dim.global_size * kernel_dim.global_size) ** -0.5) + if default_bias_initializer is None: + default_bias_initializer = init_uniform_centered_((in_dim.global_size * kernel_dim.global_size) ** -0.5) + + weight = self.weight.get_parameter( + (in_dim, scalar_dim, kernel_dim), + default_initializer=default_weight_initializer, + lr_scale=lr_scale, + ) + bias = self.bias.get_parameter( + (in_dim,), + default_initializer=default_bias_initializer, + lr_scale=lr_scale, + default_enabled=default_add_bias, + ) + print("OIFEHIUWB", default_add_bias, self.bias.enabled, bias is None) + return CausalConv1d( + weight, bias, activation=default_activation if self.activation is None else self.activation + ) diff --git a/fast_llm/layers/common/linear/convolution.py b/fast_llm/layers/common/linear/convolution.py new file mode 100644 index 000000000..57fdccfd5 --- /dev/null +++ b/fast_llm/layers/common/linear/convolution.py @@ -0,0 +1,53 @@ +import torch + +from fast_llm.functional.config import ActivationType +from fast_llm.tensor import ParameterMeta + +try: + from causal_conv1d import causal_conv1d_fn as _causal_conv1d_fn # noqa + + _causal_conv1d_available = True +except (ImportError, RuntimeError): + _causal_conv1d_available = False + + +class CausalConv1d(torch.nn.Module): + """ + TODO: Generalize to other convolutions? + """ + + def __init__( + self, + weight: ParameterMeta, + bias: ParameterMeta | None, + *, + activation: ActivationType = ActivationType.identity, + ): + super().__init__() + self.weight = weight + self.bias = bias + self._activation = activation + self.forward = ( + self._forward_causal_conv1d + if _causal_conv1d_available and self._activation in (ActivationType.identity, ActivationType.silu) + else self._forward_torch + ) + + def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: + return self._activation.activation_fn( + torch.nn.functional.conv1d( + input_, + self.weight, + bias=self.bias, + groups=self.weight.size(0), + padding=self.weight.size(2) - 1, + )[..., : input_.size(1)] + ) + + def _forward_causal_conv1d(self, input_: torch.Tensor) -> torch.Tensor: + return _causal_conv1d_fn( + input_, + self.weight.squeeze(1), + self.bias, + activation=(None if self._activation == ActivationType.identity else self._activation.value), + ) diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear/linear.py similarity index 53% rename from fast_llm/layers/common/linear.py rename to fast_llm/layers/common/linear/linear.py index ca807e67c..631193249 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear/linear.py @@ -3,8 +3,7 @@ import torch -from fast_llm.engine.config_utils.initialization import init_zeros_ -from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedDim from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.linear import ( input_parallel_linear_autograd, @@ -42,37 +41,15 @@ class LinearBase(LinearLike): def __init__( self, - in_dim: TensorDim, - out_dim: TensorDim, + weight: ParameterMeta, + bias: ParameterMeta | None, *, - bias=True, - weight_init_method, - bias_init_method=init_zeros_, transposed_weight: bool = False, - auto_bias_grad_accumulation: bool = False, - lr_scale: float | None | tuple[float | None, ...] = None, ): super().__init__() + self.weight = weight + self.bias = bias self._transposed_weight = transposed_weight - self._in_dim = in_dim - self._out_dim = out_dim - self._weight_init_method = weight_init_method - self.weight = ParameterMeta.from_dims( - (self._in_dim, self._out_dim) if self._transposed_weight else (self._out_dim, self._in_dim), - init_method=weight_init_method, - auto_grad_accumulation=False, - lr_scale=lr_scale, - ) - if bias: - self.bias = ParameterMeta.from_dims( - (self._out_dim,), - init_method=bias_init_method, - weight_decay=False, - auto_grad_accumulation=auto_bias_grad_accumulation, - lr_scale=lr_scale, - ) - else: - self.bias = None @property def transposed_weight(self) -> bool: @@ -84,29 +61,6 @@ class Linear(LinearBase): A basic linear layer without tensor parallelism. """ - def __init__( - self, - in_dim: TensorDim, - out_dim: TensorDim, - *, - bias=True, - weight_init_method, - bias_init_method=init_zeros_, - transposed_weight: bool = False, - lr_scale: float | None | tuple[float | None, ...] = None, - ): - assert not in_dim.is_parallel - assert not out_dim.is_parallel - super().__init__( - in_dim, - out_dim, - bias=bias, - weight_init_method=weight_init_method, - bias_init_method=bias_init_method, - transposed_weight=transposed_weight, - lr_scale=lr_scale, - ) - def forward_only( self, input_: torch.Tensor ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]]: @@ -123,35 +77,23 @@ class OutputParallelLinear(LinearBase): def __init__( self, - in_dim: TensorDim, - out_dim: TensorDim, + weight: ParameterMeta, + bias: ParameterMeta | None, *, - bias=True, - weight_init_method, - bias_init_method=init_zeros_, transposed_weight: bool = False, + parallel_dim: DistributedDim, sequence_parallel: bool = False, - lr_scale: float | None | tuple[float | None, ...] = None, ): - assert not in_dim.is_parallel - self._group_size = 1 if out_dim.parallel_dim is None else out_dim.parallel_dim.size - self._sequence_parallel = sequence_parallel and self._group_size > 1 - super().__init__( - in_dim, - out_dim, - bias=bias, - weight_init_method=weight_init_method, - bias_init_method=bias_init_method, - transposed_weight=transposed_weight, - lr_scale=lr_scale, - ) + super().__init__(weight, bias, transposed_weight=transposed_weight) + self._parallel_dim = parallel_dim + self._sequence_parallel = sequence_parallel and self._parallel_dim.size > 1 def forward_only(self, input_) -> tuple[torch.Tensor, tuple[typing.Any, ...]]: return output_parallel_linear_forward( input_, weight=self.weight, bias=self.bias, - group=self._out_dim.parallel_group, + group=self._parallel_dim.group, sequence_parallel=self._sequence_parallel, transposed_weight=self._transposed_weight, ) @@ -167,30 +109,16 @@ class InputParallelLinear(LinearBase): def __init__( self, - in_dim: TensorDim, - out_dim: TensorDim, + weight: ParameterMeta, + bias: ParameterMeta | None, *, - bias=True, - weight_init_method, - bias_init_method=init_zeros_, - sequence_parallel: bool = False, transposed_weight: bool = False, - lr_scale: float | None | tuple[float | None, ...] = None, + parallel_dim: DistributedDim, + sequence_parallel: bool = False, ): - assert not out_dim.is_parallel - self._group_size = 1 if in_dim.parallel_dim is None else in_dim.parallel_dim.size - self._sequence_parallel = sequence_parallel and self._group_size > 1 - super().__init__( - in_dim, - out_dim, - bias=bias, - weight_init_method=weight_init_method, - bias_init_method=bias_init_method, - transposed_weight=transposed_weight, - # Tensor-parallel bias is computed in _bias_dropout_grad. - auto_bias_grad_accumulation=self._group_size > 1, - lr_scale=lr_scale, - ) + super().__init__(weight, bias, transposed_weight=transposed_weight) + self._parallel_dim = parallel_dim + self._sequence_parallel = sequence_parallel and self._parallel_dim.size > 1 def forward(self, input_: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: # TODO: Use self._forward instead (broken). @@ -198,13 +126,13 @@ def forward(self, input_: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | No input_, weight=self.weight, bias=self.bias, - group=self._in_dim.parallel_group, + group=self._parallel_dim.group, sequence_parallel=self._sequence_parallel, transposed_weight=self._transposed_weight, ) def forward_only(self, input_: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None, tuple[typing.Any, ...]]: - group = self._in_dim.parallel_group + group = self._parallel_dim.group output, context = input_parallel_linear_forward( input_, weight=self.weight, diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index a7eba72c8..0dc7b9589 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -215,14 +215,12 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | (hidden_dim,), init_method=weight_init_method, weight_decay=False, - auto_grad_accumulation=implementation == NormalizationImplementation.torch, lr_scale=self._lr_scale, ) self.bias = ParameterMeta.from_dims( (hidden_dim,), init_method=init_zeros_, weight_decay=False, - auto_grad_accumulation=implementation == NormalizationImplementation.torch, lr_scale=self._lr_scale, ) self._normalized_shape = self.weight.shape @@ -289,7 +287,6 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | (hidden_dim,), init_method=weight_init_method, weight_decay=False, - auto_grad_accumulation=True, lr_scale=lr_scale, ) self._normalized_shape = self.weight.shape diff --git a/fast_llm/layers/common/peft/config.py b/fast_llm/layers/common/peft/config.py index 64a2ca57a..a09cf4a29 100644 --- a/fast_llm/layers/common/peft/config.py +++ b/fast_llm/layers/common/peft/config.py @@ -4,11 +4,7 @@ from fast_llm.engine.base_model.config import BaseModelConfig if typing.TYPE_CHECKING: - import torch - - from fast_llm.layers.common.linear import LinearBase, LinearLike - from fast_llm.layers.common.normalization.normalization import Normalization - from fast_llm.tensor import ParameterMeta + from fast_llm.layers.common.linear.linear import LinearBase, LinearLike @config_class() @@ -73,7 +69,7 @@ def apply_linear( if not enabled: return self.apply_other(module) - from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear + from fast_llm.layers.common.linear.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.common.peft.lora import lora_linear if isinstance(module, InputParallelLinear): diff --git a/fast_llm/layers/common/peft/lora.py b/fast_llm/layers/common/peft/lora.py index 9e0ca0dd0..f84967cab 100644 --- a/fast_llm/layers/common/peft/lora.py +++ b/fast_llm/layers/common/peft/lora.py @@ -4,7 +4,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.autograd import wrap_forward_backward -from fast_llm.layers.common.linear import Linear, LinearBase +from fast_llm.layers.common.linear.linear import Linear, LinearBase def lora_linear( @@ -50,9 +50,6 @@ def lora_linear( transposed_weight=module.transposed_weight, lr_scale=module.weight.lr_scale, ) - # TODO: Implement proper backward pass. - module.lora_0.weight.auto_grad_accumulation = True - module.lora_1.weight.auto_grad_accumulation = True old_forward = module._forward diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index df6969cfc..45bcd8300 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -2,6 +2,7 @@ from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.config_utils.parameter import ParameterConfig from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl from fast_llm.layers.attention.config import TransformerConfig from fast_llm.layers.attention.rotary.config import NoRotaryConfig @@ -41,23 +42,37 @@ class LanguageModelBaseConfig(BaseModelConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) + word_embeddings_layer: ParameterConfig = Field( + desc="Configuration for the word embedding (weight).", + hint=FieldHint.architecture, + ) + position_embeddings_layer: ParameterConfig = Field( + desc="Configuration for the word embedding (weight).", + hint=FieldHint.architecture, + ) + output_layer: ParameterConfig = Field( + desc="Configuration for the LM output layer (weight). Ignored for tied embeddings", + hint=FieldHint.architecture, + ) max_position_embeddings: int = Field( default=2048, desc="Number of absolute position embeddings, if applicable.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + # TODO: Move to `word_embeddings_layer`/`output_layer`? vocab_size: int = Field( default=49152, desc="Size of the vocabulary, i.e., number of vocabulary embeddings and logits.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + # TODO: Move to `position_embeddings_layer.enabled`? use_position_embeddings: bool = Field( default=None, desc="Enable absolute position embeddings. Default: Enable unless using rotary embeddings.", hint=FieldHint.architecture, - ) + ) # TODO: Move to `output_layer`? (dynamic type?) tie_word_embeddings: bool = Field( default=True, desc="Tie the output weights (logits) with the vocabulary embedding.", diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index fd4e8412e..270f2630b 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -10,7 +10,7 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.layers.block.block import BlockLayerBase from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs -from fast_llm.tensor import ParameterMeta, TensorMeta +from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert WORD_EMBEDDINGS_WEIGHT = "word_embeddings_weight" @@ -61,25 +61,25 @@ def __init__( self._vocab_start_index = self._distributed_config.tensor_rank * vocab_dim.size self._vocab_end_index = (self._distributed_config.tensor_rank + 1) * vocab_dim.size - self.word_embeddings_weight = ParameterMeta.from_dims( + self.word_embeddings_weight = self._config.word_embeddings_layer.get_parameter( (vocab_dim, self._hidden_dim), - init_method=init_normal_( + default_initializer=init_normal_( std=config.init_method_std_embed, min_val=config.init_method_min_embed, max_val=config.init_method_max_embed, ), - lr_scale=config.embeddings_lr_scale, + lr_scale=self._config.embeddings_lr_scale, ) if self._config.use_absolute_position_embeddings: - self.position_embeddings_weight = ParameterMeta.from_dims( + self.position_embeddings_weight = self._config.position_embeddings_layer.get_parameter( (TensorDim("position_embeddings", self._config.max_position_embeddings), self._hidden_dim), - init_method=init_normal_( + default_initializer=init_normal_( std=config.init_method_std_embed, min_val=config.init_method_min_embed, max_val=config.init_method_max_embed, ), allow_sequence_tensor_parallel=not config.parallel_embeddings, - lr_scale=config.embeddings_lr_scale, + lr_scale=self._config.embeddings_lr_scale, ) # PEFT. diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 8917feaf6..2b4644e3f 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -2,7 +2,8 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.functional.config import ActivationType +from fast_llm.engine.config_utils.parameter import ParameterConfig +from fast_llm.layers.common.linear.config import AffineLinearConfig, CausalConv1dConfig, LinearConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -46,87 +47,67 @@ def get_init_method(self, scale: float) -> "Initializer": return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) -@config_class() -class SSMConfig(Config): - _abstract = False +@config_class(registry=True) +class MixerConfig(Config): + """ + Base config class for all mixers. + TODO: Generalize to include Attention + """ - # Model dimensions - # TODO: Remove (redundant default) - expansion_factor: int = Field( - default=2, - desc="Expansion factor.", + _abstract = True + + +@config_class() +class SSMConfig(MixerConfig): + # Layers + # [Mamba, Mamba2, DiscreteMamba2] + z_layer: AffineLinearConfig = Field( + desc="Configuration for the z layer.", hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), ) - # head_size [MambaLayer, Mamba2, DiscreteMamba2] - state_size: int = Field( - default=16, - desc="State size.", + # [Mamba, Mamba2, DiscreteMamba2] + x_layer: AffineLinearConfig = Field( + desc="Configuration for the x layer.", hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), ) - # [MambaLayer, Mamba2, DiscreteMamba2] - conv_kernel_dimension: int = Field( - default=4, - desc="Conv kernel dimension.", + # [Mamba, Mamba2, DiscreteMamba2] + convolution_layer: CausalConv1dConfig = Field( + desc="Configuration for the convolution layer.", hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), ) - # [MambaLayer, Mamba2] - dt_rank: None | int = Field( - default=None, - desc="Rank of the Δ projection matrix. If 'None', will be set to ceil(hidden_size/16)", + # [Mamba, Mamba2, DiscreteMamba2] + d_weight: ParameterConfig = Field( + desc='Configuration for the D "skip" weight.', hint=FieldHint.architecture, ) - # head_groups [DiscreteMamba2] - n_qk_heads: int = Field( - default=32, - desc="Number of QK heads.", + # [Mamba, Mamba2, DiscreteMamba2] + output_layer: AffineLinearConfig = Field( + desc="Configuration for the output layer.", hint=FieldHint.architecture, ) - # heads [DiscreteMamba2]# TODO: Remove? (redundant) - n_v_heads: int = Field( - default=32, - desc="Number of V heads.", + + # Model dimensions + # head_size [Mamba, Mamba2, DiscreteMamba2] + state_size: int = Field( + default=16, + desc="State size.", hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), ) - # c_size [MambaLayer, Mamba2, DiscreteMamba2]? - d_inner: None | int = Field( + # [Mamba, Mamba2, DiscreteMamba2] + # c_size [Mamba, Mamba2, DiscreteMamba2]? + d_inner: int = Field( default=None, desc="Inner dimension.", hint=FieldHint.core, ) - # xb_size [Mamba2] - d_xb: int = Field( - default=None, - desc="Dimension of the xB in Mamba2 blocks.", - hint=FieldHint.architecture, - ) - # Model options # add_bias_linear [Mamba2, DiscreteMamba2] [hard-coded to False in MambaLayer] add_bias_linear: bool = Field( default=False, desc="Whether to use bias in SSM layers", hint=FieldHint.architecture, ) - # activation_type [DiscreteMamba2] [hard-coded to silu in MambaLayer, Mamba2] - activation_type: ActivationType = Field( - default=None, - hint=FieldHint.architecture, - ) - # repeat_xb_before_conv [Mamba2] - repeat_kv_before_conv: bool = Field( - default=True, - desc="Whether to repeat x and B before (True) or after (False) the conv1d in Mamba2 blocks.", - hint=FieldHint.architecture, - ) - # chunk_size [DiscreteMamba2] - chunk_size: int = Field( - default=256, - desc="Chunk size for Mamba2 blocks.", - hint=FieldHint.architecture, - ) # Learning rate # lr_scale [MambaLayer, Mamba2, DiscreteMamba2] @@ -137,35 +118,49 @@ class SSMConfig(Config): valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - # Initialization - # dt_weight_initialization_method [Mamba2] - dt_init: DTInitType = Field( - default=DTInitType.random, - desc="Initialization method for dt", - hint=FieldHint.core, + +@config_class() +class MambaBaseConfig(SSMConfig): + """ + Common configuration for Mamba and Mamba2. + """ + + _abstract = False + + # Layers + dt_layer: AffineLinearConfig = Field( + desc="Configuration for the dt layer.", + hint=FieldHint.architecture, ) - # dt_weight_initialization_scale [Mamba2] - dt_scale: float = Field( - default=1.0, - desc="Scale for dt", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), + a_log_weight: ParameterConfig = Field( + desc="Configuration for the a_log layer weight.", + hint=FieldHint.architecture, ) - # dt_bias_initialization_min [MambaLayer, Mamba2] + + # Model dimensions + # [Mamba, Mamba2] + dt_rank: int = Field( + default=None, + desc="Rank of the Δ projection matrix. If 'None', will be set to ceil(hidden_size/16)", + hint=FieldHint.architecture, + ) + + # Initialization + # dt_bias_initialization_min [Mamba, Mamba2] dt_min: float = Field( default=0.001, desc="Minimum step size for discretization", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - # dt_bias_initialization_max [MambaLayer, Mamba2] + # dt_bias_initialization_max [Mamba, Mamba2] dt_max: float = Field( default=0.1, desc="Maximum step size for discretization", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - # dt_bias_initialization_floor [MambaLayer, Mamba2] + # dt_bias_initialization_floor [Mamba, Mamba2] dt_init_floor: float = Field( default=1e-4, desc="Minimum value for initializing dt", @@ -174,8 +169,123 @@ class SSMConfig(Config): ) def _validate(self) -> None: - with self._set_implicit_default(): - if self.activation_type is None: - self.activation_type = ActivationType.silu super()._validate() Assert.geq(self.dt_max, self.dt_min) + + +@config_class(dynamic_type={MixerConfig: "mamba"}) +class MambaConfig(MambaBaseConfig): + """ + Configuration for Mamba. + """ + + # Layers + # TODO: Can be confused with `x_layer` + x_projection_layer: LinearConfig = Field( + desc="Configuration for the x projection layer.", + hint=FieldHint.architecture, + ) + + def _validate(self) -> None: + super()._validate() + Assert.none(self.convolution_layer.activation) + # TODO: (Oleksiy) If bias is used there is a problem in the MambaInnerFn.backward for the bias grads. + # I think this bias is not used in other mamba repos. + assert not self.output_layer.bias.enabled + + +@config_class(dynamic_type={MixerConfig: "mamba_2"}) +class Mamba2Config(MambaBaseConfig): + """ + Configuration for Mamba2. + TODO: Actually a variation of Mamba 2. + """ + + _abstract = False + + # Layers + # [Mamba2, DiscreteMamba2] + b_layer: AffineLinearConfig = Field( + desc="Configuration for the b layer.", + hint=FieldHint.architecture, + ) + # [Mamba2, DiscreteMamba2] + c_layer: AffineLinearConfig = Field( + desc="Configuration for the c layer.", + hint=FieldHint.architecture, + ) + dt_input_layer: AffineLinearConfig = Field( + desc="Configuration for the dt input projection layer.", + hint=FieldHint.architecture, + ) + + # Model dimensions + # xb_size [Mamba2] + d_xb: int = Field( + default=None, + desc="Dimension of the xB in Mamba2 blocks.", + hint=FieldHint.architecture, + ) + + # Model options + # repeat_xb_before_conv [Mamba2] + repeat_kv_before_conv: bool = Field( + default=True, + desc="Whether to repeat x and B before (True) or after (False) the conv1d in Mamba2 blocks.", + hint=FieldHint.architecture, + ) + + # Initialization + # dt_weight_initialization_method [Mamba2] + dt_init: DTInitType = Field( + default=DTInitType.random, + desc="Initialization method for dt", + hint=FieldHint.core, + ) + # dt_weight_initialization_scale [Mamba2] + dt_scale: float = Field( + default=1.0, + desc="Scale for dt", + hint=FieldHint.core, + valid=check_field(Assert.gt, 0), + ) + + +@config_class(dynamic_type={MixerConfig: "discrete_mamba_2"}) +class DiscreteMamba2Config(SSMConfig): + """ + Configuration for DiscreteMamba2. + """ + + _abstract = False + # Layers + # [Mamba2, DiscreteMamba2] + b_layer: AffineLinearConfig = Field( + desc="Configuration for the b layer.", + hint=FieldHint.architecture, + ) + # [Mamba2, DiscreteMamba2] + c_layer: AffineLinearConfig = Field( + desc="Configuration for the c layer.", + hint=FieldHint.architecture, + ) + + # Model dimensions + # head_groups [DiscreteMamba2] + n_qk_heads: int = Field( + default=32, + desc="Number of QK heads.", + hint=FieldHint.architecture, + ) + # heads [DiscreteMamba2] + n_v_heads: int = Field( + default=32, + desc="Number of V heads.", + hint=FieldHint.architecture, + ) + # chunk_size [DiscreteMamba2] + chunk_size: int = Field( + default=256, + desc="Chunk size for Mamba2 blocks.", + hint=FieldHint.architecture, + ) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index f9462a942..6947be646 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -4,15 +4,13 @@ import einops import torch -from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_, init_zeros_ -from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, scalar_dim +from fast_llm.engine.config_utils.initialization import init_normal_, init_ones_, init_zeros_ +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs -from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.ssm.config import SSMConfig -from fast_llm.layers.ssm.mamba import init_kaiming_ +from fast_llm.layers.ssm.config import DiscreteMamba2Config from fast_llm.tensor import ParameterMeta from fast_llm.utils import combine_lr_scales, div @@ -27,15 +25,7 @@ _mamba_available = False -try: - from causal_conv1d import causal_conv1d_fn as _causal_conv1d_fn # noqa - - _causal_conv1d_available = True -except (ImportError, RuntimeError): - _causal_conv1d_available = False - - -class DiscreteMamba2[ConfigType: SSMConfig](BlockLayer[ConfigType]): +class DiscreteMamba2[ConfigType: DiscreteMamba2Config](BlockLayer[ConfigType]): """ This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py """ @@ -65,7 +55,6 @@ def __init__( heads_dim = CompositeTensorDim("heads", (head_groups_dim, group_heads_dim)) inner_dim = CompositeTensorDim("inner", (head_groups_dim, group_heads_dim, v_head_size_dim)) bc_dim = CompositeTensorDim("bc", (head_groups_dim, state_dim)) - convolution_kernel_dim = TensorDim("convolution_kernel", self._config.conv_kernel_dimension) inner_projection_dim = ConcatenatedTensorDim( "inner_projection", @@ -86,49 +75,42 @@ def __init__( # TODO: double check initializations # Projections - self.in_proj = OutputParallelLinear( + + # TODO: Use x_layer, b_layer, c_layer, a_log_layer + self.in_proj = self._config.z_layer.get_layer( hidden_dim, inner_projection_dim, - bias=config.add_bias_linear, - weight_init_method=init_kaiming_(block_config.hidden_size), + default_weight_initializer=init_normal_(0, (2 / self._config.d_inner) ** 0.5), + default_add_bias=config.add_bias_linear, sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) - if not config.add_bias_linear: + if self.in_proj.bias is None: + # TODO: Integrate to z_layer config? self.z_bias = ParameterMeta.from_dims( (inner_dim,), weight_decay=False, init_method=init_zeros_, lr_scale=lr_scale, ) - self.conv1d_weight = ParameterMeta.from_dims( - ( - convolution_dim, - scalar_dim, - convolution_kernel_dim, - ), - init_method=init_uniform_centered_( - (convolution_dim.global_size * self._config.conv_kernel_dimension) ** -0.5 - ), - lr_scale=lr_scale, - ) - self.conv1d_bias = ParameterMeta.from_dims( - (convolution_dim,), - init_method=init_uniform_centered_(self._config.conv_kernel_dimension**-0.5), + + self.convolution = self._config.convolution_layer.get_layer( + convolution_dim, + default_activation=ActivationType.silu, lr_scale=lr_scale, ) # D "skip" parameter - self.D = ParameterMeta.from_dims( + self.D = self._config.d_weight.get_parameter( (heads_dim,), - weight_decay=False, - init_method=init_ones_, + default_initializer=init_ones_, lr_scale=lr_scale, + weight_decay=False, ) - self.out_proj = InputParallelLinear( + self.out_proj = self._config.output_layer.get_layer( inner_dim, hidden_dim, - bias=config.add_bias_linear, - weight_init_method=init_kaiming_(self._config.d_inner), + default_weight_initializer=init_normal_(0, (2 / self._config.d_inner) ** 0.5), + default_add_bias=config.add_bias_linear, sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) @@ -167,7 +149,7 @@ def forward( ) # Convolutional layer # xbc: (batch, padded_sequence, local_heads * head_size + 2 * local_head_groups * state) - xBC = self.convolutional_forward(xBC, padded_length) + xBC = self.convolution(xBC.transpose(1, 2)).transpose(1, 2) x, B, C = torch.split( xBC, @@ -210,37 +192,8 @@ def forward( y = y.transpose(0, 1).contiguous() # out_proj: (batch/sequence, sequence/batch, local_heads * head_size) # -> (batch/local_sequence, local_sequence/batch, hidden) - a, b = self.out_proj(y) return self.out_proj(y) @torch.compile def _apply_a_log(self, x: torch.Tensor, A_log: torch.Tensor) -> torch.Tensor: return x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1) - - def convolutional_forward(self, xBC, padded_len): - """Convolutional layer forward pass for the full sequence.""" - if _causal_conv1d_available and self._config.activation_type in ( - ActivationType.silu, - ActivationType.identity, - ): - xBC = _causal_conv1d_fn( - xBC.transpose(1, 2), - self.conv1d_weight.squeeze(1), - self.conv1d_bias, - activation=( - None - if self._config.activation_type == ActivationType.identity - else self._config.activation_type.value - ), - ).transpose(1, 2) - else: - xBC = self._config.activation_type.activation_fn( - torch.nn.functional.conv1d( - xBC.transpose(1, 2), - self.conv1d_weight, - bias=self.conv1d_bias, - groups=self.conv1d_weight.shape[0], - padding=self._config.conv_kernel_dimension - 1, - )[..., :padded_len].transpose(1, 2) - ) - return xBC diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index 453c14af6..37962e1b6 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -5,13 +5,12 @@ import torch from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ -from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, scalar_dim +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs -from fast_llm.layers.common.linear import Linear -from fast_llm.layers.ssm.config import SSMConfig +from fast_llm.layers.ssm.config import MambaConfig from fast_llm.tensor import ParameterMeta from fast_llm.utils import Assert, combine_lr_scales, div @@ -33,8 +32,7 @@ def init_A(d_state, d_inner) -> LambdaInitializer: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa - if tensor.numel() != d_state * d_inner: - raise ValueError("_init_A requires not supported for tensor slices.") + Assert.eq(tensor.numel(), d_state * d_inner) torch.log( torch.arange(1, d_state + 1, dtype=torch.float32, device=tensor.device) .unsqueeze(0) @@ -54,7 +52,7 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) return LambdaInitializer(init_) -class Mamba[ConfigType: SSMConfig](BlockLayer[ConfigType]): +class Mamba[ConfigType: MambaConfig](BlockLayer[ConfigType]): _mixer_name: typing.ClassVar[str] = "mamba" def __init__( @@ -69,77 +67,69 @@ def __init__( ): super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) assert self._distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for Mamba" - # TODO: It's not silu? - Assert.eq(self._config.activation_type, ActivationType.silu) # Tensor dims: heads_dim = TensorDim("heads", div(self._config.d_inner, self._config.state_size)) state_dim = TensorDim("state", self._config.state_size) inner_dim = CompositeTensorDim("inner", (heads_dim, state_dim)) - convolution_kernel_dim = TensorDim("convolution_kernel", self._config.conv_kernel_dimension) dt_rank_dim = TensorDim("dt_rank", self._config.dt_rank) inner_projection_dim = ConcatenatedTensorDim("inner_projection", (inner_dim, inner_dim)) x_projection_dim = ConcatenatedTensorDim("x_projection", (dt_rank_dim, state_dim, state_dim)) lr_scale = combine_lr_scales(self._lr_scale, self._config.mamba_lr_scale) - # TODO: Backward compatibility? - self.in_proj = Linear( + # TODO: Use x_layer + self.in_proj = self._config.z_layer.get_layer( hidden_dim, inner_projection_dim, - bias=False, - weight_init_method=init_kaiming_(hidden_dim.size), + default_weight_initializer=init_normal_(0, (2 / hidden_dim.global_size) ** 0.5), + default_add_bias=config.add_bias_linear, lr_scale=lr_scale, ) - self.conv1d_weight = ParameterMeta.from_dims( - ( - inner_dim, - scalar_dim, - convolution_kernel_dim, - ), - init_method=init_kaiming_(inner_dim.size), + self.convolution = self._config.convolution_layer.get_layer( + inner_dim, + default_weight_initializer=init_normal_(0, (2 / self._config.d_inner) ** 0.5), + default_add_bias=False, + default_activation=ActivationType.silu, lr_scale=lr_scale, ) - self.x_proj = Linear( + self.x_proj = self._config.x_projection_layer.get_layer( inner_dim, x_projection_dim, - weight_init_method=init_kaiming_(inner_dim.size), - bias=False, + default_weight_initializer=init_normal_(0, (2 / self._config.d_inner) ** 0.5), lr_scale=lr_scale, ) - self.x_proj.weight.auto_grad_accumulation = True + # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 - self.dt_proj_weight = ParameterMeta.from_dims( - (inner_dim, dt_rank_dim), - init_method=init_kaiming_(self._config.dt_rank), - lr_scale=lr_scale, - ) - self.dt_proj_bias = ParameterMeta.from_dims( - (inner_dim,), - init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), + self.dt_proj = self._config.dt_layer.get_layer( + dt_rank_dim, + inner_dim, + default_weight_initializer=init_normal_(0, (2 / self._config.d_inner) ** 0.5), + default_bias_initializer=init_dtprojbias( + self._config.dt_max, self._config.dt_min, self._config.dt_init_floor + ), lr_scale=lr_scale, ) - self.A_log = ParameterMeta.from_dims( + self.A_log = self._config.a_log_weight.get_parameter( (inner_dim, state_dim), - weight_decay=False, - init_method=init_A(self._config.state_size, inner_dim.size), + default_initializer=init_A(self._config.state_size, self._config.d_inner), lr_scale=lr_scale, + weight_decay=False, ) # D "skip" parameter - self.D = ParameterMeta.from_dims( + self.D = self._config.d_weight.get_parameter( (inner_dim,), - weight_decay=False, - init_method=init_ones_, + default_initializer=init_ones_, lr_scale=lr_scale, + weight_decay=False, ) - self.out_proj = Linear( + self.out_proj = self._config.output_layer.get_layer( inner_dim, hidden_dim, - bias=False, # TODO: note, if bias is used there is a problem in the MambaInnerFn.backward for the bias grads. I think this bias is not used in other mamba repos. - weight_init_method=init_kaiming_(hidden_dim.size), + default_weight_initializer=init_normal_(0, (2 / hidden_dim.global_size) ** 0.5), + default_add_bias=False, lr_scale=lr_scale, ) - self.out_proj.weight.auto_grad_accumulation = True def forward( self, @@ -152,26 +142,22 @@ def forward( in_proj = self.in_proj(input_).permute((1, 2, 0) if kwargs[BlockKwargs.sequence_first] else (0, 2, 1)) # In the backward pass we write dx and dz next to each other to avoid torch.cat - # not, if we wanbt to support inference, we would need to imp.lement slow path here, see https://github.com/Zyphra/Zamba2/blob/1b182f40f2257f822cc06dd785df53d67d691a15/mamba_layer.py#L172s + # If we wanbt to support inference, we would need to implement slow path here, see https://github.com/Zyphra/Zamba2/blob/1b182f40f2257f822cc06dd785df53d67d691a15/mamba_layer.py#L172s out = _mamba_inner_fn( in_proj, - self.conv1d_weight, - None, + self.convolution.weight, + self.convolution.bias, self.x_proj.weight, - self.dt_proj_weight, + self.dt_proj.weight, self.out_proj.weight, self.out_proj.bias, # is None here -torch.exp(self.A_log.float()), None, # input-dependent B None, # input-dependent C self.D.float(), - delta_bias=self.dt_proj_bias.float(), + delta_bias=None if self.dt_proj.bias is None else self.dt_proj.bias.float(), delta_softplus=True, ) if kwargs[BlockKwargs.sequence_first]: out = out.transpose(0, 1) return out, None - - -def init_kaiming_(d_in: float) -> LambdaInitializer: - return init_normal_(0.0, math.sqrt(2.0 / d_in)) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 2659e415f..c664dc073 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -3,17 +3,15 @@ import torch -from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_ -from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, scalar_dim +from fast_llm.engine.config_utils.initialization import init_normal_, init_ones_ +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs -from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear -from fast_llm.layers.ssm.config import SSMConfig -from fast_llm.layers.ssm.mamba import init_A, init_dtprojbias, init_kaiming_ -from fast_llm.tensor import ParameterMeta -from fast_llm.utils import Assert, combine_lr_scales, div +from fast_llm.layers.ssm.config import Mamba2Config +from fast_llm.layers.ssm.mamba import init_A, init_dtprojbias +from fast_llm.utils import combine_lr_scales, div try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa @@ -22,17 +20,10 @@ except (ImportError, RuntimeError): _mamba_available = False -try: - from causal_conv1d import causal_conv1d_fn as _causal_conv1d_fn # noqa - - _causal_conv1d_available = True -except (ImportError, RuntimeError): - _causal_conv1d_available = False - logger = logging.getLogger(__name__) -class Mamba2[ConfigType: SSMConfig](BlockLayer[ConfigType]): +class Mamba2[ConfigType: Mamba2Config](BlockLayer[ConfigType]): """ This code is adapted from https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py """ @@ -50,7 +41,6 @@ def __init__( lr_scale: float | None, ): super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) - Assert.eq(self._config.activation_type, ActivationType.silu) num_heads = div(self._config.d_inner, self._config.state_size) num_head_groups = div(self._config.d_xb, self._config.state_size) @@ -66,7 +56,6 @@ def __init__( inner_dim = CompositeTensorDim("inner", (head_groups_dim, group_heads_dim, state_dim)) xb_dim = CompositeTensorDim("xb", (head_groups_dim, state_dim)) - convolution_kernel_dim = TensorDim("convolution_kernel", self._config.conv_kernel_dimension) # DT projection dt_rank_dim = TensorDim("dt_rank", self._config.dt_rank) @@ -81,73 +70,61 @@ def __init__( self._group_heads = div(self._local_heads, self._local_head_groups) self._local_inner_size = inner_dim.size self._local_xb_size = xb_dim.size - conv1d_dim = inner_dim if self._config.repeat_kv_before_conv else xb_dim + convolution_dim = inner_dim if self._config.repeat_kv_before_conv else xb_dim lr_scale = combine_lr_scales(self._lr_scale, self._config.mamba_lr_scale) - self.conv1d_weight = ParameterMeta.from_dims( - ( - conv1d_dim, - scalar_dim, - convolution_kernel_dim, - ), - init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), - lr_scale=lr_scale, - ) - self.conv1d_bias = ParameterMeta.from_dims( - (conv1d_dim,), - init_method=init_uniform_centered_(self._config.conv_kernel_dimension**-0.5), + self.convolution = self._config.convolution_layer.get_layer( + convolution_dim, + default_activation=ActivationType.silu, lr_scale=lr_scale, ) - self.in_proj = OutputParallelLinear( + # TODO: Use x_layer, b_layer, c_layer + self.in_proj = self._config.z_layer.get_layer( hidden_dim, inner_projection_dim, - bias=config.add_bias_linear, - weight_init_method=init_kaiming_(block_config.hidden_size), + default_weight_initializer=init_normal_(0, (2 / hidden_dim.global_size) ** 0.5), + default_add_bias=config.add_bias_linear, sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) - self.dt_in_proj = Linear( + self.dt_in_proj = self._config.dt_input_layer.get_layer( hidden_dim, dt_rank_dim, - bias=config.add_bias_linear, - weight_init_method=init_kaiming_(block_config.hidden_size), + default_weight_initializer=init_normal_(0, (2 / hidden_dim.global_size) ** 0.5), + default_add_bias=config.add_bias_linear, lr_scale=lr_scale, ) - self.dt_proj = OutputParallelLinear( + self.dt_proj = self._config.dt_layer.get_layer( dt_rank_dim, inner_dim, - bias=False, - # Initialize special dt projection to preserve variance at initialization - weight_init_method=self._config.dt_init.get_init_method( + default_weight_initializer=self._config.dt_init.get_init_method( self._config.dt_rank**-0.5 * self._config.dt_scale ), + default_bias_initializer=init_dtprojbias( + self._config.dt_max, self._config.dt_min, self._config.dt_init_floor + ), sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) - # define bias outside the linear layer since it's also used in the selective_scan_fn - self.dt_proj_bias = ParameterMeta.from_dims( - (inner_dim,), - init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), - lr_scale=lr_scale, - ) - self.A_log = ParameterMeta.from_dims( + self.A_log = self._config.a_log_weight.get_parameter( (inner_dim, state_dim), - init_method=init_A(self._config.state_size, self._config.d_inner), + default_initializer=init_A(self._config.state_size, self._config.d_inner), lr_scale=lr_scale, weight_decay=False, ) - self.D = ParameterMeta.from_dims( + # D "skip" parameter + self.D = self._config.d_weight.get_parameter( (inner_dim,), - weight_decay=False, - init_method=init_ones_, + default_initializer=init_ones_, lr_scale=lr_scale, + weight_decay=False, ) - self.out_proj = InputParallelLinear( + self.out_proj = self._config.output_layer.get_layer( inner_dim, hidden_dim, - bias=config.add_bias_linear, - weight_init_method=init_kaiming_(self._config.d_inner), + default_weight_initializer=init_normal_(0, (2 / self._config.d_inner) ** 0.5), + default_add_bias=config.add_bias_linear, sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) @@ -173,12 +150,11 @@ def forward( metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available - assert _causal_conv1d_available # inner_projection : (batch/local_sequence, local_sequence/batch, hidden) # -> (batch/sequence, sequence/batch, local_inner_projection) inner_projection = self.in_proj(input_) - dt = self.dt_proj(self.dt_in_proj(input_)) + self.dt_proj_bias + dt = self.dt_proj(self.dt_in_proj(input_)) # Standardize to (batch, sequence, local_inner_projection) if kwargs[BlockKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) @@ -198,16 +174,15 @@ def forward( # x: (batch, sequence, local_head_groups * state) -> (batch, local_heads * state, sequence) x = x.transpose(1, 2) if self._config.repeat_kv_before_conv: - x = ( + x = self.convolution( x.unflatten(1, (self._local_head_groups, self._config.state_size)) .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) ) - x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight.squeeze(1), bias=self.conv1d_bias, activation="silu") else: - x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight.squeeze(1), bias=self.conv1d_bias, activation="silu") x = ( - x.unflatten(1, (self._local_head_groups, self._config.state_size)) + self.convolution(x) + .unflatten(1, (self._local_head_groups, self._config.state_size)) .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) ) @@ -240,7 +215,7 @@ def forward( c, self.D.float(), z, - delta_bias=self.dt_proj_bias.float(), + delta_bias=self.dt_proj.bias.float(), delta_softplus=True, ) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 9d54675be..32d2d23d0 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -8,7 +8,7 @@ from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig +from fast_llm.layers.ssm.config import MixerConfig, SSMBlockType from fast_llm.models.gpt.config import ( GPTBaseModelConfig, GPTBatchConfig, @@ -29,7 +29,7 @@ class HybridSSMBaseModelConfig(GPTBaseModelConfig): _abstract = False - ssm: SSMConfig = Field( + ssm: MixerConfig = Field( desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) @@ -47,14 +47,13 @@ class HybridSSMBaseModelConfig(GPTBaseModelConfig): ssm_block_type: SSMBlockType | None = Field(init=False) def _validate(self): - with self._set_implicit_default(None): - if self.ssm.dt_rank == "auto" or self.ssm.dt_rank is None: - self.ssm.dt_rank = math.ceil(self.transformer.hidden_size / 16) with self._set_implicit_default(): - if self.ssm.d_xb is None: + if getattr(self.ssm, "dt_rank", ...) is None: + self.ssm.dt_rank = math.ceil(self.transformer.hidden_size / 16) + if getattr(self.ssm, "d_xb", ...) is None: self.ssm.d_xb = self.transformer.hidden_size - if self.ssm.d_inner is None: - self.ssm.d_inner = int(self.ssm.expansion_factor * self.transformer.hidden_size) + if getattr(self.ssm, "d_inner", ...) is None: + self.ssm.d_inner = int(2 * self.transformer.hidden_size) if self.hybrid_block_layout is None: with self._set_implicit_default(): diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index b6180c190..f56834c8a 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -244,7 +244,6 @@ def __init__( lr_scale: float | None | tuple[float | None, ...] = None, requires_grad: bool = True, allow_sequence_tensor_parallel: bool = True, - auto_grad_accumulation: bool = True, allow_no_grad: bool = False, ): super().__init__(data, tensor_name=tensor_name, dims=dims) @@ -259,9 +258,6 @@ def __init__( # Almost all parameters are either tensor-parallel or process tensor-sequence-parallel inputs. # Except for position embedding weights self.sequence_tensor_parallel = allow_sequence_tensor_parallel and not self.is_tensor_parallel - # If true, grad accumulation is handled automatically by copying or adding to the grad_buffer. - # Can be disabled to allow for a more efficient implementation that accumulates directly to it. - self.auto_grad_accumulation = auto_grad_accumulation # Disable the check that gradients have been computed for this parameter before the gradient reduction, # to support cases where gradients may not always be computed (ex. MOE layers). self.allow_no_grad = allow_no_grad @@ -281,7 +277,6 @@ def __new__( weight_decay: bool = True, lr_scale: float | None | tuple[float | None, ...] = None, allow_sequence_tensor_parallel: bool = True, - auto_grad_accumulation: bool = True, allow_no_grad: bool = False, ): return super().__new__( diff --git a/tests/models/distributed_test_model.py b/tests/models/distributed_test_model.py index 564920bd5..890a75077 100644 --- a/tests/models/distributed_test_model.py +++ b/tests/models/distributed_test_model.py @@ -25,6 +25,7 @@ def main(args: list[str] | None = None) -> None: world_size = DistributedConfig.default_world_size rank = DistributedConfig.default_rank group = pool.get_process_group(range(world_size), rank) + safe_barrier(group, "start") for name, config in DISTRIBUTED_TESTING_CONFIGS.items(): if model_testing_config.should_skip(config): diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 031ec6f97..6f4631320 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -80,6 +80,7 @@ def test_resume(run_test_script_for_all_models, compare_results_for_all_models, @pytest.mark.depends_on(on=["test_checkpoint_and_eval[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.checkpoint) def test_resume_frozen(run_test_script_for_all_models, prepare_resume): + # TODO: No more frozen weights? distributed_testing_config = DistributedTestingConfig( name="resume_frozen", compare="checkpoint_and_eval", config_args=_CHECKPOINT_AND_EVAL_ARGS ) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index e9bdeba97..83f6b50b2 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -475,6 +475,7 @@ def _update_and_add_testing_config( "llamba", model_type="hybrid_ssm", extra_args=[ + "model.base_model.ssm.type=mamba", "model.base_model.hybrid_block_layout=['t','m']", "model.base_model.ssm.d_inner=512", "model.base_model.ssm.state_size=16", @@ -503,6 +504,7 @@ def _update_and_add_testing_config( model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2']", + "model.base_model.ssm.type=mamba_2", "model.base_model.ssm.d_inner=512", "model.base_model.ssm.state_size=8", "model.base_model.ssm.d_xb=256", @@ -534,6 +536,7 @@ def _update_and_add_testing_config( model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2d']", + "model.base_model.ssm.type=discrete_mamba_2", "model.base_model.ssm.d_inner=512", "model.base_model.ssm.state_size=8", "model.base_model.ssm.n_qk_heads=8", From be6967735da2bbde27ce9f0dc913bbb7768ee307 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 27 Aug 2025 17:56:17 -0400 Subject: [PATCH 70/82] fixes --- fast_llm/layers/common/linear/config.py | 1 - fast_llm/layers/common/peft/config.py | 4 ++++ fast_llm/layers/language_model/config.py | 3 ++- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/fast_llm/layers/common/linear/config.py b/fast_llm/layers/common/linear/config.py index 776a11925..e9dbe9229 100644 --- a/fast_llm/layers/common/linear/config.py +++ b/fast_llm/layers/common/linear/config.py @@ -173,7 +173,6 @@ def get_layer( lr_scale=lr_scale, default_enabled=default_add_bias, ) - print("OIFEHIUWB", default_add_bias, self.bias.enabled, bias is None) return CausalConv1d( weight, bias, activation=default_activation if self.activation is None else self.activation ) diff --git a/fast_llm/layers/common/peft/config.py b/fast_llm/layers/common/peft/config.py index a09cf4a29..7c7834cbd 100644 --- a/fast_llm/layers/common/peft/config.py +++ b/fast_llm/layers/common/peft/config.py @@ -4,7 +4,11 @@ from fast_llm.engine.base_model.config import BaseModelConfig if typing.TYPE_CHECKING: + import torch + from fast_llm.layers.common.linear.linear import LinearBase, LinearLike + from fast_llm.layers.common.normalization.normalization import Normalization + from fast_llm.tensor import ParameterMeta @config_class() diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 45bcd8300..bfb240107 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -72,7 +72,8 @@ class LanguageModelBaseConfig(BaseModelConfig): default=None, desc="Enable absolute position embeddings. Default: Enable unless using rotary embeddings.", hint=FieldHint.architecture, - ) # TODO: Move to `output_layer`? (dynamic type?) + ) + # TODO: Move to `output_layer`? (dynamic type?) tie_word_embeddings: bool = Field( default=True, desc="Tie the output weights (logits) with the vocabulary embedding.", From 82a70aa60aeaa10a06ad8c4773245bea74b10fc7 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 27 Aug 2025 18:37:17 -0400 Subject: [PATCH 71/82] Simplify bias options --- fast_llm/layers/attention/attention.py | 6 ++-- fast_llm/layers/attention/config.py | 20 +----------- fast_llm/layers/block/config.py | 13 ++------ fast_llm/layers/block/mlp/config.py | 11 ------- fast_llm/layers/block/mlp/mlp.py | 4 +-- fast_llm/layers/ssm/config.py | 7 ----- fast_llm/layers/ssm/discrete_mamba2.py | 4 +-- fast_llm/layers/ssm/mamba.py | 2 +- fast_llm/layers/ssm/mamba2.py | 6 ++-- fast_llm/models/gpt/conversion.py | 42 ++++++++++++++++++-------- fast_llm/models/ssm/conversion.py | 4 +-- 11 files changed, 45 insertions(+), 74 deletions(-) diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 74cfb6ed4..dde6bbf94 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -115,7 +115,7 @@ def __init__( hidden_dim, query_dim, default_weight_initializer=init_method_qkv, - default_add_bias=self._config.add_qkv_bias, + default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) @@ -124,7 +124,7 @@ def __init__( hidden_dim, key_value_dim, default_weight_initializer=init_method_qkv, - default_add_bias=self._config.add_qkv_bias, + default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) @@ -138,7 +138,7 @@ def __init__( dense_dim, hidden_dim, default_weight_initializer=init_method_std_attn_proj, - default_add_bias=self._config.add_dense_bias, + default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 8e4226270..2c6d4f966 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -7,7 +7,7 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig from fast_llm.layers.attention.rotary.config import RotaryConfig -from fast_llm.layers.block.config import AddLinearBiasChoices, BlockConfig, BlockKwargs +from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.common.linear.config import AffineLinearConfig from fast_llm.utils import Assert, div @@ -180,24 +180,6 @@ def projection_size(self): def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) - @property - def add_qkv_bias(self) -> bool: - # TODO: Make this work without inheritance. - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.nowhere: - return False - return True - - @property - def add_dense_bias(self) -> bool: - # TODO: Make this work without inheritance. - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.everywhere: - return True - return False - @config_class() # TODO: Use composition instead diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 29acaadf0..4d7c9ef7c 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -1,5 +1,3 @@ -import enum - from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.layers.block.mlp.config import MLPConfig @@ -34,12 +32,6 @@ class BlockKwargs: grad_output = "grad_output" -class AddLinearBiasChoices(str, enum.Enum): - nowhere = "nowhere" - everywhere = "everywhere" - only_attn_qkv = "only_attn_qkv" - - @config_class() # TODO: Use composition instead class BlockConfig(MLPConfig, BaseModelConfig): @@ -76,12 +68,11 @@ class BlockConfig(MLPConfig, BaseModelConfig): desc="Log the memory usage after each operation in a transformer layer..", hint=FieldHint.logging, ) - add_linear_biases: bool | AddLinearBiasChoices = Field( + add_linear_biases: bool = Field( default=True, - desc="Add biases to all, none or Q, K, V layers. Accepted values: True, False, or AddLinearBiasChoices.", + desc="Add biases to linear layers. May be overridden for individual layers.", hint=FieldHint.architecture, ) - # TODO: Move these, not specific to a single block. num_layers: int = Field( default=12, diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 2a4d8e81f..186c3007d 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -160,17 +160,6 @@ class MLPConfig(Config): hint=FieldHint.optional, ) - @property - def add_mlp_bias(self) -> bool: - from fast_llm.layers.block.config import AddLinearBiasChoices - - # TODO: Make this work without inheritance. - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.everywhere: - return True - return False - def _validate(self) -> None: with self._set_implicit_default(): if self.activation_type is None: diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index 4c79cf9de..ba7c45c31 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -49,7 +49,7 @@ def __init__( hidden_dim, intermediate_1_dim, default_weight_initializer=init_method_1, - default_add_bias=self._config.add_mlp_bias, + default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) @@ -57,7 +57,7 @@ def __init__( intermediate_2_dim, hidden_dim, default_weight_initializer=init_method_2, - default_add_bias=self._config.add_mlp_bias, + default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, transposed_weight=True, lr_scale=lr_scale, diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 2b4644e3f..fd86f47cf 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -102,13 +102,6 @@ class SSMConfig(MixerConfig): hint=FieldHint.core, ) - # add_bias_linear [Mamba2, DiscreteMamba2] [hard-coded to False in MambaLayer] - add_bias_linear: bool = Field( - default=False, - desc="Whether to use bias in SSM layers", - hint=FieldHint.architecture, - ) - # Learning rate # lr_scale [MambaLayer, Mamba2, DiscreteMamba2] mamba_lr_scale: float | None = Field( diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 6947be646..46f26ed4d 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -81,7 +81,7 @@ def __init__( hidden_dim, inner_projection_dim, default_weight_initializer=init_normal_(0, (2 / self._config.d_inner) ** 0.5), - default_add_bias=config.add_bias_linear, + default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) @@ -110,7 +110,7 @@ def __init__( inner_dim, hidden_dim, default_weight_initializer=init_normal_(0, (2 / self._config.d_inner) ** 0.5), - default_add_bias=config.add_bias_linear, + default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index 37962e1b6..e7bd7674b 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -83,7 +83,7 @@ def __init__( hidden_dim, inner_projection_dim, default_weight_initializer=init_normal_(0, (2 / hidden_dim.global_size) ** 0.5), - default_add_bias=config.add_bias_linear, + default_add_bias=self._block_config.add_linear_biases, lr_scale=lr_scale, ) self.convolution = self._config.convolution_layer.get_layer( diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index c664dc073..90fdb343a 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -84,7 +84,7 @@ def __init__( hidden_dim, inner_projection_dim, default_weight_initializer=init_normal_(0, (2 / hidden_dim.global_size) ** 0.5), - default_add_bias=config.add_bias_linear, + default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) @@ -92,7 +92,7 @@ def __init__( hidden_dim, dt_rank_dim, default_weight_initializer=init_normal_(0, (2 / hidden_dim.global_size) ** 0.5), - default_add_bias=config.add_bias_linear, + default_add_bias=self._block_config.add_linear_biases, lr_scale=lr_scale, ) self.dt_proj = self._config.dt_layer.get_layer( @@ -124,7 +124,7 @@ def __init__( inner_dim, hidden_dim, default_weight_initializer=init_normal_(0, (2 / self._config.d_inner) ** 0.5), - default_add_bias=config.add_bias_linear, + default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 36975dea1..789201acc 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -199,19 +199,22 @@ def _create_transformer_layer_converters( ( f"{fast_llm_layer_name}.self_attn.query", f"{hf_layer_name}.self_attn.q_proj", - transformer_config.add_qkv_bias, + # TODO: Fix + transformer_config.add_linear_biases, QueryWeightConverter, ), ( f"{fast_llm_layer_name}.self_attn.key_value", (f"{hf_layer_name}.self_attn.k_proj", f"{hf_layer_name}.self_attn.v_proj"), - transformer_config.add_qkv_bias, + # TODO: Fix + transformer_config.add_linear_biases, KeyValueWeightConverter, ), ( f"{fast_llm_layer_name}.self_attn.dense", f"{hf_layer_name}.self_attn.o_proj", - transformer_config.add_dense_bias, + # TODO: Fix + transformer_config.add_linear_biases, WeightConverter, ), # Norm @@ -241,13 +244,15 @@ def _create_transformer_layer_converters( converters += self._get_weight_and_bias_converters( f"{fast_llm_layer_name}.mlp.layer_1", (), - transformer_config.add_mlp_bias, + # TODO: Fix + transformer_config.add_linear_biases, cls=IgnoreExportWeightConverter, ) converters += self._get_weight_and_bias_converters( f"{fast_llm_layer_name}.mlp.layer_2", (), - transformer_config.add_mlp_bias, + # TODO: Fix + transformer_config.add_linear_biases, cls=IgnoreExportWeightConverter, ) converters += [IgnoreExportWeightConverter(f"{fast_llm_layer_name}.mlp.router.weight", ())] @@ -344,12 +349,17 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig transformer_config: TransformerConfig = self._model.config.base_model.transformer return [ *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", f"{hf_prefix}.mlp.c_fc", transformer_config.add_mlp_bias + f"{fast_llm_prefix}.mlp.layer_1", + f"{hf_prefix}.mlp.c_fc", + # TODO: Fix + transformer_config.add_linear_biases, + Ω, ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.c_proj", - transformer_config.add_mlp_bias, + # TODO: Fix + transformer_config.add_linear_biases, MLPLayer2Converter, ), ] @@ -463,13 +473,15 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - transformer_config.add_mlp_bias, + # TODO: Fix + transformer_config.add_linear_biases, SplitWeightConverter, ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.down_proj", - transformer_config.add_mlp_bias, + # TODO: Fix + transformer_config.add_linear_biases, MLPLayer2Converter, ), ] @@ -531,13 +543,15 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - transformer_config.add_mlp_bias, + # TODO: Fix + transformer_config.add_linear_biases, SplitWeightConverter, ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.down_proj", - transformer_config.add_mlp_bias, + # TODO: Fix + transformer_config.add_linear_biases, MLPLayer2Converter, ), ] @@ -641,13 +655,15 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - transformer_config.add_mlp_bias, + # TODO: Fix + transformer_config.add_linear_biases, SplitWeightConverter, ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.down_proj", - transformer_config.add_mlp_bias, + # TODO: Fix + transformer_config.add_linear_biases, MLPLayer2Converter, ), ] diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index e9b18b848..5e05364a4 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -224,7 +224,7 @@ def _create_weight_converters(self) -> list[WeightConverter]: converters = super()._create_weight_converters() or [] num_layers = self._model.config.base_model.transformer.num_layers - ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear + ssm_bias: bool = self._model.config.base_model.transformer.add_linear_biases for i in range(num_layers): # SSM @@ -389,7 +389,7 @@ def _create_weight_converters(self) -> list[WeightConverter]: converters = [] num_layers = self._model.config.base_model.transformer.num_layers norm_bias: bool = False - ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear + ssm_bias: bool = self._model.config.base_model.transformer.add_linear_biases # Embedding and output if self._model.config.base_model.tie_word_embeddings: From 680980a3e72e700d0f1174b05aa01eddd39ec7b6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 29 Aug 2025 16:44:24 -0400 Subject: [PATCH 72/82] stuff --- examples/mistral.yaml | 25 +++--- .../evaluation/lm_eval/fast_llm_wrapper.py | 2 +- fast_llm/layers/attention/attention.py | 19 ++--- fast_llm/layers/attention/block.py | 2 +- fast_llm/layers/attention/config.py | 85 +++---------------- fast_llm/layers/attention/preprocessing.py | 4 +- fast_llm/layers/block/block.py | 4 +- fast_llm/layers/block/config.py | 52 +++++++++++- fast_llm/layers/block/mlp/config.py | 58 ++----------- .../layers/block/mlp/mixture_of_experts.py | 2 +- fast_llm/layers/block/mlp/mlp.py | 23 ++--- fast_llm/layers/language_model/config.py | 2 +- fast_llm/layers/ssm/config.py | 30 ++++--- fast_llm/layers/ssm/discrete_mamba2.py | 2 +- fast_llm/models/gpt/conversion.py | 62 ++++++-------- fast_llm/models/gpt/megatron.py | 24 +++--- fast_llm/models/gpt/model.py | 38 +++++---- fast_llm/models/ssm/config.py | 13 +-- tests/test_attention.py | 19 ++--- tests/test_config.py | 38 ++++++--- tests/test_multi_stage.py | 4 +- tests/utils/distributed_configs.py | 16 ++-- tests/utils/model_configs.py | 40 ++++----- 23 files changed, 246 insertions(+), 318 deletions(-) diff --git a/examples/mistral.yaml b/examples/mistral.yaml index 10aa54b7f..0754d74b0 100644 --- a/examples/mistral.yaml +++ b/examples/mistral.yaml @@ -28,24 +28,27 @@ optimizer: model: base_model: transformer: + mixer: + type: attention + rotary: + type: default + theta: 10000 + num_attention_heads: 32 + head_groups: 8 + kv_channels: 128 + window_size: 4096 + attention_dropout: 0.0 + mlp: + ffn_hidden_size: 14336 + gated: true + activation_type: silu normalization: type: rms_norm epsilon: 1.0e-05 - rotary: - type: default - theta: 10000 num_layers: 32 hidden_size: 4096 - ffn_hidden_size: 14336 - num_attention_heads: 32 - head_groups: 8 add_linear_biases: false - gated: true - activation_type: silu - kv_channels: 128 - window_size: 4096 init_method_std: 0.009021 - attention_dropout: 0.0 hidden_dropout: 0.0 vocab_size: 32000 tie_word_embeddings: false diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py index 439d1da2e..3a606b41d 100644 --- a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -104,7 +104,7 @@ def max_length(self): # check if it is absolute positional encoding and return max_position_embeddings if hasattr(self._config.fast_llm_config.base_model, "transformer"): # NOTE: will need to extend if more relative encoding types will be added - if isinstance(self._config.fast_llm_config.base_model.transformer.rotary, NoRotaryConfig): + if isinstance(self._config.fast_llm_config.base_model.transformer.mixer.rotary, NoRotaryConfig): return self._config.fast_llm_config.base_model.max_position_embeddings # check if tokenizer holds model sequence leigh info diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index dde6bbf94..d7bfae29b 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -94,17 +94,6 @@ def __init__( self._softmax_scale = self._config.kv_channels ** (-self._config.attention_softmax_scale_power) - init_method_qkv = init_normal_( - std=self._config.init_method_std_qkv, - min_val=self._config.init_method_min_qkv, - max_val=self._config.init_method_max_qkv, - ) - init_method_std_attn_proj = init_normal_( - std=self._config.init_method_std_attn_proj, - min_val=self._config.init_method_min_attn_proj, - max_val=self._config.init_method_max_attn_proj, - ) - lr_scale = combine_lr_scales( self._lr_scale, self._config.attention_lr_scale, @@ -114,7 +103,7 @@ def __init__( self.query = self._config.query_layer.get_layer( hidden_dim, query_dim, - default_weight_initializer=init_method_qkv, + default_weight_initializer=init_normal_(std=self._block_config.init_method_std), default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, @@ -123,7 +112,7 @@ def __init__( self.key_value = self._config.query_layer.get_layer( hidden_dim, key_value_dim, - default_weight_initializer=init_method_qkv, + default_weight_initializer=init_normal_(std=self._block_config.init_method_std), default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, @@ -137,7 +126,9 @@ def __init__( self.dense = self._config.dense_layer.get_layer( dense_dim, hidden_dim, - default_weight_initializer=init_method_std_attn_proj, + default_weight_initializer=init_normal_( + std=self._block_config.init_method_std / max(2 * self._block_config.num_layers, 1) ** 0.5, + ), default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, diff --git a/fast_llm/layers/attention/block.py b/fast_llm/layers/attention/block.py index 3396a2997..d9fa09cb4 100644 --- a/fast_llm/layers/attention/block.py +++ b/fast_llm/layers/attention/block.py @@ -19,4 +19,4 @@ def _mixer_class(self) -> type[Attention]: @property def _mixer_config(self) -> AttentionConfig: - return self._config + return self._config.mixer diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 2c6d4f966..0a1fbeafa 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -2,12 +2,12 @@ import logging import warnings -from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig from fast_llm.layers.attention.rotary.config import RotaryConfig -from fast_llm.layers.block.config import BlockConfig, BlockKwargs +from fast_llm.layers.block.config import BlockConfig, BlockKwargs, MixerConfig from fast_llm.layers.common.linear.config import AffineLinearConfig from fast_llm.utils import Assert, div @@ -28,8 +28,8 @@ class AttentionKwargs(BlockKwargs): past_key_values = "past_key_values" -@config_class() -class AttentionConfig(Config): +@config_class(dynamic_type={MixerConfig: "attention"}) +class AttentionConfig(MixerConfig): # TODO: Make mixer class dynamic. _abstract = False @@ -106,65 +106,13 @@ class AttentionConfig(Config): " Under muP (if scaling number of heads instead of kv_channels): use 0.5.", valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - # TODO: Review initialization - init_method_std_qkv: float = Field( - default=None, - desc="Scale for the query, key and value weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_qkv: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for query, key and value matrices. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_qkv: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for query, key and value matrices. Default: -float('inf')", - hint=FieldHint.optional, - ) - init_method_std_attn_proj: float = Field( - default=None, - desc="Scale for the attention projection weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_attn_proj: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for attention projection. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_attn_proj: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for attention projection. Default: -float('inf')", - hint=FieldHint.optional, - ) - def _validate(self) -> None: - with self._set_implicit_default(): - # TODO: Make this work without inheritance. - if self.kv_channels is None: - self.kv_channels = div(self.hidden_size, self.num_attention_heads) - # TODO: Review initialization - if self.init_method_std_qkv is None: - self.init_method_std_qkv = self.init_method_std - if self.init_method_std_attn_proj is None: - self.init_method_std_attn_proj = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 - if self.init_method_max_qkv is None: - self.init_method_max_qkv = self.init_method_max - if self.init_method_min_qkv is None: - self.init_method_min_qkv = self.init_method_min - if self.init_method_max_attn_proj is None: - self.init_method_max_attn_proj = self.init_method_max - if self.init_method_min_attn_proj is None: - self.init_method_min_attn_proj = self.init_method_min - if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: - Assert.leq(self.init_method_min, self.init_method_max) - if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: - Assert.leq(self.init_method_min_qkv, self.init_method_max_qkv) - if self.init_method_min_attn_proj is not None and self.init_method_max_attn_proj is not None: - Assert.leq(self.init_method_min_attn_proj, self.init_method_max_attn_proj) + def set_defaults(self, hidden_size: int): + if self.kv_channels is None: + with self._set_implicit_default(): + self.kv_channels = div(hidden_size, self.num_attention_heads) + def _validate(self) -> None: super()._validate() if not TritonConfig.TRITON_ENABLED: @@ -183,16 +131,7 @@ def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: @config_class() # TODO: Use composition instead -class TransformerConfig(AttentionConfig, BlockConfig): +class TransformerConfig(BlockConfig): _abstract = False - - def _validate(self) -> None: - with self._set_implicit_default(): - # Kept here for initialization order. - # TODO: Review initialization - if self.init_method_std is None: - self.init_method_std = self.hidden_size**-0.5 - if self.init_method_min is not None and self.init_method_max is not None: - Assert.leq(self.init_method_min, self.init_method_max) - - super()._validate() + # TODO: Make this unnecessary + mixer: AttentionConfig = FieldUpdate() diff --git a/fast_llm/layers/attention/preprocessing.py b/fast_llm/layers/attention/preprocessing.py index 24ef3397c..8bb923455 100644 --- a/fast_llm/layers/attention/preprocessing.py +++ b/fast_llm/layers/attention/preprocessing.py @@ -86,9 +86,7 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: class FlashAttnVarlenPreprocessor(Preprocessor): def __init__(self, config: AttentionConfig, distributed_config: DistributedConfig): - self._config = config - self._distributed_config = distributed_config - assert self._config.do_use_flash_attention(self._distributed_config) + assert config.do_use_flash_attention(distributed_config) def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: """ diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index f90fce698..7dc0e6c76 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -189,8 +189,8 @@ def __init__( from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.block.mlp.mlp import MLP - self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, + self.mlp = (MixtureOfExpertMLP if self._config.mlp.num_experts > 1 else MLP)( + self._config.mlp, self._config, self._distributed_config, self._hidden_dim, diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 4d7c9ef7c..f4dca9e6e 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -1,4 +1,7 @@ -from fast_llm.config import Field, FieldHint, check_field, config_class +import abc +import typing + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.layers.block.mlp.config import MLPConfig from fast_llm.layers.block.peft import TransformerPeftConfig @@ -32,10 +35,39 @@ class BlockKwargs: grad_output = "grad_output" +@config_class(registry=True) +class MixerConfig(Config): + """ + Base config class for all mixers. + TODO: Generalize to include Attention + """ + + _abstract = True + + @abc.abstractmethod + def set_defaults(self, hidden_size: int): + pass + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is MixerConfig and cls.get_subclass(default.get("type")) is None: + from fast_llm.layers.attention.config import AttentionConfig + + # Default subclass. + return AttentionConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + + @config_class() # TODO: Use composition instead -class BlockConfig(MLPConfig, BaseModelConfig): - +class BlockConfig(BaseModelConfig): + mixer: MixerConfig = Field() + mlp: MLPConfig = Field() # TODO: Review names normalization: NormalizationConfig = Field( desc="Configuration for the normalization layers architecture.", @@ -110,3 +142,17 @@ class BlockConfig(MLPConfig, BaseModelConfig): desc="Min value for clamping initialized weights. Default: -float('inf')", hint=FieldHint.optional, ) + + def _validate(self) -> None: + with self._set_implicit_default(): + # Kept here for initialization order. + # TODO: Review initialization + if self.init_method_std is None: + self.init_method_std = self.hidden_size**-0.5 + if self.init_method_min is not None and self.init_method_max is not None: + Assert.leq(self.init_method_min, self.init_method_max) + + self.mixer.set_defaults(self.hidden_size) + self.mlp.set_defaults(self.hidden_size) + + super()._validate() diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 186c3007d..6d67224ed 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -126,64 +126,16 @@ class MLPConfig(Config): " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", hint=FieldHint.expert, ) - # TODO: Review initialization - init_method_std_mlp_1: float = Field( - default=None, - desc="Scale for the MLP first layer weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_mlp_1: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for MLP first layer. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_mlp_1: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for MLP first layer. Default: -float('inf')", - hint=FieldHint.optional, - ) - init_method_std_mlp_2: float = Field( - default=None, - desc="Scale for the MLP second layer weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_mlp_2: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for MLP second layer. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_mlp_2: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for MLP second layer. Default: -float('inf')", - hint=FieldHint.optional, - ) + + def set_defaults(self, hidden_size: int): + if self.ffn_hidden_size is None: + with self._set_implicit_default(): + self.ffn_hidden_size = 4 * hidden_size def _validate(self) -> None: with self._set_implicit_default(): if self.activation_type is None: self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu - # TODO: Make this work without inheritance. - if self.ffn_hidden_size is None: - self.ffn_hidden_size = 4 * self.hidden_size - # TODO: Review initialization - if self.init_method_std_mlp_1 is None: - self.init_method_std_mlp_1 = self.init_method_std - if self.init_method_std_mlp_2 is None: - self.init_method_std_mlp_2 = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 - if self.init_method_max_mlp_1 is None: - self.init_method_max_mlp_1 = self.init_method_max - if self.init_method_min_mlp_1 is None: - self.init_method_min_mlp_1 = self.init_method_min - if self.init_method_max_mlp_2 is None: - self.init_method_max_mlp_2 = self.init_method_max - if self.init_method_min_mlp_2 is None: - self.init_method_min_mlp_2 = self.init_method_min - if self.init_method_min_mlp_1 is not None and self.init_method_max_mlp_1 is not None: - Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) - if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: - Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) self.num_unshared_experts = self.num_experts - self.num_shared_experts diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index 9298e872b..2d0343830 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -44,7 +44,7 @@ def __init__( ): Assert.gt(config.num_experts, 1) # TODO: Implement? - assert not config.add_linear_biases, "Biases not supported for MoE." + assert not block_config.add_linear_biases, "Biases not supported for MoE." super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) self.router = self._config.router.get_layer( self._hidden_dim, diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index ba7c45c31..c1f684619 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -15,6 +15,8 @@ class MLPBase[ConfigType: MLPConfig](BlockLayer[ConfigType]): + _config: MLPConfig + def __init__( self, config: ConfigType, @@ -29,17 +31,6 @@ def __init__( self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) intermediate_1_dim, intermediate_2_dim = self._get_intermediate_dims() - init_method_1 = init_normal_( - std=self._config.init_method_std_mlp_1, - min_val=self._config.init_method_min_mlp_1, - max_val=self._config.init_method_max_mlp_1, - ) - init_method_2 = init_normal_( - std=self._config.init_method_std_mlp_2, - min_val=self._config.init_method_min_mlp_2, - max_val=self._config.init_method_max_mlp_2, - ) - self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation lr_scale = combine_lr_scales(self._lr_scale, self._config.mlp_lr_scale) @@ -48,7 +39,7 @@ def __init__( self.layer_1 = self._config.layer_1.get_layer( hidden_dim, intermediate_1_dim, - default_weight_initializer=init_method_1, + default_weight_initializer=init_normal_(std=self._block_config.init_method_std), default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, @@ -56,7 +47,9 @@ def __init__( self.layer_2 = self._config.layer_1.get_layer( intermediate_2_dim, hidden_dim, - default_weight_initializer=init_method_2, + default_weight_initializer=init_normal_( + std=self._block_config.init_method_std / max(2 * self._block_config.num_layers, 1) ** 0.5 + ), default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, transposed_weight=True, @@ -77,7 +70,9 @@ def _get_intermediate_dims(self): return intermediate_1_dim, intermediate_2_dim -class MLP[ConfigType: BlockConfig](MLPBase[ConfigType]): +class MLP[ConfigType: MLPConfig](MLPBase[ConfigType]): + _config: MLPConfig + def __init__( self, config: ConfigType, diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index bfb240107..abf2f53df 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -218,7 +218,7 @@ def _validate(self) -> None: else: self.language_model_loss_factor = 0.0 if self.use_position_embeddings is None: - self.use_position_embeddings = isinstance(self.transformer.rotary, NoRotaryConfig) + self.use_position_embeddings = isinstance(self.transformer.mixer.rotary, NoRotaryConfig) if self.init_method_std_embed is None: self.init_method_std_embed = self.transformer.init_method_std if self.init_method_max_embed is None: diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index fd86f47cf..24be8f3eb 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,13 +1,15 @@ import enum +import math import typing -from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.parameter import ParameterConfig +from fast_llm.layers.block.config import MixerConfig from fast_llm.layers.common.linear.config import AffineLinearConfig, CausalConv1dConfig, LinearConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.engine.config_utils.initialization import Initializer + pass class SSMBlockType(enum.StrEnum): @@ -47,16 +49,6 @@ def get_init_method(self, scale: float) -> "Initializer": return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) -@config_class(registry=True) -class MixerConfig(Config): - """ - Base config class for all mixers. - TODO: Generalize to include Attention - """ - - _abstract = True - - @config_class() class SSMConfig(MixerConfig): # Layers @@ -111,6 +103,10 @@ class SSMConfig(MixerConfig): valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) + def set_defaults(self, hidden_size: int): + if self.d_inner is None: + self.d_inner = 2 * hidden_size + @config_class() class MambaBaseConfig(SSMConfig): @@ -161,6 +157,11 @@ class MambaBaseConfig(SSMConfig): valid=check_field(Assert.gt, 0), ) + def set_defaults(self, hidden_size: int): + super().set_defaults(hidden_size) + if self.dt_rank is None: + self.dt_rank = math.ceil(hidden_size / 16) + def _validate(self) -> None: super()._validate() Assert.geq(self.dt_max, self.dt_min) @@ -243,6 +244,11 @@ class Mamba2Config(MambaBaseConfig): valid=check_field(Assert.gt, 0), ) + def set_defaults(self, hidden_size: int): + super().set_defaults(hidden_size) + if self.d_xb is None: + self.d_xb = hidden_size + @config_class(dynamic_type={MixerConfig: "discrete_mamba_2"}) class DiscreteMamba2Config(SSMConfig): diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 46f26ed4d..0c91b34f8 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -182,7 +182,7 @@ def forward( Du = torch.einsum("h,blhp->blhp", self.D, x) # Norm and gate - if not self._config.add_bias_linear: + if hasattr(self, "z_bias"): z = z + self.z_bias # y: (batch, padded_sequence, local_heads, head_size) -> (batch, sequence, local_heads * head_size) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 789201acc..915bbced5 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -62,16 +62,16 @@ def export_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: (query,) = weight - if self._config.transformer.rotary.complex_format: - query = convert_rotary_complex_to_real(query[:], self._config.transformer.kv_channels, 0) + if self._config.transformer.mixer.rotary.complex_format: + query = convert_rotary_complex_to_real(query[:], self._config.transformer.mixer.kv_channels, 0) return (query,) def import_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: (query,) = weight - if self._config.transformer.rotary.complex_format: - query = convert_rotary_real_to_complex(query[:], self._config.transformer.kv_channels, 0) + if self._config.transformer.mixer.rotary.complex_format: + query = convert_rotary_real_to_complex(query[:], self._config.transformer.mixer.kv_channels, 0) return (query,) @@ -84,16 +84,16 @@ def export_weight( ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: (key_value,) = weight key, value = key_value[:].chunk(2) - if self._config.transformer.rotary.complex_format: - key = convert_rotary_complex_to_real(key, self._config.transformer.kv_channels, 0) + if self._config.transformer.mixer.rotary.complex_format: + key = convert_rotary_complex_to_real(key, self._config.transformer.mixer.kv_channels, 0) return key, value def import_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: key, value = weight - if self._config.transformer.rotary.complex_format: - key = convert_rotary_real_to_complex(key[:], self._config.transformer.kv_channels, 0) + if self._config.transformer.mixer.rotary.complex_format: + key = convert_rotary_real_to_complex(key[:], self._config.transformer.mixer.kv_channels, 0) key_value = torch.cat([key[:], value[:]]) return (key_value,) @@ -130,10 +130,10 @@ def _create_config_converters(cls) -> list[ParamConverter]: ConstantExportParamConverter(export_names=(("architectures",),), export_value=[cls.architecture]), ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), RenameParamConverter( - fast_llm_names=(("transformer", "rotary", "theta"),), export_names=(("rope_theta",),) + fast_llm_names=(("transformer", "mixer", "rotary", "theta"),), export_names=(("rope_theta",),) ), MappedConfigParamConverter( - fast_llm_names=(("transformer", "activation_type"),), + fast_llm_names=(("transformer", "mlp", "activation_type"),), export_names=(("hidden_act",),), fast_llm_value=ActivationType.from_hf_name, export_value=lambda activation_type: activation_type.hf_name, @@ -147,15 +147,15 @@ def _create_config_converters(cls) -> list[ParamConverter]: export_names=(("hidden_size",),), ), RenameParamConverter( - fast_llm_names=(("transformer", "num_attention_heads"),), + fast_llm_names=(("transformer", "mixer", "num_attention_heads"),), export_names=(("num_attention_heads",),), ), RenameParamConverter( - fast_llm_names=(("transformer", "head_groups"),), + fast_llm_names=(("transformer", "mixer", "head_groups"),), export_names=(("num_key_value_heads",),), ), RenameParamConverter( - fast_llm_names=(("transformer", "ffn_hidden_size"),), + fast_llm_names=(("transformer", "mlp", "ffn_hidden_size"),), export_names=(("intermediate_size",),), ), RenameParamConverter( @@ -331,7 +331,7 @@ class Starcoder2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler) def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantImportParamConverter( - fast_llm_names=(("transformer", "rotary", "type"),), + fast_llm_names=(("transformer", "mixer", "rotary", "type"),), fast_llm_value=DefaultRotaryConfig.dynamic_type_name, ), ConstantImportParamConverter( @@ -341,7 +341,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: RenameParamConverter( fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) ), - ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=False), + ConstantImportParamConverter(fast_llm_names=(("transformer", "mlp", "gated"),), fast_llm_value=False), ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=True), ] @@ -353,7 +353,6 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig f"{hf_prefix}.mlp.c_fc", # TODO: Fix transformer_config.add_linear_biases, - Ω, ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", @@ -377,13 +376,13 @@ def _create_config_converters(cls) -> list[ParamConverter]: fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) ), RenameParamConverter( - fast_llm_names=(("transformer", "kv_channels"),), + fast_llm_names=(("transformer", "mixer", "kv_channels"),), export_names=(("head_dim",),), ), - ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), + ConstantImportParamConverter(fast_llm_names=(("transformer", "mlp", "gated"),), fast_llm_value=True), ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=False), LLamaRotaryParamConverter( - fast_llm_names=(("transformer", "rotary"),), + fast_llm_names=(("transformer", "mixer", "rotary"),), export_names=( ("rope_theta",), ("rope_scaling",), @@ -459,14 +458,6 @@ class LlamaHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler) format: typing.ClassVar[type[CheckpointFormat]] = LlamaGPTHuggingfaceCheckpointFormat architecture: typing.ClassVar[str] = "LlamaForCausalLM" - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - # TODO: Llama supports biases - ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False), - ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False), - ] - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: transformer_config: TransformerConfig = self._model.config.base_model.transformer return [ @@ -523,12 +514,13 @@ def _create_config_converters(cls) -> list[ParamConverter]: RenameParamConverter( fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) ), - ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), + ConstantImportParamConverter(fast_llm_names=(("transformer", "mlp", "gated"),), fast_llm_value=True), + # TODO: Fix ConstantImportParamConverter( fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value="only_attn_qkv" ), LLamaRotaryParamConverter( - fast_llm_names=(("transformer", "rotary"),), + fast_llm_names=(("transformer", "mixer", "rotary"),), export_names=( ("rope_theta",), ("rope_scaling",), @@ -589,19 +581,20 @@ class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandle def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantImportParamConverter( - fast_llm_names=(("transformer", "expert_routing_type"),), fast_llm_value=RoutingType.topk + fast_llm_names=(("transformer", "mlp", "expert_routing_type"),), fast_llm_value=RoutingType.topk ), RenameParamConverter( - fast_llm_names=(("transformer", "num_experts"),), export_names=(("num_local_experts",),) + fast_llm_names=(("transformer", "mlp", "num_experts"),), export_names=(("num_local_experts",),) ), RenameParamConverter( - fast_llm_names=(("transformer", "num_experts_per_token"),), export_names=(("num_experts_per_tok",),) + fast_llm_names=(("transformer", "mlp", "num_experts_per_token"),), + export_names=(("num_experts_per_tok",),), ), IgnoreImportParamConverter(export_names=(("sliding_window",),), ignore_export_value=None), ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - num_experts = self._model.config.base_model.transformer.num_experts + num_experts = self._model.config.base_model.transformer.mlp.num_experts return [ WeightConverter(f"{fast_llm_prefix}.mlp.router.weight", f"{hf_prefix}.block_sparse_moe.gate.weight"), SplitWeightConverter( @@ -640,9 +633,6 @@ def _create_config_converters(cls) -> list[ParamConverter]: "AutoModelForCausalLM": "modeling_mtp_llama.MTPLlamaForCausalLM", }, ), - # TODO: Llama supports biases - ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False), - ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False), RenameParamConverter( fast_llm_names=(("prediction_heads",),), export_names=(("prediction_heads",),), diff --git a/fast_llm/models/gpt/megatron.py b/fast_llm/models/gpt/megatron.py index 5d3130549..562873675 100644 --- a/fast_llm/models/gpt/megatron.py +++ b/fast_llm/models/gpt/megatron.py @@ -26,7 +26,7 @@ def init_megatron(tensor: "torch.Tensor", distributed: "Distributed") -> None: tensor_ = _init_position_embeddings_megatron(meta, tensor, distributed) elif "mlp.router.weight" in meta.tensor_name: tensor_ = _init_moe_router_megatron(meta, tensor, distributed) - elif config.num_experts > 1 and "mlp.layer_" in meta.tensor_name: + elif config.mlp.num_experts > 1 and "mlp.layer_" in meta.tensor_name: tensor_ = _init_moe_mlp_megatron(config, meta, tensor, distributed) elif "mlp.layer_2" in meta.tensor_name: tensor_ = _init_transposed_mlp_weight_megatron(meta, tensor, distributed) @@ -61,19 +61,19 @@ def _init_attention_megatron( meta.param_init_method( meta, dense_tensor_ := tensor.new_empty( - config.kv_channels * config.num_attention_heads, + config.mixer.kv_channels * config.mixer.num_attention_heads, config.hidden_size, ), generator, ) # QKV is split differently. (Assuming no tensor-parallel.) - heads_per_group = div(config.num_attention_heads, config.head_groups) + heads_per_group = div(config.mixer.num_attention_heads, config.mixer.head_groups) meta.param_init_method( meta, qkv_tensor_ := tensor.new_empty( - config.head_groups, + config.mixer.head_groups, heads_per_group + 2, - config.kv_channels, + config.mixer.kv_channels, config.hidden_size, ), generator, @@ -93,12 +93,12 @@ def _init_attention_megatron( else: raise NotImplementedError(meta.tensor_name) - if isinstance(config.rotary, DefaultRotaryConfig) and config.rotary.complex_format: + if isinstance(config.mixer.rotary, DefaultRotaryConfig) and config.mixer.rotary.complex_format: from fast_llm.layers.attention.rotary.config import convert_rotary_real_to_complex # Megatron uses (2, kv_channels/2) for the complex split; we use (kv_channels/2, 2). # TODO: Avoid unnecessarily changing the value and dense tensors. - tensor_ = convert_rotary_real_to_complex(tensor_.view_as(meta), config.kv_channels, kv_dim) + tensor_ = convert_rotary_real_to_complex(tensor_.view_as(meta), config.mixer.kv_channels, kv_dim) return tensor_ @@ -146,11 +146,13 @@ def _init_moe_mlp_megatron( generator = distributed.tp_init_generator if meta.is_tensor_parallel else distributed.pp_init_generator # self.param_init_method(self, tensor, generator) state = generator.get_state() - weight_1 = tensor.new_empty(config.num_experts * (1 + config.gated) * config.ffn_hidden_size, config.hidden_size) - weight_2 = tensor.new_empty(config.num_experts * config.ffn_hidden_size, config.hidden_size) - for chunk_1, chunk_2 in zip(weight_1.chunk(config.num_experts), weight_2.chunk(config.num_experts)): + weight_1 = tensor.new_empty( + config.mlp.num_experts * (1 + config.mlp.gated) * config.mlp.ffn_hidden_size, config.hidden_size + ) + weight_2 = tensor.new_empty(config.mlp.num_experts * config.mlp.ffn_hidden_size, config.hidden_size) + for chunk_1, chunk_2 in zip(weight_1.chunk(config.mlp.num_experts), weight_2.chunk(config.mlp.num_experts)): meta.param_init_method(meta, chunk_1, generator) - chunk_2_ = chunk_2.new_empty(config.hidden_size, config.ffn_hidden_size) + chunk_2_ = chunk_2.new_empty(config.hidden_size, config.mlp.ffn_hidden_size) meta.param_init_method(meta, chunk_2_, generator) chunk_2.copy_(chunk_2_.t()) if "layer_1.weight" in meta.tensor_name: diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index b13c77724..1ac16c230 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -39,7 +39,7 @@ def __init__( ): self._hidden_dim = TensorDim("hidden", config.transformer.hidden_size) super().__init__(config, distributed_config) - self._use_flash_attention = self._config.transformer.do_use_flash_attention(distributed_config) + self._use_flash_attention = self._config.transformer.mixer.do_use_flash_attention(distributed_config) if self._config.use_megatron_initialization: for param in self.parameters(): Assert.custom(isinstance, param, ParameterMeta) @@ -51,12 +51,18 @@ def __init__( # We have multiple identical rotary modules/preprocessors, so it's simpler to make a new one here. # TODO: Find a better solution. self._preprocessors.append( - self._config.transformer.rotary.get_layer(TensorDim("kv_channels", self._config.transformer.kv_channels)) + self._config.transformer.mixer.rotary.get_layer( + TensorDim("kv_channels", self._config.transformer.mixer.kv_channels) + ) ) if self._use_flash_attention: - self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._distributed_config)) + self._preprocessors.append( + FlashAttnVarlenPreprocessor(self._config.transformer.mixer, self._distributed_config) + ) else: - self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._distributed_config)) + self._preprocessors.append( + BackupAttentionPreprocessor(self._config.transformer.mixer, self._distributed_config) + ) if self._config.enable_dpo: # TODO better way to pass in? self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._distributed_config)) @@ -390,8 +396,8 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: def loss_defs(self) -> list[LossDef]: loss_defs = [] if ( - self._config.transformer.num_experts > 1 - and self._config.transformer.expert_routing_type == RoutingType.topk + self._config.transformer.mlp.num_experts > 1 + and self._config.transformer.mlp.expert_routing_type == RoutingType.topk ): loss_defs.append( LossDef( @@ -400,7 +406,7 @@ def loss_defs(self) -> list[LossDef]: count=self._config.transformer.num_layers, ) ) - if self._config.transformer.expert_z_loss_coefficient: + if self._config.transformer.mlp.expert_z_loss_coefficient: loss_defs.append( LossDef( name=MLPLossNames.router_z_loss, @@ -453,16 +459,16 @@ def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration, batch_size, s # Query, key, value, dense. flops_per_iteration = ( 2 - * (transformer_config.num_attention_heads + transformer_config.head_groups) - * transformer_config.kv_channels + * (transformer_config.mixer.num_attention_heads + transformer_config.mixer.head_groups) + * transformer_config.mixer.kv_channels * dense_flops_base ) # MLP flops_per_iteration += ( - (2 + transformer_config.gated) - * transformer_config.ffn_hidden_size + (2 + transformer_config.mlp.gated) + * transformer_config.mlp.ffn_hidden_size * dense_flops_base - * transformer_config.num_experts_per_token + * transformer_config.mlp.num_experts_per_token ) # LM-head @@ -475,8 +481,8 @@ def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration, batch_size, s ) # Attention-matrix computation - attn_flops_base = transformer_flops_base * transformer_config.projection_size - if transformer_config.window_size is None: + attn_flops_base = transformer_flops_base * transformer_config.mixer.projection_size + if transformer_config.mixer.window_size is None: # Ignore masked values (s**2/2) attn_flops = attn_flops_base * sequence_length model_tflops = flops_per_iteration + attn_flops @@ -485,8 +491,8 @@ def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration, batch_size, s attn_flops = ( 2 * attn_flops_base - * transformer_config.window_size - * (1 - transformer_config.window_size / 2 / sequence_length) + * transformer_config.mixer.window_size + * (1 - transformer_config.mixer.window_size / 2 / sequence_length) ) model_tflops = flops_per_iteration + attn_flops diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 32d2d23d0..97526ec5b 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -1,5 +1,4 @@ import logging -import math import typing from fast_llm.config import Field, FieldHint, FieldUpdate, config_class @@ -8,7 +7,7 @@ from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.ssm.config import MixerConfig, SSMBlockType +from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig from fast_llm.models.gpt.config import ( GPTBaseModelConfig, GPTBatchConfig, @@ -29,7 +28,7 @@ class HybridSSMBaseModelConfig(GPTBaseModelConfig): _abstract = False - ssm: MixerConfig = Field( + ssm: SSMConfig = Field( desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) @@ -47,13 +46,7 @@ class HybridSSMBaseModelConfig(GPTBaseModelConfig): ssm_block_type: SSMBlockType | None = Field(init=False) def _validate(self): - with self._set_implicit_default(): - if getattr(self.ssm, "dt_rank", ...) is None: - self.ssm.dt_rank = math.ceil(self.transformer.hidden_size / 16) - if getattr(self.ssm, "d_xb", ...) is None: - self.ssm.d_xb = self.transformer.hidden_size - if getattr(self.ssm, "d_inner", ...) is None: - self.ssm.d_inner = int(2 * self.transformer.hidden_size) + self.ssm.set_defaults(self.transformer.hidden_size) if self.hybrid_block_layout is None: with self._set_implicit_default(): diff --git a/tests/test_attention.py b/tests/test_attention.py index 9564a931f..37514acd5 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -4,9 +4,8 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.attention.attention import Attention -from fast_llm.layers.attention.config import AttentionKwargs, TransformerConfig +from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs from fast_llm.layers.attention.preprocessing import FlashAttnVarlenPreprocessor from fast_llm.layers.block.config import BlockDimNames from fast_llm.utils import Assert @@ -17,17 +16,17 @@ def test_decide_window_size(): attention._decide_window_size = Attention._decide_window_size.__get__(attention) # Attach real method # Arrange - Case 1: window_size is returned (layer_index >= max_window_layers) - attention._config = TransformerConfig(window_size=512, max_window_layers=2) + attention._config = AttentionConfig(kv_channels=64, window_size=512, max_window_layers=2) attention._block_index = 2 assert attention._decide_window_size() == 512 # Arrange - Case 2: window_size is None (layer_index < max_window_layers) - attention._config = TransformerConfig(window_size=512, max_window_layers=2) + attention._config = AttentionConfig(kv_channels=64, window_size=512, max_window_layers=2) attention._block_index = 1 assert attention._decide_window_size() is None # Arrange - Case 3: max_window_layers is None (always return window_size) - attention._config = TransformerConfig(window_size=512, max_window_layers=None) + attention._config = AttentionConfig(kv_channels=64, window_size=512, max_window_layers=None) assert attention._decide_window_size() == 512 @@ -51,15 +50,9 @@ def test_varlen_preprocessor(): ] micro_sequence_length = 12 sequence_length = 36 - transformer_config = TransformerConfig( - num_layers=2, - num_attention_heads=2, - hidden_size=16, - use_flash_attention=True, + varlen_preprocessor = FlashAttnVarlenPreprocessor( + AttentionConfig(kv_channels=64), DistributedConfig(training_dtype="bfloat16") ) - distributed_config = DistributedConfig(training_dtype="bfloat16") - distributed = Distributed(distributed_config, use_cpu=True) - varlen_preprocessor = FlashAttnVarlenPreprocessor(transformer_config, distributed_config=distributed_config) for micro_seq_idx in range(int(sequence_length / micro_sequence_length)): kwargs = { AttentionKwargs.sequence_q_dim: TensorDim(BlockDimNames.sequence_k, micro_sequence_length), diff --git a/tests/test_config.py b/tests/test_config.py index b6a9a9854..507f8b56a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -75,14 +75,18 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): { "base_model": { "transformer": { + "mixer": { + "rotary": {"type": "default"}, + "window_size": 32, + "head_groups": 4, + }, + "mlp": { + "ffn_hidden_size": 4096, # Implicit default, default value + "activation_type": "silu", # Implicit default, non-default value + }, "normalization": {"type": "rms_norm"}, # Nested - "rotary": {"type": "default"}, "num_layers": 12, # Default "hidden_size": 1024, # Default - "window_size": 32, - "ffn_hidden_size": 4096, # Implicit default, default value - "activation_type": "silu", # Implicit default, non-default value - "head_groups": 4, }, "tie_word_embeddings": False, }, @@ -98,11 +102,13 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): base_model_update = { "transformer": { + "mixer": { + "head_groups": 1, # Override to default + }, # rotary: Don't override nested. "normalization": {"implementation": "triton"}, # Update non-default nested "peft": {"type": "lora", "freeze_others": False}, # Update default nested, change type "hidden_size": 512, # Override, affects derived value (kv channels) - "head_groups": 1, # Override to default }, "vocab_size": 1000, } @@ -115,7 +121,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): "pretrained": {"format": "fast_llm", "path": config_path, "load_config": load_config}, } ) - Assert.eq(pretrained_config.model.base_model.transformer.kv_channels, 64) + Assert.eq(pretrained_config.model.base_model.transformer.mixer.kv_channels, 64) serialized_config = pretrained_config.model.to_dict() expected_config = {"type": "gpt", "distributed": DistributedConfig().to_dict()} @@ -125,15 +131,20 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): if load_config in (ModelConfigType.fast_llm, ModelConfigType.model): expected_config["base_model"] = { "transformer": { + "mixer": { + "type": "attention", + "rotary": {"type": "default"}, + "window_size": 32, + "head_groups": 1, + }, + "mlp": { + "ffn_hidden_size": 4096, # Implicit default, default value + "activation_type": "silu", # Implicit default, non-default value + }, "normalization": {"type": "rms_norm", "implementation": "triton"}, - "rotary": {"type": "default"}, "peft": {"type": "lora", "freeze_others": False, "layers": ["query", "value"]}, "num_layers": 12, "hidden_size": 512, - "ffn_hidden_size": 4096, - "activation_type": "silu", - "head_groups": 1, - "window_size": 32, }, "tie_word_embeddings": False, "vocab_size": 1000, @@ -145,7 +156,8 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): "layers": ["query", "value"], } base_model_update["transformer"]["normalization"]["type"] = "layer_norm" - base_model_update["transformer"]["rotary"] = {"type": "none"} + base_model_update["transformer"]["mixer"]["type"] = "attention" + base_model_update["transformer"]["mixer"]["rotary"] = {"type": "none"} expected_config["base_model"] = base_model_update check_equal_nested(serialized_config, expected_config) diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index 56356cf7a..014dee61e 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -22,6 +22,7 @@ def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: @requires_cuda +@pytest.mark.skip # TODO: mlp.lr_scale @pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_frozen_weights(model_testing_config): get_model_test_dataset() @@ -30,8 +31,7 @@ def test_frozen_weights(model_testing_config): model_frozen = _get_trainer_from_args( args + [ - f"model.base_model.transformer.mlp_lr_scale={[0]*model_ref.config.base_model.transformer.num_experts}", - f"model.base_model.transformer.router_lr_scale=0", + f"model.base_model.transformer.mlp.lr_scale=0", ], model_testing_config.model_type, )._multi_stage diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index c3064d987..61efe14ea 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -217,7 +217,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon config_args=[ "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", - "model.base_model.transformer.dropless_moe=False", + "model.base_model.transformer.mlp.dropless_moe=False", ], num_gpus=2, compare_config=_compare_layer_match, @@ -229,7 +229,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon config_args=[ "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", - "model.base_model.transformer.dropless_moe=False", + "model.base_model.transformer.mlp.dropless_moe=False", "model.base_model.parallel_embeddings=False", "model.base_model.cross_entropy_splits=4", ], @@ -244,7 +244,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon config_args=[ "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", - "model.base_model.transformer.dropless_moe=False", + "model.base_model.transformer.mlp.dropless_moe=False", ], num_gpus=4, compare_config=_compare_layer_match, @@ -268,7 +268,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon "model.distributed.sequence_data_parallel=2", "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", - "model.base_model.transformer.dropless_moe=False", + "model.base_model.transformer.mlp.dropless_moe=False", "batch.breadth_first_micro_batches=4", ], num_gpus=4, @@ -282,7 +282,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon "model.distributed.sequence_data_parallel=2", "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", - "model.base_model.transformer.dropless_moe=False", + "model.base_model.transformer.mlp.dropless_moe=False", ], num_gpus=4, compare_config=_compare_layer_match, @@ -345,7 +345,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon config_args=[ "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", - "model.base_model.transformer.dropless_moe=False", + "model.base_model.transformer.mlp.dropless_moe=False", "model.distributed.pipeline_parallel=2", "model.multi_stage.layers_per_stage=2", "batch.breadth_first_micro_batches=4", @@ -361,7 +361,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon config_args=[ "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", - "model.base_model.transformer.dropless_moe=False", + "model.base_model.transformer.mlp.dropless_moe=False", "model.distributed.pipeline_parallel=2", "model.multi_stage.layers_per_stage=2", "batch.breadth_first_micro_batches=4", @@ -391,7 +391,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon "model.distributed.sequence_data_parallel=2", "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", - "model.base_model.transformer.dropless_moe=False", + "model.base_model.transformer.mlp.dropless_moe=False", "model.distributed.pipeline_parallel=2", "model.multi_stage.layers_per_stage=2", "batch.micro_sequence_length=256", diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 83f6b50b2..07502c85c 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -159,8 +159,8 @@ def _update_and_add_testing_config( "model.base_model.max_position_embeddings=512", "model.base_model.transformer.num_layers=2", "model.base_model.transformer.hidden_size=256", - "model.base_model.transformer.num_attention_heads=8", - "model.base_model.transformer.head_groups=8", + "model.base_model.transformer.mixer.num_attention_heads=8", + "model.base_model.transformer.mixer.head_groups=8", "model.base_model.transformer.init_method_std=0.022", f"model.base_model.vocab_size={MODEL_TEST_VOCAB_SIZE}", f"model.multi_stage.debug_param_init={_LOG_LEVEL}", @@ -238,7 +238,7 @@ def _update_and_add_testing_config( # Tests MQA. "gpt2", "starcoder", - extra_args=["model.base_model.transformer.head_groups=1"], + extra_args=["model.base_model.transformer.mixer.head_groups=1"], megatron_args=["--group-query-attention"], checkpoint_format=None, groups={ @@ -256,8 +256,8 @@ def _update_and_add_testing_config( "gpt2", "starcoder2", extra_args=[ - "model.base_model.transformer.head_groups=4", - "model.base_model.transformer.rotary.type=default", + "model.base_model.transformer.mixer.head_groups=4", + "model.base_model.transformer.mixer.rotary.type=default", # Unused, but prevents issues with conversion tests. "model.base_model.max_position_embeddings=2048", ], @@ -284,11 +284,11 @@ def _update_and_add_testing_config( "starcoder2", "llama", extra_args=[ - "model.base_model.transformer.gated=True", - "model.base_model.transformer.activation_type=silu", + "model.base_model.transformer.mlp.gated=True", + "model.base_model.transformer.mlp.activation_type=silu", "model.base_model.transformer.add_linear_biases=False", "model.base_model.transformer.normalization.type=rms_norm", - "model.base_model.transformer.ffn_hidden_size=1024", + "model.base_model.transformer.mlp.ffn_hidden_size=1024", "model.base_model.tie_word_embeddings=False", ], megatron_args=[ @@ -314,7 +314,7 @@ def _update_and_add_testing_config( # Tests llama3-style rotary embeddings. "llama", "llama3", - extra_args=["model.base_model.transformer.rotary.type=llama3"], + extra_args=["model.base_model.transformer.mixer.rotary.type=llama3"], # Megatron doesn't support Llama3-style Rotary Embeddings megatron_args=None, checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, @@ -332,7 +332,7 @@ def _update_and_add_testing_config( # Tests yarn-style rotary embeddings. "llama", "llama_yarn", - extra_args=["model.base_model.transformer.rotary.type=yarn"], + extra_args=["model.base_model.transformer.mixer.rotary.type=yarn"], # Megatron doesn't support Yarn-style Rotary Embeddings megatron_args=None, checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, @@ -390,15 +390,16 @@ def _update_and_add_testing_config( # Tests partial linear biases, Qwen2 converter. "llama", "qwen2", + # TODO: replace extra_args=["model.base_model.transformer.add_linear_biases=only_attn_qkv"], # Megatron doesn't support per sub layer biases. megatron_args=None, checkpoint_format=Qwen2GPTHuggingfaceCheckpointFormat, # TODO: Add back generate as `normal` when stable. groups={ - ModelTestingGroup.basic: ModelTestingGroupAction.normal, - ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.basic: ModelTestingGroupAction.broken, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.broken, + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, @@ -409,6 +410,7 @@ def _update_and_add_testing_config( # Tests diffusion dream converter. "qwen2", "dream", + # TODO: replace only_attn_qkv extra_args=[], # Megatron doesn't support per sub layer biases. megatron_args=None, @@ -417,7 +419,7 @@ def _update_and_add_testing_config( # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.unimportant, - ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.broken, ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, @@ -429,7 +431,7 @@ def _update_and_add_testing_config( # Tests sliding window attention, mistral converter. "llama", "mistral", - extra_args=["model.base_model.transformer.window_size=128"], + extra_args=["model.base_model.transformer.mixer.window_size=128"], # Megatron doesn't support sliding windows. megatron_args=None, checkpoint_format=MistralGPTHuggingfaceCheckpointFormat, @@ -449,8 +451,8 @@ def _update_and_add_testing_config( "llama", "mixtral", extra_args=[ - "model.base_model.transformer.num_experts=4", - "model.base_model.transformer.num_experts_per_token=4", + "model.base_model.transformer.mlp.num_experts=4", + "model.base_model.transformer.mlp.num_experts_per_token=4", ], megatron_args=[ "--num-experts=4", @@ -515,7 +517,7 @@ def _update_and_add_testing_config( groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, @@ -548,7 +550,7 @@ def _update_and_add_testing_config( groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, # TODO: Implement From 3ef7860858b0022248fefd13dc2ed8d48d804aa3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 29 Aug 2025 19:05:22 -0400 Subject: [PATCH 73/82] Dynamic mlp and block layer creation --- fast_llm/layers/attention/block.py | 6 - fast_llm/layers/attention/config.py | 10 ++ fast_llm/layers/block/block.py | 19 +-- fast_llm/layers/block/config.py | 70 ++++++++-- fast_llm/layers/block/mlp/config.py | 120 +++++++++--------- .../layers/block/mlp/mixture_of_experts.py | 8 +- fast_llm/layers/block/mlp/mlp.py | 7 +- fast_llm/layers/ssm/block.py | 6 - fast_llm/layers/ssm/config.py | 22 +++- fast_llm/models/gpt/conversion.py | 1 + fast_llm/models/gpt/megatron.py | 3 +- fast_llm/models/gpt/model.py | 7 +- fast_llm/models/ssm/model.py | 1 - tests/test_config.py | 2 + tests/test_multi_stage.py | 6 +- tests/utils/distributed_configs.py | 8 -- tests/utils/model_configs.py | 1 + 17 files changed, 177 insertions(+), 120 deletions(-) diff --git a/fast_llm/layers/attention/block.py b/fast_llm/layers/attention/block.py index d9fa09cb4..2c2da2014 100644 --- a/fast_llm/layers/attention/block.py +++ b/fast_llm/layers/attention/block.py @@ -1,8 +1,6 @@ -import functools import logging import typing -from fast_llm.layers.attention.attention import Attention from fast_llm.layers.attention.config import AttentionConfig, TransformerConfig from fast_llm.layers.block.block import Block @@ -13,10 +11,6 @@ class TransformerBlock[ConfigType: TransformerConfig](Block[ConfigType]): # TODO: Standardize to `mixer` _mixer_module_name: typing.ClassVar[str] = "self_attn" - @functools.cached_property - def _mixer_class(self) -> type[Attention]: - return Attention - @property def _mixer_config(self) -> AttentionConfig: return self._config.mixer diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 0a1fbeafa..ac35bc0af 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -1,5 +1,6 @@ import functools import logging +import typing import warnings from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none @@ -11,6 +12,9 @@ from fast_llm.layers.common.linear.config import AffineLinearConfig from fast_llm.utils import Assert, div +if typing.TYPE_CHECKING: + from fast_llm.layers.attention.attention import Attention + logger = logging.getLogger(__name__) @@ -120,6 +124,12 @@ def _validate(self) -> None: Assert.multiple(self.num_attention_heads, self.head_groups) + @property + def layer_class(self) -> "type[Attention]": + from fast_llm.layers.attention.attention import Attention + + return Attention + @functools.cached_property def projection_size(self): assert self._validated diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 7dc0e6c76..72df8a63e 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -12,7 +12,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.layers.block.config import BlockConfig, BlockKwargs +from fast_llm.layers.block.config import BlockConfig, BlockKwargs, MixerConfig from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta @@ -174,8 +174,7 @@ def __init__( setattr( self, self._mixer_module_name, - self._mixer_class( - self._mixer_config, + self._mixer_config.get_layer( self._config, self._distributed_config, self._hidden_dim, @@ -185,12 +184,7 @@ def __init__( ), ) - # TODO: Use dynamic type. - from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP - from fast_llm.layers.block.mlp.mlp import MLP - - self.mlp = (MixtureOfExpertMLP if self._config.mlp.num_experts > 1 else MLP)( - self._config.mlp, + self.mlp = self._config.mlp.get_layer( self._config, self._distributed_config, self._hidden_dim, @@ -199,14 +193,9 @@ def __init__( self._lr_scale, ) - @functools.cached_property - @abc.abstractmethod - def _mixer_class(self) -> type[BlockLayer]: - pass - @property @abc.abstractmethod - def _mixer_config(self) -> Config: + def _mixer_config(self) -> MixerConfig: pass def setup(self, distributed: Distributed) -> None: diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index f4dca9e6e..b06276297 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -1,13 +1,16 @@ -import abc import typing -from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.layers.block.mlp.config import MLPConfig +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.block.peft import TransformerPeftConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.layers.block.block import BlockLayer + # TODO: Generalize these beyond language models? (Ex. vision) @@ -35,19 +38,68 @@ class BlockKwargs: grad_output = "grad_output" -@config_class(registry=True) -class MixerConfig(Config): +@config_class() +class BlockLayerConfig(BaseModelConfig): """ - Base config class for all mixers. - TODO: Generalize to include Attention + A common class for mixers and mlps, which have the same interface. """ _abstract = True + block: "BlockConfig" = Field(init=False) + + @property + def layer_class(self) -> "type[BlockLayer]": + raise NotImplementedError() - @abc.abstractmethod def set_defaults(self, hidden_size: int): + # Opportunity to set defaults that depend on the hidden size. pass + def get_layer( + self, + block_config: "BlockConfig", + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + lr_scale: float | None, + ) -> "BlockLayer": + return self.layer_class( + self, + block_config, + distributed_config, + hidden_dim, + block_index, + name, + lr_scale, + ) + + +@config_class(registry=True) +class MLPBaseConfig(BlockLayerConfig): + _abstract = True + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is MLPBaseConfig and cls.get_subclass(default.get("type")) is None: + from fast_llm.layers.block.mlp.config import MLPConfig + + # Default subclass. + return MLPConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + + +@config_class(registry=True) +class MixerConfig(BlockLayerConfig): + """ + Base config class for mixers. + """ + @classmethod def _from_dict( cls, @@ -67,7 +119,7 @@ def _from_dict( # TODO: Use composition instead class BlockConfig(BaseModelConfig): mixer: MixerConfig = Field() - mlp: MLPConfig = Field() + mlp: MLPBaseConfig = Field() # TODO: Review names normalization: NormalizationConfig = Field( desc="Configuration for the normalization layers architecture.", diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 6d67224ed..7b8d7b8c7 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -1,13 +1,16 @@ import enum +import functools import typing -from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.functional.config import ActivationType, MLPRecomputeLevel +from fast_llm.layers.block.config import MLPBaseConfig from fast_llm.layers.common.linear.config import AffineLinearConfig, LinearConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: - pass + from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP + from fast_llm.layers.block.mlp.mlp import MLP class MLPLossNames: @@ -20,8 +23,8 @@ class RoutingType(str, enum.Enum): sinkhorn = "sinkhorn" -@config_class() -class MLPConfig(Config): +@config_class(dynamic_type={MLPBaseConfig: "mlp"}) +class MLPConfig(MLPBaseConfig): # TODO: Review names # TODO: Separate MoE? _abstract = False @@ -35,17 +38,60 @@ class MLPConfig(Config): desc="Configuration for the second MLP layer.", hint=FieldHint.architecture, ) - router: LinearConfig = Field( - # TODO: Improve default? - desc="Configuration for the MoE router.", - hint=FieldHint.feature, - ) ffn_hidden_size: int = Field( default=None, desc="Hidden dimension of the MLP intermediate state. Default: 4 * hidden_size.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + gated: bool = Field(default=False, desc="Enable gated MLP.", hint=FieldHint.architecture) + activation_type: ActivationType = Field( + default=None, + desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", + hint=FieldHint.core, + ) + # normalization_implementation: NormalizationImplementation = NormalizationImplementation.auto + mlp_recompute_level: MLPRecomputeLevel = Field( + default=MLPRecomputeLevel.none, + desc="Set which of the MLP intermediate activations will be recomputed during the backward passes. This provides a trade-off between memory and speed.", + hint=FieldHint.performance, + ) + lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for each expert.", + doc="May be used to freeze some experts by setting their scale to zero.", + hint=FieldHint.feature, + ) + + def set_defaults(self, hidden_size: int): + if self.ffn_hidden_size is None: + with self._set_implicit_default(): + self.ffn_hidden_size = 4 * hidden_size + + def _validate(self) -> None: + with self._set_implicit_default(): + if self.activation_type is None: + self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu + + super()._validate() + + if self.lr_scale is not None: + Assert.geq(self.lr_scale, 0) + + @property + def layer_class(self) -> "type[MLP]": + from fast_llm.layers.block.mlp.mlp import MLP + + return MLP + + +@config_class(dynamic_type={MLPBaseConfig: "moe"}) +class MoEMLPConfig(MLPConfig): + router: LinearConfig = Field( + # TODO: Improve default? + desc="Configuration for the MoE router.", + hint=FieldHint.feature, + ) num_experts: int = Field( default=1, desc="Number of MLP experts in a Mixture of Expert (MoE) model", @@ -58,12 +104,6 @@ class MLPConfig(Config): hint=FieldHint.architecture, valid=check_field(Assert.geq, 0), ) - num_unshared_experts: int = Field( - init=False, - desc="Number of MLP experts excluding shared ones", - hint=FieldHint.architecture, - valid=check_field(Assert.geq, 0), - ) num_experts_per_token: int = Field( default=1, desc="Active experts for each token in a MoE model.", @@ -75,18 +115,6 @@ class MLPConfig(Config): desc="The routing method, i.e., the method used to assign experts to tokens.", hint=FieldHint.architecture, ) - gated: bool = Field(default=False, desc="Enable gated MLP.", hint=FieldHint.architecture) - activation_type: ActivationType = Field( - default=None, - desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", - hint=FieldHint.core, - ) - # normalization_implementation: NormalizationImplementation = NormalizationImplementation.auto - mlp_recompute_level: MLPRecomputeLevel = Field( - default=MLPRecomputeLevel.none, - desc="Set which of the MLP intermediate activations will be recomputed during the backward passes. This provides a trade-off between memory and speed.", - hint=FieldHint.performance, - ) expert_auxiliary_loss_coefficient: float = Field( default=0.01, desc="Scale of the load balancing auxiliary loss for topk routing.", @@ -105,18 +133,6 @@ class MLPConfig(Config): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - mlp_lr_scale: float | None | tuple[float | None, ...] = Field( - default=None, - desc="Custom learning rate scale for each expert.", - doc="May be used to freeze some experts by setting their scale to zero.", - hint=FieldHint.feature, - ) - router_lr_scale: float | None = Field( - default=None, - desc="Custom learning rate for the MoE router weight.", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) dropless_moe: bool = Field( default=True, desc="Evaluate all the experts at once using dropless MoE.", hint=FieldHint.expert ) @@ -127,27 +143,17 @@ class MLPConfig(Config): hint=FieldHint.expert, ) - def set_defaults(self, hidden_size: int): - if self.ffn_hidden_size is None: - with self._set_implicit_default(): - self.ffn_hidden_size = 4 * hidden_size + @property + def layer_class(self) -> "type[MixtureOfExpertMLP]": + from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP - def _validate(self) -> None: - with self._set_implicit_default(): - if self.activation_type is None: - self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu + return MixtureOfExpertMLP - self.num_unshared_experts = self.num_experts - self.num_shared_experts + @functools.cached_property + def num_unshared_experts(self): + return self.num_experts - self.num_shared_experts + def _validate(self) -> None: super()._validate() - Assert.leq(self.num_shared_experts, self.num_experts) Assert.leq(self.num_shared_experts + self.num_experts_per_token, self.num_experts) - - if isinstance(self.mlp_lr_scale, tuple): - Assert.eq(len(self.mlp_lr_scale), self.num_experts) - for scale in self.mlp_lr_scale: - if scale is not None: - Assert.geq(scale, 0) - elif self.mlp_lr_scale is not None: - Assert.geq(self.mlp_lr_scale, 0) diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index 2d0343830..49ab34a75 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -10,7 +10,7 @@ from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.layers.block.config import BlockConfig, BlockKwargs -from fast_llm.layers.block.mlp.config import MLPConfig, MLPLossNames, RoutingType +from fast_llm.layers.block.mlp.config import MLPLossNames, MoEMLPConfig, RoutingType from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.utils import Assert, combine_lr_scales @@ -18,18 +18,18 @@ logger = logging.getLogger(__name__) -class MixtureOfExpertMLP[ConfigType: MLPConfig](MLPBase[ConfigType]): +class MixtureOfExpertMLP[ConfigType: MoEMLPConfig](MLPBase[ConfigType]): """ MoeLayer following implementation from https://github.com/NVIDIA/Megatron-LM/blob/46ebc0e4202c980d98900000d455f754a7ff9d4b/megatron/model/transformer.py#L346 With custom routing implementation supporting both topk and sinkhorn routing - TODO: Merge with MLP? TODO: Bias TODO: Sequence-tensor-parallel TODO: Expert parallel """ + _config: ConfigType _group: ProcessGroup def __init__( @@ -54,7 +54,7 @@ def __init__( min_val=self._block_config.init_method_min, max_val=self._block_config.init_method_max, ), - lr_scale=combine_lr_scales(self._config.router_lr_scale, self._lr_scale), + lr_scale=combine_lr_scales(self._lr_scale, self._config.lr_scale), ) dropless_moe = self._config.dropless_moe if dropless_moe and self._sequence_parallel: diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index c1f684619..8366c8cb5 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -11,11 +11,11 @@ from fast_llm.layers.block.config import BlockConfig from fast_llm.layers.block.mlp.config import MLPConfig from fast_llm.layers.block.peft import TransformerSubLayerName -from fast_llm.utils import Assert, combine_lr_scales +from fast_llm.utils import combine_lr_scales class MLPBase[ConfigType: MLPConfig](BlockLayer[ConfigType]): - _config: MLPConfig + _config: ConfigType def __init__( self, @@ -33,7 +33,7 @@ def __init__( self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - lr_scale = combine_lr_scales(self._lr_scale, self._config.mlp_lr_scale) + lr_scale = combine_lr_scales(self._lr_scale, self._config.lr_scale) # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = self._config.layer_1.get_layer( @@ -83,7 +83,6 @@ def __init__( name: str, lr_scale: float | None, ): - Assert.eq(config.num_experts, 1) super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) def forward( diff --git a/fast_llm/layers/ssm/block.py b/fast_llm/layers/ssm/block.py index 22d01a5cb..10513102f 100644 --- a/fast_llm/layers/ssm/block.py +++ b/fast_llm/layers/ssm/block.py @@ -1,5 +1,3 @@ -import functools - from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.block.block import Block, BlockLayer @@ -29,10 +27,6 @@ def __init__( self._mixer_class = mixer_class super().__init__(config, distributed_config, hidden_dim, block_index, name, lr_scale, return_input) - @functools.cached_property - def _mixer_class(self) -> type[BlockLayer]: - return self._mixer_class - @property def _mixer_config(self) -> SSMConfig: return self._ssm_config diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 24be8f3eb..9f93b9b5d 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -9,7 +9,9 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - pass + from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 + from fast_llm.layers.ssm.mamba import Mamba + from fast_llm.layers.ssm.mamba2 import Mamba2 class SSMBlockType(enum.StrEnum): @@ -187,6 +189,12 @@ def _validate(self) -> None: # I think this bias is not used in other mamba repos. assert not self.output_layer.bias.enabled + @property + def layer_class(self) -> "type[Mamba]": + from fast_llm.layers.ssm.mamba import Mamba + + return Mamba + @config_class(dynamic_type={MixerConfig: "mamba_2"}) class Mamba2Config(MambaBaseConfig): @@ -249,6 +257,12 @@ def set_defaults(self, hidden_size: int): if self.d_xb is None: self.d_xb = hidden_size + @property + def layer_class(self) -> "type[Mamba2]": + from fast_llm.layers.ssm.mamba2 import Mamba2 + + return Mamba2 + @config_class(dynamic_type={MixerConfig: "discrete_mamba_2"}) class DiscreteMamba2Config(SSMConfig): @@ -288,3 +302,9 @@ class DiscreteMamba2Config(SSMConfig): desc="Chunk size for Mamba2 blocks.", hint=FieldHint.architecture, ) + + @property + def layer_class(self) -> "type[DiscreteMamba2]": + from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 + + return DiscreteMamba2 diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 915bbced5..42fe849b3 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -580,6 +580,7 @@ class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandle @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ + ConstantImportParamConverter(fast_llm_names=(("transformer", "mlp", "type"),), fast_llm_value="moe"), ConstantImportParamConverter( fast_llm_names=(("transformer", "mlp", "expert_routing_type"),), fast_llm_value=RoutingType.topk ), diff --git a/fast_llm/models/gpt/megatron.py b/fast_llm/models/gpt/megatron.py index 562873675..b4a6b6feb 100644 --- a/fast_llm/models/gpt/megatron.py +++ b/fast_llm/models/gpt/megatron.py @@ -2,6 +2,7 @@ from fast_llm.layers.attention.config import TransformerConfig from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig +from fast_llm.layers.block.mlp.config import MoEMLPConfig from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: @@ -26,7 +27,7 @@ def init_megatron(tensor: "torch.Tensor", distributed: "Distributed") -> None: tensor_ = _init_position_embeddings_megatron(meta, tensor, distributed) elif "mlp.router.weight" in meta.tensor_name: tensor_ = _init_moe_router_megatron(meta, tensor, distributed) - elif config.mlp.num_experts > 1 and "mlp.layer_" in meta.tensor_name: + elif isinstance(config.mlp, MoEMLPConfig) and config.mlp.num_experts > 1 and "mlp.layer_" in meta.tensor_name: tensor_ = _init_moe_mlp_megatron(config, meta, tensor, distributed) elif "mlp.layer_2" in meta.tensor_name: tensor_ = _init_transposed_mlp_weight_megatron(meta, tensor, distributed) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 1ac16c230..f2f31ddf2 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -14,7 +14,7 @@ from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.attention.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor from fast_llm.layers.block.config import BlockDimNames -from fast_llm.layers.block.mlp.config import MLPLossNames, RoutingType +from fast_llm.layers.block.mlp.config import MLPLossNames, MoEMLPConfig, RoutingType from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead @@ -396,7 +396,8 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: def loss_defs(self) -> list[LossDef]: loss_defs = [] if ( - self._config.transformer.mlp.num_experts > 1 + isinstance(self._config.transformer.mlp, MoEMLPConfig) + and self._config.transformer.mlp.num_experts > 1 and self._config.transformer.mlp.expert_routing_type == RoutingType.topk ): loss_defs.append( @@ -468,7 +469,7 @@ def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration, batch_size, s (2 + transformer_config.mlp.gated) * transformer_config.mlp.ffn_hidden_size * dense_flops_base - * transformer_config.mlp.num_experts_per_token + * (transformer_config.mlp.num_experts_per_token if isinstance(transformer_config.mlp, MoEMLPConfig) else 1) ) # LM-head diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 9b79e74a3..aebfa6ef4 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -54,7 +54,6 @@ def _get_block( block_index, name, lr_scale, - self._config.ssm_block_type.get_mixer_class(), return_input, ) diff --git a/tests/test_config.py b/tests/test_config.py index 507f8b56a..b1be46b60 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -138,6 +138,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): "head_groups": 1, }, "mlp": { + "type": "mlp", "ffn_hidden_size": 4096, # Implicit default, default value "activation_type": "silu", # Implicit default, non-default value }, @@ -158,6 +159,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): base_model_update["transformer"]["normalization"]["type"] = "layer_norm" base_model_update["transformer"]["mixer"]["type"] = "attention" base_model_update["transformer"]["mixer"]["rotary"] = {"type": "none"} + base_model_update["transformer"]["mlp"] = {"type": "mlp"} expected_config["base_model"] = base_model_update check_equal_nested(serialized_config, expected_config) diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index 014dee61e..b1989eb95 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -22,17 +22,13 @@ def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: @requires_cuda -@pytest.mark.skip # TODO: mlp.lr_scale @pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_frozen_weights(model_testing_config): get_model_test_dataset() args = model_testing_config.config_args + ["run.tensor_logs.save=False"] model_ref = _get_trainer_from_args(args, model_testing_config.model_type)._multi_stage model_frozen = _get_trainer_from_args( - args - + [ - f"model.base_model.transformer.mlp.lr_scale=0", - ], + args + [f"model.base_model.transformer.mlp.lr_scale=0"], model_testing_config.model_type, )._multi_stage diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index 61efe14ea..93d7c35cd 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -217,7 +217,6 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon config_args=[ "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", - "model.base_model.transformer.mlp.dropless_moe=False", ], num_gpus=2, compare_config=_compare_layer_match, @@ -229,7 +228,6 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon config_args=[ "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", - "model.base_model.transformer.mlp.dropless_moe=False", "model.base_model.parallel_embeddings=False", "model.base_model.cross_entropy_splits=4", ], @@ -244,7 +242,6 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon config_args=[ "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", - "model.base_model.transformer.mlp.dropless_moe=False", ], num_gpus=4, compare_config=_compare_layer_match, @@ -268,7 +265,6 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon "model.distributed.sequence_data_parallel=2", "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", - "model.base_model.transformer.mlp.dropless_moe=False", "batch.breadth_first_micro_batches=4", ], num_gpus=4, @@ -282,7 +278,6 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon "model.distributed.sequence_data_parallel=2", "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", - "model.base_model.transformer.mlp.dropless_moe=False", ], num_gpus=4, compare_config=_compare_layer_match, @@ -345,7 +340,6 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon config_args=[ "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", - "model.base_model.transformer.mlp.dropless_moe=False", "model.distributed.pipeline_parallel=2", "model.multi_stage.layers_per_stage=2", "batch.breadth_first_micro_batches=4", @@ -361,7 +355,6 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon config_args=[ "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", - "model.base_model.transformer.mlp.dropless_moe=False", "model.distributed.pipeline_parallel=2", "model.multi_stage.layers_per_stage=2", "batch.breadth_first_micro_batches=4", @@ -391,7 +384,6 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon "model.distributed.sequence_data_parallel=2", "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", - "model.base_model.transformer.mlp.dropless_moe=False", "model.distributed.pipeline_parallel=2", "model.multi_stage.layers_per_stage=2", "batch.micro_sequence_length=256", diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 07502c85c..3a13a78c2 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -451,6 +451,7 @@ def _update_and_add_testing_config( "llama", "mixtral", extra_args=[ + "model.base_model.transformer.mlp.type=moe", "model.base_model.transformer.mlp.num_experts=4", "model.base_model.transformer.mlp.num_experts_per_token=4", ], From ecad96b5908f58b94d5121d1002833570f203f84 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 3 Sep 2025 17:31:14 -0400 Subject: [PATCH 74/82] stuff --- examples/mistral.yaml | 6 +- .../engine/config_utils/initialization.py | 124 +++++++- fast_llm/engine/config_utils/parameter.py | 66 ++++- fast_llm/layers/attention/attention.py | 53 ++-- fast_llm/layers/attention/block.py | 16 -- fast_llm/layers/attention/config.py | 31 +- fast_llm/layers/block/block.py | 36 +-- fast_llm/layers/block/config.py | 80 ++++-- fast_llm/layers/block/mlp/config.py | 8 +- .../layers/block/mlp/mixture_of_experts.py | 26 +- fast_llm/layers/block/mlp/mlp.py | 45 ++- fast_llm/layers/block/peft.py | 87 ------ fast_llm/layers/common/linear/config.py | 85 ++++-- .../layers/common/normalization/config.py | 16 +- fast_llm/layers/common/peft/config.py | 25 +- fast_llm/layers/common/peft/lora.py | 23 +- fast_llm/layers/language_model/config.py | 264 ++++++++++-------- fast_llm/layers/language_model/embedding.py | 84 +++--- fast_llm/layers/language_model/head.py | 77 ++--- .../layers/language_model/preprocessing.py | 12 +- fast_llm/layers/ssm/block.py | 32 --- fast_llm/layers/ssm/config.py | 11 +- fast_llm/layers/ssm/discrete_mamba2.py | 41 ++- fast_llm/layers/ssm/mamba.py | 56 ++-- fast_llm/layers/ssm/mamba2.py | 54 ++-- fast_llm/models/gpt/config.py | 26 +- fast_llm/models/gpt/conversion.py | 64 +++-- fast_llm/models/gpt/megatron.py | 8 +- fast_llm/models/gpt/model.py | 109 ++++---- fast_llm/models/gpt/trainer.py | 6 +- fast_llm/models/ssm/config.py | 20 +- fast_llm/models/ssm/model.py | 39 +-- fast_llm/tensor.py | 8 +- fast_llm/utils.py | 26 -- tests/layers/test_lm_head.py | 62 ++-- tests/models/test_checkpoint.py | 6 +- tests/models/test_match_megatron.py | 6 +- tests/test_config.py | 18 +- tests/test_multi_stage.py | 5 +- tests/utils/distributed_configs.py | 6 +- tests/utils/model_configs.py | 12 +- 41 files changed, 986 insertions(+), 793 deletions(-) delete mode 100644 fast_llm/layers/attention/block.py delete mode 100644 fast_llm/layers/block/peft.py delete mode 100644 fast_llm/layers/ssm/block.py diff --git a/examples/mistral.yaml b/examples/mistral.yaml index 0754d74b0..6f4a60143 100644 --- a/examples/mistral.yaml +++ b/examples/mistral.yaml @@ -50,8 +50,10 @@ model: add_linear_biases: false init_method_std: 0.009021 hidden_dropout: 0.0 - vocab_size: 32000 - tie_word_embeddings: false + embeddings_layer: + vocab_size: 32000 + output_layer: + tied_weight: false multi_stage: zero_stage: 2 distributed: diff --git a/fast_llm/engine/config_utils/initialization.py b/fast_llm/engine/config_utils/initialization.py index b60070562..7fefda4b0 100644 --- a/fast_llm/engine/config_utils/initialization.py +++ b/fast_llm/engine/config_utils/initialization.py @@ -1,17 +1,139 @@ import abc import typing +from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.utils import Assert + if typing.TYPE_CHECKING: import torch from fast_llm.tensor import ParameterMeta -class Initializer(abc.ABC): +class Initialization(abc.ABC): + """ + A common base class for initializations and initialization configs so both may be used interchangeably. + """ + + @abc.abstractmethod + def get_initializer(self) -> "Initializer": + pass + + +@config_class(registry=True) +class InitializationConfig(Config, Initialization): + _abstract = True + is_default: typing.ClassVar[bool] = False + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is InitializationConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return DefaultInitializationConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + + +@config_class() +class DefaultInitializationConfig(InitializationConfig): + # A placeholder indicating that the class default should be used instead. + _abstract = False + is_default = True + + def get_initializer(self) -> "Initializer": + raise NotImplementedError() + + +@config_class(dynamic_type={InitializationConfig: "fill"}) +class FillInitializationConfig(InitializationConfig): + """ + Normal initialization: normal(mean, std).clamp(min,max) + """ + + _abstract = False + + value: float = Field( + default=1, + desc="Initialization value.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + + def get_initializer(self) -> "Initializer": + return init_fill_(self.value) + + +@config_class(dynamic_type={InitializationConfig: "normal"}) +class NormalInitializationConfig(InitializationConfig): + """ + Normal initialization: normal(mean, std).clamp(min,max) + """ + + _abstract = False + + std: float = Field( + default=1, + desc="Standard deviation for normal initialization.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + mean: float = Field( + default=0, + desc="Mean for normal initialization.", + hint=FieldHint.optional, + ) + min: float | None = Field( + default=None, + desc="Min value for initialization clamping.", + hint=FieldHint.optional, + ) + max: float | None = Field( + default=None, + desc="Min value for initialization clamping.", + hint=FieldHint.optional, + ) + + def get_initializer(self) -> "Initializer": + return init_normal_(self.mean, self.std, self.min, self.max) + + +@config_class(dynamic_type={InitializationConfig: "uniform"}) +class UniformInitializationConfig(InitializationConfig): + """ + Uniform initialization: uniform(mean - scale, mean + scale).clamp(min,max) + """ + + _abstract = False + + scale: float = Field( + default=None, + desc="Initialization scale.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + mean: float = Field( + default=None, + desc="Initialization mean.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + + def get_initializer(self) -> "Initializer": + return init_uniform_centered_(self.scale, self.mean) + + +class Initializer(Initialization): @abc.abstractmethod def __call__(self, meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: pass + def get_initializer(self) -> "Initializer": + return self + requires_global_initialization = False diff --git a/fast_llm/engine/config_utils/parameter.py b/fast_llm/engine/config_utils/parameter.py index aa84408d2..76416d365 100644 --- a/fast_llm/engine/config_utils/parameter.py +++ b/fast_llm/engine/config_utils/parameter.py @@ -1,15 +1,52 @@ +import math import typing -from fast_llm.config import Config, Field, config_class -from fast_llm.engine.config_utils.initialization import Initializer +from fast_llm.config import Config, Field, FieldHint, config_class +from fast_llm.engine.config_utils.initialization import Initialization, InitializationConfig from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.layers.common.peft.config import PeftConfig if typing.TYPE_CHECKING: from fast_llm.tensor import ParameterMeta +def combine_lr_scales(*lr_scales: float | None | tuple[float | None, ...]): + # Remove `None` entries. + lr_scales = tuple(lr_scale for lr_scale in lr_scales if lr_scale is not None) + if not lr_scales: + # Everything is None + return None + tuple_length = None + # Check if we have tuples, and determine the length. + for lr_scale in lr_scales: + if isinstance(lr_scale, tuple): + if tuple_length is None: + tuple_length = len(lr_scale) + else: + assert len(lr_scale) == tuple_length + if tuple_length is None: + # No tuple: simple product. + return math.prod(lr_scales) + else: + # Tuple(s): use recursion. + return tuple( + combine_lr_scales(*[lr_scale[i] if isinstance(lr_scale, tuple) else lr_scale for lr_scale in lr_scales]) + for i in range(tuple_length) + ) + + @config_class() class ParameterConfig(Config): + initialization: InitializationConfig = Field( + desc="If provided, override the default initialization method set by the parent layer.", + hint=FieldHint.feature, + ) + lr_scale: float | None = Field( + default=None, + desc="Scaling factor for the parameter learning rate." + " Combines multiplicatively with the scale set by the parent layer, if applicable.", + hint=FieldHint.feature, + ) # TODO: Initialization, lr_scale def _validate(self) -> None: @@ -18,20 +55,25 @@ def _validate(self) -> None: def get_parameter( self, dims: tuple[TensorDim, ...], - default_initializer: Initializer, + *, + default_initialization: Initialization, lr_scale: float | None, weight_decay: bool = True, allow_sequence_tensor_parallel: bool = True, + peft: PeftConfig | None, ) -> "ParameterMeta": from fast_llm.tensor import ParameterMeta - return ParameterMeta.from_dims( + out = ParameterMeta.from_dims( dims, - init_method=default_initializer, - lr_scale=lr_scale, + init_method=default_initialization if self.initialization.is_default else self.initialization, + lr_scale=combine_lr_scales(lr_scale, self.lr_scale), weight_decay=weight_decay, allow_sequence_tensor_parallel=allow_sequence_tensor_parallel, ) + if peft is not None: + out = peft.apply_weight(out) + return out @config_class() @@ -39,7 +81,6 @@ class OptionalParameterConfig(ParameterConfig): enabled: bool | None = Field( default=None, ) - # TODO: Initialization, lr_scale def _validate(self) -> None: pass @@ -47,21 +88,24 @@ def _validate(self) -> None: def get_parameter( self, dims: tuple[TensorDim, ...], - default_initializer: Initializer, + *, + default_initialization: Initialization, lr_scale: float | None, weight_decay: bool = True, allow_sequence_tensor_parallel: bool = True, default_enabled: bool = False, + peft: PeftConfig | None, ) -> "ParameterMeta|None": - from fast_llm.tensor import ParameterMeta + pass if (self.enabled is None and default_enabled) or self.enabled: - return ParameterMeta.from_dims( + return super().get_parameter( dims, - init_method=default_initializer, + default_initialization=default_initialization, lr_scale=lr_scale, weight_decay=weight_decay, allow_sequence_tensor_parallel=allow_sequence_tensor_parallel, + peft=peft, ) else: return None diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index d7bfae29b..8740ae490 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -11,8 +11,8 @@ from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockDimNames -from fast_llm.layers.block.peft import TransformerSubLayerName -from fast_llm.utils import combine_lr_scales, div +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.utils import div try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -57,12 +57,24 @@ def __init__( config: ConfigType, block_config: BlockConfig, distributed_config: DistributedConfig, + *, + # TODO: Review `hidden_dim` and `block_index` hidden_dim: TensorDim, block_index: int, name: str, lr_scale: float | None, + peft: PeftConfig | None, ): - super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) + super().__init__( + config, + block_config, + distributed_config, + hidden_dim=hidden_dim, + block_index=block_index, + name=name, + lr_scale=lr_scale, + peft=peft, + ) self._use_flash_attention = self._config.do_use_flash_attention(self._distributed_config) self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) @@ -94,29 +106,34 @@ def __init__( self._softmax_scale = self._config.kv_channels ** (-self._config.attention_softmax_scale_power) - lr_scale = combine_lr_scales( - self._lr_scale, - self._config.attention_lr_scale, - ) - # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = self._config.query_layer.get_layer( hidden_dim, query_dim, - default_weight_initializer=init_normal_(std=self._block_config.init_method_std), + default_weight_initialization=init_normal_(std=self._block_config.init_method_std), default_add_bias=self._block_config.add_linear_biases, + default_apply_peft=True, sequence_parallel=self._sequence_parallel, - lr_scale=lr_scale, + lr_scale=self._lr_scale, + peft=self._peft, ) # TODO: Use value config. - self.key_value = self._config.query_layer.get_layer( + self.key_value = self._config.key_layer.get_layer( hidden_dim, key_value_dim, - default_weight_initializer=init_normal_(std=self._block_config.init_method_std), + default_weight_initialization=init_normal_(std=self._block_config.init_method_std), default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, - lr_scale=lr_scale, + lr_scale=self._lr_scale, + peft=None if self._config.key_layer.apply_peft is None else self._peft, ) + if self._peft is not None and self._config.key_layer.apply_peft is None: + # Default: Apply to value only. + # TODO: Avoid this hack. + self.key_value = self._peft.apply_linear( + self.key_value, True, out_channel_begin=div(key_value_dim.global_size, 2) + ) + self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward) # Rotary embeddings. @@ -126,19 +143,15 @@ def __init__( self.dense = self._config.dense_layer.get_layer( dense_dim, hidden_dim, - default_weight_initializer=init_normal_( + default_weight_initialization=init_normal_( std=self._block_config.init_method_std / max(2 * self._block_config.num_layers, 1) ** 0.5, ), default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, - lr_scale=lr_scale, + lr_scale=self._lr_scale, + peft=self._peft, ) - # PEFT. - self.query = self._block_config.peft.apply_linear(self.query, TransformerSubLayerName.query) - self.key_value = self._block_config.peft.apply_linear(self.key_value, TransformerSubLayerName.key_value) - self.dense = self._block_config.peft.apply_linear(self.dense, TransformerSubLayerName.dense) - if self._debug.enabled: self._query_dims = ( BlockDimNames.batch, diff --git a/fast_llm/layers/attention/block.py b/fast_llm/layers/attention/block.py deleted file mode 100644 index 2c2da2014..000000000 --- a/fast_llm/layers/attention/block.py +++ /dev/null @@ -1,16 +0,0 @@ -import logging -import typing - -from fast_llm.layers.attention.config import AttentionConfig, TransformerConfig -from fast_llm.layers.block.block import Block - -logger = logging.getLogger(__name__) - - -class TransformerBlock[ConfigType: TransformerConfig](Block[ConfigType]): - # TODO: Standardize to `mixer` - _mixer_module_name: typing.ClassVar[str] = "self_attn" - - @property - def _mixer_config(self) -> AttentionConfig: - return self._config.mixer diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index ac35bc0af..47aa9deea 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -3,12 +3,14 @@ import typing import warnings -from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig from fast_llm.layers.attention.rotary.config import RotaryConfig -from fast_llm.layers.block.config import BlockConfig, BlockKwargs, MixerConfig +from fast_llm.layers.block.config import BlockKwargs, MixerConfig from fast_llm.layers.common.linear.config import AffineLinearConfig from fast_llm.utils import Assert, div @@ -95,13 +97,6 @@ class AttentionConfig(MixerConfig): hint=FieldHint.optional, valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - attention_lr_scale: float | None = Field( - default=None, - desc="Custom learning rate scale for the Attention projection weights.", - doc="Can be used in muP to scale the Attention learning rate by 1/width_factor", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) attention_softmax_scale_power: float = Field( default=0.5, desc="The scaling power to apply to kv_channel in the attention calculation. " @@ -138,10 +133,16 @@ def projection_size(self): def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) + def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: + # We have multiple identical rotary modules/preprocessors, so it's simpler to make a new one here. + # TODO: Find a better solution. + preprocessors: list[Preprocessor] = [self.rotary.get_layer(TensorDim("kv_channels", self.kv_channels))] + if self.do_use_flash_attention(distributed_config): + from fast_llm.layers.attention.preprocessing import FlashAttnVarlenPreprocessor -@config_class() -# TODO: Use composition instead -class TransformerConfig(BlockConfig): - _abstract = False - # TODO: Make this unnecessary - mixer: AttentionConfig = FieldUpdate() + preprocessors.append(FlashAttnVarlenPreprocessor(self, distributed_config)) + else: + from fast_llm.layers.attention.preprocessing import BackupAttentionPreprocessor + + preprocessors.append(BackupAttentionPreprocessor(self, distributed_config)) + return preprocessors diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 72df8a63e..10acd67e0 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -12,7 +12,8 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.layers.block.config import BlockConfig, BlockKwargs, MixerConfig +from fast_llm.layers.block.config import BlockConfig, BlockKwargs +from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta @@ -88,7 +89,7 @@ def __call__[ class BlockLayerBase[ConfigType: Config](Configurable[ConfigType], Module): """ - Base class for blocks, mixer and MLP modules. + Base class for blocks, mixers, MLPs, etc. """ def __init__( @@ -96,11 +97,13 @@ def __init__( config: ConfigType, block_config: BlockConfig, distributed_config: DistributedConfig, + *, # TODO: Review `hidden_dim` and `block_index` hidden_dim: TensorDim, block_index: int, name: str, lr_scale: float | None, + peft: PeftConfig | None, ): super().__init__(config, distributed_config) self._block_config = block_config @@ -114,6 +117,7 @@ def __init__( self._block_config.debug_transformer_memory, ) self._lr_scale = lr_scale + self._peft = peft class BlockLayer[ConfigType: Config](BlockLayerBase[ConfigType]): @@ -144,43 +148,43 @@ def __init__( self, config: ConfigType, distributed_config: DistributedConfig, + *, hidden_dim: TensorDim, block_index: int, name: str, lr_scale: float | None, + peft: PeftConfig | None, return_input: bool = False, ): super().__init__( config, config, distributed_config, - hidden_dim, - block_index, - name, - lr_scale, + hidden_dim=hidden_dim, + block_index=block_index, + name=name, + lr_scale=lr_scale, + peft=peft, ) # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input # Note, layer_lr_scale does not impact the norms # TODO: add a separate norm_lr_scale - self.norm_1 = self._config.peft.apply_other( - self._config.normalization.get_layer(self._hidden_dim, self._lr_scale) - ) - self.norm_2 = self._config.peft.apply_other( - self._config.normalization.get_layer(self._hidden_dim, self._lr_scale) - ) + self.norm_1 = self._config.normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) + self.norm_2 = self._config.normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) # Attribute should be mixer, but Attention uses a different name for backward compatibility. TODO: Fix. setattr( self, self._mixer_module_name, - self._mixer_config.get_layer( + self._config.mixer.get_layer( self._config, self._distributed_config, self._hidden_dim, self._block_index, f"{self._name} mixer", self._lr_scale, + peft=peft, ), ) @@ -191,13 +195,9 @@ def __init__( self._block_index, f"{self._name} MLP", self._lr_scale, + peft=peft, ) - @property - @abc.abstractmethod - def _mixer_config(self) -> MixerConfig: - pass - def setup(self, distributed: Distributed) -> None: super().setup(distributed) getattr(self, self._mixer_module_name).setup(distributed) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index b06276297..7602dfabe 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -1,16 +1,18 @@ import typing from fast_llm.config import Field, FieldHint, check_field, config_class -from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.base_model.config import BaseModelConfig, Preprocessor +from fast_llm.engine.config_utils.parameter import combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.block.peft import TransformerPeftConfig from fast_llm.layers.common.normalization.config import NormalizationConfig +from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.layers.block.block import BlockLayer + # TODO: Generalize these beyond language models? (Ex. vision) @@ -47,6 +49,13 @@ class BlockLayerConfig(BaseModelConfig): _abstract = True block: "BlockConfig" = Field(init=False) + lr_scale: float | None = Field( + default=None, + desc="Scaling factor for the layer learning rate." + " Combines multiplicatively with the scale set by the parent and child layers, if applicable.", + hint=FieldHint.feature, + ) + @property def layer_class(self) -> "type[BlockLayer]": raise NotImplementedError() @@ -63,17 +72,22 @@ def get_layer( block_index: int, name: str, lr_scale: float | None, + peft: PeftConfig | None, ) -> "BlockLayer": return self.layer_class( self, block_config, distributed_config, - hidden_dim, - block_index, - name, - lr_scale, + hidden_dim=hidden_dim, + block_index=block_index, + name=name, + lr_scale=combine_lr_scales(lr_scale, self.lr_scale), + peft=peft, ) + def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: + return [] + @config_class(registry=True) class MLPBaseConfig(BlockLayerConfig): @@ -118,6 +132,7 @@ def _from_dict( @config_class() # TODO: Use composition instead class BlockConfig(BaseModelConfig): + _abstract = False mixer: MixerConfig = Field() mlp: MLPBaseConfig = Field() # TODO: Review names @@ -125,9 +140,11 @@ class BlockConfig(BaseModelConfig): desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, ) - peft: TransformerPeftConfig = Field( - desc="Configuration for the parameter-efficient fine tuning.", - hint=FieldHint.architecture, + lr_scale: float | None = Field( + default=None, + desc="Scaling factor for the layer learning rate." + " Combines multiplicatively with the scale set by the parent and child layers, if applicable.", + hint=FieldHint.feature, ) # TODO: Review names hidden_dropout: float = Field( @@ -170,13 +187,6 @@ class BlockConfig(BaseModelConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - per_layer_lr_scale: list[float | None] | None = Field( - default=None, - desc="Custom learning rate scale for each layer.", - doc="May be used to freeze some layers by setting their scale to zero.", - hint=FieldHint.feature, - ) - # TODO: Review initialization init_method_std: float = Field( default=None, @@ -184,16 +194,6 @@ class BlockConfig(BaseModelConfig): hint=FieldHint.optional, valid=check_field(Assert.geq, 0), ) - init_method_max: float | None = Field( - default=None, - desc="Max value for clamping initialized weights. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min: float | None = Field( - default=None, - desc="Min value for clamping initialized weights. Default: -float('inf')", - hint=FieldHint.optional, - ) def _validate(self) -> None: with self._set_implicit_default(): @@ -201,10 +201,34 @@ def _validate(self) -> None: # TODO: Review initialization if self.init_method_std is None: self.init_method_std = self.hidden_size**-0.5 - if self.init_method_min is not None and self.init_method_max is not None: - Assert.leq(self.init_method_min, self.init_method_max) self.mixer.set_defaults(self.hidden_size) self.mlp.set_defaults(self.hidden_size) super()._validate() + + def get_layer( + self, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + lr_scale: float | None, + peft: PeftConfig | None = None, + return_input: bool = False, + ): + from fast_llm.layers.block.block import Block + + return Block( + self, + distributed_config, + hidden_dim=hidden_dim, + block_index=block_index, + name=name, + lr_scale=combine_lr_scales(lr_scale, self.lr_scale), + peft=peft, + return_input=return_input, + ) + + def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: + return self.mixer.get_preprocessors(distributed_config) diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 7b8d7b8c7..3e7d96736 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -56,12 +56,6 @@ class MLPConfig(MLPBaseConfig): desc="Set which of the MLP intermediate activations will be recomputed during the backward passes. This provides a trade-off between memory and speed.", hint=FieldHint.performance, ) - lr_scale: float | None = Field( - default=None, - desc="Custom learning rate scale for each expert.", - doc="May be used to freeze some experts by setting their scale to zero.", - hint=FieldHint.feature, - ) def set_defaults(self, hidden_size: int): if self.ffn_hidden_size is None: @@ -150,7 +144,7 @@ def layer_class(self) -> "type[MixtureOfExpertMLP]": return MixtureOfExpertMLP @functools.cached_property - def num_unshared_experts(self): + def num_unshared_experts(self) -> int: return self.num_experts - self.num_shared_experts def _validate(self) -> None: diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index 49ab34a75..d0d94d88c 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -13,7 +13,8 @@ from fast_llm.layers.block.mlp.config import MLPLossNames, MoEMLPConfig, RoutingType from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss -from fast_llm.utils import Assert, combine_lr_scales +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -37,24 +38,33 @@ def __init__( config: ConfigType, block_config: BlockConfig, distributed_config: DistributedConfig, + *, + # TODO: Review `hidden_dim` and `block_index` hidden_dim: TensorDim, block_index: int, name: str, lr_scale: float | None, + peft: PeftConfig | None, ): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not block_config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) + super().__init__( + config, + block_config, + distributed_config, + hidden_dim=hidden_dim, + block_index=block_index, + name=name, + lr_scale=lr_scale, + peft=peft, + ) self.router = self._config.router.get_layer( self._hidden_dim, TensorDim("router_experts", self._config.num_unshared_experts), - default_weight_initializer=init_normal_( - std=self._block_config.init_method_std, - min_val=self._block_config.init_method_min, - max_val=self._block_config.init_method_max, - ), - lr_scale=combine_lr_scales(self._lr_scale, self._config.lr_scale), + default_weight_initialization=init_normal_(std=self._block_config.init_method_std), + lr_scale=self._lr_scale, + peft=self._peft, ) dropless_moe = self._config.dropless_moe if dropless_moe and self._sequence_parallel: diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index 8366c8cb5..8b6ede2d8 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -10,8 +10,7 @@ from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig from fast_llm.layers.block.mlp.config import MLPConfig -from fast_llm.layers.block.peft import TransformerSubLayerName -from fast_llm.utils import combine_lr_scales +from fast_llm.layers.common.peft.config import PeftConfig class MLPBase[ConfigType: MLPConfig](BlockLayer[ConfigType]): @@ -22,44 +21,52 @@ def __init__( config: ConfigType, block_config: BlockConfig, distributed_config: DistributedConfig, + *, + # TODO: Review `hidden_dim` and `block_index` hidden_dim: TensorDim, block_index: int, name: str, lr_scale: float | None, + peft: PeftConfig | None, ): - super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) + super().__init__( + config, + block_config, + distributed_config, + hidden_dim=hidden_dim, + block_index=block_index, + name=name, + lr_scale=lr_scale, + peft=peft, + ) self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) intermediate_1_dim, intermediate_2_dim = self._get_intermediate_dims() self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - lr_scale = combine_lr_scales(self._lr_scale, self._config.lr_scale) - # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = self._config.layer_1.get_layer( hidden_dim, intermediate_1_dim, - default_weight_initializer=init_normal_(std=self._block_config.init_method_std), + default_weight_initialization=init_normal_(std=self._block_config.init_method_std), default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, - lr_scale=lr_scale, + lr_scale=self._lr_scale, + peft=self._peft, ) self.layer_2 = self._config.layer_1.get_layer( intermediate_2_dim, hidden_dim, - default_weight_initializer=init_normal_( + default_weight_initialization=init_normal_( std=self._block_config.init_method_std / max(2 * self._block_config.num_layers, 1) ** 0.5 ), default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, transposed_weight=True, - lr_scale=lr_scale, + lr_scale=self._lr_scale, + peft=self._peft, ) - # PEFT. - self.layer_1 = self._block_config.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1) - self.layer_2 = self._block_config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) - def _get_intermediate_dims(self): intermediate_2_dim = TensorDim("intermediate", self._config.ffn_hidden_size, self._parallel_dim) if self._config.gated: @@ -73,18 +80,6 @@ def _get_intermediate_dims(self): class MLP[ConfigType: MLPConfig](MLPBase[ConfigType]): _config: MLPConfig - def __init__( - self, - config: ConfigType, - block_config: BlockConfig, - distributed_config: DistributedConfig, - hidden_dim: TensorDim, - block_index: int, - name: str, - lr_scale: float | None, - ): - super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) - def forward( self, input_: torch.Tensor, diff --git a/fast_llm/layers/block/peft.py b/fast_llm/layers/block/peft.py deleted file mode 100644 index ffa40a255..000000000 --- a/fast_llm/layers/block/peft.py +++ /dev/null @@ -1,87 +0,0 @@ -""" -TODO: Generalize beyond transformers. -""" - -import enum -import typing - -from fast_llm.config import Field, FieldHint, config_class -from fast_llm.layers.common.peft.config import LoRAConfig, NoPeftConfig, PeftConfig -from fast_llm.utils import div - -if typing.TYPE_CHECKING: - from fast_llm.layers.common.linear.linear import LinearBase, LinearLike - - -class TransformerSubLayerName(str, enum.Enum): - query = "query" - key = "key" - value_ = "value" - key_value = "key_value" - dense = "dense" - mlp_1 = "mlp_1" - mlp_2 = "mlp_2" - - -@config_class(registry=True) -class TransformerPeftConfig(PeftConfig): - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - if cls is TransformerPeftConfig and cls.get_subclass(default.get("type")) is None: - # Default subclass. - return TransformerNoPeftConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) - - -@config_class(dynamic_type={TransformerPeftConfig: "none"}) -class TransformerNoPeftConfig(NoPeftConfig, TransformerPeftConfig): - pass - - -@config_class(dynamic_type={TransformerPeftConfig: "lora"}) -class TransformerLoRAConfig(LoRAConfig, TransformerPeftConfig): - layers: list[TransformerSubLayerName] = Field( - default=(TransformerSubLayerName.query, TransformerSubLayerName.value_), - desc="The layers on which to apply LoRA.", - hint=FieldHint.feature, - ) - - def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": - out_channel_begin, out_channel_end = None, None - if layer_type is None or self.layers is None or layer_type in self.layers: - enabled = True - if layer_type == TransformerSubLayerName.key: - out_channel_end = div(linear._out_dim.global_size, 2) - elif layer_type == TransformerSubLayerName.value_: - out_channel_begin = div(linear._out_dim.global_size, 2) - else: - enabled = False - return super().apply_linear(linear, enabled, out_channel_begin, out_channel_end) - - def _validate(self) -> None: - super()._validate() - if TransformerSubLayerName.mlp_1 in self.layers or TransformerSubLayerName.mlp_2 in self.layers: - # TODO: Add MLP support. - raise NotImplementedError("LoRA not supported for MLP.") - if TransformerSubLayerName.dense in self.layers: - # TODO: Support InputParallelLinear (different output format). - raise NotImplementedError("LoRA not supported for attention dense layer.") - if ( - sum( - name in self.layers - for name in ( - TransformerSubLayerName.key_value, - TransformerSubLayerName.key, - TransformerSubLayerName.value_, - ) - ) - > 1 - ): - raise ValueError( - f"{TransformerSubLayerName.key_value.value}, {TransformerSubLayerName.key.value} and {TransformerSubLayerName.value_.value} are mutually exclusive." - ) diff --git a/fast_llm/layers/common/linear/config.py b/fast_llm/layers/common/linear/config.py index e9dbe9229..2ed97ae66 100644 --- a/fast_llm/layers/common/linear/config.py +++ b/fast_llm/layers/common/linear/config.py @@ -1,10 +1,11 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.engine.config_utils.initialization import Initializer, init_uniform_centered_, init_zeros_ -from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig +from fast_llm.engine.config_utils.initialization import Initialization, init_uniform_centered_, init_zeros_ +from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.functional.config import ActivationType +from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -18,6 +19,12 @@ class LinearBaseConfig(Config): Configuration for a linear-like layer without bias. """ + lr_scale: float | None = Field( + default=None, + desc="Scaling factor for the layer learning rate." + " Combines multiplicatively with the scale set by the parent layer and individual parameters, if applicable.", + hint=FieldHint.feature, + ) weight: ParameterConfig = Field( desc="Initialization configuration for the weight.", hint=FieldHint.feature, @@ -38,26 +45,39 @@ class AffineLinearBaseConfig(LinearBaseConfig): @config_class() class LinearConfig(LinearBaseConfig): + apply_peft: bool | None = Field( + default=None, + desc="Wrap this layer ." + " Otherwise, treat the layer as a non-peft layer (may be frozen)." + " If not provided, the default set by the parent layer will be used.", + hint=FieldHint.feature, + ) + def get_layer( self, in_dim: TensorDim, out_dim: TensorDim, *, - default_weight_initializer: Initializer, + default_weight_initialization: Initialization, + default_apply_peft: bool = False, sequence_parallel: bool = False, transposed_weight: bool = False, lr_scale: float | None, + peft: PeftConfig | None, ) -> "LinearBase": from fast_llm.layers.common.linear.linear import InputParallelLinear, Linear, OutputParallelLinear + lr_scale = combine_lr_scales(lr_scale, self.lr_scale) weight = self.weight.get_parameter( (in_dim, out_dim) if transposed_weight else (out_dim, in_dim), - default_initializer=default_weight_initializer, + default_initialization=default_weight_initialization, lr_scale=lr_scale, + peft=None, ) + if in_dim.parallel_dim is not None: assert out_dim.parallel_dim is None - return InputParallelLinear( + out = InputParallelLinear( weight, None, transposed_weight=transposed_weight, @@ -65,7 +85,7 @@ def get_layer( sequence_parallel=sequence_parallel, ) elif out_dim.parallel_dim is not None: - return OutputParallelLinear( + out = OutputParallelLinear( weight, None, transposed_weight=transposed_weight, @@ -74,7 +94,12 @@ def get_layer( ) else: assert not sequence_parallel - return Linear(weight, None, transposed_weight=transposed_weight) + out = Linear(weight, None, transposed_weight=transposed_weight) + + if peft is not None: + out = peft.apply_linear(out, default_apply_peft if self.apply_peft is None else self.apply_peft) + + return out @config_class() @@ -84,29 +109,34 @@ def get_layer( in_dim: TensorDim, out_dim: TensorDim, *, - default_weight_initializer: Initializer, - default_bias_initializer: Initializer = init_zeros_, + default_weight_initialization: Initialization, + default_bias_initialization: Initialization = init_zeros_, default_add_bias: bool = True, + default_apply_peft: bool = False, sequence_parallel: bool = False, transposed_weight: bool = False, lr_scale: float | None, + peft: PeftConfig | None, ) -> "LinearBase": from fast_llm.layers.common.linear.linear import InputParallelLinear, Linear, OutputParallelLinear + lr_scale = combine_lr_scales(lr_scale, self.lr_scale) weight = self.weight.get_parameter( (in_dim, out_dim) if transposed_weight else (out_dim, in_dim), - default_initializer=default_weight_initializer, + default_initialization=default_weight_initialization, lr_scale=lr_scale, + peft=None, ) bias = self.bias.get_parameter( (out_dim,), - default_initializer=default_bias_initializer, + default_initialization=default_bias_initialization, lr_scale=lr_scale, default_enabled=default_add_bias, + peft=None, ) if in_dim.parallel_dim is not None: assert out_dim.parallel_dim is None - return InputParallelLinear( + out = InputParallelLinear( weight, bias, transposed_weight=transposed_weight, @@ -114,7 +144,7 @@ def get_layer( sequence_parallel=sequence_parallel, ) elif out_dim.parallel_dim is not None: - return OutputParallelLinear( + out = OutputParallelLinear( weight, bias, transposed_weight=transposed_weight, @@ -123,7 +153,12 @@ def get_layer( ) else: assert not sequence_parallel - return Linear(weight, bias, transposed_weight=transposed_weight) + out = Linear(weight, bias, transposed_weight=transposed_weight) + + if peft is not None: + out = peft.apply_linear(out, default_apply_peft if self.apply_peft is None else self.apply_peft) + + return out @config_class() @@ -147,31 +182,37 @@ def get_layer( self, in_dim: TensorDim, *, - default_weight_initializer: Initializer | None = None, - default_bias_initializer: Initializer | None = None, + default_weight_initialization: Initialization | None = None, + default_bias_initialization: Initialization | None = None, default_add_bias: bool = True, default_activation: ActivationType = ActivationType.identity, lr_scale: float | None, + peft: PeftConfig | None, ) -> "CausalConv1d": from fast_llm.layers.common.linear.convolution import CausalConv1d kernel_dim = TensorDim("convolution_kernel", self.kernel_size) - if default_weight_initializer is None: - default_weight_initializer = init_uniform_centered_((in_dim.global_size * kernel_dim.global_size) ** -0.5) - if default_bias_initializer is None: - default_bias_initializer = init_uniform_centered_((in_dim.global_size * kernel_dim.global_size) ** -0.5) + if default_weight_initialization is None: + default_weight_initialization = init_uniform_centered_( + (in_dim.global_size * kernel_dim.global_size) ** -0.5 + ) + if default_bias_initialization is None: + default_bias_initialization = init_uniform_centered_((in_dim.global_size * kernel_dim.global_size) ** -0.5) + lr_scale = (combine_lr_scales(lr_scale, self.lr_scale),) weight = self.weight.get_parameter( (in_dim, scalar_dim, kernel_dim), - default_initializer=default_weight_initializer, + default_initialization=default_weight_initialization, lr_scale=lr_scale, + peft=peft, ) bias = self.bias.get_parameter( (in_dim,), - default_initializer=default_bias_initializer, + default_initialization=default_bias_initialization, lr_scale=lr_scale, default_enabled=default_add_bias, + peft=peft, ) return CausalConv1d( weight, bias, activation=default_activation if self.activation is None else self.activation diff --git a/fast_llm/layers/common/normalization/config.py b/fast_llm/layers/common/normalization/config.py index 569d48b0e..3401e61be 100644 --- a/fast_llm/layers/common/normalization/config.py +++ b/fast_llm/layers/common/normalization/config.py @@ -4,6 +4,8 @@ from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.config_utils.parameter import combine_lr_scales +from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -25,7 +27,12 @@ class NormalizationImplementation(str, enum.Enum): @config_class(registry=True) class NormalizationConfig(BaseModelConfig): - pass + lr_scale: float | None = Field( + default=None, + desc="Scaling factor for the layer learning rate." + " Combines multiplicatively with the scale set by the parent layer and individual parameters, if applicable.", + hint=FieldHint.feature, + ) @property @abc.abstractmethod @@ -35,9 +42,14 @@ def module_class(self) -> type["Normalization"]: def get_layer( self, hidden_dim: "TensorDim", + *, lr_scale: float | None = None, + peft: PeftConfig | None, ) -> "Normalization": - return self.module_class(self, hidden_dim, lr_scale) + out = self.module_class(self, hidden_dim, combine_lr_scales(lr_scale, self.lr_scale)) + if peft is not None: + out = peft.apply_normalization(out) + return out @classmethod def _from_dict( diff --git a/fast_llm/layers/common/peft/config.py b/fast_llm/layers/common/peft/config.py index 7c7834cbd..d0af61cee 100644 --- a/fast_llm/layers/common/peft/config.py +++ b/fast_llm/layers/common/peft/config.py @@ -1,7 +1,6 @@ import typing -from fast_llm.config import Field, FieldHint, config_class -from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.config import Config, Field, FieldHint, config_class if typing.TYPE_CHECKING: import torch @@ -11,8 +10,22 @@ from fast_llm.tensor import ParameterMeta -@config_class() -class PeftConfig(BaseModelConfig): +@config_class(registry=True) +class PeftConfig(Config): + _abstract = True + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is PeftConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return NoPeftConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + def apply_linear( self, module: "LinearBase", @@ -34,12 +47,12 @@ def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": return parameter -@config_class() +@config_class(dynamic_type={PeftConfig: "none"}) class NoPeftConfig(PeftConfig): _abstract = False -@config_class() +@config_class(dynamic_type={PeftConfig: "lora"}) class LoRAConfig(PeftConfig): _abstract = False diff --git a/fast_llm/layers/common/peft/lora.py b/fast_llm/layers/common/peft/lora.py index f84967cab..fcff5d496 100644 --- a/fast_llm/layers/common/peft/lora.py +++ b/fast_llm/layers/common/peft/lora.py @@ -5,6 +5,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.common.linear.linear import Linear, LinearBase +from fast_llm.tensor import ParameterMeta def lora_linear( @@ -35,20 +36,22 @@ def lora_linear( middle_dim = TensorDim("lora_middle", rank) module.lora_0 = Linear( - in_dim, - middle_dim, - bias=False, - weight_init_method=module.weight.param_init_method, + ParameterMeta.from_dims( + (in_dim, middle_dim), + init_method=module.weight.param_init_method, + lr_scale=module.weight.lr_scale, + ), + None, transposed_weight=module.transposed_weight, - lr_scale=module.weight.lr_scale, ) module.lora_1 = Linear( - middle_dim, - out_dim, - bias=False, - weight_init_method=module.weight.param_init_method, + ParameterMeta.from_dims( + (middle_dim, out_dim), + init_method=module.weight.param_init_method, + lr_scale=module.weight.lr_scale, + ), + None, transposed_weight=module.transposed_weight, - lr_scale=module.weight.lr_scale, ) old_forward = module._forward diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index abf2f53df..79772bf82 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,14 +1,19 @@ import typing from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.parameter import ParameterConfig +from fast_llm.engine.base_model.config import BaseModelConfig, Preprocessor +from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl -from fast_llm.layers.attention.config import TransformerConfig -from fast_llm.layers.attention.rotary.config import NoRotaryConfig -from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockLayerConfig +from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.layers.language_model.embedding import LanguageModelEmbedding + from fast_llm.layers.language_model.head import LanguageModelHead + class LanguageModelLossNames: language_model_loss = "language_model_loss" @@ -36,45 +41,64 @@ class LanguageModelKwargs(BlockKwargs): @config_class() -class LanguageModelBaseConfig(BaseModelConfig): - # TODO: block - transformer: TransformerConfig = Field( - desc="Configuration for the transformer architecture.", - hint=FieldHint.architecture, - ) - word_embeddings_layer: ParameterConfig = Field( +class LanguageModelEmbeddingsConfig(BlockLayerConfig): + _abstract = False + word_embeddings: ParameterConfig = Field( desc="Configuration for the word embedding (weight).", hint=FieldHint.architecture, ) - position_embeddings_layer: ParameterConfig = Field( + position_embeddings: OptionalParameterConfig = Field( desc="Configuration for the word embedding (weight).", hint=FieldHint.architecture, ) - output_layer: ParameterConfig = Field( - desc="Configuration for the LM output layer (weight). Ignored for tied embeddings", + vocab_size: int = Field( + default=49152, + desc="Size of the vocabulary, i.e., number of vocabulary embeddings and logits.", hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), ) - max_position_embeddings: int = Field( + num_position_embeddings: int = Field( default=2048, desc="Number of absolute position embeddings, if applicable.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - # TODO: Move to `word_embeddings_layer`/`output_layer`? - vocab_size: int = Field( - default=49152, - desc="Size of the vocabulary, i.e., number of vocabulary embeddings and logits.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), + # Tensor-parallel word embeddings + # (Default init std is different, dropout won't match, needs seq_first = False.) + # (disable to allow for sequence-parallel embeddings and logits, better for larger models) + vocab_parallel: bool = Field( + default=True, + desc="Allow for tensor-parallel vocabulary embeddings and output weights.", + doc="Disable to allow for sequence-tensor-parallel input tokens, logits and cross-entropy computation." + " The sequence-tensor-parallel version typically runs faster, but may incur a small memory cost." + " Affects RNG for initialization and dropout.", + hint=FieldHint.performance, ) - # TODO: Move to `position_embeddings_layer.enabled`? - use_position_embeddings: bool = Field( - default=None, - desc="Enable absolute position embeddings. Default: Enable unless using rotary embeddings.", + + @property + def layer_class(self) -> "type[LanguageModelEmbedding]": + from fast_llm.layers.language_model.embedding import LanguageModelEmbedding + + return LanguageModelEmbedding + + def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: + preprocessors = [] + if self.position_embeddings.enabled: + from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor + + preprocessors.append(PositionEmbeddingPreprocessor(self, distributed_config)) + return preprocessors + + +@config_class() +class LanguageModelHeadConfig(BlockLayerConfig): + _abstract = False + # TODO: Cleanup + output_weight: ParameterConfig = Field( + desc="Configuration for the LM output layer (weight). Ignored for tied embeddings", hint=FieldHint.architecture, ) - # TODO: Move to `output_layer`? (dynamic type?) - tie_word_embeddings: bool = Field( + tied_weight: bool = Field( default=True, desc="Tie the output weights (logits) with the vocabulary embedding.", hint=FieldHint.architecture, @@ -85,38 +109,7 @@ class LanguageModelBaseConfig(BaseModelConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - init_method_std_embed: float = Field( - default=None, - desc="Initialization scale for the vocabulary embedding and output weights (logits).", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - init_method_max_embed: float | None = Field( - default=None, - desc="Max value for clamping initialized weights of the vocabulary embedding and output (logits).", - hint=FieldHint.feature, - ) - init_method_min_embed: float | None = Field( - default=None, - desc="Min value for clamping initialized weights of the vocabulary embedding and output (logits).", - hint=FieldHint.feature, - ) - enable_dpo: bool | None = Field( - default=False, - desc="Whether to enable DPO loss", - hint=FieldHint.feature, - ) - dpo_beta: float | None = Field( - default=1.0, - desc="Beta value for DPO loss.", - hint=FieldHint.feature, - ) - dpo_reference_model: str | None = Field( - default=None, - desc="Name of the reference model to use for dpo.", - hint=FieldHint.feature, - ) - cross_entropy_impl: CrossEntropyImpl = Field( + cross_entropy_implementation: CrossEntropyImpl = Field( default=CrossEntropyImpl.auto, desc="Implementation for the cross-entropy computation.", hint=FieldHint.performance, @@ -132,31 +125,6 @@ class LanguageModelBaseConfig(BaseModelConfig): hint=FieldHint.feature, valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) - distillation_model: str | None = Field( - default=None, - desc="Name of the reference model to use for knowledge distillation." - "If provided, replace the loss with a distillation loss.", - hint=FieldHint.feature, - ) - # Tensor-parallel word embeddings - # (Default init std is different, dropout won't match, needs seq_first = False.) - # (disable to allow for sequence-parallel embeddings and logits, better for larger models) - parallel_embeddings: bool = Field( - default=True, - desc="Allow for tensor-parallel vocabulary embeddings and output weights.", - doc="Disable to allow for sequence-tensor-parallel input tokens, logits and cross-entropy computation." - " The sequence-tensor-parallel version typically runs faster, but may incur a small memory cost." - " Affects RNG for initialization and dropout.", - hint=FieldHint.performance, - ) - sequence_first: bool | None = Field( - default=None, - desc="Override the default dimension ordering", - doc="By default, the hidden states are stored with dimensions (batch, sequence, ...), as it makes attention more efficient." - " However, some settings such as sequence-tensor/data/pipelineo-parallel instead require the ordering (sequence, batch, ...)." - " Setting this parameter overrides the default choice. Note that setting to `False` will either do nothing or raise an error.", - hint=FieldHint.testing, - ) logit_z_loss: float = Field( default=0.0, desc="Regularize the logits with Z-loss.", @@ -182,6 +150,12 @@ class LanguageModelBaseConfig(BaseModelConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) + prediction_loss_coefficient: list[float] | None = Field( + default=None, + desc="Loss coefficient for each prediction head.", + doc="If not provided, all heads are equally weighted.", + hint=FieldHint.feature, + ) teacher_softmax_temperature: float = Field( default=1.0, desc="Divides distillation target logits by this factor.", @@ -189,45 +163,42 @@ class LanguageModelBaseConfig(BaseModelConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - embeddings_lr_scale: float | None = Field( - default=None, - desc="Learning rate scale for the word embeddings.", - doc="May be used to freeze some layers by setting their scale to zero.", + enable_dpo: bool | None = Field( + default=False, + desc="Whether to enable DPO loss", + hint=FieldHint.feature, + ) + dpo_beta: float | None = Field( + default=1.0, + desc="Beta value for DPO loss.", hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - output_lr_scale: float | None = Field( + dpo_reference_model: str | None = Field( default=None, - desc="Custom learning rate scale for the output weights.", - doc="May be used to freeze the output weights by setting their scale to zero.", + desc="Name of the reference model to use for dpo.", hint=FieldHint.feature, ) - prediction_loss_coefficient: list[float] | None = Field( + distillation_model: str | None = Field( default=None, - desc="Loss coefficient for each prediction head.", - doc="If not provided, all heads are equally weighted.", + desc="Name of the reference model to use for knowledge distillation." + "If provided, replace the loss with a distillation loss.", hint=FieldHint.feature, ) + @property + def layer_class(self) -> "type[LanguageModelHead]": + from fast_llm.layers.language_model.head import LanguageModelHead + + return LanguageModelHead + def _validate(self) -> None: - self.transformer.validate() with self._set_implicit_default(): if self.language_model_loss_factor is None: if self.distillation_model is None: self.language_model_loss_factor = 1.0 else: self.language_model_loss_factor = 0.0 - if self.use_position_embeddings is None: - self.use_position_embeddings = isinstance(self.transformer.mixer.rotary, NoRotaryConfig) - if self.init_method_std_embed is None: - self.init_method_std_embed = self.transformer.init_method_std - if self.init_method_max_embed is None: - self.init_method_max_embed = self.transformer.init_method_max - if self.init_method_min_embed is None: - self.init_method_min_embed = self.transformer.init_method_min super()._validate() - if self.init_method_max_embed is not None and self.init_method_min_embed is not None: - Assert.leq(self.init_method_min_embed, self.init_method_max_embed) if self.distillation_model is not None: if self.prediction_heads > 1: raise NotImplementedError("Multi-token prediction not supported with distillation.") @@ -235,22 +206,66 @@ def _validate(self) -> None: Assert.eq(len(self.prediction_loss_coefficient), self.prediction_heads) for coeff in self.prediction_loss_coefficient: Assert.geq(coeff, 0) - if self.transformer.per_layer_lr_scale is not None: - # -1 because the first prediction head's transformer layer is accounted for in num_layers - # +1 because the layer index starts at 1 - Assert.eq( - len(self.transformer.per_layer_lr_scale), self.transformer.num_layers + self.prediction_heads - 1 + 1 - ) - @property - def num_absolute_position_embeddings(self) -> int: - # TODO: Rename from max embeddings. - return self.max_position_embeddings if self.use_absolute_position_embeddings else None + def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: + preprocessors: list[Preprocessor] = [] - @property - def use_absolute_position_embeddings(self) -> int: - # TODO: Set through num embeddings instead instead. - return self.use_position_embeddings + if self.enable_dpo: # TODO better way to pass in? + from fast_llm.layers.language_model.preprocessing import PreferenceSpanPreprocessor + + preprocessors.append(PreferenceSpanPreprocessor()) + + return preprocessors + + def get_layer( + self, + block_config: "BlockConfig", + distributed_config: DistributedConfig, + embeddings_config: LanguageModelEmbeddingsConfig, + *, + hidden_dim: TensorDim, + block_index: int, + name: str, + lr_scale: float | None, + peft: PeftConfig | None, + prediction_distance: int = 0, + ): + return self.layer_class( + self, + block_config, + distributed_config, + embeddings_config, + hidden_dim=hidden_dim, + block_index=block_index, + name=name, + lr_scale=combine_lr_scales(lr_scale, self.lr_scale), + peft=peft, + prediction_distance=prediction_distance, + ) + + +@config_class() +class LanguageModelBaseConfig(BaseModelConfig): + # TODO: block + transformer: BlockConfig = Field( + desc="Configuration for the transformer architecture.", + hint=FieldHint.architecture, + ) + embeddings_layer: LanguageModelEmbeddingsConfig = Field() + output_layer: LanguageModelHeadConfig = Field() + # TODO: Allow overriding in sub-models? + peft: PeftConfig = Field( + desc="Configuration for parameter-efficient fine tuning.", + hint=FieldHint.architecture, + ) + sequence_first: bool | None = Field( + default=None, + desc="Override the default dimension ordering", + doc="By default, the hidden states are stored with dimensions (batch, sequence, ...), as it makes attention more efficient." + " However, some settings such as sequence-tensor/data/pipelineo-parallel instead require the ordering (sequence, batch, ...)." + " Setting this parameter overrides the default choice. Note that setting to `False` will either do nothing or raise an error.", + hint=FieldHint.testing, + ) @classmethod def from_flat_dict( @@ -265,3 +280,10 @@ def from_flat_dict( cls._handle_renamed_field(default, "layer_norm_eps", "epsilon") cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") return super().from_flat_dict(default, strict) + + def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: + return ( + self.embeddings_layer.get_preprocessors(distributed_config) + + self.transformer.get_preprocessors(distributed_config) + + self.output_layer.get_preprocessors(distributed_config) + ) diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 270f2630b..98904c5e5 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -9,14 +9,16 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.layers.block.block import BlockLayerBase -from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs +from fast_llm.layers.block.config import BlockConfig +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, LanguageModelKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert WORD_EMBEDDINGS_WEIGHT = "word_embeddings_weight" -class LanguageModelEmbedding[ConfigType: LanguageModelBaseConfig](BlockLayerBase[ConfigType], Layer): +class LanguageModelEmbedding[ConfigType: LanguageModelEmbeddingsConfig](BlockLayerBase[ConfigType], Layer): """ A language model embedding layer. Consists of word embeddings (tensor-parallel or sequence-tensor-parallel), @@ -25,87 +27,77 @@ class LanguageModelEmbedding[ConfigType: LanguageModelBaseConfig](BlockLayerBase # Ensure the layer is on its own stage. layer_count: float = 1000.0 + _config: ConfigType def __init__( self, config: ConfigType, + # TODO: Doesn't make much sense. + block_config: BlockConfig, distributed_config: DistributedConfig, + *, + # TODO: Review `hidden_dim` and `block_index` hidden_dim: TensorDim, - # TODO: Unnecessary? block_index: int, name: str, + lr_scale: float | None, + peft: PeftConfig | None, ): super().__init__( config, - config.transformer, + block_config, distributed_config, - hidden_dim, - block_index, - name, - # TODO: Add lr scale? - None, + hidden_dim=hidden_dim, + block_index=block_index, + name=name, + lr_scale=lr_scale, + peft=peft, ) self._residual_dtype = ( self._distributed_config.optimization_dtype - if config.transformer.full_precision_residual + if self._block_config.full_precision_residual else self._distributed_config.training_dtype ).torch self._sequence_parallel = self._distributed_config.sequence_tensor_parallel - self._parallel_embeddings = self._distributed_config.tensor_parallel > 1 and config.parallel_embeddings + self._vocab_parallel = self._distributed_config.tensor_parallel > 1 and self._config.vocab_parallel self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) - vocab_dim = TensorDim( - "vocab", self._config.vocab_size, self._parallel_dim if self._parallel_embeddings else None - ) + vocab_dim = TensorDim("vocab", self._config.vocab_size, self._parallel_dim if self._vocab_parallel else None) - if self._parallel_embeddings: + if self._vocab_parallel: self._vocab_start_index = self._distributed_config.tensor_rank * vocab_dim.size self._vocab_end_index = (self._distributed_config.tensor_rank + 1) * vocab_dim.size - self.word_embeddings_weight = self._config.word_embeddings_layer.get_parameter( + self.word_embeddings_weight = self._config.word_embeddings.get_parameter( (vocab_dim, self._hidden_dim), - default_initializer=init_normal_( - std=config.init_method_std_embed, - min_val=config.init_method_min_embed, - max_val=config.init_method_max_embed, - ), - lr_scale=self._config.embeddings_lr_scale, + default_initialization=init_normal_(std=self._block_config.init_method_std), + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.position_embeddings_weight = self._config.position_embeddings.get_parameter( + (TensorDim("position_embeddings", self._config.num_position_embeddings), self._hidden_dim), + default_initialization=init_normal_(std=self._block_config.init_method_std), + allow_sequence_tensor_parallel=not self._vocab_parallel, + lr_scale=self._lr_scale, + peft=self._peft, ) - if self._config.use_absolute_position_embeddings: - self.position_embeddings_weight = self._config.position_embeddings_layer.get_parameter( - (TensorDim("position_embeddings", self._config.max_position_embeddings), self._hidden_dim), - default_initializer=init_normal_( - std=config.init_method_std_embed, - min_val=config.init_method_min_embed, - max_val=config.init_method_max_embed, - ), - allow_sequence_tensor_parallel=not config.parallel_embeddings, - lr_scale=self._config.embeddings_lr_scale, - ) - - # PEFT. - self.word_embeddings_weight = self._config.transformer.peft.apply_weight(self.word_embeddings_weight) - if hasattr(self, "position_embeddings_weight"): - self.position_embeddings_weight = self._config.transformer.peft.apply_weight( - self.position_embeddings_weight - ) @torch.compile def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask_inputs: bool) -> torch.Tensor: - Assert.eq(position_ids is not None, self._config.use_absolute_position_embeddings) + Assert.eq(position_ids is None, self.position_embeddings_weight is None) group = self._parallel_dim.group - if self._parallel_embeddings: + if self._vocab_parallel: input_mask = (input_ >= self._vocab_start_index) * (input_ < self._vocab_end_index) masked_input = (input_ - self._vocab_start_index) * input_mask embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * input_mask.unsqueeze(2) # noqa embeddings = reduce_forward(embeddings, group) - if self._config.use_absolute_position_embeddings: + if self.position_embeddings_weight is not None: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) else: if self._sequence_parallel: input_ = split(input_, group=group, dim=0) - if self._config.use_absolute_position_embeddings: + if self.position_embeddings_weight is not None: position_ids = split(position_ids, group=group, dim=0) # handle masked tokens if mask_inputs: @@ -114,14 +106,14 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask embeddings = torch.embedding(self.word_embeddings_weight, masked_input) else: embeddings = torch.embedding(self.word_embeddings_weight, input_) - if self._config.use_absolute_position_embeddings: + if self.position_embeddings_weight is not None: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if mask_inputs: embeddings = embeddings * input_mask.unsqueeze(2) with set_generator( self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator ): - embeddings = torch.dropout(embeddings, self._config.transformer.hidden_dropout, self.training) + embeddings = torch.dropout(embeddings, self._block_config.hidden_dropout, self.training) return embeddings.to(dtype=self._residual_dtype) def forward( diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index d0c0eb8f9..326bfe313 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -15,11 +15,17 @@ from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.block.block import BlockLayerBase -from fast_llm.layers.block.config import BlockDimNames +from fast_llm.layers.block.config import BlockConfig, BlockDimNames from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss -from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs, LanguageModelLossNames +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.language_model.config import ( + LanguageModelEmbeddingsConfig, + LanguageModelHeadConfig, + LanguageModelKwargs, + LanguageModelLossNames, +) from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT -from fast_llm.tensor import ParameterMeta, TensorMeta +from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert, div, get_unique logger = logging.getLogger(__name__) @@ -27,37 +33,46 @@ OUTPUT_WEIGHTS = "output_weights" -class LanguageModelHead[ConfigType: LanguageModelBaseConfig](BlockLayerBase[ConfigType], Layer): +class LanguageModelHead[ConfigType: LanguageModelHeadConfig](BlockLayerBase[ConfigType], Layer): """ A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). + TODO: Cleanup (dynamic type? composition?) """ + _config: ConfigType + def __init__( self, config: ConfigType, + # TODO: Doesn't make much sense. + block_config: BlockConfig, distributed_config: DistributedConfig, + embeddings_config: LanguageModelEmbeddingsConfig, + *, + # TODO: Review `hidden_dim` and `block_index` hidden_dim: TensorDim, - # TODO: Unnecessary? block_index: int, name: str, + lr_scale: float | None, + peft: PeftConfig | None, prediction_distance: int, ): super().__init__( config, - config.transformer, + block_config, distributed_config, - hidden_dim, - block_index, - name, - # TODO: Add lr scale? - None, + hidden_dim=hidden_dim, + block_index=block_index, + name=name, + lr_scale=lr_scale, + peft=peft, ) - self._parallel_logits = self._distributed_config.tensor_parallel > 1 and config.parallel_embeddings + self._vocab_parallel = self._distributed_config.tensor_parallel > 1 and embeddings_config.vocab_parallel self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) - self._sequence_parallel_logits = self._sequence_parallel and not self._config.parallel_embeddings + self._sequence_parallel_logits = self._sequence_parallel and not self._vocab_parallel if self._config.cross_entropy_splits is not None and self._sequence_parallel: - assert not self._parallel_logits + assert not self._vocab_parallel self._loss_coefficient = ( self._config.prediction_loss_coefficient[prediction_distance] @@ -74,9 +89,9 @@ def __init__( self._is_last_head = self._prediction_distance == self._config.prediction_heads - 1 if not self._config.enable_dpo: - self._cross_entropy_impl = self._config.cross_entropy_impl + self._cross_entropy_impl = self._config.cross_entropy_implementation if self._cross_entropy_impl == CrossEntropyImpl.auto: - if self._parallel_logits: + if self._vocab_parallel: self._cross_entropy_impl = CrossEntropyImpl.fused elif TritonConfig.TRITON_ENABLED: self._cross_entropy_impl = CrossEntropyImpl.triton @@ -85,29 +100,25 @@ def __init__( self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) - self.final_norm = self._config.transformer.normalization.get_layer(hidden_dim) + self.final_norm = self._block_config.normalization.get_layer( + hidden_dim, lr_scale=self._lr_scale, peft=self._peft + ) self._vocab_dim = TensorDim( - "vocab", self._config.vocab_size, self._parallel_dim if self._parallel_logits else None + "vocab", embeddings_config.vocab_size, self._parallel_dim if self._vocab_parallel else None ) # Only the first head defines the output weights - if self._prediction_distance == 0 and not self._config.tie_word_embeddings: + if self._prediction_distance == 0 and not self._config.tied_weight: # untie embedding weights - self.output_weights = ParameterMeta.from_dims( + self.output_weights = self._config.output_weight.get_parameter( (self._vocab_dim, hidden_dim), - init_method=init_normal_( - std=self._config.init_method_std_embed, - min_val=self._config.init_method_min_embed, - max_val=self._config.init_method_max_embed, + default_initialization=init_normal_( + std=self._block_config.init_method_std, ), - lr_scale=self._config.output_lr_scale, + lr_scale=self._lr_scale, + peft=self._peft, ) - # PEFT. - self.final_norm = self._config.transformer.peft.apply_other(self.final_norm) - if hasattr(self, "output_weights"): - self.output_weights = self._config.transformer.peft.apply_weight(self.output_weights) - def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: @@ -237,7 +248,7 @@ def _get_targets( return targets def _get_output_weights(self, kwargs: dict) -> torch.Tensor: - if self._config.tie_word_embeddings: + if self._config.tied_weight: return kwargs[WORD_EMBEDDINGS_WEIGHT] if self._prediction_distance > 0: return kwargs[OUTPUT_WEIGHTS] @@ -309,13 +320,13 @@ def _logits_cross_entropy_forward_backward( kwargs: dict, losses: dict | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - group = self._parallel_dim.group if self._parallel_logits else None + group = self._parallel_dim.group if self._vocab_parallel else None logits, context = output_parallel_linear_forward( input_=input_, weight=weight, bias=None, group=group, - sequence_parallel=self._sequence_parallel and self._parallel_logits, + sequence_parallel=self._sequence_parallel and self._vocab_parallel, ) if self._config.logit_z_loss > 0.0: diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index 5ba31c0d0..fc1dac299 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -6,7 +6,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_dim import scalar_dim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs +from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, LanguageModelKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert @@ -18,9 +18,9 @@ class PositionEmbeddingPreprocessor(Preprocessor): _position_ids: torch.Tensor _tensor_cache_max_sequence_length: int = -1 - def __init__(self, config: LanguageModelBaseConfig, distributed_config: DistributedConfig): + def __init__(self, config: LanguageModelEmbeddingsConfig, distributed_config: DistributedConfig): self._config = config - assert config.use_absolute_position_embeddings + assert config.position_embeddings.enabled self._distributed_config = distributed_config def _create_tensors(self, sequence_length: int, device: torch.device) -> None: @@ -28,7 +28,7 @@ def _create_tensors(self, sequence_length: int, device: torch.device) -> None: return self._tensor_cache_max_sequence_length = sequence_length - Assert.leq(sequence_length, self._config.num_absolute_position_embeddings) + Assert.leq(sequence_length, self._config.num_position_embeddings) self._position_ids = torch.arange(0, sequence_length, device=device, dtype=torch.int64) def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: @@ -63,10 +63,6 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: class PreferenceSpanPreprocessor(Preprocessor): - def __init__(self, config: LanguageModelBaseConfig, distributed_config: DistributedConfig): - self._config = config - self._distributed_config = distributed_config - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: return diff --git a/fast_llm/layers/ssm/block.py b/fast_llm/layers/ssm/block.py deleted file mode 100644 index 10513102f..000000000 --- a/fast_llm/layers/ssm/block.py +++ /dev/null @@ -1,32 +0,0 @@ -from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.block.block import Block, BlockLayer -from fast_llm.layers.block.config import BlockConfig -from fast_llm.layers.ssm.config import SSMConfig - - -# TODO: Sort out configs. -class SSMBlock[ConfigType: BlockConfig](Block[ConfigType]): - """ - A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458 - """ - - def __init__( - self, - config: ConfigType, - ssm_config: SSMConfig, - distributed_config: DistributedConfig, - hidden_dim: TensorDim, - block_index: int, - name: str, - lr_scale: float | None, - mixer_class: type[BlockLayer], - return_input: bool = False, - ): - self._ssm_config = ssm_config - self._mixer_class = mixer_class - super().__init__(config, distributed_config, hidden_dim, block_index, name, lr_scale, return_input) - - @property - def _mixer_config(self) -> SSMConfig: - return self._ssm_config diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 9f93b9b5d..53e4cf475 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -2,7 +2,7 @@ import math import typing -from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.config_utils.parameter import ParameterConfig from fast_llm.layers.block.config import MixerConfig from fast_llm.layers.common.linear.config import AffineLinearConfig, CausalConv1dConfig, LinearConfig @@ -96,15 +96,6 @@ class SSMConfig(MixerConfig): hint=FieldHint.core, ) - # Learning rate - # lr_scale [MambaLayer, Mamba2, DiscreteMamba2] - mamba_lr_scale: float | None = Field( - default=None, - desc="Learning rate scale for Mamba blocks.", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - def set_defaults(self, hidden_size: int): if self.d_inner is None: self.d_inner = 2 * hidden_size diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 0c91b34f8..83e02c7ac 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -10,9 +10,10 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs +from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.ssm.config import DiscreteMamba2Config from fast_llm.tensor import ParameterMeta -from fast_llm.utils import combine_lr_scales, div +from fast_llm.utils import div logger = logging.getLogger(__name__) @@ -37,12 +38,24 @@ def __init__( config: ConfigType, block_config: BlockConfig, distributed_config: DistributedConfig, + *, + # TODO: Review `hidden_dim` and `block_index` hidden_dim: TensorDim, block_index: int, name: str, lr_scale: float | None, + peft: PeftConfig | None, ): - super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) + super().__init__( + config, + block_config, + distributed_config, + hidden_dim=hidden_dim, + block_index=block_index, + name=name, + lr_scale=lr_scale, + peft=peft, + ) state_dim = TensorDim("state", self._config.state_size) v_head_size_dim = TensorDim("v_head_size", div(self._config.d_inner, self._config.n_v_heads)) @@ -71,8 +84,6 @@ def __init__( # local_bc_size = local_head_groups * state self._local_bc_size = bc_dim.size - lr_scale = combine_lr_scales(self._lr_scale, self._config.mamba_lr_scale) - # TODO: double check initializations # Projections @@ -80,10 +91,11 @@ def __init__( self.in_proj = self._config.z_layer.get_layer( hidden_dim, inner_projection_dim, - default_weight_initializer=init_normal_(0, (2 / self._config.d_inner) ** 0.5), + default_weight_initialization=init_normal_(0, (2 / self._config.d_inner) ** 0.5), default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, - lr_scale=lr_scale, + lr_scale=self._lr_scale, + peft=self._peft, ) if self.in_proj.bias is None: # TODO: Integrate to z_layer config? @@ -91,28 +103,33 @@ def __init__( (inner_dim,), weight_decay=False, init_method=init_zeros_, - lr_scale=lr_scale, + lr_scale=self._lr_scale, ) + if self._peft is not None: + self.z_bias = self._peft.apply_weight(self.z_bias) self.convolution = self._config.convolution_layer.get_layer( convolution_dim, default_activation=ActivationType.silu, - lr_scale=lr_scale, + lr_scale=self._lr_scale, + peft=self._peft, ) # D "skip" parameter self.D = self._config.d_weight.get_parameter( (heads_dim,), - default_initializer=init_ones_, - lr_scale=lr_scale, + default_initialization=init_ones_, weight_decay=False, + lr_scale=self._lr_scale, + peft=self._peft, ) self.out_proj = self._config.output_layer.get_layer( inner_dim, hidden_dim, - default_weight_initializer=init_normal_(0, (2 / self._config.d_inner) ** 0.5), + default_weight_initialization=init_normal_(0, (2 / self._config.d_inner) ** 0.5), default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, - lr_scale=lr_scale, + lr_scale=self._lr_scale, + peft=self._peft, ) def forward( diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index e7bd7674b..4bc67c650 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -10,9 +10,10 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs +from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.ssm.config import MambaConfig from fast_llm.tensor import ParameterMeta -from fast_llm.utils import Assert, combine_lr_scales, div +from fast_llm.utils import Assert, div try: from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn as _mamba_inner_fn # noqa @@ -60,12 +61,24 @@ def __init__( config: ConfigType, block_config: BlockConfig, distributed_config: DistributedConfig, + *, + # TODO: Review `hidden_dim` and `block_index` hidden_dim: TensorDim, block_index: int, name: str, lr_scale: float | None, + peft: PeftConfig | None, ): - super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) + super().__init__( + config, + block_config, + distributed_config, + hidden_dim=hidden_dim, + block_index=block_index, + name=name, + lr_scale=lr_scale, + peft=peft, + ) assert self._distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for Mamba" # Tensor dims: @@ -76,59 +89,64 @@ def __init__( inner_projection_dim = ConcatenatedTensorDim("inner_projection", (inner_dim, inner_dim)) x_projection_dim = ConcatenatedTensorDim("x_projection", (dt_rank_dim, state_dim, state_dim)) - lr_scale = combine_lr_scales(self._lr_scale, self._config.mamba_lr_scale) - # TODO: Use x_layer self.in_proj = self._config.z_layer.get_layer( hidden_dim, inner_projection_dim, - default_weight_initializer=init_normal_(0, (2 / hidden_dim.global_size) ** 0.5), + default_weight_initialization=init_normal_(0, (2 / hidden_dim.global_size) ** 0.5), default_add_bias=self._block_config.add_linear_biases, - lr_scale=lr_scale, + lr_scale=self._lr_scale, + peft=self._peft, ) self.convolution = self._config.convolution_layer.get_layer( inner_dim, - default_weight_initializer=init_normal_(0, (2 / self._config.d_inner) ** 0.5), + default_weight_initialization=init_normal_(0, (2 / self._config.d_inner) ** 0.5), default_add_bias=False, default_activation=ActivationType.silu, - lr_scale=lr_scale, + lr_scale=self._lr_scale, + peft=self._peft, ) self.x_proj = self._config.x_projection_layer.get_layer( inner_dim, x_projection_dim, - default_weight_initializer=init_normal_(0, (2 / self._config.d_inner) ** 0.5), - lr_scale=lr_scale, + default_weight_initialization=init_normal_(0, (2 / self._config.d_inner) ** 0.5), + lr_scale=self._lr_scale, + peft=self._peft, ) # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 self.dt_proj = self._config.dt_layer.get_layer( dt_rank_dim, inner_dim, - default_weight_initializer=init_normal_(0, (2 / self._config.d_inner) ** 0.5), - default_bias_initializer=init_dtprojbias( + default_weight_initialization=init_normal_(0, (2 / self._config.d_inner) ** 0.5), + default_bias_initialization=init_dtprojbias( self._config.dt_max, self._config.dt_min, self._config.dt_init_floor ), - lr_scale=lr_scale, + lr_scale=self._lr_scale, + peft=self._peft, ) self.A_log = self._config.a_log_weight.get_parameter( (inner_dim, state_dim), - default_initializer=init_A(self._config.state_size, self._config.d_inner), - lr_scale=lr_scale, + default_initialization=init_A(self._config.state_size, self._config.d_inner), weight_decay=False, + lr_scale=self._lr_scale, + peft=self._peft, ) # D "skip" parameter self.D = self._config.d_weight.get_parameter( (inner_dim,), - default_initializer=init_ones_, - lr_scale=lr_scale, + default_initialization=init_ones_, weight_decay=False, + lr_scale=self._lr_scale, + peft=self._peft, ) self.out_proj = self._config.output_layer.get_layer( inner_dim, hidden_dim, - default_weight_initializer=init_normal_(0, (2 / hidden_dim.global_size) ** 0.5), + default_weight_initialization=init_normal_(0, (2 / hidden_dim.global_size) ** 0.5), default_add_bias=False, - lr_scale=lr_scale, + lr_scale=self._lr_scale, + peft=self._peft, ) def forward( diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 90fdb343a..e386aa712 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -9,9 +9,10 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs +from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.ssm.config import Mamba2Config from fast_llm.layers.ssm.mamba import init_A, init_dtprojbias -from fast_llm.utils import combine_lr_scales, div +from fast_llm.utils import div try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa @@ -35,12 +36,24 @@ def __init__( config: ConfigType, block_config: BlockConfig, distributed_config: DistributedConfig, + *, + # TODO: Review `hidden_dim` and `block_index` hidden_dim: TensorDim, block_index: int, name: str, lr_scale: float | None, + peft: PeftConfig | None, ): - super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) + super().__init__( + config, + block_config, + distributed_config, + hidden_dim=hidden_dim, + block_index=block_index, + name=name, + lr_scale=lr_scale, + peft=peft, + ) num_heads = div(self._config.d_inner, self._config.state_size) num_head_groups = div(self._config.d_xb, self._config.state_size) @@ -72,61 +85,66 @@ def __init__( self._local_xb_size = xb_dim.size convolution_dim = inner_dim if self._config.repeat_kv_before_conv else xb_dim - lr_scale = combine_lr_scales(self._lr_scale, self._config.mamba_lr_scale) - self.convolution = self._config.convolution_layer.get_layer( convolution_dim, default_activation=ActivationType.silu, - lr_scale=lr_scale, + lr_scale=self._lr_scale, + peft=self._peft, ) # TODO: Use x_layer, b_layer, c_layer self.in_proj = self._config.z_layer.get_layer( hidden_dim, inner_projection_dim, - default_weight_initializer=init_normal_(0, (2 / hidden_dim.global_size) ** 0.5), + default_weight_initialization=init_normal_(0, (2 / hidden_dim.global_size) ** 0.5), default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, - lr_scale=lr_scale, + lr_scale=self._lr_scale, + peft=self._peft, ) self.dt_in_proj = self._config.dt_input_layer.get_layer( hidden_dim, dt_rank_dim, - default_weight_initializer=init_normal_(0, (2 / hidden_dim.global_size) ** 0.5), + default_weight_initialization=init_normal_(0, (2 / hidden_dim.global_size) ** 0.5), default_add_bias=self._block_config.add_linear_biases, - lr_scale=lr_scale, + lr_scale=self._lr_scale, + peft=self._peft, ) self.dt_proj = self._config.dt_layer.get_layer( dt_rank_dim, inner_dim, - default_weight_initializer=self._config.dt_init.get_init_method( + default_weight_initialization=self._config.dt_init.get_init_method( self._config.dt_rank**-0.5 * self._config.dt_scale ), - default_bias_initializer=init_dtprojbias( + default_bias_initialization=init_dtprojbias( self._config.dt_max, self._config.dt_min, self._config.dt_init_floor ), sequence_parallel=self._sequence_parallel, - lr_scale=lr_scale, + lr_scale=self._lr_scale, + peft=self._peft, ) self.A_log = self._config.a_log_weight.get_parameter( (inner_dim, state_dim), - default_initializer=init_A(self._config.state_size, self._config.d_inner), - lr_scale=lr_scale, + default_initialization=init_A(self._config.state_size, self._config.d_inner), weight_decay=False, + lr_scale=self._lr_scale, + peft=self._peft, ) # D "skip" parameter self.D = self._config.d_weight.get_parameter( (inner_dim,), - default_initializer=init_ones_, - lr_scale=lr_scale, + default_initialization=init_ones_, weight_decay=False, + lr_scale=self._lr_scale, + peft=self._peft, ) self.out_proj = self._config.output_layer.get_layer( inner_dim, hidden_dim, - default_weight_initializer=init_normal_(0, (2 / self._config.d_inner) ** 0.5), + default_weight_initialization=init_normal_(0, (2 / self._config.d_inner) ** 0.5), default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, - lr_scale=lr_scale, + lr_scale=self._lr_scale, + peft=self._peft, ) if self._debug.enabled: diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 3ca2d71fa..370ae4d90 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -190,18 +190,18 @@ class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): def _validate(self) -> None: if self.batch.sequence_length is None: # TODO: Drop this. - self.batch.sequence_length = self.model.base_model.max_position_embeddings + self.batch.sequence_length = self.model.base_model.embeddings_layer.num_position_embeddings if self.model.base_model.use_megatron_initialization: set_megatron_distributed_seeds(self.model.distributed) super()._validate() - if self.model.base_model.use_absolute_position_embeddings: - Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) + if self.model.base_model.embeddings_layer.position_embeddings.enabled: + Assert.geq(self.model.base_model.embeddings_layer.num_position_embeddings, self.batch.sequence_length) - distillation_model = self.model.base_model.distillation_model - dpo_reference_model = self.model.base_model.dpo_reference_model + distillation_model = self.model.base_model.output_layer.distillation_model + dpo_reference_model = self.model.base_model.output_layer.dpo_reference_model - if self.model.base_model.enable_dpo: + if self.model.base_model.output_layer.enable_dpo: assert dpo_reference_model is not None Assert.none(distillation_model) else: @@ -215,12 +215,16 @@ def _validate(self) -> None: Assert.eq(self.reference_models.keys(), expected_names) for reference_model in self.reference_models.values(): - Assert.none(reference_model.model.base_model.distillation_model) - Assert.none(reference_model.model.base_model.dpo_reference_model) + output_layer = reference_model.model.base_model.output_layer + Assert.none(output_layer.distillation_model) + Assert.none(output_layer.dpo_reference_model) # TODO: Support more LM head features. - Assert.none(reference_model.model.base_model.cross_entropy_splits) - Assert.eq(reference_model.model.base_model.parallel_embeddings, self.model.base_model.parallel_embeddings) - Assert.geq(reference_model.model.base_model.prediction_heads, self.model.base_model.prediction_heads) + Assert.none(output_layer.cross_entropy_splits) + Assert.eq( + reference_model.model.base_model.embeddings_layer.vocab_parallel, + self.model.base_model.embeddings_layer.vocab_parallel, + ) + Assert.geq(output_layer.prediction_heads, output_layer.prediction_heads) @classmethod def _from_dict( diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 42fe849b3..5e8b94354 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -24,9 +24,9 @@ from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType -from fast_llm.layers.attention.config import TransformerConfig from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig from fast_llm.layers.attention.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex +from fast_llm.layers.block.config import BlockConfig from fast_llm.layers.block.mlp.config import RoutingType from fast_llm.layers.common.normalization.config import LayerNormalizationConfig from fast_llm.models.gpt.config import ( @@ -128,7 +128,16 @@ class CommonHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantExportParamConverter(export_names=(("architectures",),), export_value=[cls.architecture]), - ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), + ConstantImportParamConverter( + fast_llm_names=( + ( + "embeddings_layer", + "position_embeddings", + "enabled", + ), + ), + fast_llm_value=False, + ), RenameParamConverter( fast_llm_names=(("transformer", "mixer", "rotary", "theta"),), export_names=(("rope_theta",),) ), @@ -159,11 +168,21 @@ def _create_config_converters(cls) -> list[ParamConverter]: export_names=(("intermediate_size",),), ), RenameParamConverter( - fast_llm_names=(("vocab_size",),), + fast_llm_names=( + ( + "embeddings_layer", + "vocab_size", + ), + ), export_names=(("vocab_size",),), ), RenameParamConverter( - fast_llm_names=(("tie_word_embeddings",),), + fast_llm_names=( + ( + "output_layer", + "tied_weight", + ), + ), export_names=(("tie_word_embeddings",),), ), ] @@ -191,28 +210,28 @@ def _create_weight_converters( def _create_transformer_layer_converters( self, fast_llm_layer_name: str, hf_layer_name: str, ignore_export: bool = False ) -> list[WeightConverter]: - transformer_config: TransformerConfig = self._model.config.base_model.transformer + transformer_config: BlockConfig = self._model.config.base_model.transformer norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) converters = [] names_bias_cls = [ # Self-attn ( - f"{fast_llm_layer_name}.self_attn.query", - f"{hf_layer_name}.self_attn.q_proj", + f"{fast_llm_layer_name}.mixer.query", + f"{hf_layer_name}.mixer.q_proj", # TODO: Fix transformer_config.add_linear_biases, QueryWeightConverter, ), ( - f"{fast_llm_layer_name}.self_attn.key_value", - (f"{hf_layer_name}.self_attn.k_proj", f"{hf_layer_name}.self_attn.v_proj"), + f"{fast_llm_layer_name}.mixer.key_value", + (f"{hf_layer_name}.mixer.k_proj", f"{hf_layer_name}.mixer.v_proj"), # TODO: Fix transformer_config.add_linear_biases, KeyValueWeightConverter, ), ( - f"{fast_llm_layer_name}.self_attn.dense", - f"{hf_layer_name}.self_attn.o_proj", + f"{fast_llm_layer_name}.mixer.dense", + f"{hf_layer_name}.mixer.o_proj", # TODO: Fix transformer_config.add_linear_biases, WeightConverter, @@ -262,7 +281,7 @@ def _create_transformer_layer_converters( def _create_lm_head_converters(self) -> list[WeightConverter]: num_layers = self._model.config.base_model.transformer.num_layers - prediction_heads = self._model.config.base_model.prediction_heads + prediction_heads = self._model.config.base_model.output_layer.prediction_heads norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) converters = [] @@ -272,7 +291,7 @@ def _create_lm_head_converters(self) -> list[WeightConverter]: f"layers.{num_layers + 1}.final_norm", "model.norm", norm_bias ) # Output weights - if self._model.config.base_model.tie_word_embeddings: + if self._model.config.base_model.output_layer.tied_weight: converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) else: converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) @@ -346,7 +365,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - transformer_config: TransformerConfig = self._model.config.base_model.transformer + transformer_config: BlockConfig = self._model.config.base_model.transformer return [ *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", @@ -459,7 +478,7 @@ class LlamaHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler) architecture: typing.ClassVar[str] = "LlamaForCausalLM" def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - transformer_config: TransformerConfig = self._model.config.base_model.transformer + transformer_config: BlockConfig = self._model.config.base_model.transformer return [ *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", @@ -530,7 +549,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - transformer_config: TransformerConfig = self._model.config.base_model.transformer + transformer_config: BlockConfig = self._model.config.base_model.transformer return [ *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", @@ -635,13 +654,18 @@ def _create_config_converters(cls) -> list[ParamConverter]: }, ), RenameParamConverter( - fast_llm_names=(("prediction_heads",),), + fast_llm_names=( + ( + "output_layer", + "prediction_heads", + ), + ), export_names=(("prediction_heads",),), ), ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - transformer_config: TransformerConfig = self._model.config.base_model.transformer + transformer_config: BlockConfig = self._model.config.base_model.transformer return [ *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", @@ -662,7 +686,7 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig # Override base method to handle the MTP heads def _create_lm_head_converters(self) -> list[WeightConverter]: num_layers = self._model.config.base_model.transformer.num_layers - prediction_heads = self._model.config.base_model.prediction_heads + prediction_heads = self._model.config.base_model.output_layer.prediction_heads norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) converters = [] @@ -687,7 +711,7 @@ def _create_lm_head_converters(self) -> list[WeightConverter]: norm_bias, ) # Output weights - if self._model.config.base_model.tie_word_embeddings: + if self._model.config.base_model.output_layer.tied_weight: converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) else: converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) diff --git a/fast_llm/models/gpt/megatron.py b/fast_llm/models/gpt/megatron.py index b4a6b6feb..5c73dbb23 100644 --- a/fast_llm/models/gpt/megatron.py +++ b/fast_llm/models/gpt/megatron.py @@ -1,7 +1,7 @@ import typing -from fast_llm.layers.attention.config import TransformerConfig from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig +from fast_llm.layers.block.config import BlockConfig from fast_llm.layers.block.mlp.config import MoEMLPConfig from fast_llm.utils import Assert, div @@ -14,7 +14,7 @@ def get_init_megatron( - meta: "ParameterMeta", config: TransformerConfig + meta: "ParameterMeta", config: BlockConfig ) -> typing.Callable[["torch.Tensor", "Distributed"], None]: def init_megatron(tensor: "torch.Tensor", distributed: "Distributed") -> None: Assert.eq(distributed.config.world_size, 1) @@ -51,7 +51,7 @@ def set_megatron_distributed_seeds(config: "DistributedConfig") -> None: def _init_attention_megatron( - config: TransformerConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" + config: BlockConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" ) -> "torch.Tensor": # Megatron combines q and kv and inverts the initialization order of qkv and dense layers. # It also always treats the tensors as tensor-parallel and uses a different rotary embedding format. @@ -141,7 +141,7 @@ def _init_moe_router_megatron( def _init_moe_mlp_megatron( - config: TransformerConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" + config: BlockConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" ) -> "torch.Tensor": assert meta.param_init_method is not None generator = distributed.tp_init_generator if meta.is_tensor_parallel else distributed.pp_init_generator diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index f2f31ddf2..cfd9ae546 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -10,15 +10,13 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel -from fast_llm.layers.attention.block import TransformerBlock from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.attention.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor +from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.block.mlp.config import MLPLossNames, MoEMLPConfig, RoutingType from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead -from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -32,6 +30,8 @@ class GPTBaseModel[ConfigType: GPTBaseModelConfig](BaseModel[ConfigType]): A transformer-based language model generalizing the GPT model architecture. """ + _config: ConfigType + def __init__( self, config: GPTBaseModelConfig, @@ -39,37 +39,16 @@ def __init__( ): self._hidden_dim = TensorDim("hidden", config.transformer.hidden_size) super().__init__(config, distributed_config) - self._use_flash_attention = self._config.transformer.mixer.do_use_flash_attention(distributed_config) if self._config.use_megatron_initialization: for param in self.parameters(): Assert.custom(isinstance, param, ParameterMeta) param.init_parameter = get_init_megatron(param, self._config.transformer) # Noqa # `self._reference_models` is not populated at this point, so we pass a mutable dict. - self._preprocessors: list[Preprocessor] = [] - if self._config.use_absolute_position_embeddings: - self._preprocessors.append(PositionEmbeddingPreprocessor(self._config, self._distributed_config)) - # We have multiple identical rotary modules/preprocessors, so it's simpler to make a new one here. - # TODO: Find a better solution. - self._preprocessors.append( - self._config.transformer.mixer.rotary.get_layer( - TensorDim("kv_channels", self._config.transformer.mixer.kv_channels) - ) - ) - if self._use_flash_attention: - self._preprocessors.append( - FlashAttnVarlenPreprocessor(self._config.transformer.mixer, self._distributed_config) - ) - else: - self._preprocessors.append( - BackupAttentionPreprocessor(self._config.transformer.mixer, self._distributed_config) - ) - - if self._config.enable_dpo: # TODO better way to pass in? - self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._distributed_config)) + self._preprocessors: list[Preprocessor] = self._config.get_preprocessors(distributed_config) def _get_output_layers(self) -> list[Layer]: layers = [] - for i in range(self._config.prediction_heads): + for i in range(self._config.output_layer.prediction_heads): if i > 0: layers.append( self._get_block( @@ -78,7 +57,7 @@ def _get_output_layers(self) -> list[Layer]: f"MPT head {i} block", # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. - i < self._config.prediction_heads - 1, + i < self._config.output_layer.prediction_heads - 1, ) ) layers.append(self._get_head(i)) @@ -93,7 +72,7 @@ def get_layers(self) -> list[Layer]: f"Block {i + 1}", # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. - self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, + self._config.output_layer.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, ) for i in range(self._config.transformer.num_layers) ], @@ -106,31 +85,37 @@ def _get_block( name: str, return_input: bool = False, ): - lr_scale = ( - None - if self._config.transformer.per_layer_lr_scale is None - else self._config.transformer.per_layer_lr_scale[block_index] - ) - return TransformerBlock( - self._config.transformer, + return self._config.transformer.get_layer( self._distributed_config, - self._hidden_dim, - block_index, - name, - lr_scale, - return_input, + hidden_dim=self._hidden_dim, + block_index=block_index, + name=name, + lr_scale=None, + peft=self._config.peft, + return_input=return_input, ) def _get_embeddings(self): - return LanguageModelEmbedding(self._config, self._distributed_config, self._hidden_dim, 0, "Embeddings") + return self._config.embeddings_layer.get_layer( + self._config.transformer, + self._distributed_config, + hidden_dim=self._hidden_dim, + block_index=0, + name="Embeddings", + lr_scale=None, + peft=self._config.peft, + ) def _get_head(self, prediction_distance): - return LanguageModelHead( - self._config, + return self._config.output_layer.get_layer( + self._config.transformer, self._distributed_config, - self._hidden_dim, - max(self._config.transformer.num_layers + prediction_distance, 1), - f"Language model head {prediction_distance}", + self._config.embeddings_layer, + hidden_dim=self._hidden_dim, + block_index=max(self._config.transformer.num_layers + prediction_distance, 1), + name=f"Language model head {prediction_distance}", + lr_scale=None, + peft=self._config.peft, prediction_distance=prediction_distance, ) @@ -148,7 +133,7 @@ def preprocess_meta( else: micro_batch_size, sequence_length = batch_meta.shape if phase != PhaseType.inference: - sequence_length -= self._config.prediction_heads + sequence_length -= self._config.output_layer.prediction_heads micro_sequence_length = sequence_length truncate_documents = True @@ -264,7 +249,7 @@ def preprocess( _, common_kwargs = preprocessed_meta[0] sequence_q = common_kwargs[AttentionKwargs.sequence_q_dim].size sequence_first = common_kwargs[AttentionKwargs.sequence_first] - prediction_heads: int = self._config.prediction_heads + prediction_heads: int = self._config.output_layer.prediction_heads batch.token_ids = batch.token_ids.to( device=self._distributed.device, @@ -346,7 +331,7 @@ def preprocess( loss_mask[start : end + 1, idx] = False else: loss_mask[idx, start : end + 1] = False - if self._config.distillation_model is not None: + if self._config.output_layer.distillation_model is not None: kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) kwargs[LanguageModelKwargs.labels] = labels @@ -363,7 +348,7 @@ def embedding(self) -> LanguageModelEmbedding: return self.layers[0] @property - def transformer_layers(self) -> list[TransformerBlock]: + def transformer_layers(self) -> list[Block]: return self.layers[1:-1] @property @@ -372,17 +357,17 @@ def model_head(self) -> LanguageModelHead: @property def model_head_indices(self) -> list[int]: - return sorted([len(self) - 1 - 2 * i for i in range(self._config.prediction_heads)]) + return sorted([len(self) - 1 - 2 * i for i in range(self._config.output_layer.prediction_heads)]) def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: - if self._config.tie_word_embeddings: + if self._config.output_layer.tied_weight: return { WORD_EMBEDDINGS_WEIGHT: ( self.embedding.word_embeddings_weight, (0, *self.model_head_indices), ) } - elif self._config.prediction_heads > 1: + elif self._config.output_layer.prediction_heads > 1: return { OUTPUT_WEIGHTS: ( self.model_head.output_weights, @@ -415,22 +400,22 @@ def loss_defs(self) -> list[LossDef]: count=self._config.transformer.num_layers, ) ) - if self._config.logit_z_loss: + if self._config.output_layer.logit_z_loss: LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=1) - if self._config.enable_dpo: + if self._config.output_layer.enable_dpo: loss_defs.append(LossDef(name=LanguageModelLossNames.dpo_loss, formatted_name="dpo loss", count=1)) - if self._config.distillation_model is not None: + if self._config.output_layer.distillation_model is not None: loss_defs.append( LossDef(name=LanguageModelLossNames.distillation_loss, formatted_name="distillation loss", count=1) ) - if self._config.language_model_loss_factor > 0.0: + if self._config.output_layer.language_model_loss_factor > 0.0: loss_defs.append( LossDef(name=LanguageModelLossNames.distil_lm_loss, formatted_name="distillation lm loss", count=1) ) - for i in range(self._config.prediction_heads): + for i in range(self._config.output_layer.prediction_heads): loss_defs.append( LossDef( name=LanguageModelLossNames.multi_token_prediction_loss(i), @@ -452,7 +437,9 @@ def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration, batch_size, s consumed_tokens_per_iteration = sequence_length * batch_size - num_transformer_layers = transformer_config.num_layers + self._config.base_model.prediction_heads - 1 + num_transformer_layers = ( + transformer_config.num_layers + self._config.base_model.output_layer.prediction_heads - 1 + ) transformer_flops_base = ( 2 * checkpoint_activations_factor * consumed_tokens_per_iteration * num_transformer_layers ) @@ -477,8 +464,8 @@ def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration, batch_size, s 6 * consumed_tokens_per_iteration * transformer_config.hidden_size - * self._config.base_model.vocab_size - * self._config.base_model.prediction_heads + * self._config.base_model.embeddings_layer.vocab_size + * self._config.base_model.output_layer.prediction_heads ) # Attention-matrix computation diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 7f2e83ab4..4dbbfbb1c 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -22,13 +22,13 @@ def _get_sampling_parameters( parameters = super()._get_sampling_parameters(parameters, _return_dict=True) parameters.update( { - "vocab_size": self._config.model.base_model.vocab_size, + "vocab_size": self._config.model.base_model.embeddings_layer.vocab_size, "sequence_length": self._config.batch.sequence_length, "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, - "use_preference_loss_spans": self._config.model.base_model.enable_dpo, + "use_preference_loss_spans": self._config.model.base_model.output_layer.enable_dpo, "cross_document_attention": self._config.batch.cross_document_attention, "truncate_documents": self._config.batch.truncate_documents, - "extra_tokens": self._config.model.base_model.prediction_heads, + "extra_tokens": self._config.model.base_model.output_layer.prediction_heads, } ) return parameters if _return_dict else GPTSamplingParameters(**parameters) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 97526ec5b..38839276f 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -158,21 +158,27 @@ def get_trainer_class(cls) -> type["HybridSSMTrainer"]: def _validate(self) -> None: super()._validate() - if (name := self.model.base_model.distillation_model) is None: + if (name := self.model.base_model.output_layer.distillation_model) is None: Assert.empty(self.reference_models) else: Assert.eq(self.reference_models.keys(), {name}) - if self.model.base_model.use_absolute_position_embeddings: - Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) + if self.model.base_model.embeddings_layer.position_embeddings.enabled: + Assert.geq(self.model.base_model.embeddings_layer.num_position_embeddings, self.batch.sequence_length) # if self.model.base_model.distillation_model is not None: # # TODO: Support loss masking for distillation? # assert not self.batch.use_loss_masking_spans for reference_model in self.reference_models.values(): - Assert.none(reference_model.model.base_model.distillation_model) + Assert.none(reference_model.model.base_model.output_layer.distillation_model) # TODO: Support more LM head features. - Assert.none(reference_model.model.base_model.cross_entropy_splits) - Assert.eq(reference_model.model.base_model.parallel_embeddings, self.model.base_model.parallel_embeddings) - Assert.geq(reference_model.model.base_model.prediction_heads, self.model.base_model.prediction_heads) + Assert.none(reference_model.model.base_model.output_layer.cross_entropy_splits) + Assert.eq( + reference_model.model.base_model.embeddings_layer.vocab_parallel, + self.model.base_model.embeddings_layer.vocab_parallel, + ) + Assert.geq( + reference_model.model.base_model.output_layer.prediction_heads, + self.model.base_model.output_layer.prediction_heads, + ) @classmethod def get_inference_runner_class(cls) -> type["HybridSSMInferenceRunner"]: diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index aebfa6ef4..3f6238c45 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -1,8 +1,6 @@ import logging import typing -from fast_llm.layers.attention.block import TransformerBlock -from fast_llm.layers.ssm.block import SSMBlock from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType @@ -29,33 +27,20 @@ def _get_block( # Decoder block block_type = self._config.hybrid_block_layout[block_index - 1] - lr_scale = ( - None - if self._config.transformer.per_layer_lr_scale is None - else self._config.transformer.per_layer_lr_scale[block_index] - ) - if block_type == SSMBlockType.transformer: - return TransformerBlock( - self._config.transformer, - self._distributed_config, - self._hidden_dim, - block_index, - name, - lr_scale, - return_input, - ) + block_config = self._config.transformer else: - return SSMBlock( - self._config.transformer, - self._config.ssm, - self._distributed_config, - self._hidden_dim, - block_index, - name, - lr_scale, - return_input, - ) + block_config = self._config.transformer.from_dict(self._config.transformer, {"mixer": self._config.ssm}) + + return block_config.get_layer( + self._distributed_config, + hidden_dim=self._hidden_dim, + block_index=block_index, + name=name, + lr_scale=None, + peft=self._config.peft, + return_input=return_input, + ) class HybridSSMModel[ConfigType: HybridSSMModelConfig](GPTModel[ConfigType]): diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index f56834c8a..4323efe3f 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -6,7 +6,7 @@ from fast_llm.core.distributed import ReduceOp from fast_llm.core.ops import reduce_op -from fast_llm.engine.config_utils.initialization import Initializer, LambdaInitializer +from fast_llm.engine.config_utils.initialization import Initialization, Initializer, LambdaInitializer from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed @@ -238,7 +238,7 @@ def __init__( *, tensor_name: str = "", dims: tuple[TensorDim, ...], - init_method: "Initializer | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None" = None, + init_method: "Initialization | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None" = None, weight_decay: bool = True, # Pass a list to split the parameter in contiguous (dim=0) chunks of equal size for optimization. lr_scale: float | None | tuple[float | None, ...] = None, @@ -247,7 +247,9 @@ def __init__( allow_no_grad: bool = False, ): super().__init__(data, tensor_name=tensor_name, dims=dims) - if init_method is not None and not isinstance(init_method, Initializer): + if isinstance(init_method, Initialization): + init_method = init_method.get_initializer() + elif init_method is not None: # Support non-wrapped callables for convenience. assert callable(init_method) init_method = LambdaInitializer(init_method) diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 51249c3fa..d13ecaf65 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -10,7 +10,6 @@ if typing.TYPE_CHECKING: import numpy as np - import numpy.typing as npt import torch logger = logging.getLogger(__name__) @@ -348,31 +347,6 @@ def check_equal_nested(config_a, config_b): raise ValueError("\n".join(errors)) -def combine_lr_scales(*lr_scales: float | None | tuple[float | None, ...]): - # Remove `None` entries. - lr_scales = tuple(lr_scale for lr_scale in lr_scales if lr_scale is not None) - if not lr_scales: - # Everything is None - return None - tuple_length = None - # Check if we have tuples, and determine the length. - for lr_scale in lr_scales: - if isinstance(lr_scale, tuple): - if tuple_length is None: - tuple_length = len(lr_scale) - else: - assert len(lr_scale) == tuple_length - if tuple_length is None: - # No tuple: simple product. - return math.prod(lr_scales) - else: - # Tuple(s): use recursion. - return tuple( - combine_lr_scales(*[lr_scale[i] if isinstance(lr_scale, tuple) else lr_scale for lr_scale in lr_scales]) - for i in range(tuple_length) - ) - - class Interrupter: def __init__(self, enabled: bool = True, signals: typing.Sequence[int] = (signal.SIGINT, signal.SIGTERM)): self._enabled = enabled diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 380ab0550..755d143e9 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -107,39 +107,47 @@ def _lm_head( ({}, {"training_dtype": DataType.bfloat16}, False), ({"transformer": {"full_precision_residual": True}}, {"training_dtype": DataType.bfloat16}, False), ({"sequence_first": True}, {}, False), - ({"logit_z_loss": 1e-3}, {}, False), - ({"logits_scale_factor": 5.0}, {}, False), - ({"tie_word_embeddings": False}, {}, False), - ({"prediction_heads": 2}, {}, False), + ({"output_layer": {"logit_z_loss": 1e-3}}, {}, False), + ({"output_layer": {"logits_scale_factor": 5.0}}, {}, False), + ({"output_layer": {"tied_weight": False}}, {}, False), + ({"output_layer": {"prediction_heads": 2}}, {}, False), ({}, {}, True), ( { - "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.cross_entropy, + "output_layer": { + "distillation_model": "distillation", + "distillation_loss_implementation": DistillationLossImpl.cross_entropy, + } }, {}, False, ), ( { - "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.reverse_kl, + "output_layer": { + "distillation_model": "distillation", + "distillation_loss_implementation": DistillationLossImpl.reverse_kl, + } }, {}, False, ), ( { - "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.cross_entropy, + "output_layer": { + "distillation_model": "distillation", + "distillation_loss_implementation": DistillationLossImpl.cross_entropy, + } }, {}, True, ), ( { - "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.reverse_kl, + "output_layer": { + "distillation_model": "distillation", + "distillation_loss_implementation": DistillationLossImpl.reverse_kl, + } }, {}, True, @@ -159,8 +167,8 @@ def test_lm_head( "hidden_size": HIDDEN_SIZE, "num_layers": 0, }, - "vocab_size": VOCAB_SIZE, - "cross_entropy_impl": cross_entropy_impl, + "embeddings_layer": {"vocab_size": VOCAB_SIZE}, + "output_layer": {"cross_entropy_implementation": cross_entropy_impl}, }, config_dict, update_type=UpdateType.update, @@ -176,7 +184,7 @@ def test_lm_head( ) sequence_first = config.sequence_first or ( - config.cross_entropy_splits is not None and config.cross_entropy_splits > 1 + config.output_layer.cross_entropy_splits is not None and config.output_layer.cross_entropy_splits > 1 ) input_ = torch.randn( (SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE) if sequence_first else (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE), @@ -189,9 +197,9 @@ def test_lm_head( requires_grad=True, ) label_shape = ( - (SEQUENCE_LENGTH + config.prediction_heads - 1, BATCH_SIZE) + (SEQUENCE_LENGTH + config.output_layer.prediction_heads - 1, BATCH_SIZE) if sequence_first - else (BATCH_SIZE, SEQUENCE_LENGTH + config.prediction_heads - 1) + else (BATCH_SIZE, SEQUENCE_LENGTH + config.output_layer.prediction_heads - 1) ) if loss_masking: loss_mask = torch.randint(0, 2, label_shape, dtype=torch.bool, device=distributed.device) @@ -201,7 +209,7 @@ def test_lm_head( AttentionKwargs.sequence_first: sequence_first, AttentionKwargs.grad_output: 1.0, } - if config.distillation_model is None: + if config.output_layer.distillation_model is None: target = torch.randint( 0, VOCAB_SIZE, @@ -214,17 +222,17 @@ def test_lm_head( kwargs[LanguageModelKwargs.labels] = target else: - assert config.prediction_heads == 1 + assert config.output_layer.prediction_heads == 1 target = torch.randn( input_.shape[:-1] + (VOCAB_SIZE,), dtype=input_.dtype, device=distributed.device, ) - kwargs[f"{config.distillation_model}_logits"] = target + kwargs[f"{config.output_layer.distillation_model}_logits"] = target if loss_mask is not None: kwargs[LanguageModelKwargs.loss_mask] = loss_mask - if config.tie_word_embeddings or config.prediction_heads > 1: + if config.output_layer.tied_weight or config.output_layer.prediction_heads > 1: logit_weight = ( torch.empty( VOCAB_SIZE, HIDDEN_SIZE, dtype=distributed.config.training_dtype.torch, device=distributed.device @@ -232,7 +240,7 @@ def test_lm_head( .normal_(config.transformer.init_method_std) .requires_grad_(True) ) - kwargs[WORD_EMBEDDINGS_WEIGHT if config.tie_word_embeddings else OUTPUT_WEIGHTS] = logit_weight + kwargs[WORD_EMBEDDINGS_WEIGHT if config.output_layer.tied_weight else OUTPUT_WEIGHTS] = logit_weight else: logit_weight = None @@ -264,9 +272,9 @@ def test_lm_head( loss_mask, rms_weight=ref_rms_weight, logit_weight=ref_logit_weight, - logit_scale_factor=config.logits_scale_factor, - logit_z_loss=config.logit_z_loss, - distillation_loss_implementation=config.distillation_loss_implementation, + logit_scale_factor=config.output_layer.logits_scale_factor, + logit_z_loss=config.output_layer.logit_z_loss, + distillation_loss_implementation=config.output_layer.distillation_loss_implementation, ) # Prepare LM head inputs @@ -283,7 +291,7 @@ def test_lm_head( loss_keys = {loss_name} if ref_z_loss is not None: loss_keys.add("z_loss") - if config.distillation_model is not None: + if config.output_layer.distillation_model is not None: loss_keys.add("distillation_loss") loss_keys.add("distil_lm_loss") losses = {key: [] for key in loss_keys} @@ -293,7 +301,7 @@ def test_lm_head( threshold = 1e-5 if distributed.config.training_dtype == DataType.float32 else 5e-3 min_threshold = ( 1e-5 if distributed.config.training_dtype == DataType.float32 else 1e-4 - ) * config.logits_scale_factor + ) * config.output_layer.logits_scale_factor Assert.eq(losses.keys(), loss_keys) Assert.eq(len(losses[loss_name]), 1) diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 6f4631320..ed911fc8a 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -301,7 +301,11 @@ def test_huggingface_model(model_testing_config, get_convert_path): ) ) test_input = torch.randint( - 0, model_ref.config.fast_llm_config.base_model.vocab_size, size=(4, 100), dtype=torch.int64, device="cuda" + 0, + model_ref.config.fast_llm_config.base_model.embeddings_layer.vocab_size, + size=(4, 100), + dtype=torch.int64, + device="cuda", ) output_ref = model_ref(test_input) model_from_fast_llm = hf_class.from_pretrained(fast_llm_path) diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index 5ff998bfa..fdb908b0d 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -38,9 +38,9 @@ def test_match_megatron(run_test_script_for_all_models, model_testing_config, co assert model_testing_config.megatron_args is not None ignore_tensors = ( - ".self_attn.query_key_value.", - ".self_attn.query.", - ".self_attn.key_value.", + ".mixer.query_key_value.", + ".mixer.query.", + ".mixer.key_value.", ".mlp.layer_2.weight", ".mlp.experts.", ) diff --git a/tests/test_config.py b/tests/test_config.py index b1be46b60..8954114f7 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -88,7 +88,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): "num_layers": 12, # Default "hidden_size": 1024, # Default }, - "tie_word_embeddings": False, + "output_layer": {"tied_weight": False}, }, "multi_stage": {"zero_stage": 3}, "distributed": {"training_dtype": "bfloat16"}, @@ -107,10 +107,10 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, # rotary: Don't override nested. "normalization": {"implementation": "triton"}, # Update non-default nested - "peft": {"type": "lora", "freeze_others": False}, # Update default nested, change type "hidden_size": 512, # Override, affects derived value (kv channels) }, - "vocab_size": 1000, + "embeddings_layer": {"vocab_size": 1000}, + "peft": {"type": "lora", "freeze_others": False}, # Update default nested, change type } pretrained_config = PretrainedGPTModelConfig.from_dict( { @@ -143,23 +143,19 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): "activation_type": "silu", # Implicit default, non-default value }, "normalization": {"type": "rms_norm", "implementation": "triton"}, - "peft": {"type": "lora", "freeze_others": False, "layers": ["query", "value"]}, "num_layers": 12, "hidden_size": 512, }, - "tie_word_embeddings": False, - "vocab_size": 1000, + "embeddings_layer": {"vocab_size": 1000}, + "output_layer": {"tied_weight": False}, + "peft": {"type": "lora", "freeze_others": False}, } else: - base_model_update["transformer"]["peft"] = { - "type": "lora", - "freeze_others": False, - "layers": ["query", "value"], - } base_model_update["transformer"]["normalization"]["type"] = "layer_norm" base_model_update["transformer"]["mixer"]["type"] = "attention" base_model_update["transformer"]["mixer"]["rotary"] = {"type": "none"} base_model_update["transformer"]["mlp"] = {"type": "mlp"} + base_model_update["peft"] = {"type": "lora", "freeze_others": False} expected_config["base_model"] = base_model_update check_equal_nested(serialized_config, expected_config) diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index b1989eb95..1b49dcfcc 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -3,8 +3,7 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer -from fast_llm.layers.attention.block import TransformerBlock -from fast_llm.layers.ssm.block import SSMBlock +from fast_llm.layers.block.block import Block from fast_llm.utils import Assert from tests.utils.dataset import get_model_test_dataset from tests.utils.model_configs import ModelTestingGroup @@ -37,7 +36,7 @@ def test_frozen_weights(model_testing_config): model_frozen._num_stages, ) frozen_parameter_counts = [ - sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerBlock, SSMBlock)) else 0 + sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, Block) else 0 for layer in model_ref.base_model.layers ] for weight_buffer_ref, weight_buffer_frozen in zip( diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index 93d7c35cd..5c3ecd8a2 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -110,7 +110,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon DistributedTestingConfig( name="ce4", compare="simple", - config_args=["model.base_model.cross_entropy_splits=4"], + config_args=["model.base_model.output_layer.cross_entropy_splits=4"], num_gpus=1, compare_config=_compare_layer_mismatch, ), @@ -228,8 +228,8 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon config_args=[ "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", - "model.base_model.parallel_embeddings=False", - "model.base_model.cross_entropy_splits=4", + "model.base_model.embeddings_layer.vocab_parallel=False", + "model.base_model.output_layer.cross_entropy_splits=4", ], num_gpus=2, compare_config=_compare_layer_match, diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 3a13a78c2..9ba266f12 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -156,13 +156,14 @@ def _update_and_add_testing_config( "training.logs.interval=1", "run.tensor_logs.save=True", "run.tensor_logs.show=False", - "model.base_model.max_position_embeddings=512", + "model.base_model.embeddings_layer.position_embeddings.enabled=True", + "model.base_model.embeddings_layer.num_position_embeddings=512", + f"model.base_model.embeddings_layer.vocab_size={MODEL_TEST_VOCAB_SIZE}", "model.base_model.transformer.num_layers=2", "model.base_model.transformer.hidden_size=256", "model.base_model.transformer.mixer.num_attention_heads=8", "model.base_model.transformer.mixer.head_groups=8", "model.base_model.transformer.init_method_std=0.022", - f"model.base_model.vocab_size={MODEL_TEST_VOCAB_SIZE}", f"model.multi_stage.debug_param_init={_LOG_LEVEL}", f"model.multi_stage.debug_layer_outputs={_LOG_LEVEL}", f"model.multi_stage.debug_layer_gradients={_LOG_LEVEL}", @@ -258,8 +259,7 @@ def _update_and_add_testing_config( extra_args=[ "model.base_model.transformer.mixer.head_groups=4", "model.base_model.transformer.mixer.rotary.type=default", - # Unused, but prevents issues with conversion tests. - "model.base_model.max_position_embeddings=2048", + "model.base_model.embeddings_layer.position_embeddings.enabled=False", ], megatron_args=[ "--group-query-attention", @@ -289,7 +289,7 @@ def _update_and_add_testing_config( "model.base_model.transformer.add_linear_biases=False", "model.base_model.transformer.normalization.type=rms_norm", "model.base_model.transformer.mlp.ffn_hidden_size=1024", - "model.base_model.tie_word_embeddings=False", + "model.base_model.output_layer.tied_weight=False", ], megatron_args=[ "--swiglu", @@ -370,7 +370,7 @@ def _update_and_add_testing_config( # Tests multi-token prediction, custom HF model and converter. "llama", "llama_mtp", - extra_args=["model.base_model.prediction_heads=4"], + extra_args=["model.base_model.output_layer.prediction_heads=4"], # Megatron doesn't support multi-token prediction. megatron_args=None, checkpoint_format=MTPLlamaGPTHuggingfaceCheckpointFormat, From 3fd092c0d586816b1503f93afc631bab02a0798d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 3 Sep 2025 18:34:18 -0400 Subject: [PATCH 75/82] fix --- fast_llm/engine/multi_stage/stage.py | 6 +++++- fast_llm/models/gpt/conversion.py | 6 +++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 40ef07f67..0bbc86f18 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -121,7 +121,11 @@ def forward( # Last layer does not provide output if output is not None: meta = self._meta_outputs[i] - output_global, _ = meta.local_to_global(output.detach()) + if output.shape == meta.shape: + output_global, _ = meta.local_to_global(output.detach()) + else: + # TODO: Handle variable shape. + output_global = output kwargs["hidden_states"][self._layer_range[i]] = { "layer_type": type(layer).__name__, "tensor": output_global, diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 5e8b94354..365f84d52 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -217,21 +217,21 @@ def _create_transformer_layer_converters( # Self-attn ( f"{fast_llm_layer_name}.mixer.query", - f"{hf_layer_name}.mixer.q_proj", + f"{hf_layer_name}.self_attn.q_proj", # TODO: Fix transformer_config.add_linear_biases, QueryWeightConverter, ), ( f"{fast_llm_layer_name}.mixer.key_value", - (f"{hf_layer_name}.mixer.k_proj", f"{hf_layer_name}.mixer.v_proj"), + (f"{hf_layer_name}.self_attn.k_proj", f"{hf_layer_name}.self_attn.v_proj"), # TODO: Fix transformer_config.add_linear_biases, KeyValueWeightConverter, ), ( f"{fast_llm_layer_name}.mixer.dense", - f"{hf_layer_name}.mixer.o_proj", + f"{hf_layer_name}.self_attn.o_proj", # TODO: Fix transformer_config.add_linear_biases, WeightConverter, From 1a3497cc0ee0794f092087c545d2b4e4478adef4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 3 Sep 2025 18:57:44 -0400 Subject: [PATCH 76/82] stuff --- Megatron-LM | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Megatron-LM b/Megatron-LM index f02b413f7..89f391e30 160000 --- a/Megatron-LM +++ b/Megatron-LM @@ -1 +1 @@ -Subproject commit f02b413f793af05ade3893bccd8aef6d644d3edf +Subproject commit 89f391e300e10a5361f5ebf4c40ac7fa69c16562 From b6e7fce0214522063852f01cffdb43283a8698a5 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 3 Sep 2025 20:43:30 -0400 Subject: [PATCH 77/82] stuff --- fast_llm/layers/ssm/config.py | 142 ++++++++++++++++++++---------- fast_llm/layers/ssm/mamba.py | 36 ++------ fast_llm/layers/ssm/mamba2.py | 15 ++-- fast_llm/models/ssm/conversion.py | 13 +-- 4 files changed, 105 insertions(+), 101 deletions(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 53e4cf475..a81f29833 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -3,15 +3,18 @@ import typing from fast_llm.config import Field, FieldHint, check_field, config_class +from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, LambdaInitializer from fast_llm.engine.config_utils.parameter import ParameterConfig from fast_llm.layers.block.config import MixerConfig from fast_llm.layers.common.linear.config import AffineLinearConfig, CausalConv1dConfig, LinearConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: + import torch + from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 from fast_llm.layers.ssm.mamba import Mamba - from fast_llm.layers.ssm.mamba2 import Mamba2 + from fast_llm.tensor import ParameterMeta class SSMBlockType(enum.StrEnum): @@ -45,11 +48,6 @@ class DTInitType(enum.StrEnum): constant = "constant" random = "random" - def get_init_method(self, scale: float) -> "Initializer": - from fast_llm.engine.config_utils.initialization import init_fill_, init_uniform_centered_ - - return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) - @config_class() class SSMConfig(MixerConfig): @@ -127,38 +125,11 @@ class MambaBaseConfig(SSMConfig): hint=FieldHint.architecture, ) - # Initialization - # dt_bias_initialization_min [Mamba, Mamba2] - dt_min: float = Field( - default=0.001, - desc="Minimum step size for discretization", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), - ) - # dt_bias_initialization_max [Mamba, Mamba2] - dt_max: float = Field( - default=0.1, - desc="Maximum step size for discretization", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), - ) - # dt_bias_initialization_floor [Mamba, Mamba2] - dt_init_floor: float = Field( - default=1e-4, - desc="Minimum value for initializing dt", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), - ) - def set_defaults(self, hidden_size: int): super().set_defaults(hidden_size) if self.dt_rank is None: self.dt_rank = math.ceil(hidden_size / 16) - def _validate(self) -> None: - super()._validate() - Assert.geq(self.dt_max, self.dt_min) - @config_class(dynamic_type={MixerConfig: "mamba"}) class MambaConfig(MambaBaseConfig): @@ -228,21 +199,6 @@ class Mamba2Config(MambaBaseConfig): hint=FieldHint.architecture, ) - # Initialization - # dt_weight_initialization_method [Mamba2] - dt_init: DTInitType = Field( - default=DTInitType.random, - desc="Initialization method for dt", - hint=FieldHint.core, - ) - # dt_weight_initialization_scale [Mamba2] - dt_scale: float = Field( - default=1.0, - desc="Scale for dt", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), - ) - def set_defaults(self, hidden_size: int): super().set_defaults(hidden_size) if self.d_xb is None: @@ -299,3 +255,93 @@ def layer_class(self) -> "type[DiscreteMamba2]": from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 return DiscreteMamba2 + + +@config_class(dynamic_type={InitializationConfig: "mamba_dt_bias"}) +class MambaDTBiasInitializationConfig(InitializationConfig): + """ + Configuration for the common Mamba DT bias initialization scheme. + """ + + _abstract = False + # dt_bias_initialization_min [Mamba, Mamba2] + min_step_size: float = Field( + default=0.001, + desc="Minimum step size for discretization", + hint=FieldHint.core, + valid=check_field(Assert.gt, 0), + ) + # dt_bias_initialization_max [Mamba, Mamba2] + max_step_size: float = Field( + default=0.1, + desc="Maximum step size for discretization", + hint=FieldHint.core, + valid=check_field(Assert.gt, 0), + ) + # dt_bias_initialization_floor [Mamba, Mamba2] + floor: float = Field( + default=1e-4, + desc="Minimum value for initializing dt", + hint=FieldHint.core, + valid=check_field(Assert.gt, 0), + ) + + def _validate(self) -> None: + super()._validate() + Assert.geq(self.max_step_size, self.min_step_size) + + def get_initializer(self) -> Initializer: + return init_dtprojbias(self.min_step_size, self.max_step_size, self.floor) + + +@config_class(dynamic_type={InitializationConfig: "mamba_a"}) +class MambaAInitializationConfig(InitializationConfig): + """ + Initialization configuration for Mamba A parameter. + Not particularly useful outside the default A initialization, but still made available for convenience. + """ + + _abstract = False + # dt_bias_initialization_min [Mamba, Mamba2] + state_size: int = Field( + desc="State size. Needs to be repeated here so the initializer knows about it.", + hint=FieldHint.core, + valid=check_field(Assert.gt, 0), + ) + # dt_bias_initialization_max [Mamba, Mamba2] + d_inner: int = Field( + desc="Inner dimension. Needs to be repeated here so the initializer knows about it.", + hint=FieldHint.core, + valid=check_field(Assert.gt, 0), + ) + + def get_initializer(self) -> Initializer: + return init_a(self.state_size, self.d_inner) + + +def init_dtprojbias( + min_step_size: float = 0.001, max_step_size: float = 0.1, floor: float = 1e-4 +) -> LambdaInitializer: + def init_(meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator"): # noqa + import torch + + tensor.uniform_(math.log(min_step_size), math.log(max_step_size), generator=generator).exp_().clamp_min_(floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + tensor.add_(torch.log(-torch.expm1(-tensor))) + + return LambdaInitializer(init_) + + +def init_a(d_state, d_inner) -> LambdaInitializer: + def init_(meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: # noqa + import torch + + Assert.eq(tensor.numel(), d_state * d_inner) + torch.log( + torch.arange(1, d_state + 1, dtype=torch.float32, device=tensor.device) + .unsqueeze(0) + .expand(d_inner, d_state), + out=tensor, + ) + + return LambdaInitializer(init_, requires_global_initialization=True) diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index 4bc67c650..e98201c67 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -1,19 +1,17 @@ import logging -import math import typing import torch -from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ +from fast_llm.engine.config_utils.initialization import init_normal_, init_ones_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.layers.ssm.config import MambaConfig -from fast_llm.tensor import ParameterMeta -from fast_llm.utils import Assert, div +from fast_llm.layers.ssm.config import MambaConfig, init_a, init_dtprojbias +from fast_llm.utils import div try: from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn as _mamba_inner_fn # noqa @@ -31,28 +29,6 @@ """ -def init_A(d_state, d_inner) -> LambdaInitializer: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa - Assert.eq(tensor.numel(), d_state * d_inner) - torch.log( - torch.arange(1, d_state + 1, dtype=torch.float32, device=tensor.device) - .unsqueeze(0) - .expand(d_inner, d_state), - out=tensor, - ) - - return LambdaInitializer(init_, requires_global_initialization=True) - - -def init_dtprojbias(dt_max: float, dt_min: float, dt_init_floor: float) -> LambdaInitializer: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - tensor.uniform_(math.log(dt_min), math.log(dt_max), generator=generator).exp_().clamp_min_(dt_init_floor) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - tensor.add_(torch.log(-torch.expm1(-tensor))) - - return LambdaInitializer(init_) - - class Mamba[ConfigType: MambaConfig](BlockLayer[ConfigType]): _mixer_name: typing.ClassVar[str] = "mamba" @@ -119,15 +95,13 @@ def __init__( dt_rank_dim, inner_dim, default_weight_initialization=init_normal_(0, (2 / self._config.d_inner) ** 0.5), - default_bias_initialization=init_dtprojbias( - self._config.dt_max, self._config.dt_min, self._config.dt_init_floor - ), + default_bias_initialization=init_dtprojbias(), lr_scale=self._lr_scale, peft=self._peft, ) self.A_log = self._config.a_log_weight.get_parameter( (inner_dim, state_dim), - default_initialization=init_A(self._config.state_size, self._config.d_inner), + default_initialization=init_a(self._config.state_size, self._config.d_inner), weight_decay=False, lr_scale=self._lr_scale, peft=self._peft, diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index e386aa712..9c7c2e97c 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -3,15 +3,14 @@ import torch -from fast_llm.engine.config_utils.initialization import init_normal_, init_ones_ +from fast_llm.engine.config_utils.initialization import init_normal_, init_ones_, init_uniform_centered_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.layers.ssm.config import Mamba2Config -from fast_llm.layers.ssm.mamba import init_A, init_dtprojbias +from fast_llm.layers.ssm.config import Mamba2Config, init_a, init_dtprojbias from fast_llm.utils import div try: @@ -112,19 +111,15 @@ def __init__( self.dt_proj = self._config.dt_layer.get_layer( dt_rank_dim, inner_dim, - default_weight_initialization=self._config.dt_init.get_init_method( - self._config.dt_rank**-0.5 * self._config.dt_scale - ), - default_bias_initialization=init_dtprojbias( - self._config.dt_max, self._config.dt_min, self._config.dt_init_floor - ), + default_weight_initialization=init_uniform_centered_(self._config.dt_rank**-0.5), + default_bias_initialization=init_dtprojbias(), sequence_parallel=self._sequence_parallel, lr_scale=self._lr_scale, peft=self._peft, ) self.A_log = self._config.a_log_weight.get_parameter( (inner_dim, state_dim), - default_initialization=init_A(self._config.state_size, self._config.d_inner), + default_initialization=init_a(self._config.state_size, self._config.d_inner), weight_decay=False, lr_scale=self._lr_scale, peft=self._peft, diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 5e05364a4..999974ea3 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -22,7 +22,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.normalization.config import RMSNormalizationConfig -from fast_llm.layers.ssm.config import DTInitType, SSMBlockType +from fast_llm.layers.ssm.config import SSMBlockType from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter from fast_llm.models.ssm.config import ( AprielSSMHHybridHuggingfaceCheckpointFormat, @@ -207,17 +207,6 @@ def _create_config_converters(cls) -> list[ParamConverter]: ignore_missing=True, default_value=4, ), - MappedConfigParamConverter( - fast_llm_names=(("ssm", "dt_init"),), - export_names=( - ( - "ssm_cfg", - "dt_init", - ), - ), - fast_llm_value=lambda x: DTInitType.random if x == MISSING else DTInitType(x), - export_value=lambda x: x.value, - ), ] def _create_weight_converters(self) -> list[WeightConverter]: From 4dfe2a4f135a9020e8c1a0e673f8ae8e6bc4d169 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 8 Sep 2025 21:05:29 -0400 Subject: [PATCH 78/82] stuff --- Megatron-LM | 2 +- examples/mistral.yaml | 20 +- fast_llm/engine/base_model/base_model.py | 8 +- fast_llm/engine/base_model/config.py | 12 + fast_llm/engine/config_utils/run.py | 16 +- fast_llm/engine/evaluation/evaluator.py | 18 +- fast_llm/engine/multi_stage/multi_stage.py | 8 +- fast_llm/engine/multi_stage/stage.py | 13 + fast_llm/engine/schedule/schedule.py | 57 +++- fast_llm/engine/training/trainer.py | 12 +- fast_llm/functional/triton/rotary.py | 4 +- fast_llm/layers/attention/attention.py | 184 ++++++----- fast_llm/layers/attention/config.py | 48 ++- fast_llm/layers/attention/preprocessing.py | 2 +- fast_llm/layers/attention/rotary/config.py | 4 +- fast_llm/layers/attention/rotary/rotary.py | 48 +-- fast_llm/layers/block/block.py | 89 +++-- fast_llm/layers/block/config.py | 53 +-- fast_llm/layers/block/mlp/config.py | 48 +-- .../layers/block/mlp/mixture_of_experts.py | 95 +++--- fast_llm/layers/block/mlp/mlp.py | 60 ++-- fast_llm/layers/common/linear/convolution.py | 6 +- fast_llm/layers/common/linear/linear.py | 20 +- fast_llm/layers/language_model/config.py | 23 +- fast_llm/layers/language_model/embedding.py | 22 +- fast_llm/layers/language_model/head.py | 31 +- fast_llm/layers/ssm/config.py | 7 + fast_llm/layers/ssm/discrete_mamba2.py | 20 +- fast_llm/layers/ssm/mamba.py | 17 +- fast_llm/layers/ssm/mamba2.py | 45 ++- fast_llm/logging.py | 12 + fast_llm/models/gpt/conversion.py | 128 ++++++-- fast_llm/models/gpt/megatron.py | 20 +- fast_llm/models/gpt/model.py | 74 +---- fast_llm/models/ssm/model.py | 2 - tests/functional/test_functional.py | 22 +- tests/functional/test_triton_kernels.py | 22 +- tests/layers/test_lm_head.py | 12 +- tests/test_attention.py | 24 +- tests/test_config.py | 13 +- tests/utils/model_configs.py | 304 +++++++++++------- 41 files changed, 905 insertions(+), 720 deletions(-) diff --git a/Megatron-LM b/Megatron-LM index 89f391e30..30e7aeccd 160000 --- a/Megatron-LM +++ b/Megatron-LM @@ -1 +1 @@ -Subproject commit 89f391e300e10a5361f5ebf4c40ac7fa69c16562 +Subproject commit 30e7aeccd87ec22e424f35c6e61f05ceb878a8df diff --git a/examples/mistral.yaml b/examples/mistral.yaml index 6f4a60143..924bfba51 100644 --- a/examples/mistral.yaml +++ b/examples/mistral.yaml @@ -33,27 +33,31 @@ model: rotary: type: default theta: 10000 - num_attention_heads: 32 + heads: 32 head_groups: 8 - kv_channels: 128 + head_size: 128 + add_linear_biases: false window_size: 4096 - attention_dropout: 0.0 + dropout: 0.0 mlp: - ffn_hidden_size: 14336 + intermediate_size: 14336 + add_linear_biases: false gated: true - activation_type: silu + activation: silu normalization: type: rms_norm epsilon: 1.0e-05 num_layers: 32 hidden_size: 4096 - add_linear_biases: false - init_method_std: 0.009021 - hidden_dropout: 0.0 + dropout: 0.0 embeddings_layer: vocab_size: 32000 + dropout: 0.0 output_layer: tied_weight: false + normalization: + type: rms_norm + epsilon: 1.0e-05 multi_stage: zero_stage: 2 distributed: diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 832225803..9de5ac2cc 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -6,7 +6,7 @@ import torch.nn from fast_llm.config import Configurable -from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.base_model.config import BaseModelConfig, ResourceUsageConfig from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.tensor import ParameterMeta, TensorMeta @@ -43,6 +43,9 @@ def forward( ) -> torch.Tensor: pass + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + raise NotImplementedError() + class Sequential(Layer): def __init__(self, distributed_config: DistributedConfig): @@ -94,7 +97,8 @@ def __init__( distributed_config: DistributedConfig, ): super().__init__(config, distributed_config) - + for key, value in self.named_modules(): + value.module_name = key for key, value in self.named_parameters(): Assert.custom(isinstance, value, ParameterMeta) # Rename to the parameter full name diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 22abb021b..2b55d782e 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -63,3 +63,15 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: @abc.abstractmethod def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: pass + + +@dataclasses.dataclass +class ResourceUsageConfig: + # Disable to get usage for current GPU only + global_: bool = True + # Enable to get hardware compute, i.e. include redundant computations. + hardware: bool = False + # Number of backward passes. Typically 1, may be 2 with full activation recomputation. + forward: int = 1 + # Number of backward passes. Typically 1 for training, 0 for inference. + backward: int = 1 diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 7ab5b8e41..1fc0c626d 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -5,11 +5,11 @@ import yaml -from fast_llm.config import Config, Field, FieldHint, FieldVerboseLevel, config_class +from fast_llm.config import Config, Field, FieldHint, FieldVerboseLevel, check_field, config_class from fast_llm.engine.config_utils.logging import TensorLogs, TensorLogsConfig, configure_logging from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.utils import log, set_global_variables +from fast_llm.utils import Assert, log, set_global_variables if typing.TYPE_CHECKING: from fast_llm.engine.distributed.distributed import Distributed @@ -58,6 +58,12 @@ class RunConfig(Config): desc="Global switch to use triton kernels for linear layers. These may be slightly slower than the defaults.", hint=FieldHint.performance, ) + model_debug_level: int = Field( + default=0, + desc="Debugging level for the model, ex. for printing intermediate model states.", + hint=FieldHint.logging, + valid=check_field(Assert.geq, 0), + ) def _validate(self): if self.experiment_dir is None: @@ -204,15 +210,21 @@ def open_artifact(self, name: str, mode: str | None = "w", verbose=True) -> path return path if mode is None else path.open(mode) def __enter__(self): + from fast_llm.logging import set_model_debug_level + assert not self._is_running global _run _run = self TensorLogs.reset(self._config.tensor_logs) + set_model_debug_level(self._config.model_debug_level) def __exit__(self, exc_type, exc_val: OSError, exc_tb): + from fast_llm.logging import set_model_debug_level + assert self._is_running global _run self.save_logged_tensors("none") + set_model_debug_level(0) _run = None diff --git a/fast_llm/engine/evaluation/evaluator.py b/fast_llm/engine/evaluation/evaluator.py index 6b8f8db00..33e4d654f 100644 --- a/fast_llm/engine/evaluation/evaluator.py +++ b/fast_llm/engine/evaluation/evaluator.py @@ -1,6 +1,8 @@ import abc import dataclasses +import functools import logging +import math import time import typing @@ -203,12 +205,10 @@ def _evaluate_loss( ) end_time = time.perf_counter() time_per_iteration = (end_time - begin_time) / num_iters - model_tflops, hardware_tflops = self._multi_stage.get_tflops( - phase, - time_per_iteration, - self._batch_config.batch_size, - self._batch_config.sequence_length, - ) + + model_compute, hardware_compute = self._schedule.compute_usage + model_tflops = math.nan if model_compute is None else model_compute / time_per_iteration + hardware_tflops = math.nan if hardware_compute is None else hardware_compute / time_per_iteration # TODO add other relevant eval metrics metrics = { "batch_size": self._batch_config.batch_size, @@ -218,7 +218,7 @@ def _evaluate_loss( "hardware_tflops": hardware_tflops, "tokens_per_sec_per_gpu": ( (self._batch_config.sequence_length * self._batch_config.batch_size) - / self._schedule._distributed.world_size + / self._schedule._distributed_config.world_size / time_per_iteration ), **get_and_reset_memory_usage_mib(), @@ -240,6 +240,10 @@ def _get_data_iterator( prefetch_factor=prefetch_factor, ) + @functools.cached_property + def compute_usage(self) -> tuple[int | None, int | None]: + return self._schedule.get_compute_usage(hardware=False), self._schedule.get_compute_usage(hardware=True) + # NOTE: This is not a standalone runnable; it's a submodule of Trainer used for code encapsulation. class EvaluatorRunner: diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index d939bda2b..b38056adb 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -1,4 +1,3 @@ -import abc import dataclasses import logging import typing @@ -13,7 +12,7 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.run import log_main_rank, log_model_parallel_main_rank from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames, PhaseType +from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode from fast_llm.engine.multi_stage.fsdp import FSDP @@ -252,11 +251,6 @@ def setup(self, distributed: Distributed | None = None, mode: StageMode = StageM self.train(self._mode.support_backward) - @abc.abstractmethod - def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration, batch_size, sequence_length) -> tuple[int, int]: - # TODO: Do in model, automate/generalize, get other stats - pass - def _allocate_buffers( self, buffer_meta: TensorMeta, sizes: list[int], name: str ) -> tuple[tuple[torch.Tensor, ...], int]: diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 0bbc86f18..7829c243b 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -5,6 +5,7 @@ import torch from fast_llm.core.distributed import check_parallel_match +from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import StageConfig, StageMode @@ -81,6 +82,7 @@ def setup( # noqa def forward_meta(self, input_: TensorMeta, kwargs: dict) -> TensorMeta: # Store the meta inputs and outputs, for debugging only. + # TODO: Varies if there are multiple schedules. self._meta_inputs, self._meta_outputs = [], [] # TODO: use layer.forward_meta for layer in self._layers: @@ -93,6 +95,17 @@ def forward_meta(self, input_: TensorMeta, kwargs: dict) -> TensorMeta: self._meta_outputs.append(input_) return input_ + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + total = 0 + for layer in self._layers: + total += layer.get_compute_usage(input_, kwargs, config) + input_ = layer( + input_, + kwargs, + losses={}, + ) + return total + def forward( self, input_: torch.Tensor, diff --git a/fast_llm/engine/schedule/schedule.py b/fast_llm/engine/schedule/schedule.py index 91ce0d892..18ca44b78 100644 --- a/fast_llm/engine/schedule/schedule.py +++ b/fast_llm/engine/schedule/schedule.py @@ -1,5 +1,6 @@ import abc import dataclasses +import functools import logging import typing import warnings @@ -9,6 +10,7 @@ import torch.utils import torch.utils.data +from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.multi_stage.multi_stage import MultiStageModel from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig, StepType @@ -127,12 +129,12 @@ def __init__( self._multi_stage = multi_stage self._batch_config = batch_config self._schedule_config = schedule_config - self._distributed = distributed_config + self._distributed_config = distributed_config self._num_stages = len(self._multi_stage.stages) self._phase = phase self._is_training = self._phase.is_training - if self._batch_config.num_inputs < self._distributed.pipeline_parallel: + if self._batch_config.num_inputs < self._distributed_config.pipeline_parallel: warnings.warn("Not enough input to achieve true pipeline parallelism.") # Setup the activation metas. @@ -172,7 +174,7 @@ def iterate(self, pipeline_rank: int | None = None) -> typing.Iterator[Step]: return iter(self._steps if pipeline_rank is None else self._device_steps[pipeline_rank]) def __iter__(self) -> typing.Iterator[Step]: - return self.iterate(self._distributed.pipeline_rank) + return self.iterate(self._distributed_config.pipeline_rank) def __repr__(self) -> str: return "Schedule with steps:\n" + "\n".join( @@ -191,7 +193,7 @@ def get_step( return self._step_map[(type_, stage, data_index)] def _create_index(self) -> None: - self._device_steps: list[list[Step]] = [[] for _ in range(self._distributed.pipeline_parallel)] + self._device_steps: list[list[Step]] = [[] for _ in range(self._distributed_config.pipeline_parallel)] self._step_map = {} for i, step in enumerate(self._steps): Assert.in_range(step.stage, 0, self._num_stages) @@ -204,7 +206,7 @@ def _create_index(self) -> None: step.global_index = i # TODO: More configurable placement? - step.pipeline_rank = step.stage % self._distributed.pipeline_parallel + step.pipeline_rank = step.stage % self._distributed_config.pipeline_parallel step.local_index = len(self._device_steps[step.pipeline_rank]) self._device_steps[step.pipeline_rank].append(step) Assert.not_incl(map_index := step.map_index, self._step_map) @@ -272,7 +274,7 @@ def _create_index(self) -> None: def _setup_restore_steps(self, weight_buffer_indices: dict[int, int]) -> None: for rank, device_steps in enumerate(self._device_steps): - if rank != self._distributed.pipeline_rank: + if rank != self._distributed_config.pipeline_rank: # TODO: Make restore schedule for all ranks (need all buffer indices) continue buffer_contents, buffer_last_used = {}, {} @@ -292,7 +294,7 @@ def _setup_reduce_steps(self, grad_buffer_indices: dict[int, int]) -> None: if not self._is_training: return for rank, device_steps in enumerate(self._device_steps): - if rank != self._distributed.pipeline_rank: + if rank != self._distributed_config.pipeline_rank: # TODO: Make restore schedule for all ranks (need all buffer indices) continue buffer_last_steps = {} @@ -314,12 +316,12 @@ def _setup_reduce_steps(self, grad_buffer_indices: dict[int, int]) -> None: for stage, count in enumerate(reduction_count): assert (count > 0) == ( stage >= self._first_grad_stage - and (stage % self._distributed.pipeline_parallel == self._distributed.pipeline_rank) + and (stage % self._distributed_config.pipeline_parallel == self._distributed_config.pipeline_rank) ) def _setup_timeline(self) -> None: # TODO: Include network time - idx = [0] * self._distributed.pipeline_parallel + idx = [0] * self._distributed_config.pipeline_parallel done = False while not done: done = True @@ -380,11 +382,11 @@ def _setup_send_recv_steps(self) -> None: recv_step.recv_event = torch.cuda.Event() def _validate_send_recv_steps(self) -> None: - times = [0.0] * self._distributed.pipeline_parallel - idx = [0] * self._distributed.pipeline_parallel - recv_idx = [0] * self._distributed.pipeline_parallel - statuses = ["Ok"] * self._distributed.pipeline_parallel - recv_queues: list[list[Step | None]] = [[] for _ in range(self._distributed.pipeline_parallel)] + times = [0.0] * self._distributed_config.pipeline_parallel + idx = [0] * self._distributed_config.pipeline_parallel + recv_idx = [0] * self._distributed_config.pipeline_parallel + statuses = ["Ok"] * self._distributed_config.pipeline_parallel + recv_queues: list[list[Step | None]] = [[] for _ in range(self._distributed_config.pipeline_parallel)] done = False while not done: done = True @@ -519,3 +521,30 @@ def _create_steps(self) -> tuple[list[Step], int]: ) ) return steps, first_grad_stage + + def get_compute_usage( + self, + global_: bool = True, + hardware: bool = False, + ) -> int | None: + total = 0 + try: + for step in self._steps if global_ else self._device_steps[self._distributed_config.pipeline_rank]: + if step.type_ == StepType.forward: + total += self._multi_stage.stages[step.stage].get_compute_usage( + step.meta_input, + step.meta_kwargs, + ResourceUsageConfig( + global_=global_, + hardware=hardware, + forward=1, + backward=int(self._is_training), + ), + ) + return total + except NotImplementedError: + return None + + @functools.cached_property + def compute_usage(self) -> tuple[int | None, int | None]: + return self.get_compute_usage(True, False), self.get_compute_usage(True, True) diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index e5bd5a583..b500a1fda 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -397,12 +397,14 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: remaining_time = average_time_per_iteration * ( self._config.training.train_iters - self._completed_steps ) - model_tflops, hardware_tflops = self._multi_stage.get_tflops( - PhaseType.training, - time_per_iteration, - self._config.batch.batch_size, - self._config.batch.sequence_length, + model_compute, hardware_compute = self._schedule[PhaseType.training][ + PhaseType.training.value.lower() + ].compute_usage + model_tflops = math.nan if model_compute is None else model_compute / time_per_iteration + hardware_tflops = ( + math.nan if hardware_compute is None else hardware_compute / time_per_iteration ) + metrics_key = PhaseType.training.value metrics[metrics_key] = { "train_iters": self._config.training.train_iters, diff --git a/fast_llm/functional/triton/rotary.py b/fast_llm/functional/triton/rotary.py index a05fe57d3..2cea4c6d0 100644 --- a/fast_llm/functional/triton/rotary.py +++ b/fast_llm/functional/triton/rotary.py @@ -69,8 +69,8 @@ def triton_rotary_( # TODO: Make a transposed version to avoid contiguous call in key backward. # TODO: Improve block size heuristics. assert input_.stride(-1) == 1, f"{input_.shape} {input_.stride()}" - batch_size, seq_len, num_heads, kv_channels = input_.shape - rotary_dim = div(kv_channels, 2) + batch_size, seq_len, num_heads, head_size = input_.shape + rotary_dim = div(head_size, 2) rotary_block_size = triton.next_power_of_2(rotary_dim) head_block_size = triton.cdiv(TritonConfig.POINTWISE_BLOCK_SIZE, rotary_block_size) if head_block_size > num_heads: diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 8740ae490..bbd70ede4 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -4,14 +4,16 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim +from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs from fast_llm.layers.block.block import BlockLayer -from fast_llm.layers.block.config import BlockConfig, BlockDimNames +from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.tensor import TensorMeta from fast_llm.utils import div try: @@ -52,26 +54,21 @@ class Attention[ConfigType: AttentionConfig](BlockLayer[ConfigType]): A self-attention layer. """ + _config: ConfigType + def __init__( self, config: ConfigType, - block_config: BlockConfig, distributed_config: DistributedConfig, *, - # TODO: Review `hidden_dim` and `block_index` hidden_dim: TensorDim, - block_index: int, - name: str, lr_scale: float | None, peft: PeftConfig | None, ): super().__init__( config, - block_config, distributed_config, hidden_dim=hidden_dim, - block_index=block_index, - name=name, lr_scale=lr_scale, peft=peft, ) @@ -86,32 +83,32 @@ def __init__( ) group_heads_dim = TensorDim( "group_heads", - div(self._config.num_attention_heads, self._config.head_groups), + div(self._config.heads, self._config.head_groups), None if self._config.head_groups > 1 else self._parallel_dim, ) self._local_head_groups = head_group_dim.size self._local_heads_per_group = group_heads_dim.size self._local_heads = self._local_head_groups * self._local_heads_per_group - kv_channels_dim = TensorDim("kv_channels", self._config.kv_channels) - query_dim = CompositeTensorDim("query", (head_group_dim, group_heads_dim, kv_channels_dim)) + head_size_dim = TensorDim("head_size", self._config.head_size) + query_dim = CompositeTensorDim("query", (head_group_dim, group_heads_dim, head_size_dim)) key_value_dim = ConcatenatedTensorDim( "key_value", ( - CompositeTensorDim("key", (head_group_dim, kv_channels_dim)), - CompositeTensorDim("value", (head_group_dim, kv_channels_dim)), + CompositeTensorDim("key", (head_group_dim, head_size_dim)), + CompositeTensorDim("value", (head_group_dim, head_size_dim)), ), ) - dense_dim = CompositeTensorDim("dense", (head_group_dim, group_heads_dim, kv_channels_dim)) + dense_dim = CompositeTensorDim("dense", (head_group_dim, group_heads_dim, head_size_dim)) - self._softmax_scale = self._config.kv_channels ** (-self._config.attention_softmax_scale_power) + self._softmax_scale = self._config.head_size ** (-self._config.softmax_scale_power) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = self._config.query_layer.get_layer( hidden_dim, query_dim, - default_weight_initialization=init_normal_(std=self._block_config.init_method_std), - default_add_bias=self._block_config.add_linear_biases, + default_weight_initialization=init_normal_(std=self._hidden_size**-0.5), + default_add_bias=self._config.add_linear_biases, default_apply_peft=True, sequence_parallel=self._sequence_parallel, lr_scale=self._lr_scale, @@ -121,8 +118,8 @@ def __init__( self.key_value = self._config.key_layer.get_layer( hidden_dim, key_value_dim, - default_weight_initialization=init_normal_(std=self._block_config.init_method_std), - default_add_bias=self._block_config.add_linear_biases, + default_weight_initialization=init_normal_(std=self._hidden_size**-0.5), + default_add_bias=self._config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=self._lr_scale, peft=None if self._config.key_layer.apply_peft is None else self._peft, @@ -137,39 +134,37 @@ def __init__( self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward) # Rotary embeddings. - self._rotary = self._config.rotary.get_layer(kv_channels_dim) + self._rotary = self._config.rotary.get_layer(head_size_dim) # Output. self.dense = self._config.dense_layer.get_layer( dense_dim, hidden_dim, - default_weight_initialization=init_normal_( - std=self._block_config.init_method_std / max(2 * self._block_config.num_layers, 1) ** 0.5, - ), - default_add_bias=self._block_config.add_linear_biases, + default_weight_initialization=init_normal_(std=self._hidden_size**-0.5), + default_add_bias=self._config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=self._lr_scale, peft=self._peft, ) - if self._debug.enabled: - self._query_dims = ( - BlockDimNames.batch, - BlockDimNames.sequence_q, - CompositeTensorDim("heads", (head_group_dim, group_heads_dim)), - kv_channels_dim, - ) - self._kv_dims = ( - BlockDimNames.batch, - BlockDimNames.sequence_q, - head_group_dim, - kv_channels_dim, - ) - self._context_dims = ( - BlockDimNames.batch, - BlockDimNames.sequence_q, - dense_dim, - ) + # Debug dims + self._query_dims = ( + BlockDimNames.batch, + BlockDimNames.sequence_q, + CompositeTensorDim("heads", (head_group_dim, group_heads_dim)), + head_size_dim, + ) + self._kv_dims = ( + BlockDimNames.batch, + BlockDimNames.sequence_q, + head_group_dim, + head_size_dim, + ) + self._context_dims = ( + BlockDimNames.batch, + BlockDimNames.sequence_q, + dense_dim, + ) def _attn_fused( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor @@ -179,17 +174,17 @@ def _attn_fused( sk = key.size(1) if self._local_head_groups == 1: - query = query.view(b, sq * self._local_heads, self._config.kv_channels) + query = query.view(b, sq * self._local_heads, self._config.head_size) key = key.transpose(-1, -2) else: query = ( - query.unflatten(-1, (self._local_head_groups, self._local_heads_per_group, self._config.kv_channels)) + query.unflatten(-1, (self._local_head_groups, self._local_heads_per_group, self._config.head_size)) .transpose(1, 2) - .reshape(b * self._local_head_groups, sq * self._local_heads_per_group, self._config.kv_channels) + .reshape(b * self._local_head_groups, sq * self._local_heads_per_group, self._config.head_size) ) - key = key.unflatten(-1, (self._local_head_groups, self._config.kv_channels)).movedim(1, 3).flatten(0, 1) + key = key.unflatten(-1, (self._local_head_groups, self._config.head_size)).movedim(1, 3).flatten(0, 1) value = ( - value.unflatten(-1, (self._local_head_groups, self._config.kv_channels)).transpose(1, 2).flatten(0, 1) + value.unflatten(-1, (self._local_head_groups, self._config.head_size)).transpose(1, 2).flatten(0, 1) ) attn_weights = torch.empty( @@ -200,15 +195,15 @@ def _attn_fused( query, key, beta=0, - alpha=self._softmax_scale / self._block_index, + alpha=self._softmax_scale, ).view(b, self._local_head_groups, sq, self._local_heads_per_group, sk) - attn_weights = attn_weights.to(torch.float32) * self._block_index + attn_weights = attn_weights.to(torch.float32) attn_weights = torch.where(mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) with set_generator(self._distributed.tp_generator): - attn_weights = torch.dropout(attn_weights, self._config.attention_dropout, self.training) + attn_weights = torch.dropout(attn_weights, self._config.dropout, self.training) attn_output = torch.bmm( attn_weights.view(b * self._local_head_groups, sq * self._local_heads_per_group, sk), value ) @@ -217,7 +212,7 @@ def _attn_fused( return attn_output.view(b, sq, -1) else: return ( - attn_output.view(b, self._local_head_groups, sq, self._local_heads_per_group, self._config.kv_channels) + attn_output.view(b, self._local_head_groups, sq, self._local_heads_per_group, self._config.head_size) .transpose(1, 2) .flatten(2) ) @@ -278,16 +273,6 @@ def _query_key_value_backward( input_grad.add_(self.key_value.backward(key_value_grad, context.pop("key_value"))) return input_grad - def _decide_window_size(self) -> int | None: - # NOTE: This is a temporal solution for qwen 2.X - # https://github.com/huggingface/transformers/blob/5e2183f344911aa82aba0b83778a4f196cff378e/src/transformers/models/qwen2/modular_qwen2.py#L71 - # TODO: make universal per layer config - window_size = self._config.window_size - if self._config.max_window_layers is not None and self._block_index < self._config.max_window_layers: - window_size = None - - return window_size - def forward( self, input_: torch.Tensor, @@ -324,18 +309,18 @@ def forward( query = query.transpose(0, 1).contiguous() key_value = key_value.transpose(0, 1).contiguous() - key, value = key_value.split(self._local_head_groups * self._config.kv_channels, dim=-1) + key, value = key_value.split(self._local_head_groups * self._config.head_size, dim=-1) - query = query.view(*query.shape[:2], self._local_heads, self._config.kv_channels) - key = key.view(*key.shape[:2], self._local_head_groups, self._config.kv_channels) - value = value.view(*value.shape[:2], self._local_head_groups, self._config.kv_channels) + query = query.view(*query.shape[:2], self._local_heads, self._config.head_size) + key = key.view(*key.shape[:2], self._local_head_groups, self._config.head_size) + value = value.view(*value.shape[:2], self._local_head_groups, self._config.head_size) if self._debug.enabled: - self._debug(query, "query_rotary_input", self._QUERY_DIMS, kwargs) - self._debug(key, "key_rotary_input", self._KV_DIMS, kwargs) + self._debug(query, "query_rotary_input", self._query_dims, kwargs) + self._debug(key, "key_rotary_input", self._kv_dims, kwargs) query, key = self._rotary(query, key, kwargs) - window_size = self._decide_window_size() + window_size = (-1, -1) if self._config.window_size is None else (self._config.window_size - 1, 0) if self._use_flash_attention: assert _flash_available @@ -353,8 +338,8 @@ def forward( cu_seqlens_k=kwargs.get(AttentionKwargs.cu_seqlens_k), max_seqlen_q=kwargs.get(AttentionKwargs.max_seqlen_q), max_seqlen_k=kwargs.get(AttentionKwargs.max_seqlen_k), - dropout_p=self._config.attention_dropout if self.training else 0.0, - window_size=(-1, -1) if window_size is None else (window_size - 1, 0), + dropout_p=self._config.dropout if self.training else 0.0, + window_size=window_size, causal=True, softmax_scale=self._softmax_scale, ).view(*out_dims) @@ -363,8 +348,8 @@ def forward( query, key, value, - window_size=(-1, -1) if window_size is None else (window_size - 1, 0), - dropout_p=self._config.attention_dropout if self.training else 0.0, + window_size=window_size, + dropout_p=self._config.dropout if self.training else 0.0, causal=True, softmax_scale=self._softmax_scale, ) @@ -389,3 +374,58 @@ def forward( # TODO: Optimize (is contiguous avoidable? Transpose dense output?) input_ = input_.transpose(0, 1).contiguous() return self.dense(input_) + + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + batch_dim: TensorDim = kwargs[AttentionKwargs.hidden_dims][1 if kwargs[AttentionKwargs.sequence_first] else 0] + + # Using this one since `hidden_dims` may be sequence-tensor-parallel, and attention is not. + sequence_q_dim: TensorDim = kwargs[AttentionKwargs.sequence_q_dim] + sequence_k_dim: TensorDim = kwargs[AttentionKwargs.sequence_k_dim] + + if config.global_: + batch_size, sequence_q = batch_dim.global_size, sequence_q_dim.global_size + # In case of sequence-data-parallel, we need to undo the shift in k-sequence-length. + sequence_k = sequence_k_dim.global_size - sequence_q_dim.size * ( + sequence_q_dim.parallel_dim.size - sequence_q_dim.parallel_dim.rank - 1 + ) + else: + batch_size, sequence_q = batch_dim.size, sequence_q_dim.size + sequence_k = sequence_k_dim.size + + # 2 for multiply and accumulate, 2 operations (Q * K, attn * V), double for backward + Q * K recomputation. + attn_compute_base = ( + 2 + * (2 * config.forward + (5 if config.hardware else 4) * config.backward) + * self._config.heads + * self._config.head_size + ) + + if self._config.window_size is not None: + # Remove the part of the past that lies completely outside the window, if applicable. + sequence_k -= max(sequence_k - sequence_q - self._config.window_size, 0) + + attention_compute = sequence_q * sequence_k * attn_compute_base + + if (not config.hardware) or self._use_flash_attention: + # Remove non-causal part. (TODO: Support non-causal) + attention_compute -= (sequence_q * (sequence_q - 1) * attn_compute_base) // 2 + + if self._config.window_size is not None: + # Remove the part of the past that lies completely outside the window, if applicable. + fully_out_of_window = max(sequence_k - sequence_q - self._config.window_size, 0) + attention_compute -= fully_out_of_window * sequence_q * attn_compute_base + # Remove the part of the past that lies partially outside the window, if applicable. + partly_out_of_window = max(sequence_k - fully_out_of_window - self._config.window_size, 0) + attention_compute -= (partly_out_of_window * (partly_out_of_window + 1) * attn_compute_base) // 2 + + dense_input = TensorMeta.from_dims((batch_dim, sequence_q_dim, self._context_dims[-1])) + + # TODO: Add marginal compute? (ex. softmax) + return sum( + ( + self.query.get_compute_usage(input_, config), + self.key_value.get_compute_usage(input_, config), + attention_compute, + self.dense.get_compute_usage(dense_input, config), + ) + ) diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 47aa9deea..868d6ba77 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -1,4 +1,3 @@ -import functools import logging import typing import warnings @@ -12,7 +11,7 @@ from fast_llm.layers.attention.rotary.config import RotaryConfig from fast_llm.layers.block.config import BlockKwargs, MixerConfig from fast_llm.layers.common.linear.config import AffineLinearConfig -from fast_llm.utils import Assert, div +from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.layers.attention.attention import Attention @@ -61,7 +60,7 @@ class AttentionConfig(MixerConfig): desc="Configuration for the rotary positional embeddings.", hint=FieldHint.architecture, ) - num_attention_heads: int = Field(default=8, desc="Number of attention heads.", hint=FieldHint.architecture) + heads: int = Field(default=8, desc="Number of attention heads.", hint=FieldHint.architecture) head_groups: int = Field( default=1, desc="Number of head group for grouped query attention.", @@ -69,13 +68,18 @@ class AttentionConfig(MixerConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - kv_channels: int = Field( - default=None, - desc="Number of key and value channels, i.e., hidden dimension of each attention head. Default: hidden_size // num_attention_heads", + head_size: int = Field( + default=128, + desc="Number of key and value channels, i.e., hidden dimension of each attention head.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - attention_dropout: float = Field( + add_linear_biases: bool = Field( + default=True, + desc="Add biases to linear layers. May be overridden for individual layers.", + hint=FieldHint.architecture, + ) + dropout: float = Field( default=0.0, desc="Dropout applied to the attention intermediate states.", hint=FieldHint.feature, @@ -91,33 +95,22 @@ class AttentionConfig(MixerConfig): hint=FieldHint.feature, valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - max_window_layers: int | None = Field( - default=None, - desc="The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.", - hint=FieldHint.optional, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - attention_softmax_scale_power: float = Field( + softmax_scale_power: float = Field( default=0.5, - desc="The scaling power to apply to kv_channel in the attention calculation. " + desc="The scaling power to apply to head_size in the attention calculation. " " Under Standard Parameterization (SP): default to 0.5. " - " Under muP (if scaling kv_channels size): use 1. " - " Under muP (if scaling number of heads instead of kv_channels): use 0.5.", + " Under muP (if scaling head_size size): use 1. " + " Under muP (if scaling number of heads instead of head_size): use 0.5.", valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - def set_defaults(self, hidden_size: int): - if self.kv_channels is None: - with self._set_implicit_default(): - self.kv_channels = div(hidden_size, self.num_attention_heads) - def _validate(self) -> None: super()._validate() if not TritonConfig.TRITON_ENABLED: warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") - Assert.multiple(self.num_attention_heads, self.head_groups) + Assert.multiple(self.heads, self.head_groups) @property def layer_class(self) -> "type[Attention]": @@ -125,18 +118,15 @@ def layer_class(self) -> "type[Attention]": return Attention - @functools.cached_property - def projection_size(self): - assert self._validated - return self.num_attention_heads * self.kv_channels - def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: # We have multiple identical rotary modules/preprocessors, so it's simpler to make a new one here. # TODO: Find a better solution. - preprocessors: list[Preprocessor] = [self.rotary.get_layer(TensorDim("kv_channels", self.kv_channels))] + preprocessors: list[Preprocessor] = [ + self.rotary.get_layer(TensorDim("head_size", self.head_size)), + ] if self.do_use_flash_attention(distributed_config): from fast_llm.layers.attention.preprocessing import FlashAttnVarlenPreprocessor diff --git a/fast_llm/layers/attention/preprocessing.py b/fast_llm/layers/attention/preprocessing.py index 8bb923455..2326b1bf7 100644 --- a/fast_llm/layers/attention/preprocessing.py +++ b/fast_llm/layers/attention/preprocessing.py @@ -13,7 +13,7 @@ class BackupAttentionPreprocessor(Preprocessor): - _kv_channels_dim: TensorDim + _head_size_dim: TensorDim _rotary_embedding_frequencies: torch.Tensor _mask: torch.Tensor _mask_value: torch.Tensor diff --git a/fast_llm/layers/attention/rotary/config.py b/fast_llm/layers/attention/rotary/config.py index 4ebd6c5dc..43bae8c54 100644 --- a/fast_llm/layers/attention/rotary/config.py +++ b/fast_llm/layers/attention/rotary/config.py @@ -29,8 +29,8 @@ def _from_dict( return NoRotaryConfig._from_dict(default, strict, flat) return super()._from_dict(default, strict=strict, flat=flat) - def get_layer(self, kv_channels_dim: TensorDim) -> "Rotary": - return self._get_configurable_class()(self, kv_channels_dim) + def get_layer(self, head_size_dim: TensorDim) -> "Rotary": + return self._get_configurable_class()(self, head_size_dim) @classmethod @abc.abstractmethod diff --git a/fast_llm/layers/attention/rotary/rotary.py b/fast_llm/layers/attention/rotary/rotary.py index 53b24c9bb..889711839 100644 --- a/fast_llm/layers/attention/rotary/rotary.py +++ b/fast_llm/layers/attention/rotary/rotary.py @@ -20,12 +20,12 @@ from fast_llm.utils import div -def convert_rotary_complex_to_real(tensor: torch.Tensor, kv_channels: int, dim: int) -> torch.Tensor: - return tensor.unflatten(dim, (-1, div(kv_channels, 2), 2)).movedim(dim + 1, dim + 2).flatten(dim, dim + 2) +def convert_rotary_complex_to_real(tensor: torch.Tensor, head_size: int, dim: int) -> torch.Tensor: + return tensor.unflatten(dim, (-1, div(head_size, 2), 2)).movedim(dim + 1, dim + 2).flatten(dim, dim + 2) -def convert_rotary_real_to_complex(tensor: torch.Tensor, kv_channels: int, dim: int) -> torch.Tensor: - return tensor.unflatten(dim, (-1, 2, div(kv_channels, 2))).movedim(dim + 1, dim + 2).flatten(dim, dim + 2) +def convert_rotary_real_to_complex(tensor: torch.Tensor, head_size: int, dim: int) -> torch.Tensor: + return tensor.unflatten(dim, (-1, 2, div(head_size, 2))).movedim(dim + 1, dim + 2).flatten(dim, dim + 2) def apply_rotary_embeddings(tensor: torch.Tensor, rope_frequencies: torch.Tensor) -> torch.Tensor: @@ -45,10 +45,10 @@ class Rotary[ConfigType: RotaryConfig](Configurable[ConfigType], torch.nn.Module def __init__( self, config: ConfigType, - kv_channels_dim: TensorDim, + head_size_dim: TensorDim, ): super().__init__(config) - self._kv_channels_dim = kv_channels_dim + self._head_size_dim = head_size_dim @abc.abstractmethod def forward( @@ -88,7 +88,7 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: scalar_dim, kwargs[AttentionKwargs.sequence_q_dim], scalar_dim, - self._kv_channels_dim, + self._head_size_dim, ), tensor_name=AttentionKwargs.rotary_freq_q, ) @@ -97,7 +97,7 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: scalar_dim, kwargs[AttentionKwargs.sequence_q_dim], scalar_dim, - self._kv_channels_dim, + self._head_size_dim, ), tensor_name=AttentionKwargs.rotary_freq_k, ) @@ -117,32 +117,32 @@ def _create_tensors(self, sequence_length: int, device: torch.device) -> None: self._rotary_embedding_frequencies = self._get_frequencies( sequence_length, - self._kv_channels_dim.global_size, + self._head_size_dim.global_size, device=device, ) - def _get_frequencies(self, sequence_length: int, kv_channels: int, device: torch.device) -> torch.Tensor: + def _get_frequencies(self, sequence_length: int, head_size: int, device: torch.device) -> torch.Tensor: # Calculate the complex frequencies (https://blog.eleuther.ai/rotary-embeddings/) # `exp(i * n * a) = cos(n * a) + i sin(n * a)`, - # `a = theta ** - (2 * (channel // 2) / kv_channels)`, + # `a = theta ** - (2 * (channel // 2) / head_size)`, # where n is the position in the sequence. # We preform the calculation in high precision because it matters for rotary embeddings. positions = torch.arange(sequence_length, device=device, dtype=torch.float64) - angles = torch.outer(positions, self._get_angle_scales(kv_channels, device)) + angles = torch.outer(positions, self._get_angle_scales(head_size, device)) frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) if not self._config.complex_format: frequencies = convert_rotary_complex_to_real( - torch.view_as_real(frequencies).flatten(-2), kv_channels, 3 + torch.view_as_real(frequencies).flatten(-2), head_size, 3 ).contiguous() return frequencies - def _get_angle_scales(self, kv_channels: int, device: torch.device) -> torch.Tensor: - return self._config.theta ** -torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64) + def _get_angle_scales(self, head_size: int, device: torch.device) -> torch.Tensor: + return self._config.theta ** -torch.arange(0, 1, 2 / head_size, device=device, dtype=torch.float64) class Llama3Rotary[ConfigType: Llama3RotaryConfig](DefaultRotary[ConfigType]): - def _get_angle_scales(self, kv_channels: int, device: torch.device) -> torch.Tensor: - scales = super()._get_angle_scales(kv_channels, device) + def _get_angle_scales(self, head_size: int, device: torch.device) -> torch.Tensor: + scales = super()._get_angle_scales(head_size, device) low_frequency_wavelength = self._config.original_context_length / self._config.low_frequency_factor high_frequency_wavelength = self._config.original_context_length / self._config.high_frequency_factor new_scales = [] @@ -167,21 +167,21 @@ class YarnRotary[ConfigType: YarnRotaryConfig](DefaultRotary[ConfigType]): [original paper](https://arxiv.org/abs/2309.00071) """ - def _get_frequencies(self, sequence_length: int, kv_channels: int, device: torch.device) -> torch.Tensor: - return super()._get_frequencies(sequence_length, kv_channels, device) * self._config.attention_factor + def _get_frequencies(self, sequence_length: int, head_size: int, device: torch.device) -> torch.Tensor: + return super()._get_frequencies(sequence_length, head_size, device) * self._config.attention_factor - def _get_angle_scales(self, kv_channels: int, device: torch.device) -> torch.Tensor: - scales = super()._get_angle_scales(kv_channels, device) + def _get_angle_scales(self, head_size: int, device: torch.device) -> torch.Tensor: + scales = super()._get_angle_scales(head_size, device) # TODO: max_position_embeddings or original_context_length? # see https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L304 - low = max(math.floor(self._get_correction(self._config.beta_fast, kv_channels)), 0) - high = min(math.ceil(self._get_correction(self._config.beta_slow, kv_channels)), kv_channels - 1) + low = max(math.floor(self._get_correction(self._config.beta_fast, head_size)), 0) + high = min(math.ceil(self._get_correction(self._config.beta_slow, head_size)), head_size - 1) if low == high: high += 0.001 # Prevent singularity # Get n-dimensional rotational scaling corrected for extrapolation extrapolation_factor = torch.clamp( - (torch.arange(kv_channels // 2, dtype=torch.float32, device=scales.device) - low) / (high - low), 0, 1 + (torch.arange(head_size // 2, dtype=torch.float32, device=scales.device) - low) / (high - low), 0, 1 ) return scales / self._config.scale_factor * extrapolation_factor + scales * (1 - extrapolation_factor) diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 10acd67e0..5187ebfdc 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -8,13 +8,14 @@ from fast_llm.config import Config, Configurable from fast_llm.core.distributed import set_generator from fast_llm.engine.base_model.base_model import Layer, Module +from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage +from fast_llm.logging import get_model_debug_level, log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) @@ -22,10 +23,8 @@ class DebugLayer: # TODO: Move elsewhere? - def __init__(self, name: str, debug_level: int = 0, debug_memory: bool = False): - self._name = name - self._debug_level = debug_level - self._debug_memory = debug_memory + def __init__(self, module: torch.nn.Module): + self._module = module def _get_meta( self, tensor: torch.Tensor, name: str, dims: tuple[TensorDim | str, ...], kwargs: dict[str, typing.Any] @@ -47,8 +46,13 @@ def _get_meta( ) @functools.cached_property + def _name(self): + # Should be called after `module_name` is set in `BaseModel` + return getattr(self._module, "module_name", "unknown") + + @property def enabled(self) -> bool: - return self._debug_level > 0 or self._debug_memory + return get_model_debug_level() > 0 def __call__[ T @@ -62,14 +66,15 @@ def __call__[ global_: bool = True, log_fn: type[BaseException] | typing.Callable[[str], T] | None = logger.info, ) -> None: - # TODO: Local vs global? - if self._debug_memory: + if (level := get_model_debug_level()) == 0: + return + if level > 1: log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"{self._name} {name}", str)) - if self._debug_level > 0 and tensor is not None: + if tensor is not None: log_distributed_tensor( "", tensor, - level=self._debug_level, + level=level, meta=self._get_meta(tensor, name, dims, kwargs), global_=global_, log_fn=log_fn, @@ -79,7 +84,7 @@ def __call__[ log_distributed_grad( "", tensor, - level=self._debug_level, + level=level, meta=self._get_meta(tensor, name + " grad", dims, kwargs), global_=global_, log_fn=log_fn, @@ -95,30 +100,23 @@ class BlockLayerBase[ConfigType: Config](Configurable[ConfigType], Module): def __init__( self, config: ConfigType, - block_config: BlockConfig, distributed_config: DistributedConfig, *, - # TODO: Review `hidden_dim` and `block_index` hidden_dim: TensorDim, - block_index: int, - name: str, lr_scale: float | None, peft: PeftConfig | None, ): super().__init__(config, distributed_config) - self._block_config = block_config self._hidden_dim = hidden_dim - self._block_index = block_index - self._name = name + self._hidden_size = self._hidden_dim.global_size self._sequence_parallel: bool = self._distributed_config.sequence_tensor_parallel - self._debug = DebugLayer( - self._name, - self._block_config.debug_transformer, - self._block_config.debug_transformer_memory, - ) + self._debug = DebugLayer(self) self._lr_scale = lr_scale self._peft = peft + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + raise NotImplementedError() + class BlockLayer[ConfigType: Config](BlockLayerBase[ConfigType]): """ @@ -141,28 +139,20 @@ class Block[ConfigType: BlockConfig](BlockLayerBase[ConfigType], Layer): A transformer-like decoder base block with abstract mixer. """ - # TODO: Standardize to `mixer` - _mixer_module_name: typing.ClassVar[str] = "mixer" - def __init__( self, config: ConfigType, distributed_config: DistributedConfig, *, hidden_dim: TensorDim, - block_index: int, - name: str, lr_scale: float | None, peft: PeftConfig | None, return_input: bool = False, ): super().__init__( - config, config, distributed_config, hidden_dim=hidden_dim, - block_index=block_index, - name=name, lr_scale=lr_scale, peft=peft, ) @@ -174,33 +164,23 @@ def __init__( self.norm_2 = self._config.normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) # Attribute should be mixer, but Attention uses a different name for backward compatibility. TODO: Fix. - setattr( - self, - self._mixer_module_name, - self._config.mixer.get_layer( - self._config, - self._distributed_config, - self._hidden_dim, - self._block_index, - f"{self._name} mixer", - self._lr_scale, - peft=peft, - ), + self.mixer = self._config.mixer.get_layer( + self._distributed_config, + self._hidden_dim, + self._lr_scale, + peft=peft, ) self.mlp = self._config.mlp.get_layer( - self._config, self._distributed_config, self._hidden_dim, - self._block_index, - f"{self._name} MLP", self._lr_scale, peft=peft, ) def setup(self, distributed: Distributed) -> None: super().setup(distributed) - getattr(self, self._mixer_module_name).setup(distributed) + self.mixer.setup(distributed) self.mlp.setup(distributed) @torch.compile @@ -209,7 +189,7 @@ def _bias_dropout_add( ) -> torch.Tensor: if bias is not None: input_ = input_ + bias - return residual + torch.dropout(input_, self._config.hidden_dropout, self.training) + return residual + torch.dropout(input_, self._config.dropout, self.training) def forward( self, @@ -222,7 +202,7 @@ def forward( dims = kwargs[BlockKwargs.hidden_dims] if self._return_input: dims = (TensorDim("stacked_input_output", 2),) + dims - return TensorMeta.from_dims(dims, tensor_name=f"{self._name} output", dtype=input_.dtype) + return TensorMeta.from_dims(dims, tensor_name=f"{self.module_name} output", dtype=input_.dtype) generator = self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator if self._debug.enabled: self._debug(None, "begin", kwargs[BlockKwargs.hidden_dims], kwargs) @@ -230,7 +210,7 @@ def forward( hidden_states = self.norm_1(input_) if self._debug.enabled: self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) - hidden_states, bias = getattr(self, self._mixer_module_name)(hidden_states, kwargs) + hidden_states, bias = self.mixer(hidden_states, kwargs) if self._debug.enabled: self._debug( hidden_states if bias is None else hidden_states + bias, @@ -260,3 +240,12 @@ def forward( if self._return_input: hidden_states = torch.stack((fw_input, hidden_states), dim=0) return hidden_states + + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + # TODO: Add marginal compute? (normalization, bias_dropout_add) + return sum( + ( + self.mixer.get_compute_usage(input_, kwargs, config), + self.mlp.get_compute_usage(input_, kwargs, config), + ) + ) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 7602dfabe..b4772d50e 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -66,26 +66,21 @@ def set_defaults(self, hidden_size: int): def get_layer( self, - block_config: "BlockConfig", distributed_config: DistributedConfig, hidden_dim: TensorDim, - block_index: int, - name: str, lr_scale: float | None, peft: PeftConfig | None, ) -> "BlockLayer": return self.layer_class( self, - block_config, distributed_config, hidden_dim=hidden_dim, - block_index=block_index, - name=name, lr_scale=combine_lr_scales(lr_scale, self.lr_scale), peft=peft, ) def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: + # TODO: Move to actual layers? return [] @@ -130,14 +125,13 @@ def _from_dict( @config_class() -# TODO: Use composition instead class BlockConfig(BaseModelConfig): _abstract = False mixer: MixerConfig = Field() mlp: MLPBaseConfig = Field() # TODO: Review names normalization: NormalizationConfig = Field( - desc="Configuration for the normalization layers architecture.", + desc="Configuration for the block normalization layers.", hint=FieldHint.architecture, ) lr_scale: float | None = Field( @@ -147,33 +141,12 @@ class BlockConfig(BaseModelConfig): hint=FieldHint.feature, ) # TODO: Review names - hidden_dropout: float = Field( + dropout: float = Field( default=0.0, desc="Dropout applied to the residual connections.", hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - full_precision_residual: bool = Field( - default=False, - desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", - hint=FieldHint.stability, - ) - debug_transformer: int = Field( - default=0, - desc="Log the output of each operation in a transformer layer.", - hint=FieldHint.logging, - valid=check_field(Assert.geq, 0), - ) - debug_transformer_memory: bool = Field( - default=False, - desc="Log the memory usage after each operation in a transformer layer..", - hint=FieldHint.logging, - ) - add_linear_biases: bool = Field( - default=True, - desc="Add biases to linear layers. May be overridden for individual layers.", - hint=FieldHint.architecture, - ) # TODO: Move these, not specific to a single block. num_layers: int = Field( default=12, @@ -187,21 +160,8 @@ class BlockConfig(BaseModelConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - # TODO: Review initialization - init_method_std: float = Field( - default=None, - desc="Default scale for weight initialization. Default: hidden_size**-0.5", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) def _validate(self) -> None: - with self._set_implicit_default(): - # Kept here for initialization order. - # TODO: Review initialization - if self.init_method_std is None: - self.init_method_std = self.hidden_size**-0.5 - self.mixer.set_defaults(self.hidden_size) self.mlp.set_defaults(self.hidden_size) @@ -211,8 +171,6 @@ def get_layer( self, distributed_config: DistributedConfig, hidden_dim: TensorDim, - block_index: int, - name: str, lr_scale: float | None, peft: PeftConfig | None = None, return_input: bool = False, @@ -223,12 +181,11 @@ def get_layer( self, distributed_config, hidden_dim=hidden_dim, - block_index=block_index, - name=name, lr_scale=combine_lr_scales(lr_scale, self.lr_scale), peft=peft, return_input=return_input, ) def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - return self.mixer.get_preprocessors(distributed_config) + # TODO: Move to actual layers? + return self.mixer.get_preprocessors(distributed_config) + self.mlp.get_preprocessors(distributed_config) diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 3e7d96736..3d8a9c2bf 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -38,34 +38,34 @@ class MLPConfig(MLPBaseConfig): desc="Configuration for the second MLP layer.", hint=FieldHint.architecture, ) - ffn_hidden_size: int = Field( - default=None, - desc="Hidden dimension of the MLP intermediate state. Default: 4 * hidden_size.", + intermediate_size: int = Field( + default=4096, + desc="Hidden dimension of the MLP intermediate state.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + add_linear_biases: bool = Field( + default=True, + desc="Add biases to linear layers. May be overridden for individual layers.", + hint=FieldHint.architecture, + ) gated: bool = Field(default=False, desc="Enable gated MLP.", hint=FieldHint.architecture) - activation_type: ActivationType = Field( + activation: ActivationType = Field( default=None, desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", hint=FieldHint.core, ) # normalization_implementation: NormalizationImplementation = NormalizationImplementation.auto - mlp_recompute_level: MLPRecomputeLevel = Field( + recompute_level: MLPRecomputeLevel = Field( default=MLPRecomputeLevel.none, desc="Set which of the MLP intermediate activations will be recomputed during the backward passes. This provides a trade-off between memory and speed.", hint=FieldHint.performance, ) - def set_defaults(self, hidden_size: int): - if self.ffn_hidden_size is None: - with self._set_implicit_default(): - self.ffn_hidden_size = 4 * hidden_size - def _validate(self) -> None: with self._set_implicit_default(): - if self.activation_type is None: - self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu + if self.activation is None: + self.activation = ActivationType.silu if self.gated else ActivationType.gelu super()._validate() @@ -86,48 +86,48 @@ class MoEMLPConfig(MLPConfig): desc="Configuration for the MoE router.", hint=FieldHint.feature, ) - num_experts: int = Field( + experts: int = Field( default=1, desc="Number of MLP experts in a Mixture of Expert (MoE) model", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - num_shared_experts: int = Field( + shared_experts: int = Field( default=0, desc="Number of MLP experts that are shared between all tokens, i.e., always enabled.", hint=FieldHint.architecture, valid=check_field(Assert.geq, 0), ) - num_experts_per_token: int = Field( + experts_per_token: int = Field( default=1, desc="Active experts for each token in a MoE model.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - expert_routing_type: RoutingType = Field( + routing: RoutingType = Field( default=RoutingType.topk, desc="The routing method, i.e., the method used to assign experts to tokens.", hint=FieldHint.architecture, ) - expert_auxiliary_loss_coefficient: float = Field( + auxiliary_loss_coefficient: float = Field( default=0.01, desc="Scale of the load balancing auxiliary loss for topk routing.", hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - expert_z_loss_coefficient: float = Field( + z_loss_coefficient: float = Field( default=0.0, desc="Regularize the router during training by applying Z-loss to the logits.", hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - moe_jitter_eps: float = Field( + jitter_eps: float = Field( default=0.0, desc="Regularize the router during training by applying a random multiplicative noise `uniform(1-eps, 1+eps)` to the logits.", hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - dropless_moe: bool = Field( + dropless: bool = Field( default=True, desc="Evaluate all the experts at once using dropless MoE.", hint=FieldHint.expert ) dropless_dynamic_shape: bool = Field( @@ -144,10 +144,10 @@ def layer_class(self) -> "type[MixtureOfExpertMLP]": return MixtureOfExpertMLP @functools.cached_property - def num_unshared_experts(self) -> int: - return self.num_experts - self.num_shared_experts + def unshared_experts(self) -> int: + return self.experts - self.shared_experts def _validate(self) -> None: super()._validate() - Assert.leq(self.num_shared_experts, self.num_experts) - Assert.leq(self.num_shared_experts + self.num_experts_per_token, self.num_experts) + Assert.leq(self.shared_experts, self.experts) + Assert.leq(self.shared_experts + self.experts_per_token, self.experts) diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index d0d94d88c..9478dc51c 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -1,19 +1,23 @@ import logging +import typing import warnings import torch from fast_llm.core.distributed import ProcessGroup, set_generator +from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map -from fast_llm.layers.block.config import BlockConfig, BlockKwargs +from fast_llm.layers.attention.config import AttentionKwargs +from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.block.mlp.config import MLPLossNames, MoEMLPConfig, RoutingType from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -36,37 +40,31 @@ class MixtureOfExpertMLP[ConfigType: MoEMLPConfig](MLPBase[ConfigType]): def __init__( self, config: ConfigType, - block_config: BlockConfig, distributed_config: DistributedConfig, *, # TODO: Review `hidden_dim` and `block_index` hidden_dim: TensorDim, - block_index: int, - name: str, lr_scale: float | None, peft: PeftConfig | None, ): - Assert.gt(config.num_experts, 1) + Assert.gt(config.experts, 1) # TODO: Implement? - assert not block_config.add_linear_biases, "Biases not supported for MoE." + assert not config.add_linear_biases, "Biases not supported for MoE." super().__init__( config, - block_config, distributed_config, hidden_dim=hidden_dim, - block_index=block_index, - name=name, lr_scale=lr_scale, peft=peft, ) self.router = self._config.router.get_layer( self._hidden_dim, - TensorDim("router_experts", self._config.num_unshared_experts), - default_weight_initialization=init_normal_(std=self._block_config.init_method_std), + TensorDim("router_experts", self._config.unshared_experts), + default_weight_initialization=init_normal_(std=self._hidden_size**-0.5), lr_scale=self._lr_scale, peft=self._peft, ) - dropless_moe = self._config.dropless_moe + dropless_moe = self._config.dropless if dropless_moe and self._sequence_parallel: warnings.warn( "Dropless MoE not supported for sequence-tensor-parallel, falling back to looped implementation." @@ -75,11 +73,11 @@ def __init__( self._mlp_forward = self._forward_dropless if dropless_moe else self._forward_looped if self._debug.enabled: - self._top_expert_dim = TensorDim("top_experts", self._config.num_experts_per_token) + self._top_expert_dim = TensorDim("top_experts", self._config.experts_per_token) def _get_intermediate_dims(self) -> tuple[TensorDim, TensorDim]: intermediate_1_dim, intermediate_2_dim = super()._get_intermediate_dims() - experts_dim = TensorDim("experts", self._config.num_experts) + experts_dim = TensorDim("experts", self._config.experts) return ( CompositeTensorDim("moe_intermediate_1", (experts_dim, intermediate_1_dim)), CompositeTensorDim("moe_intermediate_2", (experts_dim, intermediate_2_dim)), @@ -96,10 +94,10 @@ def forward( ) # Apply z_loss if applicable - if self._config.expert_z_loss_coefficient > 0.0: + if self._config.z_loss_coefficient > 0.0: logits = z_loss( logits, - self._config.expert_z_loss_coefficient, + self._config.z_loss_coefficient, self.training, grad_scale=kwargs.get("grad_output"), losses=losses, @@ -107,19 +105,19 @@ def forward( ) # Apply input_jitter if applicable: - if self.training and self._config.moe_jitter_eps > 0.0: + if self.training and self._config.jitter_eps > 0.0: with set_generator(self._distributed.pp_generator): logits = self._apply_input_jitter(logits) # Routing - if self._config.expert_routing_type == RoutingType.topk: + if self._config.routing == RoutingType.topk: scores, top_experts = self._topk_routing(logits, kwargs.get(BlockKwargs.grad_output), losses) - if self._config.num_shared_experts > 0: + if self._config.shared_experts > 0: scores, top_experts = self._add_shared_experts(top_experts, scores) - elif self._config.expert_routing_type == RoutingType.sinkhorn: + elif self._config.routing == RoutingType.sinkhorn: scores, top_experts = self._sinkhorn_routing(logits) else: - raise NotImplementedError(self._config.expert_routing_type) + raise NotImplementedError(self._config.routing) if self._debug.enabled: # To log all ranks set `global_=False` @@ -140,7 +138,7 @@ def _forward_dropless( ) -> torch.Tensor: # Compute token counts and the sparse mapping (dense_row, top_index) -> sparse_row. sparse_map = get_sparse_map( - top_experts, self._config.num_experts, dynamic_shape=self._config.dropless_dynamic_shape + top_experts, self._config.experts, dynamic_shape=self._config.dropless_dynamic_shape ) # Sparse MLP @@ -152,11 +150,11 @@ def _forward_dropless( self.layer_2.weight, None, gated=self._config.gated, - activation_type=self._config.activation_type, + activation_type=self._config.activation, group=self._parallel_dim.group, sequence_parallel=self._sequence_parallel, training=self.training, - recompute_level=self._config.mlp_recompute_level, + recompute_level=self._config.recompute_level, transposed_layer_2_weight=True, sparse_map=sparse_map, ) @@ -170,20 +168,18 @@ def _forward_looped( top_experts, self.layer_1.weight, self.layer_2.weight, - self._config.num_experts, + self._config.experts, self._config.gated, - self._config.activation_type, + self._config.activation, self._parallel_dim.group, self._sequence_parallel, self.training, - self._config.mlp_recompute_level, + self._config.recompute_level, ) @torch.compile def _apply_input_jitter(self, logits: torch.Tensor) -> torch.Tensor: - return logits * torch.empty_like(logits).uniform_( - 1.0 - self._config.moe_jitter_eps, 1.0 + self._config.moe_jitter_eps - ) + return logits * torch.empty_like(logits).uniform_(1.0 - self._config.jitter_eps, 1.0 + self._config.jitter_eps) def _topk_routing( self, @@ -191,11 +187,11 @@ def _topk_routing( grad_scale: float | None = None, losses: dict | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - top_logits, top_experts = torch.topk(logits, k=self._config.num_experts_per_token, dim=-1) + top_logits, top_experts = torch.topk(logits, k=self._config.experts_per_token, dim=-1) scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32) if losses is not None or (self.training and grad_scale is not None): probs = torch.softmax(logits, dim=-1, dtype=torch.float32) - mask = torch.nn.functional.one_hot(top_experts, num_classes=self._config.num_unshared_experts).sum(dim=1) + mask = torch.nn.functional.one_hot(top_experts, num_classes=self._config.unshared_experts).sum(dim=1) # Auxiliary loss, corresponding to the sum of probabilities for the top experts. # In the optimal case (uniform distribution), loss = experts_per_token / num_experts. # In the worst case (whole distribution in the top experts), loss = 1. @@ -208,7 +204,7 @@ def _topk_routing( scores = AuxiliaryLoss.apply( scores, aux_loss, - self._config.num_unshared_experts * self._config.expert_auxiliary_loss_coefficient * grad_scale, + self._config.unshared_experts * self._config.auxiliary_loss_coefficient * grad_scale, ) return scores, top_experts @@ -217,33 +213,54 @@ def _add_shared_experts( ) -> tuple[torch.Tensor, torch.Tensor]: # Add the shared experts (last ones) to the top experts. shared_experts = torch.arange( - self._config.num_unshared_experts, - self._config.num_experts, + self._config.unshared_experts, + self._config.experts, device=top_experts.device, dtype=top_experts.dtype, )[None].repeat(top_experts.size(0), 1) top_experts = torch.cat((shared_experts, top_experts), dim=1) # Add scores of 1 to scores for shared experts. - scores = torch.cat((scores.new_ones(scores.size(0), self._config.num_shared_experts), scores), dim=1) + scores = torch.cat((scores.new_ones(scores.size(0), self._config.shared_experts), scores), dim=1) return scores, top_experts def _sinkhorn_routing(self, logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: if self.training: - _, top_experts = torch.topk(sinkhorn(logits), k=self._config.num_experts_per_token, dim=-1) + _, top_experts = torch.topk(sinkhorn(logits), k=self._config.experts_per_token, dim=-1) logits = self._sinkhorn_activation(logits) scores = torch.gather(logits, -1, top_experts) else: logits = self._sinkhorn_activation(logits) - scores, top_experts = torch.topk(logits, k=self._config.num_experts_per_token, dim=-1) + scores, top_experts = torch.topk(logits, k=self._config.experts_per_token, dim=-1) return scores, top_experts def _sinkhorn_activation(self, logits: torch.Tensor) -> torch.Tensor: return ( torch.sigmoid(logits) - if self._config.num_experts_per_token == 1 + if self._config.experts_per_token == 1 else torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) ) + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + if kwargs[AttentionKwargs.sequence_first]: + sequence_dim, batch_dim, hidden_dim = input_.dims + else: + batch_dim, sequence_dim, hidden_dim = input_.dims + + # Applying the tokens per expert on the batch dim so the super() call works as intended. + moe_batch_dim = TensorDim( + f"moe_{batch_dim.name}", batch_dim.global_size * self._config.experts_per_token, batch_dim.parallel_dim + ) + + if kwargs[AttentionKwargs.sequence_first]: + dims = sequence_dim, moe_batch_dim, hidden_dim + else: + dims = moe_batch_dim, sequence_dim, hidden_dim + + # Also adjust the dtype in case of full-precision residual + moe_input = TensorMeta.from_dims(dims, tensor_name=f"moe_{input_.tensor_name}", dtype=input_.dtype) + + return super().get_compute_usage(moe_input, kwargs, config) + self.router.get_compute_usage(input_, config) + def sinkhorn(cost: torch.Tensor, tolerance: float = 1e-5, eps=1e-9) -> torch.Tensor: """Sinkhorn based MoE routing function""" diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index 8b6ede2d8..c88f766b0 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -1,16 +1,19 @@ +import dataclasses import typing import torch +from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd +from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.block import BlockLayer -from fast_llm.layers.block.config import BlockConfig from fast_llm.layers.block.mlp.config import MLPConfig from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.tensor import TensorMeta class MLPBase[ConfigType: MLPConfig](BlockLayer[ConfigType]): @@ -19,28 +22,22 @@ class MLPBase[ConfigType: MLPConfig](BlockLayer[ConfigType]): def __init__( self, config: ConfigType, - block_config: BlockConfig, distributed_config: DistributedConfig, *, # TODO: Review `hidden_dim` and `block_index` hidden_dim: TensorDim, - block_index: int, - name: str, lr_scale: float | None, peft: PeftConfig | None, ): super().__init__( config, - block_config, distributed_config, hidden_dim=hidden_dim, - block_index=block_index, - name=name, lr_scale=lr_scale, peft=peft, ) self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) - intermediate_1_dim, intermediate_2_dim = self._get_intermediate_dims() + intermediate_1_dim, self._intermediate_2_dim = self._get_intermediate_dims() self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation @@ -48,19 +45,17 @@ def __init__( self.layer_1 = self._config.layer_1.get_layer( hidden_dim, intermediate_1_dim, - default_weight_initialization=init_normal_(std=self._block_config.init_method_std), - default_add_bias=self._block_config.add_linear_biases, + default_weight_initialization=init_normal_(std=self._hidden_size**-0.5), + default_add_bias=self._config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=self._lr_scale, peft=self._peft, ) - self.layer_2 = self._config.layer_1.get_layer( - intermediate_2_dim, + self.layer_2 = self._config.layer_2.get_layer( + self._intermediate_2_dim, hidden_dim, - default_weight_initialization=init_normal_( - std=self._block_config.init_method_std / max(2 * self._block_config.num_layers, 1) ** 0.5 - ), - default_add_bias=self._block_config.add_linear_biases, + default_weight_initialization=init_normal_(std=self._hidden_size**-0.5), + default_add_bias=self._config.add_linear_biases, sequence_parallel=self._sequence_parallel, transposed_weight=True, lr_scale=self._lr_scale, @@ -68,7 +63,7 @@ def __init__( ) def _get_intermediate_dims(self): - intermediate_2_dim = TensorDim("intermediate", self._config.ffn_hidden_size, self._parallel_dim) + intermediate_2_dim = TensorDim("intermediate", self._config.intermediate_size, self._parallel_dim) if self._config.gated: TensorDim("gate_and_up", 2) intermediate_1_dim = ConcatenatedTensorDim("gate_and_up", (intermediate_2_dim, intermediate_2_dim)) @@ -76,6 +71,33 @@ def _get_intermediate_dims(self): intermediate_1_dim = intermediate_2_dim return intermediate_1_dim, intermediate_2_dim + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + # TODO: Generalize? + layer_1_config = ( + dataclasses.replace(config, forward=config.forward + config.backward) + if config.hardware and self._config.recompute_level.recompute_layer_1 + else config + ) + + # Get the layer 2 input dims, accounting for ordering and possible sequence-parallelism. + # TODO: Don't rely on kwargs dimensions. + if kwargs[AttentionKwargs.sequence_first]: + dims = (kwargs[AttentionKwargs.sequence_q_dim], input_.dims[1], self._intermediate_2_dim) + else: + dims = (input_.dims[0], kwargs[AttentionKwargs.sequence_q_dim], self._intermediate_2_dim) + # Also adjust the dtype in case of full-precision residual + layer_2_input = TensorMeta.from_dims( + dims, tensor_name="intermediate_1", dtype=self._distributed_config.training_dtype.torch + ) + + # TODO: Add marginal compute? (ex. activation, gate + up) + return sum( + ( + self.layer_1.get_compute_usage(input_, layer_1_config), + self.layer_2.get_compute_usage(layer_2_input, config), + ) + ) + class MLP[ConfigType: MLPConfig](MLPBase[ConfigType]): _config: MLPConfig @@ -96,11 +118,11 @@ def forward( self.layer_2.weight, None if self._parallel_dim.group else self.layer_2.bias, gated=self._config.gated, - activation_type=self._config.activation_type, + activation_type=self._config.activation, group=self._parallel_dim.group, sequence_parallel=self._sequence_parallel, training=self.training, - recompute_level=self._config.mlp_recompute_level, + recompute_level=self._config.recompute_level, transposed_layer_2_weight=self.layer_2.transposed_weight, ), self.layer_2.bias if self._parallel_dim.group else None, diff --git a/fast_llm/layers/common/linear/convolution.py b/fast_llm/layers/common/linear/convolution.py index 57fdccfd5..b88b7b2e6 100644 --- a/fast_llm/layers/common/linear/convolution.py +++ b/fast_llm/layers/common/linear/convolution.py @@ -1,7 +1,8 @@ import torch +from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.functional.config import ActivationType -from fast_llm.tensor import ParameterMeta +from fast_llm.tensor import ParameterMeta, TensorMeta try: from causal_conv1d import causal_conv1d_fn as _causal_conv1d_fn # noqa @@ -51,3 +52,6 @@ def _forward_causal_conv1d(self, input_: torch.Tensor) -> torch.Tensor: self.bias, activation=(None if self._activation == ActivationType.identity else self._activation.value), ) + + def get_compute_usage(self, input_: TensorMeta, config: ResourceUsageConfig) -> int: + raise NotImplementedError() diff --git a/fast_llm/layers/common/linear/linear.py b/fast_llm/layers/common/linear/linear.py index 631193249..3028fd1e9 100644 --- a/fast_llm/layers/common/linear/linear.py +++ b/fast_llm/layers/common/linear/linear.py @@ -3,6 +3,7 @@ import torch +from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.distributed.config import DistributedDim from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.linear import ( @@ -14,7 +15,8 @@ output_parallel_linear_backward, output_parallel_linear_forward, ) -from fast_llm.tensor import ParameterMeta +from fast_llm.tensor import ParameterMeta, TensorMeta +from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -33,6 +35,9 @@ def forward_only(self, input_: torch.Tensor): def backward(self, grad_output: torch.Tensor, context: typing.Any) -> torch.Tensor: raise NotImplementedError() + def get_compute_usage(self, input_: TensorMeta, config: ResourceUsageConfig) -> int: + raise NotImplementedError() + class LinearBase(LinearLike): """ @@ -50,11 +55,24 @@ def __init__( self.weight = weight self.bias = bias self._transposed_weight = transposed_weight + if self._transposed_weight: + self._input_dim, self._output_dim = self.weight.dims + else: + self._output_dim, self._input_dim = self.weight.dims @property def transposed_weight(self) -> bool: return self._transposed_weight + def get_compute_usage(self, input_: TensorMeta, config: ResourceUsageConfig) -> int: + Assert.eq(input_.size(-1), self._input_dim.size) + return ( + 2 + * (config.forward + 2 * config.backward) + * (input_.global_shape if config.global_ else input_).numel() + * (self._output_dim.global_size if config.global_ else self._output_dim.size) + ) + class Linear(LinearBase): """ diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 79772bf82..c15515fb5 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -7,6 +7,7 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockLayerConfig +from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert @@ -63,6 +64,18 @@ class LanguageModelEmbeddingsConfig(BlockLayerConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + dropout: float = Field( + default=0.0, + desc="Dropout applied to the embedding layer.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + full_precision_residual: bool = Field( + default=False, + desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", + hint=FieldHint.stability, + ) + # Tensor-parallel word embeddings # (Default init std is different, dropout won't match, needs seq_first = False.) # (disable to allow for sequence-parallel embeddings and logits, better for larger models) @@ -93,6 +106,10 @@ def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Prepr @config_class() class LanguageModelHeadConfig(BlockLayerConfig): _abstract = False + normalization: NormalizationConfig = Field( + desc="Configuration for the final normalization layer.", + hint=FieldHint.architecture, + ) # TODO: Cleanup output_weight: ParameterConfig = Field( desc="Configuration for the LM output layer (weight). Ignored for tied embeddings", @@ -219,25 +236,19 @@ def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Prepr def get_layer( self, - block_config: "BlockConfig", distributed_config: DistributedConfig, embeddings_config: LanguageModelEmbeddingsConfig, *, hidden_dim: TensorDim, - block_index: int, - name: str, lr_scale: float | None, peft: PeftConfig | None, prediction_distance: int = 0, ): return self.layer_class( self, - block_config, distributed_config, embeddings_config, hidden_dim=hidden_dim, - block_index=block_index, - name=name, lr_scale=combine_lr_scales(lr_scale, self.lr_scale), peft=peft, prediction_distance=prediction_distance, diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 98904c5e5..b7a780a33 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -5,11 +5,11 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import reduce_forward, split from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.layers.block.block import BlockLayerBase -from fast_llm.layers.block.config import BlockConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, LanguageModelKwargs from fast_llm.tensor import TensorMeta @@ -32,30 +32,22 @@ class LanguageModelEmbedding[ConfigType: LanguageModelEmbeddingsConfig](BlockLay def __init__( self, config: ConfigType, - # TODO: Doesn't make much sense. - block_config: BlockConfig, distributed_config: DistributedConfig, *, - # TODO: Review `hidden_dim` and `block_index` hidden_dim: TensorDim, - block_index: int, - name: str, lr_scale: float | None, peft: PeftConfig | None, ): super().__init__( config, - block_config, distributed_config, hidden_dim=hidden_dim, - block_index=block_index, - name=name, lr_scale=lr_scale, peft=peft, ) self._residual_dtype = ( self._distributed_config.optimization_dtype - if self._block_config.full_precision_residual + if self._config.full_precision_residual else self._distributed_config.training_dtype ).torch self._sequence_parallel = self._distributed_config.sequence_tensor_parallel @@ -69,13 +61,13 @@ def __init__( self.word_embeddings_weight = self._config.word_embeddings.get_parameter( (vocab_dim, self._hidden_dim), - default_initialization=init_normal_(std=self._block_config.init_method_std), + default_initialization=init_normal_(std=self._hidden_size**-0.5), lr_scale=self._lr_scale, peft=self._peft, ) self.position_embeddings_weight = self._config.position_embeddings.get_parameter( (TensorDim("position_embeddings", self._config.num_position_embeddings), self._hidden_dim), - default_initialization=init_normal_(std=self._block_config.init_method_std), + default_initialization=init_normal_(std=self._hidden_size**-0.5), allow_sequence_tensor_parallel=not self._vocab_parallel, lr_scale=self._lr_scale, peft=self._peft, @@ -113,7 +105,7 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask with set_generator( self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator ): - embeddings = torch.dropout(embeddings, self._block_config.hidden_dropout, self.training) + embeddings = torch.dropout(embeddings, self._config.dropout, self.training) return embeddings.to(dtype=self._residual_dtype) def forward( @@ -132,3 +124,7 @@ def forward( return self._forward( input_, kwargs.get(LanguageModelKwargs.position_ids), kwargs.get(LanguageModelKwargs.mask_inputs) ) + + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + # TODO: Add marginal compute? (embeddings) + return 0 diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 326bfe313..e71512915 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -1,4 +1,5 @@ import logging +import typing import torch from torch._C._distributed_c10d import ReduceOp # noqa @@ -6,6 +7,7 @@ from fast_llm.core.ops import split_op from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames @@ -15,7 +17,7 @@ from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.block.block import BlockLayerBase -from fast_llm.layers.block.config import BlockConfig, BlockDimNames +from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import ( @@ -44,26 +46,18 @@ class LanguageModelHead[ConfigType: LanguageModelHeadConfig](BlockLayerBase[Conf def __init__( self, config: ConfigType, - # TODO: Doesn't make much sense. - block_config: BlockConfig, distributed_config: DistributedConfig, embeddings_config: LanguageModelEmbeddingsConfig, *, - # TODO: Review `hidden_dim` and `block_index` hidden_dim: TensorDim, - block_index: int, - name: str, lr_scale: float | None, peft: PeftConfig | None, prediction_distance: int, ): super().__init__( config, - block_config, distributed_config, hidden_dim=hidden_dim, - block_index=block_index, - name=name, lr_scale=lr_scale, peft=peft, ) @@ -100,8 +94,8 @@ def __init__( self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) - self.final_norm = self._block_config.normalization.get_layer( - hidden_dim, lr_scale=self._lr_scale, peft=self._peft + self.final_norm = self._config.normalization.get_layer( + self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft ) self._vocab_dim = TensorDim( @@ -111,10 +105,8 @@ def __init__( if self._prediction_distance == 0 and not self._config.tied_weight: # untie embedding weights self.output_weights = self._config.output_weight.get_parameter( - (self._vocab_dim, hidden_dim), - default_initialization=init_normal_( - std=self._block_config.init_method_std, - ), + (self._vocab_dim, self._hidden_dim), + default_initialization=init_normal_(std=self._hidden_size**-0.5), lr_scale=self._lr_scale, peft=self._peft, ) @@ -156,6 +148,15 @@ def forward( # MTP: Return shared_hidden to be used by the next head. return shared_hidden + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + # TODO: Add marginal compute? (loss) + return ( + 2 + * (config.forward + 2 * config.backward) + * (input_.global_shape if config.global_ else input_).numel() + * (self._vocab_dim.global_size if config.global_ else self._vocab_dim.size) + ) + def _forward_backward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None ) -> tuple[torch.Tensor, torch.Tensor | None]: diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index a81f29833..862a6cd1a 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -94,6 +94,13 @@ class SSMConfig(MixerConfig): hint=FieldHint.core, ) + # Model options + add_linear_biases: bool = Field( + default=True, + desc="Add biases to linear layers. May be overridden for individual layers.", + hint=FieldHint.architecture, + ) + def set_defaults(self, hidden_size: int): if self.d_inner is None: self.d_inner = 2 * hidden_size diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 83e02c7ac..a7d059781 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -4,15 +4,16 @@ import einops import torch +from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_, init_ones_, init_zeros_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer -from fast_llm.layers.block.config import BlockConfig, BlockKwargs +from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.ssm.config import DiscreteMamba2Config -from fast_llm.tensor import ParameterMeta +from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import div logger = logging.getLogger(__name__) @@ -36,23 +37,16 @@ class DiscreteMamba2[ConfigType: DiscreteMamba2Config](BlockLayer[ConfigType]): def __init__( self, config: ConfigType, - block_config: BlockConfig, distributed_config: DistributedConfig, *, - # TODO: Review `hidden_dim` and `block_index` hidden_dim: TensorDim, - block_index: int, - name: str, lr_scale: float | None, peft: PeftConfig | None, ): super().__init__( config, - block_config, distributed_config, hidden_dim=hidden_dim, - block_index=block_index, - name=name, lr_scale=lr_scale, peft=peft, ) @@ -92,7 +86,7 @@ def __init__( hidden_dim, inner_projection_dim, default_weight_initialization=init_normal_(0, (2 / self._config.d_inner) ** 0.5), - default_add_bias=self._block_config.add_linear_biases, + default_add_bias=self._config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=self._lr_scale, peft=self._peft, @@ -126,7 +120,7 @@ def __init__( inner_dim, hidden_dim, default_weight_initialization=init_normal_(0, (2 / self._config.d_inner) ** 0.5), - default_add_bias=self._block_config.add_linear_biases, + default_add_bias=self._config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=self._lr_scale, peft=self._peft, @@ -214,3 +208,7 @@ def forward( @torch.compile def _apply_a_log(self, x: torch.Tensor, A_log: torch.Tensor) -> torch.Tensor: return x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1) + + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + # TODO: Implement. + raise NotImplementedError() diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index e98201c67..5caa1a97c 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -3,14 +3,16 @@ import torch +from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_, init_ones_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer -from fast_llm.layers.block.config import BlockConfig, BlockKwargs +from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.ssm.config import MambaConfig, init_a, init_dtprojbias +from fast_llm.tensor import TensorMeta from fast_llm.utils import div try: @@ -35,23 +37,16 @@ class Mamba[ConfigType: MambaConfig](BlockLayer[ConfigType]): def __init__( self, config: ConfigType, - block_config: BlockConfig, distributed_config: DistributedConfig, *, - # TODO: Review `hidden_dim` and `block_index` hidden_dim: TensorDim, - block_index: int, - name: str, lr_scale: float | None, peft: PeftConfig | None, ): super().__init__( config, - block_config, distributed_config, hidden_dim=hidden_dim, - block_index=block_index, - name=name, lr_scale=lr_scale, peft=peft, ) @@ -70,7 +65,7 @@ def __init__( hidden_dim, inner_projection_dim, default_weight_initialization=init_normal_(0, (2 / hidden_dim.global_size) ** 0.5), - default_add_bias=self._block_config.add_linear_biases, + default_add_bias=self._config.add_linear_biases, lr_scale=self._lr_scale, peft=self._peft, ) @@ -153,3 +148,7 @@ def forward( if kwargs[BlockKwargs.sequence_first]: out = out.transpose(0, 1) return out, None + + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + # TODO: Implement. + raise NotImplementedError() diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 9c7c2e97c..b48f100db 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -3,14 +3,16 @@ import torch +from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_, init_ones_, init_uniform_centered_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer -from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs +from fast_llm.layers.block.config import BlockDimNames, BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.ssm.config import Mamba2Config, init_a, init_dtprojbias +from fast_llm.tensor import TensorMeta from fast_llm.utils import div try: @@ -33,23 +35,16 @@ class Mamba2[ConfigType: Mamba2Config](BlockLayer[ConfigType]): def __init__( self, config: ConfigType, - block_config: BlockConfig, distributed_config: DistributedConfig, *, - # TODO: Review `hidden_dim` and `block_index` hidden_dim: TensorDim, - block_index: int, - name: str, lr_scale: float | None, peft: PeftConfig | None, ): super().__init__( config, - block_config, distributed_config, hidden_dim=hidden_dim, - block_index=block_index, - name=name, lr_scale=lr_scale, peft=peft, ) @@ -95,7 +90,7 @@ def __init__( hidden_dim, inner_projection_dim, default_weight_initialization=init_normal_(0, (2 / hidden_dim.global_size) ** 0.5), - default_add_bias=self._block_config.add_linear_biases, + default_add_bias=self._config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=self._lr_scale, peft=self._peft, @@ -104,7 +99,7 @@ def __init__( hidden_dim, dt_rank_dim, default_weight_initialization=init_normal_(0, (2 / hidden_dim.global_size) ** 0.5), - default_add_bias=self._block_config.add_linear_biases, + default_add_bias=self._config.add_linear_biases, lr_scale=self._lr_scale, peft=self._peft, ) @@ -136,24 +131,24 @@ def __init__( inner_dim, hidden_dim, default_weight_initialization=init_normal_(0, (2 / self._config.d_inner) ** 0.5), - default_add_bias=self._block_config.add_linear_biases, + default_add_bias=self._config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=self._lr_scale, peft=self._peft, ) - if self._debug.enabled: - self._xz_dims = ( - BlockDimNames.batch, - inner_dim, - BlockDimNames.sequence_q, - ) - self._bc_dims = ( - BlockDimNames.batch, - heads_dim, - state_dim, - BlockDimNames.sequence_q, - ) + # Debug dims + self._xz_dims = ( + BlockDimNames.batch, + inner_dim, + BlockDimNames.sequence_q, + ) + self._bc_dims = ( + BlockDimNames.batch, + heads_dim, + state_dim, + BlockDimNames.sequence_q, + ) def forward( self, @@ -243,3 +238,7 @@ def forward( # (batch/sequence, sequence/batch, local_heads * state) # -> (batch/local_sequence, local_sequence/batch, hidden) return self.out_proj(y) + + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + # TODO: Implement. + raise NotImplementedError() diff --git a/fast_llm/logging.py b/fast_llm/logging.py index 024d7d79c..1bc30aeab 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -343,3 +343,15 @@ def log_memory_usage[ if header is not None: formatted = f"{header}: {formatted}" return log(formatted, log_fn=log_fn) + + +_model_debug_level = 0 + + +def get_model_debug_level() -> int: + return _model_debug_level + + +def set_model_debug_level(level: int) -> None: + global _model_debug_level + _model_debug_level = level diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 365f84d52..3cb954e1d 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -46,7 +46,7 @@ from fast_llm.models.gpt.external.mtp_llama.configuration_mtp_llama import MTPLlamaConfig from fast_llm.models.gpt.model import GPTModel from fast_llm.tensor import SafeTensorSlice -from fast_llm.utils import Assert +from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: pass @@ -54,6 +54,26 @@ logger = logging.getLogger(__name__) +@dataclasses.dataclass +class HiddenSizeParamConverter(ParamConverter): + """ + Some HF models don't have a `head_dim` parameter, and instead use hidden_size // heads + """ + + def __post_init__(self): + Assert.eq(len(self.fast_llm_names), 3) + Assert.eq(len(self.export_names), 2) + + def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + hidden_size, heads, head_size = fast_llm_values + Assert.eq(head_size * heads, hidden_size) + return hidden_size, heads + + def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + hidden_size, heads = export_values + return hidden_size, heads, div(hidden_size, heads) + + class QueryWeightConverter(WeightConverter): # Hf uses the real format for rotary embeddings. _config: GPTBaseModelConfig @@ -63,7 +83,7 @@ def export_weight( ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: (query,) = weight if self._config.transformer.mixer.rotary.complex_format: - query = convert_rotary_complex_to_real(query[:], self._config.transformer.mixer.kv_channels, 0) + query = convert_rotary_complex_to_real(query[:], self._config.transformer.mixer.head_size, 0) return (query,) def import_weight( @@ -71,7 +91,7 @@ def import_weight( ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: (query,) = weight if self._config.transformer.mixer.rotary.complex_format: - query = convert_rotary_real_to_complex(query[:], self._config.transformer.mixer.kv_channels, 0) + query = convert_rotary_real_to_complex(query[:], self._config.transformer.mixer.head_size, 0) return (query,) @@ -85,7 +105,7 @@ def export_weight( (key_value,) = weight key, value = key_value[:].chunk(2) if self._config.transformer.mixer.rotary.complex_format: - key = convert_rotary_complex_to_real(key, self._config.transformer.mixer.kv_channels, 0) + key = convert_rotary_complex_to_real(key, self._config.transformer.mixer.head_size, 0) return key, value def import_weight( @@ -93,7 +113,7 @@ def import_weight( ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: key, value = weight if self._config.transformer.mixer.rotary.complex_format: - key = convert_rotary_real_to_complex(key[:], self._config.transformer.mixer.kv_channels, 0) + key = convert_rotary_real_to_complex(key[:], self._config.transformer.mixer.head_size, 0) key_value = torch.cat([key[:], value[:]]) return (key_value,) @@ -142,7 +162,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: fast_llm_names=(("transformer", "mixer", "rotary", "theta"),), export_names=(("rope_theta",),) ), MappedConfigParamConverter( - fast_llm_names=(("transformer", "mlp", "activation_type"),), + fast_llm_names=(("transformer", "mlp", "activation"),), export_names=(("hidden_act",),), fast_llm_value=ActivationType.from_hf_name, export_value=lambda activation_type: activation_type.hf_name, @@ -151,20 +171,12 @@ def _create_config_converters(cls) -> list[ParamConverter]: fast_llm_names=(("transformer", "num_layers"),), export_names=(("num_hidden_layers",),), ), - RenameParamConverter( - fast_llm_names=(("transformer", "hidden_size"),), - export_names=(("hidden_size",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "mixer", "num_attention_heads"),), - export_names=(("num_attention_heads",),), - ), RenameParamConverter( fast_llm_names=(("transformer", "mixer", "head_groups"),), export_names=(("num_key_value_heads",),), ), RenameParamConverter( - fast_llm_names=(("transformer", "mlp", "ffn_hidden_size"),), + fast_llm_names=(("transformer", "mlp", "intermediate_size"),), export_names=(("intermediate_size",),), ), RenameParamConverter( @@ -219,21 +231,21 @@ def _create_transformer_layer_converters( f"{fast_llm_layer_name}.mixer.query", f"{hf_layer_name}.self_attn.q_proj", # TODO: Fix - transformer_config.add_linear_biases, + transformer_config.mixer.add_linear_biases, QueryWeightConverter, ), ( f"{fast_llm_layer_name}.mixer.key_value", (f"{hf_layer_name}.self_attn.k_proj", f"{hf_layer_name}.self_attn.v_proj"), # TODO: Fix - transformer_config.add_linear_biases, + transformer_config.mixer.add_linear_biases, KeyValueWeightConverter, ), ( f"{fast_llm_layer_name}.mixer.dense", f"{hf_layer_name}.self_attn.o_proj", # TODO: Fix - transformer_config.add_linear_biases, + transformer_config.mixer.add_linear_biases, WeightConverter, ), # Norm @@ -264,14 +276,14 @@ def _create_transformer_layer_converters( f"{fast_llm_layer_name}.mlp.layer_1", (), # TODO: Fix - transformer_config.add_linear_biases, + transformer_config.mlp.add_linear_biases, cls=IgnoreExportWeightConverter, ) converters += self._get_weight_and_bias_converters( f"{fast_llm_layer_name}.mlp.layer_2", (), # TODO: Fix - transformer_config.add_linear_biases, + transformer_config.mlp.add_linear_biases, cls=IgnoreExportWeightConverter, ) converters += [IgnoreExportWeightConverter(f"{fast_llm_layer_name}.mlp.router.weight", ())] @@ -349,10 +361,21 @@ class Starcoder2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler) @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ + HiddenSizeParamConverter( + fast_llm_names=( + ("transformer", "hidden_size"), + ("transformer", "mixer", "heads"), + ("transformer", "mixer", "head_size"), + ), + export_names=(("hidden_size",), ("num_attention_heads",)), + ), ConstantImportParamConverter( fast_llm_names=(("transformer", "mixer", "rotary", "type"),), fast_llm_value=DefaultRotaryConfig.dynamic_type_name, ), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "mixer", "add_linear_biases"),), fast_llm_value=True + ), ConstantImportParamConverter( fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value="layer_norm", @@ -361,7 +384,13 @@ def _create_config_converters(cls) -> list[ParamConverter]: fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) ), ConstantImportParamConverter(fast_llm_names=(("transformer", "mlp", "gated"),), fast_llm_value=False), - ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=True), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "mlp", "add_linear_biases"),), fast_llm_value=True + ), + ConstantImportParamConverter( + fast_llm_names=(("output_layer", "normalization", "type"),), + fast_llm_value="layer_norm", + ), ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: @@ -371,13 +400,13 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig f"{fast_llm_prefix}.mlp.layer_1", f"{hf_prefix}.mlp.c_fc", # TODO: Fix - transformer_config.add_linear_biases, + transformer_config.mlp.add_linear_biases, ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.c_proj", # TODO: Fix - transformer_config.add_linear_biases, + transformer_config.mlp.add_linear_biases, MLPLayer2Converter, ), ] @@ -395,11 +424,24 @@ def _create_config_converters(cls) -> list[ParamConverter]: fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) ), RenameParamConverter( - fast_llm_names=(("transformer", "mixer", "kv_channels"),), + fast_llm_names=(("transformer", "hidden_size"),), + export_names=(("hidden_size",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "mixer", "heads"),), + export_names=(("num_attention_heads",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "mixer", "head_size"),), export_names=(("head_dim",),), ), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "mixer", "add_linear_biases"),), fast_llm_value=False + ), ConstantImportParamConverter(fast_llm_names=(("transformer", "mlp", "gated"),), fast_llm_value=True), - ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=False), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "mlp", "add_linear_biases"),), fast_llm_value=False + ), LLamaRotaryParamConverter( fast_llm_names=(("transformer", "mixer", "rotary"),), export_names=( @@ -407,6 +449,10 @@ def _create_config_converters(cls) -> list[ParamConverter]: ("rope_scaling",), ), ), + ConstantImportParamConverter( + fast_llm_names=(("output_layer", "normalization", "type"),), + fast_llm_value="rms_norm", + ), ] @@ -484,14 +530,14 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig f"{fast_llm_prefix}.mlp.layer_1", (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), # TODO: Fix - transformer_config.add_linear_biases, + transformer_config.mlp.add_linear_biases, SplitWeightConverter, ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.down_proj", # TODO: Fix - transformer_config.add_linear_biases, + transformer_config.mlp.add_linear_biases, MLPLayer2Converter, ), ] @@ -533,6 +579,14 @@ def _create_config_converters(cls) -> list[ParamConverter]: RenameParamConverter( fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) ), + HiddenSizeParamConverter( + fast_llm_names=( + ("transformer", "hidden_size"), + ("transformer", "mixer", "heads"), + ("transformer", "mixer", "head_size"), + ), + export_names=(("hidden_size",), ("num_attention_heads",)), + ), ConstantImportParamConverter(fast_llm_names=(("transformer", "mlp", "gated"),), fast_llm_value=True), # TODO: Fix ConstantImportParamConverter( @@ -545,6 +599,10 @@ def _create_config_converters(cls) -> list[ParamConverter]: ("rope_scaling",), ), ), + ConstantImportParamConverter( + fast_llm_names=(("output_layer", "normalization", "type"),), + fast_llm_value="rms_norm", + ), IgnoreImportQwen2SlidingWindowParamsConverter(), ] @@ -555,14 +613,14 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig f"{fast_llm_prefix}.mlp.layer_1", (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), # TODO: Fix - transformer_config.add_linear_biases, + transformer_config.mlp.add_linear_biases, SplitWeightConverter, ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.down_proj", # TODO: Fix - transformer_config.add_linear_biases, + transformer_config.mlp.add_linear_biases, MLPLayer2Converter, ), ] @@ -601,20 +659,20 @@ def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantImportParamConverter(fast_llm_names=(("transformer", "mlp", "type"),), fast_llm_value="moe"), ConstantImportParamConverter( - fast_llm_names=(("transformer", "mlp", "expert_routing_type"),), fast_llm_value=RoutingType.topk + fast_llm_names=(("transformer", "mlp", "routing"),), fast_llm_value=RoutingType.topk ), RenameParamConverter( - fast_llm_names=(("transformer", "mlp", "num_experts"),), export_names=(("num_local_experts",),) + fast_llm_names=(("transformer", "mlp", "experts"),), export_names=(("num_local_experts",),) ), RenameParamConverter( - fast_llm_names=(("transformer", "mlp", "num_experts_per_token"),), + fast_llm_names=(("transformer", "mlp", "experts_per_token"),), export_names=(("num_experts_per_tok",),), ), IgnoreImportParamConverter(export_names=(("sliding_window",),), ignore_export_value=None), ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - num_experts = self._model.config.base_model.transformer.mlp.num_experts + num_experts = self._model.config.base_model.transformer.mlp.experts return [ WeightConverter(f"{fast_llm_prefix}.mlp.router.weight", f"{hf_prefix}.block_sparse_moe.gate.weight"), SplitWeightConverter( @@ -671,14 +729,14 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig f"{fast_llm_prefix}.mlp.layer_1", (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), # TODO: Fix - transformer_config.add_linear_biases, + transformer_config.mlp.add_linear_biases, SplitWeightConverter, ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.down_proj", # TODO: Fix - transformer_config.add_linear_biases, + transformer_config.mlp.add_linear_biases, MLPLayer2Converter, ), ] diff --git a/fast_llm/models/gpt/megatron.py b/fast_llm/models/gpt/megatron.py index 5c73dbb23..bbe7ae43f 100644 --- a/fast_llm/models/gpt/megatron.py +++ b/fast_llm/models/gpt/megatron.py @@ -27,7 +27,7 @@ def init_megatron(tensor: "torch.Tensor", distributed: "Distributed") -> None: tensor_ = _init_position_embeddings_megatron(meta, tensor, distributed) elif "mlp.router.weight" in meta.tensor_name: tensor_ = _init_moe_router_megatron(meta, tensor, distributed) - elif isinstance(config.mlp, MoEMLPConfig) and config.mlp.num_experts > 1 and "mlp.layer_" in meta.tensor_name: + elif isinstance(config.mlp, MoEMLPConfig) and config.mlp.experts > 1 and "mlp.layer_" in meta.tensor_name: tensor_ = _init_moe_mlp_megatron(config, meta, tensor, distributed) elif "mlp.layer_2" in meta.tensor_name: tensor_ = _init_transposed_mlp_weight_megatron(meta, tensor, distributed) @@ -62,19 +62,19 @@ def _init_attention_megatron( meta.param_init_method( meta, dense_tensor_ := tensor.new_empty( - config.mixer.kv_channels * config.mixer.num_attention_heads, + config.mixer.head_size * config.mixer.heads, config.hidden_size, ), generator, ) # QKV is split differently. (Assuming no tensor-parallel.) - heads_per_group = div(config.mixer.num_attention_heads, config.mixer.head_groups) + heads_per_group = div(config.mixer.heads, config.mixer.head_groups) meta.param_init_method( meta, qkv_tensor_ := tensor.new_empty( config.mixer.head_groups, heads_per_group + 2, - config.mixer.kv_channels, + config.mixer.head_size, config.hidden_size, ), generator, @@ -97,9 +97,9 @@ def _init_attention_megatron( if isinstance(config.mixer.rotary, DefaultRotaryConfig) and config.mixer.rotary.complex_format: from fast_llm.layers.attention.rotary.config import convert_rotary_real_to_complex - # Megatron uses (2, kv_channels/2) for the complex split; we use (kv_channels/2, 2). + # Megatron uses (2, head_size/2) for the complex split; we use (head_size/2, 2). # TODO: Avoid unnecessarily changing the value and dense tensors. - tensor_ = convert_rotary_real_to_complex(tensor_.view_as(meta), config.mixer.kv_channels, kv_dim) + tensor_ = convert_rotary_real_to_complex(tensor_.view_as(meta), config.mixer.head_size, kv_dim) return tensor_ @@ -148,12 +148,12 @@ def _init_moe_mlp_megatron( # self.param_init_method(self, tensor, generator) state = generator.get_state() weight_1 = tensor.new_empty( - config.mlp.num_experts * (1 + config.mlp.gated) * config.mlp.ffn_hidden_size, config.hidden_size + config.mlp.experts * (1 + config.mlp.gated) * config.mlp.intermediate_size, config.hidden_size ) - weight_2 = tensor.new_empty(config.mlp.num_experts * config.mlp.ffn_hidden_size, config.hidden_size) - for chunk_1, chunk_2 in zip(weight_1.chunk(config.mlp.num_experts), weight_2.chunk(config.mlp.num_experts)): + weight_2 = tensor.new_empty(config.mlp.experts * config.mlp.intermediate_size, config.hidden_size) + for chunk_1, chunk_2 in zip(weight_1.chunk(config.mlp.experts), weight_2.chunk(config.mlp.experts)): meta.param_init_method(meta, chunk_1, generator) - chunk_2_ = chunk_2.new_empty(config.hidden_size, config.mlp.ffn_hidden_size) + chunk_2_ = chunk_2.new_empty(config.hidden_size, config.mlp.intermediate_size) meta.param_init_method(meta, chunk_2_, generator) chunk_2.copy_(chunk_2_.t()) if "layer_1.weight" in meta.tensor_name: diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index cfd9ae546..8b2947837 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -88,8 +88,6 @@ def _get_block( return self._config.transformer.get_layer( self._distributed_config, hidden_dim=self._hidden_dim, - block_index=block_index, - name=name, lr_scale=None, peft=self._config.peft, return_input=return_input, @@ -97,23 +95,17 @@ def _get_block( def _get_embeddings(self): return self._config.embeddings_layer.get_layer( - self._config.transformer, self._distributed_config, hidden_dim=self._hidden_dim, - block_index=0, - name="Embeddings", lr_scale=None, peft=self._config.peft, ) def _get_head(self, prediction_distance): return self._config.output_layer.get_layer( - self._config.transformer, self._distributed_config, self._config.embeddings_layer, hidden_dim=self._hidden_dim, - block_index=max(self._config.transformer.num_layers + prediction_distance, 1), - name=f"Language model head {prediction_distance}", lr_scale=None, peft=self._config.peft, prediction_distance=prediction_distance, @@ -382,8 +374,8 @@ def loss_defs(self) -> list[LossDef]: loss_defs = [] if ( isinstance(self._config.transformer.mlp, MoEMLPConfig) - and self._config.transformer.mlp.num_experts > 1 - and self._config.transformer.mlp.expert_routing_type == RoutingType.topk + and self._config.transformer.mlp.experts > 1 + and self._config.transformer.mlp.routing == RoutingType.topk ): loss_defs.append( LossDef( @@ -392,7 +384,7 @@ def loss_defs(self) -> list[LossDef]: count=self._config.transformer.num_layers, ) ) - if self._config.transformer.mlp.expert_z_loss_coefficient: + if self._config.transformer.mlp.z_loss_coefficient: loss_defs.append( LossDef( name=MLPLossNames.router_z_loss, @@ -429,66 +421,6 @@ def loss_defs(self) -> list[LossDef]: class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): base_model_class: typing.ClassVar[type[GPTBaseModel]] = GPTBaseModel - def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration, batch_size, sequence_length) -> tuple[int, int]: - # TODO: Do in model, automate/generalize, get other stats - """Get tflop/s/GPU from global-batch-size and elapsed-time""" - checkpoint_activations_factor = 3 if phase == PhaseType.training else 1 - transformer_config = self._config.base_model.transformer - - consumed_tokens_per_iteration = sequence_length * batch_size - - num_transformer_layers = ( - transformer_config.num_layers + self._config.base_model.output_layer.prediction_heads - 1 - ) - transformer_flops_base = ( - 2 * checkpoint_activations_factor * consumed_tokens_per_iteration * num_transformer_layers - ) - dense_flops_base = transformer_flops_base * transformer_config.hidden_size - # Query, key, value, dense. - flops_per_iteration = ( - 2 - * (transformer_config.mixer.num_attention_heads + transformer_config.mixer.head_groups) - * transformer_config.mixer.kv_channels - * dense_flops_base - ) - # MLP - flops_per_iteration += ( - (2 + transformer_config.mlp.gated) - * transformer_config.mlp.ffn_hidden_size - * dense_flops_base - * (transformer_config.mlp.num_experts_per_token if isinstance(transformer_config.mlp, MoEMLPConfig) else 1) - ) - - # LM-head - flops_per_iteration += ( - 6 - * consumed_tokens_per_iteration - * transformer_config.hidden_size - * self._config.base_model.embeddings_layer.vocab_size - * self._config.base_model.output_layer.prediction_heads - ) - - # Attention-matrix computation - attn_flops_base = transformer_flops_base * transformer_config.mixer.projection_size - if transformer_config.mixer.window_size is None: - # Ignore masked values (s**2/2) - attn_flops = attn_flops_base * sequence_length - model_tflops = flops_per_iteration + attn_flops - else: - # s*w - w**2/2 - attn_flops = ( - 2 - * attn_flops_base - * transformer_config.mixer.window_size - * (1 - transformer_config.mixer.window_size / 2 / sequence_length) - ) - model_tflops = flops_per_iteration + attn_flops - - # Partial recomputation (normal is 2 ops * ckpt_factor = 6, adding 1 for recomputing Q x K) - hardware_flops = flops_per_iteration + 7 / 6 * attn_flops - ratio = elapsed_time_per_iteration * self._config.distributed.world_size * 1e12 - return model_tflops / ratio, hardware_flops / ratio - class GPTInferenceRunner(InferenceRunner): model_class: typing.ClassVar[type[GPTModel]] = GPTModel diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 3f6238c45..0382462b5 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -35,8 +35,6 @@ def _get_block( return block_config.get_layer( self._distributed_config, hidden_dim=self._hidden_dim, - block_index=block_index, - name=name, lr_scale=None, peft=self._config.peft, return_input=return_input, diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 3ddd5d4fe..3fae970f8 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -167,18 +167,18 @@ def test_dpo_loss(): @requires_cuda @pytest.mark.parametrize("gated", [True, False]) @pytest.mark.parametrize( - "activation_type", [ActivationType.gelu, ActivationType.silu, ActivationType.relu, ActivationType.squared_relu] + "activation", [ActivationType.gelu, ActivationType.silu, ActivationType.relu, ActivationType.squared_relu] ) -def test_mlp_recomputation(gated, activation_type): +def test_mlp_recomputation(gated, activation): tokens = 1024 hidden_size = 2048 - ffn_hidden_size = 4096 + intermediate_size = 4096 std = 1 / 64 input_ = torch.randn(tokens, hidden_size, device="cuda", requires_grad=True) output_grad = torch.randn(tokens, hidden_size, device="cuda", requires_grad=True) - weight_1 = torch.normal(0, std, (ffn_hidden_size * (gated + 1), hidden_size), device="cuda", requires_grad=True) - bias_1 = torch.normal(0, std, (ffn_hidden_size * (gated + 1),), device="cuda", requires_grad=True) - weight_2 = torch.normal(0, std, (ffn_hidden_size, hidden_size), device="cuda", requires_grad=True) + weight_1 = torch.normal(0, std, (intermediate_size * (gated + 1), hidden_size), device="cuda", requires_grad=True) + bias_1 = torch.normal(0, std, (intermediate_size * (gated + 1),), device="cuda", requires_grad=True) + weight_2 = torch.normal(0, std, (intermediate_size, hidden_size), device="cuda", requires_grad=True) bias_2 = torch.normal(0, std, (hidden_size,), device="cuda", requires_grad=True) params = (weight_1, bias_1, weight_2, bias_2) @@ -186,7 +186,7 @@ def test_mlp_recomputation(gated, activation_type): torch_mlp_activation( (torch.nn.functional.linear(input_, weight_1, bias_1)), gated, - activation_type, + activation, ), weight_2.t(), bias_2, @@ -202,7 +202,7 @@ def test_mlp_recomputation(gated, activation_type): param.grad = None param.grad_buffer = torch.empty_like(param) param.param_grad_is_zero = True - output = mlp_autograd(input_, None, *params, gated, activation_type, None, False, True, recompute_level, True) + output = mlp_autograd(input_, None, *params, gated, activation, None, False, True, recompute_level, True) output.backward(output_grad) if i == 0: Assert.rms_close(output, output_ref, 1e-5) @@ -228,7 +228,7 @@ def test_dropless_mlp(): experts_per_token = 4 tokens = 256 hidden_size = 512 - ffn_hidden_size = 1024 + intermediate_size = 1024 std = 1 / 64 input_ = torch.randn(tokens, hidden_size, device="cuda", requires_grad=True) router_weight = torch.normal(0, std, (num_experts, hidden_size), device="cuda") @@ -240,9 +240,9 @@ def test_dropless_mlp(): output_grad = torch.randn(tokens, hidden_size, device="cuda", requires_grad=True) weight_1 = torch.normal( - 0, std, (ffn_hidden_size * 2 * num_experts, hidden_size), device="cuda", requires_grad=True + 0, std, (intermediate_size * 2 * num_experts, hidden_size), device="cuda", requires_grad=True ) - weight_2 = torch.normal(0, std, (ffn_hidden_size * num_experts, hidden_size), device="cuda", requires_grad=True) + weight_2 = torch.normal(0, std, (intermediate_size * num_experts, hidden_size), device="cuda", requires_grad=True) params = (weight_1, weight_2) for param in params: diff --git a/tests/functional/test_triton_kernels.py b/tests/functional/test_triton_kernels.py index 5a9065454..b5d88e0ac 100644 --- a/tests/functional/test_triton_kernels.py +++ b/tests/functional/test_triton_kernels.py @@ -82,12 +82,12 @@ def test_triton_add(): @requires_cuda @pytest.mark.parametrize( - ("batch_size", "sequence_length", "num_heads", "kv_channels"), + ("batch_size", "sequence_length", "num_heads", "head_size"), [(4, 1024, 8, 128), (1, 32, 1, 16), (2, 2048, 2, 192), (3, 519, 7, 134)], ) -def test_triton_rotary(batch_size, sequence_length, num_heads, kv_channels): +def test_triton_rotary(batch_size, sequence_length, num_heads, head_size): assert TritonConfig.TRITON_ENABLED - x = torch.randn(batch_size, sequence_length, num_heads, kv_channels, dtype=torch.bfloat16, device="cuda") + x = torch.randn(batch_size, sequence_length, num_heads, head_size, dtype=torch.bfloat16, device="cuda") y1 = apply_rotary_embeddings( x, @@ -95,19 +95,19 @@ def test_triton_rotary(batch_size, sequence_length, num_heads, kv_channels): .get_layer(None) ._get_frequencies( sequence_length, - kv_channels, + head_size, device="cuda", ), ) y2 = convert_rotary_real_to_complex( triton_rotary_( - convert_rotary_complex_to_real(x, kv_channels, 3), + convert_rotary_complex_to_real(x, head_size, 3), DefaultRotaryConfig(triton=True) .get_layer(None) - ._get_frequencies(sequence_length, kv_channels, device="cuda"), + ._get_frequencies(sequence_length, head_size, device="cuda"), ), - kv_channels, + head_size, 3, ) Assert.rms_close(y1, y2, 1e-3) @@ -166,7 +166,7 @@ def test_triton_normalization(has_bias, zero_centered): @requires_cuda @pytest.mark.parametrize("gated", [True, False]) @pytest.mark.parametrize( - "activation_type", + "activation", [ ActivationType.gelu, ActivationType.silu, @@ -176,15 +176,15 @@ def test_triton_normalization(has_bias, zero_centered): ], ) @pytest.mark.parametrize("recompute", [True, False]) -def test_triton_mlp_activation(gated, activation_type, recompute): +def test_triton_mlp_activation(gated, activation, recompute): assert TritonConfig.TRITON_ENABLED input_ = torch.randn(1024, 4096 * (2 if gated else 1), device="cuda", requires_grad=True) output_grad = torch.randn(1024, 4096, device="cuda") - output1, context = triton_mlp_activation_forward(input_, gated, activation_type) + output1, context = triton_mlp_activation_forward(input_, gated, activation) input_grad1, output3 = triton_mlp_activation_backward(output_grad, context, recompute) - output2 = torch_mlp_activation(input_, gated, activation_type) + output2 = torch_mlp_activation(input_, gated, activation) output2.backward(output_grad) Assert.rms_close(output1, output2, 1e-5) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 755d143e9..e402659b0 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -105,7 +105,7 @@ def _lm_head( ( ({}, {}, False), ({}, {"training_dtype": DataType.bfloat16}, False), - ({"transformer": {"full_precision_residual": True}}, {"training_dtype": DataType.bfloat16}, False), + ({"embeddings_layer": {"full_precision_residual": True}}, {"training_dtype": DataType.bfloat16}, False), ({"sequence_first": True}, {}, False), ({"output_layer": {"logit_z_loss": 1e-3}}, {}, False), ({"output_layer": {"logits_scale_factor": 5.0}}, {}, False), @@ -163,12 +163,14 @@ def test_lm_head( config = GPTBaseModelConfig.from_dict( { "transformer": { - "normalization": {"type": "rms_norm"}, "hidden_size": HIDDEN_SIZE, "num_layers": 0, }, "embeddings_layer": {"vocab_size": VOCAB_SIZE}, - "output_layer": {"cross_entropy_implementation": cross_entropy_impl}, + "output_layer": { + "cross_entropy_implementation": cross_entropy_impl, + "normalization": {"type": "rms_norm"}, + }, }, config_dict, update_type=UpdateType.update, @@ -190,7 +192,7 @@ def test_lm_head( (SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE) if sequence_first else (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE), dtype=( distributed.config.optimization_dtype.torch - if config.transformer.full_precision_residual + if config.embeddings_layer.full_precision_residual else distributed.config.training_dtype.torch ), device=distributed.device, @@ -237,7 +239,7 @@ def test_lm_head( torch.empty( VOCAB_SIZE, HIDDEN_SIZE, dtype=distributed.config.training_dtype.torch, device=distributed.device ) - .normal_(config.transformer.init_method_std) + .normal_(config.transformer.hidden_size**-0.5) .requires_grad_(True) ) kwargs[WORD_EMBEDDINGS_WEIGHT if config.output_layer.tied_weight else OUTPUT_WEIGHTS] = logit_weight diff --git a/tests/test_attention.py b/tests/test_attention.py index 37514acd5..62c34d3c0 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -1,35 +1,13 @@ -import unittest.mock - import torch from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.attention.attention import Attention from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs from fast_llm.layers.attention.preprocessing import FlashAttnVarlenPreprocessor from fast_llm.layers.block.config import BlockDimNames from fast_llm.utils import Assert -def test_decide_window_size(): - attention = unittest.mock.Mock(spec=Attention) - attention._decide_window_size = Attention._decide_window_size.__get__(attention) # Attach real method - - # Arrange - Case 1: window_size is returned (layer_index >= max_window_layers) - attention._config = AttentionConfig(kv_channels=64, window_size=512, max_window_layers=2) - attention._block_index = 2 - assert attention._decide_window_size() == 512 - - # Arrange - Case 2: window_size is None (layer_index < max_window_layers) - attention._config = AttentionConfig(kv_channels=64, window_size=512, max_window_layers=2) - attention._block_index = 1 - assert attention._decide_window_size() is None - - # Arrange - Case 3: max_window_layers is None (always return window_size) - attention._config = AttentionConfig(kv_channels=64, window_size=512, max_window_layers=None) - assert attention._decide_window_size() == 512 - - def test_varlen_preprocessor(): sequence_lengths = [torch.tensor([8, 13, 4, 11], dtype=torch.int32), torch.tensor([11, 16, 9], dtype=torch.int32)] # First micro-sequence: @@ -51,7 +29,7 @@ def test_varlen_preprocessor(): micro_sequence_length = 12 sequence_length = 36 varlen_preprocessor = FlashAttnVarlenPreprocessor( - AttentionConfig(kv_channels=64), DistributedConfig(training_dtype="bfloat16") + AttentionConfig(head_size=64), DistributedConfig(training_dtype="bfloat16") ) for micro_seq_idx in range(int(sequence_length / micro_sequence_length)): kwargs = { diff --git a/tests/test_config.py b/tests/test_config.py index 8954114f7..03d535520 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -81,8 +81,8 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): "head_groups": 4, }, "mlp": { - "ffn_hidden_size": 4096, # Implicit default, default value - "activation_type": "silu", # Implicit default, non-default value + "intermediate_size": 4096, # Implicit default, default value + "activation": "silu", # Implicit default, non-default value }, "normalization": {"type": "rms_norm"}, # Nested "num_layers": 12, # Default @@ -121,7 +121,6 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): "pretrained": {"format": "fast_llm", "path": config_path, "load_config": load_config}, } ) - Assert.eq(pretrained_config.model.base_model.transformer.mixer.kv_channels, 64) serialized_config = pretrained_config.model.to_dict() expected_config = {"type": "gpt", "distributed": DistributedConfig().to_dict()} @@ -139,15 +138,15 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "mlp": { "type": "mlp", - "ffn_hidden_size": 4096, # Implicit default, default value - "activation_type": "silu", # Implicit default, non-default value + "intermediate_size": 4096, # Implicit default, default value + "activation": "silu", # Implicit default, non-default value }, "normalization": {"type": "rms_norm", "implementation": "triton"}, "num_layers": 12, "hidden_size": 512, }, "embeddings_layer": {"vocab_size": 1000}, - "output_layer": {"tied_weight": False}, + "output_layer": {"tied_weight": False, "normalization": {"type": "layer_norm"}}, "peft": {"type": "lora", "freeze_others": False}, } else: @@ -155,9 +154,11 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): base_model_update["transformer"]["mixer"]["type"] = "attention" base_model_update["transformer"]["mixer"]["rotary"] = {"type": "none"} base_model_update["transformer"]["mlp"] = {"type": "mlp"} + base_model_update["output_layer"] = {"normalization": {"type": "layer_norm"}} base_model_update["peft"] = {"type": "lora", "freeze_others": False} expected_config["base_model"] = base_model_update + print("IKEUFGH", serialized_config, expected_config) check_equal_nested(serialized_config, expected_config) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 9ba266f12..ba17beca5 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -1,3 +1,4 @@ +import copy import dataclasses import enum import functools @@ -6,6 +7,7 @@ import pytest +from fast_llm.config import set_nested_dict_value from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig @@ -56,11 +58,24 @@ class ModelTestingGroupAction(enum.StrEnum): not_implemented = "not_implemented" +def _config_dict_to_args(config_dict: dict[str, typing.Any], keys=()): + """ + Converts a config dict to cli arguments. Not generic but good enough for the tests. + """ + args = [] + for key, value in config_dict.items(): + if isinstance(value, dict): + args += _config_dict_to_args(value, (*keys, key)) + else: + args.append(f"{'.'.join((*keys, key))}={value}") + return args + + @dataclasses.dataclass(kw_only=True, frozen=True) class ModelTestingConfig: name: str = None model_type: str - config_args: list[str] + config_dict: dict[str, typing.Any] megatron_args: list[str] | None checkpoint_format: type[CheckpointFormat] | None groups: dict[ModelTestingGroup, ModelTestingGroupAction] @@ -69,6 +84,10 @@ class ModelTestingConfig: # Option to skip specific distributed configuration with name containing any of the provided strings. skip_tests: tuple[str] = () + @functools.cached_property + def config_args(self): + return _config_dict_to_args(self.config_dict) + @functools.cached_property def trainer_config_class(self) -> type[TrainerConfig]: return TrainerConfig.get_subclass(self.model_type) @@ -76,7 +95,7 @@ def trainer_config_class(self) -> type[TrainerConfig]: @functools.cached_property def trainer_config(self) -> TrainerConfig: # See `RunnableConfig._from_parsed_args` - return self.trainer_config_class.from_dict(self.trainer_config_class._parse_updates(self.config_args)) + return self.trainer_config_class.from_dict(self.config_dict) @functools.cached_property def evaluators_config_class(self) -> type[EvaluatorsConfig]: @@ -87,7 +106,7 @@ def evaluators_config_class(self) -> type[EvaluatorsConfig]: @functools.cached_property def evaluators_config(self) -> EvaluatorsConfig: # See `RunnableConfig._from_parsed_args` - return self.evaluators_config_class.from_dict(self.evaluators_config_class._parse_updates(self.config_args)) + return self.evaluators_config_class.from_dict(self.config_dict) @functools.cached_property def model_config_class(self) -> type[FastLLMModelConfig]: @@ -119,79 +138,120 @@ def _update_and_add_testing_config( new_name: str, *, model_type: str | None = None, - extra_args: list[str] | None = None, + updates: dict[str | tuple[str, ...], typing.Any] | None = None, megatron_args: list[str] | None = ..., groups: dict[ModelTestingGroup, ModelTestingGroupAction], **kwargs, -): +) -> ModelTestingConfig: + config = MODEL_CONFIGS[old_name] - updates: dict[str, typing.Any] = { - "name": new_name, - "groups": groups, - } - if model_type is not None: - updates["model_type"] = model_type - if extra_args is not None: - updates["config_args"] = config.config_args + extra_args + config_dict = copy.deepcopy(config.config_dict) + if updates is not None: + for keys, update in updates.items(): + set_nested_dict_value(config_dict, keys, update) if megatron_args is not ...: if megatron_args is None: - updates["megatron_args"] = None - elif config.megatron_args is None: - updates["megatron_args"] = megatron_args - else: - updates["megatron_args"] = config.megatron_args + megatron_args - updates.update(kwargs) - - MODEL_CONFIGS[new_name] = dataclasses.replace(config, **updates) + megatron_args = None + elif config.megatron_args is not None: + megatron_args = config.megatron_args + megatron_args + new_config = dataclasses.replace( + config, + name=new_name, + model_type=config.model_type if model_type is None else model_type, + groups=groups, + config_dict=config_dict, + megatron_args=megatron_args, + **kwargs, + ) + MODEL_CONFIGS[new_name] = new_config + return new_config MODEL_CONFIGS: dict[str, ModelTestingConfig] = {} +# We use a smaller initialization scheme than the default to lower variance in layer outputs during comparisons. +# This is as if we had a hidden size of 2048 +init_1 = {"initialization": {"type": "normal", "std": 2**-5.5}} +# Needed to match Megatron (init_1 / (2 * num_layers) ** 0.5) +init_2 = {"initialization": {"type": "normal", "std": 2**-6.5}} MODEL_CONFIGS["gpt2"] = ModelTestingConfig( # Tests gpt2 features (absolute embeddings, layer norm, relu activation, tied embeddings, MHA, linear biases). name="gpt2", model_type="gpt", - config_args=[ - "training.logs.interval=1", - "run.tensor_logs.save=True", - "run.tensor_logs.show=False", - "model.base_model.embeddings_layer.position_embeddings.enabled=True", - "model.base_model.embeddings_layer.num_position_embeddings=512", - f"model.base_model.embeddings_layer.vocab_size={MODEL_TEST_VOCAB_SIZE}", - "model.base_model.transformer.num_layers=2", - "model.base_model.transformer.hidden_size=256", - "model.base_model.transformer.mixer.num_attention_heads=8", - "model.base_model.transformer.mixer.head_groups=8", - "model.base_model.transformer.init_method_std=0.022", - f"model.multi_stage.debug_param_init={_LOG_LEVEL}", - f"model.multi_stage.debug_layer_outputs={_LOG_LEVEL}", - f"model.multi_stage.debug_layer_gradients={_LOG_LEVEL}", - f"model.multi_stage.debug_all_param_gradients={_LOG_LEVEL}", - "model.multi_stage.debug_tensor_parallel=True", - "model.distributed.reproducible_init=True", - "model.distributed.timeout=20", - "training.train_iters=2", - "training.num_workers=0", - "training.timeout=30", - "batch.batch_size=8", - "batch.sequence_length=512", - "data.datasets.training.type=slice", - "data.datasets.training.end=0.969", - "data.datasets.training.dataset.type=memmap", - f"data.datasets.training.dataset.path={MODEL_DATASET_PREFIX}", - "data.datasets.validation.type=slice", - "data.datasets.validation.begin=0.969", - "data.datasets.validation.end=0.999", - "data.datasets.validation.dataset.type=memmap", - f"data.datasets.validation.dataset.path={MODEL_DATASET_PREFIX}", - "data.datasets.test.type=slice", - "data.datasets.test.begin=0.999", - "data.datasets.test.end=1", - "data.datasets.test.dataset.type=memmap", - f"data.datasets.test.dataset.path={MODEL_DATASET_PREFIX}", - "optimizer.learning_rate.base=0.0001", - ], + config_dict={ + "run": { + "tensor_logs": { + "save": True, + "show": False, + }, + }, + "training": { + "logs": {"interval": 1}, + "train_iters": 2, + "num_workers": 0, + "timeout": 30, + }, + "model": { + "base_model": { + "embeddings_layer": { + "word_embeddings": init_1, + "position_embeddings": {"enabled": True, **init_1}, + "num_position_embeddings": 512, + "vocab_size": MODEL_TEST_VOCAB_SIZE, + }, + "transformer": { + "mixer": { + "query_layer": {"weight": init_1}, + "key_layer": {"weight": init_1}, + "value_layer": {"weight": init_1}, + "dense_layer": {"weight": init_2}, + "heads": 8, + "head_groups": 8, + "head_size": 32, + }, + "mlp": {"layer_1": {"weight": init_1}, "layer_2": {"weight": init_2}, "intermediate_size": 1024}, + "num_layers": 2, + "hidden_size": 256, + }, + "output_layer": {"output_weight": init_1}, + }, + "multi_stage": { + "debug_param_init": _LOG_LEVEL, + "debug_layer_outputs": _LOG_LEVEL, + "debug_layer_gradients": _LOG_LEVEL, + "debug_all_param_gradients": _LOG_LEVEL, + "debug_tensor_parallel": True, + }, + "distributed": { + "reproducible_init": True, + "timeout": 20, + }, + }, + "batch": {"batch_size": 8, "sequence_length": 512}, + "data": { + "datasets": { + "training": { + "dataset": {"type": "memmap", "path": MODEL_DATASET_PREFIX}, + "type": "slice", + "end": 0.969, + }, + "validation": { + "dataset": {"type": "memmap", "path": MODEL_DATASET_PREFIX}, + "type": "slice", + "begin": 0.969, + "end": 0.999, + }, + "test": { + "dataset": {"type": "memmap", "path": MODEL_DATASET_PREFIX}, + "type": "slice", + "begin": 0.999, + "end": 1, + }, + } + }, + "optimizer": {"learning_rate": {"base": 0.0001}}, + }, megatron_args=[ "--num-layers=2", "--hidden-size=256", @@ -210,7 +270,7 @@ def _update_and_add_testing_config( "--micro-batch-size=8", "--max-position-embeddings=512", "--seq-length=512", - "--init-method-std=0.022", + f"--init-method-std={2**-5.5}", "--lr=0.0001", "--num-workers=0", "--valid-num-workers=0", @@ -239,7 +299,9 @@ def _update_and_add_testing_config( # Tests MQA. "gpt2", "starcoder", - extra_args=["model.base_model.transformer.mixer.head_groups=1"], + updates={ + ("model", "base_model", "transformer", "mixer", "head_groups"): 1, + }, megatron_args=["--group-query-attention"], checkpoint_format=None, groups={ @@ -256,11 +318,11 @@ def _update_and_add_testing_config( # Tests intermediate between gpt2 and llama, closest converter to gpt2. "gpt2", "starcoder2", - extra_args=[ - "model.base_model.transformer.mixer.head_groups=4", - "model.base_model.transformer.mixer.rotary.type=default", - "model.base_model.embeddings_layer.position_embeddings.enabled=False", - ], + updates={ + ("model", "base_model", "transformer", "mixer", "head_groups"): 4, + ("model", "base_model", "transformer", "mixer", "rotary", "type"): "default", + ("model", "base_model", "embeddings_layer", "position_embeddings", "enabled"): False, + }, megatron_args=[ "--group-query-attention", "--num-query-groups=4", @@ -283,14 +345,15 @@ def _update_and_add_testing_config( # Main tested model. "starcoder2", "llama", - extra_args=[ - "model.base_model.transformer.mlp.gated=True", - "model.base_model.transformer.mlp.activation_type=silu", - "model.base_model.transformer.add_linear_biases=False", - "model.base_model.transformer.normalization.type=rms_norm", - "model.base_model.transformer.mlp.ffn_hidden_size=1024", - "model.base_model.output_layer.tied_weight=False", - ], + updates={ + ("model", "base_model", "transformer", "mixer", "add_linear_biases"): False, + ("model", "base_model", "transformer", "mlp", "gated"): True, + ("model", "base_model", "transformer", "mlp", "activation"): "silu", + ("model", "base_model", "transformer", "mlp", "add_linear_biases"): False, + ("model", "base_model", "transformer", "normalization", "type"): "rms_norm", + ("model", "base_model", "output_layer", "normalization", "type"): "rms_norm", + ("model", "base_model", "output_layer", "tied_weight"): False, + }, megatron_args=[ "--swiglu", "--disable-bias-linear", @@ -314,7 +377,9 @@ def _update_and_add_testing_config( # Tests llama3-style rotary embeddings. "llama", "llama3", - extra_args=["model.base_model.transformer.mixer.rotary.type=llama3"], + updates={ + ("model", "base_model", "transformer", "mixer", "rotary", "type"): "llama3", + }, # Megatron doesn't support Llama3-style Rotary Embeddings megatron_args=None, checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, @@ -332,7 +397,9 @@ def _update_and_add_testing_config( # Tests yarn-style rotary embeddings. "llama", "llama_yarn", - extra_args=["model.base_model.transformer.mixer.rotary.type=yarn"], + updates={ + ("model", "base_model", "transformer", "mixer", "rotary", "type"): "yarn", + }, # Megatron doesn't support Yarn-style Rotary Embeddings megatron_args=None, checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, @@ -350,7 +417,7 @@ def _update_and_add_testing_config( # Tests diffusion llama converter. "llama_yarn", "diffusion_llama", - extra_args=[], + updates={}, # Megatron doesn't support Yarn-style Rotary Embeddings megatron_args=None, checkpoint_format=DiffusionLlamaGPTHuggingfaceCheckpointFormat, @@ -370,7 +437,9 @@ def _update_and_add_testing_config( # Tests multi-token prediction, custom HF model and converter. "llama", "llama_mtp", - extra_args=["model.base_model.output_layer.prediction_heads=4"], + updates={ + ("model", "base_model", "output_layer", "prediction_heads"): 4, + }, # Megatron doesn't support multi-token prediction. megatron_args=None, checkpoint_format=MTPLlamaGPTHuggingfaceCheckpointFormat, @@ -391,7 +460,9 @@ def _update_and_add_testing_config( "llama", "qwen2", # TODO: replace - extra_args=["model.base_model.transformer.add_linear_biases=only_attn_qkv"], + updates={ + ("model", "base_model", "transformer", "add_linear_biases"): "only_attn_qkv", + }, # Megatron doesn't support per sub layer biases. megatron_args=None, checkpoint_format=Qwen2GPTHuggingfaceCheckpointFormat, @@ -411,7 +482,7 @@ def _update_and_add_testing_config( "qwen2", "dream", # TODO: replace only_attn_qkv - extra_args=[], + updates={}, # Megatron doesn't support per sub layer biases. megatron_args=None, checkpoint_format=DiffusionDreamGPTHuggingfaceCheckpointFormat, @@ -431,7 +502,9 @@ def _update_and_add_testing_config( # Tests sliding window attention, mistral converter. "llama", "mistral", - extra_args=["model.base_model.transformer.mixer.window_size=128"], + updates={ + ("model", "base_model", "transformer", "mixer", "window_size"): 128, + }, # Megatron doesn't support sliding windows. megatron_args=None, checkpoint_format=MistralGPTHuggingfaceCheckpointFormat, @@ -450,11 +523,12 @@ def _update_and_add_testing_config( # Tests mixture of experts, mixtral converter. "llama", "mixtral", - extra_args=[ - "model.base_model.transformer.mlp.type=moe", - "model.base_model.transformer.mlp.num_experts=4", - "model.base_model.transformer.mlp.num_experts_per_token=4", - ], + updates={ + ("model", "base_model", "transformer", "mlp", "type"): "moe", + ("model", "base_model", "transformer", "mlp", "router", "weight"): init_1, + ("model", "base_model", "transformer", "mlp", "experts"): 4, + ("model", "base_model", "transformer", "mlp", "experts_per_token"): 4, + }, megatron_args=[ "--num-experts=4", "--moe-router-topk=4", @@ -477,12 +551,15 @@ def _update_and_add_testing_config( "llama", "llamba", model_type="hybrid_ssm", - extra_args=[ - "model.base_model.ssm.type=mamba", - "model.base_model.hybrid_block_layout=['t','m']", - "model.base_model.ssm.d_inner=512", - "model.base_model.ssm.state_size=16", - ], + updates={ + ("model", "base_model", "ssm"): { + "type": "mamba", + "d_inner": 512, + "state_size": 16, + "add_linear_biases": False, + }, + ("model", "base_model", "hybrid_block_layout"): "['t','m']", + }, megatron_args=None, checkpoint_format=LLambaHuggingfaceCheckpointFormat, # TODO: Add back generate as `normal` when stable. @@ -505,14 +582,16 @@ def _update_and_add_testing_config( "llama", "hybrid_mamba2", model_type="hybrid_ssm", - extra_args=[ - "model.base_model.hybrid_block_layout=['t','m2']", - "model.base_model.ssm.type=mamba_2", - "model.base_model.ssm.d_inner=512", - "model.base_model.ssm.state_size=8", - "model.base_model.ssm.d_xb=256", - # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}" - ], + updates={ + ("model", "base_model", "ssm"): { + "type": "mamba_2", + "d_inner": 512, + "state_size": 8, + "d_xb": 256, + "add_linear_biases": False, + }, + ("model", "base_model", "hybrid_block_layout"): "['t','m2']", + }, megatron_args=None, checkpoint_format=AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, groups={ @@ -537,15 +616,18 @@ def _update_and_add_testing_config( "llama", "hybrid_discrete_mamba2", model_type="hybrid_ssm", - extra_args=[ - "model.base_model.hybrid_block_layout=['t','m2d']", - "model.base_model.ssm.type=discrete_mamba_2", - "model.base_model.ssm.d_inner=512", - "model.base_model.ssm.state_size=8", - "model.base_model.ssm.n_qk_heads=8", - "model.base_model.ssm.n_v_heads=16", - "model.base_model.ssm.chunk_size=32", - ], + updates={ + ("model", "base_model", "ssm"): { + "type": "discrete_mamba_2", + "d_inner": 512, + "state_size": 8, + "n_qk_heads": 8, + "n_v_heads": 16, + "chunk_size": 32, + "add_linear_biases": False, + }, + ("model", "base_model", "hybrid_block_layout"): "['t','m2d']", + }, megatron_args=None, checkpoint_format=AprielSSMHHybridHuggingfaceCheckpointFormat, groups={ From 4185741caf5b3cc4ac3259590d6c4dbd24c4c60c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 9 Sep 2025 14:22:11 -0400 Subject: [PATCH 79/82] misc --- fast_llm/layers/block/config.py | 10 ---------- fast_llm/layers/ssm/config.py | 22 ++++------------------ fast_llm/models/ssm/config.py | 2 -- tests/utils/model_configs.py | 2 ++ 4 files changed, 6 insertions(+), 30 deletions(-) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index b4772d50e..fd42bccf9 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -60,10 +60,6 @@ class BlockLayerConfig(BaseModelConfig): def layer_class(self) -> "type[BlockLayer]": raise NotImplementedError() - def set_defaults(self, hidden_size: int): - # Opportunity to set defaults that depend on the hidden size. - pass - def get_layer( self, distributed_config: DistributedConfig, @@ -161,12 +157,6 @@ class BlockConfig(BaseModelConfig): valid=check_field(Assert.gt, 0), ) - def _validate(self) -> None: - self.mixer.set_defaults(self.hidden_size) - self.mlp.set_defaults(self.hidden_size) - - super()._validate() - def get_layer( self, distributed_config: DistributedConfig, diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 862a6cd1a..9b89b28cd 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -89,7 +89,7 @@ class SSMConfig(MixerConfig): # [Mamba, Mamba2, DiscreteMamba2] # c_size [Mamba, Mamba2, DiscreteMamba2]? d_inner: int = Field( - default=None, + default=2048, desc="Inner dimension.", hint=FieldHint.core, ) @@ -101,10 +101,6 @@ class SSMConfig(MixerConfig): hint=FieldHint.architecture, ) - def set_defaults(self, hidden_size: int): - if self.d_inner is None: - self.d_inner = 2 * hidden_size - @config_class() class MambaBaseConfig(SSMConfig): @@ -127,16 +123,11 @@ class MambaBaseConfig(SSMConfig): # Model dimensions # [Mamba, Mamba2] dt_rank: int = Field( - default=None, - desc="Rank of the Δ projection matrix. If 'None', will be set to ceil(hidden_size/16)", + default=64, + desc="Rank of the Δ projection matrix.", hint=FieldHint.architecture, ) - def set_defaults(self, hidden_size: int): - super().set_defaults(hidden_size) - if self.dt_rank is None: - self.dt_rank = math.ceil(hidden_size / 16) - @config_class(dynamic_type={MixerConfig: "mamba"}) class MambaConfig(MambaBaseConfig): @@ -193,7 +184,7 @@ class Mamba2Config(MambaBaseConfig): # Model dimensions # xb_size [Mamba2] d_xb: int = Field( - default=None, + default=1024, desc="Dimension of the xB in Mamba2 blocks.", hint=FieldHint.architecture, ) @@ -206,11 +197,6 @@ class Mamba2Config(MambaBaseConfig): hint=FieldHint.architecture, ) - def set_defaults(self, hidden_size: int): - super().set_defaults(hidden_size) - if self.d_xb is None: - self.d_xb = hidden_size - @property def layer_class(self) -> "type[Mamba2]": from fast_llm.layers.ssm.mamba2 import Mamba2 diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 38839276f..da44e547f 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -46,8 +46,6 @@ class HybridSSMBaseModelConfig(GPTBaseModelConfig): ssm_block_type: SSMBlockType | None = Field(init=False) def _validate(self): - self.ssm.set_defaults(self.transformer.hidden_size) - if self.hybrid_block_layout is None: with self._set_implicit_default(): self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_layers diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index ba17beca5..abd3d4bad 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -556,6 +556,7 @@ def _update_and_add_testing_config( "type": "mamba", "d_inner": 512, "state_size": 16, + "dt_rank": 16, "add_linear_biases": False, }, ("model", "base_model", "hybrid_block_layout"): "['t','m']", @@ -587,6 +588,7 @@ def _update_and_add_testing_config( "type": "mamba_2", "d_inner": 512, "state_size": 8, + "dt_rank": 16, "d_xb": 256, "add_linear_biases": False, }, From 7763296bf88202783c12253b2ba85527e7a4a2b1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 16 Sep 2025 21:48:41 -0400 Subject: [PATCH 80/82] stuff --- .dockerignore | 1 + Dockerfile | 2 + docs/developer_guide/conversion.md | 2 +- docs/recipes/generate.md | 4 +- examples/mistral.yaml | 47 +- fast_llm/config.py | 12 +- fast_llm/engine/base_model/base_model.py | 17 - fast_llm/engine/base_model/config.py | 12 + fast_llm/engine/checkpoint/external.py | 184 +- fast_llm/engine/checkpoint/huggingface.py | 133 +- fast_llm/engine/evaluation/evaluator.py | 2 +- fast_llm/engine/multi_stage/multi_stage.py | 4 +- fast_llm/engine/schedule/runner.py | 13 +- fast_llm/engine/training/trainer.py | 2 +- fast_llm/layers/attention/attention.py | 4 +- fast_llm/layers/attention/config.py | 3 +- fast_llm/layers/block/block.py | 135 +- fast_llm/layers/block/config.py | 206 ++- .../block/{mlp/__init__.py => sequence.py} | 0 fast_llm/layers/common/linear/config.py | 8 +- .../layers/common/normalization/config.py | 18 +- .../common/normalization/normalization.py | 36 +- .../custom => layers/decoder}/__init__.py | 0 fast_llm/layers/decoder/block.py | 152 ++ fast_llm/layers/decoder/config.py | 111 ++ fast_llm/layers/decoder/mlp/__init__.py | 0 .../layers/{block => decoder}/mlp/config.py | 35 +- .../mlp/mixture_of_experts.py | 4 +- fast_llm/layers/{block => decoder}/mlp/mlp.py | 6 +- fast_llm/layers/language_model/config.py | 142 +- fast_llm/layers/language_model/embedding.py | 9 +- fast_llm/layers/language_model/head.py | 11 +- fast_llm/layers/ssm/config.py | 35 +- fast_llm/layers/ssm/discrete_mamba2.py | 6 +- fast_llm/layers/ssm/mamba.py | 10 +- fast_llm/layers/ssm/mamba2.py | 9 +- fast_llm/logging.py | 2 +- fast_llm/models/auto.py | 4 +- fast_llm/models/custom/config.py | 62 - fast_llm/models/custom/data.py | 48 - fast_llm/models/custom/head.py | 6 - fast_llm/models/custom/huggingface.py | 18 - fast_llm/models/custom/model.py | 59 - fast_llm/models/custom/readme.md | 38 - fast_llm/models/custom/trainer.py | 15 - fast_llm/models/gpt/config.py | 75 +- fast_llm/models/gpt/conversion.py | 856 --------- fast_llm/models/gpt/conversion/__init__.py | 0 fast_llm/models/gpt/conversion/apriel.py | 374 ++++ fast_llm/models/gpt/conversion/auto.py | 38 + fast_llm/models/gpt/conversion/config.py | 49 + .../models/gpt/conversion/diffusion_dream.py | 44 + .../models/gpt/conversion/diffusion_llama.py | 42 + fast_llm/models/gpt/conversion/llama.py | 575 ++++++ fast_llm/models/gpt/conversion/mistral.py | 61 + fast_llm/models/gpt/conversion/mixtral.py | 88 + fast_llm/models/gpt/conversion/mtp_llama.py | 95 + fast_llm/models/gpt/conversion/qwen2.py | 62 + fast_llm/models/gpt/megatron.py | 32 +- fast_llm/models/gpt/model.py | 128 +- fast_llm/models/ssm/config.py | 189 -- fast_llm/models/ssm/conversion.py | 774 -------- .../configuration_ssm_hybrid_apriel.py | 448 ----- .../modeling_ssm_hybrid_apriel.py | 1576 ----------------- .../apriel_ssm/configuration_ssm_apriel.py | 103 -- .../apriel_ssm/modeling_ssm_apriel.py | 743 -------- .../llamba/configuration_mtp_llamba.py | 94 - .../external/llamba/modeling_mtp_llamba.py | 389 ---- fast_llm/models/ssm/huggingface.py | 23 - fast_llm/models/ssm/model.py | 53 - fast_llm/models/ssm/trainer.py | 9 - fast_llm/tensor.py | 4 +- fast_llm/utils.py | 15 +- fast_llm_external_models/__init__.py | 0 .../configuration_apriel_hybrid_ssm.py | 6 +- .../modeling_apriel_hybrid_ssm.py | 35 +- .../diffusion_dream/configuration_dream.py | 0 .../diffusion_dream/generation_config.json | 0 .../diffusion_dream/generation_utils.py | 0 .../diffusion_dream/modeling_dream.py | 121 +- .../configuration_diffusion_llama.py | 0 .../diffusion_llama/generation_utils.py | 0 .../modeling_diffusion_llama.py | 18 +- .../eval/apriel_eval_wrapper.py | 14 +- .../eval/run_evalchemy.py | 3 +- .../eval/run_lm_eval.py | 2 +- ...brid_checkpoint_with_importance_15b_mil.py | 4 +- .../mtp_llama/configuration_mtp_llama.py | 0 .../mtp_llama/modeling_mtp_llama.py | 0 setup.cfg | 4 +- tests/conftest.py | 4 + tests/layers/test_lm_head.py | 10 +- tests/models/distributed_test_checkpoint.py | 2 + tests/models/test_checkpoint.py | 139 +- tests/models/test_generate.py | 27 +- tests/models/test_model.py | 7 +- tests/test_config.py | 92 +- tests/test_multi_stage.py | 74 +- tests/utils/distributed_configs.py | 2 +- tests/utils/model_configs.py | 233 +-- tests/utils/save_load_configs.py | 33 +- 101 files changed, 2727 insertions(+), 6669 deletions(-) rename fast_llm/layers/block/{mlp/__init__.py => sequence.py} (100%) rename fast_llm/{models/custom => layers/decoder}/__init__.py (100%) create mode 100644 fast_llm/layers/decoder/block.py create mode 100644 fast_llm/layers/decoder/config.py create mode 100644 fast_llm/layers/decoder/mlp/__init__.py rename fast_llm/layers/{block => decoder}/mlp/config.py (81%) rename fast_llm/layers/{block => decoder}/mlp/mixture_of_experts.py (98%) rename fast_llm/layers/{block => decoder}/mlp/mlp.py (96%) delete mode 100644 fast_llm/models/custom/config.py delete mode 100644 fast_llm/models/custom/data.py delete mode 100644 fast_llm/models/custom/head.py delete mode 100644 fast_llm/models/custom/huggingface.py delete mode 100644 fast_llm/models/custom/model.py delete mode 100644 fast_llm/models/custom/readme.md delete mode 100644 fast_llm/models/custom/trainer.py delete mode 100644 fast_llm/models/gpt/conversion.py create mode 100644 fast_llm/models/gpt/conversion/__init__.py create mode 100644 fast_llm/models/gpt/conversion/apriel.py create mode 100644 fast_llm/models/gpt/conversion/auto.py create mode 100644 fast_llm/models/gpt/conversion/config.py create mode 100644 fast_llm/models/gpt/conversion/diffusion_dream.py create mode 100644 fast_llm/models/gpt/conversion/diffusion_llama.py create mode 100644 fast_llm/models/gpt/conversion/llama.py create mode 100644 fast_llm/models/gpt/conversion/mistral.py create mode 100644 fast_llm/models/gpt/conversion/mixtral.py create mode 100644 fast_llm/models/gpt/conversion/mtp_llama.py create mode 100644 fast_llm/models/gpt/conversion/qwen2.py delete mode 100644 fast_llm/models/ssm/config.py delete mode 100644 fast_llm/models/ssm/conversion.py delete mode 100644 fast_llm/models/ssm/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py delete mode 100644 fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py delete mode 100644 fast_llm/models/ssm/external/apriel_ssm/configuration_ssm_apriel.py delete mode 100644 fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py delete mode 100644 fast_llm/models/ssm/external/llamba/configuration_mtp_llamba.py delete mode 100644 fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py delete mode 100644 fast_llm/models/ssm/huggingface.py delete mode 100644 fast_llm/models/ssm/model.py delete mode 100644 fast_llm/models/ssm/trainer.py create mode 100644 fast_llm_external_models/__init__.py rename fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py => fast_llm_external_models/apriel_hybrid_ssm/configuration_apriel_hybrid_ssm.py (89%) rename fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py => fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py (98%) rename {fast_llm/models/gpt/external => fast_llm_external_models}/diffusion_dream/configuration_dream.py (100%) rename {fast_llm/models/gpt/external => fast_llm_external_models}/diffusion_dream/generation_config.json (100%) rename {fast_llm/models/gpt/external => fast_llm_external_models}/diffusion_dream/generation_utils.py (100%) rename {fast_llm/models/gpt/external => fast_llm_external_models}/diffusion_dream/modeling_dream.py (95%) rename {fast_llm/models/gpt/external => fast_llm_external_models}/diffusion_llama/configuration_diffusion_llama.py (100%) rename {fast_llm/models/gpt/external => fast_llm_external_models}/diffusion_llama/generation_utils.py (100%) rename {fast_llm/models/gpt/external => fast_llm_external_models}/diffusion_llama/modeling_diffusion_llama.py (99%) rename {fast_llm/models/ssm/external => fast_llm_external_models}/eval/apriel_eval_wrapper.py (93%) rename {fast_llm/models/ssm/external => fast_llm_external_models}/eval/run_evalchemy.py (66%) rename {fast_llm/models/ssm/external => fast_llm_external_models}/eval/run_lm_eval.py (67%) rename {fast_llm/models/ssm/external => fast_llm_external_models}/make_hybrid_checkpoint_with_importance_15b_mil.py (96%) rename {fast_llm/models/gpt/external => fast_llm_external_models}/mtp_llama/configuration_mtp_llama.py (100%) rename {fast_llm/models/gpt/external => fast_llm_external_models}/mtp_llama/modeling_mtp_llama.py (100%) diff --git a/.dockerignore b/.dockerignore index 0ed5480a2..500fbe11c 100644 --- a/.dockerignore +++ b/.dockerignore @@ -6,6 +6,7 @@ !setup.cfg !Megatron-LM !fast_llm +!fast_llm_external_models !examples !tools !tests diff --git a/Dockerfile b/Dockerfile index 71f59fffe..526026fa4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -33,6 +33,7 @@ RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d@git+https://gith RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2" # Copy dependency files with universal write permissions for all users. COPY --chmod=777 setup.py setup.cfg pyproject.toml ./ +COPY --chmod=777 ./fast_llm_external_models/__init__.py fast_llm_external_models/ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/ @@ -44,4 +45,5 @@ COPY --chmod=777 ./Megatron-LM Megatron-LM COPY --chmod=777 ./examples examples COPY --chmod=777 ./tests tests COPY --chmod=777 ./tools tools +COPY --chmod=777 ./fast_llm_external_models fast_llm_external_models COPY --chmod=777 --exclude=./fast_llm/csrc/ ./fast_llm/ fast_llm/ diff --git a/docs/developer_guide/conversion.md b/docs/developer_guide/conversion.md index 35a324db0..6f42d8b6a 100644 --- a/docs/developer_guide/conversion.md +++ b/docs/developer_guide/conversion.md @@ -232,7 +232,7 @@ Continuing our `AwesomeModel` handler example, we define: def _create_weight_converters(self) -> list[WeightConverter]: converters = [] # The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`. - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = len(self._model.config.base_model.decoder) # A simple renaming example, for the word embeddings. converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) diff --git a/docs/recipes/generate.md b/docs/recipes/generate.md index e6bda8031..655fa29c0 100644 --- a/docs/recipes/generate.md +++ b/docs/recipes/generate.md @@ -21,12 +21,12 @@ Below is a step-by-step example of how to generate text using a Fast-LLM model c import huggingface_hub from transformers import AutoTokenizer from fast_llm.engine.checkpoint.config import CheckpointLoadConfig -from fast_llm.models.gpt.config import LlamaGPTHuggingfaceCheckpointFormat +from fast_llm.models.gpt.conversion.config import LlamaCheckpointFormat from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM # Specify model and configuration model = "HuggingFaceTB/SmolLM2-135M-Instruct" -checkpoint_format = LlamaGPTHuggingfaceCheckpointFormat +checkpoint_format = LlamaCheckpointFormat max_new_tokens = 50 # Download model checkpoint from the Hugging Face Hub to a local directory diff --git a/examples/mistral.yaml b/examples/mistral.yaml index 924bfba51..4b7fdd968 100644 --- a/examples/mistral.yaml +++ b/examples/mistral.yaml @@ -27,32 +27,33 @@ optimizer: beta_2: 0.95 model: base_model: - transformer: - mixer: - type: attention - rotary: - type: default - theta: 10000 - heads: 32 - head_groups: 8 - head_size: 128 - add_linear_biases: false - window_size: 4096 - dropout: 0.0 - mlp: - intermediate_size: 14336 - add_linear_biases: false - gated: true - activation: silu - normalization: - type: rms_norm - epsilon: 1.0e-05 - num_layers: 32 - hidden_size: 4096 - dropout: 0.0 embeddings_layer: + hidden_size: 4096 vocab_size: 32000 dropout: 0.0 + decoder: + block: + mixer: + type: attention + rotary: + type: default + theta: 10000 + heads: 32 + head_groups: 8 + head_size: 128 + add_linear_biases: false + window_size: 4096 + dropout: 0.0 + mlp: + intermediate_size: 14336 + add_linear_biases: false + gated: true + activation: silu + normalization: + type: rms_norm + epsilon: 1.0e-05 + dropout: 0.0 + num_blocks: 32 output_layer: tied_weight: false normalization: diff --git a/fast_llm/config.py b/fast_llm/config.py index 3352f3570..5284d8bee 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -783,11 +783,12 @@ def _from_dict( try: actual_cls = cls.get_subclass(default.get("type")) - if actual_cls is not None and actual_cls is not cls: - return actual_cls._from_dict(default, strict=strict, flat=flat) except KeyError: - # Postpone error to validation. - pass + # Try to postpone error to validation. + actual_cls = cls + + if actual_cls is not None and actual_cls is not cls: + return actual_cls._from_dict(default, strict=strict, flat=flat) # Do not validate yet in case the root class sets cross-dependencies in validation. with NoAutoValidate(): @@ -1121,3 +1122,6 @@ def pop_nested_dict_value[ return d.pop(keys[-1]) else: return d.pop(keys) + + +i = 0 diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 9de5ac2cc..0a3f8d1ce 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -1,5 +1,4 @@ import abc -import dataclasses import typing import torch @@ -78,17 +77,6 @@ def setup(self, distributed: Distributed) -> None: layer.setup(distributed) -@dataclasses.dataclass() -class LossDef: - # A name for the loss - name: str - formatted_name: str - # The number of times this loss is evaluated by the model for each micro-batch. Used as a denominator for averaging. - # TODO: Allow variable count? Would need a reduction across PP devices. - count: int = 1 - dtype: torch.dtype = torch.float32 - - class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], Sequential): def __init__( @@ -135,11 +123,6 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: # The name (dict key) is used to insert the weight in the kwargs of the forward pass. return {} - @property - @abc.abstractmethod - def loss_defs(self) -> list[LossDef]: - pass - def add_reference_model(self, name: str, inference_runner: "InferenceRunner") -> None: assert name not in self._reference_models assert not self._is_setup diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 2b55d782e..78fafea34 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -3,6 +3,7 @@ import typing from fast_llm.config import MISSING, Config, Field, FieldHint, FieldVerboseLevel, config_class +from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import compare_nested, log if typing.TYPE_CHECKING: @@ -75,3 +76,14 @@ class ResourceUsageConfig: forward: int = 1 # Number of backward passes. Typically 1 for training, 0 for inference. backward: int = 1 + + +@dataclasses.dataclass() +class LossDef: + # A name for the loss + name: str + formatted_name: str + # The number of times this loss is evaluated by the model for each micro-batch. Used as a denominator for averaging. + # TODO: Allow variable count? Would need a reduction across PP devices. + count: int = 1 + dtype: DataType = DataType.float32 diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 72db80f6a..886c706c1 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -1,5 +1,4 @@ import abc -import dataclasses import logging import pathlib import typing @@ -7,7 +6,7 @@ import torch from fast_llm import __version__ -from fast_llm.config import MISSING, get_nested_dict_value, set_nested_dict_value +from fast_llm.config import Config from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.checkpoint.config import CheckpointLoadMetadataConfig from fast_llm.engine.checkpoint.state_dict import StateDictCheckpointHandler @@ -19,124 +18,12 @@ logger = logging.getLogger(__name__) -@dataclasses.dataclass(kw_only=True) -class ParamConverter(abc.ABC): - fast_llm_names: tuple[tuple[str, ...], ...] = () # Array of fast-llm names, in nested (tuple) format. - export_names: tuple[tuple[str, ...], ...] = () # Array of export names, in nested (tuple) format. - - @abc.abstractmethod - def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - pass - - @abc.abstractmethod - def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - pass - - -@dataclasses.dataclass(kw_only=True) -class RenameParamConverter(ParamConverter): - ignore_missing: bool = False - default_value: typing.Any = None - - def __post_init__(self) -> None: - Assert.eq(len(self.fast_llm_names), 1) - Assert.eq(len(self.export_names), 1) - - def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - return fast_llm_values - - def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - if self.ignore_missing: - if export_values[0] == MISSING: - logger.warning( - "The configuration parameter `%s=%s` is ignored during conversion as it is not present in the checkpoint.", - self.export_names[0], - export_values[0], - ) - return (self.default_value,) - return export_values - - -# def __repr__(self): -# return f"RenameParamConverter({'.'.join(self.fast_llm_names[0])} <--> {'.'.join(self.export_names[0])})" - - -@dataclasses.dataclass(kw_only=True) -class ConstantImportParamConverter(ParamConverter): - fast_llm_value: typing.Any = MISSING - - def __post_init__(self): - Assert.eq(len(self.fast_llm_names), 1) - Assert.eq(len(self.export_names), 0) - - def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - Assert.eq(fast_llm_values[0], self.fast_llm_value) - return () - - def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - return (self.fast_llm_value,) - - -@dataclasses.dataclass(kw_only=True) -class ConstantExportParamConverter(ParamConverter): - export_value: typing.Any = MISSING - - def __post_init__(self): - Assert.eq(len(self.fast_llm_names), 0) - Assert.eq(len(self.export_names), 1) - - def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - return (self.export_value,) - - def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - Assert.eq(export_values[0], self.export_value) - return () - - -@dataclasses.dataclass(kw_only=True) -class IgnoreImportParamConverter(ParamConverter): - ignore_export_value: typing.Any = MISSING - - def __post_init__(self): - Assert.eq(len(self.fast_llm_names), 0) - Assert.eq(len(self.export_names), 1) - - def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - return (MISSING,) - - def import_params(self, export_values): - if export_values[0] not in (self.ignore_export_value, MISSING): - logger.warning( - "The configuration parameter `%s=%s` is ignored during conversion." - " If you intend to use it in Fast-LLM, make sure to set it explicitly in the model configuration.", - self.export_names[0], - export_values[0], - ) - return () - - -@dataclasses.dataclass(kw_only=True) -class MappedConfigParamConverter(ParamConverter): - fast_llm_value: typing.Callable[[typing.Any], typing.Any] = lambda x: x - export_value: typing.Callable[[typing.Any], typing.Any] = lambda x: x - - def __post_init__(self) -> None: - Assert.eq(len(self.fast_llm_names), 1) - Assert.eq(len(self.export_names), 1) - - def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - return (self.export_value(fast_llm_values[0]),) - - def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - return (self.fast_llm_value(export_values[0]),) - - class WeightConverter: def __init__( self, fast_llm_name: str | tuple[str, ...], export_name: str | tuple[str, ...], - config: BaseModelConfig | None = None, + config: Config | None = None, ): self.fast_llm_name: tuple[str, ...] = (fast_llm_name,) if isinstance(fast_llm_name, str) else fast_llm_name self.export_name: tuple[str, ...] = (export_name,) if isinstance(export_name, str) else export_name @@ -216,7 +103,6 @@ def import_weight( class ExternalStateDictCheckpointHandler(StateDictCheckpointHandler): _model_class: typing.ClassVar[FastLLMModelConfig] - _config_converters: list[ParamConverter] def __init__(self, model: "FastLLMModel"): super().__init__(model) @@ -239,20 +125,14 @@ def __init__(self, model: "FastLLMModel"): @classmethod def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: - imported_model_config = cls._import_config(cls._load_config(config.path)) return CheckpointMetadata( fast_llm_version=__version__, model=cls._model_class, format=config.format, - config=cls._model_class.from_dict({"base_model": imported_model_config.to_dict()}), + config=cls._import_config(cls._load_config(config.path)), shards=["weights"], ) - @classmethod - @abc.abstractmethod - def _create_config_converters(cls) -> list[ParamConverter]: - pass - @abc.abstractmethod def _create_weight_converters(self) -> list[WeightConverter]: pass @@ -263,51 +143,15 @@ def _load_config(cls, directory: pathlib.Path | str) -> dict: pass @classmethod - def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]: - # TODO v0.3: not used in this class - exported_config = {} - for converter in cls._get_config_converters(): - try: - values = converter.export_params( - tuple( - cls._get_fast_llm_attribute(config, fast_llm_name) - for fast_llm_name in converter.fast_llm_names - ) - ) - for export_name, value in zip(converter.export_names, values, strict=True): - if value is not MISSING: - set_nested_dict_value(exported_config, export_name, value) - except Exception as e: - raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) - - return exported_config # Noqa + @abc.abstractmethod + def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]: + # TODO: not used in this class + pass @classmethod - def _import_config(cls, config: dict[str, typing.Any]) -> BaseModelConfig: # noqa - kwargs = {} - for converter in cls._get_config_converters(): - try: - values = () - for export_name in converter.export_names: - try: - value = get_nested_dict_value(config, export_name) - except KeyError: - value = MISSING - values = values + (value,) - values = converter.import_params(values) - for fast_llm_name, value in zip(converter.fast_llm_names, values, strict=True): - if value is MISSING: - # Missing values need to be handled in dedicated converters, - # because implicit / default values may not match. - # TODO: Different behavior from other uses of MISSING. Use different tag? - raise ValueError(f"Missing converted value for fast-llm parameter {fast_llm_name}") - if fast_llm_name in kwargs: - raise ValueError(f"Duplicate converted value for fast-llm parameter {fast_llm_name}") - kwargs[fast_llm_name] = value - except Exception as e: - raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) - - return cls._model_class.get_base_model_config_class().from_dict({}, kwargs) + @abc.abstractmethod + def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig: + pass def _convert_state_dict( self, state_dict: dict[str, torch.Tensor | SafeTensorSlice], export: bool @@ -343,12 +187,6 @@ def _convert_state_dict( return out_state_dict - @classmethod - def _get_config_converters(cls) -> list[ParamConverter]: - if not hasattr(cls, "_config_converters"): - cls._config_converters = cls._create_config_converters() - return cls._config_converters - @staticmethod def _get_fast_llm_attribute(config: BaseModelConfig, name: str | tuple[str, ...]) -> typing.Any: if isinstance(name, str): @@ -374,6 +212,6 @@ def get_handler_class(cls, format: str) -> type[ExternalStateDictCheckpointHandl # TODO: load_metadata??? @classmethod - def _import_config(cls, config: dict[str, typing.Any]) -> BaseModelConfig: + def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig: # TODO: ??? return cls.handler_map[config["model_type"]]._import_config(config) diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 16b3e005f..e5d14711d 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -6,21 +6,47 @@ import safetensors import torch -from transformers.configuration_utils import PretrainedConfig +from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig, CheckpointSaveMetadataConfig -from fast_llm.engine.checkpoint.external import ( - ConstantExportParamConverter, - ExternalStateDictCheckpointHandler, - ParamConverter, - logger, -) -from fast_llm.engine.multi_stage.config import CheckpointMetadata +from fast_llm.engine.checkpoint.external import ExternalStateDictCheckpointHandler, WeightConverter, logger +from fast_llm.engine.multi_stage.config import CheckpointMetadata, FastLLMModelConfig from fast_llm.tensor import SafeTensorSlice -from fast_llm.utils import Assert +from fast_llm.utils import Assert, safe_merge_dicts + +if typing.TYPE_CHECKING: + import transformers + + +class HuggingFaceBaseModelConverter: + @classmethod + @abc.abstractmethod + def import_config(cls, config: dict) -> dict: + pass + + @classmethod + @abc.abstractmethod + def export_config(cls, config: BaseModelConfig) -> dict: + pass + + @classmethod + @abc.abstractmethod + def get_converters(cls, config: BaseModelConfig) -> list[WeightConverter]: + pass class HuggingfaceStateDictCheckpointHandler(ExternalStateDictCheckpointHandler, abc.ABC): + architecture: typing.ClassVar[str] + base_model_converter_class: typing.ClassVar[type[HuggingFaceBaseModelConverter]] + + @classmethod + @abc.abstractmethod + def get_transformers_configuration_class(cls) -> type["transformers.PretrainedConfig"]: + pass + + @classmethod + def get_model_files(cls) -> tuple[str | None, str | None, str | None]: + return None, None, None @classmethod def _save_serialized_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: dict, index: dict) -> None: @@ -35,7 +61,7 @@ def _save_serialized_metadata(cls, config: CheckpointSaveMetadataConfig, metadat ) def _serialize_metadata(self, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata) -> dict: - huggingface_config = self._export_config(self._model.config.base_model) + huggingface_config = self._export_config(self._model.config) self._save_config(config.path, huggingface_config) return { "fast_llm_metadata": metadata.to_dict(), @@ -49,6 +75,20 @@ def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: self._model.config.base_model.compare_architecture(metadata.config.base_model, logger.warning) super().load(config) + def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None: + super().save(config, metadata) + # Copy the modeling files to the output directory + modeling_file, configuration_file, generation_utils_file = self.get_model_files() + if configuration_file is not None: + shutil.copy(configuration_file, config.path) + if modeling_file is not None: + shutil.copy(modeling_file, config.path) + if generation_utils_file is not None: + shutil.copy(generation_utils_file, config.path) + gen_config = pathlib.Path(generation_utils_file).parent / "generation_config.json" + if gen_config.exists(): + shutil.copy(gen_config, config.path) + @classmethod def get_huggingface_model_type(self) -> str: # We assume the two names match, but derived classes can make it different. @@ -59,28 +99,37 @@ def _get_key(cls, parameter_name: str, shard_name: str) -> str: Assert.eq(shard_name, "weights") return parameter_name - @classmethod - @abc.abstractmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return [ - ConstantExportParamConverter( - export_names=(("model_type",),), export_value=cls.get_huggingface_model_type() - ) - ] - + # Use custom config instead of relying on the transformers library @classmethod def _load_config(cls, directory: pathlib.Path | str) -> dict: - import transformers - - config = transformers.AutoConfig.from_pretrained(directory).to_dict() + config = cls.get_transformers_configuration_class().from_pretrained(directory).to_dict() Assert.eq(config["model_type"], cls.get_huggingface_model_type()) return config @classmethod def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: - import transformers + cls.get_transformers_configuration_class().from_dict(config).save_pretrained(directory) + + @classmethod + def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]: + return safe_merge_dicts( + cls.base_model_converter_class.export_config(config.base_model), + { + "model_type": cls.get_huggingface_model_type(), + "architecture": cls.architecture, + }, + ) + + @classmethod + def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig: + Assert.eq(config["model_type"], cls.get_huggingface_model_type()) + Assert.eq(config["architecture"], cls.architecture) + return cls._model_class.from_dict({"base_model": cls.base_model_converter_class.import_config(config)}) - transformers.CONFIG_MAPPING[config["model_type"]].from_dict(config).save_pretrained(directory) + def _create_weight_converters( + self, + ) -> list[WeightConverter]: + return self.base_model_converter_class.get_converters(self._model.config.base_model) def _load_weights( self, config: CheckpointLoadConfig, device @@ -123,39 +172,3 @@ def _load_weights( yield from torch.load(path) else: raise NotImplementedError(f"Unknown file format for {path}") - - -class CustomModelingExportMixin: - """ - Mixin class for HuggingfaceStateDictCheckpointHandler to handle custom modeling files. - """ - - modeling_file: typing.ClassVar[str] - configuration_file: typing.ClassVar[str] - configuration_cls: typing.ClassVar[type[PretrainedConfig]] - generation_utils_file: str | None = None - - # Use custom config instead of relying on the transformers library - @classmethod - def _load_config(cls, directory: pathlib.Path | str) -> dict: - config = cls.configuration_cls.from_pretrained(directory).to_dict() - Assert.eq(config["model_type"], cls.get_huggingface_model_type()) - return config - - @classmethod - def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: - cls.configuration_cls.from_dict(config).save_pretrained(directory) - - def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None: - super().save(config, metadata) - self._copy_modeling_files(config) - - def _copy_modeling_files(self, config: CheckpointSaveConfig) -> None: - # Copy the modeling files to the output directory - shutil.copy(self.modeling_file, config.path) - shutil.copy(self.configuration_file, config.path) - if self.generation_utils_file: - shutil.copy(self.generation_utils_file, config.path) - gen_config = pathlib.Path(self.generation_utils_file).parent / "generation_config.json" - if gen_config.exists(): - shutil.copy(gen_config, config.path) diff --git a/fast_llm/engine/evaluation/evaluator.py b/fast_llm/engine/evaluation/evaluator.py index 33e4d654f..d5202a90f 100644 --- a/fast_llm/engine/evaluation/evaluator.py +++ b/fast_llm/engine/evaluation/evaluator.py @@ -116,7 +116,7 @@ def setup( phase=PhaseType.validation, ) - self._loss_defs = self._multi_stage.base_model.loss_defs + self._loss_defs = self._multi_stage.base_model.config.get_loss_definitions() self._evaluation_iterator = None self._is_setup = True diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index b38056adb..e48fdb88b 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -468,9 +468,7 @@ def get_state_tensor_iterator( ) -> typing.Generator[tuple[str, str, torch.Tensor], None, None]: for shard_name in shard_names: shard_split = self._shards[shard_name].split(self._stage_weight_shard_sizes, 0) - for shard_index, (stage, shard) in enumerate( - zip(self._stages_on_device.values(), shard_split, strict=True) - ): + for shard_index, (stage, shard) in enumerate(zip(self._stages_owned.values(), shard_split, strict=True)): for name, tensor in stage._export_shard( shard.split(self._fsdp_weight_shard_sizes[shard_index]), data_type=data_type ): # noqa diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 21ecbe476..dbdd035a4 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -93,7 +93,10 @@ def __init__( self._stages: list[Stage] = self._multi_stage.stages self._tied_parameters = self._multi_stage.tied_parameters self._num_stages = len(self._stages) - self._loss_defs = {loss_def.name: loss_def for loss_def in self._multi_stage.base_model.loss_defs} + self._loss_definitions = { + loss_definition.name: loss_definition + for loss_definition in self._multi_stage.base_model.config.get_loss_definitions() + } def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> None: assert not self._is_setup @@ -148,7 +151,7 @@ def run_step( context = BatchContext( iteration=iteration, schedule=schedule, - losses={loss_def: [] for loss_def in self._loss_defs}, + losses={loss_def: [] for loss_def in self._loss_definitions}, metrics=metrics, ) context.data_iterator = self._preprocess_data(context, data_iterator, preprocessed) @@ -280,11 +283,13 @@ def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: for name, losses in context.losses.items(): if losses or self._distributed.pipeline_group: if losses: - reduced_loss = torch.stack(losses).sum() / num_inputs / self._loss_defs[name].count + reduced_loss = torch.stack(losses).sum() / num_inputs / self._loss_definitions[name].count if self._distributed.data_group: all_reduce(reduced_loss, group=self._distributed.data_group) else: - reduced_loss = torch.zeros([1], dtype=self._loss_defs[name].dtype, device=self._distributed.device) + reduced_loss = torch.zeros( + [1], dtype=self._loss_definitions[name].dtype.torch, device=self._distributed.device + ) if self._distributed.pipeline_group: all_reduce(reduced_loss, group=self._distributed.pipeline_group) else: diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index b500a1fda..32f73fc43 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -149,7 +149,7 @@ def __init__(self, config: TrainerConfig): multi_stage=self._multi_stage, distributed_config=self._config.model.distributed, ) - self._loss_defs = self._multi_stage.base_model.loss_defs + self._loss_defs = self._multi_stage.base_model.config.get_loss_definitions() if not self._is_evaluation_only: steps_per_split = { diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index bbd70ede4..9a940f4cb 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -10,9 +10,9 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs -from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.decoder.block import BlockWithBias from fast_llm.tensor import TensorMeta from fast_llm.utils import div @@ -49,7 +49,7 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention[ConfigType: AttentionConfig](BlockLayer[ConfigType]): +class Attention[ConfigType: AttentionConfig](BlockWithBias[ConfigType]): """ A self-attention layer. """ diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 868d6ba77..214bb7729 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -9,8 +9,9 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig from fast_llm.layers.attention.rotary.config import RotaryConfig -from fast_llm.layers.block.config import BlockKwargs, MixerConfig +from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.linear.config import AffineLinearConfig +from fast_llm.layers.decoder.config import MixerConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 5187ebfdc..773cce87e 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -1,4 +1,3 @@ -import abc import functools import logging import typing @@ -6,14 +5,12 @@ import torch from fast_llm.config import Config, Configurable -from fast_llm.core.distributed import set_generator from fast_llm.engine.base_model.base_model import Layer, Module from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.layers.block.config import BlockConfig, BlockKwargs +from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.logging import get_model_debug_level, log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta @@ -22,6 +19,10 @@ class DebugLayer: + """ + A debugging utility for blocks. + """ + # TODO: Move elsewhere? def __init__(self, module: torch.nn.Module): self._module = module @@ -92,9 +93,9 @@ def __call__[ ) -class BlockLayerBase[ConfigType: Config](Configurable[ConfigType], Module): +class BaseBlock[ConfigType: Config](Configurable[ConfigType], Module): """ - Base class for blocks, mixers, MLPs, etc. + Base class for blocks and block-like layers (mlp, mixers, etc.). """ def __init__( @@ -118,25 +119,9 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c raise NotImplementedError() -class BlockLayer[ConfigType: Config](BlockLayerBase[ConfigType]): +class Block[ConfigType: Config](BaseBlock[ConfigType], Layer): """ - Base class for mixer and MLP modules. - """ - - @abc.abstractmethod - def forward( - self, - input_: torch.Tensor, - kwargs: dict[str, typing.Any], - losses: dict[str, typing.Any] | None = None, - metrics: dict[str, typing.Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - pass - - -class Block[ConfigType: BlockConfig](BlockLayerBase[ConfigType], Layer): - """ - A transformer-like decoder base block with abstract mixer. + Base class for actual blocks, i.e., base blocks that are also `Layers`. """ def __init__( @@ -149,103 +134,5 @@ def __init__( peft: PeftConfig | None, return_input: bool = False, ): - super().__init__( - config, - distributed_config, - hidden_dim=hidden_dim, - lr_scale=lr_scale, - peft=peft, - ) - # For multi-token prediction, return a stack of shared_hidden and transformer_output. - self._return_input: bool = return_input - # Note, layer_lr_scale does not impact the norms - # TODO: add a separate norm_lr_scale - self.norm_1 = self._config.normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) - self.norm_2 = self._config.normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) - - # Attribute should be mixer, but Attention uses a different name for backward compatibility. TODO: Fix. - self.mixer = self._config.mixer.get_layer( - self._distributed_config, - self._hidden_dim, - self._lr_scale, - peft=peft, - ) - - self.mlp = self._config.mlp.get_layer( - self._distributed_config, - self._hidden_dim, - self._lr_scale, - peft=peft, - ) - - def setup(self, distributed: Distributed) -> None: - super().setup(distributed) - self.mixer.setup(distributed) - self.mlp.setup(distributed) - - @torch.compile - def _bias_dropout_add( - self, input_: torch.Tensor, bias: torch.Tensor | None, residual: torch.Tensor - ) -> torch.Tensor: - if bias is not None: - input_ = input_ + bias - return residual + torch.dropout(input_, self._config.dropout, self.training) - - def forward( - self, - input_: torch.Tensor, - kwargs: dict[str, typing.Any], - losses: dict[str, typing.Any] | None = None, - metrics: dict[str, typing.Any] | None = None, - ) -> torch.Tensor: - if isinstance(input_, TensorMeta): - dims = kwargs[BlockKwargs.hidden_dims] - if self._return_input: - dims = (TensorDim("stacked_input_output", 2),) + dims - return TensorMeta.from_dims(dims, tensor_name=f"{self.module_name} output", dtype=input_.dtype) - generator = self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator - if self._debug.enabled: - self._debug(None, "begin", kwargs[BlockKwargs.hidden_dims], kwargs) - fw_input = input_ - hidden_states = self.norm_1(input_) - if self._debug.enabled: - self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) - hidden_states, bias = self.mixer(hidden_states, kwargs) - if self._debug.enabled: - self._debug( - hidden_states if bias is None else hidden_states + bias, - "mixer output", - kwargs[BlockKwargs.hidden_dims], - kwargs, - ) - with set_generator(generator): - input_ = self._bias_dropout_add(hidden_states, bias, input_) - if self._debug.enabled: - self._debug(input_, "mixer residual", kwargs[BlockKwargs.hidden_dims], kwargs) - hidden_states = self.norm_2(input_) - if self._debug.enabled: - self._debug(hidden_states, "norm 2", kwargs[BlockKwargs.hidden_dims], kwargs) - hidden_states, bias = self.mlp(hidden_states, kwargs, losses, metrics) - if self._debug.enabled: - self._debug( - hidden_states if bias is None else hidden_states + bias, - "MLP output", - kwargs[BlockKwargs.hidden_dims], - kwargs, - ) - with set_generator(generator): - hidden_states = self._bias_dropout_add(hidden_states, bias, input_) - if self._debug.enabled: - self._debug(None, "MLP residual", kwargs[BlockKwargs.hidden_dims], kwargs) - if self._return_input: - hidden_states = torch.stack((fw_input, hidden_states), dim=0) - return hidden_states - - def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: - # TODO: Add marginal compute? (normalization, bias_dropout_add) - return sum( - ( - self.mixer.get_compute_usage(input_, kwargs, config), - self.mlp.get_compute_usage(input_, kwargs, config), - ) - ) + super().__init__(config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft) + self._return_input = return_input diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index fd42bccf9..7df2705fa 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -1,19 +1,19 @@ +import abc +import collections +import functools import typing +import warnings from fast_llm.config import Field, FieldHint, check_field, config_class -from fast_llm.engine.base_model.config import BaseModelConfig, Preprocessor +from fast_llm.engine.base_model.config import BaseModelConfig, LossDef, Preprocessor from fast_llm.engine.config_utils.parameter import combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.layers.block.block import BlockLayer - - -# TODO: Generalize these beyond language models? (Ex. vision) + from fast_llm.layers.block.block import Block class BlockDimNames: @@ -41,13 +41,12 @@ class BlockKwargs: @config_class() -class BlockLayerConfig(BaseModelConfig): +class BaseBlockConfig(BaseModelConfig): """ - A common class for mixers and mlps, which have the same interface. + Base configuration class for blocks and block-like layers (mlp, mixers, etc.). """ _abstract = True - block: "BlockConfig" = Field(init=False) lr_scale: float | None = Field( default=None, @@ -56,34 +55,57 @@ class BlockLayerConfig(BaseModelConfig): hint=FieldHint.feature, ) + def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: + return [] + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + return [] + + +@config_class(registry=True) +class BlockConfig(BaseBlockConfig): + """ + Base configuration class for actual blocks, i.e., base blocks that are also `Layers`. + """ + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is BlockConfig and cls.get_subclass(default.get("type")) is None: + from fast_llm.layers.decoder.config import DecoderBlockConfig + + # Default subclass. + return DecoderBlockConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + @property - def layer_class(self) -> "type[BlockLayer]": + def layer_class(self) -> "type[Block]": raise NotImplementedError() - def get_layer( + def get_block( self, distributed_config: DistributedConfig, hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, - ) -> "BlockLayer": + return_input: bool = False, + ) -> "Block": return self.layer_class( self, distributed_config, hidden_dim=hidden_dim, lr_scale=combine_lr_scales(lr_scale, self.lr_scale), peft=peft, + return_input=return_input, ) - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - # TODO: Move to actual layers? - return [] - @config_class(registry=True) -class MLPBaseConfig(BlockLayerConfig): - _abstract = True - +class BlockSequenceConfig(BaseModelConfig): @classmethod def _from_dict( cls, @@ -91,91 +113,105 @@ def _from_dict( strict: bool = True, flat: bool = False, ) -> typing.Self: - if cls is MLPBaseConfig and cls.get_subclass(default.get("type")) is None: - from fast_llm.layers.block.mlp.config import MLPConfig - + if cls is BlockSequenceConfig and cls.get_subclass(default.get("type")) is None: # Default subclass. - return MLPConfig._from_dict(default, strict, flat) + return FixedBlockSequenceConfig._from_dict(default, strict, flat) return super()._from_dict(default, strict=strict, flat=flat) + @abc.abstractmethod + def __len__(self) -> int: + pass -@config_class(registry=True) -class MixerConfig(BlockLayerConfig): - """ - Base config class for mixers. - """ + @abc.abstractmethod + def __getitem__(self, index: int) -> BlockConfig: + pass - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - if cls is MixerConfig and cls.get_subclass(default.get("type")) is None: - from fast_llm.layers.attention.config import AttentionConfig + @abc.abstractmethod + def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: + pass - # Default subclass. - return AttentionConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + return [] -@config_class() -class BlockConfig(BaseModelConfig): +@config_class(dynamic_type={BlockSequenceConfig: "fixed"}) +class FixedBlockSequenceConfig(BlockSequenceConfig): _abstract = False - mixer: MixerConfig = Field() - mlp: MLPBaseConfig = Field() - # TODO: Review names - normalization: NormalizationConfig = Field( - desc="Configuration for the block normalization layers.", + block: BlockConfig = Field( + desc="Common configuration for all the blocks.", hint=FieldHint.architecture, ) - lr_scale: float | None = Field( - default=None, - desc="Scaling factor for the layer learning rate." - " Combines multiplicatively with the scale set by the parent and child layers, if applicable.", - hint=FieldHint.feature, - ) - # TODO: Review names - dropout: float = Field( - default=0.0, - desc="Dropout applied to the residual connections.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - # TODO: Move these, not specific to a single block. - num_layers: int = Field( + num_blocks: int = Field( default=12, desc="Number of blocks in the model.", hint=FieldHint.architecture, valid=check_field(Assert.geq, 0), ) - hidden_size: int = Field( - default=1024, - desc="Size of the transformer's main hidden dimension, e.g., for its input and output layers.", + + def __len__(self) -> int: + return self.num_blocks + + def __getitem__(self, index: int) -> BlockConfig: + return self.block + + def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: + # TODO: Prevent name conflicts in preprocessed kwargs. + return self.block.get_preprocessors(distributed_config) + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + return self.block.get_loss_definitions(count=count * self.num_blocks) + + +@config_class(dynamic_type={BlockSequenceConfig: "pattern"}) +class PatternBlockSequenceConfig(BlockSequenceConfig): + _abstract = False + blocks: dict[str, BlockConfig] = Field() + pattern: list[str] = Field( + default=None, + desc="The name of each block (key in `blocks`) in the repeated pattern.", + hint=FieldHint.architecture, + ) + num_blocks: int = Field( + default=12, + desc="Number of blocks in the model.", hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), + valid=check_field(Assert.geq, 0), ) - def get_layer( - self, - distributed_config: DistributedConfig, - hidden_dim: TensorDim, - lr_scale: float | None, - peft: PeftConfig | None = None, - return_input: bool = False, - ): - from fast_llm.layers.block.block import Block + def _validate(self): + if not self.blocks: + raise ValueError("No block configuration provided") + if not self.pattern: + raise ValueError("No block pattern provided") + used_blocks = set(self.pattern) + available_blocks = set(self.blocks) + if missing := used_blocks - available_blocks: + raise ValueError(f"The following blocks are present in the pattern but undefined: {missing}") + if extra := available_blocks - used_blocks: + raise warnings.warn(f"The following blocks are defined but unused: {extra}") - return Block( - self, - distributed_config, - hidden_dim=hidden_dim, - lr_scale=combine_lr_scales(lr_scale, self.lr_scale), - peft=peft, - return_input=return_input, - ) + super()._validate() + + def __len__(self) -> int: + return self.num_blocks + + def __getitem__(self, index: int) -> BlockConfig: + return self.blocks[self.expanded_pattern[index]] + + @functools.cached_property + def expanded_pattern(self) -> list[str]: + return (self.pattern * (self.num_blocks // len(self.pattern) + 1))[: self.num_blocks] def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - # TODO: Move to actual layers? - return self.mixer.get_preprocessors(distributed_config) + self.mlp.get_preprocessors(distributed_config) + # TODO: Prevent name conflicts in preprocessed kwargs. + return sum((block.get_preprocessors(distributed_config) for block in self.blocks.values()), []) + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + # TODO: Prevent name conflicts. + return sum( + ( + self.blocks[name].get_loss_definitions(count=count * count_) + for name, count_ in collections.Counter(self.expanded_pattern).items() + ), + [], + ) diff --git a/fast_llm/layers/block/mlp/__init__.py b/fast_llm/layers/block/sequence.py similarity index 100% rename from fast_llm/layers/block/mlp/__init__.py rename to fast_llm/layers/block/sequence.py diff --git a/fast_llm/layers/common/linear/config.py b/fast_llm/layers/common/linear/config.py index 2ed97ae66..e7c6d9e92 100644 --- a/fast_llm/layers/common/linear/config.py +++ b/fast_llm/layers/common/linear/config.py @@ -19,16 +19,16 @@ class LinearBaseConfig(Config): Configuration for a linear-like layer without bias. """ + weight: ParameterConfig = Field( + desc="Configuration for the weight.", + hint=FieldHint.architecture, + ) lr_scale: float | None = Field( default=None, desc="Scaling factor for the layer learning rate." " Combines multiplicatively with the scale set by the parent layer and individual parameters, if applicable.", hint=FieldHint.feature, ) - weight: ParameterConfig = Field( - desc="Initialization configuration for the weight.", - hint=FieldHint.feature, - ) @config_class() diff --git a/fast_llm/layers/common/normalization/config.py b/fast_llm/layers/common/normalization/config.py index 3401e61be..33cbd9768 100644 --- a/fast_llm/layers/common/normalization/config.py +++ b/fast_llm/layers/common/normalization/config.py @@ -4,7 +4,7 @@ from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.parameter import combine_lr_scales +from fast_llm.engine.config_utils.parameter import ParameterConfig, combine_lr_scales from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert @@ -81,7 +81,10 @@ class LayerNormalizationBaseConfig(NormalizationConfig): Common configuration for layer norm and rms norm """ - # TODO: Rename to normalization_epsilon + weight: ParameterConfig = Field( + desc="Configuration for the weight.", + hint=FieldHint.architecture, + ) epsilon: float = Field( default=1e-5, desc="Regularizer for the division.", @@ -98,13 +101,6 @@ class LayerNormalizationBaseConfig(NormalizationConfig): desc="The implementation to use for the normalization layer.", hint=FieldHint.performance, ) - # TODO: Rename to normalization_init_range - initialization_range: float = Field( - default=0.0, - desc="Randomize the initialization with a uniform noise. Used to test for issues that may not be visible with the default initialization.", - hint=FieldHint.testing, - valid=check_field(Assert.geq, 0), - ) @property @abc.abstractmethod @@ -128,6 +124,10 @@ def _from_dict( @config_class(dynamic_type={NormalizationConfig: "layer_norm"}) class LayerNormalizationConfig(LayerNormalizationBaseConfig): + bias: ParameterConfig = Field( + desc="Configuration for the weight.", + hint=FieldHint.architecture, + ) _abstract = False @property diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index 0dc7b9589..d0a5ab151 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -3,7 +3,7 @@ import torch from fast_llm.config import Configurable -from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_, init_zeros_ +from fast_llm.engine.config_utils.initialization import init_ones_, init_zeros_ from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.config import TritonConfig @@ -15,7 +15,7 @@ NormalizationImplementation, RMSNormalizationConfig, ) -from fast_llm.tensor import ParameterMeta, accumulate_gradient +from fast_llm.tensor import accumulate_gradient from fast_llm.utils import Assert try: @@ -205,23 +205,17 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | else: raise NotImplementedError(implementation) - if self.config.initialization_range: - mean = 0 if self.zero_centered else 1 - weight_init_method = init_uniform_centered_(self.config.initialization_range, mean=mean) - else: - weight_init_method = init_zeros_ if self._config.zero_centered else init_ones_ - - self.weight = ParameterMeta.from_dims( + self.weight = self._config.weight.get_parameter( (hidden_dim,), - init_method=weight_init_method, - weight_decay=False, + default_initialization=init_zeros_ if self._config.zero_centered else init_ones_, lr_scale=self._lr_scale, + peft=None, ) - self.bias = ParameterMeta.from_dims( + self.bias = self._config.bias.get_parameter( (hidden_dim,), - init_method=init_zeros_, - weight_decay=False, + default_initialization=init_zeros_, lr_scale=self._lr_scale, + peft=None, ) self._normalized_shape = self.weight.shape @@ -277,17 +271,11 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | else: raise NotImplementedError(implementation) - if self.config.initialization_range: - mean = 0 if self.zero_centered else 1 - weight_init_method = init_uniform_centered_(self.config.initialization_range, mean=mean) - else: - weight_init_method = init_zeros_ if self._config.zero_centered else init_ones_ - - self.weight = ParameterMeta.from_dims( + self.weight = self._config.weight.get_parameter( (hidden_dim,), - init_method=weight_init_method, - weight_decay=False, - lr_scale=lr_scale, + default_initialization=init_zeros_ if self._config.zero_centered else init_ones_, + lr_scale=self._lr_scale, + peft=None, ) self._normalized_shape = self.weight.shape diff --git a/fast_llm/models/custom/__init__.py b/fast_llm/layers/decoder/__init__.py similarity index 100% rename from fast_llm/models/custom/__init__.py rename to fast_llm/layers/decoder/__init__.py diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py new file mode 100644 index 000000000..ba4c370c2 --- /dev/null +++ b/fast_llm/layers/decoder/block.py @@ -0,0 +1,152 @@ +import abc +import logging +import typing + +import torch + +from fast_llm.config import Config +from fast_llm.core.distributed import set_generator +from fast_llm.engine.base_model.config import ResourceUsageConfig +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.block.block import BaseBlock, Block +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.tensor import TensorMeta + +logger = logging.getLogger(__name__) + + +class BlockWithBias[ConfigType: Config](BaseBlock[ConfigType]): + """ + Base class for mixer and MLP modules. + """ + + @abc.abstractmethod + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + pass + + +class DecoderBlock[ConfigType: DecoderBlockConfig](Block[ConfigType]): + """ + A transformer-like decoder base block with abstract mixer. + """ + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + return_input: bool = False, + ): + super().__init__( + config, + distributed_config, + hidden_dim=hidden_dim, + lr_scale=lr_scale, + peft=peft, + ) + # For multi-token prediction, return a stack of shared_hidden and transformer_output. + self._return_input: bool = return_input + # Note, layer_lr_scale does not impact the norms + # TODO: add a separate norm_lr_scale + self.norm_1 = self._config.normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) + self.norm_2 = self._config.normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) + + # Attribute should be mixer, but Attention uses a different name for backward compatibility. TODO: Fix. + self.mixer = self._config.mixer.get_layer( + self._distributed_config, + self._hidden_dim, + self._lr_scale, + peft=peft, + ) + + self.mlp = self._config.mlp.get_layer( + self._distributed_config, + self._hidden_dim, + self._lr_scale, + peft=peft, + ) + + def setup(self, distributed: Distributed) -> None: + super().setup(distributed) + self.mixer.setup(distributed) + self.mlp.setup(distributed) + + @torch.compile + def _bias_dropout_add( + self, input_: torch.Tensor, bias: torch.Tensor | None, residual: torch.Tensor + ) -> torch.Tensor: + if bias is not None: + input_ = input_ + bias + return residual + torch.dropout(input_, self._config.dropout, self.training) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> torch.Tensor: + if isinstance(input_, TensorMeta): + dims = kwargs[BlockKwargs.hidden_dims] + if self._return_input: + dims = (TensorDim("stacked_input_output", 2),) + dims + return TensorMeta.from_dims(dims, tensor_name=f"{self.module_name} output", dtype=input_.dtype) + generator = self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator + if self._debug.enabled: + self._debug(None, "begin", kwargs[BlockKwargs.hidden_dims], kwargs) + fw_input = input_ + hidden_states = self.norm_1(input_) + if self._debug.enabled: + self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) + hidden_states, bias = self.mixer(hidden_states, kwargs) + if self._debug.enabled: + self._debug( + hidden_states if bias is None else hidden_states + bias, + "mixer output", + kwargs[BlockKwargs.hidden_dims], + kwargs, + ) + with set_generator(generator): + input_ = self._bias_dropout_add(hidden_states, bias, input_) + if self._debug.enabled: + self._debug(input_, "mixer residual", kwargs[BlockKwargs.hidden_dims], kwargs) + hidden_states = self.norm_2(input_) + if self._debug.enabled: + self._debug(hidden_states, "norm 2", kwargs[BlockKwargs.hidden_dims], kwargs) + hidden_states, bias = self.mlp(hidden_states, kwargs, losses, metrics) + if self._debug.enabled: + self._debug( + hidden_states if bias is None else hidden_states + bias, + "MLP output", + kwargs[BlockKwargs.hidden_dims], + kwargs, + ) + with set_generator(generator): + hidden_states = self._bias_dropout_add(hidden_states, bias, input_) + if self._debug.enabled: + self._debug(None, "MLP residual", kwargs[BlockKwargs.hidden_dims], kwargs) + if self._return_input: + hidden_states = torch.stack((fw_input, hidden_states), dim=0) + return hidden_states + + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + # TODO: Add marginal compute? (normalization, bias_dropout_add) + return sum( + ( + self.mixer.get_compute_usage(input_, kwargs, config), + self.mlp.get_compute_usage(input_, kwargs, config), + ) + ) diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py new file mode 100644 index 000000000..2d8cc71fd --- /dev/null +++ b/fast_llm/layers/decoder/config.py @@ -0,0 +1,111 @@ +import typing + +from fast_llm.config import Field, FieldHint, check_field, config_class +from fast_llm.engine.base_model.config import LossDef, Preprocessor +from fast_llm.engine.config_utils.parameter import combine_lr_scales +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.block.config import BaseBlockConfig, BlockConfig +from fast_llm.layers.common.normalization.config import NormalizationConfig +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + from fast_llm.layers.decoder.block import BlockWithBias, DecoderBlock + + +@config_class() +class BlockWithBiasConfig(BaseBlockConfig): + """ + A common interface for various blocks and block layers. + """ + + @property + def layer_class(self) -> "type[BlockWithBias]": + raise NotImplementedError() + + def get_layer( + self, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + ) -> "BlockWithBias": + return self.layer_class( + self, + distributed_config, + hidden_dim=hidden_dim, + lr_scale=combine_lr_scales(lr_scale, self.lr_scale), + peft=peft, + ) + + +@config_class(registry=True) +class MLPBaseConfig(BlockWithBiasConfig): + _abstract = True + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is MLPBaseConfig and cls.get_subclass(default.get("type")) is None: + from fast_llm.layers.decoder.mlp.config import MLPConfig + + # Default subclass. + return MLPConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + + +@config_class(registry=True) +class MixerConfig(BlockWithBiasConfig): + """ + Base config class for mixers. + """ + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is MixerConfig and cls.get_subclass(default.get("type")) is None: + from fast_llm.layers.attention.config import AttentionConfig + + # Default subclass. + return AttentionConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + + +@config_class(dynamic_type={BlockConfig: "decoder"}) +class DecoderBlockConfig(BlockConfig): + _abstract = False + mixer: MixerConfig = Field() + mlp: MLPBaseConfig = Field() + # TODO: Review names + normalization: NormalizationConfig = Field( + desc="Configuration for the block normalization layers.", + hint=FieldHint.architecture, + ) + # TODO: Review names + dropout: float = Field( + default=0.0, + desc="Dropout applied to the residual connections.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + + @property + def layer_class(self) -> "type[DecoderBlock]": + from fast_llm.layers.decoder.block import DecoderBlock + + return DecoderBlock + + def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: + return self.mixer.get_preprocessors(distributed_config) + self.mlp.get_preprocessors(distributed_config) + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + return self.mixer.get_loss_definitions(count=count) + self.mlp.get_loss_definitions(count=count) diff --git a/fast_llm/layers/decoder/mlp/__init__.py b/fast_llm/layers/decoder/mlp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/decoder/mlp/config.py similarity index 81% rename from fast_llm/layers/block/mlp/config.py rename to fast_llm/layers/decoder/mlp/config.py index 3d8a9c2bf..100f53740 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/decoder/mlp/config.py @@ -3,14 +3,15 @@ import typing from fast_llm.config import Field, FieldHint, check_field, config_class +from fast_llm.engine.base_model.config import LossDef from fast_llm.functional.config import ActivationType, MLPRecomputeLevel -from fast_llm.layers.block.config import MLPBaseConfig from fast_llm.layers.common.linear.config import AffineLinearConfig, LinearConfig +from fast_llm.layers.decoder.config import MLPBaseConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP - from fast_llm.layers.block.mlp.mlp import MLP + from fast_llm.layers.decoder.mlp.mixture_of_experts import MixtureOfExpertMLP + from fast_llm.layers.decoder.mlp.mlp import MLP class MLPLossNames: @@ -74,7 +75,7 @@ def _validate(self) -> None: @property def layer_class(self) -> "type[MLP]": - from fast_llm.layers.block.mlp.mlp import MLP + from fast_llm.layers.decoder.mlp.mlp import MLP return MLP @@ -87,10 +88,10 @@ class MoEMLPConfig(MLPConfig): hint=FieldHint.feature, ) experts: int = Field( - default=1, + default=2, desc="Number of MLP experts in a Mixture of Expert (MoE) model", hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), + valid=check_field(Assert.gt, 1), ) shared_experts: int = Field( default=0, @@ -139,7 +140,7 @@ class MoEMLPConfig(MLPConfig): @property def layer_class(self) -> "type[MixtureOfExpertMLP]": - from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP + from fast_llm.layers.decoder.mlp.mixture_of_experts import MixtureOfExpertMLP return MixtureOfExpertMLP @@ -151,3 +152,23 @@ def _validate(self) -> None: super()._validate() Assert.leq(self.shared_experts, self.experts) Assert.leq(self.shared_experts + self.experts_per_token, self.experts) + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + loss_definitions = [] + if self.routing == RoutingType.topk: + loss_definitions.append( + LossDef( + name=MLPLossNames.load_balancing_loss, + formatted_name="load balancing loss", + count=1, + ) + ) + if self.z_loss_coefficient: + loss_definitions.append( + LossDef( + name=MLPLossNames.router_z_loss, + formatted_name="router z loss", + count=1, + ) + ) + return loss_definitions diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/decoder/mlp/mixture_of_experts.py similarity index 98% rename from fast_llm/layers/block/mlp/mixture_of_experts.py rename to fast_llm/layers/decoder/mlp/mixture_of_experts.py index 9478dc51c..089fa2dc7 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/decoder/mlp/mixture_of_experts.py @@ -13,10 +13,10 @@ from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockKwargs -from fast_llm.layers.block.mlp.config import MLPLossNames, MoEMLPConfig, RoutingType -from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.decoder.mlp.config import MLPLossNames, MoEMLPConfig, RoutingType +from fast_llm.layers.decoder.mlp.mlp import MLPBase from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/decoder/mlp/mlp.py similarity index 96% rename from fast_llm/layers/block/mlp/mlp.py rename to fast_llm/layers/decoder/mlp/mlp.py index c88f766b0..fe4879e73 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/decoder/mlp/mlp.py @@ -10,13 +10,13 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.block.block import BlockLayer -from fast_llm.layers.block.mlp.config import MLPConfig from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.decoder.block import BlockWithBias +from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.tensor import TensorMeta -class MLPBase[ConfigType: MLPConfig](BlockLayer[ConfigType]): +class MLPBase[ConfigType: MLPConfig](BlockWithBias[ConfigType]): _config: ConfigType def __init__( diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index c15515fb5..849e09aa9 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,12 +1,12 @@ import typing from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.base_model.config import BaseModelConfig, Preprocessor +from fast_llm.engine.base_model.config import BaseModelConfig, LossDef, Preprocessor from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl -from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockLayerConfig +from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockSequenceConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert @@ -42,7 +42,7 @@ class LanguageModelKwargs(BlockKwargs): @config_class() -class LanguageModelEmbeddingsConfig(BlockLayerConfig): +class LanguageModelEmbeddingsConfig(BlockConfig): _abstract = False word_embeddings: ParameterConfig = Field( desc="Configuration for the word embedding (weight).", @@ -52,6 +52,12 @@ class LanguageModelEmbeddingsConfig(BlockLayerConfig): desc="Configuration for the word embedding (weight).", hint=FieldHint.architecture, ) + hidden_size: int = Field( + default=1024, + desc="Size of the model's main hidden dimension, e.g., for its input and output layers.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) vocab_size: int = Field( default=49152, desc="Size of the vocabulary, i.e., number of vocabulary embeddings and logits.", @@ -72,7 +78,7 @@ class LanguageModelEmbeddingsConfig(BlockLayerConfig): ) full_precision_residual: bool = Field( default=False, - desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", + desc="Store the residuals for the model in full precision (`optimization_dtype`).", hint=FieldHint.stability, ) @@ -104,7 +110,7 @@ def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Prepr @config_class() -class LanguageModelHeadConfig(BlockLayerConfig): +class LanguageModelHeadConfig(BlockConfig): _abstract = False normalization: NormalizationConfig = Field( desc="Configuration for the final normalization layer.", @@ -234,7 +240,36 @@ def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Prepr return preprocessors - def get_layer( + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + loss_defs = [] + if self.logit_z_loss: + LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=count) + + if self.enable_dpo: + loss_defs.append(LossDef(name=LanguageModelLossNames.dpo_loss, formatted_name="dpo loss", count=count)) + + if self.distillation_model is not None: + loss_defs.append( + LossDef(name=LanguageModelLossNames.distillation_loss, formatted_name="distillation loss", count=count) + ) + if self.language_model_loss_factor > 0.0: + loss_defs.append( + LossDef( + name=LanguageModelLossNames.distil_lm_loss, formatted_name="distillation lm loss", count=count + ) + ) + + for i in range(self.prediction_heads): + loss_defs.append( + LossDef( + name=LanguageModelLossNames.multi_token_prediction_loss(i), + formatted_name=f"language model loss {i}", + count=count, + ) + ) + return loss_defs + + def get_block( self, distributed_config: DistributedConfig, embeddings_config: LanguageModelEmbeddingsConfig, @@ -254,12 +289,49 @@ def get_layer( prediction_distance=prediction_distance, ) + def get_blocks( + self, + distributed_config: DistributedConfig, + embeddings_config: LanguageModelEmbeddingsConfig, + mtp_block_config: BlockConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + ): + blocks = [] + for i in range(self.prediction_heads): + if i > 0: + blocks.append( + mtp_block_config.get_block( + distributed_config, + hidden_dim=hidden_dim, + lr_scale=lr_scale, + peft=peft, + # The last block only returns the model output. + # The previous blocks return a stack of shared_hidden and transformer_output. + return_input=i < self.prediction_heads - 1, + ) + ) + blocks.append( + self.get_block( + distributed_config, + embeddings_config, + hidden_dim=hidden_dim, + lr_scale=lr_scale, + peft=peft, + prediction_distance=i, + ) + ) + return blocks + +# TODO: `BlockSequenceConfig`? (interface not fully compatible) @config_class() class LanguageModelBaseConfig(BaseModelConfig): # TODO: block - transformer: BlockConfig = Field( - desc="Configuration for the transformer architecture.", + decoder: BlockSequenceConfig = Field( + desc="Configuration for the language model decoder.", hint=FieldHint.architecture, ) embeddings_layer: LanguageModelEmbeddingsConfig = Field() @@ -292,9 +364,61 @@ def from_flat_dict( cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") return super().from_flat_dict(default, strict) + def __len__(self) -> int: + return len(self.decoder) + 2 * self.output_layer.prediction_heads + + def __getitem__(self, index: int) -> BlockConfig: + if index <= 0: + Assert.eq(index, 0) + return self.embeddings_layer + elif index <= len(self.decoder): + return self.decoder[index - 1] + else: + # Start at the last decoder layer so all MTP heads are treated similarly. + index - len(self.decoder) + return self.embeddings_layer + def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: return ( self.embeddings_layer.get_preprocessors(distributed_config) - + self.transformer.get_preprocessors(distributed_config) + + self.decoder.get_preprocessors(distributed_config) + self.output_layer.get_preprocessors(distributed_config) ) + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + return ( + self.embeddings_layer.get_loss_definitions(count) + + self.decoder.get_loss_definitions(count) + + self.output_layer.get_loss_definitions(count) + ) + + def get_blocks(self, distributed_config: DistributedConfig): + hidden_dim = TensorDim("hidden", self.embeddings_layer.hidden_size) + return [ + self.embeddings_layer.get_block( + distributed_config, + hidden_dim=hidden_dim, + lr_scale=None, + peft=self.peft, + ), + *[ + self.decoder[i].get_block( + distributed_config, + hidden_dim, + lr_scale=None, + peft=self.peft, + # The last layer only returns the transformer output. + # The previous layers return a stack of shared_hidden and transformer_output. + return_input=self.output_layer.prediction_heads > 1 and i == len(self.decoder) - 1, + ) + for i in range(len(self.decoder)) + ], + *self.output_layer.get_blocks( + distributed_config, + self.embeddings_layer, + self.decoder[len(self.decoder) - 1], + hidden_dim=hidden_dim, + lr_scale=None, + peft=self.peft, + ), + ] diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index b7a780a33..e0661cfa2 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -4,12 +4,11 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import reduce_forward, split -from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames -from fast_llm.layers.block.block import BlockLayerBase +from fast_llm.layers.block.block import Block from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, LanguageModelKwargs from fast_llm.tensor import TensorMeta @@ -18,7 +17,7 @@ WORD_EMBEDDINGS_WEIGHT = "word_embeddings_weight" -class LanguageModelEmbedding[ConfigType: LanguageModelEmbeddingsConfig](BlockLayerBase[ConfigType], Layer): +class LanguageModelEmbedding[ConfigType: LanguageModelEmbeddingsConfig](Block[ConfigType]): """ A language model embedding layer. Consists of word embeddings (tensor-parallel or sequence-tensor-parallel), @@ -37,13 +36,17 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_input: bool = False, ): + if return_input: + raise NotImplementedError() super().__init__( config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, + return_input=return_input, ) self._residual_dtype = ( self._distributed_config.optimization_dtype diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index e71512915..ade1144d2 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -6,7 +6,6 @@ from torch.distributed import all_reduce from fast_llm.core.ops import split_op -from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim @@ -16,7 +15,7 @@ from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward -from fast_llm.layers.block.block import BlockLayerBase +from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.peft.config import PeftConfig @@ -35,7 +34,7 @@ OUTPUT_WEIGHTS = "output_weights" -class LanguageModelHead[ConfigType: LanguageModelHeadConfig](BlockLayerBase[ConfigType], Layer): +class LanguageModelHead[ConfigType: LanguageModelHeadConfig](Block[ConfigType]): """ A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). TODO: Cleanup (dynamic type? composition?) @@ -53,13 +52,17 @@ def __init__( lr_scale: float | None, peft: PeftConfig | None, prediction_distance: int, + return_input: bool = False, ): + if return_input: + raise NotImplementedError() super().__init__( config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, + return_input=return_input, ) self._vocab_parallel = self._distributed_config.tensor_parallel > 1 and embeddings_config.vocab_parallel self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) @@ -179,7 +182,7 @@ def _forward_backward( if self._sequence_parallel_logits else TensorDim(BlockDimNames.sequence_q, dims[sequence_index].global_size) ) - meta = TensorMeta.from_dims(tuple(dims), tensor_name="transformer hidden_state", dtype=ln_output.dtype) + meta = TensorMeta.from_dims(tuple(dims), tensor_name="hidden_state", dtype=ln_output.dtype) hidden_state, _ = meta.local_to_global(ln_output.detach()) kwargs["hidden_states"][len(kwargs["hidden_states"]) - 1]["tensor"] = hidden_state diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 9b89b28cd..e541341e5 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,12 +1,11 @@ -import enum import math import typing from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, LambdaInitializer from fast_llm.engine.config_utils.parameter import ParameterConfig -from fast_llm.layers.block.config import MixerConfig from fast_llm.layers.common.linear.config import AffineLinearConfig, CausalConv1dConfig, LinearConfig +from fast_llm.layers.decoder.config import MixerConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -17,38 +16,6 @@ from fast_llm.tensor import ParameterMeta -class SSMBlockType(enum.StrEnum): - """ - An enum for the available mamba types for the MLP layer. - """ - - mamba = "m" - mamba2_discrete = "m2d" - mamba2 = "m2" - transformer = "t" - - def get_mixer_class(self): - if self == SSMBlockType.mamba: - from fast_llm.layers.ssm.mamba import Mamba - - return Mamba - elif self == SSMBlockType.mamba2: - from fast_llm.layers.ssm.mamba2 import Mamba2 - - return Mamba2 - elif self == SSMBlockType.mamba2_discrete: - from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 - - return DiscreteMamba2 - else: - raise NotImplementedError(self) - - -class DTInitType(enum.StrEnum): - constant = "constant" - random = "random" - - @config_class() class SSMConfig(MixerConfig): # Layers diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index a7d059781..f014012b2 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -9,9 +9,9 @@ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType -from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.decoder.block import BlockWithBias from fast_llm.layers.ssm.config import DiscreteMamba2Config from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import div @@ -27,12 +27,13 @@ _mamba_available = False -class DiscreteMamba2[ConfigType: DiscreteMamba2Config](BlockLayer[ConfigType]): +class DiscreteMamba2[ConfigType: DiscreteMamba2Config](BlockWithBias[ConfigType]): """ This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py """ _mixer_name: typing.ClassVar[str] = "discrete_mamba_2" + _config: DiscreteMamba2Config def __init__( self, @@ -104,6 +105,7 @@ def __init__( self.convolution = self._config.convolution_layer.get_layer( convolution_dim, + default_add_bias=self._config.add_linear_biases, default_activation=ActivationType.silu, lr_scale=self._lr_scale, peft=self._peft, diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index 5caa1a97c..e77a4468b 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -8,9 +8,9 @@ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import ActivationType -from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.decoder.block import BlockWithBias from fast_llm.layers.ssm.config import MambaConfig, init_a, init_dtprojbias from fast_llm.tensor import TensorMeta from fast_llm.utils import div @@ -31,8 +31,9 @@ """ -class Mamba[ConfigType: MambaConfig](BlockLayer[ConfigType]): +class Mamba[ConfigType: MambaConfig](BlockWithBias[ConfigType]): _mixer_name: typing.ClassVar[str] = "mamba" + _config: MambaConfig def __init__( self, @@ -72,7 +73,7 @@ def __init__( self.convolution = self._config.convolution_layer.get_layer( inner_dim, default_weight_initialization=init_normal_(0, (2 / self._config.d_inner) ** 0.5), - default_add_bias=False, + default_add_bias=self._config.add_linear_biases, default_activation=ActivationType.silu, lr_scale=self._lr_scale, peft=self._peft, @@ -91,6 +92,7 @@ def __init__( inner_dim, default_weight_initialization=init_normal_(0, (2 / self._config.d_inner) ** 0.5), default_bias_initialization=init_dtprojbias(), + default_add_bias=self._config.add_linear_biases, lr_scale=self._lr_scale, peft=self._peft, ) @@ -113,7 +115,7 @@ def __init__( inner_dim, hidden_dim, default_weight_initialization=init_normal_(0, (2 / hidden_dim.global_size) ** 0.5), - default_add_bias=False, + default_add_bias=self._config.add_linear_biases, lr_scale=self._lr_scale, peft=self._peft, ) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index b48f100db..b0657313d 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -8,9 +8,9 @@ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType -from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockDimNames, BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.decoder.block import BlockWithBias from fast_llm.layers.ssm.config import Mamba2Config, init_a, init_dtprojbias from fast_llm.tensor import TensorMeta from fast_llm.utils import div @@ -25,12 +25,13 @@ logger = logging.getLogger(__name__) -class Mamba2[ConfigType: Mamba2Config](BlockLayer[ConfigType]): +class Mamba2[ConfigType: Mamba2Config](BlockWithBias[ConfigType]): """ This code is adapted from https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py """ _mixer_name: typing.ClassVar[str] = "mamba_2" + _config: Mamba2Config def __init__( self, @@ -81,6 +82,7 @@ def __init__( self.convolution = self._config.convolution_layer.get_layer( convolution_dim, + default_add_bias=self._config.add_linear_biases, default_activation=ActivationType.silu, lr_scale=self._lr_scale, peft=self._peft, @@ -108,6 +110,7 @@ def __init__( inner_dim, default_weight_initialization=init_uniform_centered_(self._config.dt_rank**-0.5), default_bias_initialization=init_dtprojbias(), + default_add_bias=self._config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=self._lr_scale, peft=self._peft, @@ -223,7 +226,7 @@ def forward( c, self.D.float(), z, - delta_bias=self.dt_proj.bias.float(), + delta_bias=None if self.dt_proj.bias is None else self.dt_proj.bias.float(), delta_softplus=True, ) diff --git a/fast_llm/logging.py b/fast_llm/logging.py index 1bc30aeab..931c7f644 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -6,7 +6,7 @@ import torch import torch._dynamo # noqa -from fast_llm.engine.base_model.base_model import LossDef +from fast_llm.engine.base_model.config import LossDef from fast_llm.engine.config_utils.logging import TensorLogs from fast_llm.engine.distributed.config import PhaseType from fast_llm.tensor import TensorMeta diff --git a/fast_llm/models/auto.py b/fast_llm/models/auto.py index a5860096e..322932664 100644 --- a/fast_llm/models/auto.py +++ b/fast_llm/models/auto.py @@ -2,8 +2,6 @@ Import these submodules to ensure classes are added to the dynamic class registry. """ -from fast_llm.models.custom.config import CustomModelConfig, CustomTrainerConfig # isort: skip +from fast_llm.layers.ssm.config import MambaConfig, Mamba2Config, DiscreteMamba2Config # isort: skip from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig # isort: skip -from fast_llm.models.ssm.config import HybridSSMModelConfig, HybridSSMTrainerConfig # isort: skip - from fast_llm.engine.evaluation.evaluators import EvaluatorsConfig # isort: skip diff --git a/fast_llm/models/custom/config.py b/fast_llm/models/custom/config.py deleted file mode 100644 index aa304d396..000000000 --- a/fast_llm/models/custom/config.py +++ /dev/null @@ -1,62 +0,0 @@ -import typing - -from fast_llm.config import FieldUpdate, config_class -from fast_llm.data.data.gpt.config import GPTDataConfig -from fast_llm.engine.config_utils.runnable import RunnableConfig -from fast_llm.engine.multi_stage.config import FastLLMModelConfig -from fast_llm.engine.training.config import TrainerConfig -from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig, GPTTrainerConfig, PretrainedGPTModelConfig - -if typing.TYPE_CHECKING: - from fast_llm.models.custom.huggingface import HuggingfaceCustomModelForCausalLM - from fast_llm.models.custom.model import CustomModel - from fast_llm.models.custom.trainer import CustomTrainer - - -@config_class() -class CustomDataConfig(GPTDataConfig): - # TODO: If needed, inherit from AbstractDataConfig instead and re-implement everything. - pass - - -@config_class() -class CustomBaseModelConfig(GPTBaseModelConfig): - # TODO: Add custom other base model config parameters, if any. - pass - - -@config_class(dynamic_type={FastLLMModelConfig: "gpt_custom"}) -class CustomModelConfig(GPTModelConfig): - # TODO: Add custom model config parameters, if any (typically none). - model_name: typing.ClassVar[str] = "gpt_custom" - base_model: CustomBaseModelConfig = FieldUpdate() - - @classmethod - def get_model_class(cls) -> type["CustomModel"]: - from fast_llm.models.custom.model import CustomModel - - return CustomModel - - @classmethod - def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceCustomModelForCausalLM"]: - from fast_llm.models.custom.huggingface import HuggingfaceCustomModelForCausalLM - - return HuggingfaceCustomModelForCausalLM - - -@config_class() -class PretrainedCustomModelConfig(PretrainedGPTModelConfig): - model: CustomModelConfig = FieldUpdate() - - -@config_class(dynamic_type={RunnableConfig: "train_gpt_custom", TrainerConfig: "gpt_custom"}) -class CustomTrainerConfig(PretrainedCustomModelConfig, GPTTrainerConfig): - # TODO: Add custom trainer config parameters, if any (typically none). - data: CustomDataConfig = FieldUpdate() - reference_models: dict[str, PretrainedCustomModelConfig] = FieldUpdate() - - @classmethod - def get_trainer_class(cls) -> type["CustomTrainer"]: - from fast_llm.models.custom.trainer import CustomTrainer - - return CustomTrainer diff --git a/fast_llm/models/custom/data.py b/fast_llm/models/custom/data.py deleted file mode 100644 index 45ffd9edb..000000000 --- a/fast_llm/models/custom/data.py +++ /dev/null @@ -1,48 +0,0 @@ -import pathlib -import typing - -from fast_llm.data.data.gpt.data import GPTData -from fast_llm.engine.distributed.config import DistributedConfig, PhaseType -from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.schedule.config import BatchConfig -from fast_llm.models.custom.config import CustomDataConfig - - -class CustomData(GPTData): - # TODO: If needed, inherit from AbstractData instead and re-implement everything. - def __init__( - self, - config: CustomDataConfig, - distributed_config: DistributedConfig, - vocab_size: int, - max_sequence_length: int, - ): - # TODO: Adjust or reimplement. - super().__init__(config, distributed_config, vocab_size, max_sequence_length) - - def setup( - self, - distributed: Distributed, - samples_per_phase: dict[PhaseType, int], - cache_directory: pathlib.Path, - ): - # TODO: Adjust or reimplement. - return super().setup(distributed, samples_per_phase, cache_directory) - - def get_iterator( - self, - batch_config: BatchConfig, - phase: PhaseType, - *, - consumed_samples: int, - num_workers: int, - prefetch_factor: int | None = None, - ) -> typing.Iterator[typing.Any]: - # TODO: Adjust or reimplement. - return super().get_iterator( - batch_config, - phase, - consumed_samples=consumed_samples, - num_workers=num_workers, - prefetch_factor=prefetch_factor, - ) diff --git a/fast_llm/models/custom/head.py b/fast_llm/models/custom/head.py deleted file mode 100644 index 786e36929..000000000 --- a/fast_llm/models/custom/head.py +++ /dev/null @@ -1,6 +0,0 @@ -from fast_llm.layers.language_model.head import LanguageModelHead - - -class CustomHead(LanguageModelHead): - # TODO: Implement custom parts - pass diff --git a/fast_llm/models/custom/huggingface.py b/fast_llm/models/custom/huggingface.py deleted file mode 100644 index 7db4e73f8..000000000 --- a/fast_llm/models/custom/huggingface.py +++ /dev/null @@ -1,18 +0,0 @@ -from fast_llm.models.custom.config import CustomModelConfig -from fast_llm.models.custom.model import CustomModel -from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelConfig, HuggingfaceGPTModelForCausalLM - - -class HuggingfaceCustomModelConfig(HuggingfaceGPTModelConfig): - model_type = "fast_llm_gpt_custom" - model_config_class = CustomModelConfig - fast_llm_config: CustomModelConfig - - -class HuggingfaceCustomModelForCausalLM(HuggingfaceGPTModelForCausalLM): - # TODO: Implement changes in huggingface interface, if any. - # Ex.: Return predictions instead of logits. - config_class = HuggingfaceCustomModelConfig - config: HuggingfaceCustomModelConfig - model_class = CustomModel - _fast_llm_model: CustomModel diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py deleted file mode 100644 index 3afd88ce1..000000000 --- a/fast_llm/models/custom/model.py +++ /dev/null @@ -1,59 +0,0 @@ -import typing - -import torch - -from fast_llm.data.data.gpt.data import GPTBatch -from fast_llm.engine.base_model.base_model import LossDef -from fast_llm.engine.distributed.config import DistributedConfig, PhaseType -from fast_llm.engine.schedule.config import BatchConfig -from fast_llm.models.custom.config import CustomBaseModelConfig -from fast_llm.models.custom.head import CustomHead -from fast_llm.models.gpt.model import GPTBaseModel, GPTModel -from fast_llm.tensor import TensorMeta - - -class CustomBaseModel[ConfigType: CustomBaseModelConfig](GPTBaseModel[ConfigType]): - def __init__( - self, - config: ConfigType, - distributed_config: DistributedConfig, - ): - # TODO: Implement / update. - super().__init__(config, distributed_config) - - def _get_head(self, prediction_distance): - return CustomHead( - self._config, - self._distributed_config, - self._hidden_dim, - max(self._config.transformer.num_layers + prediction_distance, 1), - f"Language model head {prediction_distance}", - prediction_distance=prediction_distance, - ) - - def preprocess_meta( - self, batch_meta: BatchConfig | torch.Tensor, phase: PhaseType - ) -> list[tuple[TensorMeta, dict]]: - # TODO: Adjust or reimplement. - return super().preprocess_meta(batch_meta, phase) - - def preprocess( - self, - batch: GPTBatch, - preprocessed_meta: list[tuple[TensorMeta, dict]] | None = None, - *, - phase: PhaseType, - iteration: int, - metrics: dict | None = None, - ) -> list[tuple[torch.Tensor, dict]]: - # TODO: Adjust or reimplement. - return super().preprocess(batch, preprocessed_meta, phase=phase, iteration=iteration, metrics=metrics) - - @property - def loss_defs(self) -> list[LossDef]: - # TODO: Adjust or reimplement. - return super().loss_defs - - -class CustomModel[ConfigType: CustomBaseModelConfig](GPTModel[ConfigType]): - base_model_class: typing.ClassVar[type[CustomBaseModel]] = CustomBaseModel diff --git a/fast_llm/models/custom/readme.md b/fast_llm/models/custom/readme.md deleted file mode 100644 index ca0059084..000000000 --- a/fast_llm/models/custom/readme.md +++ /dev/null @@ -1,38 +0,0 @@ -# Custom model template - -The "custom" model is a template for customized training of a GPT-style model, -for example to fine-tune it for a particular class. -This is typically done as follows: - -1. Create a copy of the `custom` model, and rename it appropriately, ex. `my_model`, `MyModelTrainer`, etc. -2. If necessary, adjust the base classes to inherit from more abstract classes or another model. -ex. `MyModelData(AbstractData)` to re-implement data processing from scratch. -3. Add custom configuration fields in `config.py`. -4. Adapt or re-implement the data loading scheme in `MyModelData`. -5. Adapt or re-implement the preprocessing scheme in `MyModelBaseModel`. -6. Adapt or re-implement the model head, ex. change the task and/or add a custom loss. -7. If needed, adapt the huggingface interface to return outputs for the desired task. -8. Apply other changes as needed. -9. Add the new model to the registry (`models.auto.py`) so it can be used through the cli. -10. Run training with the new model, ex. `fast-llm train my_model [...]`. - -## Preprocessing variables and kwargs - -To pass additional parameters to the model during preprocessing, ex. a target for the loss or a runtime parameter, -simply add them to the returned `kwargs`. -Those kwargs will be passed directly to the `forward` method of each layer and can be used as needed. - -In some cases, it may be desirable to modify the `kwargs` inside a layer, -for example to pass additional data to other layers or to the backward pass. -This possible with certain caveats: - -* There is no direct support for autograd. Detaching tensors is recommended to prevent memory losses. -* Such modifications may be incompatible with pipeline parallelism, -as the data will not be transferred to pipeline-parallel devices. - -## Disclaimer - -Model customization is a work in progress. -Some abstractions may be missing or poorly implemented, -and some methods and variables may be hard-coded or very difficult to override. -We intend to address these issues in the future, but it will most likely incur some breaking changes in the interface. diff --git a/fast_llm/models/custom/trainer.py b/fast_llm/models/custom/trainer.py deleted file mode 100644 index 587adad3e..000000000 --- a/fast_llm/models/custom/trainer.py +++ /dev/null @@ -1,15 +0,0 @@ -from fast_llm.models.custom.config import CustomTrainerConfig -from fast_llm.models.custom.data import CustomData -from fast_llm.models.gpt.trainer import GPTTrainer - - -class CustomTrainer[ConfigType: CustomTrainerConfig](GPTTrainer[ConfigType]): - # TODO: Implement changes in the training loop (or tflops computation), if any (typically none). - def _get_data(self): - # TODO: Adjust signature if needed. - return CustomData( - config=self._config.data, - distributed_config=self._config.model.distributed, - vocab_size=self._config.model.base_model.vocab_size, - max_sequence_length=self._config.batch.sequence_length, - ) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 370ae4d90..702db413b 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -4,12 +4,23 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.data.data.gpt.config import GPTDataConfig -from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler +from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.training.config import TrainerConfig from fast_llm.layers.language_model.config import LanguageModelBaseConfig +from fast_llm.models.gpt.conversion.config import ( + AprielHybridSSMCheckpointFormat, + AutoGPTHuggingfaceCheckpointFormat, + DiffusionDreamCheckpointFormat, + DiffusionLlamaCheckpointFormat, + LlamaCheckpointFormat, + MistralCheckpointFormat, + MixtralCheckpointFormat, + MTPLlamaCheckpointFormat, + Qwen2CheckpointFormat, +) from fast_llm.models.gpt.megatron import set_megatron_distributed_seeds from fast_llm.utils import Assert, div @@ -21,52 +32,6 @@ logger = logging.getLogger(__name__) -class GPTHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False - - @classmethod - def get_handler_class(cls) -> type[CheckpointHandler]: - from fast_llm.models.gpt.conversion import AutoGPTHuggingfaceCheckpointHandler - - return AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.name) - - -class AutoGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "auto" - - -class Starcoder2GPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "starcoder2" - - -class LlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "llama" - - -class Qwen2GPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "qwen2" - - -class MistralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "mistral" - - -class MixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "mixtral" - - -class MTPLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "mtp_llama" - - -class DiffusionDreamGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "dream" - - -class DiffusionLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "diffusion_llama" - - @config_class() class GPTBatchConfig(BatchConfig): sequence_length: int = Field( @@ -151,14 +116,14 @@ class GPTModelConfig(FastLLMModelConfig): base_model: GPTBaseModelConfig = FieldUpdate() checkpoint_formats: typing.ClassVar[tuple[type[CheckpointFormat], ...]] = FastLLMModelConfig.checkpoint_formats + ( AutoGPTHuggingfaceCheckpointFormat, - Starcoder2GPTHuggingfaceCheckpointFormat, - LlamaGPTHuggingfaceCheckpointFormat, - Qwen2GPTHuggingfaceCheckpointFormat, - MistralGPTHuggingfaceCheckpointFormat, - MixtralGPTHuggingfaceCheckpointFormat, - MTPLlamaGPTHuggingfaceCheckpointFormat, - DiffusionDreamGPTHuggingfaceCheckpointFormat, - DiffusionLlamaGPTHuggingfaceCheckpointFormat, + LlamaCheckpointFormat, + Qwen2CheckpointFormat, + MistralCheckpointFormat, + MixtralCheckpointFormat, + MTPLlamaCheckpointFormat, + DiffusionDreamCheckpointFormat, + DiffusionLlamaCheckpointFormat, + AprielHybridSSMCheckpointFormat, ) @classmethod diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py deleted file mode 100644 index 3cb954e1d..000000000 --- a/fast_llm/models/gpt/conversion.py +++ /dev/null @@ -1,856 +0,0 @@ -import abc -import dataclasses -import logging -import typing - -import torch -from transformers.configuration_utils import PretrainedConfig - -from fast_llm.config import DEFAULT, MISSING -from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import ( - AutoStateDictCheckpointHandler, - ConstantExportParamConverter, - ConstantImportParamConverter, - IgnoreExportWeightConverter, - IgnoreImportParamConverter, - IgnoreImportWeightConverter, - MappedConfigParamConverter, - ParamConverter, - RenameParamConverter, - SplitWeightConverter, - WeightConverter, -) -from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler -from fast_llm.engine.multi_stage.config import FastLLMModelConfig -from fast_llm.functional.config import ActivationType -from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig -from fast_llm.layers.attention.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex -from fast_llm.layers.block.config import BlockConfig -from fast_llm.layers.block.mlp.config import RoutingType -from fast_llm.layers.common.normalization.config import LayerNormalizationConfig -from fast_llm.models.gpt.config import ( - DiffusionDreamGPTHuggingfaceCheckpointFormat, - DiffusionLlamaGPTHuggingfaceCheckpointFormat, - GPTBaseModelConfig, - GPTModelConfig, - LlamaGPTHuggingfaceCheckpointFormat, - MistralGPTHuggingfaceCheckpointFormat, - MixtralGPTHuggingfaceCheckpointFormat, - MTPLlamaGPTHuggingfaceCheckpointFormat, - Qwen2GPTHuggingfaceCheckpointFormat, - Starcoder2GPTHuggingfaceCheckpointFormat, -) -from fast_llm.models.gpt.external.diffusion_dream.configuration_dream import DreamConfig -from fast_llm.models.gpt.external.diffusion_llama.configuration_diffusion_llama import DiffusionLlamaConfig -from fast_llm.models.gpt.external.mtp_llama.configuration_mtp_llama import MTPLlamaConfig -from fast_llm.models.gpt.model import GPTModel -from fast_llm.tensor import SafeTensorSlice -from fast_llm.utils import Assert, div - -if typing.TYPE_CHECKING: - pass - -logger = logging.getLogger(__name__) - - -@dataclasses.dataclass -class HiddenSizeParamConverter(ParamConverter): - """ - Some HF models don't have a `head_dim` parameter, and instead use hidden_size // heads - """ - - def __post_init__(self): - Assert.eq(len(self.fast_llm_names), 3) - Assert.eq(len(self.export_names), 2) - - def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - hidden_size, heads, head_size = fast_llm_values - Assert.eq(head_size * heads, hidden_size) - return hidden_size, heads - - def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - hidden_size, heads = export_values - return hidden_size, heads, div(hidden_size, heads) - - -class QueryWeightConverter(WeightConverter): - # Hf uses the real format for rotary embeddings. - _config: GPTBaseModelConfig - - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (query,) = weight - if self._config.transformer.mixer.rotary.complex_format: - query = convert_rotary_complex_to_real(query[:], self._config.transformer.mixer.head_size, 0) - return (query,) - - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (query,) = weight - if self._config.transformer.mixer.rotary.complex_format: - query = convert_rotary_real_to_complex(query[:], self._config.transformer.mixer.head_size, 0) - return (query,) - - -class KeyValueWeightConverter(WeightConverter): - # Hf uses the real format for rotary embeddings, and keeps the key and value separate. - _config: GPTBaseModelConfig - - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (key_value,) = weight - key, value = key_value[:].chunk(2) - if self._config.transformer.mixer.rotary.complex_format: - key = convert_rotary_complex_to_real(key, self._config.transformer.mixer.head_size, 0) - return key, value - - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - key, value = weight - if self._config.transformer.mixer.rotary.complex_format: - key = convert_rotary_real_to_complex(key[:], self._config.transformer.mixer.head_size, 0) - key_value = torch.cat([key[:], value[:]]) - return (key_value,) - - -class MLPLayer2Converter(WeightConverter): - # Similar to SplitWeightConverter, but handles the optional MLP transpose. - # Still ok for non-gated (trivial split) and biases (trivial 1d transpose) - _config: GPTBaseModelConfig - - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (merged_weight,) = weight - return tuple(t.contiguous() for t in merged_weight[:].t().chunk(len(self.export_name), dim=-1)) - - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - merged_weight = torch.cat([weight_[:] for weight_ in weight], dim=-1) - return (merged_weight.t().contiguous(),) - - -class CommonHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): - _model: GPTModel - _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig - architecture: typing.ClassVar[str] - """ - Common converter for llama-based huggingface models (llama, starcoder2, mistral, mixtral) - """ - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantExportParamConverter(export_names=(("architectures",),), export_value=[cls.architecture]), - ConstantImportParamConverter( - fast_llm_names=( - ( - "embeddings_layer", - "position_embeddings", - "enabled", - ), - ), - fast_llm_value=False, - ), - RenameParamConverter( - fast_llm_names=(("transformer", "mixer", "rotary", "theta"),), export_names=(("rope_theta",),) - ), - MappedConfigParamConverter( - fast_llm_names=(("transformer", "mlp", "activation"),), - export_names=(("hidden_act",),), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - RenameParamConverter( - fast_llm_names=(("transformer", "num_layers"),), - export_names=(("num_hidden_layers",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "mixer", "head_groups"),), - export_names=(("num_key_value_heads",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "mlp", "intermediate_size"),), - export_names=(("intermediate_size",),), - ), - RenameParamConverter( - fast_llm_names=( - ( - "embeddings_layer", - "vocab_size", - ), - ), - export_names=(("vocab_size",),), - ), - RenameParamConverter( - fast_llm_names=( - ( - "output_layer", - "tied_weight", - ), - ), - export_names=(("tie_word_embeddings",),), - ), - ] - - @abc.abstractmethod - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - pass - - def _create_weight_converters( - self, - ) -> list[WeightConverter]: - converters = [] - num_layers = self._model.config.base_model.transformer.num_layers - - # Embeddings - converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) - - converters += self._create_lm_head_converters() - - for i in range(num_layers): - converters += self._create_transformer_layer_converters(f"layers.{i+1}", f"model.layers.{i}") - - return converters - - def _create_transformer_layer_converters( - self, fast_llm_layer_name: str, hf_layer_name: str, ignore_export: bool = False - ) -> list[WeightConverter]: - transformer_config: BlockConfig = self._model.config.base_model.transformer - norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) - converters = [] - names_bias_cls = [ - # Self-attn - ( - f"{fast_llm_layer_name}.mixer.query", - f"{hf_layer_name}.self_attn.q_proj", - # TODO: Fix - transformer_config.mixer.add_linear_biases, - QueryWeightConverter, - ), - ( - f"{fast_llm_layer_name}.mixer.key_value", - (f"{hf_layer_name}.self_attn.k_proj", f"{hf_layer_name}.self_attn.v_proj"), - # TODO: Fix - transformer_config.mixer.add_linear_biases, - KeyValueWeightConverter, - ), - ( - f"{fast_llm_layer_name}.mixer.dense", - f"{hf_layer_name}.self_attn.o_proj", - # TODO: Fix - transformer_config.mixer.add_linear_biases, - WeightConverter, - ), - # Norm - ( - f"{fast_llm_layer_name}.norm_1", - f"{hf_layer_name}.input_layernorm", - norm_bias, - WeightConverter, - ), - ( - f"{fast_llm_layer_name}.norm_2", - f"{hf_layer_name}.post_attention_layernorm", - norm_bias, - WeightConverter, - ), - ] - for fast_llm_prefix, hf_prefix, use_bias, cls in names_bias_cls: - converters += self._get_weight_and_bias_converters( - fast_llm_prefix, - () if ignore_export else hf_prefix, - use_bias, - cls=IgnoreExportWeightConverter if ignore_export else cls, - ) - - # MLP - if ignore_export: - converters += self._get_weight_and_bias_converters( - f"{fast_llm_layer_name}.mlp.layer_1", - (), - # TODO: Fix - transformer_config.mlp.add_linear_biases, - cls=IgnoreExportWeightConverter, - ) - converters += self._get_weight_and_bias_converters( - f"{fast_llm_layer_name}.mlp.layer_2", - (), - # TODO: Fix - transformer_config.mlp.add_linear_biases, - cls=IgnoreExportWeightConverter, - ) - converters += [IgnoreExportWeightConverter(f"{fast_llm_layer_name}.mlp.router.weight", ())] - else: - converters += self._get_mlp_converters(f"{fast_llm_layer_name}", f"{hf_layer_name}") - return converters - - def _create_lm_head_converters(self) -> list[WeightConverter]: - num_layers = self._model.config.base_model.transformer.num_layers - prediction_heads = self._model.config.base_model.output_layer.prediction_heads - norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) - converters = [] - - # Next-token prediction head - # Final norm - converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + 1}.final_norm", "model.norm", norm_bias - ) - # Output weights - if self._model.config.base_model.output_layer.tied_weight: - converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) - else: - converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) - - # MTP-heads > 0 are thrown away - for i in range(1, prediction_heads): - logger.warning( - f"The model weights for the multi-token prediction head {i} are discarded during conversion." - ) - mtp_transformer_layer_index = num_layers - 1 + 2 * i - # MTP transformer layer - converters += self._create_transformer_layer_converters( - f"layers.{mtp_transformer_layer_index + 1}", "", ignore_export=True - ) - # MTP output norm - converters += self._get_weight_and_bias_converters( - f"layers.{mtp_transformer_layer_index + 2}.final_norm", (), norm_bias, IgnoreExportWeightConverter - ) - - return converters - - def _get_weight_and_bias_converters( - self, - fast_llm_prefix: str | tuple[str, ...], - hf_prefix: str | tuple[str, ...], - use_bias: bool, - cls=WeightConverter, - ) -> list[WeightConverter]: - if isinstance(fast_llm_prefix, str): - fast_llm_prefix = (fast_llm_prefix,) - if isinstance(hf_prefix, str): - hf_prefix = (hf_prefix,) - converters = [ - cls( - tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), - tuple(f"{prefix}.weight" for prefix in hf_prefix), - self._model.config.base_model, - ) - ] - if use_bias: - converters.append( - cls( - tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), - tuple(f"{prefix}.bias" for prefix in hf_prefix), - self._model.config.base_model, - ) - ) - return converters - - -class Starcoder2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler): - format: typing.ClassVar[type[CheckpointFormat]] = Starcoder2GPTHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "Starcoder2ForCausalLM" - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - HiddenSizeParamConverter( - fast_llm_names=( - ("transformer", "hidden_size"), - ("transformer", "mixer", "heads"), - ("transformer", "mixer", "head_size"), - ), - export_names=(("hidden_size",), ("num_attention_heads",)), - ), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "mixer", "rotary", "type"),), - fast_llm_value=DefaultRotaryConfig.dynamic_type_name, - ), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "mixer", "add_linear_biases"),), fast_llm_value=True - ), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), - fast_llm_value="layer_norm", - ), - RenameParamConverter( - fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) - ), - ConstantImportParamConverter(fast_llm_names=(("transformer", "mlp", "gated"),), fast_llm_value=False), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "mlp", "add_linear_biases"),), fast_llm_value=True - ), - ConstantImportParamConverter( - fast_llm_names=(("output_layer", "normalization", "type"),), - fast_llm_value="layer_norm", - ), - ] - - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - transformer_config: BlockConfig = self._model.config.base_model.transformer - return [ - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - f"{hf_prefix}.mlp.c_fc", - # TODO: Fix - transformer_config.mlp.add_linear_biases, - ), - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.c_proj", - # TODO: Fix - transformer_config.mlp.add_linear_biases, - MLPLayer2Converter, - ), - ] - - -class CommonLlamaHuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler, abc.ABC): - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), - fast_llm_value="rms_norm", - ), - RenameParamConverter( - fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) - ), - RenameParamConverter( - fast_llm_names=(("transformer", "hidden_size"),), - export_names=(("hidden_size",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "mixer", "heads"),), - export_names=(("num_attention_heads",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "mixer", "head_size"),), - export_names=(("head_dim",),), - ), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "mixer", "add_linear_biases"),), fast_llm_value=False - ), - ConstantImportParamConverter(fast_llm_names=(("transformer", "mlp", "gated"),), fast_llm_value=True), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "mlp", "add_linear_biases"),), fast_llm_value=False - ), - LLamaRotaryParamConverter( - fast_llm_names=(("transformer", "mixer", "rotary"),), - export_names=( - ("rope_theta",), - ("rope_scaling",), - ), - ), - ConstantImportParamConverter( - fast_llm_names=(("output_layer", "normalization", "type"),), - fast_llm_value="rms_norm", - ), - ] - - -@dataclasses.dataclass -class LLamaRotaryParamConverter(ParamConverter): - def __post_init__(self): - Assert.eq(len(self.fast_llm_names), 1) - Assert.eq(len(self.export_names), 2) - - def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - (rotary_config,) = fast_llm_values - if type(rotary_config) is DefaultRotaryConfig: - rotary_scaling = { - "rope_type": "default", - } - elif type(rotary_config) is Llama3RotaryConfig: - rotary_scaling = { - "rope_type": "llama3", - "factor": rotary_config.scale_factor, - "low_freq_factor": rotary_config.low_frequency_factor, - "high_freq_factor": rotary_config.high_frequency_factor, - "original_max_position_embeddings": rotary_config.original_context_length, - } - elif type(rotary_config) is YarnRotaryConfig: - rotary_scaling = { - "rope_type": "yarn", - "attention_factor": rotary_config.attention_factor, - "beta_fast": rotary_config.beta_fast, - "beta_slow": rotary_config.beta_slow, - "original_max_position_embeddings": rotary_config.original_context_length, - } - else: - raise ValueError(f"Unsupported rotary type: {type(rotary_config).__name__}") - - return rotary_config.theta, rotary_scaling - - def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - rotary_theta, rope_scaling = export_values - rotary_type = "default" if rope_scaling in (None, MISSING) else rope_scaling.get("rope_type", "default") - rotary_config = { - "type": rotary_type, - "theta": rotary_theta, - } - if rotary_type == "default": - pass - elif rotary_type == "llama3": - rotary_config.update( - { - "scale_factor": rope_scaling.get("factor", DEFAULT), - "low_frequency_factor": rope_scaling.get("low_freq_factor", DEFAULT), - "high_frequency_factor": rope_scaling.get("high_freq_factor", DEFAULT), - "original_context_length": rope_scaling.get("original_max_position_embeddings", DEFAULT), - } - ) - elif rotary_type == "yarn": - rotary_config.update( - { - "attention_factor": rope_scaling.get("attention_factor", DEFAULT), - "beta_fast": rope_scaling.get("beta_fast", DEFAULT), - "beta_slow": rope_scaling.get("beta_slow", DEFAULT), - "original_context_length": rope_scaling.get("original_max_position_embeddings", DEFAULT), - } - ) - return (rotary_config,) # RotaryConfig.from_dict(rotary_config) - - -class LlamaHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): - format: typing.ClassVar[type[CheckpointFormat]] = LlamaGPTHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "LlamaForCausalLM" - - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - transformer_config: BlockConfig = self._model.config.base_model.transformer - return [ - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - # TODO: Fix - transformer_config.mlp.add_linear_biases, - SplitWeightConverter, - ), - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - # TODO: Fix - transformer_config.mlp.add_linear_biases, - MLPLayer2Converter, - ), - ] - - -@dataclasses.dataclass -class IgnoreImportQwen2SlidingWindowParamsConverter(ParamConverter): - def __post_init__(self): - Assert.eq(len(self.fast_llm_names), 0) - Assert.eq(len(self.export_names), 0) - self.export_names = (("use_sliding_window",), ("sliding_window",), ("max_window_layers",)) - - def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - return (MISSING, MISSING, MISSING) - - def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - # Default value for use_sliding_window in Qwen2 HF config is False - if export_values[0] != MISSING and export_values[0] == True: - logger.warning( - f"The configuration parameters `{self.export_names[0]}={export_values[0]}`," - f" `{self.export_names[1]}={export_values[1]}`, `{self.export_names[2]}={export_values[2]}`" - f" are ignored during conversion." - f" If you intend to use them in Fast-LLM, make sure to set them explicitly in the model configuration." - ) - return () - - -class Qwen2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler): - format: typing.ClassVar[type[CheckpointFormat]] = Qwen2GPTHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "Qwen2ForCausalLM" - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), - fast_llm_value="rms_norm", - ), - RenameParamConverter( - fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) - ), - HiddenSizeParamConverter( - fast_llm_names=( - ("transformer", "hidden_size"), - ("transformer", "mixer", "heads"), - ("transformer", "mixer", "head_size"), - ), - export_names=(("hidden_size",), ("num_attention_heads",)), - ), - ConstantImportParamConverter(fast_llm_names=(("transformer", "mlp", "gated"),), fast_llm_value=True), - # TODO: Fix - ConstantImportParamConverter( - fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value="only_attn_qkv" - ), - LLamaRotaryParamConverter( - fast_llm_names=(("transformer", "mixer", "rotary"),), - export_names=( - ("rope_theta",), - ("rope_scaling",), - ), - ), - ConstantImportParamConverter( - fast_llm_names=(("output_layer", "normalization", "type"),), - fast_llm_value="rms_norm", - ), - IgnoreImportQwen2SlidingWindowParamsConverter(), - ] - - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - transformer_config: BlockConfig = self._model.config.base_model.transformer - return [ - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - # TODO: Fix - transformer_config.mlp.add_linear_biases, - SplitWeightConverter, - ), - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - # TODO: Fix - transformer_config.mlp.add_linear_biases, - MLPLayer2Converter, - ), - ] - - -class MistralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): - format: typing.ClassVar[type[CheckpointFormat]] = MistralGPTHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "MistralForCausalLM" - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - IgnoreImportParamConverter(export_names=(("sliding_window",),), ignore_export_value=None), - ] - - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - return [ - SplitWeightConverter( - f"{fast_llm_prefix}.mlp.layer_1.weight", - (f"{hf_prefix}.mlp.gate_proj.weight", f"{hf_prefix}.mlp.up_proj.weight"), - ), - MLPLayer2Converter( - f"{fast_llm_prefix}.mlp.layer_2.weight", - f"{hf_prefix}.mlp.down_proj.weight", - self._model.config.base_model, - ), - ] - - -class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): - format: typing.ClassVar[type[CheckpointFormat]] = MixtralGPTHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "MixtralForCausalLM" - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantImportParamConverter(fast_llm_names=(("transformer", "mlp", "type"),), fast_llm_value="moe"), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "mlp", "routing"),), fast_llm_value=RoutingType.topk - ), - RenameParamConverter( - fast_llm_names=(("transformer", "mlp", "experts"),), export_names=(("num_local_experts",),) - ), - RenameParamConverter( - fast_llm_names=(("transformer", "mlp", "experts_per_token"),), - export_names=(("num_experts_per_tok",),), - ), - IgnoreImportParamConverter(export_names=(("sliding_window",),), ignore_export_value=None), - ] - - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - num_experts = self._model.config.base_model.transformer.mlp.experts - return [ - WeightConverter(f"{fast_llm_prefix}.mlp.router.weight", f"{hf_prefix}.block_sparse_moe.gate.weight"), - SplitWeightConverter( - f"{fast_llm_prefix}.mlp.layer_1.weight", - tuple( - f"{hf_prefix}.block_sparse_moe.experts.{i}.{w}.weight" - for i in range(num_experts) - for w in ("w1", "w3") - ), - ), - MLPLayer2Converter( - f"{fast_llm_prefix}.mlp.layer_2.weight", - tuple(f"{hf_prefix}.block_sparse_moe.experts.{i}.w2.weight" for i in range(num_experts)), - self._model.config.base_model, - ), - ] - - -class MTPLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonLlamaHuggingfaceCheckpointHandler): - from fast_llm.models.gpt.external.mtp_llama import configuration_mtp_llama, modeling_mtp_llama - - format: typing.ClassVar[type[CheckpointFormat]] = MTPLlamaGPTHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "MTPLlamaForCausalLM" - modeling_file = modeling_mtp_llama.__file__ - configuration_file = configuration_mtp_llama.__file__ - configuration_cls: typing.ClassVar[type[PretrainedConfig]] = MTPLlamaConfig - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantExportParamConverter( - export_names=(("auto_map",),), - export_value={ - "AutoConfig": "configuration_mtp_llama.MTPLlamaConfig", - "AutoModel": "modeling_mtp_llama.MTPLlamaModel", - "AutoModelForCausalLM": "modeling_mtp_llama.MTPLlamaForCausalLM", - }, - ), - RenameParamConverter( - fast_llm_names=( - ( - "output_layer", - "prediction_heads", - ), - ), - export_names=(("prediction_heads",),), - ), - ] - - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - transformer_config: BlockConfig = self._model.config.base_model.transformer - return [ - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - # TODO: Fix - transformer_config.mlp.add_linear_biases, - SplitWeightConverter, - ), - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - # TODO: Fix - transformer_config.mlp.add_linear_biases, - MLPLayer2Converter, - ), - ] - - # Override base method to handle the MTP heads - def _create_lm_head_converters(self) -> list[WeightConverter]: - num_layers = self._model.config.base_model.transformer.num_layers - prediction_heads = self._model.config.base_model.output_layer.prediction_heads - norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) - converters = [] - - # Next-token prediction head - # Transformer layer is already handled in the transformer layer converters - # Final norm - converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + 1}.final_norm", "model.mtp_norms.0", norm_bias - ) - # Multi-token prediction head - for i in range(1, prediction_heads): - mtp_transformer_layer_index = num_layers - 1 + 2 * i - # MTP transformer layer - converters += self._create_transformer_layer_converters( - f"layers.{mtp_transformer_layer_index + 1}", - f"model.mtp_heads.{i - 1}", - ) - # MTP output norm - converters += self._get_weight_and_bias_converters( - f"layers.{mtp_transformer_layer_index + 2}.final_norm", - f"model.mtp_norms.{i}", - norm_bias, - ) - # Output weights - if self._model.config.base_model.output_layer.tied_weight: - converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) - else: - converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) - - return converters - - -class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, Qwen2HuggingfaceCheckpointHandler): - """ - Handler for DiffusionDream Huggingface checkpoints. - Inherits from Qwen2HuggingfaceCheckpointHandler (and CustomModelingExportMixin), - but overrides _create_config_converters to update architectures and auto_map. - """ - - from fast_llm.models.gpt.external.diffusion_dream import configuration_dream, generation_utils, modeling_dream - - format: typing.ClassVar[type[CheckpointFormat]] = DiffusionDreamGPTHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "DreamModel" - modeling_file = modeling_dream.__file__ - configuration_file = configuration_dream.__file__ - generation_utils_file = generation_utils.__file__ - configuration_cls: typing.ClassVar[type[PretrainedConfig]] = DreamConfig - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantExportParamConverter( - export_names=(("auto_map",),), - export_value={ - "AutoConfig": "configuration_dream.DreamConfig", - "AutoModel": "modeling_dream.DreamModel", - }, - ), - ] - - -class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, LlamaHuggingfaceCheckpointHandler): - - from fast_llm.models.gpt.external.diffusion_llama import ( - configuration_diffusion_llama, - generation_utils, - modeling_diffusion_llama, - ) - - format: typing.ClassVar[type[CheckpointFormat]] = DiffusionLlamaGPTHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "DiffusionLlamaModel" - modeling_file = modeling_diffusion_llama.__file__ - configuration_file = configuration_diffusion_llama.__file__ - generation_utils_file = generation_utils.__file__ - configuration_cls: typing.ClassVar[type[PretrainedConfig]] = DiffusionLlamaConfig - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantExportParamConverter( - export_names=(("auto_map",),), - export_value={ - "AutoConfig": "configuration_diffusion_llama.DiffusionLlamaConfig", - "AutoModel": "modeling_diffusion_llama.DiffusionLlamaModel", - }, - ), - # TODO: include when the mask diffusion training is implemented; - # since the imported model (llama) for CPT doesn't have it but the exported model (diffusion llama) does need to have this token. - # RenameParamConverter( - # fast_llm_names=(("mask_token_id",),), - # export_names=(("mask_token_id",),), - # ), - ] - - -class AutoGPTHuggingfaceCheckpointHandler( - AutoStateDictCheckpointHandler, HuggingfaceStateDictCheckpointHandler, abc.ABC -): - - handler_map = { - Starcoder2GPTHuggingfaceCheckpointFormat.name: Starcoder2HuggingfaceCheckpointHandler, - LlamaGPTHuggingfaceCheckpointFormat.name: LlamaHuggingfaceCheckpointHandler, - Qwen2GPTHuggingfaceCheckpointFormat.name: Qwen2HuggingfaceCheckpointHandler, - MistralGPTHuggingfaceCheckpointFormat.name: MistralHuggingfaceCheckpointHandler, - MixtralGPTHuggingfaceCheckpointFormat.name: MixtralHuggingfaceCheckpointHandler, - MTPLlamaGPTHuggingfaceCheckpointFormat.name: MTPLlamaHuggingfaceCheckpointHandler, - DiffusionDreamGPTHuggingfaceCheckpointFormat.name: DiffusionDreamHuggingfaceCheckpointHandler, - DiffusionLlamaGPTHuggingfaceCheckpointFormat.name: DiffusionLlamaHuggingfaceCheckpointHandler, - } diff --git a/fast_llm/models/gpt/conversion/__init__.py b/fast_llm/models/gpt/conversion/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py new file mode 100644 index 000000000..5b32c481d --- /dev/null +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -0,0 +1,374 @@ +import math +import typing + +from transformers import PretrainedConfig + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.checkpoint.external import WeightConverter +from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.block.config import BlockSequenceConfig, FixedBlockSequenceConfig, PatternBlockSequenceConfig +from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.layers.ssm.config import DiscreteMamba2Config, Mamba2Config +from fast_llm.models.gpt.config import GPTModelConfig +from fast_llm.models.gpt.conversion.config import AprielHybridSSMCheckpointFormat +from fast_llm.models.gpt.conversion.llama import get_parameter_converter, get_weight_and_bias_converters +from fast_llm.models.gpt.conversion.mistral import ( + MistralBaseModelConverter, + MistralBlockConverter, + MistralDecoderConverter, + MistralHeadConverter, + MistralHuggingfaceCheckpointHandler, +) +from fast_llm.utils import Assert, safe_merge_dicts + + +class AprielDiscreteMamba2Converter: + @classmethod + def import_config(cls, config: dict, hidden_size: int) -> dict: + return { + "type": "discrete_mamba_2", + "state_size": config["ssm_cfg"]["d_state"], + "d_inner": config["ssm_cfg"].get("d_inner") or hidden_size * config["ssm_cfg"].get("expand", 1), + "add_linear_biases": config["ssm_cfg"]["bias"], + "convolution_layer": {"bias": {"enabled": config["ssm_cfg"].get("conv_bias", True)}}, + "n_qk_heads": config["ssm_cfg"]["n_qk_heads"], + "n_v_heads": config["ssm_cfg"]["n_v_heads"], + "chunk_size": config["ssm_cfg"]["chunk_size"], + } + + @classmethod + def export_config(cls, config: DiscreteMamba2Config) -> dict: + cls._check_config(config) + return { + "ssm_cfg": { + "d_state": config.state_size, + "d_inner": config.d_inner, + "bias": config.add_linear_biases, + "conv_bias": ( + config.add_linear_biases + if config.convolution_layer.bias.enabled is None + else config.convolution_layer.bias.enabled + ), + "n_qk_heads": config.n_qk_heads, + "n_v_heads": config.n_v_heads, + "chunk_size": config.chunk_size, + } + } + + @classmethod + def _check_config(cls, config: DiscreteMamba2Config) -> None: + # Opportunity to make derived classes less constrained. + Assert.is_(type(config), DiscreteMamba2Config) + Assert.incl(config.z_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.x_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.b_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.c_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.output_layer.bias.enabled, (None, config.add_linear_biases)) + + @classmethod + def get_converters( + cls, + config: DiscreteMamba2Config, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.in_proj", + f"{hf_prefix}.in_proj", + config.add_linear_biases, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.convolution", + f"{hf_prefix}.conv1d", + ( + config.add_linear_biases + if config.convolution_layer.bias.enabled is None + else config.convolution_layer.bias.enabled + ), + drop_on_export=drop_on_export, + ), + *( + [] + if config.add_linear_biases + else [ + get_parameter_converter( + f"{fast_llm_prefix}.z_bias", + f"{hf_prefix}.z_bias", + drop_on_export=drop_on_export, + ) + ] + ), + get_parameter_converter( + f"{fast_llm_prefix}.D", + f"{hf_prefix}.D", + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.out_proj", + f"{hf_prefix}.out_proj", + config.add_linear_biases, + drop_on_export=drop_on_export, + ), + ] + + +class AprielMamba2Converter: + @classmethod + def import_config(cls, config: dict, hidden_size: int) -> dict: + return { + "type": "mamba_2", + "state_size": config["ssm_cfg"]["d_state"], + "d_inner": config["ssm_cfg"].get("d_inner") or hidden_size * config["ssm_cfg"].get("expand", 1), + "add_linear_biases": config["ssm_cfg"]["bias"], + "convolution_layer": {"bias": {"enabled": config["ssm_cfg"].get("conv_bias", True)}}, + "d_xb": config["ssm_cfg"].get("d_xb") or hidden_size, + "dt_layer": {"bias": {"enabled": config["ssm_cfg"].get("dt_proj_bias", True)}}, + "dt_rank": ( + math.ceil(hidden_size) + if config["ssm_cfg"].get("dt_rank", "auto") == "auto" + else config["ssm_cfg"]["dt_rank"] + ), + "repeat_kv_before_conv": config["ssm_cfg"].get("repeat_kv_before_conv", True), + } + + @classmethod + def export_config(cls, config: Mamba2Config) -> dict: + cls._check_config(config) + return { + "ssm_cfg": { + "d_state": config.state_size, + "d_inner": config.d_inner, + "bias": config.add_linear_biases, + "conv_bias": ( + config.add_linear_biases + if config.convolution_layer.bias.enabled is None + else config.convolution_layer.bias.enabled + ), + "d_xb": config.d_xb, + "dt_proj_bias": ( + config.add_linear_biases if config.dt_layer.bias.enabled is None else config.dt_layer.bias.enabled + ), + "dt_rank": config.dt_rank, + "repeat_kv_before_conv": config.repeat_kv_before_conv, + } + } + + @classmethod + def _check_config(cls, config: Mamba2Config) -> None: + # Opportunity to make derived classes less constrained. + Assert.is_(type(config), Mamba2Config) + Assert.incl(config.z_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.x_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.b_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.c_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.dt_input_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.output_layer.bias.enabled, (None, config.add_linear_biases)) + + @classmethod + def get_converters( + cls, + config: Mamba2Config, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return [ + # TODO: Conv + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.in_proj", + f"{hf_prefix}.in_proj", + config.add_linear_biases, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.dt_in_proj", + f"{hf_prefix}.dt_in_proj", + config.add_linear_biases, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.dt_proj", + f"{hf_prefix}.dt_proj", + config.add_linear_biases if config.dt_layer.bias.enabled is None else config.dt_layer.bias.enabled, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.convolution", + f"{hf_prefix}.conv1d", + ( + config.add_linear_biases + if config.convolution_layer.bias.enabled is None + else config.convolution_layer.bias.enabled + ), + drop_on_export=drop_on_export, + ), + get_parameter_converter( + f"{fast_llm_prefix}.A_log", + f"{hf_prefix}.A_log", + drop_on_export=drop_on_export, + ), + get_parameter_converter( + f"{fast_llm_prefix}.D", + f"{hf_prefix}.D", + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.out_proj", + f"{hf_prefix}.out_proj", + config.add_linear_biases, + drop_on_export=drop_on_export, + ), + ] + + +class AprielDiscreteMamba2BlockConverter(MistralBlockConverter): + mixer_converter_class: typing.ClassVar[type[AprielDiscreteMamba2Converter]] = AprielDiscreteMamba2Converter + + +class AprielMamba2BlockConverter(MistralBlockConverter): + mixer_converter_class: typing.ClassVar[type[AprielMamba2Converter]] = AprielMamba2Converter + + +class AprielBlockConverter: + layout_names = { + AttentionConfig: "t", + Mamba2Config: "m2", + DiscreteMamba2Config: "m2d", + } + _converter_classes = { + AttentionConfig: MistralBlockConverter, + Mamba2Config: AprielMamba2BlockConverter, + DiscreteMamba2Config: AprielDiscreteMamba2BlockConverter, + } + _config_classes = {value: key for key, value in layout_names.items()} + + @classmethod + def import_config(cls, config: dict, hidden_size: int, layout_name: str = "t") -> dict: + return cls._converter_classes[cls._config_classes[layout_name]].import_config(config, hidden_size) + + @classmethod + def export_config(cls, config) -> dict: + return cls._converter_classes[type(config.mixer)].export_config(config) + + @classmethod + def get_converters( + cls, + config: DecoderBlockConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return cls._converter_classes[type(config.mixer)].get_converters( + config, fast_llm_prefix, hf_prefix, drop_on_export=drop_on_export + ) + + +class AprielDecoderConverter(MistralDecoderConverter): + block_converter_class: typing.ClassVar[type[AprielBlockConverter]] = AprielBlockConverter + + @classmethod + def import_config(cls, config: dict, hidden_size: int) -> dict: + layout = config["hybrid_block_layout"] + if len(layout) == 1: + return { + "block": cls.block_converter_class.import_config(config, hidden_size, layout[0]), + "num_blocks": config["num_hidden_layers"], + } + else: + return { + "type": "pattern", + "blocks": { + layout_name: cls.block_converter_class.import_config(config, hidden_size, layout_name) + for layout_name in set(layout) + }, + "pattern": layout, + "num_blocks": config["num_hidden_layers"], + } + + @classmethod + def export_config(cls, config: BlockSequenceConfig) -> dict: + if type(config) is FixedBlockSequenceConfig: + block_configs = [config.block] + pattern_block_configs = [config.block] + elif type(config) is PatternBlockSequenceConfig: + block_configs = config.blocks.values() + pattern_block_configs = [config.blocks[block_name] for block_name in config.pattern] + else: + raise NotImplementedError() + # There may be all sorts of blocks, but `safe_merge_dicts` ensures they are compatible. + return safe_merge_dicts( + *[cls.block_converter_class.export_config(block_config) for block_config in block_configs], + { + "num_hidden_layers": config.num_blocks, + "hybrid_block_layout": [ + cls.block_converter_class.layout_names[type(block_config.mixer)] + for block_config in pattern_block_configs + ], + }, + ) + + @classmethod + def get_converters( + cls, + config: PatternBlockSequenceConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + fast_llm_layer_start: int = 1, + ) -> list[WeightConverter]: + converters = [] + for block_index in range(config.num_blocks): + block_config = config.blocks[config.pattern[block_index % len(config.pattern)]] + converters += cls.block_converter_class.get_converters( + block_config, + f"{fast_llm_prefix}.{block_index+fast_llm_layer_start}", + f"{hf_prefix}.{block_index}", + drop_on_export, + ) + return converters + + +class AprielHeadConverter(MistralHeadConverter): + block_converter_class: typing.ClassVar[type[AprielBlockConverter]] = AprielBlockConverter + + +class AprielBaseModelConverter(MistralBaseModelConverter): + decoder_converter_class: typing.ClassVar[type[AprielDecoderConverter]] = AprielDecoderConverter + head_converter_class: typing.ClassVar[type[AprielHeadConverter]] = AprielHeadConverter + + +class AprielHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = AprielHybridSSMCheckpointFormat + architecture: typing.ClassVar[str] = "AprielHybridSSMForCausalLM" + base_model_converter_class: typing.ClassVar[type[AprielBaseModelConverter]] = AprielBaseModelConverter + + @classmethod + def get_transformers_configuration_class(cls) -> type[PretrainedConfig]: + from fast_llm_external_models.apriel_hybrid_ssm.configuration_apriel_hybrid_ssm import AprielHybridSSMConfig + + return AprielHybridSSMConfig + + @classmethod + def get_model_files(cls) -> tuple[str, str, str | None]: + from fast_llm_external_models.apriel_hybrid_ssm import ( + configuration_apriel_hybrid_ssm, + modeling_apriel_hybrid_ssm, + ) + + return configuration_apriel_hybrid_ssm.__file__, modeling_apriel_hybrid_ssm.__file__, None + + @classmethod + def _export_config(cls, config: GPTModelConfig) -> dict[str, typing.Any]: + return safe_merge_dicts( + super()._export_config(config), + { + "auto_map": { + "AutoConfig": "configuration_apriel_hybrid_ssm.AprielHybridSSMConfig", + "AutoModel": "modeling_apriel_hybrid_ssm.AprielHybridSSMModel", + "AutoModelForCausalLM": "modeling_apriel_hybrid_ssm.AprielHybridSSMForCausalLM", + }, + }, + ) diff --git a/fast_llm/models/gpt/conversion/auto.py b/fast_llm/models/gpt/conversion/auto.py new file mode 100644 index 000000000..659d1f12c --- /dev/null +++ b/fast_llm/models/gpt/conversion/auto.py @@ -0,0 +1,38 @@ +import abc + +from fast_llm.engine.checkpoint.external import AutoStateDictCheckpointHandler +from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.models.gpt.conversion.apriel import AprielHuggingfaceCheckpointHandler +from fast_llm.models.gpt.conversion.config import ( + AprielHybridSSMCheckpointFormat, + DiffusionDreamCheckpointFormat, + DiffusionLlamaCheckpointFormat, + LlamaCheckpointFormat, + MistralCheckpointFormat, + MixtralCheckpointFormat, + MTPLlamaCheckpointFormat, + Qwen2CheckpointFormat, +) +from fast_llm.models.gpt.conversion.diffusion_dream import DiffusionDreamHuggingfaceCheckpointHandler +from fast_llm.models.gpt.conversion.diffusion_llama import DiffusionLlamaHuggingfaceCheckpointHandler +from fast_llm.models.gpt.conversion.llama import LlamaHuggingfaceCheckpointHandler +from fast_llm.models.gpt.conversion.mistral import MistralHuggingfaceCheckpointHandler +from fast_llm.models.gpt.conversion.mixtral import MixtralHuggingfaceCheckpointHandler +from fast_llm.models.gpt.conversion.mtp_llama import MTPLlamaHuggingfaceCheckpointHandler +from fast_llm.models.gpt.conversion.qwen2 import Qwen2HuggingfaceCheckpointHandler + + +class AutoGPTHuggingfaceCheckpointHandler( + AutoStateDictCheckpointHandler, HuggingfaceStateDictCheckpointHandler, abc.ABC +): + + handler_map = { + LlamaCheckpointFormat.name: LlamaHuggingfaceCheckpointHandler, + Qwen2CheckpointFormat.name: Qwen2HuggingfaceCheckpointHandler, + MistralCheckpointFormat.name: MistralHuggingfaceCheckpointHandler, + MixtralCheckpointFormat.name: MixtralHuggingfaceCheckpointHandler, + MTPLlamaCheckpointFormat.name: MTPLlamaHuggingfaceCheckpointHandler, + DiffusionDreamCheckpointFormat.name: DiffusionDreamHuggingfaceCheckpointHandler, + DiffusionLlamaCheckpointFormat.name: DiffusionLlamaHuggingfaceCheckpointHandler, + AprielHybridSSMCheckpointFormat.name: AprielHuggingfaceCheckpointHandler, + } diff --git a/fast_llm/models/gpt/conversion/config.py b/fast_llm/models/gpt/conversion/config.py new file mode 100644 index 000000000..7c06906ad --- /dev/null +++ b/fast_llm/models/gpt/conversion/config.py @@ -0,0 +1,49 @@ +import typing + +from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler + + +class GPTHuggingfaceCheckpointFormat(CheckpointFormat): + support_optimizer: typing.ClassVar[bool] = False + + @classmethod + def get_handler_class(cls) -> type[CheckpointHandler]: + from fast_llm.models.gpt.conversion.auto import AutoGPTHuggingfaceCheckpointHandler + + return AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.name) + + +class AutoGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "auto" + + +class LlamaCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "llama" + + +class Qwen2CheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "qwen2" + + +class MistralCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "mistral" + + +class MixtralCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "mixtral" + + +class MTPLlamaCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "mtp_llama" + + +class DiffusionDreamCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "dream" + + +class DiffusionLlamaCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "diffusion_llama" + + +class AprielHybridSSMCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "apriel_hybrid_ssm" diff --git a/fast_llm/models/gpt/conversion/diffusion_dream.py b/fast_llm/models/gpt/conversion/diffusion_dream.py new file mode 100644 index 000000000..43742dd68 --- /dev/null +++ b/fast_llm/models/gpt/conversion/diffusion_dream.py @@ -0,0 +1,44 @@ +import typing + +from transformers import PretrainedConfig + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.models.gpt.config import GPTModelConfig +from fast_llm.models.gpt.conversion.config import DiffusionDreamCheckpointFormat +from fast_llm.models.gpt.conversion.qwen2 import Qwen2HuggingfaceCheckpointHandler +from fast_llm.utils import safe_merge_dicts + + +class DiffusionDreamHuggingfaceCheckpointHandler(Qwen2HuggingfaceCheckpointHandler): + """ + Handler for DiffusionDream Huggingface checkpoints. + Inherits from Qwen2HuggingfaceCheckpointHandler (and CustomModelingExportMixin), + but overrides _create_config_converters to update architectures and auto_map. + """ + + format: typing.ClassVar[type[CheckpointFormat]] = DiffusionDreamCheckpointFormat + architecture: typing.ClassVar[str] = "DreamModel" + + @classmethod + def get_transformers_configuration_class(cls) -> type[PretrainedConfig]: + from fast_llm_external_models.diffusion_dream.configuration_dream import DreamConfig + + return DreamConfig + + @classmethod + def get_model_files(cls) -> tuple[str, str, str | None]: + from fast_llm_external_models.diffusion_dream import configuration_dream, generation_utils, modeling_dream + + return configuration_dream.__file__, modeling_dream.__file__, generation_utils.__file__ + + @classmethod + def _export_config(cls, config: GPTModelConfig) -> dict[str, typing.Any]: + return safe_merge_dicts( + super()._export_config(config), + { + "auto_map": { + "AutoConfig": "configuration_dream.DreamConfig", + "AutoModel": "modeling_dream.DreamModel", + }, + }, + ) diff --git a/fast_llm/models/gpt/conversion/diffusion_llama.py b/fast_llm/models/gpt/conversion/diffusion_llama.py new file mode 100644 index 000000000..3343e5f1e --- /dev/null +++ b/fast_llm/models/gpt/conversion/diffusion_llama.py @@ -0,0 +1,42 @@ +import typing + +from transformers import PretrainedConfig + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.models.gpt.config import GPTModelConfig +from fast_llm.models.gpt.conversion.config import DiffusionLlamaCheckpointFormat +from fast_llm.models.gpt.conversion.llama import LlamaHuggingfaceCheckpointHandler +from fast_llm.utils import safe_merge_dicts + + +class DiffusionLlamaHuggingfaceCheckpointHandler(LlamaHuggingfaceCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = DiffusionLlamaCheckpointFormat + architecture: typing.ClassVar[str] = "DiffusionLlamaModel" + + @classmethod + def get_transformers_configuration_class(cls) -> type[PretrainedConfig]: + from fast_llm_external_models.diffusion_llama.configuration_diffusion_llama import DiffusionLlamaConfig + + return DiffusionLlamaConfig + + @classmethod + def get_model_files(cls) -> tuple[str, str, str | None]: + from fast_llm_external_models.diffusion_llama import ( + configuration_diffusion_llama, + generation_utils, + modeling_diffusion_llama, + ) + + return configuration_diffusion_llama.__file__, modeling_diffusion_llama.__file__, generation_utils.__file__ + + @classmethod + def _export_config(cls, config: GPTModelConfig) -> dict[str, typing.Any]: + return safe_merge_dicts( + super()._export_config(config), + { + "auto_map": { + "AutoConfig": "configuration_diffusion_llama.DiffusionLlamaConfig", + "AutoModel": "modeling_diffusion_llama.DiffusionLlamaModel", + }, + }, + ) diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py new file mode 100644 index 000000000..1162db4de --- /dev/null +++ b/fast_llm/models/gpt/conversion/llama.py @@ -0,0 +1,575 @@ +import logging +import typing + +import torch + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.checkpoint.external import ( + IgnoreExportWeightConverter, + IgnoreImportWeightConverter, + SplitWeightConverter, + WeightConverter, +) +from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.engine.multi_stage.config import FastLLMModelConfig +from fast_llm.functional.config import ActivationType +from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig +from fast_llm.layers.attention.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex +from fast_llm.layers.block.config import FixedBlockSequenceConfig +from fast_llm.layers.common.normalization.config import RMSNormalizationConfig +from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.layers.decoder.mlp.config import MLPConfig +from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, LanguageModelHeadConfig +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig +from fast_llm.models.gpt.conversion.config import LlamaCheckpointFormat +from fast_llm.models.gpt.model import GPTModel +from fast_llm.tensor import SafeTensorSlice +from fast_llm.utils import Assert, div, safe_merge_dicts + +logger = logging.getLogger(__name__) + + +def get_parameter_converter( + fast_llm_name: str | tuple[str, ...], + hf_name: str | tuple[str, ...], + cls=WeightConverter, + config=None, + drop_on_export: bool = False, + drop_on_import: bool = False, +) -> WeightConverter: + if isinstance(fast_llm_name, str): + fast_llm_name = (fast_llm_name,) + if isinstance(hf_name, str): + hf_name = (hf_name,) + if drop_on_export: + cls = IgnoreExportWeightConverter + if drop_on_import: + cls = IgnoreImportWeightConverter + return cls( + () if drop_on_import else fast_llm_name, + () if drop_on_export else hf_name, + config, + ) + + +def get_weight_and_bias_converters( + fast_llm_prefix: str | tuple[str, ...], + hf_prefix: str | tuple[str, ...], + use_bias: bool, + cls=WeightConverter, + config=None, + drop_on_export: bool = False, + drop_on_import: bool = False, +) -> list[WeightConverter]: + if isinstance(fast_llm_prefix, str): + fast_llm_prefix = (fast_llm_prefix,) + if isinstance(hf_prefix, str): + hf_prefix = (hf_prefix,) + converters = [ + get_parameter_converter( + () if drop_on_import else tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), + () if drop_on_export else tuple(f"{prefix}.weight" for prefix in hf_prefix), + cls, + config, + drop_on_export, + drop_on_import, + ) + ] + if use_bias: + converters.append( + get_parameter_converter( + () if drop_on_import else tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), + () if drop_on_export else tuple(f"{prefix}.bias" for prefix in hf_prefix), + cls, + config, + drop_on_export, + drop_on_import, + ) + ) + return converters + + +class LlamaNormalizationConverter: + @classmethod + def import_config(cls, config: dict) -> dict: + return {"type": "rms_norm", "epsilon": config["rms_norm_eps"]} + + @classmethod + def export_config(cls, config: RMSNormalizationConfig) -> dict: + Assert.custom(isinstance, config, RMSNormalizationConfig) + assert not config.zero_centered + return {"rms_norm_eps": config.epsilon} + + @classmethod + def get_converters( + cls, + config: RMSNormalizationConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return get_weight_and_bias_converters( + fast_llm_prefix, + () if drop_on_export else hf_prefix, + False, + IgnoreExportWeightConverter if drop_on_export else WeightConverter, + ) + + +class LlamaMLPConverter: + @classmethod + def import_config(cls, config: dict) -> dict: + return { + "intermediate_size": config["intermediate_size"], + "add_linear_biases": config["mlp_bias"], + "activation": ActivationType.from_hf_name(config["hidden_act"]), + "gated": True, + } + + @classmethod + def export_config(cls, config: MLPConfig) -> dict: + Assert.custom(isinstance, config, MLPConfig) + Assert.incl(config.layer_1.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.layer_2.bias.enabled, (None, config.add_linear_biases)) + assert config.gated + return { + "intermediate_size": config.intermediate_size, + "mlp_bias": config.add_linear_biases, + "hidden_act": config.activation.hf_name, + } + + @classmethod + def get_converters( + cls, + config: MLPConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.layer_1", + (f"{hf_prefix}.gate_proj", f"{hf_prefix}.up_proj"), + config.add_linear_biases, + SplitWeightConverter, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.layer_2", + f"{hf_prefix}.down_proj", + config.add_linear_biases, + MLPLayer2Converter, + drop_on_export=drop_on_export, + ), + ] + + +class MLPLayer2Converter(WeightConverter): + # Similar to SplitWeightConverter, but handles the optional MLP transpose. + # Still ok for non-gated (trivial split) and biases (trivial 1d transpose) + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (merged_weight,) = weight + return tuple(t.contiguous() for t in merged_weight[:].t().chunk(len(self.export_name), dim=-1)) + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + merged_weight = torch.cat([weight_[:] for weight_ in weight], dim=-1) + return (merged_weight.t().contiguous(),) + + +class LlamaAttentionConverter: + @classmethod + def import_config(cls, config: dict, hidden_size: int) -> dict: + try: + rope_type = config["rope_scaling"]["rope_type"] + except (KeyError, TypeError): + rope_type = "default" + rotary_config = { + "type": rope_type, + "theta": config["rope_theta"], + } + if rope_type == "default": + pass + elif rope_type == "llama3": + rotary_config.update( + { + "scale_factor": config["factor"], + "low_frequency_factor": config["low_freq_factor"], + "high_frequency_factor": config["high_freq_factor"], + "original_context_length": config["original_max_position_embeddings"], + } + ) + elif rope_type == "yarn": + rotary_config.update( + { + "attention_factor": config["attention_factor"], + "beta_fast": config["beta_fast"], + "beta_slow": config["beta_slow"], + "original_context_length": config["original_max_position_embeddings"], + } + ) + else: + raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}") + out = { + "rotary": rotary_config, + "heads": config["num_attention_heads"], + "head_groups": config["num_key_value_heads"], + "head_size": config.get("head_dim"), + "add_linear_biases": config["attention_bias"], + "dropout": config["attention_dropout"], + } + if out["head_size"] is None: + out["head_size"] = div(hidden_size, out["heads"]) + + return out + + @classmethod + def export_config(cls, config: AttentionConfig) -> dict: + cls._check_config(config) + Assert.eq(config.softmax_scale_power, 0.5) + out = { + "num_attention_heads": config.heads, + "num_key_value_heads": config.head_groups, + "head_dim": config.head_size, + "attention_bias": config.add_linear_biases, + "attention_dropout": config.dropout, + "rope_theta": config.rotary.theta, + } + if type(config.rotary) is DefaultRotaryConfig: + pass + elif type(config.rotary) is Llama3RotaryConfig: + out["rope_scaling"] = { + "rope_type": "llama3", + "factor": config.rotary.scale_factor, + "low_freq_factor": config.rotary.low_frequency_factor, + "high_freq_factor": config.rotary.high_frequency_factor, + "original_max_position_embeddings": config.rotary.original_context_length, + } + elif type(config.rotary) is YarnRotaryConfig: + out["rope_scaling"] = { + "rope_type": "yarn", + "attention_factor": config.rotary.attention_factor, + "beta_fast": config.rotary.beta_fast, + "beta_slow": config.rotary.beta_slow, + "original_max_position_embeddings": config.rotary.original_context_length, + } + else: + raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}") + + return out + + @classmethod + def _check_config(cls, config: AttentionConfig) -> None: + # Opportunity to make derived classes less constrained. + Assert.is_(type(config), AttentionConfig) + Assert.incl(config.query_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.key_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.value_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.dense_layer.bias.enabled, (None, config.add_linear_biases)) + + @classmethod + def get_converters( + cls, + config: AttentionConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.query", + f"{hf_prefix}.q_proj", + config.add_linear_biases, + QueryWeightConverter, + config, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.key_value", + (f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"), + config.add_linear_biases, + KeyValueWeightConverter, + config, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.dense", + f"{hf_prefix}.o_proj", + config.add_linear_biases, + drop_on_export=drop_on_export, + ), + ] + + +class QueryWeightConverter(WeightConverter): + # Hf uses the real format for rotary embeddings. + _config: AttentionConfig + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (query,) = weight + if self._config.rotary.complex_format: + query = convert_rotary_complex_to_real(query[:], self._config.head_size, 0) + return (query,) + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (query,) = weight + if self._config.rotary.complex_format: + query = convert_rotary_real_to_complex(query[:], self._config.head_size, 0) + return (query,) + + +class KeyValueWeightConverter(WeightConverter): + # Hf uses the real format for rotary embeddings, and keeps the key and value separate. + _config: AttentionConfig + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (key_value,) = weight + key, value = key_value[:].chunk(2) + if self._config.rotary.complex_format: + key = convert_rotary_complex_to_real(key, self._config.head_size, 0) + return key, value + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + key, value = weight + if self._config.rotary.complex_format: + key = convert_rotary_real_to_complex(key[:], self._config.head_size, 0) + key_value = torch.cat([key[:], value[:]]) + return (key_value,) + + +class LlamaBlockConverter: + mixer_converter_class: typing.ClassVar[type[LlamaAttentionConverter]] = LlamaAttentionConverter + mlp_converter_class: typing.ClassVar[type[LlamaMLPConverter]] = LlamaMLPConverter + normalization_converter_class: typing.ClassVar[type[LlamaNormalizationConverter]] = LlamaNormalizationConverter + + @classmethod + def import_config(cls, config: dict, hidden_size: int) -> dict: + return { + "mixer": cls.mixer_converter_class.import_config(config, hidden_size), + "mlp": cls.mlp_converter_class.import_config(config), + "normalization": cls.normalization_converter_class.import_config(config), + } + + @classmethod + def export_config(cls, config: DecoderBlockConfig) -> dict: + Assert.custom(isinstance, config, DecoderBlockConfig) + return safe_merge_dicts( + cls.mixer_converter_class.export_config(config.mixer), + cls.mlp_converter_class.export_config(config.mlp), + cls.normalization_converter_class.export_config(config.normalization), + ) + + @classmethod + def get_converters( + cls, config: DecoderBlockConfig, fast_llm_prefix: str, hf_prefix: str, drop_on_export: bool = False + ) -> list[WeightConverter]: + return [ + *cls.mixer_converter_class.get_converters( + config.mixer, + f"{fast_llm_prefix}.mixer", + f"{hf_prefix}.self_attn", + drop_on_export, + ), + *cls.mlp_converter_class.get_converters( + config.mlp, + f"{fast_llm_prefix}.mlp", + f"{hf_prefix}.mlp", + drop_on_export, + ), + *cls.normalization_converter_class.get_converters( + config.normalization, + f"{fast_llm_prefix}.norm_1", + f"{hf_prefix}.input_layernorm", + drop_on_export, + ), + *cls.normalization_converter_class.get_converters( + config.normalization, + f"{fast_llm_prefix}.norm_2", + f"{hf_prefix}.post_attention_layernorm", + drop_on_export, + ), + ] + + +class LlamaDecoderConverter: + block_converter_class: typing.ClassVar[type[LlamaBlockConverter]] = LlamaBlockConverter + + @classmethod + def import_config(cls, config: dict, hidden_size: int) -> dict: + return { + "block": cls.block_converter_class.import_config(config, hidden_size), + "num_blocks": config["num_hidden_layers"], + } + + @classmethod + def export_config(cls, config: FixedBlockSequenceConfig) -> dict: + # TODO: Support PatternBlockSequenceConfig with compatible configs. + Assert.custom(isinstance, config, FixedBlockSequenceConfig) + return safe_merge_dicts( + cls.block_converter_class.export_config(config.block), + {"num_hidden_layers": config.num_blocks}, + ) + + @classmethod + def get_converters( + cls, + config: FixedBlockSequenceConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + fast_llm_layer_start: int = 1, + ) -> list[WeightConverter]: + converters = [] + for block_index in range(config.num_blocks): + converters += cls.block_converter_class.get_converters( + config.block, + f"{fast_llm_prefix}.{block_index+fast_llm_layer_start}", + f"{hf_prefix}.{block_index}", + drop_on_export, + ) + return converters + + +class LlamaEmbeddingsConverter: + @classmethod + def import_config(cls, config: dict) -> dict: + return { + "vocab_size": config["vocab_size"], + "hidden_size": config["hidden_size"], + } + + @classmethod + def export_config(cls, config: LanguageModelEmbeddingsConfig) -> dict: + Assert.custom(isinstance, config, LanguageModelEmbeddingsConfig) + assert not config.position_embeddings.enabled + return { + "vocab_size": config.vocab_size, + "hidden_size": config.hidden_size, + } + + @classmethod + def get_converters( + cls, config: LanguageModelEmbeddingsConfig, fast_llm_prefix: str, hf_prefix: str + ) -> list[WeightConverter]: + return [WeightConverter(f"{fast_llm_prefix}.word_embeddings_weight", f"{hf_prefix}.embed_tokens.weight")] + + +class LlamaHeadConverter: + normalization_converter_class: typing.ClassVar[type[LlamaNormalizationConverter]] = LlamaNormalizationConverter + block_converter_class: typing.ClassVar[type[LlamaBlockConverter]] = LlamaBlockConverter + + @classmethod + def import_config(cls, config: dict) -> dict: + return { + "tied_weight": config["tie_word_embeddings"], + "normalization": cls.normalization_converter_class.import_config(config), + } + + @classmethod + def export_config(cls, config: LanguageModelHeadConfig) -> dict: + Assert.custom(isinstance, config, LanguageModelHeadConfig) + return safe_merge_dicts( + cls.normalization_converter_class.export_config(config.normalization), + {"tie_word_embeddings": config.tied_weight}, + ) + + @classmethod + def get_converters( + cls, config: LanguageModelHeadConfig, block_config: DecoderBlockConfig, fast_llm_prefix: str, start_index: int + ) -> list[WeightConverter]: + converters = [] + for prediction_distance in range(config.prediction_heads): + if prediction_distance > 0: + converters += cls.block_converter_class.get_converters( + block_config, + f"{fast_llm_prefix}.{start_index+2*prediction_distance-1}", + "", + drop_on_export=True, + ) + converters += cls.normalization_converter_class.get_converters( + config.normalization, + f"{fast_llm_prefix}.{start_index+2*prediction_distance}.final_norm", + f"model.norm", + drop_on_export=prediction_distance > 0, + ) + converters.append( + get_parameter_converter( + f"{fast_llm_prefix}.{start_index}.output_weights", + "lm_head.weight", + drop_on_import=config.tied_weight, + ) + ) + + return converters + + +class LlamaBaseModelConverter: + # TODO: Peft? + decoder_converter_class: typing.ClassVar[type[LlamaDecoderConverter]] = LlamaDecoderConverter + embeddings_converter_class: typing.ClassVar[type[LlamaEmbeddingsConverter]] = LlamaEmbeddingsConverter + head_converter_class: typing.ClassVar[type[LlamaHeadConverter]] = LlamaHeadConverter + + @classmethod + def import_config(cls, config: dict) -> dict: + return { + "embeddings_layer": cls.embeddings_converter_class.import_config(config), + "decoder": cls.decoder_converter_class.import_config(config, config["hidden_size"]), + "output_layer": cls.head_converter_class.import_config(config), + } + + @classmethod + def export_config(cls, config: GPTBaseModelConfig) -> dict: + Assert.custom(isinstance, config, GPTBaseModelConfig) + return safe_merge_dicts( + cls.embeddings_converter_class.export_config(config.embeddings_layer), + cls.decoder_converter_class.export_config(config.decoder), + cls.head_converter_class.export_config(config.output_layer), + ) + + @classmethod + def get_converters(cls, config: GPTBaseModelConfig) -> list[WeightConverter]: + return [ + *cls.embeddings_converter_class.get_converters(config.embeddings_layer, "layers.0", "model"), + *cls.decoder_converter_class.get_converters(config.decoder, "layers", "model.layers"), + *cls.head_converter_class.get_converters( + config.output_layer, config.decoder[len(config.decoder) - 1], "layers", len(config.decoder) + 1 + ), + ] + + def _create_weight_converters( + self, + ) -> list[WeightConverter]: + base_model_config = self._model.config.base_model + self.embeddings_converter_class.get_converters(base_model_config.embeddings_layer, "layers.0", "model") + converters = self.decoder_converter_class.get_converters(base_model_config.decoder, "layers", "model.layers") + self.head_converter_class.get_converters( + base_model_config.decoder, base_model_config.decoder.block, "layers", len(base_model_config.decoder) + 1 + ) + return converters + + +class LlamaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): + _model: GPTModel + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + format: typing.ClassVar[type[CheckpointFormat]] = LlamaCheckpointFormat + architecture: typing.ClassVar[str] = "LlamaForCausalLM" + base_model_converter_class: typing.ClassVar[type[LlamaBaseModelConverter]] = LlamaBaseModelConverter + + @classmethod + def get_transformers_configuration_class(cls): + import transformers + + return transformers.LlamaConfig diff --git a/fast_llm/models/gpt/conversion/mistral.py b/fast_llm/models/gpt/conversion/mistral.py new file mode 100644 index 000000000..4673f5b2c --- /dev/null +++ b/fast_llm/models/gpt/conversion/mistral.py @@ -0,0 +1,61 @@ +import typing + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.models.gpt.conversion.config import MistralCheckpointFormat +from fast_llm.models.gpt.conversion.llama import ( + LlamaAttentionConverter, + LlamaBaseModelConverter, + LlamaBlockConverter, + LlamaDecoderConverter, + LlamaHeadConverter, + LlamaHuggingfaceCheckpointHandler, +) +from fast_llm.utils import safe_merge_dicts + + +class MistralAttentionConverter(LlamaAttentionConverter): + @classmethod + def import_config(cls, config: dict, hidden_size: int) -> dict: + return safe_merge_dicts(super().import_config(config, hidden_size), {"window_size": config["sliding_window"]}) + + @classmethod + def export_config(cls, config: AttentionConfig) -> dict: + return safe_merge_dicts( + super().export_config(config), + {"sliding_window": config.window_size}, + ) + + @classmethod + def _check_config(cls, config: AttentionConfig) -> None: + # Mistral doesn't support biases. + assert not config.add_linear_biases + + +class MistralBlockConverter(LlamaBlockConverter): + mixer_converter_class: typing.ClassVar[type[MistralAttentionConverter]] = MistralAttentionConverter + + +class MistralDecoderConverter(LlamaDecoderConverter): + block_converter_class: typing.ClassVar[type[MistralBlockConverter]] = MistralBlockConverter + + +class MistralHeadConverter(LlamaHeadConverter): + block_converter_class: typing.ClassVar[type[MistralBlockConverter]] = MistralBlockConverter + + +class MistralBaseModelConverter(LlamaBaseModelConverter): + decoder_converter_class: typing.ClassVar[type[MistralDecoderConverter]] = MistralDecoderConverter + head_converter_class: typing.ClassVar[type[MistralHeadConverter]] = MistralHeadConverter + + +class MistralHuggingfaceCheckpointHandler(LlamaHuggingfaceCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = MistralCheckpointFormat + architecture: typing.ClassVar[str] = "MistralForCausalLM" + base_model_converter_class: typing.ClassVar[type[MistralBaseModelConverter]] = MistralBaseModelConverter + + @classmethod + def get_transformers_configuration_class(cls): + import transformers + + return transformers.MistralConfig diff --git a/fast_llm/models/gpt/conversion/mixtral.py b/fast_llm/models/gpt/conversion/mixtral.py new file mode 100644 index 000000000..428c2d3a3 --- /dev/null +++ b/fast_llm/models/gpt/conversion/mixtral.py @@ -0,0 +1,88 @@ +import typing + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.checkpoint.external import SplitWeightConverter, WeightConverter +from fast_llm.layers.decoder.mlp.config import MoEMLPConfig +from fast_llm.models.gpt.conversion.config import MixtralCheckpointFormat +from fast_llm.models.gpt.conversion.llama import LlamaMLPConverter, get_weight_and_bias_converters +from fast_llm.models.gpt.conversion.mistral import ( + MistralBaseModelConverter, + MistralBlockConverter, + MistralDecoderConverter, + MistralHeadConverter, + MistralHuggingfaceCheckpointHandler, +) +from fast_llm.utils import Assert, safe_merge_dicts + + +class MixtralMLPConverter(LlamaMLPConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + return safe_merge_dicts( + super().import_config(config), + { + "type": "moe", + "experts": config["num_local_experts"], + "experts_per_token": config["num_experts_per_tok"], + }, + ) + + @classmethod + def export_config(cls, config: MoEMLPConfig) -> dict: + Assert.custom(isinstance, config, MoEMLPConfig) + assert not config.add_linear_biases + return safe_merge_dicts( + super().export_config(config), + { + "num_local_experts": config.experts, + "num_experts_per_tok": config.experts_per_token, + }, + ) + + @classmethod + def get_converters( + cls, + config: MoEMLPConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.router", + () if drop_on_export else (f"{hf_prefix}.router",), + config.add_linear_biases, + SplitWeightConverter, + drop_on_export=drop_on_export, + ), + *super().get_converters(config, fast_llm_prefix, hf_prefix, drop_on_export=drop_on_export), + ] + + +class MixtralBlockConverter(MistralBlockConverter): + mlp_converter_class: typing.ClassVar[type[MixtralMLPConverter]] = MixtralMLPConverter + + +class MixtralDecoderConverter(MistralDecoderConverter): + block_converter_class: typing.ClassVar[type[MixtralBlockConverter]] = MixtralBlockConverter + + +class MixtralHeadConverter(MistralHeadConverter): + block_converter_class: typing.ClassVar[type[MixtralBlockConverter]] = MixtralBlockConverter + + +class MixtralBaseModelConverter(MistralBaseModelConverter): + decoder_converter_class: typing.ClassVar[type[MixtralDecoderConverter]] = MixtralDecoderConverter + head_converter_class: typing.ClassVar[type[MixtralHeadConverter]] = MixtralHeadConverter + + +class MixtralHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = MixtralCheckpointFormat + architecture: typing.ClassVar[str] = "MixtralForCausalLM" + base_model_converter_class: typing.ClassVar[type[MixtralBaseModelConverter]] = MixtralBaseModelConverter + + @classmethod + def get_transformers_configuration_class(cls): + import transformers + + return transformers.MixtralConfig diff --git a/fast_llm/models/gpt/conversion/mtp_llama.py b/fast_llm/models/gpt/conversion/mtp_llama.py new file mode 100644 index 000000000..194c263f9 --- /dev/null +++ b/fast_llm/models/gpt/conversion/mtp_llama.py @@ -0,0 +1,95 @@ +import typing + +from transformers import PretrainedConfig + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.checkpoint.external import WeightConverter +from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.layers.language_model.config import LanguageModelHeadConfig +from fast_llm.models.gpt.config import GPTModelConfig +from fast_llm.models.gpt.conversion.config import MTPLlamaCheckpointFormat +from fast_llm.models.gpt.conversion.llama import ( + LlamaBaseModelConverter, + LlamaHeadConverter, + LlamaHuggingfaceCheckpointHandler, + get_parameter_converter, +) +from fast_llm.utils import safe_merge_dicts + + +class MTPLlamaHeadConverter(LlamaHeadConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + return safe_merge_dicts( + super().import_config(config), + {"prediction_heads": config["prediction_heads"]}, + ) + + @classmethod + def export_config(cls, config: LanguageModelHeadConfig) -> dict: + return safe_merge_dicts( + super().export_config(config), + {"prediction_heads": config.prediction_heads}, + ) + + @classmethod + def get_converters( + cls, config: LanguageModelHeadConfig, block_config: DecoderBlockConfig, fast_llm_prefix: str, start_index: int + ) -> list[WeightConverter]: + converters = [] + for prediction_distance in range(config.prediction_heads): + if prediction_distance > 0: + converters += cls.block_converter_class.get_converters( + block_config, + f"{fast_llm_prefix}.{start_index+2*prediction_distance-1}", + f"model.mtp_heads.{prediction_distance - 1}", + ) + converters += cls.normalization_converter_class.get_converters( + config.normalization, + f"{fast_llm_prefix}.{start_index+2*prediction_distance}.final_norm", + f"model.mtp_norms.{prediction_distance}", + ) + converters.append( + get_parameter_converter( + f"{fast_llm_prefix}.{start_index}.output_weights", + "lm_head.weight", + drop_on_import=config.tied_weight, + ) + ) + + return converters + + +class MTPLlamaBaseModelConverter(LlamaBaseModelConverter): + head_converter_class: typing.ClassVar[type[MTPLlamaHeadConverter]] = MTPLlamaHeadConverter + + +class MTPLlamaHuggingfaceCheckpointHandler(LlamaHuggingfaceCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = MTPLlamaCheckpointFormat + architecture: typing.ClassVar[str] = "MTPLlamaForCausalLM" + base_model_converter_class: typing.ClassVar[type[MTPLlamaBaseModelConverter]] = MTPLlamaBaseModelConverter + + @classmethod + def _export_config(cls, config: GPTModelConfig) -> dict[str, typing.Any]: + return safe_merge_dicts( + super()._export_config(config), + { + "auto_map": { + "AutoConfig": "configuration_mtp_llama.MTPLlamaConfig", + "AutoModel": "modeling_mtp_llama.MTPLlamaModel", + "AutoModelForCausalLM": "modeling_mtp_llama.MTPLlamaForCausalLM", + }, + }, + ) + + @classmethod + def get_transformers_configuration_class(cls) -> type[PretrainedConfig]: + from fast_llm_external_models.mtp_llama.configuration_mtp_llama import MTPLlamaConfig + + return MTPLlamaConfig + + @classmethod + def get_model_files(cls) -> tuple[str, str, str | None]: + from fast_llm_external_models.mtp_llama import configuration_mtp_llama, modeling_mtp_llama + + return configuration_mtp_llama.__file__, modeling_mtp_llama.__file__, None diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py new file mode 100644 index 000000000..a8bc33454 --- /dev/null +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -0,0 +1,62 @@ +import typing + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.models.gpt.conversion.config import Qwen2CheckpointFormat +from fast_llm.models.gpt.conversion.llama import ( + LlamaAttentionConverter, + LlamaBaseModelConverter, + LlamaBlockConverter, + LlamaDecoderConverter, + LlamaHeadConverter, + LlamaHuggingfaceCheckpointHandler, +) +from fast_llm.utils import Assert + + +class Qwen2AttentionConverter(LlamaAttentionConverter): + # TODO: Support sliding window with max_window_layers (need 2 kinds of block?) + + @classmethod + def _check_config(cls, config: AttentionConfig) -> None: + Assert.is_(type(config), AttentionConfig) + # There are multiple ways to enable biases on QKV only + if config.add_linear_biases: + Assert.incl(config.query_layer.bias.enabled, (None, True)) + Assert.incl(config.key_layer.bias.enabled, (None, True)) + Assert.incl(config.value_layer.bias.enabled, (None, True)) + Assert.is_(config.dense_layer.bias.enabled, False) + else: + Assert.is_(config.query_layer.bias.enabled, True) + Assert.is_(config.key_layer.bias.enabled, True) + Assert.is_(config.value_layer.bias.enabled, True) + Assert.incl(config.dense_layer.bias.enabled, (None, False)) + + +class Qwen2BlockConverter(LlamaBlockConverter): + mixer_converter_class: typing.ClassVar[type[Qwen2AttentionConverter]] = Qwen2AttentionConverter + + +class Qwen2DecoderConverter(LlamaDecoderConverter): + block_converter_class: typing.ClassVar[type[Qwen2BlockConverter]] = Qwen2BlockConverter + + +class Qwen2HeadConverter(LlamaHeadConverter): + block_converter_class: typing.ClassVar[type[Qwen2BlockConverter]] = Qwen2BlockConverter + + +class Qwen2BaseModelConverter(LlamaBaseModelConverter): + decoder_converter_class: typing.ClassVar[type[Qwen2DecoderConverter]] = Qwen2DecoderConverter + head_converter_class: typing.ClassVar[type[Qwen2HeadConverter]] = Qwen2HeadConverter + + +class Qwen2HuggingfaceCheckpointHandler(LlamaHuggingfaceCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = Qwen2CheckpointFormat + architecture: typing.ClassVar[str] = "Qwen2ForCausalLM" + base_model_converter_class: typing.ClassVar[type[Qwen2BaseModelConverter]] = Qwen2BaseModelConverter + + @classmethod + def get_transformers_configuration_class(cls): + import transformers + + return transformers.Qwen2Config diff --git a/fast_llm/models/gpt/megatron.py b/fast_llm/models/gpt/megatron.py index bbe7ae43f..f63bd76f8 100644 --- a/fast_llm/models/gpt/megatron.py +++ b/fast_llm/models/gpt/megatron.py @@ -1,8 +1,8 @@ import typing from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig -from fast_llm.layers.block.config import BlockConfig -from fast_llm.layers.block.mlp.config import MoEMLPConfig +from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.layers.decoder.mlp.config import MoEMLPConfig from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: @@ -14,7 +14,7 @@ def get_init_megatron( - meta: "ParameterMeta", config: BlockConfig + meta: "ParameterMeta", config: DecoderBlockConfig, hidden_size: int ) -> typing.Callable[["torch.Tensor", "Distributed"], None]: def init_megatron(tensor: "torch.Tensor", distributed: "Distributed") -> None: Assert.eq(distributed.config.world_size, 1) @@ -22,13 +22,13 @@ def init_megatron(tensor: "torch.Tensor", distributed: "Distributed") -> None: # Generator unused. return meta.param_init_method(meta, tensor, distributed.tp_init_generator) if "query" in meta.tensor_name or "key_value" in meta.tensor_name or "dense" in meta.tensor_name: - tensor_ = _init_attention_megatron(config, meta, tensor, distributed) + tensor_ = _init_attention_megatron(config, meta, tensor, distributed, hidden_size) elif "position_embeddings" in meta.tensor_name: tensor_ = _init_position_embeddings_megatron(meta, tensor, distributed) elif "mlp.router.weight" in meta.tensor_name: tensor_ = _init_moe_router_megatron(meta, tensor, distributed) elif isinstance(config.mlp, MoEMLPConfig) and config.mlp.experts > 1 and "mlp.layer_" in meta.tensor_name: - tensor_ = _init_moe_mlp_megatron(config, meta, tensor, distributed) + tensor_ = _init_moe_mlp_megatron(config, meta, tensor, distributed, hidden_size) elif "mlp.layer_2" in meta.tensor_name: tensor_ = _init_transposed_mlp_weight_megatron(meta, tensor, distributed) else: @@ -51,7 +51,11 @@ def set_megatron_distributed_seeds(config: "DistributedConfig") -> None: def _init_attention_megatron( - config: BlockConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" + config: DecoderBlockConfig, + meta: "ParameterMeta", + tensor: "torch.Tensor", + distributed: "Distributed", + hidden_size: int, ) -> "torch.Tensor": # Megatron combines q and kv and inverts the initialization order of qkv and dense layers. # It also always treats the tensors as tensor-parallel and uses a different rotary embedding format. @@ -63,7 +67,7 @@ def _init_attention_megatron( meta, dense_tensor_ := tensor.new_empty( config.mixer.head_size * config.mixer.heads, - config.hidden_size, + hidden_size, ), generator, ) @@ -75,7 +79,7 @@ def _init_attention_megatron( config.mixer.head_groups, heads_per_group + 2, config.mixer.head_size, - config.hidden_size, + hidden_size, ), generator, ) @@ -141,19 +145,23 @@ def _init_moe_router_megatron( def _init_moe_mlp_megatron( - config: BlockConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" + config: DecoderBlockConfig, + meta: "ParameterMeta", + tensor: "torch.Tensor", + distributed: "Distributed", + hidden_size: int, ) -> "torch.Tensor": assert meta.param_init_method is not None generator = distributed.tp_init_generator if meta.is_tensor_parallel else distributed.pp_init_generator # self.param_init_method(self, tensor, generator) state = generator.get_state() weight_1 = tensor.new_empty( - config.mlp.experts * (1 + config.mlp.gated) * config.mlp.intermediate_size, config.hidden_size + config.mlp.experts * (1 + config.mlp.gated) * config.mlp.intermediate_size, hidden_size ) - weight_2 = tensor.new_empty(config.mlp.experts * config.mlp.intermediate_size, config.hidden_size) + weight_2 = tensor.new_empty(config.mlp.experts * config.mlp.intermediate_size, hidden_size) for chunk_1, chunk_2 in zip(weight_1.chunk(config.mlp.experts), weight_2.chunk(config.mlp.experts)): meta.param_init_method(meta, chunk_1, generator) - chunk_2_ = chunk_2.new_empty(config.hidden_size, config.mlp.intermediate_size) + chunk_2_ = chunk_2.new_empty(hidden_size, config.mlp.intermediate_size) meta.param_init_method(meta, chunk_2_, generator) chunk_2.copy_(chunk_2_.t()) if "layer_1.weight" in meta.tensor_name: diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 8b2947837..b7d751a61 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -4,17 +4,15 @@ import torch from fast_llm.data.data.gpt.data import GPTBatch -from fast_llm.engine.base_model.base_model import BaseModel, Layer, LossDef +from fast_llm.engine.base_model.base_model import BaseModel, Layer from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockDimNames -from fast_llm.layers.block.mlp.config import MLPLossNames, MoEMLPConfig, RoutingType -from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames +from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig @@ -37,79 +35,19 @@ def __init__( config: GPTBaseModelConfig, distributed_config: DistributedConfig, ): - self._hidden_dim = TensorDim("hidden", config.transformer.hidden_size) + self._hidden_dim = TensorDim("hidden", config.embeddings_layer.hidden_size) super().__init__(config, distributed_config) if self._config.use_megatron_initialization: for param in self.parameters(): Assert.custom(isinstance, param, ParameterMeta) - param.init_parameter = get_init_megatron(param, self._config.transformer) # Noqa + param.init_parameter = get_init_megatron( + param, self._config.decoder.block, config.embeddings_layer.hidden_size + ) # Noqa # `self._reference_models` is not populated at this point, so we pass a mutable dict. self._preprocessors: list[Preprocessor] = self._config.get_preprocessors(distributed_config) - def _get_output_layers(self) -> list[Layer]: - layers = [] - for i in range(self._config.output_layer.prediction_heads): - if i > 0: - layers.append( - self._get_block( - # TODO MTP: which index? - max(self._config.transformer.num_layers + i, 1), - f"MPT head {i} block", - # The last layer only returns the transformer output. - # The previous layers return a stack of shared_hidden and transformer_output. - i < self._config.output_layer.prediction_heads - 1, - ) - ) - layers.append(self._get_head(i)) - return layers - def get_layers(self) -> list[Layer]: - return [ - self._get_embeddings(), - *[ - self._get_block( - i + 1, - f"Block {i + 1}", - # The last layer only returns the transformer output. - # The previous layers return a stack of shared_hidden and transformer_output. - self._config.output_layer.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, - ) - for i in range(self._config.transformer.num_layers) - ], - *self._get_output_layers(), - ] - - def _get_block( - self, - block_index: int, - name: str, - return_input: bool = False, - ): - return self._config.transformer.get_layer( - self._distributed_config, - hidden_dim=self._hidden_dim, - lr_scale=None, - peft=self._config.peft, - return_input=return_input, - ) - - def _get_embeddings(self): - return self._config.embeddings_layer.get_layer( - self._distributed_config, - hidden_dim=self._hidden_dim, - lr_scale=None, - peft=self._config.peft, - ) - - def _get_head(self, prediction_distance): - return self._config.output_layer.get_layer( - self._distributed_config, - self._config.embeddings_layer, - hidden_dim=self._hidden_dim, - lr_scale=None, - peft=self._config.peft, - prediction_distance=prediction_distance, - ) + return self._config.get_blocks(self._distributed_config) def preprocess_meta( self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType @@ -339,10 +277,6 @@ def preprocess( def embedding(self) -> LanguageModelEmbedding: return self.layers[0] - @property - def transformer_layers(self) -> list[Block]: - return self.layers[1:-1] - @property def model_head(self) -> LanguageModelHead: return self.layers[self.model_head_indices[0]] @@ -369,54 +303,6 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: else: return {} - @property - def loss_defs(self) -> list[LossDef]: - loss_defs = [] - if ( - isinstance(self._config.transformer.mlp, MoEMLPConfig) - and self._config.transformer.mlp.experts > 1 - and self._config.transformer.mlp.routing == RoutingType.topk - ): - loss_defs.append( - LossDef( - name=MLPLossNames.load_balancing_loss, - formatted_name="load balancing loss", - count=self._config.transformer.num_layers, - ) - ) - if self._config.transformer.mlp.z_loss_coefficient: - loss_defs.append( - LossDef( - name=MLPLossNames.router_z_loss, - formatted_name="router z loss", - count=self._config.transformer.num_layers, - ) - ) - if self._config.output_layer.logit_z_loss: - LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=1) - - if self._config.output_layer.enable_dpo: - loss_defs.append(LossDef(name=LanguageModelLossNames.dpo_loss, formatted_name="dpo loss", count=1)) - - if self._config.output_layer.distillation_model is not None: - loss_defs.append( - LossDef(name=LanguageModelLossNames.distillation_loss, formatted_name="distillation loss", count=1) - ) - if self._config.output_layer.language_model_loss_factor > 0.0: - loss_defs.append( - LossDef(name=LanguageModelLossNames.distil_lm_loss, formatted_name="distillation lm loss", count=1) - ) - - for i in range(self._config.output_layer.prediction_heads): - loss_defs.append( - LossDef( - name=LanguageModelLossNames.multi_token_prediction_loss(i), - formatted_name=f"language model loss {i}", - count=1, - ) - ) - return loss_defs - class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): base_model_class: typing.ClassVar[type[GPTBaseModel]] = GPTBaseModel diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py deleted file mode 100644 index da44e547f..000000000 --- a/fast_llm/models/ssm/config.py +++ /dev/null @@ -1,189 +0,0 @@ -import logging -import typing - -from fast_llm.config import Field, FieldHint, FieldUpdate, config_class -from fast_llm.data.data.gpt.config import GPTDataConfig -from fast_llm.engine.checkpoint.config import CheckpointHandler -from fast_llm.engine.config_utils.runnable import RunnableConfig -from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig -from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig -from fast_llm.models.gpt.config import ( - GPTBaseModelConfig, - GPTBatchConfig, - GPTHuggingfaceCheckpointFormat, - PretrainedGPTModelConfig, -) -from fast_llm.utils import Assert - -if typing.TYPE_CHECKING: - from fast_llm.models.ssm.huggingface import HuggingfaceHybridSSMModelForCausalLM - from fast_llm.models.ssm.model import HybridSSMInferenceRunner, HybridSSMModel - from fast_llm.models.ssm.trainer import HybridSSMTrainer - -logger = logging.getLogger(__name__) - - -@config_class() -class HybridSSMBaseModelConfig(GPTBaseModelConfig): - _abstract = False - - ssm: SSMConfig = Field( - desc="Configuration for the transformer architecture.", - hint=FieldHint.architecture, - ) - hybrid_block_layout: list[SSMBlockType] | None = Field( - default=None, - desc=f"Pattern of blocks to use in the model. Available types: {SSMBlockType.__members__.values()}", - hint=FieldHint.core, - ) - default_mtp_type: SSMBlockType | None = Field( - default=None, - desc="Multi-token prediction mixer to use in the model. If None, will use the last block type in `hybrid_block_layout`.", - hint=FieldHint.optional, - ) - # TODO: Support combination of different SSM block types. - ssm_block_type: SSMBlockType | None = Field(init=False) - - def _validate(self): - if self.hybrid_block_layout is None: - with self._set_implicit_default(): - self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_layers - - if len(self.hybrid_block_layout) != self.transformer.num_layers: - message = f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" - if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: - raise ValueError(message) - num_repeats = self.transformer.num_layers // len(self.hybrid_block_layout) - logger.warning(f"{message}, will repeat {self.hybrid_block_layout} {num_repeats} times.") - self.hybrid_block_layout = self.hybrid_block_layout * num_repeats - - super()._validate() - ssm_block_types = set(self.hybrid_block_layout) - {SSMBlockType.transformer} - # TODO: Support combination of different SSM block types. - Assert.leq(len(ssm_block_types), 1) - self.ssm_block_type = ssm_block_types.pop() if ssm_block_types else None - - -class LLambaHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "llamba" - - @classmethod - def get_handler_class(cls) -> type[CheckpointHandler]: - from fast_llm.models.ssm.conversion import LLambaHuggingfaceCheckpointHandler - - return LLambaHuggingfaceCheckpointHandler - - -class AprielSSMHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "apriel_ssm" - - @classmethod - def get_handler_class(cls) -> type[CheckpointHandler]: - from fast_llm.models.ssm.conversion import AprielSSMHuggingfaceCheckpointHandler - - return AprielSSMHuggingfaceCheckpointHandler - - -class AprielSSMHHybridHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "apriel_ssm_hybrid" - - @classmethod - def get_handler_class(cls) -> type[CheckpointHandler]: - from fast_llm.models.ssm.conversion import AprielSSMHHybridHuggingfaceCheckpointHandler - - return AprielSSMHHybridHuggingfaceCheckpointHandler - - -class AprielThinkerSSMHHybridHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "apriel_ssm_thinker_hybrid" - - @classmethod - def get_handler_class(cls) -> type[CheckpointHandler]: - from fast_llm.models.ssm.conversion import AprielThinkerSSMHHybridHuggingfaceCheckpointHandler - - return AprielThinkerSSMHHybridHuggingfaceCheckpointHandler - - -@config_class(dynamic_type={FastLLMModelConfig: "hybrid_ssm"}) -class HybridSSMModelConfig(FastLLMModelConfig): - _abstract = False - model_name: typing.ClassVar[str] = "hybrid_ssm" - base_model: HybridSSMBaseModelConfig = FieldUpdate() - checkpoint_formats = FastLLMModelConfig.checkpoint_formats + ( - LLambaHuggingfaceCheckpointFormat, - AprielSSMHuggingfaceCheckpointFormat, - AprielSSMHHybridHuggingfaceCheckpointFormat, - AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, - ) - - @classmethod - def get_model_class(cls) -> type["HybridSSMModel"]: - from fast_llm.models.ssm.model import HybridSSMModel - - return HybridSSMModel - - @classmethod - def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceHybridSSMModelForCausalLM"]: - from fast_llm.models.ssm.huggingface import HuggingfaceHybridSSMModelForCausalLM - - return HuggingfaceHybridSSMModelForCausalLM - - def _validate(self): - logger.warning( - "HybridSSMModelConfig is being instantiated. This model is experimental and may not work as expected." - ) - super()._validate() - - -@config_class() -class PretrainedHybridSSMModelConfig(PretrainedFastLLMModelConfig): - _abstract = False - model: HybridSSMModelConfig = FieldUpdate() - - -@config_class(dynamic_type={RunnableConfig: "train_hybrid_ssm", TrainerConfig: "hybrid_ssm"}) -class HybridSSMTrainerConfig(PretrainedHybridSSMModelConfig, TrainerConfig): - data: GPTDataConfig = FieldUpdate() - batch: GPTBatchConfig = FieldUpdate() - reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate() - - @classmethod - def get_trainer_class(cls) -> type["HybridSSMTrainer"]: - from fast_llm.models.ssm.trainer import HybridSSMTrainer - - return HybridSSMTrainer - - def _validate(self) -> None: - super()._validate() - if (name := self.model.base_model.output_layer.distillation_model) is None: - Assert.empty(self.reference_models) - else: - Assert.eq(self.reference_models.keys(), {name}) - if self.model.base_model.embeddings_layer.position_embeddings.enabled: - Assert.geq(self.model.base_model.embeddings_layer.num_position_embeddings, self.batch.sequence_length) - # if self.model.base_model.distillation_model is not None: - # # TODO: Support loss masking for distillation? - # assert not self.batch.use_loss_masking_spans - for reference_model in self.reference_models.values(): - Assert.none(reference_model.model.base_model.output_layer.distillation_model) - # TODO: Support more LM head features. - Assert.none(reference_model.model.base_model.output_layer.cross_entropy_splits) - Assert.eq( - reference_model.model.base_model.embeddings_layer.vocab_parallel, - self.model.base_model.embeddings_layer.vocab_parallel, - ) - Assert.geq( - reference_model.model.base_model.output_layer.prediction_heads, - self.model.base_model.output_layer.prediction_heads, - ) - - @classmethod - def get_inference_runner_class(cls) -> type["HybridSSMInferenceRunner"]: - from fast_llm.models.ssm.model import HybridSSMInferenceRunner - - logger.warning( - "HybridSSMInferenceRunner only supports training-style forward pass. Use generate with cache disabled." - ) - - return HybridSSMInferenceRunner diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py deleted file mode 100644 index 999974ea3..000000000 --- a/fast_llm/models/ssm/conversion.py +++ /dev/null @@ -1,774 +0,0 @@ -import json -import os -import pathlib -import typing - -from transformers import PretrainedConfig - -from fast_llm.config import MISSING -from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import ( - ConstantExportParamConverter, - ConstantImportParamConverter, - IgnoreImportParamConverter, - IgnoreImportWeightConverter, - MappedConfigParamConverter, - ParamConverter, - RenameParamConverter, - SplitWeightConverter, - WeightConverter, -) -from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler -from fast_llm.engine.multi_stage.config import FastLLMModelConfig -from fast_llm.functional.config import ActivationType -from fast_llm.layers.common.normalization.config import RMSNormalizationConfig -from fast_llm.layers.ssm.config import SSMBlockType -from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter -from fast_llm.models.ssm.config import ( - AprielSSMHHybridHuggingfaceCheckpointFormat, - AprielSSMHuggingfaceCheckpointFormat, - AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, - HybridSSMModelConfig, - LLambaHuggingfaceCheckpointFormat, -) -from fast_llm.models.ssm.external.apriel_15b_hybrid import ( - configuration_ssm_hybrid_apriel15b, - modeling_ssm_hybrid_apriel15b, -) -from fast_llm.models.ssm.external.apriel_hybrid import configuration_ssm_hybrid_apriel, modeling_ssm_hybrid_apriel -from fast_llm.models.ssm.model import HybridSSMModel -from fast_llm.utils import Assert - - -class HybridModelCheckpointHandler(HuggingfaceStateDictCheckpointHandler): - _model: HybridSSMModel - _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig - _default_block_type: str = SSMBlockType.mamba2_discrete.value - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - block_converter = MappedConfigParamConverter( - fast_llm_names=(("hybrid_block_layout",),), - export_names=(("hybrid_block_layout",),), - fast_llm_value=lambda x: [cls._default_block_type] if x == MISSING else x, - export_value=lambda x: [x_.value for x_ in x], - ) - return super()._create_config_converters() + [block_converter] - - -class CommonSSMHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): - _model: HybridSSMModel - _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - RenameParamConverter( - fast_llm_names=(("ssm", "state_size"),), - export_names=( - ( - "ssm_cfg", - "d_state", - ), - ), - ), - RenameParamConverter( - fast_llm_names=(("ssm", "n_v_heads"),), - export_names=( - ( - "ssm_cfg", - "n_v_heads", - ), - ), - ), - RenameParamConverter( - fast_llm_names=(("ssm", "n_qk_heads"),), - export_names=( - ( - "ssm_cfg", - "n_qk_heads", - ), - ), - ), - RenameParamConverter( - fast_llm_names=(("ssm", "expansion_factor"),), - export_names=( - ( - "ssm_cfg", - "expand", - ), - ), - ), - RenameParamConverter( - fast_llm_names=(("ssm", "chunk_size"),), - export_names=( - ( - "ssm_cfg", - "chunk_size", - ), - ), - ), - RenameParamConverter( - fast_llm_names=(("ssm", "add_bias_linear"),), - export_names=( - ( - "ssm_cfg", - "bias", - ), - ), - ), - MappedConfigParamConverter( - fast_llm_names=(("ssm", "activation_type"),), - export_names=( - ( - "ssm_cfg", - "activation", - ), - ), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - # ================================================ - # Mamba2 specific parameters: they dont exist in old checkpoints exported for discrete Mamba2, hence need backward compatibility - RenameParamConverter( - fast_llm_names=(("ssm", "dt_rank"),), - export_names=( - ( - "ssm_cfg", - "dt_rank", - ), - ), - ignore_missing=True, - default_value=None, - ), - RenameParamConverter( - fast_llm_names=(("ssm", "dt_min"),), - export_names=( - ( - "ssm_cfg", - "dt_min", - ), - ), - ignore_missing=True, - default_value=0.001, - ), - RenameParamConverter( - fast_llm_names=(("ssm", "dt_max"),), - export_names=( - ( - "ssm_cfg", - "dt_max", - ), - ), - ignore_missing=True, - default_value=0.1, - ), - RenameParamConverter( - fast_llm_names=(("ssm", "dt_init_floor"),), - export_names=( - ( - "ssm_cfg", - "dt_init_floor", - ), - ), - ignore_missing=True, - default_value=1e-4, - ), - RenameParamConverter( - fast_llm_names=(("ssm", "dt_scale"),), - export_names=( - ( - "ssm_cfg", - "dt_scale", - ), - ), - ignore_missing=True, - default_value=1.0, - ), - RenameParamConverter( - fast_llm_names=(("ssm", "d_xb"),), - export_names=( - ( - "ssm_cfg", - "d_xb", - ), - ), - ignore_missing=True, - default_value=None, - ), - RenameParamConverter( - fast_llm_names=(("ssm", "conv_kernel_dimension"),), - export_names=( - ( - "ssm_cfg", - "d_conv", - ), - ), - ignore_missing=True, - default_value=4, - ), - ] - - def _create_weight_converters(self) -> list[WeightConverter]: - converters = super()._create_weight_converters() or [] - - num_layers = self._model.config.base_model.transformer.num_layers - ssm_bias: bool = self._model.config.base_model.transformer.add_linear_biases - - for i in range(num_layers): - # SSM - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.in_proj", f"model.layers.{i}.mixer.in_proj", ssm_bias - ) - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.out_proj", f"model.layers.{i}.mixer.out_proj", ssm_bias - ) - converters.append( - WeightConverter(f"layers.{i+1}.mixer.D", f"model.layers.{i}.mixer.D", self._model.config.base_model) - ) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.z_bias", f"model.layers.{i}.mixer.z_bias", self._model.config.base_model - ) - ) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.z_bias", f"model.layers.{i}.mixer.z_bias", self._model.config.base_model - ) - ) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.conv1d_weight", - f"model.layers.{i}.mixer.conv1d.weight", - self._model.config.base_model, - ) - ) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.conv1d_bias", - f"model.layers.{i}.mixer.conv1d.bias", - self._model.config.base_model, - ) - ) - # ================================================ - # Mamba2 specific parameters - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.dt_in_proj", f"model.layers.{i}.mixer.dt_in_proj", ssm_bias - ) - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.dt_proj", f"model.layers.{i}.mixer.dt_proj", False - ) - # bias is treated separately in Mamba2 and must always exist (https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.dt_proj_bias", - f"model.layers.{i}.mixer.dt_proj.bias", - self._model.config.base_model, - ) - ) - - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.A_log", f"model.layers.{i}.mixer.A_log", self._model.config.base_model - ) - ) - - return converters - - def _get_weight_and_bias_converters( - self, - fast_llm_prefix: str | tuple[str, ...], - hf_prefix: str | tuple[str, ...], - use_bias: bool, - cls=WeightConverter, - ) -> list[WeightConverter]: - if isinstance(fast_llm_prefix, str): - fast_llm_prefix = (fast_llm_prefix,) - if isinstance(hf_prefix, str): - hf_prefix = (hf_prefix,) - converters = [ - cls( - tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), - tuple(f"{prefix}.weight" for prefix in hf_prefix), - self._model.config.base_model, - ) - ] - if use_bias: - converters.append( - cls( - tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), - tuple(f"{prefix}.bias" for prefix in hf_prefix), - self._model.config.base_model, - ) - ) - return converters - - -class LLambaHuggingfaceCheckpointHandler(CommonSSMHuggingfaceCheckpointHandler): - _model: HybridSSMModel - _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig - format: typing.ClassVar[type[CheckpointFormat]] = LLambaHuggingfaceCheckpointFormat - _hf_prefix: str = "backbone" - architecture: typing.ClassVar[str] = "LlambaForCausalLM" - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - """ - Create config converters for the model, see args under https://huggingface.co/cartesia-ai/Llamba-8B/blob/main/config.json - """ - return super()._create_config_converters() + [ - RenameParamConverter( - fast_llm_names=(("vocab_size",),), - export_names=(("vocab_size",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) - ), - ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), - RenameParamConverter( - fast_llm_names=(("transformer", "num_layers"),), - export_names=(("n_layer",),), - ), - ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), - fast_llm_value=RMSNormalizationConfig.dynamic_type_name, - ), - MappedConfigParamConverter( - fast_llm_names=(("transformer", "activation_type"),), - export_names=( - ( - "mlp_cfg", - "act_fn", - ), - ), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - RenameParamConverter( - fast_llm_names=(("transformer", "add_linear_biases"),), - export_names=( - ( - "mlp_cfg", - "bias", - ), - ), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "ffn_hidden_size"),), - export_names=( - ( - "mlp_cfg", - "intermediate_size", - ), - ), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "hidden_size"),), - export_names=(("d_model",),), - ), - RenameParamConverter( - fast_llm_names=(("tie_word_embeddings",),), - export_names=(("tie_embeddings",),), - ), - ] - - def _create_weight_converters(self) -> list[WeightConverter]: - # not using super() because LLamba model is called backbone in the checkpoints - converters = [] - num_layers = self._model.config.base_model.transformer.num_layers - norm_bias: bool = False - ssm_bias: bool = self._model.config.base_model.transformer.add_linear_biases - - # Embedding and output - if self._model.config.base_model.tie_word_embeddings: - converters.append( - WeightConverter("layers.0.word_embeddings_weight", f"{self._hf_prefix}.embedding.weight") - ) - converters.append(IgnoreImportWeightConverter((), f"{self._hf_prefix}.lm_head.weight")) - else: - converters.append( - WeightConverter("layers.0.word_embeddings_weight", f"{self._hf_prefix}.embedding.weight") - ) - converters.append( - WeightConverter(f"layers.{num_layers + 1}.output_weights", f"{self._hf_prefix}.lm_head.weight") - ) - - # Final norm - converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + 1}.final_norm", f"{self._hf_prefix}.final_layernorm", norm_bias - ) - - for i in range(num_layers): - # SSM - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.in_proj", f"{self._hf_prefix}.layers.{i}.mixer.in_proj", ssm_bias - ) - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.out_proj", f"{self._hf_prefix}.layers.{i}.mixer.out_proj", ssm_bias - ) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.D", f"{self._hf_prefix}.layers.{i}.mixer.D", self._model.config.base_model - ) - ) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.z_bias", - f"{self._hf_prefix}.layers.{i}.mixer.z_bias", - self._model.config.base_model, - ) - ) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.conv1d_weight", - f"{self._hf_prefix}.layers.{i}.mixer.conv1d.weight", - self._model.config.base_model, - ) - ) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.conv1d_bias", - f"{self._hf_prefix}.layers.{i}.mixer.conv1d.bias", - self._model.config.base_model, - ) - ) - - # Norm - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.norm_1", f"{self._hf_prefix}.layers.{i}.input_layernorm", norm_bias - ) - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.norm_2", f"{self._hf_prefix}.layers.{i}.post_attention_layernorm", norm_bias - ) - - # MLP - converters += self._get_mlp_converters(f"layers.{i+1}", f"{self._hf_prefix}.layers.{i}") - - return converters - - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases - return [ - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - linear_bias, - SplitWeightConverter, - ), - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - linear_bias, - MLPLayer2Converter, - ), - ] - - def _get_weight_and_bias_converters( - self, - fast_llm_prefix: str | tuple[str, ...], - hf_prefix: str | tuple[str, ...], - use_bias: bool, - cls=WeightConverter, - ) -> list[WeightConverter]: - if isinstance(fast_llm_prefix, str): - fast_llm_prefix = (fast_llm_prefix,) - if isinstance(hf_prefix, str): - hf_prefix = (hf_prefix,) - converters = [ - cls( - tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), - tuple(f"{prefix}.weight" for prefix in hf_prefix), - self._model.config.base_model, - ) - ] - if use_bias: - converters.append( - cls( - tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), - tuple(f"{prefix}.bias" for prefix in hf_prefix), - self._model.config.base_model, - ) - ) - return converters - - @classmethod - def _load_config(cls, directory: pathlib.Path | str) -> dict: - if not os.path.exists(directory / "config.json"): - raise FileNotFoundError(f"config.json not found in {directory}") - with open(directory / "config.json") as f: - config = json.load(f) - Assert.eq(config["model_type"], cls.get_huggingface_model_type()) - return config - - @classmethod - def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: - with open(directory / "config.json", "w") as f: - json.dump(config, f) - - -class AprielSSMHuggingfaceCheckpointHandler(CommonSSMHuggingfaceCheckpointHandler): - """ - Lamba-like configs, pure SSM models. - """ - - _model: HybridSSMModel - _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig - format: typing.ClassVar[type[CheckpointFormat]] = AprielSSMHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "AprielSSMForCausalLM" - modeling_file = modeling_ssm_hybrid_apriel15b.__file__ - configuration_file = configuration_ssm_hybrid_apriel15b.__file__ - configuration_cls: typing.ClassVar[type["PretrainedConfig"]] = ( - configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig - ) - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - RenameParamConverter( - fast_llm_names=(("vocab_size",),), - export_names=(("vocab_size",),), - ), - RenameParamConverter( - fast_llm_names=(("ssm", "d_inner"),), - export_names=(("ssm_cfg", "d_inner"),), - ), - ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False), - ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), - MappedConfigParamConverter( - fast_llm_names=(("transformer", "activation_type"),), - export_names=(("hidden_act",),), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - RenameParamConverter( - fast_llm_names=(("transformer", "num_layers"),), - export_names=(("num_hidden_layers",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "hidden_size"),), - export_names=(("hidden_size",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "ffn_hidden_size"),), - export_names=(("intermediate_size",),), - ), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), - fast_llm_value=RMSNormalizationConfig.dynamic_type_name, - ), - RenameParamConverter( - fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) - ), - ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), - ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=False), - RenameParamConverter( - fast_llm_names=(("tie_word_embeddings",),), - export_names=(("tie_word_embeddings",),), - ), - ] - - def _create_weight_converters(self) -> list[WeightConverter]: - converters = super()._create_weight_converters() - num_layers = self._model.config.base_model.transformer.num_layers - norm_bias: bool = False - - # Embedding and output - if self._model.config.base_model.tie_word_embeddings: - converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) - converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) - else: - converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) - converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) - - # Final norm - converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + 1}.final_norm", "model.norm", norm_bias - ) - - for i in range(num_layers): - # Norm - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.norm_1", f"model.layers.{i}.input_layernorm", norm_bias - ) - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.norm_2", f"model.layers.{i}.post_attention_layernorm", norm_bias - ) - - # MLP - converters += self._get_mlp_converters(f"layers.{i+1}", f"model.layers.{i}") - - return converters - - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases - return [ - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - linear_bias, - SplitWeightConverter, - ), - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - linear_bias, - MLPLayer2Converter, - ), - ] - - @classmethod - def _load_config(cls, directory: pathlib.Path | str) -> dict: - if not os.path.exists(directory / "config.json"): - raise FileNotFoundError(f"config.json not found in {directory}") - with open(directory / "config.json") as f: - config = json.load(f) - Assert.eq(config["model_type"], cls.get_huggingface_model_type()) - return config - - @classmethod - def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: - with open(directory / "config.json", "w") as f: - json.dump(config, f) - - -class AprielSSMHHybridHuggingfaceCheckpointHandler( - CustomModelingExportMixin, - HybridModelCheckpointHandler, # handles the block structure parameter - CommonSSMHuggingfaceCheckpointHandler, # handles the SSM layers - CommonLlamaHuggingfaceCheckpointHandler, # handles the LLama layers -): - """ - Lamba-like configs, models that interleave LLama like layers with LLamba-like SSM layers. - """ - - _model: HybridSSMModel - _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig - format: typing.ClassVar[type[CheckpointFormat]] = AprielSSMHHybridHuggingfaceCheckpointFormat - _default_block_type: str = SSMBlockType.mamba2_discrete.value - architecture: typing.ClassVar[str] = "AprielSSMHybridForCausalLM" - modeling_file = modeling_ssm_hybrid_apriel.__file__ - configuration_file = configuration_ssm_hybrid_apriel.__file__ - configuration_cls: typing.ClassVar[type["PretrainedConfig"]] = modeling_ssm_hybrid_apriel.AprielSSMHybridConfig - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantExportParamConverter( - export_names=(("auto_map",),), - export_value={ - "AutoConfig": "configuration_ssm_hybrid_apriel.AprielSSMHybridConfig", - "AutoModel": "modeling_ssm_hybrid_apriel.AprielSSMHybridModel", - "AutoModelForCausalLM": "modeling_ssm_hybrid_apriel.AprielSSMHybridForCausalLM", - }, - ), - RenameParamConverter( - fast_llm_names=(("ssm", "d_inner"),), - export_names=(("ssm_cfg", "d_inner"),), - ), - ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False), - ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False), - ] - - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases - return [ - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - linear_bias, - SplitWeightConverter, - ), - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - linear_bias, - MLPLayer2Converter, - ), - ] - - @classmethod - def _load_config(cls, directory: pathlib.Path | str) -> dict: - if not os.path.exists(directory / "config.json"): - raise FileNotFoundError(f"config.json not found in {directory}") - with open(directory / "config.json") as f: - config = json.load(f) - Assert.eq(config["model_type"], cls.get_huggingface_model_type()) - return config - - @classmethod - def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: - with open(directory / "config.json", "w") as f: - json.dump(config, f) - - -class AprielThinkerSSMHHybridHuggingfaceCheckpointHandler( - CustomModelingExportMixin, - HybridModelCheckpointHandler, # handles the block structure parameter - CommonSSMHuggingfaceCheckpointHandler, # handles the SSM layers - CommonLlamaHuggingfaceCheckpointHandler, # handles the LLama layers -): - """ - Lamba-like configs, models that interleave LLama like layers with LLamba-like SSM layers. - """ - - _model: HybridSSMModel - _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig - format: typing.ClassVar[type[CheckpointFormat]] = AprielThinkerSSMHHybridHuggingfaceCheckpointFormat - _default_block_type: str = SSMBlockType.mamba2_discrete.value - _hf_prefix: str = "model" - architecture: typing.ClassVar[str] = "AprielThinkerSSMHybridForCausalLM" - modeling_file = modeling_ssm_hybrid_apriel15b.__file__ - configuration_file = configuration_ssm_hybrid_apriel15b.__file__ - configuration_cls: typing.ClassVar[type["PretrainedConfig"]] = ( - configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig - ) - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantExportParamConverter( - export_names=(("auto_map",),), - export_value={ - "AutoConfig": "configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig", - "AutoModel": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridModel", - "AutoModelForCausalLM": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridForCausalLM", - }, - ), - RenameParamConverter( - fast_llm_names=(("ssm", "d_inner"),), - export_names=(("ssm_cfg", "d_inner"),), - ), - IgnoreImportParamConverter(export_names=(("sliding_window",),), ignore_export_value=None), - ] - - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases - return [ - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - linear_bias, - SplitWeightConverter, - ), - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - linear_bias, - MLPLayer2Converter, - ), - ] - - @classmethod - def _load_config(cls, directory: pathlib.Path | str) -> dict: - if not os.path.exists(directory / "config.json"): - raise FileNotFoundError(f"config.json not found in {directory}") - with open(directory / "config.json") as f: - config = json.load(f) - Assert.eq(config["model_type"], cls.get_huggingface_model_type()) - return config - - @classmethod - def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: - with open(directory / "config.json", "w") as f: - json.dump(config, f) diff --git a/fast_llm/models/ssm/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py b/fast_llm/models/ssm/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py deleted file mode 100644 index 1d230bb67..000000000 --- a/fast_llm/models/ssm/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py +++ /dev/null @@ -1,448 +0,0 @@ -import math -from typing import Optional - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import is_torch_available, logging - -logger = logging.get_logger(__name__) - -if is_torch_available(): - import torch - - -def _compute_default_rope_parameters( - config: Optional[PretrainedConfig] = None, - device: Optional["torch.device"] = None, - seq_len: Optional[int] = None, - **rope_kwargs, -) -> tuple["torch.Tensor", float]: - """ - Computes the inverse frequencies according to the original RoPE implementation - Args: - config ([`~transformers.PretrainedConfig`]): - The model configuration. - device (`torch.device`): - The device to use for initialization of the inverse frequencies. - seq_len (`int`, *optional*): - The current sequence length. Unused for this type of RoPE. - rope_kwargs (`Dict`, *optional*): - BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. - Returns: - Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the - post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). - """ - if config is not None and len(rope_kwargs) > 0: - raise ValueError( - "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " - f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" - ) - if len(rope_kwargs) > 0: - base = rope_kwargs["base"] - dim = rope_kwargs["dim"] - elif config is not None: - base = config.rope_theta - partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - dim = int(head_dim * partial_rotary_factor) - - attention_factor = 1.0 # Unused in this type of RoPE - - # Compute the inverse frequencies - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) - return inv_freq, attention_factor - - -def _compute_yarn_parameters( - config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs -) -> tuple["torch.Tensor", float]: - """ - Computes the inverse frequencies with NTK scaling. Please refer to the - [original paper](https://arxiv.org/abs/2309.00071) - Args: - config ([`~transformers.PretrainedConfig`]): - The model configuration. - device (`torch.device`): - The device to use for initialization of the inverse frequencies. - seq_len (`int`, *optional*): - The current sequence length. Unused for this type of RoPE. - rope_kwargs (`Dict`, *optional*): - BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. - Returns: - Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the - post-processing scaling factor applied to the computed cos/sin. - """ - # No need to keep BC with yarn, unreleased when this new pattern was created. - if len(rope_kwargs) > 0: - raise ValueError( - f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}" - ) - - base = config.rope_theta - partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - dim = int(head_dim * partial_rotary_factor) - - # Apriel: Use original max_position_embeddings instead of max_position_embeddings - max_position_embeddings = config.rope_scaling.get( - "original_max_position_embeddings", config.max_position_embeddings - ) - factor = config.rope_scaling["factor"] - - # Sets the attention factor as suggested in the paper - attention_factor = config.rope_scaling.get("attention_factor") - if attention_factor is None: - attention_factor = 0.1 * math.log(factor) + 1.0 - - # Optional config options - # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) - beta_fast = config.rope_scaling.get("beta_fast") or 32 - beta_slow = config.rope_scaling.get("beta_slow") or 1 - - # Compute the inverse frequencies - def find_correction_dim(num_rotations, dim, base, max_position_embeddings): - """Inverse dimension formula to find the dimension based on the number of rotations""" - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - - def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings): - """Find dimension range bounds based on rotations""" - low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings)) - high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) - return max(low, 0), min(high, dim - 1) - - def linear_ramp_factor(min, max, dim): - if min == max: - max += 0.001 # Prevent singularity - - linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs - # to expand the possible context length. In other words, interpolation = apply scaling factor. - pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim) - inv_freq_extrapolation = 1.0 / pos_freqs - inv_freq_interpolation = 1.0 / (factor * pos_freqs) - - low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings) - - # Get n-dimensional rotational scaling corrected for extrapolation - inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device) - inv_freq = ( - inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) - + inv_freq_extrapolation * inv_freq_extrapolation_factor - ) - - return inv_freq, attention_factor - - -def _check_received_keys( - rope_type: str, - received_keys: set, - required_keys: set, - optional_keys: Optional[set] = None, - ignore_keys: Optional[set] = None, -): - """Compare the received keys in `config.rope_scaling` against the expected and optional keys""" - # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present - if "type" in received_keys: - received_keys -= {"type"} - required_keys.add("rope_type") - - # Some models need to store model-specific keys, and we don't want to throw warning at them - if ignore_keys is not None: - received_keys -= ignore_keys - - missing_keys = required_keys - received_keys - if missing_keys: - raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}") - - if optional_keys is not None: - unused_keys = received_keys - required_keys - optional_keys - else: - unused_keys = received_keys - required_keys - if unused_keys: - logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") - - -def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): - rope_scaling = config.rope_scaling - rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" - required_keys = {"rope_type"} - received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) - - -def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): - rope_scaling = config.rope_scaling - rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" - required_keys = {"rope_type", "factor", "original_max_position_embeddings"} - optional_keys = {"attention_factor", "beta_fast", "beta_slow"} - received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) - - factor = rope_scaling["factor"] - if factor is None or not isinstance(factor, float) or factor < 1.0: - logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") - - attention_factor = rope_scaling.get("attention_factor") - if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0): - logger.warning( - f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" - ) - beta_fast = rope_scaling.get("beta_fast") - if beta_fast is not None and not isinstance(beta_fast, float): - logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") - beta_slow = rope_scaling.get("beta_slow") - if beta_slow is not None and not isinstance(beta_slow, float): - logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") - - if (beta_fast or 32) < (beta_slow or 1): - logger.warning( - f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} " - f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" - ) - - -# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters -# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE -# parameterizations, as long as the callable has the same signature. -ROPE_INIT_FUNCTIONS = { - "default": _compute_default_rope_parameters, - "yarn": _compute_yarn_parameters, -} - -# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types. -ROPE_VALIDATION_FUNCTIONS = { - "default": _validate_default_rope_parameters, - "yarn": _validate_yarn_parameters, -} - - -def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None): - """ - Validate the RoPE config arguments, given a `PretrainedConfig` object - """ - rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig` - if rope_scaling is None: - return - - # BC: "rope_type" was originally "type" - rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) - validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type) - if validation_fn is not None: - validation_fn(config, ignore_keys=ignore_keys) - else: - logger.warning( - f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'" - ) - - -class AprielSSMHybridConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`AprielModel`]. It is used to instantiate an Apriel - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the Apriel-5B-Base. - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - Args: - vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the Apriel model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`AprielModel`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer decoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer decoder. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. Apriel-5B-Base supports up to 16384 tokens. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - pad_token_id (`int`, *optional*): - Padding token id. - bos_token_id (`int`, *optional*, defaults to 1): - Beginning of stream token id. - eos_token_id (`int`, *optional*, defaults to 2): - End of stream token id. - pretraining_tp (`int`, *optional*, defaults to 1): - Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this - document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to - understand more about it. This value is necessary to ensure exact reproducibility of the pretraining - results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type - and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value - accordingly. - Expected contents: - `rope_type` (`str`): - The sub-variant of RoPE to use. Can be one of ['default', 'yarn'], with 'default' being the original RoPE implementation. - `factor` (`float`, *optional*): - Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In - most scaling types, a `factor` of x will enable the model to handle sequences of length x * - original maximum pre-trained length. - `original_max_position_embeddings` (`int`, *optional*): - Used with 'yarn', 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during - pretraining. - `attention_factor` (`float`, *optional*): - Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention - computation. If unspecified, it defaults to value recommended by the implementation, using the - `factor` field to infer the suggested value. - `beta_fast` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear - ramp function. If unspecified, it defaults to 32. - `beta_slow` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear - ramp function. If unspecified, it defaults to 1. - `short_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to short contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `long_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to long contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `low_freq_factor` (`float`, *optional*): - Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE - `high_freq_factor` (`float`, *optional*): - Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE - attention_bias (`bool`, *optional*, defaults to `False`): - Whether to use a bias in the query, key, value and output projection layers during self-attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - mlp_bias (`bool`, *optional*, defaults to `False`): - Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. - head_dim (`int`, *optional*): - The attention head dimension. If None, it will default to hidden_size // num_attention_heads - ```python - >>> from transformers import AprielModel, AprielConfig - >>> # Initializing an Apriel Apriel-5B-Base style configuration - >>> configuration = AprielConfig() - >>> # Initializing a model from the Apriel-5B-Base style configuration - >>> model = AprielModel(configuration) - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "apriel_ssm_hybrid" - keys_to_ignore_at_inference = ["past_key_values"] - # Default tensor parallel plan for base model `AprielModel` - base_model_tp_plan = { - "layers.*.self_attn.q_proj": "colwise", - "layers.*.self_attn.k_proj": "colwise", - "layers.*.self_attn.v_proj": "colwise", - "layers.*.self_attn.o_proj": "rowwise", - "layers.*.mlp.gate_proj": "colwise", - "layers.*.mlp.up_proj": "colwise", - "layers.*.mlp.down_proj": "rowwise", - } - base_model_pp_plan = { - "embed_tokens": (["input_ids"], ["inputs_embeds"]), - "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), - "norm": (["hidden_states"], ["hidden_states"]), - } - - def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=None, - hidden_act="silu", - max_position_embeddings=2048, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=None, - bos_token_id=1, - eos_token_id=2, - pretraining_tp=1, - tie_word_embeddings=False, - rope_theta=10000.0, - rope_scaling=None, - attention_bias=False, - attention_dropout=0.0, - mlp_bias=False, - head_dim=None, - hybrid_block_layout=["m2d"], - ssm_cfg=None, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.pretraining_tp = pretraining_tp - self.use_cache = use_cache - self.rope_theta = rope_theta - self.rope_scaling = rope_scaling - self.attention_bias = attention_bias - self.attention_dropout = attention_dropout - self.mlp_bias = mlp_bias - self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads - self.hybrid_block_layout = hybrid_block_layout - if len(hybrid_block_layout) == 1: - self.hybrid_block_layout = [hybrid_block_layout[0]] * self.num_hidden_layers - assert len(self.hybrid_block_layout) == self.num_hidden_layers - # Validate the correctness of rotary position embeddings parameters - # BC: if there is a 'type' field, copy it it to 'rope_type'. - if self.rope_scaling is not None and "type" in self.rope_scaling: - self.rope_scaling["rope_type"] = self.rope_scaling["type"] - rope_config_validation(self) - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - ssm_defaults = { - "d_state": 64, - "n_v_heads": 24, - "n_qk_heads": 24, - "expand": 1, - "chunk_size": 128, - "activation": "identity", - "bias": False, - "d_conv": 4, - "d_inner": 24 * self.head_dim, # num_heads * head_dim - } - self.ssm_cfg = ssm_cfg or ssm_defaults - for k, v in ssm_defaults.items(): - if k not in self.ssm_cfg: - self.ssm_cfg[k] = v diff --git a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py deleted file mode 100644 index 771f81a7d..000000000 --- a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py +++ /dev/null @@ -1,1576 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Callable, Optional, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -from einops import rearrange, repeat -from mamba_ssm.ops.triton.selective_state_update import selective_state_update -from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined -from torch import nn -from transformers import GenerationMixin -from transformers.activations import ACT2FN -from transformers.cache_utils import Cache, DynamicCache, StaticCache -from transformers.modeling_attn_mask_utils import AttentionMaskConverter -from transformers.modeling_flash_attention_utils import FlashAttentionKwargs -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from transformers.processing_utils import Unpack -from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS -from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging -from transformers.utils.generic import ModelOutput - -from fast_llm.models.ssm.external.apriel_hybrid.configuration_ssm_hybrid_apriel import ( - ROPE_INIT_FUNCTIONS, - AprielSSMHybridConfig, -) - -logger = logging.get_logger(__name__) - -is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) - - -def apply_mask_to_padding_states(hidden_states, attention_mask): - """ - Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 - """ - if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: - dtype = hidden_states.dtype - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) - - return hidden_states - - -class HybridMambaAttentionStaticCache(Cache): - def __init__(self, config: AprielSSMHybridConfig, batch_size, max_length, dtype=torch.float16, device=None): - super().__init__() # config, batch_size, max_length, device, dtype) - self.dtype = dtype - self.hybrid_override_pattern = config.hybrid_block_layout - self.has_previous_state = False # only used by mamba - intermediate_size = config.ssm_cfg["d_inner"] - ssm_state_size = config.ssm_cfg["d_state"] - conv_kernel_size = config.ssm_cfg["d_conv"] - self.n_qk_heads = config.ssm_cfg["n_qk_heads"] - assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" - self.head_d = intermediate_size // self.n_qk_heads - self.conv_states = [] - self.ssm_states = [] - self.transformer_layers = [] - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - - self.batch_size = batch_size - self.head_dim = ( - config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads - ) - self.max_cache_len = config.max_position_embeddings if max_length is None else max_length - - self.num_key_value_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads - ) - cache_shape = (self.batch_size, self.num_key_value_heads, max_length, self.head_dim) - - for i in range(config.num_hidden_layers): - if self.hybrid_override_pattern[i] == "m2d": - # Mamba layer - new_layer_conv_state = torch.zeros( - batch_size, - conv_kernel_size, - intermediate_size + 2 * self.n_qk_heads * ssm_state_size, - device=device, - dtype=dtype, - ).transpose(1, 2) - - new_layer_ssm_state = torch.zeros( - batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype - ) - new_layer_key_cache = None # torch.zeros((0,), dtype=dtype, device=device) - new_layer_value_cache = None # torch.zeros((0,), dtype=dtype, device=device) - else: - # Attention or MLP layer - new_layer_conv_state = None # torch.tensor((0,), dtype=dtype, device=device) - new_layer_ssm_state = None # torch.tensor((0,), dtype=dtype, device=device) - new_layer_key_cache = torch.zeros(cache_shape, dtype=dtype, device=device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=dtype, device=device) - self.transformer_layers.append(i) - - # if not is_torchdynamo_compiling(): - # self.register_buffer(f"key_cache_{i}", torch.zeros(cache_shape, dtype=dtype, device=device)) - # self.register_buffer(f"value_cache_{i}", torch.zeros(cache_shape, dtype=dtype, device=device)) - # new_layer_key_cache = getattr(self, f"key_cache_{i}") - # new_layer_value_cache = getattr(self, f"value_cache_{i}") - # torch._dynamo.mark_static_address(new_layer_key_cache) - # torch._dynamo.mark_static_address(new_layer_value_cache) - # self.register_buffer(f"conv_states_{i}", new_layer_conv_state) - # self.register_buffer(f"ssm_states_{i}", new_layer_ssm_state) - # torch._dynamo.mark_static_address(new_layer_conv_state) - # torch._dynamo.mark_static_address(new_layer_ssm_state) - # new_layer_ssm_state = getattr(self, f"ssm_states_{i}") - # new_layer_conv_state = getattr(self, f"conv_states_{i}") - - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) - self.conv_states.append(new_layer_conv_state) - self.ssm_states.append(new_layer_ssm_state) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - It is VERY important to index using a tensor, otherwise you introduce a copy to the device. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input - to know how where to write in the cache. - - Return: - A tuple containing the updated key and value states. - """ - - cache_position = cache_kwargs.get("cache_position") - - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] - key_states = key_states.to(k_out.dtype) - value_states = value_states.to(v_out.dtype) - - if cache_position is None: - k_out.copy_(key_states) - v_out.copy_(value_states) - else: - # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to - # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place - # operation, that avoids copies and uses less memory. - try: - k_out.index_copy_(2, cache_position, key_states) - v_out.index_copy_(2, cache_position, value_states) - except NotImplementedError: - # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - return k_out, v_out - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - def get_seq_length(self, layer_idx: Optional[int] = None) -> int: - """Returns the sequence length of the cached states that were seen by the model.""" - # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's - # limit the check to the first batch member and head dimension. - # TODO: deprecate this function in favor of `cache_position` - if layer_idx is None: - if len(self.transformer_layers) > 0: - layer_idx = self.transformer_layers[0] - else: - return 0 - return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() - - def get_max_cache_shape(self) -> Optional[int]: - return self.max_cache_len - - # Copied from modeling_mamba2.py - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False - ) -> torch.Tensor: - if cache_init: - self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) - else: - self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) - self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) - return self.conv_states[layer_idx] - - def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): - self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) - return self.ssm_states[layer_idx] - - def reset(self): - self.conv_states.zero_() - self.ssm_states.zero_() - - -# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py -class HybridMambaAttentionDynamicCache(DynamicCache): - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - def __init__(self, config: AprielSSMHybridConfig, batch_size, dtype=torch.float16, device=None): - super().__init__() - self.dtype = dtype - self.hybrid_override_pattern = config.hybrid_block_layout - self.has_previous_state = False # only used by mamba - intermediate_size = config.ssm_cfg["d_inner"] - ssm_state_size = config.ssm_cfg["d_state"] - conv_kernel_size = config.ssm_cfg["d_conv"] - self.n_qk_heads = config.ssm_cfg["n_qk_heads"] - assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" - self.head_d = intermediate_size // self.n_qk_heads - self.conv_states = [] - self.ssm_states = [] - self.transformer_layers = [] - for i in range(config.num_hidden_layers): - if self.hybrid_override_pattern[i] == "m2d": - # Mamba layer - self.conv_states += [ - torch.zeros( - batch_size, - conv_kernel_size, - intermediate_size + 2 * self.n_qk_heads * ssm_state_size, - device=device, - dtype=dtype, - ).transpose(1, 2) - ] - self.ssm_states += [ - torch.zeros(batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype) - ] - else: - # Attention or MLP layer - self.conv_states += [torch.tensor([[]] * batch_size, device=device)] - self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] - self.transformer_layers.append(i) - - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # Update the cache - if self.key_cache[layer_idx].shape[-1] == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: - raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") - - @classmethod - def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache": - raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") - - # Copied from modeling_mamba2.py - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False - ) -> torch.Tensor: - if cache_init: - self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) - else: - self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) - self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) - return self.conv_states[layer_idx] - - def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): - self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) - return self.ssm_states[layer_idx] - - def reset(self): - self.conv_states.zero_() - self.ssm_states.zero_() - - -@dataclass -class AprielHybridCausalOutput(ModelOutput): - """Custom output class for MambaLMHeadModel.""" - - loss: Optional[torch.FloatTensor] = None - logits: Optional[torch.FloatTensor] = None - all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None - last_hidden_state: Optional[torch.FloatTensor] = None - attention_weights: Optional[torch.FloatTensor] = None - past_key_values: Optional[HybridMambaAttentionDynamicCache] = None - - -class AprielRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6, device=None, dtype=None, **kwargs): - """ - AprielRMSNorm is equivalent to T5LayerNorm - """ - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -ALL_LAYERNORM_LAYERS.append(AprielRMSNorm) - - -class AprielMLP(nn.Module): - def __init__(self, config, device=None, dtype=None, **kwargs): - super().__init__(**kwargs) - factory_kwargs = {"device": device, "dtype": dtype} - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias, **factory_kwargs) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - -class AprielRotaryEmbedding(nn.Module): - def __init__(self, config: AprielSSMHybridConfig, device=None): - super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - # This .to() is needed if the model has been moved to a device after being initialized (because - # the buffer is automatically moved, but not the original copy) - self.original_inv_freq = self.original_inv_freq.to(device) - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs, -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - -class AprielAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: AprielSSMHybridConfig, layer_idx: int): - super().__init__() - self.config = config - self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout - self.is_causal = True - - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias - ) - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - **kwargs, - ) - - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights - - -def segsum(x): - """More stable segment sum calculation.""" - # [1, 2, 3] - T = x.size(-1) - x = repeat(x, "... d -> ... d e", e=T) - # [[1, 1, 1], [2, 2, 2], [3, 3, 3]] - mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) - x = x.masked_fill(~mask, 0) - # [[0, 0, 0], [2, 0, 0], [3, 3, 0]] - x_segsum = torch.cumsum(x, dim=-2) - # [[0, 0, 0], [2, 0, 0], [5, 3, 0]] - mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) - x_segsum = x_segsum.masked_fill(~mask, -torch.inf) - return x_segsum - - -def materialize_mixer(A_log, B, C, D): - """ - Since the transfer matrix will be equated to the attention matrix, - we need to support the form: torch.matmul(attn_weights, value_states). - Thus, y = torch.matmul(T, X) - Arguments: - A_log: (batch, length, n_heads) - B: (batch, length, n_heads, d_state) - C: (batch, length, n_heads, d_state) - Return: - T: (batch, n_heads, length, length) - """ - batch_size, length, n_heads, d_state = B.shape - assert A_log.shape == (batch_size, length, n_heads) - assert B.shape == C.shape == (batch_size, length, n_heads, d_state) - - # Compute: - A_log = rearrange(-F.softplus(A_log), "b l h -> b h l") - powers = torch.exp(segsum(A_log)) - T = torch.einsum("blhn,bshn,bhls->bhsl", C, B, powers) - - # Add D: - if D is not None: - T[:, :, torch.arange(length), torch.arange(length)] += D.view(1, n_heads, 1) - - T = rearrange(T, "b h z l -> b h l z") - return T - - -class DiscreteMamba2(nn.Module): - def __init__( - self, - d_model, - d_state=64, - n_qk_heads=32, - n_v_heads=32, - d_conv=4, - expand=1, - activation="identity", - bias=False, - conv_bias=True, - chunk_size=128, - layer_idx=None, - device=None, - dtype=None, - d_inner=None, - **kwargs, # Absorb kwarg for general module - ): - """ - See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. - Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" - - Other options are all experimental and should not need to be configured - """ - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.d_model = d_model - self.d_state = d_state - self.d_conv = d_conv - self.expand = expand - self.d_inner = self.expand * self.d_model if d_inner is None else d_inner - self.n_qk_heads = n_qk_heads - self.n_v_heads = n_v_heads - self.headdim = self.d_inner // self.n_v_heads - assert self.n_v_heads == self.d_inner // self.headdim - assert self.d_inner % self.headdim == 0 - assert self.n_v_heads % self.n_qk_heads == 0 - self.activation = activation - self.chunk_size = chunk_size - self.layer_idx = layer_idx - self.bias = bias - self.kwargs = kwargs - - # Projections - self.in_proj = nn.Linear( - self.d_model, - 2 * self.d_inner + 2 * self.n_qk_heads * self.d_state + self.n_v_heads, - bias=bias, - **factory_kwargs, - ) - self.z_bias = ( - nn.Parameter(torch.zeros(self.d_inner, device=device)) if not bias else 0 - ) # make sure z_bias always exists - - # Convolutional layer - conv_dim = self.d_inner + 2 * self.n_qk_heads * self.d_state - self.conv_bias = conv_bias - self.conv1d = nn.Conv1d( - in_channels=conv_dim, - out_channels=conv_dim, - bias=conv_bias, - kernel_size=d_conv, - groups=conv_dim, - padding=d_conv - 1, - **factory_kwargs, - ) - - # Activation after conv - if self.activation == "identity": - self.act = nn.Identity() - elif self.activation in ["silu", "swish"]: - self.act = nn.SiLU() - else: - raise ValueError(f"Unknown activation {self.activation}") - - # D "skip" parameter - self.D = nn.Parameter(torch.ones(self.n_v_heads, device=device)) - self.D._optim = {"weight_decay": 0.0} - - # out_proj - self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - # In __init__, pre-allocate these tensors - self.zeros_buffer = torch.zeros((self.n_v_heads, self.headdim), device=device, dtype=dtype) - self.ones_buffer = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=device, dtype=dtype) - - @property - def d_output(self): - return self.d_model - - @property - def state_to_tensor(self): - return self.layer.state_to_tensor - - def forward( - self, - u, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, - attention_mask: Optional[torch.Tensor] = None, - return_mixer_matrix=False, - **kwargs, - ): - """ - u: (B, L, D) - Returns: same shape as u - For later refference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bamba/modeling_bamba.py - """ - assert is_fast_path_available and "cuda" in self.in_proj.weight.device.type, "Only support fast path on cuda" - cache_position = kwargs.get("cache_position", None) - batch, seqlen, dim = u.shape - u = apply_mask_to_padding_states(u, attention_mask) - ssm_state, conv_state = None, None - - use_precomputed_states = ( - past_key_value is not None - and past_key_value.has_previous_state - and seqlen == 1 - and past_key_value.conv_states[self.layer_idx].shape[0] - == past_key_value.ssm_states[self.layer_idx].shape[0] - == batch - and cache_position is not None - and cache_position[0] > 0 - ) - ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch) - if use_precomputed_states: - u = u.squeeze(1) if len(u.shape) == 3 else u - out, _, _ = self.step(u, ssm_state, conv_state) - out = out.unsqueeze(1) if len(u.shape) == 2 else out - return {"hidden_states": out} - - outputs = {} - # Hacky way to initialize state during inference - chunk_size = self.chunk_size if ssm_state is None else seqlen - - # Pad input to nearest multiple of chunklen - padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size - u = F.pad(u, (0, 0, 0, padded_len - seqlen)) - - # Project input - xBCzA_log = self.in_proj(u) - xBC, z, A_log = torch.split( - xBCzA_log, - [ - self.d_inner + 2 * self.n_qk_heads * self.d_state, - self.d_inner, - self.n_v_heads, - ], - dim=-1, - ) - - if ssm_state is not None: - # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") - conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) - - # Convolutional layer - xBC = self.convolutional_forward(xBC, padded_len) - - x, B, C = torch.split( - xBC, - [ - self.d_inner, - self.n_qk_heads * self.d_state, - self.n_qk_heads * self.d_state, - ], - dim=-1, - ) - - x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) - B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) - C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) - - # SSM forward - result = mamba_chunk_scan_combined( - x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), - dt=A_log, - dt_softplus=True, - A=-torch.ones(self.n_v_heads, device=A_log.device), - B=B, - C=C, - chunk_size=chunk_size, - # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation - return_final_states=(ssm_state is not None), - ) - - if ssm_state is not None: - y, ssm_state_update = result - ssm_state.copy_(ssm_state_update) - else: - y = result - - Du = torch.einsum("h,blhp->blhp", self.D, x) - y = rearrange(y + Du, "b l h p -> b l (h p)") - - # Norm and gate - out = self.out_proj(y * F.silu(z + self.z_bias)) - outputs["hidden_states"] = out[:, :seqlen, :] - - if return_mixer_matrix: - outputs["transfer_matrix"] = materialize_mixer(A_log=A_log, B=B, C=C, D=self.D)[..., :seqlen, :seqlen] - return outputs - - def step(self, u, ssm_state, conv_state, **kwargs): - """ - u: (B D) - state: dict of states - Returns: same shape as u - """ - - # Project input - xBCzA_log = self.in_proj(u) - xBC, z, A_log = torch.split( - xBCzA_log, - [ - self.d_inner + 2 * self.n_qk_heads * self.d_state, - self.d_inner, - self.n_v_heads, - ], - dim=-1, - ) - - xBC, conv_state_new = self.convolutional_step(xBC, conv_state) - if conv_state_new is not None: - raise NotImplementedError("Should not end up here snce only support fast path.") - # conv_state.copy_(conv_state_new) # update state in place, only for slow pass - - x, B, C = torch.split( - xBC, - [ - self.d_inner, - self.n_qk_heads * self.d_state, - self.n_qk_heads * self.d_state, - ], - dim=-1, - ) - - x = rearrange(x, "b (h s) -> b h s", h=self.n_v_heads) - B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) - C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) - - ssm_state = ssm_state.to(x.dtype) - zeros = self.zeros_buffer.to(A_log.device).to(x.dtype) # Just cast, don't allocate - ones = self.ones_buffer.to(A_log.device).to(x.dtype) - y = selective_state_update( - x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), - dt=repeat(A_log, "b h -> b h p", p=self.headdim), - dt_softplus=True, - A=-ones, - B=B, - C=C, - state=ssm_state, # will be updated in place - dt_bias=zeros, - D=zeros, - ) - - y = y + self.D[:, None] * x - y = rearrange(y, "b h p -> b (h p)") - - # Norm and gate - out = self.out_proj(y * F.silu(z + self.z_bias)) - - return out, ssm_state, conv_state - - def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): - """ - conv_state: (batch, d_conv, conv1d.weight.shape[0]) - ssm_state: (batch, n_qk_heads, headdim, d_state) - """ - assert self.layer_idx is not None - # Allocate memory if not exists - # if self.layer_idx not in inference_params.ssm_states: - # inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( - # batch_size, inference_params.max_seqlen, dtype=torch.float32 - # ) - # Get states - ssm_states = inference_params.ssm_states[self.layer_idx] - conv_states = inference_params.conv_states[self.layer_idx] - if initialize_states: - ssm_states.zero_() - conv_states.zero_() - return ssm_states, conv_states - - def convolutional_forward(self, xBC, padded_len): - if causal_conv1d_fn is None or self.activation not in [ - "silu", - "swish", - "identity", - ]: - xBC = self.act(self.conv1d(xBC.transpose(1, 2))[..., :padded_len].transpose(1, 2)) - else: - xBC = causal_conv1d_fn( - xBC.transpose(1, 2), - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - activation=None if self.activation == "identity" else self.activation, - ).transpose(1, 2) - return xBC - - def convolutional_step(self, xBC, conv_state): - # Convolutional layer - conv_state = conv_state.to(xBC.dtype) - if causal_conv1d_update: - xBC = causal_conv1d_update( - xBC, - conv_state, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation if self.activation != "identity" else None, - ) - return xBC, None - else: - conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) - conv_state[:, :, -1] = xBC - xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) - if self.conv_bias: - xBC = xBC + self.conv1d.bias - xBC = self.act(xBC).to(xBC.dtype) # Some activations change dtype - - return xBC, conv_state - - -class AprielDecoderLayer(nn.Module): - def __init__(self, config: AprielSSMHybridConfig, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - self.self_attn = AprielAttention(config=config, layer_idx=layer_idx) - - self.mlp = AprielMLP(config) - self.input_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - if output_attentions: - outputs += (self_attn_weights,) - - return outputs - - -class AprielSSMDecoderLayer(nn.Module): - def __init__(self, config: AprielSSMHybridConfig, layer_idx: int, device=None, dtype=None, **kwargs): - super().__init__(**kwargs) - factory_kwargs = {"device": device, "dtype": dtype} - self.hidden_size = config.hidden_size - - self.mixer = DiscreteMamba2( - d_model=config.hidden_size, - layer_idx=layer_idx, - **config.ssm_cfg, - **factory_kwargs, - ) - - self.mlp = AprielMLP(config, **factory_kwargs) - self.input_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) - self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) - - def forward( - self, hidden_states: torch.Tensor, **kwargs - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - - outputs = {} - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - mixer_outputs = self.mixer( - hidden_states, - **kwargs, - ) - - hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - # outputs["hidden_states"] = hidden_states - outputs = (hidden_states,) - - return outputs - - # def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - # """Allocate inference cache for the model.""" - # if getattr(self.mixer, "allocate_inference_cache", None) is None: - # return - # return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) - - -APRIEL_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - Parameters: - config ([`AprielSSMHybridConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Apriel Model outputting raw hidden-states without any specific head on top.", - APRIEL_START_DOCSTRING, -) -class AprielSSMPreTrainedModel(PreTrainedModel): - config_class = AprielSSMHybridConfig - base_model_prefix = "model" - _no_split_modules = ["AprielDecoderLayer", "AprielSSMDecoderLayer"] - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = True - _supports_attention_backend = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - # def allocate_inference_cache(self, *args, **kwargs): - # """Allocate inference cache for the model.""" - # return getattr(self, self.base_model_prefix).allocate_inference_cache(*args, **kwargs) - - -APRIEL_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - [What are attention masks?](../glossary#attention-mask) - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - If `past_key_values` is used, optionally only the last `input_ids` have to be input (see - `past_key_values`). - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - Two formats are allowed: - - a [`~cache_utils.Cache`] instance, see our - [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, - this tensor is not affected by padding. It is used to update the cache in the correct position and to infer - the complete sequence length. -""" - - -@add_start_docstrings( - "The bare Apriel Model outputting raw hidden-states without any specific head on top.", - APRIEL_START_DOCSTRING, -) -class AprielSSMHybridModel(AprielSSMPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AprielDecoderLayer`, `AprielSSMDecoderLayer`] - Args: - config: AprielSSMHybridConfig - """ - - def __init__(self, config: AprielSSMHybridConfig, device=None, dtype=None, **kwargs): - super().__init__(config, device=device, dtype=dtype, **kwargs) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - factory_kwargs = {"device": device, "dtype": dtype} - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, **factory_kwargs) - blocks = [] - logger.info(f"Loading hyubrid model with the following layout: {config.hybrid_block_layout}") - for layer_idx, type in enumerate(config.hybrid_block_layout): - if type == "m2d": - blocks.append(AprielSSMDecoderLayer(config, layer_idx, **factory_kwargs)) - elif type == "t": - blocks.append(AprielDecoderLayer(config, layer_idx)) - else: - raise ValueError(f"Invalid block type: {type}") - self.layers = nn.ModuleList(blocks) - self.norm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) - self.gradient_checkpointing = False - self.rotary_emb = AprielRotaryEmbedding(config=config) - self.has_transformer_layers = any(type == "t" for type in config.hybrid_block_layout) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - # def allocate_inference_cache(self, *args, **kwargs): - # """Allocate inference cache for the model.""" - # cache = {} - # for i, layer in enumerate(self.layers): - # if isinstance(layer, AprielSSMDecoderLayer): - # cache[i] = layer.allocate_inference_cache(*args, **kwargs) - # return cache - - @add_start_docstrings_to_model_forward(APRIEL_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - inference_params=None, - **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if use_cache and past_key_values is None: - # past_key_values = HybridMambaAttentionDynamicCache() - logger.warning_once( - "Hybrid Apriel requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was " - "provided, so no cache will be returned." - ) - - if cache_position is None and self.has_transformer_layers: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - if position_ids is None and self.has_transformer_layers: - position_ids = cache_position.unsqueeze(0) - - causal_mask = ( - self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions) - if self.has_transformer_layers - else None - ) - - hidden_states = inputs_embeds - - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) if self.has_transformer_layers else None - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - if output_hidden_states: - all_hidden_states += (hidden_states,) - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - inference_params=inference_params, - **flash_attn_kwargs, - ) - - hidden_states = layer_outputs[0] - - if output_attentions and isinstance(decoder_layer, AprielDecoderLayer): - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - output = BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values if use_cache else None, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - return output if return_dict else output.to_tuple() - - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) or isinstance( - past_key_values, HybridMambaAttentionStaticCache - ) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype, device = input_tensor.dtype, input_tensor.device - sequence_length = input_tensor.shape[1] - if using_static_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... - - -class AprielSSMHybridForCausalLM(AprielSSMPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - - def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) - self.model = AprielSSMHybridModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - output_router_logits=False, - cache_position=None, - position_ids=None, - use_cache=True, - **kwargs, - ): - # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` - - empty_past_kv = past_key_values is None or not isinstance(past_key_values, HybridMambaAttentionDynamicCache) - - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - # Exception 1: when passing input_embeds, input_ids may be missing entries - # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. - # (we can't check exception 3 while compiling) - if not empty_past_kv: - if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 # Exception 3 - input_ids = input_ids[:, -cache_position.shape[0] :] - elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - else: - past_key_values = HybridMambaAttentionDynamicCache( - self.config, input_ids.shape[0], self.dtype, device=self.device - ) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if not empty_past_kv: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and empty_past_kv: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "attention_mask": attention_mask, - "output_router_logits": output_router_logits, - # "logits_to_keep": self.config.num_logits_to_keep, - "cache_position": cache_position, - } - ) - return model_inputs - - def forward( - self, - input_ids: torch.LongTensor = None, - position_ids=None, - return_hidden_states=False, - return_logits=True, - num_last_tokens=0, - past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, - **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[tuple, CausalLMOutputWithPast]: - - # past_key_values is None if prepare_inputs_for_generation is not called, which is the case when we evaluate without calling generate (non-generation tasks) - # Its generally ok if cache is nto instantiated in this case, since we do single pass per sample anyways, a warning will be triggered in the model - outputs: BaseModelOutputWithPast = self.model( - input_ids, - return_hidden_states=return_hidden_states, - position_ids=position_ids, - past_key_values=past_key_values, - **kwargs, - ) - - if outputs["last_hidden_state"] is not None and return_logits: - logits = self.lm_head(outputs["last_hidden_state"]).float() - outputs["logits"] = logits if num_last_tokens == 0 else logits[:, -num_last_tokens:] - else: - outputs["logits"] = None - - return AprielHybridCausalOutput( - loss=None, - logits=outputs["logits"], - all_hidden_states=outputs.hidden_states, - last_hidden_state=outputs.last_hidden_state, - past_key_values=outputs.past_key_values, - ) - - -__all__ = [ - "AprielSSMHybridForCausalLM", - "AprielSSMHybridModel", - "AprielSSMPreTrainedModel", -] diff --git a/fast_llm/models/ssm/external/apriel_ssm/configuration_ssm_apriel.py b/fast_llm/models/ssm/external/apriel_ssm/configuration_ssm_apriel.py deleted file mode 100644 index 6943a3124..000000000 --- a/fast_llm/models/ssm/external/apriel_ssm/configuration_ssm_apriel.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Apriel SSM model configuration""" - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import is_torch_available, logging - -logger = logging.get_logger(__name__) - -if is_torch_available(): - pass - - -class AprielSSMConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`AprielModel`]. It is used to instantiate an Apriel - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the Apriel-5B-Base. - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - Args: - .... - ```""" - - model_type = "apriel_ssm" - - def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - hidden_act="silu", - initializer_range=0.02, - use_cache=True, - pad_token_id=None, - bos_token_id=1, - eos_token_id=2, - tie_word_embeddings=False, - mlp_bias=False, - rms_norm_eps=1e-5, - ssm_cfg: dict = None, - head_dim: int = 128, - **kwargs, - ): - self.vocab_size = vocab_size - # self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - # self.pretraining_tp = pretraining_tp - self.use_cache = use_cache - # self.rope_theta = rope_theta - self.mlp_bias = mlp_bias - self.head_dim = head_dim - # Validate the correctness of rotary position embeddings parameters - # BC: if there is a 'type' field, copy it it to 'rope_type'. - # if self.rope_scaling is not None and "type" in self.rope_scaling: - # self.rope_scaling["rope_type"] = self.rope_scaling["type"] - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - self.ssm_cfg = ssm_cfg or { - "d_state": 64, - "n_v_heads": 24, - "n_qk_heads": 24, - "expand": 1, - "chunk_size": 128, - "activation": "identity", - "bias": False, - "d_inner": 24 * self.head_dim, # num_heads * head_dim - } - if self.head_dim != self.ssm_cfg["d_inner"] // self.ssm_cfg["n_qk_heads"]: - logger.warning("Head dim is not equal to d_inner // n_qk_heads.") - - -__all__ = ["AprielConfig"] diff --git a/fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py b/fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py deleted file mode 100644 index 09dc8259c..000000000 --- a/fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py +++ /dev/null @@ -1,743 +0,0 @@ -from dataclasses import dataclass -from typing import Optional, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -from einops import rearrange, repeat -from mamba_ssm.ops.triton.selective_state_update import selective_state_update -from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined -from mamba_ssm.utils.generation import GenerationMixin -from torch import nn -from transformers.activations import ACT2FN -from transformers.modeling_flash_attention_utils import FlashAttentionKwargs -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.processing_utils import Unpack -from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS -from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging -from transformers.utils.generic import ModelOutput - -from fast_llm.models.ssm.external.apriel_ssm.configuration_ssm_apriel import AprielSSMConfig - -logger = logging.get_logger(__name__) - - -@dataclass -class CustomMambaCausalLMOutput(ModelOutput): - """Custom output class for MambaLMHeadModel.""" - - loss: Optional[torch.FloatTensor] = None - logits: Optional[torch.FloatTensor] = None - all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None - last_hidden_state: Optional[torch.FloatTensor] = None - - -class AprielRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6, device=None, dtype=None, **kwargs): - """ - AprielRMSNorm is equivalent to T5LayerNorm - """ - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -ALL_LAYERNORM_LAYERS.append(AprielRMSNorm) - - -class AprielMLP(nn.Module): - def __init__(self, config, device=None, dtype=None, **kwargs): - super().__init__(**kwargs) - factory_kwargs = {"device": device, "dtype": dtype} - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias, **factory_kwargs) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - -def segsum(x): - """More stable segment sum calculation.""" - # [1, 2, 3] - T = x.size(-1) - x = repeat(x, "... d -> ... d e", e=T) - # [[1, 1, 1], [2, 2, 2], [3, 3, 3]] - mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) - x = x.masked_fill(~mask, 0) - # [[0, 0, 0], [2, 0, 0], [3, 3, 0]] - x_segsum = torch.cumsum(x, dim=-2) - # [[0, 0, 0], [2, 0, 0], [5, 3, 0]] - mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) - x_segsum = x_segsum.masked_fill(~mask, -torch.inf) - return x_segsum - - -def materialize_mixer(A_log, B, C, D): - """ - Since the transfer matrix will be equated to the attention matrix, - we need to support the form: torch.matmul(attn_weights, value_states). - Thus, y = torch.matmul(T, X) - Arguments: - A_log: (batch, length, n_heads) - B: (batch, length, n_heads, d_state) - C: (batch, length, n_heads, d_state) - Return: - T: (batch, n_heads, length, length) - """ - batch_size, length, n_heads, d_state = B.shape - assert A_log.shape == (batch_size, length, n_heads) - assert B.shape == C.shape == (batch_size, length, n_heads, d_state) - - # Compute: - A_log = rearrange(-F.softplus(A_log), "b l h -> b h l") - powers = torch.exp(segsum(A_log)) - T = torch.einsum("blhn,bshn,bhls->bhsl", C, B, powers) - - # Add D: - if D is not None: - T[:, :, torch.arange(length), torch.arange(length)] += D.view(1, n_heads, 1) - - T = rearrange(T, "b h z l -> b h l z") - return T - - -class DiscreteMamba2(nn.Module): - def __init__( - self, - d_model, - d_state=64, - n_qk_heads=32, - n_v_heads=32, - d_conv=4, - expand=1, - activation="identity", - bias=False, - conv_bias=True, - chunk_size=128, - layer_idx=None, - device=None, - dtype=None, - d_inner=None, - **kwargs, # Absorb kwarg for general module - ): - """ - See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. - Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" - - Other options are all experimental and should not need to be configured - """ - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.d_model = d_model - self.d_state = d_state - self.d_conv = d_conv - self.expand = expand - self.d_inner = self.expand * self.d_model if d_inner is None else d_inner - self.n_qk_heads = n_qk_heads - self.n_v_heads = n_v_heads - self.headdim = self.d_inner // self.n_v_heads - assert self.n_v_heads == self.d_inner // self.headdim - assert self.d_inner % self.headdim == 0 - assert self.n_v_heads % self.n_qk_heads == 0 - self.activation = activation - self.chunk_size = chunk_size - self.layer_idx = layer_idx - self.bias = bias - self.kwargs = kwargs - - # Projections - self.in_proj = nn.Linear( - self.d_model, - 2 * self.d_inner + 2 * self.n_qk_heads * self.d_state + self.n_v_heads, - bias=bias, - **factory_kwargs, - ) - self.z_bias = ( - nn.Parameter(torch.zeros(self.d_inner, **factory_kwargs)) if not bias else 0 - ) # make sure z_bias always exists - - # Convolutional layer - conv_dim = self.d_inner + 2 * self.n_qk_heads * self.d_state - self.conv_bias = conv_bias - self.conv1d = nn.Conv1d( - in_channels=conv_dim, - out_channels=conv_dim, - bias=conv_bias, - kernel_size=d_conv, - groups=conv_dim, - padding=d_conv - 1, - **factory_kwargs, - ) - - # Activation after conv - if self.activation == "identity": - self.act = nn.Identity() - elif self.activation in ["silu", "swish"]: - self.act = nn.SiLU() - else: - raise ValueError(f"Unknown activation {self.activation}") - - # D "skip" parameter - self.D = nn.Parameter(torch.ones(self.n_v_heads, **factory_kwargs)) - self.D._optim = {"weight_decay": 0.0} - - # out_proj - self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - - @property - def d_output(self): - return self.d_model - - @property - def state_to_tensor(self): - return self.layer.state_to_tensor - - def forward(self, u, return_mixer_matrix=False, inference_params=None, **kwargs): - """ - u: (B, L, D) - Returns: same shape as u - """ - outputs = {} - # assert state is None - batch, seqlen, dim = u.shape - - state = None - if inference_params is not None: - state = self._get_states_from_cache(inference_params, batch) - if inference_params.seqlen_offset > 0: - # States are updated inplace - u = u.squeeze(1) if len(u.shape) == 3 else u - out, _ = self.step(u, state) - out = out.unsqueeze(1) if len(u.shape) == 2 else out - return {"hidden_states": out} - - # Hacky way to initialize state during inference - chunk_size = self.chunk_size if state is None else seqlen - - # Pad input to nearest multiple of chunklen - padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size - u = F.pad(u, (0, 0, 0, padded_len - seqlen)) - - # Project input - xBCzA_log = self.in_proj(u) - xBC, z, A_log = torch.split( - xBCzA_log, - [ - self.d_inner + 2 * self.n_qk_heads * self.d_state, - self.d_inner, - self.n_v_heads, - ], - dim=-1, - ) - - if state is not None: - # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") - state["conv"].copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) - - # Convolutional layer - xBC = self.convolutional_forward(xBC, padded_len) - - x, B, C = torch.split( - xBC, - [ - self.d_inner, - self.n_qk_heads * self.d_state, - self.n_qk_heads * self.d_state, - ], - dim=-1, - ) - - x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) - B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) - C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) - - # SSM forward - result = mamba_chunk_scan_combined( - x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), - dt=A_log, - dt_softplus=True, - A=-torch.ones(self.n_v_heads, device=A_log.device), - B=B, - C=C, - chunk_size=chunk_size, - # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation - return_final_states=(state is not None), - ) - - if state is not None: - y, ssm_state = result - state["ssm"].copy_(ssm_state) - else: - y = result - - Du = torch.einsum("h,blhp->blhp", self.D, x) - y = rearrange(y + Du, "b l h p -> b l (h p)") - - # Norm and gate - out = self.out_proj(y * F.silu(z + self.z_bias)) - outputs["hidden_states"] = out[:, :seqlen, :] - - if return_mixer_matrix: - outputs["transfer_matrix"] = materialize_mixer(A_log=A_log, B=B, C=C, D=self.D)[..., :seqlen, :seqlen] - return outputs - - def step(self, u, state, **kwargs): - """ - u: (B D) - state: dict of states - Returns: same shape as u - """ - - # Project input - xBCzA_log = self.in_proj(u.squeeze(1)) - xBC, z, A_log = torch.split( - xBCzA_log, - [ - self.d_inner + 2 * self.n_qk_heads * self.d_state, - self.d_inner, - self.n_v_heads, - ], - dim=-1, - ) - - xBC, conv_state = self.convolutional_step(xBC, state["conv"]) - state["conv"].copy_(conv_state) # update state in place - - x, B, C = torch.split( - xBC, - [ - self.d_inner, - self.n_qk_heads * self.d_state, - self.n_qk_heads * self.d_state, - ], - dim=-1, - ) - - x = rearrange(x, "b (h s) -> b h s", h=self.n_v_heads) - B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) - C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) - - state["ssm"] = state["ssm"].to(x.dtype) - zeros = torch.zeros((self.n_v_heads, self.headdim), device=A_log.device).to(dtype=x.dtype) - ones = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=A_log.device).to(dtype=x.dtype) - y = selective_state_update( - x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), - dt=repeat(A_log, "b h -> b h p", p=self.headdim), - dt_softplus=True, - A=-ones, - B=B, - C=C, - state=state["ssm"], # will be updated in place - dt_bias=zeros, - D=zeros, - ) - - y = y + self.D[:, None] * x - y = rearrange(y, "b h p -> b (h p)") - - # Norm and gate - out = self.out_proj(y * F.silu(z + self.z_bias)) - - return out, state - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - device = self.in_proj.weight.device - # conv_state: - conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype - conv_state = torch.zeros( - batch_size, - self.d_conv, - self.conv1d.weight.shape[0], - device=device, - dtype=conv_dtype, - ).transpose(1, 2) - # ssm_state: - ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype - ssm_state = torch.zeros( - batch_size, - self.n_v_heads, - self.headdim, - self.d_state, - device=device, - dtype=ssm_dtype, - ) - return {"conv": conv_state, "ssm": ssm_state} - - def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): - """ - conv_state: (batch, d_conv, conv1d.weight.shape[0]) - ssm_state: (batch, n_qk_heads, headdim, d_state) - """ - assert self.layer_idx is not None - # Allocate memory if not exists - if self.layer_idx not in inference_params.key_value_memory_dict: - inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( - batch_size, inference_params.max_seqlen, dtype=torch.float32 - ) - # Get states - states = inference_params.key_value_memory_dict[self.layer_idx] - if initialize_states: - states["conv"].zero_() - states["ssm"].zero_() - return states - - def convolutional_forward(self, xBC, padded_len): - if causal_conv1d_fn is None or self.activation not in [ - "silu", - "swish", - "identity", - ]: - xBC = self.act(self.conv1d(xBC.transpose(1, 2))[..., :padded_len].transpose(1, 2)) - else: - xBC = causal_conv1d_fn( - xBC.transpose(1, 2), - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - activation=None if self.activation == "identity" else self.activation, - ).transpose(1, 2) - return xBC - - def convolutional_step(self, xBC, conv_state): - # Convolutional layer - conv_state = conv_state.to(xBC.dtype) - if causal_conv1d_update: - xBC = causal_conv1d_update( - xBC, - conv_state, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation if self.activation != "identity" else None, - ) - else: - conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) - conv_state[:, :, -1] = xBC - xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) - if self.conv_bias: - xBC = xBC + self.conv1d.bias - xBC = self.act(xBC).to(xBC.dtype) # Some activations change dtype - - return xBC, conv_state - - -class AprielDecoderLayer(nn.Module): - def __init__(self, config: AprielSSMConfig, layer_idx: int, device=None, dtype=None, **kwargs): - super().__init__(**kwargs) - factory_kwargs = {"device": device, "dtype": dtype} - self.hidden_size = config.hidden_size - - self.mixer = DiscreteMamba2( - d_model=config.hidden_size, - layer_idx=layer_idx, - **config.ssm_cfg, - **factory_kwargs, - ) - - self.mlp = AprielMLP(config, **factory_kwargs) - self.input_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) - self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) - - def forward( - self, hidden_states: torch.Tensor, inference_params=None, **kwargs - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - - outputs = {} - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - mixer_outputs = self.mixer( - hidden_states, - inference_params=inference_params, - ) - - hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs["hidden_states"] = hidden_states - - return outputs - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - """Allocate inference cache for the model.""" - if getattr(self.mixer, "allocate_inference_cache", None) is None: - return - return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) - - -APRIEL_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - Parameters: - config ([`AprielSSMConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Apriel Model outputting raw hidden-states without any specific head on top.", - APRIEL_START_DOCSTRING, -) -class AprielSSMPreTrainedModel(PreTrainedModel): - config_class = AprielSSMConfig - base_model_prefix = "model" - _no_split_modules = ["AprielDecoderLayer"] - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def allocate_inference_cache(self, *args, **kwargs): - """Allocate inference cache for the model.""" - return getattr(self, self.base_model_prefix).allocate_inference_cache(*args, **kwargs) - - -APRIEL_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - [What are attention masks?](../glossary#attention-mask) - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - If `past_key_values` is used, optionally only the last `input_ids` have to be input (see - `past_key_values`). - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - Two formats are allowed: - - a [`~cache_utils.Cache`] instance, see our - [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, - this tensor is not affected by padding. It is used to update the cache in the correct position and to infer - the complete sequence length. -""" - - -@add_start_docstrings( - "The bare Apriel Model outputting raw hidden-states without any specific head on top.", - APRIEL_START_DOCSTRING, -) -class AprielSSMModel(AprielSSMPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AprielDecoderLayer`] - Args: - config: AprielSSMConfig - """ - - def __init__(self, config: AprielSSMConfig, device=None, dtype=None, **kwargs): - super().__init__(config, device=device, dtype=dtype, **kwargs) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - factory_kwargs = {"device": device, "dtype": dtype} - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, **factory_kwargs) - self.layers = nn.ModuleList( - [AprielDecoderLayer(config, layer_idx, **factory_kwargs) for layer_idx in range(config.num_hidden_layers)] - ) - self.norm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) - self.gradient_checkpointing = False - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - def allocate_inference_cache(self, *args, **kwargs): - """Allocate inference cache for the model.""" - return {i: layer.allocate_inference_cache(*args, **kwargs) for i, layer in enumerate(self.layers)} - - @add_start_docstrings_to_model_forward(APRIEL_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - return_hidden_states=False, - inference_params=None, - position_ids=None, - ) -> Union[tuple, BaseModelOutputWithPast]: - - hidden_states = self.embed_tokens(input_ids) - - # decoder layers - outputs = { - "last_hidden_state": None, - "all_hidden_states": (hidden_states,) if return_hidden_states else (), - } - - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - - layer_outputs = decoder_layer( - hidden_states, - inference_params=inference_params, - position_ids=position_ids, - ) - # Record outputs - hidden_states = layer_outputs["hidden_states"] - if return_hidden_states: - outputs["all_hidden_states"] += (hidden_states,) - - outputs["last_hidden_state"] = self.norm(hidden_states) - return outputs - - -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... - - -class AprielSSMForCausalLM(AprielSSMPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - - def __init__(self, config, device=None, dtype=None, **kwargs): - super().__init__(config, device=device, dtype=dtype, **kwargs) - self.model = AprielSSMModel(config, device=device, dtype=dtype) - self.vocab_size = config.vocab_size - factory_kwargs = {"device": device, "dtype": dtype} - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, **factory_kwargs) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def forward( - self, - input_ids: torch.LongTensor = None, - position_ids=None, - return_hidden_states=False, - return_logits=True, - inference_params=None, - num_last_tokens=0, - **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[tuple, CausalLMOutputWithPast]: - - outputs = self.model( - input_ids, - return_hidden_states=return_hidden_states, - inference_params=inference_params, - position_ids=position_ids, - ) - - if outputs["last_hidden_state"] is not None and return_logits: - logits = self.lm_head(outputs["last_hidden_state"]).float() - outputs["logits"] = logits if num_last_tokens == 0 else logits[:, -num_last_tokens:] - else: - outputs["logits"] = None - - return CustomMambaCausalLMOutput( - loss=None, - logits=outputs["logits"], - all_hidden_states=outputs["all_hidden_states"], - last_hidden_state=outputs["last_hidden_state"], - ) - - def generate(self, *args, **kwargs): - """ - This is a wrapper to make sure we comply with the HF generation interface for eval harness - """ - return super().generate(*args, **kwargs) - - -__all__ = [ - "AprielSSMForCausalLM", - "AprielModel", - "AprielSSMPreTrainedModel", -] diff --git a/fast_llm/models/ssm/external/llamba/configuration_mtp_llamba.py b/fast_llm/models/ssm/external/llamba/configuration_mtp_llamba.py deleted file mode 100644 index b8173b733..000000000 --- a/fast_llm/models/ssm/external/llamba/configuration_mtp_llamba.py +++ /dev/null @@ -1,94 +0,0 @@ -from enum import Enum - -from transformers.configuration_utils import PretrainedConfig - - -class StateUpdateKernel(Enum): - ssu_verification = "ssu_verification" # selective scan for multi-token verification, not implemented yet - cs = "chunk_scan" # see https://proceedings.mlr.press/v262/wu24a.html - ssu = "standard" # usual one token per time-step inference using selective-scan update, no verification - - -class MTPLlambaConfig(PretrainedConfig): - r"""Configuration class for the CustomMamba model. - - This configuration is used to instantiate the CustomMamba model according to the specified arguments, - defining the model architecture. - - Args: - vocab_size (`int`, *optional*, defaults to 128256): - Vocabulary size of the model. - tie_embeddings (`bool`, *optional*, defaults to `False`): - Whether the model's input and output word embeddings should be tied. - pad_vocab_size_multiple (`int`, *optional*, defaults to 8): - Pad the vocabulary size up to the next multiple of this value. - lm_head_bias (`bool`, *optional*, defaults to `False`): - Whether the LM head includes a bias term. - d_model (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - lm_head_prenorm (`str`, *optional*, defaults to "rms"): - Normalization type for LM head. - n_layer (`int`, *optional*, defaults to 32): - Number of layers in the model. - resid_dropout (`float`, *optional*, defaults to 0.0): - Dropout rate for residual connections. - norm_epsilon (`float`, *optional*, defaults to 1e-5): - Epsilon value used for normalization layers. - mlp_cfg (`dict`, *optional*): - Configuration for the MLP (Multi-Layer Perceptron) layer, including intermediate size, activation function, and whether to use bias. - ssm_cfg (`dict`, *optional*): - Configuration for the SSM (State Space Model) layer, including d_state, number of heads, expansion, and other parameters. - - """ - - model_type = "llamba" - - def __init__( - self, - vocab_size: int, - d_model: int, - tie_embeddings: bool = False, - pad_vocab_size_multiple: int = 8, - lm_head_bias: bool = False, - n_layer: int = 32, - resid_dropout: float = 0.0, - norm_epsilon: float = 1e-5, - mlp_cfg: dict = None, - ssm_cfg: dict = None, - prediction_heads=1, - state_update_kernel: StateUpdateKernel = StateUpdateKernel.cs, - **kwargs, - ): - super().__init__(**kwargs) - - self.vocab_size = vocab_size - self.tie_embeddings = tie_embeddings - self.pad_vocab_size_multiple = pad_vocab_size_multiple - self.lm_head_bias = lm_head_bias - self.d_model = d_model - self.n_layer = n_layer - self.resid_dropout = resid_dropout - self.norm_epsilon = norm_epsilon - self.prediction_heads = prediction_heads - assert ( - state_update_kernel != StateUpdateKernel.ssu_verification - ), "Only chunk scan and standard modes are supported for now" - self.state_update_kernel = state_update_kernel - - # MLP (Multi-Layer Perceptron) Config - self.mlp_cfg = mlp_cfg or { - "intermediate_size": 14336, - "bias": False, - "act_fn": "silu", - } - - # SSM (State Space Model) Config - self.ssm_cfg = ssm_cfg or { - "d_state": 64, - "n_v_heads": 32, - "n_qk_heads": 32, - "expand": 1, - "chunk_size": 128, - "activation": "identity", - "bias": False, - } diff --git a/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py b/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py deleted file mode 100644 index 6d9746db1..000000000 --- a/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py +++ /dev/null @@ -1,389 +0,0 @@ -# Copyright (c) 2024, Kevin Li, Aviv Bick. - -import json -import os -from dataclasses import dataclass -from typing import Optional - -import torch -import torch.nn as nn -from huggingface_hub import PyTorchModelHubMixin -from mamba_ssm.utils.generation import GenerationMixin -from torch import Tensor, nn -from transformers.activations import ACT2FN -from transformers.utils.generic import ModelOutput - -from .configuration_mtp_llamba import MTPLlambaConfig as LlambaConfig -from .discrete_mamba2 import DiscreteMamba2 - - -class LlamaRMSNorm(nn.Module): - """LlamaRMSNorm (taken from transformers.models.llama.modeling_llama.LlamaRMSNorm).""" - - def __init__(self, hidden_size, eps=1e-6, factory_kwargs=None): - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - """ - Args: - hidden_states: torch.Tensor of shape (batch_size, seq_len, hidden_size). - - Returns: - torch.Tensor of shape (batch_size, seq_len, hidden_size). - """ - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - """Set the extra representation of the module.""" - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -class LlamaMLP(nn.Module): - """LlamaMLP (taken from transformers.models.llama.modeling_llama.LlamaMLP).""" - - def __init__(self, hidden_size, intermediate_size, bias, act_fn, factory_kwargs=None): - super().__init__() - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias, **factory_kwargs) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias, **factory_kwargs) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias, **factory_kwargs) - self.act_fn = ACT2FN[act_fn] - - def forward(self, x): - """ - Args: - x: torch.Tensor of shape (batch_size, seq_len, hidden_size). - - Returns: - torch.Tensor of shape (batch_size, seq_len, hidden_size). - """ - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -@dataclass -class CustomMambaCausalLMOutput(ModelOutput): - """Custom output class for MambaLMHeadModel.""" - - loss: Optional[torch.FloatTensor] = None - logits: Optional[torch.FloatTensor] = None - all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None - last_hidden_state: Optional[torch.FloatTensor] = None - - -class MTPLlambaLMHeadModel(nn.Module, GenerationMixin, PyTorchModelHubMixin): - """MambaLM model with a language modeling head on top (linear layer).""" - - def __init__(self, config, initializer_cfg=None, device=None, dtype=None, **kwargs) -> None: - super().__init__() - - # Load config - if not isinstance(config, LlambaConfig): - config = LlambaConfig(**config) - self.config = config - - # Factory kwargs - factory_kwargs = {"device": device, "dtype": dtype} - - # Pad vocab size to be a multiple of pad_vocab_size_multiple - vocab_size = config.vocab_size - pad_vocab_size_multiple = config.pad_vocab_size_multiple - if vocab_size % pad_vocab_size_multiple != 0: - vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) - self.config.vocab_size = vocab_size - - # Mixer model - self.backbone = MixerModel( - input_size=vocab_size, - config=self.config, - initializer_cfg=initializer_cfg, - **factory_kwargs, - ) - - # MTP heads - self.mtp_heads = nn.ModuleList( - [ - Block( - config=config, - factory_kwargs=factory_kwargs, - layer_idx=layer_idx, - ).to(device) - for layer_idx in range(config.n_layer, config.n_layer + config.prediction_heads - 1) - ] - ) - - self.mtp_norms = nn.ModuleList( - [ - LlamaRMSNorm(config.d_model, eps=config.norm_epsilon, factory_kwargs=factory_kwargs) - for _ in range(config.prediction_heads - 1) - ] - ) - # LM head - if not self.config.tie_embeddings: - self.lm_head = nn.Linear( - in_features=self.config.d_model, - out_features=self.config.vocab_size, - bias=self.config.lm_head_bias, - **factory_kwargs, - ) - else: - self.lm_head = lambda x: x @ self.backbone.embedding.weight.t() - - def allocate_inference_cache(self, *args, **kwargs): - """Allocate inference cache for the model.""" - - mtps = { - i + self.config.n_layer: layer.allocate_inference_cache(*args, **kwargs) - for i, layer in enumerate(self.mtp_heads) - } - return {**self.backbone.allocate_inference_cache(*args, **kwargs), **mtps} - - def forward( - self, - input_ids, - position_ids=None, - return_hidden_states=False, - return_logits=True, - inference_params=None, - num_last_tokens=0, - ): - """ - Args: - input_ids: torch.Tensor of shape (batch_size, seq_len), - position_ids: torch.Tensor of shape (batch_size, seq_len), optional, not used (just for compatibility), - return_hidden_states: bool, optional, - return_logits: bool, optional, whether to compute the logits with the LM head, - inference_params: dict, optional, the model's inference cache, - num_last_tokens: int, optional. If > 0, only return the logits for the last n tokens. - - Returns: - CustomMambaCausalLMOutput. - - """ - outputs = self.backbone( - input_ids, - return_hidden_states=return_hidden_states, - inference_params=inference_params, - position_ids=position_ids, - ) - - # MTP heads processing - latents = [] - hidden_states = outputs["last_hidden_state"] - hidden_states_before_last = outputs["hidden_state_before_last"] - - # last layer already has layer norm applied - latents.append(hidden_states) - - # Process through MTP heads - for i, mtp_head in enumerate(self.mtp_heads): - mtp_outputs = mtp_head( - hidden_states_before_last, - inference_params=inference_params, - position_ids=position_ids, - ) - mtp_hidden_states = mtp_outputs["hidden_states"] - latents.append(self.mtp_norms[i](mtp_hidden_states)) - - # Stack the latents to get (batch_size, seq_len, num_prediction_heads, hidden_size) - stacked_latents = torch.stack(latents, dim=-2) - - if return_logits: - if isinstance(self.lm_head, nn.Linear): - # Apply lm_head to each prediction head's output - logits = self.lm_head(stacked_latents).float() - else: - # Using the tied embedding weights - logits = self.lm_head(stacked_latents) - - outputs["logits"] = logits if num_last_tokens == 0 else logits[:, -num_last_tokens:] - else: - outputs["logits"] = None - - return CustomMambaCausalLMOutput( - loss=None, - logits=outputs["logits"], - all_hidden_states=outputs["all_hidden_states"], - last_hidden_state=stacked_latents, - ) - - def save_pretrained(self, save_directory): - """ - Minimal implementation of save_pretrained for MambaLMHeadModel. - Save the model and its configuration file to a directory. - """ - # Ensure save_directory exists - if not os.path.exists(save_directory): - os.makedirs(save_directory) - - # Save the model's state_dict - model_path = os.path.join(save_directory, "pytorch_model.bin") - torch.save(self.state_dict(), model_path) - - # Save the configuration of the model - config_path = os.path.join(save_directory, "config.json") - with open(config_path, "w") as f: - json.dump(self.config.to_dict(), f) - - -class MixerModel(nn.Module): - """Mixer model with a stack of Mixer layers.""" - - def __init__(self, input_size, config=None, device=None, dtype=None, **kwargs) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.config = config - self.embedding = nn.Embedding(input_size, self.config.d_model, **factory_kwargs) - - self.layers = nn.ModuleList( - [ - Block( - config=config, - factory_kwargs=factory_kwargs, - layer_idx=i, - ).to(device) - for i in range(self.config.n_layer) - ] - ) - - self.final_layernorm = LlamaRMSNorm( - hidden_size=self.config.d_model, - eps=self.config.norm_epsilon, - factory_kwargs=factory_kwargs, - ) - - return - - def allocate_inference_cache(self, *args, **kwargs): - """Allocate inference cache for the model.""" - return {i: layer.allocate_inference_cache(*args, **kwargs) for i, layer in enumerate(self.layers)} - - def forward( - self, - input_ids, - return_hidden_states=False, - inference_params=None, - position_ids=None, - ): - """Run the model.""" - # Start running the layers - hidden_states = self.embedding(input_ids) - - # Initialize outputs - outputs = { - "last_hidden_state": None, - "hidden_state_before_last": None, - "all_hidden_states": (hidden_states,) if return_hidden_states else (), - } - - # Run the layers - for layer in self.layers: - layer_outputs = layer( - hidden_states, - inference_params=inference_params, - position_ids=position_ids, - ) - if layer == self.layers[-1]: - outputs["hidden_state_before_last"] = hidden_states - # Record outputs - hidden_states = layer_outputs["hidden_states"] - if return_hidden_states: - outputs["all_hidden_states"] += (hidden_states,) - - # Last layer, apply layer norm - outputs["last_hidden_state"] = self.final_layernorm(hidden_states) - return outputs - - -class Block(nn.Module): - """ - Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection. - - This Block has a slightly different structure compared to a regular - prenorm Transformer block. - The standard block is: LN -> MHA/MLP -> Add. - [Ref: https://arxiv.org/abs/2002.04745] - Here we have: Add -> LN -> Mixer, returning both - the hidden_states (output of the mixer) and the residual. - This is purely for performance reasons, as we can fuse add and LayerNorm. - The residual needs to be provided (except for the very first block). - """ - - def __init__(self, config, factory_kwargs, layer_idx, **kwargs): - super().__init__() - self.config = config - self.layer_idx = layer_idx - - # Mixer - self.mixer = DiscreteMamba2( - d_model=self.config.d_model, - layer_idx=layer_idx, - **config.ssm_cfg, - **factory_kwargs, - ) - - # Other components - self.input_layernorm = LlamaRMSNorm(hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs) - self.post_attention_layernorm = LlamaRMSNorm( - hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs - ) - self.mlp = LlamaMLP( - hidden_size=self.config.d_model, - **config.mlp_cfg, - factory_kwargs=factory_kwargs, - ) - - def forward( - self, - hidden_states: Tensor, - inference_params=None, - **kwargs, - ): - """ - Pass the input through the encoder layer. - - Args: - hidden_states: torch.Tensor of shape (batch_size, seq_len, hidden_size), - inference_params: dict, optional, - - Returns: - dict with keys: - hidden_states: torch.Tensor of shape (batch_size, seq_len, hidden_size), - mamba_hidden_states: torch.Tensor of shape (batch_size, seq_len, hidden_size), - transfer_matrix: torch.Tensor of shape (batch_size, seq_len, seq_len). - """ - outputs = {} - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Apply Mixer - mixer_outputs = self.mixer( - hidden_states, - inference_params=inference_params, - ) - - hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs["hidden_states"] = hidden_states - - return outputs - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - """Allocate inference cache for the model.""" - if getattr(self.mixer, "allocate_inference_cache", None) is None: - return - return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) diff --git a/fast_llm/models/ssm/huggingface.py b/fast_llm/models/ssm/huggingface.py deleted file mode 100644 index 24005ee9f..000000000 --- a/fast_llm/models/ssm/huggingface.py +++ /dev/null @@ -1,23 +0,0 @@ -import logging -import typing - -from fast_llm.engine.inference.config import HuggingfaceModelConfig -from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM -from fast_llm.models.ssm.config import HybridSSMModelConfig -from fast_llm.models.ssm.model import HybridSSMInferenceRunner, HybridSSMModel - -logger = logging.getLogger(__name__) - - -class HuggingfaceSSMModelConfig(HuggingfaceModelConfig): - model_type = "fast_llm_ssm" - model_config_class = HybridSSMModelConfig - fast_llm_config: HybridSSMModelConfig - - -class HuggingfaceHybridSSMModelForCausalLM(HuggingfaceGPTModelForCausalLM): - config_class = HuggingfaceSSMModelConfig - config: HuggingfaceSSMModelConfig - model_class = HybridSSMModel - runner_class: typing.ClassVar[type[HybridSSMInferenceRunner]] = HybridSSMInferenceRunner - _fast_llm_model: HybridSSMModel diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py deleted file mode 100644 index 0382462b5..000000000 --- a/fast_llm/models/ssm/model.py +++ /dev/null @@ -1,53 +0,0 @@ -import logging -import typing - -from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner, GPTModel -from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType - -logger = logging.getLogger(__name__) - - -class HybridSSMBaseModel[ConfigType: HybridSSMBaseModelConfig](GPTBaseModel[ConfigType]): - """ - A hybrid model that interleaves Transformer and Mamba blocks. - Right now only LlambaBlock is supported. - As for the mixer, transformer uses MHA. For the LlambaBlock we support Mamba1 and discrete mamba2. - """ - - def _get_block( - self, - block_index: int, - name: str, - return_input: bool = False, - ): - if block_index > self._config.transformer.num_layers: - # MTP block - block_type = self._config.default_mtp_type or self._config.hybrid_block_layout[-1] - else: - # Decoder block - block_type = self._config.hybrid_block_layout[block_index - 1] - - if block_type == SSMBlockType.transformer: - block_config = self._config.transformer - else: - block_config = self._config.transformer.from_dict(self._config.transformer, {"mixer": self._config.ssm}) - - return block_config.get_layer( - self._distributed_config, - hidden_dim=self._hidden_dim, - lr_scale=None, - peft=self._config.peft, - return_input=return_input, - ) - - -class HybridSSMModel[ConfigType: HybridSSMModelConfig](GPTModel[ConfigType]): - """ - A hybrid model that combines Transformer and SSM blocks. - """ - - base_model_class: typing.ClassVar[type[HybridSSMBaseModel]] = HybridSSMBaseModel - - -class HybridSSMInferenceRunner(GPTInferenceRunner): - model_class: typing.ClassVar[type[HybridSSMModel]] = HybridSSMModel diff --git a/fast_llm/models/ssm/trainer.py b/fast_llm/models/ssm/trainer.py deleted file mode 100644 index 39f589384..000000000 --- a/fast_llm/models/ssm/trainer.py +++ /dev/null @@ -1,9 +0,0 @@ -import typing - -from fast_llm.models.gpt.trainer import GPTTrainer -from fast_llm.models.ssm.config import HybridSSMTrainerConfig -from fast_llm.models.ssm.model import HybridSSMModel - - -class HybridSSMTrainer[ConfigType: HybridSSMTrainerConfig](GPTTrainer[ConfigType]): - model_class: typing.ClassVar[type[HybridSSMModel]] = HybridSSMModel diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 4323efe3f..b709ea835 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -198,12 +198,12 @@ def global_to_local(self, tensor: torch.Tensor | SafeTensorSlice) -> torch.Tenso assert not self._reductions if tensor.ndim == 0: tensor = tensor[None] - Assert.eq(tensor.shape, self.global_shape) + Assert.eq(tensor.shape, self.global_shape, msg=self) for dim, tensor_dim in reversed(list(enumerate(self.dims))): tensor = tensor_dim.global_to_local(tensor, dim) - Assert.eq(tensor.shape, self.shape) + Assert.eq(tensor.shape, self.shape, msg=self) return tensor @classmethod diff --git a/fast_llm/utils.py b/fast_llm/utils.py index d13ecaf65..1f9feceb4 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -93,8 +93,9 @@ class Assert: @staticmethod def eq(x, *args, msg=None): + assert args for arg in args: - assert x == arg, f"{x} != {arg} " + (f"| {msg}" if msg else "") + assert x == arg, f"{x} != {arg} " + ("" if msg is None else f"| {msg}") @staticmethod def is_(x, y): @@ -457,3 +458,15 @@ def get_and_reset_memory_usage_mib( _global_max_reserved = max(max_reserved, _global_max_reserved) return report + + +def safe_merge_dicts(*dicts) -> dict: + out = {} + for dict_ in dicts: + for key, value in dict_.items(): + if key in out: + if isinstance(value, dict) and isinstance(out[key], dict): + out[key] = safe_merge_dicts(value, out[key]) + Assert.eq(value, out[key]) + out[key] = value + return out diff --git a/fast_llm_external_models/__init__.py b/fast_llm_external_models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py b/fast_llm_external_models/apriel_hybrid_ssm/configuration_apriel_hybrid_ssm.py similarity index 89% rename from fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py rename to fast_llm_external_models/apriel_hybrid_ssm/configuration_apriel_hybrid_ssm.py index 98d2fc28d..12ee343ef 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py +++ b/fast_llm_external_models/apriel_hybrid_ssm/configuration_apriel_hybrid_ssm.py @@ -23,11 +23,13 @@ "dt_scale": 1.0, "dt_init_floor": 1e-4, "conv_bias": True, + "dt_proj_bias": True, + "repeat_kv_before_conv": True, } -class AprielSSMHybridConfig(MistralConfig): - model_type = "apriel_ssm_thinker_hybrid" +class AprielHybridSSMConfig(MistralConfig): + model_type = "apriel_hybrid_ssm" def __init__(self, hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs): super().__init__(**kwargs) diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py similarity index 98% rename from fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py rename to fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py index 4fde72458..5c0a2216c 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py @@ -21,7 +21,7 @@ from transformers.utils import LossKwargs, logging from transformers.utils.generic import ModelOutput -from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig +from fast_llm_external_models.apriel_hybrid_ssm.configuration_apriel_hybrid_ssm import AprielHybridSSMConfig # from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_scan_fn as varlen_selective_scan_fn # from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_fn as varlen_causal_conv1d_fn @@ -46,7 +46,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: class HybridMambaAttentionStaticCache(Cache): - def __init__(self, config: AprielSSMHybridConfig, batch_size, max_length, dtype=torch.float16, device=None): + def __init__(self, config: AprielHybridSSMConfig, batch_size, max_length, dtype=torch.float16, device=None): super().__init__() # config, batch_size, max_length, device, dtype) self.dtype = dtype self.hybrid_override_pattern = config.hybrid_block_layout @@ -231,7 +231,7 @@ class HybridMambaAttentionDynamicCache(DynamicCache): and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. """ - def __init__(self, config: AprielSSMHybridConfig, batch_size, dtype=torch.float16, device=None): + def __init__(self, config: AprielHybridSSMConfig, batch_size, dtype=torch.float16, device=None): super().__init__() self.dtype = dtype self.hybrid_override_pattern = config.hybrid_block_layout @@ -564,8 +564,7 @@ def forward( else: seqlen_offset = kwargs.get("seqlen_offset", cache_position[0]) if cache_position is not None else 0 use_precomputed_states = ( - past_key_value is not None - and past_key_value.has_previous_state + getattr(past_key_value, "has_previous_state", False) and seqlen == 1 and past_key_value.conv_states[self.layer_idx].shape[0] == past_key_value.ssm_states[self.layer_idx].shape[0] @@ -1130,7 +1129,7 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states class AprielSSMDecoderLayer(nn.Module): _mixer_class = DiscreteMamba2 - def __init__(self, config: AprielSSMHybridConfig, layer_idx: int, device=None, dtype=None, **kwargs): + def __init__(self, config: AprielHybridSSMConfig, layer_idx: int, device=None, dtype=None, **kwargs): super().__init__(**kwargs) factory_kwargs = {"device": device, "dtype": dtype} self.hidden_size = config.hidden_size @@ -1179,7 +1178,7 @@ class AprielSSMM2DecoderLayer(AprielSSMDecoderLayer): class AprielHybridIdentity(nn.Module): - def __init__(self, config: AprielSSMHybridConfig): + def __init__(self, config: AprielHybridSSMConfig): super().__init__() self.config = config @@ -1187,14 +1186,14 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): return (hidden_states,) -class AprielThinkerSSMHybridModel(MistralModel): +class AprielHybridSSMModel(MistralModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AprielDecoderLayer`, `AprielSSMDecoderLayer`] Args: - config: AprielSSMHybridConfig + config: AprielHybridSSMConfig """ - def __init__(self, config: AprielSSMHybridConfig, **kwargs): + def __init__(self, config: AprielHybridSSMConfig, **kwargs): config_copy = copy.deepcopy(config) config_copy.num_hidden_layers = 0 super().__init__(config_copy, **kwargs) @@ -1221,8 +1220,8 @@ def __init__(self, config: AprielSSMHybridConfig, **kwargs): class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... -class AprielThinkerSSMHybridPreTrainedModel(PreTrainedModel): - config_class = AprielSSMHybridConfig +class AprielHybridSSMPreTrainedModel(PreTrainedModel): + config_class = AprielHybridSSMConfig base_model_prefix = "model" _no_split_modules = ["MistralDecoderLayer", "AprielSSMDecoderLayer", "AprielSSMM2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] @@ -1248,13 +1247,13 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) -class AprielThinkerSSMHybridForCausalLM(AprielThinkerSSMHybridPreTrainedModel, GenerationMixin): +class AprielHybridSSMForCausalLM(AprielHybridSSMPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - def __init__(self, config: AprielSSMHybridConfig, **kwargs): + def __init__(self, config: AprielHybridSSMConfig, **kwargs): super().__init__(config, **kwargs) - self.model = AprielThinkerSSMHybridModel(config) + self.model = AprielHybridSSMModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) @@ -1419,7 +1418,7 @@ def forward( __all__ = [ - "AprielThinkerSSMHybridForCausalLM", - "AprielThinkerSSMHybridModel", - "AprielThinkerSSMHybridPreTrainedModel", + "AprielHybridSSMForCausalLM", + "AprielHybridSSMModel", + "AprielHybridSSMPreTrainedModel", ] diff --git a/fast_llm/models/gpt/external/diffusion_dream/configuration_dream.py b/fast_llm_external_models/diffusion_dream/configuration_dream.py similarity index 100% rename from fast_llm/models/gpt/external/diffusion_dream/configuration_dream.py rename to fast_llm_external_models/diffusion_dream/configuration_dream.py diff --git a/fast_llm/models/gpt/external/diffusion_dream/generation_config.json b/fast_llm_external_models/diffusion_dream/generation_config.json similarity index 100% rename from fast_llm/models/gpt/external/diffusion_dream/generation_config.json rename to fast_llm_external_models/diffusion_dream/generation_config.json diff --git a/fast_llm/models/gpt/external/diffusion_dream/generation_utils.py b/fast_llm_external_models/diffusion_dream/generation_utils.py similarity index 100% rename from fast_llm/models/gpt/external/diffusion_dream/generation_utils.py rename to fast_llm_external_models/diffusion_dream/generation_utils.py diff --git a/fast_llm/models/gpt/external/diffusion_dream/modeling_dream.py b/fast_llm_external_models/diffusion_dream/modeling_dream.py similarity index 95% rename from fast_llm/models/gpt/external/diffusion_dream/modeling_dream.py rename to fast_llm_external_models/diffusion_dream/modeling_dream.py index e041d6189..714576eeb 100644 --- a/fast_llm/models/gpt/external/diffusion_dream/modeling_dream.py +++ b/fast_llm_external_models/diffusion_dream/modeling_dream.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX @@ -20,37 +19,26 @@ """PyTorch Dream model.""" import math -from typing import List, Optional, Tuple, Union import os +from dataclasses import dataclass +from typing import Optional, Union + import torch import torch.utils.checkpoint from torch import nn -from dataclasses import dataclass - +from transformers import PretrainedConfig from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_outputs import ( - BaseModelOutput, - BaseModelOutputWithPast, - MaskedLMOutput, -) -from transformers.utils import ModelOutput +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, MaskedLMOutput from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.modeling_utils import PreTrainedModel -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, - logging, -) -from transformers import PretrainedConfig +from transformers.utils import ModelOutput, is_flash_attn_2_available, logging + from .configuration_dream import DreamConfig -from .generation_utils import DreamGenerationMixin, DreamGenerationConfig +from .generation_utils import DreamGenerationConfig, DreamGenerationMixin if is_flash_attn_2_available(): - from transformers.modeling_flash_attention_utils import _flash_attention_forward - from flash_attn import flash_attn_with_kvcache, flash_attn_func + from flash_attn import flash_attn_with_kvcache logger = logging.get_logger(__name__) @@ -131,7 +119,6 @@ def reset_parameters(self): inv_freq, self.attention_scaling = self.rope_init_fn(self.config, self.inv_freq.device, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq - def _dynamic_frequency_update(self, position_ids, device): """ @@ -287,8 +274,8 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() # Luke: Computing K and Vs for all tokens upto now q_len w/o using cache ? @@ -345,7 +332,7 @@ def forward( attn_weights = None return attn_output, attn_weights, past_key_value - + class DreamSdpaAttention(DreamAttention): """ @@ -364,9 +351,9 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 is_causal: Optional[bool] = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( @@ -384,7 +371,6 @@ def forward( bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) @@ -405,18 +391,16 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings - + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and attention_mask is not None: @@ -435,9 +419,9 @@ def forward( value_states, attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None, dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - + is_causal=is_causal, + ) + attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) @@ -445,6 +429,7 @@ def forward( return attn_output, None, past_key_value + class DreamFlashAttention(DreamAttention): """ Dream attention module using Flash attention 2. @@ -460,9 +445,9 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 is_causal: Optional[bool] = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( @@ -478,7 +463,7 @@ def forward( ) bsz, q_len, _ = hidden_states.size() - + query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) @@ -489,7 +474,7 @@ def forward( # print(f"hidden_states: {hidden_states.shape} query_states {query_states.shape} key_states {key_states.shape} value_states {value_states.shape}") # print(f"position_ids {position_ids} {position_ids.shape}") - + if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " @@ -500,18 +485,16 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings - + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - + key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - # if query_states.device.type == "cuda" and attention_mask is not None: # query_states = query_states.contiguous() # key_states = key_states.contiguous() @@ -529,19 +512,19 @@ def forward( # attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None, # dropout_p=self.attention_dropout if self.training else 0.0, # is_causal=False, # hard coded - # ) - + # ) + # print(f"query_states {query_states.shape} key_states {key_states.shape} value_states {value_states.shape}") - + # replacing with flash attention attn_output = flash_attn_with_kvcache( # q dim (batch_size, seqlen, nheads, headdim) q=query_states.transpose(1, 2).contiguous(), k_cache=key_states.transpose(1, 2).contiguous(), v_cache=value_states.transpose(1, 2).contiguous(), - causal=is_causal, # hard coded + causal=is_causal, # hard coded softmax_scale=1.0 / math.sqrt(self.head_dim), - ) + ) attn_output = attn_output.view(bsz, q_len, self.hidden_size) @@ -549,6 +532,7 @@ def forward( return attn_output, None, past_key_value + class DreamDecoderLayer(nn.Module): def __init__(self, config: DreamConfig, layer_idx: int): super().__init__() @@ -559,7 +543,7 @@ def __init__(self, config: DreamConfig, layer_idx: int): f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " "unexpected results may be encountered." ) - + # self.self_attn = Dream_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) if config._attn_implementation == "flash_attention_2": self.self_attn = DreamFlashAttention(config, layer_idx) @@ -575,13 +559,13 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -605,9 +589,9 @@ def forward( """ # print(f"DreamDecoderLayer: past_key_value {past_key_value} use_cache {use_cache}") - + is_casual = kwargs.get("is_casual", False) - + residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -643,9 +627,10 @@ def forward( # When use_cache is True, outputs will have length: # - 2 if output_attentions is False (hidden_states, present_key_value) # - 3 if output_attentions is True (hidden_states, self_attn_weights, present_key_value) - # print(f"DreamDecoderLayer: outputs {len(outputs)}") + # print(f"DreamDecoderLayer: outputs {len(outputs)}") return outputs + class DreamPreTrainedModel(PreTrainedModel): config_class = DreamConfig base_model_prefix = "model" @@ -700,7 +685,7 @@ def from_pretrained( **kwargs, ) # NOTE(Lin): we need to override the generation config - # because the generation config loaded in `from_pretrained` + # because the generation config loaded in `from_pretrained` # does not include all the attributes of DreamGenerationConfig resume_download = kwargs.get("resume_download", None) proxies = kwargs.get("proxies", None) @@ -722,6 +707,7 @@ def from_pretrained( ) return _model + class DreamBaseModel(DreamPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DreamDecoderLayer`] @@ -758,7 +744,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -766,13 +752,13 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, is_casual: Optional[bool] = False, - ) -> Union[Tuple, BaseModelOutput]: + ) -> Union[tuple, BaseModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - + # print("DreamBaseModel: past_key_values", past_key_values, "use_cache", use_cache,) return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -789,7 +775,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - + if use_cache and past_key_values is None: past_key_values = DynamicCache() @@ -867,6 +853,8 @@ def forward( hidden_states=all_hidden_states, attentions=all_self_attns, ) + + @dataclass class MaskedLMOutputWithPast(ModelOutput): """ @@ -891,16 +879,17 @@ class MaskedLMOutputWithPast(ModelOutput): past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. """ loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - past_key_values: Optional[Tuple[Cache]] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + past_key_values: Optional[tuple[Cache]] = None + class DreamModel(DreamGenerationMixin, DreamPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] @@ -942,7 +931,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -952,7 +941,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, **loss_kwargs, - ) -> Union[Tuple, MaskedLMOutput]: + ) -> Union[tuple, MaskedLMOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -993,4 +982,4 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, past_key_values=outputs.past_key_values, - ) \ No newline at end of file + ) diff --git a/fast_llm/models/gpt/external/diffusion_llama/configuration_diffusion_llama.py b/fast_llm_external_models/diffusion_llama/configuration_diffusion_llama.py similarity index 100% rename from fast_llm/models/gpt/external/diffusion_llama/configuration_diffusion_llama.py rename to fast_llm_external_models/diffusion_llama/configuration_diffusion_llama.py diff --git a/fast_llm/models/gpt/external/diffusion_llama/generation_utils.py b/fast_llm_external_models/diffusion_llama/generation_utils.py similarity index 100% rename from fast_llm/models/gpt/external/diffusion_llama/generation_utils.py rename to fast_llm_external_models/diffusion_llama/generation_utils.py diff --git a/fast_llm/models/gpt/external/diffusion_llama/modeling_diffusion_llama.py b/fast_llm_external_models/diffusion_llama/modeling_diffusion_llama.py similarity index 99% rename from fast_llm/models/gpt/external/diffusion_llama/modeling_diffusion_llama.py rename to fast_llm_external_models/diffusion_llama/modeling_diffusion_llama.py index 5e613093e..c8723af5d 100644 --- a/fast_llm/models/gpt/external/diffusion_llama/modeling_diffusion_llama.py +++ b/fast_llm_external_models/diffusion_llama/modeling_diffusion_llama.py @@ -1,7 +1,3 @@ -import math -import os -from dataclasses import dataclass - # Copyright 2022 ServiceNow. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX @@ -20,6 +16,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import math +import os +from dataclasses import dataclass from typing import Callable, Optional, Union import torch @@ -30,20 +30,12 @@ from transformers.cache_utils import Cache, DynamicCache from transformers.integrations import use_kernel_forward_from_hub from transformers.modeling_flash_attention_utils import FlashAttentionKwargs - -# from transformers.modeling_layers import GradientCheckpointingLayer # Update transformer from transformers.modeling_outputs import BaseModelOutputWithPast, MaskedLMOutput from transformers.modeling_rope_utils import dynamic_rope_update from transformers.modeling_utils import PreTrainedModel from transformers.processing_utils import Unpack from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS -from transformers.utils import ( # auto_docstring - LossKwargs, - ModelOutput, - can_return_tuple, - is_torch_flex_attn_available, - logging, -) +from transformers.utils import ModelOutput, can_return_tuple, is_torch_flex_attn_available, logging from .configuration_diffusion_llama import ROPE_INIT_FUNCTIONS, DiffusionLlamaConfig from .generation_utils import SLAMGenerationConfig, SLAMGenerationMixin diff --git a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py b/fast_llm_external_models/eval/apriel_eval_wrapper.py similarity index 93% rename from fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py rename to fast_llm_external_models/eval/apriel_eval_wrapper.py index ee2c83e03..2405175b6 100644 --- a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py +++ b/fast_llm_external_models/eval/apriel_eval_wrapper.py @@ -54,13 +54,13 @@ def __init__(self, pretrained, **kwargs) -> None: def _get_config(self, pretrained: str, **kwargs) -> None: """Get the model configuration.""" - from fast_llm.models.ssm.external.apriel_ssm.configuration_ssm_apriel import AprielSSMConfig + from fast_llm_external_models.apriel_ssm.configuration_ssm_apriel import AprielSSMConfig self._config = AprielSSMConfig.from_pretrained(pretrained) def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: """Create the model.""" - from fast_llm.models.ssm.external.apriel_ssm.modeling_ssm_apriel import AprielSSMForCausalLM + from fast_llm_external_models.apriel_ssm.modeling_ssm_apriel import AprielSSMForCausalLM # Ensure we're using the correct device device = _get_device() @@ -121,13 +121,13 @@ def __init__(self, pretrained, **kwargs) -> None: def _get_config(self, pretrained: str, **kwargs) -> None: """Get the model configuration.""" - from fast_llm.models.ssm.external.apriel_hybrid.configuration_ssm_hybrid_apriel import AprielSSMHybridConfig + from fast_llm_external_models.apriel_hybrid.configuration_ssm_hybrid_apriel import AprielSSMHybridConfig self._config = AprielSSMHybridConfig.from_pretrained(pretrained, trust_remote_code=True) def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: """Create the model.""" - from fast_llm.models.ssm.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridForCausalLM + from fast_llm_external_models.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridForCausalLM # Ensure we're using the correct device device = _get_device() @@ -194,15 +194,13 @@ def __init__(self, pretrained, **kwargs) -> None: def _get_config(self, pretrained: str, **kwargs) -> None: """Get the model configuration.""" - from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import ( - AprielSSMHybridConfig, - ) + from fast_llm_external_models.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig self._config = AprielSSMHybridConfig.from_pretrained(pretrained, trust_remote_code=True) def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: """Create the model.""" - from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( + from fast_llm_external_models.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( AprielThinkerSSMHybridForCausalLM, ) diff --git a/fast_llm/models/ssm/external/eval/run_evalchemy.py b/fast_llm_external_models/eval/run_evalchemy.py similarity index 66% rename from fast_llm/models/ssm/external/eval/run_evalchemy.py rename to fast_llm_external_models/eval/run_evalchemy.py index 1cbb5b4da..2758c9ee1 100644 --- a/fast_llm/models/ssm/external/eval/run_evalchemy.py +++ b/fast_llm_external_models/eval/run_evalchemy.py @@ -1,5 +1,6 @@ from eval.eval import cli_evaluate -from fast_llm.models.ssm.external.eval.apriel_eval_wrapper import ( # noqa: F401 + +from fast_llm_external_models.eval.apriel_eval_wrapper import ( # noqa: F401 AprielHybrid15bSSMWrapper, AprielHybridSSMWrapper, AprielSSMWrapper, diff --git a/fast_llm/models/ssm/external/eval/run_lm_eval.py b/fast_llm_external_models/eval/run_lm_eval.py similarity index 67% rename from fast_llm/models/ssm/external/eval/run_lm_eval.py rename to fast_llm_external_models/eval/run_lm_eval.py index 53c0febab..8d37584c4 100644 --- a/fast_llm/models/ssm/external/eval/run_lm_eval.py +++ b/fast_llm_external_models/eval/run_lm_eval.py @@ -1,6 +1,6 @@ from lm_eval.__main__ import cli_evaluate -from fast_llm.models.ssm.external.eval.apriel_eval_wrapper import ( # noqa: F401 +from fast_llm_external_models.eval.apriel_eval_wrapper import ( # noqa: F401 AprielHybrid15bSSMWrapper, AprielHybridSSMWrapper, AprielSSMWrapper, diff --git a/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_importance_15b_mil.py b/fast_llm_external_models/make_hybrid_checkpoint_with_importance_15b_mil.py similarity index 96% rename from fast_llm/models/ssm/external/make_hybrid_checkpoint_with_importance_15b_mil.py rename to fast_llm_external_models/make_hybrid_checkpoint_with_importance_15b_mil.py index dde11cfbc..f5d09da61 100644 --- a/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_importance_15b_mil.py +++ b/fast_llm_external_models/make_hybrid_checkpoint_with_importance_15b_mil.py @@ -3,8 +3,8 @@ import transformers from transformers import AutoConfig, AutoModelForCausalLM -from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig -from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( +from fast_llm_external_models.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig +from fast_llm_external_models.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( AprielSSMM2DecoderLayer, AprielThinkerSSMHybridForCausalLM, ) diff --git a/fast_llm/models/gpt/external/mtp_llama/configuration_mtp_llama.py b/fast_llm_external_models/mtp_llama/configuration_mtp_llama.py similarity index 100% rename from fast_llm/models/gpt/external/mtp_llama/configuration_mtp_llama.py rename to fast_llm_external_models/mtp_llama/configuration_mtp_llama.py diff --git a/fast_llm/models/gpt/external/mtp_llama/modeling_mtp_llama.py b/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py similarity index 100% rename from fast_llm/models/gpt/external/mtp_llama/modeling_mtp_llama.py rename to fast_llm_external_models/mtp_llama/modeling_mtp_llama.py diff --git a/setup.cfg b/setup.cfg index 843aa15ca..77073ab55 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,7 +2,9 @@ name = fast_llm [options] -packages = find_namespace: +packages = + fast_llm + fast_llm_external_models include_package_data = True python_requires = >=3.12 install_requires = diff --git a/tests/conftest.py b/tests/conftest.py index 86937326c..58301919f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,6 +34,10 @@ from tests.utils.model_configs import model_testing_config, ModelTestingConfig, testing_group_enabled # isort: skip from tests.utils.utils import result_path, format_resource_report, report_subtest # isort: skip +# Import all dynamic classes. +import fast_llm.cli # isort: skip + + logger = logging.getLogger(__name__) manager: DependencyManager | None = None diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index e402659b0..d52564cc0 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -162,11 +162,13 @@ def test_lm_head( ): config = GPTBaseModelConfig.from_dict( { - "transformer": { + "decoder": { + "num_blocks": 0, + }, + "embeddings_layer": { + "vocab_size": VOCAB_SIZE, "hidden_size": HIDDEN_SIZE, - "num_layers": 0, }, - "embeddings_layer": {"vocab_size": VOCAB_SIZE}, "output_layer": { "cross_entropy_implementation": cross_entropy_impl, "normalization": {"type": "rms_norm"}, @@ -239,7 +241,7 @@ def test_lm_head( torch.empty( VOCAB_SIZE, HIDDEN_SIZE, dtype=distributed.config.training_dtype.torch, device=distributed.device ) - .normal_(config.transformer.hidden_size**-0.5) + .normal_(config.embeddings_layer.hidden_size**-0.5) .requires_grad_(True) ) kwargs[WORD_EMBEDDINGS_WEIGHT if config.output_layer.tied_weight else OUTPUT_WEIGHTS] = logit_weight diff --git a/tests/models/distributed_test_checkpoint.py b/tests/models/distributed_test_checkpoint.py index 51687c6d8..217ecd0e1 100644 --- a/tests/models/distributed_test_checkpoint.py +++ b/tests/models/distributed_test_checkpoint.py @@ -63,6 +63,8 @@ def main(args: list[str] | None = None) -> None: group = pool.get_process_group(range(world_size), rank) for config in DISTRIBUTED_SAVE_LOAD_CONFIGS.values(): + if config.load_format == "{checkpoint_format}" and model_testing_config.checkpoint_format is None: + continue config = config.resolve(base_path, model_testing_config) Assert.eq(world_size, config.num_gpus) with DistributedSubtestContext(base_path, config.name, group, world_size, enabled=do_capture) as subtest: diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index ed911fc8a..714abc130 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -119,31 +119,38 @@ def test_conversion(model_testing_config, run_conversion, get_convert_path): DistributedCheckpointFormat, FastLLMCheckpointFormat, ) - run_conversion( - get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat), - FastLLMCheckpointFormat, - model_testing_config.checkpoint_format, - ) - run_conversion( - get_convert_path(model_testing_config.checkpoint_format, FastLLMCheckpointFormat), - model_testing_config.checkpoint_format, - DistributedCheckpointFormat, - ) - run_conversion( - get_convert_path(), - DistributedCheckpointFormat, - model_testing_config.checkpoint_format, - ) - run_conversion( - get_convert_path(model_testing_config.checkpoint_format, DistributedCheckpointFormat), - model_testing_config.checkpoint_format, - FastLLMCheckpointFormat, - ) - run_conversion( - get_convert_path(FastLLMCheckpointFormat, model_testing_config.checkpoint_format), - FastLLMCheckpointFormat, - DistributedCheckpointFormat, - ) + if model_testing_config.checkpoint_format is None: + run_conversion( + get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat), + FastLLMCheckpointFormat, + DistributedCheckpointFormat, + ) + else: + run_conversion( + get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat), + FastLLMCheckpointFormat, + model_testing_config.checkpoint_format, + ) + run_conversion( + get_convert_path(model_testing_config.checkpoint_format, FastLLMCheckpointFormat), + model_testing_config.checkpoint_format, + DistributedCheckpointFormat, + ) + run_conversion( + get_convert_path(), + DistributedCheckpointFormat, + model_testing_config.checkpoint_format, + ) + run_conversion( + get_convert_path(model_testing_config.checkpoint_format, DistributedCheckpointFormat), + model_testing_config.checkpoint_format, + FastLLMCheckpointFormat, + ) + run_conversion( + get_convert_path(FastLLMCheckpointFormat, model_testing_config.checkpoint_format), + FastLLMCheckpointFormat, + DistributedCheckpointFormat, + ) def _compare_safetensor_files( @@ -170,20 +177,29 @@ def _compare_safetensor_files( @pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_converted_round_trip(model_testing_config, get_convert_path): # Test that the various possible conversion paths yield identical results. - _compare_safetensor_files( - get_convert_path() / "rank_0.safetensors", - get_convert_path(DistributedCheckpointFormat, FastLLMCheckpointFormat) / "rank_0.safetensors", - get_convert_path(DistributedCheckpointFormat, model_testing_config.checkpoint_format) / "rank_0.safetensors", - expected_keys={_WEIGHT_SHARD_SAVE_NAME}, - ) - _compare_safetensor_files( - get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat) / "model_0.safetensors", - get_convert_path(FastLLMCheckpointFormat, model_testing_config.checkpoint_format) / "model_0.safetensors", - ) - _compare_safetensor_files( - get_convert_path(model_testing_config.checkpoint_format, DistributedCheckpointFormat) / "model_0.safetensors", - get_convert_path(model_testing_config.checkpoint_format, FastLLMCheckpointFormat) / "model_0.safetensors", - ) + if model_testing_config.checkpoint_format is None: + _compare_safetensor_files( + get_convert_path() / "rank_0.safetensors", + get_convert_path(DistributedCheckpointFormat, FastLLMCheckpointFormat) / "rank_0.safetensors", + expected_keys={_WEIGHT_SHARD_SAVE_NAME}, + ) + else: + _compare_safetensor_files( + get_convert_path() / "rank_0.safetensors", + get_convert_path(DistributedCheckpointFormat, FastLLMCheckpointFormat) / "rank_0.safetensors", + get_convert_path(DistributedCheckpointFormat, model_testing_config.checkpoint_format) + / "rank_0.safetensors", + expected_keys={_WEIGHT_SHARD_SAVE_NAME}, + ) + _compare_safetensor_files( + get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat) / "model_0.safetensors", + get_convert_path(FastLLMCheckpointFormat, model_testing_config.checkpoint_format) / "model_0.safetensors", + ) + _compare_safetensor_files( + get_convert_path(model_testing_config.checkpoint_format, DistributedCheckpointFormat) + / "model_0.safetensors", + get_convert_path(model_testing_config.checkpoint_format, FastLLMCheckpointFormat) / "model_0.safetensors", + ) def _compare_model_configs(config_ref: FastLLMModelConfig, config_test: FastLLMModelConfig): @@ -223,6 +239,24 @@ def test_load_pretrained( reference_config = model_testing_config.model_config_class.from_dict( yaml.safe_load(get_convert_path().parents[1].joinpath("config.yaml").open("r"))["model"] ) + reference_shard = safetensors.torch.load_file(get_convert_path() / "rank_0.safetensors", device="cuda")[ + _WEIGHT_SHARD_SAVE_NAME + ] + load_and_compare_checkpoints( + FastLLMCheckpointFormat, + get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat), + reference_config, + reference_shard, + ) + if model_testing_config.checkpoint_format is None: + load_and_compare_checkpoints( + DistributedCheckpointFormat, + get_convert_path(DistributedCheckpointFormat, FastLLMCheckpointFormat), + reference_config, + reference_shard, + ) + return + reference_config_from_hf = model_testing_config.model_config_class.from_dict( { "base_model": yaml.safe_load( @@ -234,10 +268,6 @@ def test_load_pretrained( ) _compare_architectures(reference_config, reference_config_from_hf) - reference_shard = safetensors.torch.load_file(get_convert_path() / "rank_0.safetensors", device="cuda")[ - _WEIGHT_SHARD_SAVE_NAME - ] - load_and_compare_checkpoints(DistributedCheckpointFormat, get_convert_path(), reference_config, reference_shard) load_and_compare_checkpoints( @@ -253,12 +283,6 @@ def test_load_pretrained( reference_shard, ) - load_and_compare_checkpoints( - FastLLMCheckpointFormat, - get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat), - reference_config, - reference_shard, - ) load_and_compare_checkpoints( FastLLMCheckpointFormat, get_convert_path(FastLLMCheckpointFormat, model_testing_config.checkpoint_format), @@ -284,6 +308,8 @@ def test_load_pretrained( @pytest.mark.depends_on(on=["test_load_pretrained[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_huggingface_model(model_testing_config, get_convert_path): + if model_testing_config.checkpoint_format is None: + return # Test that Fast-LLM's Hugging Face wrapper produces the same results as the converted Hugging Face model. # TODO: Stress the importance of this test as the main correctness test for most models. # TODO: Review test. Move to test_generate? @@ -354,7 +380,8 @@ def test_save_and_load_in_parallel(run_distributed_script, run_test_script_base_ import tests.models.distributed_test_checkpoint script = [ - tests.models.distributed_test_checkpoint.__file__, + "-m", + tests.models.distributed_test_checkpoint.__name__, str(run_test_script_base_path), model_testing_config.name, ] @@ -388,6 +415,11 @@ def test_load_parallel_checkpoint_in_single_gpu( reference_distributed_shard, report_subtest, ): + if ( + model_testing_config.checkpoint_format is None + and distributed_save_load_config.load_format == "{checkpoint_format}" + ): + return # This should only happen when test is skipped (failed dependency). assert reference_distributed_shard is not None distributed_save_load_config = distributed_save_load_config.resolve( @@ -416,11 +448,8 @@ def test_parallel_checkpoint_consistency(model_testing_config, run_test_script_b .resolve(base_path=run_test_script_base_path, model_testing_config=model_testing_config) .save_path / f"{DistributedCheckpointFormat.name}/rank_{rank}.safetensors" - for format_ in ( - DistributedCheckpointFormat.name, - FastLLMCheckpointFormat.name, - "{checkpoint_format}", - ) + for format_ in (DistributedCheckpointFormat.name, FastLLMCheckpointFormat.name) + + (() if model_testing_config.checkpoint_format is None else ("{checkpoint_format}",)) ] ) diff --git a/tests/models/test_generate.py b/tests/models/test_generate.py index 7f0b902f8..ad0de47e6 100644 --- a/tests/models/test_generate.py +++ b/tests/models/test_generate.py @@ -7,7 +7,8 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.schedule.config import ScheduleConfig from fast_llm.engine.schedule.runner import ScheduleRunner -from fast_llm.models.gpt.config import LlamaGPTHuggingfaceCheckpointFormat, PretrainedGPTModelConfig +from fast_llm.models.gpt.config import PretrainedGPTModelConfig +from fast_llm.models.gpt.conversion.config import LlamaCheckpointFormat from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -55,14 +56,14 @@ def _get_hf_model(model_path: str, use_flash_attention: bool, use_bf16: bool): def _get_fast_llm_model( - model_path: str, use_flash_attention: bool, use_bf16: bool, checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat + model_path: str, use_flash_attention: bool, use_bf16: bool, checkpoint_format=LlamaCheckpointFormat ): updates = {} if use_flash_attention: - updates[("base_model", "transformer", "use_flash_attention")] = True + updates[("base_model", "decoder", "block", "mixer", "use_flash_attention")] = True updates[("distributed", "training_dtype")] = "bf16" else: - updates[("base_model", "transformer", "use_flash_attention")] = False + updates[("base_model", "decoder", "block", "mixer", "use_flash_attention")] = False if use_bf16: updates[("distributed", "training_dtype")] = "bf16" return HuggingfaceGPTModelForCausalLM.from_pretrained( @@ -76,7 +77,7 @@ def _get_fast_llm_model( def _get_fast_llm_model_from_model( - model_path: str, use_flash_attention: bool, use_bf16: bool, checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat + model_path: str, use_flash_attention: bool, use_bf16: bool, checkpoint_format=LlamaCheckpointFormat ): updates = { ("pretrained", "path"): model_path, @@ -85,10 +86,10 @@ def _get_fast_llm_model_from_model( } if use_flash_attention: - updates[("model", "base_model", "transformer", "use_flash_attention")] = True + updates[("model", "base_model", "decoder", "block", "mixer", "use_flash_attention")] = True updates[("model", "distributed", "training_dtype")] = "bf16" else: - updates[("model", "base_model", "transformer", "use_flash_attention")] = False + updates[("model", "base_model", "decoder", "block", "mixer", "use_flash_attention")] = False if use_bf16: updates[("model", "distributed", "training_dtype")] = "bf16" @@ -227,7 +228,7 @@ def test_generate( ): _test_generate( model_path, - LlamaGPTHuggingfaceCheckpointFormat, + LlamaCheckpointFormat, use_flash_attention, use_bf16, max_new_tokens, @@ -311,9 +312,7 @@ def _test_generate_from_model(model_path, tokenizer, fast_llm_checkpoint_format) def test_generate_from_model( model_path, ): - _test_generate_from_model( - model_path, AutoTokenizer.from_pretrained(model_path), LlamaGPTHuggingfaceCheckpointFormat - ) + _test_generate_from_model(model_path, AutoTokenizer.from_pretrained(model_path), LlamaCheckpointFormat) @requires_cuda @@ -353,16 +352,14 @@ def _test_forward_return_hidden_states( ) # hidden_states include embeddings layer - assert ( - len(res_fast_llm.hidden_states) - 1 == fast_llm_model.config.fast_llm_config.base_model.transformer.num_layers - ) + assert len(res_fast_llm.hidden_states) - 1 == len(fast_llm_model.config.fast_llm_config.base_model.decoder) @pytest.mark.extra_slow @requires_cuda def test_forward_return_hidden_states(model_path): _test_forward_return_hidden_states( - model_path, LlamaGPTHuggingfaceCheckpointFormat, AutoTokenizer.from_pretrained(model_path).vocab_size + model_path, LlamaCheckpointFormat, AutoTokenizer.from_pretrained(model_path).vocab_size ) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 5c4897646..d14721142 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -57,7 +57,12 @@ def test_and_compare_model( def test_run_model_distributed(run_distributed_script, model_testing_config, run_test_script_base_path, request): import tests.models.distributed_test_model - script = [tests.models.distributed_test_model.__file__, str(run_test_script_base_path), model_testing_config.name] + script = [ + "-m", + tests.models.distributed_test_model.__name__, + str(run_test_script_base_path), + model_testing_config.name, + ] if request.config.getoption("distributed_capture"): logger.warning( "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." diff --git a/tests/test_config.py b/tests/test_config.py index 03d535520..4e73569b3 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -74,20 +74,24 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): pretrained_model_config = GPTModelConfig.from_dict( { "base_model": { - "transformer": { - "mixer": { - "rotary": {"type": "default"}, - "window_size": 32, - "head_groups": 4, - }, - "mlp": { - "intermediate_size": 4096, # Implicit default, default value - "activation": "silu", # Implicit default, non-default value - }, - "normalization": {"type": "rms_norm"}, # Nested - "num_layers": 12, # Default + "embeddings_layer": { "hidden_size": 1024, # Default }, + "decoder": { + "block": { + "mixer": { + "rotary": {"type": "default"}, + "window_size": 32, + "head_groups": 4, + }, + "mlp": { + "intermediate_size": 4096, # Implicit default, default value + "activation": "silu", # Implicit default, non-default value + }, + "normalization": {"type": "rms_norm"}, # Nested + }, + "num_blocks": 12, # Default + }, "output_layer": {"tied_weight": False}, }, "multi_stage": {"zero_stage": 3}, @@ -101,15 +105,16 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): pretrained_model_config.save_metadata(save_config) base_model_update = { - "transformer": { - "mixer": { - "head_groups": 1, # Override to default + "embeddings_layer": {"hidden_size": 512, "vocab_size": 1000}, + "decoder": { + "block": { + "mixer": { + "head_groups": 1, # Override to default + }, + # rotary: Don't override nested. + "normalization": {"implementation": "triton"}, # Update non-default nested }, - # rotary: Don't override nested. - "normalization": {"implementation": "triton"}, # Update non-default nested - "hidden_size": 512, # Override, affects derived value (kv channels) }, - "embeddings_layer": {"vocab_size": 1000}, "peft": {"type": "lora", "freeze_others": False}, # Update default nested, change type } pretrained_config = PretrainedGPTModelConfig.from_dict( @@ -129,36 +134,43 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): expected_config["distributed"].update({"seed": 1234, "training_dtype": "float16"}) if load_config in (ModelConfigType.fast_llm, ModelConfigType.model): expected_config["base_model"] = { - "transformer": { - "mixer": { - "type": "attention", - "rotary": {"type": "default"}, - "window_size": 32, - "head_groups": 1, - }, - "mlp": { - "type": "mlp", - "intermediate_size": 4096, # Implicit default, default value - "activation": "silu", # Implicit default, non-default value - }, - "normalization": {"type": "rms_norm", "implementation": "triton"}, - "num_layers": 12, + "embeddings_layer": { "hidden_size": 512, + "vocab_size": 1000, + }, + "decoder": { + "type": "fixed", + "block": { + "type": "decoder", + "mixer": { + "type": "attention", + "rotary": {"type": "default"}, + "window_size": 32, + "head_groups": 1, + }, + "mlp": { + "type": "mlp", + "intermediate_size": 4096, # Implicit default, default value + "activation": "silu", # Implicit default, non-default value + }, + "normalization": {"type": "rms_norm", "implementation": "triton"}, + }, + "num_blocks": 12, }, - "embeddings_layer": {"vocab_size": 1000}, "output_layer": {"tied_weight": False, "normalization": {"type": "layer_norm"}}, "peft": {"type": "lora", "freeze_others": False}, } else: - base_model_update["transformer"]["normalization"]["type"] = "layer_norm" - base_model_update["transformer"]["mixer"]["type"] = "attention" - base_model_update["transformer"]["mixer"]["rotary"] = {"type": "none"} - base_model_update["transformer"]["mlp"] = {"type": "mlp"} + base_model_update["decoder"]["type"] = "fixed" + base_model_update["decoder"]["block"]["type"] = "decoder" + base_model_update["decoder"]["block"]["normalization"]["type"] = "layer_norm" + base_model_update["decoder"]["block"]["mixer"]["type"] = "attention" + base_model_update["decoder"]["block"]["mixer"]["rotary"] = {"type": "none"} + base_model_update["decoder"]["block"]["mlp"] = {"type": "mlp"} base_model_update["output_layer"] = {"normalization": {"type": "layer_norm"}} base_model_update["peft"] = {"type": "lora", "freeze_others": False} expected_config["base_model"] = base_model_update - print("IKEUFGH", serialized_config, expected_config) check_equal_nested(serialized_config, expected_config) @@ -276,6 +288,6 @@ def test_distributed_global_ranks(bdp: int, sdp: int, tp: int, pp: int, pipeline # Check that the global ranks are partitioned into disjoint groups for each distributed dimension, # and indirectly that `DistributedDim.global_ranks` is consistent between ranks. Assert.eq(sum(len(global_ranks) for global_ranks in global_ranks_set), world_size) - Assert.eq(len({global_rank for global_ranks in global_ranks_set for global_rank in global_ranks})) + Assert.eq(len({global_rank for global_ranks in global_ranks_set for global_rank in global_ranks}), world_size) Assert.eq(len(rank_breakdowns), world_size) diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index 1b49dcfcc..cc5a60a8a 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -1,59 +1,79 @@ +import copy + import pytest from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.training.config import TrainerConfig -from fast_llm.engine.training.trainer import Trainer -from fast_llm.layers.block.block import Block +from fast_llm.engine.multi_stage.config import FastLLMModelConfig +from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.layers.decoder.block import DecoderBlock from fast_llm.utils import Assert from tests.utils.dataset import get_model_test_dataset from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda -def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: - cls = TrainerConfig.get_subclass(model_type) - parsed, unparsed = cls._get_parser().parse_known_args(args) - config: TrainerConfig = cls._from_parsed_args(parsed, unparsed) - distributed = Distributed(config.model.distributed) - trainer = config.get_trainer_class()(config=config) - trainer.setup(distributed, config.get_run(distributed)) - return trainer +def _get_model(config_dict: dict, model_type: str = "gpt") -> FastLLMModel: + cls = FastLLMModelConfig.get_subclass(model_type) + config: FastLLMModelConfig = cls.from_dict(config_dict) + model = config.get_model_class()(config) + model.setup(Distributed(config.distributed)) + return model @requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_frozen_weights(model_testing_config): get_model_test_dataset() - args = model_testing_config.config_args + ["run.tensor_logs.save=False"] - model_ref = _get_trainer_from_args(args, model_testing_config.model_type)._multi_stage - model_frozen = _get_trainer_from_args( - args + [f"model.base_model.transformer.mlp.lr_scale=0"], - model_testing_config.model_type, - )._multi_stage + frozen_config_dict = copy.deepcopy(model_testing_config.config_dict) + decoder_config = frozen_config_dict["model"]["base_model"]["decoder"] + if (decoder_type := decoder_config.get("type", "fixed")) == "fixed": + decoder_config["block"]["mlp"]["lr_scale"] = 0 + elif decoder_type == "pattern": + for block_config in decoder_config["blocks"].values(): + block_config["mlp"]["lr_scale"] = 0 + else: + raise NotImplementedError(decoder_type) + + model_ref = _get_model(model_testing_config.config_dict["model"], model_testing_config.model_type) + model_frozen = _get_model(frozen_config_dict["model"], model_testing_config.model_type) Assert.eq( model_ref._num_stages, model_frozen._num_stages, ) frozen_parameter_counts = [ - sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, Block) else 0 + sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, DecoderBlock) else 0 for layer in model_ref.base_model.layers ] - for weight_buffer_ref, weight_buffer_frozen in zip( - model_ref._weight_buffers, model_frozen._weight_buffers, strict=True - ): - Assert.eq(weight_buffer_ref.numel() == weight_buffer_frozen.numel()) - for grad_buffer_ref, grad_buffer_frozen, frozen_parameter_count in zip( - model_ref._grad_buffers, model_frozen._grad_buffers, frozen_parameter_counts, strict=True - ): - Assert.eq(grad_buffer_ref.numel() - grad_buffer_frozen.numel() == frozen_parameter_count) + # Make sure each layer has its own buffer so the check below works. + Assert.eq( + num_stages := len(model_ref.base_model.layers), + len(model_frozen.base_model.layers), + len(model_ref.stages), + len(model_frozen.stages), + ) + for stage_index in range(num_stages): + # Weight buffers are the same. + Assert.eq( + model_ref._weight_buffers[model_ref._weight_buffer_indices[stage_index]].numel(), + model_frozen._weight_buffers[model_frozen._weight_buffer_indices[stage_index]].numel(), + ) + # Weight buffers exclude frozen weights. + Assert.eq( + model_ref._grad_buffers[model_ref._grad_buffer_indices[stage_index]].numel() + - model_frozen._grad_buffers[model_frozen._grad_buffer_indices[stage_index]].numel(), + frozen_parameter_counts[stage_index], + ) for shard_name, shard_frozen_count in zip( model_ref._shard_names, [0] + [sum(frozen_parameter_counts)] * (len(model_ref._all_shard_names) - 1), strict=True, ): + # Same with shards. Assert.eq( - model_ref.get_shard(shard_name).numel() - model_frozen.get_shard(shard_name).numel(), shard_frozen_count + model_ref.get_shard(shard_name).numel() - model_frozen.get_shard(shard_name).numel(), + shard_frozen_count, + msg=shard_name, ) diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index 5c3ecd8a2..306beadf8 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -67,7 +67,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon sub_configs={ ("init", None): get_config(), # Saved gradient include the gradient scaling by 2**16 (default initial value) - (None, "fw"): get_config(1e-3, 3e-4), + (None, "fw"): get_config(1.2e-3, 3e-4), (None, "bw"): get_config(3e-3, 1e-5, scale=2**16), (None, "bias"): get_config(3e-3, 1e-4, scale=2**16), (None, "gradient"): get_config(3e-3, 5e-5, scale=2**16), diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index abd3d4bad..55ac4ae74 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -11,20 +11,15 @@ from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.models.gpt.config import ( - DiffusionDreamGPTHuggingfaceCheckpointFormat, - DiffusionLlamaGPTHuggingfaceCheckpointFormat, - LlamaGPTHuggingfaceCheckpointFormat, - MistralGPTHuggingfaceCheckpointFormat, - MixtralGPTHuggingfaceCheckpointFormat, - MTPLlamaGPTHuggingfaceCheckpointFormat, - Qwen2GPTHuggingfaceCheckpointFormat, - Starcoder2GPTHuggingfaceCheckpointFormat, -) -from fast_llm.models.ssm.config import ( - AprielSSMHHybridHuggingfaceCheckpointFormat, - AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, - LLambaHuggingfaceCheckpointFormat, +from fast_llm.models.gpt.conversion.config import ( + AprielHybridSSMCheckpointFormat, + DiffusionDreamCheckpointFormat, + DiffusionLlamaCheckpointFormat, + LlamaCheckpointFormat, + MistralCheckpointFormat, + MixtralCheckpointFormat, + MTPLlamaCheckpointFormat, + Qwen2CheckpointFormat, ) from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.global_variables import MODEL_DATASET_PREFIX, MODEL_TEST_VOCAB_SIZE @@ -175,9 +170,9 @@ def _update_and_add_testing_config( # Needed to match Megatron (init_1 / (2 * num_layers) ** 0.5) init_2 = {"initialization": {"type": "normal", "std": 2**-6.5}} -MODEL_CONFIGS["gpt2"] = ModelTestingConfig( +MODEL_CONFIGS["gpt_2"] = ModelTestingConfig( # Tests gpt2 features (absolute embeddings, layer norm, relu activation, tied embeddings, MHA, linear biases). - name="gpt2", + name="gpt_2", model_type="gpt", config_dict={ "run": { @@ -197,22 +192,28 @@ def _update_and_add_testing_config( "embeddings_layer": { "word_embeddings": init_1, "position_embeddings": {"enabled": True, **init_1}, + "hidden_size": 256, "num_position_embeddings": 512, "vocab_size": MODEL_TEST_VOCAB_SIZE, }, - "transformer": { - "mixer": { - "query_layer": {"weight": init_1}, - "key_layer": {"weight": init_1}, - "value_layer": {"weight": init_1}, - "dense_layer": {"weight": init_2}, - "heads": 8, - "head_groups": 8, - "head_size": 32, + "decoder": { + "block": { + "mixer": { + "query_layer": {"weight": init_1}, + "key_layer": {"weight": init_1}, + "value_layer": {"weight": init_1}, + "dense_layer": {"weight": init_2}, + "heads": 8, + "head_groups": 8, + "head_size": 32, + }, + "mlp": { + "layer_1": {"weight": init_1}, + "layer_2": {"weight": init_2}, + "intermediate_size": 1024, + }, }, - "mlp": {"layer_1": {"weight": init_1}, "layer_2": {"weight": init_2}, "intermediate_size": 1024}, - "num_layers": 2, - "hidden_size": 256, + "num_blocks": 2, }, "output_layer": {"output_weight": init_1}, }, @@ -288,7 +289,8 @@ def _update_and_add_testing_config( groups={ ModelTestingGroup.basic: ModelTestingGroupAction.main, ModelTestingGroup.checkpoint: ModelTestingGroupAction.main, - ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + # TODO: PP checkpoint failing for tied weights. + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.normal, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, @@ -297,10 +299,10 @@ def _update_and_add_testing_config( _update_and_add_testing_config( # Tests MQA. - "gpt2", + "gpt_2", "starcoder", updates={ - ("model", "base_model", "transformer", "mixer", "head_groups"): 1, + ("model", "base_model", "decoder", "block", "mixer", "head_groups"): 1, }, megatron_args=["--group-query-attention"], checkpoint_format=None, @@ -316,11 +318,11 @@ def _update_and_add_testing_config( _update_and_add_testing_config( # Tests intermediate between gpt2 and llama, closest converter to gpt2. - "gpt2", - "starcoder2", + "gpt_2", + "starcoder_2", updates={ - ("model", "base_model", "transformer", "mixer", "head_groups"): 4, - ("model", "base_model", "transformer", "mixer", "rotary", "type"): "default", + ("model", "base_model", "decoder", "block", "mixer", "head_groups"): 4, + ("model", "base_model", "decoder", "block", "mixer", "rotary", "type"): "default", ("model", "base_model", "embeddings_layer", "position_embeddings", "enabled"): False, }, megatron_args=[ @@ -329,7 +331,7 @@ def _update_and_add_testing_config( "--use-rotary-position-embeddings", "--no-position-embedding", ], - checkpoint_format=Starcoder2GPTHuggingfaceCheckpointFormat, + checkpoint_format=None, # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, @@ -343,14 +345,14 @@ def _update_and_add_testing_config( _update_and_add_testing_config( # Main tested model. - "starcoder2", + "starcoder_2", "llama", updates={ - ("model", "base_model", "transformer", "mixer", "add_linear_biases"): False, - ("model", "base_model", "transformer", "mlp", "gated"): True, - ("model", "base_model", "transformer", "mlp", "activation"): "silu", - ("model", "base_model", "transformer", "mlp", "add_linear_biases"): False, - ("model", "base_model", "transformer", "normalization", "type"): "rms_norm", + ("model", "base_model", "decoder", "block", "mixer", "add_linear_biases"): False, + ("model", "base_model", "decoder", "block", "mlp", "gated"): True, + ("model", "base_model", "decoder", "block", "mlp", "activation"): "silu", + ("model", "base_model", "decoder", "block", "mlp", "add_linear_biases"): False, + ("model", "base_model", "decoder", "block", "normalization", "type"): "rms_norm", ("model", "base_model", "output_layer", "normalization", "type"): "rms_norm", ("model", "base_model", "output_layer", "tied_weight"): False, }, @@ -361,7 +363,7 @@ def _update_and_add_testing_config( "--ffn-hidden-size=1024", "--untie-embeddings-and-output-weights", ], - checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, + checkpoint_format=LlamaCheckpointFormat, # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.main, @@ -376,13 +378,13 @@ def _update_and_add_testing_config( _update_and_add_testing_config( # Tests llama3-style rotary embeddings. "llama", - "llama3", + "llama_3", updates={ - ("model", "base_model", "transformer", "mixer", "rotary", "type"): "llama3", + ("model", "base_model", "decoder", "block", "mixer", "rotary", "type"): "llama3", }, # Megatron doesn't support Llama3-style Rotary Embeddings megatron_args=None, - checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, + checkpoint_format=LlamaCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, @@ -398,11 +400,11 @@ def _update_and_add_testing_config( "llama", "llama_yarn", updates={ - ("model", "base_model", "transformer", "mixer", "rotary", "type"): "yarn", + ("model", "base_model", "decoder", "block", "mixer", "rotary", "type"): "yarn", }, # Megatron doesn't support Yarn-style Rotary Embeddings megatron_args=None, - checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, + checkpoint_format=LlamaCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, @@ -420,7 +422,7 @@ def _update_and_add_testing_config( updates={}, # Megatron doesn't support Yarn-style Rotary Embeddings megatron_args=None, - checkpoint_format=DiffusionLlamaGPTHuggingfaceCheckpointFormat, + checkpoint_format=DiffusionLlamaCheckpointFormat, # TODO: Conversion is broken. # TODO: Add back generate as `normal` when stable. groups={ @@ -436,13 +438,13 @@ def _update_and_add_testing_config( _update_and_add_testing_config( # Tests multi-token prediction, custom HF model and converter. "llama", - "llama_mtp", + "mtp_llama", updates={ - ("model", "base_model", "output_layer", "prediction_heads"): 4, + ("model", "base_model", "output_layer", "prediction_heads"): 2, }, # Megatron doesn't support multi-token prediction. megatron_args=None, - checkpoint_format=MTPLlamaGPTHuggingfaceCheckpointFormat, + checkpoint_format=MTPLlamaCheckpointFormat, # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, @@ -458,14 +460,14 @@ def _update_and_add_testing_config( _update_and_add_testing_config( # Tests partial linear biases, Qwen2 converter. "llama", - "qwen2", + "qwen_2", # TODO: replace updates={ - ("model", "base_model", "transformer", "add_linear_biases"): "only_attn_qkv", + ("model", "base_model", "decoder", "block", "add_linear_biases"): "only_attn_qkv", }, # Megatron doesn't support per sub layer biases. megatron_args=None, - checkpoint_format=Qwen2GPTHuggingfaceCheckpointFormat, + checkpoint_format=Qwen2CheckpointFormat, # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.broken, @@ -479,13 +481,13 @@ def _update_and_add_testing_config( _update_and_add_testing_config( # Tests diffusion dream converter. - "qwen2", + "qwen_2", "dream", # TODO: replace only_attn_qkv updates={}, # Megatron doesn't support per sub layer biases. megatron_args=None, - checkpoint_format=DiffusionDreamGPTHuggingfaceCheckpointFormat, + checkpoint_format=DiffusionDreamCheckpointFormat, # TODO: Conversion is broken. # TODO: Add back generate as `normal` when stable. groups={ @@ -503,11 +505,11 @@ def _update_and_add_testing_config( "llama", "mistral", updates={ - ("model", "base_model", "transformer", "mixer", "window_size"): 128, + ("model", "base_model", "decoder", "block", "mixer", "window_size"): 128, }, # Megatron doesn't support sliding windows. megatron_args=None, - checkpoint_format=MistralGPTHuggingfaceCheckpointFormat, + checkpoint_format=MistralCheckpointFormat, # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, @@ -524,16 +526,16 @@ def _update_and_add_testing_config( "llama", "mixtral", updates={ - ("model", "base_model", "transformer", "mlp", "type"): "moe", - ("model", "base_model", "transformer", "mlp", "router", "weight"): init_1, - ("model", "base_model", "transformer", "mlp", "experts"): 4, - ("model", "base_model", "transformer", "mlp", "experts_per_token"): 4, + ("model", "base_model", "decoder", "block", "mlp", "type"): "moe", + ("model", "base_model", "decoder", "block", "mlp", "router", "weight"): init_1, + ("model", "base_model", "decoder", "block", "mlp", "experts"): 4, + ("model", "base_model", "decoder", "block", "mlp", "experts_per_token"): 4, }, megatron_args=[ "--num-experts=4", "--moe-router-topk=4", ], - checkpoint_format=MixtralGPTHuggingfaceCheckpointFormat, + checkpoint_format=MixtralCheckpointFormat, # TODO: New base image broke mixtral groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, @@ -546,29 +548,41 @@ def _update_and_add_testing_config( compare_factor=2.0, ) +_llama_block = MODEL_CONFIGS["llama"].config_dict["model"]["base_model"]["decoder"]["block"] + + _update_and_add_testing_config( # Tests hybrid Mamba, llamba converter. "llama", - "llamba", - model_type="hybrid_ssm", + "hybrid_mamba", updates={ - ("model", "base_model", "ssm"): { - "type": "mamba", - "d_inner": 512, - "state_size": 16, - "dt_rank": 16, - "add_linear_biases": False, + ("model", "base_model", "decoder"): { + "type": "pattern", + "blocks": { + "t": copy.deepcopy(_llama_block), + "m": { + **copy.deepcopy(_llama_block), + "mixer": { + "type": "mamba", + "d_inner": 512, + "state_size": 16, + "dt_rank": 16, + "add_linear_biases": False, + }, + }, + }, + "num_blocks": 2, + "pattern": ["t", "m"], }, - ("model", "base_model", "hybrid_block_layout"): "['t','m']", }, megatron_args=None, - checkpoint_format=LLambaHuggingfaceCheckpointFormat, + checkpoint_format=AprielHybridSSMCheckpointFormat, # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, # TODO: Fix and bring back to `testing_groups` - ModelTestingGroup.convert: ModelTestingGroupAction.broken, + ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, @@ -581,25 +595,35 @@ def _update_and_add_testing_config( _update_and_add_testing_config( # Tests hybrid Mamba 2. "llama", - "hybrid_mamba2", - model_type="hybrid_ssm", + "hybrid_mamba_2", updates={ - ("model", "base_model", "ssm"): { - "type": "mamba_2", - "d_inner": 512, - "state_size": 8, - "dt_rank": 16, - "d_xb": 256, - "add_linear_biases": False, + ("model", "base_model", "decoder"): { + "type": "pattern", + "blocks": { + "t": copy.deepcopy(_llama_block), + "m2": { + **copy.deepcopy(_llama_block), + "mixer": { + "type": "mamba_2", + "dt_layer": {"bias": {"enabled": True}}, + "d_inner": 512, + "state_size": 8, + "dt_rank": 16, + "d_xb": 256, + "add_linear_biases": False, + }, + }, + }, + "num_blocks": 2, + "pattern": ["t", "m2"], }, - ("model", "base_model", "hybrid_block_layout"): "['t','m2']", }, megatron_args=None, - checkpoint_format=AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, + checkpoint_format=AprielHybridSSMCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.broken, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, @@ -616,26 +640,35 @@ def _update_and_add_testing_config( _update_and_add_testing_config( # Tests hybrid discrete Mamba 2. "llama", - "hybrid_discrete_mamba2", - model_type="hybrid_ssm", + "hybrid_discrete_mamba_2", updates={ - ("model", "base_model", "ssm"): { - "type": "discrete_mamba_2", - "d_inner": 512, - "state_size": 8, - "n_qk_heads": 8, - "n_v_heads": 16, - "chunk_size": 32, - "add_linear_biases": False, + ("model", "base_model", "decoder"): { + "type": "pattern", + "blocks": { + "t": copy.deepcopy(_llama_block), + "m2d": { + **copy.deepcopy(_llama_block), + "mixer": { + "type": "discrete_mamba_2", + "d_inner": 512, + "state_size": 8, + "n_qk_heads": 8, + "n_v_heads": 16, + "chunk_size": 32, + "add_linear_biases": False, + }, + }, + }, + "num_blocks": 2, + "pattern": ["t", "m2d"], }, - ("model", "base_model", "hybrid_block_layout"): "['t','m2d']", }, megatron_args=None, - checkpoint_format=AprielSSMHHybridHuggingfaceCheckpointFormat, + checkpoint_format=AprielHybridSSMCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.broken, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, # TODO: Implement diff --git a/tests/utils/save_load_configs.py b/tests/utils/save_load_configs.py index f5a15020e..3e7cbf10f 100644 --- a/tests/utils/save_load_configs.py +++ b/tests/utils/save_load_configs.py @@ -18,13 +18,32 @@ class DistributedSaveLoadConfig: num_gpus: int = 2 def resolve(self, base_path: pathlib.Path, model_testing_config: ModelTestingConfig) -> typing.Self: + if model_testing_config.checkpoint_format is None: + format = { + "distributed": do_get_convert_path( + DistributedCheckpointFormat.name, FastLLMCheckpointFormat.name, base_path=pathlib.Path() + ), + "fast_llm": do_get_convert_path( + FastLLMCheckpointFormat.name, DistributedCheckpointFormat.name, base_path=pathlib.Path() + ), + } + else: + format = { + "checkpoint_format": model_testing_config.checkpoint_format.name, + "distributed": do_get_convert_path( + DistributedCheckpointFormat.name, + model_testing_config.checkpoint_format.name, + base_path=pathlib.Path(), + ), + "fast_llm": do_get_convert_path( + FastLLMCheckpointFormat.name, model_testing_config.checkpoint_format.name, base_path=pathlib.Path() + ), + } return dataclasses.replace( self, - load_path=base_path - / str(self.load_path).format(checkpoint_format=model_testing_config.checkpoint_format.name), - load_format=self.load_format.format(checkpoint_format=model_testing_config.checkpoint_format.name), - save_path=base_path - / str(self.save_path).format(checkpoint_format=model_testing_config.checkpoint_format.name), + load_path=base_path / str(self.load_path).format(**format), + load_format=self.load_format.format(**format), + save_path=base_path / str(self.save_path).format(**format), ) @property @@ -58,11 +77,11 @@ def get_convert_path(run_test_script_base_path): for pretrained_format, pretrained_path in ( ( DistributedCheckpointFormat.name, - do_get_convert_path(DistributedCheckpointFormat.name, "{checkpoint_format}", base_path=pathlib.Path()), + pathlib.Path("{distributed}"), ), ( FastLLMCheckpointFormat.name, - do_get_convert_path(FastLLMCheckpointFormat.name, "{checkpoint_format}", base_path=pathlib.Path()), + pathlib.Path("{fast_llm}"), ), ( "{checkpoint_format}", From 8249f8a15e102f06d1b07ce957528e8a5ea7589f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 16 Sep 2025 22:52:23 -0400 Subject: [PATCH 81/82] fix --- fast_llm/config.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 5284d8bee..4d3858fd7 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -1122,6 +1122,3 @@ def pop_nested_dict_value[ return d.pop(keys[-1]) else: return d.pop(keys) - - -i = 0 From 870afd3739b5eb29991c8207880815726f1fa45d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 18 Sep 2025 17:00:22 -0400 Subject: [PATCH 82/82] v0.3.0 --- examples/mistral.yaml | 2 +- fast_llm/__init__.py | 2 +- fast_llm/config.py | 50 ++--- fast_llm/data/data/gpt/config.py | 40 +--- fast_llm/data/dataset/config.py | 19 +- fast_llm/data/dataset/gpt/config.py | 172 +----------------- fast_llm/data/dataset/gpt/indexed.py | 10 +- fast_llm/data/dataset/gpt/sampled.py | 159 +--------------- fast_llm/engine/checkpoint/config.py | 31 ---- fast_llm/engine/checkpoint/distributed.py | 28 +-- .../engine/config_utils/initialization.py | 11 +- fast_llm/engine/config_utils/run.py | 2 +- fast_llm/engine/distributed/config.py | 20 +- fast_llm/engine/distributed/distributed.py | 2 +- fast_llm/engine/evaluation/__init__.py | 0 fast_llm/engine/evaluation/config.py | 15 +- .../engine/evaluation/lm_eval/__init__.py | 0 fast_llm/engine/inference/config.py | 4 +- fast_llm/engine/multi_stage/config.py | 42 ++--- fast_llm/engine/multi_stage/fast_llm_model.py | 12 +- fast_llm/engine/multi_stage/fsdp.py | 8 +- fast_llm/engine/optimizer/optimizer.py | 2 +- fast_llm/engine/training/config.py | 28 +-- fast_llm/engine/training/trainer.py | 4 +- fast_llm/layers/attention/config.py | 2 +- fast_llm/layers/attention/preprocessing.py | 6 +- fast_llm/layers/attention/rotary/config.py | 11 +- fast_llm/layers/block/config.py | 22 +-- .../layers/common/normalization/config.py | 25 +-- fast_llm/layers/common/peft/config.py | 11 +- fast_llm/layers/decoder/config.py | 22 +-- fast_llm/layers/decoder/mlp/mlp.py | 2 +- fast_llm/layers/language_model/config.py | 14 -- fast_llm/layers/language_model/embedding.py | 2 +- fast_llm/models/gpt/config.py | 43 ----- tests/data/common.py | 9 +- tests/data/test_blending.py | 31 +--- tests/data/test_concatenate.py | 2 +- tests/data/test_concatenated_memmap.py | 78 -------- tests/data/test_fim.py | 31 +--- tests/data/test_random.py | 12 +- tests/data/test_sampling.py | 22 +-- tests/data/test_slice.py | 31 ---- tests/layers/test_lm_head.py | 12 +- tests/models/test_generate.py | 8 +- tests/models/test_match_megatron.py | 101 +++++++++- tests/test_attention.py | 2 +- tests/test_config.py | 6 +- tests/utils/dataset.py | 22 --- tests/utils/distributed_configs.py | 4 +- tests/utils/model_configs.py | 1 + 51 files changed, 218 insertions(+), 977 deletions(-) create mode 100644 fast_llm/engine/evaluation/__init__.py create mode 100644 fast_llm/engine/evaluation/lm_eval/__init__.py delete mode 100644 tests/data/test_concatenated_memmap.py diff --git a/examples/mistral.yaml b/examples/mistral.yaml index 4b7fdd968..88655954f 100644 --- a/examples/mistral.yaml +++ b/examples/mistral.yaml @@ -62,7 +62,7 @@ model: multi_stage: zero_stage: 2 distributed: - training_dtype: bf16 + compute_dtype: bf16 seed: 984059 run: experiment_dir: mistral_example diff --git a/fast_llm/__init__.py b/fast_llm/__init__.py index d3ec452c3..493f7415d 100644 --- a/fast_llm/__init__.py +++ b/fast_llm/__init__.py @@ -1 +1 @@ -__version__ = "0.2.0" +__version__ = "0.3.0" diff --git a/fast_llm/config.py b/fast_llm/config.py index 4d3858fd7..9644df9c1 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -759,28 +759,8 @@ def from_dict( return cls._from_dict(default, strict) @classmethod - def from_flat_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - ) -> typing.Self: - # TODO v0.3: Remove flat format - return cls._from_dict(default, strict, True) - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - # TODO v0.3: Remove flat format + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: out_arg_dict = {"_from_dict_check": True} - - # TODO v0.3: Remove backward compatibility fix - if "__class__" in default: - del default["__class__"] - try: actual_cls = cls.get_subclass(default.get("type")) except KeyError: @@ -788,29 +768,23 @@ def _from_dict( actual_cls = cls if actual_cls is not None and actual_cls is not cls: - return actual_cls._from_dict(default, strict=strict, flat=flat) + return actual_cls._from_dict(default, strict=strict) # Do not validate yet in case the root class sets cross-dependencies in validation. with NoAutoValidate(): for name, field in cls.fields(): if not field.init or field._field_type != dataclasses._FIELD: # noqa continue - if flat: - if isinstance(field.type, type) and issubclass(field.type, Config): - out_arg_dict[name] = field.type._from_dict(default, False, True) - elif name in default: - out_arg_dict[name] = default.pop(name) - else: - # Check for nested configs to instantiate. - try: - value = cls._from_dict_nested(default.pop(name, MISSING), field.type, strict) - if value is not MISSING: - out_arg_dict[name] = value - except FieldTypeError as e: - raise FieldTypeError( - f"Invalid field type `{get_type_name(field.type)}` in class {cls._get_class_name()}: " - + ", ".join(e.args) - ) + # Check for nested configs to instantiate. + try: + value = cls._from_dict_nested(default.pop(name, MISSING), field.type, strict) + if value is not MISSING: + out_arg_dict[name] = value + except FieldTypeError as e: + raise FieldTypeError( + f"Invalid field type `{get_type_name(field.type)}` in class {cls._get_class_name()}: " + + ", ".join(e.args) + ) out = cls(**out_arg_dict) # noqa if strict and default: out._unknown_fields = default.copy() diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index 405d1c672..efee46959 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -1,23 +1,16 @@ import logging -import typing from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.data.config import MultiprocessingContext, TokenizerConfig from fast_llm.data.data.config import DataConfig -from fast_llm.data.dataset.gpt.config import ( - GPTLegacyConfig, - GPTLegacyDatasetConfig, - GPTSampledDatasetConfig, - GPTSamplingConfig, -) -from fast_llm.engine.distributed.config import PhaseType +from fast_llm.data.dataset.gpt.config import GPTSampledDatasetConfig, GPTSamplingConfig from fast_llm.utils import Assert logger = logging.getLogger(__name__) @config_class() -class GPTDataConfig(DataConfig, GPTLegacyConfig): +class GPTDataConfig(DataConfig): """ Configuration for the dataset(s), split and sampling. Currently hard-coded to a GPT dataset. @@ -48,32 +41,3 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): desc="Multiprocessing context. Do not touch.", hint=FieldHint.expert, ) - - def _validate(self) -> None: - if not self.datasets: - logger.warning( - "Using the legacy dataset definition format." " Specify it through `data.datasets` instead." - ) - self.datasets = { - phase.value.lower(): GPTLegacyDatasetConfig.from_dict(self, strict=False) - for phase in (PhaseType.training, PhaseType.validation, PhaseType.test) - } - super()._validate() - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - # TODO v0.x: Remove backward compatibility. - if "datasets" in default: - for phase in PhaseType: - if phase.value in default["datasets"]: - rename = phase.value.lower() - logger.warning(f"Renaming dataset {phase.value} to {rename}") - assert rename not in default["datasets"] - default["datasets"][rename] = default["datasets"].pop(phase.value) - - return super()._from_dict(default, strict, flat) diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 5e3ced8a4..0c1b0cd09 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -204,11 +204,6 @@ class BlendedDatasetConfig(SampledDatasetConfig): desc="The blending weight of each dataset.", hint=FieldHint.core, ) - legacy: bool = Field( - default=False, - desc="Use the legacy formulas for sub-dataset seeds and sample sizes.", - hint=FieldHint.deprecated, - ) def _validate(self) -> None: self.weights = normalize_probabilities(self.weights) @@ -231,20 +226,10 @@ def build_and_sample( sampling, parameters=dataclasses.replace( sampling.parameters, - num_samples=( - math.ceil( - weight - * ( - sampling.parameters.num_samples - + 5 * (sampling.parameters.num_samples * (1 - weight)) ** 0.5 - ) - ) - if self.legacy - else math.ceil(weight * sampling.parameters.num_samples) + 1 - ), + num_samples=math.ceil(weight * sampling.parameters.num_samples) + 1, ), # TODO: Seed may not be unique for nested blended datasets. - config=sampling.config.to_copy({"seed": sampling.config.seed + i * (0 if self.legacy else 697)}), + config=sampling.config.to_copy({"seed": sampling.config.seed + i * 697}), ), ) for i, (dataset, weight) in enumerate(zip(self.datasets, self.weights, strict=True)) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index ef2efedc9..656cd7d24 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -1,10 +1,8 @@ import dataclasses import enum -import json import pathlib import time import typing -import warnings import yaml @@ -22,8 +20,7 @@ SamplingData, SamplingParameters, ) -from fast_llm.engine.distributed.config import PhaseType -from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum +from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset, GPTDatasetSlice, GPTIndexedDataset @@ -41,7 +38,6 @@ class ShufflingType(str, enum.Enum): skip_first_epoch = "skip_first_epoch" # Disable shuffling entirely. disabled = "disabled" - legacy = "legacy" @config_class() @@ -222,45 +218,6 @@ def _convert_paths(self, config): return config -# Add user-friendly names for the configs. -@config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated_memmap"}) -class GPTConcatenatedMemmapConfig(GPTIndexedDatasetConfig): - # TODO v0.3: Remove. - _abstract: typing.ClassVar[bool] = False - path: pathlib.Path = Field( - default=None, - desc="The path to a dataset directory.", - hint=FieldHint.core, - ) - - def _validate(self) -> None: - warnings.warn("`concatenated_memmap` dataset is deprecated. Use `file` instead.", DeprecationWarning) - super()._validate() - - def build(self) -> "GPTConcatenatedDataset": - - assert self.path.is_dir() - index_path = self.path / "index.txt" - - if index_path.is_file(): - prefixes = [self.path / line.strip() for line in index_path.open("r").readlines()] - else: - warnings.warn( - f"The dataset path {self.path} points to a directory." - " The dataset will be indexed automatically, which may be unsafe." - " We recommend using an index file instead." - ) - prefixes = [ - path.with_suffix("") - for path in self.path.iterdir() - if path.suffix == ".idx" and path.is_file() and path.with_suffix(".bin").is_file() - ] - dataset_config = GPTConcatenatedDatasetConfig.from_dict( - {"datasets": [{"type": "memmap", "path": prefix} for prefix in prefixes]} - ) - return dataset_config.build() - - @config_class() class FimConfig(Config): """ @@ -268,7 +225,7 @@ class FimConfig(Config): """ rate: float = Field( - # TODO: Use meaningful default now that fim is a wrapper? (bad for legacy config) + # TODO: Use meaningful default now that fim is a wrapper? default=0.0, desc="FIM rate for each sample.", hint=FieldHint.core, @@ -352,131 +309,6 @@ def build_and_sample( return GPTFimDataset(self, self.dataset.build_and_sample(sampling), sampling) -class LegacyDatasetSource(str, enum.Enum): - """ - An enum for the different ways to load datasets. - """ - - list = "list" - file = "file" - random = "random" - - -def _validate_split(value: list[int]) -> list[int]: - Assert.leq(len(value), 3) - return value + [0] * (len(value) - 3) - - -def _validate_path(value: str | list[str]) -> list[str]: - return [value] if isinstance(value, str) else value - - -@config_class() -class GPTLegacyConfig(Config): - split: list[float] = Field( - default_factory=lambda: [969, 30, 1], - desc="Split ratio for train, valid and test datasets.", - hint=FieldHint.deprecated, - valid=_validate_split, - ) - format: LegacyDatasetSource = Field( - default=LegacyDatasetSource.list, - desc="Format for the dataset definition.", - hint=FieldHint.deprecated, - ) - path: list[str] = Field( - default_factory=list, - desc="Path or list of paths and weights.", - hint=FieldHint.deprecated, - valid=_validate_path, - ) - fim: FimConfig = Field( - desc="Configuration for Fill In the Middle (FIM).", - hint=FieldHint.feature, - ) - - -@config_class(dynamic_type={GPTSampledDatasetConfig: "legacy"}) -class GPTLegacyDatasetConfig(GPTSampledDatasetConfig, GPTLegacyConfig): - _abstract: typing.ClassVar[bool] = False - - def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset: - - if self.format == LegacyDatasetSource.random: - Assert.eq(len(self.path), 0) - dataset_config = GPTRandomDatasetConfig() - else: - if self.format == LegacyDatasetSource.file: - Assert.eq(len(self.path), 1) - data_path = pathlib.Path(self.path[0]) - dataset_defs = json.load(data_path.open("r")) - data_base_path = data_path.parent - dataset_prefixes = [ - (data_base_path / dataset_def["prefix"]).resolve() for dataset_def in dataset_defs["datasets"] - ] - dataset_weights = normalize_probabilities( - [dataset_def["weight"] for dataset_def in dataset_defs["datasets"]] - ) - elif self.format == LegacyDatasetSource.list: - Assert.geq(len(self.path), 1) - if len(self.path) == 1: - dataset_prefixes, dataset_weights = [self.path[0].strip()], [1.0] - else: - Assert.custom(lambda x: x % 2 == 0, len(self.path)) - dataset_prefixes = [pathlib.Path(x.strip()).resolve() for x in self.path[1::2]] - assert len(dataset_prefixes) == len(set(dataset_prefixes)) - dataset_weights = normalize_probabilities([float(x) for x in self.path[::2]]) - else: - raise NotImplementedError(self.format) - - phase_splits = padded_cumsum(normalize_probabilities(self.split)) - - phase_index = { - PhaseType.training.value.lower(): 0, - PhaseType.validation.value.lower(): 1, - PhaseType.test.value.lower(): 2, - }[sampling.dataset_name] - - dataset_configs = [ - { - "type": "slice", - # TODO: this duplicates memmap datasets for each phase. - "dataset": {"type": "memmap", "path": prefix}, - "begin": float(phase_splits[phase_index]), - "end": float(phase_splits[phase_index + 1]), - } - for prefix in dataset_prefixes - ] - dataset_config = ( - { - "type": "blended", - "name": "blended", - "datasets": dataset_configs, - "weights": dataset_weights, - "legacy": True, - } - if len(dataset_configs) > 1 - else dataset_configs[0] - ) - if self.fim.rate > 0: - dataset_config = { - "type": "fim", - "dataset": dataset_config, - **self.fim.to_dict(), - } - # Legacy sampling config - dataset_config = { - "type": "sampled", - "dataset": dataset_config, - "sampling": { - "seed": sampling.distributed.config.seed, - "shuffle": "legacy", - }, - } - - return GPTSampledDatasetConfig.from_dict(dataset_config).build_and_sample(sampling) - - @config_class(dynamic_type={GPTSampledDatasetConfig: "test_slow"}) class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig): """ diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 688ea6a70..896229772 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -3,7 +3,7 @@ import numpy as np -from fast_llm.data.dataset.gpt.config import GPTSamplingData, ShufflingType +from fast_llm.data.dataset.gpt.config import GPTSamplingData from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset if typing.TYPE_CHECKING: @@ -26,13 +26,9 @@ def get_document_size(self, index: int) -> int: """ def sample(self, sampling: GPTSamplingData) -> "GPTSampledIndexedDataset": - from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset, LegacyGPTSampledIndexedDataset + from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset - return ( - LegacyGPTSampledIndexedDataset(self, sampling) - if sampling.config.shuffle == ShufflingType.legacy - else GPTSampledIndexedDataset(self, sampling) - ) + return GPTSampledIndexedDataset(self, sampling) class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[IndexedDatasetType], GPTIndexedDataset): diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 6a06002cb..95006f18e 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -17,7 +17,7 @@ from fast_llm.utils import Assert try: - from fast_llm.csrc.data import build_padded_token_cumsum, build_sample_idx # noqa + from fast_llm.csrc.data import build_padded_token_cumsum # noqa _extension_available = True except ImportError: @@ -531,160 +531,3 @@ def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._unshuffled_tokens = data["unshuffled_tokens"] self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch - - -class LegacyGPTSampledIndexedDataset(SampledDataset): - """ - A GPT dataset augmented with a sampling, i.e., - a pre-computed, shuffled list of samples to be indexed sequentially (as-is) during training. - The sampling exactly matches Megatron-LM with matching parameters. - Supports optional post-processing with FIM. - """ - - def __init__( - self, - indexed_dataset: GPTIndexedDataset, - sampling: GPTSamplingData, - ): - assert isinstance(sampling, GPTSamplingData) - self._indexed_dataset = indexed_dataset - if not sampling.parameters.truncate_documents: - raise NotImplementedError( - "Legacy sampling only supports document truncation. Please use the latest dataset format." - ) - self._config = sampling.config - self._parameters = sampling.parameters - if self._parameters.use_preference_loss_spans: - raise NotImplementedError("Legacy sampling does not support preference loss masking.") - - if sampling.cache_directory is None: - log_main_rank( - " > No dataset cache directory provided, building the index map on all ranks." - "This may be very inefficient...", - log_fn=logger.warning, - ) - base_path = None - else: - base_path = ( - sampling.cache_directory - / f"{self.name}_ns_{self._parameters.num_samples}_sl_{self._parameters.sequence_length}" - f"_s_{self._config.seed}" - ) - - self._doc_idx = MemmapArray( - None if base_path is None else base_path.with_name(base_path.name + "_doc_idx.npy") - ) - self._sample_idx = MemmapArray( - None if base_path is None else base_path.with_name(base_path.name + "_sample_idx.npy") - ) - self._shuffle_idx = MemmapArray( - None if base_path is None else base_path.with_name(base_path.name + "_shuffle_idx.npy") - ) - - # Build the indexed mapping if it doesn't exist. - if base_path is None or ( - sampling.distributed.config.rank == sampling.get_next_rank() - and not (self._doc_idx.exists() and self._sample_idx.exists() and self._shuffle_idx.exists()) - ): - self._sample() - - def _sample(self) -> None: - """ - Create a `GPTSampledDataset` with the requested parameters. - """ - logger.info(f" > Sampling dataset {self._indexed_dataset.name} ...") - document_sizes = self._indexed_dataset.get_document_sizes() - num_documents = len(document_sizes) - num_tokens = document_sizes.sum() - np_rng = np.random.RandomState(seed=self._config.seed) - - num_epochs = math.ceil((self._parameters.sequence_length * self._parameters.num_samples + 1) / num_tokens) - main_epochs_samples = ((num_epochs - 1) * num_tokens - 1) // self._parameters.sequence_length - last_epoch_samples = self._parameters.num_samples - main_epochs_samples - samples_per_epoch = (num_tokens - 1) // self._parameters.sequence_length - separate_last_epoch = num_epochs > 1 and last_epoch_samples < 0.8 * samples_per_epoch - - doc_idx = np.tile(np.arange(num_documents, dtype=np.int32), num_epochs) - if separate_last_epoch: - np_rng.shuffle(doc_idx[:-num_documents]) - np_rng.shuffle(doc_idx[-num_documents:]) - else: - np_rng.shuffle(doc_idx) - - assert _extension_available, ( - "The C++ extension for dataset sampling is missing." " Please make sure Fast-LLM is installed correctly." - ) - - sample_idx = build_sample_idx( - document_sizes, - doc_idx, - self._parameters.sequence_length, - num_epochs, - num_tokens, - True, - ) - - total_size = sample_idx.shape[0] - 1 - shuffle_idx = np.arange( - 0, total_size, dtype=np.int64 if total_size >= (np.iinfo(np.uint32).max - 1) else np.uint32 - ) - if separate_last_epoch: - np_rng.shuffle(shuffle_idx[:main_epochs_samples]) - np_rng.shuffle(shuffle_idx[main_epochs_samples:]) - else: - np_rng.shuffle(shuffle_idx) - - Assert.geq(len(shuffle_idx), self._parameters.num_samples) - self._doc_idx.save(doc_idx) - self._sample_idx.save(sample_idx) - self._shuffle_idx.save(shuffle_idx[: self._parameters.num_samples]) - - def __len__(self) -> int: - return self._parameters.num_samples - - def __getitem__(self, idx: int) -> typing.Any: - """ - Get the sample, (fixed-length sequence of tokens holding one or more complete or partial documents) - with the requested sampling index. - The returned sample is ready to be concatenated, then fed to a `GPTModel` (see `GPTModel.preprocess`). - """ - # Get the shuffled index. - shuffled_idx = self._shuffle_idx[idx] - # Start and end documents and offsets. - doc_f, offset_f = self._sample_idx[shuffled_idx] - doc_l, offset_l = self._sample_idx[shuffled_idx + 1] - sample_list = [ - self._indexed_dataset.get( - self._doc_idx[doc].item(), - offset=(doc == doc_f) * offset_f, - length=offset_l + 1 - (doc == doc_f) * offset_f if doc == doc_l else None, - use_loss_masking_spans=self._parameters.use_loss_masking_spans, - ) - for doc in range(doc_f, doc_l + 1) - ] - token_ids = np.concatenate([sample.token_ids for sample in sample_list], dtype=np.int64) - Assert.eq(len(token_ids), self._parameters.sequence_length + 1) - - if self._parameters.use_loss_masking_spans: - spans = [] - offset = 0 - for sample in sample_list: - for span in sample.loss_masking_spans: - spans.append(span + offset) - offset += len(sample.token_ids) - spans = np.stack(spans, dtype=np.int32) if spans else np.array([]) - else: - spans = None - sequence_lengths = ( - np.array( - [sample.token_ids.size - (idx == len(sample_list) - 1) for idx, sample in enumerate(sample_list)], - dtype=np.int32, - ) - if not self._parameters.cross_document_attention - else None - ) - return GPTSample(token_ids=token_ids, loss_masking_spans=spans, sequence_lengths=sequence_lengths) - - @property - def name(self) -> str: - return self._indexed_dataset.name diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py index c878cec0a..3f1970538 100644 --- a/fast_llm/engine/checkpoint/config.py +++ b/fast_llm/engine/checkpoint/config.py @@ -4,7 +4,6 @@ import logging import pathlib import typing -import warnings import yaml @@ -58,9 +57,7 @@ def __fast_llm_serialize__(cls) -> str: class DistributedCheckpointFormat(CheckpointFormat): - # TODO v0.3: Add `enforce_version_match` name: typing.ClassVar[str] = "distributed" - enforce_architecture_match: typing.ClassVar[bool] = True @classmethod def get_handler_class(cls) -> type["DistributedCheckpointHandler"]: @@ -125,17 +122,6 @@ class CheckpointStateConfigBase(CheckpointConfigBase): model_weights: bool = Field(default=True, hint=FieldHint.feature) optimizer_state: bool = Field(default=None, hint=FieldHint.feature) - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - cls._handle_renamed_field(default, "load_weights", "model_weights") - cls._handle_renamed_field(default, "load_optimizer", "optimizer_state") - return super()._from_dict(default, strict, flat) - @config_class() class CheckpointSaveConfigBase(CheckpointConfigBase): @@ -204,23 +190,6 @@ class CheckpointLoadMetadataConfig(CheckpointPathConfigBase): hint=FieldHint.core, ) - def _validate(self) -> None: - if self.load_config == "architecture": - raise NotImplementedError("load_config==`architecture` is no longer supported.") - super()._validate() - if ( - self.format in (DistributedCheckpointFormat, FastLLMCheckpointFormat) - and "load_config" not in self._explicit_fields - ): - warnings.warn( - "The default behaviour for model configuration loading has changed (May 2025)." - "All model parameters are now loaded, not just the architecture parameters." - "Please make sure this doesn't lead to unexpected breaking changes." - "Suppress this warning by setting `load_config = model` explicitly.", - ) - if self.format.enforce_architecture_match: - assert self.load_config.load_base_model - @config_class() class CheckpointLoadConfig(CheckpointLoadMetadataConfig, CheckpointStateConfigBase): diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index 7faf599f7..c2f4d8cdd 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -71,18 +71,9 @@ def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: framework="pt", device=str(self._model.distributed.device), ) as f: - if "state_shard" in f.keys(): - # Old format `state_shard` with shape `(num_shards, shard_size) - # TODO v0.3: Use checkpoint version? Drop support? - log_main_rank("Using legacy distributed checkpoint loader.", log_fn=logger.warning) - for shard_name in shard_names: - self._model.get_shard(shard_name).copy_( - f.get_slice("state_shard")[loaded_metadata.shards.index(shard_name)] - ) - else: - # TODO: Does this copy twice? - for shard_name in shard_names: - self._model.get_shard(shard_name).copy_(f.get_tensor(f"{shard_name}_shard")) + # TODO: Does this copy twice? + for shard_name in shard_names: + self._model.get_shard(shard_name).copy_(f.get_tensor(f"{shard_name}_shard")) else: log_main_rank("Checkpoint format doesn't match, using safe load", log_fn=logger.info) @@ -105,18 +96,7 @@ def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: # TODO: Lazy loading? with safetensors.safe_open(path, framework="pt", device=str(self._model.distributed.device)) as f: # TODO: Use self_shard - if "state_shard" in f.keys(): - # Old format `state_shard` with shape `(num_shards, shard_size) - # TODO v0.3: Use checkpoint version? Drop support? - log_main_rank("Using legacy distributed checkpoint loader.", log_fn=logger.warning) - loaded_shards = { - shard_name: f.get_slice("state_shard")[loaded_metadata.shards.index(shard_name)] - for shard_name in shard_names - } - else: - loaded_shards = { - shard_name: f.get_tensor(f"{shard_name}_shard") for shard_name in shard_names - } + loaded_shards = {shard_name: f.get_tensor(f"{shard_name}_shard") for shard_name in shard_names} self._copy_shard_overlaps(loaded_model, loaded_shards, context) diff --git a/fast_llm/engine/config_utils/initialization.py b/fast_llm/engine/config_utils/initialization.py index 7fefda4b0..2f12a45d2 100644 --- a/fast_llm/engine/config_utils/initialization.py +++ b/fast_llm/engine/config_utils/initialization.py @@ -26,16 +26,11 @@ class InitializationConfig(Config, Initialization): is_default: typing.ClassVar[bool] = False @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: if cls is InitializationConfig and cls.get_subclass(default.get("type")) is None: # Default subclass. - return DefaultInitializationConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) + return DefaultInitializationConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) @config_class() diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 1fc0c626d..1737f4308 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -20,7 +20,7 @@ @config_class() class RunConfig(Config): tensor_logs: TensorLogsConfig = Field(desc="Configuration for debug tensor logs.", hint=FieldHint.logging) - # TODO v0.3: Adjust (now only affects logging to file). + # TODO: Adjust (now only affects logging to file). structured_logs: bool = Field( default=True, desc="Configure logging to the Fast-LLM format.", hint=FieldHint.logging ) diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index 9ec63517c..602c44a4e 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -104,7 +104,7 @@ class DistributedConfig(Config): """ Configuration for the distributed setup. Also include variables for global settings such as data types, random seeds, initialization parameters. - TODO v0.3: Move these unrelated variables elsewhere. + TODO: Move these unrelated variables elsewhere. TODO: Avoid hard-coding distributed dims (use derived class?) TODO: Separate distributed space from config? """ @@ -181,19 +181,19 @@ class DistributedConfig(Config): valid=check_field(Assert.gt, 0), ) seed: int = Field(default=1234, desc="A seed for training.", hint=FieldHint.optional) - # TODO v0.3: Rename to compute_dtype (not just for training), move elsewhere - training_dtype: DataType = Field( + # TODO: Rename to compute_dtype (not just for training), move elsewhere + compute_dtype: DataType = Field( default=DataType.float32, desc="The data type used for the forward and backward passes.", hint=FieldHint.core, ) - # TODO v0.3: move elsewhere + # TODO : move elsewhere optimization_dtype: DataType = Field( default=DataType.float32, desc="The data type used for the optimizer.", hint=FieldHint.expert, ) - # TODO v0.3: move random state elsewhere + # TODO: move random state elsewhere # Extra seed parameters (can usually be left alone) dp_seed_shift: int = Field( default=_BIG_PRIMES[0], desc="Seed shift for extra randomness.", hint=FieldHint.optional @@ -378,13 +378,3 @@ def _log_on_rank[ def log_first_rank[T](self, *message, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info): return self._log_on_rank(*message, rank=0, log_fn=log_fn) - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - cls._handle_renamed_field(default, "distributed_timeout", "timeout") - return super()._from_dict(default, strict, flat) diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index dc41539c0..2e2f9d401 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -238,7 +238,7 @@ def check_config(self, config: DistributedConfig) -> None: def set_step(self, step: int, phase: PhaseType) -> None: """ Reseed pytorch for a given training step. - TODO v0.3: Move unrelated content elsewhere. + TODO: Move unrelated content elsewhere. """ seed_shift = step * self._config.sample_seed_shift + self._phase_seeds_shifts[phase] self.pp_generator.manual_seed((self._pp_seed + seed_shift) % MAX_SEED) diff --git a/fast_llm/engine/evaluation/__init__.py b/fast_llm/engine/evaluation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index 04e4227f1..4eb5d71df 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -27,16 +27,11 @@ class EvaluatorConfig(EvaluatorConfigBase): _abstract: typing.ClassVar[bool] = True @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - # TODO v0.x: Remove backward compatibility. - if not "type" in default: - default["type"] = "loss" - return super()._from_dict(default, strict, flat) + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: + if cls is EvaluatorConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return LossEvaluatorConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) @config_class(dynamic_type={EvaluatorConfig: "loss"}) diff --git a/fast_llm/engine/evaluation/lm_eval/__init__.py b/fast_llm/engine/evaluation/lm_eval/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/engine/inference/config.py b/fast_llm/engine/inference/config.py index b414323e4..d19e2478d 100644 --- a/fast_llm/engine/inference/config.py +++ b/fast_llm/engine/inference/config.py @@ -25,7 +25,7 @@ def __init__(self, fast_llm_config: FastLLMModelConfig | None = None, **kwargs): self.use_cache = kwargs.pop("use_cache", True) super().__init__(**kwargs) if self.torch_dtype is not None: - assert self.torch_dtype == self.fast_llm_config.distributed.training_dtype.torch + assert self.torch_dtype == self.fast_llm_config.distributed.compute_dtype.torch def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs) -> None: # Hack the method to save at the right place. @@ -90,7 +90,7 @@ def _get_config_dict( updates = {} torch_dtype = kwargs.pop("torch_dtype", None) if torch_dtype is not None: - updates[("distributed", "training_dtype")] = torch_dtype + updates[("distributed", "compute_dtype")] = torch_dtype fast_llm_config = cls.model_config_class.from_metadata( pretrained, metadata, default=kwargs.pop("fast_llm_config", None), updates=updates ) diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 719088057..aa18f5052 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -351,48 +351,28 @@ class CheckpointMetadata(Config): def _validate(self) -> None: if isinstance(self.fast_llm_version, str): self.fast_llm_version = packaging.version.Version(self.fast_llm_version) - + code_version = packaging.version.Version(__version__) self.format = self.model.get_checkpoint_format(self.format) super()._validate() - if self.fast_llm_version.major != 0 or self.fast_llm_version.minor not in (0, 1, 2): - raise ValueError(f"Invalid checkpoint version: {self.fast_llm_version}") + if self.fast_llm_version > code_version: + raise ValueError(f"Unknown checkpoint version: {self.fast_llm_version}") + if self.fast_llm_version < packaging.version.Version("0.3.0"): + raise ValueError( + f"Checkpoint version {self.fast_llm_version} is no longer supported." + " If you really need this checkpoint," + " please convert it to an external model first using a compatible Fast-LLM version." + ) Assert.eq(self.config.__class__, self.model) @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - # TODO v0.3: Remove backward compatibility. - cls._handle_renamed_field(default, "checkpoint_type", "format") - cls._handle_renamed_field(default, "checkpoint_version", "fast_llm_version") - cls._handle_renamed_field(default, "fast_llm_config", "config") - cls._handle_renamed_field(default, "state_shard_names", "shards") - if "model" not in default: - default["model"] = "gpt" - if "format" not in default: - default["format"] = DistributedCheckpointFormat - if "fast_llm_version" not in default: - default["fast_llm_version"] = "0" - + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: model_config_class = default["model"] if isinstance(model_config_class, str): model_config_class = FastLLMModelConfig.get_subclass(default["model"]) default["model"] = model_config_class - # TODO v0.3: Remove backward compatibility. - if "config" not in default: - default["config"] = { - "base_model": model_config_class.get_base_model_config_class().from_flat_dict( - default.pop("model_config", {}) - ), - "multi_stage": default.pop("multi_stage_config", {}), - "distributed": default.pop("distributed_config", {}), - } # Instantiate the config with the appropriate class config = default.get("config", {}) if isinstance(config, dict): default["config"] = model_config_class.from_dict(config) - return super()._from_dict(default, strict, flat) + return super()._from_dict(default, strict) diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index 09ee788e6..6a6223cb7 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -54,14 +54,10 @@ def from_pretrained( metadata = cls.config_class.load_metadata(pretrained_config) config = cls.config_class.from_dict(metadata.config, *updates, update_type=UpdateType.update) if mode.support_training: - # TODO v0.3: Make metadata.shards mandatory? - if metadata.shards: - if optimizer_state_names is None: - optimizer_state_names = metadata.shards[1:] - else: - Assert.eq(optimizer_state_names, metadata.shards[1:]) - elif optimizer_state_names is None: - raise ValueError("`optimizer_state_names` is required") + if optimizer_state_names is None: + optimizer_state_names = metadata.shards[1:] + else: + Assert.eq(optimizer_state_names, metadata.shards[1:]) else: assert optimizer_state_names is None optimizer_state_names = () diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index cb0a02a67..868cc2db4 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -84,7 +84,7 @@ def __init__( dtype=( self._distributed_config.optimization_dtype if full_precision_shards - else self._distributed_config.training_dtype + else self._distributed_config.compute_dtype ).torch, ) # TODO: Distinguish grad and optimizer shard? @@ -94,13 +94,13 @@ def __init__( dtype=( self._distributed_config.optimization_dtype if full_precision_shards - else self._distributed_config.training_dtype + else self._distributed_config.compute_dtype ).torch, ) self._weight_buffer_meta = TensorMeta.from_dims( (TensorDim("weight_buffer", weight_shard_dim.size * self._fsdp_dim.size),), tensor_name=f"{self._name}_weight_buffer", - dtype=self._distributed_config.training_dtype.torch, + dtype=self._distributed_config.compute_dtype.torch, ) self._grad_buffer_meta = TensorMeta.from_dims( (TensorDim("grad_buffer", weight_shard_dim.size * self._fsdp_dim.size if self._requires_grad else 0),), @@ -108,7 +108,7 @@ def __init__( dtype=( self._distributed_config.optimization_dtype if full_precision_gradient_buffer - else self._distributed_config.training_dtype + else self._distributed_config.compute_dtype ).torch, ) diff --git a/fast_llm/engine/optimizer/optimizer.py b/fast_llm/engine/optimizer/optimizer.py index e72901e6e..0dd094390 100644 --- a/fast_llm/engine/optimizer/optimizer.py +++ b/fast_llm/engine/optimizer/optimizer.py @@ -19,7 +19,7 @@ def get_grad_scaler(config: GradientScalerConfig, distributed: Distributed) -> " initial_scale=config.constant, distributed=distributed, ) - elif distributed.config.training_dtype == DataType.float16: + elif distributed.config.compute_dtype == DataType.float16: return DynamicGradScaler( initial_scale=config.initial, min_scale=config.minimum, diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 8c9e035d9..531bc206b 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -170,17 +170,6 @@ def get_evaluator( return TrainingEvaluator(name, self, batch_config, data_load_num_proc, train_iters) - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - # TODO v0.x: Remove backward compatibility. - cls._handle_renamed_field(default, "iterations", ("evaluator", "iterations")) - return super()._from_dict(default, strict, flat) - @config_class() class TrainingCheckpointBaseConfig(IntervalConfig): @@ -234,10 +223,7 @@ class TrainingCheckpointConfig(TrainingCheckpointBaseConfig): keep: int | None = FieldUpdate(default=5) def get_save_directory(self, experiment_directory: pathlib.Path) -> pathlib.Path: - # TODO v0.3: Remove backward compatibility. - old_path = experiment_directory / "checkpoints" - new_path = experiment_directory / "checkpoint" - return old_path if old_path.is_dir() and not new_path.is_dir() else new_path + return experiment_directory / "checkpoint" def get_save_config(self, path: pathlib.Path, timeout: float | None) -> CheckpointSaveConfig: return CheckpointSaveConfig( @@ -329,18 +315,6 @@ class TrainingConfig(Config): valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - # TODO v0.x: Remove backward compatibility. - cls._handle_renamed_field(default, "validation", ("evaluators", "validation")) - cls._handle_renamed_field(default, "evaluations", ("evaluators")) - return super()._from_dict(default, strict, flat) - def _validate(self) -> None: super()._validate() self.shutdown.assert_sub_interval(self.checkpoint) diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 7db9b1fc3..a752bec28 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -541,7 +541,7 @@ def _prepare_training_state(self) -> None: def _save_checkpoint( self, config: TrainingCheckpointBaseConfig, metrics: dict[str, dict[str, float | int]] | None ) -> None: - # TODO v0.3: Move barrier, ok file to FastLLMModel + # TODO: Move barrier, ok file to FastLLMModel checkpoint_base_directory = config.get_save_directory(self._run.experiment_directory) checkpoint_directory = checkpoint_base_directory / str(self._completed_steps) @@ -600,7 +600,7 @@ def _load_checkpoint(self, config: TrainingCheckpointConfig, iteration: int) -> self._completed_steps = metadata["schedules"][PhaseType.training.value]["completed_steps"] else: self._completed_steps = metadata["completed_steps"] - # TODO v0.3: Move barrier, ok file to FastLLMModel + # TODO: Move barrier, ok file to FastLLMModel safe_barrier( self._distributed.world_group, f"load {config.save_name} {iteration} exit", diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 214bb7729..2910c7c76 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -120,7 +120,7 @@ def layer_class(self) -> "type[Attention]": return Attention def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: - return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) + return self.use_flash_attention and distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16) def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: # We have multiple identical rotary modules/preprocessors, so it's simpler to make a new one here. diff --git a/fast_llm/layers/attention/preprocessing.py b/fast_llm/layers/attention/preprocessing.py index 2326b1bf7..204c08ad2 100644 --- a/fast_llm/layers/attention/preprocessing.py +++ b/fast_llm/layers/attention/preprocessing.py @@ -39,8 +39,8 @@ def _create_tensors(self, sequence_length: int, device: torch.device) -> None: self._mask.triu_(-self._config.window_size + 1) self._mask_value = torch.full( [], - torch.finfo(self._distributed_config.training_dtype.torch).min, - dtype=self._distributed_config.training_dtype.torch, + torch.finfo(self._distributed_config.compute_dtype.torch).min, + dtype=self._distributed_config.compute_dtype.torch, device=device, ) @@ -80,7 +80,7 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: kwargs[AttentionKwargs.attention_mask_value] = TensorMeta.from_dims( (scalar_dim,), tensor_name=AttentionKwargs.attention_mask_value, - dtype=self._distributed_config.training_dtype.torch, + dtype=self._distributed_config.compute_dtype.torch, ) diff --git a/fast_llm/layers/attention/rotary/config.py b/fast_llm/layers/attention/rotary/config.py index 43bae8c54..5bd7a9b87 100644 --- a/fast_llm/layers/attention/rotary/config.py +++ b/fast_llm/layers/attention/rotary/config.py @@ -18,16 +18,11 @@ class RotaryConfig(BaseModelConfig): # TODO: Move rotary to its own submodule. @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: if cls is RotaryConfig and cls.get_subclass(default.get("type")) is None: # Default subclass. - return NoRotaryConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) + return NoRotaryConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) def get_layer(self, head_size_dim: TensorDim) -> "Rotary": return self._get_configurable_class()(self, head_size_dim) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 7df2705fa..df5bd8181 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -69,18 +69,13 @@ class BlockConfig(BaseBlockConfig): """ @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: if cls is BlockConfig and cls.get_subclass(default.get("type")) is None: from fast_llm.layers.decoder.config import DecoderBlockConfig # Default subclass. - return DecoderBlockConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) + return DecoderBlockConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) @property def layer_class(self) -> "type[Block]": @@ -107,16 +102,11 @@ def get_block( @config_class(registry=True) class BlockSequenceConfig(BaseModelConfig): @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: if cls is BlockSequenceConfig and cls.get_subclass(default.get("type")) is None: # Default subclass. - return FixedBlockSequenceConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) + return FixedBlockSequenceConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) @abc.abstractmethod def __len__(self) -> int: diff --git a/fast_llm/layers/common/normalization/config.py b/fast_llm/layers/common/normalization/config.py index 33cbd9768..c1ced10df 100644 --- a/fast_llm/layers/common/normalization/config.py +++ b/fast_llm/layers/common/normalization/config.py @@ -52,16 +52,11 @@ def get_layer( return out @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: if cls is NormalizationConfig and cls.get_subclass(default.get("type")) is None: # Default subclass. - return LayerNormalizationConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) + return LayerNormalizationConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) @config_class(dynamic_type={NormalizationConfig: "none"}) @@ -107,20 +102,6 @@ class LayerNormalizationBaseConfig(NormalizationConfig): def module_class(self): pass - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - cls._handle_renamed_field(default, "normalization_type", "type") - cls._handle_renamed_field(default, "layer_norm_eps", "epsilon") - cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") - cls._handle_renamed_field(default, "normalization_implementation", "implementation") - cls._handle_renamed_field(default, "layer_norm_init_range", "initialization_range") - return super()._from_dict(default, strict, flat) - @config_class(dynamic_type={NormalizationConfig: "layer_norm"}) class LayerNormalizationConfig(LayerNormalizationBaseConfig): diff --git a/fast_llm/layers/common/peft/config.py b/fast_llm/layers/common/peft/config.py index d0af61cee..6c7656839 100644 --- a/fast_llm/layers/common/peft/config.py +++ b/fast_llm/layers/common/peft/config.py @@ -15,16 +15,11 @@ class PeftConfig(Config): _abstract = True @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: if cls is PeftConfig and cls.get_subclass(default.get("type")) is None: # Default subclass. - return NoPeftConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) + return NoPeftConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) def apply_linear( self, diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 2d8cc71fd..5f8131b5c 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -45,18 +45,13 @@ class MLPBaseConfig(BlockWithBiasConfig): _abstract = True @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: if cls is MLPBaseConfig and cls.get_subclass(default.get("type")) is None: from fast_llm.layers.decoder.mlp.config import MLPConfig # Default subclass. - return MLPConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) + return MLPConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) @config_class(registry=True) @@ -66,18 +61,13 @@ class MixerConfig(BlockWithBiasConfig): """ @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: if cls is MixerConfig and cls.get_subclass(default.get("type")) is None: from fast_llm.layers.attention.config import AttentionConfig # Default subclass. - return AttentionConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) + return AttentionConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) @config_class(dynamic_type={BlockConfig: "decoder"}) diff --git a/fast_llm/layers/decoder/mlp/mlp.py b/fast_llm/layers/decoder/mlp/mlp.py index fe4879e73..9dd17d698 100644 --- a/fast_llm/layers/decoder/mlp/mlp.py +++ b/fast_llm/layers/decoder/mlp/mlp.py @@ -87,7 +87,7 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c dims = (input_.dims[0], kwargs[AttentionKwargs.sequence_q_dim], self._intermediate_2_dim) # Also adjust the dtype in case of full-precision residual layer_2_input = TensorMeta.from_dims( - dims, tensor_name="intermediate_1", dtype=self._distributed_config.training_dtype.torch + dims, tensor_name="intermediate_1", dtype=self._distributed_config.compute_dtype.torch ) # TODO: Add marginal compute? (ex. activation, gate + up) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 849e09aa9..f59b4cffd 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -350,20 +350,6 @@ class LanguageModelBaseConfig(BaseModelConfig): hint=FieldHint.testing, ) - @classmethod - def from_flat_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - ) -> typing.Self: - # The backward compatibility fix in `NormalizationArchitectureConfig` - # won't work for older checkpoints saved with a flat config. - # TODO v0.3: Remove flat format - cls._handle_renamed_field(default, "normalization_type", "type") - cls._handle_renamed_field(default, "layer_norm_eps", "epsilon") - cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") - return super().from_flat_dict(default, strict) - def __len__(self) -> int: return len(self.decoder) + 2 * self.output_layer.prediction_heads diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index e0661cfa2..1d1e13a5b 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -51,7 +51,7 @@ def __init__( self._residual_dtype = ( self._distributed_config.optimization_dtype if self._config.full_precision_residual - else self._distributed_config.training_dtype + else self._distributed_config.compute_dtype ).torch self._sequence_parallel = self._distributed_config.sequence_tensor_parallel self._vocab_parallel = self._distributed_config.tensor_parallel > 1 and self._config.vocab_parallel diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 9cd77ff37..8fbb99cad 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -88,26 +88,6 @@ class GPTBaseModelConfig(LanguageModelBaseConfig): default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing ) - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - # TODO v0.3: Remove backward compatibility fix - if "transposed_mlp_weight" in default: - assert default.pop("transposed_mlp_weight") - if "match_megatron" in default: - assert "use_megatron_initialization" not in default - default["use_megatron_initialization"] = default.pop("match_megatron") - if "layer_norm_impl" in default: - assert "normalization_implementation" not in default - default["normalization_implementation"] = default.pop("layer_norm_impl") - if "fused_mlp" in default: - del default["fused_mlp"] - return super()._from_dict(default, strict, flat) - @config_class(dynamic_type={FastLLMModelConfig: "gpt"}) class GPTModelConfig(FastLLMModelConfig): @@ -197,29 +177,6 @@ def _validate(self) -> None: ) Assert.geq(output_layer.prediction_heads, output_layer.prediction_heads) - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - # TODO v0.x: Remove backward compatibility. - cls._handle_renamed_field( - default, ("data", "sampling", "use_loss_masking_spans"), ("batch", "use_loss_masking_spans") - ) - if "truncate_documents" in default.get("data", {}): - # Backward compatibility for the legacy truncate_documents field. - # TODO v0.x: Remove backward compatibility. - logger.warning( - "`data.truncate_documents` field is deprecated. " "Please use `batch.truncate_documents` instead." - ) - assert "truncate_documents" not in default.get("batch", {}) - if "batch" not in default: - default["batch"] = {} - default["batch"]["truncate_documents"] = default["data"].pop("truncate_documents") - return super()._from_dict(default, strict, flat) - @classmethod def get_trainer_class(cls) -> type["GPTTrainer"]: from fast_llm.models.gpt.trainer import GPTTrainer diff --git a/tests/data/common.py b/tests/data/common.py index 6614accce..d8cc6fff2 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -77,9 +77,8 @@ def get_test_data_and_compare_samples( sequence_length: int = 512, vocab_size=TEST_VOCAB_SIZE, expected_samples: dict[str, list[list[int]]] | list[list[int]], - legacy: bool = False, ) -> GPTData: - distributed_config = DistributedConfig(seed=seed if legacy else 87522) + distributed_config = DistributedConfig(seed=87522) distributed = Distributed(distributed_config, use_cpu=True) if isinstance(samples_per_dataset, int): samples_per_dataset = {PhaseType.training.value.lower(): samples_per_dataset} @@ -97,11 +96,7 @@ def get_test_data_and_compare_samples( expected_samples = {PhaseType.training.value.lower(): expected_samples} assert "sampling" not in config - config["sampling"] = GPTSamplingConfig( - seed=87522 if legacy else seed, - gpu=gpu, - shuffle=shuffle, - ) + config["sampling"] = GPTSamplingConfig(seed=seed, gpu=gpu, shuffle=shuffle) data = GPTData(GPTDataConfig.from_dict(config), distributed_config) data.setup(distributed, sampling_parameters, cache_directory) with NoAutoValidate(): diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 312807aad..e64b47020 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -46,17 +46,6 @@ def _get_blending_alt(probs: list[float], num_samples: int) -> tuple[np.ndarray, [3036, 253, 207, 2968, 4536, 1178], ] -GPT_BLENDED_LEGACY_SAMPLES = [ - [1725, 74, 207, 1635, 4440, 2774], - [359, 489, 4266, 2052, 5351, 80], - [328, 80, 263, 890, 1797, 88], - [374, 7534, 87, 1073, 79, 480], - [8008, 498, 71, 727, 80, 315], - [2210, 8179, 73, 2582, 897, 1178], - [1852, 71, 776, 7878, 7390, 80], - [409, 5091, 328, 1378, 5483, 88], -] - GPT_BLENDED_MIXED_SAMPLES = [ [4709, 819, 79, 207, 277, 1790], [916, 6683, 7685, 1277, 5106, 378], @@ -144,7 +133,7 @@ def test_gpt_blended_data(): get_test_data_and_compare_samples( { "datasets": { - "Training": { + "training": { "type": "blended", "datasets": [ {"type": "memmap", "path": DATASET_PREFIX}, @@ -160,22 +149,6 @@ def test_gpt_blended_data(): ) -def test_gpt_blended_data_legacy(): - get_test_dataset() - _get_test_dataset_mix_1() - get_test_data_and_compare_samples( - { - "format": "list", - "path": ["0.75", str(DATASET_PREFIX), "0.25", str(_DATASET_PREFIX_MIX_1)], - "split": [1, 0, 0], - }, - 8, - sequence_length=5, - expected_samples=GPT_BLENDED_LEGACY_SAMPLES, - legacy=True, - ) - - def test_gpt_blended_mixed(): # Make sure dataset blending works and check for unintended changes in behavior. get_test_dataset() @@ -198,7 +171,7 @@ def test_gpt_blended_mixed_data(): get_test_data_and_compare_samples( { "datasets": { - "Training": { + "training": { "type": "blended", "datasets": [{"type": "memmap", "path": DATASET_PREFIX}, {"type": "random"}], "weights": [0.6, 0.4], diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index 6cc5d639a..2c025cbaf 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -44,7 +44,7 @@ def test_gpt_concatenate_data(): get_test_data_and_compare_samples( { "datasets": { - "Training": { + "training": { "type": "concatenated", "datasets": [{"type": "memmap", "path": DATASET_PREFIX} for _ in range(3)], } diff --git a/tests/data/test_concatenated_memmap.py b/tests/data/test_concatenated_memmap.py deleted file mode 100644 index 35d93d9d5..000000000 --- a/tests/data/test_concatenated_memmap.py +++ /dev/null @@ -1,78 +0,0 @@ -import pytest - -from fast_llm.data.dataset.gpt.config import GPTConcatenatedMemmapConfig -from tests.data.common import ( - compare_indexed_dataset, - get_dataset_config, - get_sampling_data, - get_test_data_and_compare_samples, - validate_indexed_dataset_sampling, -) -from tests.data.test_memmap import MEMMAP_DATASET_SAMPLES -from tests.utils.dataset import get_test_concatenated_memmap_dataset -from tests.utils.global_variables import DATASET_CACHE - -_DATASET_PREFIX_MIX_CONCATENATED_MEMMAP = DATASET_CACHE / "concatenated_memmap" - - -def _get_test_dataset_concatenated_memmap(): - return get_test_concatenated_memmap_dataset(_DATASET_PREFIX_MIX_CONCATENATED_MEMMAP, 4) - - -CONCATENATED_MEMMAP_DATASET_LENGTH = 24806 -CONCATENATED_MEMMAP_DATASET_TOKENS = 2033639 -CONCATENATED_MEMMAP_DATASET_SAMPLES = { - **MEMMAP_DATASET_SAMPLES, - 6930: [65, 2327], - 11962: [7078, 2713, 1431], - 15958: [207], - 19362: [69], - 24098: [555, 668, 70], -} -CONCATENATED_MEMMAP_SAMPLES = [ - [7554, 80, 5970, 87, 477, 4119], - [4119, 6506, 74, 447, 87, 277], - [277, 320, 2597, 4117, 301, 727], - [727, 330, 3067, 2740, 81, 417], - [417, 1486, 542, 248, 540, 1364], - [1364, 7072, 2516, 2455, 79, 207], - [207, 727, 2204, 2379, 540, 1322], - [1322, 365, 2009, 72, 489, 1886], -] - - -def test_gpt_concatenated_memmap(): - # Make sure dataset splitting works and check for unintended changes in behavior. - _get_test_dataset_concatenated_memmap() - # samples[9:18] - with pytest.warns(DeprecationWarning): - dataset = get_dataset_config( - {"type": "concatenated_memmap", "path": _DATASET_PREFIX_MIX_CONCATENATED_MEMMAP}, - GPTConcatenatedMemmapConfig, - ).build() - compare_indexed_dataset( - dataset, - CONCATENATED_MEMMAP_DATASET_LENGTH, - CONCATENATED_MEMMAP_DATASET_TOKENS, - CONCATENATED_MEMMAP_DATASET_SAMPLES, - ) - sampled = dataset.sample(get_sampling_data(8, sequence_length=5)) - validate_indexed_dataset_sampling(sampled, CONCATENATED_MEMMAP_SAMPLES) - - -def test_gpt_concatenated_memmap_data(): - _get_test_dataset_concatenated_memmap() - with pytest.warns(DeprecationWarning): - get_test_data_and_compare_samples( - { - "datasets": { - "Training": { - "type": "concatenated_memmap", - "path": _DATASET_PREFIX_MIX_CONCATENATED_MEMMAP, - } - } - }, - 8, - sequence_length=5, - expected_samples=CONCATENATED_MEMMAP_SAMPLES, - ) diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 551134fd2..c9212d6e3 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -21,17 +21,6 @@ [86, 49152, 89, 542, 395, 89], ] -GPT_FIM_SAMPLES_LEGACY = [ - [1725, 74, 207, 1635, 4440, 2774], - [359, 489, 4266, 2052, 5351, 80], - [86, 49152, 89, 22255, 1073, 79], - [8008, 498, 71, 727, 80, 315], - [2210, 8179, 73, 2582, 897, 1178], - [86, 89, 88, 49152, 87, 49152], - [86, 49152, 83, 744, 89, 64], - [86, 89, 1461, 49152, 87, 49152], -] - def test_gpt_fim(): # Make sure the FIM wrapper works in a simple case and check for unintended changes in behavior. @@ -63,7 +52,7 @@ def test_gpt_fim_data(): get_test_data_and_compare_samples( { "datasets": { - "Training": { + "training": { "type": "fim", "dataset": {"type": "memmap", "path": DATASET_PREFIX}, "rate": 0.5, @@ -80,21 +69,3 @@ def test_gpt_fim_data(): expected_samples=GPT_FIM_SAMPLES, vocab_size=49157, ) - - -def test_gpt_fim_data_legacy(): - get_test_dataset() - get_test_data_and_compare_samples( - { - "format": "list", - "path": [str(DATASET_PREFIX)], - "fim": {"rate": 0.5, "prefix_token": "w", "middle_token": "x", "pad_token": "y", "suffix_token": "z"}, - "tokenizer": {"path": TOKENIZER_PATH}, - "split": [1, 0, 0], - }, - 8, - sequence_length=5, - expected_samples=GPT_FIM_SAMPLES_LEGACY, - legacy=True, - vocab_size=49157, - ) diff --git a/tests/data/test_random.py b/tests/data/test_random.py index 72a6080a7..8e5c61904 100644 --- a/tests/data/test_random.py +++ b/tests/data/test_random.py @@ -26,7 +26,7 @@ def test_gpt_random_data(): get_test_data_and_compare_samples( { "datasets": { - "Training": { + "training": { "type": "random", } } @@ -35,13 +35,3 @@ def test_gpt_random_data(): sequence_length=7, expected_samples=RANDOM_DATASET_EXPECTED_SAMPLES, ) - - -def test_gpt_random_data_legacy(): - get_test_data_and_compare_samples( - {"format": "random"}, - 4, - sequence_length=7, - expected_samples=RANDOM_DATASET_EXPECTED_SAMPLES, - legacy=True, - ) diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index a2996aa1c..6a2be3dcc 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -34,16 +34,6 @@ [1178, 3291, 317, 277, 2679, 89], [89, 542, 395, 583, 684, 554], ] -GPT_MEMMAP_SAMPLES_LEGACY = [ - [1725, 74, 207, 1635, 4440, 2774], - [359, 489, 4266, 2052, 5351, 80], - [374, 7534, 87, 1073, 79, 480], - [8008, 498, 71, 727, 80, 315], - [2210, 8179, 73, 2582, 897, 1178], - [409, 5091, 328, 1378, 5483, 88], - [83, 4457, 3316, 333, 489, 317], - [330, 155, 2449, 1136, 1106, 5370], -] def test_gpt_sampled(): @@ -60,7 +50,7 @@ def test_gpt_sampled_data(): get_test_data_and_compare_samples( { "datasets": { - "Training": { + "training": { "type": "memmap", "path": DATASET_PREFIX, } @@ -72,16 +62,6 @@ def test_gpt_sampled_data(): ) -def test_gpt_sampled_data_legacy(): - get_test_data_and_compare_samples( - {"format": "list", "path": [str(DATASET_PREFIX)], "split": [1, 0, 0]}, - 8, - sequence_length=5, - expected_samples=GPT_MEMMAP_SAMPLES_LEGACY, - legacy=True, - ) - - class SimpleGPTIndexedDataset(GPTIndexedDataset): # TODO: worth adding to the main codebase? def __init__(self, samples): diff --git a/tests/data/test_slice.py b/tests/data/test_slice.py index 1440614cb..1fc8df1eb 100644 --- a/tests/data/test_slice.py +++ b/tests/data/test_slice.py @@ -27,23 +27,6 @@ [3712, 86, 476, 80, 2547, 7390], ] -GPT_SLICE_TRAINING_SAMPLES_LEGACY = [ - [2625, 76, 2625, 2639, 74, 243], - [207, 481, 5546, 74, 414, 498], - [74, 333, 1963, 310, 5337, 3628], - [79, 2361, 80, 2012, 84, 480], -] -GPT_SLICE_VALIDATION_SAMPLES_LEGACY = [ - [2352, 3687, 2311, 4900, 542, 3732], - [2551, 5283, 900, 3140, 328, 68], - [7979, 2283, 329, 727, 2740, 2818], - [4117, 8056, 79, 1798, 243, 498], - [243, 542, 387, 6476, 6686, 785], - [95, 6641, 207, 279, 2304, 602], - [89, 4446, 947, 293, 947, 1544], - [243, 3712, 86, 476, 80, 2547], -] - def test_gpt_slice(): # Make sure dataset splitting works and check for unintended changes in behavior. @@ -89,17 +72,3 @@ def test_gpt_slice_data(): "validation": GPT_SLICE_VALIDATION_SAMPLES, }, ) - - -def test_gpt_slice_data_legacy(): - get_test_dataset() - get_test_data_and_compare_samples( - {"format": "list", "path": [str(DATASET_PREFIX)], "split": [0.0015, 0.0015, 0.997]}, - {"training": 4, "validation": 8, "test": 5}, - sequence_length=5, - expected_samples={ - "training": GPT_SLICE_TRAINING_SAMPLES_LEGACY, - "validation": GPT_SLICE_VALIDATION_SAMPLES_LEGACY, - }, - legacy=True, - ) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index d52564cc0..f14f028e1 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -104,8 +104,8 @@ def _lm_head( ("config_dict", "distributed_config_dict", "loss_masking"), ( ({}, {}, False), - ({}, {"training_dtype": DataType.bfloat16}, False), - ({"embeddings_layer": {"full_precision_residual": True}}, {"training_dtype": DataType.bfloat16}, False), + ({}, {"compute_dtype": DataType.bfloat16}, False), + ({"embeddings_layer": {"full_precision_residual": True}}, {"compute_dtype": DataType.bfloat16}, False), ({"sequence_first": True}, {}, False), ({"output_layer": {"logit_z_loss": 1e-3}}, {}, False), ({"output_layer": {"logits_scale_factor": 5.0}}, {}, False), @@ -195,7 +195,7 @@ def test_lm_head( dtype=( distributed.config.optimization_dtype.torch if config.embeddings_layer.full_precision_residual - else distributed.config.training_dtype.torch + else distributed.config.compute_dtype.torch ), device=distributed.device, requires_grad=True, @@ -239,7 +239,7 @@ def test_lm_head( if config.output_layer.tied_weight or config.output_layer.prediction_heads > 1: logit_weight = ( torch.empty( - VOCAB_SIZE, HIDDEN_SIZE, dtype=distributed.config.training_dtype.torch, device=distributed.device + VOCAB_SIZE, HIDDEN_SIZE, dtype=distributed.config.compute_dtype.torch, device=distributed.device ) .normal_(config.embeddings_layer.hidden_size**-0.5) .requires_grad_(True) @@ -302,9 +302,9 @@ def test_lm_head( output, context = stage.forward(head_input, kwargs, losses) stage.backward(output_grad, context) - threshold = 1e-5 if distributed.config.training_dtype == DataType.float32 else 5e-3 + threshold = 1e-5 if distributed.config.compute_dtype == DataType.float32 else 5e-3 min_threshold = ( - 1e-5 if distributed.config.training_dtype == DataType.float32 else 1e-4 + 1e-5 if distributed.config.compute_dtype == DataType.float32 else 1e-4 ) * config.output_layer.logits_scale_factor Assert.eq(losses.keys(), loss_keys) diff --git a/tests/models/test_generate.py b/tests/models/test_generate.py index ad0de47e6..bce77d4f2 100644 --- a/tests/models/test_generate.py +++ b/tests/models/test_generate.py @@ -61,11 +61,11 @@ def _get_fast_llm_model( updates = {} if use_flash_attention: updates[("base_model", "decoder", "block", "mixer", "use_flash_attention")] = True - updates[("distributed", "training_dtype")] = "bf16" + updates[("distributed", "compute_dtype")] = "bf16" else: updates[("base_model", "decoder", "block", "mixer", "use_flash_attention")] = False if use_bf16: - updates[("distributed", "training_dtype")] = "bf16" + updates[("distributed", "compute_dtype")] = "bf16" return HuggingfaceGPTModelForCausalLM.from_pretrained( CheckpointLoadConfig( path=model_path, @@ -87,11 +87,11 @@ def _get_fast_llm_model_from_model( if use_flash_attention: updates[("model", "base_model", "decoder", "block", "mixer", "use_flash_attention")] = True - updates[("model", "distributed", "training_dtype")] = "bf16" + updates[("model", "distributed", "compute_dtype")] = "bf16" else: updates[("model", "base_model", "decoder", "block", "mixer", "use_flash_attention")] = False if use_bf16: - updates[("model", "distributed", "training_dtype")] = "bf16" + updates[("model", "distributed", "compute_dtype")] = "bf16" config = PretrainedGPTModelConfig.from_dict({}, updates) multi_stage = config.model.get_model_class()(config.model) diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index fdb908b0d..6aa541b8c 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -1,7 +1,15 @@ import os +import typing +import numpy as np import pytest +from fast_llm.config import Field, FieldHint, config_class +from fast_llm.data.dataset.abstract import SampledDataset +from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSampledDatasetConfig, GPTSamplingData +from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset +from fast_llm.data.dataset.gpt.sampled import GPTSample, logger +from fast_llm.utils import Assert from tests.utils.compare_tensor_logs import CompareConfig from tests.utils.dataset import get_model_test_dataset from tests.utils.distributed_configs import DistributedTestingConfig @@ -9,6 +17,13 @@ from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda +try: + from fast_llm.csrc.data import build_sample_idx # noqa + + _extension_available = True +except ImportError: + _extension_available = False + @requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.megatron) @@ -51,9 +66,9 @@ def test_match_megatron(run_test_script_for_all_models, model_testing_config, co name="match_megatron", compare="megatron", config_args=[ - "model.distributed.training_dtype=fp32", - "data.datasets={}", - f"data.path={MODEL_DATASET_PREFIX}", + "model.distributed.compute_dtype=fp32", + f'data.datasets.training={{"type":"megatron","path":{MODEL_DATASET_PREFIX}}}', + "data.sampling.seed=1234", "model.base_model.use_megatron_initialization=True", ], num_gpus=1, @@ -62,3 +77,83 @@ def test_match_megatron(run_test_script_for_all_models, model_testing_config, co run_test_script_for_all_models(distributed_testing_config) compare_results_for_all_models(distributed_testing_config) + + +@config_class(dynamic_type={GPTSampledDatasetConfig: "megatron"}) +class GPTMegatronDatasetConfig(GPTMemmapDatasetConfig): + _abstract: typing.ClassVar[bool] = False + path: str = Field( + desc="Dataset path (prefix).", + hint=FieldHint.core, + ) + + def build(self) -> "GPTMemmapDataset": + return GPTMegatronMemmapDataset( + str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens + ) + + +class GPTMegatronMemmapDataset(GPTMemmapDataset): + def sample(self, sampling: GPTSamplingData) -> "MegatronGPTSampledIndexedDataset": + return MegatronGPTSampledIndexedDataset(self, sampling) + + +class MegatronGPTSampledIndexedDataset(SampledDataset): + """ + A GPT sampled dataset that exactly matches Megatron-LM, for testing purposes. + Minimalistic implementation, implements only the required features. + """ + + def __init__( + self, + indexed_dataset: GPTMegatronMemmapDataset, + sampling: GPTSamplingData, + ): + assert isinstance(sampling, GPTSamplingData) + self._indexed_dataset = indexed_dataset + self._num_samples = sampling.parameters.num_samples + self._sequence_length = sampling.parameters.sequence_length + + logger.info(f" > Sampling dataset {self._indexed_dataset.name} ...") + document_sizes = self._indexed_dataset.get_document_sizes() + num_documents = len(document_sizes) + num_tokens = document_sizes.sum() + np_rng = np.random.RandomState(seed=sampling.config.seed) + + # Assume less than one epoch. + Assert.lt(self._sequence_length * self._num_samples, num_tokens) + + self._doc_idx = np.arange(num_documents, dtype=np.int32) + np_rng.shuffle(self._doc_idx) + + assert _extension_available, ( + "The C++ extension for dataset sampling is missing." " Please make sure Fast-LLM is installed correctly." + ) + + self._sample_idx = build_sample_idx(document_sizes, self._doc_idx, self._sequence_length, 1, num_tokens, True) + self._shuffle_idx = np.arange(0, self._sample_idx.shape[0] - 1, dtype=np.uint32) + np_rng.shuffle(self._shuffle_idx) + + def __len__(self) -> int: + return self._num_samples + + def __getitem__(self, idx: int) -> typing.Any: + shuffled_idx = self._shuffle_idx[idx] + doc_f, offset_f = self._sample_idx[shuffled_idx] + doc_l, offset_l = self._sample_idx[shuffled_idx + 1] + sample_list = [ + self._indexed_dataset.get( + self._doc_idx[doc].item(), + offset=(doc == doc_f) * offset_f, + length=offset_l + 1 - (doc == doc_f) * offset_f if doc == doc_l else None, + ) + for doc in range(doc_f, doc_l + 1) + ] + token_ids = np.concatenate([sample.token_ids for sample in sample_list], dtype=np.int64) + Assert.eq(len(token_ids), self._sequence_length + 1) + + return GPTSample(token_ids=token_ids) + + @property + def name(self) -> str: + return self._indexed_dataset.name diff --git a/tests/test_attention.py b/tests/test_attention.py index 62c34d3c0..dceaa8282 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -29,7 +29,7 @@ def test_varlen_preprocessor(): micro_sequence_length = 12 sequence_length = 36 varlen_preprocessor = FlashAttnVarlenPreprocessor( - AttentionConfig(head_size=64), DistributedConfig(training_dtype="bfloat16") + AttentionConfig(head_size=64), DistributedConfig(compute_dtype="bfloat16") ) for micro_seq_idx in range(int(sequence_length / micro_sequence_length)): kwargs = { diff --git a/tests/test_config.py b/tests/test_config.py index 4e73569b3..6d2583ba3 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -95,7 +95,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): "output_layer": {"tied_weight": False}, }, "multi_stage": {"zero_stage": 3}, - "distributed": {"training_dtype": "bfloat16"}, + "distributed": {"compute_dtype": "bfloat16"}, } ) with NoAutoValidate(): @@ -121,7 +121,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): { "model": { "base_model": base_model_update, - "distributed": {"seed": 1234, "training_dtype": "float16"}, + "distributed": {"seed": 1234, "compute_dtype": "float16"}, }, "pretrained": {"format": "fast_llm", "path": config_path, "load_config": load_config}, } @@ -131,7 +131,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): if load_config == ModelConfigType.fast_llm: expected_config["multi_stage"] = {"zero_stage": 3} - expected_config["distributed"].update({"seed": 1234, "training_dtype": "float16"}) + expected_config["distributed"].update({"seed": 1234, "compute_dtype": "float16"}) if load_config in (ModelConfigType.fast_llm, ModelConfigType.model): expected_config["base_model"] = { "embeddings_layer": { diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index e4cce2935..680faa931 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -66,25 +66,3 @@ def get_model_test_dataset( vocab_size: int = MODEL_TEST_VOCAB_SIZE, ): return get_test_dataset(prefix=prefix, vocab_size=vocab_size) - - -def get_test_concatenated_memmap_dataset( - path: pathlib.Path, - num_files: int, - seed: int = 1234, - num_tokens: int = TEST_DATASET_TOKENS, - characters: str = TEST_CHARACTERS, - vocab_size: int = TEST_VOCAB_SIZE, - seed_shift: int = 55, -): - index_file = path / "index.txt" - if not index_file.is_file(): - for i in range(num_files): - get_test_dataset( - prefix=path / f"dataset_{i}", - seed=seed + i * seed_shift, - num_tokens=num_tokens, - characters=characters, - vocab_size=vocab_size, - ) - index_file.open("w").writelines([str(path / f"dataset_{i}") + "\n" for i in range(num_files)]) diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index 306beadf8..863be2cae 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -87,14 +87,14 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon DistributedTestingConfig( name="bf16", compare="simple", - config_args=["model.distributed.training_dtype=bf16"], + config_args=["model.distributed.compute_dtype=bf16"], num_gpus=1, compare_config=_bf16_compare, ), DistributedTestingConfig( name="fp16", compare="simple", - config_args=["model.distributed.training_dtype=fp16"], + config_args=["model.distributed.compute_dtype=fp16"], num_gpus=1, compare_config=_fp16_compare, ), diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 55ac4ae74..aa8100126 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -279,6 +279,7 @@ def _update_and_add_testing_config( # Megatron messes with the vocab size, so we have to subtract 1. f"--vocab-size={MODEL_TEST_VOCAB_SIZE - 1}", f"--data-path={MODEL_DATASET_PREFIX}", + "--split=1,0,0", "--lr-decay-style=constant", # Initialization is set up to match MCore models (MCore inverts self-attn qkv and dense layers compared to original Megatron) "--use-mcore-models",