Skip to content

Commit

Permalink
update reinforce cartpole
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Oct 23, 2018
1 parent c904734 commit 7b3505b
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Expand Up @@ -5,7 +5,7 @@ os:
# - osx

julia:
- 0.6
- 1.0
- nightly

matrix:
Expand Down
3 changes: 2 additions & 1 deletion REQUIRE
@@ -1,2 +1,3 @@
julia 0.6
julia 0.7
Knet
Gym
37 changes: 16 additions & 21 deletions examples/reinforce_cartpole.jl
@@ -1,6 +1,8 @@
using Knet
using Gym
import AutoGrad: getval
import Gym
using AutoGrad: getval
import Random
using Statistics

mutable struct History
nS::Int
Expand Down Expand Up @@ -44,18 +46,11 @@ function predict(w, x)
return probs
end

function softmax(x)
y = maximum(x, 1)
y = x .- y
y = exp.(y)
return y ./ sum(y, 1)
end

function sample_action(probs)
@assert size(probs, 2) == 1
cprobs = cumsum(probs, 1)
cprobs = cumsum(probs, dims=1)
sampled = cprobs .> rand()
return mapslices(indmax, sampled, 1)[1]
return mapslices(argmax, sampled, dims=1)[1]
end

function loss(w, history)
Expand All @@ -66,7 +61,7 @@ function loss(w, history)

p = predict(w, states)
inds = history.actions + nA*(0:M-1)
lp = logp(p, 1)[inds] # lp is a vector
lp = logp(p, dims=1)[inds] # lp is a vector

return -mean(lp .* R)
end
Expand All @@ -76,41 +71,41 @@ function main(;
lr = 1e-2, # learning rate
γ = 0.99, #discount rate
episodes = 500,
rendered = true,
render = true,
seed = -1,
infotime = 50)

env = GymEnv("CartPole-v1")
seed > 0 && (srand(seed); srand(env, seed))
env = Gym.GymEnv("CartPole-v1")
seed > 0 && (Random.seed!(seed); Gym.seed!(env, seed))
nS, nA = 4, 2
w = initweights(hidden, nS, nA)
opt = [Adam(lr=lr) for _=1:length(w)]

avgreward = 0
for episode=1:episodes
state = reset!(env)
state = Gym.reset!(env)
episode_rewards = 0
history = History(nS, nA, γ)
for t=1:10000
p = predict(w, state)
p = softmax(p)
action = sample_action(p)

next_state, reward, done, _ = step!(env, action_space(env)[action])
next_state, reward, done, info = Gym.step!(env, action-1)
append!(history.states, state)
push!(history.actions, action)
push!(history.rewards, reward)
state = next_state
episode_rewards += reward

episode % infotime == 0 && rendered && render(env)
episode % infotime == 0 && render && Gym.render(env)
done && break
end

avgreward = 0.02 * episode_rewards + avgreward * 0.98
avgreward = 0.1 * episode_rewards + avgreward * 0.9
if episode % infotime == 0
println("(episode:$episode, avgreward:$avgreward)")
rendered && render(env, close=true)
Gym.close!(env)
end

dw = grad(loss)(w, history)
Expand Down

0 comments on commit 7b3505b

Please sign in to comment.