In [None]:
## This file is part of Jax Geometry
#
# Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
# https://bitbucket.org/stefansommer/jaxgeometry
#
# Jax Geometry is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Jax Geometry is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Jax Geometry. If not, see <http://www.gnu.org/licenses/>.
#

# $\mathbb{S}^2$ Sphere Geometry

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from src.manifolds.S2 import *
M = S2()
print(M)
from src.plotting import *
#%matplotlib notebook

In [None]:
# element, tangent vector and covector
x = M.coords([0.,0.])
v = jnp.array([1.,0.])

print("x = ", x)
print("v = ", v)


print("F(x): ",M.F(x))
print("JF(x):\n",M.JF(x))
print("JF(x)^{-1}:\n",M.invJF((M.F(x),x[1])))

# plot
newfig()
M.plot()
M.plotx(x,u=v,linewidth = 1.5, s=50)
plt.show()

In [None]:
## Riemannian structure
from src.Riemannian import metric
metric.initialize(M)

print("g(x):\n",M.g(x))
print("g^#(x):\n",M.gsharp(x))
print("\Gamma(x):\n",M.Gamma_g(x))

# covector
p = M.flat(x,v)
print("v: ",v,", p: ",p,", p^#: ",M.sharp(x,p))

## Riemannian Geodesics

In [None]:
# 2nd order geodesic equation
from src.Riemannian import geodesic
geodesic.initialize(M)

# compute geodesics
(xs,charts) = M.Expt(x,v)

# plot
newfig()
M.plot()
M.plotx(x,u=v,linewidth = 1.5, s=50)
M.plot_path(zip(xs,charts),linewidth = 1.5, s=50)
plt.show()

### Geodesics from Hamiltonian equations

In [None]:
# # Hamiltonian dynamics
q = x
print(M.H(q,p))

from src.dynamics import Hamiltonian
Hamiltonian.initialize(M)

# Exponential map from Hamiltonian equations
(qs,charts) = M.Exp_Hamiltoniant(q,p)

# plot
newfig()
M.plot()
M.plotx(x,u=v,linewidth = 1.5, s=50)
M.plot_path(zip(qs,charts),linewidth = 1.5, s=50)
plt.show()

# dynamics returning both position and momentum
(ts,qps,_) = M.Hamiltonian_dynamics(q,p,dts())
ps = qps[:,1,:]
print("Energy: ",np.array([M.H((q,chart),p) for (q,p,chart) in zip(qs,ps,charts)]))

## Curvature

In [None]:
from src.Riemannian import curvature
curvature.initialize(M)
 
# Curvature tensor, Ricci and scalar curvature:
print("curvature = ", M.R(x))
print("Ricci curvature = ", M.Ricci_curv(x))
print("Scalar curvature = ", M.S_curv(x))

# Orthonormal basis under g:
nu = jnp.linalg.cholesky(M.gsharp(x))

# # Sectional Curvature
# print("sectional curvature = ",M.sec_curv(x,nu[:,0],nu[:,1]))

## Parallel Transport

In [None]:
# Parallel transport
from src.Riemannian import parallel_transport
parallel_transport.initialize(M)

chart = M.chart()
w = np.array([-1./2,-1./2])
w = w/M.norm(x,w)
t = np.cumsum(dts())
xs = np.vstack([t**2,-np.sin(t)]).T
dxs = np.vstack([2*t,-np.cos(t)]).T

# compute  parallel transport
ws = M.parallel_transport(w,dts(),xs,np.tile(chart,(n_steps,1)),dxs)
print("ws norm: ",np.array([M.norm((x,chart),w) for (x,w,chart) in zip(xs,ws,charts)]))

# plot result
newfig()
M.plot()
M.plot_path(zip(xs,itertools.cycle((chart,))),vs=ws,v_steps=np.arange(0,n_steps,5))
plt.show()

# along geodesic
# compute geodesic
(ts,xsdxs,charts) = M.geodesic(x,v,dts())
xs = xsdxs[:,0,:]
dxs = xsdxs[:,1,:]
# compute  parallel transport
ws = M.parallel_transport(w,dts(),xs,charts,dxs)
print("ws norm: ",np.array([M.norm((x,chart),w) for (x,w,chart) in zip(xs,ws,charts)]))

# plot
newfig()
M.plot()
M.plot_path(zip(xs,charts),vs=ws,v_steps=np.arange(0,n_steps,5),linewidth=1.5, s=50)
plt.show()

## Brownian Motion

In [None]:
# coordinate form
from src.stochastics import Brownian_coords
Brownian_coords.initialize(M)

