In [None]:
## This file is part of Jax Geometry
#
# Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
# https://bitbucket.org/stefansommer/jaxgeometry
#
# Jax Geometry is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Jax Geometry is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Jax Geometry. If not, see <http://www.gnu.org/licenses/>.
#

# Most probable paths and development, Lie groups

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%cd ..

# SO(3)
from src.groups.SON import *
G = SON(3,invariance='right')
print(G)

from src.plotting import *

In [None]:
# setup for testing different versions of stochastic dynamics
q = jnp.array([1e-3,0.,0.])
g = G.psi(q)
v = jnp.array([0.,1.,1.])

# geodesics
xiv=G.VtoLA(v)
(ts,gsv) = G.expt(xiv)
y = gsv[-1]

from src.group import invariant_metric
invariant_metric.initialize(G)
p = G.sharppsi(q,v)
mu = G.sharpV(v)

from src.group import energy
energy.initialize(G)

In [None]:
from src.dynamics import MPP_group
MPP_group.initialize(G,Sigma=jnp.diag(jnp.array([.3,2.,1.])),a=lambda t: jnp.array([1.,0.,0.]))

# forward equations
sigma = jnp.diag(jnp.array([2.,1.,.5]))
alpha = jnp.array([0.,0.,-1.])

_dts = dts()
(ts,alphas) = G.mpp(alpha,_dts,sigma)
(ts,gs) = G.mpprec(g,alphas,_dts,sigma)

# plot
newfig()
G.plot_path(gs)
plt.savefig('MPP_SO3_IVP.pdf')
plt.show()

In [None]:
# mpp between g and y
alpha = G.MPP(g,y,sigma)
(ts,alphas) = G.mpp(alpha,_dts,sigma)
(ts,gs) = G.mpprec(g,alphas,_dts,sigma)

# plot
newfig()
G.plotg(g,color='b')
G.plotg(y,color='k')
G.plot_path(gs)
plt.savefig('MPP_SO3_BVP.pdf')
plt.show()

# Most probable paths and development, homogeneous spaces

In [None]:
# SO(3) acts on S^2
from src.manifolds.S2 import *
M = S2()
print(M)

from src.group.quotient import *

# base point and projection
x = M.coords(jnp.array([0.,0.]))
proj = lambda g: M.act(g,M.F(x))

In [None]:
# vector field and lift
f = lambda x: 1.*M.StdLog(x,M.F(M.coords(jnp.array([jnp.pi/2.,0.]))))
f_emb = lambda x: 1.*M.StdLogEmb(x,M.F(M.coords(jnp.array([jnp.pi/2.,0.]))))
def hf_LA(g): # lift of f, translated to Lie algebra
    frame,_,_,_,horz = horz_vert_split(g,proj,jnp.eye(G.dim),G,M)
    dproj = jnp.einsum('...ij,ijk->...k',jacrev(proj)(g), frame)
    return jnp.linalg.lstsq(dproj,f_emb(x))[0]

# plot field
M.newfig()
M.plot_field(f,scale=.25)
plt.savefig('MPP_S2_field.pdf')
plt.show()

In [None]:
MPP_group.initialize(G,Sigma=jnp.diag(jnp.array([1.,1.,1.])),a=lambda t,g: jax.lax.stop_gradient(hf_LA(g)))

# forward equations
sigma = jnp.diag(jnp.array([2.,1.,.5]))
alpha = jnp.array([0.,0.,-1.])
proj_horz = horz_vert_split(G.e,proj,jnp.eye(G.dim),G,M)[2]
alpha = jnp.dot(proj_horz,alpha)
print(proj_horz,alpha)

_dts = dts()
(ts,alphags) = G.mpp_drift(alpha,g,_dts,sigma)
gs = alphags[:,G.dim:].reshape((-1,G.dim,G.dim))

# plot
newfig()
G.plot_path(gs)
plt.savefig('MPP_S2_IVP_lift.pdf')
plt.show()

# plot
newfig()
M.plot()
M.plot_path(M.acts(gs,M.F(x)))
M.plotx(proj(g),color='b')
plt.savefig('MPP_S2_IVP.pdf')
plt.show()

In [None]:
# mpp between g and y
y = (M.invF((proj(gs[-1]),x[1])),x[1])#M.coords(jnp.array([jnp.pi/8,-.4]))
alpha = G.MPP_drift(g,y,proj,M,sigma)
print(alpha)
(ts,alphags) = G.mpp_drift(alpha,g,_dts,sigma)
gs = alphags[:,G.dim:].reshape((-1,G.dim,G.dim))

# plot
newfig()
G.plot_path(gs)
plt.savefig('MPP_S2_BVP_lift.pdf')
plt.show()

# plot
newfig()
M.plot()
M.plot_path(M.acts(gs,M.F(x)))
M.plotx(proj(g),color='b')
M.plotx(y,color='k')
plt.savefig('MPP_S2_BVP.pdf')
plt.show()