In [1]:
using LinearAlgebra

In [None]:
## Bootstrap Partilce Filter (additive Gaussian)
mutable struct ParticleFilter
    f::Function
    h::Function
    W::Matrix{Float64}
    V::Matrix{Float64}
    particles::Matrix{Float64} # n x L matrix of particles, where n is the state dimension and L is the number of particles
    likelihoods::Vector{Float64}
    function ParticleFilter(f, h, W, V)
        if !isposdef(W)
            throw(ArgumentError("W must be a positive semi-definite matrix."))
        end
        if !isposdef(V)
            throw(ArgumentError("V must be a positive semi-definite matrix."))
        end
        new(f, h, W, V, particles, likelihoods)
    end
end

function time_update(PF::ParticleFilter, u::Vector{Float64})
    particles_plus = Matrix{Float64}(undef, size(PF.particles))
    Threads.@threads for i = axes(particles_plus,2)
        particles_plus[:,i] = PF.f(PF.particles[:,i], u)
    end
    return particles_plus
end

function measurement_update(PF::ParticleFilter, y::Vector{Float64})
    likelihoods = zeros(size(PF.particles,2))
    V_inv = inv(PF.V)
    Threads.@threads for i = axes(particle_filter.particles,2)
        err = y - PF.h(PF.particles[:,i])
        likelihoods[i] = exp.(-1/2 * err' * V_inv *err)
    end
    # We assume resampling is done every time step, so no need to multiply with old likelihoods
    PF.likelihoods = likelihoods ./ sum(likelihoods)
end

function resampler!(PF::ParticleFilter)
    particles_resampled = zeros(size(PF.particles))
    CDF = cumsum(PF.likelihoods)
    for i = axes(particles_resampled,2)
        particles_resampled[:,i] = PF.particles[:,findfirst(CDF .>= rand(1))]
    end
    PF.particles = particles_resampled
end

function propagate_PF!(PF::ParticleFilter, u::Vector{Float64}, y::Vector{Float64})
    time_update(PF, u)
    measurement_update(PF, y)
    resampler(PF)
end

propagate_PF (generic function with 1 method)

In [None]:
mutable struct SSA
    PF::ParticleFilter
    K₀::Function
    N::Int # prediction horizon length
    M::Int # number of monte carlo samples
    running_cost::Function
    check_constraint_violation::Function
    α::Float64 # constraint violation threshold
end

function SSA_sample_averages(SSA::SSA)
    n = size(SSA.PF.particles, 1)
    L = size(SSA.PF.particles, 2)
    X_prime = Array{Float64}(undef, (n, SSA.N, L))
    α_t_achieved = Array{Float64}(undef, (L, SSA.N))
    cost_t_achieved = Array{Float64}(undef, (L, SSA.N))
    Threads.@threads for i = 1:L
        X_prime[:,1,i] = SSA.PF.particles[:,i]
        x_dprime_per_sample = SSA.PF.particles[:,rand(1:L, SSA.M)]
         for t = 1:SSA.N-1
            u = SSA.K₀(X_prime[:,t,i])
            X_prime[:,t+1,i] = SSA.PF.f(X_prime[:,t,i], u)
            cost_t = 0.0
            α_t = 0.0
            for j = 1:SSA.M
            x_dprime_per_sample[:,j] = SSA.PF.f.(x_dprime_per_sample[:,j], u)
            cost_t += SSA.running_cost(x_dprime_per_sample[:,j], u)
            α_t += SSA.check_constraint_violation(x_dprime_per_sample[:,j])
            end
            α_t_achieved[i,t] = α_t / SSA.M
            cost_t_achieved[i,t] = cost_t / SSA.M
        end
    end
    return X_prime[:,1,:], α_t_achieved, cost_t_achieved
end

function SSA_select(SSA::SSA, x_prime_0, α_t_achieved, cost_t_achieved)
    n = size(SSA.PF.particles, 1)
    L = size(SSA.PF.particles, 2)
    feasible_indices = falses(L)
    cost_achieved = zeros(L)
    for i = 1:L
        # Check feasibility
        if all(α_t_achieved[i,:] .<= SSA.α)
            feasible_indices[i] = True
        end
        # Check predicted cost
        cost_achieved = sum(cost_t_achieved[i,:])
    end
    feasible_costs = cost_achieved[feasible_indices]
    feasible_indices_set = findall(feasible_indices)
    if isempty(feasible_costs)
        println("No feasible state found!")
        α_achieved_sum = sum(α_t_achieved, dims=2)
        min_α, min_index = findmin(α_achieved_sum)
        return x_prime_0[:,min_index]
    else
        min_cost, min_index = findmin(cost_achieved[feasible_indices])
        return x_prime_0[:, feasible_indices_set[min_index]]
    end
end

SSA_select (generic function with 1 method)