In [1]:
%load_ext autoreload 
%autoreload 2

In [2]:
import numpy as np
import scipy
import matplotlib.pyplot as plt
import pdb
import torch

In [3]:
from neurosim.models.ssr import StateSpaceRealization as SSR

In [4]:
from dca_research.kca import KalmanComponentsAnalysis as KCA
from dca_research.kca import calc_mmse_from_cross_cov_mats

In [5]:
from dca.cov_util import calc_cov_from_cross_cov_mats

In [6]:
size = 20
A = 1/np.sqrt(1.2 * size) * np.random.normal(size=(size, size))
while max(np.abs(np.linalg.eigvals(A)) > 0.99):
    A = 1/np.sqrt(1.2 * size) * np.random.normal(size=(size, size))

B = np.eye(size)
C = np.eye(size)
ssr = SSR(A=A, B=B, C=C)
ssr.solve_min_phase()

In [7]:
V = scipy.stats.ortho_group.rvs(size)[:, 0:1]

In [8]:
c, cp, cpf = ssr.mmse_cov(10, proj=V.T)

In [9]:
ssrproj = SSR(A=A, B=B, C=V.T)
ssrproj.solve_min_phase()
np.trace(ssrproj.P - ssrproj.Pmin)

63.985306496734474

In [10]:
np.trace(ssr.P - cpf.T @ np.linalg.inv(cp) @ cpf)

64.02504856619916

In [12]:
mmse, cp_, cf_, cpf_ = calc_mmse_from_cross_cov_mats(torch.tensor(ssr.autocorrelation(10)), 
                                                     torch.tensor(V), return_covs=True)

In [36]:
mmse

tensor(75.4476, dtype=torch.float64)

In [123]:
# These are still off from the Riccati equation...

In [7]:
x = ssr.trajectory(5 * int(1e4), burnoff=True)

In [21]:
# Check # 1: Does mmse_from_cross_cov_mats match the trace of Pmin

In [8]:
kcamodel = KCA(d=, T=5, verbose=True)
kcamodel.estimate_data_statistics(x)

INFO:Model:Starting cross covariance estimate.
INFO:Model:Cross covariance estimate took 0.0 minutes.


<dca_research.kca.KalmanComponentsAnalysis at 0x7f5abe03b4d0>

In [12]:
kcamodel._fit_projection(record_V=True)

INFO:Model:Loss 71.6205, PI: -71.6205 nats, reg: 0.0
 This problem is unconstrained.
INFO:Model:Loss 60.8155, PI: -51.8248 nats, reg: 8.9907
INFO:Model:Loss 50.5755, PI: -50.1468 nats, reg: 0.4286
INFO:Model:Loss 47.3611, PI: -47.3446 nats, reg: 0.0165
INFO:Model:Loss 43.7048, PI: -43.6141 nats, reg: 0.0907
INFO:Model:Loss 42.3901, PI: -42.2173 nats, reg: 0.1728
INFO:Model:Loss 41.7463, PI: -41.676 nats, reg: 0.0703
INFO:Model:Loss 41.2026, PI: -41.1525 nats, reg: 0.0501
INFO:Model:Loss 40.9157, PI: -40.8847 nats, reg: 0.031
INFO:Model:Loss 40.668, PI: -40.6524 nats, reg: 0.0156
INFO:Model:Loss 40.594, PI: -40.567 nats, reg: 0.027
INFO:Model:Loss 40.5622, PI: -40.54 nats, reg: 0.0221
INFO:Model:Loss 40.4994, PI: -40.4965 nats, reg: 0.0028
INFO:Model:Loss 40.4644, PI: -40.4618 nats, reg: 0.0026
INFO:Model:Loss 40.4065, PI: -40.3963 nats, reg: 0.0102
INFO:Model:Loss 40.3389, PI: -40.3155 nats, reg: 0.0233
INFO:Model:Loss 40.2608, PI: -40.2427 nats, reg: 0.0181
INFO:Model:Loss 40.1643, PI

RUNNING THE L-BFGS-B CODE

           * * *

Machine precision = 2.220D-16
 N =           40     M =           10

At X0         0 variables are exactly at the bounds

At iterate    0    f=  7.16205D+01    |proj g|=  5.04913D+01

At iterate    1    f=  6.08155D+01    |proj g|=  3.01284D+01

At iterate    2    f=  5.05755D+01    |proj g|=  7.79067D+00

At iterate    3    f=  4.73611D+01    |proj g|=  4.17080D+00

At iterate    4    f=  4.37048D+01    |proj g|=  2.61482D+00

At iterate    5    f=  4.23901D+01    |proj g|=  3.07588D+00

At iterate    6    f=  4.17463D+01    |proj g|=  1.84844D+00

At iterate    7    f=  4.12026D+01    |proj g|=  1.79682D+00

At iterate    8    f=  4.09157D+01    |proj g|=  1.01755D+00

At iterate    9    f=  4.06680D+01    |proj g|=  1.28347D+00

At iterate   10    f=  4.05940D+01    |proj g|=  1.44418D+00

At iterate   11    f=  4.05622D+01    |proj g|=  8.63828D-01

At iterate   12    f=  4.04994D+01    |proj g|=  4.72862D-01

At iterate   13    f=  4.0

(array([[-0.21390369,  0.1097262 ],
        [ 0.02485077, -0.11152889],
        [-0.45549672, -0.15530567],
        [-0.2468783 ,  0.20347477],
        [-0.20403521,  0.13881013],
        [-0.26722186, -0.60600851],
        [ 0.25512414, -0.09519969],
        [ 0.25981572,  0.1651015 ],
        [-0.2267173 ,  0.10114024],
        [ 0.35364584, -0.25330922],
        [ 0.00327859, -0.08187837],
        [-0.09568254,  0.49418196],
        [ 0.17244414,  0.10783595],
        [-0.12669125, -0.20490008],
        [-0.25004209,  0.27956936],
        [-0.13916988, -0.07512523],
        [-0.20096656, -0.10638501],
        [-0.07320046, -0.04900695],
        [ 0.03628532,  0.10350397],
        [ 0.28142426, -0.00090147]]),
 -39.74808106147508)