In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import numpy as np
import scipy 
from neurosim.models.ssr import StateSpaceRealization as SSR, gen_random_model
from dca_research.kca import calc_mmse_from_cross_cov_mats
import matplotlib.pyplot as plt

In [4]:
import sys
from tqdm import tqdm
import torch

In [5]:
sys.path.append('../..')
from utils import calc_loadings

### Goal: Implement formula for gradient of exponential of the Hamiltonian matrix

In [5]:
# First step, verify that the formula in terms of blocks of the matrix exponential matches the solution from scipy
# Then, implement in pytorch and compare gradient to explicit formula. 
# Then, see whether the gradient vanishes for eigenvectors of A

In [6]:
def riccati_solve_torch(A, B, Q, C, P, t):

    H = torch.cat([torch.cat([-A.T, torch.chain_matmul(C.t(), torch.inverse(P), C)], dim=1), torch.cat([torch.chain_matmul(B, Q, B.t()), A], dim=1)], dim=0)
    Phi = torch.matrix_exp(H * t)

    n = A.shape[0]
    # Partition
    Phi11 = Phi[0:n, 0:n]
    Phi12 = Phi[0:n, n:]
    Phi21 = Phi[n:, 0:n]
    Phi22 = Phi[n:, n:]

    Sigma = torch.matmul(Phi21 + Phi22, torch.inverse(Phi11 + Phi12))
    return Sigma

In [96]:
A, B, C = gen_random_model(20, 2, 20, True)
# P and Q can also be set to the identity for simplicity
P = np.eye(C.shape[0])
Q = np.eye(B.shape[1])

In [20]:
At = torch.tensor(A)
Bt = torch.tensor(B)
Ct = torch.tensor(C, requires_grad=True)
Pt = torch.tensor(P)
Qt = torch.tensor(Q)

In [21]:
St = riccati_solve_torch(At, Bt, Qt, Ct, Pt, 20)

In [16]:
torch.linalg.eigvals(St)

tensor([ 2.8637e-01+0.0000e+00j,  2.4201e-01+0.0000e+00j,
         6.5059e-03+0.0000e+00j,  5.3278e-03+0.0000e+00j,
         2.0675e-04+0.0000e+00j,  1.0879e-04+0.0000e+00j,
         2.5782e-06+0.0000e+00j,  1.0697e-06+0.0000e+00j,
         5.0332e-08+0.0000e+00j,  1.4891e-08+2.4104e-09j,
         1.4891e-08-2.4104e-09j, -2.9159e-10+8.4879e-09j,
        -2.9159e-10-8.4879e-09j,  5.7782e-09+0.0000e+00j,
        -7.1883e-09+0.0000e+00j, -4.6420e-09+3.8113e-09j,
        -4.6420e-09-3.8113e-09j,  1.8775e-10+7.8625e-10j,
         1.8775e-10-7.8625e-10j, -2.5079e-09+0.0000e+00j],
       dtype=torch.complex128)

In [19]:
np.linalg.eigvals(scipy.linalg.solve_continuous_are(A.T, C.T, B @ Q @ B.T, P))

array([ 2.86374934e-01,  2.42013261e-01,  6.50585532e-03,  5.32780943e-03,
        2.06717893e-04,  1.08777074e-04,  2.57029947e-06,  1.06513128e-06,
        4.26309777e-08,  1.63997503e-08,  6.18588015e-10,  1.38343146e-10,
        2.42175674e-12,  1.08104000e-12,  2.38413897e-14,  6.11796199e-15,
        2.93579288e-16, -1.73360130e-16, -5.72109684e-17, -6.74857990e-20])

In [22]:
loss = torch.trace(St)

In [23]:
loss.backward()

In [24]:
Ct.grad

tensor([[  8.4888,  -3.8430,   6.9102, -11.8985,  10.7013, -18.7788, -22.5299,
         -28.2243,  19.4963,   8.4663,  -1.1047,  12.4422, -17.5302, -14.5567,
         -11.7259,  13.0519, -38.0177, -16.7474,  -0.6369,  10.6166],
        [ 18.3191,  10.4250, -17.8094,   2.9947, -15.7714, -23.7243,  -2.8834,
         -31.2822, -45.9174,   2.2860,   0.8855,   0.6905,  29.1203, -33.8856,
          17.0007,  40.3078,  17.2733, -26.9402,  15.5793,  13.0783]],
       dtype=torch.float64)

In [25]:
# We should first test against the matrix exponential

In [None]:
At = torch.tensor(A)
Bt = torch.tensor(B)
Ct = torch.tensor(C, requires_grad=True)
Pt = torch.tensor(P)
Qt = torch.tensor(Q)

