# GPSR - Cylinder Example

!!! change `compute_barycenter` to newer version !!!

In [None]:
# import jax
# jax.config.update('jax_array', False)
# # again, this only works on startup!
# from jax.config import config
# config.update("jax_enable_x64", True)

import time
from hydra_zen import instantiate, make_config, builds
import os
import sys
from tqdm import tqdm
import pickle

import jax.numpy as np
from jax import random, vmap
import jax.numpy.linalg as lin

from dataclasses import field
from typing import Callable, Tuple
import chex
from chex import assert_shape, assert_rank

import numpyro
import numpyro.distributions as dist
import numpyro.handlers as handlers
from numpyro.infer import SVI, Trace_ELBO, autoguide, init_to_value

from grassgp.inference import run_inference

from grassgp.utils import get_save_path, subspace_angle, to_dictconf, unvec, vec, kron_chol
# from grassgp.utils import safe_save_jax_array_dict as safe_save
# from grassgp.utils import load_and_convert_to_samples_dict as load_data
from grassgp.grassmann import valid_grass_point, convert_to_projs, grass_log, grass_exp, valid_grass_tangent, grass_dist
from grassgp.kernels import rbf
# from grassgp.models import GrassGP
# from grassgp.models_optimised import MatGP as MatGP_optimised
# from grassgp.models import MatGP
from grassgp.means import zero_mean
from grassgp.plot_utils import flatten_samples, plot_grass_dists, plot_AS_dir_preds

from tqdm import tqdm
import pandas as pd
import statsmodels.api as sm

import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams["figure.figsize"] = (10,6)

In [None]:
def run_svi_for_map(rng_key, model, maxiter, step_size, *args):
    start = time.time()
    guide = autoguide.AutoDelta(model)
    optimzer = numpyro.optim.Adam(step_size)
    svi = SVI(model, guide, optimzer, Trace_ELBO())
    svi_results = svi.run(rng_key, maxiter, *args)
    print('\nSVI elapsed time:', time.time() - start)
    return svi_results

In [None]:
@chex.dataclass
class MatGP:
    d_in: int
    d_out: Tuple[int, int]
    mu: Callable = field(repr=False)
    k: Callable = field(repr=False)
    Omega: chex.ArrayDevice = field(repr=False)
    cov_jitter: float = field(default=1e-8, repr=False)

    def __post_init__(self):
        d, n = self.d_out
        d_n = d * n
        assert_shape(self.Omega, (d_n, d_n),
                    custom_message=f"Omega has shape {self.Omega.shape}; expected shape {(d_n, d_n)}")

    def model(self, s: chex.ArrayDevice, use_kron_chol: bool = True) -> chex.ArrayDevice:
        d, n = self.d_out
        d_n = d * n
        assert_rank(s, self.d_in)
        N = s.shape[0]

        # compute mean matrix M = [mu(s[1]), mu(s[2]), ..., mu(s[N])]
        M = np.hstack(vmap(self.mu)(s))
        assert_shape(M, (d, n*N))

        # compute kernel matrix
        K = self.k(s, s)
        assert_shape(K, (N, N))

        # compute covariance matrix and cholesky factor
        if use_kron_chol:
            Chol = kron_chol(K + self.cov_jitter * np.eye(N), self.Omega)
        else:
            Cov = np.kron(K + self.cov_jitter * np.eye(N), self.Omega)
            Chol = lin.cholesky(Cov)

        # sample vec_Vs
        # Z = numpyro.sample("Z", dist.MultivariateNormal(covariance_matrix=np.eye(N*d_n)))
        Z = numpyro.sample("Z", dist.Normal().expand([N*d_n]))
        vec_Vs = numpyro.deterministic("vec_Vs", vec(M) + Chol @ Z)

        # form Vs
        Vs = numpyro.deterministic("Vs", vmap(lambda params: unvec(params, d, n))(np.array(vec_Vs.split(N))))
        return Vs

    def sample(self, seed: int, s: chex.ArrayDevice) -> chex.ArrayDevice:
        model = self.model
        seeded_model = handlers.seed(model, rng_seed=seed)
        return seeded_model(s)

    def predict(self, key: chex.ArrayDevice, s_test: chex.ArrayDevice, s_train: chex.ArrayDevice, Vs_train: chex.ArrayDevice, jitter: float = 1e-8) -> Tuple[chex.ArrayDevice, chex.ArrayDevice]:
        d, n = self.d_out
        d_in = self.d_in
        d_n = d * n
        N_train = s_train.shape[0]
        N_test = s_test.shape[0]
        if d_in > 1:
            assert s_train.shape[1] == d_in
            assert s_test.shape[1] == d_in

        # compute means
        M_train = np.hstack(vmap(self.mu)(s_train))
        M_test = np.hstack(vmap(self.mu)(s_test))
        assert_shape(M_train, (d, n*N_train))
        assert_shape(M_test, (d, n*N_test))

        # compute kernels between train and test locs
        K_train_train = self.k(s_train, s_train)
        assert_shape(K_train_train, (N_train, N_train))
        K_train_test = self.k(s_train, s_test)
        assert_shape(K_train_test, (N_train, N_test))
        K_test_train = K_train_test.T
        K_test_test = self.k(s_test, s_test)
        assert_shape(K_test_test, (N_test, N_test))

        # compute posterior mean and cov
        K_test_train_Omega = np.kron(K_test_train, self.Omega)
        K_train_test_Omega = np.kron(K_train_test, self.Omega)
        K_test_test_Omega = np.kron(K_test_test, self.Omega)
        # FIX: change for singular Omega
        # print(f"Rank of Omega = {lin.matrix_rank(self.Omega)}. Shape of Omega = {self.Omega.shape}")
        # if lin.matrix_rank(self.Omega) == d_n:
        #     mean_sols = kron_solve(K_train_train, self.Omega, vec(np.hstack(Vs_train)) - vec(M_train))
        #     cov_sols = vmap(lambda v: kron_solve(K_train_train, self.Omega, v), in_axes=1, out_axes=1)(K_train_test_Omega)
        # else:
        #     K_train_train_inv = lin.inv(K_train_train)
        #     Omega_pinv = lin.pinv(self.Omega)
        #     K_train_train_Omega_pinv = np.kron(K_train_train_inv, Omega_pinv)
        #     mean_sols = K_train_train_Omega_pinv @ (vec(np.hstack(Vs_train)) - vec(M_train))
        #     cov_sols = K_train_train_Omega_pinv @ K_train_test_Omega
        K_train_train_inv = lin.inv(K_train_train)
        Omega_pinv = lin.pinv(self.Omega)
        K_train_train_Omega_pinv = np.kron(K_train_train_inv, Omega_pinv)
        mean_sols = K_train_train_Omega_pinv @ (vec(np.hstack(Vs_train)) - vec(M_train))
        cov_sols = K_train_train_Omega_pinv @ K_train_test_Omega
        
        vec_post_mean = vec(M_test) + K_test_train_Omega @ mean_sols
        assert_shape(vec_post_mean, (d*n*N_test,),
                     custom_message=f"vec_post_mean should have shape {(d*n*N_test,)}; obtained {vec_post_mean.shape}")

        # cov_sols = vmap(lambda v: kron_solve(K_train_train, self.Omega, v), in_axes=1, out_axes=1)(K_train_test_Omega)
        post_cov = K_test_test_Omega - K_test_train_Omega @ cov_sols
        assert_shape(post_cov, (d*n*N_test, d*n*N_test),
                     custom_message=f"post_cov should have shape {(d*n*N_test,d*n*N_test)}; obtained {post_cov.shape}")

        # sample predictions
        post_cov += jitter * np.eye(d*n*N_test)
        
        # FIX: change for singular post_cov
        # print(f"Rank of posterior cov = {lin.matrix_rank(post_cov)}. Shape of posterior cov = {post_cov.shape}")
        vec_pred = dist.MultivariateNormal(loc=vec_post_mean, covariance_matrix=post_cov).sample(key)
        assert_shape(vec_pred, (d*n*N_test,),
                     custom_message=f"vec_pred should have shape {(d*n*N_test,)}; obtained {vec_pred.shape}")

        # unvec mean and preds and return
        post_mean = vmap(lambda params: unvec(params, d, n))(np.array(vec_post_mean.split(N_test)))
        pred = vmap(lambda params: unvec(params, d, n))(np.array(vec_pred.split(N_test)))
        return post_mean, pred

In [None]:
# @chex.dataclass
# class MatGP_optimised:
#     d_in: int
#     d_out: Tuple[int, int]
#     mu: Callable = field(repr=False)
#     k: Callable = field(repr=False)
#     Omega_diag_chol: chex.ArrayDevice = field(repr=False)
#     cov_jitter: float = field(default=1e-8, repr=False)

#     def __post_init__(self):
#         d, n = self.d_out
#         d_n = d * n
#         assert_shape(self.Omega_diag_chol, (d_n,),
#                     custom_message=f"Omega_diag_chol has shape {self.Omega_diag_chol.shape}; expected shape {(d_n,)}")

#     def model(self, s: chex.ArrayDevice) -> chex.ArrayDevice:
#         d, n = self.d_out
#         d_n = d * n
#         assert_rank(s, self.d_in)
#         N = s.shape[0]

#         # compute mean matrix M = [mu(s[1]), mu(s[2]), ..., mu(s[N])]
#         M = np.hstack(vmap(self.mu)(s))
#         assert_shape(M, (d, n*N))
#         # ! TODO: check this out
#         vec_M = vec(M)

#         # compute kernel matrix
#         K = self.k(s, s)
#         assert_shape(K, (N, N))
            
#         K_chol = lin.cholesky(K + self.cov_jitter * np.eye(N))
#         # Omega_diag_chol = np.sqrt(self.Omega_diag)

#         # sample vec_Vs
#         # Z = numpyro.sample("Z", dist.MultivariateNormal(covariance_matrix=np.eye(N*d_n)))
#         Z = numpyro.sample("Z", dist.Normal().expand([N*d_n]))
#         unvec_Z = unvec(Z, d_n, N)
#         # vec_Vs = numpyro.deterministic("vec_Vs", vec(M + np.einsum('i,ij->ij', self.Omega_diag_chol, unvec_Z @ K_chol.T)))
#         vec_Vs = numpyro.deterministic("vec_Vs", vec_M + vec(np.einsum('i,ij->ij', self.Omega_diag_chol, unvec_Z @ K_chol.T)))

#         # form Vs
#         Vs = numpyro.deterministic("Vs", vmap(lambda params: unvec(params, d, n))(np.array(vec_Vs.split(N))))
#         return Vs

#     def sample(self, seed: int, s: chex.ArrayDevice) -> chex.ArrayDevice:
#         model = self.model
#         seeded_model = handlers.seed(model, rng_seed=seed)
#         return seeded_model(s)

In [None]:
# @chex.dataclass
# class GrassGP:
#     d_in: int
#     d_out: Tuple[int, int]
#     mu: Callable = field(repr=False)
#     k: Callable = field(repr=False)
#     Omega_diag_chol: chex.ArrayDevice = field(repr=False)
#     U: chex.ArrayDevice
#     cov_jitter: float = field(default=1e-4, repr=False)

#     def __post_init__(self):
#         d, n = self.d_out
#         d_n = d * n
#         assert_shape(self.Omega_diag_chol, (d_n,),
#                     custom_message=f"Omega_diag_chol has shape {self.Omega_diag_chol.shape}; expected shape {(d_n,)}")
#         assert_shape(self.U, (d, n),
#                     custom_message=f"U has shape {self.U.shape}; expected shape {(d, n)}")
#         tol = 1e-06
#         # assert valid_grass_point(self.U), f"U is not a valid point on Grassmann manifold G({d},{n}) at tolerance level {tol = }"

#     @property
#     def V(self) -> MatGP:
#         mat_gp = MatGP_optimised(d_in=self.d_in, d_out=self.d_out, mu=self.mu, k=self.k, Omega_diag_chol=self.Omega_diag_chol, cov_jitter=self.cov_jitter)
#         return mat_gp

#     def tangent_model(self, s: chex.ArrayDevice) -> chex.ArrayDevice:
#         d, n = self.d_out
#         N = s.shape[0]
#         Vs = self.V.model(s)
#         I_UUT = np.eye(d) - self.U @ self.U.T
#         Deltas = numpyro.deterministic("Deltas", np.einsum('ij,ljk->lik', I_UUT, Vs))
#         assert_shape(Deltas, (N, d, n),
#                     custom_message=f"Deltas has shape {Deltas.shape}; expected shape {(N, d, n)}")
#         return Deltas

#     def sample_tangents(self, seed: int, s: chex.ArrayDevice) -> chex.ArrayDevice:
#         tangent_model = self.tangent_model
#         seeded_model = handlers.seed(tangent_model, rng_seed=seed)
#         Deltas = seeded_model(s)
#         assert vmap(lambda Delta: valid_grass_tangent(self.U, Delta))(Deltas).all()
#         return Deltas

