In [None]:
## This file is part of Jax Geometry
#
# Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
# https://bitbucket.org/stefansommer/jaxgeometry
#
# Jax Geometry is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.|
#
# Jax Geometry is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Jax Geometry. If not, see <http://www.gnu.org/licenses/>.
#

# LDDMM landmark dynamics

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from jaxgeometry.manifolds.landmarks import *
M = landmarks(4)
print(M)
from jaxgeometry.plotting import *

# random key
key = jax.random.PRNGKey(42)

In [None]:
# Riemannian structure
from jaxgeometry.Riemannian import metric
metric.initialize(M)

In [None]:
# example configuration
M.k_sigma = jnp.diag(jnp.array([.5,.5]))

key, subkey = jax.random.split(key)
q = M.coords(jnp.vstack((np.linspace(-.5,.5,M.N),1+np.zeros(M.N))).T.flatten())
q = M.coords(jax.random.normal(subkey, shape=q[0].shape))
# if M.std_basis:
v = jnp.array(jnp.vstack((np.zeros(M.N),np.ones(M.N))).T.flatten())
p = M.flat(q,v)
# else:
#     q = 
#     p = .5*jnp.vstack((jnp.ones((M.codim//M.m,M.m)),np.zeros(((M.dim-M.codim)//M.m,M.m)))).flatten()
#     v = M.sharp(q,p)
print("q = ", q)
print("v = ", v)
print("p = ", p)

## Hamiltonian systems

In [None]:
# define Erlend's Hamiltonian
#@jax.jit
#def H(q,p):
#    q = q[0].reshape((-1,2)).view(complex)
#    p = p.reshape((-1,2)).view(complex)
#    return (.5*jnp.sum(p.real)**2 + .5*jnp.sum((q.conj()**3)*p).imag**2).real

#@jax.jit
#def H(z, p):
#    z = z[0].reshape((-1,2))
#    p = p.reshape((-1,2))
#    # Extract real and imaginary parts
#    z_real = z[:, 0]
#    z_imag = z[:, 1]
#    p_real = p[:, 0]
#    p_imag = p[:, 1]
#    
#    # Compute complex conjugates
#    z_conj = z_real - 1j * z_imag
#    
#    # Compute |z_j|^2
#    abs_z_squared = jnp.tanh(z_real**2 + z_imag**2)
#    
#    # Compute |z_j|^4
#    abs_z_fourth = abs_z_squared**2
#    
#    # Compute Re(z_j * p_j)
#    real_part = z_real * p_real + z_imag * p_imag
#    
#    # Compute Im(z_j * p_j)
#    imag_part = z_real * p_imag - z_imag * p_real
#    
#    # Compute each term
#    term1 = jnp.sum(abs_z_squared * real_part)**2
#    term2 = jnp.sum(abs_z_squared * imag_part)**2
#    term3 = jnp.sum(abs_z_fourth * real_part)**2
#    term4 = jnp.sum(abs_z_fourth * imag_part)**2
#    
#    # Compute the final result
#    H = (term1 + term2 + term3 + term4) / 2
#    
#    return H

@jax.jit
def H(z, p):
    z = z[0].reshape((-1,2))
    p = p.reshape((-1,2))
    def compute_re_im(zj, pj):
        zj_abs = jnp.abs(zj)
        zj_angle = jnp.angle(zj)
        factor = jnp.tanh(zj_abs ** 2)
        factor_sq = jnp.tanh(zj_abs ** 2) ** 2
        exp_neg_i_theta = jnp.cos(-zj_angle) + 1j * jnp.sin(-zj_angle)
        term = exp_neg_i_theta * pj
        return factor * jnp.real(term), factor * jnp.imag(term), factor_sq * jnp.real(term), factor_sq * jnp.imag(term)
    
    re_im_terms = vmap(compute_re_im)(z[:, 0] + 1j * z[:, 1], p[:, 0] + 1j * p[:, 1])
    re_terms, im_terms, re_sq_terms, im_sq_terms = re_im_terms

    sum_re = jnp.sum(re_terms)
    sum_im = jnp.sum(im_terms)
    sum_re_sq = jnp.sum(re_sq_terms)
    sum_im_sq = jnp.sum(im_sq_terms)

    h = sum_re ** 2 + sum_im ** 2 + sum_re_sq ** 2 + sum_im_sq ** 2

    return 2 * h

