In [None]:
# # This file is part of Theano Geometry
#
# Copyright (C) 2017, Stefan Sommer (sommer@di.ku.dk)
# https://bitbucket.org/stefansommer/theanogemetry
#
# Theano Geometry is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Theano Geometry is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Theano Geometry. If not, see <http://www.gnu.org/licenses/>.
#

# Brownian Bridge Simulation and Metric Estimation on Landmark Manifolds
arXiv:1705.10943 [cs.CV] https://arxiv.org/abs/1705.10943

Stefan Sommer, Line Kuhnel, Alexis Arnaudon, and Sarang Joshi

In [None]:
%cd ..
from src.manifolds.landmarks import *
M = landmarks(2)
print(M)

from src.plotting import *
%matplotlib inline
plt.rcParams['figure.figsize'] = 13, 10
colormap = plt.get_cmap('winter')

In [None]:
# Riemannian structure
from src.Riemannian import metric
metric.initialize(M)

In [None]:
# initialize
M.N.set_value(8)
M.k_alpha.set_value(.1)
M.k_sigma.set_value(.5*np.diag((1.,1.)))
n_steps.set_value(500)

In [None]:
# setup 
q = np.vstack((np.linspace(-.5,.5,M.N.eval()),np.zeros(M.N.eval()))).T.flatten()
v = np.vstack((np.zeros(M.N.eval()),np.ones(M.N.eval()))).T.flatten()
p = M.flatf(q,v)
print("q = ", q)
print("p = ", p)

In [None]:
# Hamiltonian dynamics
print(M.Hf(q,p))
from src.dynamics import Hamiltonian
Hamiltonian.initialize(M)

# geodesic
qs = M.Exp_Hamiltoniantf(q,v).T
M.plot()
M.plotx(qs,v)
plt.show()
(ts,qps) = M.Hamiltonian_dynamicsf(q,p)
ps = qps[:,1,:]
print("Energy: ",np.array([M.Hf(q,p) for (q,p) in zip(qs,ps)]))

In [None]:
## Visualize bridge

# Brownian motion
from src.stochastics import Brownian_coords
Brownian_coords.initialize(M)
# Delyon/Hu guided process
from src.stochastics.guided_process import *

Cholesky = T.slinalg.Cholesky()
phi = lambda q,v: T.tensordot(T.nlinalg.MatrixInverse()(Cholesky(M.gsharp(q))),-(q-v).flatten(),(1,0))
sde_Brownian_coords_guided = get_sde_guided(M.sde_Brownian_coords,phi,lambda q: Cholesky(M.gsharp(q)))
Brownian_coords_guided = lambda q,v,dWt: integrate_sde(sde_Brownian_coords_guided,
                                                   integrator_ito,
                                                   q,dWt,T.constant(0.),T.constant(0.),v)
q0 = M.element()
v = M.element()
Brownian_coords_guidedf = theano.function([q0,v,dWt], Brownian_coords_guided(q0,v,dWt)[:4])
phif = theano.function([q0,v], phi(q0,v))

# derivatives
thetas = (q0,M.k_alpha,M.k_sigma) # parameters
def dlog_likelihood(q,v,dWt):
    s = Brownian_coords_guided(q,v,dWt)
    dlog_likelihoods = tuple(T.grad(s[2][-1],theta) for theta in thetas)
    
    return (s[0],s[1],s[2],s[3])+dlog_likelihoods
dlog_likelihoodf = theano.function([q0,v,dWt],dlog_likelihood(q0,v,dWt))
v = np.stack((np.linspace(-.5,.5,M.N.eval()),np.ones(M.N.eval()))).T.flatten()
(ts,qs,log_likelihood,log_varphi) = Brownian_coords_guidedf(q,v,dWsf(M.dim.eval()))
print("log likelihood: ", log_likelihood[-1], ", log varphi: ", log_varphi[-1])
M.plot()
M.plotx(np.vstack((q,qs)),curve=True)
M.plotx(v,color='k',curve=True)
# plt.savefig('bridge.pdf')

In [None]:
## Set up for inference example

# initialize
N.set_value(10)
k_alpha.set_value(.01)
n_steps.set_value(20)

# setup 
x0 = ellipse([0.,0.],[1.,.5])
q0 = x0.flatten()
v0 = np.vstack((np.zeros(N.eval()),np.ones(N.eval()))).T
print("q = ", q0)