In [None]:
H = torch.cat([torch.cat([-A.T, torch.chain_matmul(C.t(), torch.inverse(P), C)], dim=1), torch.cat([torch.chain_matmul(B, Q, B.t()), A], dim=1)], dim=0)
Phi = torch.matrix_exp(H * t)
loss = torch.trace(Phi)

In [None]:
# First, need dH/dC, thb
# en plug into exponential formula

In [1]:
# Are the Riccati solutions inverses of each other as Jockhertee claims?

In [101]:
P1 = scipy.linalg.solve_continuous_are(A.T, C.T, B @ B.T, np.eye(C.shape[0]))
P2 = scipy.linalg.solve_continuous_are(A, B, C.T @ C, np.eye(B.shape[1]))

In [12]:
import pdb

In [40]:
# Solve are using schur decomposition
def solve_are(A, B, Q, R, stable=True):


    # Hamiltonian matrix
    H = np.block([[A, -B @ np.linalg.inv(R) @ B.T], [-Q, -A.T]])
    BB = np.eye(H.shape[0])

    if stable:
        AA, BB, alpha, beta, Q, Z = scipy.linalg.ordqz(H, BB, sort='lhp')
    else:
        AA, BB, alpha, beta, Q, Z = scipy.linalg.ordqz(H, BB, sort='rhp')

    Z = Z[:, 0:A.shape[0]]
    Z1 = Z[0:A.shape[0], :]
    Z2 = Z[A.shape[0]:, :]

    return Z2 @ np.linalg.inv(Z1)
    

In [10]:
def solve_dare(A, B, Q, R, stable=True):

    # Matrix pencil
    AA = np.block([[A, np.zeros(A.shape)], [-Q, np.eye(A.shape[0])]])
    BB = np.block([[np.eye(B.shape[0]), B @ np.linalg.inv(R) @ B.T], [np.zeros(A.shape), A.T]])

    if stable:
        AA, BB, alpha, beta, Q, Z = scipy.linalg.ordqz(AA, BB, sort='iuc')
    else:
        AA, BB, alpha, beta, Q, Z = scipy.linalg.ordqz(AA, BB, sort='ouc')

    Z = Z[:, 0:A.shape[0]]
    Z1 = Z[0:A.shape[0], :]
    Z2 = Z[A.shape[0]:, :]

    return Z2 @ np.linalg.inv(Z1)



In [12]:
# Solve using the matrix sign function. The advantage of this is that one gets all 4 matrices of interest in one go

# See: Positive and negative solutions of dual Riccati equations by matrix sign function iteration
def sgn(H):
    Z = H
    for i in range(500):
        Z = 1/2 * (Z + np.linalg.inv(Z))
    return Z

def solve_are_sgn(A, B, C):

    # Hamiltonian matrix associated with the LQR problem. The transpose is
    # associated with the filtering problem
    H = np.block([[A, -B @ B.T], [-C.T @ C, -A.T]])
    Z = sgn(H)
    Z11 = Z[0:A.shape[0], 0:A.shape[0]]
    Z12 = Z[0:A.shape[0], A.shape[0]:]
    Z21 = Z[A.shape[0]:, 0:A.shape[0]]
    Z22 = Z[A.shape[0]:, A.shape[0]:]

    Pp = -1 * scipy.linalg.pinv(np.block([[Z12], [Z22 + np.eye(A.shape[0])]])) @ np.block([[Z11 + np.eye(A.shape[0])], [Z21]])
    Pm = -1 * scipy.linalg.pinv(np.block([[Z12], [Z22 - np.eye(A.shape[0])]])) @ np.block([[Z11 - np.eye(A.shape[0])], [Z21]])

    # Take the transpose of the Hamiltonian matrix and proceed as before
    Z = sgn(H).T
    Z11 = Z[0:A.shape[0], 0:A.shape[0]]
    Z12 = Z[0:A.shape[0], A.shape[0]:]
    Z21 = Z[A.shape[0]:, 0:A.shape[0]]
    Z22 = Z[A.shape[0]:, A.shape[0]:]

    Qp = -1 * scipy.linalg.pinv(np.block([[Z12], [Z22 + np.eye(A.shape[0])]])) @ np.block([[Z11 + np.eye(A.shape[0])], [Z21]])
    Qm = -1 * scipy.linalg.pinv(np.block([[Z12], [Z22 - np.eye(A.shape[0])]])) @ np.block([[Z11 - np.eye(A.shape[0])], [Z21]])

    return Pp, Pm, Qp, Qm


In [105]:
Pp, Pm, Qp, Qm = solve_are_sgn(A, B, C)