M.H = H

# Hamiltonian dynamics
from jaxgeometry.dynamics import Hamiltonian
Hamiltonian.initialize(M)

# momentum
p = 1e-1*jnp.ones(M.N*2)

print(M.H(q,p))

# geodesic
(qs,charts) = M.Exp_Hamiltoniant(q,p)
#(_,qps,charts) = M.Hamiltonian_dynamics(q,p,_dts)
#qs = qps[:,0]; ps = qps[:,1]

M.plot()
M.plot_path(zip(qs,charts))
plt.show()

# dynamics returning both position and momentum
n_steps = 1000
_dts = dts(n_steps=n_steps)
(ts,qps,_) = M.Hamiltonian_dynamics(q,p,_dts)
qs = qps[:,0]; ps = qps[:,1]
print("Energy: ",np.array([M.H((q,chart),p) for (q,p,chart) in zip(qs,ps,charts)]))

## Boundary value problem

In [None]:
# Logarithm map
from jaxgeometry.Riemannian import Log
Log.initialize(M,f=M.Exp_Hamiltonian)

#v = M.coords(q[0]+jnp.zeros_like(q[0]).at[0::2].set(-.1).at[1::2].set(.1))
key, subkey = jax.random.split(key)
v = M.coords(q[0]+.1*jax.random.normal(subkey, shape=(M.N*2,)))

tol = 1e-6  # Set a tolerance for the error
max_iterations = 10  # Set a maximum number of iterations to prevent infinite loops

best_err = float('inf')
best_p_Log = None
best_v0 = None

for i in range(max_iterations):
    key, subkey = jax.random.split(key)
    v0 = 1e-1*jax.random.normal(subkey, shape=p.shape)
    p_Log,err = M.Log(q,v,v0=v0)
    (_q,chart) = M.Exp_Hamiltonian(q,p_Log)
    
    print("Iteration:", i, "Error:", err)
    if err < best_err:
        best_err = err
        best_p_Log = p_Log
        best_v0 = v0
    
    if best_err < tol:
        break

p_Log, err, v0 = best_p_Log, best_err, best_v0

print("Number of iterations:", i)
print("Initial guess:", v0)
print("Final p_Log:", p_Log)
print("Final error:", err)
print("Energy:", M.H(q,p_Log))

(qs,charts) = M.Exp_Hamiltoniant(q,p_Log)
print("Err (square):",1/M.dim*jnp.sum((jnp.square(qs[-1]-v[0]))))
print("Err (max landmark dist): ",jnp.max(jnp.linalg.norm((qs[-1]-v[0]).reshape((-1,2)),axis=1)))
M.plot()
M.plotx(q,color='k')
M.plotx(v,color='r')
M.plot_path(zip(qs,charts),v,linewidth=1.5)
plt.show()

## Visualization of the diffeomorphism solving boundary value problem
(this is old code not updated for the current problem, but it could potentially be useful for visualizing the generated flows later on)

In [None]:
n_landmark = M.N
minx = np.min([np.reshape(q[0],(n_landmark,2))[:,0], np.reshape(v[0],(n_landmark,2))[:,0]]) - 1
miny = np.min([np.reshape(q[0],(n_landmark,2))[:,1], np.reshape(v[0],(n_landmark,2))[:,1]]) - 1
maxy = np.max([np.reshape(q[0],(n_landmark,2))[:,1], np.reshape(v[0],(n_landmark,2))[:,1]]) + 1
maxx = np.max([np.reshape(q[0],(n_landmark,2))[:,0], np.reshape(v[0],(n_landmark,2))[:,0]]) + 1

