In [None]:
import torch
import numpy as np

from lqr_solver import LQRSolver

## System parameters (DC motor)

In [None]:
DIM_X = 2
DIM_Y = 1

# Model constants
Ra = 1.9
La = 0.0226
Km = 1.1902
J = 0.035
C = 0.0035

# Dynamics matrices
A = torch.tensor(((-C/J, Km/J), 
                  (-Km/La, -Ra/La)))
B = torch.tensor((0., 1 / La))

# Transition matrices
dt = 0.001
A = (torch.eye(2) + A * dt)
B = B * dt

## LQR solver settings

In [None]:
# LQR properties
solver = LQRSolver(dim_x=DIM_X, dim_u=DIM_Y)

# Update solver matrices
solver.A = torch.matmul(solver.A, A)
solver.B = solver.B + B.reshape(solver.B.shape)

## LQR solution reference

In [None]:
from control import dlqr

K, P, _ = dlqr(solver.A.numpy(),
               solver.B.numpy(), 
               solver.Q.numpy(), 
               solver.R.numpy())
K = torch.tensor(K)

## Solve LQR with backpropagation

In [None]:
EPOCHS = 1000
BATCH_SIZE = 100
TIME_HORIZON = 100

# Choose optimizer
# optimizer = torch.optim.SGD([{"params": solver.parameters()}], lr=0.1)
optimizer = torch.optim.AdamW([{"params": solver.parameters()}], lr=0.1, weight_decay=0.01)

for epoch in range(EPOCHS):
    optimizer.zero_grad()
    
    # Generate random set of initial conditions
    x = torch.randn(2, BATCH_SIZE)

    loss = 0.0
    for i in range(TIME_HORIZON):
        # Compute control action via Ricatti euqation variable 'P'
        u = solver.forward(x)
        # Compute loss as : xT Q x + uT R u
        state_cost = torch.matmul(x.T, torch.matmul(solver.Q, x)) 
        action_cost = torch.matmul(u.T, torch.matmul(solver.R, u))
        # Accumulate loss
        loss += (state_cost + action_cost).sum() / (TIME_HORIZON * BATCH_SIZE)
        # Get next system state
        x = torch.matmul(solver.A, x) + torch.matmul(solver.B, u)

    loss.backward()
    optimizer.step()

    with torch.no_grad():
        if epoch % 10 == 0:
            # Compute P by definition
            P = torch.matmul(solver.P_sqrt, solver.P_sqrt.T)

            # Solve for Ricatti equation to get LQR gain 'K'
            A = solver.R + torch.matmul(solver.B.T, torch.matmul(P, solver.B))
            B = torch.matmul(solver.B.T, torch.matmul(P, solver.A))
            K_backprop = torch.linalg.solve(A, B)

            # Output difference of analytic solution and backprop
            l2_diff = torch.linalg.norm(K - K_backprop)
            print(f"# {epoch}, L2 difference: {l2_diff} \n")
print(f"real K : {K}, backprop K : {K_backprop}")