# Initialize

In [None]:
# %load ./init.ipy
%reload_ext autoreload
%autoreload 2

# %load init.ipy
import os, sys, logging, datetime, warnings, shutil
from importlib import reload

import numpy as np
import scipy as sp
import scipy.stats
import matplotlib as mpl
import matplotlib.pyplot as plt
from nose import tools

import kalepy as kale
import kalepy.utils
import kalepy.plot
import kalepy.sample

# The `nbshow` command runs `plt.show()` in interactive jupyter notebooks, but closes
#   figures when run from the command-line (notebooks are converted to scripts and run as tests)
from kalepy.plot import nbshow

import warnings
# warnings.simplefilter("error")   # WARNING: this is breaking jupyter at the moment (2021-02-14)

In [None]:
np.random.seed(1234567)

In [None]:
def draw_example(xx, yy, data):
    fig, axes = plt.subplots(figsize=[12, 4], ncols=2)

    for ii, ax in enumerate(axes):
        interp = bool(ii)
        vals = kale.sample.sample_grid([xx, yy], data, 2000, interpolate=interp)
        ax.set_title(f"{interp=}")
        ax.scatter(*vals, color='pink', s=10, zorder=20, alpha=0.3)
        ax.hist2d(*vals, bins=10)

    return fig, vals

# Test 1

In [None]:
xx = [0.0, 1.0, 2.0]
yy = [100.0, 200.0, 300.0]

edge = [
    [0.0, 10.0, 0.0],
    [10.0, 20.0, 10.0],
    [0.0, 10.0, 0.0]
]

fig, vals = draw_example(xx, yy, edge)

nbshow()

## Compare 1/4ths

In [None]:
# ---- Make sure quadrants are evenly sampled
quads, *_ = np.histogram2d(*vals, bins=(xx, yy))

tot = quads.sum()
# fraction in each quadrant
frac = quads / tot
print("frac = ", frac.flatten())
# expected poisson error
err = 1.0 / np.sqrt(tot / 4.0)
err *= 3.0   # padding
extr = 0.25 * np.array([1.0 - err, 1.0 + err])
print("exp = ", 1.0/4.0, " +- ", err, "==>", extr)
# Make sure each quadrant is roughly equal
if np.any((frac < extr[0]) | (frac > extr[1])):
    raise ValueError("Uneven distribution of points!")

## Compare 1/16ths

In [None]:
edges = [np.sort(np.concatenate([kale.utils.midpoints(ee), ee])) for ee in [xx, yy]]
octs, *_ = np.histogram2d(*vals, bins=edges)

tot = octs.sum()
frac = octs / tot
ave = tot / octs.size
err = np.sqrt(ave)
extr = [ave - err, ave + err]
extr = np.array(extr) / tot
print(f"{ave=}, {err=}, {extr=}")

print("frac=\n", frac)
corners = np.ix_([0, 3], [0, 3])
corners = frac[corners].flatten()
print("corners=\n", corners)

centers = np.ix_([1, 2], [1, 2])
centers = frac[centers].flatten()
print("centers=\n", centers)

diags_h = np.ix_([1, 2], [0, 3])
diags_h = frac[diags_h]
diags_v = np.ix_([0, 3], [1, 2])
diags_v = frac[diags_v]
diags = np.concatenate([diags_h, diags_v]).flatten()
print("diags=\n", diags)

if not np.all(diags[:, np.newaxis] < centers[np.newaxis, :]):
    err = "diagonal values are not all smaller than center values!"
    raise ValueError(err)
    
if not np.all(corners[:, np.newaxis] < diags[np.newaxis, :]):
    err = "corner values are not all smaller than diagonal values!"
    raise ValueError(err)
    
if np.any((diags > 0.07) | (diags < 0.03)):
    raise ValueError("diagonals are out of bounds!")
    
if np.any((centers > 0.16) | (0.11 > centers)):
    raise ValueError("centers are out of bounds!")
    
if np.any((corners > 0.025) | (corners < 0.008)):
    raise ValueError("corners are out of bounds!")

# Test 2

In [None]:
xx = [0.0, 1.0, 2.0]
yy = [100.0, 200.0, 300.0]

edge = [
    [10.0, 0.0, 10.0],
    [ 0.0, 0.0,  0.0],
    [10.0, 0.0, 10.0]
]

fig, vals = draw_example(xx, yy, edge)
nbshow()

## Compare 1/4ths