#     def sample_grass(self, seed: int, s: chex.ArrayDevice, reortho: bool = False) -> chex.ArrayDevice:
#         Deltas = self.sample_tangents(seed, s)
#         Ws = convert_to_projs(Deltas, self.U, reorthonormalize=reortho)
#         return Ws

#     def predict_tangents(self, key: chex.ArrayDevice, s_test: chex.ArrayDevice, s_train: chex.ArrayDevice, Vs_train: chex.ArrayDevice, jitter: float = 1e-8) -> Tuple[chex.ArrayDevice, chex.ArrayDevice]:
#         d, _ = self.d_out
#         I_UUT = np.eye(d) - self.U @ self.U.T
#         V_mu = lambda s: I_UUT @ self.mu(s)
#         V_Omega = I_UUT @ np.diag(self.Omega_diag_chol) @ I_UUT.T
#         V = MatGP(d_in=self.d_in, d_out=self.d_out, mu=V_mu, k=self.k, Omega=V_Omega, cov_jitter=self.cov_jitter)
#         Deltas_mean, Deltas_pred = V.predict(key, s_test, s_train, Vs_train, jitter=jitter)
#         return Deltas_mean, Deltas_pred

#     def predict_grass(self, key: chex.ArrayDevice, s_test: chex.ArrayDevice, s_train: chex.ArrayDevice, Vs_train: chex.ArrayDevice, jitter: float = 1e-8, reortho: bool = False) -> Tuple[chex.ArrayDevice, chex.ArrayDevice]:
#         Deltas_mean, Deltas_pred = self.predict_tangents(key, s_test, s_train, Vs_train, jitter=jitter)
#         Ws_mean = convert_to_projs(Deltas_mean, self.U, reorthonormalize=reortho)
#         Ws_pred = convert_to_projs(Deltas_pred, self.U, reorthonormalize=reortho)
#         return Ws_mean, Ws_pred

In [None]:
@chex.dataclass
class GrassGP:
    d_in: int
    d_out: Tuple[int, int]
    mu: Callable = field(repr=False)
    k: Callable = field(repr=False)
    Omega: chex.ArrayDevice = field(repr=False)
    U: chex.ArrayDevice
    cov_jitter: float = field(default=1e-4, repr=False)

    def __post_init__(self):
        d, n = self.d_out
        d_n = d * n
        assert_shape(self.Omega, (d_n, d_n),
                    custom_message=f"Omega has shape {self.Omega.shape}; expected shape {(d_n, d_n)}")
        assert_shape(self.U, (d, n),
                    custom_message=f"U has shape {self.U.shape}; expected shape {(d, n)}")
        tol = 1e-06
        # assert valid_grass_point(self.U), f"U is not a valid point on Grassmann manifold G({d},{n}) at tolerance level {tol = }"

    @property
    def V(self) -> MatGP:
        mat_gp = MatGP(d_in=self.d_in, d_out=self.d_out, mu=self.mu, k=self.k, Omega=self.Omega, cov_jitter=self.cov_jitter)
        return mat_gp

    def tangent_model(self, s: chex.ArrayDevice) -> chex.ArrayDevice:
        d, n = self.d_out
        N = s.shape[0]
        Vs = self.V.model(s)
        I_UUT = np.eye(d) - self.U @ self.U.T
        Deltas = numpyro.deterministic("Deltas", np.einsum('ij,ljk->lik', I_UUT, Vs))
        assert_shape(Deltas, (N, d, n),
                    custom_message=f"Deltas has shape {Deltas.shape}; expected shape {(N, d, n)}")
        return Deltas

    def sample_tangents(self, seed: int, s: chex.ArrayDevice) -> chex.ArrayDevice:
        tangent_model = self.tangent_model
        seeded_model = handlers.seed(tangent_model, rng_seed=seed)
        Deltas = seeded_model(s)
        assert vmap(lambda Delta: valid_grass_tangent(self.U, Delta))(Deltas).all()
        return Deltas

    def sample_grass(self, seed: int, s: chex.ArrayDevice, reortho: bool = False) -> chex.ArrayDevice:
        Deltas = self.sample_tangents(seed, s)
        Ws = convert_to_projs(Deltas, self.U, reorthonormalize=reortho)
        return Ws

    def predict_tangents(self, key: chex.ArrayDevice, s_test: chex.ArrayDevice, s_train: chex.ArrayDevice, Vs_train: chex.ArrayDevice, jitter: float = 1e-8) -> Tuple[chex.ArrayDevice, chex.ArrayDevice]:
        d, _ = self.d_out
        I_UUT = np.eye(d) - self.U @ self.U.T
        V_mu = lambda s: I_UUT @ self.mu(s)
        V_Omega = I_UUT @ self.Omega @ I_UUT.T
        V = MatGP(d_in=self.d_in, d_out=self.d_out, mu=V_mu, k=self.k, Omega=V_Omega, cov_jitter=self.cov_jitter)
        Deltas_mean, Deltas_pred = V.predict(key, s_test, s_train, Vs_train, jitter=jitter)
        return Deltas_mean, Deltas_pred

    def predict_grass(self, key: chex.ArrayDevice, s_test: chex.ArrayDevice, s_train: chex.ArrayDevice, Vs_train: chex.ArrayDevice, jitter: float = 1e-8, reortho: bool = False) -> Tuple[chex.ArrayDevice, chex.ArrayDevice]:
        Deltas_mean, Deltas_pred = self.predict_tangents(key, s_test, s_train, Vs_train, jitter=jitter)
        Ws_mean = convert_to_projs(Deltas_mean, self.U, reorthonormalize=reortho)
        Ws_pred = convert_to_projs(Deltas_pred, self.U, reorthonormalize=reortho)
        return Ws_mean, Ws_pred

In [None]:
# generate dataset
N = 40
s_test = np.linspace(0, 1, N)
k = 2 * np.pi
x = np.cos(k * s_test).reshape(-1, 1)
y = np.sin(k * s_test).reshape(-1, 1)
Ws_test = np.hstack((x,y))[:,:,None]
assert vmap(valid_grass_point)(Ws_test).all()
d, n = Ws_test.shape[1:]

# plot dataset
for i in range(d):
    plt.plot(s_test, Ws_test[:,i,0])
    plt.title(f'{i+1}th component of projection')
    plt.grid()
    plt.xlabel(r'$s$')
    plt.show()

In [None]:
# subsample data
s_gap = 3
s_train = s_test[::s_gap].copy()
print(f"Number of training points: {s_train.shape[0]}")
Ws_train = Ws_test[::s_gap,:,:].copy()

for i in range(d):
    plt.plot(s_test, Ws_test[:,i,0])
    plt.scatter(s_train, Ws_train[:,i,0], c='r')
    plt.title(f'{i+1}th component of projection')
    plt.grid()
    plt.xlabel(r'$s$')
    plt.show()

In [None]:
def subspace_angle_to_grass_pt(theta):
    x = np.cos(theta).reshape(-1,1)
    y = np.sin(theta).reshape(-1,1)
    W = np.hstack((x,y))[:,:,None]
    W = W[0]
    return W

def loss(theta, Ws):
    W = subspace_angle_to_grass_pt(theta)
    return vmap(lambda x: grass_dist(W, x)**2)(Ws).sum()

In [None]:
thetas = np.linspace(0, np.pi, 1000)
losses = vmap(lambda theta: loss(theta, Ws_train))(thetas)

theta_argmin = thetas[losses.argmin()]
anchor_point = subspace_angle_to_grass_pt(theta_argmin)
assert valid_grass_point(anchor_point)

plt.plot(thetas,losses)
plt.scatter(theta_argmin, losses.min(), color="red", label='loss of anchor point')
plt.grid()
plt.xlabel("subspace angle")
plt.title("Plot of loss vs subspace angle to determine Karcher mean")
plt.legend()
plt.show()

fig = plt.figure(figsize=(7,7))
ax = fig.add_subplot(111)
ax.set_ylim((-1.25,1.25))
ax.set_xlim((-1.25,1.25))
ax.scatter(Ws_train[:,0,0], Ws_train[:,1,0], color="blue", alpha=0.25, label='training points')
ax.scatter(anchor_point[0,0], anchor_point[1,0], color="red", marker="*", label="anchor point")
ax.grid()
ax.legend()
ax.set_title("Plot of training points with anchor point on circle rep of Gr(2,1)")

plt.show()

In [None]:
# compute log of training data and full data
log_Ws_train = vmap(lambda W: grass_log(anchor_point, W))(Ws_train)
log_Ws_test = vmap(lambda W: grass_log(anchor_point, W))(Ws_test)

In [None]:
alphas = np.array([subspace_angle(w) for w in Ws_test])
alphas_train = np.array([subspace_angle(w) for w in Ws_train])

In [None]:
model_config = {
    'anchor_point': anchor_point.tolist(),
    'd_in': 1,
    'Omega' : None,
    'k_include_noise': True,
    'var' : 1.0,
    'length' : None, 
    'noise' : None,
    'require_noise' : False,
    'jitter' : 1e-06,
    'cov_jitter' : 1e-4,
    'L_jitter' : 1e-8,
    'reorthonormalize' : False,
    'b' : 0.5,
    # 'ell': 0.0075
    'ell': 0.01
}

def model(s, log_Ws, grass_config = model_config):
    U = np.array(grass_config['anchor_point'])
    d, n = U.shape
    N = s.shape[0]
    d_n = d * n
    # N_params = N * d_n
    if log_Ws is not None:
        assert log_Ws.shape == (N, d, n)
    
    # get/sample Omega
    if grass_config['Omega'] is None:
        sigmas = numpyro.sample('sigmas', dist.LogNormal(0.0, 1.0).expand([d_n]))
        L_factor = numpyro.sample('L_factor', dist.LKJ(d_n, 1.0))
        L = numpyro.deterministic('L', L_factor + grass_config['L_jitter'] * np.eye(d_n))
        Omega = numpyro.deterministic('Omega', np.outer(sigmas, sigmas) * L)
    else:
        Omega = np.array(grass_config['Omega'])
        
    # get/sample kernel params
    if grass_config['var'] is None:
        # sample var
        var = numpyro.sample("kernel_var", dist.LogNormal(0.0, grass_config['b']))
    else:
        var = grass_config['var']

    if grass_config['length'] is None:
        # sample length
        length = numpyro.sample("kernel_length", dist.LogNormal(0.0, grass_config['b']))
    else:
        length = grass_config['length']

    if grass_config['require_noise']:
        if grass_config['noise'] is None:
            # sample noise
            noise = numpyro.sample("kernel_noise", dist.LogNormal(0.0, grass_config['b']))
        else:
            noise = grass_config['noise']
    else:
        noise = 0.0
    
    kernel_params = {'var': var, 'length': length, 'noise': noise}
    # create kernel function
    k = lambda t, s: rbf(t, s, kernel_params, jitter=grass_config['jitter'], include_noise=grass_config['k_include_noise'])
    # create mean function
    mu = lambda s: zero_mean(s, d, n)

    # initialize GrassGP
    grass_gp = GrassGP(d_in=grass_config['d_in'], d_out=(d,n), mu=mu, k=k, Omega=Omega, U=U, cov_jitter=grass_config['cov_jitter'])
    
    # sample Deltas
    Deltas = grass_gp.tangent_model(s)
    
    # # # # ! check what power this should be
    # likelihood
    ell = grass_config['ell']
    with numpyro.plate("N", N):
        numpyro.sample("log_W", dist.continuous.MatrixNormal(loc=Deltas, scale_tril_row=ell * np.eye(d),scale_tril_column=np.eye(n)), obs=log_Ws)

TangentSpaceModelConf = builds(model, grass_config=model_config, zen_partial=True)

In [None]:
SVIConfig = make_config(
    seed = 123514354575,
    maxiter = 15000,
    step_size = 0.001
)

TrainConfig = make_config(
    seed = 9870687,
    n_warmup = 2000,
    n_samples = 7000,
    n_chains = 1,
    n_thinning = 2
)

Config = make_config(
    model = TangentSpaceModelConf,
    svi = SVIConfig,
    train = TrainConfig
)

