In [1]:
// Basic tensor operations using Kotlin arrays
typealias Matrix = Array<FloatArray>
typealias Vector = FloatArray

class MathUtils {
    companion object {
        fun sigmoid(x: Float): Float = 1.0f / (1.0f + kotlin.math.exp(-x))

        fun tanh(x: Float): Float = kotlin.math.tanh(x.toDouble()).toFloat()

        fun matrixMultiply(a: Matrix, b: Matrix): Matrix {
            val rows = a.size
            val cols = b[0].size
            val n = b.size
            return Array(rows) { i ->
                FloatArray(cols) { j ->
                    var sum = 0f
                    for (k in 0 until n) {
                        sum += a[i][k] * b[k][j]
                    }
                    sum
                }
            }
        }

        fun vectorAdd(a: Vector, b: Vector): Vector {
            return FloatArray(a.size) { i -> a[i] + b[i] }
        }
    }
}

class LiquidTimeConstant(
    private val inputSize: Int,
    private val hiddenSize: Int
) {
    // Weights initialized with random values
    private val weightsInput = Array(inputSize + hiddenSize) {
        FloatArray(hiddenSize) { kotlin.random.Random.nextFloat() - 0.5f }
    }
    private val weightsHidden = Array(hiddenSize) {
        FloatArray(hiddenSize) { kotlin.random.Random.nextFloat() - 0.5f }
    }
    private val biases = FloatArray(hiddenSize) { 0f }

    fun forward(x: Vector, h: Vector, t: Float): Pair<Vector, Vector> {
        // Combine input and hidden state
        val combined = x + h

        // Forward pass through simple feed-forward network
        val features = MathUtils.matrixMultiply(arrayOf(combined), weightsInput)[0]
        val activated = features.map { MathUtils.tanh(it) }.toFloatArray()

        // Time-dependent gating
        val timeGate = activated.map { MathUtils.sigmoid(-it * t) }.toFloatArray()

        // State update
        val newState = FloatArray(hiddenSize) { i ->
            timeGate[i] * features[i] + (1 - timeGate[i]) * h[i]
        }

        return Pair(newState, newState)
    }
}

class SimpleAttention(
    private val hiddenSize: Int,
    private val numHeads: Int
) {
    private val headSize = hiddenSize / numHeads

    fun forward(query: Matrix, key: Matrix, value: Matrix): Matrix {
        // Simple dot-product attention
        val scores = MathUtils.matrixMultiply(query, key)
        val weights = scores.map { row ->
            val maxVal = row.maxOrNull() ?: 0f
            val expVals = row.map { kotlin.math.exp((it - maxVal).toDouble()).toFloat() }
            val sum = expVals.sum()
            expVals.map { it / sum }.toFloatArray()
        }.toTypedArray()

        return MathUtils.matrixMultiply(weights, value)
    }
}

class TransformerLNN(
    private val inputSize: Int,
    private val hiddenSize: Int,
    private val numHeads: Int = 4
) {
    private val attention = SimpleAttention(hiddenSize, numHeads)
    private val ltc = LiquidTimeConstant(hiddenSize, hiddenSize)

    fun forward(input: Matrix, times: Vector? = null): Matrix {
        val batchSize = input.size
        val seqLen = input[0].size

        // Use default times if none provided
        val actualTimes = times ?: FloatArray(seqLen) { it.toFloat() }

        // Process with attention
        val attOutput = attention.forward(input, input, input)

        // Process with LTC
        val outputs = mutableListOf<Vector>()
        var state = FloatArray(hiddenSize) { 0f }

        for (t in 0 until seqLen) {
            val (newOutput, newState) = ltc.forward(attOutput[t], state, actualTimes[t])
            outputs.add(newOutput)
            state = newState
        }

        return outputs.toTypedArray()
    }
}