avg_landmark_dist = np.mean(np.linalg.norm(x0[:-1]-x0[1:],axis=1))
k_sigma.set_value(avg_landmark_dist*np.diag((1.,1.)))
print("k_alpha: ", k_alpha.eval(), ", k_sigma: ", k_sigma.eval().flatten())
k_alpha_init = k_alpha.eval()
k_sigma_init = k_sigma.eval()

plotx(q0,curve=True,color='k')
v0 = np.vstack((np.zeros(N.eval()),np.ones(N.eval()))).T
p0 = gMflatf(q0,v0.flatten())
plotx(Expf(q0,p0),curve=True)

In [None]:
# sample for Brownian motion transition distribution
N_samples = 64
obss = np.zeros((N_samples,)+q0.shape)
qsvs = np.zeros((N_samples,n_steps.eval(),)+q0.shape)
# srng.seed(422)
for i in range(N_samples):
    (ts,qsv) = Brownian_coordsf(q0,dWsf())
    qsvs[i] = qsv
    obss[i] = qsv[-1]

# plot samples
plot_samples = 15
colors=[colormap(k) for k in np.linspace(0, 1, plot_samples)]
for i in range(plot_samples):
    plotx(obss[i],curve=True,color=colors[i])
plt.axis('equal')
# plt.savefig('samples.pdf')
plt.show()

# plot samples with paths
plot_samples = 5
colors=[colormap(k) for k in np.linspace(0, 1, plot_samples)]
plotx(q0,color='k',curve=True)
for i in range(plot_samples):
    plotx(qsvs[i],color=colors[i],curve=True)
# plt.savefig('samples_paths.pdf')
plt.axis('equal')

In [None]:
from src.mle import *
import src.mle as mle
mle.dlog_likelihoodf = dlog_likelihoodf
mle.thetas = thetas

options = {}
options['samples_per_obs'] = 1
options['epochs'] = 80
options['learning_rate'] = .5e-3#1.5e-3
options['varphi_update_rate'] = 1.
options['initial'] = (np.zeros(d.eval()),
                      .12,.3*np.diag((1.,1.)))#.2,.6*np.diag((1.,1.)))
options['update_v'] = lambda g: g

In [None]:
# produce bridge plot for paper

def bridge_sampling(lg,dWsf,options,pars):
    (v,log_phi,seed) = pars
    if seed:
        srng.seed(seed)
    bridges = np.zeros((options['samples_per_obs'],n_steps.eval(),)+lg.shape)
    log_varphis = np.zeros((options['samples_per_obs'],))
    log_likelihoods = np.zeros((options['samples_per_obs'],))
    dlog_likelihoods = None#tuple(np.zeros((options['samples_per_obs'],)+theta.shape) for theta in thetas)
    #global dlog_likelihoodf
    for i in range(options['samples_per_obs']):
        (ts,gsv,log_likelihood,log_varphi,*dlog_likelihood) = dlog_likelihoodf(lg,v,dWsf())
        print(log_varphi)
        bridges[i] = gsv
        log_varphis[i] = log_varphi[-1]
        #log_likelihoods[i] = log_likelihood[-1]
        #for (j,dl) in enumerate(dlog_likelihoods):
         #   dl[i] = dlog_likelihood[i]
        try:
            v = options['update_v'](v) # update v, e.g. simulate in fiber
        except KeyError:
            pass
    return (bridges,log_varphis,log_likelihoods,dlog_likelihoods,v)
# bridge(g0,A.eval(),options,(v.eval(),np.random.randint(1000)))[0].shape

def lbridge_sampling(thetas,*args,**kwargs):
    k_alpha.set_value(thetas[1])
    k_sigma.set_value(thetas[2])
    return bridge_sampling(q0,*args,**kwargs)

N_samples = 1
tmp0 = options['samples_per_obs']
options['samples_per_obs'] = 1
tmp1 = k_alpha.eval()
k_alpha.set_value(.01)
v = obss[0]
log_phis = np.zeros((N_samples,))
try:
    mpu.openPool()
    sol = mpu.pool.imap(partial(lbridge_sampling,options['initial'],dWsf,options),\
                        mpu.inputArgs(v.reshape((1,)+v.shape),log_phis,np.random.randint(1000,size=N_samples)))
    res = list(sol)
    bridges = mpu.getRes(res,0)
    log_varphis = mpu.getRes(res,1)
    log_likelihoods = mpu.getRes(res,2)
    dlog_likelihoods = mpu.getRes(res,3)
