In [5]:
using ReinforcementLearning
using Flux
using Statistics
using Plots

In [10]:
env = CartPoleEnv(;T=Float32, seed=123)

CartPoleEnv{Float32}(gravity=9.8,masscart=1.0,masspole=0.1,totalmass=1.1,halflength=0.5,polemasslength=0.05,forcemag=10.0,tau=0.02,thetathreshold=0.20943952,xthreshold=2.4,max_steps=200)

In [7]:
mutable struct DQN
    TRAIN::Bool
    CHANGE::Bool
    ϵ::Float64
    ϵ_DECAY::Float64
    ϵ_MIN::Float64
    BATCH_SIZE::Int64
    MEMORY
    MEM_SIZE::Int64
    STATE_SIZE::Int64
    ACTION_SIZE::Int64
    γ::Float64
    C_UPDATE::Int64
    model1
    model2
    
    function DQN(ϵ_DECAY::Float64, ϵ_MIN::Float64,BATCH_SIZE::Int64,MEM_SIZE::Int64,STATE_SIZE::Int64,
            
                 ACTION_SIZE::Int64,γ::Float64,C_UPDATE::Int64, model1)
        
        new(true, true, 1.0f0, ϵ_DECAY, ϵ_MIN, BATCH_SIZE, [], MEM_SIZE, STATE_SIZE, ACTION_SIZE,γ,C_UPDATE
            
            model1, deepcopy(model2))
    end
    
end


In [17]:
#Fonctions élémentaires

function action(dqn::DQN,state::Array{Float32,1})

  if rand() <= dqn.ϵ && dqn.TRAIN
        
    return rand(1:dqn.ACTION_SIZE)

  end



  act_values = dqn.model1(state)

  return Flux.argmax(act_values)
    
end

function update_ϵ!(dqn::DQN)
    
    x = dqn.ϵ*dqn.ϵ_DECAY
    
    if x < dqn.ϵ_MIN && dqn.CHANGE
        
        dqn.ϵ=dqn.ϵ_MIN
        
        dqn.CHANGE=false
        
        
    elseif dqn.CHANGE
        
        dqn.ϵ = x
        
    end
    
    
end
        
function act(action::Int64, env)
    
    env(action)
    
    
    obs=observe(env)
    
    get_state(obs), get_reward(obs), get_terminal(obs)
    
end

    
    
function remember!(dqn::DQN,state::Array{Float32,1}, action::Int64, reward::Int64, next_state::Array{Float32,1}, done::Bool)

  if length(dqn.memory) == dqn.MEM_SIZE

    deleteat!(dqn.memory, 1)

  end

  push!(dqn.memory, (state, action, reward, next_state, done))

end


function replay!(dqn::DQN)

  batch_size = min(dqn.BATCH_SIZE, length(dqn.memory))

  minibatch = sample(dqn.memory, batch_size, replace = false)

  

  x = Matrix{Float32}(dqn.STATE_SIZE, batch_size)

  y = Matrix{Float32}(dqn.ACTION_SIZE, batch_size)

  for (iter, (state, action, reward, next_state, done)) in enumerate(minibatch)

    target = reward

    if !done

      target += dqn.γ * maximum(dqn.model2(next_state))

    end



    target_f = dqn.model1(state)

    target_f[action] = target

    

    x[:, iter] .= state

    y[:, iter] .= target_f

  end


  Flux.train!(loss, [(x, y)], opt)
    

end


function copy(iter::Int64, dqn::DQN)
    
    if iter%dqn.C_UPDATE==0
        
        dqn.model2 = deepcopy(dqn.model1)
        
    end
    
end
    
    
#Run 1 épisode

function episode!(dqn::DQN, env)
    
    obs_ini=observe(env)
    
    current_state=get_state(obs_ini)
    
    
    total_reward=0
    
    i = 0
    
    while true
        
        current_action=action(dqn,current_state)
        
        current_next_state,current_reward,current_done=act(current_action,env)
        
        
        total_reward+=current_reward
        
        
        remember!(dqn, current_state, current_action, current_reward, current_next_state, current_done)
        
        current_state=current_next_state
        
        replay!(dqn)
        
        update_ϵ!(dqn)
        
        i+=1
        
        copy(i,dqn)
        
        if done
            
            break
            
        end
        
    end
    
    total_reward
    
end
        

#Run DQN algorithm

function main_DQN(dqn::DQN, env)
    
    e = 1

    scores = []

    while true

      reset!(env)

      total_reward = episode!(dqn,env)

      push!(scores, total_reward)

      print("Episode: $e | Score: $total_reward ")

      if e > 100

        last_100_mean = mean(scores[end-99:end])

        print("Last 100 episodes mean score: $last_100_mean")

        if last_100_mean > 195

          println("\nProblem solved!")

          break

        end
        
        if e > 200
                
          println("\nProblem unsolved!")
                
          break
        
        end

      end

      println()

      e += 1

    end
    
    e, scores
    
end
        
        
        
        
        
        

LoadError: syntax: incomplete: premature end of input