In [None]:
## This file is part of Jax Geometry
#
# Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
# https://bitbucket.org/stefansommer/jaxgeometry
#
# Jax Geometry is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Jax Geometry is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Jax Geometry. If not, see <http://www.gnu.org/licenses/>.
#

# LDDMM landmark dynamics

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from src.manifolds.landmarks import *
M = landmarks(3)
print(M)
from src.plotting import *

In [None]:
# Riemannian structure
from src.Riemannian import metric
metric.initialize(M)

In [None]:
# example configuration
M.k_sigma = jnp.diag(jnp.array([.5,.5]))

q = M.coords(jnp.vstack((np.linspace(-.5,.5,M.N),np.zeros(M.N))).T.flatten())
v = jnp.array(jnp.vstack((np.zeros(M.N),np.ones(M.N))).T.flatten())
p = M.flat(q,v)
print("q = ", q)
print("p = ", p)

## Geodesics

In [None]:
# 2nd order geodesic equation
from src.Riemannian import geodesic
geodesic.initialize(M)

(qs,charts) = M.Expt(q,v)
M.plot()
M.plot_path(zip(qs,charts),v,linewidth=1.5)
plt.show()

In [None]:
# Hamiltonian dynamics
from src.dynamics import Hamiltonian
Hamiltonian.initialize(M)

print(M.H(q,p))

# geodesic
(qs,charts) = M.Exp_Hamiltoniant(q,p)

M.plot()
M.plot_path(zip(qs,charts),v)
plt.show()

# dynamics returning both position and momentum
(ts,qps,_) = M.Hamiltonian_dynamics(q,p,dts())
ps = qps[:,1,:]
print("Energy: ",np.array([M.H((q,chart),p) for (q,p,chart) in zip(qs,ps,charts)]))

## Boundary value problem

In [None]:
# Logarithm map
from src.Riemannian import Log
Log.initialize(M,f=M.Exp_Hamiltonian)

p_Log = M.Log(q,(qs[-1],charts[-1]))[0]
print(p_Log)
print(p)

(qs,charts) = M.Exp_Hamiltoniant(q,p_Log)
M.plot()
M.plot_path(zip(qs,charts),v,linewidth=1.5)
plt.show()

## Curvature

In [None]:
from src.Riemannian import curvature
curvature.initialize(M)
print("curvature shape= ", M.R(q).shape)
# print("curvature = ", M.Rf(q))
# Ricci and scalar curvature:
print("Ricci curvature = ", M.Ricci_curv(q))
print("Scalar curvature = ", M.S_curv(q))

In [None]:
# plot min of Ricci curvature tensor between two landmarks, one fixed at x1=(0,0)
if M.N == 2:
    x1 = jnp.array([0.,0.])

    # grids
    pts = 40 # even number to avoid (0,0), high value implies nicer plot but extended computation time
    border = .2
    minx = -border
    maxx = +border
    miny = -border
    maxy = +border
    X, Y = np.meshgrid(np.linspace(minx,maxx,pts),np.linspace(miny,maxy,pts))
    xy = np.vstack([X.ravel(), Y.ravel()]).T        

    # plot
    newfig()
    cmap = cm.jet
    alpha = 1
    ax = plt.gca()
    fs = np.array([np.min(np.real(np.linalg.eigvals(
        np.dot(M.gsharp(M.coords(np.concatenate((x1,x)))),
               M.Ricci_curv(M.coords(np.concatenate((x1,x))))
              )))) for x in xy])
    norm = mpl.colors.Normalize(vmin=np.min(fs),vmax=np.max(fs))
    colors = cmap(norm(fs)).reshape(X.shape+(4,))
    surf = ax.plot_surface(X, Y, fs.reshape(X.shape), rstride=1, cstride=1, cmap=cmap, facecolors = colors,  linewidth=0., antialiased=True, alpha=alpha, edgecolor=(0,0,0,0), shade=False)
    m = cm.ScalarMappable(cmap=surf.cmap,norm=norm)
    m.set_array(colors)
    plt.colorbar(m, shrink=0.7)
    ax.set_xlim3d(minx,maxx), ax.set_ylim3d(miny,maxy), ax.set_zlim3d(np.min(fs)-1,np.max(fs)+1)

## Brownian Motion

In [None]:
# coordinate form
from src.stochastics import Brownian_coords
Brownian_coords.initialize(M)

(ts,xs,charts) = M.Brownian_coords(q,dWs(M.dim,n_steps=1000))

# plot
M.newfig()
M.plot()
M.plot_path(zip(xs,charts))
plt.show()

# plot multiple sample paths
N = 5
xss = np.zeros((N,xs.shape[0],M.dim))
chartss = np.zeros((N,xs.shape[0],q[1].shape[0]))
for i in range(N):
    (ts,xs,charts) = M.Brownian_coords(q,dWs(M.dim,n_steps=xs.shape[0]))
    xss[i] = xs
    chartss[i] = charts

# plot
M.newfig()
M.plot()
colormap = plt.get_cmap('winter')
colors=[colormap(k) for k in np.linspace(0, 1, N)]
for i in range(N):
    M.plot_path(zip(xss[i],chartss[i]),color=colors[i])
M.plotx(q,color='r')
plt.show()