except:
    mpu.closePool()
    raise
else:
    mpu.closePool()

# plot
colormap = plt.get_cmap('winter')
colors=[colormap(k) for k in np.linspace(0, 1, options['samples_per_obs'])]
for j in range(bridges.shape[1]):
    gsv = np.vstack((q0.flatten(),bridges[0,j]))
    plotx(gsv,linewidth=.6,color=colors[j],curve=True)
plotx(v,color='b',curve=True)        
plotx(q0,color='k',curve=True)
# plt.savefig('bridges.pdf')

options['samples_per_obs'] = tmp0
k_alpha.set_value(tmp1)
N_samples = obss.shape[0]

In [None]:
# Transition density

t = T.scalar()
sde_Brownian_coordsf = theano.function([t,q],sde_Brownian_coords(dWs[0],t,q),on_unused_input='ignore')

# symbolic lift to fiber not supported yet
def log_p_T(g,v,dWs,bridge_sde,phi,options,sigma=None,sde=None):
    if sde is not None:
        (_,_,XT) = sde(dWs,Tend,v) # starting point of SDE, we need diffusion field X at t=0
        sigma = XT
    assert(sigma is not None)
    Cgv = T.sum(phi(g,v)**2)    

    # sample noise
    (cout, updates) = theano.scan(fn=lambda x: dWs,
            outputs_info=[T.zeros_like(dWs)],
            n_steps=options['samples_per_obs'])
    dWsi = cout
    
    # bridges
    def bridge_p_T(dWs,gsv,log_varphi):
        (ts,gsv,log_likelihood,log_varphi,_) = bridge_sde(g,v,dWs)
        return (gsv,log_varphi[-1])
    
    (cout, updates) = theano.scan(fn=bridge_p_T,
            outputs_info=[T.zeros((n_steps,d)),T.constant(0.)],
            sequences=[dWsi])
    bridges = cout[:][0]
    log_varphis = cout[:][1]

#     p_T =  T.power(2.*np.pi*Tend,-.5*sigma.shape[0])/T.abs_(T.nlinalg.Det()(sigma))*T.exp(-Cgv/(2.*Tend))*T.mean(T.exp(log_varphis))
    return -.5*sigma.shape[0]*T.log(2.*np.pi*Tend)-linalg.LogAbsDet()(sigma)-Cgv/(2.*Tend)+T.log(T.mean(T.exp(log_varphis)))
v = T.vector()
log_p_Tf = theano.function([q,v],log_p_T(q,v,dWs,Brownian_coords_guided,phi,options,sde=sde_Brownian_coords))
def dlog_p_T(*args,**kwargs):
    llog_p_T = log_p_T(*args,**kwargs)
    return (llog_p_T,)+tuple(T.grad(llog_p_T,theta) for theta in thetas)
dlog_p_Tf = theano.function([q,v],dlog_p_T(q,v,dWs,Brownian_coords_guided,phi,options,sde=sde_Brownian_coords))
p_Tf = theano.function([q,v],T.exp(log_p_T(q,v,dWs,Brownian_coords_guided,phi,options,sde=sde_Brownian_coords)))
v = Expf(q0,p0)
print(log_p_Tf(q0,v))
print(p_Tf(q0,v))
print(dlog_p_Tf(q0,v))

def log_p_T_numeric(lg,v,dWsf,bridge_sdef,phif,options,sigma=None,sdef=None,x0=None):
    vorg = v # debug
    if x0 is not None: # if lv point on manifold, lift target to fiber
        v = lift_to_fiber(v,x0)[0]
    if sdef is not None:
        (_,_,XT) = sdef(Tend.eval(),v) # starting point of SDE, we need diffusion field X at t=0
        sigma = XT
    elif sigma is not None:
        sigma = sigma.eval()
    assert(sigma is not None)
    bridges = np.zeros((options['samples_per_obs'],n_steps.eval(),)+lg.shape)
    log_varphis = np.zeros((options['samples_per_obs'],))
    Cgvs = np.zeros((options['samples_per_obs'],))
    for i in range(options['samples_per_obs']):
        try:
            (ts,gsv,log_likelihood,log_varphi) = bridge_sdef(lg,v,dWsf())
        except ValueError:
            print('Bridge sampling error:')
            print(v)
            print(vorg)
            print(lift_to_fiber(vorg,x0))
            print(phif(lg,v))
            raise
        bridges[i] = gsv
        log_varphis[i] = log_varphi[-1]
        Cgvs[i] = np.linalg.norm(phif(lg,v))**2
        try:
            v = options['update_v'](v) # update v, e.g. simulate in fiber
        except KeyError:
            pass