In [None]:
def train(cfg):
    # instantiate grass model
    model = instantiate(cfg.model)
    
    # run SVI to get MAP esimtate to initialise MCMC
    svi_key = random.PRNGKey(cfg.svi.seed)
    maxiter = cfg.svi.maxiter
    step_size = cfg.svi.step_size
    print(f"n_train = {s_train.shape[0]}")
    print("Running SVI for MAP estimate to initialise MCMC")
    svi_results = run_svi_for_map(svi_key, model, maxiter, step_size, s_train, log_Ws_train)
    
    # plot svi losses
    plt.plot(svi_results.losses)
    plt.show()
    
    # get initialisation from SVI results
    map_est = svi_results.params
    strip_val = len('_auto_loc')
    init_values = {key[:-strip_val]:value for (key, value) in map_est.items()}
    
    # run HMC
    train_key = random.PRNGKey(cfg.train.seed)
    mcmc_config = {'num_warmup' : cfg.train.n_warmup, 'num_samples' : cfg.train.n_samples, 'num_chains' : cfg.train.n_chains, 'thinning' : cfg.train.n_thinning, 'init_strategy' : init_to_value(values=init_values)}
    print("HMC starting.")
    mcmc = run_inference(train_key, mcmc_config, model, s_train, log_Ws_train)    
    # original_stdout = sys.stdout
    # with open('hmc_log.txt', 'w') as f:
    #     sys.stdout = f
    #     mcmc.print_summary()
    #     sys.stdout = original_stdout
    
    samples = mcmc.get_samples()
    inference_data = samples.copy()
    for param, initial_val in init_values.items():
        inference_data[f"{param}-initial_value"] = initial_val
    
    # head = os.getcwd()
    # main_name = "inference_data"
    # path = get_save_path(head, main_name)
    # try:
    #     safe_save(path, inference_data)
    # except FileExistsError:
    #     print("File exists so not saving.")
    return inference_data

In [None]:
numpyro.render_model(instantiate(Config.model), model_args=(s_train,log_Ws_train))

In [None]:
inference_data = train(Config)

In [None]:
samples = dict(filter(lambda elem: 'initial_value' not in elem[0], inference_data.items()))
initial_values = dict(filter(lambda elem: 'initial_value' in elem[0], inference_data.items()))
assert set(samples.keys()).union(initial_values.keys()) == set(inference_data.keys())

In [None]:
# with open("results/samples_less_train_pts.pickle", 'rb') as f:
#     samples = pickle.load(f)

In [None]:
with open("results/samples_less_train_pts.pickle", 'wb') as f:
    pickle.dump(samples, f)

In [None]:
my_samples = flatten_samples(samples, ignore=[])

In [None]:
trace_plot_vars = ['kernel_length']
for key in my_samples.keys():
    if 'Omega' in key:
        trace_plot_vars.append(key)
    if 'sigmas' in key:
        trace_plot_vars.append(key)
        
my_samples[trace_plot_vars].plot(subplots=True, figsize=(10,40), sharey=False)
plt.show()

In [None]:
for var in trace_plot_vars:
    sm.graphics.tsa.plot_acf(my_samples[var], lags=100)
    plt.title(f"acf for {var}")
    plt.show()

In [None]:
tol=1e-5

samples_Ws_train = vmap(lambda Deltas: convert_to_projs(Deltas, anchor_point, reorthonormalize=False))(samples['Deltas'])

for ws in samples_Ws_train:
    assert vmap(lambda w: valid_grass_point(w, tol=tol))(ws).all()

In [None]:
mcmc_barycenters = []
for i in tqdm(range(s_train.shape[0])):
    points = samples_Ws_train[:,i,:,:]
    mcmc_losses = vmap(lambda theta: loss(theta, points))(thetas)
    mcmc_theta_argmin = thetas[mcmc_losses.argmin()]
    barycenter = subspace_angle_to_grass_pt(mcmc_theta_argmin)
    mcmc_barycenters.append(barycenter)
    # plt.plot(thetas, mcmc_losses)
    # plt.scatter(mcmc_theta_argmin, mcmc_losses.min(),color="red")
    # plt.grid()
    # plt.title(f"{i}")
    # plt.show()

In [None]:
mcmc_barycenters = np.array(mcmc_barycenters)

In [None]:
with open("results/mcmc_barycenters_less_train_pts.pickle", 'wb') as f:
    pickle.dump(mcmc_barycenters, f)

In [None]:
# with open("results/mcmc_barycenters_less_train_pts.pickle", 'rb') as f:
#     mcmc_barycenters = pickle.load(f)

In [None]:
in_sample_errors = vmap(grass_dist)(Ws_train, mcmc_barycenters)

In [None]:
plt.plot(s_train,in_sample_errors)
plt.show()

In [None]:
sd_s_train = []
for i in tqdm(range(s_train.shape[0])):
    fixed = mcmc_barycenters[i]
    dists = vmap(lambda W: grass_dist(W, fixed))(samples_Ws_train[:,i,:,:])
    dists_Sq = dists**2
    sd_s_train.append(np.sqrt(dists_Sq.mean()))

In [None]:
sd_s_train = np.array(sd_s_train)

In [None]:
pd_data = {'s': s_train, 'errors': in_sample_errors, 'sd': sd_s_train}
in_sample_errors_df = pd.DataFrame(data=pd_data)
in_sample_errors_df.head()

In [None]:
in_sample_errors_df.drop(["s"],axis=1).describe()

In [None]:
# plot_grass_dists(samples_Ws_train, Ws_train, s_train)

In [None]:
samples_alphas_train = np.array([[subspace_angle(w)for w in Ws_sample] for Ws_sample in samples_Ws_train])

In [None]:
with open("results/samples_alphas_train_less_train_pts.pickle", 'wb') as f:
    pickle.dump(samples_alphas_train, f)

In [None]:
# with open("results/samples_alphas_train_less_train_pts.pickle", 'rb') as f:
#     samples_alphas_train = pickle.load(f)

In [None]:
percentile_levels = [2.5, 97.5]
conf_level = percentile_levels[-1] - percentile_levels[0]
percentiles = np.percentile(samples_alphas_train, np.array(percentile_levels), axis=0)
lower = percentiles[0,:]
upper = percentiles[1,:]

In [None]:
plt.plot(s_test, alphas, c='black', alpha=0.5, label='full data')
plt.scatter(s_train, alphas_train, label='training data', c='g')
plt.scatter(s_train, samples_alphas_train.mean(axis=0), label='mean samples', c='r')
plt.fill_between(s_train, lower, upper,  color='lightblue', alpha=0.75,label=f'{conf_level}% credible interval')
plt.xlabel(r"$s$")
plt.ylabel("subspace angle")
plt.legend()
plt.show()

In [None]:
def predict_tangents(
    key: chex.ArrayDevice,
    s_test: chex.ArrayDevice,
    s_train: chex.ArrayDevice,
    Vs_train: chex.ArrayDevice,
    dict_cfg,
    samples: dict,
    jitter: float = 1e-8
) -> Tuple[chex.ArrayDevice, chex.ArrayDevice]:
    
    d_in = dict_cfg.model.grass_config.d_in
    U = np.array(dict_cfg.model.grass_config.anchor_point)
    d, n = U.shape
    cov_jitter = dict_cfg.model.grass_config.cov_jitter
    k_include_noise = dict_cfg.model.grass_config.k_include_noise
    kern_jitter = dict_cfg.model.grass_config.jitter
    n_samples = dict_cfg.train.n_samples // dict_cfg.train.n_thinning
    assert n_samples == samples['Deltas'].shape[0]
    
    def predict(
        key: chex.ArrayDevice,
        Omega: chex.ArrayDevice,
        var: float,
        length: float,
        noise: float,
    ) -> Tuple[chex.ArrayDevice, chex.ArrayDevice]:
        # iniatilize GrassGP
        kernel_params = {'var': var, 'length': length, 'noise': noise}
        k = lambda t, s: rbf(t, s, kernel_params, jitter=kern_jitter, include_noise=k_include_noise)
        mu = lambda s: zero_mean(s, d, n)
        grass_gp = GrassGP(d_in=d_in, d_out=(d, n), mu=mu, k=k, Omega=Omega, U=U, cov_jitter=cov_jitter)

        # predict
        Deltas_mean, Deltas_pred = grass_gp.predict_tangents(key, s_test, s_train, Vs_train, jitter=jitter)
        return Deltas_mean, Deltas_pred

    # initialize vmap args
    vmap_args = (random.split(key, n_samples),)
    
    cfg_Omega = dict_cfg.model.grass_config.Omega
    cfg_var = dict_cfg.model.grass_config.var
    cfg_length = dict_cfg.model.grass_config.length
    cfg_noise = dict_cfg.model.grass_config.noise
    cfg_require_noise = dict_cfg.model.grass_config.require_noise
    
    if cfg_Omega is None:
        vmap_args += (samples['Omega'],)
    else:
        cfg_Omega = np.array(cfg_Omega)
        vmap_args += (np.repeat(cfg_Omega[None,:,:], n_samples, axis=0),)
    
    if cfg_var is None:
        vmap_args += (samples['kernel_var'],)
    else:
        vmap_args += (cfg_var * np.ones(n_samples),)
        
    if cfg_length is None:
        vmap_args += (samples['kernel_length'],)
    else:
        vmap_args += (cfg_length * np.ones(n_samples),)
        
    if cfg_require_noise:
        if cfg_noise is None:
            vmap_args += (samples['kernel_noise'],)
        else:
            vmap_args += (cfg_noise * np.ones(n_samples),)
    else:
        vmap_args += (np.zeros(n_samples),)
    
    assert len(vmap_args) == 5
    Deltas_means, Deltas_preds = vmap(predict)(*vmap_args)
    return Deltas_means, Deltas_preds

In [None]:
config = to_dictconf(Config)

In [None]:
pred_key = random.PRNGKey(6578)
Deltas_means, Deltas_preds = predict_tangents(pred_key, s_test, s_train, log_Ws_train, config, samples)
assert np.isnan(Deltas_means).sum() == 0
assert np.isnan(Deltas_preds).sum() == 0

In [None]:
with open("results/Deltas_means_less_train_pts.pickle", 'wb') as f:
    pickle.dump(Deltas_means, f)
    
with open("results/Deltas_preds_less_train_pts.pickle", 'wb') as f:
    pickle.dump(Deltas_preds, f)

In [None]:
# with open("results/Deltas_means_less_train_pts.pickle", 'rb') as f:
#     Deltas_means = pickle.load(f)
    
# with open("results/Deltas_preds_less_train_pts.pickle", 'rb') as f:
#     Deltas_preds = pickle.load(f)

In [None]:
plt.rcParams["figure.figsize"] = (12,6)
percentile_levels = [2.5, 97.5]
conf_level = percentile_levels[-1] - percentile_levels[0]
for i in range(d):
    obs = log_Ws_train[:,i,0]
    means = Deltas_means[:,:,i,0]
    means_avg = np.mean(means, axis=0)
    preds = Deltas_preds[:,:,i,0]
    percentiles = np.percentile(preds, np.array(percentile_levels), axis=0)
    lower = percentiles[0,:]
    upper = percentiles[1,:]
    plt.plot(s_test, log_Ws_test[:,i,0], label='full data',c='black', alpha=0.75, linestyle='dashed')
    plt.scatter(s_train, log_Ws_train[:,i,0], label='training data', c='g')
    plt.plot(s_test, means_avg, label='averaged mean prediction', c='r', alpha=0.75)
    plt.fill_between(s_test, lower, upper, color='lightblue', alpha=0.75, label=f'{conf_level}% credible interval')
    plt.xlabel(r"$s$")
    plt.legend()
    plt.vlines(s_train, 0.99*lower.min(), 1.01*upper.max(), colors='green', linestyles='dashed')
    plt.title(f"{i+1}th component of tangents")
    plt.show()

In [None]:
def predict_grass(
    key: chex.ArrayDevice,
    s_test: chex.ArrayDevice,
    s_train: chex.ArrayDevice,
    Vs_train: chex.ArrayDevice,
    dict_cfg,
    samples: dict,
    jitter: float = 1e-8,
    reortho: bool = False
) -> Tuple[chex.ArrayDevice, chex.ArrayDevice]:
    
    d_in = dict_cfg.model.grass_config.d_in
    U = np.array(dict_cfg.model.grass_config.anchor_point)
    d, n = U.shape
    cov_jitter = dict_cfg.model.grass_config.cov_jitter
    k_include_noise = dict_cfg.model.grass_config.k_include_noise
    kern_jitter = dict_cfg.model.grass_config.jitter
    n_samples = dict_cfg.train.n_samples // dict_cfg.train.n_thinning
    assert n_samples == samples['Deltas'].shape[0]
    
    def predict(
        key: chex.ArrayDevice,
        Omega: chex.ArrayDevice,
        var: float,
        length: float,
        noise: float,
    ) -> Tuple[chex.ArrayDevice, chex.ArrayDevice]:
        # iniatilize GrassGP
        kernel_params = {'var': var, 'length': length, 'noise': noise}
        k = lambda t, s: rbf(t, s, kernel_params, jitter=kern_jitter, include_noise=k_include_noise)
        mu = lambda s: zero_mean(s, d, n)
        grass_gp = GrassGP(d_in=d_in, d_out=(d, n), mu=mu, k=k, Omega=Omega, U=U, cov_jitter=cov_jitter)

        # predict
        Ws_mean, Ws_pred = grass_gp.predict_grass(key, s_test, s_train, Vs_train, jitter=jitter, reortho=reortho)
        return Ws_mean, Ws_pred

    # initialize vmap args
    vmap_args = (random.split(key, n_samples),)
    
    cfg_Omega = dict_cfg.model.grass_config.Omega
    cfg_var = dict_cfg.model.grass_config.var
    cfg_length = dict_cfg.model.grass_config.length
    cfg_noise = dict_cfg.model.grass_config.noise
    cfg_require_noise = dict_cfg.model.grass_config.require_noise
    
    if cfg_Omega is None:
        vmap_args += (samples['Omega'],)
    else:
        cfg_Omega = np.array(cfg_Omega)
        vmap_args += (np.repeat(cfg_Omega[None,:,:], n_samples, axis=0),)
    
    if cfg_var is None:
        vmap_args += (samples['kernel_var'],)
    else:
        vmap_args += (cfg_var * np.ones(n_samples),)
        
    if cfg_length is None:
        vmap_args += (samples['kernel_length'],)
    else:
        vmap_args += (cfg_length * np.ones(n_samples),)
        
    if cfg_require_noise:
        if cfg_noise is None:
            vmap_args += (samples['kernel_noise'],)
        else:
            vmap_args += (cfg_noise * np.ones(n_samples),)
    else:
        vmap_args += (np.zeros(n_samples),)
    
    assert len(vmap_args) == 5
    Ws_means, Ws_preds = vmap(predict)(*vmap_args)
    return Ws_means, Ws_preds

