In [None]:
using Pkg; Pkg.activate("/home/dhairyagandhi96/temp/model-zoo/script/.."); Pkg.status();

using Flux, Trebuchet
using Zygote: forwarddiff
using Statistics: mean
using Random

lerp(x, lo, hi) = x*(hi-lo)+lo

function shoot(wind, angle, weight)
  Trebuchet.shoot((wind, Trebuchet.deg2rad(angle), weight))[2]
end

shoot(ps) = forwarddiff(p -> shoot(p...), ps)

Random.seed!(0)

model = Chain(Dense(2, 16, σ),
              Dense(16, 64, σ),
              Dense(64, 16, σ),
              Dense(16, 2)) |> f64

θ = params(model)

function aim(wind, target)
  angle, weight = model([wind, target])
  angle = σ(angle)*90
  weight = weight + 200
  angle, weight
end

distance(wind, target) =
  shoot(Tracker.collect([wind, aim(wind, target)...]))

function loss(wind, target)
  try
    (distance(wind, target) - target)^2
  catch e
    # Roots.jl sometimes give convergence errors, ignore them
    param(0)
  end
end

DIST  = (20, 100)	# Maximum target distance
SPEED =   5 # Maximum wind speed

target() = (randn() * SPEED, lerp(rand(), DIST...))

meanloss() = mean(sqrt(loss(target()...)) for i = 1:100)

opt = ADAM()

dataset = (target() for i = 1:100_000)
cb = Flux.throttle(() -> @show(meanloss()), 10)

Flux.train!(loss, θ, dataset, opt, cb = cb)

[32m[1m    Status[22m[39m `~/temp/model-zoo/Project.toml`
 [90m [1520ce14][39m[37m   AbstractTrees v0.2.1[39m
 [90m [fbb218c0][39m[93m ↑ BSON v0.2.3 ⇒ v0.2.4[39m
 [90m [54eefc05][39m[37m   Cascadia v0.4.0[39m
 [90m [8f4d0f93][39m[37m   Conda v1.3.0[39m
 [90m [864edb3b][39m[93m ↑ DataStructures v0.17.0 ⇒ v0.17.5[39m
 [90m [31c24e10][39m[93m ↑ Distributions v0.21.3 ⇒ v0.21.5[39m
 [90m [587475ba][39m[37m   Flux v0.9.0[39m
 [90m [708ec375][39m[37m   Gumbo v0.5.1[39m
 [90m [b0807396][39m[37m   Gym v1.1.3[39m
 [90m [cd3eb016][39m[93m ↑ HTTP v0.8.6 ⇒ v0.8.7[39m
 [90m [6218d12a][39m[37m   ImageMagick v0.7.5[39m
 [90m [916415d5][39m[37m   Images v0.18.0[39m
 [90m [e5e0dc1b][39m[37m   Juno v0.7.2[39m
 [90m [ca7b5df7][39m[37m   MFCC v0.3.1[39m
 [90m [dbeba491][39m[92m + Metalhead v0.4.0 #c4d1eba (https://github.com/FluxML/Metalhead.jl.git)[39m
 [90m [91a5bcdd][39m[93m ↑ Plots v0.26.3 ⇒ v0.27.0[39m
 [90m [2913bbd2][39m[37m   St

┌ Info: Recompiling stale cache file /home/dhairyagandhi96/.julia/compiled/v1.1/Flux/QdkVy.ji for Flux [587475ba-b771-5e3f-ad9e-33799f191a9c]
└ @ Base loading.jl:1184
