In [None]:
# load basic packages
include("./analysis_fixtures.jl")

import jcontrol.Run: run_mtm
import jcontrol: toDict
savepath = "/Users/federicoclaudi/Dropbox (UCL)/Rotation_vte/Locomotion/analysis/behavior/mtm_cost_function"
import InfiniteOpt: termination_status

## Run MTM
Run MTM while systematically varying the cost fuction alpha factor.

In [None]:

"""
Runs the MTM problem while systematically changing the alpha factor of the cost function
    to assess the effect on behavior
"""


for γ in (4e-6, 8e-6 1e-5, 5e-5)
    @info "Running γ: $(γ)"
    _, _, control_model, solution = run_mtm(
        :dynamics,  # model type
        2.0;  # supports density
        showtrials=nothing,
        showplots=false,
        quiet=true,
        γ = γ,
        α=0.0,
    )

    if "LOCALLY_SOLVED" == string(termination_status(control_model))
        destination = joinpath(savepath, "cost_fn_alpha_0_gamma_$(γ).csv")
        data = DataFrame(toDict(solution))
        CSV.write(destination, data)
    else
        println(termination_status(control_model))
    end
end


for α in (4e-6, 8e-6 1e-5, 5e-5)
    @info "Running α: $(α)"
    _, _, control_model, solution = run_mtm(
        :dynamics,  # model type
        2.0;  # supports density
        showtrials=nothing,
        showplots=false,
        quiet=true,
        α = α,
        γ = 0.0,
    )

    if "LOCALLY_SOLVED" == string(termination_status(control_model))
        destination = joinpath(savepath, "cost_fn_alpha_$(α)_gamma_0.csv")
        data = DataFrame(toDict(solution))
        CSV.write(destination, data)
    else
        println(termination_status(control_model))
    end
end

## Analysis

In [None]:
# load simulations
solutions, _names = load_mtm_solutions(folder=savepath, name="cost_fn_alpha_");

for (solution, name) in zip(solutions, _names)
    println(name, "  ", solution.t[end])
end

In [None]:
function make_palette(x)
    return range(HSL(217, .64, .55), stop=HSL(320, .73, .78), length=length(x))
end

In [None]:
plt = draw(:arena)

draw!.(trials[1:100]; alpha=.1)

colors = make_palette(solutions)
for (n, solution) in enumerate(solutions)
    plot!(solution.x, solution.y, lw=5, alpha=.8, color=colors[n], label=_names[n]) 
end
plt