In [None]:
## This file is part of Jax Geometry
#
# Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
# https://bitbucket.org/stefansommer/jaxgeometry
#
# Jax Geometry is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Jax Geometry is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Jax Geometry. If not, see <http://www.gnu.org/licenses/>.
#

# SO(3) group operations and dynamics

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from jaxgeometry.groups.SON import *
G = SON(3)
print(G)
from jaxgeometry.plotting import *
#%matplotlib notebook

In [None]:
# visualization
newfig()
G.plotg(G.e)
plt.show()

# geodesics in three directions
v=jnp.array([1,0,0])
xiv=G.VtoLA(v)
(ts,gsv) = G.expt(xiv)
newfig()
G.plot_path(gsv)
plt.show()

v=jnp.array([0,1,0])
xiv=G.VtoLA(v)
(ts,gsv) = G.expt(xiv)
newfig()
G.plot_path(gsv)
plt.show()

v=jnp.array([0,0,1])
xiv=G.VtoLA(v)
(ts,gsv) = G.expt(xiv)
newfig()
G.plot_path(gsv)
plt.show()

In [None]:
# plot path on S2
from jaxgeometry.manifolds.S2 import *
M = S2()
print(M)

# plot
newfig()
M.plot()
x = M.F(M.coords([0.,0.]))
M.plot_path(M.acts(gsv,x))
plt.show()

In [None]:
# setup for testing different versions of dynamics
q = jnp.array([1e-3,0.,0.])
g = G.psi(q)
v = jnp.array([0.,1.,1.])

from jaxgeometry.group import invariant_metric
invariant_metric.initialize(G)
p = G.sharppsi(q,v)
mu = G.sharpV(v)
print(p)
print(mu)

from jaxgeometry.group import energy
energy.initialize(G)

In [None]:
# Euler-Poincare dynamics
from jaxgeometry.group import EulerPoincare
EulerPoincare.initialize(G)# Euler-Poincare dynamics

# geodesic
(ts,gsv) = G.ExpEPt(G.psi(q),v)
newfig()
G.plot_path(gsv)
plt.show()
(ts,musv) = G.EP(mu)
xisv = [G.invFl(mu) for mu in musv]
print("Energy: ",np.array([G.l(xi) for xi in xisv]))
print("Orthogonality: ",np.array([np.linalg.norm(np.dot(g,g.T)-np.eye(int(np.sqrt(G.emb_dim))),np.inf) for g in gsv]))

# on S2
newfig()
M.plot(rotate=(30,-15))
x = jnp.array([0,0,1])
M.plot_path(M.acts(gsv,x))
plt.show()

In [None]:
# Lie-Poission dynamics
from jaxgeometry.group import LiePoisson
LiePoisson.initialize(G)

# geodesic
(ts,gsv) = G.ExpLPt(G.psi(q),v)
newfig()
G.plot_path(gsv)
plt.show()
(ts,musv) = G.LP(mu)
print("Energy: ",np.array([G.Hminus(mu) for mu in musv]))
print("Orthogonality: ",np.array([np.linalg.norm(np.dot(g,g.T)-np.eye(int(np.sqrt(G.dim))),np.inf) for g in gsv]))

In [None]:
# Hamiltonian dynamics
from jaxgeometry.dynamics import Hamiltonian
Hamiltonian.initialize(G)

# test Hamiltionian and gradients
print(p)
print(G.H(q,p))

# geodesic
qsv,_ = G.Exp_Hamiltoniant((q,None),p)
gsv = np.array([G.psi(q) for q in qsv])
newfig()
G.plot_path(gsv)
plt.show()
(ts,qpsv,_) = G.Hamiltonian_dynamics((q,None),p,dts())
psv = qpsv[:,1,:]
print("Energy: ",np.array([G.H(q,p) for (q,p) in zip(qsv,psv)]))