In [2]:
%load_ext autoreload
%autoreload 2

In [254]:
import numpy as np
import scipy 
from neurosim.models.ssr import StateSpaceRealization as SSR, gen_random_model
from dca_research.kca import calc_mmse_from_cross_cov_mats
import matplotlib.pyplot as plt

In [4]:
import sys
from tqdm import tqdm
import torch

In [5]:
sys.path.append('../..')
from utils import calc_loadings

### Goal: Implement formula for gradient of exponential of the Hamiltonian matrix

In [5]:
# First step, verify that the formula in terms of blocks of the matrix exponential matches the solution from scipy
# Then, implement in pytorch and compare gradient to explicit formula. 
# Then, see whether the gradient vanishes for eigenvectors of A

In [6]:
def riccati_solve_torch(A, B, Q, C, P, t):

    H = torch.cat([torch.cat([-A.T, torch.chain_matmul(C.t(), torch.inverse(P), C)], dim=1), torch.cat([torch.chain_matmul(B, Q, B.t()), A], dim=1)], dim=0)
    Phi = torch.matrix_exp(H * t)

    n = A.shape[0]
    # Partition
    Phi11 = Phi[0:n, 0:n]
    Phi12 = Phi[0:n, n:]
    Phi21 = Phi[n:, 0:n]
    Phi22 = Phi[n:, n:]

    Sigma = torch.matmul(Phi21 + Phi22, torch.inverse(Phi11 + Phi12))
    return Sigma

In [96]:
A, B, C = gen_random_model(20, 2, 20, True)
# P and Q can also be set to the identity for simplicity
P = np.eye(C.shape[0])
Q = np.eye(B.shape[1])

In [20]:
At = torch.tensor(A)
Bt = torch.tensor(B)
Ct = torch.tensor(C, requires_grad=True)
Pt = torch.tensor(P)
Qt = torch.tensor(Q)

In [21]:
St = riccati_solve_torch(At, Bt, Qt, Ct, Pt, 20)

In [16]:
torch.linalg.eigvals(St)

tensor([ 2.8637e-01+0.0000e+00j,  2.4201e-01+0.0000e+00j,
         6.5059e-03+0.0000e+00j,  5.3278e-03+0.0000e+00j,
         2.0675e-04+0.0000e+00j,  1.0879e-04+0.0000e+00j,
         2.5782e-06+0.0000e+00j,  1.0697e-06+0.0000e+00j,
         5.0332e-08+0.0000e+00j,  1.4891e-08+2.4104e-09j,
         1.4891e-08-2.4104e-09j, -2.9159e-10+8.4879e-09j,
        -2.9159e-10-8.4879e-09j,  5.7782e-09+0.0000e+00j,
        -7.1883e-09+0.0000e+00j, -4.6420e-09+3.8113e-09j,
        -4.6420e-09-3.8113e-09j,  1.8775e-10+7.8625e-10j,
         1.8775e-10-7.8625e-10j, -2.5079e-09+0.0000e+00j],
       dtype=torch.complex128)

In [19]:
np.linalg.eigvals(scipy.linalg.solve_continuous_are(A.T, C.T, B @ Q @ B.T, P))

array([ 2.86374934e-01,  2.42013261e-01,  6.50585532e-03,  5.32780943e-03,
        2.06717893e-04,  1.08777074e-04,  2.57029947e-06,  1.06513128e-06,
        4.26309777e-08,  1.63997503e-08,  6.18588015e-10,  1.38343146e-10,
        2.42175674e-12,  1.08104000e-12,  2.38413897e-14,  6.11796199e-15,
        2.93579288e-16, -1.73360130e-16, -5.72109684e-17, -6.74857990e-20])

In [22]:
loss = torch.trace(St)

In [23]:
loss.backward()

In [24]:
Ct.grad

