diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index c347a5c70..ed9128c6e 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -74,6 +74,9 @@ class GPTSamplingParameters(SamplingParameters): vocab_size: int use_loss_masking_spans: bool = False cross_document_attention: bool = True + # How many extra tokens to add to the sequence length. + # This is used to provide labels even for the last tokens in the sequence. + extra_tokens: int = 1 @dataclasses.dataclass(kw_only=True) @@ -258,7 +261,7 @@ def build(self) -> SamplableDataset: return config.build() def _load_config(self): - assert self.path.is_file() + assert self.path.is_file(), f"File {self.path} does not exist." return GPTSampledDatasetConfig.from_dict(self._convert_paths(yaml.safe_load(self.path.open("r")))) def _convert_paths(self, config): diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 0dac725b2..f3633a76a 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -145,17 +145,19 @@ def _sample(self) -> None: raise RuntimeError( f" > No documents shorter than {self._parameters.sequence_length+1} tokens found in dataset {self._indexed_dataset.name}." ) - # TODO MTP: Produce more labels to provide labels for the multi-token prediction heads? - # We produce sequences of length `self._sequence_length + 1` so the last token has a label, - # but in case of truncations we also include that last label in the following sample, - # so we need `sequence_length * num_samples + 1` tokens in total. - num_epochs = math.ceil( - ( - (self._parameters.sequence_length + 1 - self._truncate_documents) * self._parameters.num_samples - + 1 * self._truncate_documents + # We produce sequences of length `self._sequence_length + extra_tokens` so the last token has a label for all prediction heads, + # but in case of truncations we also include those last labels in the following sample, + # so we need `sequence_length * num_samples + extra_tokens` tokens in total. + if self._truncate_documents: + num_epochs = math.ceil( + (self._parameters.sequence_length * self._parameters.num_samples + self._parameters.extra_tokens) + / tokens_per_epoch + ) + else: + num_epochs = math.ceil( + ((self._parameters.sequence_length + self._parameters.extra_tokens) * self._parameters.num_samples) + / tokens_per_epoch ) - / tokens_per_epoch - ) # Prepare for shuffling. generator = torch.Generator(device=self._device) @@ -349,8 +351,13 @@ def __getitem__(self, index: int) -> typing.Any: self._lazy_load() # tokens at the boundary are included in only one sample when we pack without truncations # in case of packing with truncations, the last token from the previous sample is also the first token of the next sample - token_start = index * (self._parameters.sequence_length + 1 - self._truncate_documents) - token_end = token_start + self._parameters.sequence_length + 1 + sample_length = ( + self._parameters.sequence_length + if self._truncate_documents + else self._parameters.sequence_length + self._parameters.extra_tokens + ) + token_start = index * sample_length + token_end = token_start + self._parameters.sequence_length + self._parameters.extra_tokens if token_start < self._unshuffled_tokens: token_start_array = self._token_cumsum_unshuffled.array @@ -410,7 +417,9 @@ def __getitem__(self, index: int) -> typing.Any: if self._parameters.use_loss_masking_spans: for loss_masking_span in sample.loss_masking_spans: span = np.clip( - loss_masking_span + token_count - token_start, 0, self._parameters.sequence_length + 1 + loss_masking_span + token_count - token_start, + 0, + self._parameters.sequence_length + self._parameters.extra_tokens, ) if span[1] > span[0]: loss_masking_spans.append(span) @@ -430,7 +439,7 @@ def __getitem__(self, index: int) -> typing.Any: if self._parameters.use_loss_masking_spans else None ) - Assert.eq(len(token_ids), self._parameters.sequence_length + 1) + Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) return GPTSample(token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths) diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index f335015a6..4de3ab3eb 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -1,12 +1,14 @@ import abc import json import pathlib +import shutil import typing import safetensors import torch +from transformers.configuration_utils import PretrainedConfig -from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveMetadataConfig +from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig, CheckpointSaveMetadataConfig from fast_llm.engine.checkpoint.external import ( ConstantExportParamConverter, ExternalStateDictCheckpointHandler, @@ -118,3 +120,33 @@ def _load_weights( yield from torch.load(path) else: raise NotImplementedError(f"Unknown file format for {path}") + + +class CustomModelingExportMixin: + """ + Mixin class for HuggingfaceStateDictCheckpointHandler to handle custom modeling files. + """ + + modeling_file: typing.ClassVar[str] + configuration_file: typing.ClassVar[str] + configuration_cls: typing.ClassVar[type[PretrainedConfig]] + + # Use custom config instead of relying on the transformers library + @classmethod + def _load_config(cls, directory: pathlib.Path | str) -> dict: + config = cls.configuration_cls.from_pretrained(directory).to_dict() + Assert.eq(config["model_type"], cls.get_huggingface_model_type()) + return config + + @classmethod + def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: + cls.configuration_cls.from_dict(config).save_pretrained(directory) + + def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None: + super().save(config, metadata) + self._copy_modeling_files(config) + + def _copy_modeling_files(self, config: CheckpointSaveConfig) -> None: + # Copy the modeling files to the output directory + shutil.copy(self.modeling_file, config.path) + shutil.copy(self.configuration_file, config.path) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 1286121c3..3cc348d06 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -70,11 +70,6 @@ def __init__( Assert.geq(prediction_distance, 0) self._prediction_distance = prediction_distance self.is_last_head = self._prediction_distance == config.prediction_heads - 1 - if self._prediction_distance > 0: - assert ( - not self._sequence_parallel_logits - ), "Sequence parallel logits not supported for multi-token prediction." - assert not self._cross_entropy_splits, "Cross-entropy splits not supported for multi-token prediction." self._init_output_weights(hidden_dim, config) @@ -137,8 +132,9 @@ def forward( # Last head should return the loss for backward. return language_model_loss else: - # Backward hook to compute the gradient of the loss - shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, 1.0) + if self.training: + # Backward hook to compute the gradient of the loss + shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, 1.0) # MTP: Return shared_hidden to be used by the next head. return shared_hidden @@ -147,18 +143,22 @@ def _forward_backward( ) -> tuple[torch.Tensor, torch.Tensor | None]: labels = kwargs[LanguageModelKwargs.labels] if LanguageModelKwargs.labels in kwargs else None # MTP: Shift the labels - labels = labels[:, self._prediction_distance :].flatten() if labels is not None else None + if labels is not None: + labels = ( + labels[self._prediction_distance : self._prediction_distance + input_.size(0),] + if kwargs[TransformerKwargs.sequence_first] + else labels[ + :, + self._prediction_distance : self._prediction_distance + input_.size(1), + ] + ) + labels = labels.flatten() if self._sequence_parallel_logits: labels = split_op(labels, self._tensor_space.distributed.tensor_group, 0) do_grad = labels is not None and self.training input_ = input_.detach().requires_grad_(do_grad) with torch.enable_grad(): - # MTP: truncate the input - if self._prediction_distance > 0: - truncated_input = input_[:, : -self._prediction_distance, :].contiguous() - else: - truncated_input = input_ - ln_output = self.final_norm(truncated_input) + ln_output = self.final_norm(input_) grad_output = kwargs[TransformerKwargs.grad_output] / ( self._group_size if self._sequence_parallel_logits else 1 @@ -197,7 +197,7 @@ def _logits_cross_entropy_forward_backward_split( ) if labels is None: # TODO: Make a proper way of returning the model output. - kwargs["logits"] = loss + kwargs["logits" if self._prediction_distance == 0 else f"logits_{self._prediction_distance}"] = loss return None, None else: loss = None diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index b78c3311b..14e598be0 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -19,6 +19,7 @@ class GPTHuggingfaceCheckpointFormat(CheckpointFormat): support_optimizer: typing.ClassVar[bool] = False + trust_remote_code: typing.ClassVar[bool] = False @classmethod def get_handler_class(cls) -> type[CheckpointHandler]: @@ -51,6 +52,11 @@ class MixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "mixtral" +class MTPLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "mtp_llama" + trust_remote_code: typing.ClassVar[bool] = True + + @config_class() class GPTArchitectureConfig(LanguageModelArchitectureConfig): _abstract = False @@ -145,6 +151,7 @@ class GPTModelConfig(FastLLMModelConfig): Qwen2GPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, + MTPLlamaGPTHuggingfaceCheckpointFormat, ) @classmethod diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 30ae80416..bc8bea266 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -4,6 +4,7 @@ import typing import torch +from transformers.configuration_utils import PretrainedConfig from fast_llm.config import DEFAULT, MISSING from fast_llm.engine.checkpoint.config import CheckpointFormat @@ -20,7 +21,7 @@ SplitWeightConverter, WeightConverter, ) -from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.functional.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex @@ -32,9 +33,11 @@ LlamaGPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, + MTPLlamaGPTHuggingfaceCheckpointFormat, Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) +from fast_llm.models.gpt.external.mtp_llama.configuration_mtp_llama import MTPLlamaConfig from fast_llm.models.gpt.model import GPTModel from fast_llm.tensor import SafeTensorSlice from fast_llm.utils import Assert @@ -173,44 +176,46 @@ def _create_weight_converters( converters += self._create_lm_head_converters() for i in range(num_layers): - converters += self._create_transformer_layer_converters(i) + converters += self._create_transformer_layer_converters(f"layers.{i+1}", f"model.layers.{i}") return converters - def _create_transformer_layer_converters(self, i: int, ignore_export: bool = False) -> list[WeightConverter]: + def _create_transformer_layer_converters( + self, fast_llm_layer_name: str, hf_layer_name: str, ignore_export: bool = False + ) -> list[WeightConverter]: transformer_config: TransformerConfig = self._model.config.base_model.transformer norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm converters = [] names_bias_cls = [ # Self-attn ( - f"layers.{i+1}.self_attn.query", - f"model.layers.{i}.self_attn.q_proj", + f"{fast_llm_layer_name}.self_attn.query", + f"{hf_layer_name}.self_attn.q_proj", transformer_config.add_attn_qkv_bias, QueryWeightConverter, ), ( - f"layers.{i+1}.self_attn.key_value", - (f"model.layers.{i}.self_attn.k_proj", f"model.layers.{i}.self_attn.v_proj"), + f"{fast_llm_layer_name}.self_attn.key_value", + (f"{hf_layer_name}.self_attn.k_proj", f"{hf_layer_name}.self_attn.v_proj"), transformer_config.add_attn_qkv_bias, KeyValueWeightConverter, ), ( - f"layers.{i+1}.self_attn.dense", - f"model.layers.{i}.self_attn.o_proj", + f"{fast_llm_layer_name}.self_attn.dense", + f"{hf_layer_name}.self_attn.o_proj", transformer_config.add_attn_dense_bias, WeightConverter, ), # Norm ( - f"layers.{i+1}.norm_1", - f"model.layers.{i}.input_layernorm", + f"{fast_llm_layer_name}.norm_1", + f"{hf_layer_name}.input_layernorm", norm_bias, WeightConverter, ), ( - f"layers.{i+1}.norm_2", - f"model.layers.{i}.post_attention_layernorm", + f"{fast_llm_layer_name}.norm_2", + f"{hf_layer_name}.post_attention_layernorm", norm_bias, WeightConverter, ), @@ -226,14 +231,20 @@ def _create_transformer_layer_converters(self, i: int, ignore_export: bool = Fal # MLP if ignore_export: converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mlp.layer_1", (), transformer_config.add_mlp_bias, cls=IgnoreExportWeightConverter + f"{fast_llm_layer_name}.mlp.layer_1", + (), + transformer_config.add_mlp_bias, + cls=IgnoreExportWeightConverter, ) converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mlp.layer_2", (), transformer_config.add_mlp_bias, cls=IgnoreExportWeightConverter + f"{fast_llm_layer_name}.mlp.layer_2", + (), + transformer_config.add_mlp_bias, + cls=IgnoreExportWeightConverter, ) - converters += [IgnoreExportWeightConverter(f"layers.{i+1}.mlp.router.weight", ())] + converters += [IgnoreExportWeightConverter(f"{fast_llm_layer_name}.mlp.router.weight", ())] else: - converters += self._get_mlp_converters(f"layers.{i+1}", f"model.layers.{i}") + converters += self._get_mlp_converters(f"{fast_llm_layer_name}", f"{hf_layer_name}") return converters def _create_lm_head_converters(self) -> list[WeightConverter]: @@ -260,7 +271,9 @@ def _create_lm_head_converters(self) -> list[WeightConverter]: ) mtp_transformer_layer_index = num_layers - 1 + 2 * i # MTP transformer layer - converters += self._create_transformer_layer_converters(mtp_transformer_layer_index, ignore_export=True) + converters += self._create_transformer_layer_converters( + f"layers.{mtp_transformer_layer_index + 1}", "", ignore_export=True + ) # MTP output norm converters += self._get_weight_and_bias_converters( f"layers.{mtp_transformer_layer_index + 2}.final_norm", (), norm_bias, IgnoreExportWeightConverter @@ -570,6 +583,88 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ] +class MTPLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonLlamaHuggingfaceCheckpointHandler): + from fast_llm.models.gpt.external.mtp_llama import configuration_mtp_llama, modeling_mtp_llama + + format: typing.ClassVar[type[CheckpointFormat]] = MTPLlamaGPTHuggingfaceCheckpointFormat + modeling_file = modeling_mtp_llama.__file__ + configuration_file = configuration_mtp_llama.__file__ + configuration_cls: typing.ClassVar[type[PretrainedConfig]] = MTPLlamaConfig + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantExportParamConverter(export_names=(("architectures",),), export_value=["MTPLlamaForCausalLM"]), + ConstantExportParamConverter( + export_names=(("auto_map",),), + export_value={ + "AutoConfig": "configuration_mtp_llama.MTPLlamaConfig", + "AutoModel": "modeling_mtp_llama.MTPLlamaModel", + "AutoModelForCausalLM": "modeling_mtp_llama.MTPLlamaForCausalLM", + }, + ), + # TODO: Llama supports biases + ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False), + ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False), + RenameParamConverter( + fast_llm_names=(("prediction_heads",),), + export_names=(("prediction_heads",),), + ), + ] + + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + transformer_config: TransformerConfig = self._model.config.base_model.transformer + return [ + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), + transformer_config.add_mlp_bias, + SplitWeightConverter, + ), + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + transformer_config.add_mlp_bias, + MLPLayer2Converter, + ), + ] + + # Override base method to handle the MTP heads + def _create_lm_head_converters(self) -> list[WeightConverter]: + num_layers = self._model.config.base_model.transformer.num_layers + prediction_heads = self._model.config.base_model.prediction_heads + norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm + converters = [] + + # Next-token prediction head + # Transformer layer is already handled in the transformer layer converters + # Final norm + converters += self._get_weight_and_bias_converters( + f"layers.{num_layers + 1}.final_norm", "model.mtp_norms.0", norm_bias + ) + # Multi-token prediction head + for i in range(1, prediction_heads): + mtp_transformer_layer_index = num_layers - 1 + 2 * i + # MTP transformer layer + converters += self._create_transformer_layer_converters( + f"layers.{mtp_transformer_layer_index + 1}", + f"model.mtp_heads.{i - 1}", + ) + # MTP output norm + converters += self._get_weight_and_bias_converters( + f"layers.{mtp_transformer_layer_index + 2}.final_norm", + f"model.mtp_norms.{i}", + norm_bias, + ) + # Output weights + if self._model.config.base_model.tie_word_embeddings: + converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) + else: + converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) + + return converters + + class AutoGPTHuggingfaceCheckpointHandler( AutoStateDictCheckpointHandler, HuggingfaceStateDictCheckpointHandler, abc.ABC ): @@ -580,4 +675,5 @@ class AutoGPTHuggingfaceCheckpointHandler( Qwen2GPTHuggingfaceCheckpointFormat.name: Qwen2HuggingfaceCheckpointHandler, MistralGPTHuggingfaceCheckpointFormat.name: MistralHuggingfaceCheckpointHandler, MixtralGPTHuggingfaceCheckpointFormat.name: MixtralHuggingfaceCheckpointHandler, + MTPLlamaGPTHuggingfaceCheckpointFormat.name: MTPLlamaHuggingfaceCheckpointHandler, } diff --git a/fast_llm/models/gpt/external/mtp_llama/configuration_mtp_llama.py b/fast_llm/models/gpt/external/mtp_llama/configuration_mtp_llama.py new file mode 100644 index 000000000..5b23f4053 --- /dev/null +++ b/fast_llm/models/gpt/external/mtp_llama/configuration_mtp_llama.py @@ -0,0 +1,202 @@ +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation + + +class MTPLlamaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MTPLlamaModel`]. It is used to instantiate an LLaMA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaMA-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MTPLlamaModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens, + Llama 2 up to 4096, CodeLlama up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to + understand more about it. This value is necessary to ensure exact reproducibility of the pretraining + results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + head_dim (`int`, *optional*): + The attention head dimension. If None, it will default to hidden_size // num_attention_heads + + ```python + >>> from transformers import LlamaModel, LlamaConfig + + >>> # Initializing a LLaMA llama-7b style configuration + >>> configuration = LlamaConfig() + + >>> # Initializing a model from the llama-7b style configuration + >>> model = LlamaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mtp_llama" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `LlamaModel` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + head_dim=None, + prediction_heads=1, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + self.prediction_heads = prediction_heads + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fast_llm/models/gpt/external/mtp_llama/modeling_mtp_llama.py b/fast_llm/models/gpt/external/mtp_llama/modeling_mtp_llama.py new file mode 100644 index 000000000..5ad99ff96 --- /dev/null +++ b/fast_llm/models/gpt/external/mtp_llama/modeling_mtp_llama.py @@ -0,0 +1,963 @@ +from functools import partial +from typing import Callable, Optional, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, QuestionAnsweringModelOutput +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + LossKwargs, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.utils.deprecation import deprecate_kwarg + +from .configuration_mtp_llama import MTPLlamaConfig + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MTPLlamaConfig" + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, config: MTPLlamaConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: MTPLlamaConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: MTPLlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = MTPLlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class MTPLlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: MTPLlamaConfig + """ + + def __init__(self, config: MTPLlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + # MTP heads + self.mtp_heads = nn.ModuleList( + [ + LlamaDecoderLayer(config, layer_idx) + for layer_idx in range( + config.num_hidden_layers, config.num_hidden_layers + config.prediction_heads - 1 + ) + ] + ) + + self.mtp_norms = nn.ModuleList( + [LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in range(config.prediction_heads)] + ) + # LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + return_all_prediction_heads: bool = False, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + # MTP: The last layer is not part of the shared trunk + for decoder_layer in self.layers[:-1]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + partial(decoder_layer.__call__, **flash_attn_kwargs), + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # MTP heads + if return_all_prediction_heads: + layers_to_run = [self.layers[-1]] + list(self.mtp_heads) + else: + layers_to_run = [self.layers[-1]] + latents = [] + for i, decoder_layer in enumerate(layers_to_run): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + mtp_hidden_states = layer_outputs[0] + latents.append(self.mtp_norms[i](mtp_hidden_states)) + + if return_all_prediction_heads: + # (batch, seq, len(layers_to_run), hidden_size) + hidden_states = torch.stack(latents, dim=-2) + else: + # (batch, seq, hidden_size) + assert len(latents) == 1 + hidden_states = latents[0] + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class MTPLlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = MTPLlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + # (*, num_prediction_heads, hidden_size) + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Llama Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForQuestionAnswering(LlamaPreTrainedModel): + base_model_prefix = "transformer" + + # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama + def __init__(self, config): + super().__init__(config) + self.transformer = MTPLlamaModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + **kwargs, + ) -> QuestionAnsweringModelOutput: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + + outputs: BaseModelOutputWithPast = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + sequence_output = outputs.last_hidden_state + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 7d9e59a4d..a7ec58d67 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -231,6 +231,7 @@ def preprocess( _, common_kwargs = preprocessed_meta[0] sequence_q = common_kwargs[TransformerKwargs.sequence_q_dim].size sequence_first = common_kwargs[TransformerKwargs.sequence_first] + prediction_heads: int = self._config.prediction_heads batch.token_ids = batch.token_ids.to( device=self._tensor_space.distributed.device, @@ -265,20 +266,22 @@ def preprocess( if phase != PhaseType.inference: sequence_offset = sequence_k - sequence_q + 1 if sequence_first: - labels = batch.token_ids[sequence_offset : sequence_k + 1] + labels = batch.token_ids[sequence_offset : sequence_k + prediction_heads] else: # TODO: Avoid multiple contiguous calls? - labels = batch.token_ids[:, sequence_k - sequence_q + 1 : sequence_k + 1].contiguous() + labels = batch.token_ids[:, sequence_offset : sequence_k + prediction_heads].contiguous() # We set label indices to -100 for masked spans, inline with ignore_index in torch.nn.CrossEntropyLoss # TODO: take ignore_index from config if batch.loss_masking_spans is not None: for i, spans in enumerate(batch.loss_masking_spans): if not spans.numel(): continue - valid_spans = spans[(spans[:, 0] <= sequence_k) & (spans[:, 1] >= sequence_offset)] + valid_spans = spans[ + (spans[:, 0] <= sequence_k + prediction_heads - 1) & (spans[:, 1] >= sequence_offset) + ] if valid_spans.numel(): valid_spans[:, 0].clamp_(min=sequence_offset) - valid_spans[:, 1].clamp_(max=sequence_k) + valid_spans[:, 1].clamp_(max=sequence_k + prediction_heads - 1) valid_spans -= sequence_offset for start, end in valid_spans: if sequence_first: diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index a269f5a63..a1c0c8bb7 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -49,6 +49,7 @@ def _get_sampling_parameters( "sequence_length": self._config.batch.sequence_length, "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, "cross_document_attention": self._config.batch.cross_document_attention, + "extra_tokens": self._config.model.base_model.prediction_heads, } ) return parameters if _return_dict else GPTSamplingParameters(**parameters) @@ -61,7 +62,8 @@ def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, sequence_length = self._config.batch.sequence_length tokens = self._config.batch.batch_size * sequence_length - transformer_flops_base = 2 * checkpoint_activations_factor * tokens * transformer_config.num_layers + num_transformer_layers = transformer_config.num_layers + self._config.model.base_model.prediction_heads - 1 + transformer_flops_base = 2 * checkpoint_activations_factor * tokens * num_transformer_layers dense_flops_base = transformer_flops_base * transformer_config.hidden_size # Query, key, value, dense. flops_per_iteration = ( @@ -79,7 +81,13 @@ def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, ) # LM-head - flops_per_iteration += 6 * tokens * transformer_config.hidden_size * self._config.model.base_model.vocab_size + flops_per_iteration += ( + 6 + * tokens + * transformer_config.hidden_size + * self._config.model.base_model.vocab_size + * self._config.model.base_model.prediction_heads + ) # Attention-matrix computation attn_flops_base = transformer_flops_base * transformer_config.projection_size diff --git a/tests/common.py b/tests/common.py index 211be004a..dfdee9642 100644 --- a/tests/common.py +++ b/tests/common.py @@ -17,6 +17,7 @@ LlamaGPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, + MTPLlamaGPTHuggingfaceCheckpointFormat, Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) @@ -264,7 +265,7 @@ CONFIG_LLAMA_MTP_FAST_LLM, CONFIG_LLAMA_MTP_MEGATRON, CONFIG_LLAMA_MTP_COMMON, - LlamaGPTHuggingfaceCheckpointFormat, + MTPLlamaGPTHuggingfaceCheckpointFormat, ), } diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index d446f4142..c5b350192 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -354,7 +354,9 @@ def test_run_converted_model(): ) errors = [] compare = CompareConfig() - model_as_hf = transformers.AutoModelForCausalLM.from_pretrained(_CONVERT_PATH / "huggingface_0").cuda() + model_as_hf = transformers.AutoModelForCausalLM.from_pretrained( + _CONVERT_PATH / "huggingface_0", trust_remote_code=HUGGINGFACE_CHECKPOINT_FORMAT.trust_remote_code + ).cuda() for name, model in zip( ("From state dict", "From Huggingface", "Native Huggingface"), (model_from_fast_llm, model_from_hf, model_as_hf),