Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,11 @@ tokenize_train_data: True # False if the dataset is pre-tokenized
tokenize_eval_data: True # False if the dataset is pre-tokenized
add_bos: True
add_eos: True
# If False, use chunking for long sequences instead of truncation.
# Note: use_truncation=False is only available in grain's pretrain preprocessing pipeline.
# See the TokenizeAndTrim and TokenizeAndChunk classes in
# `src/MaxText/input_pipeline/_grain_tokenizer.py` for implementation details.
use_truncation: True

# Dataset
per_device_batch_size: 12.0
Expand Down
29 changes: 22 additions & 7 deletions src/MaxText/input_pipeline/_grain_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,7 @@ def pretrain_preprocessing_pipeline(dataset, config, data_columns, tokenize, gra
dataset = dataset.map(_input_pipeline_utils.NormalizeFeatures(data_columns, tokenize))

assert len(data_columns) == 1
rekey_dict = {"inputs": "text", "targets": "text"}
dataset = dataset.map(_input_pipeline_utils.Rekey(rekey_dict))
data_columns = ("inputs", "targets")
text_column = data_columns[0]

tokenizer_model = tokenizer.build_tokenizer(
config.tokenizer_path,
Expand All @@ -115,11 +113,28 @@ def pretrain_preprocessing_pipeline(dataset, config, data_columns, tokenize, gra
pad_id = -1

if tokenize:
dataset = dataset.map(
if config.use_truncation:
dataset = dataset.map(
_grain_tokenizer.TokenizeAndTrim(
data_columns, config.max_target_length, config.add_bos, config.add_eos, tokenizer_model
text_column, config.max_target_length, tokenizer_model
)
)
)
else:
dataset = grain.experimental.WithOptionsIterDataset(
dataset,
options=grain.experimental.DatasetOptions()
)
dataset = grain.experimental.apply_transformations(
dataset,
_grain_tokenizer.TokenizeAndChunk(
text_column, config.max_target_length, tokenizer_model
)
)

data_columns = ("inputs", "targets")
rekey_dict = {col: text_column for col in data_columns}
dataset = dataset.map(_input_pipeline_utils.Rekey(rekey_dict))

# Pack and Batch examples.
batch_size = config.global_batch_size_to_load // jax.process_count()
if config.packing:
Expand Down Expand Up @@ -173,7 +188,7 @@ def dpo_preprocessing_pipeline(dataset, config, data_columns, tokenize, grain_wo
if tokenize:
dataset = dataset.map(
_grain_tokenizer.TokenizeAndTrim(
data_columns, config.max_target_length, config.add_bos, config.add_eos, tokenizer_model
data_columns, config.max_target_length, tokenizer_model
)
)

Expand Down
71 changes: 59 additions & 12 deletions src/MaxText/input_pipeline/_grain_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,35 +24,34 @@


@dataclasses.dataclass
class TokenizeAndTrim(grain.MapTransform):
"""Tokenize and trim features to sequence length."""
class TokenizerTransformBase:
"""Base class for tokenizer transforms with common functionality."""

# pylint: disable=attribute-defined-outside-init
feature_names: str | Sequence[str]
sequence_length: int | Sequence[int]
add_bos: bool
add_eos: bool
tokenizer: tokenizer.SentencePieceTokenizerGrain | tokenizer.HFTokenizer

def __post_init__(self):
self._processor = None
self._initialize_processor_lock = threading.Lock()
# Convert single values to lists for consistent processing
if isinstance(self.feature_names, str):
self.feature_names = [self.feature_names]
if isinstance(self.sequence_length, int):
self.sequence_length = [self.sequence_length] * len(self.feature_names)

def map(self, element: dict[str, Any]) -> dict[str, Any]:
"""Maps to each element."""
def _get_processor(self):
if self._processor is None:
with self._initialize_processor_lock:
if self._processor is None: # Ensures only one thread initializes SPP.
if self._processor is None: # Ensures only one thread initializes processor.
self._processor = self.tokenizer
for feature_name, sequence_length in zip(self.feature_names, self.sequence_length, strict=True):
text = element[feature_name]
token_ids = self._processor.encode(text)[:sequence_length]
element[feature_name] = np.asarray(token_ids, dtype=np.int32)
return element
return self._processor

def _encode(self, text: str) -> list[int]:
"""Common method to encode text using the tokenizer."""
processor = self._get_processor()
return processor.encode(text)

def __getstate__(self):
state = self.__dict__.copy()
Expand All @@ -64,3 +63,51 @@ def __setstate__(self, state):
self.__dict__.update(state)
self._processor = None
self._initialize_processor_lock = threading.Lock()


@dataclasses.dataclass
class TokenizeAndTrim(TokenizerTransformBase, grain.MapTransform):
"""Tokenize and trim features to sequence length."""

def __post_init__(self):
super().__post_init__()

def map(self, element: dict[str, Any]) -> dict[str, Any]:
"""Maps to each element."""
for feature_name, max_length in zip(self.feature_names, self.sequence_length, strict=True):
text = element[feature_name]
token_ids = self._encode(text)[:max_length]
element[feature_name] = np.asarray(token_ids, dtype=np.int32)
return element


@dataclasses.dataclass
class TokenizeAndChunk(TokenizerTransformBase, grain.experimental.FlatMapTransform):
"""Tokenize and chunk features into multiple examples of sequence length."""

max_fan_out: int = 2048

def __post_init__(self):
super().__post_init__()
# TokenizeAndChunk only supports single feature for chunking
assert len(self.feature_names) == 1, "TokenizeAndChunk only supports single feature name"
assert len(self.sequence_length) == 1, "TokenizeAndChunk only supports single sequence length"
self.feature_name = self.feature_names[0] # For backward compatibility
self.sequence_length = self.sequence_length[0] # Convert back to int for chunking

def flat_map(self, element: dict[str, Any]) -> list[dict[str, Any]]:
text = element[self.feature_name]
chunk_size = self.sequence_length

token_ids = self._encode(text)

if not token_ids:
return []

output_elements = []
for start_idx in range(0, len(token_ids), chunk_size):
chunk = np.asarray(token_ids[start_idx : start_idx + chunk_size], dtype=np.int32)
new_element = {self.feature_name: chunk}
output_elements.append(new_element)

return output_elements
152 changes: 152 additions & 0 deletions tests/tokenizer_transform_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" Tests for tokenizer
"""

import unittest

import grain.python as grain
import numpy as np
from MaxText.input_pipeline import _grain_tokenizer
from MaxText.input_pipeline import _input_pipeline_utils
from numpy.testing import assert_array_equal


class MockTokenizer:
"""
Mocks a tokenizer by splitting on space and mapping letters to simple ints.
e.g., "a b c" -> [1, 2, 3]
"""
def encode(self, text: str) -> list[int]:
if not text:
return []
# Simple 'a'=1, 'b'=2, ... mapping
return [ord(c) - ord('a') + 1 for c in text.split(' ')]


class TokenizerTransformTest(unittest.TestCase):
"""Tests for chunking, trimming, and padding transformations."""

def setUp(self):
self.max_len = 5
self.pad_length = 7
self.pad_id = 0
self.feature_names = "text"
self.mock_tokenizer = MockTokenizer()
self.source_data = [
{"text": "a b c"},
{"text": "d e f g h i j"},
{"text": ""},
{"text": "k l m n o p q r s t"}
]
self.base_ds = grain.MapDataset.source(self.source_data).to_iter_dataset()

def test_tokenize_and_trim(self):
"""Tests the 1:1 MapTransform (truncation) logic."""
trim_op = _grain_tokenizer.TokenizeAndTrim(
feature_names=self.feature_names,
sequence_length=self.max_len,
tokenizer=self.mock_tokenizer
)
trim_ds = self.base_ds.map(trim_op)
results = list(trim_ds)
self.assertEqual(len(results), len(self.source_data))
expected_inputs = [
np.array([1, 2, 3], dtype=np.int32),
np.array([4, 5, 6, 7, 8], dtype=np.int32),
np.array([], dtype=np.int32),
np.array([11, 12, 13, 14, 15], dtype=np.int32)
]
result_inputs = [r["text"] for r in results]
self.assertEqual(len(result_inputs), len(expected_inputs))
for res, exp in zip(result_inputs, expected_inputs):
assert_array_equal(res, exp)

def test_tokenize_and_chunk(self):
"""Tests the 1:N FlatMapTransform (chunking) logic."""
chunk_op = _grain_tokenizer.TokenizeAndChunk(
feature_names=self.feature_names,
sequence_length=self.max_len,
tokenizer=self.mock_tokenizer
)
chunk_ds = self.base_ds.apply(chunk_op)
results = list(chunk_ds)
self.assertEqual(len(results), 5)
expected_inputs = [
np.array([1, 2, 3], dtype=np.int32),
np.array([4, 5, 6, 7, 8], dtype=np.int32),
np.array([9, 10], dtype=np.int32),
np.array([11, 12, 13, 14, 15], dtype=np.int32),
np.array([16, 17, 18, 19, 20], dtype=np.int32)
]
result_inputs = [r["text"] for r in results]
self.assertEqual(len(result_inputs), len(expected_inputs))
for res, exp in zip(result_inputs, expected_inputs):
assert_array_equal(res, exp)

def test_trim_and_pad_chaining(self):
"""Tests chaining TokenizeAndTrim.map() -> PadOrTrimToMaxLength.map()"""
trim_op = _grain_tokenizer.TokenizeAndTrim(
feature_names=self.feature_names,
sequence_length=self.max_len,
tokenizer=self.mock_tokenizer
)
pad_op = _input_pipeline_utils.PadOrTrimToMaxLength(
max_length=self.pad_length,
pad_id=self.pad_id
)
chained_ds = self.base_ds.map(trim_op).map(pad_op)
results = list(chained_ds)
self.assertEqual(len(results), len(self.source_data))
expected_inputs = [
np.array([1, 2, 3, 0, 0, 0, 0], dtype=np.int32),
np.array([4, 5, 6, 7, 8, 0, 0], dtype=np.int32),
np.array([0, 0, 0, 0, 0, 0, 0], dtype=np.int32),
np.array([11, 12, 13, 14, 15, 0, 0], dtype=np.int32)
]
result_inputs = [r["text"] for r in results]
self.assertEqual(len(result_inputs), len(expected_inputs))
for res, exp in zip(result_inputs, expected_inputs):
assert_array_equal(res, exp)

def test_chunk_and_pad_chaining(self):
"""Tests chaining TokenizeAndChunk.apply() -> PadOrTrimToMaxLength.map()"""
chunk_op = _grain_tokenizer.TokenizeAndChunk(
feature_names=self.feature_names,
sequence_length=self.max_len,
tokenizer=self.mock_tokenizer
)
pad_op = _input_pipeline_utils.PadOrTrimToMaxLength(
max_length=self.pad_length,
pad_id=self.pad_id
)
chained_ds = self.base_ds.apply(chunk_op).map(pad_op)
results = list(chained_ds)
self.assertEqual(len(results), 5)
expected_inputs = [
np.array([1, 2, 3, 0, 0, 0, 0], dtype=np.int32),
np.array([4, 5, 6, 7, 8, 0, 0], dtype=np.int32),
np.array([9, 10, 0, 0, 0, 0, 0], dtype=np.int32),
np.array([11, 12, 13, 14, 15, 0, 0], dtype=np.int32),
np.array([16, 17, 18, 19, 20, 0, 0], dtype=np.int32),
]
result_inputs = [r["text"] for r in results]
self.assertEqual(len(result_inputs), len(expected_inputs))
for res, exp in zip(result_inputs, expected_inputs):
assert_array_equal(res, exp)


if __name__ == "__main__":
unittest.main()