In [None]:
# ---- Make sure quadrants are evenly sampled
quads, *_ = np.histogram2d(*vals, bins=(xx, yy))

tot = quads.sum()
# fraction in each quadrant
frac = quads / tot
print("frac = ", frac.flatten())
# expected poisson error
err = 1.0 / np.sqrt(tot / 4.0)
err *= 2.0   # padding
extr = 0.25 * np.array([1.0 - err, 1.0 + err])
print("exp = ", 1.0/4.0, " +- ", err, "==>", extr)
# Make sure each quadrant is roughly equal
if np.any((frac < extr[0]) | (frac > extr[1])):
    raise ValueError("Uneven distribution of points!")

## Compare 1/16ths

In [None]:
edges = [np.sort(np.concatenate([kale.utils.midpoints(ee), ee])) for ee in [xx, yy]]
octs, *_ = np.histogram2d(*vals, bins=edges)

tot = octs.sum()
frac = octs / tot
ave = tot / octs.size
err = np.sqrt(ave)
extr = [ave - err, ave + err]
extr = np.array(extr) / tot
print(f"{ave=}, {err=}, {extr=}")

print("frac=\n", frac)
corners = np.ix_([0, 3], [0, 3])
corners = frac[corners].flatten()
print("corners=\n", corners)

centers = np.ix_([1, 2], [1, 2])
centers = frac[centers].flatten()
print("centers=\n", centers)

diags_h = np.ix_([1, 2], [0, 3])
diags_h = frac[diags_h]
diags_v = np.ix_([0, 3], [1, 2])
diags_v = frac[diags_v]
diags = np.concatenate([diags_h, diags_v]).flatten()
print("diags=\n", diags)

if not np.all(corners[:, np.newaxis] > diags[np.newaxis, :]):
    err = "corner values are not all larger than diagonal values!"
    raise ValueError(err)
    
if not np.all(diags[:, np.newaxis] > centers[np.newaxis, :]):
    err = "diagonal values are not all larger than center values!"
    raise ValueError(err)
    
if np.any((diags > 0.06) | (diags < 0.03)):
    raise ValueError("diagonals are out of bounds!")
    
test = (corners > 0.16) | (corners < 0.11)
if np.any(test):
    raise ValueError("corners are out of bounds!")
    
if np.any((centers > 0.025) | (centers < 0.008)):
    raise ValueError("centers are out of bounds!")

# Test 3

In [None]:
xx = [0.0, 1.0, 2.0, 3.0]
yy = [100.0, 200.0, 300.0]

edge = [
    [10.0, 0.0, 10.0],
    [10.0, 0.0, 10.0],
    [10.0, 0.0, 10.0],
    [10.0, 0.0, 10.0],
]

fig, vals = draw_example(xx, yy, edge)
nbshow()

In [None]:
hist, *_ = np.histogram2d(*vals, bins=(3, 3))
tot = hist.sum()
frac = hist / tot
ave = tot / hist.size
err = np.sqrt(ave)
print(f"{ave=} {err=}")
print("frac=\n", frac)

rows = np.mean(frac, axis=0)
print("rows=", rows)

df = np.fabs(rows[-1] - rows[0])
print("df=", df)
if df > 0.015:
    raise ValueError("Edge rows do not match!")

if (rows[0] < 0.14) or (rows[0] > 0.16):
    raise ValueError("edge rows are out of bounds!")
    
if (rows[1] < 0.03) or (rows[1] > 0.05):
    raise ValueError("middle row is out of bounds!")

    
cols = np.mean(frac, axis=1)
print("cols=", cols)
df = np.fabs(cols[:, np.newaxis] - cols[np.newaxis, :]).max()
print("max diff=", df)
if df > 0.02:
    raise ValueError("columns are inconsistent!")

if (cols[0] < 0.095) or (cols[0] > 0.125):
    raise ValueError("columns are out of bounds!")

# Visual : test 3D function

In [None]:
def func(x, y, z):
    rv = (x + y - 2*z)**2 + x**2 + y**2 + z**2
    rv = 3 * np.exp(-rv / 2.0)
    return rv

xx = np.linspace(-4, 4, 200)
yy = np.linspace(-4, 4, 300)
zz = np.linspace(-1, 1, 5)

data = func(xx[:, np.newaxis, np.newaxis], yy[np.newaxis, :, np.newaxis], zz[np.newaxis, np.newaxis, :]) 
vals = kale.sample.sample_grid([xx, yy, zz], data, 2000)

