Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package sk.ainet.backend.api.kernel

/**
* JVM-only sibling of [KernelProvider] for kernels whose interface
* surface depends on `java.lang.foreign.MemorySegment`. Kept separate
* because [KernelProvider] lives in `commonMain` — adding
* `MemorySegment` accessors there would break Kotlin/Native, JS, and
* Wasm targets.
*
* Providers that ship MemSeg-input kernels declare both interfaces:
*
* ```kotlin
* public object MyProvider : KernelProvider, MemSegKernelProvider { ... }
* ```
*
* Lookup pattern at the call site:
*
* ```kotlin
* val kernel = (KernelRegistry.bestAvailable() as? MemSegKernelProvider)
* ?.matmulQ4KMemSeg()
* ?: fallbackHeapPath()
* ```
*
* No automatic registry lookup helper for now — the smart-cast is
* sufficient and avoids a second registry. If a third MemSeg surface
* lands (FP32 matmul-MemSeg, Q6_K matmul-MemSeg, ...) it joins this
* interface as another `null`-defaulting accessor.
*/
public interface MemSegKernelProvider {
/**
* F32 × Q4_K matmul-MemSeg kernel exposed by this provider, or
* `null` if this provider does not specialize the MemSeg path.
* Default returns `null` so providers that pre-date the MemSeg SPI
* keep compiling.
*/
public fun matmulQ4KMemSeg(): Q4KMemSegMatmulKernel? = null
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package sk.ainet.backend.api.kernel

import java.lang.foreign.MemorySegment

/**
* F32 input × Q4_K-packed weights matrix-vector multiply where the
* **weight tensor is supplied as a `java.lang.foreign.MemorySegment`**
* rather than a heap [ByteArray]. JVM-only sibling of [Q4KMatmulKernel].
*
* Use this kernel when the Q4_K weight bytes already live in an
* off-heap segment — typically because they were `mmap`'d from a
* `.gguf` / `.safetensors` file, or because they were materialized
* into an `Arena.ofShared` segment at load time. Letting a backend
* read those bytes directly avoids the staging copy that
* [Q4KMatmulKernel.matmul] does on every call (heap `ByteArray` →
* temporary off-heap segment → native).
*
* The block layout, scale-pair packing, and lazy-`dmin` math are
* identical to [Q4KMatmulKernel] (canonical ggml super-block, 256
* elements, 144 bytes/block; see that kernel's kdoc for the byte
* map). Implementations MUST NOT mutate `input` or `weight`, MUST
* fully write `outputDim` floats starting at `output[outputOffset]`,
* and MAY assume no aliasing between the inputs and the output.
*
* Lifetime contract: the caller owns the [weight] segment's [Arena].
* The kernel must not retain pointers past the [matmul] call return —
* no asynchronous reads, no caching of dereferenced addresses across
* calls. Callers in turn must keep the segment's arena alive for the
* duration of the call.
*/
public interface Q4KMemSegMatmulKernel {
/**
* @param input FP32 input vector (single row), heap array.
* @param inputOffset element offset into [input] where the row starts.
* @param weight off-heap `MemorySegment` holding the packed Q4_K
* weights for the full `outputDim × inputDim` tensor in canonical
* block-major layout `(blockIdx * outputDim + o) * 144` bytes.
* @param weightByteOffset byte offset into [weight] where block
* `(0, 0)` starts.
* @param inputDim contraction dimension; must be a multiple of 256.
* @param outputDim number of output cells.
* @param output FP32 output vector, heap array.
* @param outputOffset element offset into [output] where the row
* starts.
*/
public fun matmul(
input: FloatArray, inputOffset: Int,
weight: MemorySegment, weightByteOffset: Long,
inputDim: Int, outputDim: Int,
output: FloatArray, outputOffset: Int,
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,34 @@ package sk.ainet.exec.kernel

import sk.ainet.backend.api.kernel.Fp32MatmulKernel
import sk.ainet.backend.api.kernel.KernelProvider
import sk.ainet.backend.api.kernel.MemSegKernelProvider
import sk.ainet.backend.api.kernel.Q4KMatmulKernel
import sk.ainet.backend.api.kernel.Q4KMemSegMatmulKernel

/**
* Native (FFM) [KernelProvider]. Sits at priority `100`, above
* [PanamaVectorKernelProvider] (`50`) and the scalar reference (`0`).
* Native (FFM) [KernelProvider] / [MemSegKernelProvider]. Sits at
* priority `100`, above [PanamaVectorKernelProvider] (`50`) and the
* scalar reference (`0`).
*
* Availability is gated on [NativeQ4KMatmulKernel.isAvailable] — the
* bundled `libskainet_kernels` shared library has to load AND the
* `skainet_q4k_matmul` symbol has to resolve via FFM. When either
* fails (missing arch, sandbox, JDK without FFM, kill-switch),
* bundled `libskainet_kernels` shared library has to load AND
* `skainet_q4k_matmul` has to resolve via FFM. When either fails
* (missing arch, sandbox, JDK without FFM, kill-switch),
* `KernelRegistry.bestAvailable()` cleanly cascades to
* [PanamaVectorKernelProvider] at priority 50.
*
* PR 2 of the staged rollout: real Q4_K matmul wired into the SPI.
* `matmulFp32` follows in a later PR alongside a native FP32 kernel.
* The MemSeg surface ([matmulQ4KMemSeg]) is the JVM-only zero-copy
* path for mmap'd Q4_K weights — sized for inference loops that
* project against pre-loaded `MemorySegment`-backed tensors. Heap
* callers stick with [matmulQ4K]; both wrap the same C symbol so
* outputs are bit-for-bit identical.
*
* Staged rollout cursor (see `native-ffm-plan` asciidoc):
* - PR 2: real Q4_K matmul wired into the heap SPI.
* - PR 3 (this commit): MemSeg-input zero-copy sibling.
* - Later: native `matmulFp32`, `matmulQ6K`, `matmulQ8_0`.
*/
public object NativeKernelProvider : KernelProvider {
public object NativeKernelProvider : KernelProvider, MemSegKernelProvider {
override val name: String = "native-ffm"
override val priority: Int = 100

Expand All @@ -28,4 +39,7 @@ public object NativeKernelProvider : KernelProvider {

override fun matmulQ4K(): Q4KMatmulKernel? =
if (NativeQ4KMatmulKernel.isAvailable()) NativeQ4KMatmulKernel else null

override fun matmulQ4KMemSeg(): Q4KMemSegMatmulKernel? =
if (NativeQ4KMemSegMatmulKernel.isAvailable()) NativeQ4KMemSegMatmulKernel else null
}
Original file line number Diff line number Diff line change
@@ -1,16 +1,33 @@
package sk.ainet.exec.kernel

import sk.ainet.backend.api.kernel.KernelProvider
import sk.ainet.backend.api.kernel.MemSegKernelProvider

/**
* `ServiceLoader`-friendly wrapper around [NativeKernelProvider]. The
* platform `ServiceLoader` machinery requires a public no-arg
* constructor, which a Kotlin `object` does not expose; this factory
* delegates every [KernelProvider] member back to the singleton.
* delegates every [KernelProvider] / [MemSegKernelProvider] member
* back to the singleton.
*
* Implementing both interfaces here matters for the MemSeg lookup
* pattern at the call site:
*
* ```kotlin
* val provider = KernelRegistry.bestAvailable() // KernelProvider
* val memSeg = (provider as? MemSegKernelProvider) // smart-cast
* ?.matmulQ4KMemSeg()
* ```
*
* Without the second `by`, the factory instance the registry hands out
* wouldn't satisfy the smart-cast even though the underlying singleton
* implements both interfaces.
*
* Listed in
* `META-INF/services/sk.ainet.backend.api.kernel.KernelProvider` so
* `KernelServiceLoader.installAll()` discovers the provider on JVM
* startup.
*/
public class NativeKernelProviderFactory : KernelProvider by NativeKernelProvider
public class NativeKernelProviderFactory :
KernelProvider by NativeKernelProvider,
MemSegKernelProvider by NativeKernelProvider
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package sk.ainet.exec.kernel

import java.lang.foreign.Arena
import java.lang.foreign.FunctionDescriptor
import java.lang.foreign.Linker
import java.lang.foreign.MemorySegment
import java.lang.foreign.ValueLayout
import java.lang.invoke.MethodHandle
import sk.ainet.backend.api.kernel.Q4KMemSegMatmulKernel

/**
* Zero-copy native [Q4KMemSegMatmulKernel] implementation.
*
* Reuses the same `skainet_q4k_matmul` C symbol as
* [NativeQ4KMatmulKernel] — the C side just sees `const uint8_t*` and
* doesn't care whether the Kotlin caller backed those bytes by a
* staged copy of a `ByteArray` or by an mmap'd off-heap segment. The
* win on this path is that the weight bytes (which dominate the
* payload — typical LLM Q4_K tensor: tens to hundreds of MB per layer)
* never round-trip through the heap.
*
* Per-call cost vs [NativeQ4KMatmulKernel]:
* - skips `MemorySegment.copy(weight, ...)` of `inputDim/256 * outputDim
* * 144` bytes (e.g. 9 MB at 4096² shape).
* - still copies `inputDim * 4` bytes for the input vector and
* `outputDim * 4` bytes for the output — the input/output are
* typically heap arrays produced/consumed by the surrounding
* forward pass.
*
* PR 3 of the staged native-FFM rollout — see the `native-ffm-plan`
* asciidoc.
*/
internal object NativeQ4KMemSegMatmulKernel : Q4KMemSegMatmulKernel {

private const val BLOCK_SIZE = 256

fun isAvailable(): Boolean = handle != null

override fun matmul(
input: FloatArray, inputOffset: Int,
weight: MemorySegment, weightByteOffset: Long,
inputDim: Int, outputDim: Int,
output: FloatArray, outputOffset: Int,
) {
require(inputDim % BLOCK_SIZE == 0) {
"NativeQ4KMemSegMatmulKernel: inputDim must be a multiple of $BLOCK_SIZE; got $inputDim"
}
require(weightByteOffset >= 0) {
"NativeQ4KMemSegMatmulKernel: weightByteOffset must be non-negative; got $weightByteOffset"
}
require(weightByteOffset <= Int.MAX_VALUE) {
"NativeQ4KMemSegMatmulKernel: weightByteOffset $weightByteOffset exceeds Int range — " +
"the C kernel takes int32_t today; slice the segment first or wait for the int64_t overload"
}
if (outputDim == 0 || inputDim == 0) return
val mh = handle
?: error("NativeQ4KMemSegMatmulKernel.matmul invoked while native library unavailable")

// The C kernel reads weight from offset 0..weightBytesUsed, so
// require that the caller's segment is large enough. This catches
// scope/aliasing bugs early; without it, an undersized segment
// would crash the JVM with SIGSEGV from native code.
val weightBytesUsed = ((inputDim / BLOCK_SIZE).toLong() * outputDim) * 144L
require(weightByteOffset + weightBytesUsed <= weight.byteSize()) {
"NativeQ4KMemSegMatmulKernel: weight segment too small — needs " +
"$weightBytesUsed bytes from offset $weightByteOffset, " +
"segment is ${weight.byteSize()} bytes"
}

Arena.ofConfined().use { arena ->
val inSeg = arena.allocate(
inputDim.toLong() * java.lang.Float.BYTES,
ValueLayout.JAVA_FLOAT.byteAlignment(),
)
val outSeg = arena.allocate(
outputDim.toLong() * java.lang.Float.BYTES,
ValueLayout.JAVA_FLOAT.byteAlignment(),
)
MemorySegment.copy(input, inputOffset, inSeg, ValueLayout.JAVA_FLOAT, 0L, inputDim)

mh.invoke(
inSeg, 0,
weight, weightByteOffset.toInt(),
inputDim, outputDim,
outSeg, 0,
)

MemorySegment.copy(outSeg, ValueLayout.JAVA_FLOAT, 0L, output, outputOffset, outputDim)
}
}

private val handle: MethodHandle? by lazy {
val lookup = NativeLibraryLoader.lookup() ?: return@lazy null
val symbol = lookup.find("skainet_q4k_matmul").orElse(null) ?: return@lazy null
val descriptor = FunctionDescriptor.ofVoid(
ValueLayout.ADDRESS, // input
ValueLayout.JAVA_INT, // input_offset
ValueLayout.ADDRESS, // weight (passed straight through from caller)
ValueLayout.JAVA_INT, // weight_byte_offset
ValueLayout.JAVA_INT, // input_dim
ValueLayout.JAVA_INT, // output_dim
ValueLayout.ADDRESS, // output
ValueLayout.JAVA_INT, // output_offset
)
runCatching { Linker.nativeLinker().downcallHandle(symbol, descriptor) }.getOrNull()
}
}
Loading
Loading