In [1]:
using BenchmarkTools
using CUDA
using DataStructures
using Flux
using Flux: params
using Gym
using StatsBase: StatsBase, sample
using Zygote

In [2]:
env = GymEnv("CartPole-v1")

GymEnv("CartPole-v1", Gym.Spec("CartPole-v1", 475.0, false, 500), DiscreteS(2), BoxS(Float32[-4.8, -3.4028235f38, -0.41887903, -3.4028235f38], Float32[4.8, 3.4028235f38, 0.41887903, 3.4028235f38], (4,)), (-Inf, Inf), PyObject <TimeLimit<CartPoleEnv<CartPole-v1>>>)

In [3]:
env.action_space.n

2

In [4]:
Base.length(ds::DiscreteS) = ds.n
Base.iterate(ds::DiscreteS) = iterate(0:ds.n-1)
Base.iterate(ds::DiscreteS, state) = iterate(0:ds.n-1, state)
StatsBase.sample(ds::DiscreteS) = Gym.sample(ds)

collect(env.action_space)

2-element Vector{Any}:
 0
 1

In [5]:
struct SAR{S, A}
    s      :: S
    a      :: A
    r      :: Float32
    s′     :: S
    t      :: Int32
    failed :: Bool
    limit  :: Bool
end


In [6]:
abstract type AbstractPolicy end

function action end

function run_episode(step_f, env, policy)
    episode_reward = 0f0
    s = Gym.reset!(env)
    for t in Iterators.countfrom(1)
        a = action(policy, s, env.action_space)
        s′, r, failed, info = step!(env, a)
        episode_reward += r
        @assert t < env.gymenv._max_episode_steps
        limit = t == env.gymenv._max_episode_steps
        if limit
            failed = false
        end
        step_f(SAR(s, a, Float32(r), s′, Int32(t), failed, limit))
        if failed || limit
            break
        end
    end
    episode_reward
end

run_episode (generic function with 1 method)

In [7]:
struct Policy <: AbstractPolicy end

function action(policy::Policy, s, A)
    sample(A)
end

action (generic function with 1 method)

In [8]:
sars = SAR[]
for episode in 1:100
    run_episode(env, Policy()) do sar
        push!(sars, sar)
        # render(env)
    end
end
Gym.close!(env)
length(sars)

2181

In [9]:
function nonans(label)
    function (xs)
        @assert !any(isnan, xs) "nan at $label"
        xs
    end
end

function make_π_network()
    Chain(
        nonans("π input"),
        Dense(4, 600, relu),
        Dense(600, 200, relu),
        Dense(200, 2, identity),
        softmax,
        nonans("π output"))
end

function make_q_network()
    Chain(
        nonans("q input"),
        Dense(4, 600, relu),
        Dense(600, 200, relu),
        Dense(200, 2, identity),
        nonans("q output"))
end

make_q_network (generic function with 1 method)

In [10]:
π_d = gpu(make_π_network())

Chain(
  var"#3#4"{String}("π input"),
  Dense(4, 600, relu),                  [90m# 3_000 parameters[39m
  Dense(600, 200, relu),                [90m# 120_200 parameters[39m
  Dense(200, 2),                        [90m# 402 parameters[39m
  NNlib.softmax,
  var"#3#4"{String}("π output"),
)[90m                   # Total: 6 arrays, [39m123_602 parameters, 745 bytes.

In [11]:
x_d = gpu(rand(4))
π_d(x_d)

2-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:
 0.5094007
 0.4905993

In [12]:
x_d .+= 1

4-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:
 1.3621557
 1.479598
 1.2280709
 1.7560036

In [13]:
function valgrad(f, x...)
    val, back = pullback(f, x...)
    val, back(1)
end

valgrad (generic function with 1 method)

In [14]:
params(π_d)

Params([Float32[0.015638476 -0.0048991013 -0.07762728 -0.038080353; 0.09224431 0.042531956 0.038626686 -0.012881162; … ; -0.03551219 -0.0406656 -0.052360788 0.07966123; 0.020060753 0.018769579 0.037280783 -0.033801794], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.07079705 -0.06682925 … 0.03770497 -0.033084437; -0.07819713 0.063017815 … -0.018593522 -0.068303965; … ; 0.022216111 0.028466983 … 0.08305521 0.0056571476; 0.0065940553 0.03461445 … 0.0720937 -0.0667682], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[-0.0369958 0.09166699 … 0.10131069 0.11823352; -0.1389865 0.03372632 … 0.05299462 -0.12359712], Float32[0.0, 0.0]])

In [15]:
loss, grads = valgrad(params(π_d)) do
    sum(π_d(x_d))
end
display(loss)
display(grads)

1.0f0

Grads(...)

In [16]:
println(typeof(grads.params |> first))
println(typeof(grads.grads |> keys |> first))
println(typeof(grads.grads |> values |> first))

CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}
CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}
CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}


