# Diffusion mean estimation
See https://arxiv.org/abs/2105.12061

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from src.manifolds.S2 import *
from src.manifolds.torus import *
from src.manifolds.cylinder import *
M = S2()
# M = Torus(params=(.5,1.,[0,0,1]))
# M = Cylinder()

print(M)
from src.plotting import *
#%matplotlib notebook

In [None]:
# Riemannian structure
from src.Riemannian import metric
metric.initialize(M)

# geodesics
from src.Riemannian import geodesic
geodesic.initialize(M)

# Logarithm map
from src.Riemannian import Log
Log.initialize(M)

x = M.coords(jnp.zeros(M.dim))

## Sample Data

In [None]:
# coordinate form
from src.stochastics import Brownian_coords
Brownian_coords.initialize(M)

# product sde
from src.stochastics import product_sde
from src.stochastics.product_sde import tile
(product,sde_product,chart_update_product) = product_sde.initialize(M,M.sde_Brownian_coords,M.chart_update_Brownian_coords)

N = 100
_dts = dts(T=1.)
(ts,xss,chartss,*_) = product(tile(x,N),_dts,dWs(N*M.dim,_dts).reshape(-1,N,M.dim),jnp.repeat(.3,N))
samples = xss[-1]
chartss = chartss[-1]    

# plot
newfig()
M.plot()
for i in range(N):
    M.plotx((samples[i],chartss[i]))
plt.show()

## Diffusion mean estimation, bridge sampling

In [None]:
from src.statistics import diffusion_mean
diffusion_mean.initialize(M)

# run once to compile
(thetas,chart,log_likelihood,log_likelihoods,thetass) = M.diffusion_mean((samples,chartss),num_steps=2)

In [None]:
# run MLE
%time (thetas,chart,log_likelihood,log_likelihoods,thetass) = M.diffusion_mean((samples,chartss),num_steps=30)

# plot
n_steps = log_likelihoods.shape[0]
plt.plot(range(n_steps),log_likelihoods)
# plt.savefig('ML_likelihoods.pdf')
plt.show()
plt.plot(range(n_steps),[t[1] for t in thetass])
plt.show()
plt.plot(range(n_steps),[M.F((t[0],t[2])) for t in thetass])
print(M.F((thetas[0],chart)))
# plt.savefig('ML_thetas.pdf')
plt.show()

M.newfig()
M.plot()
M.plotx((thetas[0],chart),color='k',s=100) # result
M.plotx((thetass[0][0],thetass[0][2]),color='b',s=100) # initial point
M.plotx(x,color='r',s=100)
M.plot_path(list([(t[0],t[2]) for t in thetass]),color='b',linewidth=2.5)

# plt.savefig('MLmean_iterations.pdf')
plt.show()

## Diffusion mean estimation, diagonal sampling

In [None]:
# # condition on diagonal of product manifold
from src.stochastics import diagonal_conditioning
diagonal_conditioning.initialize(M,sde_product,chart_update_product)

_dts = dts(n_steps=500,T=.1)
(ts,xss,_chartss) = M.diagonal((samples,chartss),
                             _dts,dWs(N*M.dim,_dts).reshape(-1,N,M.dim),x[1],jnp.repeat(1.,N))
mean = jnp.mean(jax.vmap(lambda _x,chart: M.update_coords((_x,chart),x[1])[0],0)(xss[-1],_chartss[-1]),0)

# plot
M.newfig()
M.plot()
colormap = plt.get_cmap('winter')
colors=[colormap(k) for k in np.linspace(0, 1, N)]
for i in jnp.arange(1,N,5):
    M.plot_path(zip(xss[:,i],_chartss[:,i]),color=colors[i])
    M.plotx((samples[i],chartss[i]),s=40)
M.plotx((mean,x[1]),color='r',s=80)
ax = plt.gcf().gca(); ax.view_init(60, 45) # rotate
plt.axis('off')
plt.show()

In [None]:
# sample multiple means
K = 32
means = np.zeros((K,M.dim))

_dts = dts(T=.2)
for i in range(K):
    (ts,xss,_chartss) = M.diagonal((samples,chartss),
                             _dts,dWs(N*M.dim,_dts).reshape(-1,N,M.dim),x[1],jnp.repeat(1.,N))
    means[i] = jnp.mean(jax.vmap(lambda _x,chart: M.update_coords((_x,chart),x[1])[0],0)(xss[-1],_chartss[-1]),0)
    
colormap = plt.get_cmap('winter')
colors=[colormap(k) for k in np.linspace(0, 1, K)]

# plot estimated means with std variation
M.newfig()
M.plot()
ax = plt.gca()
for j in range(K):
    M.plotx((means[j],x[1]),color=colors[j],s=30)
ax = plt.gcf().gca(); ax.view_init(60, 45) # rotate
plt.axis('off')
plt.show()

## Frechet mean

In [None]:
from src.statistics import Frechet_mean
Frechet_mean.initialize(M)

m,loss,iterations,vs = M.Frechet_mean(zip(samples,chartss),x)
# m,loss,iterations = M.Frechet_mean(zip(samples,chartss),x,Log=lambda *args: M.Log(*args))
print("loss = ", loss)
print("mean = ", m)

# plot
newfig()
M.plot()
M.plotx(m,s=100,color='g')

for i in range(len(samples)):
    try:
        (xs,charts) = M.Expt(m,vs[i])
        M.plot_path(zip(xs,charts))
    except:
        pass
    M.plotx((samples[i],chartss[i]),linewidth = 1.5, s=50, color='r')
M.plot_path(iterations,color='y')
plt.show()