In [1]:
%use kotlin-dl

val kinferenceVersion = "0.2.26"
val ktorVersion = "3.0.3"

USE {
    repositories {
        mavenCentral()
        maven("https://packages.jetbrains.team/maven/p/ki/maven")
        maven("https://packages.jetbrains.team/maven/p/grazi/grazie-platform-public")
    }
    dependencies {
        implementation("io.kinference:inference-core-jvm:$kinferenceVersion")
        implementation("io.kinference:inference-ort-jvm:$kinferenceVersion")
        implementation("io.kinference:serializer-protobuf-jvm:$kinferenceVersion")
        implementation("io.kinference:ndarray-core-jvm:$kinferenceVersion")

        implementation("io.ktor:ktor-client-core-jvm:$ktorVersion")
        implementation("io.ktor:ktor-client-cio-jvm:$ktorVersion")
    }
}

In [2]:
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.arrays.FloatNDArray.Companion.invoke
import io.kinference.ort.ORTEngine
import io.kinference.ort.data.tensor.ORTTensor
import io.kinference.utils.CommonDataLoader
import io.kinference.utils.inlines.InlineInt
import io.kinference.utils.toLongArray
import okio.Path.Companion.toPath
import org.jetbrains.kotlinx.dl.dataset.generator.FromFolders
import org.jetbrains.kotlinx.dl.impl.inference.imagerecognition.InputType
import java.awt.image.BufferedImage
import kotlin.collections.mutableMapOf
import io.kinference.core.KIONNXData
import io.kinference.core.data.tensor.KITensor
import io.kinference.ndarray.arrays.LongNDArray
import io.kinference.ndarray.arrays.NumberNDArrayCore
import io.ktor.client.HttpClient
import io.ktor.client.plugins.HttpTimeout
import io.ktor.client.request.prepareRequest
import io.ktor.client.statement.bodyAsChannel
import io.ktor.util.cio.writeChannel
import io.ktor.utils.io.copyAndClose
import java.io.File
import kotlinx.coroutines.runBlocking

In [3]:
/**
 * Directory used to store cached files.
 *
 * This variable combines the user's current working directory
 * with a "cache" subdirectory to create the path for storing cache files.
 * It is used in various functions to check for existing files or directories,
 * create new ones if they do not exist, and manage the caching of downloaded files.
 */
val cacheDirectory = System.getProperty("user.dir") + "/.cache/"


/**
 * Downloads a file from the given URL and saves it with the specified file name.
 *
 * Checks if the directory specified by `cacheDirectory` exists.
 * If not, it creates the directory. If the file already exists,
 * the download is skipped. Otherwise, the file is downloaded
 * using an HTTP client with a 10-minute timeout setting.
 *
 * @param url The URL from which to download the file.
 * @param fileName The name to use for the downloaded file.
 * @param timeout Optional timeout duration for the download request, in milliseconds.
 * Defaults to 600_000 milliseconds (10 minutes).
 * Increase the timeout if you are not sure that download for the particular model with fit into the default timeout.
 */
suspend fun downloadFile(url: String, fileName: String, timeout: Long = 600_000) {
    // Ensure the predefined path is treated as a directory
    val directory = File(cacheDirectory)

    // Check if the directory exists, if not create it
    if (!directory.exists()) {
        println("Predefined directory doesn't exist. Creating directory at $cacheDirectory.")
        directory.mkdirs() // Create the directory if it doesn't exist
    }

    // Check if the file already exists
    val file = File(directory, fileName)
    if (file.exists()) {
        println("File already exists at ${file.absolutePath}. Skipping download.")
        return // Exit the function if the file exists
    }

    // Create an instance of HttpClient with custom timeout settings
    val client = HttpClient {
        install(HttpTimeout) {
            requestTimeoutMillis = timeout
        }
    }

    // Download the file and write to the specified output path
    client.prepareRequest(url).execute { response ->
        response.bodyAsChannel().copyAndClose(file.writeChannel())
    }

    client.close()
}

In [4]:
// Constants for input and output tensor names used in the CaffeNet model
val INPUT_TENSOR_NAME = "data_0"
val OUTPUT_TENSOR_NAME = "prob_1"

In [5]:
// Preprocessing pipeline for input images using KotlinDL
val preprocessing = pipeline<BufferedImage>()
    .resize {
        outputWidth = 224
        outputHeight = 224
        interpolation = InterpolationType.BILINEAR
    }
    .convert { colorMode = ColorMode.BGR }
    .toFloatArray { }
    .call(InputType.CAFFE.preprocessing())

// Path to the small dataset of dogs vs cats images (100 images)
val dogsVsCatsDatasetPath = dogsCatsSmallDatasetPath()