# smap = zplot.smap(data, scale='linear')

# for ii, _ in enumerate(zz):
#     fig, ax = plt.subplots()
#     aa = data[:, :, ii]
#     plt.pcolormesh(xx, yy, aa.T, cmap=smap.cmap, norm=smap.norm)
    
# nbshow()
    

In [None]:
# smap = zplot.smap(data_edge, scale='lin')
extr = kale.utils.minmax(data)
norm = mpl.colors.Normalize(*extr)

# add an offset so that the last 'bin' is also matched
off = np.diff(zz)[0]*0.5
keys = np.searchsorted(zz, vals[2] + off) - 1
for ii in range(zz.size):
    idx = (keys == ii)
        
    fig, ax = plt.subplots()
    plt.pcolormesh(xx, yy, data[:, :, ii].T, norm=norm, shading='auto')
    uu = vals[0][idx]
    vv = vals[1][idx]
    plt.scatter(uu, vv, facecolor='pink', edgecolor='0.25', alpha=0.7, s=20, zorder=100)

nbshow()

# Visual : test 3D random data

In [None]:
data = kale.utils._random_data_3d_01()
data = data[:, ::10]

In [None]:
kde = kale.KDE(data)
edges, values = kde.density()
print([len(ee) for ee in edges], np.shape(values))

In [None]:
nsamp = 1000
samples = kale.sample.sample_grid(edges, values, nsamp)

In [None]:
corner = kale.Corner(kde)
corner.plot_kde()
corner.plot_data(samples, color='r')

nbshow()

# Scalar Values

In [None]:
def func(x, y):
    rv = (x + y)**2 + x**2 + y**2
    rv = 3 * np.exp(-rv / 2.0) + 1.0
    return rv

xx = np.linspace(-4, 4, 22)
yy = np.linspace(-4, 4, 23)

data = func(xx[:, np.newaxis], yy[np.newaxis, :]) 
vals = kale.sample.sample_grid([xx, yy], data, 20000)

In [None]:
fig, axes = plt.subplots(figsize=[15, 6], ncols=2)

ax = axes[0]
ax.pcolormesh(xx, yy, data.T, shading='auto')

ax = axes[1]
hist, *_ = np.histogram2d(*vals, bins=(xx, yy))
print(hist.shape)
ax.pcolormesh(xx, yy, hist.T, shading='auto')

plt.show()

In [None]:
# data = func(xx[:, np.newaxis], yy[np.newaxis, :])
rr = xx[:, np.newaxis]**2 + yy[np.newaxis, :]**2 + 1.0
# ww = 4 * rr
# ww = zmath.rescale(rr)
ww = np.sqrt(rr)
# vals, scal = kale.sample.sample_grid([xx, yy], ww, 200000, scalar=data/ww, interpolate=True)
# zmath.stats_str(scal)

vals, scal = kale.sample.sample_grid_proportional([xx, yy], data, ww, 200000)

In [None]:
fig, axes = plt.subplots(figsize=[15, 6], ncols=3)

ax = axes[0]
ax.pcolormesh(xx, yy, data.T, shading='auto')

ax = axes[1]
hist, *_ = np.histogram2d(*vals, bins=(xx, yy))
ax.pcolormesh(xx, yy, hist.T, shading='auto')

ax = axes[2]
hist, *_ = np.histogram2d(*vals, bins=(xx, yy), weights=scal)
ax.pcolormesh(xx, yy, hist.T, shading='auto')

plt.show()

# Outlier Sampling

## Example 1

In [None]:
NTOT = 1e4

xx = np.logspace(-4, 4, 31)
yy = np.logspace(-2, 2, 21)
edges = (xx, yy)

xg, yg = np.meshgrid(*edges, indexing='ij')


aa = (xg/10 + (yg/10)**2) / np.power(xg * yg, 0.2)
pdf = np.power(aa, 2.0) / (1 + aa)**4
pdf = np.power(pdf, 0.25)

pdf *= NTOT / pdf.sum()

pmf = kale.utils.trapz_dens_to_mass(pdf, edges)
print(f"pdf = {kale.utils.stats_str(pdf)}")
print(f"pmf = {kale.utils.stats_str(pmf)}")

fig, ax = plt.subplots(figsize=[12, 8])
ax.set(xscale='log', yscale='log')
ax.pcolormesh(xg, yg, pdf, shading='nearest')