In [107]:
np.allclose(P1, Qp)

True

In [108]:
np.linalg.cond(Pp)

1.5789733726932968e+16

In [109]:
np.linalg.cond(Pm)

2.4137913499767225

In [113]:
np.allclose(np.linalg.inv(Pm), -1 * Qp)

True

In [114]:
np.allclose(np.linalg.inv(Qm), -1 * Pp)

False

In [115]:
# One gets the right answer when the condition numbers are reasonable...

In [6]:
# Discrete time testing
A, B, C = gen_random_model(20)
A = 1/2 * (A + A.T)
#A = np.diag(np.random.uniform(0, 1, size=(20,)))

In [7]:
np.linalg.eigvals(A)

array([-0.96019059, -0.6924917 , -0.6074098 ,  0.77639008,  0.7313606 ,
        0.64905692,  0.57270791,  0.4932683 , -0.44861674, -0.35492999,
       -0.32065521, -0.33719855,  0.3667119 ,  0.27059249,  0.18798358,
        0.10707903, -0.12402855, -0.07472079,  0.0079528 , -0.02754021])

In [8]:
P1 = scipy.linalg.solve_discrete_are(A, B, C.T @ C, np.eye(B.shape[1]))
# Dual solution
P2 = scipy.linalg.solve_discrete_are(A.T, C.T, B @ B.T, np.eye(C.shape[0]))

In [13]:
Q1 = solve_dare(A, B, C.T @ C, np.eye(B.shape[1]))
Q1m = solve_dare(A, B, C.T @ C, np.eye(B.shape[1]), False)

Q2 = solve_dare(A.T, C.T, B @ B.T, np.eye(C.shape[0]))
Q2m = solve_dare(A.T, C.T, B @ B.T, np.eye(C.shape[0]), False)

In [244]:
np.trace(Q1)

23.96079954242218

In [15]:
Q1 @ np.linalg.inv(Q1m)

