In [57]:
%load_ext autoreload 
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [58]:
import numpy as np
import scipy
import matplotlib.pyplot as plt
import pdb
import torch

In [59]:
from neurosim.models.ssr import StateSpaceRealization as SSR

In [60]:
from dca_research.kca import KalmanComponentsAnalysis as KCA
#from dca_research.kca import ObserverControllerComponentsAnalysis as OCCA

In [61]:
from dca.cov_util import calc_cov_from_cross_cov_mats

In [62]:
size = 20
A = 1/np.sqrt(1.2 * size) * np.random.normal(size=(size, size))
while max(np.abs(np.linalg.eigvals(A)) > 0.99):
    A = 1/np.sqrt(1.2 * size) * np.random.normal(size=(size, size))

B = np.eye(size)
C = np.eye(size)
ssr = SSR(A=A, B=B, C=C)
ssr.solve_min_phase()

In [63]:
V = scipy.stats.ortho_group.rvs(size)[:, 0:1]

In [64]:
c, cp, cpf = ssr.mmse_cov(5, proj=V.T)
c, cf, cfp = ssr.acausal_mmse_cov(5, proj=V.T)
ssrproj = SSR(A=A, B=B, C=V.T)
ssrproj.solve_min_phase()
ssrproj.solve_max_phase()

#### Forward time

In [65]:
np.trace(cpf.T @ np.linalg.inv(cp) @ cpf)

95.51115338985589

In [66]:
np.trace(ssrproj.Pmin)

100.81942177896529

In [67]:
np.trace(c - cpf.T @ np.linalg.inv(cp) @ cpf)

77.52397525305304

In [68]:
np.trace(ssr.P - ssrproj.Pmin)

72.21570686394365

#### Reverse time

Note that here, we can only measure the error in the actual projected space. The reverse time state space is 
observationally inacessible

In [69]:
np.trace(V.T @ cfp.T @ np.linalg.inv(cf) @ cfp @ V)

4.351784102581286

In [70]:
np.trace(ssrproj.Cbar @ ssrproj.Pmax @ ssrproj.Cbar.T)

4.550879012500852

In [71]:
np.trace(V.T @ ssr.P @ V - ssrproj.Cbar @ ssrproj.Pmax @ ssrproj.Cbar.T)

3.6267538859094994

In [72]:
np.trace(V.T @ c @ V - V.T @ cfp.T @ np.linalg.inv(cf) @ cfp @ V)

3.8258487958290655

In [64]:
# These are still off slightly...

### Test implementation of MMSE losses

In [33]:
from dca_research.kca import KalmanComponentsAnalysis as KCA
#from dca_research.kca import ObserverControllerComponentsAnalysis as OCCA
from dca_research.kca import calc_mmse_from_cross_cov_mats

In [46]:
size = 20
A = 1/np.sqrt(1.2 * size) * np.random.normal(size=(size, size))
while max(np.abs(np.linalg.eigvals(A)) > 0.99):
    A = 1/np.sqrt(1.2 * size) * np.random.normal(size=(size, size))

B = np.eye(size)
C = np.eye(size)
ssr = SSR(A=A, B=B, C=C)
ssr.solve_min_phase()

In [47]:
V = scipy.stats.ortho_group.rvs(size)[:, 0:4]

In [48]:
x = ssr.trajectory(int(1e5), burnoff=True)

In [49]:
c, cp, cpf = ssr.mmse_cov(5, proj=V.T)
c, cf, cfp = ssr.acausal_mmse_cov(5, proj=V.T)

#### Forward time KCA

In [50]:
kcamodel = KCA(d=1, T=5)
kcamodel.estimate_data_statistics(x)

<dca_research.kca.KalmanComponentsAnalysis at 0x7f1cb42def90>

In [51]:
calc_mmse_from_cross_cov_mats(kcamodel.cross_covs, proj=torch.tensor(V))

tensor(47.1619, dtype=torch.float64)

In [52]:
c, cp, cpf = ssr.mmse_cov(5, proj=V.T)

In [53]:
np.trace((c - cpf.T @ np.linalg.inv(cp) @ cpf))

47.08889385601768

#### Reverse time KCA

In [54]:
calc_mmse_from_cross_cov_mats(torch.transpose(kcamodel.cross_covs, 1, 2), proj=torch.tensor(V))

