In [None]:
## This file is part of Jax Geometry
#
# Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
# https://bitbucket.org/stefansommer/jaxgeometry
#
# Jax Geometry is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Jax Geometry is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Jax Geometry. If not, see <http://www.gnu.org/licenses/>.
#

# Score matching on the Heisenberg group

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from jaxgeometry.manifolds.Heisenberg import *
M = Heisenberg()
print(M)
from jaxgeometry.plotting import *
from IPython.display import clear_output, Image
#%matplotlib notebook

In [None]:
# seed
global key
seed = 42434154
#import os; seed = int(os.urandom(5).hex(), 16)
key = jax.random.PRNGKey(seed)

In [None]:
# element, tangent vector and covector
X = M.coords([0.,0.,0.])

## sub-Riemannian structure
from jaxgeometry.sR import metric
metric.initialize(M)

## Brownian Motion

In [None]:
from jaxgeometry.stochastics import Brownian_sR
Brownian_sR.initialize(M)

_dts = dts(n_steps=1000)
(ts,xs,charts) = M.Brownian_sR(X,_dts,dWs(M.sR_dim,_dts))

# plot
newfig()
M.plot()
M.plot_path(zip(xs,charts))
plt.show()

In [None]:
# product sde
from jaxgeometry.stochastics import product_sde
from jaxgeometry.stochastics.product_sde import tile
(product,sde_product,chart_update_product) = product_sde.initialize(M,M.sde_Brownian_sR,M.chart_update_Brownian_sR)

In [None]:
# get coordinate representation from point in embedding space, i.e. recover x from F(x)
def get_coords(x):
    chart = M.chart()
    return (x,chart)

# map embedding space tangent vector to the tangent bundle
def to_TM(Fx,v):
    return v


In [None]:
# plotting
%matplotlib inline

def plot_field(X,xys,zs,view=(35,45)):
    # Create a meshgrid for the 3D space
    xm, ym, zm = np.meshgrid(xys,xys,zs)

    # Define the vector field (u, v, w are the components of the field)
    #s = net.apply(params, jnp.stack((xm.flatten(),ym.flatten(),zm.flatten())).T)
    s = jnp.squeeze(jax.vmap(X)(jnp.stack((xm.flatten(),ym.flatten(),zm.flatten())).T))
    # s = jax.vmap(lambda x: M.D((x,0)))(jnp.stack((xm.flatten(),ym.flatten(),zm.flatten())).T)[:,:,0]
    u = s[:,0].reshape(xm.shape); v = s[:,1].reshape(ym.shape); w = s[:,2].reshape(zm.shape);

    # Calculate the magnitude of vectors
    magnitude = np.linalg.norm(s,axis=1)
    magnitude = magnitude * 2

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.quiver(xm, ym, zm, u, v, w, color=plt.cm.jet(magnitude), length=1., linewidth=1.0, normalize=False)
    ax.set_xlabel('X'); ax.set_ylabel('Y'); ax.set_zlabel('Z')
    ax.set_xlim([jnp.min(xys), jnp.max(xys)]); ax.set_ylim([jnp.min(xys),jnp.max(xys)]); 
    ax.set_zlim([-1.,1.]); # ax.set_zlim([jnp.min(zs),jnp.max(zs)])
    ax.view_init(*view)
    plt.show()



# Score matching

In [None]:
import wandb

import flax
from flax import linen as nn
import optax
import tensorflow_datasets as tfds
import tensorflow as tf

sweep_configuration = {
 "name": "sR-Heisenberg",
 "metric": {"name": "loss", "goal": "minimize"},
 "method": "grid",
 "parameters": {
    "learning_rate": {"values": [.01]}, # Learning rate for optimizer
    "epochs": {"values": [2500]}, # Number of passes over the dataset
    "batches_per_epoch": {"values": [2**3]}, # batches per epoch
    "batch_size": {"values": [2**6]}, # usually multiple of samples_per_x0
    "nodes": {"values": [(10,15,15),(15,15,15),(10,10,15),(10,10,20)]}, # nodes per layer
    "activation": {"values": ['relu']}, # nodes per layer
    "T": {"values": [1.]}, # max diffusion time
    "n_steps": {"values": [1000]}, # time discretization
    },
}