In [None]:
pred_key_grass = random.PRNGKey(7695)
Ws_means, Ws_preds = predict_grass(pred_key_grass, s_test, s_train, log_Ws_train, config, samples)
assert np.isnan(Ws_means).sum() == 0
assert np.isnan(Ws_preds).sum() == 0

In [None]:
with open("results/Ws_means_less_train_pts.pickle", 'wb') as f:
    pickle.dump(Ws_means, f)
    
with open("results/Ws_preds_less_train_pts.pickle", 'wb') as f:
    pickle.dump(Ws_preds, f)

In [None]:
# with open("results/Ws_means_less_train_pts.pickle", 'rb') as f:
#     Ws_means = pickle.load(f)
    
# with open("results/Ws_preds_less_train_pts.pickle", 'rb') as f:
#     Ws_preds = pickle.load(f)

In [None]:
plt.rcParams["figure.figsize"] = (12,6)
percentile_levels = [2.5, 97.5]
conf_level = percentile_levels[-1] - percentile_levels[0]
for i in range(d):
    obs = Ws_train[:,i,0]
    means = Ws_means[:,:,i,0]
    means_avg = np.mean(means, axis=0)
    preds = Ws_preds[:,:,i,0]
    percentiles = np.percentile(preds, np.array(percentile_levels), axis=0)
    lower = percentiles[0,:]
    upper = percentiles[1,:]
    plt.plot(s_test, Ws_test[:,i,0], label='full data',c='black', alpha=0.75, linestyle='dashed')
    plt.scatter(s_train, Ws_train[:,i,0], label='training data', c='g')
    plt.plot(s_test, means_avg, label='averaged mean prediction', c='r', alpha=0.75)
    plt.fill_between(s_test, lower, upper, color='lightblue', alpha=0.75, label=f'{conf_level}% credible interval')
    plt.xlabel(r"$s$")
    plt.legend()
    plt.vlines(s_train, 0.99*lower.min(), 1.01*upper.max(), colors='green', linestyles='dashed')
    plt.title(f"{i+1}th component of projections")
    plt.show()

In [None]:
alphas_means = np.array([[subspace_angle(w) for w in mean] for mean in Ws_means])
alphas_preds = np.array([[subspace_angle(w) for w in pred] for pred in Ws_preds])

In [None]:
with open("results/alphas_means_less_train_pts.pickle", 'wb') as f:
    pickle.dump(alphas_means, f)
    
with open("results/alphas_preds_less_train_pts.pickle", 'wb') as f:
    pickle.dump(alphas_preds, f)

In [None]:
# with open("results/alphas_means_less_train_pts.pickle", 'rb') as f:
#     alphas_means = pickle.load(f)
    
# with open("results/alphas_preds_less_train_pts.pickle", 'rb') as f:
#     alphas_preds = pickle.load(f)

In [None]:
plt.rcParams["figure.figsize"] = (12,6)
percentile_levels = [2.5, 97.5]
conf_level = percentile_levels[-1] - percentile_levels[0]
alphas_means_avg = np.mean(alphas_means, axis=0)
percentiles = np.percentile(alphas_preds, np.array(percentile_levels), axis=0)
lower = percentiles[0,:]
upper = percentiles[1,:]
plt.plot(s_test, alphas, label='full data',c='black', alpha=0.75, linestyle='dashed')
plt.scatter(s_train, alphas_train, label='training data', c='g')
plt.plot(s_test, alphas_means_avg, label='averaged mean prediction', c='r', alpha=0.75)
plt.fill_between(s_test, lower, upper, color='lightblue', alpha=0.75, label=f'{conf_level}% credible interval')
plt.xlabel(r"$s$")
plt.ylabel("subspace angle")
plt.legend()
plt.vlines(s_train, 0, np.pi, colors='green', linestyles='dashed')
plt.title(f"predictions for subspace angles")
plt.show()

In [None]:
Ws_means.shape

In [None]:
test_means_mcmc_barycenters = []
for i in tqdm(range(s_test.shape[0])):
    points = Ws_means[:,i,:,:]
    test_means_mcmc_losses = vmap(lambda theta: loss(theta, points))(thetas)
    test_means_mcmc_theta_argmin = thetas[test_means_mcmc_losses.argmin()]
    barycenter = subspace_angle_to_grass_pt(test_means_mcmc_theta_argmin)
    test_means_mcmc_barycenters.append(barycenter)
    # plt.plot(thetas, test_means_mcmc_losses)
    # plt.scatter(test_means_mcmc_theta_argmin, test_means_mcmc_losses.min(),color="red")
    # plt.grid()
    # plt.title(f"{i}")
    # plt.show()

In [None]:
test_preds_mcmc_barycenters = []
for i in tqdm(range(s_test.shape[0])):
    points = Ws_preds[:,i,:,:]
    test_preds_mcmc_losses = vmap(lambda theta: loss(theta, points))(thetas)
    test_preds_mcmc_theta_argmin = thetas[test_preds_mcmc_losses.argmin()]
    barycenter = subspace_angle_to_grass_pt(test_preds_mcmc_theta_argmin)
    test_preds_mcmc_barycenters.append(barycenter)
    # plt.plot(thetas, test_preds_mcmc_losses)
    # plt.scatter(test_preds_mcmc_theta_argmin, test_preds_mcmc_losses.min(),color="red")
    # plt.grid()
    # plt.title(f"{i}")
    # plt.show()

In [None]:
test_means_mcmc_barycenters = np.array(test_means_mcmc_barycenters)

In [None]:
test_preds_mcmc_barycenters = np.array(test_preds_mcmc_barycenters)

In [None]:
with open("results/test_means_mcmc_barycenters_less_train_pts.pickle", 'wb') as f:
    pickle.dump(test_means_mcmc_barycenters, f)

with open("results/test_preds_mcmc_barycenters_less_train_pts.pickle", 'wb') as f:
    pickle.dump(test_preds_mcmc_barycenters, f)

In [None]:
# with open("results/test_means_mcmc_barycenters_less_train_pts.pickle", 'rb') as f:
#     test_means_mcmc_barycenters = pickle.load(f)

# with open("results/test_preds_mcmc_barycenters_less_train_pts.pickle", 'rb') as f:
#     test_preds_mcmc_barycenters = pickle.load(f)

In [None]:
out_sample_mean_errors = vmap(grass_dist)(Ws_test, test_means_mcmc_barycenters)
out_sample_pred_errors = vmap(grass_dist)(Ws_test, test_preds_mcmc_barycenters)

In [None]:
plt.plot(s_test,out_sample_mean_errors, label='error using means')
plt.plot(s_test,out_sample_pred_errors, label='error using preds')
plt.vlines(s_train, 0, 1, colors="green", linestyles="dashed")
plt.legend()
plt.show()

In [None]:
sd_s_test = []
for i in tqdm(range(s_test.shape[0])):
    fixed = test_preds_mcmc_barycenters[i]
    dists = vmap(lambda W: grass_dist(W, fixed))(Ws_preds[:,i,:,:])
    dists_Sq = dists**2
    sd_s_test.append(np.sqrt(dists_Sq.mean()))

In [None]:
sd_s_test = np.array(sd_s_test)

In [None]:
test_pd_data = {'s': s_test, 'errors_mean': out_sample_mean_errors, 'errors_pred': out_sample_pred_errors, 'sd': sd_s_test}
out_sample_errors_df = pd.DataFrame(data=test_pd_data)
out_sample_errors_df.head()

In [None]:
out_sample_errors_df.drop(['s'], axis=1).describe()

In [None]:
# plot_AS_dir_preds(Ws_preds, Ws_test, s_test, s_train)

In [None]:
# plot_grass_dists(Ws_preds, Ws_test, s_test)

# Increase training points - keep old anchor point

In [None]:
# subsample data
s_gap = 2
s_train = s_test[::s_gap].copy()
print(f"Number of training points: {s_train.shape[0]}")
Ws_train = Ws_test[::s_gap,:,:].copy()

for i in range(d):
    plt.plot(s_test, Ws_test[:,i,0])
    plt.scatter(s_train, Ws_train[:,i,0], c='r')
    plt.title(f'{i+1}th component of projection')
    plt.grid()
    plt.xlabel(r'$s$')
    plt.show()

In [None]:
anchor_point

In [None]:
# compute log of training data and full data
log_Ws_train = vmap(lambda W: grass_log(anchor_point, W))(Ws_train)
log_Ws_test = vmap(lambda W: grass_log(anchor_point, W))(Ws_test)

In [None]:
alphas = np.array([subspace_angle(w) for w in Ws_test])
alphas_train = np.array([subspace_angle(w) for w in Ws_train])

In [None]:
model_config = {
    'anchor_point': anchor_point.tolist(),
    'd_in': 1,
    'Omega' : None,
    'k_include_noise': True,
    'var' : 1.0,
    'length' : None, 
    'noise' : None,
    'require_noise' : False,
    'jitter' : 1e-06,
    'cov_jitter' : 1e-4,
    'L_jitter' : 1e-8,
    'reorthonormalize' : False,
    'b' : 0.5,
    # 'ell': 0.0075
    'ell': 0.01
}

def model(s, log_Ws, grass_config = model_config):
    U = np.array(grass_config['anchor_point'])
    d, n = U.shape
    N = s.shape[0]
    d_n = d * n
    # N_params = N * d_n
    if log_Ws is not None:
        assert log_Ws.shape == (N, d, n)
    
    # get/sample Omega
    if grass_config['Omega'] is None:
        sigmas = numpyro.sample('sigmas', dist.LogNormal(0.0, 1.0).expand([d_n]))
        L_factor = numpyro.sample('L_factor', dist.LKJ(d_n, 1.0))
        L = numpyro.deterministic('L', L_factor + grass_config['L_jitter'] * np.eye(d_n))
        Omega = numpyro.deterministic('Omega', np.outer(sigmas, sigmas) * L)
    else:
        Omega = np.array(grass_config['Omega'])
        
    # get/sample kernel params
    if grass_config['var'] is None:
        # sample var
        var = numpyro.sample("kernel_var", dist.LogNormal(0.0, grass_config['b']))
    else:
        var = grass_config['var']

    if grass_config['length'] is None:
        # sample length
        length = numpyro.sample("kernel_length", dist.LogNormal(0.0, grass_config['b']))
    else:
        length = grass_config['length']

    if grass_config['require_noise']:
        if grass_config['noise'] is None:
            # sample noise
            noise = numpyro.sample("kernel_noise", dist.LogNormal(0.0, grass_config['b']))
        else:
            noise = grass_config['noise']
    else:
        noise = 0.0
    
    kernel_params = {'var': var, 'length': length, 'noise': noise}
    # create kernel function
    k = lambda t, s: rbf(t, s, kernel_params, jitter=grass_config['jitter'], include_noise=grass_config['k_include_noise'])
    # create mean function
    mu = lambda s: zero_mean(s, d, n)

    # initialize GrassGP
    grass_gp = GrassGP(d_in=grass_config['d_in'], d_out=(d,n), mu=mu, k=k, Omega=Omega, U=U, cov_jitter=grass_config['cov_jitter'])
    
    # sample Deltas
    Deltas = grass_gp.tangent_model(s)
    
    # # # # ! check what power this should be
    # likelihood
    ell = grass_config['ell']
    with numpyro.plate("N", N):
        numpyro.sample("log_W", dist.continuous.MatrixNormal(loc=Deltas, scale_tril_row=ell * np.eye(d),scale_tril_column=np.eye(n)), obs=log_Ws)