plt.show()

In [None]:
thresh = 10.0
sampler = kale.Sample_Outliers(edges, pdf, threshold=thresh)

num, vv, ww = sampler.sample()
print(num, vv.size, ww.size, ww.sum(), pmf.sum(), ww.sum()/pmf.sum())

In [None]:
fig, axes = plt.subplots(figsize=[20, 7], ncols=3)
for ax in axes:
    ax.set(xscale='log', yscale='log')
    
ax = axes[0]
extr = np.fabs(zz).max() * 1.1
norm = mpl.colors.Normalize(0.0, extr)

pcm = ax.pcolormesh(xg, yg, pmf, norm=norm)
plt.colorbar(pcm, ax=ax)

ax = axes[1]
hist, *_ = np.histogram2d(*vv, weights=ww, bins=(xx, yy))
xmid = kale.utils.midpoints(xx, log=True)
ymid = kale.utils.midpoints(yy, log=True)
pcm = ax.pcolormesh(xmid, ymid, hist.T, shading='nearest', norm=norm)
plt.colorbar(pcm, ax=ax)



ax = axes[2]
err = (hist - pmf) / pmf
extr = np.fabs(err).max()
norm = mpl.colors.Normalize(-extr, extr)
# norm = mpl.colors.Normalize(-0.1, 0.1)

pcm = ax.pcolormesh(xg, yg, err, cmap='bwr', norm=norm)
plt.colorbar(pcm, ax=ax)

plt.show()


idx = (pmf > thresh)
print(f"fraction above threshold: {np.count_nonzero(idx)/idx.size:.4e}")
check = err[idx]
print(f"errors above thresh     : {kale.utils.stats_str(check, format=':.4e')}")
if np.any(np.fabs(check) > 1e-10):
    raise ValueError(f"Error too large for interior region!")


idx = (pmf < thresh)
print(f"fraction below threshold: {np.count_nonzero(idx)/idx.size:.4e}")
check = np.fabs(hist - pmf)[idx]
allow = 4 * np.sqrt(np.maximum(pmf[idx], 1.0))
print(f"errors below thresh     : {kale.utils.stats_str(check, format=':.4e')}")
print(f" compared to 4x Poisson : {kale.utils.stats_str(check/allow, format=':.4e')}")
if np.any(check > allow):
    raise ValueError(f"Error too large for outlier region!")


## Example 2

In [None]:
NTOT = 3e5

xx = np.logspace(-4, 4, 31)
yy = np.logspace(-2, 2, 21)

xg, yg = np.meshgrid(xx, yy, indexing='ij')

pdf = np.power(xg, -5.5) * np.power(yy, -1.35)
pdf = np.power(pdf, 0.25)

ledges = [np.log10(xx), np.log10(yy)]
pmf = kale.utils.trapz_dens_to_mass(pdf, ledges)
pdf *= NTOT / pmf.sum()
pmf = kale.utils.trapz_dens_to_mass(pdf, ledges)

print(f"pdf = {kale.utils.stats_str(pdf)}")
print(f"pmf = {kale.utils.stats_str(pmf)}")

fig, ax = plt.subplots(figsize=[16, 10])
ax.set(xscale='log', yscale='log')
pcm = ax.pcolormesh(xg, yg, np.log10(pdf), shading='nearest')
plt.colorbar(pcm, ax=ax)

plt.show()

In [None]:
def test_loc(test, loc):
    x = test._coms_ins[0]
    loc_up = tuple([loc[0] + 1, loc[1]])
    print(x[loc])
    x1 = test._edges[0][loc[0]]
    x2 = test._edges[0][loc_up[0]]
    print(x1, x2)
    y1 = test._dens[loc]
    y2 = test._dens[loc_up]
    print(y1, y2)
    check = (x1*y1 + x2*y2) / (y1 + y2)
    print(check)

    if not np.isclose(x[loc], check):
        raise RuntimeError()

    return


test_samp = kale.Sample_Outliers(ledges, pdf, threshold=0.0)

N = 10
xi = np.random.randint(0, test_samp._shape_bins[0], N)
yi = np.random.randint(0, test_samp._shape_bins[1], N)
for ii, jj in zip(xi, yi):
    print(f"\n{ii}, {jj}")
    test_loc(test_samp, (ii, jj))
        

