Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class GPTSamplingParameters(SamplingParameters):
vocab_size: int
use_loss_masking_spans: bool = False
cross_document_attention: bool = True
# How many extra tokens to add to the sequence length.
# This is used to provide labels even for the last tokens in the sequence.
extra_tokens: int = 1


@dataclasses.dataclass(kw_only=True)
Expand Down Expand Up @@ -258,7 +261,7 @@ def build(self) -> SamplableDataset:
return config.build()

def _load_config(self):
assert self.path.is_file()
assert self.path.is_file(), f"File {self.path} does not exist."
return GPTSampledDatasetConfig.from_dict(self._convert_paths(yaml.safe_load(self.path.open("r"))))

def _convert_paths(self, config):
Expand Down
37 changes: 23 additions & 14 deletions fast_llm/data/dataset/gpt/sampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,17 +145,19 @@ def _sample(self) -> None:
raise RuntimeError(
f" > No documents shorter than {self._parameters.sequence_length+1} tokens found in dataset {self._indexed_dataset.name}."
)
# TODO MTP: Produce more labels to provide labels for the multi-token prediction heads?
# We produce sequences of length `self._sequence_length + 1` so the last token has a label,
# but in case of truncations we also include that last label in the following sample,
# so we need `sequence_length * num_samples + 1` tokens in total.
num_epochs = math.ceil(
(
(self._parameters.sequence_length + 1 - self._truncate_documents) * self._parameters.num_samples
+ 1 * self._truncate_documents
# We produce sequences of length `self._sequence_length + extra_tokens` so the last token has a label for all prediction heads,
# but in case of truncations we also include those last labels in the following sample,
# so we need `sequence_length * num_samples + extra_tokens` tokens in total.
if self._truncate_documents:
num_epochs = math.ceil(
(self._parameters.sequence_length * self._parameters.num_samples + self._parameters.extra_tokens)
/ tokens_per_epoch
)
else:
num_epochs = math.ceil(
((self._parameters.sequence_length + self._parameters.extra_tokens) * self._parameters.num_samples)
/ tokens_per_epoch
)
/ tokens_per_epoch
)

# Prepare for shuffling.
generator = torch.Generator(device=self._device)
Expand Down Expand Up @@ -349,8 +351,13 @@ def __getitem__(self, index: int) -> typing.Any:
self._lazy_load()
# tokens at the boundary are included in only one sample when we pack without truncations
# in case of packing with truncations, the last token from the previous sample is also the first token of the next sample
token_start = index * (self._parameters.sequence_length + 1 - self._truncate_documents)
token_end = token_start + self._parameters.sequence_length + 1
sample_length = (
self._parameters.sequence_length
if self._truncate_documents
else self._parameters.sequence_length + self._parameters.extra_tokens
)
token_start = index * sample_length
token_end = token_start + self._parameters.sequence_length + self._parameters.extra_tokens

if token_start < self._unshuffled_tokens:
token_start_array = self._token_cumsum_unshuffled.array
Expand Down Expand Up @@ -410,7 +417,9 @@ def __getitem__(self, index: int) -> typing.Any:
if self._parameters.use_loss_masking_spans:
for loss_masking_span in sample.loss_masking_spans:
span = np.clip(
loss_masking_span + token_count - token_start, 0, self._parameters.sequence_length + 1
loss_masking_span + token_count - token_start,
0,
self._parameters.sequence_length + self._parameters.extra_tokens,
)
if span[1] > span[0]:
loss_masking_spans.append(span)
Expand All @@ -430,7 +439,7 @@ def __getitem__(self, index: int) -> typing.Any:
if self._parameters.use_loss_masking_spans
else None
)
Assert.eq(len(token_ids), self._parameters.sequence_length + 1)
Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens)

return GPTSample(token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths)

Expand Down
34 changes: 33 additions & 1 deletion fast_llm/engine/checkpoint/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import abc
import json
import pathlib
import shutil
import typing

import safetensors
import torch
from transformers.configuration_utils import PretrainedConfig