pts_v =100
pts_h =100
K = pts_h*pts_v # number of evaluation points
x,y = np.meshgrid(np.linspace(minx,maxx,pts_v),np.linspace(miny,maxy,pts_h))
x = x.flatten(); y = y.flatten()
xy = jnp.vstack((x,y)).T

# flow arbitrary points of the domain
def ode_Hamiltonian_advect(c,y):
    t,x,chart = c
    qp, = y
    q = qp[0]
    p = qp[1]

    # jax.debug.print("{x}", x=x.shape)
    # jax.debug.print("{q}", q=q.shape)
    # jax.debug.print("{MK}", MK=M.K(x,q).shape)

    dxt = jnp.tensordot(M.K(x,q),p,(1,0)).reshape((-1,M.m))
    return dxt

M.Hamiltonian_advect = lambda xs,qps,dts: integrate(ode_Hamiltonian_advect,
                                                    None,
                                                    xs[0].reshape((-1,M.m)),
                                                    xs[1],
                                                    dts,
                                                    qps[::1])

# landmark flow
n_steps = 100
_dts = dts(n_steps=n_steps)
(_,qps,charts_qp) = M.Hamiltonian_dynamics(q,p,_dts)

# grid/ambient flow
_,xs = M.Hamiltonian_advect((xy.flatten(),M.chart()),qps,_dts)

print(qps.shape, xs.shape)
qs = qps[:,0]
ms = qps[:,1]
#M.plot_path(zip(qs,charts),v,linewidth=1.5)

                                               

In [None]:
M.newfig()

q_0 = v[0].reshape(n_landmark,2)

#M.plot()
#M.plotx(v,color="r")
##plt.scatter(xs[-1,:][:,0],xs[-1,:][:,1],color='b')
#for i in range(0, K, pts_v):
#    plt.plot(xy[np.arange(i,(i+pts_v)),0],xy[np.arange(i,(i+pts_v)),1],
#             color="k", linewidth=0.5, alpha=0.5)
#for i in range(0,K,1):
#    plt.plot(xy[np.arange(i,K,pts_v),0],xy[np.arange(i,K,pts_v),1],
#             color="k", linewidth=0.5, alpha=0.5)
#plt.title("Target shape")
#plt.xlim([minx, maxx])
#plt.ylim([miny-0.2, maxy+0.2])
#plt.show()


diffeo_time_idx = 0
grid_time_idx = 0

M.plot()
#M.plotx(q,color="b")
#M.plotx(v,color="r")
# plt.scatter(q_0[:,0], q_0[:,1], marker='o', c = np.arange(k), cmap=cmap, s=50) # target shape
plt.scatter(q_0[:,0], q_0[:,1], marker='o', c = 'red', s=50) # target shape
# plt.scatter(np.reshape(qs[diffeo_time_idx,:],(n_landmark,2))[:,0],np.reshape(qs[diffeo_time_idx,:],(n_landmark,2))[:,1],
#             c = np.arange(k), cmap=cmap, marker='+', s=120)
plt.scatter(np.reshape(qs[diffeo_time_idx,:],(n_landmark,2))[:,0],np.reshape(qs[diffeo_time_idx,:],(n_landmark,2))[:,1],
            c = 'green', marker='+', s=120)
#plt.scatter(xs[-1,:][:,0],xs[-1,:][:,1],color='b')

for i in range(0, K, pts_v):  # dette loop forbinder punkter med streger horizontalt
    plt.plot(xs[grid_time_idx,:][np.arange(i,(i+pts_v)),0],xs[grid_time_idx,:][np.arange(i,(i+pts_v)),1],
             color="k",linewidth=0.5, alpha=0.5
             )

