In [None]:
import Pkg; Pkg.activate(@__DIR__); Pkg.instantiate()

In [None]:
using LinearAlgebra
using PyPlot
using ForwardDiff
using RobotZoo
using RobotDynamics
using MatrixCalculus
using JLD2

In [None]:
#Cartpole Dynamics
a = RobotZoo.Cartpole()
h = 0.05

In [None]:
function dynamics_rk4(x,u)
    #RK4 integration with zero-order hold on u
    f1 = RobotZoo.dynamics(a, x, u)
    f2 = RobotZoo.dynamics(a, x + 0.5*h*f1, u)
    f3 = RobotZoo.dynamics(a, x + 0.5*h*f2, u)
    f4 = RobotZoo.dynamics(a, x + h*f3, u)
    return x + (h/6.0)*(f1 + 2*f2 + 2*f3 + f4)
end

In [None]:
function dfdx(x,u)
    ForwardDiff.jacobian(dx->dynamics_rk4(dx,u),x)
end

function dfdu(x,u)
    ForwardDiff.derivative(du->dynamics_rk4(x,du),u)
end

In [None]:
function dAdx(x,u)
    ForwardDiff.jacobian(dx->vec(dfdx(dx,u)),x)
end

function dBdx(x,u)
    ForwardDiff.jacobian(dx->dfdu(dx,u),x)
end

function dAdu(x,u)
    ForwardDiff.derivative(du->vec(dfdx(x,du)),u)
end

function dBdu(x,u)
    ForwardDiff.derivative(du->dfdu(x,du),u)
end

In [None]:
Nx = 4     # number of state
Nu = 1     # number of controls
Tfinal = 5.0 # final time
Nt = Int(Tfinal/h)+1    # number of time steps
thist = Array(range(0,h*(Nt-1), step=h));

In [None]:
# Cost weights
Q = Diagonal([1.0*ones(2); 1.0*ones(2)]);
R = 0.1;
Qn = Array(100.0*I(Nx));

