In [None]:
## This file is part of Jax Geometry
#
# Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
# https://bitbucket.org/stefansommer/jaxgeometry
#
# Jax Geometry is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Jax Geometry is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Jax Geometry. If not, see <http://www.gnu.org/licenses/>.
#

# Manifold Statistics - Examples on $\mathbb{S}^2$

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from src.manifolds.S2 import *
M = S2()
print(M)
from src.plotting import *
#%matplotlib notebook

In [None]:
# Riemannian structure
from src.Riemannian import metric
metric.initialize(M)

# geodesics
from src.Riemannian import geodesic
geodesic.initialize(M)

# Logarithm map
from src.Riemannian import Log
Log.initialize(M)

## Sample Data

In [None]:
# coordinate form
from src.stochastics import Brownian_coords
Brownian_coords.initialize(M)

N_samples = 3 # 256
x = M.coords([0.,0.])

samples = np.zeros((N_samples,M.dim))
chartss = np.zeros((N_samples,x[1].shape[0]))
for i in range(N_samples):
    (ts,xs,charts) = M.Brownian_coords(x,dWs(M.dim))
    samples[i] = xs[-1]
    chartss[i] = charts[-1]

# plot
newfig()
M.plot()
for i in range(N_samples):
    M.plotx((samples[i],chartss[i]))
plt.show()

## Frechet mean

In [None]:
from src.statistics.Frechet_mean import *

res = Frechet_mean(lambda *args: M.Log(*args), zip(samples,chartss), x)
Fm = res[0]
print("loss = ", res[1])
print("mean = ", Fm)
iterations = res[2]

newfig()
M.plot(rotate = np.array([50,-45]))
M.plotx(Fm)
M.plot_path(zip(iterations,itertools.cycle((x[1],))))
plt.show()

## Tangent PCA

In [None]:
from src.statistics.tangent_PCA import *

from src.utils import *
from sklearn.decomposition import PCA

pca = tangent_PCA(M, lambda *args: M.Log(*args),x,zip(samples,chartss))
print(pca.get_covariance())

plt.scatter(pca.transformed_Logs[:, 0], pca.transformed_Logs[:, 1])
plt.axis('equal')
plt.show()