In [None]:
## This file is part of Jax Geometry
#
# Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
# https://bitbucket.org/stefansommer/jaxgeometry
#
# Jax Geometry is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Jax Geometry is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Jax Geometry. If not, see <http://www.gnu.org/licenses/>.
#

# LDDMM landmark stochastic dynamics

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%cd ..
from src.manifolds.landmarks import *
M = landmarks(3)
print(M)
from src.plotting import *

In [None]:
# Riemannian structure
from src.Riemannian import metric
metric.initialize(M)

## Stochastic EPDiff / Eulerian

In [None]:
minx = -2; maxx = 2
miny = -2; maxy = 2

case = 0
if case <= 0:
    # define noise field grid
    pts = 7
    X, Y = jnp.meshgrid(np.linspace(minx,maxx,pts),np.linspace(miny,maxy,pts))
    xy = jnp.vstack([X.ravel(), Y.ravel()]).T
    sigmas_x = jnp.hstack((xy,xy)).reshape((-1,2))
    sigmas_a = 1.*jnp.tile(np.eye(2),(sigmas_x.shape[0]//2,1))
    #sigmas_x = np.array([[0.,0.]])
    #sigmas_a = np.array([[.1,0.]])

    # noise kernels
    k_alpha = .75
    k_sigma = jnp.diag(jnp.array([(maxx-minx)/(pts-1),(maxy-miny)/(pts-1)]))
elif case <= 1:
    sigmas_x = np.array([[0.,0.],[0.,0.]])
    sigmas_a = np.array([[1.,0.],[0.,1.]])

    # noise kernels
    k_alpha = 1.
    k_sigma = jnp.diag(jnp.ones(M.m))
elif case <= 2:
    sigmas_x = np.array([[-.5,0.],[-.5,0.],[.5,0.],[.5,0.]])
    sigmas_a = np.array([[1.,0.],[0.,1.],[1.,0.],[0.,1.]])

    # noise kernels
    k_alpha = .5
    k_sigma = .5*jnp.diag(jnp.ones(M.m))
    
J = sigmas_x.shape[0]  
print(k_alpha,k_sigma)
inv_k_sigma = jnp.linalg.inv(k_sigma)
k = lambda x: k_alpha*jnp.exp(-.5*jnp.square(jnp.tensordot(x,inv_k_sigma,(x.ndim-1,1))).sum(x.ndim-1))
k_q = lambda q1,q2: k(q1.reshape((-1,M.m))[:,np.newaxis,:]-q2.reshape((-1,M.m))[np.newaxis,:,:])
sigmas = lambda x: jnp.einsum('ij,jd->ijd',k_q(x,sigmas_x),sigmas_a)

# plot all fields
pts = 20
x,y = np.meshgrid(np.linspace(minx,maxx,pts),np.linspace(miny,maxy,pts))
x = x.flatten(); y = y.flatten()
xy = jnp.vstack((x,y)).T

# compute values
sigmasxy = sigmas(xy)

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython import display

# Turn off matplotlib plot in Notebook
plt.ioff()

fig, ax = plt.subplots()

ax.set_xlim(minx,maxx)
ax.set_ylim(miny,maxy)


def animate(i):
    ax.clear()
    return ax.quiver(x,y,sigmasxy[:,i,0],sigmasxy[:,i,1],angles='xy', scale_units='xy', scale=1)

anim = FuncAnimation(fig, animate, frames=J, interval=500, repeat=False)

video = anim.to_html5_video()
html = display.HTML(video)
display.display(html)
plt.close()

In [None]:
# initialize with specified noise kernel
from src.stochastics import Eulerian
Eulerian.initialize(M,k=k)

q = M.coords(jnp.vstack((np.linspace(-.5,.5,M.N),np.zeros(M.N))).T.flatten())
v = jnp.array(jnp.vstack((np.zeros(M.N),np.ones(M.N))).T.flatten())
p = M.flat(q,v)

# integrate
_dts = dts(n_steps=1000)
(ts,qps,charts) = M.Eulerian(q,p,sigmas_x,sigmas_a,_dts,dWs(J,_dts))

# plot
M.newfig()
M.plot()
M.plot_path(zip(qps[:,0,:],charts))
plt.plot(sigmas_x[:,0],sigmas_x[:,1],'.',color='g',markersize=15)
plt.show()

# Most probable paths for Arnaudon-Cruzeiro model

In [None]:
# define domain manifold
from src.manifolds.Euclidean import *
N = Euclidean(2)

# define flow field
u = lambda x,qp: jnp.dot(M.K(x[0],qp[0,:]),qp[1,:])

# MPP Kunita equations
from src.dynamics import MPP_Kunita
MPP_Kunita.initialize(M,N,sigmas,u)

# Curvature
from src.Riemannian import curvature
curvature.initialize(N)

print(N)
print(N.g((jnp.zeros(N.dim),N.chart())))

In [None]:
f = jax.vmap(lambda x: N.logAbsDet((x,N.chart())),1)

x = np.linspace(minx, maxx, 30)
y = np.linspace(miny, maxy, 30)

X, Y = np.meshgrid(x, y)
Z = f(jnp.vstack((X.flatten(),Y.flatten()))).reshape(X.shape)

ax = plt.axes(projection='3d')
s = ax.plot_surface(X,Y,Z,cmap='viridis')


ax.view_init(10, 35)
plt.colorbar(s)
plt.show()

In [None]:
# Hamiltonian dynamics
from src.dynamics import Hamiltonian
Hamiltonian.initialize(M)

# flow arbitrary points of N
def ode_Hamiltonian_advect(c,y):
    t,x,chart = c
    qp, = y
    q = qp[0]
    p = qp[1]
        
    dxt = jnp.tensordot(M.K(x,q),p,(1,0)).reshape((-1,M.m))
    return dxt

M.Hamiltonian_advect = jit(lambda xs,qps,dts: integrate(ode_Hamiltonian_advect,None,
                                                        xs[0].reshape((-1,M.m)),xs[1],dts,qps))

In [None]:
# checks
q = M.coords(jnp.vstack((np.linspace(-.5,.5,M.N),np.zeros(M.N)-.5)).T.flatten())
v = 1.*jnp.array(jnp.vstack((np.zeros(M.N),np.ones(M.N))).T.flatten())
p = M.flat(q,v)
_dts = dts(n_steps=100)
(_,qps,charts_qp) = M.Hamiltonian_dynamics(q,p,_dts)
dqps = jnp.einsum('t...,t->t...',jnp.gradient(qps,axis=0),1/_dts)

pts = 40
# x,y = np.meshgrid(np.linspace(minx,maxx,pts),np.linspace(miny,maxy,pts))
x,y = np.meshgrid(np.linspace(-1.2,1.2,pts),np.linspace(-1.2,1.2,pts))
x = x.flatten(); y = y.flatten()
xy = jnp.vstack((x,y)).T

us = jax.vmap(lambda x,qp: N.u((x,N.chart()),qp),(0,None))
zs = jax.vmap(lambda x,qp: N.z((x,N.chart()),qp),(0,None))
gradfs = jax.vmap(lambda x,qp: jnp.einsum('ik,i',N.gsharp((x,N.chart())),gradx(N.f)((x,N.chart()),qp)),(0,None))

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython import display

# Turn off matplotlib plot in Notebook
plt.ioff()

fig, ax = plt.subplots(1, 2, figsize=(14, 6))

ax[0].set_xlim(minx,maxx)
ax[0].set_ylim(miny,maxy)
ax[1].set_xlim(minx,maxx)
ax[1].set_ylim(miny,maxy)

skip = 5
def animate(i):
    i *= skip
    
    ax[0].clear()
    ax[1].clear()
    
#     uxy = us(xy,qps[i])
    zxy = zs(xy,qps[i])
    gradfxy = gradfs(xy,qps[i])
    
    for j in range(M.N):
        ax[0].plot(qps[0:i,0,j*M.m],qps[0:i,0,j*M.m+1],linewidth=5)
        ax[1].plot(qps[0:i,0,j*M.m],qps[0:i,0,j*M.m+1],linewidth=5)
    ax[0].quiver(x,y,zxy[:,0],zxy[:,1],angles='xy', scale_units='xy', scale=1, color='b')
    ax[1].quiver(x,y,gradfxy[:,0],gradfxy[:,1],angles='xy', scale_units='xy', scale=1, color='k')
    
    return ax

anim = FuncAnimation(fig, animate, frames=_dts.shape[0]//skip, interval=100, repeat=False)

video = anim.to_html5_video()
html = display.HTML(video)
display.display(html)
plt.close()

In [None]:
# number of evaluation points
K = 50

x0s = jnp.vstack((np.linspace(minx+.3,maxx-.3,K),np.zeros(K)-.5)).T
_,xs,_ = M.Hamiltonian_advect((x0s.flatten(),M.chart()),qps,_dts)

N.newfig()
N.plot()
plt.plot(sigmas_x[:,0],sigmas_x[:,1],'.',color='g',markersize=15)
for i in range(K):
    x0 = x0s[i]
    v0 = jnp.tensordot(M.K(x0,qps[0,0]),qps[0,1],(1,0))
    (_,xx1,charts) = M.MPP_AC((x0,N.chart()),v0,qps,dqps,_dts)
#     (_,xx1,charts) = M.MPP_AC((q[0][i*N.dim:(i+1)*N.dim],N.chart()),v[i*N.dim:(i+1)*N.dim],qps,dqps,_dts)

    N.plot_path(zip(xx1[:,0,:],charts))
    N.plot_path(zip(xs[:,i],charts),color='r')
#     N.plot_path(zip(qps[:,0,i*N.dim:(i+1)*N.dim],charts_qp),color='r')

plt.savefig('MPP_AC_forward'+str(case)+'.pdf')
plt.show()

In [None]:
# plot u0
pts = 40
minx = -2; maxx = 2
miny = -2; maxy = 2

x,y = np.meshgrid(np.linspace(minx,maxx,pts),np.linspace(miny,maxy,pts))
x = x.flatten(); y = y.flatten()
xy = jnp.vstack((x,y)).T

us = jax.vmap(lambda x,qp: u((x,N.chart()),qp),(0,None))
uxy = us(xy,qps[0])

N.newfig()
N.plot()
plt.plot(sigmas_x[:,0],sigmas_x[:,1],'.',color='g',markersize=15)
plt.quiver(x,y,uxy[:,0],uxy[:,1],angles='xy', scale_units='xy', scale=1, color='b')
plt.xlim([minx,maxx])
plt.ylim([miny,maxy])

plt.savefig('MPP_AC_U0'+str(case)+'.pdf')
plt.show()

In [None]:
# Boundary value problem
from src.dynamics import MPP_Kunita_Log
MPP_Kunita_Log.initialize(M,N)

_dts = dts(n_steps=100)
(_,qps,charts_qp) = M.Hamiltonian_dynamics(q,p,_dts)
dqps = jnp.einsum('t...,t->t...',jnp.gradient(qps,axis=0),1/_dts)

# plot
N.newfig()
N.plot()
plt.plot(sigmas_x[:,0],sigmas_x[:,1],'.',color='g',markersize=15)
for i in range(K):
    x0 = x0s[i]
    v0 = jnp.tensordot(M.K(x0,qps[0,0]),qps[0,1],(1,0))

    vv = M.Log_MPP_AC((x0,N.chart()),(xs[-1,i],N.chart()),qps,dqps,_dts)[0]
#     vv = M.Log_MPP_AC((q[0][i*N.dim:(i+1)*N.dim],N.chart()),(qps[-1,0,i*N.dim:(i+1)*N.dim],N.chart()),qps)[0]
    
    (_,xx1,charts) = M.MPP_AC((x0,N.chart()),vv,qps,dqps,_dts)
#     (_,xx1,charts) = M.MPP_AC((q[0][i*N.dim:(i+1)*N.dim],N.chart()),vv,qps,dqps,_dts)

    N.plot_path(zip(xx1[:,0,:],charts))
    N.plot_path(zip(xs[:,i],charts),color='r')
#     N.plot_path(zip(qps[:,0,i*N.dim:(i+1)*N.dim],charts_qp),color='r')
    
plt.savefig('MPP_AC_IVP'+str(case)+'.pdf')
plt.show()