#     p_T = np.power(2.*np.pi*Tend.eval(),-.5*sigma.shape[0])/np.abs(np.linalg.det(sigma))*np.mean(np.exp(-Cgvs/(2.*Tend.eval()))*np.exp(log_varphis))
    return -.5*sigma.shape[0]*np.log(2.*np.pi*Tend.eval())-np.log(np.abs(np.linalg.det(sigma)))+np.log(np.mean(np.exp(-Cgvs/(2.*Tend.eval()))*np.exp(log_varphis)))
p_T_numeric = lambda *pars, **kwargs: np.exp(log_p_T_numeric(*pars,**kwargs))

print(log_p_T_numeric(q0,v,dWsf,Brownian_coords_guidedf,phif,{'samples_per_obs': 1},sdef=sde_Brownian_coordsf))
print(p_T_numeric(q0,v,dWsf,Brownian_coords_guidedf,phif,{'samples_per_obs': 1},sdef=sde_Brownian_coordsf))

In [None]:
%%time

vs = obss
try:
    def llog_p_T(thetas,pars):
        (v,seed) = pars
        if seed:
            srng.seed(seed)
        k_alpha.set_value(thetas[1])
        k_sigma.set_value(thetas[2])
        return dlog_p_Tf(q0,v)
    
    log_likelihoods = np.zeros(options['epochs'])
    
    # initial thetas
    q0 = np.mean(obss,axis=0)
    k_alpha.set_value(options['initial'][1])    
#     avg_landmark_dist = np.mean(np.linalg.norm(q0.reshape((-1,m.eval()))[:-1]-q0.reshape((-1,m.eval()))[1:],axis=1))
#     k_sigma.set_value(avg_landmark_dist*np.diag((1.,1.)))
    k_sigma.set_value(options['initial'][2])
    print("initial thetas",
          "\n\tq0: ", q0, ", \n\tk_alpha: ", k_alpha.eval(), ", \n\tk_sigma: ", k_sigma.eval().flatten())    
    
    # for plotting iterations
    q0s = np.zeros((options['epochs'],)+q0.shape)    
    k_alphas = np.zeros((options['epochs'],)+k_alpha.eval().shape)    
    k_sigmas = np.zeros((options['epochs'],)+k_sigma.eval().shape)    
    mpu.openPool()
    for i in range(options['epochs']):
        sol = mpu.pool.imap(partial(llog_p_T,(q0,k_alpha.eval(),k_sigma.eval())),\
                            mpu.inputArgs(vs,np.random.randint(1000,size=N_samples)))
        res = list(sol)
        log_likelihood = np.mean(mpu.getRes(res,0),axis=0)
        dqlog_likelihood = np.mean(mpu.getRes(res,1),axis=0)
        dk_alphalog_likelihood = np.mean(mpu.getRes(res,2),axis=0)
        dk_sigmalog_likelihood = np.mean(mpu.getRes(res,3),axis=0)
        
        log_likelihoods[i] = log_likelihood # total log likelihood        

        # step, update parameters and varphis                            
        q0 = q0+options['learning_rate']*np.dot(gMsharpf(q0),dqlog_likelihood) # use Riemannian g-gradient
        q0s[i] = q0
        k_alpha.set_value(k_alpha.eval()+options['learning_rate']/d.eval()*dk_alphalog_likelihood)
        k_alphas[i] = k_alpha.eval()
        k_sigma.set_value(k_sigma.eval()+options['learning_rate']*dk_sigmalog_likelihood)
        k_sigmas[i] = k_sigma.eval()
        
        print("iteration: ", i, ", log-likelihood: ", log_likelihood, 
              ", new thetas: \n\tq0: ", q0, ", \n\tk_alpha: ", k_alpha.eval(), ", \n\tk_sigma: ", k_sigma.eval().flatten())        
except:
    mpu.closePool()
    raise
else:
    mpu.closePool()

