In [147]:
import numpy as np

In [148]:
# transition
def calc_xt1(xt, ut, At, Bt):
    return At @ xt + Bt @ ut

# cleanup matrices

def calc_Ft(Qt, At, Kt1):
    return Qt + At.T @ Kt1 @ At

def calc_Gt(Rt, Bt, Kt1):
    return Rt + Bt.T @ Kt1 @ Bt

def calc_Ht(At, Kt1, Bt):
    return At.T @ Kt1 @ Bt

def calc_Kt(At, Kt1, Bt, Gt, Qt):
    return At.T @ (Kt1 - Kt1 @ Bt @ np.linalg.inv(Gt) @ Bt.T @ Kt1) @ At + Qt

def cleanup_wrapper(Qt, At, Bt, Rt, Kt1):
    F_t = calc_Ft(Qt, At, Kt1)
    G_t = calc_Gt(Rt, Bt, Kt1)
    H_t = calc_Ht(At, Kt1, Bt)
    
    print("F_t.shape =", F_t.shape)
    print("G_t.shape =", G_t.shape)
    print("H_t.shape =", H_t.shape)
    return F_t, G_t, H_t

# optimal policy structure

def mut(Gt, Ht, xt=None):
    if xt is None:
        return -np.linalg.inv(Gt) @ Ht.T
    else:
        return -np.linalg.inv(Gt) @ Ht.T @ xt

def Vt(x, Kt, alphat=0):
    return x.T @ Kt @ x + alphat

def iterate_K(Qt, At, Bt, Rt, KT, T=3):
    K_arr = [0]*T
    K_arr[-1] = KT
    
    for i in range(len(K_arr) - 1):
        K = K_arr[-1 - i]
        F_t, G_t, H_t = cleanup_wrapper(Qt, At, Bt, Rt, K)
        K_t = calc_Kt(At, K, Bt, G_t, Qt)
        K_arr[-1 - i - 1] = K_t
    
    return K_arr
    

In [149]:
A = np.array([
    [1, 2],
    [0, 1]
])

B = np.array([0, 1])[:, np.newaxis]  # col vector

Q = np.array([
    [1, 0],
    [0, 0]
])

R = 1

K3 = np.zeros((2, 2))

In [150]:
K_arr = iterate_K(Q, A, B, R, K3, T=3)
for idx, K in enumerate(K_arr):
    print("-"*40)
    print(f"K_{idx+1}:")
    print(K)

F_t.shape = (2, 2)
G_t.shape = (1, 1)
H_t.shape = (2, 1)
F_t.shape = (2, 2)
G_t.shape = (1, 1)
H_t.shape = (2, 1)
----------------------------------------
K_1:
[[2. 2.]
 [2. 4.]]
----------------------------------------
K_2:
[[1. 0.]
 [0. 0.]]
----------------------------------------
K_3:
[[0. 0.]
 [0. 0.]]


In [151]:
mu_coeffs_arr = []
for idx, K in enumerate(K_arr):
    Gt = calc_Gt(R, B, K)
    Ht = calc_Ht(A, K, B)
    
    mu_coeffs = mut(Gt, Ht)
    mu_coeffs_arr.append(mu_coeffs)
    
    print(f"mu_{idx}:", mu_coeffs)

mu_0: [[-0.4 -1.6]]
mu_1: [[0. 0.]]
mu_2: [[0. 0.]]


In [152]:
x0 = np.array([1, 1])[:, np.newaxis]
u0 = mu_coeffs_arr[0] @ x0
x1 = calc_xt1(x0, u0, A, B)
print("x1:", x1)

u1 = mu_coeffs_arr[1] @ x1
x2 = calc_xt1(x1, u1, A, B)
print("x2:", x2)

u2 = mu_coeffs_arr[2] @ x2
x3 = calc_xt1(x2, u2, A, B)
print("x3:", x3)

x1: [[ 3.]
 [-1.]]
x2: [[ 1.]
 [-1.]]
x3: [[-1.]
 [-1.]]