TangentSpaceModelConf = builds(model, grass_config=model_config, zen_partial=True)

In [None]:
SVIConfig = make_config(
    seed = 123514354575,
    maxiter = 15000,
    step_size = 0.001
)

TrainConfig = make_config(
    seed = 9870687,
    n_warmup = 2000,
    n_samples = 7000,
    n_chains = 1,
    n_thinning = 2
)

Config = make_config(
    model = TangentSpaceModelConf,
    svi = SVIConfig,
    train = TrainConfig
)

In [None]:
def train(cfg):
    # instantiate grass model
    model = instantiate(cfg.model)
    
    # run SVI to get MAP esimtate to initialise MCMC
    svi_key = random.PRNGKey(cfg.svi.seed)
    maxiter = cfg.svi.maxiter
    step_size = cfg.svi.step_size
    print(f"n_train = {s_train.shape[0]}")
    print("Running SVI for MAP estimate to initialise MCMC")
    svi_results = run_svi_for_map(svi_key, model, maxiter, step_size, s_train, log_Ws_train)
    
    # plot svi losses
    plt.plot(svi_results.losses)
    plt.show()
    
    # get initialisation from SVI results
    map_est = svi_results.params
    strip_val = len('_auto_loc')
    init_values = {key[:-strip_val]:value for (key, value) in map_est.items()}
    
    # run HMC
    train_key = random.PRNGKey(cfg.train.seed)
    mcmc_config = {'num_warmup' : cfg.train.n_warmup, 'num_samples' : cfg.train.n_samples, 'num_chains' : cfg.train.n_chains, 'thinning' : cfg.train.n_thinning, 'init_strategy' : init_to_value(values=init_values)}
    print("HMC starting.")
    mcmc = run_inference(train_key, mcmc_config, model, s_train, log_Ws_train)    
    # original_stdout = sys.stdout
    # with open('hmc_log.txt', 'w') as f:
    #     sys.stdout = f
    #     mcmc.print_summary()
    #     sys.stdout = original_stdout
    
    samples = mcmc.get_samples()
    inference_data = samples.copy()
    for param, initial_val in init_values.items():
        inference_data[f"{param}-initial_value"] = initial_val
    
    # head = os.getcwd()
    # main_name = "inference_data"
    # path = get_save_path(head, main_name)
    # try:
    #     safe_save(path, inference_data)
    # except FileExistsError:
    #     print("File exists so not saving.")
    return inference_data

In [None]:
numpyro.render_model(instantiate(Config.model), model_args=(s_train,log_Ws_train))

In [None]:
inference_data = train(Config)

In [None]:
samples = dict(filter(lambda elem: 'initial_value' not in elem[0], inference_data.items()))
initial_values = dict(filter(lambda elem: 'initial_value' in elem[0], inference_data.items()))
assert set(samples.keys()).union(initial_values.keys()) == set(inference_data.keys())

In [None]:
with open("results/samples_more_train_pts_same_anchor.pickle", 'wb') as f:
    pickle.dump(samples, f)

In [None]:
# with open("results/samples_more_train_pts_same_anchor.pickle", 'rb') as f:
#     samples = pickle.load(f)

In [None]:
my_samples = flatten_samples(samples, ignore=[])

In [None]:
trace_plot_vars = ['kernel_length']
for key in my_samples.keys():
    if 'Omega' in key:
        trace_plot_vars.append(key)
    if 'sigmas' in key:
        trace_plot_vars.append(key)
        
my_samples[trace_plot_vars].plot(subplots=True, figsize=(10,40), sharey=False)
plt.show()

In [None]:
for var in trace_plot_vars:
    sm.graphics.tsa.plot_acf(my_samples[var], lags=100)
    plt.title(f"acf for {var}")
    plt.show()

In [None]:
tol=1e-5

samples_Ws_train = vmap(lambda Deltas: convert_to_projs(Deltas, anchor_point, reorthonormalize=False))(samples['Deltas'])

for ws in samples_Ws_train:
    assert vmap(lambda w: valid_grass_point(w, tol=tol))(ws).all()

In [None]:
mcmc_barycenters = []
for i in tqdm(range(s_train.shape[0])):
    points = samples_Ws_train[:,i,:,:]
    mcmc_losses = vmap(lambda theta: loss(theta, points))(thetas)
    mcmc_theta_argmin = thetas[mcmc_losses.argmin()]
    barycenter = subspace_angle_to_grass_pt(mcmc_theta_argmin)
    mcmc_barycenters.append(barycenter)
    # plt.plot(thetas, mcmc_losses)
    # plt.scatter(mcmc_theta_argmin, mcmc_losses.min(),color="red")
    # plt.grid()
    # plt.title(f"{i}")
    # plt.show()

In [None]:
mcmc_barycenters = np.array(mcmc_barycenters)

In [None]:
with open("results/mcmc_barycenters_more_train_pts_same_anchor.pickle", 'wb') as f:
    pickle.dump(mcmc_barycenters, f)

In [None]:
# with open("results/mcmc_barycenters_more_train_pts_same_anchor.pickle", 'rb') as f:
#     mcmc_barycenters = pickle.load(f)

In [None]:
in_sample_errors = vmap(grass_dist)(Ws_train, mcmc_barycenters)

In [None]:
plt.plot(s_train,in_sample_errors)
plt.show()

In [None]:
sd_s_train = []
for i in tqdm(range(s_train.shape[0])):
    fixed = mcmc_barycenters[i]
    dists = vmap(lambda W: grass_dist(W, fixed))(samples_Ws_train[:,i,:,:])
    dists_Sq = dists**2
    sd_s_train.append(np.sqrt(dists_Sq.mean()))

In [None]:
sd_s_train = np.array(sd_s_train)

In [None]:
pd_data = {'s': s_train, 'errors': in_sample_errors, 'sd': sd_s_train}
in_sample_errors_df = pd.DataFrame(data=pd_data)
in_sample_errors_df.head()

In [None]:
in_sample_errors_df.drop(["s"],axis=1).describe()

In [None]:
# plot_grass_dists(samples_Ws_train, Ws_train, s_train)

In [None]:
samples_alphas_train = np.array([[subspace_angle(w)for w in Ws_sample] for Ws_sample in samples_Ws_train])

In [None]:
with open("results/samples_alphas_train_more_train_pts_same_anchor.pickle", 'wb') as f:
    pickle.dump(samples_alphas_train, f)

In [None]:
# with open("results/samples_alphas_train_more_train_pts_same_anchor.pickle", 'rb') as f:
#     samples_alphas_train = pickle.load(f)

In [None]:
percentile_levels = [2.5, 97.5]
conf_level = percentile_levels[-1] - percentile_levels[0]
percentiles = np.percentile(samples_alphas_train, np.array(percentile_levels), axis=0)
lower = percentiles[0,:]
upper = percentiles[1,:]

In [None]:
plt.plot(s_test, alphas, c='black', alpha=0.5, label='full data')
plt.scatter(s_train, alphas_train, label='training data', c='g')
plt.scatter(s_train, samples_alphas_train.mean(axis=0), label='mean samples', c='r')
plt.fill_between(s_train, lower, upper,  color='lightblue', alpha=0.75,label=f'{conf_level}% credible interval')
plt.xlabel(r"$s$")
plt.ylabel("subspace angle")
plt.legend()
plt.show()

In [None]:
def predict_tangents(
    key: chex.ArrayDevice,
    s_test: chex.ArrayDevice,
    s_train: chex.ArrayDevice,
    Vs_train: chex.ArrayDevice,
    dict_cfg,
    samples: dict,
    jitter: float = 1e-8
) -> Tuple[chex.ArrayDevice, chex.ArrayDevice]:
    
    d_in = dict_cfg.model.grass_config.d_in
    U = np.array(dict_cfg.model.grass_config.anchor_point)
    d, n = U.shape
    cov_jitter = dict_cfg.model.grass_config.cov_jitter
    k_include_noise = dict_cfg.model.grass_config.k_include_noise
    kern_jitter = dict_cfg.model.grass_config.jitter
    n_samples = dict_cfg.train.n_samples // dict_cfg.train.n_thinning
    assert n_samples == samples['Deltas'].shape[0]
    
    def predict(
        key: chex.ArrayDevice,
        Omega: chex.ArrayDevice,
        var: float,
        length: float,
        noise: float,
    ) -> Tuple[chex.ArrayDevice, chex.ArrayDevice]:
        # iniatilize GrassGP
        kernel_params = {'var': var, 'length': length, 'noise': noise}
        k = lambda t, s: rbf(t, s, kernel_params, jitter=kern_jitter, include_noise=k_include_noise)
        mu = lambda s: zero_mean(s, d, n)
        grass_gp = GrassGP(d_in=d_in, d_out=(d, n), mu=mu, k=k, Omega=Omega, U=U, cov_jitter=cov_jitter)

        # predict
        Deltas_mean, Deltas_pred = grass_gp.predict_tangents(key, s_test, s_train, Vs_train, jitter=jitter)
        return Deltas_mean, Deltas_pred

    # initialize vmap args
    vmap_args = (random.split(key, n_samples),)
    
    cfg_Omega = dict_cfg.model.grass_config.Omega
    cfg_var = dict_cfg.model.grass_config.var
    cfg_length = dict_cfg.model.grass_config.length
    cfg_noise = dict_cfg.model.grass_config.noise
    cfg_require_noise = dict_cfg.model.grass_config.require_noise
    
    if cfg_Omega is None:
        vmap_args += (samples['Omega'],)
    else:
        cfg_Omega = np.array(cfg_Omega)
        vmap_args += (np.repeat(cfg_Omega[None,:,:], n_samples, axis=0),)
    
    if cfg_var is None:
        vmap_args += (samples['kernel_var'],)
    else:
        vmap_args += (cfg_var * np.ones(n_samples),)
        
    if cfg_length is None:
        vmap_args += (samples['kernel_length'],)
    else:
        vmap_args += (cfg_length * np.ones(n_samples),)
        
    if cfg_require_noise:
        if cfg_noise is None:
            vmap_args += (samples['kernel_noise'],)
        else:
            vmap_args += (cfg_noise * np.ones(n_samples),)
    else:
        vmap_args += (np.zeros(n_samples),)
    
    assert len(vmap_args) == 5
    Deltas_means, Deltas_preds = vmap(predict)(*vmap_args)
    return Deltas_means, Deltas_preds

In [None]:
config = to_dictconf(Config)

In [None]:
pred_key = random.PRNGKey(6578)
Deltas_means, Deltas_preds = predict_tangents(pred_key, s_test, s_train, log_Ws_train, config, samples)
assert np.isnan(Deltas_means).sum() == 0
assert np.isnan(Deltas_preds).sum() == 0

In [None]:
with open("results/Deltas_means_more_train_pts_same_anchor.pickle", 'wb') as f:
    pickle.dump(Deltas_means, f)
    
with open("results/Deltas_preds_more_train_pts_same_anchor.pickle", 'wb') as f:
    pickle.dump(Deltas_preds, f)

In [None]:
# with open("results/Deltas_means_more_train_pts_same_anchor.pickle", 'rb') as f:
#     Deltas_means = pickle.load(f)
    
# with open("results/Deltas_preds_more_train_pts_same_anchor.pickle", 'rb') as f:
#     Deltas_preds = pickle.load(f)

In [None]:
plt.rcParams["figure.figsize"] = (12,6)
percentile_levels = [2.5, 97.5]
conf_level = percentile_levels[-1] - percentile_levels[0]
for i in range(d):
    obs = log_Ws_train[:,i,0]
    means = Deltas_means[:,:,i,0]
    means_avg = np.mean(means, axis=0)
    preds = Deltas_preds[:,:,i,0]
    percentiles = np.percentile(preds, np.array(percentile_levels), axis=0)
    lower = percentiles[0,:]
    upper = percentiles[1,:]
    plt.plot(s_test, log_Ws_test[:,i,0], label='full data',c='black', alpha=0.75, linestyle='dashed')
    plt.scatter(s_train, log_Ws_train[:,i,0], label='training data', c='g')
    plt.plot(s_test, means_avg, label='averaged mean prediction', c='r', alpha=0.75)
    plt.fill_between(s_test, lower, upper, color='lightblue', alpha=0.75, label=f'{conf_level}% credible interval')
    plt.xlabel(r"$s$")
    plt.legend()
    plt.vlines(s_train, 0.99*lower.min(), 1.01*upper.max(), colors='green', linestyles='dashed')
    plt.title(f"{i+1}th component of tangents")
    plt.show()

