In [13]:
import Pkg
Pkg.activate(@__DIR__)
Pkg.instantiate()
import MathOptInterface as MOI
import Ipopt 
import FiniteDiff
import ForwardDiff as FD
import Convex as cvx 
import ECOS
import MeshCat as mc
import Distributions
import Random

using LinearAlgebra
using Plots
using Random
using JLD2
using Test
using CSV
using DataFrames


[32m[1m  Activating[22m[39m environment at `c:\Users\AiPEX-WS4\Documents\AiPEX-Projects\warmstarting_NLPs\JL_presolves\Project.toml`


In [14]:
include(joinpath(@__DIR__, "utils","fmincon.jl"))
include(joinpath(@__DIR__, "utils","cartpole_animation.jl"))

animate_cartpole (generic function with 1 method)

In [61]:
# cartpole 
function dynamics(params::NamedTuple, x::Vector, u)
    # cartpole ODE, parametrized by params. 

    # cartpole physical parameters 
    mc, mp, l = params.mc, params.mp, params.l
    g = 9.81
    
    q = x[1:2]
    qd = x[3:4]

    s = sin(q[2])
    c = cos(q[2])


    H = [mc+mp mp*l*c; mp*l*c mp*l^2]
    @show size(H)
    C = [0 -mp*qd[2]*l*s; 0 0]
    G = [0, mp*g*l*s]
    B = [1, 0]

    qdd = -H\(C*qd + G - B*u[1])
    xdot = [qd;qdd]
    return xdot 

end

function hermite_simpson(params::NamedTuple, x1::Vector, x2::Vector, u, dt::Real)::Vector
    # TODO: input hermite simpson implicit integrator residual 
     x_mid = 0.5(x1 + x2) + (dt/8) * (dynamics(params, x1, u) - dynamics(params, x2, u))
     res = x1 + (dt/6) * (dynamics(params, x1, u) + 4*dynamics(params, x_mid, u) + dynamics(params, x2, u)) - x2
     return res
end

hermite_simpson (generic function with 1 method)

In [62]:
function create_idx(nx,nu,N)
    # This function creates some useful indexing tools for Z 
    # x_i = Z[idx.x[i]]
    # u_i = Z[idx.u[i]]
    
    # Feel free to use/not use anything here.
    
    # our Z vector is [x0, u0, x1, u1, …, xN]
    nz = (N-1) * nu + N * nx # length of Z 
    x = [(i - 1) * (nx + nu) .+ (1 : nx) for i = 1:N]
    u = [(i - 1) * (nx + nu) .+ ((nx + 1):(nx + nu)) for i = 1:(N - 1)]
    
    # constraint indexing for the (N-1) dynamics constraints when stacked up
    c = [(i - 1) * (nx) .+ (1 : nx) for i = 1:(N - 1)]
    nc = (N - 1) * nx # (N-1)*nx 
    
    return (nx=nx,nu=nu,N=N,nz=nz,nc=nc,x= x,u = u,c = c)
end

function cartpole_cost(params::NamedTuple, Z::Vector)::Real
    idx, N, xg = params.idx, params.N, params.xg
    Q, R, Qf = params.Q, params.R, params.Qf
    
    # TODO: input cartpole LQR cost 
    J = 0 

    for i = 1:(N-1)
        xi = Z[idx.x[i]]
        ui = Z[idx.u[i]]
       
        J += 0.5*(xi-xg)'*Q*(xi-xg) + 0.5*ui'*R*ui
    end
    
    # dont forget terminal cost 
    xN = Z[idx.x[N]]
    J += 0.5*(xN-xg)'*Qf*(xN-xg)
    return J 
end

function cartpole_dynamics_constraints(params::NamedTuple, Z::Vector)::Vector
    idx, N, dt = params.idx, params.N, params.dt
    
    # TODO: create dynamics constraints using hermite simpson 

    # create c in a ForwardDiff friendly way (check HW0)
    c = zeros(eltype(Z), idx.nc)
    
    for i = 1:(N-1)
        xi = Z[idx.x[i]]
        ui = Z[idx.u[i]] 
        xip1 = Z[idx.x[i+1]]
        
        # TODO: hermite simpson 
        c[idx.c[i]] = hermite_simpson(params, xi, xip1, ui, dt)
    end
    return c 
end

function cartpole_equality_constraint(params::NamedTuple, Z::Vector)::Vector
    N, idx, xic, xg = params.N, params.idx, params.xic, params.xg 
    
    
    # TODO: return all of the equality constraints 

    
    return [Z[idx.x[1]] - xic; Z[idx.x[end]] - xg; cartpole_dynamics_constraints(params, Z)] 
end

function solve_cartpole_swingup(σ; verbose=true)
    
    # problem size 
    nx = 4 
    nu = 1 
    dt = 0.05
    tf = 2.0
    t_vec = 0:dt:tf 
    N = length(t_vec)
    
    # LQR cost 
    Q = 1*diagm(ones(nx))
    R = 0.1*diagm(ones(nu))
    Qf = 10*diagm(ones(nx))
    
    # indexing 
    idx = create_idx(nx,nu,N)
    
    # initial and goal states 
    # xic = [0, 0, 0, 0]
    xic = [σ[1], σ[2], 0, 0]
    xg = [0, pi, 0, 0]
    
    # load all useful things into params 
    params = (Q = Q, R = R, Qf = Qf, xic = xic, xg = xg, dt = dt, N = N, idx = idx,mc = 1.0, mp = 0.2, l = 0.5)
    
    # TODO: primal bounds 
    x_l = fill(-Inf, idx.nz)
    x_u = fill(Inf, idx.nz)
    
    for i = 1:(N-1)
        x_l[idx.u[i]] .= -10
        x_u[idx.u[i]] .= 10
    end

    
    # inequality constraint bounds (this is what we do when we have no inequality constraints)
    c_l = zeros(0)
    c_u = zeros(0)
    function inequality_constraint(params, Z)
        return zeros(eltype(Z), 0)
    end
    
    # initial guess 
    z0 = 0.01*randn(idx.nz)
    # z0 = pi*ones(idx.nz)
    # z0 = z0
    
    # choose diff type (try :auto, then use :finite if :auto doesn't work)
    diff_type = :auto 
#     diff_type = :finite
    
    # @show cartpole_equality_constraint(params, z0)
    # @show cartpole_dynamics_constraints(params, z0)
    # @show inequality_constraint(params, z0)
    # @show cartpole_cost(params, z0)
        
    Z, obj, solve_time_sec, term_status = fmincon(cartpole_cost,cartpole_equality_constraint,inequality_constraint,
                x_l,x_u,c_l,c_u,z0,params, diff_type;
                tol = 1e-6, c_tol = 1e-6, max_iters = 10_000, verbose = verbose)
    # term_status = 0




    # pull the X and U solutions out of Z 
    X = [Z[idx.x[i]] for i = 1:N]
    U = [Z[idx.u[i]] for i = 1:(N-1)]
    
    return X, U, obj, solve_time_sec, term_status, t_vec, params 
end


    


solve_cartpole_swingup (generic function with 1 method)

## Solve cartpole for a single parameter sample

In [63]:

σ = [0, pi, 0, 0]

X, U, obj, solve_time_sec, term_status, t_vec, params = solve_cartpole_swingup(σ; verbose=true)

# --------------testing------------------

Xm = hcat(X...)
Um = hcat(U...)

# --------------plotting-----------------
display(plot(t_vec, Xm', label = ["p" "θ" "ṗ" "θ̇"], xlabel = "time (s)", title = "State Trajectory"))
display(plot(t_vec[1:end-1],Um',label="",xlabel = "time (s)", ylabel = "u",title = "Controls"))




In [57]:
# display(animate_cartpole(X, 0.05))
animate_cartpole(X, 0.05)


┌ Info: Listening on: 127.0.0.1:8715, thread id: 1
└ @ HTTP.Servers C:\Users\AiPEX-WS4\.julia\packages\HTTP\sJD5V\src\Servers.jl:382
┌ Info: MeshCat server started. You can open the visualizer by visiting the following URL in your browser:
│ http://127.0.0.1:8715
└ @ MeshCat C:\Users\AiPEX-WS4\.julia\packages\MeshCat\0RCA3\src\visualizer.jl:64


## Solve the DIRCOL NLP for a parameter set

In [19]:
# ## Define upper and lower bounds of the parameters for the Paramaetric Optimal Control Problem
# # using the xic of the cartpole
# using Random, Distributions, CSV, DataFrames, ProgressMeter
# Random.seed!(123)

# N = 50 # number of samples
# σ_lower = [-1.0, 0]
# σ_upper = [1.0, 2*pi]

# # Randomly sample the iid parameters uniformly from the given bounds
# d = Product(Uniform.(σ_lower, σ_upper))
# σ_samples = rand(d, N)
# σ_samples = eachcol(σ_samples)

# # Solve the NLP for the parameter sample set
# df = DataFrame(params = Vector{Vector{Float64}}(), X=Vector{Vector{Vector{Float64}}}(), U=Vector{Vector{Vector{Float64}}}(), obj = Float64[], solve_time_sec = Float64[], term_status = MOI.TerminationStatusCode[])

# p = Progress(N, 1)
# i = 1
# for σ in σ_samples
#     next!(p)
#     # z0 = 0.001*randn(idx.nz)
#     X, U, obj, solve_time_sec, term_status, t_vec, params = solve_cartpole_swingup(σ, verbose=false)

#     # if i % 10 == 0
#     #     println("Sample: ", i)
#     #     println("------------------")
#     #     println("σ: ", σ)
#     #     println("Objective Value: ", obj)
#     #     println("Termination Status Code: ", term_status)
#     #     println("Solve Time: ", solve_time_sec, "s")
#     #     println("")
#     # end

#     push!(df, [σ, X, U, obj, solve_time_sec, term_status])
#     i += 1
# end



[32mProgress: 100%|█████████████████████████████████████████| Time: 0:01:34[39m[K


In [20]:

# CSV.write("../data/presolves/cartpole_DIRCOL_1000_pi_wrap_fix.csv", df)

df

Row,params,X,U,obj,solve_time_sec,term_status
Unnamed: 0_level_1,Array…,Array…,Array…,Float64,Float64,Terminat…
1,"[0.536895, 5.90943]","[[0.536895, 3.37297, 1.29781e-13, -4.126e-13], [0.525056, 3.35546, -0.474051, -0.704672], [0.489399, 3.30165, -0.953515, -1.46086], [0.429539, 3.20769, -1.44276, -2.31824], [0.344945, 3.06746, -1.94226, -3.31627], [0.235296, 2.87354, -2.44192, -4.46246], [0.101232, 2.6196, -2.9119, -5.70227], [-0.0545727, 2.30364, -3.30267, -6.92384], [-0.205339, 1.95407, -2.72254, -7.14077], [-0.325965, 1.58137, -2.09028, -7.85149] … [-0.698828, 0.978451, 0.652881, 9.57426], [-0.646737, 1.41853, 1.41233, 8.11869], [-0.559374, 1.79922, 2.06735, 7.19345], [-0.441136, 2.1459, 2.65431, 6.75201], [-0.313006, 2.4587, 2.4556, 5.75849], [-0.201164, 2.71481, 2.00669, 4.48245], [-0.113088, 2.90742, 1.51208, 3.23326], [-0.0501441, 3.03969, 1.00532, 2.07466], [-0.01251, 3.11647, 0.50078, 1.00969], [-3.01424e-14, 3.14159, -6.90501e-14, 2.50241e-13]]","[[-10.0], [-10.0], [-10.0], [-10.0], [-10.0], [-10.0], [-10.0], [10.0], [10.0], [10.0] … [10.0], [10.0], [10.0], [10.0], [-6.35807], [-10.0], [-10.0], [-10.0], [-10.0], [-10.0]]",1299.07,1.745,LOCALLY_INFEASIBLE
2,"[0.347917, 2.48471]","[[0.347917, 2.48471, 0.0, 0.0], [0.335215, 2.44959, -0.506704, -1.40327], [0.306937, 2.35885, -0.622486, -2.23211], [0.28188, 2.23756, -0.379024, -2.63915], [0.273618, 2.09811, 0.0492363, -2.967], [0.286951, 1.93793, 0.486512, -3.47057], [0.322633, 1.74788, 0.945845, -4.16512], [0.382213, 1.51779, 1.44607, -5.07816], [0.468331, 1.23569, 2.01171, -6.25406], [0.577193, 0.893763, 2.3526, -7.43912] … [-0.229905, 2.08268, -0.86037, 3.07803], [-0.262148, 2.2273, -0.430171, 2.73624], [-0.272935, 2.35921, -0.000830958, 2.56917], [-0.26213, 2.48716, 0.434506, 2.57904], [-0.229292, 2.62025, 0.881373, 2.7776], [-0.175196, 2.76577, 1.28503, 3.07795], [-0.109182, 2.91517, 1.35642, 2.92416], [-0.0501441, 3.03969, 1.00532, 2.07466], [-0.01251, 3.11647, 0.50078, 1.00969], [0.0, 3.14159, 0.0, 6.9533e-31]]","[[-10.0], [-1.77467], [5.94236], [10.0], [10.0], [10.0], [10.0], [10.0], [2.97412], [-9.53868] … [10.0], [10.0], [10.0], [10.0], [10.0], [8.77879], [1.74058], [-6.85525], [-10.0], [-10.0]]",699.009,0.871,LOCALLY_SOLVED
3,"[-0.373512, 4.16295]","[[-0.373512, 3.35163, 1.33444e-13, -4.14319e-13], [-0.38542, 3.33336, -0.476804, -0.735123], [-0.421285, 3.27723, -0.959094, -1.52314], [-0.481493, 3.17933, -1.45103, -2.41408], [-0.566547, 3.03347, -1.95207, -3.44508], [-0.676661, 2.83243, -2.4498, -4.61658], [-0.810935, 2.5705, -2.91114, -5.86426], [-0.950872, 2.27088, -2.68094, -6.16458], [-1.0714, 1.95755, -2.13569, -6.43884], [-1.1637, 1.61995, -1.54587, -7.13774] … [-0.735627, 0.925806, 0.852823, 9.97188], [-0.672894, 1.38409, 1.63828, 8.45231], [-0.573651, 1.78047, 2.31562, 7.49246], [-0.445188, 2.14016, 2.81291, 6.96499], [-0.313006, 2.4587, 2.4556, 5.75849], [-0.201164, 2.71481, 2.00669, 4.48245], [-0.113088, 2.90742, 1.51208, 3.23326], [-0.0501441, 3.03969, 1.00532, 2.07466], [-0.01251, 3.11647, 0.50078, 1.00969], [-6.29864e-14, 3.14159, -7.00678e-15, 2.00705e-13]]","[[-10.0], [-10.0], [-10.0], [-10.0], [-10.0], [-10.0], [3.6008], [10.0], [10.0], [10.0] … [5.49854], [10.0], [10.0], [7.54446], [-10.0], [-10.0], [-10.0], [-10.0], [-10.0], [-10.0]]",1279.31,1.416,LOCALLY_INFEASIBLE
4,"[0.172044, 0.327562]","[[0.172044, 0.327562, 0.0, 6.18897e-29], [0.178888, 0.306778, 0.273761, -0.828773], [0.198008, 0.247916, 0.490764, -1.51705], [0.22705, 0.15792, 0.669845, -2.0674], [0.264389, 0.0439419, 0.821693, -2.46974], [0.30878, -0.0861919, 0.951041, -2.70847], [0.359099, -0.223878, 1.05856, -2.76974], [0.4142, -0.360019, 1.14281, -2.64844], [0.472852, -0.485606, 1.20169, -2.35242], [0.533721, -0.59234, 1.23265, -1.90064] … [-0.112594, 2.79031, 0.20602, 1.49437], [-0.100484, 2.86058, 0.279182, 1.32917], [-0.085486, 2.92275, 0.321492, 1.16921], [-0.0690205, 2.9771, 0.337849, 1.0151], [-0.052291, 3.02391, 0.331997, 0.866185], [-0.0363451, 3.06338, 0.306423, 0.719954], [-0.0221416, 3.09551, 0.262211, 0.571379], [-0.0106267, 3.11997, 0.19877, 0.41202], [-0.00283078, 3.13591, 0.113324, 0.228501], [0.0, 3.14159, 0.0, 8.44884e-27]]","[[4.99013], [3.8469], [3.15593], [2.79297], [2.6422], [2.57617], [2.46578], [2.20967], [1.75009], [1.05085] … [2.85638], [2.00806], [1.28685], [0.672413], [0.141968], [-0.328894], [-0.767142], [-1.20488], [-1.68386], [-2.26278]]",404.52,1.354,LOCALLY_SOLVED
5,"[-0.462721, 0.684055]","[[-0.462721, 0.684055, 0.0, 0.0], [-0.452359, 0.652528, 0.415494, -1.26139], [-0.424602, 0.56376, 0.696494, -2.28491], [-0.385584, 0.430018, 0.864556, -3.04971], [-0.34051, 0.265081, 0.936123, -3.51946], [-0.29358, 0.0838721, 0.936568, -3.69095], [-0.247495, -0.0997364, 0.901815, -3.61346], [-0.203162, -0.275033, 0.867512, -3.36276], [-0.159975, -0.434878, 0.85745, -3.0021], [-0.116497, -0.574666, 0.88059, -2.56701] … [-0.162421, 2.62983, 0.242128, 2.09125], [-0.147052, 2.72891, 0.373368, 1.88967], [-0.12635, 2.81802, 0.455499, 1.69079], [-0.102634, 2.8972, 0.493937, 1.4906], [-0.077962, 2.96632, 0.493769, 1.2868], [-0.054159, 3.02513, 0.459122, 1.07653], [-0.0328795, 3.07318, 0.392735, 0.854689], [-0.015686, 3.10968, 0.295557, 0.612383], [-0.00415149, 3.13326, 0.166195, 0.335106], [0.0, 3.14159, 0.0, 4.97173e-28]]","[[7.96743], [4.88584], [2.35202], [0.468624], [-0.552236], [-0.669122], [-0.104988], [0.785978], [1.69134], [2.41344] … [4.54247], [3.33436], [2.22709], [1.23597], [0.35441], [-0.435818], [-1.16019], [-1.85106], [-2.55073], [-3.31848]]",436.644,1.299,LOCALLY_SOLVED
6,"[-0.672668, 2.97205]","[[-0.672668, 2.97205, 0.0, 0.0], [-0.685505, 2.94251, -0.513505, -1.1855], [-0.72401, 2.85278, -1.02611, -2.41327], [-0.787959, 2.70019, -1.52865, -3.69829], [-0.861835, 2.50792, -1.4254, -4.02153], [-0.921494, 2.3088, -0.962381, -3.99008], [-0.958109, 2.10439, -0.501302, -4.2302], [-0.971389, 1.88142, -0.0258792, -4.73267], [-0.960098, 1.62661, 0.485729, -5.50696], [-0.921753, 1.32563, 1.06123, -6.58746] … [-0.465189, 1.45496, -0.207558, 6.51885], [-0.461144, 1.75604, 0.35757, 5.58221], [-0.430342, 2.01858, 0.86791, 4.971], [-0.374786, 2.25823, 1.35197, 4.66627], [-0.295211, 2.49044, 1.83183, 4.67616], [-0.20018, 2.7167, 1.96774, 4.40835], [-0.113088, 2.90742, 1.51208, 3.23326], [-0.0501441, 3.03969, 1.00532, 2.07466], [-0.01251, 3.11647, 0.50078, 1.00969], [0.0, 3.14159, 0.0, 6.4826e-32]]","[[-10.0], [-10.0], [-10.0], [2.27184], [10.0], [10.0], [10.0], [10.0], [10.0], [10.0] … [10.0], [10.0], [10.0], [10.0], [10.0], [2.66789], [-9.19326], [-10.0], [-10.0], [-10.0]]",870.494,0.713,LOCALLY_SOLVED
7,"[0.730824, 3.87982]","[[0.730824, 3.37203, 5.07035e-14, -4.10411e-13], [0.718981, 3.35449, -0.474173, -0.706021], [0.683316, 3.30057, -0.953763, -1.46363], [0.62344, 3.20643, -1.44313, -2.3225], [0.538826, 3.06595, -1.94271, -3.32203], [0.429155, 2.87172, -2.4423, -4.46941], [0.29508, 2.61742, -2.91192, -5.70964], [0.139291, 2.30109, -3.30196, -6.93076], [-0.0127792, 1.94952, -2.77431, -7.20801], [-0.135879, 1.57331, -2.13722, -7.92587] … [-0.678329, 1.01321, 0.560952, 9.32593], [-0.631243, 1.44196, 1.30409, 7.91182], [-0.549646, 1.81298, 1.94555, 7.01002], [-0.437765, 2.15069, 2.52229, 6.57402], [-0.313006, 2.4587, 2.4556, 5.75849], [-0.201164, 2.71481, 2.00669, 4.48245], [-0.113088, 2.90742, 1.51208, 3.23326], [-0.0501441, 3.03969, 1.00532, 2.07466], [-0.01251, 3.11647, 0.50078, 1.00969], [2.92291e-14, 3.14159, -1.09442e-13, 2.87828e-13]]","[[-10.0], [-10.0], [-10.0], [-10.0], [-10.0], [-10.0], [-10.0], [8.74717], [10.0], [10.0] … [10.0], [10.0], [10.0], [10.0], [-3.33063], [-10.0], [-10.0], [-10.0], [-10.0], [-10.0]]",1276.37,2.158,LOCALLY_INFEASIBLE
8,"[-0.428604, 2.91444]","[[-0.428604, 2.91444, 0.0, 0.0], [-0.441509, 2.88368, -0.516086, -1.23387], [-0.48017, 2.79044, -1.0292, -2.50326], [-0.54231, 2.6361, -1.45288, -3.67793], [-0.61054, 2.44646, -1.27572, -3.93829], [-0.662808, 2.24933, -0.815818, -3.99222], [-0.692088, 2.04285, -0.353674, -4.31042], [-0.697851, 1.81397, 0.128231, -4.8888], [-0.678551, 1.54938, 0.653238, -5.74348], [-0.631331, 1.23433, 1.24991, -6.91588] … [-0.425319, 1.55692, -0.313147, 5.94607], [-0.427398, 1.83199, 0.220316, 5.10829], [-0.404012, 2.07265, 0.710079, 4.5656], [-0.3567, 2.29324, 1.18097, 4.30562], [-0.285913, 2.50808, 1.65175, 4.33811], [-0.19779, 2.72128, 1.87305, 4.22806], [-0.113088, 2.90742, 1.51208, 3.23326], [-0.0501441, 3.03969, 1.00532, 2.07466], [-0.01251, 3.11647, 0.50078, 1.00969], [0.0, 3.14159, 0.0, -2.8671e-28]]","[[-10.0], [-10.0], [-8.43347], [3.86516], [10.0], [10.0], [10.0], [10.0], [10.0], [7.77423] … [10.0], [10.0], [10.0], [10.0], [10.0], [4.58375], [-7.23323], [-10.0], [-10.0], [-10.0]]",838.77,0.789,LOCALLY_SOLVED
9,"[-0.448362, 2.80587]","[[-0.448362, 2.80587, 0.0, 0.0], [-0.461333, 2.77322, -0.518374, -1.30828], [-0.500087, 2.67467, -1.0296, -2.63761], [-0.555681, 2.52357, -1.19169, -3.41895], [-0.605527, 2.35251, -0.803261, -3.4608], [-0.63451, 2.1759, -0.35608, -3.64237], [-0.641032, 1.98445, 0.0974532, -4.05409], [-0.62439, 1.76659, 0.573708, -4.70072], [-0.582916, 1.5101, 1.09491, -5.60494], [-0.513692, 1.20106, 1.68835, -6.81147] … [-0.373127, 1.68927, -0.460653, 5.20629], [-0.383554, 1.93061, 0.0365474, 4.49241], [-0.369971, 2.14282, 0.50357, 4.03825], [-0.333367, 2.33865, 0.960205, 3.83774], [-0.273876, 2.53099, 1.4211, 3.90145], [-0.19466, 2.72729, 1.74901, 3.99159], [-0.113088, 2.90742, 1.51208, 3.23326], [-0.0501441, 3.03969, 1.00532, 2.07466], [-0.01251, 3.11647, 0.50078, 1.00969], [0.0, 3.14159, 0.0, -1.79042e-31]]","[[-10.0], [-10.0], [-3.0277], [8.62559], [10.0], [10.0], [10.0], [10.0], [10.0], [7.29255] … [10.0], [10.0], [10.0], [10.0], [10.0], [6.95502], [-4.66871], [-10.0], [-10.0], [-10.0]]",792.198,0.892,LOCALLY_SOLVED
10,"[0.164636, 1.60838]","[[0.164636, 1.60838, 0.0, 0.0], [0.172353, 1.58439, 0.309672, -0.962244], [0.196008, 1.51164, 0.639523, -1.95537], [0.236153, 1.38789, 0.971739, -3.00703], [0.291932, 1.21032, 1.26714, -4.10904], [0.360105, 0.978578, 1.46731, -5.16399], [0.434307, 0.700312, 1.5019, -5.94002], [0.50521, 0.396428, 1.32416, -6.14657], [0.563107, 0.0978747, 0.97659, -5.71101], [0.602466, -0.168909, 0.587736, -4.89821] … [-0.215, 2.30521, -0.331634, 2.6308], [-0.222698, 2.43173, 0.0241919, 2.45608], [-0.214049, 2.55098, 0.322628, 2.33798], [-0.192032, 2.66504, 0.559076, 2.2467], [-0.159872, 2.77455, 0.72839, 2.15452], [-0.121129, 2.87868, 0.822378, 2.02994], [-0.0799123, 2.97478, 0.827286, 1.83126], [-0.0412407, 3.05764, 0.720505, 1.49812], [-0.0116195, 3.11826, 0.465134, 0.937827], [0.0, 3.14159, 0.0, 9.30889e-28]]","[[7.45826], [7.65907], [7.11055], [5.28483], [1.93783], [-2.48872], [-6.52181], [-8.36924], [-7.62187], [-5.4129] … [9.47981], [8.34025], [7.00425], [5.56585], [4.03426], [2.35735], [0.426481], [-1.93675], [-5.01837], [-9.2881]]",539.585,0.933,LOCALLY_SOLVED


## Load in the Warmstarts and Solve for the Refined Trajectories

In [21]:
# using CSV

# df = DataFrame(CSV.File("data/warmstart_cartpole.csv"))
# Z_warmstart_str = df.Z_warmstart[1]
# @show Z_warmstart_str
# Z_warmstart_str = replace(Z_warmstart_str, r"\n" => "")
# Z_warmstart = eval(Meta.parse(Z_warmstart_str))
# # # Function to convert the string representation to a 2D array
# # function convert_to_2d_array(str::String)
# #     # Remove the brackets and newline characters
# #     clean_str = replace(str, r"[\[\]\n]" => "")
# #     # Split the string into individual numbers
# #     num_strs = split(clean_str)
# #     # Convert the numbers to Float64
# #     nums = parse.(Float64, num_strs)
# #     # Reshape the flat array into a 2D array
# #     num_rows = count(x -> x == '\n', str) + 1
# #     num_cols = length(nums) ÷ num_rows
# #     return reshape(nums, num_cols, num_rows)'
# # end

# # X_warmstart = convert_to_2d_array(X_warmstart_str)
# # X_warmstart = X_warmstart'
# # X_warmstart = [X_warmstart[:, i] for i in 1:size(X_warmstart, 2)]
# # @show size(X_warmstart[1])

In [22]:
# # Assuming you have a DataFrame named df

# for row in eachrow(df)
#     Z_warmstart_str = row.Z_warmstart
#     Z_warmstart_str = replace(Z_warmstart_str, r"\n" => "")
#     Z_warmstart = vec(eval(Meta.parse(Z_warmstart_str)))
#     # @show typeof(Z_warmstart)
#     param_str = row.params
#     σ = vec(eval(Meta.parse(param_str)))
#     # @show typeof(vec(param_str))

#     X, U, obj, solve_time_sec, term_status, t_vec, params = solve_cartpole_swingup(σ, Z_warmstart; verbose=false)


# end