In [1]:
!rm -rf img/ && mkdir -p img/

In [2]:
import numpy as np

import scipy
from scipy.stats import norm as gauss

import matplotlib.pyplot as plt
import matplotlib.cm
import nbfigtulz as ftl

import jax
from jax import numpy as jnp
from jax import grad, jit
import jax.scipy.optimize

from iminuit import Minuit

# Setup

In [3]:
SEED = 42

np.random.seed(seed=SEED)

jax.config.update('jax_enable_x64', True)
jax.config.update('jax_platform_name', 'cpu')

# Degeneration

In [4]:
beta = np.linspace(1, 10, 101)
kappa  = np.linspace(1, 10, 101)

In [5]:
@ftl.with_context
def make_fig(beta, kappa, filename):
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    
    x, y = np.meshgrid(beta, kappa)
    z = x * (1. + y) / y
    
    ax.plot_surface(x, y, z, alpha=.3)
    ax.contour(x, y, z, zdir='z', offset=0, cmap=matplotlib.cm.coolwarm)
    
    ax.xaxis.set_rotate_label(False)
    ax.yaxis.set_rotate_label(False)
    
    ax.set_xlabel(r'$\beta$')
    ax.set_ylabel(r'$\kappa$')
    ax.set_zlabel(r'$\frac{\beta (1 + \kappa)}{\kappa}$')
    
    ax.grid(False)
    
    return ftl.save_fig(fig, filename)


make_fig(beta, kappa, 'degeneration')

img/degeneration.png
img/degeneration.pgf


degeneration.png

In [6]:
!convert img/degeneration.png -flatten -fuzz 1% -trim +repage img/degeneration.png

# Gaussian Regression

In [7]:
n = 5

data = np.zeros((n, 3))
data = np.zeros((n, 3))
data[:, 0] = np.arange(n)

for i in range(n):
    data[i, 1] = np.random.normal(loc=i)

In [8]:
@ftl.with_context
def make_fig(data, mu, sigma, filename):
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    
    ax.plot(data[:, 0], data[:, 1], data[:, 2], 'x', color='black')
    
    n = data.shape[0]
    
    ax.plot(data[:, 0], mu, data[:, 2], '--', color='black', alpha=.5)
    
    y = np.linspace(-2, n + 1, 501)
    for i in range(n):
        x = np.ones_like(y) * i
        z = gauss.pdf(y, loc=mu[i], scale=sigma[i])
        ax.plot(x, y, z, color='C1', alpha=.7)

    
    z = np.array([gauss.pdf(data[i, 1], loc=mu[i], scale=sigma[i]) for i in range(n)])
    ax.plot(data[:, 0], data[:, 1], z, 'o', color='C0', alpha=.7)
    
    ax.set_xlabel('$x$')
    ax.set_ylabel('$y$')
    ax.set_zlabel('$z$')
    
    ax.grid(False)
    
    return ftl.save_fig(fig, filename)

In [9]:
n = data.shape[0]
make_fig(data, np.arange(n), np.ones(n), 'greg')

img/greg.png
img/greg.pgf


greg.png

In [10]:
n = data.shape[0]
make_fig(data, data[:, 1], np.ones(n) * .3, 'greg2')

img/greg2.png
img/greg2.pgf


greg2.png

# Fit Student's $t$-distributions

In this section we study fits of Student's $t$-distribution functions to synthetic data samples.

First, we define probability distribution functions for Student's $t$-distribution and their logarithms. In order to speed up computation we JIT compile them using `jax.jit`.

In [11]:
@jit
def log_studentst(x, nu, mu, sigma):
    norm = jax.scipy.special.gammaln((nu + 1) / 2.) - jax.scipy.special.gammaln(nu / 2.)
    norm -= .5 * jnp.log(np.pi * nu)
    norm -= jnp.log(sigma)
    arg2 = ((x - mu) / sigma)**2
    return norm - .5 * (nu + 1.) * jnp.log(1. + 1. / nu * arg2)


@jit
def studentst(x, nu, mu, sigma):
    return jnp.exp(log_studentst(x, nu, mu, sigma))

We then define a function to compute the negative log-likelihood which we will use later.

In [17]:
def nll(data, nu, mu, sigma):
    return -jnp.sum(log_studentst(data, nu, mu, sigma))

Then, we crosscheck if our implementation produces same results as `scipy.stats.t.pdf`...

In [12]:
x = np.linspace(-5, 5, 500)
nu = np.random.rand() * 4 + 1
mu = np.random.rand()
sigma = np.random.rand()
assert np.allclose(studentst(x, nu, mu, sigma), scipy.stats.t.pdf(x, nu, loc=mu, scale=sigma))