In [None]:
thresh = 10.0
sampler = kale.Sample_Outliers(ledges, pdf, threshold=thresh)

num, vv, ww = sampler.sample()
vv = np.power(10.0, vv)
print(num, vv.size, ww.size, ww.sum(), pmf.sum(), ww.sum()/pmf.sum())

In [None]:
REALS = 100
sh = (xx.size - 1, yy.size - 1, REALS)
hist = np.zeros(sh)
for rr in range(REALS):
    nn, vv, ww = sampler.sample()
    vv = np.power(10.0, vv)
    hist[..., rr], *_ = np.histogram2d(*vv, weights=ww, bins=(xx, yy))
    


In [None]:
def plot_sample_comparison(xx, yy, pdf, pmf, hist, log_flag=True, diff_extr=None):
    LOG_FLAG = True

    fig, axes = plt.subplots(figsize=[20, 7], ncols=3)
    for ax in axes:
        ax.set(xscale='log', yscale='log')

    def get_temp(vals):
        temp = vals
        if log_flag:
            _temp = np.zeros_like(temp)
            idx = (temp > 0.0)
            _temp[idx] = np.log10(temp[idx])
            temp = _temp
            del _temp
        return temp    

    temp = get_temp(pdf)
    extr = np.fabs(temp).max() * 1.1
    norm = mpl.colors.Normalize(0.0, extr)

    ax = axes[0]
    pcm = ax.pcolormesh(xg, yg, temp, shading='nearest', norm=norm)
    plt.colorbar(pcm, ax=ax)

    ax = axes[1]
    # hist, *_ = np.histogram2d(*vv, weights=ww, bins=(xx, yy))
#     xmid = kale.utils.midpoints(xx, log=True)
#     ymid = kale.utils.midpoints(yy, log=True)
    temp = get_temp(hist.T)
    pcm = ax.pcolormesh(xmid, ymid, temp, shading='nearest', norm=norm)
    plt.colorbar(pcm, ax=ax)

    ax = axes[2]
    zmid = kale.utils.midpoints(zz, log=False, axis=None)
    err = (hist - pmf) / pmf
    extr = np.fabs(err).max()
    if diff_extr is not None:
        extr = diff_extr

    norm = mpl.colors.Normalize(-extr, extr)

    pcm = ax.pcolormesh(xg, yg, err, cmap='bwr', norm=norm)
    plt.colorbar(pcm, ax=ax)
    return err
    
    
diff_extr = None
diff_extr = 10

err = plot_sample_comparison(xx, yy, pdf, pmf, hist.mean(axis=-1), log_flag=True, diff_extr=diff_extr)
plt.show()


idx = (pmf > thresh)
print(f"fraction above threshold: {np.count_nonzero(idx)/idx.size:.4e}")
check = err[idx]
print(f"errors above thresh     : {kale.utils.stats_str(check, format=':.4e')}")
if np.any(np.fabs(check) > 1e-10):
    raise ValueError(f"Error too large for interior region!")



idx = (pmf < thresh)
print(f"fraction below threshold: {np.count_nonzero(idx)/idx.size:.4e}")
check = np.fabs(hist.mean(axis=-1) - pmf)[idx]
allow = 4 * np.sqrt(np.maximum(pmf[idx], 1.0))
print(f"errors below thresh     : {kale.utils.stats_str(check, format=':.4e')}")
print(f" compared to 4x Poisson : {kale.utils.stats_str(check/allow, format=':.4e')}")
if np.any(check > allow):
    raise ValueError(f"Error too large for outlier region!")


## Example 3

In [None]:
NTOT = 1e5

xx = np.logspace(-4, 4, 31)
yy = np.logspace(-2, 2, 21)
zz = np.linspace(0.0, 4.0, 11)

ledges = [np.log10(xx), np.log10(yy), zz]
xg, yg, zg = np.meshgrid(xx, yy, zz, indexing='ij')

aa = (xg/10 + (yg/10)**2) / np.power(xg * yg, 0.2)
pdf = np.power(aa, 2.0) / (1 + aa)**4
pdf = np.power(pdf, 0.25)

pmf = kale.utils.trapz_dens_to_mass(pdf, ledges)
pdf *= NTOT / pmf.sum()
pmf = kale.utils.trapz_dens_to_mass(pdf, ledges)

print(f"pdf = {kale.utils.stats_str(pdf)}")
print(f"pmf = {kale.utils.stats_str(pmf)}")


