-
Notifications
You must be signed in to change notification settings - Fork 418
feat(input_pipeline): Add support for chunking long sequences instead truncation #2354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feature! And great unit tests! Just some minor comments.
Thanks for the great feedback! I've pushed the changes addressing all your points:
to: @aireenmei |
Looks like the github actions tests need to be triggered by a maintainer. Please take a look at the test failures. You can also run them locally |
- Added comment that TokenizeAndChunk removes all columns except the text_column
- Modified _grain_tokenizer.py with latest changes
- Added note that use_truncation=False is only available in grain's pretrain preprocessing pipeline
- Move feature_names, sequence_length, add_bos, add_eos, and tokenizer to TokenizerTransformBase - Consolidate initialization logic in base class __post_init__ - Simplify TokenizeAndTrim and TokenizeAndChunk by removing duplicate parameters - Add common _encode method to eliminate code duplication - Maintain backward compatibility and specialized behavior for each class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks! Pls make sure the tests are passing before merging
@aireenmei yeah, I think all look good for now. thanks for the detailed reviews! |
I think |
The github runner uses requirements_with_jax_ai_image.txt. Could you change it with the PR and see if that fix the test? |
Hi @bzantium can you rebase your pr and run the tests again. There were some changes to testing infra last week that require a rebase for successful image builds in unit tests. |
…os since they are used at tokenizer itself not tokenizer trasform
@aireenmei @Rohan-Bierneni @SurbhiJainUSC |
Thanks for the update! If grain==0.2.12 doesn't work, let me know if we need to request a new release from the grain team. |
@aireenmei thanks for the fast reply! I've checked current implementation works fine with grain==0.12.2 on tpu-v6e but I think if it's possible for grain team to release new version, that would be nice in terms of this PR (can make code more neat as previous) and other upcoming features that I want to add soon. |
This PR introduces "chunking" as an alternative to "truncation" in the Grain input pipeline.
Previously, the
TokenizeAndTrim
operation (MapTransform
) would truncate any document longer thanmax_target_length
, discarding all subsequent tokens. This change introduces a newTokenizeAndChunk
operation (FlatMapTransform
) that splits a single long document into multiple training examples, each no longer thanmax_target_length
.This new behavior is controlled by a new configuration flag,
use_truncation
.Why is this change being made?
The default truncation behavior is highly data-inefficient for corpora with many long documents (like C4). It wastes significant amounts of data, compute, and storage, and may bias the model by only ever training on the beginning of documents.
The problem being solved and any relevant context:
This PR solves the problem of data loss during tokenization for long sequences. By using a 1:N
FlatMapTransform
, we can map one long input document to a list of multiple, valid training chunks, ensuring 100% of the tokenized data is used.Why this is a good solution:
This solution is efficient and flexible. It utilizes the
FlatMapTransform
provided by Grain, which is designed for this 1:N mapping. It is also fully backwards-compatible, as the new chunking behavior is "opt-in" by settinguse_truncation = False
in the config. The default behavior remains truncation.Some information about the specific implementation:
_grain_tokenizer.py
: A newTokenizeAndChunk
class has been added. It inherits fromgrain.experimental.FlatMapTransform
and implements theflat_map
method to split a list of token IDs into multiple chunks._grain_data_processing.py
: Thepretrain_preprocessing_pipeline
function has been updated with a conditional check forconfig.use_truncation
:True
, it uses the existingdataset.map(TokenizeAndTrim(...))
.False
, it usesdataset.apply(TokenizeAndChunk(...))
.dataset.apply()
method and support forFlatMapTransform
are recent features in Grain. This PR requires a version of Grain installed directly from the main branch.Shortcomings of the solution and possible future improvements.
The
max_fan_out
attribute inTokenizeAndChunk
is set with a class-level default (2048
). If a document is exceptionally long and produces more chunks than this, it will error. This could be exposed as a configuration option in the future if needed.Tests
This change is tested with a new, self-contained unit test file:
tests/tokenizer_transform_test.py
.MockTokenizer
to provide known, deterministic tokenization ("a b c" -> [1, 2, 3]
).grain.MapDataset.source
with a small, known dataset to test edge cases (short text, long text, and multi-chunk text).test_tokenize_and_trim
: Verifies the original 1:1 truncation logic is correct.test_tokenize_and_chunk
: Verifies the new 1:N chunking logic (e.g., an input with 7 tokens andmax_len=5
correctly produces two new examples with 5 and 2 tokens).test_trim_and_pad_chaining
: Verifies that the output ofTokenizeAndTrim
can be correctly chained into a subsequentPadToMaxLength
transform.test_chunk_and_pad_chaining
: Verifies that all outputs fromTokenizeAndChunk
are correctly chained intoPadToMaxLength
(e.g., both the 5-token chunk and the 2-token chunk are correctly padded).To reproduce, you can run the new test file directly:
Fixes: #2344
Checklist
Before submitting this PR, please make sure (put X in square brackets):