In [3]:
import Flux: params
using ReinforcementLearning
using ReinforcementLearningExperiments
using Flux
import Flux.params
using Flux.Losses
using IntervalSets
using StableRNGs
using Flux.Optimise
using Flux: crossentropy, ADAM
using LinearAlgebra
import Base.rand
import Random: Sampler, SamplerSimple, Repetition, eltype
import Random: seed!
using CUDA
include("../model/FurutaPendulums/ATRXR/src/FurutaPendulums.jl")
using .FurutaPendulums


SimulatedFurutaPendulum{Float64, Random._GLOBAL_RNG}()

## Configuring the Environment

In [4]:
mutable struct FPendulum <: AbstractEnv
    reward::Float32 = 0.0 # step reward
    state::Vector{Float64} = measure(furuta) # Measures the state of the pendulum
    action
    action_space
    state_space
    t::Int = 0 # Indicates the number of steps taken
    dt
    furuta
end
Main.FPendulum

FPendulum

In [5]:
RLBase.action_space(env::FPendulum) = Space(repeat([ClosedInterval{Float32}(-5.0,5.0)],length(1))) # The input voltage
RLBase.reward(env::FPendulum) = env.reward
RLBase.state(env::FPendulum) = env.x
A = Space(repeat([ClosedInterval{Float64}(-π, π)],length(1)))
B = Space(repeat([ClosedInterval{Float64}(-Inf, Inf)],length(1)))
RLBase.state_space(env::FPendulum) = [A, B, A, B]
RLBase.is_terminated(env::FPendulum) = env.time >= 250 # The episode is terminated after 250 steps
function RLBase.reset!(env::FPendulum) # Reset the environment
    furuta = SimulatedFurutaPendulum()
    control(furuta, 0.0)
    env.reward = 0.0
    env.t = 0.0
end

In [6]:
function (x::FPendulum)(action::Float32)
    control(furuta, action) # Apply the action
    env.x = measure(furuta) # Measure the state
    env.reward = -abs(env.x[3]) - abs(env.x[4]) # Calculate the reward
    env.t += 1
end

In [7]:
env = FPendulum()

# FPendulum

## Traits

| Trait Type        |                  Value |
|:----------------- | ----------------------:|
| NumAgentStyle     |          SingleAgent() |
| DynamicStyle      |           Sequential() |
| InformationStyle  | ImperfectInformation() |
| ChanceStyle       |           Stochastic() |
| RewardStyle       |           StepReward() |
| UtilityStyle      |           GeneralSum() |
| ActionStyle       |     MinimalActionSet() |
| StateStyle        |     Observation{Any}() |
| DefaultStateStyle |     Observation{Any}() |

## Is Environment Terminated?

No

## State Space

`Space{Vector{ClosedInterval{Float64}}}[Space{Vector{ClosedInterval{Float64}}}(ClosedInterval{Float64}[-3.141592653589793..3.141592653589793]), Space{Vector{ClosedInterval{Float64}}}(ClosedInterval{Float64}[-Inf..Inf]), Space{Vector{ClosedInterval{Float64}}}(ClosedInterval{Float64}[-3.141592653589793..3.141592653589793]), Space{Vector{ClosedInterval{Float64}}}(ClosedInterval{Float64}[-Inf..Inf])]`

## Action Space

`Space{Vector{ClosedInterval{Float32}}}(ClosedInterval{Float32}[-5.0..5.0])`

## Current State

```
[-0.008489333402506015, -0.005777222460162207, 3.1433747083208754, -0.00044036602123035876]
```


Here we can use furuta.x to to measure the current state.

In [8]:
measure(furuta)
@show furuta.x

furuta.x = [0.0, 0.0, 3.141592653589793, 0.0]


4-element Vector{Float64}:
 0.0
 0.0
 3.141592653589793
 0.0

## TD3 Model Implementation

In [9]:
function RL.Experiment(
    ::Val{:JuliaRL},
    ::Val{:TD3},
    ::Val{:Pendulum},
    ::Nothing;
    seed = 123,
)
    # params = furuta.params
    env = FPendulum()
    rng = StableRNG(seed)
    init = glorot_uniform(rng)
    ns = 4

    # Choosing the actor networks architecture
    create_actor() = Chain(
        Dense(ns, 30, relu; init = init),
        Dense(30, 30, relu; init = init),
        Dense(30, 1, tanh; init = init),
    ) |> cpu # Change to gpu if you have a gpu

    # Choosing the critic networks architecture
    create_critic_model() = Chain(
        Dense(ns + 1, 30, relu; init = init),
        Dense(30, 30, relu; init = init),
        Dense(30, 1; init = init),
    ) |> cpu # Change to gpu if you have a gpu

    create_critic() = TD3Critic(create_critic_model(), create_critic_model())

    agent = Agent(
        policy = TD3Policy(
            behavior_actor = NeuralNetworkApproximator(
                model = create_actor(),
                optimizer = ADAM(),
            ),
            behavior_critic = NeuralNetworkApproximator(
                model = create_critic(),
                optimizer = ADAM(),
            ),
            target_actor = NeuralNetworkApproximator(
                model = create_actor(),
                optimizer = ADAM(),
            ),
            target_critic = NeuralNetworkApproximator(
                model = create_critic(),
                optimizer = ADAM(),
            ),
            γ = 0.99f0,
            ρ = 0.99f0,
            batch_size = 64,
            start_steps = 1000,
            start_policy = RandomPolicy(-5.0..5.0; rng = rng),
            update_after = 1000,
            update_freq = 1,
            policy_freq = 2,
            target_act_limit = 1.0,
            target_act_noise = 0.1,
            act_limit = 1.0,
            act_noise = 0.1,
            rng = rng,
        ),
        trajectory = CircularArraySARTTrajectory(
            capacity = 10_000,
            state = Vector{Float32} => (ns,),
            action = Float32 => (),
        ),
    )

    stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
    hook = TotalRewardPerEpisode()
    Experiment(agent, env, stop_condition, hook, "# Play Pendulum with TD3")
end

In [10]:
# Train the model
using Plots
ex = E`JuliaRL_TD3_Pendulum`
run(ex)
@show plot(ex.hook.rewards)

# Play Pendulum with TD3


LoadError: method not implemented



## Testing the environment

In [None]:
rng = StableRNG(123)
RLBase.test_runnable!(env)

In [None]:
env |> action_space |> rand |> env

In [None]:
run(RandomPolicy(action_space(env)), env, StopAfterEpisode(1_000))
EmptyHook()

In [None]:
policy = RandomPolicy([action_space(env)])
ts, rs, xs = Int[], Float32[], Float64[]
for i = 1:2000
    env |> policy |> env
    # write your own logic here
    # like saving parameters, recording loss function, evaluating policy, etc.
    push!(ts, env.time)
    push!(rs, env.reward)
    push!(xs, env.x[1])
#    is_terminated(env) && reset!(env)
end
plot(ts, xs, xlabel = "time", ylabel = "x", legend = false)