# ContinuumWorld
This notebook demonstrates some of the global policy features of AdaStress.

In [None]:
using AdaStress
import AdaStress.GrayBox
using BSON
using Distributions
using Parameters
using Random

## Pawn

In [None]:
@with_kw mutable struct Initialization
    x::Distribution = Uniform(0, 10)
    y::Distribution = Uniform(0, 10)
end

@with_kw mutable struct Pawn
    x::Float64 = 0.0
    y::Float64 = 0.0
end

vec(pawn::Pawn) = [pawn.x, pawn.y]

observation(pawn::Pawn) = vec(pawn) / 10.0 # normalized state

function initialize(pawn::Pawn, init::Initialization)
    pawn.x = rand(init.x)
    pawn.y = rand(init.y)
end

function update(pawn::Pawn, Δx::Float64, Δy::Float64)
    pawn.x += Δx
    pawn.y += Δy
end

## Disturbance and failure models

In [None]:
@with_kw mutable struct Disturbance
	x::Distribution = Normal(0.0, 0.25)
	y::Distribution = Normal(0.0, 0.25)
end

@with_kw mutable struct FailureZone
	x::Float64 = 0.0
	y::Float64 = 0.0
	r::Float64 = 0.0
end

Base.in(pawn::Pawn, zone::FailureZone) = (pawn.x - zone.x)^2 + (pawn.y - zone.y)^2 <= zone.r^2

distance(pawn::Pawn, zone::FailureZone) = max(sqrt((pawn.x - zone.x)^2 + (pawn.y - zone.y)^2) - zone.r, 0.0)

## Metrics

In [None]:
@with_kw mutable struct Metrics
    d::Float64 = 0.0
	in_zone::Bool = false
end

const Log = Dict{String, Any}

function initialize(m::Metrics, pawn::Pawn, zone::FailureZone)
	update(m, pawn, zone)
end

function update(m::Metrics, pawn::Pawn, zone::FailureZone)
    m.d = distance(pawn, zone) 
	m.in_zone = pawn in zone
end

## Simulator

In [None]:
@with_kw mutable struct Simulator <: AdaStress.GrayBox
	t::Float64 = 0.0
	t_max::Float64 = 50.0
	pawn::Pawn = Pawn()
	init::Initialization = Initialization()
	disturbance::Disturbance = Disturbance()
	zone::FailureZone = FailureZone(7, 3, 1)
	metrics::Metrics = Metrics()
    env::AdaStress.Environment = AdaStress.Environment()
    log::Log = Log()
    logging::Bool = false
    rand_time::Bool = true
end

function initialize(sim::Simulator)
	sim.t = sim.rand_time ? rand() * sim.t_max : 0.0
	initialize(sim.pawn, sim.init)
	initialize(sim.metrics, sim.pawn, sim.zone)
    sim.env[:Δx] = sim.disturbance.x
    sim.env[:Δy] = sim.disturbance.y
    initialize(sim.log, sim)
    return
end

function update(sim::Simulator, value::AdaStress.EnvironmentValue)
	sim.t += 1.0
	update(sim.pawn, value[:Δx], value[:Δy])
	update(sim.metrics, sim.pawn, sim.zone)
	update(sim.log, sim)
	return
end

## Logging

In [None]:
function initialize(log::Log, sim::Simulator)
    sim.logging || return
    log["t"] = [sim.t]
    log["pawn"] = [vec(sim.pawn)]
    log["d"] = [sim.metrics.d]
    log["in_zone"] = [sim.metrics.in_zone]
end

function update(log::Log, sim::Simulator)
    sim.logging || return
    push!(log["t"], sim.t)
    push!(log["pawn"], vec(sim.pawn))
    push!(log["d"], sim.metrics.d)
    push!(log["in_zone"], sim.metrics.in_zone)
end

save(log::Log, filename::String) = BSON.@save filename * ".bson" log

## Interface setup

In [None]:
Interface.reset!(sim::Simulator) = initialize(sim)

Interface.environment(sim::Simulator) = sim.env

Interface.observe(sim::Simulator) = vcat(observation(sim.pawn), sim.t / sim.t_max)

Interface.step!(sim::Simulator, x::AdaStress.EnvironmentValue) = update(sim, x)

Interface.isterminal(sim::Simulator) = sim.t >= sim.t_max

Interface.isevent(sim::Simulator) = sim.metrics.in_zone

Interface.distance(sim::Simulator) = sim.metrics.d

## Solver

In [None]:
using AdaStress.SoftActorCritic

In [None]:
mdp_env(; kwargs...) = Interface.ASTMDP(Simulator(; kwargs...); reward_bonus=100.0)

In [None]:
Random.seed!(0)
sac = SAC(;
    env_fn=() -> mdp_env(),
    obs_dim=3, 
    act_dim=2,
    gamma=1.0,
    act_mins=-3.0*ones(2),
    act_maxs=3.0*ones(2),
    hidden_sizes=[30,30,30],
    num_q=3,
    max_buffer_size=1000000,
    batch_size=1024,
    epochs=11, # low value for testing only; set to >25 to see learning
    steps_per_epoch=1000,
    start_steps=10000,
    max_ep_len=50,
    update_after=10000,
    update_every=1000,
    num_test_episodes=100,
    displays=[(:fails, mdp -> mdp.sim.metrics.in_zone)],
)

SoftActorCritic.ProgressMeter.ijulia_behavior(:clear)
ac, info = SoftActorCritic.solve(sac);

## Analysis

In [None]:
using AdaStress.PolicyValueVerification

In [None]:
network = mean_network(ac; act_mins=-3*ones(2), act_maxs=3*ones(2))

In [None]:
cs = CrossSection([:x1, :x2, 0.9])
limits = ([0.0, 0.0], [1.0, 1.0])
p = PolicyValueVerification.visualize(network, cs, limits)

In [None]:
nnet = cross_section(network, cs, limits)
midpoint_value = AdaStress.Analysis.PolicyValueVerification.compute_output(nnet, mean(limits))[]

In [None]:
r = BinaryRefinery(network=nnet, val=midpoint_value, tol=0.01)
root = get_root(limits)
@time refine!(root, r)

In [None]:
@show num_leaves(root)
@show coverage(root)
@show coverage(root, true)
@show coverage(root, false)

In [None]:
visualize!(deepcopy(p), root)

In [None]:
visualize!(deepcopy(p), root; fill=true)

In [None]:
true