Skip to content

[2/5] feat: cross-tokenizer collator, Arrow dataset, and eval datasets#2348

Closed
avenkateshha wants to merge 3 commits into
NVIDIA-NeMo:mainfrom
avenkateshha:avenkateshha/xtoken-off-policy-distillation/02-collator-data
Closed

[2/5] feat: cross-tokenizer collator, Arrow dataset, and eval datasets#2348
avenkateshha wants to merge 3 commits into
NVIDIA-NeMo:mainfrom
avenkateshha:avenkateshha/xtoken-off-policy-distillation/02-collator-data

Conversation

@avenkateshha
Copy link
Copy Markdown

@avenkateshha avenkateshha commented Apr 27, 2026

Data-layer plumbing for cross-tokenizer off-policy distillation, plus in-training eval datasets. Builds on the TokenAligner package from PR 1.

  • nemo_rl/data/cross_tokenizer_collate.py: CrossTokenizerCollator and TeacherCTSpec. Runs in StatefulDataLoader worker processes — does per-teacher tokenize + DP alignment up front so the train loop only consumes pre-built per_teacher_ct_data. Lazy-imports TokenAligner so workers that don't need cross-tokenizer never touch x_token.
  • nemo_rl/data/__init__.py: add NotRequired prefetch_factor to DataConfig.
  • nemo_rl/data/datasets/response_datasets/arrow_text_dataset.py: ArrowTextDataset with lazy packing, registered as "arrow_text" in DATASET_REGISTRY.
  • nemo_rl/data/datasets/eval_datasets/{humaneval_plus, mbpp_plus, mmlu}.py and registry entries: in-training eval datasets. mmlu.py adds an optional num_few_shot argument with a static _build_few_shot_prefixes helper; default of 0 preserves existing behavior.

What does this PR do?

Adds the data-layer plumbing (collator + Arrow dataset) consumed by the cross-tokenizer distillation training loop, plus HumanEval+/MBPP+/MMLU eval datasets used for in-training evaluation.

Issues

None linked yet.

Usage

from torchdata.stateful_dataloader import StatefulDataLoader
from nemo_rl.data.cross_tokenizer_collate import CrossTokenizerCollator, TeacherCTSpec

teacher_ct_specs = [TeacherCTSpec(
    teacher_tokenizer_name="microsoft/Phi-4-mini-instruct",
    student_tokenizer_name="meta-llama/Llama-3.2-1B",
    projection_matrix_path="cross_tokenizer_data/llama_phi-mini_proj.pt",
    use_sparse_format=False, learnable=False, max_comb_len=4,
    projection_matrix_multiplier=1.0, project_teacher_to_student=False,
    max_teacher_len=4096, dp_chunk_size=128, use_align_fast=True,
    exact_token_match_only=False,
)]

collator = CrossTokenizerCollator(
    pad_token_id=tokenizer.pad_token_id,
    make_sequence_length_divisible_by=1,
    teacher_ct_specs=teacher_ct_specs,
    fallback_student_tokenizer_name="meta-llama/Llama-3.2-1B",
)

loader = StatefulDataLoader(
    dataset, batch_size=768, collate_fn=collator,
    num_workers=8, prefetch_factor=4, persistent_workers=True,
)

Before your PR is "Ready for review"

  • Read Contributor guidelines
  • No new tests in this PR. Collator behavior is exercised by the off-policy distillation recipe in PR 5.
  • Static py_compile confirmed clean. Full functional run lands with PR 5.
  • No docs entry yet — added alongside PR 5.

Additional Information

Stacked on PR 1 (TokenAligner) — #2347. Imports TokenAligner lazily so workers without CT teachers don't pay the cost.

Full chain:

  1. TokenAligner + projection utilities — [1/5] feat: add TokenAligner and cross-tokenizer projection utilities #2347
  2. (this PR) Collator + Arrow dataset + eval datasets
  3. CT distillation loss + multi-teacher aggregator — [3/5] feat: cross-tokenizer distillation loss and multi-teacher aggregator #2349
  4. CUDA IPC for teacher logits — [4/5] feat: CUDA IPC for teacher logits transfer #2350
  5. Algorithm + worker integration — [5/5] feat: off-policy distillation algorithm and worker integration #2351

