In [None]:
## This file is part of Jax Geometry
#
# Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
# https://bitbucket.org/stefansommer/jaxgeometry
#
# Jax Geometry is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Jax Geometry is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Jax Geometry. If not, see <http://www.gnu.org/licenses/>.
#

# LDDMM landmark stochastic dynamics

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from src.manifolds.landmarks import *
M = landmarks(3)
print(M)
from src.plotting import *

In [None]:
# Riemannian structure
from src.Riemannian import metric
metric.initialize(M)

## Brownian Motion

In [None]:
# coordinate form
from src.stochastics import Brownian_coords
Brownian_coords.initialize(M)

q = M.coords(jnp.vstack((np.linspace(-.5,.5,M.N),np.zeros(M.N))).T.flatten())

_dts = dts(n_steps=1000)
(ts,xs,charts) = M.Brownian_coords(q,_dts,dWs(M.dim,_dts))

# plot
M.newfig()
M.plot()
M.plotx(q,color='r')
M.plot_path(zip(xs,charts))
plt.show()

# plot multiple sample paths
N = 5
xss = np.zeros((N,xs.shape[0],M.dim))
chartss = np.zeros((N,xs.shape[0],q[1].shape[0]))
for i in range(N):
    (ts,xs,charts) = M.Brownian_coords(q,_dts,dWs(M.dim,_dts))
    xss[i] = xs
    chartss[i] = charts

# plot
M.newfig()
M.plot()
colormap = plt.get_cmap('winter')
colors=[colormap(k) for k in np.linspace(0, 1, N)]
for i in range(N):
    M.plot_path(zip(xss[i],chartss[i]),color=colors[i])
M.plotx(q,color='r')
plt.show()

# Langevin equations
see https://arxiv.org/abs/1605.09276

In [None]:
from src.stochastics import Langevin
Langevin.initialize(M)

M.setN(4)
q = M.coords(jnp.vstack((np.linspace(-.5,.5,M.N),np.zeros(M.N))).T.flatten())
v = jnp.array(jnp.vstack((np.zeros(M.N),np.ones(M.N))).T.flatten())

p = M.flat(q,v)
print("q = ", q)
print("p = ", p)

_dts = dts(n_steps=1000)
(ts,qps,charts) = M.Langevin(q,p,.5,.25,_dts,dWs(M.dim,_dts))

# plot
M.newfig()
M.plot()
M.plot_path(zip(qps[:,0,:],charts))
plt.show()

## Stochastic EPDiff / Eulerian

In [None]:
minx = -2; maxx = 2
miny = -2; maxy = 2