In [6]:
/**
 * Creates a Map of input tensors categorized by their respective classes (e.g., "cat" and "dog").
 *
 * This function reads images from the dataset, preprocesses them,
 * transposes the tensors to the required format, and groups them
 * based on their class label.
 *
 * @return A Map where the keys are the class labels (e.g., "cat" and "dog"),
 * and the values are lists of KITensor objects representing the input tensors
 * for each class.
 */
suspend fun createInputs(): Map<String, List<ORTTensor>> {
    val dataset = OnFlyImageDataset.create(
        File(dogsVsCatsDatasetPath),
        FromFolders(mapping = mapOf("cat" to 0, "dog" to 1)),
        preprocessing
    ).shuffle()


    val tensorShape = intArrayOf(1, 224, 224, 3)        // Original tensor shape is [batch, width, height, channel]
    val permuteAxis = intArrayOf(0, 3, 1, 2)            // Permutations for shape [batch, channel, width, height]
    val inputTensors = mutableMapOf<String, MutableList<ORTTensor>>()

    for (i in 0 until dataset.xSize()) {
        val inputData = dataset.getX(i)
        val inputClass = if (dataset.getY(i).toInt() == 0) "cat" else "dog"
        val floatNDArray = FloatNDArray(tensorShape) { index: InlineInt -> inputData[index.value] }.transpose(permuteAxis)  // Create an NDArray from the image data
        val inputTensor = ORTTensor(floatNDArray.array.toArray(), floatNDArray.shape.toLongArray(), INPUT_TENSOR_NAME)      // Transpose and create a tensor from the NDArray
        inputTensors.putIfAbsent(inputClass, mutableListOf())
        inputTensors[inputClass]!!.add(inputTensor)
    }

    return inputTensors
}

In [7]:
/**
 * Displays the top 5 predictions with their corresponding labels and scores.
 *
 * @param predictions The predicted scores in a multidimensional array format.
 * @param classLabels The list of class labels corresponding to the predictions.
 * @param originalClass The actual class label of the instance being predicted.
 */
fun displayTopPredictions(predictions: ORTTensor, classLabels: List<String>, originalClass: String) {
    val predictionArray = predictions.toFloatArray()
    val indexedScores = predictionArray.withIndex().sortedByDescending { it.value }.take(5)

    println("\nOriginal class: $originalClass")
    println("Top 5 predictions:")
    for ((index, score) in indexedScores) {
        val predictedClassLabel = if (index in classLabels.indices) classLabels[index] else "Unknown"
        println("${predictedClassLabel}: ${"%.2f".format(score * 100)}%")
    }
}

In [8]:
val modelUrl = "https://github.com/onnx/models/raw/main/validated/vision/classification/caffenet/model/caffenet-12.onnx"
val synsetUrl = "https://s3.amazonaws.com/onnx-model-zoo/synset.txt"
val modelName = "CaffeNet"

In [9]:
runBlocking {
    println("Downloading model from: $modelUrl")
    downloadFile(modelUrl, "$modelName.onnx")
    println("Downloading synset from: $synsetUrl")
    downloadFile(synsetUrl, "synset.txt")

    val classLabels = File("$cacheDirectory/synset.txt").readLines()

    println("Loading model...")
    val model = ORTEngine.loadModel("$cacheDirectory/$modelName.onnx".toPath())
    println("Creating inputs...")
    val inputTensors = createInputs()

    println("Starting inference...")
    inputTensors.forEach { dataClass ->
        dataClass.value.forEach { tensor ->
            val actualOutputs = model.predict(listOf(tensor))
            val predictions = actualOutputs[OUTPUT_TENSOR_NAME]!! as ORTTensor
            displayTopPredictions(predictions, classLabels, dataClass.key)
        }
    }
}

Downloading model from: https://github.com/onnx/models/raw/main/validated/vision/classification/caffenet/model/caffenet-12.onnx
Predefined directory doesn't exist. Creating directory at /Users/pavel.gorgulov/Projects/Kotlin/Kotlin-AI-Examples/notebooks/kinference//.cache/.
Downloading synset from: https://s3.amazonaws.com/onnx-model-zoo/synset.txt
Loading model...
Creating inputs...
Starting inference...

Original class: cat
Top 5 predictions:
n03825788 nipple: 12,60%
n02097298 Scotch terrier, Scottish terrier, Scottie: 9,55%
n02094433 Yorkshire terrier: 4,42%
n03944341 pinwheel: 4,40%
n07615774 ice lolly, lolly, lollipop, popsicle: 4,40%

Original class: cat
Top 5 predictions:
n02124075 Egyptian cat: 42,12%
n02123045 tabby, tabby cat: 29,15%
n02123159 tiger cat: 17,08%
n02127052 lynx, catamount: 4,28%
n02123394 Persian cat: 3,40%

Original class: cat
Top 5 predictions:
n02104365 schipperke: 9,79%
n02107312 miniature pinscher: 9,33%
n02105056 groenendael: 4,72%
n02110627 affenpinscher,