Skip to content

Conversation

bzantium
Copy link
Contributor

@bzantium bzantium commented Sep 16, 2025

This PR introduces "chunking" as an alternative to "truncation" in the Grain input pipeline.

Previously, the TokenizeAndTrim operation (MapTransform) would truncate any document longer than max_target_length, discarding all subsequent tokens. This change introduces a new TokenizeAndChunk operation (FlatMapTransform) that splits a single long document into multiple training examples, each no longer than max_target_length.

This new behavior is controlled by a new configuration flag, use_truncation.

  • Why is this change being made?
    The default truncation behavior is highly data-inefficient for corpora with many long documents (like C4). It wastes significant amounts of data, compute, and storage, and may bias the model by only ever training on the beginning of documents.

  • The problem being solved and any relevant context:
    This PR solves the problem of data loss during tokenization for long sequences. By using a 1:N FlatMapTransform, we can map one long input document to a list of multiple, valid training chunks, ensuring 100% of the tokenized data is used.

  • Why this is a good solution:
    This solution is efficient and flexible. It utilizes the FlatMapTransform provided by Grain, which is designed for this 1:N mapping. It is also fully backwards-compatible, as the new chunking behavior is "opt-in" by setting use_truncation = False in the config. The default behavior remains truncation.

  • Some information about the specific implementation:

    1. _grain_tokenizer.py: A new TokenizeAndChunk class has been added. It inherits from grain.experimental.FlatMapTransform and implements the flat_map method to split a list of token IDs into multiple chunks.
    2. _grain_data_processing.py: The pretrain_preprocessing_pipeline function has been updated with a conditional check for config.use_truncation:
      • If True, it uses the existing dataset.map(TokenizeAndTrim(...)).
      • If False, it uses dataset.apply(TokenizeAndChunk(...)).
    3. Requirement: The dataset.apply() method and support for FlatMapTransform are recent features in Grain. This PR requires a version of Grain installed directly from the main branch.
      pip install git+https://github.com/google/grain.git
  • Shortcomings of the solution and possible future improvements.
    The max_fan_out attribute in TokenizeAndChunk is set with a class-level default (2048). If a document is exceptionally long and produces more chunks than this, it will error. This could be exposed as a configuration option in the future if needed.

Tests

This change is tested with a new, self-contained unit test file: tests/tokenizer_transform_test.py.

  • This test does not require real data (like C4) or JAX/TPU.
  • It uses a MockTokenizer to provide known, deterministic tokenization ("a b c" -> [1, 2, 3]).
  • It uses an in-memory grain.MapDataset.source with a small, known dataset to test edge cases (short text, long text, and multi-chunk text).
  • Four separate test cases were added to verify the logic:
    1. test_tokenize_and_trim: Verifies the original 1:1 truncation logic is correct.
    2. test_tokenize_and_chunk: Verifies the new 1:N chunking logic (e.g., an input with 7 tokens and max_len=5 correctly produces two new examples with 5 and 2 tokens).
    3. test_trim_and_pad_chaining: Verifies that the output of TokenizeAndTrim can be correctly chained into a subsequent PadToMaxLength transform.
    4. test_chunk_and_pad_chaining: Verifies that all outputs from TokenizeAndChunk are correctly chained into PadToMaxLength (e.g., both the 5-token chunk and the 2-token chunk are correctly padded).

To reproduce, you can run the new test file directly:

python -m unittest tests/tokenizer_transform_test.py

Fixes: #2344

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link
Collaborator

@aireenmei aireenmei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feature! And great unit tests! Just some minor comments.

@bzantium
Copy link
Contributor Author

Thanks for the great feedback!

I've pushed the changes addressing all your points:

  • base.yml comment: Added the comment to the use_truncation flag to point to the implementation classes as you suggested.
  • Tokenizer Refactoring:
    • I've simplified both TokenizeAndTrim and TokenizeAndChunk to operate on a single text_column.
    • The Rekey transform has been moved to execute after the tokenization step for both code paths, which cleans up the pipeline logic nicely.

to: @aireenmei

@aireenmei
Copy link
Collaborator

Looks like the github actions tests need to be triggered by a maintainer. Please take a look at the test failures. You can also run them locally

- Added comment that TokenizeAndChunk removes all columns except the text_column
- Modified _grain_tokenizer.py with latest changes
- Added note that use_truncation=False is only available in grain's pretrain preprocessing pipeline
- Move feature_names, sequence_length, add_bos, add_eos, and tokenizer to TokenizerTransformBase
- Consolidate initialization logic in base class __post_init__
- Simplify TokenizeAndTrim and TokenizeAndChunk by removing duplicate parameters
- Add common _encode method to eliminate code duplication
- Maintain backward compatibility and specialized behavior for each class
Copy link
Collaborator

@aireenmei aireenmei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks! Pls make sure the tests are passing before merging

@bzantium
Copy link
Contributor Author

bzantium commented Sep 22, 2025

@aireenmei yeah, I think all look good for now. thanks for the detailed reviews!

@bzantium
Copy link
Contributor Author

I think dataset.apply requires latest grain so need to change requirements.txt like grain[parquet] @ https://github.com/google/grain/archive/bab58eabf0c94d16002b5d82dda8d8320edd3c7b.zip)

@aireenmei
Copy link
Collaborator

I think dataset.apply requires latest grain so need to change requirements.txt like grain[parquet] @ https://github.com/google/grain/archive/bab58eabf0c94d16002b5d82dda8d8320edd3c7b.zip)

The github runner uses requirements_with_jax_ai_image.txt. Could you change it with the PR and see if that fix the test?

@bzantium bzantium requested a review from parambole as a code owner September 24, 2025 23:22
@Rohan-Bierneni
Copy link
Collaborator

Hi @bzantium can you rebase your pr and run the tests again. There were some changes to testing infra last week that require a rebase for successful image builds in unit tests.

@bzantium bzantium requested a review from suexu1025 as a code owner October 1, 2025 23:25
@bzantium
Copy link
Contributor Author

bzantium commented Oct 2, 2025

@aireenmei @Rohan-Bierneni @SurbhiJainUSC
I revised the code to use WithOptionsIterDataset and apply_transformations instead of dataset.apply because grain needs to be pre-compiled because it generates python file using pb2 with compilation but installation with latest commit does not perform this (pypi supports this so I decided to use grain==0.2.12).
Moreover, I found that add_bos/eos are used when building tokenizer not on transformation so I removed those arguments. I've checked current code works fine on my TPU cluster.

@aireenmei
Copy link
Collaborator

@aireenmei @Rohan-Bierneni @SurbhiJainUSC I revised the code to use WithOptionsIterDataset and apply_transformations instead of dataset.apply because grain needs to be pre-compiled because it generates python file using pb2 with compilation but installation with latest commit does not perform this (pypi supports this so I decided to use grain==0.2.12). Moreover, I found that add_bos/eos are used when building tokenizer not on transformation so I removed those arguments. I've checked current code works fine on my TPU cluster.

Thanks for the update! If grain==0.2.12 doesn't work, let me know if we need to request a new release from the grain team.

@bzantium
Copy link
Contributor Author

bzantium commented Oct 2, 2025

@aireenmei thanks for the fast reply! I've checked current implementation works fine with grain==0.12.2 on tpu-v6e but I think if it's possible for grain team to release new version, that would be nice in terms of this PR (can make code more neat as previous) and other upcoming features that I want to add soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[feature request] Support chunking (splitting) long sequences instead of truncation during tokenization
4 participants