Skip to content

feat(ops): add TensorOps.permute(axes) for arbitrary-axis permutation#552

Merged
michalharakal merged 1 commit intodevelopfrom
feature/ISSUE-551-permute-axes
Apr 28, 2026
Merged

feat(ops): add TensorOps.permute(axes) for arbitrary-axis permutation#552
michalharakal merged 1 commit intodevelopfrom
feature/ISSUE-551-permute-axes

Conversation

@michalharakal
Copy link
Copy Markdown
Contributor

Closes #551.

Summary

  • TensorOps.transpose only swaps the last two dims. This PR adds permute(tensor, axes) for arbitrary-axis permutation (PyTorch/NumPy semantics: the i-th result axis is the axes[i]-th input axis).
  • The op is @Diff-annotated. KSP regenerates DifferentiableTensorOps.permuteBackward. DefaultGradientTape implements the adjoint as permute(upstream, inverse(axes)).
  • DefaultCpuOps ships an identity-skip + FloatArray fast path + generic dataFactory.init fallback. VoidTensorOps ships the shape-only stub.

Why

Downstream SKaiNET-transformers PR #81 had to add a copy-based swapSeqHeadDims helper inside MultiHeadAttention to fix a seqLen > 1 reshape bug (the multi-head reshape was a stride reinterpretation, not a data permute). With no permute available, that helper allocates a fresh FloatArray per attention call — ~52 MB of transient allocation on a 1024-token TinyLlama prefill. With this op landed upstream, the helper can be replaced by ctx.ops.permute(t, intArrayOf(1, 0, 2)).

Test plan

PermuteTest (8 cases, all green on :skainet-backends:skainet-backend-cpu:jvmTest):

  • identity permute (returns input unchanged)
  • dim-0/dim-1 swap on rank-3, element-by-element
  • full reverse on rank-4
  • round-trip equivalence (permute(permute(t, axes), inverse) == t)
  • rank-2 equivalence with transpose
  • rejects wrong-length axes
  • rejects out-of-range axis
  • rejects duplicate axis

Plus pre-existing test suites for :skainet-lang:skainet-lang-core, :skainet-backends:skainet-backend-cpu, :skainet-compile:skainet-compile-dag all still green.

Out of scope

  • Strided/lazy permute (no data copy). Worthwhile once a strided tensor representation lands; for now copy matches existing op semantics (transpose also copies).
  • A dedicated PermuteOperation for the recording tape — passthrough is correct for AD via the adjoint rule. Can be added if a tape consumer needs to distinguish it from raw passthrough.

🤖 Generated with Claude Code

Closes #551.

`TensorOps.transpose` only swaps the last two dimensions. Adds a generic
`permute(tensor, axes)` op for arbitrary-axis permutation, matching the
standard PyTorch / NumPy semantics: the i-th axis of the result is the
`axes[i]`-th axis of the input.

Use case (downstream): `MultiHeadAttention` in `SKaiNET-transformers`
needed to convert `[seqLen, nHeads, headDim]` to `[nHeads, seqLen,
headDim]` after Q/K/V projection. Without `permute`, that landed (PR #81)
as a manual copy-based helper allocating a fresh `FloatArray` per call —
~13M floats / forward on a 1024-token TinyLlama prefill (~52 MB). With
this op, the helper can be replaced by `ctx.ops.permute(t, intArrayOf(1,
0, 2))` without the manual copy bookkeeping.

Implementation:

* commonMain: SPI declaration on `TensorOps` (`@Diff`-annotated for KSP
  adjoint generation, mirroring `transpose`).
* `VoidTensorOps`: shape-only stub plus `validatePermuteAxes` /
  `calculatePermuteShape` helpers (internal).
* `DefaultCpuOps`: full data-moving implementation. Identity permute is
  a no-op return. FloatArrayTensorData fast path uses precomputed input/
  output strides and iterates the output flat buffer linearly. Generic
  fallback uses `dataFactory.init` with element-wise access for
  non-contiguous backends.
* `RecordingTensorOpsDecorator`: passthrough delegate. A dedicated
  PermuteOperation can be introduced later if the tape consumer needs
  to distinguish it from raw passthrough.
* `DefaultGradientTape`: `permuteBackward` returns
  `permute(upstream, inverse(axes))`, matching the well-known adjoint
  rule.
* `TestTensorOps` (compile-dag tests): identity stub for the test-only
  TensorOps mock.

Tests (`PermuteTest`, 8 cases):
* identity permute returns the input unchanged
* dim-0/dim-1 swap on rank-3, element-by-element verification
* full reverse on rank-4
* round-trip (`permute(permute(t, axes), inverse) == t`)
* equivalence with `transpose` on rank-2
* rejects wrong-length / out-of-range / duplicate axes

Out of scope:
* Strided/lazy permute (no data copy) — worthwhile once a strided tensor
  representation lands; for now copy matches existing op semantics.
* PermuteOperation tape op for replay-time DSL distinction (passthrough
  is correct for AD via the adjoint rule).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@michalharakal michalharakal merged commit b1f8b15 into develop Apr 28, 2026
9 checks passed
@github-actions
Copy link
Copy Markdown

📖 Documentation Preview

The documentation has been built successfully for this PR.

Generated Files:

  • Operator documentation: docs/modules/operators/_generated_/
  • JSON schema output: operators.json

Artifacts:

  • Download the documentation-preview-552 artifact to view the complete documentation locally.

This comment will be updated automatically when the PR is updated.

@michalharakal michalharakal deleted the feature/ISSUE-551-permute-axes branch April 28, 2026 20:57
@michalharakal michalharakal mentioned this pull request Apr 28, 2026
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add TensorOps.permute(axes) for arbitrary-axis permutation

1 participant