In [None]:
def predict_grass(
    key: chex.ArrayDevice,
    s_test: chex.ArrayDevice,
    s_train: chex.ArrayDevice,
    Vs_train: chex.ArrayDevice,
    dict_cfg,
    samples: dict,
    jitter: float = 1e-8,
    reortho: bool = False
) -> Tuple[chex.ArrayDevice, chex.ArrayDevice]:
    
    d_in = dict_cfg.model.grass_config.d_in
    U = np.array(dict_cfg.model.grass_config.anchor_point)
    d, n = U.shape
    cov_jitter = dict_cfg.model.grass_config.cov_jitter
    k_include_noise = dict_cfg.model.grass_config.k_include_noise
    kern_jitter = dict_cfg.model.grass_config.jitter
    n_samples = dict_cfg.train.n_samples // dict_cfg.train.n_thinning
    assert n_samples == samples['Deltas'].shape[0]
    
    def predict(
        key: chex.ArrayDevice,
        Omega: chex.ArrayDevice,
        var: float,
        length: float,
        noise: float,
    ) -> Tuple[chex.ArrayDevice, chex.ArrayDevice]:
        # iniatilize GrassGP
        kernel_params = {'var': var, 'length': length, 'noise': noise}
        k = lambda t, s: rbf(t, s, kernel_params, jitter=kern_jitter, include_noise=k_include_noise)
        mu = lambda s: zero_mean(s, d, n)
        grass_gp = GrassGP(d_in=d_in, d_out=(d, n), mu=mu, k=k, Omega=Omega, U=U, cov_jitter=cov_jitter)

        # predict
        Ws_mean, Ws_pred = grass_gp.predict_grass(key, s_test, s_train, Vs_train, jitter=jitter, reortho=reortho)
        return Ws_mean, Ws_pred

    # initialize vmap args
    vmap_args = (random.split(key, n_samples),)
    
    cfg_Omega = dict_cfg.model.grass_config.Omega
    cfg_var = dict_cfg.model.grass_config.var
    cfg_length = dict_cfg.model.grass_config.length
    cfg_noise = dict_cfg.model.grass_config.noise
    cfg_require_noise = dict_cfg.model.grass_config.require_noise
    
    if cfg_Omega is None:
        vmap_args += (samples['Omega'],)
    else:
        cfg_Omega = np.array(cfg_Omega)
        vmap_args += (np.repeat(cfg_Omega[None,:,:], n_samples, axis=0),)
    
    if cfg_var is None:
        vmap_args += (samples['kernel_var'],)
    else:
        vmap_args += (cfg_var * np.ones(n_samples),)
        
    if cfg_length is None:
        vmap_args += (samples['kernel_length'],)
    else:
        vmap_args += (cfg_length * np.ones(n_samples),)
        
    if cfg_require_noise:
        if cfg_noise is None:
            vmap_args += (samples['kernel_noise'],)
        else:
            vmap_args += (cfg_noise * np.ones(n_samples),)
    else:
        vmap_args += (np.zeros(n_samples),)
    
    assert len(vmap_args) == 5
    Ws_means, Ws_preds = vmap(predict)(*vmap_args)
    return Ws_means, Ws_preds

In [None]:
pred_key_grass = random.PRNGKey(7695)
Ws_means, Ws_preds = predict_grass(pred_key_grass, s_test, s_train, log_Ws_train, config, samples)
assert np.isnan(Ws_means).sum() == 0
assert np.isnan(Ws_preds).sum() == 0

In [None]:
with open("results/Ws_means_more_train_pts_same_anchor.pickle", 'wb') as f:
    pickle.dump(Ws_means, f)
    
with open("results/Ws_preds_more_train_pts_same_anchor.pickle", 'wb') as f:
    pickle.dump(Ws_preds, f)

In [None]:
# with open("results/Ws_means_more_train_pts_same_anchor.pickle", 'rb') as f:
#     Ws_means = pickle.load(f)
    
# with open("results/Ws_preds_more_train_pts_same_anchor.pickle", 'rb') as f:
#     Ws_preds = pickle.load(f)

In [None]:
plt.rcParams["figure.figsize"] = (12,6)
percentile_levels = [2.5, 97.5]
conf_level = percentile_levels[-1] - percentile_levels[0]
for i in range(d):
    obs = Ws_train[:,i,0]
    means = Ws_means[:,:,i,0]
    means_avg = np.mean(means, axis=0)
    preds = Ws_preds[:,:,i,0]
    percentiles = np.percentile(preds, np.array(percentile_levels), axis=0)
    lower = percentiles[0,:]
    upper = percentiles[1,:]
    plt.plot(s_test, Ws_test[:,i,0], label='full data',c='black', alpha=0.75, linestyle='dashed')
    plt.scatter(s_train, Ws_train[:,i,0], label='training data', c='g')
    plt.plot(s_test, means_avg, label='averaged mean prediction', c='r', alpha=0.75)
    plt.fill_between(s_test, lower, upper, color='lightblue', alpha=0.75, label=f'{conf_level}% credible interval')
    plt.xlabel(r"$s$")
    plt.legend()
    plt.vlines(s_train, 0.99*lower.min(), 1.01*upper.max(), colors='green', linestyles='dashed')
    plt.title(f"{i+1}th component of projections")
    plt.show()

In [None]:
alphas_means = np.array([[subspace_angle(w) for w in mean] for mean in Ws_means])
alphas_preds = np.array([[subspace_angle(w) for w in pred] for pred in Ws_preds])

In [None]:
with open("results/alphas_means_more_train_pts.pickle_same_anchor", 'wb') as f:
    pickle.dump(alphas_means, f)
    
with open("results/alphas_preds_more_train_pts.pickle_same_anchor", 'wb') as f:
    pickle.dump(alphas_preds, f)

In [None]:
# with open("results/alphas_means_more_train_pts.pickle_same_anchor", 'rb') as f:
#     alphas_means = pickle.load(f)
    
# with open("results/alphas_preds_more_train_pts.pickle_same_anchor", 'rb') as f:
#     alphas_preds = pickle.load(f)

In [None]:
plt.rcParams["figure.figsize"] = (12,6)
percentile_levels = [2.5, 97.5]
conf_level = percentile_levels[-1] - percentile_levels[0]
alphas_means_avg = np.mean(alphas_means, axis=0)
percentiles = np.percentile(alphas_preds, np.array(percentile_levels), axis=0)
lower = percentiles[0,:]
upper = percentiles[1,:]
plt.plot(s_test, alphas, label='full data',c='black', alpha=0.75, linestyle='dashed')
plt.scatter(s_train, alphas_train, label='training data', c='g')
plt.plot(s_test, alphas_means_avg, label='averaged mean prediction', c='r', alpha=0.75)
plt.fill_between(s_test, lower, upper, color='lightblue', alpha=0.75, label=f'{conf_level}% credible interval')
plt.xlabel(r"$s$")
plt.ylabel("subspace angle")
plt.legend()
plt.vlines(s_train, 0, np.pi, colors='green', linestyles='dashed')
plt.title(f"predictions for subspace angles")
plt.show()

In [None]:
Ws_means.shape

In [None]:
test_means_mcmc_barycenters = []
for i in tqdm(range(s_test.shape[0])):
    points = Ws_means[:,i,:,:]
    test_means_mcmc_losses = vmap(lambda theta: loss(theta, points))(thetas)
    test_means_mcmc_theta_argmin = thetas[test_means_mcmc_losses.argmin()]
    barycenter = subspace_angle_to_grass_pt(test_means_mcmc_theta_argmin)
    test_means_mcmc_barycenters.append(barycenter)
    # plt.plot(thetas, test_means_mcmc_losses)
    # plt.scatter(test_means_mcmc_theta_argmin, test_means_mcmc_losses.min(),color="red")
    # plt.grid()
    # plt.title(f"{i}")
    # plt.show()

In [None]:
test_preds_mcmc_barycenters = []
for i in tqdm(range(s_test.shape[0])):
    points = Ws_preds[:,i,:,:]
    test_preds_mcmc_losses = vmap(lambda theta: loss(theta, points))(thetas)
    test_preds_mcmc_theta_argmin = thetas[test_preds_mcmc_losses.argmin()]
    barycenter = subspace_angle_to_grass_pt(test_preds_mcmc_theta_argmin)
    test_preds_mcmc_barycenters.append(barycenter)
    # plt.plot(thetas, test_preds_mcmc_losses)
    # plt.scatter(test_preds_mcmc_theta_argmin, test_preds_mcmc_losses.min(),color="red")
    # plt.grid()
    # plt.title(f"{i}")
    # plt.show()

In [None]:
test_means_mcmc_barycenters = np.array(test_means_mcmc_barycenters)

In [None]:
test_preds_mcmc_barycenters = np.array(test_preds_mcmc_barycenters)

In [None]:
with open("results/test_means_mcmc_barycenters_more_train_pts_same_anchor.pickle", 'wb') as f:
    pickle.dump(test_means_mcmc_barycenters, f)

with open("results/test_preds_mcmc_barycenters_more_train_pts_same_anchor.pickle", 'wb') as f:
    pickle.dump(test_preds_mcmc_barycenters, f)

In [None]:
# with open("results/test_means_mcmc_barycenters_more_train_pts_same_anchor.pickle", 'rb') as f:
#     test_means_mcmc_barycenters = pickle.load(f)

# with open("results/test_preds_mcmc_barycenters_more_train_pts_same_anchor.pickle", 'rb') as f:
#     test_preds_mcmc_barycenters = pickle.load(f)

In [None]:
out_sample_mean_errors = vmap(grass_dist)(Ws_test, test_means_mcmc_barycenters)
out_sample_pred_errors = vmap(grass_dist)(Ws_test, test_preds_mcmc_barycenters)

In [None]:
plt.plot(s_test,out_sample_mean_errors, label='error using means')
plt.plot(s_test,out_sample_pred_errors, label='error using preds')
plt.vlines(s_train, 0, 1, colors="green", linestyles="dashed")
plt.legend()
plt.show()

In [None]:
sd_s_test = []
for i in tqdm(range(s_test.shape[0])):
    fixed = test_preds_mcmc_barycenters[i]
    dists = vmap(lambda W: grass_dist(W, fixed))(Ws_preds[:,i,:,:])
    dists_Sq = dists**2
    sd_s_test.append(np.sqrt(dists_Sq.mean()))

In [None]:
sd_s_test = np.array(sd_s_test)

In [None]:
test_pd_data = {'s': s_test, 'errors_mean': out_sample_mean_errors, 'errors_pred': out_sample_pred_errors, 'sd': sd_s_test}
out_sample_errors_df = pd.DataFrame(data=test_pd_data)
out_sample_errors_df.head()

In [None]:
out_sample_errors_df.drop(['s'], axis=1).describe()

# Increase training points - new anchor point

In [None]:
# subsample data
s_gap = 2
s_train = s_test[::s_gap].copy()
print(f"Number of training points: {s_train.shape[0]}")
Ws_train = Ws_test[::s_gap,:,:].copy()

for i in range(d):
    plt.plot(s_test, Ws_test[:,i,0])
    plt.scatter(s_train, Ws_train[:,i,0], c='r')
    plt.title(f'{i+1}th component of projection')
    plt.grid()
    plt.xlabel(r'$s$')
    plt.show()

In [None]:
thetas = np.linspace(0, np.pi, 1000)
losses = vmap(lambda theta: loss(theta, Ws_train))(thetas)

theta_argmin = thetas[losses.argmin()]
anchor_point = subspace_angle_to_grass_pt(theta_argmin)
assert valid_grass_point(anchor_point)

plt.plot(thetas,losses)
plt.scatter(theta_argmin, losses.min(), color="red", label='loss of anchor point')
plt.grid()
plt.xlabel("subspace angle")
plt.title("Plot of loss vs subspace angle to determine Karcher mean")
plt.legend()
plt.show()

fig = plt.figure(figsize=(7,7))
ax = fig.add_subplot(111)
ax.set_ylim((-1.25,1.25))
ax.set_xlim((-1.25,1.25))
ax.scatter(Ws_train[:,0,0], Ws_train[:,1,0], color="blue", alpha=0.25, label='training points')
ax.scatter(anchor_point[0,0], anchor_point[1,0], color="red", marker="*", label="anchor point")
ax.grid()
ax.legend()
ax.set_title("Plot of training points with anchor point on circle rep of Gr(2,1)")

plt.show()

In [None]:
# compute log of training data and full data
log_Ws_train = vmap(lambda W: grass_log(anchor_point, W))(Ws_train)
log_Ws_test = vmap(lambda W: grass_log(anchor_point, W))(Ws_test)

In [None]:
alphas = np.array([subspace_angle(w) for w in Ws_test])
alphas_train = np.array([subspace_angle(w) for w in Ws_train])

In [None]:
model_config = {
    'anchor_point': anchor_point.tolist(),
    'd_in': 1,
    'Omega' : None,
    'k_include_noise': True,
    'var' : 1.0,
    'length' : None, 
    'noise' : None,
    'require_noise' : False,
    'jitter' : 1e-06,
    'cov_jitter' : 1e-4,
    'L_jitter' : 1e-8,
    'reorthonormalize' : False,
    'b' : 0.5,
    # 'ell': 0.0075
    'ell': 0.01
}

