In [1]:
val kinferencerVersion = "0.2.26"
val ktorVersion = "3.0.3"

USE {
    repositories {
        mavenCentral()
        maven("https://packages.jetbrains.team/maven/p/ki/maven")
        maven("https://packages.jetbrains.team/maven/p/grazi/grazie-platform-public")
    }
    dependencies {
        implementation("io.kinference:inference-core-jvm:$kinferencerVersion")
        implementation("io.kinference:inference-ort-jvm:$kinferencerVersion")
        implementation("io.kinference:serializer-protobuf-jvm:$kinferencerVersion")
        implementation("io.kinference:utils-common-jvm:$kinferencerVersion")
        implementation("io.kinference:ndarray-core-jvm:$kinferencerVersion")

        implementation("io.ktor:ktor-client-core-jvm:$ktorVersion")
        implementation("io.ktor:ktor-client-cio-jvm:$ktorVersion")

        implementation("org.slf4j:slf4j-api:2.0.9")
        implementation("org.slf4j:slf4j-simple:2.0.9")

        implementation("ai.djl:api:0.28.0")
        implementation("ai.djl.huggingface:tokenizers:0.28.0")
    }
}

In [2]:
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer
import io.kinference.core.data.tensor.KITensor
import io.kinference.core.data.tensor.asTensor
import io.kinference.ndarray.arrays.FloatNDArray
import io.kinference.ndarray.arrays.FloatNDArray.Companion.invoke
import io.kinference.ort.ORTData
import io.kinference.ort.ORTEngine
import io.kinference.ort.data.tensor.ORTTensor
import io.kinference.utils.CommonDataLoader
import io.kinference.utils.inlines.InlineInt
import io.kinference.utils.toIntArray
import okio.Path.Companion.toPath
import io.kinference.core.KIONNXData
import io.kinference.ndarray.arrays.LongNDArray
import io.kinference.ndarray.arrays.NumberNDArrayCore
import io.ktor.client.HttpClient
import io.ktor.client.plugins.HttpTimeout
import io.ktor.client.request.prepareRequest
import io.ktor.client.statement.bodyAsChannel
import io.ktor.util.cio.writeChannel
import io.ktor.utils.io.copyAndClose
import java.io.File
import kotlinx.coroutines.runBlocking

In [3]:
/**
 * Directory used to store cached files.
 *
 * This variable combines the user's current working directory
 * with a "cache" subdirectory to create the path for storing cache files.
 * It is used in various functions to check for existing files or directories,
 * create new ones if they do not exist, and manage the caching of downloaded files.
 */
val cacheDirectory = System.getProperty("user.dir") + "/.cache/"

/**
 * Downloads a file from the given URL and saves it with the specified file name.
 *
 * Checks if the directory specified by `cacheDirectory` exists.
 * If not, it creates the directory. If the file already exists,
 * the download is skipped. Otherwise, the file is downloaded
 * using an HTTP client with a 10-minute timeout setting.
 *
 * @param url The URL from which to download the file.
 * @param fileName The name to use for the downloaded file.
 * @param timeout Optional timeout duration for the download request, in milliseconds.
 * Defaults to 600,000 milliseconds (10 minutes).
 * Increase the timeout if you are not sure that download for the particular model with fit into the default timeout.
 */
suspend fun downloadFile(url: String, fileName: String, timeout: Long = 600_000) {
    // Ensure the predefined path is treated as a directory
    val directory = File(cacheDirectory)

    // Check if the directory exists, if not create it
    if (!directory.exists()) {
        println("Predefined directory doesn't exist. Creating directory at $cacheDirectory.")
        directory.mkdirs() // Create the directory if it doesn't exist
    }

    // Check if the file already exists
    val file = File(directory, fileName)
    if (file.exists()) {
        println("File already exists at ${file.absolutePath}. Skipping download.")
        return // Exit the function if the file exists
    }

    // Create an instance of HttpClient with custom timeout settings
    val client = HttpClient {
        install(HttpTimeout) {
            requestTimeoutMillis = timeout
        }
    }

    // Download the file and write to the specified output path
    client.prepareRequest(url).execute { response ->
        response.bodyAsChannel().copyAndClose(file.writeChannel())
    }

    client.close()
}

/**
 * Extracts the token ID with the highest probability from the output tensor.
 *
 * @param output A map containing the output tensors identified by their names.
 * @param tokensSize The number of tokens in the sequence.
 * @param outputName The name of the tensor containing the logits.
 * @return The ID of the top token.
 */
