In [None]:
## This file is part of Jax Geometry
#
# Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
# https://bitbucket.org/stefansommer/jaxgeometry
#
# Jax Geometry is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Jax Geometry is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Jax Geometry. If not, see <http://www.gnu.org/licenses/>.
#

# $\mathbb{S}^2$ Sphere Geometry

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from src.manifolds.S2 import *
M = S2()
print(M)
from src.plotting import *
#%matplotlib notebook

In [None]:
# element, tangent vector and covector
x = M.coords([0.,0.])
v = jnp.array([1.,0.])

print("x = ", x)
print("v = ", v)


print("F(x): ",M.F(x))
print("JF(x):\n",M.JF(x))
print("JF(x)^{-1}:\n",M.invJF((M.F(x),x[1])))

# plot
newfig()
M.plot()
M.plotx(x,u=v,linewidth = 1.5, s=50)
plt.show()

In [None]:
## Riemannian structure
from src.Riemannian import metric
metric.initialize(M)

print("g(x):\n",M.g(x))
print("g^#(x):\n",M.gsharp(x))
print("\Gamma(x):\n",M.Gamma_g(x))

# covector
p = M.flat(x,v)
print("v: ",v,", p: ",p,", p^#: ",M.sharp(x,p))

## Riemannian Geodesics

In [None]:
# 2nd order geodesic equation
from src.Riemannian import geodesic
geodesic.initialize(M)

# compute geodesics
(xs,charts) = M.Expt(x,v)

# plot
newfig()
M.plot()
M.plotx(x,u=v,linewidth = 1.5, s=50)
M.plot_path(zip(xs,charts),linewidth = 1.5, s=50)
plt.show()

### Geodesics from Hamiltonian equations

In [None]:
# # Hamiltonian dynamics
q = x
print(M.H(q,p))

from src.dynamics import Hamiltonian
Hamiltonian.initialize(M)

# Exponential map from Hamiltonian equations
(qs,charts) = M.Exp_Hamiltoniant(q,p)

# plot
newfig()
M.plot()
M.plotx(x,u=v,linewidth = 1.5, s=50)
M.plot_path(zip(qs,charts),linewidth = 1.5, s=50)
plt.show()

# dynamics returning both position and momentum
(ts,qps,_) = M.Hamiltonian_dynamics(q,p,dts())
ps = qps[:,1,:]
print("Energy: ",np.array([M.H((q,chart),p) for (q,p,chart) in zip(qs,ps,charts)]))

## Curvature

In [None]:
from src.Riemannian import curvature
curvature.initialize(M)
 
# Curvature tensor, Ricci and scalar curvature:
print("curvature = ", M.R(x))
print("Ricci curvature = ", M.Ricci_curv(x))
print("Scalar curvature = ", M.S_curv(x))

# Orthonormal basis under g:
nu = jnp.linalg.cholesky(M.gsharp(x))

# # Sectional Curvature
# print("sectional curvature = ",M.sec_curv(x,nu[:,0],nu[:,1]))

## Parallel Transport

In [None]:
# Parallel transport
from src.Riemannian import parallel_transport
parallel_transport.initialize(M)

chart = M.chart()
w = np.array([-1./2,-1./2])
w = w/M.norm(x,w)
t = np.cumsum(dts())
xs = np.vstack([t**2,-np.sin(t)]).T
dxs = np.vstack([2*t,-np.cos(t)]).T

# compute  parallel transport
ws = M.parallel_transport(w,dts(),xs,np.tile(chart,(n_steps,1)),dxs)
print("ws norm: ",np.array([M.norm((x,chart),w) for (x,w,chart) in zip(xs,ws,charts)]))

# plot result
newfig()
M.plot()
M.plot_path(zip(xs,itertools.cycle((chart,))),vs=ws,v_steps=np.arange(0,n_steps,5))
plt.show()

# along geodesic
# compute geodesic
(ts,xsdxs,charts) = M.geodesic(x,v,dts())
xs = xsdxs[:,0,:]
dxs = xsdxs[:,1,:]
# compute  parallel transport
ws = M.parallel_transport(w,dts(),xs,charts,dxs)
print("ws norm: ",np.array([M.norm((x,chart),w) for (x,w,chart) in zip(xs,ws,charts)]))

# plot
newfig()
M.plot()
M.plot_path(zip(xs,charts),vs=ws,v_steps=np.arange(0,n_steps,5),linewidth=1.5, s=50)
plt.show()

## Brownian Motion

In [None]:
# coordinate form
from src.stochastics import Brownian_coords
Brownian_coords.initialize(M)

_dts = dts(n_steps=1000)
(ts,xs,charts) = M.Brownian_coords(x,_dts,dWs(M.dim,_dts))

# plot
newfig()
M.plot()
M.plot_path(zip(xs,charts))
plt.show()

