Skip to content

feat(safetensors): loadTensorStorageMapped on the sharded reader#582

Merged
michalharakal merged 1 commit intodevelopfrom
feature/sharded-loadtensorstoragemapped
Apr 30, 2026
Merged

feat(safetensors): loadTensorStorageMapped on the sharded reader#582
michalharakal merged 1 commit intodevelopfrom
feature/sharded-loadtensorstoragemapped

Conversation

@michalharakal
Copy link
Copy Markdown
Contributor

Summary

Adds loadTensorStorageMapped to StreamingShardedSafeTensorsReader, mirroring the existing single-file StreamingSafeTensorsReader.loadTensorStorageMapped(tensor, filePath). Two overloads: by ShardedTensorInfo and by tensor name. Both return a TensorStorage whose BufferHandle.FileBacked references the resolved shard file's tensor byte range — enabling zero-copy / memory-mapped reads of tensors that exceed the 2 GB JVM ByteArray limit.

The new methods delegate internally to the per-shard reader; the caller doesn't need to know which physical shard contains a given tensor.

Motivation: SKaiNET-transformers Gemma4SafeTensorsMappedPle currently opens a FileChannel and computes the byte range itself to mmap the Gemma 4 PLE token-embedding table (~4.7 GB BF16 on E2B, well past the 2 GB ByteArray cap). Once this lands and a release ships, that downstream code drops ~30 lines of JVM mmap glue and consumes the upstream TensorStorage directly.

Test plan

  • ./gradlew :skainet-io:skainet-io-safetensors:jvmTest — passes in 6 s including new tests.
  • New StreamingShardedSafeTensorsReaderJvmTest covers:
    • by-name and by-ShardedTensorInfo overloads against a real single-shard SafeTensors fixture (built via SafeTensorsWriter)
    • TensorStorage.shape, isFileBacked, BufferHandle.FileBacked.path, and sizeInBytes correctness
    • IllegalArgumentException on unknown tensor name

🤖 Generated with Claude Code

Mirrors the existing single-file
`StreamingSafeTensorsReader.loadTensorStorageMapped(tensor, filePath)`
on the sharded reader, removing the need for callers to know which
physical shard contains a given tensor.

Adds two overloads to `StreamingShardedSafeTensorsReader`:

  - `loadTensorStorageMapped(tensor: ShardedTensorInfo): TensorStorage`
  - `loadTensorStorageMapped(name: String): TensorStorage`

Both return a `TensorStorage` whose `BufferHandle.FileBacked` references
the resolved shard file's tensor byte range — enabling zero-copy /
memory-mapped reads of tensors that exceed the 2 GB JVM `ByteArray`
limit (used by the Gemma 4 PLE token-embedding table; ~4.7 GB BF16 on
E2B). Internally the new methods delegate to the per-shard reader's
existing `loadTensorStorageMapped(streamingTensor, filePath)`.

Adds end-to-end coverage in the new
`StreamingShardedSafeTensorsReaderJvmTest`:

  - Build a real single-shard SafeTensors file via `SafeTensorsWriter`,
    hand-craft a `model.safetensors.index.json` referencing it, open
    via the sharded reader, assert `loadTensorStorageMapped` returns
    the expected shape and a `BufferHandle.FileBacked` pointing at the
    shard with the right size in bytes.
  - Confirm the by-name overload errors with `IllegalArgumentException`
    for an unknown tensor.

Motivation: SKaiNET-transformers `Gemma4SafeTensorsMappedPle` currently
opens a `FileChannel` and computes the byte range itself; this upstream
API will let it drop ~30 lines of mmap glue once a release lands.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@michalharakal michalharakal merged commit aa77b4d into develop Apr 30, 2026
6 checks passed
@michalharakal michalharakal deleted the feature/sharded-loadtensorstoragemapped branch April 30, 2026 15:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant