# Manifold Statistics - Examples on $\mathbb{S}^2$

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from src.manifolds.S2 import *
M = S2()
print(M)
from src.plotting import *
#%matplotlib notebook

In [None]:
# Riemannian structure
from src.Riemannian import metric
metric.initialize(M)

# geodesics
from src.Riemannian import geodesic
geodesic.initialize(M)

# Logarithm map
from src.Riemannian import Log
Log.initialize(M)

x = M.coords(jnp.zeros(M.dim))

## Sample Data

In [None]:
# coordinate form
from src.stochastics import Brownian_coords
Brownian_coords.initialize(M)

# product sde
from src.stochastics import product_sde
from src.stochastics.product_sde import tile
(product,sde_product,chart_update_product) = product_sde.initialize(M,M.sde_Brownian_coords,M.chart_update_Brownian_coords)

N = 32
_dts = dts(T=.5)
(ts,xss,chartss,*_) = product(tile(x,N),_dts,dWs(N*M.dim,_dts).reshape(-1,N,M.dim),jnp.repeat(1.,N))
samples = xss[-1]
chartss = chartss[-1]    

# plot
newfig()
M.plot()
for i in range(N):
    M.plotx((samples[i],chartss[i]))
plt.show()

## Frechet mean

In [None]:
from src.statistics import Frechet_mean
Frechet_mean.initialize(M)

m,loss,iterations,vs = M.Frechet_mean(zip(samples,chartss),x)
# m,loss,iterations = M.Frechet_mean(zip(samples,chartss),x,Log=lambda *args: M.Log(*args))
print("loss = ", loss)
print("mean = ", m)

# plot
newfig()
M.plot()
M.plotx(m,s=100,color='g')

for i in range(len(samples)):
    try:
        (xs,charts) = M.Expt(m,vs[i])
        M.plot_path(zip(xs,charts))
    except:
        pass
    M.plotx((samples[i],chartss[i]),linewidth = 1.5, s=50, color='r')
M.plot_path(iterations,color='y')
plt.show()

## Tangent PCA

In [None]:
from src.statistics.tangent_PCA import *

from src.utils import *
from sklearn.decomposition import PCA

pca = tangent_PCA(M, lambda *args: M.Log(*args),x,zip(samples,chartss))
print(pca.get_covariance())

plt.scatter(pca.transformed_Logs[:, 0], pca.transformed_Logs[:, 1])
plt.axis('equal')
plt.show()

## Sampled mean

In [None]:
# condition on diagonal of product manifold
from src.stochastics import diagonal_conditioning
diagonal_conditioning.initialize(M,sde_product,chart_update_product)

_dts = dts(n_steps=500,T=.1)
(ts,xss,_chartss) = M.diagonal((samples,chartss),
                             _dts,dWs(N*M.dim,_dts).reshape(-1,N,M.dim),x[1],jnp.repeat(1.,N))
mean = jnp.mean(jax.vmap(lambda _x,chart: M.update_coords((_x,chart),x[1])[0],0)(xss[-1],_chartss[-1]),0)

# plot
M.newfig()
M.plot()
colormap = plt.get_cmap('winter')
colors=[colormap(k) for k in np.linspace(0, 1, N)]
for i in range(N):
    M.plot_path(zip(xss[:,i],_chartss[:,i]),color=colors[i])
    M.plotx((samples[i],chartss[i]),s=40)
M.plotx((mean,x[1]),color='r',s=80)
ax = plt.gcf().gca(); ax.view_init(60, 45) # rotate
plt.axis('off')
# plt.savefig('diagonal-mean-N3.pdf')
plt.show()

In [None]:
# samples
N = 256
_dts = dts(T=.5)
(ts,xss,chartss,*_) = product((jnp.tile(x[0],(N,)+(1,)*x[0].ndim),jnp.tile(x[1],(N,)+(1,)*x[1].ndim)),
                             _dts,dWs(N*M.dim,_dts).reshape(-1,N,M.dim),jnp.repeat(1.,N))
