A notebook to investigate the behaviour of the MMD-based three-sample-test of Bounliphone et al..

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
#%config InlineBackend.figure_format = 'svg'
#%config InlineBackend.figure_format = 'pdf'

import kmod
import kgof
import kgof.goftest as gof
# submodules
from kmod import data, density, kernel, util, plot
from kmod import mctest as mct
import matplotlib
import matplotlib.pyplot as plt
import autograd.numpy as np
import scipy.stats as stats

In [None]:
plot.set_default_matplotlib_options()
# # font options
# font = {
#     #'family' : 'normal',
#     #'weight' : 'bold',
#     'size'   : 18
# }

# plt.rc('font', **font)
# plt.rc('lines', linewidth=2)
# matplotlib.rcParams['pdf.fonttype'] = 42
# matplotlib.rcParams['ps.fonttype'] = 42

## 1D Gaussian mean shift

$$p = \mathcal{N}(\mu_p, 1)$$
$$q = \mathcal{N}(\mu_q, 1)$$
$$r = \mathcal{N}(0, 1)$$

Assume that $\mu_p\neq\mu_q \neq 0$ . Assume that a Gaussian kernel $k(x,y) = \exp(-(x-y)^2/(2\nu^2))$ is used. Then the exact form of MMD^2 is known (Garreau 2017).

$$\mathrm{MMD}^2(p, r) = \frac{2\nu}{\sqrt{\nu^2 + 2}} \big( 1-\exp\big[ -\frac{\mu_p^2}{2(\nu^2+2)} \big] \big)$$

In [None]:
def mmd2_gauss(mu, gwidth):
    """
    mu: mean of the model
    gwidth: Gaussian width NOT squared
    """
    nu = gwidth
    scale = 2.0*nu/np.sqrt(nu**2 + 2.0)
    main = 1.0 - np.exp(-mu**2/(2.0*(nu**2 + 2.0)))
    return scale*main

def stat_3sample(mup, muq, gwidth):
#     print('{} {} {}'.format(mup, muq, gwidth))
    return mmd2_gauss(mup, gwidth) - mmd2_gauss(muq, gwidth)

In [None]:
def plot_stat_vs_width(mup, muq):
    plt.figure(figsize=(8, 5))
    dom = np.linspace(1e-2, 5, 200)
    array_stat = stat_3sample(mup, muq, dom)
    
    plt.plot(dom, array_stat, 'r-')
    plt.xlabel('Gaussian width')
    plt.ylabel('stat')


In [None]:
# plot MMD vs width for a few values of mu
mus = [0, 1, 2]
plt.figure(figsize=(8, 5))
dom = np.linspace(1e-2, 5, 200)
for i, mu in enumerate(mus):
    mmd2s = mmd2_gauss(mu, dom)
    plt.plot(dom, mmd2s, label=r'$\mu={}$'.format(mu))
plt.xlabel('Gaussian width')
plt.ylabel('Squared MMD')
plt.legend(fontsize=22)

In [None]:
import ipywidgets
from ipywidgets import interact, interactive, fixed
from IPython.display import display
import ipywidgets as widgets

mup_slide = ipywidgets.FloatSlider(value=1, min=-3, max=3, step=0.5)
muq_slide = ipywidgets.FloatSlider(value=0.5, min=-3, max=3.0, step=0.5)
vs = interact(plot_stat_vs_width, mup=mup_slide, muq=muq_slide,)
display(vs)