tensor([[  8.4888,  -3.8430,   6.9102, -11.8985,  10.7013, -18.7788, -22.5299,
         -28.2243,  19.4963,   8.4663,  -1.1047,  12.4422, -17.5302, -14.5567,
         -11.7259,  13.0519, -38.0177, -16.7474,  -0.6369,  10.6166],
        [ 18.3191,  10.4250, -17.8094,   2.9947, -15.7714, -23.7243,  -2.8834,
         -31.2822, -45.9174,   2.2860,   0.8855,   0.6905,  29.1203, -33.8856,
          17.0007,  40.3078,  17.2733, -26.9402,  15.5793,  13.0783]],
       dtype=torch.float64)

In [25]:
# We should first test against the matrix exponential

In [None]:
At = torch.tensor(A)
Bt = torch.tensor(B)
Ct = torch.tensor(C, requires_grad=True)
Pt = torch.tensor(P)
Qt = torch.tensor(Q)

In [None]:
H = torch.cat([torch.cat([-A.T, torch.chain_matmul(C.t(), torch.inverse(P), C)], dim=1), torch.cat([torch.chain_matmul(B, Q, B.t()), A], dim=1)], dim=0)
Phi = torch.matrix_exp(H * t)
loss = torch.trace(Phi)

In [None]:
# First, need dH/dC, thb
# en plug into exponential formula

In [1]:
# Are the Riccati solutions inverses of each other as Jockhertee claims?

In [101]:
P1 = scipy.linalg.solve_continuous_are(A.T, C.T, B @ B.T, np.eye(C.shape[0]))
P2 = scipy.linalg.solve_continuous_are(A, B, C.T @ C, np.eye(B.shape[1]))

In [12]:
import pdb

In [44]:
# Solve are using schur decomposition
def solve_are(A, B, Q, R, stable=True):


    # Hamiltonian matrix
    H = np.block([[A, -B @ np.linalg.inv(R) @ B.T], [-Q, -A.T]])
    BB = np.eye(H.shape[0])

    if stable:
        AA, BB, alpha, beta, Q, Z = scipy.linalg.ordqz(H, BB, sort='lhp')
    else:
        AA, BB, alpha, beta, Q, Z = scipy.linalg.ordqz(H, BB, sort='rhp')

    Z = Z[:, 0:A.shape[0]]
    Z1 = Z[0:A.shape[0], :]
    Z2 = Z[A.shape[0]:, :]

    return Z2 @ np.linalg.inv(Z1)
    

In [172]:
def solve_dare(A, B, Q, R, stable=True):

    # Matrix pencil
    AA = np.block([[A, np.zeros(A.shape)], [-Q, np.eye(A.shape[0])]])
    BB = np.block([[np.eye(B.shape[0]), B @ np.linalg.inv(R) @ B.T], [np.zeros(A.shape), A.T]])

    if stable:
        AA, BB, alpha, beta, Q, Z = scipy.linalg.ordqz(AA, BB, sort='iuc')
    else:
        AA, BB, alpha, beta, Q, Z = scipy.linalg.ordqz(AA, BB, sort='ouc')

    Z = Z[:, 0:A.shape[0]]
    Z1 = Z[0:A.shape[0], :]
    Z2 = Z[A.shape[0]:, :]

    return Z2 @ np.linalg.inv(Z1)



In [116]:
# Solve using the matrix sign function. The advantage of this is that one gets all 4 matrices of interest in one go

# See: Positive and negative solutions of dual Riccati equations by matrix sign function iteration
def sgn(H):
    Z = H
    for i in range(500):
        Z = 1/2 * (Z + np.linalg.inv(Z))
    return Z