samples = xss[-1]
charts_samples = chartss[-1]    
# plot
newfig()
M.plot()
for i in range(N):
    M.plotx((samples[i],charts_samples[i]))
ax = plt.gcf().gca(); ax.view_init(60, 45) # rotate
plt.axis('off')
# plt.savefig('diagonal-samples-N256.pdf')
plt.show()

# sample multiple means
K = 32
means = np.zeros((K,M.dim))

_dts = dts(T=.2)
for i in range(K):
    (ts,xss,_chartss) = M.diagonal((samples,charts_samples),
                             _dts,dWs(N*M.dim,_dts).reshape(-1,N,M.dim),x[1],jnp.repeat(1.,N))
    means[i] = jnp.mean(jax.vmap(lambda _x,chart: M.update_coords((_x,chart),x[1])[0],0)(xss[-1],_chartss[-1]),0)
    
colormap = plt.get_cmap('winter')
colors=[colormap(k) for k in np.linspace(0, 1, K)]

# plot estimated means with std variation
M.newfig()
M.plot()
ax = plt.gca()
for j in range(K):
    M.plotx((means[j],x[1]),color=colors[j],s=30)
ax = plt.gcf().gca(); ax.view_init(60, 45) # rotate
plt.axis('off')
# plt.savefig('diagonal-mean-N256.pdf')
plt.show()

## Diffusion mean estimation

In [None]:
# Delyon/Hu guided process
from src.stochastics.guided_process import *

# guide function
phi = lambda q,v,s: jnp.tensordot((1/s)*jnp.linalg.cholesky(M.g(q)).T,M.StdLog(q,M.F((v,q[1]))).flatten(),(1,0))

(Brownian_coords_guided,sde_Brownian_coords_guided,chart_update_Brownian_coords_guided) = get_guided(
    M,M.sde_Brownian_coords,M.chart_update_Brownian_coords,phi,
    sqrtCov=lambda x,s: s*jnp.linalg.cholesky(M.gsharp(x)),A=lambda x,v,w,s: jnp.dot(v,jnp.dot((s**(-2))*M.g(x),w)))

# product bridge sde
from src.stochastics import product_sde
from src.stochastics.product_sde import tile
(product_guided,*_) = product_sde.initialize(M,sde_Brownian_coords_guided,chart_update_Brownian_coords_guided)

x = M.coords(jnp.zeros(M.dim))
w = M.Exp(x,np.array([.8,-.5])) # target
N = 4
_dts = dts(n_steps=1000,T=.1)
(ts,xss,chartss,*_) = product_guided(tile(x,N),_dts,dWs(N*M.dim,_dts).reshape(-1,N,M.dim),
                                     tile(0.,N),tile(0.,N),tile(jnp.sum(_dts),N),
                                     tile(M.update_coords(w,x[1])[0],N),jnp.repeat(1.,N)) # target

# plot
M.newfig()
M.plot()
colormap = plt.get_cmap('winter')
colors=[colormap(k) for k in np.linspace(0, 1, N)]
for i in range(N):
    M.plot_path(zip(xss[:,i],chartss[:,i]),color=colors[i])
M.plotx(x,color='r')
plt.show()

In [None]:
def log_p_T(guided,phi,x,v,_dts,dWs,*ys):
    """ Monte Carlo approximation of log transition density from guided process """
    T = jnp.sum(_dts)
    
    print(*ys)
    
    Cxv = jnp.sum(phi(x,M.update_coords(v,x[1])[0],*ys)**2)
    
    # sample
    log_varphis = jax.vmap(lambda dW: guided(x,v,_dts,dW,*ys)[4][-1],1)(dWs)
    
    log_varphi = jnp.log(jnp.mean(jnp.exp(log_varphis)))
    log_p_T = -.5*x[0].shape[0]*jnp.log(2.*jnp.pi*T)-Cxv/(2.*T)+log_varphi
    return log_p_T

log_p_T = partial(log_p_T,Brownian_coords_guided,phi)

_dts = dts(n_steps=100,T=1.)