_dts = dts(n_steps=1000)
(ts,xs,charts) = M.Brownian_coords(x,_dts,dWs(M.dim,_dts))

# plot
newfig()
M.plot()
M.plot_path(zip(xs,charts))
plt.show()

# plot multiple sample paths
N = 5
xss = np.zeros((N,n_steps,M.dim))
chartss = np.zeros((N,n_steps,x[1].shape[0]))
for i in range(N):
    (ts,xs,charts) = M.Brownian_coords(x,dts(),dWs(M.dim))
    xss[i] = xs
    chartss[i] = charts

# plot
M.newfig()
M.plot()
colormap = plt.get_cmap('winter')
colors=[colormap(k) for k in np.linspace(0, 1, N)]
for i in range(N):
    M.plot_path(zip(xss[i],chartss[i]),color=colors[i])
M.plotx(x,color='r',s=50)
plt.show()

In [None]:
# development and Brownian motion from stochastic development
from src.framebundle import FM
from src.stochastics import stochastic_development
from src.stochastics import Brownian_development

FM.initialize(M)
stochastic_development.initialize(M)
Brownian_development.initialize(M)

# develop a curve
t = np.cumsum(dts(n_steps=50))
dxs = np.vstack([2*t,-np.cos(t)]).T
nu = np.linalg.cholesky(M.gsharp(x))
u = (np.concatenate((x[0],nu.flatten())),x[1])
(ts,xs,charts) = M.development(u,dxs,dts(n_steps=50))
# print("u.T*g*u: ",np.array([np.einsum('ji,jk,kl->il',u.reshape((M.dim,-1)),M.g((x,chart)),u.reshape((M.dim,-1))) for (x,u,chart) in zip(xs[:,0:M.dim],xs[:,M.dim:],charts)]))

# plot
newfig()
M.plot()
M.plot_path(zip(xs,charts))
plt.show()

# simulate Brownian Motion
_dts = dts(n_steps=1000)
(ts,xs,charts) = M.Brownian_development(x,_dts,dWs(M.dim,_dts))

# plot
newfig()
M.plot()
M.plot_path(zip(xs,charts))
plt.show()

## Most probable paths

In [None]:
# forward mpp equations, from Anisotropic covariance on manifolds and most probable paths,
# Erlend Grong and Stefan Sommer, 2021
from src.framebundle import MPP
MPP.initialize(M)

# integrate mpp
nu = jnp.linalg.cholesky(M.gsharp(x))
u = (jnp.concatenate((x[0],nu.flatten())),x[1])
lamb = jnp.array([2.,.25])
v = jnp.array([2.,0.])
chi = jnp.array([-6.])
print("lambda:\n",lamb,"\nnu:\n",nu,"\nv: ",v,"\nchi:\n",chi)
%time (xs,vs,chis,charts) = M.MPP_forwardt(u,lamb,v,chi)
%time (xs,vs,chis,charts) = M.MPP_forwardt(u,lamb,v,chi)

# plot
newfig()
M.plot()
M.plot_path(zip(xs[:,0:M.dim],charts))
plt.show()

# plot
newfig()
M.plot()
M.plot_path(zip(xs,charts))
plt.show()

# anti-development
print("Svs norm: ",np.array([np.linalg.norm(v/lamb) for v in vs]))
axs = np.cumsum(vs*dts(T,n_steps)[:,np.newaxis],axis=0)
plt.plot(axs[:,0],axs[:,1],'b',label='anti-development')
plt.plot(vs[:,0],vs[:,1],'r',label='v(t)')
plt.legend()
plt.show()

In [None]:
# target point
v = jnp.array([.3,.3])
(xs,charts) = M.Expt(x,v)
y = (xs[-1],charts[-1])

nu = jnp.linalg.cholesky(M.gsharp(x))
u = (jnp.concatenate((x[0],nu.flatten())),x[1])

# find MPP
lamb = jnp.array([.5,.5])
v,chi = M.MPP(u,lamb,y)

print("lambda:\n",lamb,"\nnu:\n",nu,"\nv: ",v,"\nchi:\n",chi)
(xs,vs,chis,charts) = M.MPP_forwardt(u,lamb,v,chi)
print("chiT:\n",chis[-1])

# plot
newfig()
M.plot()
M.plotx(x,u=np.einsum('i,ij->ij',lamb,nu),linewidth = 1.5, s=50)
M.plotx(y,linewidth = 1.5, s=50, color='r')
M.plot_path(zip(xs[:,0:M.dim],charts))
fig = plt.gcf(); ax = fig.gca(); ax.view_init(30, 30) # rotate
plt.show()