# plot multiple sample paths
N = 5
xss = np.zeros((N,n_steps,M.dim))
chartss = np.zeros((N,n_steps,x[1].shape[0]))
for i in range(N):
    (ts,xs,charts) = M.Brownian_coords(x,dts(),dWs(M.dim))
    xss[i] = xs
    chartss[i] = charts

# plot
M.newfig()
M.plot()
colormap = plt.get_cmap('winter')
colors=[colormap(k) for k in np.linspace(0, 1, N)]
for i in range(N):
    M.plot_path(zip(xss[i],chartss[i]),color=colors[i])
M.plotx(x,color='r',s=50)
plt.show()

In [None]:
# Delyon/Hu guided process
from src.stochastics.guided_process import *

# guide function
phi = lambda q,v,s: jnp.tensordot((1./s)*jnp.linalg.cholesky(M.g(q)).T,M.StdLog(q,M.F((v,q[1]))).flatten(),(1,0))
A = lambda x,v,w,s: (s**(-2))*jnp.dot(v,jnp.dot(M.g(x),w))
logdetA = lambda x,s: jnp.linalg.slogdet(s**(-2)*M.g(x))[1]

# plot guiding field
M.newfig()
M.plot_field(lambda x: .2*M.StdLog(x,jnp.array([0,0,-1])))
plt.show()

(Brownian_coords_guided,sde_Brownian_coords_guided,chart_update_Brownian_coords_guided,_,_) = get_guided(
    M,M.sde_Brownian_coords,M.chart_update_Brownian_coords,phi,
    lambda x,s: s*jnp.linalg.cholesky(M.gsharp(x)),A,logdetA)

_dts = dts(n_steps=1000)

w = M.Exp(x,np.array([.8,-.5]))
(ts,xs,charts,log_likelihood,log_varphi) = Brownian_coords_guided(x,w,_dts,dWs(M.dim,_dts),1.)
print("log likelihood: ", log_likelihood[-1], ", log varphi: ", log_varphi[-1])

# plot
newfig()
M.plot()
M.plot_path(zip(xs,charts))
M.plotx(x,color='r',s=150)
M.plotx(w,color='k',s=150)
plt.show()

# plot multiple bridges
N = 5
xss = np.zeros((N,_dts.shape[0],M.dim))
chartss = np.zeros((N,_dts.shape[0],x[1].shape[0]))
for i in range(N):
    (ts,xs,charts,log_likelihood,log_varphi) = Brownian_coords_guided(x,w,_dts,dWs(M.dim,_dts),1.)
    xss[i] = xs
    chartss[i] = charts

# plot
M.newfig()
M.plot()
colormap = plt.get_cmap('winter')
colors=[colormap(k) for k in np.linspace(0, 1, N)]
for i in range(N):
    M.plot_path(zip(xss[i],chartss[i]),color=colors[i])
M.plotx(x,color='r',s=100)
M.plotx(w,color='k',s=100)
plt.savefig('S2_bridges.pdf')
plt.show()

In [None]:
# development and Brownian motion from stochastic development
from src.framebundle import FM
from src.stochastics import stochastic_development
from src.stochastics import Brownian_development

FM.initialize(M)
stochastic_development.initialize(M)
Brownian_development.initialize(M)

# develop a curve
t = np.cumsum(dts(n_steps=50))
dxs = np.vstack([2*t,-np.cos(t)]).T
nu = np.linalg.cholesky(M.gsharp(x))
u = (np.concatenate((x[0],nu.flatten())),x[1])
(ts,xs,charts) = M.development(u,dxs,dts(n_steps=50))
# print("u.T*g*u: ",np.array([np.einsum('ji,jk,kl->il',u.reshape((M.dim,-1)),M.g((x,chart)),u.reshape((M.dim,-1))) for (x,u,chart) in zip(xs[:,0:M.dim],xs[:,M.dim:],charts)]))

# plot
newfig()
M.plot()
M.plot_path(zip(xs,charts))
plt.show()

# simulate Brownian Motion
_dts = dts(n_steps=1000)
(ts,xs,charts) = M.Brownian_development(x,_dts,dWs(M.dim,_dts))

# plot
newfig()
M.plot()
M.plot_path(zip(xs,charts))
plt.show()

## Brownian motion on product manifold

In [None]:
# product sde
from src.stochastics import product_sde
from src.stochastics.product_sde import tile
(product,sde_product,chart_update_product) = product_sde.initialize(M,M.sde_Brownian_coords,M.chart_update_Brownian_coords)

N = 5
_dts = dts(n_steps=500,T=1.)
(ts,xss,chartss,*_) = product(tile(x,N),_dts,dWs(N*M.dim,_dts).reshape(-1,N,M.dim),jnp.repeat(1.,N))

# plot
M.newfig()
M.plot()
colormap = plt.get_cmap('winter')
colors=[colormap(k) for k in np.linspace(0, 1, N)]
for i in range(N):
    M.plot_path(zip(xss[:,i],chartss[:,i]),color=colors[i])
M.plotx(x,color='r')
plt.show()