// Example usage
fun main() {
    val model = TransformerLNN(inputSize = 10, hiddenSize = 20)

    // Create sample input
    val input = Array(1) { FloatArray(10) { kotlin.random.Random.nextFloat() } }

    // Forward pass
    val output = model.forward(input)
    println("Output shape: ${output.size}x${output[0].size}")
}

In [1]:
%use dataframe
%use lets-plot
%useLatestDescriptors

import kotlin.math.PI
        import kotlin.math.sin
        import kotlin.random.Random
        import java.time.LocalDateTime
        import java.time.format.DateTimeFormatter

// Basic tensor operations using Kotlin arrays
typealias Matrix = Array<FloatArray>
typealias Vector = FloatArray

class MathUtils {
    companion object {
        fun sigmoid(x: Float): Float = 1.0f / (1.0f + kotlin.math.exp(-x))

        fun tanh(x: Float): Float = kotlin.math.tanh(x.toDouble()).toFloat()

        fun matrixMultiply(a: Array<FloatArray>, b: Array<Matrix>): Matrix {
            val rows = a.size
            val cols = b[0].size
            val n = b.size
            return Array(rows) { i ->
                FloatArray(cols) { j ->
                    var sum = 0f
                    for (k in 0 until n) {
                        sum += a[i][k] * b[k][j]
                    }
                    sum
                }
            }
        }

        fun vectorAdd(a: Vector, b: Vector): Vector {
            return FloatArray(a.size) { i -> a[i] + b[i] }
        }
    }
}

class LiquidTimeConstant(
    private val inputSize: Int,
    private val hiddenSize: Int
) {
    private val weightsInput = Array(inputSize + hiddenSize) {
        FloatArray(hiddenSize) { Random.nextFloat() - 0.5f }
    }
    private val weightsHidden = Array(hiddenSize) {
        FloatArray(hiddenSize) { Random.nextFloat() - 0.5f }
    }
    private val biases = FloatArray(hiddenSize) { 0f }

    fun forward(x: Vector, h: Vector, t: Float): Pair<Vector, Vector> {
        val combined = x + h
        val features = MathUtils.matrixMultiply(arrayOf(combined), weightsInput)[0]
        val activated = features.map { MathUtils.tanh(it) }.toFloatArray()
        val timeGate = activated.map { MathUtils.sigmoid(-it * t) }.toFloatArray()
        val newState = FloatArray(hiddenSize) { i ->
            timeGate[i] * features[i] + (1 - timeGate[i]) * h[i]
        }
        return Pair(newState, newState)
    }
}

class SimpleAttention(
    private val hiddenSize: Int,
    private val numHeads: Int
) {
    private val headSize = hiddenSize / numHeads

    fun forward(query: Array<Matrix>, key: Array<Matrix>, value: Array<Matrix>): Matrix {
        val scores = MathUtils.matrixMultiply(query, key)
        val weights = scores.map { row ->
            val maxVal = row.maxOrNull() ?: 0f
            val expVals = row.map { kotlin.math.exp((it - maxVal).toDouble()).toFloat() }
            val sum = expVals.sum()
            expVals.map { it / sum }.toFloatArray()
        }.toTypedArray()
        return MathUtils.matrixMultiply(weights, value)
    }
}

class TransformerLNN(
    val inputSize: Int,
    val hiddenSize: Int,
    val numHeads: Int = 4
) {
    private val attention = SimpleAttention(hiddenSize, numHeads)
    private val ltc = LiquidTimeConstant(hiddenSize, hiddenSize)

    fun forward(input: Array<Matrix>, times: Vector? = null): Matrix {
        val seqLen = input[0].size
        val actualTimes = times ?: FloatArray(seqLen) { it.toFloat() }
        val attOutput = attention.forward(input, input, input)
        val outputs = mutableListOf<Vector>()
        var state = FloatArray(hiddenSize) { 0f }

        for (t in 0 until seqLen) {
            val (newOutput, newState) = ltc.forward(attOutput[t], state, actualTimes[t])
            outputs.add(newOutput)
            state = newState
        }
        return outputs.toTypedArray()
    }

    fun getAttentionWeights(input: Matrix): Matrix {
        val scores = MathUtils.matrixMultiply(input, input)
        return scores.map { row ->
            val maxVal = row.maxOrNull() ?: 0f
            val expVals = row.map { kotlin.math.exp((it - maxVal).toDouble()).toFloat() }
            val sum = expVals.sum()
            expVals.map { it / sum }.toFloatArray()
        }.toTypedArray()
    }
}