# anti-development
print("Svs norm: ",np.array([np.linalg.norm(v/lamb) for v in vs]))
axs = np.cumsum(vs*dts(T,n_steps)[:,np.newaxis],axis=0)
plt.plot(axs[:,0],axs[:,1])
plt.show()


# find MPP
lamb = jnp.array([.5,1.5])
v,chi = M.MPP(u,lamb,y)

print("lambda:\n",lamb,"\nnu:\n",nu,"\nv: ",v,"\nchi:\n",chi)
(xs,vs,chis,charts) = M.MPP_forwardt(u,lamb,v,chi)
print("chiT:\n",chis[-1])

# plot
newfig()
M.plot()
M.plotx(x,u=np.einsum('i,ij->ij',lamb,nu),linewidth = 1.5, s=50)
M.plotx(y,linewidth = 1.5, s=50, color='r')
M.plot_path(zip(xs[:,0:M.dim],charts))
fig = plt.gcf(); ax = fig.gca(); ax.view_init(30, 0) # rotate
# fig = plt.gcf(); ax = fig.gca(); ax.view_init(90, 0) # rotate
# plt.axis('off')
# plt.savefig('S2_MPP_lambda_05_15.pdf')
plt.show()

# with frame
newfig()
M.plot()
M.plotx(x,linewidth = 1.5, s=50)
M.plotx(y,linewidth = 1.5, s=50, color='r')
M.plot_path(zip(xs,charts))
fig = plt.gcf(); ax = fig.gca(); ax.view_init(30, 0) # rotate
plt.show()

# anti-development
print("Svs norm: ",np.array([np.linalg.norm(v/lamb) for v in vs]))
axs = np.cumsum(vs*dts(T,n_steps)[:,np.newaxis],axis=0)
plt.plot(axs[:,0],axs[:,1])
plt.show()

In [None]:
# sample data

# simulate Brownian Motion
# %time _,xss,chartss=jax.vmap(lambda dWs: M.Brownian_coords(x,dWs))(dWs(M.dim,n_steps=1000,num=16))
# obss = xss[:,-1]
# obs_charts = chartss[:,-1]

# simulate anisotropic Brownian Motion
lamb = jnp.array([.6,.25])
nu = jnp.einsum('i,ij->ij',lamb,np.linalg.cholesky(M.gsharp(x)))
u = (np.concatenate((x[0],nu.flatten())),x[1])
(ts,us,charts) = M.stochastic_development(u,dts(),dWs(M.dim))
xs = us[:,0:M.dim]

# plot
newfig()
M.plot()
M.plot_path(zip(us,charts))
plt.show()

%time _,uss,chartss=jax.vmap(lambda dWs: M.stochastic_development(u,dts(),dWs))(dWs(M.dim,num=64))
obss = uss[:,-1,0:M.dim]
obs_charts = chartss[:,-1]

# plot
newfig()
M.plot()
for (_x,_chart) in zip(obss,obs_charts):
    M.plotx((_x,_chart))
plt.savefig('S2_samples_lamb_05_005.pdf')
plt.show()

In [None]:
# from src.framebundle import MPP
# MPP.initialize(M)

ys = list(zip(obss,obs_charts))
(_x,_lamb,vs,chis) = M.MPP_mean(x,chart,ys)

# # compute variance
# var = 1/(N*M.dim)*jnp.sum(jnp.array([f(chart,_x,_lamb,v,chi) for (v,chi) in get_params45(opt_state45)]))
# print(_lamb,var)
# # _lamb = _lamb*var
# # print(_lamb)

# print(lamb/jnp.sqrt(jnp.prod(lamb)))
# print(_lamb/jnp.sqrt(jnp.prod(_lamb)))

In [None]:
_nu = jnp.linalg.cholesky(M.gsharp((_x,chart)))
print("x: ",(_x,chart),"\nlambda:\n",_lamb,"\nnu:\n",_nu)
_u = (jnp.hstack((_x,_nu.flatten())),chart)

# plot
newfig()
M.plot()
M.plotx((_x,chart),u=np.einsum('i,ij->ij',_lamb,_nu),linewidth = 1.5, s=50)

for i in range(len(vs)):
    v = vs[i]
    chi = chis[i]

    (xs,_,_chis,charts) = M.MPP_forwardt(_u,_lamb,v,chi)
    print("v: ",v,", chi: ",chi, ", chiT: ",_chis[-1])

    
    M.plotx((obss[i],obs_charts[i]),linewidth = 1.5, s=50, color='r')
    M.plot_path(zip(xs[:,0:M.dim],charts))
    
plt.axis('off')
# plt.savefig('S2_estimation_lambda_06_025.pdf')
plt.show()