In [None]:
## This file is part of Jax Geometry
#
# Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
# https://bitbucket.org/stefansommer/jaxgeometry
#
# Jax Geometry is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Jax Geometry is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Jax Geometry. If not, see <http://www.gnu.org/licenses/>.
#

# GLN and SPDN dynamics

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from jaxgeometry.groups.GLN import *
G = GLN(3)
print(G)

from jaxgeometry.manifolds.SPDN import *
M = SPDN(3)
print(M)

from jaxgeometry.plotting import *
figsize = 12,12
plt.rcParams['figure.figsize'] = figsize

In [None]:
# some values
v=np.array([.5,0,0,0,0,0,0,0,0])+1e-6*np.random.normal(size=G.dim) # must be non-singular for Expm derivative
xiv=G.VtoLA(v)
x = G.exp(xiv)

In [None]:
# visualization
newfig()
G.plotg(x)
plt.show()

_dts = dts()
gsv = np.zeros((_dts.shape[0],3,3))
for i in range(_dts.shape[0]):
    gsv[i] = G.exp(_dts[i]*xiv)
newfig()
G.plot_path(gsv)
plt.show()

# on SPD(3)
newfig()
M.plot()
x0 = np.eye(M.N).flatten()
M.plot_path(M.acts(gsv,x0))
plt.show()

# ellipsoids
plt.rcParams['figure.figsize'] = 23, 10
M.plot_path(M.acts(gsv,x0),ellipsoid={'alpha': .2, 'step': _dts.shape[0]/4, 'subplot': True})
plt.show()
plt.rcParams['figure.figsize'] = figsize

In [None]:
# define invariant metric on GL(N)
from jaxgeometry.group import invariant_metric
invariant_metric.initialize(G)
from jaxgeometry.group import energy
energy.initialize(G)

In [None]:
# Euler-Poincare dynamics
from jaxgeometry.group import EulerPoincare
EulerPoincare.initialize(G)

# geodesic
(ts,gsv) = G.ExpEPt(G.e,v)
newfig()
G.plot_path(gsv)
plt.show()
(ts,musv) = G.EP(v)
xisv = [G.invFl(mu) for mu in musv]
print("Energy: ",np.array([G.l(xi) for xi in xisv]))

# on SPD(3)
newfig()
M.plot()
x0 = np.eye(M.N).flatten()
M.plot_path(M.acts(gsv,x0))
plt.show()

# ellipsoids
plt.rcParams['figure.figsize'] = 23, 10
M.plot_path(M.acts(gsv,x0),ellipsoid={'alpha': .2, 'step': dts().shape[0]/4, 'subplot': True})
plt.show()
plt.rcParams['figure.figsize'] = figsize

In [None]:
# Lie-Poission dynamics
from jaxgeometry.group import LiePoisson
LiePoisson.initialize(G)

# geodesic
(ts,gsv) = G.ExpLPt(G.e,v)
newfig()
G.plot_path(gsv)
plt.show()
(ts,musv) = G.LP(v)
print("Energy: ",np.array([G.Hminus(mu) for mu in musv]))

In [None]:
# Brownian motion
from jaxgeometry.stochastics import Brownian_inv
Brownian_inv.initialize(G)

_dts = dts(n_steps=100)
(ts,gs,_) = G.Brownian_inv(G.e,_dts,dWs(G.dim,_dts),jnp.sqrt(.1)*jnp.eye(G.emb_dim))

# on SPD(3)
newfig()
M.plot()
x0 = np.eye(M.N).flatten()
M.plot_path(M.acts(gs,x0))
plt.show()

# ellipsoids
plt.rcParams['figure.figsize'] = 23, 10
M.plot_path(M.acts(gsv,x0),ellipsoid={'alpha': .2, 'step': _dts.shape[0]/8, 'subplot': True})
# plt.savefig('SPD3-path.pdf')
plt.show()
plt.rcParams['figure.figsize'] = figsize