In [1]:
using LinearAlgebra, Revise, ControlSystemsBase, Plots

In [2]:
## Bootstrap Partilce Filter (additive Gaussian)
mutable struct ParticleFilter
    f::Function
    h::Function
    W::Matrix{Float64}
    V::Matrix{Float64}
    particles::Matrix{Float64} # n x L matrix of particles, where n is the state dimension and L is the number of particles
    likelihoods::Vector{Float64}
    function ParticleFilter(f, h, W, V, particles, likelihoods)
        if !isposdef(W)
            throw(ArgumentError("W must be a positive semi-definite matrix."))
        end
        if !isposdef(V)
            throw(ArgumentError("V must be a positive semi-definite matrix."))
        end
        new(f, h, W, V, particles, likelihoods)
    end
end

function time_update(PF::ParticleFilter, u::Vector{Float64})
    particles_plus = Matrix{Float64}(undef, size(PF.particles))
    Threads.@threads for i = axes(particles_plus,2)
        particles_plus[:,i] = PF.f(PF.particles[:,i], u)
    end
    return particles_plus
end

function measurement_update(PF::ParticleFilter, y::Vector{Float64})
    likelihoods = zeros(size(PF.particles,2))
    V_inv = inv(PF.V)
    Threads.@threads for i = axes(particle_filter.particles,2)
        err = y - PF.h(PF.particles[:,i])
        likelihoods[i] = exp.(-1/2 * err' * V_inv *err)
    end
    # We assume resampling is done every time step, so no need to multiply with old likelihoods
    PF.likelihoods = likelihoods ./ sum(likelihoods)
end

function resampler!(PF::ParticleFilter)
    particles_resampled = zeros(size(PF.particles))
    CDF = cumsum(PF.likelihoods)
    for i = axes(particles_resampled,2)
        particles_resampled[:,i] = PF.particles[:,findfirst(CDF .>= rand(1))]
    end
    PF.particles = particles_resampled
end

function propagate_PF!(PF::ParticleFilter, u::Vector{Float64}, y::Vector{Float64})
    time_update(PF, u)
    measurement_update(PF, y)
    resampler(PF)
end

propagate_PF! (generic function with 1 method)

In [None]:
mutable struct SSA
    PF::ParticleFilter
    K₀::Function
    N::Int # prediction horizon length
    M::Int # number of monte carlo samples
    running_cost::Function
    check_constraint_violation::Function
    α::Float64 # constraint violation threshold
end

function SSA_sample_averages(SSA::SSA)
    n = size(SSA.PF.particles, 1)
    L = size(SSA.PF.particles, 2)
    X_prime = Array{Float64}(undef, (n, SSA.N, L))
    α_t_achieved = Array{Float64}(undef, (L, SSA.N))
    cost_t_achieved = Array{Float64}(undef, (L, SSA.N))
    Threads.@threads for i = 1:L
        X_prime[:,1,i] = SSA.PF.particles[:,i]
        x_dprime_per_sample = SSA.PF.particles[:,rand(1:L, SSA.M)]
         for t = 1:SSA.N-1
            u = SSA.K₀(X_prime[:,t,i])
            X_prime[:,t+1,i] = SSA.PF.f(X_prime[:,t,i], u)
            cost_t = 0.0
            α_t = 0.0
            for j = 1:SSA.M
            x_dprime_per_sample[:,j] = SSA.PF.f.(x_dprime_per_sample[:,j], u)
            cost_t += SSA.running_cost(x_dprime_per_sample[:,j], u)
            α_t += SSA.check_constraint_violation(x_dprime_per_sample[:,j])
            end
            α_t_achieved[i,t] = α_t / SSA.M
            cost_t_achieved[i,t] = cost_t / SSA.M
        end
    end
    return X_prime[:,1,:], α_t_achieved, cost_t_achieved
end

function SSA_select(SSA::SSA, x_prime_0, α_t_achieved, cost_t_achieved)
    n = size(SSA.PF.particles, 1)
    L = size(SSA.PF.particles, 2)
    feasible_indices = falses(L)
    cost_achieved = zeros(L)
    for i = 1:L
        # Check feasibility
        if all(α_t_achieved[i,:] .<= SSA.α)
            feasible_indices[i] = true
        end
        # Check predicted cost
        cost_achieved = sum(cost_t_achieved[i,:])
    end
    feasible_costs = cost_achieved[feasible_indices]
    feasible_indices_set = findall(feasible_indices)
    if isempty(feasible_costs)
        println("No feasible state found!")
        α_achieved_sum = sum(α_t_achieved, dims=2)
        min_α, min_index = findmin(α_achieved_sum)
        return x_prime_0[:,min_index[0]]
    else
        min_cost, min_index = findmin(cost_achieved[feasible_indices])
        return x_prime_0[:, feasible_indices_set[min_index[0]]]
    end
end

SSA_select (generic function with 1 method)

In [4]:
# Define state-space example
A =  [0 0 1 0;
      0 0 0 1;
      0 0 0 0;
      0 0 0 0]
B =  [0.0 0.0;
      0.0 0.0;
      1.0 0.0;
      0.0 1.0]
C =  [1.0 0.0 0.0 0.0;
      0.0 1.0 0.0 0.0]

# Compute discrete-time A_d and B_d
#  Time step for discretization
Δt = 0.1
Ad = exp(Δt * A)  # Exponential of matrix A
Bd = Δt * Ad * B  # Euler approximation of integral equation for Bd

# Gain matrices for nominal feedback controller


# Noise matrix
W = Diagonal([0.15, 0.15, 0.15, 0.15])
V = Diagonal([0.15, 0.15])

f(x::Vector{Float64}, u::Vector{Float64}) = Ad * x + Bd * u + sqrt(W) * randn(size(W,1))

h(x::Vector{Float64}) = C * x + V * randn(size(V,1))

sys = ss(Ad,Bd,C,0,Δt)

Q = I
R = 2I

K_LQR = -lqr(sys, Q, R)

controller(x::Vector{Float64}) = K_LQR * x

controller (generic function with 1 method)

In [5]:
L = 2000
initial_particles = sqrt(I) * randn(size(W,1), L)
initial_likelihoods = ones(L) / L
pf = ParticleFilter(f, h, W, V, initial_particles, initial_likelihoods)

ParticleFilter(Main.f, Main.h, [0.15 0.0 0.0 0.0; 0.0 0.15 0.0 0.0; 0.0 0.0 0.15 0.0; 0.0 0.0 0.0 0.15], [0.15 0.0; 0.0 0.15], [-1.614732652059376 1.0140575437176969 … -1.3425763888998568 -0.3267585820804758; -0.6234456976265337 0.0027785980048414 … -0.7730950276893058 1.5087290241421083; -0.21996531004999972 -0.9711938438934449 … 0.7456351732472276 -1.1013532304714058; 2.24834914703995 -0.6874887129040155 … -1.6166810280567816 0.2459031233636174], [0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005  …  0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005])

In [6]:
cost(x::Vector{Float64}, u::Vector{Float64}) = x' * Q * x + u' * R * u
constraint_violation(x::Vector{Float64}) = 0.0
N = 5
M = 100
α = 0.15
ssa = SSA(pf, controller, N, M, cost, constraint_violation, α)

SSA(ParticleFilter(Main.f, Main.h, [0.15 0.0 0.0 0.0; 0.0 0.15 0.0 0.0; 0.0 0.0 0.15 0.0; 0.0 0.0 0.0 0.15], [0.15 0.0; 0.0 0.15], [-1.614732652059376 1.0140575437176969 … -1.3425763888998568 -0.3267585820804758; -0.6234456976265337 0.0027785980048414 … -0.7730950276893058 1.5087290241421083; -0.21996531004999972 -0.9711938438934449 … 0.7456351732472276 -1.1013532304714058; 2.24834914703995 -0.6874887129040155 … -1.6166810280567816 0.2459031233636174], [0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005  …  0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005]), Main.controller, 5, 100, Main.cost, Main.constraint_violation, 0.15)

In [None]:
T = 20
x_true = zeros(size(W,1))
for t = 1:T
    x_prime0, α_t_achieved, cost_t_achieved = SSA_sample_averages(ssa)
    x_prime_optimal = SSA_select(ssa, x_prime_0, α_t_achieved, cost_t_achieved)
    u = ssa.K₀(x_prime_optimal)
    y = h(x_true)
    propagate_PF!(ssa.PF, u, y)
    x_true = ssa.PF.f(x_true, u)
end

CompositeException: TaskFailedException

    nested task error: DimensionMismatch: arrays could not be broadcast to a common size: a has axes Base.OneTo(4) and b has axes Base.OneTo(2)
    Stacktrace:
      [1] _bcs1
        @ ./broadcast.jl:523 [inlined]
      [2] _bcs
        @ ./broadcast.jl:517 [inlined]
      [3] broadcast_shape
        @ ./broadcast.jl:511 [inlined]
      [4] combine_axes
        @ ./broadcast.jl:492 [inlined]
      [5] instantiate
        @ ./broadcast.jl:302 [inlined]
      [6] materialize(bc::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(f), Tuple{Vector{Float64}, Vector{Float64}}})
        @ Base.Broadcast ./broadcast.jl:867
      [7] macro expansion
        @ ~/Documents/GitHub/robust-less-dist-shifts/SSA-CIDA/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_W2sZmlsZQ==.jl:26 [inlined]
      [8] (::var"#101#threadsfor_fun#18"{var"#101#threadsfor_fun#17#19"{SSA, Matrix{Float64}, Matrix{Float64}, Array{Float64, 3}, Int64, UnitRange{Int64}}})(tid::Int64; onethread::Bool)
        @ Main ./threadingconstructs.jl:252
      [9] #101#threadsfor_fun
        @ ./threadingconstructs.jl:219 [inlined]
     [10] (::Base.Threads.var"#1#2"{var"#101#threadsfor_fun#18"{var"#101#threadsfor_fun#17#19"{SSA, Matrix{Float64}, Matrix{Float64}, Array{Float64, 3}, Int64, UnitRange{Int64}}}, Int64})()
        @ Base.Threads ./threadingconstructs.jl:154