Skip to content

feat: Optimize SFT dataloader with slice-based tokenization and caching#1695

Merged
SumanthRH merged 7 commits into
NovaSky-AI:mainfrom
jlee-lila:optimize-sft-tokenization
May 22, 2026
Merged

feat: Optimize SFT dataloader with slice-based tokenization and caching#1695
SumanthRH merged 7 commits into
NovaSky-AI:mainfrom
jlee-lila:optimize-sft-tokenization

Conversation

@jlee-lila
Copy link
Copy Markdown
Contributor

feat: Optimize SFT dataloader with slice-based tokenization and caching

Summary

Optimizes the SFT dataloader with two major improvements:

  1. Slice-based parallel tokenization - eliminates pickle overhead by having workers load their own data slices
  2. Tokenized dataset caching - persistent NFS-safe caching for reuse across training runs

Performance Improvements

Slice-Based Parallel Tokenization

Scale Before After Speedup
100K examples 383.7s (sequential) 43.2s (16 workers) 8.89x
~940K examples 1711s (old manual MP) 1647.7s (slice-based) 1.04x

Key benefits:

  • 8.89x speedup at 100K scale
  • 2.27x speedup at 1M scale vs sequential
  • Memory efficient: workers load data directly from HuggingFace
  • No pickle serialization overhead

Tokenized Dataset Caching

Operation Time (10K examples) Speedup
First run (tokenize + save) 21.4s 1.00x
Cache hit (load from disk) 0.4s 56.7x

Key benefits:

  • 56.7x speedup on cache hits
  • NFS-safe for multi-node training
  • Automatic cache key based on dataset + tokenization params
  • Configurable via cache_dir, force_recache, disable_cache

Changes

1. Slice-Based Tokenization (skyrl/train/sft_trainer.py)

  • Added _tokenize_chat_slice_worker() - worker for chat format with slice loading
  • Added _tokenize_alpaca_slice_worker() - worker for Alpaca format with slice loading
  • Added _parse_dataset_split() - parses split strings like "train[:100000]" into base split + indices
  • Updated _load_and_tokenize() - uses slice-based loading when num_workers > 0

How it works:

# Each worker loads its own slice directly from HuggingFace
dataset_slice = load_dataset(dataset_name, split=f"{base_split}[{start_idx}:{end_idx}]")

2. Dataset Caching (skyrl/train/sft_trainer.py, skyrl/train/config/sft_config.py)

  • Added _compute_cache_key() - deterministic hash of dataset + tokenization params
  • Added _get_cache_path(), _load_from_cache(), _save_to_cache() - cache I/O
  • Updated _load_and_tokenize() - checks cache before tokenizing
  • Added config fields: cache_dir, force_recache, disable_cache

Cache key includes:

  • Dataset name + split
  • Model path (tokenizer identity)
  • Max length
  • Messages key, tools key, system key
  • Train-on-what mode

NFS-safe atomic writes:

# Write to temp file, then atomic rename
temp_path = cache_path + ".tmp"
with open(temp_path, "wb") as f:
    pickle.dump(tokenized, f)
os.rename(temp_path, cache_path)  # Atomic on NFS

Configuration

Slice-Based Tokenization

cfg = SFTConfig(
    num_workers=16,  # Number of parallel workers (0 = sequential)
    ...
)

Caching

cfg = SFTConfig(
    cache_dir="/mnt/nfs/cache/skyrl",  # NFS path for multi-node
    # cache_dir="",  # Default: ~/.cache/skyrl/tokenized_datasets
    force_recache=False,  # Set True to ignore cache
    disable_cache=False,  # Set True to disable caching
    ...
)

Use Cases

  1. Hyperparameter sweeps - tokenize once, reuse across all runs
  2. Multi-node training - share cache via NFS across nodes
  3. Development - instant dataset loading during iteration
  4. Large datasets - amortize tokenization cost across many runs

Testing