fig, axes = plt.subplots(figsize=[20, 7], ncols=3)
    
grids = [xg, yg, zg]
cut = tuple([slice(None), slice(None), 0])
for ii, ax in enumerate(axes):
    jj = (ii + 1) % 3
    kk = (ii + 2) % 3
    cc = [slice(None) for ii in range(3)]
    cc[ii] = 0
    cc = tuple(cc)
    aa = grids[jj]
    bb = grids[kk]
    
    xsc = 'linear' if jj == 2 else 'log'
    ysc = 'linear' if kk == 2 else 'log'
    ax.set(xscale=xsc, xlim=kale.utils.minmax(aa), yscale=ysc, ylim=kale.utils.minmax(bb))

    ax.pcolormesh(aa[cc], bb[cc], pmf[cc])

plt.show()

In [None]:
thresh = 20.0
sampler = kale.Sample_Outliers(ledges, pdf, threshold=thresh)

num, vv, ww = sampler.sample()
vv[0] = np.power(10.0, vv[0])
vv[1] = np.power(10.0, vv[1])
print(num, vv.size, ww.size, ww.sum(), pmf.sum(), ww.sum()/pmf.sum())

In [None]:
fig, axes = plt.subplots(figsize=[20, 7], ncols=3)
for ax in axes:
    ax.set(xscale='log', yscale='log')
    
dd = np.sum(pmf, axis=-1)
extr = np.fabs(dd).max() * 1.1
norm = mpl.colors.Normalize(0.0, extr)

ax = axes[0]
cut = tuple([slice(None), slice(None), 0])
pcm = ax.pcolormesh(xg[cut], yg[cut], dd, norm=norm)
plt.colorbar(pcm, ax=ax)

ax = axes[1]
hist, *_ = np.histogram2d(vv[0], vv[1], weights=ww, bins=(xx, yy))
# xmid = kale.utils.midpoints(xx, log=True)
# ymid = kale.utils.midpoints(yy, log=True)
pcm = ax.pcolormesh(xx, yy, hist.T, norm=norm)
plt.colorbar(pcm, ax=ax)

ax = axes[2]
# dd = kale.utils.midpoints(dd, log=True, axis=None)
err = (hist - dd) / dd
extr = np.fabs(err).max()
# extr = 0.3
# extr = 1.0
norm = mpl.colors.Normalize(-extr, extr)
# norm = mpl.colors.Normalize(-0.1, 0.1)

pcm = ax.pcolormesh(xg[cut], yg[cut], err, cmap='bwr', norm=norm)
plt.colorbar(pcm, ax=ax)

plt.show()


idx = np.all(pmf > thresh, axis=-1)
print(f"fraction above threshold: {np.count_nonzero(idx)/idx.size:.4e}")
check = err[idx]
print(f"errors above thresh     : {kale.utils.stats_str(check, format=':.4e')}")
if np.any(np.fabs(check) > 1e-10):
    raise ValueError(f"Error too large for interior region!")



idx = ~idx
# idx = slice(None)
# print(f"fraction below threshold: {np.count_nonzero(idx)/idx.size:.4e}")
check = np.fabs(hist - dd)[idx]
# allow = 3 * np.sqrt(np.maximum(dd[idx], 1.0))
allow = 2 * np.fabs(dd[idx] - sp.stats.poisson.ppf(1.0 - 1.0/check.size, dd[idx]))
print(f"errors below thresh     : {kale.utils.stats_str(check, format=':.4e')}")
print(f"    compared to allowed : {kale.utils.stats_str(check/allow, format=':.4e')}")
if np.any(check > allow):
    
    bad = np.where(check > allow)
    print(f"{bad=}, {hist[idx][bad]=}, {dd[idx][bad]=}")
    
    raise ValueError(f"Error too large for outlier region!")


## Example 4

In [None]:
xx = np.logspace(-4, 4, 31)
yy = np.logspace(-2, 2, 21)
zz = np.linspace(0.0, 4.0, 11)
# zz = np.logspace(-3, 3, 11)

ledges = [np.log10(xx), np.log10(yy), zz]

xg, yg, zg = np.meshgrid(xx, yy, zz, indexing='ij')

pdf = 2.0 * np.ones_like(xg)
idx = ((1e-2 < xg) & (xg < 1e2)) & ((1e-1 < yg) & (yg < 1e1))

pdf[idx] = 30.0