## plot
plt.plot(range(options['epochs']),log_likelihoods)
# plt.savefig('likeliood.pdf')
plt.show()
plt.plot(range(options['epochs']),q0s.reshape((q0s.shape[0],-1)))
# plt.savefig('q0s.pdf')
plt.show()
plt.plot(range(options['epochs']),k_alphas,color='b')
plt.hlines(k_alpha_init,plt.xlim()[0],plt.xlim()[1],color='r')
# plt.savefig('k_alpha.pdf')
plt.show()
plt.plot(range(options['epochs']),k_sigmas.reshape((k_sigmas.shape[0],-1)),color='b')
plt.hlines(k_sigma_init.flatten(),plt.xlim()[0],plt.xlim()[1],color='r')
plt.ylabel(r'$\sigma$', fontsize=30)
# plt.savefig('k_sigma.pdf')
plt.show()
plotx(x0.flatten(),color='k',curve=True)
plotx(q0,color='b',curve=True)
# plt.savefig('est_q0.pdf')
plt.show()
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
ax1.plot(range(options['epochs']),log_likelihoods,'g--')
ax1.set_ylabel(r'$\mathcal{L}_\theta$', fontsize=30)
ax2.plot(range(options['epochs']),k_alphas,color='b')
ax2.hlines(k_alpha_init,plt.xlim()[0],plt.xlim()[1],color='r')
ax2.set_ylabel(r'$\alpha$', fontsize=30)
# plt.savefig('likelihood-k_alpha.pdf')
plt.show()
None

In [None]:
# sample with estimated parameters
obss_new = np.zeros((N_samples,)+q0.shape)
for i in range(N_samples):
    (ts,qsv) = Brownian_coordsf(q0,dWsf())
    qsvs[i] = qsv
    obss_new[i] = qsv[-1]

In [None]:
# plot
def estimate_qq(data_q):
    data_mean= data_q.sum(0)/data_q.shape[0]
    data= data_q - data_mean
    
    return [data_mean,(data[:,:,:,np.newaxis,np.newaxis]*data[:,np.newaxis,np.newaxis,:,:]).sum(0)/data.shape[0]]
qq = estimate_qq(obss.reshape((-1,N.eval(),m.eval())))
qq_new = estimate_qq(obss_new.reshape((-1,N.eval(),m.eval())))

#plot density distribution of landmarks
def plot_distribution(xss):

    xTx=[]
    xTy=[]
    for i in range(xss.shape[0]):
        for j in range(0,N.eval()):
            xTx.append(xss[i,j,0])
            xTy.append(xss[i,j,1])
    hist,histy,histx= np.histogram2d(xTy,xTx,bins=25)
    extent = [histx[0],histx[-1],histy[0],histy[-1]]

    
    #plt.contour(hist/np.max(hist),extent=extent,levels=[0.05,0.2,0.4,0.6],zorder=10)
    plt.imshow(hist/np.max(hist),extent=extent,interpolation='bicubic',origin='lower',cmap='Greys')#,levels=[0.05,0.2,0.4,0.6],zorder=10)
    #plt.colorbar()

# plot variance
def plot_final_ellipses(q,QQ,coeff=1.,c='m',ls='-',lw=1):
    # plot sigma as ellipses 
    from matplotlib.patches import Ellipse
    from numpy import linalg as LA
    ax= plt.gca()
    for i in range(N.eval()):
        qq_eig,qq_vec = LA.eig(QQ[i,:,i,:])
        qq_eig = np.sqrt(qq_eig)
        theta = np.degrees(np.arctan(qq_vec[1,0]/qq_vec[0,0]))

        ell= Ellipse(xy=q[i] ,width=coeff*qq_eig[0],height= coeff*qq_eig[1],angle=theta,ls=ls,lw=lw)
        ax.add_artist(ell)
        ell.set_alpha(1.)
        ell.set_facecolor('None')
        ell.set_edgecolor(c)
        
plotx(x0.flatten(),color='k',curve=True)
plotx(q0,color='b',curve=True)
plot_final_ellipses(qq[0],qq[1],coeff=3.,c='k',ls='-',lw=2)
plot_final_ellipses(q0.reshape((-1,m.eval())),qq_new[1],coeff=3.,c='b',ls='--',lw=2)
plot_distribution(obss_new.reshape((-1,N.eval(),m.eval())))
# plt.savefig('ellipse_inf.pdf')