In [17]:
π_h = make_π_network()
x_h = rand(4)
host_small_bm = @benchmark valgrad(params(π_h)) do
    sum(π_h(x_h))
end

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m186.940 μs[22m[39m … [35m  3.451 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 89.95%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m207.540 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m239.982 μs[22m[39m ± [32m261.893 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m12.99% ± 10.77%

  [39m█[34m [39m[32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁
  [39m█[34m▇[39

In [18]:
device_small_bm = @benchmark valgrad(params(π_d)) do
    sum(π_d(x_d))
end

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m258.781 μs[22m[39m … [35m40.095 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 33.57%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m268.550 μs              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m303.006 μs[22m[39m ± [32m 1.110 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m3.60% ±  0.97%

  [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁[39m▄[39m▅[39m▅[39m▇[39m▇[39m▇[39m█[39m▆[39m▆[34m▅[39m[39m▅[39m▃[39m▃[39m▂[39m▂[39m▂[39m▂[39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▂[39m▁[39m▂[39m▁[39m▂[3

In [19]:
x_h = rand(4, 10_000)
host_large_bm = @benchmark valgrad(params(π_h)) do
    sum(π_h(x_h))
end

BenchmarkTools.Trial: 6 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m803.467 ms[22m[39m … [35m973.227 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m1.09% … 14.18%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m847.013 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m2.83%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m874.983 ms[22m[39m ± [32m 66.946 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m6.18% ±  5.96%

  [39m█[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m█[39m [39m█[34m [39m[39m [39m [39m█[39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m█[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m█[39m [39m 
  [39m█[39m▁[39m▁[39m

In [20]:
x_d = rand(4, 10_000) |> gpu
device_large_bm = @benchmark valgrad(params(π_d)) do
    sum(π_d(x_d))
end

BenchmarkTools.Trial: 532 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m1.767 ms[22m[39m … [35m 16.721 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 37.48%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m9.325 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m9.406 ms[22m[39m ± [32m570.810 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m1.10% ±  7.10%

  [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [34m█[39m[32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▄[39m▁[39m▁[39m▁[39m▁[39m▁[3

In [21]:
device_large_bm2 = @benchmark begin
    x_d = rand(4, 10_000) |> gpu
    valgrad(params(π_d)) do
        sum(π_d(x_d))
    end
end

BenchmarkTools.Trial: 525 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m1.958 ms[22m[39m … [35m17.392 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 35.89%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m9.407 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m9.525 ms[22m[39m ± [32m 1.012 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m1.45% ±  6.43%

  [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [34m█[39m[32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▄[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁

In [22]:
println("Device small speedup: $((mean(device_small_bm.times) / mean(host_small_bm.times))^-1)")
println("Device large speedup: $((mean(device_large_bm.times) / mean(host_large_bm.times))^-1)")
println("Device large speedup: $((mean(device_large_bm2.times) / mean(host_large_bm.times))^-1)")

Device small speedup: 0.7920047512753262
Device large speedup: 93.02562057742328
Device large speedup: 91.86003977286636
