In [1]:
using RATiLQR
using LinearAlgebra
using Random
using Distributions

In [2]:
"""
    RATCEMSolver(μ_init_array::Vector{Vector{Float64}},
    Σ_init_array::Vector{Matrix{Float64}}; kwargs...)

RAT CEM Solver initialized with `μ_init_array = [μ_0,...,μ_{N-1}]` and
`Σ_init_array = [Σ_0,...,Σ_{N-1}]`, where the initial control distribution at time
`k` is a Gaussian distribution `Distributions.MvNormal(μ_k, Σ_k)`.

# Optional Keyword Arguments
- `num_control_samples::Int64` -- number of Monte Carlo samples for the control
  trajectory. Default: `10`.
- `deterministic_dynamics::Bool` -- determinies whether to use deterministic prediction
  for the dynamics. If `true`, `num_trajectory_samples` must be 1. Default: `false`.
- `num_trajectory_samples::Int64` -- number of Monte Carlo samples for the state
  trajectory. Default: `10`.
- `μ_θ_init::Float64` -- initial mean parameter `μ_θ` for the risk-sensitivity. 
  Default: `1.0`.
- `σ_θ_init::Float64` -- initial covariance parameter `σ_θ` for the risk-sensitivity. 
  Default: `2.0`.
- `num_risk_samples::Int64` -- number of Monte Carlo samples for the risk-sensitivity. 
  Default: `10`.
- `num_elite::Int64` -- number of elite samples. Default: `10`.
- `iter_max::Int64` -- maximum iteration number. Default: `5`.
- `smoothing_factor::Float64` -- smoothing factor in (0, 1), used to update the
  mean and the variance of the Cross Entropy distribution for the next iteration.
  If `smoothing_factor` is `0.0`, the updated distribution is independent of the
  previous iteration. If it is `1.0`, the updated distribution is the same as the
  previous iteration. Default: `0.1`.
- `mean_carry_over::Bool` -- save `μ_array` & `μ_θ` of the last iteration and use it to
  initialize `μ_array` & `μ_θ` in the next call to `solve!`. Default: `false`.
"""
mutable struct RATCEMSolver
    # CE solver parameters
    num_control_samples::Int64
    deterministic_dynamics::Bool
    num_trajectory_samples::Int64
    num_risk_samples::Int64
    num_elite::Int64
    iter_max::Int64
    smoothing_factor::Float64
    mean_carry_over::Bool
    
    # action distributions
    μ_init_array::Vector{Vector{Float64}}
    Σ_init_array::Vector{Matrix{Float64}}
    μ_array::Vector{Vector{Float64}}
    Σ_array::Vector{Matrix{Float64}}
    # risk_param distributions
    μ_θ_init::Float64
    σ_θ_init::Float64
    μ_θ::Float64
    σ_θ::Float64
    N::Int64 # control sequence length > 0 (must be the same as N in FiniteHorizonGenerativeProblem)
    iter_current::Int64
end;

In [3]:
function RATCEMSolver(μ_init_array::Vector{Vector{Float64}},
                      Σ_init_array::Vector{Matrix{Float64}};
                      num_control_samples=10,
                      deterministic_dynamics=false,
                      num_trajectory_samples=10,
                      μ_θ_init=1.0,
                      σ_θ_init=2.0,
                      num_risk_samples=10,
                      num_elite=10,
                      iter_max=5,
                      smoothing_factor=0.1,
                      mean_carry_over=false)
    
    @assert length(μ_init_array) == length(Σ_init_array);
    if deterministic_dynamics
        @assert num_trajectory_samples == 1 "num_trajectory_samples must to be 1";
    end
    μ_array, Σ_array = copy(μ_init_array), copy(Σ_init_array);
    μ_θ, σ_θ = μ_θ_init, σ_θ_init;
    N = length(μ_init_array);
    iter_current = 0;
    
    return RATCEMSolver(num_control_samples, deterministic_dynamics,
                        num_trajectory_samples, num_risk_samples,
                        num_elite, iter_max, smoothing_factor, 
                        mean_carry_over, 
                        μ_init_array, Σ_init_array, 
                        μ_array, Σ_array,
                        μ_θ_init, σ_θ_init, μ_θ, σ_θ,
                        N, iter_current)
end;

In [5]:
function initialize!(rat_cem_solver::RATCEMSolver)
    rat_cem_solver.iter_current = 0;
    rat_cem_solver.μ_array = copy(rat_cem_solver.μ_init_array);
    rat_cem_solver.Σ_array = copy(rat_cem_solver.Σ_init_array);
    rat_cem_solver.μ_θ = rat_cem_solver.μ_θ_init;
    rat_cem_solver.σ_θ = rat_cem_solver.σ_θ_init;
    end

initialize! (generic function with 1 method)