In [None]:
## This file is part of Jax Geometry
#
# Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
# https://bitbucket.org/stefansommer/jaxgeometry
#
# Jax Geometry is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Jax Geometry is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Jax Geometry. If not, see <http://www.gnu.org/licenses/>.
#

# Heisenberg group

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from src.manifolds.Heisenberg import *
M = Heisenberg()
print(M)
from src.plotting import *
#%matplotlib notebook

In [None]:
# element, tangent vector and covector
x = M.coords([.5,0.,.5])
v = jnp.array([-.5,0.,0.])

print("x = ", x)
print("v = ", v)

# plot
newfig()
M.plot()
M.plotx(x,u=v,linewidth = 1.5, s=50)
plt.show()

In [None]:
## sub-Riemannian structure
from src.sR import metric
metric.initialize(M)

print("D(x):\n",M.D(x))
print("a(x):\n",M.a(x))

# covector
p = jnp.array([-1.,0,-.5])
v = M.sharp(x,p)
print("v: ",v,", p: ",p)

### Geodesics from Hamiltonian equations

In [None]:
# # Hamiltonian dynamics
print(M.H(x,p))

from src.dynamics import Hamiltonian
Hamiltonian.initialize(M)

# Exponential map from Hamiltonian equations
(xs,charts) = M.Exp_Hamiltoniant(x,p)

# plot
newfig()
M.plot()
M.plotx(x,u=v,linewidth = 1.5, s=50)
M.plot_path(zip(xs,charts),linewidth = 1.5, s=50)
plt.show()

# dynamics returning both position and momentum
(ts,xps,_) = M.Hamiltonian_dynamics(x,p,dts())
ps = xps[:,1,:]
print("Energy: ",np.array([M.H((x,chart),p) for (x,p,chart) in zip(xs,ps,charts)]))

## Boundary value problem

In [None]:
# Logarithm map
from src.Riemannian import Log
Log.initialize(M,f=M.Exp_Hamiltonian)

y = M.coords(jnp.array([0.,0,.0]))
p_Log = M.Log(x,y,v0=jnp.dot(jnp.linalg.pinv(M.a(x)),y[0]-x[0]))[0]
v_Log = M.sharp(x,p_Log)
print("v_Log: ",v_Log,", p_Log: ",p_Log)

(xs,charts) = M.Exp_Hamiltoniant(x,p_Log)
newfig()
M.plot()
M.plot_path(zip(xs,charts),linewidth=1.5)
plt.show()

print((xs[1]-x[0])*n_steps)
print(v_Log*4)

## Brownian Motion

In [None]:
# coordinate form
from src.stochastics import Brownian_sR
Brownian_sR.initialize(M)

_dts = dts(n_steps=1000)
(ts,xs,charts) = M.Brownian_sR(x,_dts,dWs(M.sR_dim,_dts))

# plot
newfig()
M.plot()
M.plot_path(zip(xs,charts))
plt.show()

# plot multiple sample paths
N = 5
xss = np.zeros((N,n_steps,M.dim))
chartss = np.zeros((N,n_steps,x[1].shape[0]))
for i in range(N):
    (ts,xs,charts) = M.Brownian_sR(x,dts(),dWs(M.sR_dim))
    xss[i] = xs
    chartss[i] = charts

# plot
newfig()
M.plot()
colormap = plt.get_cmap('winter')
colors=[colormap(k) for k in np.linspace(0, 1, N)]
for i in range(N):
    M.plot_path(zip(xss[i],chartss[i]),color=colors[i])
M.plotx(x,color='r',s=50)
plt.show()

## Brownian bridge guided

In [None]:
def guide(x,v):
    """ guided towards 0 in Heisenberg group """
    gamma = jnp.arctan2(x[0][1],x[0][0])
    
    f = lambda alpha: (8*jnp.sin(alpha[0]/2)**2*jnp.abs(x[0][2])-jnp.sum(x[0][0:2]**2)*(alpha[0]-jnp.sin(alpha[0])))**2
    alpha = optimize.minimize(f,jnp.array([jnp.pi]),method='BFGS').x[0]
    
    r = jnp.linalg.norm(x[0][0:2])/(2*jnp.sin(alpha/2))
    
    epsilon = 1e-4
    b = jax.lax.cond(jnp.abs(x[0][2])<epsilon,
                     lambda _: 
                         jnp.array([-jnp.linalg.norm(x[0][0:2])*jnp.cos(gamma),
                                    -jnp.linalg.norm(x[0][0:2])*jnp.sin(gamma)]),
                     lambda _: 
                         jnp.array([-r*alpha*jnp.cos(gamma+jnp.sign(x[0][2])*alpha/2),
                                    -r*alpha*jnp.sin(gamma+jnp.sign(x[0][2])*alpha/2)]),
                     None)
    return b

# example
print(x[0],guide(x,None),jnp.dot(M.D(x),guide(x,None)))

# coordinate form
from src.stochastics.guided_process import *

(Brownian_sR_guided,sde_Brownian_sR_guided,*_) = get_guided(
    M,M.sde_Brownian_sR,M.chart_update_Brownian_sR,guide,
    sqrtCov=lambda x: jnp.linalg.cholesky(jnp.tensordot(M.D(x),M.D(x),(0,0))))

_dts = dts(n_steps=500)
(ts,xs,charts,log_likelihood,log_varphi) = Brownian_sR_guided(x,(jnp.zeros_like(x[0]),x[1]),_dts,dWs(M.sR_dim,_dts))
print(xs[-1])

# plot
newfig()
M.plot()
M.plot_path(zip(xs,charts))
plt.show()

plt.plot(jnp.cumsum(_dts),jax.vmap(lambda x: jnp.linalg.norm(x[0:2]),0)(xs),'r')
plt.plot(jnp.cumsum(_dts),jax.vmap(lambda x: jnp.linalg.norm(x[3]),0)(xs),'b')
plt.show()

In [None]:
# plot multiple sample paths
N = 10
_dts = dts(n_steps=500)
xss = np.zeros((N,_dts.shape[0],M.dim))
chartss = np.zeros((N,_dts.shape[0],x[1].shape[0]))
for i in range(N):
    (ts,xs,charts,_,_) = Brownian_sR_guided(x,jnp.zeros_like(x[0]),_dts,dWs(M.sR_dim,_dts))
    xss[i] = xs
    chartss[i] = charts
    
    plt.plot(jnp.cumsum(_dts),jax.vmap(lambda x: jnp.linalg.norm(x[0:2]),0)(xs),'r')
    plt.plot(jnp.cumsum(_dts),jax.vmap(lambda x: jnp.linalg.norm(x[3]),0)(xs),'b')
plt.show()