tensor(44.1031, dtype=torch.float64)

In [55]:
c, cf, cfp = ssr.acausal_mmse_cov(5, proj=V.T)

In [56]:
np.trace((c - cfp.T @ np.linalg.inv(cf) @ cfp))

44.034182993761284

#### OCCA

In [38]:
occamodel = OCCA(d=1, T=5, project_mmse=True)
occamodel.estimate_data_statistics(x)

<dca_research.kca.ObserverControllerComponentsAnalysis at 0x7ff8705c49d0>

In [39]:
occamodel.score(coef=V)

array(22.51536198)

In [40]:
np.trace(V.T @ ((c - cpf.T @ np.linalg.inv(cp) @ cpf) @ V @ V.T @ (c - cfp.T @ np.linalg.inv(cf) @ cfp)) @ V)

22.548206418091414

In [42]:
occamodel = OCCA(d=1, T=5, project_mmse=True, loss_type='sum', loss_reg_vals=[0.4, 0.6])
occamodel.estimate_data_statistics(x)
occamodel.score(coef=V)

  mmse_rev = calc_mmse_from_cross_cov_mats(torch.transpose(torch.tensor(cross_covs), 1, 2), torch.tensor(coef),


9.732918801261896

In [45]:
np.trace(0.4 * V.T @ (c - cpf.T @ np.linalg.inv(cp) @ cpf) @ V +  0.6 * V.T @ (c - cfp.T @ np.linalg.inv(cf) @ cfp) @ V)

9.737837747362015

### Verify that the different objective functions optimize

In [9]:
x = ssr.trajectory(5 * int(1e4), burnoff=True)

In [10]:
# Check # 1: Does mmse_from_cross_cov_mats match the trace of Pmin

In [11]:
kcamodel = OCCA(d=3, T=5, verbose=True)
kcamodel.estimate_data_statistics(x)

INFO:Model:Starting cross covariance estimate.
INFO:Model:Cross covariance estimate took 0.0 minutes.


<dca_research.kca.ObserverControllerComponentsAnalysis at 0x7fe2992b4490>

In [12]:
kcamodel._fit_projection(record_V=True)

  return _VF.chain_matmul(matrices)  # type: ignore[attr-defined]
INFO:Model:Loss 18.7601, PI: -18.7601 nats, reg: 0.0
 This problem is unconstrained.
INFO:Model:Loss 12.5371, PI: -1.8612 nats, reg: 10.676
INFO:Model:Loss 8.6863, PI: -3.5803 nats, reg: 5.106
INFO:Model:Loss 6.253, PI: -5.227 nats, reg: 1.0259
INFO:Model:Loss 5.3494, PI: -5.0353 nats, reg: 0.3141
INFO:Model:Loss 4.5864, PI: -4.1693 nats, reg: 0.4171
INFO:Model:Loss 4.1935, PI: -3.4537 nats, reg: 0.7398
INFO:Model:Loss 3.8607, PI: -3.342 nats, reg: 0.5187
INFO:Model:Loss 3.6693, PI: -3.4385 nats, reg: 0.2308
INFO:Model:Loss 3.4498, PI: -2.9786 nats, reg: 0.4711
INFO:Model:Loss 3.3183, PI: -2.9189 nats, reg: 0.3994
INFO:Model:Loss 3.2567, PI: -2.8803 nats, reg: 0.3764
INFO:Model:Loss 3.212, PI: -2.8466 nats, reg: 0.3655
INFO:Model:Loss 3.1864, PI: -2.9171 nats, reg: 0.2693
INFO:Model:Loss 3.1416, PI: -2.8611 nats, reg: 0.2805
INFO:Model:Loss 3.1112, PI: -2.8255 nats, reg: 0.2857
INFO:Model:Loss 3.0699, PI: -2.8017 nats, r

RUNNING THE L-BFGS-B CODE

           * * *

Machine precision = 2.220D-16
 N =           60     M =           10

At X0         0 variables are exactly at the bounds

At iterate    0    f=  1.87601D+01    |proj g|=  1.44758D+01

At iterate    1    f=  1.25371D+01    |proj g|=  7.94052D+00

At iterate    2    f=  8.68630D+00    |proj g|=  5.12165D+00

At iterate    3    f=  6.25296D+00    |proj g|=  3.20399D+00

At iterate    4    f=  5.34935D+00    |proj g|=  3.03560D+00

At iterate    5    f=  4.58637D+00    |proj g|=  1.67427D+00

At iterate    6    f=  4.19353D+00    |proj g|=  9.69315D-01

At iterate    7    f=  3.86071D+00    |proj g|=  9.96731D-01

At iterate    8    f=  3.66931D+00    |proj g|=  1.86937D+00

At iterate    9    f=  3.44976D+00    |proj g|=  1.20710D+00

At iterate   10    f=  3.31830D+00    |proj g|=  8.93396D-01

At iterate   11    f=  3.25670D+00    |proj g|=  8.53800D-01

At iterate   12    f=  3.21201D+00    |proj g|=  8.09844D-01

At iterate   13    f=  3.1

INFO:Model:Loss 2.7357, PI: -2.4808 nats, reg: 0.2549
INFO:Model:Loss 2.7355, PI: -2.4844 nats, reg: 0.2511
INFO:Model:Loss 2.7353, PI: -2.4843 nats, reg: 0.2509
INFO:Model:Loss 2.735, PI: -2.487 nats, reg: 0.248
INFO:Model:Loss 2.7347, PI: -2.4826 nats, reg: 0.2521
INFO:Model:Loss 2.7343, PI: -2.4852 nats, reg: 0.2491
INFO:Model:Loss 2.7341, PI: -2.4858 nats, reg: 0.2482
INFO:Model:Loss 2.7339, PI: -2.4851 nats, reg: 0.2488
INFO:Model:Loss 2.7337, PI: -2.4874 nats, reg: 0.2463
INFO:Model:Loss 2.7336, PI: -2.4819 nats, reg: 0.2517
INFO:Model:Loss 2.7335, PI: -2.4828 nats, reg: 0.2507
INFO:Model:Loss 2.7334, PI: -2.4838 nats, reg: 0.2496
INFO:Model:Loss 2.7334, PI: -2.4847 nats, reg: 0.2487
INFO:Model:Loss 2.7332, PI: -2.4836 nats, reg: 0.2496
INFO:Model:Loss 2.7331, PI: -2.4864 nats, reg: 0.2467
INFO:Model:Loss 2.7331, PI: -2.4848 nats, reg: 0.2483
INFO:Model:Loss 2.733, PI: -2.4829 nats, reg: 0.2501
INFO:Model:Loss 2.7329, PI: -2.4824 nats, reg: 0.2506
INFO:Model:Loss 2.7328, PI: -2.4


At iterate   61    f=  2.73571D+00    |proj g|=  6.93416D-02

At iterate   62    f=  2.73545D+00    |proj g|=  3.53489D-02

At iterate   63    f=  2.73526D+00    |proj g|=  2.84581D-02

At iterate   64    f=  2.73498D+00    |proj g|=  4.39290D-02

At iterate   65    f=  2.73469D+00    |proj g|=  7.81285D-02

At iterate   66    f=  2.73435D+00    |proj g|=  3.78337D-02

At iterate   67    f=  2.73406D+00    |proj g|=  2.51410D-02

At iterate   68    f=  2.73388D+00    |proj g|=  3.55306D-02

At iterate   69    f=  2.73371D+00    |proj g|=  5.82883D-02

At iterate   70    f=  2.73358D+00    |proj g|=  3.81101D-02

At iterate   71    f=  2.73350D+00    |proj g|=  1.60421D-02

At iterate   72    f=  2.73343D+00    |proj g|=  2.39639D-02

At iterate   73    f=  2.73335D+00    |proj g|=  3.47156D-02

At iterate   74    f=  2.73320D+00    |proj g|=  4.21606D-02

At iterate   75    f=  2.73313D+00    |proj g|=  2.79844D-02

At iterate   76    f=  2.73306D+00    |proj g|=  1.32362D-02

At iter

INFO:Model:Loss 2.7316, PI: -2.4821 nats, reg: 0.2495
INFO:Model:Loss 2.7316, PI: -2.4828 nats, reg: 0.2488
INFO:Model:Loss 2.7316, PI: -2.4826 nats, reg: 0.249
INFO:Model:Loss 2.7316, PI: -2.4824 nats, reg: 0.2491
INFO:Model:Loss 2.7316, PI: -2.4821 nats, reg: 0.2495
INFO:Model:Loss 2.7316, PI: -2.4818 nats, reg: 0.2498
INFO:Model:Loss 2.7316, PI: -2.4819 nats, reg: 0.2496
INFO:Model:Loss 2.7315, PI: -2.4822 nats, reg: 0.2494
INFO:Model:Loss 2.7315, PI: -2.4822 nats, reg: 0.2494
INFO:Model:Loss 2.7315, PI: -2.483 nats, reg: 0.2486
INFO:Model:Loss 2.7315, PI: -2.4826 nats, reg: 0.249
INFO:Model:Loss 2.7315, PI: -2.4819 nats, reg: 0.2496
INFO:Model:Loss 2.7315, PI: -2.4818 nats, reg: 0.2498
INFO:Model:Loss 2.7315, PI: -2.4812 nats, reg: 0.2503
INFO:Model:Loss 2.7315, PI: -2.482 nats, reg: 0.2495
INFO:Model:Loss 2.7314, PI: -2.4819 nats, reg: 0.2495
INFO:Model:Loss 2.7314, PI: -2.4834 nats, reg: 0.248
INFO:Model:Loss 2.7314, PI: -2.482 nats, reg: 0.2494
INFO:Model:Loss 2.7314, PI: -2.483


At iterate  122    f=  2.73159D+00    |proj g|=  5.55125D-03

At iterate  123    f=  2.73159D+00    |proj g|=  1.02927D-02

At iterate  124    f=  2.73158D+00    |proj g|=  7.56396D-03

At iterate  125    f=  2.73158D+00    |proj g|=  6.56273D-03

At iterate  126    f=  2.73157D+00    |proj g|=  4.95202D-03

At iterate  127    f=  2.73156D+00    |proj g|=  1.01601D-02

At iterate  128    f=  2.73155D+00    |proj g|=  5.79497D-03

At iterate  129    f=  2.73155D+00    |proj g|=  3.84902D-03

At iterate  130    f=  2.73154D+00    |proj g|=  6.82543D-03

At iterate  131    f=  2.73154D+00    |proj g|=  1.14694D-02

At iterate  132    f=  2.73153D+00    |proj g|=  8.91210D-03

At iterate  133    f=  2.73152D+00    |proj g|=  7.78691D-03

At iterate  134    f=  2.73151D+00    |proj g|=  9.32235D-03

At iterate  135    f=  2.73148D+00    |proj g|=  1.47726D-02

At iterate  136    f=  2.73147D+00    |proj g|=  1.89478D-02

At iterate  137    f=  2.73145D+00    |proj g|=  1.27020D-02

At iter

INFO:Model:Loss 2.7291, PI: -2.4807 nats, reg: 0.2483
INFO:Model:Loss 2.729, PI: -2.4824 nats, reg: 0.2466
INFO:Model:Loss 2.7289, PI: -2.4829 nats, reg: 0.246
INFO:Model:Loss 2.7288, PI: -2.4835 nats, reg: 0.2453
INFO:Model:Loss 2.7287, PI: -2.4819 nats, reg: 0.2469
INFO:Model:Loss 2.7287, PI: -2.4793 nats, reg: 0.2494
INFO:Model:Loss 2.7287, PI: -2.4792 nats, reg: 0.2494
INFO:Model:Loss 2.7286, PI: -2.4801 nats, reg: 0.2485
INFO:Model:Loss 2.7286, PI: -2.4774 nats, reg: 0.2511
INFO:Model:Loss 2.7285, PI: -2.4801 nats, reg: 0.2484
INFO:Model:Loss 2.7285, PI: -2.4798 nats, reg: 0.2487
INFO:Model:Loss 2.7284, PI: -2.4806 nats, reg: 0.2478
INFO:Model:Loss 2.7283, PI: -2.4792 nats, reg: 0.2491
INFO:Model:Loss 2.7282, PI: -2.4814 nats, reg: 0.2468
INFO:Model:Loss 2.7282, PI: -2.4803 nats, reg: 0.2479
INFO:Model:Loss 2.7281, PI: -2.4804 nats, reg: 0.2477
INFO:Model:Loss 2.7281, PI: -2.4781 nats, reg: 0.2499
INFO:Model:Loss 2.728, PI: -2.4798 nats, reg: 0.2482
INFO:Model:Loss 2.728, PI: -2.4


At iterate  179    f=  2.72905D+00    |proj g|=  1.07523D-02

At iterate  180    f=  2.72898D+00    |proj g|=  1.40732D-02

At iterate  181    f=  2.72890D+00    |proj g|=  3.36812D-02

At iterate  182    f=  2.72880D+00    |proj g|=  2.06109D-02

At iterate  183    f=  2.72874D+00    |proj g|=  3.44375D-02

At iterate  184    f=  2.72868D+00    |proj g|=  2.28102D-02

At iterate  185    f=  2.72866D+00    |proj g|=  1.21866D-02

At iterate  186    f=  2.72860D+00    |proj g|=  2.21246D-02

At iterate  187    f=  2.72856D+00    |proj g|=  2.36048D-02

At iterate  188    f=  2.72852D+00    |proj g|=  2.24575D-02

At iterate  189    f=  2.72847D+00    |proj g|=  1.39685D-02

At iterate  190    f=  2.72839D+00    |proj g|=  1.51750D-02

At iterate  191    f=  2.72830D+00    |proj g|=  2.59364D-02

At iterate  192    f=  2.72821D+00    |proj g|=  1.71124D-02

At iterate  193    f=  2.72815D+00    |proj g|=  1.51986D-02

At iterate  194    f=  2.72809D+00    |proj g|=  2.42883D-02

At iter

INFO:Model:Loss 2.7271, PI: -2.4802 nats, reg: 0.2469



At iterate  240    f=  2.72713D+00    |proj g|=  2.60155D-02

           * * *

Tit   = total number of iterations
Tnf   = total number of function evaluations
Tnint = total number of segments explored during Cauchy searches
Skip  = number of BFGS updates skipped
Nact  = number of active bounds at final generalized Cauchy point
Projg = norm of the final projected gradient
F     = final function value

           * * *

   N    Tit     Tnf  Tnint  Skip  Nact     Projg        F
   60    240    258      1     0     0   2.602D-02   2.727D+00
  F =   2.7271329720580906     

CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH             


  _, mmse_covf = calc_mmse_from_cross_cov_mats(cross_covs, torch.tensor(coef), project_mmse=self.project_mmse,
  _, mmse_covr = calc_mmse_from_cross_cov_mats(torch.transpose(torch.tensor(cross_covs), 1, 2), torch.tensor(coef),


(array([[-0.1950761 ,  0.03143931,  0.14152738],
        [-0.33140969, -0.09086171,  0.00367122],
        [ 0.24594925,  0.09106845,  0.14109807],
        [-0.03247709, -0.0273465 , -0.19348932],
        [ 0.16886083, -0.34532501,  0.01554556],
        [-0.24535984, -0.14561671,  0.17328133],
        [-0.11065045, -0.38244838,  0.23242596],
        [-0.05990538, -0.08416205,  0.06627354],
        [-0.24061015, -0.01425001,  0.43219262],
        [-0.17462439,  0.12594472,  0.25003033],
        [-0.08311169, -0.29367468, -0.09974349],
        [-0.35252028,  0.21404957,  0.32584736],
        [ 0.45633657,  0.3226987 ,  0.15794856],
        [-0.17400557, -0.07253291,  0.08667217],
        [-0.28960933,  0.03484542, -0.42356837],
        [-0.05864577,  0.07443087,  0.09561937],
        [-0.01548609, -0.30488717, -0.31464442],
        [ 0.29870288, -0.56889702,  0.29010993],
        [-0.2097272 , -0.0584882 , -0.24542589],
        [-0.08468705, -0.0601094 ,  0.06391368]]),
 -2.99996764114255

In [9]:
kcamodel = OCCA(d=3, T=5, verbose=True, loss_type='sum', loss_reg_vals=[0.4, 0.6], project_mmse=True)

In [14]:
size = 20
A = 1/np.sqrt(1.2 * size) * np.random.normal(size=(size, size))
while max(np.abs(np.linalg.eigvals(A)) > 0.99):
    A = 1/np.sqrt(1.2 * size) * np.random.normal(size=(size, size))

B = np.eye(size)
C = np.eye(size)
ssr = SSR(A=A, B=B, C=C)
ssr.solve_min_phase()
x = ssr.trajectory(5 * int(1e4), burnoff=True)

In [15]:
kcamodel.estimate_data_statistics(x)
kcamodel._fit_projection(record_V=True)

INFO:Model:Starting cross covariance estimate.
INFO:Model:Cross covariance estimate took 0.0 minutes.
INFO:Model:Loss 8.7656, PI: -8.7656 nats, reg: 0.0
INFO:Model:Loss 7.5686, PI: -6.355 nats, reg: 1.2136
INFO:Model:Loss 6.2125, PI: -5.9491 nats, reg: 0.2633
INFO:Model:Loss 5.8223, PI: -5.7229 nats, reg: 0.0994
INFO:Model:Loss 5.4076, PI: -5.0028 nats, reg: 0.4048
INFO:Model:Loss 5.0782, PI: -4.9058 nats, reg: 0.1724
INFO:Model:Loss 4.7601, PI: -4.7017 nats, reg: 0.0584
INFO:Model:Loss 4.5152, PI: -4.402 nats, reg: 0.1132
INFO:Model:Loss 4.2165, PI: -3.9553 nats, reg: 0.2613
INFO:Model:Loss 4.0782, PI: -3.9416 nats, reg: 0.1366
INFO:Model:Loss 3.9637, PI: -3.9169 nats, reg: 0.0469
INFO:Model:Loss 3.8526, PI: -3.7788 nats, reg: 0.0739
INFO:Model:Loss 3.779, PI: -3.5929 nats, reg: 0.186
INFO:Model:Loss 3.7288, PI: -3.5196 nats, reg: 0.2092
INFO:Model:Loss 3.6769, PI: -3.5926 nats, reg: 0.0842
INFO:Model:Loss 3.6358, PI: -3.5577 nats, reg: 0.0782
INFO:Model:Loss 3.5638, PI: -3.4675 nats,

INFO:Model:Loss 2.9465, PI: -2.8699 nats, reg: 0.0766
INFO:Model:Loss 2.9465, PI: -2.8697 nats, reg: 0.0768
INFO:Model:Loss 2.9464, PI: -2.8697 nats, reg: 0.0767
INFO:Model:Loss 2.9464, PI: -2.8699 nats, reg: 0.0765
INFO:Model:Loss 2.9464, PI: -2.8703 nats, reg: 0.0761
INFO:Model:Loss 2.9464, PI: -2.8706 nats, reg: 0.0758
INFO:Model:Loss 2.9464, PI: -2.8705 nats, reg: 0.0759
INFO:Model:Loss 2.9463, PI: -2.8698 nats, reg: 0.0765
INFO:Model:Loss 2.9463, PI: -2.8697 nats, reg: 0.0766
INFO:Model:Loss 2.9463, PI: -2.8699 nats, reg: 0.0764
INFO:Model:Loss 2.9463, PI: -2.8701 nats, reg: 0.0762
INFO:Model:Loss 2.9463, PI: -2.8703 nats, reg: 0.076
INFO:Model:Loss 2.9463, PI: -2.8703 nats, reg: 0.076
INFO:Model:Loss 2.9463, PI: -2.8701 nats, reg: 0.0761
INFO:Model:Loss 2.9462, PI: -2.87 nats, reg: 0.0762
INFO:Model:Loss 2.9462, PI: -2.87 nats, reg: 0.0762
INFO:Model:Loss 2.9462, PI: -2.87 nats, reg: 0.0762
INFO:Model:Loss 2.9462, PI: -2.87 nats, reg: 0.0762
INFO:Model:Loss 2.9462, PI: -2.8701 na

INFO:Model:Loss 2.9452, PI: -2.8691 nats, reg: 0.0761
INFO:Model:Loss 2.9452, PI: -2.869 nats, reg: 0.0762
INFO:Model:Loss 2.9452, PI: -2.869 nats, reg: 0.0762
INFO:Model:Loss 2.9452, PI: -2.8691 nats, reg: 0.0761
INFO:Model:Loss 2.9452, PI: -2.8692 nats, reg: 0.076
INFO:Model:Loss 2.9452, PI: -2.8692 nats, reg: 0.0759
INFO:Model:Loss 2.9452, PI: -2.8691 nats, reg: 0.076
INFO:Model:Loss 2.9452, PI: -2.869 nats, reg: 0.0762
INFO:Model:Loss 2.9452, PI: -2.8689 nats, reg: 0.0762
INFO:Model:Loss 2.9452, PI: -2.869 nats, reg: 0.0761
INFO:Model:Loss 2.9452, PI: -2.8691 nats, reg: 0.0761
INFO:Model:Loss 2.9451, PI: -2.8691 nats, reg: 0.076
INFO:Model:Loss 2.9451, PI: -2.8691 nats, reg: 0.0761
INFO:Model:Loss 2.9451, PI: -2.8688 nats, reg: 0.0763
INFO:Model:Loss 2.9451, PI: -2.8688 nats, reg: 0.0764
INFO:Model:Loss 2.9451, PI: -2.8691 nats, reg: 0.076
INFO:Model:Loss 2.9451, PI: -2.8693 nats, reg: 0.0758
INFO:Model:Loss 2.9451, PI: -2.8696 nats, reg: 0.0755
INFO:Model:Loss 2.9451, PI: -2.8696 

INFO:Model:Loss 2.9447, PI: -2.8686 nats, reg: 0.076
INFO:Model:Loss 2.9447, PI: -2.8688 nats, reg: 0.0759
INFO:Model:Loss 2.9447, PI: -2.8689 nats, reg: 0.0758
INFO:Model:Loss 2.9447, PI: -2.8688 nats, reg: 0.0759
INFO:Model:Loss 2.9447, PI: -2.8687 nats, reg: 0.0759
INFO:Model:Loss 2.9447, PI: -2.8685 nats, reg: 0.0762
INFO:Model:Loss 2.9447, PI: -2.8682 nats, reg: 0.0764
INFO:Model:Loss 2.9447, PI: -2.8681 nats, reg: 0.0765
INFO:Model:Loss 2.9447, PI: -2.8682 nats, reg: 0.0765
INFO:Model:Loss 2.9447, PI: -2.8687 nats, reg: 0.076
INFO:Model:Loss 2.9447, PI: -2.8689 nats, reg: 0.0758
INFO:Model:Loss 2.9447, PI: -2.8688 nats, reg: 0.0759
INFO:Model:Loss 2.9447, PI: -2.8685 nats, reg: 0.0762
INFO:Model:Loss 2.9447, PI: -2.8684 nats, reg: 0.0763
INFO:Model:Loss 2.9447, PI: -2.8684 nats, reg: 0.0763
INFO:Model:Loss 2.9446, PI: -2.8684 nats, reg: 0.0762
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.076
INFO:Model:Loss 2.9446, PI: -2.

INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: 

INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: 

INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: -2.8686 nats, reg: 0.0761
INFO:Model:Loss 2.9446, PI: 

Optimization terminated successfully.
         Current function value: 2.944629
         Iterations: 991
         Function evaluations: 1583
         Gradient evaluations: 1583


(array([[-0.47521161,  0.24675635,  0.05394384],
        [ 0.00469634,  0.21518144, -0.34787006],
        [ 0.03562444,  0.1526713 , -0.25776778],
        [ 0.29280775, -0.1309613 , -0.15325819],
        [ 0.03445893,  0.09496863, -0.06297155],
        [-0.10014459,  0.38229097,  0.12706669],
        [ 0.08541727,  0.16190476, -0.01472211],
        [-0.17899528,  0.31033493,  0.2505881 ],
        [-0.40294565,  0.07756954, -0.13029911],
        [-0.21147585, -0.43964001,  0.24256519],
        [-0.37603466, -0.35233679, -0.06149652],
        [ 0.00293277,  0.23941007,  0.25263579],
        [ 0.12467206,  0.16182654,  0.04154164],
        [-0.09283411, -0.25555634,  0.18949846],
        [ 0.00808357, -0.17910535,  0.04595092],
        [-0.26167742, -0.13782187, -0.58435522],
        [ 0.25158835, -0.09551212,  0.28786729],
        [ 0.31133901, -0.15601687, -0.16576373],
        [-0.1384585 , -0.09958326,  0.1013061 ],
        [-0.12647361, -0.08917753,  0.24036077]]),
 -3.02069027372522