In [13]:
@ftl.with_context
def make_fig(x, show_title=False):
    fig, ax = plt.subplots()
    
    nus = [1, 2, 5]
    ys = [studentst(x, nu=nu, mu=0., sigma=1.) for nu in nus]
            
    ax.plot(x,
            scipy.stats.norm.pdf(x, loc=0., scale=1.),
            '-',
            color='C1',
            label=r'$\nu \to \infty$')
    
    legend = [r'$\nu=' + f'{nu}$' for nu in nus]
    styles = ['dotted', '-.', '--']
    for k, s, y in list(zip(legend, styles, ys))[::-1]:
        ax.plot(x, y, linestyle=s, label=k, color='black')
        
    if show_title:
        ax.set_title('Student\'s $t$ distribution with $\\nu$ DoF')
    
    ax.set_xlabel('$x$')
    ax.set_ylabel(r'$\mathrm{PDF}$')
    ax.legend()

    return ftl.save_fig(fig, 'students')
    

make_fig(np.linspace(-4, 4, 500))

img/students.png
img/students.pgf


students.png

We can now _fit_ a $t$-distribution $f(x|\nu, \sigma)$ (for the sake of brevity we drop from here on the _Student's_ prefix) with $\nu=5$ DoF to a $t$-distribution with $\nu=2$ DoF by tweaking $\sigma$ and minimizing their squared difference on the interval $x \in [-3, +3]$.

In [15]:
@jit
def diff(sigma):
    x = np.linspace(-3., 3., 500)
    nu = 2.
    y1 = studentst(x, 2., 0., 1.)
    y2 = studentst(x, 5., 0., sigma)
    return jnp.sum((y1 - y2) ** 2)


@ftl.with_context
def make_fig(x):
    fig, ax = plt.subplots()
    
    x = np.linspace(-3, 3, 500)
    
    ax.plot(x, studentst(x, 2., 0., 1.), 'k--', label=r'$t_2(\sigma=1)$')
    
    sigma = jax.scipy.optimize.minimize(diff, np.array([1.,]), method='BFGS').x
    sigma = sigma.copy()[0]
    ax.plot(x, studentst(x, 5., 0., sigma), color='C1', label=r'$t_5(\sigma=' + f'{sigma:.2f})$')
    
    ax.set_xlabel('$x$')
    ax.set_ylabel(r'$\mathrm{PDF}$')
    ax.legend()

    return ftl.save_fig(fig, 'dfit')
    

make_fig(np.linspace(-5, 5, 500))

img/dfit.png
img/dfit.pgf


dfit.png

We test our fitting facility by fitting our PDF to $t$-distributed data points with $\nu = 5$ and $\sigma=1$ ...

In [16]:
nu_gt = 5
data = scipy.stats.t.rvs(nu_gt, size=10000, loc=0., scale=1.)

In [18]:
@jit
def fcn(par):
    nu = par[0]
    mu = par[1]
    sigma = par[2]
    return nll(data[:200], nu, mu, sigma)

x0 = (nu_gt, 0., 1.)
m = Minuit(fcn, x0, grad=jit(jax.grad(fcn)), name=('nu', 'mu', 'sigma'))
m.errordef = Minuit.LIKELIHOOD
m.limits = ((0, None), (-1., 1.), (0., 5.))
m.migrad()
m.hesse()
m.minos()

0,1,2,3,4
FCN = 333.5,FCN = 333.5,Nfcn = 197,Nfcn = 197,Nfcn = 197
EDM = 1.02e-06 (Goal: 0.0001),EDM = 1.02e-06 (Goal: 0.0001),Ngrad = 5,Ngrad = 5,Ngrad = 5
Valid Minimum,Valid Parameters,No Parameters at limit,No Parameters at limit,No Parameters at limit
Below EDM threshold (goal x 10),Below EDM threshold (goal x 10),Below call limit,Below call limit,Below call limit
Covariance,Hesse ok,Accurate,Pos. def.,Not forced

0,1,2,3,4,5,6,7,8
,Name,Value,Hesse Error,Minos Error-,Minos Error+,Limit-,Limit+,Fixed
0.0,nu,3.3,0.8,-0.7,1.0,0,,
1.0,mu,0.12,0.08,-0.08,0.08,-1,1,
2.0,sigma,0.93,0.08,-0.08,0.09,0,5,

0,1,2,3,4,5,6
,nu,nu,mu,mu,sigma,sigma
Error,-0.7,1.0,-0.08,0.08,-0.08,0.09
Valid,True,True,True,True,True,True
At Limit,False,False,False,False,False,False
Max FCN,False,False,False,False,False,False
New Min,False,False,False,False,False,False

0,1,2,3
,nu,mu,sigma
nu,0.679,0.00279 (0.043),0.0457 (0.654)
mu,0.00279 (0.043),0.00631,0.00027 (0.040)
sigma,0.0457 (0.654),0.00027 (0.040),0.00721


...the fit converges and reveals a strong correlation of the parameters $\nu$ and $\sigma$.

Let's now fit $10\,000$ data points:

In [19]:
def fit(size=None, verbose=False):
    rnd_idx = np.random.permutation(data.shape[0])
    if size:
        rnd_idx = rnd_idx[:size]
    
    @jit
    def fcn(par):
        nu = par[0]
        sigma = par[1]
        return nll(data[rnd_idx], nu, 0., sigma)
    
    nu = np.random.rand() + nu_gt - .5
    sigma = np.random.rand() + .5
    
    x0 = (nu, sigma)
    if verbose:
        print('x0:', x0)
        
    m = Minuit(fcn, x0, grad=jit(jax.grad(fcn)), name=('nu', 'sigma'))
    m.errordef = Minuit.LIKELIHOOD
    m.limits = ((0, 10), (0., 2.))
    m.migrad()
    
    if m.fmin.is_valid and not m.fmin.has_parameters_at_limit:
        nu = m.values['nu']
        sigma = m.values['sigma']
        if verbose:
            print(' x0 fcn', fcn(x0))
            print('Fit fcn', fcn([nu, sigma]))
            print(' GT fcn', fcn([nu_gt, 1.]))
        
        return (nu, sigma), rnd_idx
    else:
        return None, rnd_idx

In [20]:
@ftl.with_context
def make_fig():
    fig, ax = plt.subplots()
    
    (nu, sigma), idx = fit(verbose=True)
    ax.hist(data, np.linspace(-5, 5, 21), alpha=.5, rwidth=.9, density=True, label='Data')
    
    x = np.linspace(-5, 5, 501)
    ax.plot(x, studentst(x, nu, 0., sigma), label='Fit')
    ax.plot(x, studentst(x, nu_gt, 0., 1.), 'k--', label='GT')
    
    ax.text(.05, .95, '\n'.join([
        r'$\nu_\mathrm{Fit} = ' + f'{nu:.2f}$',
        r'$\nu_\mathrm{GT} = ' + f'{nu_gt}$',
        r'$\sigma_\mathrm{Fit} = ' + f'{sigma:.2f}$',
        r'$\sigma_\mathrm{GT} = 1',
    ]), transform=ax.transAxes, verticalalignment='top')
    
    ax.set_xlabel('$x$')
    ax.set_ylabel('PDF')
    
    ax.legend()
    
    return ftl.save_fig(fig, 'highstatfit')


make_fig()

x0: (5.451644348194464, 1.0312212458482315)
 x0 fcn 16267.428936146327
Fit fcn 16265.606678248825
 GT fcn 16266.047446873165
img/highstatfit.png
img/highstatfit.pgf


highstatfit.png

Again, the fit converges nicely and shows a good agreement with the ground truth (GT). Note the rather large numerical discrepancy of $\nu$, i.e., values $\nu = 5 + \mathcal{O}(0.1)$ are almost indistinguishable even with a large statistic.

Finally, we run multiple fits for different sample sizes. Some of the fits will fail to converge and parameters might reach their limit. In particular, this is problematic for $\nu$, since values $\nu \gg 5$ are almost indistinguishable for the fit. In order to not bias our result we print the entire value range for $\nu$ of all converged fits. If the upper boundary comes too close to our cut-off $\nu_\text{cut-off} = 10$ we have to increase it.

In [21]:
def run_fits(n, n_iter=200, verbose=False):
    nu = np.empty((n.shape[0], 2))
    sigma = np.empty((n.shape[0], 2))
    for i, k in enumerate(n):
        if verbose:
            print(f'n = {k}')
        
        nus = []
        sigmas = []
        for _ in range(n_iter):
            fit_res, _ = fit(k)
            if fit_res:
                nus.append(fit_res[0])
                sigmas.append(fit_res[1])

        nus = np.array(nus)
        nu_q16 = np.quantile(nus, .16)
        nu_q84 = np.quantile(nus, .84)
        sel = (nus >= nu_q16) & (nus <= nu_q84)
        if verbose:
            print(f' * {nus.shape[0]} / {n_iter} fits converged')
            print(f' * 100% nu: [{np.min(nus):.2f} .. {np.max(nus):.2f}]')
            print(f'    68% nu: [{np.min(nus[sel]):.2f} .. {np.max(nus[sel]):.2f}]')
        nus = nus[sel]

        sigmas = np.array(sigmas)
        sigma_q16 = np.quantile(sigmas, .16)
        sigma_q84 = np.quantile(sigmas, .84)
        sel = (sigmas >= sigma_q16) & (sigmas <= sigma_q84)
        if verbose:
            print(f' * 100% sigma: [{np.min(sigmas):.2f} .. {np.max(sigmas):.2f}]')
            print(f'    68% sigma: [{np.min(sigmas[sel]):.2f} .. {np.max(sigmas[sel]):.2f}]')
            print('')
        sigmas = sigmas[sel]

        nu[i] = min(nus), max(nus)
        sigma[i] = min(sigmas), max(sigmas)
    
    return nu, sigma


n = np.array([50, 75, 100, 125, 150, 175, 200])
nu, sigma = run_fits(n, n_iter=200, verbose=True)

n = 50
 * 91 / 200 fits converged
 * 100% nu: [1.67 .. 6.28]
    68% nu: [2.65 .. 5.06]
 * 100% sigma: [0.65 .. 1.25]
    68% sigma: [0.79 .. 1.04]

n = 75
 * 112 / 200 fits converged
 * 100% nu: [2.11 .. 6.71]
    68% nu: [3.33 .. 5.47]
 * 100% sigma: [0.72 .. 1.39]
    68% sigma: [0.85 .. 1.09]

n = 100
 * 112 / 200 fits converged
 * 100% nu: [1.93 .. 7.18]
    68% nu: [3.37 .. 5.58]
 * 100% sigma: [0.71 .. 1.28]
    68% sigma: [0.85 .. 1.08]

n = 125
 * 132 / 200 fits converged
 * 100% nu: [2.58 .. 7.65]
    68% nu: [3.64 .. 6.28]
 * 100% sigma: [0.76 .. 1.30]
    68% sigma: [0.88 .. 1.07]

n = 150
 * 135 / 200 fits converged
 * 100% nu: [2.52 .. 7.60]
    68% nu: [3.78 .. 6.40]
 * 100% sigma: [0.80 .. 1.24]
    68% sigma: [0.91 .. 1.07]

n = 175
 * 148 / 200 fits converged
 * 100% nu: [2.55 .. 7.73]
    68% nu: [3.90 .. 6.51]
 * 100% sigma: [0.78 .. 1.20]
    68% sigma: [0.91 .. 1.07]

n = 200
 * 155 / 200 fits converged
 * 100% nu: [2.76 .. 7.95]
    68% nu: [3.97 .. 6.59]
 * 100%

All that's left is to visualize the deviations... 

In [22]:
@ftl.with_context
def make_fig(x, nu, sigma):
    fig, ax = plt.subplots()
    
    sigma_gt = 1.
    sigma_scale = 10.
    
    nu_mean = (nu[:, 0] + nu[:, 1]) / 2.
    nu_error = (nu_mean - nu[:, 0], nu[:, 1] - nu_mean)
    ax.errorbar(x + 3,
                nu_mean - nu_gt,
                xerr=None,
                yerr=nu_error,
                fmt='o',
                label=r'$\nu_\mathrm{Fit} - \nu_\mathrm{GT}$ @ 68\% CL')
    
    sigma_mean = (sigma[:, 0] + sigma[:, 1]) / 2.
    sigma_error = np.array([sigma_mean - sigma[:, 0], sigma[:, 1] - sigma_mean])
    ax.errorbar(x - 3,
                sigma_scale * (sigma_mean - sigma_gt),
                xerr=None,
                yerr=sigma_scale * sigma_error,
                fmt='x',
                label=r'$10 \times (\sigma_\mathrm{Fit} - \sigma_\mathrm{GT})$ @ 68\% CL')
    
    ax.set_xlabel('Sample size')
    ax.set_ylabel(r'Residual')
    ax.set_title(r'$\nu_\mathrm{GT}=' + f'{nu_gt}' + r'$, $\sigma_\mathrm{GT}=1$')
    
    y_lim = np.max(np.abs(ax.get_ylim()))
    ax.set_ylim(-y_lim, y_lim)
    
    ax.legend()
    ax.grid()

    return ftl.save_fig(fig, 'res')
   

make_fig(n, nu, sigma)

img/res.png
img/res.pgf


res.png

...revealing a slight bias for $\nu$ and $\sigma$...