In [None]:
# condition on diagonal of product manifold
from src.stochastics import diagonal_conditioning
diagonal_conditioning.initialize(M,sde_product,chart_update_product)

_dts = dts(n_steps=100,T=.01)
(ts,xss,chartss) = M.diagonal((xss[-1],chartss[-1]),_dts,dWs(N*M.dim,_dts).reshape(-1,N,M.dim),x[1],jnp.repeat(1.,N))

# plot
M.newfig()
M.plot()
colormap = plt.get_cmap('winter')
colors=[colormap(k) for k in np.linspace(0, 1, N)]
for i in range(N):
    M.plot_path(zip(xss[:,i],chartss[:,i]),color=colors[i])
plt.show()

## Most probable paths

In [None]:
# forward mpp equations, from Anisotropic covariance on manifolds and most probable paths,
# Erlend Grong and Stefan Sommer, 2021
from src.framebundle import MPP
MPP.initialize(M)

# integrate mpp
nu = jnp.linalg.cholesky(M.gsharp(x))
u = (jnp.concatenate((x[0],nu.flatten())),x[1])
lamb = jnp.array([2.,.25])
v = jnp.array([2.,0.])
chi = jnp.array([-6.])
print("lambda:\n",lamb,"\nnu:\n",nu,"\nv: ",v,"\nchi:\n",chi)
%time (xs,vs,chis,charts) = M.MPP_forwardt(u,lamb,v,chi)
%time (xs,vs,chis,charts) = M.MPP_forwardt(u,lamb,v,chi)

# plot
newfig()
M.plot()
M.plot_path(zip(xs[:,0:M.dim],charts))
plt.show()

# plot
newfig()
M.plot()
M.plot_path(zip(xs,charts))
plt.show()

# anti-development
print("Svs norm: ",np.array([np.linalg.norm(v/lamb) for v in vs]))
axs = np.cumsum(vs*dts(T,n_steps)[:,np.newaxis],axis=0)
plt.plot(axs[:,0],axs[:,1],'b',label='anti-development')
plt.plot(vs[:,0],vs[:,1],'r',label='v(t)')
plt.legend()
plt.show()

In [None]:
# target point
v = jnp.array([.3,.3])
(xs,charts) = M.Expt(x,v)
y = (xs[-1],charts[-1])

nu = jnp.linalg.cholesky(M.gsharp(x))
u = (jnp.concatenate((x[0],nu.flatten())),x[1])

# find MPP
lamb = jnp.array([.5,.5])
v,chi = M.MPP(u,lamb,y)

print("lambda:\n",lamb,"\nnu:\n",nu,"\nv: ",v,"\nchi:\n",chi)
(xs,vs,chis,charts) = M.MPP_forwardt(u,lamb,v,chi)
print("chiT:\n",chis[-1])

# plot
newfig()
M.plot()
M.plotx(x,u=np.einsum('i,ij->ij',lamb,nu),linewidth = 1.5, s=50)
M.plotx(y,linewidth = 1.5, s=50, color='r')
M.plot_path(zip(xs[:,0:M.dim],charts))
fig = plt.gcf(); ax = fig.gca(); ax.view_init(30, 30) # rotate
plt.show()

# anti-development
print("Svs norm: ",np.array([np.linalg.norm(v/lamb) for v in vs]))
axs = np.cumsum(vs*dts(T,n_steps)[:,np.newaxis],axis=0)
plt.plot(axs[:,0],axs[:,1])
plt.show()


# find MPP
lamb = jnp.array([.5,1.5])
v,chi = M.MPP(u,lamb,y)

print("lambda:\n",lamb,"\nnu:\n",nu,"\nv: ",v,"\nchi:\n",chi)
(xs,vs,chis,charts) = M.MPP_forwardt(u,lamb,v,chi)
print("chiT:\n",chis[-1])

# plot
newfig()
M.plot()
M.plotx(x,u=np.einsum('i,ij->ij',lamb,nu),linewidth = 1.5, s=50)
M.plotx(y,linewidth = 1.5, s=50, color='r')
M.plot_path(zip(xs[:,0:M.dim],charts))
fig = plt.gcf(); ax = fig.gca(); ax.view_init(30, 0) # rotate
# fig = plt.gcf(); ax = fig.gca(); ax.view_init(90, 0) # rotate
# plt.axis('off')
# plt.savefig('S2_MPP_lambda_05_15.pdf')
plt.show()

# with frame
newfig()
M.plot()
M.plotx(x,linewidth = 1.5, s=50)
M.plotx(y,linewidth = 1.5, s=50, color='r')
M.plot_path(zip(xs,charts))
fig = plt.gcf(); ax = fig.gca(); ax.view_init(30, 0) # rotate
plt.show()

# anti-development
print("Svs norm: ",np.array([np.linalg.norm(v/lamb) for v in vs]))
axs = np.cumsum(vs*dts(T,n_steps)[:,np.newaxis],axis=0)
plt.plot(axs[:,0],axs[:,1])
plt.show()