pmf = kale.utils.trapz_dens_to_mass(pdf, ledges)
mid = tuple([ss.size//2 for ss in [xx, yy, zz]])
print(pdf[mid], pmf[mid])
pdf *= 30.0 / pmf[mid]
pmf = kale.utils.trapz_dens_to_mass(pdf, ledges)

print(f"pdf = {kale.utils.stats_str(pdf)}")
print(f"pmf = {kale.utils.stats_str(pmf)}")


fig, axes = plt.subplots(figsize=[20, 7], ncols=3)
# for ax in axes:
#     ax.set(xscale='log', yscale='log')
    
grids = [xg, yg, zg]
cut = tuple([slice(None), slice(None), 0])
for ii, ax in enumerate(axes):
    jj = (ii + 1) % 3
    kk = (ii + 2) % 3
    cc = [slice(None) for ii in range(3)]
    cc[ii] = ledges[ii].size//2
    cc = tuple(cc)
    aa = grids[jj]
    bb = grids[kk]

    xsc = 'linear' if jj == 2 else 'log'
    ysc = 'linear' if kk == 2 else 'log'
    
    pcm = ax.pcolormesh(aa[cc], bb[cc], pmf[cc])
    ax.set(xscale=xsc, xlim=kale.utils.minmax(aa), yscale=ysc, ylim=kale.utils.minmax(bb))
    plt.colorbar(pcm, ax=ax)

# ax = axes[0]
# ax.pcolormesh(xg[cut], yg[cut], dist[cut], shading='nearest')

plt.show()


In [None]:
thresh = 10.0
sampler = kale.Sample_Outliers(ledges, pdf, threshold=thresh)

num, vv, ww = sampler.sample()
vv[0] = np.power(10.0, vv[0])
vv[1] = np.power(10.0, vv[1])
print(num, vv.shape, ww.size, ww.sum(), pmf.sum(), ww.sum()/pmf.sum())

In [None]:
fig, axes = plt.subplots(figsize=[20, 7], ncols=3)
for ax in axes:
    ax.set(xscale='log', yscale='log')
    

ax = axes[0]
cut = tuple([slice(None), slice(None), 0])
dd = np.sum(pmf, axis=-1)

extr = np.fabs(dd).max() * 1.1
norm = mpl.colors.Normalize(0.0, extr)

pcm = ax.pcolormesh(xg[cut], yg[cut], dd, norm=norm)
plt.colorbar(pcm, ax=ax)

ax = axes[1]
hist, *_ = np.histogram2d(vv[0], vv[1], weights=ww, bins=(xx, yy))
xmid = kale.utils.midpoints(xx, log=True)
ymid = kale.utils.midpoints(yy, log=True)
pcm = ax.pcolormesh(xmid, ymid, hist.T, shading='nearest', norm=norm)
plt.colorbar(pcm, ax=ax)

ax = axes[2]
# dd = kale.utils.midpoints(dd, log=False, axis=None)
# dd /= (zz.size / (zz.size - 1))
err = (hist - dd) / dd
extr = np.fabs(err).max()
# extr = 0.01
# extr = 1.0
norm = mpl.colors.Normalize(-extr, extr)
# norm = mpl.colors.Normalize(-0.1, 0.1)

pcm = ax.pcolormesh(xg[cut], yg[cut], err, cmap='bwr', norm=norm)
plt.colorbar(pcm, ax=ax)

plt.show()


idx = np.all(pmf > thresh, axis=-1)
print(f"fraction above threshold: {np.count_nonzero(idx)/idx.size:.4e}")
check = err[idx]
print(f"errors above thresh     : {kale.utils.stats_str(check, format=':.4e')}")
if np.any(np.fabs(check) > 1e-10):
    raise ValueError(f"Error too large for interior region!")



idx = ~idx
print(f"fraction below threshold: {np.count_nonzero(idx)/idx.size:.4e}")
check = np.fabs(hist - dd)[idx]
# allow = 3 * np.sqrt(np.maximum(dd[idx], 1.0))
allow = 2 * np.fabs(dd[idx] - sp.stats.poisson.ppf(1.0 - 1.0/check.size, dd[idx]))
print(f"errors below thresh     : {kale.utils.stats_str(check, format=':.4e')}")
print(f"    compared to allowed : {kale.utils.stats_str(check/allow, format=':.4e')}")
if np.any(check > allow):
    raise ValueError(f"Error too large for outlier region!")
