In [1]:
using Printf
using Random
using LinearAlgebra
using Statistics
using Interact
using Plots
using Krylov
rng = MersenneTwister(18);

In [2]:
#include("/Data/Documents/PhD/2022/Fall_22/Numerical_Lineal_Algebra/homeworks/hw5/load_plot_pkg.jl")
#output = false

In [3]:
function generate_data(N, M, K, sig_x_2)

    dicti = randn(N, M) * sqrt(1/N)
    dicti = mapslices(dicti -> dicti / norm(dicti), dicti, dims=1);

    supp = zeros(Int64, M)
    supp[1:K] .= 1
    supp = shuffle(supp)

    coeff = randn(M) * sqrt(sig_x_2) .* supp

    return dicti, coeff, supp
end


generate_data (generic function with 1 method)

## Standard

In [4]:
function run_mp(signal, dicti, tol, max_steps=100000)

    supp = zeros(Int64, size(dicti)[2])
    coeffs = zeros(size(dicti)[2])
    res = copy(signal)
    errors = zeros(0)

    for i = 1:max_steps

        dot_res_dicti = transpose(dicti) * res
        max_index = argmax(dot_res_dicti.^2)
        supp[max_index] = 1
        coeffs[max_index] += dot_res_dicti[max_index]

        res = signal - dicti[:, supp.==1] * coeffs[supp.==1];

        append!(errors, norm(res))

        if errors[i] < tol
            break
        end
    end

    return coeffs, supp, errors
end


run_mp (generic function with 2 methods)

