feat(ops): add TensorOps.permute(axes) for arbitrary-axis permutation#552
Merged
michalharakal merged 1 commit intodevelopfrom Apr 28, 2026
Merged
feat(ops): add TensorOps.permute(axes) for arbitrary-axis permutation#552michalharakal merged 1 commit intodevelopfrom
michalharakal merged 1 commit intodevelopfrom
Conversation
Closes #551. `TensorOps.transpose` only swaps the last two dimensions. Adds a generic `permute(tensor, axes)` op for arbitrary-axis permutation, matching the standard PyTorch / NumPy semantics: the i-th axis of the result is the `axes[i]`-th axis of the input. Use case (downstream): `MultiHeadAttention` in `SKaiNET-transformers` needed to convert `[seqLen, nHeads, headDim]` to `[nHeads, seqLen, headDim]` after Q/K/V projection. Without `permute`, that landed (PR #81) as a manual copy-based helper allocating a fresh `FloatArray` per call — ~13M floats / forward on a 1024-token TinyLlama prefill (~52 MB). With this op, the helper can be replaced by `ctx.ops.permute(t, intArrayOf(1, 0, 2))` without the manual copy bookkeeping. Implementation: * commonMain: SPI declaration on `TensorOps` (`@Diff`-annotated for KSP adjoint generation, mirroring `transpose`). * `VoidTensorOps`: shape-only stub plus `validatePermuteAxes` / `calculatePermuteShape` helpers (internal). * `DefaultCpuOps`: full data-moving implementation. Identity permute is a no-op return. FloatArrayTensorData fast path uses precomputed input/ output strides and iterates the output flat buffer linearly. Generic fallback uses `dataFactory.init` with element-wise access for non-contiguous backends. * `RecordingTensorOpsDecorator`: passthrough delegate. A dedicated PermuteOperation can be introduced later if the tape consumer needs to distinguish it from raw passthrough. * `DefaultGradientTape`: `permuteBackward` returns `permute(upstream, inverse(axes))`, matching the well-known adjoint rule. * `TestTensorOps` (compile-dag tests): identity stub for the test-only TensorOps mock. Tests (`PermuteTest`, 8 cases): * identity permute returns the input unchanged * dim-0/dim-1 swap on rank-3, element-by-element verification * full reverse on rank-4 * round-trip (`permute(permute(t, axes), inverse) == t`) * equivalence with `transpose` on rank-2 * rejects wrong-length / out-of-range / duplicate axes Out of scope: * Strided/lazy permute (no data copy) — worthwhile once a strided tensor representation lands; for now copy matches existing op semantics. * PermuteOperation tape op for replay-time DSL distinction (passthrough is correct for AD via the adjoint rule). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
📖 Documentation Preview The documentation has been built successfully for this PR. Generated Files:
Artifacts:
This comment will be updated automatically when the PR is updated. |
3 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Closes #551.
Summary
TensorOps.transposeonly swaps the last two dims. This PR addspermute(tensor, axes)for arbitrary-axis permutation (PyTorch/NumPy semantics: the i-th result axis is theaxes[i]-th input axis).@Diff-annotated. KSP regeneratesDifferentiableTensorOps.permuteBackward.DefaultGradientTapeimplements the adjoint aspermute(upstream, inverse(axes)).DefaultCpuOpsships an identity-skip + FloatArray fast path + genericdataFactory.initfallback.VoidTensorOpsships the shape-only stub.Why
Downstream
SKaiNET-transformersPR #81 had to add a copy-basedswapSeqHeadDimshelper insideMultiHeadAttentionto fix aseqLen > 1reshape bug (the multi-head reshape was a stride reinterpretation, not a data permute). With nopermuteavailable, that helper allocates a freshFloatArrayper attention call — ~52 MB of transient allocation on a 1024-token TinyLlama prefill. With this op landed upstream, the helper can be replaced byctx.ops.permute(t, intArrayOf(1, 0, 2)).Test plan
PermuteTest(8 cases, all green on:skainet-backends:skainet-backend-cpu:jvmTest):permute(permute(t, axes), inverse) == t)transposePlus pre-existing test suites for
:skainet-lang:skainet-lang-core,:skainet-backends:skainet-backend-cpu,:skainet-compile:skainet-compile-dagall still green.Out of scope
transposealso copies).PermuteOperationfor the recording tape — passthrough is correct for AD via the adjoint rule. Can be added if a tape consumer needs to distinguish it from raw passthrough.🤖 Generated with Claude Code