def solve_are_sgn(A, B, C):

    # Hamiltonian matrix associated with the LQR problem. The transpose is
    # associated with the filtering problem
    H = np.block([[A, -B @ B.T], [-C.T @ C, -A.T]])
    Z = sgn(H)
    Z11 = Z[0:A.shape[0], 0:A.shape[0]]
    Z12 = Z[0:A.shape[0], A.shape[0]:]
    Z21 = Z[A.shape[0]:, 0:A.shape[0]]
    Z22 = Z[A.shape[0]:, A.shape[0]:]

    Pp = -1 * scipy.linalg.pinv(np.block([[Z12], [Z22 + np.eye(A.shape[0])]])) @ np.block([[Z11 + np.eye(A.shape[0])], [Z21]])
    Pm = -1 * scipy.linalg.pinv(np.block([[Z12], [Z22 - np.eye(A.shape[0])]])) @ np.block([[Z11 - np.eye(A.shape[0])], [Z21]])

    # Take the transpose of the Hamiltonian matrix and proceed as before
    Z = sgn(H).T
    Z11 = Z[0:A.shape[0], 0:A.shape[0]]
    Z12 = Z[0:A.shape[0], A.shape[0]:]
    Z21 = Z[A.shape[0]:, 0:A.shape[0]]
    Z22 = Z[A.shape[0]:, A.shape[0]:]

    Qp = -1 * scipy.linalg.pinv(np.block([[Z12], [Z22 + np.eye(A.shape[0])]])) @ np.block([[Z11 + np.eye(A.shape[0])], [Z21]])
    Qm = -1 * scipy.linalg.pinv(np.block([[Z12], [Z22 - np.eye(A.shape[0])]])) @ np.block([[Z11 - np.eye(A.shape[0])], [Z21]])

    return Pp, Pm, Qp, Qm


In [105]:
Pp, Pm, Qp, Qm = solve_are_sgn(A, B, C)

In [107]:
np.allclose(P1, Qp)

True

In [108]:
np.linalg.cond(Pp)

1.5789733726932968e+16

In [109]:
np.linalg.cond(Pm)

2.4137913499767225

In [113]:
np.allclose(np.linalg.inv(Pm), -1 * Qp)

True

In [114]:
np.allclose(np.linalg.inv(Qm), -1 * Pp)

False

In [115]:
# One gets the right answer when the condition numbers are reasonable...

In [248]:
# Discrete time testing
A, B, C = gen_random_model(20)
A = 1/2 * (A + A.T)
#A = np.diag(np.random.uniform(0, 1, size=(20,)))

In [239]:
np.linalg.eigvals(A)

array([-0.81726449, -0.64745509, -0.52170145, -0.42822448, -0.38401076,
       -0.32162278,  0.74599143,  0.6755151 ,  0.60637773,  0.53613787,
       -0.15385582, -0.09038145, -0.05293723,  0.03307939,  0.13363644,
        0.43073073,  0.19308675,  0.31883503,  0.34010362,  0.34728985])

In [220]:
P1 = scipy.linalg.solve_discrete_are(A, B, C.T @ C, np.eye(B.shape[1]))
# Dual solution
P2 = scipy.linalg.solve_discrete_are(A.T, C.T, B @ B.T, np.eye(C.shape[0]))

In [249]:
Q1 = solve_dare(A, B, C.T @ C, np.eye(B.shape[1]))
Q1m = solve_dare(A, B, C.T @ C, np.eye(B.shape[1]), False)

Q2 = solve_dare(A.T, C.T, B @ B.T, np.eye(C.shape[0]))
Q2m = solve_dare(A.T, C.T, B @ B.T, np.eye(C.shape[0]), False)

In [244]:
np.trace(Q1)

23.96079954242218

In [250]:
Q1 @ np.linalg.inv(Q1)