for i in range(0,K,1): # dette loop forbinder punkter med streger vertikalt
    idxs = [i + j*pts_v for j in [0,1] if i + j*pts_v < K]  # get index of current point and the point one row below
    plt.plot(xs[grid_time_idx,:][idxs,0],xs[grid_time_idx,:][idxs,1],
             color="k",linewidth=0.5, alpha=0.5)

plt.title("Time 0")
plt.axis('auto')
plt.xlim([minx-1, maxx+1])
plt.ylim([miny-1, maxy+1])
plt.show()

############

diffeo_time_idx = 100
grid_time_idx = 100

M.plot()
#M.plotx(q,color="b")
#M.plotx(v,color="r")
# plt.scatter(q_0[:,0], q_0[:,1], marker='o', c = np.arange(k), cmap=cmap, s=50) # target shape
# plt.scatter(np.reshape(qs[diffeo_time_idx,:],(n_landmark,2))[:,0],
#             np.reshape(qs[diffeo_time_idx,:],(n_landmark,2))[:,1],
#             c = np.arange(k), cmap=cmap, marker='+', s=300)
plt.scatter(q_0[:,0], q_0[:,1], marker='o', c = 'red', s=50) # target shape
plt.scatter(np.reshape(qs[diffeo_time_idx,:],(n_landmark,2))[:,0],
            np.reshape(qs[diffeo_time_idx,:],(n_landmark,2))[:,1],
            c = 'green', marker='+', s=300)
#plt.scatter(xs[-1,:][:,0],xs[-1,:][:,1],color='b')

for i in range(0, K, pts_v):  # dette loop forbinder punkter med streger horizontalt
    plt.plot(xs[grid_time_idx,:][np.arange(i,(i+pts_v)),0],
             xs[grid_time_idx,:][np.arange(i,(i+pts_v)),1],
             color="k",linewidth=0.5, alpha=0.5
             )

for i in range(0,K,1): # dette loop forbinder punkter med streger vertikalt
    idxs = [i + j*pts_v for j in [0,1] if i + j*pts_v < K]  # get index of current point and the point one row below
    plt.plot(xs[grid_time_idx,:][idxs,0],
             xs[grid_time_idx,:][idxs,1],
             color="k",linewidth=0.5, alpha=0.5)
plt.title("End time")

plt.axis('auto')
plt.xlim([minx-1, maxx+1])
plt.ylim([miny-1, maxy+1])
plt.show()




Animation of deformation, momentum and vector fields

In [None]:
### on a fixed grid

# create grid 
pts_v_spatial=20
pts_h_spatial=20
K_spatial = pts_h_spatial*pts_v_spatial # number of evaluation points
x_spatial,y_spatial = np.meshgrid(np.linspace(minx,maxx,pts_v_spatial),np.linspace(miny,maxy,pts_h_spatial))
x_spatial = x_spatial.flatten(); y_spatial = y_spatial.flatten()
xy_spatial = jnp.vstack((x_spatial,y_spatial)).T

def ode_Hamiltonian_advect_fixed_grid(c,y):
    t,x = c
    q = y[0]
    p = y[1]

    dxt = jnp.tensordot(M.K(x,q),p,(1,0)).reshape((-1,M.m))
    return dxt

def ode_Hamiltonian_advect_moving_grid(y):
    q = y[0]
    p = y[1]

    dxt = jnp.tensordot(M.K(q,q),p,(1,0)).reshape((-1,M.m))
    return dxt

t_tmp, x_tmp = None, xy_spatial
c_tmp = (t_tmp, x_tmp)

def func1(y):
    return ode_Hamiltonian_advect_fixed_grid(c_tmp, y)

def func2(y):
    return ode_Hamiltonian_advect_moving_grid(y)


vectorized_func1 = jax.vmap(func1)
vectorized_func2 = jax.vmap(func2)

ys_tmp = qps
spatial_velo_grid = vectorized_func1(ys_tmp)
spatial_velo_landmarks = vectorized_func2(ys_tmp)

In [None]:
from matplotlib.animation import FuncAnimation
from IPython import display