case = 0
if case <= 0:
    # define noise field grid
    pts = 7
    X, Y = jnp.meshgrid(np.linspace(minx,maxx,pts),np.linspace(miny,maxy,pts))
    xy = jnp.vstack([X.ravel(), Y.ravel()]).T
    sigmas_x = jnp.hstack((xy,xy)).reshape((-1,2))
    sigmas_a = 1.*jnp.tile(np.eye(2),(sigmas_x.shape[0]//2,1))
    #sigmas_x = np.array([[0.,0.]])
    #sigmas_a = np.array([[.1,0.]])

    # noise kernels
    k_alpha = .5
    k_sigma = jnp.diag(jnp.array([(maxx-minx)/(pts-1),(maxy-miny)/(pts-1)]))
elif case <= 1:
    sigmas_x = np.array([[0.,0.],[0.,0.]])
    sigmas_a = np.array([[1.,0.],[0.,1.]])

    # noise kernels
    k_alpha = 1.
    k_sigma = jnp.diag(jnp.ones(M.m))
elif case <= 2:
    sigmas_x = np.array([[-.5,0.],[-.5,0.],[.5,0.],[.5,0.]])
    sigmas_a = np.array([[1.,0.],[0.,1.],[1.,0.],[0.,1.]])

    # noise kernels
    k_alpha = .5
    k_sigma = .5*jnp.diag(jnp.ones(M.m))
    
J = sigmas_x.shape[0]  
print(k_alpha,k_sigma)
inv_k_sigma = jnp.linalg.inv(k_sigma)
k = lambda x: k_alpha*jnp.exp(-.5*jnp.square(jnp.tensordot(x,inv_k_sigma,(x.ndim-1,1))).sum(x.ndim-1))
k_q = lambda q1,q2: k(q1.reshape((-1,M.m))[:,np.newaxis,:]-q2.reshape((-1,M.m))[np.newaxis,:,:])
sigmas = lambda x: jnp.einsum('ij,jd->ijd',k_q(x,sigmas_x),sigmas_a)

# plot all fields
pts = 20
x,y = np.meshgrid(np.linspace(minx,maxx,pts),np.linspace(miny,maxy,pts))
x = x.flatten(); y = y.flatten()
xy = jnp.vstack((x,y)).T

# compute values
sigmasxy = sigmas(xy)

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython import display

# Turn off matplotlib plot in Notebook
plt.ioff()

fig, ax = plt.subplots()

ax.set_xlim(minx,maxx)
ax.set_ylim(miny,maxy)


def animate(i):
    ax.clear()
    return ax.quiver(x,y,sigmasxy[:,i,0],sigmasxy[:,i,1],angles='xy', scale_units='xy', scale=1)

anim = FuncAnimation(fig, animate, frames=J, interval=500, repeat=False)

video = anim.to_html5_video()
html = display.HTML(video)
display.display(html)
plt.close()

In [None]:
# initialize with specified noise kernel
from src.stochastics import Eulerian
Eulerian.initialize(M,k=k)

q = M.coords(jnp.vstack((np.linspace(-.5,.5,M.N),np.zeros(M.N))).T.flatten())
v = jnp.array(jnp.vstack((np.zeros(M.N),np.ones(M.N))).T.flatten())
p = M.flat(q,v)

# integrate
_dts = dts(n_steps=1000)
(ts,qps,charts) = M.Eulerian(q,p,sigmas_x,sigmas_a,_dts,dWs(J,_dts))

# plot
M.newfig()
M.plot()
M.plot_path(zip(qps[:,0,:],charts))
plt.plot(sigmas_x[:,0],sigmas_x[:,1],'x')
plt.show()

# Most probable paths for Arnaudon-Cruzeiro model

In [None]:
# define domain manifold
from src.manifolds.Euclidean import *
N = Euclidean(2)

N.gsharp = lambda x: jnp.einsum('pri,qrj->ij',sigmas(x[0]),sigmas(x[0]))
delattr(N,'g')

# Riemannian structure
from src.Riemannian import metric
metric.initialize(N)
# Curvature
from src.Riemannian import curvature
curvature.initialize(N)

print(N)
print(N.g((jnp.zeros(N.dim),N.chart())))

In [None]:
f = jax.vmap(lambda x: N.logAbsDet((x,N.chart())),1)

x = np.linspace(minx, maxx, 30)
y = np.linspace(miny, maxy, 30)

X, Y = np.meshgrid(x, y)
Z = f(jnp.vstack((X.flatten(),Y.flatten()))).reshape(X.shape)

ax = plt.axes(projection='3d')
s = ax.plot_surface(X,Y,Z,cmap='viridis')


ax.view_init(10, 35)
plt.colorbar(s)
plt.show()

In [None]:
# Hamiltonian dynamics
from src.dynamics import Hamiltonian
Hamiltonian.initialize(M)

# u_t as \sum_{i=1}^N K(\cdot,q_i)p_i
u = lambda x,qp: jnp.dot(M.K(x[0],qp[0,:]),qp[1,:])

# scalar part of elliptic operator L = 1/2 \Delta_g + z
z = lambda x,qp: u(x,qp)+(jnp.einsum('ij,i->j',N.gsharp(x),gradx(N.logAbsDetsharp)(x))
                          -2*jnp.einsum('...rj,...rii->j',sigmas(x[0]),jax.jacrev(sigmas)(x[0])) )

# Onsager-Machlup deviation from geodesic energy
# f = lambda x,qp: .5*jnp.einsum('rs,sr->',N.gsharp(x),
#                                    jacrevx(z)(x,qp)+jnp.einsum('k,srk->sr',z(x,qp),N.Gamma_g(x)))-1/12*N.S_curv(x)
f = lambda x,qp: .5*N.divsharp(x,lambda x: z(x,qp))-1/12*N.S_curv(x)

def initialize(M):
    """ Most probable paths for Arnaudon-Cruzeiro models """

    def ode_MPP_AC(c,y):
        t,xx1,chart = c
        qp,dqp = y
        x = xx1[0] # point
        x1 = xx1[1] # derivative
        
        g = N.g((x,chart))
        gsharp = N.gsharp((x,chart))
        Gamma = N.Gamma_g((x,chart))
        
        zx = z((x,chart),qp)
        gradz = jacrevx(z)((x,chart),qp)
        dz = jnp.einsum('...ij,ij',jax.jacrev(z,argnums=1)((x,chart),qp),dqp)
        
        dx2 = (dz-jnp.einsum('i,j,kij->k',x1,x1,Gamma)
               +jnp.einsum('i,ki->k',x1,gradz+jnp.einsum('kij,j->ki',Gamma,zx))
               -jnp.einsum('rs,ri,s,ik->k',g,gradz+jnp.einsum('j,rij->ri',zx,Gamma),x1-zx,gsharp)
               +jnp.einsum('ik,i',gsharp,gradx(f)((x,chart),qp))
            )
        dx1 = x1
        return jnp.stack((dx1,dx2))

    def chart_update_MPP_AC(xv,chart,y):
        if M.do_chart_update is None:
            return (xv,chart)
    
        v = xv[1]
        x = (xv[0],chart)

        update = M.do_chart_update(x)
        new_chart = M.centered_chart(x)
        new_x = M.update_coords(x,new_chart)[0]
    
        return (jnp.where(update,
                                jnp.stack((new_x,M.update_vector(x,new_x,new_chart,v))),
                                xv),
                jnp.where(update,
                                new_chart,
                                chart))
    
    M.MPP_AC = jit(lambda x,v,qps,dqps,dts: integrate(ode_MPP_AC,chart_update_MPP_AC,jnp.stack((x[0],v)),x[1],dts,qps,dqps))

initialize(M)

In [None]:
# checks
q = M.coords(jnp.vstack((np.linspace(-.5,.5,M.N),np.zeros(M.N)-.5)).T.flatten())
v = 1.*jnp.array(jnp.vstack((np.zeros(M.N),np.ones(M.N))).T.flatten())
p = M.flat(q,v)
_dts = dts(n_steps=100)
(_,qps,charts_qp) = M.Hamiltonian_dynamics(q,p,_dts)
dqps = jnp.einsum('t...,t->t...',jnp.gradient(qps,axis=0),1/_dts)

pts = 40
# x,y = np.meshgrid(np.linspace(minx,maxx,pts),np.linspace(miny,maxy,pts))
x,y = np.meshgrid(np.linspace(-1.2,1.2,pts),np.linspace(-1.2,1.2,pts))
x = x.flatten(); y = y.flatten()
xy = jnp.vstack((x,y)).T

us = jax.vmap(lambda x,qp: u((x,N.chart()),qp),(0,None))
zs = jax.vmap(lambda x,qp: z((x,N.chart()),qp),(0,None))
gradfs = jax.vmap(lambda x,qp: jnp.einsum('ik,i',N.gsharp((x,N.chart())),gradx(f)((x,N.chart()),qp)),(0,None))

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython import display

# Turn off matplotlib plot in Notebook
plt.ioff()

fig, ax = plt.subplots(1, 2, figsize=(14, 6))

ax[0].set_xlim(minx,maxx)
ax[0].set_ylim(miny,maxy)
ax[1].set_xlim(minx,maxx)
ax[1].set_ylim(miny,maxy)

skip = 5
def animate(i):
    i *= skip
    
    ax[0].clear()
    ax[1].clear()
    
#     uxy = us(xy,qps[i])
    zxy = zs(xy,qps[i])
    gradfxy = gradfs(xy,qps[i])
    
    for j in range(M.N):
        ax[0].plot(qps[0:i,0,j*M.m],qps[0:i,0,j*M.m+1],linewidth=5)
        ax[1].plot(qps[0:i,0,j*M.m],qps[0:i,0,j*M.m+1],linewidth=5)
    ax[0].quiver(x,y,zxy[:,0],zxy[:,1],angles='xy', scale_units='xy', scale=1, color='b')
    ax[1].quiver(x,y,gradfxy[:,0],gradfxy[:,1],angles='xy', scale_units='xy', scale=1, color='k')
    
    return ax

anim = FuncAnimation(fig, animate, frames=_dts.shape[0]//skip, interval=100, repeat=False)

video = anim.to_html5_video()
html = display.HTML(video)
display.display(html)
plt.close()

In [None]:
# integrate and plot MPPs
N.newfig()
N.plot()
plt.plot(sigmas_x[:,0],sigmas_x[:,1],'x')
for i in range(M.N):
    (_,xx1,charts) = M.MPP_AC((q[0][i*N.dim:(i+1)*N.dim],N.chart()),v[i*N.dim:(i+1)*N.dim],qps,dqps,_dts)

    N.plot_path(zip(xx1[:,0,:],charts))
    N.plot_path(zip(qps[:,0,i*N.dim:(i+1)*N.dim],charts_qp),color='r')

plt.show()

In [None]:
# MPP log

method='BFGS'

def loss(x,v,y,qps,dqps,_dts):
    (_,xx1,charts) = M.MPP_AC(x,v,qps,dqps,_dts)
    (x1,chart1) = (xx1[-1,0],charts[-1])
    y_chart1 = M.update_coords(y,chart1)
    return 1./N.dim*jnp.sum(jnp.square(x1 - y_chart1[0]))
dloss = jax.grad(loss,1)
from scipy.optimize import approx_fprime
dloss = lambda x,v,y,qps,dqps,_dts: approx_fprime(v,lambda v: loss(x,v,y,qps,dqps,_dts),1e-4)

from scipy.optimize import minimize,fmin_bfgs,fmin_cg
def shoot(x,y,qps,v0=None):        

    if v0 is None:
        v0 = jnp.zeros(N.dim)

    _dts = dts(n_steps=100)
    (_,qps,charts_qp) = M.Hamiltonian_dynamics(q,p,_dts)
    dqps = jnp.einsum('t...,t->t...',jnp.gradient(qps,axis=0),1/_dts)
    
    res = minimize(lambda w: (loss(x,w,y,qps,dqps,_dts),dloss(x,w,y,qps,dqps,_dts)), v0, method=method, jac=True, options={'disp': False, 'maxiter': 100})
#     res = minimize(lambda w: loss(x,w,y,qps,dqps,_dts), v0, method=method, jac=False, options={'disp': False, 'maxiter': 100})

    print(res)
    
    return (res.x,res.fun)

M.Log_MPP_AC = shoot

In [None]:
_dts = dts(n_steps=100)
(_,qps,charts_qp) = M.Hamiltonian_dynamics(q,p,_dts)
dqps = jnp.einsum('t...,t->t...',jnp.gradient(qps,axis=0),1/_dts)

# plot
N.newfig()
N.plot()
plt.plot(sigmas_x[:,0],sigmas_x[:,1],'x')
for i in range(M.N):
    vv = M.Log_MPP_AC((q[0][i*N.dim:(i+1)*N.dim],N.chart()),(qps[-1,0,i*N.dim:(i+1)*N.dim],N.chart()),qps)[0]
    
    (_,xx1,charts) = M.MPP_AC((q[0][i*N.dim:(i+1)*N.dim],N.chart()),vv,qps,dqps,_dts)

    N.plot_path(zip(xx1[:,0,:],charts))
    N.plot_path(zip(qps[:,0,i*N.dim:(i+1)*N.dim],charts_qp),color='r')

plt.show()

# High-dimensional system

In [None]:
from scipy import misc
import jax.scipy as jsp

# generate stochastic images
global key
keys = jax.random.split(key)
key = keys[0]
subkeys = keys[1:]
%time images = jnp.sqrt(T/n_steps)*random.normal(subkeys[0],(n_steps,64,64,M.m))
image = images[0,:,:]

print("Size of noise basis: ", images.shape[1]*images.shape[2])

# Smooth the noisy image with a 2D Gaussian smoothing kernel.
scale = 1
x = jnp.linspace(-3, 3, 17)
window = jsp.stats.norm.pdf(x,0,scale) * jsp.stats.norm.pdf(x[:, None],0,scale)
%time smooth_image = jax.vmap(lambda im: jsp.signal.convolve(im, window, mode='same'),2,2)(image)

# plot
fig, ax = plt.subplots(M.m, 2, figsize=(12, 10))
for i in range(M.m):
    ax[i,0].imshow(image[:,:,i], cmap='binary_r')
    ax[i,0].set_title('original')
    ax[i,1].imshow(smooth_image[:,:,i], cmap='binary_r')
    ax[i,1].set_title('convolved');


In [None]:
X, Y = np.meshgrid(jnp.linspace(0,63,35),jnp.linspace(0,63,35))
new_image = jax.vmap(lambda im: jsp.ndimage.map_coordinates(im.T,jnp.vstack((X.flatten(),Y.flatten())),order=1),
                     2,1)(smooth_image).reshape(X.shape+(M.m,))
fig, ax = plt.subplots(M.m, 2, figsize=(12, 10))
for i in range(M.m):
    ax[i,0].imshow(smooth_image[:,:,i], cmap='binary_r')
    ax[i,0].set_title('convolvedl')
    ax[i,1].imshow(new_image[:,:,i], cmap='binary_r')
    ax[i,1].set_title('interploated');

In [None]:
from matplotlib import cm
(_,ax) = newfig3d()
X, Y = np.meshgrid(x,x)
surf = ax.plot_surface(X,Y,window,cmap=cm.coolwarm)
# Customize the z axis.
ax.set_zlim(0, .2)
fig.colorbar(surf, shrink=0.5, aspect=5, ax=ax)
plt.show()

In [None]:
# smoothing kernel for Q^{1/2}
kernel_dim = 17
x = jnp.linspace(-3, 3, kernel_dim)
window = jsp.stats.norm.pdf(x) * jsp.stats.norm.pdf(x[:, None],0)
convolve = jax.vmap(lambda dW: jsp.signal.convolve(dW, window, mode='same'),2,2)
interpolate = jax.vmap(lambda Q12dW,q: jsp.ndimage.map_coordinates(Q12dW.T,q.T,order=1),(2,None),1)

def sde_Eulerian_infdim_noise(q,dW):
    # multiply noise on Q^{1/2}
    sqrtQdW = convolve(dW)
    # evluate at x
    dx = interpolate(sqrtQdW,q.reshape((M.N,M.m)))
    return dx

# generate noise for all t
global key
keys = jax.random.split(key)
key = keys[0]
subkeys = keys[1:]
%time dW = jnp.sqrt(T/n_steps)*random.normal(subkeys[0],(n_steps,64,64,M.m))

# evaluate sde function on position q and noise dW
sde_Eulerian_infdim_noise(q[0],dW[0])
%time sde_Eulerian_infdim_noise(q[0],dW[0])
None

In [None]:
M.N = 64
phis = jnp.linspace(0,2*jnp.pi,M.N)
q = M.coords(jnp.vstack((jnp.cos(phis),jnp.sin(phis))).T.flatten())

# # plot
# M.newfig()
# M.plot()
# M.plotx(q)
# plt.show()

In [None]:
# generate noise for all t
global key
keys = jax.random.split(key)
key = keys[0]
subkeys = keys[1:]
dW = jnp.sqrt(T/n_steps)*random.normal(subkeys[0],(n_steps,64,64,M.m))

coords_to_pixels = lambda q: (24*q.reshape((-1,M.m))+jnp.array([32,32])[np.newaxis,:])

# smoothing kernel for Q^{1/2}
kernel_dim = 17
x = jnp.linspace(-3, 3, kernel_dim)
window = lambda amp,scale: amp*scale*jnp.sqrt(2*jnp.pi)*jsp.stats.norm.pdf(x,0,scale) * jsp.stats.norm.pdf(x[:, None],0,scale)
# convolve = jax.vmap(lambda amp,scale,dW: jsp.signal.convolve(dW, window(amp,scale), mode='same'),(None,None,2),2)
convolve = jax.vmap(jax.vmap(lambda amp,scale,dW: jsp.signal.convolve(dW, window(amp,scale), mode='same'),(None,None,2),2),(None,None,0),0)
interpolate = jax.vmap(lambda Q12dW,q: jsp.ndimage.map_coordinates(Q12dW.T,coords_to_pixels(q).T,order=1),(2,None),1)
    
def sde_Eulerian_infdim_noise(c,y):
    t,q,_ = c
    dt,sqrtQdW = y

    X = None # to be implemented
    det = jnp.zeros_like(q)
    # evaluate at x
    sto = interpolate(sqrtQdW,q.reshape((M.N,M.m))).flatten()
    return (det,sto,X)

Eulerian_q = lambda q,dts,dW: integrate_sde(sde_Eulerian_infdim_noise,integrator_ito,None,q[0],q[1],dts,dW)
Eulerian = lambda q,amp,scale,dts,dW: Eulerian_q(q,dts,convolve(amp,scale,dW))[0:3]
    
# integrate
amp = .1
scale = 10
Eulerian(q,amp,scale,dts(),dW)
%time (ts,qs,charts) = Eulerian(q,amp,scale,dts(),dW)
print(qs.shape)

# # plot
# M.newfig()
# M.plot()
# M.plot_path(zip(qs,charts))
# plt.show()

In [None]:
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython import display

# Turn off matplotlib plot in Notebook
plt.ioff()

fig, ax = plt.subplots()

line, = ax.plot([],'*')

ax.set_xlim(-1.2, 1.2)
ax.set_ylim(-1.2, 1.2)


def animate(i):
    line.set_data((qs[i].reshape((-1,M.m))[:,0],qs[i].reshape((-1,M.m))[:,1]))
    return line

# animate(0)
# plt.show()

anim = FuncAnimation(fig, animate, frames=n_steps, interval=100)

video = anim.to_html5_video()
html = display.HTML(video)
display.display(html)
plt.close()

In [None]:
# # checks
# N.newfig()
# N.plot()
# qs = qps[:,0,:].reshape((-1,M.N,M.m))
# qsint = (qps[0][jnp.newaxis]+jnp.cumsum(np.einsum('t...,t->t...',dqps,_dts),axis=0))[:,0,:].reshape((-1,M.N,M.m))
# for i in range(M.N):
#     plt.plot(qs[:,i,0],qs[:,i,1])
#     plt.plot(qsint[:,i,0],qsint[:,i,1],'*')    
# plt.show()

# N.newfig()
# N.plot()

# qs = qps[:,0,:].reshape((-1,M.N,M.m))
# qsint = (qps[0][jnp.newaxis]+jnp.cumsum(np.einsum('t...,t->t...',dqps,_dts),axis=0))[:,0,:].reshape((-1,M.N,M.m))
# for i in range(M.N):
#     (_,xx1,charts) = M.MPP_AC((q[0][i*N.dim:(i+1)*N.dim],N.chart()),v[i*N.dim:(i+1)*N.dim],qps,dqps,_dts)
#     xs = xx1[:,0]
#     x1s = xx1[:,1]
#     dxs = jnp.einsum('t...,t->t...',jnp.gradient(xs,axis=0),1/_dts)
    
#     xsint1 = xs[0][jnp.newaxis]+jnp.cumsum(np.einsum('t...,t->t...',dxs,_dts),axis=0)
#     xsint2 = xs[0][jnp.newaxis]+jnp.cumsum(np.einsum('t...,t->t...',x1s,_dts),axis=0)
    
#     plt.plot(xs[:,0],xs[:,1])
#     plt.plot(xsint1[:,0],xsint1[:,1],'x')    
#     plt.plot(xsint2[:,0],xsint2[:,1],'*')    
# plt.show()