array([[ 1.00000000e+00, -7.28672907e-19,  2.85582777e-19,
         2.07835593e-18,  3.30855785e-18,  1.15044810e-18,
        -1.76984303e-19,  7.96621250e-19,  4.27226569e-18,
        -7.32195994e-20, -2.76888781e-18, -1.11229585e-18,
        -1.62630326e-19,  6.93889390e-18, -3.79470760e-19,
         6.50521303e-19, -2.71050543e-19,  4.66206934e-18,
        -1.18313562e-17,  0.00000000e+00],
       [-1.19596865e-19,  1.00000000e+00, -1.84662694e-18,
         4.58713837e-18, -5.30180699e-18, -1.20060910e-17,
         4.02776303e-19,  2.52014070e-18,  3.26424481e-19,
        -3.18476876e-18, -1.93057936e-18,  3.53306939e-19,
         1.78893358e-18,  9.10729825e-18,  2.22261445e-18,
        -5.42101086e-19,  2.00577402e-18,  4.33680869e-19,
        -1.69067776e-17,  3.46944695e-18],
       [ 1.56082944e-17,  5.27078295e-19,  1.00000000e+00,
        -2.17699562e-18, -1.11974447e-18, -3.53527319e-18,
         9.81620720e-18,  2.05426797e-18,  2.71747945e-18,
         4.17497864e-19, -8.0

In [246]:
np.trace(np.linalg.inv(Q1m))

-23.96079954242218

In [231]:
np.allclose(np.linalg.inv(Q1m), -1*P2)

True

In [319]:
A, B, C = gen_random_model(20)

In [320]:
ssr = SSR(A, B, C)

In [330]:
ccm = ssr.autocorrelation(5)

In [322]:
from dca.cov_util import calc_cov_from_cross_cov_mats

In [335]:
calc_mmse_from_cross_cov_mats(torch.tensor(ccm).float(), proj=torch.eye(ccm.shape[1]).float())

tensor(20.)

In [336]:
ssr.solve_min_phase()

In [290]:
X = np.block([[covf.numpy(), covpf.numpy().T], [covpf.numpy(), covp.numpy()]])

In [276]:
np.linalg.eigvals(covf.numpy())

array([2.32110202, 1.58520366])

In [277]:
np.linalg.eigvals(covp.numpy())

array([4.7620994 , 4.73554397, 3.95362946, 3.90252817, 3.21494589,
       3.16970859, 2.69407875, 2.64382498, 1.85947067, 2.39583849,
       2.34771529, 2.32758591, 2.32135343, 2.06495108, 2.2485509 ,
       2.20993624, 2.14318295, 2.1417027 , 1.72135523, 1.50840663,
       1.3169601 , 1.16886422, 1.14774686, 1.14580719, 1.12512732,
       1.13924207, 1.13645881, 1.13444905, 1.09242932, 1.08197266,
       1.05989906, 1.0527829 , 1.03637703, 1.03723791, 1.04083439,
       1.04137276, 1.04781372, 1.04802378])

In [279]:
torch.trace(covf - torch.chain_matmul(covpf.t(), torch.inverse(covp), covpf))

tensor(-1.8538, dtype=torch.float64)

In [287]:
np.linalg.cond(covp.numpy())

4.5949488271275865

In [286]:
covpf.t() @ torch.inverse(covp) @ covpf

tensor([[ 3.0336, -0.5099],
        [-0.5099,  2.7265]], dtype=torch.float64)

In [281]:
torch.trace(covf - covpf.t() @ torch.inverse(covp) @ covpf)

tensor(-1.8538, dtype=torch.float64)

In [284]:
np.trace(covf.numpy() - covpf.numpy().T @ np.linalg.inv(covp.numpy()) @ covpf.numpy())

-1.8537728993631153

In [299]:
# Does forward Riccati equation converge to MMSE
ccm = ssr.autocorrelation(20)
cov = calc_cov_from_cross_cov_mats(ccm)

In [311]:
np.concatenate([c[np.newaxis, :] for c in ccm]).shape

(20, 2, 2)

In [304]:
# MMSE forward
def mmse_forward(ccm, proj=None):

    if proj is not None:
        ccm_proj

    T = ccm.shape[0] - 1
    N = ccm.shape[-1]
    cov = calc_cov_from_cross_cov_mats(ccm)
    cov_proj = calc_cov_from_cross_cov_mats(np.concatenate([(proj.T @ c @ proj)[np.newaxis, :] for c in ccm])

    covf = cov[-N:, -N:]
    covp = cov[:T*N, :T*N]
    covpf = cov[:T*N, -N:]
    covfp = cov[-N:, :T*N]

    return covf - covfp @ np.linalg.inv(covp) @ covpf


def mmse_reverse(ccm):
    pass    

In [305]:
mmse_forward(ccm)

array([[ 1.88057021, -0.38751228],
       [-0.38751228,  1.63696645]])

In [307]:
ssr.solve_min_phase()