data class TrainingMetric(
    val epoch: Int,
    val batchIdx: Int,
    val loss: Float,
    val accuracy: Float,
    val timestamp: String
)

class TrainingVisualizer(
    val saveIntervals: Int = 10,
    val plotIntervals: Int = 5
) {
    private val metrics = mutableListOf<TrainingMetric>()
    private val ltcStates = mutableListOf<Vector>()
    private val attnWeights = mutableListOf<Matrix>()

    fun updateMetrics(
        epoch: Int,
        batchIdx: Int,
        loss: Float,
        accuracy: Float,
        ltcState: Vector,
        attnWeights: Matrix
    ) {
        metrics.add(TrainingMetric(
            epoch = epoch,
            batchIdx = batchIdx,
            loss = loss,
            accuracy = accuracy,
            timestamp = LocalDateTime.now().format(DateTimeFormatter.ofPattern("HH:mm:ss"))
        ))
        ltcStates.add(ltcState.clone())
        this.attnWeights.add(attnWeights.map { it.clone() }.toTypedArray())
    }

    fun plotTrainingProgress() {
        val data = mapOf(
            "Step" to metrics.indices.toList(),
            "Loss" to metrics.map { it.loss },
            "Accuracy" to metrics.map { it.accuracy },
            "Epoch" to metrics.map { it.epoch }
        )

        val lossPlot = letsPlot(data) {
            x = "Step"
            y = "Loss"
            color = "Epoch"
        } +
                geomLine(size = 1.5) +
                geomPoint(size = 3.0) +
                themeLight() +
                labs(
                    title = "Training Loss Over Time",
                    x = "Training Step",
                    y = "Loss"
                )

        val accuracyPlot = letsPlot(data) {
            x = "Step"
            y = "Accuracy"
            color = "Epoch"
        } +
                geomLine(size = 1.5) +
                geomPoint(size = 3.0) +
                themeLight() +
                labs(
                    title = "Training Accuracy Over Time",
                    x = "Training Step",
                    y = "Accuracy (%)"
                )

        lossPlot.show()
        accuracyPlot.show()

        if (attnWeights.isNotEmpty()) {
            val lastAttnWeight = attnWeights.last()
            val attnData = mutableMapOf<String, List<Any>>()

            val rows = mutableListOf<Int>()
            val cols = mutableListOf<Int>()
            val values = mutableListOf<Float>()

            for (i in lastAttnWeight.indices) {
                for (j in lastAttnWeight[i].indices) {
                    rows.add(i)
                    cols.add(j)
                    values.add(lastAttnWeight[i][j])
                }
            }

            attnData["Row"] = rows
            attnData["Column"] = cols
            attnData["Value"] = values

            val attnPlot = letsPlot(attnData) {
                x = "Column"
                y = "Row"
                fill = "Value"
            } +
                    geomTile() +
                    scaleFillGradient(low = "#FFFFFF", high = "#0000FF") +
                    themeLight() +
                    labs(
                        title = "Attention Weights Heatmap",
                        x = "Query Position",
                        y = "Key Position"
                    )

            attnPlot.show()
        }
    }
}