# test one sample
N=1 
log_p_T(x,w,_dts,dWs(N*M.dim,_dts).reshape(-1,N,M.dim),1.)
%time log_p_T(x,w,_dts,dWs(N*M.dim,_dts).reshape(-1,N,M.dim),1.)
N=10
log_p_T(x,w,_dts,dWs(N*M.dim,_dts).reshape(-1,N,M.dim),1.)
%time log_p_T(x,w,_dts,dWs(N*M.dim,_dts).reshape(-1,N,M.dim),1.)
N=1000
log_p_T(x,w,_dts,dWs(N*M.dim,_dts).reshape(-1,N,M.dim),1.)
%time log_p_T(x,w,_dts,dWs(N*M.dim,_dts).reshape(-1,N,M.dim),1.)

# multiple samples
N=100
neg_log_p_Ts = lambda *args: -jnp.mean(jax.vmap(lambda x,chart,w,dW,*ys: log_p_T((x,chart),w,_dts,dW,*ys),(None,None,0,0,*((None,)*(len(args)-4))))(*args))
neg_log_p_Ts(*x,(samples,charts_samples),dWs(samples.shape[0]*N*M.dim,_dts).reshape(-1,_dts.shape[0],N,M.dim),1.)
%time neg_log_p_Ts(*x,(samples,charts_samples),dWs(samples.shape[0]*N*M.dim,_dts).reshape(-1,_dts.shape[0],N,M.dim),1.)

In [None]:
# run MLE

from jax.experimental import optimizers

def iterative_mle(obss,neg_log_p_Ts,params,params_inds,params_update,chart,step_size=1e-1,num_steps=5):
    opt_init, opt_update, get_params = optimizers.adam(step_size)

    def step(step, params, opt_state, chart):
        params = get_params(opt_state)
        value,grads = jax.value_and_grad(neg_log_p_Ts,params_inds)(params[0],chart,obss,
                                                                      dWs(len(obss[0])*N*M.dim,_dts).reshape(-1,_dts.shape[0],N,M.dim),
                                                                  *params[1:])
        opt_state = opt_update(step, grads, opt_state)
        opt_state,chart = params_update(opt_state, chart)
        return (value,opt_state,chart)

    opt_state = opt_init(params)
    values = (); paramss = ()

    for i in range(num_steps):
        (value, opt_state, chart) = step(i, params, opt_state, chart)
        values += (value,); paramss += ((*get_params(opt_state),chart),)
        if i % 1 == 0:
            print("Step {} | T: {:0.6e} | T: {}".format(i, value, str((get_params(opt_state),chart))))
    print("Final {} | T: {:0.6e} | T: {}".format(i, value, str(get_params(opt_state))))
    
    return (get_params(opt_state),chart,value,jnp.array(values),paramss)

# define parameters
x = M.coords(jnp.zeros(M.dim))
params = (x[0]+.1*np.random.normal(size=M.dim),.5)
params_inds = (0,4)
# function to update charts for position depends parameters
def params_update(state, chart):
    try:
        ((x,m,v),),*s = state
        if M.do_chart_update((x,chart)):
            new_chart = M.centered_chart((x,chart))
            (x,chart) = M.update_coords((x,chart),new_chart)
        return optimizers.OptimizerState(((x,m,v),),*s),chart
    except ValueError: # state is packed
        states_flat, tree_def, subtree_defs = state
        ((x,m,v),*s) = states_flat
        if M.do_chart_update((x,chart)):
            new_chart = M.centered_chart((x,chart))
            (x,chart) = M.update_coords((x,chart),new_chart)
        states_flat = ((x,m,v),*s)
        return (states_flat,tree_def,subtree_defs),chart

(thetas,chart,log_likelihood,log_likelihoods,thetass) = iterative_mle((samples,charts_samples),
                                                                      neg_log_p_Ts,
                                                                      params,params_inds,params_update,x[1],
                                                                      num_steps=20)

# plot
n_steps = log_likelihoods.shape[0]
plt.plot(range(n_steps),log_likelihoods)
# plt.savefig('ML_likelihoods.pdf')
plt.show()
plt.plot(range(n_steps),[t[1] for t in thetass])
plt.show()
plt.plot(range(n_steps),[M.F((t[0],t[2])) for t in thetass])
print(M.F((thetas[0],chart)))
# plt.savefig('ML_thetas.pdf')
plt.show()

M.newfig()
M.plot()
M.plotx((thetas[0],chart),color='k',s=100) # result
M.plotx((thetass[0][0],thetass[0][2]),color='b',s=100) # initial point
M.plotx(x,color='r',s=100)
M.plot_path(list([(t[0],t[2]) for t in thetass]),color='b',linewidth=2.5)

plt.savefig('MLmean_iterations.pdf')
plt.show()

# Most probable paths

In [None]:
# from Anisotropic covariance on manifolds and most probable paths,
# Erlend Grong and Stefan Sommer, 2021
from src.framebundle import MPP
MPP.initialize(M)

In [None]:
# sample data
from src.framebundle import FM
FM.initialize(M)
from src.stochastics import stochastic_development
stochastic_development.initialize(M)

# simulate Brownian Motion
# %time _,xss,chartss=jax.vmap(lambda dWs: M.Brownian_coords(x,dWs))(dWs(M.dim,n_steps=1000,num=16))
# obss = xss[:,-1]
# obs_charts = chartss[:,-1]

# simulate anisotropic Brownian Motion
lamb = jnp.array([.6,.25])
nu = jnp.einsum('i,ij->ij',lamb,np.linalg.cholesky(M.gsharp(x)))
u = (np.concatenate((x[0],nu.flatten())),x[1])
(ts,us,charts) = M.stochastic_development(u,dts(),dWs(M.dim))
xs = us[:,0:M.dim]

# plot
newfig()
M.plot()
M.plot_path(zip(us,charts))
plt.show()

%time _,uss,chartss=jax.vmap(lambda dWs: M.stochastic_development(u,dts(),dWs))(dWs(M.dim,num=64))
obss = uss[:,-1,0:M.dim]
obs_charts = chartss[:,-1]

# plot
newfig()
M.plot()
for (_x,_chart) in zip(obss,obs_charts):
    M.plotx((_x,_chart))
plt.savefig('S2_samples_lamb_05_005.pdf')
plt.show()

In [None]:
from src.Riemannian import curvature
curvature.initialize(M)

ys = list(zip(obss,obs_charts))
chart = x[1]
(_x,_lamb,vs,chis) = M.MPP_mean(x,chart,ys)

# # compute variance
# var = 1/(N*M.dim)*jnp.sum(jnp.array([f(chart,_x,_lamb,v,chi) for (v,chi) in get_params45(opt_state45)]))
# print(_lamb,var)
# # _lamb = _lamb*var
# # print(_lamb)

# print(lamb/jnp.sqrt(jnp.prod(lamb)))
# print(_lamb/jnp.sqrt(jnp.prod(_lamb)))

In [None]:
_nu = jnp.linalg.cholesky(M.gsharp((_x,chart)))
print("x: ",(_x,chart),"\nlambda:\n",_lamb,"\nnu:\n",_nu)
_u = (jnp.hstack((_x,_nu.flatten())),chart)

# plot
newfig()
M.plot()
M.plotx((_x,chart),u=np.einsum('i,ij->ij',_lamb,_nu),linewidth = 1.5, s=50)

for i in range(len(vs)):
    v = vs[i]
    chi = chis[i]

    (xs,_,_chis,charts) = M.MPP_forwardt(_u,_lamb,v,chi)
    print("v: ",v,", chi: ",chi, ", chiT: ",_chis[-1])

    
    M.plotx((obss[i],obs_charts[i]),linewidth = 1.5, s=50, color='r')
    M.plot_path(zip(xs[:,0:M.dim],charts))
    
plt.axis('off')
# plt.savefig('S2_estimation_lambda_06_025.pdf')
plt.show()