In [5]:
function run_omp(signal, dicti, tol, max_steps=1000000)

    supp = zeros(Int, size(dicti, 2))
    coeffs = zeros(size(dicti, 2))
    res = copy(signal)
    errors = zeros(0)

    for i in 1:max_steps
        
        max_index = argmax(abs.(dicti' * res))
        supp[max_index] = 1
        coeffs[supp.==1] = dicti[:, supp.==1] \ signal
        
        res = signal - dicti[:, supp .== 1] * coeffs[supp .== 1]
        
        append!(errors, norm(res))

        if errors[i] < tol
            break
        end
    end

    return coeffs, supp, errors
end

run_omp (generic function with 2 methods)

## Bayesian

In [6]:
function log_joint_p(res, coeffs, supp, sig_w_2, sig_x_2, pj)

    log_p_y_xs = (-0.5 / sig_w_2) * norm(res)^2 - 0.5 * size(res)[1] * log(2*pi * sig_w_2)
    log_p_x = -0.5 / sig_x_2 * norm(coeffs)^2 - 0.5 * size(coeffs)[1] * log(2*pi * sig_x_2)
    log_p_s = sum(supp .* log(pj) .+ (supp .- 1) .* log(1 - pj))

    return log_p_y_xs + log_p_x + log_p_s
end

log_joint_p (generic function with 1 method)

In [7]:
function run_bmp(signal, dicti, sig_w_2, sig_x_2, pj)
    supp = zeros(Int64, size(dicti)[2])
    coeffs = zeros(size(dicti)[2])
    res = copy(signal)
    errors = zeros(0)

    lambda = sig_w_2 * log((1 - pj) / pj)
    epsilon = sig_w_2 / sig_x_2
    thresh = 2 * (sig_x_2 + sig_w_2) * lambda / sig_x_2
    old_logpos = -Inf

    for i = 1:100000

        dot_dicti_res = transpose(dicti) * res

        s_tilde = Int64.(((supp .* coeffs .+ dot_dicti_res).^2) .>  thresh)
        x_tilde = s_tilde .* (sig_x_2 / (sig_x_2 + sig_w_2)) .* (coeffs .* supp .+ dot_dicti_res)

        f1 = -0.5 .* map(norm, eachcol(res .+ dicti .* (coeffs .- x_tilde)')).^2
        f2 = -0.5 .* epsilon .* x_tilde.^2 .- lambda .* s_tilde

        i_ast = argmax(f1 + f2)

        supp[i_ast] = s_tilde[i_ast]
        coeffs[i_ast] = x_tilde[i_ast]

        res = signal - dicti[:, supp.==1] * coeffs[supp.==1]
        curr_logpos = log_joint_p(res, coeffs, supp, sig_w_2, sig_x_2, pj)

        #append!(errors, norm(res))
        append!(errors, curr_logpos)

        if (curr_logpos <= old_logpos)
            return coeffs, supp, errors
        else
            old_logpos = curr_logpos
        end
    end

    return coeffs, supp, errors
end

run_bmp (generic function with 1 method)

In [8]:
function run_bomp(signal, dicti, sig_w_2, sig_x_2, pj)
    supp = zeros(Int64, size(dicti)[2])
    coeffs = zeros(size(dicti)[2])
    res = copy(signal)
    errors = zeros(0)

    lambda = sig_w_2 * log((1 - pj) / pj)
    epsilon = sig_w_2 / sig_x_2
    thresh = 2 * (sig_x_2 + sig_w_2) * lambda / sig_x_2
    old_logpos = -Inf

    for i = 1:100000

        dot_dicti_res = transpose(dicti) * res

        s_tilde = Int64.(((supp .* coeffs .+ dot_dicti_res).^2) .>  thresh)
        x_tilde = s_tilde .* (sig_x_2 / (sig_x_2 + sig_w_2)) .* (coeffs .* supp .+ dot_dicti_res)

        f1 = -0.5 .* map(norm, eachcol(res .+ dicti .* (coeffs .- x_tilde)')).^2
        f2 = -0.5 .* epsilon .* x_tilde.^2 .- lambda .* s_tilde

        i_ast = argmax(f1 + f2)

        supp[i_ast] = s_tilde[i_ast]
        coeffs[supp.==1] = cgls(dicti[:, supp.==1], signal, λ=sig_w_2/sig_x_2)[1]

        res = signal - dicti[:, supp.==1] * coeffs[supp.==1]
        curr_logpos = log_joint_p(res, coeffs, supp, sig_w_2, sig_x_2, pj)

        #append!(errors, norm(res))
        append!(errors, curr_logpos)

        if (curr_logpos <= old_logpos)
            return coeffs, supp, errors
        else
            old_logpos = curr_logpos
        end
    end

    return coeffs, supp, errors
end

run_bomp (generic function with 1 method)

In [9]:
function run_bstomp(signal, dicti, sig_w_2, sig_x_2, pj)
    supp = zeros(Int64, size(dicti)[2])
    coeffs = zeros(size(dicti)[2])
    res = copy(signal)
    errors = zeros(0)
    
    lambda = sig_w_2 * log((1 - pj) / pj)
    epsilon = sig_w_2 / sig_x_2
    old_logpos = -Inf
    
    for i = 1:100000
        
        dot_dicti_res = transpose(dicti) * res
        
        norm_res_2_N = norm(res)^2 / size(signal)[1]
        thresh = 2 * norm_res_2_N * (sig_x_2 + norm_res_2_N) / sig_x_2 * log((1 - pj) / pj)
        
        supp = Int64.(((supp .* coeffs .+ dot_dicti_res).^2) .>  thresh)
        coeffs[supp.==1] = cgls(dicti[:, supp.==1], signal, λ=sig_w_2/sig_x_2)[1]
        coeffs[supp.==0] .= 0

        res = signal - dicti[:, supp.==1] * coeffs[supp.==1]
        curr_logpos = log_joint_p(res, coeffs, supp, sig_w_2, sig_x_2, pj)

        #append!(errors, norm(res))
        append!(errors, curr_logpos)

        if (curr_logpos <= old_logpos)
            return coeffs, supp, errors
        else
            old_logpos = curr_logpos
        end
    end

    return coeffs, supp, errors
end

run_bstomp (generic function with 1 method)

In [10]:
N = 128
M = 256
K = 15
sig_w_2 = 1e-5
sig_x_2 = 10

config = Dict([
    ("N", N),
    ("M", M),
    ("sig_w_2", sig_w_2),
    ("sig_x_2", sig_x_2),
]);

In [11]:
function run_trial(config, k)
    
    dicti, true_coeffs, true_supp = generate_data(config["N"], config["M"], k, config["sig_x_2"])
    signal = dicti * true_coeffs + randn(config["N"]) .* sqrt(config["sig_w_2"])
    
    coeffs_mp, supp_mp, errors_mp = run_mp(signal, dicti, sqrt(config["N"] * config["sig_w_2"]))
    coeffs_omp, supp_omp, errors_omp = run_omp(signal, dicti, sqrt(config["N"] * config["sig_w_2"]))

    coeffs_bmp, supp_bmp, errors_bmp = run_bmp(signal, dicti, config["sig_w_2"], config["sig_x_2"]*100, k/config["M"])
    coeffs_bomp, supp_bomp, errors_bomp = run_bomp(signal, dicti, config["sig_w_2"], config["sig_x_2"]*100, k/config["M"])
    coeffs_bstomp, supp_bstomp, errors_bstomp = run_bstomp(signal, dicti, config["sig_w_2"], config["sig_x_2"]*100, k/config["M"])

    output = Dict([
        ("Ref", (true_coeffs, true_supp)),
        ("MP", (coeffs_mp, supp_mp, errors_mp)),
        ("OMP", (coeffs_omp, supp_omp, errors_omp)),
        ("BMP", (coeffs_bmp, supp_bmp, errors_bmp)),
        ("BOMP", (coeffs_bomp, supp_bomp, errors_bomp)),
        ("BSTOMP", (coeffs_bstomp, supp_bstomp, errors_bstomp))
    ])

    return output
end

run_trial (generic function with 1 method)

In [24]:
function compute_stats(res_trial)

    stats = Dict()

    true_coeffs, true_supp = res_trial["Ref"]

    for (key, values) in res_trial
        if (key == "Ref")
            continue
        else
            max_error = norm(values[1] - true_coeffs, Inf)
            #mean_error = sum((values[1] - true_coeffs).^2) / sum((values[2].==1) .+ (true_supp.==1))
            mean_error = norm(values[1] - true_coeffs) / norm(true_coeffs)
            supp_eq = Int(values[2] == true_supp)
            n_iter = size(values[3])[1]

            stats[key] = (max_error, mean_error, supp_eq, n_iter)
        end
    end

    return stats
end


compute_stats (generic function with 1 method)

In [136]:
res_trial = run_trial(config, 10)

x = Array(range(1, M, length=M))

plot(
    layout=(2, 3), size=(1000, 400), left_margin = 5Plots.mm,
    bottom_margin = 5Plots.mm
)

plot!(x, res_trial["Ref"][1], label="", subplot=1, lc=:black, ylabel="Coef. Value")
scatter!(x[res_trial["Ref"][2].==1], res_trial["Ref"][1][res_trial["Ref"][2].==1], label="Ref", subplot=1)

plot!(x, res_trial["Ref"][1], label="Ref", subplot=2, lc=:black)
scatter!(x[res_trial["MP"][2].==1], res_trial["MP"][1][res_trial["MP"][2].==1], label="MP", subplot=2, mc=:blue)

plot!(x, res_trial["Ref"][1], label="Ref", subplot=3, lc=:black)
scatter!(x[res_trial["BMP"][2].==1], res_trial["BMP"][1][res_trial["BMP"][2].==1], label="BMP", subplot=3, mc=:blue)

plot!(x, res_trial["Ref"][1], label="Ref", subplot=4, lc=:black, ylabel="Coef. Value")
scatter!(x[res_trial["OMP"][2].==1], res_trial["OMP"][1][res_trial["OMP"][2].==1], label="OMP", subplot=4, mc=:red)

plot!(x, res_trial["Ref"][1], label="Ref", subplot=5, lc=:black)
scatter!(x[res_trial["BOMP"][2].==1], res_trial["BOMP"][1][res_trial["BOMP"][2].==1], label="BOMP", subplot=5, mc=:red, xlabel="Index")

plot!(x, res_trial["Ref"][1], label="Ref", subplot=6, lc=:black)
scatter!(x[res_trial["BSTOMP"][2].==1], res_trial["BSTOMP"][1][res_trial["BSTOMP"][2].==1], label="BSTOMP", subplot=6, mc=:green)


savefig("Figure_Recon.pdf")

"/Data/Documents/PhD/Spring23/Sparse_Analysis/project/Figure_Recon.pdf"

In [30]:
K = Int.(round.(LinRange(5, 100, 13)))
n_trials = 100

results = Dict([
    ("MP", (zeros(size(K)[1], n_trials), zeros(size(K)[1]), zeros(size(K)[1]))),
    ("OMP", (zeros(size(K)[1], n_trials), zeros(size(K)[1]), zeros(size(K)[1]))),
    ("BMP", (zeros(size(K)[1], n_trials), zeros(size(K)[1]), zeros(size(K)[1]))),
    ("BOMP", (zeros(size(K)[1], n_trials), zeros(size(K)[1]), zeros(size(K)[1]))),
    ("BSTOMP", (zeros(size(K)[1], n_trials), zeros(size(K)[1]), zeros(size(K)[1])))
])


for i = 1:size(K)[1]

    k = K[i]
    println(i)
    for j = 1:n_trials

        res_trial = run_trial(config, k)
        stats_trial = compute_stats(res_trial)

        for (key, values) in stats_trial
            results[key][1][i, j] = stats_trial[key][2] # Error
            results[key][2][i] += stats_trial[key][3] # Support recovery
            results[key][3][i] += stats_trial[key][4] # Support recovery
        end
    end

    for (key, values) in results
        results[key][2][i] /= n_trials
        results[key][3][i] /= n_trials
    end
end

1


2


3


4


5


6


7


8


9


10


11


12


13


In [113]:
plot(
    layout=(1, 2), size=(800, 400), left_margin = 5Plots.mm,
    bottom_margin = 5Plots.mm, xlabel="Sparsity (K)"
)

p_errors = plot!(K, mean(results["MP"][1], dims=2), lc=:red, label="MP",
    marker=:^, mc=:red, ms=:5, yscale=:log10, legend=:false, subplot=1,
    ylabel="Mean relative error"
)
p_errors = plot!(K, mean(results["BMP"][1], dims=2), 
    lc=:red, label="BMP", marker=:^, mc=:red, ms=:5, ls=:dash, yscale=:log10, subplot=1
)
p_errors = plot!(K, mean(results["OMP"][1], dims=2), 
    lc=:blue, label="OMP", marker=:cross, mc=:blue, ms=:5, yscale=:log10, subplot=1
)
p_errors = plot!(K, mean(results["BOMP"][1], dims=2), 
    lc=:blue, label="BOMP", marker=:cross, mc=:blue, ms=:5, ls=:dash, yscale=:log10, subplot=1 
)
p_errors = plot!(K, mean(results["BSTOMP"][1], dims=2), 
    lc=:green, label="BSTOMP", marker=:circle, mc=:green, ms=:5, ls=:dash, yscale=:log10, subplot=1
)
p_errors = hline!([1], yscale=:log10, lc=:black, label="", subplot=1)

p_supp = plot!(
    K, 1.0 .- results["MP"][2], lc=:red, mc=:red, marker=:^, legend=:outerright,
    ms=:5, label="MP", subplot=2, ylabel="Freq. of error in support"
)
p_supp = plot!(K, 1.0 .- results["BMP"][2], lc=:red, ls=:dash, marker=:^, mc=:red, ms=:5, label="BMP", subplot=2)
p_supp = plot!(K, 1.0 .- results["OMP"][2], lc=:blue, mc=:blue, marker=:cross, label="OMP", subplot=2)
p_supp = plot!(K, 1.0 .- results["BOMP"][2], lc=:blue, ls=:dash, mc=:blue, marker=:cross, label="BOMP", subplot=2)
p_supp = plot!(K, 1.0 .- results["BSTOMP"][2], lc=:green, ls=:dash, mc=:green, marker=:circle, label="BSTOMP", subplot=2)


savefig("Figure_Error.pdf")

"/Data/Documents/PhD/Spring23/Sparse_Analysis/project/Figure_Error.pdf"

In [17]:
function run_time_test(config, K, n_trials)
    
    results = Dict([
        ("MP", (zeros(size(K)[1]))),
        ("OMP", (zeros(size(K)[1]))),
        ("BMP", (zeros(size(K)[1]))),
        ("BOMP", (zeros(size(K)[1]))),
        ("BSTOMP", (zeros(size(K)[1])))
    ])

    for i = 1:size(K)[1]

        k = K[i]
        for j = 1:n_trials
    
            dicti, true_coeffs, true_supp = generate_data(config["N"], config["M"], k, config["sig_x_2"])
            signal = dicti * true_coeffs + randn(config["N"]) .* sqrt(config["sig_w_2"])
            time_mp = @elapsed run_mp(signal, dicti, sqrt(config["N"] * config["sig_w_2"]))
            time_omp = @elapsed run_omp(signal, dicti, sqrt(config["N"] * config["sig_w_2"]))
            time_bmp = @elapsed run_bmp(signal, dicti, config["sig_w_2"], config["sig_x_2"], k/config["M"])
            time_bomp = @elapsed run_bomp(signal, dicti, config["sig_w_2"], config["sig_x_2"], k/config["M"])
            time_bstomp = @elapsed run_bstomp(signal, dicti, config["sig_w_2"], config["sig_x_2"], k/config["M"])

            results["MP"][i] += time_mp
            results["OMP"][i] += time_omp
            results["BMP"][i] += time_bmp
            results["BOMP"][i] += time_bomp
            results["BSTOMP"][i] += time_bstomp
        end
    
        for (key, values) in results
            results[key][i] /= n_trials
        end
    end

    return results
end

run_time_test (generic function with 1 method)

In [18]:
#time_results = run_time_test(config, K, 100)

Dict{String, Vector{Float64}} with 5 entries:
  "BSTOMP" => [0.000186123, 0.000711811, 0.000773277, 0.00129143, 0.0017939, 0.…
  "MP"     => [0.000220962, 0.000859539, 0.00233407, 0.0148123, 0.0405853, 0.05…
  "BOMP"   => [0.00128929, 0.00362858, 0.00430349, 0.00649816, 0.0086633, 0.011…
  "OMP"    => [0.000226975, 0.00091908, 0.00188237, 0.00307399, 0.00539146, 0.0…
  "BMP"    => [0.00134914, 0.00539959, 0.0111277, 0.0348527, 0.0756673, 0.10252…

In [117]:
plot(
    layout=(1, 2), size=(1000, 400), left_margin=5Plots.mm,
    bottom_margin=5Plots.mm, xlabel="Sparsity (K)", legend=false
)

# Computing time
plot!(K, time_results["MP"], lc=:red, mc=:red, marker=:^, ms=:5, label="MP", yscale=:log10, subplot=1, ylabel="Computing Time")
plot!(K, time_results["BMP"], lc=:red, ls=:dash, marker=:^, mc=:red, ms=:5, label="BMP", yscale=:log10, subplot=1)

plot!(K, time_results["OMP"], lc=:blue, mc=:blue, marker=:cross, label="OMP", yscale=:log10, subplot=1)
plot!(K, time_results["BOMP"], lc=:blue, ls=:dash, mc=:blue, marker=:cross, label="BOMP", yscale=:log10, subplot=1)
plot!(K, time_results["BSTOMP"], lc=:green, ls=:dash, mc=:green, marker=:circle, label="BSTOMP", yscale=:log10, subplot=1)

# Number of iterations
plot!(K, results["MP"][3], lc=:red, label="MP", marker=:^, ylabel="Number of iterations",
    mc=:red, ms=:5, yscale=:log10, legend=:bottomright, subplot=2
)

plot!(K, results["BMP"][3], lc=:red, label="BMP",
    marker=:^, mc=:red, ms=:5, ls=:dash, subplot=2
)

plot!(K, results["OMP"][3], lc=:blue, label="OMP",
    marker=:cross, mc=:blue, ms=:5, subplot=2
)

plot!(K, results["BOMP"][3], lc=:blue, label="BOMP",
    marker=:cross, mc=:blue, ms=:5, ls=:dash, subplot=2
)

plot!(K, results["BSTOMP"][3], lc=:green, label="BSTOMP",
    marker=:circle, mc=:green, ms=:5, ls=:dash, subplot=2, legend=:outerright
)

savefig("Figure_Time.pdf")

"/Data/Documents/PhD/Spring23/Sparse_Analysis/project/Figure_Time.pdf"