In [1]:
using Flux
using Flux.Tracker: gradient, update!
using Printf

# `gradient`

Returns the "rate of change" of the function at the given values.

In [2]:
f(x) = 3x^2 + 2x + 1
df(x) = gradient(f, x; nest=true)[1] # 6x + 2
d2f(x) = gradient(df, x)[1]          # 6

d2f (generic function with 1 method)

In [3]:
for i in 1:10
    @printf("%4d %20s\n", i, df(i))
end

   1        8.0 (tracked)
   2       14.0 (tracked)
   3       20.0 (tracked)
   4       26.0 (tracked)
   5       32.0 (tracked)
   6       38.0 (tracked)
   7       44.0 (tracked)
   8       50.0 (tracked)
   9       56.0 (tracked)
  10       62.0 (tracked)


In [4]:
for i in 1:10
    @printf("%4d %20s\n", i, d2f(i))
end

   1        6.0 (tracked)
   2        6.0 (tracked)
   3        6.0 (tracked)
   4        6.0 (tracked)
   5        6.0 (tracked)
   6        6.0 (tracked)
   7        6.0 (tracked)
   8        6.0 (tracked)
   9        6.0 (tracked)
  10        6.0 (tracked)


In [5]:
f(a, b) = a*b
# This will flip the arguments, because the "rate of change" of each depends on the other.
gradient(f, 3, 7)

(7.0 (tracked), 3.0 (tracked))

# Linear Regression Example

Here we train a "neural network" with 5 inputs and 2 outputs. There are no hidden layers, so I'm not sure if this is actually considered a neural network.

Each of the two outputs is a weighted combination of the inputs plus a bias.

We are training on a single data point of 5 parameters (`x`) and a single target of 2 parameters (`y`). Thus, we should be able to fit the model perfectly and get a loss very close to zero.

In [6]:
W = rand(2, 5)
b = rand(2)

predict(x) = W*x .+ b

function loss(x, y)
    ŷ = predict(x)
    sum((y .- ŷ).^2)
end

x, y = rand(5), rand(2)

loss(x, y)

4.160387594487831

In [7]:
W = param(W)
b = param(b)

Tracked 2-element Array{Float64,1}:
 0.7052962440416521
 0.8164580317072307

In [8]:
for epoch in 1:10
    gradients = gradient(() -> loss(x, y), params(W, b))

    update!(W, -0.1*gradients[W])
    update!(b, -0.1*gradients[b])

    @printf("%.9f\n", loss(x, y))
end

1.547934579
0.575932268
0.214284235
0.079727662
0.029663872
0.011036888
0.004106440
0.001527862
0.000568464
0.000211506