array([[-1.42745780e+00, -3.04567103e-02, -7.68949693e-03,
         5.80212970e-02, -4.01651083e-02,  1.03878578e-01,
         9.34401591e-02,  1.29243807e-04, -3.97534861e-02,
        -1.73586496e-02,  4.01350340e-03, -1.39432679e-01,
         1.79311585e-01,  5.54353834e-02, -3.74970455e-02,
        -3.39647321e-02, -1.04107520e-01,  7.79831441e-02,
        -1.94677134e-02, -1.81056539e-01],
       [-3.04567103e-02, -1.32778868e+00,  2.39289449e-02,
         1.47664498e-01, -5.37147493e-02, -1.04300433e-01,
         8.23898424e-02, -4.46758957e-02, -3.18291185e-02,
         3.96018785e-02,  1.52111741e-02, -8.67200057e-03,
         6.49942986e-02, -8.89126055e-02, -5.09733906e-02,
        -4.30826995e-02,  7.78529368e-02,  1.12699715e-01,
        -7.04759628e-02,  4.59843115e-02],
       [-7.68949693e-03,  2.39289449e-02, -1.30626055e+00,
        -7.25909128e-02,  2.84914333e-02,  1.14476327e-01,
         7.50409796e-02, -1.77503293e-02, -4.74116922e-02,
        -2.69873828e-02,  1.7

In [246]:
np.trace(np.linalg.inv(Q1m))

-23.96079954242218

In [16]:
np.allclose(np.linalg.inv(Q1m), -1*P2)

True

In [319]:
A, B, C = gen_random_model(20)

In [320]:
ssr = SSR(A, B, C)

In [330]:
ccm = ssr.autocorrelation(5)

In [322]:
from dca.cov_util import calc_cov_from_cross_cov_mats

In [335]:
calc_mmse_from_cross_cov_mats(torch.tensor(ccm).float(), proj=torch.eye(ccm.shape[1]).float())

tensor(20.)

In [336]:
ssr.solve_min_phase()

In [290]:
X = np.block([[covf.numpy(), covpf.numpy().T], [covpf.numpy(), covp.numpy()]])

In [276]:
np.linalg.eigvals(covf.numpy())

array([2.32110202, 1.58520366])

In [277]:
np.linalg.eigvals(covp.numpy())

array([4.7620994 , 4.73554397, 3.95362946, 3.90252817, 3.21494589,
       3.16970859, 2.69407875, 2.64382498, 1.85947067, 2.39583849,
       2.34771529, 2.32758591, 2.32135343, 2.06495108, 2.2485509 ,
       2.20993624, 2.14318295, 2.1417027 , 1.72135523, 1.50840663,
       1.3169601 , 1.16886422, 1.14774686, 1.14580719, 1.12512732,
       1.13924207, 1.13645881, 1.13444905, 1.09242932, 1.08197266,
       1.05989906, 1.0527829 , 1.03637703, 1.03723791, 1.04083439,
       1.04137276, 1.04781372, 1.04802378])

In [279]:
torch.trace(covf - torch.chain_matmul(covpf.t(), torch.inverse(covp), covpf))

tensor(-1.8538, dtype=torch.float64)

In [287]:
np.linalg.cond(covp.numpy())

4.5949488271275865

In [286]:
covpf.t() @ torch.inverse(covp) @ covpf

tensor([[ 3.0336, -0.5099],
        [-0.5099,  2.7265]], dtype=torch.float64)

In [281]:
torch.trace(covf - covpf.t() @ torch.inverse(covp) @ covpf)

tensor(-1.8538, dtype=torch.float64)

In [284]:
np.trace(covf.numpy() - covpf.numpy().T @ np.linalg.inv(covp.numpy()) @ covpf.numpy())

-1.8537728993631153

In [299]:
# Does forward Riccati equation converge to MMSE
ccm = ssr.autocorrelation(20)
cov = calc_cov_from_cross_cov_mats(ccm)

In [311]:
np.concatenate([c[np.newaxis, :] for c in ccm]).shape

(20, 2, 2)

In [304]:
# MMSE forward
def mmse_forward(ccm, proj=None):

    if proj is not None:
        ccm_proj

    T = ccm.shape[0] - 1
    N = ccm.shape[-1]
    cov = calc_cov_from_cross_cov_mats(ccm)
    cov_proj = calc_cov_from_cross_cov_mats(np.concatenate([(proj.T @ c @ proj)[np.newaxis, :] for c in ccm])

    covf = cov[-N:, -N:]
    covp = cov[:T*N, :T*N]
    covpf = cov[:T*N, -N:]
    covfp = cov[-N:, :T*N]

    return covf - covfp @ np.linalg.inv(covp) @ covpf


def mmse_reverse(ccm):
    pass    

In [305]:
mmse_forward(ccm)

array([[ 1.88057021, -0.38751228],
       [-0.38751228,  1.63696645]])

In [307]:
ssr.solve_min_phase()

In [17]:
# Does the antistabilizing solution to the Riccati equation coincide with the solution of the Riccati equation obtained from a backwards
# Markovian realization of the process?

# One immediate consequence of the Kailath formula is that the forward and reverse time Kalman filter parameters should coincide.

In [157]:
A, B, C = gen_random_model(20, 10, cont=True)
# B = C.T

In [158]:
Pi = scipy.linalg.solve_continuous_lyapunov(A, -B @ B.T)

In [159]:
Pp, Pm, Qp, Qm = solve_are_sgn(A, B, C)

In [160]:
np.allclose(Qm, -1*np.linalg.inv(Pp))

True

In [161]:
np.linalg.cond(Pm)

2.213196988879976

In [162]:
np.allclose(-A, -A - B @ B.T @ np.linalg.inv(Pi))

False

In [171]:
Q1 = solve_are(A.T, C.T, B @ B.T, np.eye(C.shape[0]), stable=True)

In [176]:
Q2 = solve_are((-A - B @ B.T @ np.linalg.inv(Pi)).T, C.T, B @ B.T, np.eye(C.shape[0]), stable=True)

In [172]:
np.linalg.eigvals(Q1)

array([0.37810058, 0.36822253, 0.33790802, 0.32270506, 0.300653  ,
       0.29239621, 0.1708391 , 0.28015369, 0.27312249, 0.17790005,
       0.18521241, 0.19358801, 0.20305031, 0.25809415, 0.21077485,
       0.2451703 , 0.23787772, 0.23383626, 0.22500763, 0.22561057])

In [177]:
np.linalg.eigvals(Q2)

array([0.37678043, 0.36530915, 0.33857327, 0.32329659, 0.17082345,
       0.30140423, 0.17730941, 0.29075257, 0.28078575, 0.27382403,
       0.18554753, 0.1930167 , 0.20329504, 0.25765317, 0.21120329,
       0.24535055, 0.23822446, 0.23321985, 0.22534598, 0.22686252])

In [180]:
np.linalg.eigvals(Pp)b

array([3.07868345e-01, 2.99673619e-01, 2.62684288e-01, 2.59082437e-01,
       2.30970964e-01, 2.26881340e-01, 2.17022759e-01, 1.84176358e-01,
       2.00077744e-01, 1.96047785e-01, 9.06777542e-03, 6.29964175e-03,
       5.00721777e-03, 4.29488689e-03, 2.75048337e-03, 1.66799693e-03,
       8.47813473e-04, 4.25302854e-05, 1.04944233e-04, 2.07173131e-04])

In [181]:
# Conclusion is that we need to normalize the reverse time parameterization to obtain the adjoint state system
# Numerically verify 2 things:
# (1) The acausal Kalman filter Riccati solution coincides with the empirical MMSE
# (2) The acausal filtering problem for the adjoint state coincides with the solution of the forward time regulator riccati equation

In [182]:
# First task: Does discrete time MMSE converge to continuous time riccati solution as we make the timestep increasingly smaller?

In [183]:
# Back up: Does our implementation of mmse_from_cross_cov_mats work in the discrete time case?

In [215]:
from dca_research.cov_util import calc_mmse_from_cross_cov_mats

In [192]:
A, B, C = gen_random_model(20)

In [198]:
ssr = SSR(A, B, C)
ssr.solve_min_phase()
ccm = ssr.autocorrelation(10)

In [194]:
calc_mmse_from_cross_cov_mats(torch.tensor(ccm))

tensor(20.0000, dtype=torch.float64)

In [199]:
np.trace(ssr.P - ssr.Pmin)

20.000000000000007

In [200]:
# Now do projected version
A, B, C = gen_random_model(20, 2)

In [210]:
ssr = SSR(A, B, C)
ssr_ambient = SSR(A, B, C=np.eye(A.shape[0]))
ssr.solve_min_phase()
ccm = ssr_ambient.autocorrelation(10)

In [211]:
calc_mmse_from_cross_cov_mats(torch.tensor(ccm), proj=torch.tensor(C.T))

> [0;32m/home/akumar/nse/DCA_research/dca_research/kca.py[0m(46)[0;36mcalc_mmse_from_cross_cov_mats[0;34m()[0m
[0;32m     44 [0;31m        [0mcovf[0m [0;34m=[0m [0mcross_cov_mats[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     45 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 46 [0;31m        [0mcovpf[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mcat[0m[0;34m([0m[0mccm_proj2[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     47 [0;31m[0;34m[0m[0m
[0m[0;32m     48 [0;31m    [0;32melse[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


tensor(34.3742, dtype=torch.float64)

In [212]:
np.trace((ssr.P - ssr.Pmin))

34.373397726939075

In [216]:
# Next task: Does the (discrete time) acausal Riccati equation induced by the backwards model coincide with the empirical mmse?
A, B, C = gen_random_model(20)
# Construct acausal parameters:
ssr_fwd = SSR(A, B, C)
Ar = ssr_fwd.P @ A @ np.linalg.inv(ssr_fwd.P)
Br = ssr_fwd.P - Ar @ ssr_fwd.P @ Ar.T

# Projection
V = scipy.stats.ortho_group.rvs(20)[:, 0:2].T

ssr_bkwd_ambient = SSR(Ar, Br, C)
ccm = ssr_bkwd_ambient.autocorrelation(10)
ssr_bkwd = SSR(Ar, Br, V)
ssr_bkwd.solve_min_phase()

In [217]:
calc_mmse_from_cross_cov_mats(torch.tensor(ccm), proj=torch.tensor(V.T))

tensor(133.7017, dtype=torch.float64)

In [218]:
np.trace(ssr_bkwd.P - ssr_bkwd.Pmin)

133.70133716047124

In [219]:
# Next: Does discrete time MMSE converge to the solution of the continuous time Riccati equation as we let delta t -> 0?
deltat = np.logspace(-3, 0, 10)[::-1]
nt = np.array([10, 25, 50, 100])
diff = np.zeros((deltat.size, nt.size))

A, B, C = gen_random_model(20, cont=True)
Pcont = scipy.linalg.solve_continuous_lyapunov(A, -B @ B.T)
V = scipy.stats.ortho_group.rvs(20)[:, 0:2].T

Q = scipy.linalg.solve_continuous_are(A.T, V.T, B @ B.T, np.eye(V.shape[0]))

for i, dt in enumerate(deltat):
    for j, n in enumerate(nt):        
        ccm = np.array([scipy.linalg.expm(A * j * dt) @ Pcont for j in range(n)]) 
        m1 = calc_mmse_from_cross_cov_mats(torch.tensor(ccm), proj=torch.tensor(V.T))
        diff[i, j] = np.trace(Q - m1)

ValueError: matmul: Input operand 1 does not have enough dimensions (has 0, gufunc core with signature (n?,k),(k,m?)->(n?,m?) requires 1)

In [None]:
# Then, code up the modification to LQGCA in which we normalize by the state variance