@avenkateshha avenkateshha requested review from a team as code owners April 27, 2026 03:56
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 27, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Foundational library code for cross-tokenizer distillation. No algorithm
or training-loop integration yet — those follow in subsequent PRs.

- nemo_rl/algorithms/x_token/tokenalign.py: TokenAligner(nn.Module) with
  Numba-accelerated DP alignment, projection-matrix loading
  (dense and sparse COO), and the project_token_likelihoods_instance
  forward path used by the cross-tokenizer loss.
- nemo_rl/algorithms/x_token/__init__.py: package init.
- nemo_rl/utils/x_token/{minimal_projection_generator,
  minimal_projection_via_multitoken,reapply_exact_map,
  sort_and_cut_projection_matrix}.py: standalone CLI scripts
  (argparse-driven, __main__ entrypoints) for one-time projection-matrix
  preparation. Not on the training import path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
@avenkateshha avenkateshha force-pushed the avenkateshha/xtoken-off-policy-distillation/02-collator-data branch from 7e19a1c to 66eb138 Compare April 27, 2026 10:21
Data-layer plumbing for cross-tokenizer off-policy distillation, plus
in-training eval datasets. Builds on the TokenAligner package from the
prior PR.

- nemo_rl/data/cross_tokenizer_collate.py: CrossTokenizerCollator and
  TeacherCTSpec. Runs in StatefulDataLoader worker processes — does
  per-teacher tokenize + DP alignment up front so the train loop only
  consumes pre-built per_teacher_ct_data. Lazy-imports TokenAligner so
  workers that don't need cross-tokenizer never touch x_token.
- nemo_rl/data/__init__.py: add NotRequired prefetch_factor to DataConfig.
- nemo_rl/data/datasets/response_datasets/arrow_text_dataset.py:
  ArrowTextDataset with lazy packing, registered as "arrow_text" in
  DATASET_REGISTRY.
- nemo_rl/data/datasets/eval_datasets/{humaneval_plus,mbpp_plus,mmlu}.py
  and registry entries: in-training eval datasets. mmlu.py adds an
  optional num_few_shot argument with a static _build_few_shot_prefixes
  helper; default of 0 preserves existing behavior.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
@avenkateshha avenkateshha force-pushed the avenkateshha/xtoken-off-policy-distillation/02-collator-data branch from 66eb138 to 5934699 Compare April 27, 2026 10:25
@avenkateshha avenkateshha requested a review from a team as a code owner April 27, 2026 10:25
Follow-up to the TokenAligner refactor on 01-tokenaligner that dropped
align_fast() / precompute_canonical_maps(). The collator was the only
non-trivial caller of those APIs.

- Dropped TeacherCTSpec.use_align_fast (typed-dict field).
- Removed the precompute_canonical_maps() call in _lazy_init — that
  helper no longer exists on TokenAligner; align() does the full
  sequence-level canonicalization on each call.
- Replaced the use_align_fast if/else branch with the unconditional
  aligner.align(s_t, t_t, chunk_size=dp_chunk_size) DP path.

Behavior: alignment for sequences containing encoding-artifact or byte
tokens may differ slightly because align() runs
_merge_encoding_artifacts / _merge_consecutive_bytes (the cached
per-token canonical maps used by align_fast skipped these). For pairs
without those tokens, results are identical. Small per-batch CPU bump
in the collator workers since canonical strings are no longer cached
per id.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
@svcnvidia-nemo-ci svcnvidia-nemo-ci added the waiting-on-maintainers Waiting on maintainers to respond label Apr 29, 2026
@avenkateshha avenkateshha deleted the avenkateshha/xtoken-off-policy-distillation/02-collator-data branch May 16, 2026 01:47
@svcnvidia-nemo-ci svcnvidia-nemo-ci removed the waiting-on-maintainers Waiting on maintainers to respond label May 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants