Skip to content

Add KernelProvider SPI for matmul/SDPA dispatch (Scalar baseline) #553

@michalharakal

Description

@michalharakal

Context

TensorOps defines high-level operations (matmul, scaledDotProductAttention, etc.) that backends like DefaultCpuOps implement directly. There's no dispatch layer between "the high-level op" and "the kernel that actually does the FLOPs" — every backend has to ship its own monolithic implementation of every op.

This makes it hard to:

  • Plug in a SIMD-accelerated matmul while keeping the rest of DefaultCpuOps unchanged.
  • Swap a Panama Vector kernel in/out at runtime without recompiling the backend.
  • Test a hand-written kernel against a scalar reference.
  • Ship a native (FFM) kernel as an optional artifact without baking it into the CPU backend.

A small SPI layer between TensorOps and the actual numeric kernels solves this. Backends keep their op-level wiring; performance-sensitive kernels (matmul, SDPA) become pluggable.

Proposal

A new sk.ainet.backend.api.kernel package (in skainet-backend-api) with:

public interface KernelProvider {
    public val name: String
    public val priority: Int
    public fun isAvailable(): Boolean
    public fun matmulFp32(): Fp32MatmulKernel?
    // Future: sdpaFp32(), matmulQ4K(), matmulQ8(), ...
}

public interface Fp32MatmulKernel {
    /**
     * C(m,n) = A(m,k) · B(k,n)  in row-major layout.
     * Strides are in floats (not bytes); they let callers pass sub-blocks
     * of larger arrays without copying.
     */
    public fun matmul(
        a: FloatArray, aOffset: Int, aStride: Int,
        b: FloatArray, bOffset: Int, bStride: Int,
        out: FloatArray, outOffset: Int, outStride: Int,
        m: Int, n: Int, k: Int
    )
}

Plus a simple KernelRegistry (expect/actual; JVM uses ServiceLoader, native uses manual register — same shape as MultiplatformDispatcher-style registries used elsewhere in the ecosystem).

A ScalarKernelProvider lives in skainet-backend-cpu. It's the always-available correctness reference: a triple-nested-loop matmul with the strides honoured. Higher-priority providers (Panama, native FFM) ship in follow-up issues/PRs.

Scope of this issue

  • In: SPI types in skainet-backend-api, ScalarKernelProvider + ScalarMatmulKernel in skainet-backend-cpu, KernelRegistry (simple expect/actual), unit tests.
  • In: registry-driven discovery on JVM (ServiceLoader) and manual register on native — but no JVM service file ships yet because there's only one provider.
  • Out: Panama Vector matmul (separate issue / PR).
  • Out: Native FFM matmul (separate issue / PR).
  • Out: Wiring DefaultCpuOps.matmul to consult the registry — that's a downstream-of-this-issue change once at least one accelerated provider exists. For now ScalarKernelProvider is reachable but unused by the existing op layer.
  • Out: SDPA kernel API (will land alongside an SDPA-specific accelerator).
  • Out: Quantized kernels (Q4_K, Q8) — separate issue once the existing native matmul work in the codebase has a stable interface.

Acceptance

  • API compiles on all KMP targets.
  • ScalarMatmulKernelTest covers small/medium shapes, strided sub-blocks, and rejects incompatible dimensions.
  • KernelRegistry.bestAvailable() returns ScalarKernelProvider when no other provider is registered.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions