# Generalized representer formula

In [2]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from Kernels import get_gaussianRBF
from KernelTools import *
from EquationModel import CholInducedRKHS

  from .autonotebook import tqdm as notebook_tqdm


## Introduction to `CholInducedRKHS`

Given

- Kernel: $K: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}$,
- Basis points: $X = \left\{x_1,\dots,x_M\right\} \subset \mathbb{R}^d$,
- Basis operators: $\phi = \left[\phi_1,\dots,\phi_Q \right]$,

we can represent an element of the correponding RKHS $H_K$ of $K$ as

$$
\begin{align*}
u(x) &= \sum_{m=1}^M \phi_1 \left(K(x,x')\right)_{x' = x_m} \beta_{1,m} + \dots + \sum_{m=1}^M \phi_Q \left(K(x,x')\right)_{x' = x_m} \beta_{Q,m} \\
&:= \sum_{m=1}^M  K(x,\phi_{1,m}) \beta_{1,m} + \dots + \sum_{m=1}^M K(x,\phi_{Q,m}) \beta_{Q,m} \\
&:= K(x,\delta_X \circ \phi_{1})^\top \beta_{1} + \dots + K(x,\delta_X \circ \phi_{Q})^\top \beta_{Q} \\
&:= K(x,\phi)^\top \beta \\
\end{align*} 
$$
where $\{\beta_1,\dots,\beta_Q\} \subset \mathbb{R}^M$ and $\beta = \operatorname{vec}[\beta_1,\dots,\beta_Q] \in \mathbb{R}^{MQ}$, where also 

$$
\beta = \left(L^\top\right)^{-1} \alpha
$$

and $K(\phi,\phi) = L L^\top$ with

$$
K(\phi,\phi) :=
\begin{bmatrix}
K(\phi_1,\phi_1) & \dots &K(\phi_1,\phi_Q) \\
\vdots & \ddots & \vdots \\
K(\phi_Q,\phi_1) & \dots &K(\phi_Q,\phi_Q) \\
\end{bmatrix}.
$$

In summary, our model given the basis points, basis operators and a kernel is given by,
$$
u(x) = K(x,\phi)^T \left(L^\top\right)^{-1} \alpha
$$

In [12]:
# Given
K = get_gaussianRBF(0.1)
basis_pts = jax.random.normal(jax.random.PRNGKey(0), shape = (50,2)) # M = 50, d = 2
basis_ops = (eval_k, dx_k, dxx_k) # Q = 3

We can use the class `CholInducedRKHS` to create the function $u$

In [4]:
# Create an object
u_model = CholInducedRKHS(
    basis_pts,
    basis_ops,
    K 
    )

If we know the value of  $\alpha$

In [5]:
# alpha 
alpha = jnp.ones(len(basis_pts)*len(basis_ops))

> Notice that `u_model.num_params = len(basis_pts)*len(basis_ops)`

In order to evaluate $u$ at a new set of points $\bar X = \{x_1,\dots,x_N\} \subset \mathbb{R}^d$, we call the method `point_evaluate`

In [6]:
# Evaluation pts
eval_pts = jax.random.normal(jax.random.PRNGKey(0), shape = (200,2)) # N = 200
# Evaluate u at eval_pts
u_model.point_evaluate(eval_pts,alpha)

Array([ 1.02790385e-01,  2.23019689e-01,  1.39417505e+00,  1.53730869e+00,
        7.65631157e-13,  1.27246284e+00,  3.59051526e-01,  1.03798795e+00,
        2.25111872e-01,  1.45417929e-01,  2.90732443e-01,  1.75836647e+00,
        2.74264395e-01,  1.57710910e-01,  5.71521744e-02,  1.41322706e-03,
        4.13329840e-01,  1.07529664e+00,  6.71443966e-14,  1.45550012e-16,
        5.52786514e-02,  1.04689646e+00,  4.67338413e-02,  3.10636305e-19,
        3.33243981e-03,  3.81695614e-35,  5.02751209e-05,  2.80149698e-01,
        1.69907138e-02,  0.00000000e+00,  9.74068940e-01,  6.98249400e-01,
        1.10536337e+00,  7.99819708e-01,  1.96490530e-33,  1.04360925e-02,
        1.53633437e-07,  1.40858218e-01,  4.11508530e-02,  3.48224282e-01,
        9.64362249e-02,  2.11817653e-24,  1.32595956e-01,  3.28202329e-14,
        6.04260981e-01,  6.29478774e-04,  7.58219361e-02,  1.40270126e+00,
        4.69731271e-01,  2.50477111e-04,  4.71596956e-01,  3.65466923e-01,
        2.93266475e-01,  

But maybe we want the function itself, but it will not be vectorized 

In [11]:
u_func = u_model.get_eval_function(alpha)
type(u_func)

function

To get the matrix $K(\phi,\phi)$

In [None]:
u_model.kmat.shape

(150, 150)

To get the Cholesky factor $L^T$

In [None]:
u_model.cholT.shape

(150, 150)

## Fit the model

Let's call $U = \{u: u(x) = K(x,\phi)^\top\beta \text{ for some } \beta\} \subset H_K$. Say we want to fit the a model $u$ to the data 
$$
\bar X = \{x_1,\dots,x_N\}\subset \mathbb{R}^d \text{ with corresponding values }  \{y_1,\dots,y_N\} \subset \mathbb{R}
$$

and for simplicty let $y = [y_1,\dots,y_N]^\top$

$$
\underset{u \in U}{\min} \|u\|_{H_K}^2 \quad \text{s.t.} \quad u(\bar X) = y
$$

Notice that we can relax the problem to

$$
\underset{u \in U}{\min} \|u\|_{H_K}^2 + \|u(\bar X) - y\|_{2}^2  .
$$

In addition, since each element can be identified with $\beta$ then the problem above is equivalent to finite dimensional problem,

$$
\underset{\beta \in \mathbb{R}^{MK}}{\min} \|K(\cdot,\phi)\beta\|_{H_K}^2 + \|K(\bar X,\phi)\beta - y\|_{2}^2,
$$
and finally, using the relation between $\alpha$ and $\beta$

$$
\begin{align*}
\|K(\cdot,\phi)\beta\|_{H_K}^2 
&= \|\sum_{j=1}^{MQ} K(\cdot,\phi_j)\beta_j\|_{H_K}^2 \\
&= \langle \sum_{j=1}^{MQ} K(\cdot,\phi_j)\beta_j, \sum_{j=1}^{MQ} K(\cdot,\phi_j)\beta_j \rangle_{H_K}\\
&= \beta^\top K(\phi,\phi) \beta\\
&= \left(\left(L^\top\right)^{-1} \alpha \right)^\top K(\phi,\phi) \left(\left(L^\top\right)^{-1} \alpha \right)\\
&= \alpha^\top L^{-1} K(\phi,\phi) \left(L^\top\right)^{-1} \alpha \\
&= \alpha^\top L^{-1} L L^\top \left(L^\top\right)^{-1} \alpha \\
&= \alpha^\top \alpha \\
&= \|\alpha\|_{H_K}^2. 
\end{align*}
$$

Thus, we want to solve the regularized least-squares (quadratic problem)
$$
\underset{\alpha \in \mathbb{R}^{MQ}}{\min} \quad \lambda\|\alpha\|_{2}^2 + \|K(\bar X,\phi)\left(L^\top\right)^{-1}\alpha - y\|_{2}^2.
$$
For simplicity, let's call $R = K(\bar X,\phi)\left(L^\top\right)^{-1}$. Then the solution using the SVD of $R$ is given by

$$
\alpha =\sum_{i=1}^{r} \frac{\sigma_i\left(u_i^T y\right)}{\sigma_i^2+\lambda} v_i
$$

where $R = \sum_{i=1}^{r}\sigma_i u_i v_i^T$ and $r = \min\{MS,N\}$





This vector can be computed via the method `get_fitted_params()`

In [None]:
# Training data
X_bar = jax.random.normal(jax.random.PRNGKey(0), shape = (70,2)) 
y = jax.random.normal(jax.random.PRNGKey(0), shape = (70,))
# Regularizer
lmb = 1e-6
# Fit params
fitted_alpha = u_model.get_fitted_params(X_bar, y, lam = lmb)