suspend fun extractTopToken(output: Map<String, KIONNXData<*>>, tokensSize: Int, outputName: String): Long {
    val logits = output[outputName]!! as KITensor
    val sliced = logits.data.slice(
        starts = intArrayOf(0, 0, tokensSize - 1, 0),   // First batch, first element in the second dimension, last token, first vocab entry
        ends = intArrayOf(1, 1, tokensSize, 50257),     // Same batch, same second dimension, one token step, whole vocab (50257)
        steps = intArrayOf(1, 1, 1, 1)                  // Step of 1 for each dimension
    ) as NumberNDArrayCore
    val softmax = sliced.softmax(axis = -1)
    val topK = softmax.topK(
        axis = -1,                                      // Apply top-k along the last dimension (vocabulary size)
        k = 1,                                          // Retrieve the top 1 element
        largest = true,                                 // We want the largest probabilities (most probable tokens)
        sorted = false                                  // Sorting is unnecessary since we are only retrieving the top 1
    )
    val tokenId = (topK.second as LongNDArray)[intArrayOf(0, 0, 0, 0)]

    return tokenId
}

suspend fun convertToKITensorMap(outputs: Map<String, ORTData<*>>): Map<String, KITensor> {
    return outputs.map { (name, ortTensor) ->
        val ortTensor = ortTensor as ORTTensor
        val data = ortTensor.toFloatArray()
        val shape = ortTensor.shape.toIntArray()
        val ndArray = FloatNDArray(shape) { idx: InlineInt -> data[idx.value] }
        val kiTensor = ndArray.asTensor(name)
        return@map name to kiTensor
    }.toMap()
}

In [4]:
// Constants for input and output tensor names used in the GPT-2 model
val INPUT_TENSOR_NAME = "input1"
val OUTPUT_TENSOR_NAME = "output1" // We use only logits tensor

In [5]:
val modelUrl = "https://github.com/onnx/models/raw/main/validated/text/machine_comprehension/gpt-2/model/gpt2-lm-head-10.onnx"
val modelName = "gpt2-lm-head-10"


In [6]:
runBlocking {
    println("Downloading model from: $modelUrl")
    downloadFile(modelUrl, "$modelName.onnx") //GPT-2 from model zoo is around 650 Mb, adjust your timeout if needed

    println("Loading model...")
    val model = ORTEngine.loadModel("$cacheDirectory/$modelName.onnx".toPath())

    val tokenizer = HuggingFaceTokenizer.newInstance("gpt2", mapOf("modelMaxLength" to "1024"))
    val testString = "Neurogenesis is most active during embryonic development and is responsible for producing " +
            "all the various types of neurons of the organism, but it continues throughout adult life " +
            "in a variety of organisms. Once born, neurons do not divide (see mitosis), and many will " +
            "live the lifespan of the animal, except under extraordinary and usually pathogenic circumstances."
    val encoded = tokenizer.encode(testString)
    val tokens = encoded.ids
    val tokensSize = tokens.size

    val predictionLength = 34
    val outputTokens = LongArray(predictionLength) { 0 }

    val input = ORTTensor(tokens, longArrayOf(1, 1, tokensSize.toLong()))
    var currentContext = input.clone(INPUT_TENSOR_NAME)

    print("Here goes the test text for generation:\n$testString")

    for (idx in 0 until predictionLength) {
        val inputTensor = listOf(currentContext)
        val output = model.predict(inputTensor)

        outputTokens[idx] = extractTopToken(convertToKITensorMap(output), tokensSize + idx, OUTPUT_TENSOR_NAME)

        val newTokenArray = tokens + outputTokens.slice(IntRange(0, idx))
        currentContext = ORTTensor(newTokenArray, longArrayOf(1, 1, tokensSize + idx + 1L), INPUT_TENSOR_NAME)
        print(tokenizer.decode(longArrayOf(outputTokens[idx])))
    }
    println("\n\nDone")
}

Downloading model from: https://github.com/onnx/models/raw/main/validated/text/machine_comprehension/gpt-2/model/gpt2-lm-head-10.onnx
Loading model...
Here goes the test text for generation:
Neurogenesis is most active during embryonic development and is responsible for producing all the various types of neurons of the organism, but it continues throughout adult life in a variety of organisms. Once born, neurons do not divide (see mitosis), and many will live the lifespan of the animal, except under extraordinary and usually pathogenic circumstances.

The most common type of neurogenesis is the development of the hippocampus, which is the area of the brain that contains the hippocampus's electrical and chemical signals.

Done