def train(run=False,wandbmode='online'):
    if run is False:
        run = wandb.init(mode=wandbmode)
    print(wandb.config)
    learning_rate = wandb.config.learning_rate
    epochs = wandb.config.epochs
    batch_size = wandb.config.batch_size
    batches_per_epoch = wandb.config.batches_per_epoch
    nodes = wandb.config.nodes
    activation = wandb.config.activation
    T = wandb.config.T
    n_steps = wandb.config.n_steps
    _dts = dts(T=T,n_steps=n_steps)
    dt = _dts[0]
    _ts = jnp.cumsum(_dts)
    
    x = M.coords([0.,0.,0.])
    global x0s
    x0s = (jnp.tile(x[0],(batch_size,1)),jnp.tile(x[1],(batch_size,1)))

    def generator():
        while True:
            """Generates batches of samples."""
            N = batch_size
            _dWs = dWs(N*M.sR_dim,_dts).reshape(-1,N,M.sR_dim)
            global x0s
            (ts,xss,chartss,*_) = product(x0s,_dts,_dWs,jnp.repeat(1.,N))
            #x0s = (xss[-1],chartss[-1])
            yield jnp.concatenate((jax.vmap(lambda x,chart: M.F((x,chart)))(xss,chartss),_dWs),-1)

    ds = tf.data.Dataset.from_generator(generator,output_types=tf.float32,
                                        output_shapes=([n_steps,batch_size,M.dim+M.sR_dim]))

    print(ds)
    print(ds.element_spec)
    ds = iter(tfds.as_numpy(ds))
    ## plot sample
    #Fxs = next(ds)
    ## plot
    #newfig()
    #M.plot()
    #for i in range(batch_size):
    #     M.plot_path(Fxs[:,i,0:M.dim])
    #     M.plotx(Fxs[-1,i,0:M.dim],color='r')
    #plt.show()

    class Net(nn.Module):
        @nn.compact
        def __call__(self, x):
            """Create model."""
            model = nn.Sequential(
                sum([(nn.Dense(node),getattr(jax.nn,activation)) for node in nodes],())+(nn.Dense(M.sR_dim),)
            )
            t = x[...,0]; Fx = x[...,1:]
            return model(x)/t

    # initialize net and parameters
    s1 = Net() # score
    s2 = Net() # diagonal of gradient of score
    s1_div = Net() # score with divergence loss
    global key
    key,subkey = jax.random.split(key)
    params_s1 = s1.init(subkey, jnp.zeros((1+M.dim)))
    params_s2 = s2.init(subkey, jnp.zeros((1+M.dim)))
    params_s1_div = s1_div.init(subkey, jnp.zeros((1+M.dim)))

    def loss_s1(params, data):
        """ compute denoising loss """
        def f(t,xnoise): 
            Fx = xnoise[0:M.dim]
            delta = sigma2 = dt; sigma = jnp.sqrt(sigma2) 
            noise = xnoise[M.dim:]; z = noise/sigma
            _s1 = s1.apply(params,jnp.hstack((t,Fx)))
            return jnp.sum(jnp.dot(_s1,delta*_s1+2*noise)) # (5.1) in Grong, Habermann, Sommer 2024
        #    return jnp.sum(jnp.square(-z/sigma-_s1))
        v = jnp.mean(jax.vmap(
                        jax.vmap(
                            f,
                        (0,0)),
                    (None,1))(_ts,data))
        return v
    def loss_s1_Heisenberg(params, data):
        """ compute denoising loss """
        def f(t,Fxnoise,tildeFx): 
            Fx = Fxnoise[0:M.dim]
            delta = sigma2 = dt; sigma = jnp.sqrt(sigma2) 
            _s1 = s1.apply(params,jnp.hstack((t,Fx)))
            hatS = -jnp.array(
                [(Fx[0]-tildeFx[0])-jnp.pi*(Fx[1]-tildeFx[1])*jnp.sign(Fx[2]-tildeFx[2]-(tildeFx[0]*Fx[1]-Fx[0]*tildeFx[1])/2)/2,
                 (Fx[1]-tildeFx[1])+jnp.pi*(Fx[0]-tildeFx[0])*jnp.sign(Fx[2]-tildeFx[2]-(tildeFx[0]*Fx[1]-Fx[0]*tildeFx[1])/2)/2]
                 )
            return jnp.sum(jnp.dot(_s1,delta*_s1-2*hatS)) # 
            # left to implement (5.5) in Grong, Habermann, Sommer 2024
            #hatS = lambda q: -jnp.array(
            #    [q[0]-jnp.pi*jnp.sign(q[2])*q[1],
            #     q[1]+jnp.pi*jnp.sign(q[2])*q[0]]
            #)
            #return jnp.sum(jnp.dot(_s1,delta*_s1-2*hatS(Delta_hatX))) # (5.5) in Grong, Habermann, Sommer 2024
        tildeFx = jnp.concatenate((jnp.expand_dims(x0s[0],0),data[:-1,:,0:M.dim]),axis=0)
        v = jnp.mean(jax.vmap(
                        jax.vmap(
                            f,
                        (0,0,0)),
                    (None,1,1))(_ts,data,tildeFx))
        return v
    def loss_s2(params, data):
        """ compute loss."""
        def f(t,xnoise): 
            Fx = xnoise[0:M.dim]
            sigma2 = dt; sigma = jnp.sqrt(sigma2) 
            noise = xnoise[M.dim:]; z = noise/sigma
            _s1 = s1.apply(params,jnp.hstack((t,Fx)))
            _s2 = s2.apply(params,jnp.hstack((t,Fx)))
            return jnp.sum(jnp.square(_s2+jnp.square(_s1)+(jnp.ones(M.sR_dim)-jnp.square(z))/sigma2))
        v = jnp.mean(jax.vmap(
                        jax.vmap(
                            f,
                        (0,0)),
                    (None,1))(_ts,data))
        return v
    def loss_s1_div(params, data):
        """ compute loss using divergence """
        def f(t,xnoise): 
            Fx = xnoise[0:M.dim]
            x,chart = get_coords(Fx)
            _s1 = s1_div.apply(params,jnp.hstack((t,Fx)))
            norm2 = jnp.sum(jnp.square(_s1))
            div = M.div((x,chart),lambda x: s1_div.apply(params,jnp.hstack((t,M.F(x)))))
            return norm2+2*div # (4.3) in Grong, Habermann, Sommer 2024
        v = jnp.mean(jax.vmap(
                        jax.vmap(
                            f,
                        (0,0)),
                    (None,1))(_ts,data))
        return v

    # Initialize solver.
    opt = optax.adam(learning_rate)
    opt_state_s1 = opt.init(params_s1)
    opt_state_s1_div = opt.init(params_s1_div)
    loss_grad_s1 = jax.jit(jax.value_and_grad(loss_s1,argnums=0))
    loss_grad_s1_Heisenberg = jax.jit(jax.value_and_grad(loss_s1_Heisenberg,argnums=0))
    loss_grad_s1_div = jax.jit(jax.value_and_grad(loss_s1_div,argnums=0))

    # run training loop, s1
    print("Training s1")
    for i in range(epochs*batches_per_epoch):
      #loss_val, grads = loss_grad_s1(params_s1, next(ds))
      loss_val, grads = loss_grad_s1(params_s1, next(ds))
      updates, opt_state_s1 = opt.update(grads, opt_state_s1)
      params_s1 = optax.apply_updates(params_s1, updates)
      if i % 10 == 0:
        wandb.log({'loss': loss_val, 'epoch': i//batches_per_epoch})
      if i % batches_per_epoch == 0:
        print(f"[Step {i}], epoch {i//batches_per_epoch}, training loss: {loss_val:.3e}.")

    ## run training loop, s1 divergence loss
    #print("Training s1_div")
    #for i in range(epochs*batches_per_epoch):
    #  loss_val, grads = loss_grad_s1_div(params_s1_div, next(ds))
    #  updates, opt_state_s1_div = opt.update(grads, opt_state_s1_div)
    #  params_s1_div = optax.apply_updates(params_s1_div, updates)
    #  if i % 10 == 0:
    #    wandb.log({'loss': loss_val, 'epoch': i//batches_per_epoch})
    #  if i % batches_per_epoch == 0:
    #    print(f"[Step {i}], epoch {i//batches_per_epoch}, training loss: {loss_val:.3e}.")

    #return (params_s1, s1), (params_s2, s2), (params_s1_div, s1_div)
    #return (params_s1, s1), (params_s1_div, s1_div)
    return (params_s1, s1),

## run sweep
#sweep_id = wandb.sweep(sweep=sweep_configuration, project="sR")
#wandb.agent(sweep_id, function=train)
config = {'learning_rate': 1e-2, 'epochs': 2500, 'batches_per_epoch': 2**3, 'batch_size': 2**6, 'nodes': (15,15,15,), 'activation': 'relu', 'T': 1., 'n_steps': 1000}
wandbmode="disabled"
with wandb.init(project='sR',config=config,mode=wandbmode) as run:
    #(params_s1, s1), (params_s2, s2), (params_s1_div, s1_div) = train(run,wandbmode=wandbmode)
    #(params_s1, s1), (params_s1_div, s1_div) = train(run,wandbmode=wandbmode)
    (params_s1, s1), = train(run,wandbmode=wandbmode)

In [None]:
# plot xy-slices of the result for different t
for t in jnp.linspace(.2,1,5+1):
    print("t:",t)
    plot_field(lambda y: .2*t*jnp.dot(M.D(get_coords(y)),s1.apply(params_s1,jnp.hstack((t,y)))),jnp.arange(-1,1,.1),jnp.arange(.2,1,1),view=(90,0))
    #plot_field(lambda y: .2*t*jnp.dot(M.D(get_coords(y)),s1_div.apply(params_s1_div,jnp.hstack((t,y)))),jnp.arange(-1,1,.1),jnp.arange(.2,1,1),view=(90,0))


In [None]:
# plot full 3D view
# Generate a vector field
X, Y, Z = np.mgrid[-1:1:15j, -1:1:15j, -1:1:15j]
t = .2
_s1 = jax.vmap(lambda x,y,z: t*jnp.dot(M.D(get_coords(jnp.hstack((x,y,z)))),s1.apply(params_s1,jnp.hstack((t,x,y,z)))),(0,0,0))(X.flatten(),Y.flatten(),Z.flatten()).reshape(X.shape+(1+M.sR_dim,))
u = _s1[...,0]; v = _s1[...,1]; w = _s1[...,2]

# Calculate the magnitude of vectors
magnitude = np.linalg.norm(_s1,axis=3)
#magnitude = magnitude / magnitude.max()
magnitude = magnitude / 4

# plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.quiver(X, Y, Z, u, v, w, color=plt.cm.jet(magnitude).reshape((-1,4)), length=.1)

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')

ax.view_init(35, 30)
plt.show()


# Bridges using learned score

In [None]:
""" guide towards 0 in Heisenberg group """
x = M.coords([.5,0.,.8]) # starting point
v = (jnp.zeros_like(X[0]),X[1]) # target

In [None]:
# guided process with explicit guide
def guide(x,v,*_):
    """ guided towards 0 in Heisenberg group """
    gamma = jnp.arctan2(x[0][1],x[0][0])
    
    f = lambda alpha: (8*jnp.sin(alpha[0]/2)**2*jnp.abs(x[0][2])-jnp.sum(x[0][0:2]**2)*(alpha[0]-jnp.sin(alpha[0])))**2
    alpha = optimize.minimize(f,jnp.array([jnp.pi]),method='BFGS').x[0]
    
    r = jnp.linalg.norm(x[0][0:2])/(2*jnp.sin(alpha/2))
    
    epsilon = 1e-4
    b = jax.lax.cond(jnp.abs(x[0][2])<epsilon,
                     lambda _: 
                         jnp.array([-jnp.linalg.norm(x[0][0:2])*jnp.cos(gamma),
                                    -jnp.linalg.norm(x[0][0:2])*jnp.sin(gamma)]),
                     lambda _: 
                         jnp.array([-r*alpha*jnp.cos(gamma+jnp.sign(x[0][2])*alpha/2),
                                    -r*alpha*jnp.sin(gamma+jnp.sign(x[0][2])*alpha/2)]),
                     None)
    return b

# coordinate form
from jaxgeometry.stochastics.guided_process import *

(Brownian_sR_guided_explicit,_,*_) = get_guided(
    M,M.sde_Brownian_sR,M.chart_update_Brownian_sR,guide,
    lambda x,*_: jnp.linalg.cholesky(jnp.tensordot(M.D(x),M.D(x),(0,0))))

# plot with zero nois##############i   ##dddadfadfiiiiiiiiiiiiie
_dts = dts(n_steps=1000)
(ts,xs,charts,log_likelihood,log_varphi) = Brownian_sR_guided_explicit(x,v,_dts,0*dWs(M.sR_dim,_dts),1.)
xs_explicit,charts_explicit = xs,charts
print(xs[-1])

# plot
newfig()
M.plot()
M.plot_path(zip(xs,charts))
plt.show()

plt.plot(jnp.cumsum(_dts),jax.vmap(lambda x: jnp.linalg.norm(x[0:2]),0)(xs),'r')
plt.plot(jnp.cumsum(_dts),jax.vmap(lambda x: jnp.linalg.norm(x[3]),0)(xs),'b')
plt.show()

In [None]:
# bridge with score

#guide = lambda x,t: jnp.dot(M.D(x).T,net.apply(params,jnp.hstack((t,x[0]))))
guide = lambda x,t: s1.apply(params_s1,jnp.hstack((t,x[0])))

sde = M.sde_Brownian_sR
chart_update = M.chart_update_Brownian_sR
def sde_guided(c,y):
    t,x,chart,log_likelihood,log_varphi,T,v,*cy = c
    xchart = (x,chart)
    dt,dW = y
    
    (det,sto,X,*dcy) = sde((t,x,chart,*cy),y)
    
    h = jax.lax.cond(t<T-dt/2,
                     lambda _: guide(xchart,T-t),
                     lambda _: jnp.zeros_like(guide(xchart,T-t)),
                     None)
    
    sto = jax.lax.cond(t < T-3*dt/2, # for Ito as well?
                       lambda _: sto,
                       lambda _: jnp.zeros_like(sto),
                       None)

    ### likelihood
    log_likelihood = 0.

    ## correction factor
    log_varphi = 0.

    return (det+jnp.dot(X,h),sto,X,log_likelihood,log_varphi,jnp.zeros_like(T),jnp.zeros_like(v),*dcy)

def chart_update_guided(x,chart,log_likelihood,log_varphi,T,v,*ys):
    if chart_update is None:
        return (x,chart,log_likelihood,log_varphi,T,v,*ys)

    (x_new, chart_new, *ys_new) = chart_update(x,chart,*ys)
    v_new = M.update_coords((v,chart),chart_new)[0]
    return (x_new,chart_new,log_likelihood,log_varphi,T,v_new,*ys_new)

guided = lambda x,v,dts,dWs,*ys: integrate_sde(sde_guided,integrator_stratonovich,chart_update_guided,x[0],x[1],dts,dWs,0.,0.,jnp.sum(dts),M.update_coords(v,x[1])[0] if chart_update else v,*ys)[0:5]

#(Brownian_sR_guided,sde_Brownian_sR_guided,*_) = get_guided(
#    M,M.sde_Brownian_sR,M.chart_update_Brownian_sR,guide,
#    lambda x: jnp.linalg.cholesky(jnp.tensordot(M.D(x),M.D(x),(0,0))))
Brownian_sR_guided = guided

T = .1
_dts = dts(n_steps=1000,T=T)
(ts,xs,charts,log_likelihood,log_varphi) = Brownian_sR_guided(x,v,_dts,dWs(M.sR_dim,_dts),1.)
print(xs[-1])

# plot
newfig()
M.plot()
M.plot_path(zip(xs,charts))
M.plot_path(zip(xs_explicit,charts_explicit),color='r')
plt.savefig(f'guided_{T}.png')
plt.show()

plt.plot(jnp.cumsum(_dts),jax.vmap(lambda x: jnp.linalg.norm(x[0:2]),0)(xs),'r')
plt.plot(jnp.cumsum(_dts),jax.vmap(lambda x: jnp.linalg.norm(x[2]),0)(xs),'b')
plt.savefig(f'guided_norms_{T}.png')
plt.show()

In [None]:
# statistics
N = 100
for T in [.1,.2,.5,1.]:
    _dts = dts(n_steps=1000,T=T)
    xss = jax.vmap(lambda dWs: Brownian_sR_guided(x,v,_dts,dWs,1.)[1])(dWs(N*M.sR_dim,_dts).reshape((N,-1,M.sR_dim)))

    xy = jax.vmap(jax.vmap(lambda x: jnp.linalg.norm(x[0:2]),0))(xss)
    Z = jax.vmap(jax.vmap(lambda x: jnp.linalg.norm(x[2]),0))(xss)
    mean_xy = jnp.median(xy,0)
    mean_z = jnp.median(Z,0)
    quartiles_xy = np.percentile(xy, [25, 50, 75], axis=0)
    quartiles_z = np.percentile(Z, [25, 50, 75], axis=0)

    # plot xy
    plt.plot(jnp.cumsum(_dts),mean_xy,'r')
    plt.fill_between(jnp.cumsum(_dts), quartiles_xy[0], quartiles_xy[2], alpha=0.3)
    # plot z
    plt.plot(jnp.cumsum(_dts),mean_z,'b')
    plt.fill_between(jnp.cumsum(_dts), quartiles_z[0], quartiles_z[2], alpha=0.3)

    plt.z_lim = (0.,1.3)
    plt.savefig(f'guided_statistics_{T}.png')

    plt.show()