In [None]:
import Pkg; Pkg.activate(@__DIR__); Pkg.instantiate()

In [None]:
using LinearAlgebra
using PyPlot

In [None]:
# Discrete dynamics
h = 0.1   # time step
A = [1 h; 0 1]
B = [0.5*h*h; h]

In [None]:
n = 2     # number of state
m = 1     # number of controls
Tfinal = 10.0 # final time #try larger values
N = Int(Tfinal/h)+1    # number of time steps
thist = Array(range(0,h*(N-1), step=h));

In [None]:
# Initial conditions
x0 = [1.0; 0]

In [None]:
# Cost weights
Q = 1.0*I(2)
R = 0.1
Qn = 1.0*I(2)

In [None]:
function J(xhist,uhist)
    cost = 0.5*xhist[:,end]'*Qn*xhist[:,end]
    for k = 1:(N-1)
        cost = cost + 0.5*xhist[:,k]'*Q*xhist[:,k] + 0.5*uhist[k]'*R*uhist[k]
    end
    return cost
end

In [None]:
function rollout(xhist, uhist)
    xnew = zeros(size(xhist))
    xnew[:,1] = xhist[:,1]
    for k = 1:(N-1)
        xnew[:,k+1] .= A*xnew[:,k] + B*uhist[k]
    end
    return xnew
end

In [None]:
# Initial guess
xhist = repeat(x0, 1, N)
uhist = zeros(N-1)
Δu = ones(N-1)
λhist = zeros(n,N)

xhist = rollout(xhist, uhist) #initial rollout to get state trajectory

J(xhist,uhist) #Initial cost

In [None]:
b = 1e-2 #line search tolerance
α = 1.0
iter = 0
while maximum(abs.(Δu[:])) > 1e-2 #terminate when the gradient is small
    
    #Backward pass to compute λ and Δu
    λhist[:,N] .= Qn*xhist[:,N]
    for k = N-1:-1:1
        Δu[k] = -(uhist[k]+R\B'*λhist[:,k+1])
        λhist[:,k] .= Q*xhist[:,k] + A'*λhist[:,k+1]
    end
    
    #Forward pass with line search to compute x
    α = 1.0
    unew = uhist + α.*Δu
    xnew = rollout(xhist, unew)
    while J(xnew, unew) > J(xhist, uhist) - b*α*Δu[:]'*Δu[:]
        α = 0.5*α
        unew = uhist + α.*Δu
        xnew = rollout(xhist, unew)
    end
    uhist .= unew;
    xhist .= xnew;
    iter += 1
end

In [None]:
iter

In [None]:
J(xhist,uhist) #cost of solution

In [None]:
# Plot x1 vs. x2, u vs. t, x vs. t, etc.
plot(thist,xhist[1,:], label="Position")
plot(thist,xhist[2,:], label="Velocity")
xlabel("Time")
legend()

In [None]:
plot(thist[1:end-1], uhist, label="control")
xlabel("Time")
legend()