from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveMetadataConfig
from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig, CheckpointSaveMetadataConfig
from fast_llm.engine.checkpoint.external import (
ConstantExportParamConverter,
ExternalStateDictCheckpointHandler,
Expand Down Expand Up @@ -118,3 +120,33 @@ def _load_weights(
yield from torch.load(path)
else:
raise NotImplementedError(f"Unknown file format for {path}")


class CustomModelingExportMixin:
"""
Mixin class for HuggingfaceStateDictCheckpointHandler to handle custom modeling files.
"""

modeling_file: typing.ClassVar[str]
configuration_file: typing.ClassVar[str]
configuration_cls: typing.ClassVar[type[PretrainedConfig]]

# Use custom config instead of relying on the transformers library
@classmethod
def _load_config(cls, directory: pathlib.Path | str) -> dict:
config = cls.configuration_cls.from_pretrained(directory).to_dict()
Assert.eq(config["model_type"], cls.get_huggingface_model_type())
return config

@classmethod
def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None:
cls.configuration_cls.from_dict(config).save_pretrained(directory)

def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None:
super().save(config, metadata)
self._copy_modeling_files(config)

def _copy_modeling_files(self, config: CheckpointSaveConfig) -> None:
# Copy the modeling files to the output directory
shutil.copy(self.modeling_file, config.path)
shutil.copy(self.configuration_file, config.path)
30 changes: 15 additions & 15 deletions fast_llm/layers/language_model/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,6 @@ def __init__(
Assert.geq(prediction_distance, 0)
self._prediction_distance = prediction_distance
self.is_last_head = self._prediction_distance == config.prediction_heads - 1
if self._prediction_distance > 0:
assert (
not self._sequence_parallel_logits
), "Sequence parallel logits not supported for multi-token prediction."
assert not self._cross_entropy_splits, "Cross-entropy splits not supported for multi-token prediction."

self._init_output_weights(hidden_dim, config)

Expand Down Expand Up @@ -137,8 +132,9 @@ def forward(
# Last head should return the loss for backward.
return language_model_loss
else:
# Backward hook to compute the gradient of the loss
shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, 1.0)
if self.training:
# Backward hook to compute the gradient of the loss
shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, 1.0)
# MTP: Return shared_hidden to be used by the next head.
return shared_hidden

Expand All @@ -147,18 +143,22 @@ def _forward_backward(
) -> tuple[torch.Tensor, torch.Tensor | None]:
labels = kwargs[LanguageModelKwargs.labels] if LanguageModelKwargs.labels in kwargs else None
# MTP: Shift the labels
labels = labels[:, self._prediction_distance :].flatten() if labels is not None else None
if labels is not None:
labels = (
labels[self._prediction_distance : self._prediction_distance + input_.size(0),]
if kwargs[TransformerKwargs.sequence_first]
else labels[
:,
self._prediction_distance : self._prediction_distance + input_.size(1),
]
)
labels = labels.flatten()
if self._sequence_parallel_logits:
labels = split_op(labels, self._tensor_space.distributed.tensor_group, 0)
do_grad = labels is not None and self.training
input_ = input_.detach().requires_grad_(do_grad)
with torch.enable_grad():
# MTP: truncate the input
if self._prediction_distance > 0:
truncated_input = input_[:, : -self._prediction_distance, :].contiguous()
else:
truncated_input = input_
ln_output = self.final_norm(truncated_input)
ln_output = self.final_norm(input_)

grad_output = kwargs[TransformerKwargs.grad_output] / (
self._group_size if self._sequence_parallel_logits else 1
Expand Down Expand Up @@ -197,7 +197,7 @@ def _logits_cross_entropy_forward_backward_split(
)
if labels is None:
# TODO: Make a proper way of returning the model output.
kwargs["logits"] = loss
kwargs["logits" if self._prediction_distance == 0 else f"logits_{self._prediction_distance}"] = loss
return None, None
else:
loss = None
Expand Down
7 changes: 7 additions & 0 deletions fast_llm/models/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

class GPTHuggingfaceCheckpointFormat(CheckpointFormat):
support_optimizer: typing.ClassVar[bool] = False
trust_remote_code: typing.ClassVar[bool] = False

@classmethod
def get_handler_class(cls) -> type[CheckpointHandler]:
Expand Down Expand Up @@ -51,6 +52,11 @@ class MixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
name: typing.ClassVar[str] = "mixtral"


class MTPLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
name: typing.ClassVar[str] = "mtp_llama"
trust_remote_code: typing.ClassVar[bool] = True


@config_class()
class GPTArchitectureConfig(LanguageModelArchitectureConfig):
_abstract = False
Expand Down Expand Up @@ -145,6 +151,7 @@ class GPTModelConfig(FastLLMModelConfig):
Qwen2GPTHuggingfaceCheckpointFormat,
MistralGPTHuggingfaceCheckpointFormat,
MixtralGPTHuggingfaceCheckpointFormat,
MTPLlamaGPTHuggingfaceCheckpointFormat,
)

@classmethod
Expand Down
Loading