Tested with:

  • 10K examples (cache validation)
  • 100K examples (slice-based validation)
  • ~940K examples (1M scale validation)

Test scripts available in commit history.

Breaking Changes

None - all changes are backward compatible. Default behavior unchanged.

Notes

  • Cache uses pickle format for fast serialization
  • Cache key is deterministic and collision-resistant (SHA256 hash)
  • Workers spawn cleanly without Ray fork conflicts
  • Atomic writes prevent corruption on NFS

Related Issues

Addresses performance concerns with SFT dataset loading at scale.

@SumanthRH SumanthRH force-pushed the optimize-sft-tokenization branch 2 times, most recently from a25e19d to c937e47 Compare May 20, 2026 20:52
Adds multiprocessing-based parallel tokenization with slice-based HF loading
to eliminate pickle overhead. Includes tokenized dataset caching (pickle) with
NFS support for multi-node training.

New config options: num_workers, cache_dir, force_recache, disable_cache.

Co-Authored-By: SumanthRH <sumanthrh@anyscale.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: jlee-lila <jlee@lila.ai>
@SumanthRH SumanthRH force-pushed the optimize-sft-tokenization branch from c937e47 to dc25e14 Compare May 20, 2026 20:53
SumanthRH added 4 commits May 20, 2026 21:05
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
@SumanthRH SumanthRH marked this pull request as ready for review May 22, 2026 00:57
@SumanthRH
Copy link
Copy Markdown
Member

I've made some improvements:

  1. Cleaned up unused code
  2. Simplied default for cache dir
  3. Changed the save format to be HF dataset since the file can get quite large for 1M rows. Note that the load format with SkyRL is still a list, which means we are loading it in memory. This will be fixed soon in a follow up PR

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements parallel tokenization and on-disk caching for SFT datasets to optimize training startup. It introduces configuration settings for worker counts and cache paths, implements multiprocessing worker functions for data slicing, and adds logic to store tokenized datasets in an arrow-backed format. The PR also includes comprehensive tests for the new parallel processing and caching functionality. Feedback points out concurrency risks in the cache-saving process on shared filesystems and suggests a more robust approach for generating cache keys to prevent potential collisions.

Comment on lines +264 to +273
temp_path = cache_path + ".tmp"
# Clean up any stale temp dir from an interrupted prior run.
if os.path.isdir(temp_path):
shutil.rmtree(temp_path)
dataset.save_to_disk(temp_path)
# If a previous cache exists at the final path, drop it before
# rename so the swap is the only visible state change.
if os.path.isdir(cache_path):
shutil.rmtree(cache_path)
os.rename(temp_path, cache_path)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current cache saving logic is not safe for multi-node training on shared filesystems (NFS). \n\n1. temp_path is not unique across processes or nodes, which can lead to data corruption if multiple workers attempt to write to the same cache simultaneously.\n2. The shutil.rmtree(temp_path) at the start of the write process can delete a directory that another process is currently using.\n3. The shutil.rmtree(cache_path) followed by os.rename is not atomic for directories on many systems, creating a race condition where one process might delete the successful output of another.\n\nA safer approach is to use a unique temporary directory and an atomic os.rename, ensuring that concurrent writes do not interfere with each other.

        temp_path = f"{cache_path}.tmp.{random.getrandbits(64):x}"\n        try:\n            dataset.save_to_disk(temp_path)\n            if os.path.isdir(cache_path):\n                shutil.rmtree(cache_path, ignore_errors=True)\n            os.rename(temp_path, cache_path)\n        finally:\n            if os.path.isdir(temp_path):\n                shutil.rmtree(temp_path, ignore_errors=True)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a rare scenario with concurrent training runs for the same model and dataset name (i.e same cache key).

Comment thread skyrl/train/sft_trainer.py Outdated
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
@SumanthRH SumanthRH merged commit ccc181e into NovaSky-AI:main May 22, 2026
4 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants