In [None]:
## This file is part of Jax Geometry
#
# Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
# https://bitbucket.org/stefansommer/jaxgeometry
#
# Jax Geometry is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Jax Geometry is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Jax Geometry. If not, see <http://www.gnu.org/licenses/>.
#

# Manifold Statistics - Examples on $\mathbb{S}^2$

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from src.manifolds.S2 import *
M = S2()
print(M)
from src.plotting import *
#%matplotlib notebook

In [None]:
# Riemannian structure
from src.Riemannian import metric
metric.initialize(M)

# geodesics
from src.Riemannian import geodesic
geodesic.initialize(M)

# Logarithm map
from src.Riemannian import Log
Log.initialize(M)

## Sample Data

In [None]:
# coordinate form
from src.stochastics import Brownian_coords
Brownian_coords.initialize(M)

# product sde
from src.stochastics import sde_product
sde_product.initialize(M,M.sde_Brownian_coords,M.chart_update_Brownian_coords)

x = M.coords(jnp.zeros(M.dim))

N = 3
_dts = dts(T=.5)
(ts,xss,chartss) = M.product((jnp.tile(x[0],(N,)+(1,)*x[0].ndim),jnp.tile(x[1],(N,)+(1,)*x[1].ndim)),
                             _dts,dWs(N*M.dim,_dts).reshape(-1,N,M.dim))
samples = xss[-1]
chartss = chartss[-1]    

# plot
newfig()
M.plot()
for i in range(N):
    M.plotx((samples[i],chartss[i]))
plt.show()

## Frechet mean

In [None]:
from src.statistics.Frechet_mean import *

res = Frechet_mean(lambda *args: M.Log(*args), zip(samples,chartss), x)
Fm = res[0]
print("loss = ", res[1])
print("mean = ", Fm)
iterations = res[2]

newfig()
M.plot(rotate = np.array([50,-45]))
M.plotx(Fm)
M.plot_path(zip(iterations,itertools.cycle((x[1],))))
plt.show()

## Tangent PCA

In [None]:
from src.statistics.tangent_PCA import *

from src.utils import *
from sklearn.decomposition import PCA

pca = tangent_PCA(M, lambda *args: M.Log(*args),x,zip(samples,chartss))
print(pca.get_covariance())

plt.scatter(pca.transformed_Logs[:, 0], pca.transformed_Logs[:, 1])
plt.axis('equal')
plt.show()

## sampled mean

In [None]:
# condition on diagonal of product manifold
from src.stochastics import diagonal_conditioning
diagonal_conditioning.initialize(M,M.sde_product,M.chart_update_product)

_dts = dts(n_steps=500,T=.1)
(ts,xss,_chartss) = M.diagonal((samples,chartss),
                             _dts,dWs(N*M.dim,_dts).reshape(-1,N,M.dim),x[1])
mean = jnp.mean(jax.vmap(lambda _x,chart: M.update_coords((_x,chart),x[1])[0],0)(xss[-1],_chartss[-1]),0)

# plot
M.newfig()
M.plot()
colormap = plt.get_cmap('winter')
colors=[colormap(k) for k in np.linspace(0, 1, N)]
for i in range(N):
    M.plot_path(zip(xss[:,i],_chartss[:,i]),color=colors[i])
    M.plotx((samples[i],chartss[i]),s=40)
M.plotx((mean,x[1]),color='r',s=80)
ax = plt.gcf().gca(); ax.view_init(60, 45) # rotate
plt.axis('off')
# plt.savefig('diagonal-mean-N3.pdf')
plt.show()

In [None]:
# samples
N = 256
_dts = dts(T=.5)
(ts,xss,chartss) = M.product((jnp.tile(x[0],(N,)+(1,)*x[0].ndim),jnp.tile(x[1],(N,)+(1,)*x[1].ndim)),
                             _dts,dWs(N*M.dim,_dts).reshape(-1,N,M.dim))
samples = xss[-1]
chartss = chartss[-1]    
# plot
newfig()
M.plot()
for i in range(N):
    M.plotx((samples[i],chartss[i]))
ax = plt.gcf().gca(); ax.view_init(60, 45) # rotate
plt.axis('off')
# plt.savefig('diagonal-samples-N256.pdf')
plt.show()

# sample multiple means
K = 32
means = np.zeros((K,M.dim))

_dts = dts(T=.2)
for i in range(K):
    (ts,xss,_chartss) = M.diagonal((samples,chartss),
                             _dts,dWs(N*M.dim,_dts).reshape(-1,N,M.dim),x[1])
    means[i] = jnp.mean(jax.vmap(lambda _x,chart: M.update_coords((_x,chart),x[1])[0],0)(xss[-1],_chartss[-1]),0)
    
colormap = plt.get_cmap('winter')
colors=[colormap(k) for k in np.linspace(0, 1, K)]

# plot estimated means with std variation
M.newfig()
M.plot()
ax = plt.gca()
for j in range(K):
    M.plotx((means[j],x[1]),color=colors[j],s=30)
ax = plt.gcf().gca(); ax.view_init(60, 45) # rotate
plt.axis('off')
# plt.savefig('diagonal-mean-N256.pdf')
plt.show()