In [1]:
using BenchmarkTools
using CUDA
using DataStructures
using Flux
using Flux: params
using Gym
using StatsBase: StatsBase, sample
using Zygote

In [2]:
env = GymEnv("CartPole-v1")

GymEnv("CartPole-v1", Gym.Spec("CartPole-v1", 475.0, false, 500), DiscreteS(2), BoxS(Float32[-4.8, -3.4028235f38, -0.41887903, -3.4028235f38], Float32[4.8, 3.4028235f38, 0.41887903, 3.4028235f38], (4,)), (-Inf, Inf), PyObject <TimeLimit<CartPoleEnv<CartPole-v1>>>)

In [3]:
env.action_space.n

2

In [4]:
Base.length(ds::DiscreteS) = ds.n
Base.iterate(ds::DiscreteS) = iterate(0:ds.n-1)
Base.iterate(ds::DiscreteS, state) = iterate(0:ds.n-1, state)
StatsBase.sample(ds::DiscreteS) = Gym.sample(ds)

collect(env.action_space)

2-element Vector{Any}:
 0
 1

In [5]:
struct SAR{S, A}
    s      :: S
    a      :: A
    r      :: Float32
    s′     :: S
    t      :: Int32
    failed :: Bool
    limit  :: Bool
end


In [6]:
abstract type AbstractPolicy end

function action end

function run_episode(step_f, env, policy)
    episode_reward = 0f0
    s = Gym.reset!(env)
    for t in Iterators.countfrom(1)
        a = action(policy, s, env.action_space)
        s′, r, failed, info = step!(env, a)
        episode_reward += r
        @assert t < env.gymenv._max_episode_steps
        limit = t == env.gymenv._max_episode_steps
        if limit
            failed = false
        end
        step_f(SAR(s, a, Float32(r), s′, Int32(t), failed, limit))
        if failed || limit
            break
        end
    end
    episode_reward
end

run_episode (generic function with 1 method)

In [7]:
struct Policy <: AbstractPolicy end

function action(policy::Policy, s, A)
    sample(A)
end

action (generic function with 1 method)

In [8]:
sars = SAR[]
for episode in 1:100
    run_episode(env, Policy()) do sar
        push!(sars, sar)
        # render(env)
    end
end
Gym.close!(env)
length(sars)

2193

In [9]:
function nonans(label)
    function (xs)
        @assert !any(isnan, xs) "nan at $label"
        xs
    end
end

function make_π_network()
    Chain(
        nonans("π input"),
        Dense(4, 600, relu),
        Dense(600, 200, relu),
        Dense(200, 2, identity),
        softmax,
        nonans("π output"))
end

function make_q_network()
    Chain(
        nonans("q input"),
        Dense(4, 600, relu),
        Dense(600, 200, relu),
        Dense(200, 2, identity),
        nonans("q output"))
end

make_q_network (generic function with 1 method)

In [10]:
π_d = gpu(make_π_network())

Chain(
  var"#3#4"{String}("π input"),
  Dense(4, 600, relu),                  [90m# 3_000 parameters[39m
  Dense(600, 200, relu),                [90m# 120_200 parameters[39m
  Dense(200, 2),                        [90m# 402 parameters[39m
  NNlib.softmax,
  var"#3#4"{String}("π output"),
)[90m                   # Total: 6 arrays, [39m123_602 parameters, 745 bytes.

In [11]:
x_d = gpu(rand(4))
π_d(x_d)

2-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:
 0.5155876
 0.4844124

In [12]:
x_d .+= 1

4-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:
 1.3463959
 1.8578858
 1.266954
 1.6604097

In [13]:
function valgrad(f, x...)
    val, back = pullback(f, x...)
    val, back(1)
end

valgrad (generic function with 1 method)

In [14]:
params(π_d)

Params([Float32[0.02055074 0.017598977 0.03929853 -0.009701583; 0.097948015 0.045566842 0.0261656 -0.09887467; … ; 0.06911378 -0.01995021 0.025440695 -0.028289732; -0.008783246 -0.051764224 0.0485605 0.02541781], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[-0.038315438 -0.04047407 … -0.04325533 0.018109128; 0.0056203534 0.021703616 … 0.078905135 0.06412338; … ; -0.050137274 -0.049155746 … -0.0534402 -0.053650968; -0.03902287 -0.081143565 … 0.005548933 0.065123774], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[-0.16034214 -0.023040686 … 0.12092592 -0.11126434; -0.07888829 -0.05053302 … 0.11698182 -0.117985904], Float32[0.0, 0.0]])

In [None]:
loss, grads = valgrad(params(π_d)) do
    sum(π_d(x_d))
end
display(loss)
display(grads)

In [None]:
println(typeof(grads.params |> first))
println(typeof(grads.grads |> keys |> first))
println(typeof(grads.grads |> values |> first))

In [None]:
π_h = make_π_network()
x_h = rand(4)
host_small_bm = @benchmark valgrad(params(π_h)) do
    sum(π_h(x_h))
end

In [None]:
device_small_bm = @benchmark valgrad(params(π_d)) do
    sum(π_d(x_d))
end

In [None]:
x_h = rand(4, 10_000)
host_large_bm = @benchmark valgrad(params(π_h)) do
    sum(π_h(x_h))
end

In [None]:
x_d = rand(4, 10_000) |> gpu
device_large_bm = @benchmark valgrad(params(π_d)) do
    sum(π_d(x_d))
end

In [None]:
device_large_bm2 = @benchmark begin
    x_d = rand(4, 10_000) |> gpu
    valgrad(params(π_d)) do
        sum(π_d(x_d))
    end
end

In [None]:
println("Device small speedup: $((mean(device_small_bm.times) / mean(host_small_bm.times))^-1)")
println("Device large speedup: $((mean(device_large_bm.times) / mean(host_large_bm.times))^-1)")
println("Device large speedup: $((mean(device_large_bm2.times) / mean(host_large_bm.times))^-1)")