fig, ax = plt.subplots(figsize=(20, 20))

ax.set_aspect(0.9)

line, = ax.plot([],'*')

ax.set_xlim(minx-1, maxx+1)
ax.set_ylim(miny-1, maxy+1)


# shape and ambient grid
grid, = ax.plot([], [], '.', markersize=1)  # grid
points_deformed, = ax.plot([], [], '+', markersize=15, color='g')  # deformed shape points
points_target, = ax.plot([], [], 'o', markersize=10, color='r', alpha=1)  # target shape points

# momentum field
scatter_points = np.array([np.reshape(qs[0,:],(n_landmark,2))[:,0], 
                  np.reshape(qs[0,:],(n_landmark,2))[:,1]]).T

Q_momentum = ax.quiver(scatter_points[:, 0], scatter_points[:, 1], np.zeros_like(scatter_points[:, 0]), 
                      np.zeros_like(scatter_points[:, 1]), 
                      color='g', scale=2, scale_units='xy', angles='xy',
                      alpha=0.7, width=0.003, headwidth=3, headlength=4
                      )

momentum, = ax.plot([], [], '.', markersize=10, color='k', alpha=0.1)  # momentum along deformed shape 

# add spatial velocity field along a fixed grid
Q_spatial_grid = ax.quiver(xy_spatial[:,0], xy_spatial[:,1], spatial_velo_grid[0,:,0], spatial_velo_grid[0,:,1],
                           alpha=0.5, scale=10, scale_units='xy', angles='xy', width=0.003, headwidth=3, headlength=4,
                           color='k')
Q_spatial_landmarks = ax.quiver(qps[0,0,0::2], qps[0,0,1::2], spatial_velo_landmarks[0,:,0], spatial_velo_landmarks[0,:,1],
                           alpha=0.5, scale=2, scale_units='xy', angles='xy', width=0.003, headwidth=3, headlength=4,
                           color='b')

from matplotlib.lines import Line2D

# Create legend
legend_elements = [Line2D([0], [0], color='g', lw=2, label='Momentum field on landmarks'),
                   Line2D([0], [0], color='b', lw=2, label='Velocity field on landmarks'),
                   Line2D([0], [0], color='k', lw=2, label='Velocity field on grid')]

# Place legend on the axes
ax.legend(handles=legend_elements, loc='upper right', fontsize=20)


ax.set_title('Flow of diffeomorphisms solving the boundary value problem', fontsize=30)

def animate(t):
    grid.set_data((xs[t,:,0],
                   xs[t,:,1]))
    
    scatter_points[:,0] = np.reshape(qs[t,:],(n_landmark,2))[:,0]
    scatter_points[:,1] = np.reshape(qs[t,:],(n_landmark,2))[:,1]

    points_deformed.set_data((scatter_points[:,0],
                              scatter_points[:,1]))
    points_target.set_data((q_0[:,0], q_0[:,1]))

    U_scatter = np.reshape(ms[t,:],(n_landmark,2))[:,0]
    V_scatter = np.reshape(ms[t,:],(n_landmark,2))[:,1]
    Q_momentum.set_offsets(scatter_points)
    Q_momentum.set_UVC(U_scatter, V_scatter)

    Q_spatial_grid.set_UVC(spatial_velo_grid[t,:,0], spatial_velo_grid[t,:,1])

    Q_spatial_landmarks.set_offsets(qps[t,0].reshape((-1,2)))
    Q_spatial_landmarks.set_UVC(spatial_velo_landmarks[t,:,0], spatial_velo_landmarks[t,:,1])
    

    return Q_momentum, grid, points_deformed, points_target, Q_spatial_grid, Q_spatial_landmarks  # add the new quiver plot to the return statement


anim = FuncAnimation(fig, animate, frames=n_steps, interval=100, blit=True)

video = anim.to_html5_video()
html = display.HTML(video)
display.display(html)
plt.close()