object DataGenerator {
    fun generateSyntheticData(
        numSamples: Int = 1000,
        seqLength: Int = 50,
        inputDim: Int = 10
    ): Pair<Array<Matrix>, Array<Matrix>> {
        val t = Array(seqLength) { it * 10.0f / seqLength }

        val x = Array(numSamples) { sampleIdx ->
            Array(seqLength) { timeIdx ->
                FloatArray(inputDim) { dimIdx ->
                    val freq1 = (dimIdx + 1) * 0.5f
                    val freq2 = (dimIdx + 1) * 0.25f
                    sin(2 * PI * freq1 * t[timeIdx]).toFloat() +
                            0.5f * sin(2 * PI * freq2 * t[timeIdx]).toFloat()
                }
            }
        }

        val y = Array(numSamples) { sampleIdx ->
            Array(seqLength) { timeIdx ->
                FloatArray(inputDim) { dimIdx ->
                    val shiftedIdx = (timeIdx + 1) % seqLength
                    1.5f * x[sampleIdx][shiftedIdx][dimIdx] + 0.5f
                }
            }
        }

        return Pair(x, y)
    }
}

class Trainer(
    private val model: TransformerLNN,
    private val learningRate: Float = 0.001f
) {
    private val visualizer = TrainingVisualizer()

    fun trainEpoch(
        trainData: Pair<Array<Matrix>, Array<Matrix>>,
        batchSize: Int,
        epoch: Int
    ): Float {
        val (x, y) = trainData
        var totalLoss = 0f
        val numBatches = x.size / batchSize

        for (batchIdx in 0 until numBatches) {
            val startIdx = batchIdx * batchSize
            val endIdx = startIdx + batchSize

            val batchX = x.slice(startIdx until endIdx).toTypedArray()
            val batchY = y.slice(startIdx until endIdx).toTypedArray()

            val yPred = model.forward(batchX)
            val loss = computeMSELoss(yPred, batchY)
            totalLoss += loss

            updateModelWeights(loss)

            if (batchIdx % visualizer.saveIntervals == 0) {
                val accuracy = computeAccuracy(yPred, batchY)
                val ltcState = FloatArray(model.hiddenSize) { 0f }
                val attnWeights = model.getAttentionWeights(batchX[0])

                visualizer.updateMetrics(
                    epoch = epoch,
                    batchIdx = batchIdx,
                    loss = loss,
                    accuracy = accuracy,
                    ltcState = ltcState,
                    attnWeights = attnWeights
                )

                if (batchIdx % visualizer.plotIntervals == 0) {
                    visualizer.plotTrainingProgress()
                }
            }
        }

        return totalLoss / numBatches
    }

    private fun computeMSELoss(pred: Matrix, target: Array<Matrix>): Float {
        var loss = 0f
        for (i in pred.indices) {
            for (j in pred[i].indices) {
                val diff = pred[i][j] - target[i][j]
                loss += diff * diff
            }
        }
        return loss / (pred.size * pred[0].size)
    }

    private fun computeAccuracy(pred: Matrix, target: Array<Matrix>): Float {
        val mse = computeMSELoss(pred, target)
        return 100f * (1f - kotlin.math.min(mse, 1f))
    }

    private fun updateModelWeights(loss: Float) {
        // Simplified weight update
        // In a real implementation, you'd want proper backpropagation
    }
}

// Example usage
val model = TransformerLNN(inputSize = 10, hiddenSize = 20)
val trainer = Trainer(model)

val (trainX, trainY) = DataGenerator.generateSyntheticData(
    numSamples = 1000,
    seqLength = 50,
    inputDim = 10
)

// Train for several epochs
val numEpochs = 5
for (epoch in 0 until numEpochs) {
    val avgLoss = trainer.trainEpoch(
        trainData = Pair(trainX, trainY),
        batchSize = 32,
        epoch = epoch
    )
    println("Epoch $epoch completed with average loss: $avgLoss")
}

org.jetbrains.kotlinx.jupyter.exceptions.ReplLibraryException: The problem is found in one of the loaded libraries: check library imports, dependencies and repositories