def model(s, log_Ws, grass_config = model_config):
    U = np.array(grass_config['anchor_point'])
    d, n = U.shape
    N = s.shape[0]
    d_n = d * n
    # N_params = N * d_n
    if log_Ws is not None:
        assert log_Ws.shape == (N, d, n)
    
    # get/sample Omega
    if grass_config['Omega'] is None:
        sigmas = numpyro.sample('sigmas', dist.LogNormal(0.0, 1.0).expand([d_n]))
        L_factor = numpyro.sample('L_factor', dist.LKJ(d_n, 1.0))
        L = numpyro.deterministic('L', L_factor + grass_config['L_jitter'] * np.eye(d_n))
        Omega = numpyro.deterministic('Omega', np.outer(sigmas, sigmas) * L)
    else:
        Omega = np.array(grass_config['Omega'])
        
    # get/sample kernel params
    if grass_config['var'] is None:
        # sample var
        var = numpyro.sample("kernel_var", dist.LogNormal(0.0, grass_config['b']))
    else:
        var = grass_config['var']

    if grass_config['length'] is None:
        # sample length
        length = numpyro.sample("kernel_length", dist.LogNormal(0.0, grass_config['b']))
    else:
        length = grass_config['length']

    if grass_config['require_noise']:
        if grass_config['noise'] is None:
            # sample noise
            noise = numpyro.sample("kernel_noise", dist.LogNormal(0.0, grass_config['b']))
        else:
            noise = grass_config['noise']
    else:
        noise = 0.0
    
    kernel_params = {'var': var, 'length': length, 'noise': noise}
    # create kernel function
    k = lambda t, s: rbf(t, s, kernel_params, jitter=grass_config['jitter'], include_noise=grass_config['k_include_noise'])
    # create mean function
    mu = lambda s: zero_mean(s, d, n)

    # initialize GrassGP
    grass_gp = GrassGP(d_in=grass_config['d_in'], d_out=(d,n), mu=mu, k=k, Omega=Omega, U=U, cov_jitter=grass_config['cov_jitter'])
    
    # sample Deltas
    Deltas = grass_gp.tangent_model(s)
    
    # # # # ! check what power this should be
    # likelihood
    ell = grass_config['ell']
    with numpyro.plate("N", N):
        numpyro.sample("log_W", dist.continuous.MatrixNormal(loc=Deltas, scale_tril_row=ell * np.eye(d),scale_tril_column=np.eye(n)), obs=log_Ws)

TangentSpaceModelConf = builds(model, grass_config=model_config, zen_partial=True)

In [None]:
SVIConfig = make_config(
    seed = 123514354575,
    maxiter = 15000,
    step_size = 0.001
)

TrainConfig = make_config(
    seed = 9870687,
    n_warmup = 2000,
    n_samples = 7000,
    n_chains = 1,
    n_thinning = 2
)

Config = make_config(
    model = TangentSpaceModelConf,
    svi = SVIConfig,
    train = TrainConfig
)

In [None]:
def train(cfg):
    # instantiate grass model
    model = instantiate(cfg.model)
    
    # run SVI to get MAP esimtate to initialise MCMC
    svi_key = random.PRNGKey(cfg.svi.seed)
    maxiter = cfg.svi.maxiter
    step_size = cfg.svi.step_size
    print(f"n_train = {s_train.shape[0]}")
    print("Running SVI for MAP estimate to initialise MCMC")
    svi_results = run_svi_for_map(svi_key, model, maxiter, step_size, s_train, log_Ws_train)
    
    # plot svi losses
    plt.plot(svi_results.losses)
    plt.show()
    
    # get initialisation from SVI results
    map_est = svi_results.params
    strip_val = len('_auto_loc')
    init_values = {key[:-strip_val]:value for (key, value) in map_est.items()}
    
    # run HMC
    train_key = random.PRNGKey(cfg.train.seed)
    mcmc_config = {'num_warmup' : cfg.train.n_warmup, 'num_samples' : cfg.train.n_samples, 'num_chains' : cfg.train.n_chains, 'thinning' : cfg.train.n_thinning, 'init_strategy' : init_to_value(values=init_values)}
    print("HMC starting.")
    mcmc = run_inference(train_key, mcmc_config, model, s_train, log_Ws_train)    
    # original_stdout = sys.stdout
    # with open('hmc_log.txt', 'w') as f:
    #     sys.stdout = f
    #     mcmc.print_summary()
    #     sys.stdout = original_stdout
    
    samples = mcmc.get_samples()
    inference_data = samples.copy()
    for param, initial_val in init_values.items():
        inference_data[f"{param}-initial_value"] = initial_val
    
    # head = os.getcwd()
    # main_name = "inference_data"
    # path = get_save_path(head, main_name)
    # try:
    #     safe_save(path, inference_data)
    # except FileExistsError:
    #     print("File exists so not saving.")
    return inference_data

In [None]:
numpyro.render_model(instantiate(Config.model), model_args=(s_train,log_Ws_train))

In [None]:
inference_data = train(Config)

In [None]:
samples = dict(filter(lambda elem: 'initial_value' not in elem[0], inference_data.items()))
initial_values = dict(filter(lambda elem: 'initial_value' in elem[0], inference_data.items()))
assert set(samples.keys()).union(initial_values.keys()) == set(inference_data.keys())

In [None]:
with open("results/samples_more_train_pts_new_anchor.pickle", 'wb') as f:
    pickle.dump(samples, f)

In [None]:
# with open("results/samples_more_train_pts_new_anchor.pickle", 'rb') as f:
#     samples = pickle.load(f)

In [None]:
my_samples = flatten_samples(samples, ignore=[])

In [None]:
trace_plot_vars = ['kernel_length']
for key in my_samples.keys():
    if 'Omega' in key:
        trace_plot_vars.append(key)
    if 'sigmas' in key:
        trace_plot_vars.append(key)
        
my_samples[trace_plot_vars].plot(subplots=True, figsize=(10,40), sharey=False)
plt.show()

In [None]:
for var in trace_plot_vars:
    sm.graphics.tsa.plot_acf(my_samples[var], lags=100)
    plt.title(f"acf for {var}")
    plt.show()

In [None]:
tol=1e-5

samples_Ws_train = vmap(lambda Deltas: convert_to_projs(Deltas, anchor_point, reorthonormalize=False))(samples['Deltas'])

for ws in samples_Ws_train:
    assert vmap(lambda w: valid_grass_point(w, tol=tol))(ws).all()

In [None]:
mcmc_barycenters = []
for i in tqdm(range(s_train.shape[0])):
    points = samples_Ws_train[:,i,:,:]
    mcmc_losses = vmap(lambda theta: loss(theta, points))(thetas)
    mcmc_theta_argmin = thetas[mcmc_losses.argmin()]
    barycenter = subspace_angle_to_grass_pt(mcmc_theta_argmin)
    mcmc_barycenters.append(barycenter)
    # plt.plot(thetas, mcmc_losses)
    # plt.scatter(mcmc_theta_argmin, mcmc_losses.min(),color="red")
    # plt.grid()
    # plt.title(f"{i}")
    # plt.show()

In [None]:
mcmc_barycenters = np.array(mcmc_barycenters)

In [None]:
with open("results/mcmc_barycenters_more_train_pts_new_anchor.pickle", 'wb') as f:
    pickle.dump(mcmc_barycenters, f)

In [None]:
# with open("results/mcmc_barycenters_more_train_pts_new_anchor.pickle", 'rb') as f:
#     mcmc_barycenters = pickle.load(f)

In [None]:
in_sample_errors = vmap(grass_dist)(Ws_train, mcmc_barycenters)

In [None]:
plt.plot(s_train,in_sample_errors)
plt.show()

In [None]:
sd_s_train = []
for i in tqdm(range(s_train.shape[0])):
    fixed = mcmc_barycenters[i]
    dists = vmap(lambda W: grass_dist(W, fixed))(samples_Ws_train[:,i,:,:])
    dists_Sq = dists**2
    sd_s_train.append(np.sqrt(dists_Sq.mean()))

In [None]:
sd_s_train = np.array(sd_s_train)

In [None]:
pd_data = {'s': s_train, 'errors': in_sample_errors, 'sd': sd_s_train}
in_sample_errors_df = pd.DataFrame(data=pd_data)
in_sample_errors_df.head()

In [None]:
in_sample_errors_df.drop(["s"],axis=1).describe()

In [None]:
# plot_grass_dists(samples_Ws_train, Ws_train, s_train)

In [None]:
samples_alphas_train = np.array([[subspace_angle(w)for w in Ws_sample] for Ws_sample in samples_Ws_train])

In [None]:
with open("results/samples_alphas_train_more_train_pts_new_anchor.pickle", 'wb') as f:
    pickle.dump(samples_alphas_train, f)

In [None]:
# with open("results/samples_alphas_train_more_train_pts_new_anchor.pickle", 'rb') as f:
#     samples_alphas_train = pickle.load(f)

In [None]:
percentile_levels = [2.5, 97.5]
conf_level = percentile_levels[-1] - percentile_levels[0]
percentiles = np.percentile(samples_alphas_train, np.array(percentile_levels), axis=0)
lower = percentiles[0,:]
upper = percentiles[1,:]

In [None]:
plt.plot(s_test, alphas, c='black', alpha=0.5, label='full data')
plt.scatter(s_train, alphas_train, label='training data', c='g')
plt.scatter(s_train, samples_alphas_train.mean(axis=0), label='mean samples', c='r')
plt.fill_between(s_train, lower, upper,  color='lightblue', alpha=0.75,label=f'{conf_level}% credible interval')
plt.xlabel(r"$s$")
plt.ylabel("subspace angle")
plt.legend()
plt.show()

In [None]:
def predict_tangents(
    key: chex.ArrayDevice,
    s_test: chex.ArrayDevice,
    s_train: chex.ArrayDevice,
    Vs_train: chex.ArrayDevice,
    dict_cfg,
    samples: dict,
    jitter: float = 1e-8
) -> Tuple[chex.ArrayDevice, chex.ArrayDevice]:
    
    d_in = dict_cfg.model.grass_config.d_in
    U = np.array(dict_cfg.model.grass_config.anchor_point)
    d, n = U.shape
    cov_jitter = dict_cfg.model.grass_config.cov_jitter
    k_include_noise = dict_cfg.model.grass_config.k_include_noise
    kern_jitter = dict_cfg.model.grass_config.jitter
    n_samples = dict_cfg.train.n_samples // dict_cfg.train.n_thinning
    assert n_samples == samples['Deltas'].shape[0]
    
    def predict(
        key: chex.ArrayDevice,
        Omega: chex.ArrayDevice,
        var: float,
        length: float,
        noise: float,
    ) -> Tuple[chex.ArrayDevice, chex.ArrayDevice]:
        # iniatilize GrassGP
        kernel_params = {'var': var, 'length': length, 'noise': noise}
        k = lambda t, s: rbf(t, s, kernel_params, jitter=kern_jitter, include_noise=k_include_noise)
        mu = lambda s: zero_mean(s, d, n)
        grass_gp = GrassGP(d_in=d_in, d_out=(d, n), mu=mu, k=k, Omega=Omega, U=U, cov_jitter=cov_jitter)

        # predict
        Deltas_mean, Deltas_pred = grass_gp.predict_tangents(key, s_test, s_train, Vs_train, jitter=jitter)
        return Deltas_mean, Deltas_pred

    # initialize vmap args
    vmap_args = (random.split(key, n_samples),)
    
    cfg_Omega = dict_cfg.model.grass_config.Omega
    cfg_var = dict_cfg.model.grass_config.var
    cfg_length = dict_cfg.model.grass_config.length
    cfg_noise = dict_cfg.model.grass_config.noise
    cfg_require_noise = dict_cfg.model.grass_config.require_noise
    
    if cfg_Omega is None:
        vmap_args += (samples['Omega'],)
    else:
        cfg_Omega = np.array(cfg_Omega)
        vmap_args += (np.repeat(cfg_Omega[None,:,:], n_samples, axis=0),)
    
    if cfg_var is None:
        vmap_args += (samples['kernel_var'],)
    else:
        vmap_args += (cfg_var * np.ones(n_samples),)
        
    if cfg_length is None:
        vmap_args += (samples['kernel_length'],)
    else:
        vmap_args += (cfg_length * np.ones(n_samples),)
        
    if cfg_require_noise:
        if cfg_noise is None:
            vmap_args += (samples['kernel_noise'],)
        else:
            vmap_args += (cfg_noise * np.ones(n_samples),)
    else:
        vmap_args += (np.zeros(n_samples),)
    
    assert len(vmap_args) == 5
    Deltas_means, Deltas_preds = vmap(predict)(*vmap_args)
    return Deltas_means, Deltas_preds

In [None]:
config = to_dictconf(Config)