In [None]:
function stage_cost(x,u)
    return 0.5*((x-xgoal)'*Q*(x-xgoal)) + 0.5*R*u*u
end

In [None]:
function terminal_cost(x)
    return 0.5*(x-xgoal)'*Qn*(x-xgoal)
end

In [None]:
function cost(xtraj,utraj)
    J = 0.0
    for k = 1:(Nt-1)
        J += stage_cost(xtraj[:,k],utraj[k])
    end
    J += terminal_cost(xtraj[:,Nt])
    return J
end

In [None]:
function backward_pass!(p,P,d,K)
    
    ΔJ = 0.0
    p[:,Nt] .= Qn*(xtraj[:,Nt]-xgoal)
    P[:,:,Nt] .= Qn
    
    for k = (Nt-1):-1:1
        #Calculate derivatives
        q = Q*(xtraj[:,k]-xgoal)
        r = R*utraj[k]
    
        A = dfdx(xtraj[:,k], utraj[k])
        B = dfdu(xtraj[:,k], utraj[k])
    
        gx = q + A'*p[:,k+1]
        gu = r + B'*p[:,k+1]
    
        #iLQR (Gauss-Newton) version
        Gxx = Q + A'*P[:,:,k+1]*A
        Guu = R + B'*P[:,:,k+1]*B
        Gxu = A'*P[:,:,k+1]*B
        Gux = B'*P[:,:,k+1]*A
        
        #DDP (full Newton) version
        #Ax = dAdx(xtraj[:,k], utraj[k])
        #Bx = dBdx(xtraj[:,k], utraj[k])
        #Au = dAdu(xtraj[:,k], utraj[k])
        #Bu = dBdu(xtraj[:,k], utraj[k])
        #Gxx = Q + A'*P[:,:,k+1]*A + kron(p[:,k+1]',I(Nx))*comm(Nx,Nx)*Ax
        #Guu = R + B'*P[:,:,k+1]*B + (kron(p[:,k+1]',I(Nu))*comm(Nx,Nu)*Bu)[1]
        #Gxu = A'*P[:,:,k+1]*B + kron(p[:,k+1]',I(Nx))*comm(Nx,Nx)*Au
        #Gux = B'*P[:,:,k+1]*A + kron(p[:,k+1]',I(Nu))*comm(Nx,Nu)*Bx
        
        #β = 0.1
        #while !isposdef(Symmetric([Gxx Gxu; Gux Guu]))
        #    Gxx += β*I
        #    Guu += β*I
        #    β = 2*β
        #    display("regularizing G")
        #    display(β)
        #end
        
        d[k] = Guu\gu
        K[:,:,k] .= Guu\Gux
    
        #p[:,k] .= dropdims(gx - K[:,:,k]'*gu + K[:,:,k]'*Guu*d[k] - Gxu*d[k], dims=2)
        p[:,k] = gx - K[:,:,k]'*gu + K[:,:,k]'*Guu*d[k] - Gxu*d[k]
        P[:,:,k] .= Gxx + K[:,:,k]'*Guu*K[:,:,k] - Gxu*K[:,:,k] - K[:,:,k]'*Gux
    
        ΔJ += gu'*d[k]
    end
    
    return ΔJ
end

In [None]:
#Initial guess
x0 = [0; 0; 0; 0]
xgoal = [0, pi, 0, 0]
xtraj = kron(ones(1,Nt), x0)
utraj = randn(Nt-1);

In [None]:
#Initial Rollout
for k = 1:(Nt-1)
    xtraj[:,k+1] .= dynamics_rk4(xtraj[:,k],utraj[k])
end
J = cost(xtraj,utraj)

In [None]:
#DDP Algorithm
using Printf
p = ones(Nx,Nt)
P = zeros(Nx,Nx,Nt)
d = ones(Nt-1)
K = zeros(Nu,Nx,Nt-1)
ΔJ = 0.0

xn = zeros(Nx,Nt)
un = zeros(Nt-1)

gx = zeros(Nx)
gu = 0.0
Gxx = zeros(Nx,Nx)
Guu = 0.0
Gxu = zeros(Nx)
Gux = zeros(Nx)

iter = 0
while maximum(abs.(d[:])) >  1e-3
    iter += 1    
    
    #Backward Pass
    ΔJ = backward_pass!(p,P,d,K)

    #Forward rollout with line search
    xn[:,1] = xtraj[:,1]
    α = 1.0
    for k = 1:(Nt-1)
        un[k] = utraj[k] - α*d[k] - dot(K[:,:,k],xn[:,k]-xtraj[:,k])
        xn[:,k+1] .= dynamics_rk4(xn[:,k],un[k])
    end
    Jn = cost(xn,un)
    
    while isnan(Jn) || Jn > (J - 1e-2*α*ΔJ)
        α = 0.5*α
        for k = 1:(Nt-1)
            un[k] = utraj[k] - α*d[k] - dot(K[:,:,k],xn[:,k]-xtraj[:,k])
            xn[:,k+1] .= dynamics_rk4(xn[:,k],un[k])
        end
        Jn = cost(xn,un)
    end

    # logging
    if rem(iter - 1, 100) == 0
        @printf "iter     J           ΔJ        |d|         α       \n"
        @printf "---------------------------------------------------\n"
    end
    if rem(iter - 1, 10) == 0 
        @printf("%3d   %10.3e  %9.2e  %9.2e  %6.4f  \n",
              iter, J, ΔJ, maximum(abs.(d[:])), α)
    end
    
    J = Jn
    xtraj .= xn
    utraj .= un
end

In [None]:
iter

In [None]:
plot(thist,xtraj[1,:])
plot(thist,xtraj[2,:])

In [None]:
plot(thist[1:Nt-1],utraj)

In [None]:
import MeshCat as mc 

function rotx(θ)
    s, c = sincos(θ)
    return [1 0 0; 0 c -s; 0 s c]
end
function create_cartpole!(vis)
    mc.setobject!(vis[:cart], mc.HyperRectangle(mc.Vec(-.25,-1.0,-.15), mc.Vec(0.5,2,0.3)))
    mc.setobject!(vis[:pole], mc.Cylinder(mc.Point(0,0,-.75), mc.Point(0,0,.75), 0.05))
    mc.setobject!(vis[:a], mc.HyperSphere(mc.Point(0,0,0.0),0.1))
    mc.setobject!(vis[:b], mc.HyperSphere(mc.Point(0,0,0.0),0.1))
end
function update_cartpole_transform!(vis,x)
    pole_o = 0.3
    px = x[1]
    θ = x[2]
    mc.settransform!(vis[:cart], mc.Translation([0,px,0.0]))
    p1 = [pole_o,px,0]
    p2 = p1 + 1.5*[0,sin(θ), -cos(θ)]
    mc.settransform!(vis[:a], mc.Translation(p1))
    mc.settransform!(vis[:b], mc.Translation(p2))
    mc.settransform!(vis[:pole], mc.Translation(0.5*(p1 + p2)) ∘ mc.LinearMap(rotx(θ))) 
end

function animate_cartpole(X, dt)
    vis = mc.Visualizer()
    create_cartpole!(vis)
    anim = mc.Animation(vis; fps=floor(Int,1/dt))
    for k = 1:length(X)
        mc.atframe(anim, k) do
            update_cartpole_transform!(vis,X[k])
        end
    end
    mc.setanimation!(vis, anim)
    return mc.render(vis)
end

In [None]:
X1 = [Vector(x) for x in eachcol(xtraj)];
animate_cartpole(X1, h)