In [1]:
// Analogous to micrograd Value
data class Value(
    var data: Double,
    var grad: Double = 0.0,
    var label: String = (1..3).map { ('H'..'Z').random() }.joinToString(""),
    var op: String = "",
    var backward: () -> Unit = {},
    val children: MutableSet<Value> = mutableSetOf()
) {
    operator fun plus(other: Value): Value {
        val out = Value(data + other.data, op = "$label + ${other.label}")
        out.children.addAll(setOf(this, other))
        out.backward = {
            this.grad += out.grad
            other.grad += out.grad
        }
        return out
    }

    operator fun times(other: Value): Value {
        val out = Value(data * other.data, op = "$label * ${other.label}")
        out.children.addAll(setOf(this, other))
        out.backward = {
            this.grad += other.data * out.grad
            other.grad += this.data * out.grad
        }
        return out
    }

    fun pow(num: Int): Value {
        val out = Value(data.pow(num), op = "$label**${num}")
        out.children.add(this)
        out.backward = {
            this.grad += num * this.data.pow(num - 1) * out.grad
        }
        return out
    }

    fun relu(): Value {
        val out = Value(if (data < 0) 0.0 else data, op = "$label^ReLu")
        out.children.add(this)
        out.backward = {
            if (out.data > 0) {
                this.grad += out.grad
            }
        }
        return out
    }

    fun runBackProp(debug: Boolean = false) {
        // Topologically sort Values for backprop
        val sorted = mutableListOf<Value>()
        val visited = mutableSetOf<Value>()

        fun topo(node: Value) {
            if (node in visited) return
            visited.add(node)
            for (child in node.children) {
                topo(child)
            }
            node.grad = 0.0
            sorted.add(node)
        }
        topo(this)

        this.grad = 1.0
        for (node in sorted.reversed()) {
            node.backward()
            if (debug) {
                println(node)
            }
        }
    }

    operator fun unaryMinus(): Value = this * -1.0
    operator fun plus(otherNumber: Number): Value = this + Value(otherNumber.toDouble())
    operator fun minus(other: Value) = this + -other
    operator fun minus(otherNumber: Number) = this - Value(otherNumber.toDouble())
    operator fun times(otherNumber: Number): Value = this * Value(otherNumber.toDouble())
    operator fun div(other: Value): Value = this * other.pow(-1)
    operator fun div(otherNumber: Number): Value = this / Value(otherNumber.toDouble())

    override fun toString(): String {
        return "Value(data=%.4f,grad=%.4f) $label${if (op.isNotBlank()) " = $op" else ""}".format(data, grad)
    }
}


In [2]:
// Extensions for Numerical types
operator fun Number.plus(other: Value): Value = other + this
operator fun Number.minus(other: Value): Value = this + -other
operator fun Number.times(other: Value): Value = other * this
operator fun Number.div(other: Value): Value = this * other.pow(-1)

In [3]:
var a = Value(2.0, label = "a")
var b = Value(-3.0, label = "b")
var c = Value(10.0, label = "c")
var e = a*b; e.label = "e"
var d = e+c; d.label = "d"
var f = Value(-2.0, label="f")
var L = d * f; L.label = "L"
L.runBackProp()

In [4]:
var a = Value(-4.0, label = "a")
var b = Value(2.0, label = "b")
var c = a + b; c.label = "c"
var d = a * b + b.pow(3); d.label = "d"
c += c + 1
c += 1 + c + (-a)
d += d * 2 + (b + a).relu()
d += 3 * d + (b - a).relu()
var e = c - d; e.label = "e"
var f = e.pow(2); f.label = "f"
var g = f / 2.0; g.label = "g"
g += 10.0 / f
println("%.4f".format(g.data))  // prints 24.7041, the outcome of this forward pass
g.runBackProp()
println("%.4f".format(a.grad))  // prints 138.8338, i.e. the numerical value of dg/da
println("%.4f".format(b.grad))  // prints 645.5773, i.e. the numerical value of dg/db

24.7041
138.8338
645.5773