In [None]:
pred_key = random.PRNGKey(6578)
Deltas_means, Deltas_preds = predict_tangents(pred_key, s_test, s_train, log_Ws_train, config, samples)
assert np.isnan(Deltas_means).sum() == 0
assert np.isnan(Deltas_preds).sum() == 0

In [None]:
with open("results/Deltas_means_more_train_pts_new_anchor.pickle", 'wb') as f:
    pickle.dump(Deltas_means, f)
    
with open("results/Deltas_preds_more_train_pts_new_anchor.pickle", 'wb') as f:
    pickle.dump(Deltas_preds, f)

In [None]:
# with open("results/Deltas_means_more_train_pts_new_anchor.pickle", 'rb') as f:
#     Deltas_means = pickle.load(f)
    
# with open("results/Deltas_preds_more_train_pts_new_anchor.pickle", 'rb') as f:
#     Deltas_preds = pickle.load(f)

In [None]:
plt.rcParams["figure.figsize"] = (12,6)
percentile_levels = [2.5, 97.5]
conf_level = percentile_levels[-1] - percentile_levels[0]
for i in range(d):
    obs = log_Ws_train[:,i,0]
    means = Deltas_means[:,:,i,0]
    means_avg = np.mean(means, axis=0)
    preds = Deltas_preds[:,:,i,0]
    percentiles = np.percentile(preds, np.array(percentile_levels), axis=0)
    lower = percentiles[0,:]
    upper = percentiles[1,:]
    plt.plot(s_test, log_Ws_test[:,i,0], label='full data',c='black', alpha=0.75, linestyle='dashed')
    plt.scatter(s_train, log_Ws_train[:,i,0], label='training data', c='g')
    plt.plot(s_test, means_avg, label='averaged mean prediction', c='r', alpha=0.75)
    plt.fill_between(s_test, lower, upper, color='lightblue', alpha=0.75, label=f'{conf_level}% credible interval')
    plt.xlabel(r"$s$")
    plt.legend()
    plt.vlines(s_train, 0.99*lower.min(), 1.01*upper.max(), colors='green', linestyles='dashed')
    plt.title(f"{i+1}th component of tangents")
    plt.show()

In [None]:
def predict_grass(
    key: chex.ArrayDevice,
    s_test: chex.ArrayDevice,
    s_train: chex.ArrayDevice,
    Vs_train: chex.ArrayDevice,
    dict_cfg,
    samples: dict,
    jitter: float = 1e-8,
    reortho: bool = False
) -> Tuple[chex.ArrayDevice, chex.ArrayDevice]:
    
    d_in = dict_cfg.model.grass_config.d_in
    U = np.array(dict_cfg.model.grass_config.anchor_point)
    d, n = U.shape
    cov_jitter = dict_cfg.model.grass_config.cov_jitter
    k_include_noise = dict_cfg.model.grass_config.k_include_noise
    kern_jitter = dict_cfg.model.grass_config.jitter
    n_samples = dict_cfg.train.n_samples // dict_cfg.train.n_thinning
    assert n_samples == samples['Deltas'].shape[0]
    
    def predict(
        key: chex.ArrayDevice,
        Omega: chex.ArrayDevice,
        var: float,
        length: float,
        noise: float,
    ) -> Tuple[chex.ArrayDevice, chex.ArrayDevice]:
        # iniatilize GrassGP
        kernel_params = {'var': var, 'length': length, 'noise': noise}
        k = lambda t, s: rbf(t, s, kernel_params, jitter=kern_jitter, include_noise=k_include_noise)
        mu = lambda s: zero_mean(s, d, n)
        grass_gp = GrassGP(d_in=d_in, d_out=(d, n), mu=mu, k=k, Omega=Omega, U=U, cov_jitter=cov_jitter)

        # predict
        Ws_mean, Ws_pred = grass_gp.predict_grass(key, s_test, s_train, Vs_train, jitter=jitter, reortho=reortho)
        return Ws_mean, Ws_pred

    # initialize vmap args
    vmap_args = (random.split(key, n_samples),)
    
    cfg_Omega = dict_cfg.model.grass_config.Omega
    cfg_var = dict_cfg.model.grass_config.var
    cfg_length = dict_cfg.model.grass_config.length
    cfg_noise = dict_cfg.model.grass_config.noise
    cfg_require_noise = dict_cfg.model.grass_config.require_noise
    
    if cfg_Omega is None:
        vmap_args += (samples['Omega'],)
    else:
        cfg_Omega = np.array(cfg_Omega)
        vmap_args += (np.repeat(cfg_Omega[None,:,:], n_samples, axis=0),)
    
    if cfg_var is None:
        vmap_args += (samples['kernel_var'],)
    else:
        vmap_args += (cfg_var * np.ones(n_samples),)
        
    if cfg_length is None:
        vmap_args += (samples['kernel_length'],)
    else:
        vmap_args += (cfg_length * np.ones(n_samples),)
        
    if cfg_require_noise:
        if cfg_noise is None:
            vmap_args += (samples['kernel_noise'],)
        else:
            vmap_args += (cfg_noise * np.ones(n_samples),)
    else:
        vmap_args += (np.zeros(n_samples),)
    
    assert len(vmap_args) == 5
    Ws_means, Ws_preds = vmap(predict)(*vmap_args)
    return Ws_means, Ws_preds

In [None]:
pred_key_grass = random.PRNGKey(7695)
Ws_means, Ws_preds = predict_grass(pred_key_grass, s_test, s_train, log_Ws_train, config, samples)
assert np.isnan(Ws_means).sum() == 0
assert np.isnan(Ws_preds).sum() == 0

In [None]:
with open("results/Ws_means_more_train_pts_new_anchor.pickle", 'wb') as f:
    pickle.dump(Ws_means, f)
    
with open("results/Ws_preds_more_train_pts_new_anchor.pickle", 'wb') as f:
    pickle.dump(Ws_preds, f)

In [None]:
# with open("results/Ws_means_more_train_pts_new_anchor.pickle", 'rb') as f:
#     Ws_means = pickle.load(f)
    
# with open("results/Ws_preds_more_train_pts_new_anchor.pickle", 'rb') as f:
#     Ws_preds = pickle.load(f)

In [None]:
plt.rcParams["figure.figsize"] = (12,6)
percentile_levels = [2.5, 97.5]
conf_level = percentile_levels[-1] - percentile_levels[0]
for i in range(d):
    obs = Ws_train[:,i,0]
    means = Ws_means[:,:,i,0]
    means_avg = np.mean(means, axis=0)
    preds = Ws_preds[:,:,i,0]
    percentiles = np.percentile(preds, np.array(percentile_levels), axis=0)
    lower = percentiles[0,:]
    upper = percentiles[1,:]
    plt.plot(s_test, Ws_test[:,i,0], label='full data',c='black', alpha=0.75, linestyle='dashed')
    plt.scatter(s_train, Ws_train[:,i,0], label='training data', c='g')
    plt.plot(s_test, means_avg, label='averaged mean prediction', c='r', alpha=0.75)
    plt.fill_between(s_test, lower, upper, color='lightblue', alpha=0.75, label=f'{conf_level}% credible interval')
    plt.xlabel(r"$s$")
    plt.legend()
    plt.vlines(s_train, 0.99*lower.min(), 1.01*upper.max(), colors='green', linestyles='dashed')
    plt.title(f"{i+1}th component of projections")
    plt.show()

In [None]:
alphas_means = np.array([[subspace_angle(w) for w in mean] for mean in Ws_means])
alphas_preds = np.array([[subspace_angle(w) for w in pred] for pred in Ws_preds])

In [None]:
with open("results/alphas_means_more_train_pts.pickle_new_anchor", 'wb') as f:
    pickle.dump(alphas_means, f)
    
with open("results/alphas_preds_more_train_pts.pickle_new_anchor", 'wb') as f:
    pickle.dump(alphas_preds, f)

In [None]:
# with open("results/alphas_means_more_train_pts.pickle_new_anchor", 'rb') as f:
#     alphas_means = pickle.load(f)
    
# with open("results/alphas_preds_more_train_pts.pickle_new_anchor", 'rb') as f:
#     alphas_preds = pickle.load(f)

In [None]:
plt.rcParams["figure.figsize"] = (12,6)
percentile_levels = [2.5, 97.5]
conf_level = percentile_levels[-1] - percentile_levels[0]
alphas_means_avg = np.mean(alphas_means, axis=0)
percentiles = np.percentile(alphas_preds, np.array(percentile_levels), axis=0)
lower = percentiles[0,:]
upper = percentiles[1,:]
plt.plot(s_test, alphas, label='full data',c='black', alpha=0.75, linestyle='dashed')
plt.scatter(s_train, alphas_train, label='training data', c='g')
plt.plot(s_test, alphas_means_avg, label='averaged mean prediction', c='r', alpha=0.75)
plt.fill_between(s_test, lower, upper, color='lightblue', alpha=0.75, label=f'{conf_level}% credible interval')
plt.xlabel(r"$s$")
plt.ylabel("subspace angle")
plt.legend()
plt.vlines(s_train, 0, np.pi, colors='green', linestyles='dashed')
plt.title(f"predictions for subspace angles")
plt.show()

In [None]:
Ws_means.shape

In [None]:
test_means_mcmc_barycenters = []
for i in tqdm(range(s_test.shape[0])):
    points = Ws_means[:,i,:,:]
    test_means_mcmc_losses = vmap(lambda theta: loss(theta, points))(thetas)
    test_means_mcmc_theta_argmin = thetas[test_means_mcmc_losses.argmin()]
    barycenter = subspace_angle_to_grass_pt(test_means_mcmc_theta_argmin)
    test_means_mcmc_barycenters.append(barycenter)
    # plt.plot(thetas, test_means_mcmc_losses)
    # plt.scatter(test_means_mcmc_theta_argmin, test_means_mcmc_losses.min(),color="red")
    # plt.grid()
    # plt.title(f"{i}")
    # plt.show()

In [None]:
test_preds_mcmc_barycenters = []
for i in tqdm(range(s_test.shape[0])):
    points = Ws_preds[:,i,:,:]
    test_preds_mcmc_losses = vmap(lambda theta: loss(theta, points))(thetas)
    test_preds_mcmc_theta_argmin = thetas[test_preds_mcmc_losses.argmin()]
    barycenter = subspace_angle_to_grass_pt(test_preds_mcmc_theta_argmin)
    test_preds_mcmc_barycenters.append(barycenter)
    # plt.plot(thetas, test_preds_mcmc_losses)
    # plt.scatter(test_preds_mcmc_theta_argmin, test_preds_mcmc_losses.min(),color="red")
    # plt.grid()
    # plt.title(f"{i}")
    # plt.show()

In [None]:
test_means_mcmc_barycenters = np.array(test_means_mcmc_barycenters)

In [None]:
test_preds_mcmc_barycenters = np.array(test_preds_mcmc_barycenters)

In [None]:
with open("results/test_means_mcmc_barycenters_more_train_pts_new_anchor.pickle", 'wb') as f:
    pickle.dump(test_means_mcmc_barycenters, f)

with open("results/test_preds_mcmc_barycenters_more_train_pts_new_anchor.pickle", 'wb') as f:
    pickle.dump(test_preds_mcmc_barycenters, f)

In [None]:
# with open("results/test_means_mcmc_barycenters_more_train_pts_new_anchor.pickle", 'rb') as f:
#     test_means_mcmc_barycenters = pickle.load(f)

# with open("results/test_preds_mcmc_barycenters_more_train_pts_new_anchor.pickle", 'rb') as f:
#     test_preds_mcmc_barycenters = pickle.load(f)

In [None]:
out_sample_mean_errors = vmap(grass_dist)(Ws_test, test_means_mcmc_barycenters)
out_sample_pred_errors = vmap(grass_dist)(Ws_test, test_preds_mcmc_barycenters)

In [None]:
plt.plot(s_test,out_sample_mean_errors, label='error using means')
plt.plot(s_test,out_sample_pred_errors, label='error using preds')
plt.vlines(s_train, 0, 1, colors="green", linestyles="dashed")
plt.legend()
plt.show()

In [None]:
sd_s_test = []
for i in tqdm(range(s_test.shape[0])):
    fixed = test_preds_mcmc_barycenters[i]
    dists = vmap(lambda W: grass_dist(W, fixed))(Ws_preds[:,i,:,:])
    dists_Sq = dists**2
    sd_s_test.append(np.sqrt(dists_Sq.mean()))

In [None]:
sd_s_test = np.array(sd_s_test)

In [None]:
test_pd_data = {'s': s_test, 'errors_mean': out_sample_mean_errors, 'errors_pred': out_sample_pred_errors, 'sd': sd_s_test}
out_sample_errors_df = pd.DataFrame(data=test_pd_data)
out_sample_errors_df.head()

In [None]:
out_sample_errors_df.drop(['s'], axis=1).describe()