In [10]:
# %% Imports
import jax
from jax import jit, numpy as jnp, lax, random
from functools import partial
from modax.linear_model import SBL

from sklearn.linear_model import ARDRegression


%load_ext autoreload 
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
data = jnp.load('test_data.npy', allow_pickle=True).item()
y, X = data['y'], data['X']

X_normed = X / jnp.linalg.norm(X, axis=0)

In [49]:
prior = SBL.SBLregression(X_normed, y)
print(prior)

[1.3595734e+03 5.6380401e+00 7.0157982e-02 5.3128362e+00 8.4873238e+01
 5.7919085e-02 1.2744015e+01 2.6220288e+00 9.1703812e+01 8.0625182e-01
 1.2733847e+01 4.6117731e+02 2.7395503e+05]


In [62]:
prior[:-1, None]

DeviceArray([[1.3595734e+03],
             [5.6380401e+00],
             [7.0157982e-02],
             [5.3128362e+00],
             [8.4873238e+01],
             [5.7919085e-02],
             [1.2744015e+01],
             [2.6220288e+00],
             [9.1703812e+01],
             [8.0625182e-01],
             [1.2733847e+01],
             [4.6117731e+02]], dtype=float32)

In [64]:
evidence, mn = SBL.evidence(X_normed, y, prior, (0.0, 0.0))
print(evidence, mn)

5710.3774 [[ 0.02696981]
 [-0.42118847]
 [ 3.7753148 ]
 [ 0.43370676]
 [-0.1081934 ]
 [-4.154943  ]
 [ 0.27753925]
 [-0.6172552 ]
 [ 0.10374784]
 [-1.1136498 ]
 [-0.27886343]
 [-0.04382658]]


In [53]:
reg = ARDRegression(compute_score=True, 
                    fit_intercept=False, 
                    alpha_1=0.0, 
                    alpha_2=0.0,
                    lambda_1=0.0,
                    lambda_2=0.0, 
                   tol=1e-4)

In [54]:
reg.fit(X_normed, y.squeeze())

ARDRegression(alpha_1=0.0, alpha_2=0.0, compute_score=True, fit_intercept=False,
              lambda_1=0.0, lambda_2=0.0, tol=0.0001)

In [55]:
reg.alpha_

253961.6121603202

In [56]:
reg.lambda_[:, None]

array([[3.44620966e+03],
       [5.64050771e+00],
       [6.90140592e-02],
       [5.10886248e+00],
       [6.90789256e+02],
       [5.74954487e-02],
       [7.93799507e+01],
       [3.16854132e+00],
       [1.09172455e+04],
       [9.06506097e-01],
       [2.52464024e+01],
       [1.41182674e+02]])

In [57]:
reg.scores_

[4063.464736735098,
 5674.823066162173,
 5682.251057953032,
 5682.014236012988,
 5682.012401103186,
 5682.012200107362,
 5682.012173293515]

In [60]:
reg.coef_[:, None]

array([[ 0.01683335],
       [-0.42096037],
       [ 3.80653338],
       [ 0.44227556],
       [-0.03790536],
       [-4.17038708],
       [ 0.10705643],
       [-0.56129879],
       [ 0.        ],
       [-1.05015642],
       [-0.19732784],
       [-0.08252294]])

In [65]:
mn

DeviceArray([[ 0.02696981],
             [-0.42118847],
             [ 3.7753148 ],
             [ 0.43370676],
             [-0.1081934 ],
             [-4.154943  ],
             [ 0.27753925],
             [-0.6172552 ],
             [ 0.10374784],
             [-1.1136498 ],
             [-0.27886343],
             [-0.04382658]], dtype=float32)

In [44]:
reg.fit(X, y.squeeze()).coef_[:, None]

array([[ 0.        ],
       [-0.02956439],
       [ 0.09587161],
       [ 0.        ],
       [-0.0193805 ],
       [-0.87973161],
       [ 0.        ],
       [ 0.        ],
       [ 0.0336494 ],
       [-0.08981044],
       [ 0.        ],
       [ 0.        ]])

In [47]:
reg.lambda_[:, None]

array([[1.32819420e+06],
       [1.13433641e+03],
       [1.08796093e+02],
       [3.14849151e+05],
       [2.60501863e+03],
       [1.29192640e+00],
       [1.74882926e+05],
       [1.69096594e+05],
       [8.56825597e+02],
       [1.22728101e+02],
       [8.57527675e+04],
       [1.47666149e+06]])

In [48]:
prior = SBL.SBLregression(X, y)
print(prior[:, None])

[[1.00000000e+06]
 [5.76688232e+02]
 [1.02008888e+02]
 [3.80333156e+05]
 [4.80305176e+03]
 [1.72114873e+00]
 [1.24310928e+04]
 [1.20778625e+05]
 [1.80649500e+03]
 [1.45358677e+01]
 [1.33389336e+04]
 [1.00000000e+06]
 [2.73934031e+05]]


Not sure it's correct; it's off more than a numerical factor. Could be due to my not removing terms, but that seems hardly the case... Let's use it for now an come back to it later.


In [77]:
def update(prior_params, X, y, hyper_prior_params):
    # Unpacking parameters
    alpha_prev, beta_prev = prior_params[:-1], prior_params[-1]
    a, b = hyper_prior_params

    # Calculating intermediate matrices
    n_samples, n_terms = X.shape
    Sigma = jnp.linalg.inv(beta_prev * X.T @ X + jnp.diag(alpha_prev))
    mu = beta_prev * Sigma @ X.T @ y
    gamma = 1 - alpha_prev * jnp.diag(Sigma)


    # Updating
    cap = 1e6
    alpha = jnp.minimum(gamma / (mu ** 2).squeeze(), cap)
    beta = (n_samples - jnp.sum(gamma) + 2 * a) / (
        jnp.sum((y - X @ mu) ** 2) + 2 * b
    )

    return jnp.concatenate([alpha, beta[None]], axis=0)

In [78]:
prior_params_init = jnp.concatenate(
            (jnp.ones((X.shape[1],)), 1.0 / jnp.var(y)[None]), axis=0
        )

In [79]:
update(prior_params_init, X_normed, y, (0.0, 0.0))

[8.1802086e+01 4.3041584e-01 8.7814875e-02 6.4752374e+00 1.4092141e+03
 9.0157993e-02 1.6357720e+00 1.5733417e+01 9.3208799e+00 1.7367724e-01
 7.7745576e+00 2.4815776e+00]


DeviceArray([8.1802086e+01, 4.3041584e-01, 8.7814875e-02, 6.4752374e+00,
             1.4092141e+03, 9.0157993e-02, 1.6357720e+00, 1.5733417e+01,
             9.3208799e+00, 1.7367724e-01, 7.7745576e+00, 2.4815776e+00,
             9.3036387e+03], dtype=float32)