In [None]:
## This file is part of Jax Geometry
#
# Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
# https://bitbucket.org/stefansommer/jaxgeometry
#
# Jax Geometry is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Jax Geometry is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Jax Geometry. If not, see <http://www.gnu.org/licenses/>.
#

# Most probable landmark paths


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%cd ..
from src.manifolds.landmarks import *
M = landmarks(3,k_sigma=.5*jnp.eye(2))
print(M)
from src.plotting import *

In [None]:
# Riemannian structure
from src.Riemannian import metric
metric.initialize(M)

## Eulerian noise fields

In [None]:
minx = -2; maxx = 2
miny = -2; maxy = 2

case = 0
if case <= 0:
    # define noise field grid
    pts = 7
    X, Y = jnp.meshgrid(np.linspace(minx,maxx,pts),np.linspace(miny,maxy,pts))
    xy = jnp.vstack([X.ravel(), Y.ravel()]).T
    sigmas_x = jnp.hstack((xy,xy)).reshape((-1,2))
    sigmas_a = 1.*jnp.tile(np.eye(2),(sigmas_x.shape[0]//2,1))
    #sigmas_x = np.array([[0.,0.]])
    #sigmas_a = np.array([[.1,0.]])

    # noise kernels
    k_alpha = .75
    k_sigma = jnp.diag(jnp.array([(maxx-minx)/(pts-1),(maxy-miny)/(pts-1)]))
elif case <= 1:
    sigmas_x = np.array([[0.,0.],[0.,0.]])
    sigmas_a = np.array([[1.,0.],[0.,1.]])

    # noise kernels
    k_alpha = 1.
    k_sigma = jnp.diag(jnp.ones(M.m))
elif case <= 2:
    sigmas_x = np.array([[-.5,0.],[-.5,0.],[.5,0.],[.5,0.]])
    sigmas_a = np.array([[1.,0.],[0.,1.],[1.,0.],[0.,1.]])

    # noise kernels
    k_alpha = .5
    k_sigma = .5*jnp.diag(jnp.ones(M.m))
    
J = sigmas_x.shape[0]  
print(k_alpha,k_sigma)
inv_k_sigma = jnp.linalg.inv(k_sigma)
k = lambda x: k_alpha*jnp.exp(-.5*jnp.square(jnp.tensordot(x,inv_k_sigma,(x.ndim-1,1))).sum(x.ndim-1))
k_q = lambda q1,q2: k(q1.reshape((-1,M.m))[:,np.newaxis,:]-q2.reshape((-1,M.m))[np.newaxis,:,:])
sigmas = lambda x: jnp.einsum('ij,jd->ijd',k_q(x,sigmas_x),sigmas_a)
dsigmas = lambda x: jnp.einsum('i...jk,jd->ijdk',jax.vmap(jacrev(k_q,0),(0,None))(x.reshape((-1,M.m)),sigmas_x),sigmas_a)

# plot all fields
pts = 20
x,y = np.meshgrid(np.linspace(minx,maxx,pts),np.linspace(miny,maxy,pts))
x = x.flatten(); y = y.flatten()
xy = jnp.vstack((x,y)).T

# compute values
sigmasxy = sigmas(xy)

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython import display

# Turn off matplotlib plot in Notebook
plt.ioff()

fig, ax = plt.subplots()

ax.set_xlim(minx,maxx)
ax.set_ylim(miny,maxy)


def animate(i):
    ax.clear()
    return ax.quiver(x,y,sigmasxy[:,i,0],sigmasxy[:,i,1],angles='xy', scale_units='xy', scale=1)

anim = FuncAnimation(fig, animate, frames=J, interval=500, repeat=False)

video = anim.to_html5_video()
html = display.HTML(video)
display.display(html)
plt.close()

# Most probable paths

In [None]:
# define drift field
a = lambda x,qp: jnp.dot(M.K(x.flatten(),qp[0,:]),qp[1,:]).reshape((-1,M.m))

# Hamiltonian dynamics
from src.dynamics import Hamiltonian
Hamiltonian.initialize(M)

M.setN(3)
q = M.coords(jnp.vstack((np.linspace(-.5,.5,M.N),-.5+np.zeros(M.N))).T.flatten())
v = jnp.array(jnp.vstack((np.zeros(M.N),2*np.ones(M.N))).T.flatten())
p = M.flat(q,v)

# integrate
_dts = dts(n_steps=100)
(ts,qps,charts) = M.Hamiltonian_dynamics(q,p,_dts)

# flow to create a(t)
def ode_Hamiltonian_advect(c,y):
    t,x,chart = c
    qp, = y
    q = qp[0]
    p = qp[1]
        
    dxt = jnp.tensordot(M.K(x,q),p,(1,0)).reshape((-1,M.m))
    return dxt

_dts = dts()
M.Hamiltonian_advect = jit(lambda xs,qps,dts: integrate(ode_Hamiltonian_advect,None,
                                                        xs[0].reshape((-1,M.m)),xs[1],_dts,qps))

In [None]:
# MPP landmark equations
from src.dynamics import MPP_landmarks
MPP_landmarks.initialize(M,sigmas,dsigmas,a)

In [None]:
# number of landmarks
K = 10

chart = M.chart()

x0s = jnp.vstack((np.linspace(minx+.3,maxx-.3,K),np.zeros(K)-.5)).T
_,xs_advect = M.Hamiltonian_advect((x0s.flatten(),M.chart()),qps,_dts)
M.setN(K)
_,xs,_,charts = M.MPP_landmarks((xs_advect[0].flatten(),chart),jnp.zeros(M.dim),qps,_dts)

M.newfig()
M.plot()
M.plot_path(zip(qps[:,0],charts),color='k')
M.plot_path(zip(xs_advect,itertools.cycle(chart)))
M.plot_path(zip(xs,itertools.cycle(chart)),color='r')
plt.plot(sigmas_x[:,0],sigmas_x[:,1],'.',color='g',markersize=15)
plt.show()

In [None]:
# Boundary value problem
from src.dynamics import MPP_landmarks_Log
MPP_landmarks_Log.initialize(M)

x = (xs_advect[0].flatten(),chart)
y = (xs_advect[75].flatten(),chart)
lambd,_ = M.Log_MPP_landmarks(x,y,qps,_dts)
_,xs,lambds,charts = M.MPP_landmarks((xs_advect[0].flatten(),chart),lambd,qps,_dts)

# plot
M.newfig()
M.plot()
# M.plot_path(zip(qps[:,0],charts))
M.plot_path(zip(xs_advect,itertools.cycle(chart)))
M.plot_path(zip(xs,itertools.cycle(chart)),color='r')
M.plotx(y,color='k')
plt.plot(sigmas_x[:,0],sigmas_x[:,1],'.',color='g',markersize=15)
    
plt.savefig('MPP_landmarks_BVP'+str(case)+'.pdf')
plt.show()

In [None]:
# comparison to most probable transformation

# define domain manifold
from src.manifolds.Euclidean import *
N = Euclidean(2)

# MPP Kunita equations
from src.dynamics import MPP_Kunita
u = lambda x,qp: jnp.dot(M.K(x[0],qp[0,:]),qp[1,:])
MPP_Kunita.initialize(M,N,sigmas,u)

# Curvature
from src.Riemannian import curvature
curvature.initialize(N)

from src.dynamics import MPP_Kunita_Log
MPP_Kunita_Log.initialize(M,N)

dqps = jnp.einsum('t...,t->t...',jnp.gradient(qps,axis=0),1/_dts)

# plot
N.newfig()
N.plot()
M.plot_path(zip(xs_advect,itertools.cycle(chart)))
plt.plot(sigmas_x[:,0],sigmas_x[:,1],'.',color='g',markersize=15)
for i in range(K):
    x0 = x[0].reshape((-1,M.m))[i]
    v0 = jnp.tensordot(M.K(x0,qps[0,0]),qps[0,1],(1,0))
    y0 = y[0].reshape((-1,M.m))[i]

    vv = M.Log_MPP_AC((x0,N.chart()),(y0,N.chart()),qps,dqps,_dts)[0]
    (_,xx1,charts) = M.MPP_AC((x0,N.chart()),vv,qps,dqps,_dts)

    N.plot_path(zip(xx1[:,0,:],charts),color='r')
M.plotx(y,color='k')
    
plt.savefig('MPP_AC_BVP_comparison_'+str(case)+'.pdf')
plt.show()

# Shape

In [None]:
# MPP landmark equations
from src.dynamics import MPP_landmarks
MPP_landmarks.initialize(M,sigmas,dsigmas,a)

# number of landmarks
M.setN(128)
phis = jnp.linspace(0,2*jnp.pi,M.N)
x0 = M.coords(jnp.vstack((jnp.cos(phis),jnp.sin(phis))).T.flatten())

_,xs_advect = M.Hamiltonian_advect(x0,qps,_dts)
%time _,xs,_,charts = M.MPP_landmarks(x0,jnp.zeros(M.dim),qps,_dts)

M.newfig()
M.plot()
M.plot_path(zip(xs_advect,itertools.cycle(chart)))
M.plot_path(zip(xs,itertools.cycle(chart)),color='r')
M.plotx(x0,color='r')
plt.plot(sigmas_x[:,0],sigmas_x[:,1],'.',color='g',markersize=15)
plt.show()

In [None]:
y = (xs_advect[-1].flatten(),chart)
%time lambd,_ = M.Log_MPP_landmarks(x0,y,.5*qps,_dts)
_,xs,lambds,charts = M.MPP_landmarks(x0,lambd,.5*qps,_dts)

# plot
M.newfig()
M.plot()
M.plot_path(zip(xs,itertools.cycle(chart)),color='r')
M.plot_path(zip(xs_zero_noise,itertools.cycle(chart)),color='b')
M.plotx(x0,color='r')
M.plotx(y,color='k')
# plt.plot(sigmas_x[:,0],sigmas_x[:,1],'.',color='g',markersize=15)
    
plt.savefig('MPP_landmarks_shape_BVP.pdf')
plt.show()

_,xs_zero_noise,lambds,charts = M.MPP_landmarks(x0,jnp.zeros_like(lambd),.5*qps,_dts)

# plot
M.newfig()
M.plot()
M.plot_path(zip(xs_zero_noise,itertools.cycle(chart)),color='r')
M.plotx(x0,color='r')
M.plotx(y,color='k')
# plt.plot(sigmas_x[:,0],sigmas_x[:,1],'.',color='g',markersize=15)
plt.savefig('MPP_landmarks_shape_no_noise.pdf')
plt.show()

In [None]:
# comparison to most probable transformation

# MPP Kunita equations
from src.dynamics import MPP_Kunita
u = lambda x,qp: jnp.dot(M.K(x[0],qp[0,:]),qp[1,:])
MPP_Kunita.initialize(M,N,sigmas,u)

# Curvature
from src.Riemannian import curvature
curvature.initialize(N)

from src.dynamics import MPP_Kunita_Log
MPP_Kunita_Log.initialize(M,N)

dqps = jnp.einsum('t...,t->t...',jnp.gradient(qps,axis=0),1/_dts)

# plot
N.newfig()
N.plot()
# plt.plot(sigmas_x[:,0],sigmas_x[:,1],'.',color='g',markersize=15)
for i in range(M.N):
    xi = x0[0].reshape((-1,M.m))[i]
    yi = y[0].reshape((-1,M.m))[i]

    vv = M.Log_MPP_AC((xi,N.chart()),(yi,N.chart()),qps,dqps,_dts)[0]
    (_,xx1,charts) = M.MPP_AC((xi,N.chart()),vv,qps,dqps,_dts)

    N.plot_path(zip(xx1[:,0,:],charts),color='r')
M.plotx(y,color='k')
  
plt.savefig('MPP_AC_shape_BVP.pdf')
plt.show()