In [None]:
import numpy as np
from matplotlib import pyplot
import matplotlib
from mpl_toolkits.mplot3d import Axes3D

In [None]:
from path_guiding import VMFMixture, VMFFitIncremental

In [None]:
%matplotlib notebook

In [None]:
def plot_vmf_pdf(ax, vmfm):
    # https://scipython.com/book/chapter-8-scipy/examples/visualizing-the-spherical-harmonics/
    phi = np.linspace(0, np.pi, 100)
    theta = np.linspace(0, 2*np.pi, 100)
    phi, theta = np.meshgrid(phi, theta)

    # The Cartesian coordinates of the unit sphere
    x = np.sin(phi) * np.cos(theta)
    y = np.sin(phi) * np.sin(theta)
    z = np.cos(phi)

    pts = np.vstack((x.ravel(),y.ravel(),z.ravel())).T
    fcolors = vmfm.pdf(pts).reshape(x.shape)
    fcolors /= fcolors.max()
    
    ax.plot_surface(x, y, z,  rstride=1, cstride=1, facecolors=matplotlib.cm.coolwarm(fcolors), shade=False)
    return fcolors

In [None]:
def two_modes_vmfm():
    vmfm = VMFMixture()
    w = vmfm.weights 
    w[:] = 0
    w[0] = 0.2
    w[1] = 0.8
    vmfm.weights = w
    c = vmfm.concentrations
    c[:] = 20
    vmfm.concentrations = c
    m = vmfm.means
    m[:,:] = 0
    m[0,:] = [1, 0, 0]
    m[1,:] = [0, 1, 0]
    vmfm.means = m
    return vmfm

def uniform_vmfm():
    vmfm = VMFMixture()
    w = vmfm.weights 
    w[:] = 0
    w[0] = 1.
    vmfm.weights = w
    c = vmfm.concentrations
    c[:] = 0.1
    vmfm.concentrations = c
    return vmfm

def make_half_sphere_vmfm():
    vmfm = VMFMixture()
    means = vmfm.means
    i = means[:,0] < 0.
    w = vmfm.weights
    w[i] = 0.
    w[~i] *= 2.
    vmfm.weights = w
    return vmfm

def _shuffle(xs, ws):
    n = xs.shape[0]
    idx = np.random.permutation(n)
    return xs[idx,...], ws[idx]


def make_uniform_samples(n):
    vmfm = uniform_vmfm()
    samples = vmfm.sample(n)
    return _shuffle(samples, np.ones(n, np.float32))


def make_two_modes_samples(n):
    vmfm = two_modes_vmfm()
    samples = vmfm.sample(n)
    return _shuffle(samples, np.ones(n, np.float32))


def make_half_sphere_samples(n):
    vmfm = make_half_sphere_vmfm()
    samples = vmfm.sample(n)
    return _shuffle(samples, np.ones(n, np.float32))


def make_two_modes_weighted(n):
    vmfm = two_modes_vmfm()
    samples = uniform_vmfm().sample(n)
    weights = vmfm.pdf(samples)
    return _shuffle(samples, weights)

In [None]:
prior_nu = 10.
prior_alpha = 10.
prior_tau = 10.
maximization_step_every = 100;
prior_mode = VMFMixture()
incremental = VMFFitIncremental(
    prior_nu = prior_nu, 
    prior_alpha = prior_alpha, 
    prior_tau = prior_tau,
    prior_mode = prior_mode,
    maximization_step_every = maximization_step_every)

xs, ws = make_uniform_samples(1000)
vmfm = VMFMixture()
incremental.fit(vmfm, xs, ws)

fig = pyplot.figure()
ax = fig.add_subplot(111, projection='3d')
v = plot_vmf_pdf(ax, vmfm)

print (vmfm.concentrations)

m = matplotlib.cm.ScalarMappable(cmap=matplotlib.cm.coolwarm)
m.set_array(v)
pyplot.colorbar(m)

pyplot.show()

In [None]:
prior_nu = 10.
prior_alpha = 10.
prior_tau = 10.
maximization_step_every = 100;
prior_mode = VMFMixture()
incremental = VMFFitIncremental(
    prior_nu = prior_nu, 
    prior_alpha = prior_alpha, 
    prior_tau = prior_tau,
    prior_mode = prior_mode,
    maximization_step_every = maximization_step_every)

xs, ws = make_half_sphere_samples(1000)
vmfm = VMFMixture()
incremental.fit(vmfm, xs, ws)

fig = pyplot.figure()
ax = fig.add_subplot(111, projection='3d')
v = plot_vmf_pdf(ax, vmfm)

print (vmfm.concentrations)

m = matplotlib.cm.ScalarMappable(cmap=matplotlib.cm.coolwarm)
m.set_array(v)
pyplot.colorbar(m)

pyplot.show()

In [None]:
prior_nu = 10.
prior_alpha = 10.
prior_tau = 1.
maximization_step_every = 100;
prior_mode = VMFMixture()
incremental = VMFFitIncremental(
    prior_nu = prior_nu, 
    prior_alpha = prior_alpha, 
    prior_tau = prior_tau,
    prior_mode = prior_mode,
    maximization_step_every = maximization_step_every)

xs, ws = make_two_modes_samples(1000)
vmfm = VMFMixture()
incremental.fit(vmfm, xs, ws)

fig = pyplot.figure()
ax = fig.add_subplot(111, projection='3d')
plot_vmf_pdf(ax, vmfm)
pyplot.show()

In [None]:
prior_nu = 10.
prior_alpha = 10.
prior_tau = 1.
maximization_step_every = 100;
prior_mode = VMFMixture()
incremental = VMFFitIncremental(
    prior_nu = prior_nu, 
    prior_alpha = prior_alpha, 
    prior_tau = prior_tau,
    prior_mode = prior_mode,
    maximization_step_every = maximization_step_every)

xs, ws = make_two_modes_weighted(1000)
vmfm = VMFMixture()
incremental.fit(vmfm, xs, ws)

fig = pyplot.figure()
ax = fig.add_subplot(111, projection='3d')
plot_vmf_pdf(ax, vmfm)
ax.scatter(*(vmfm.means*1.1).T, marker='x', c='r', s = 30.)
pyplot.show()