In [None]:
# %load ./init.ipy
%reload_ext autoreload
%autoreload 2

# %load init.ipy
import os, sys, logging, datetime, warnings, shutil
from importlib import reload

import numpy as np
import scipy as sp
import scipy.stats
import matplotlib as mpl
import matplotlib.pyplot as plt
from nose import tools

import kalepy as kale
import kalepy.utils
import kalepy.plot
import kalepy.sample

# The `nbshow` command runs `plt.show()` in interactive jupyter notebooks, but closes
#   figures when run from the command-line (notebooks are converted to scripts and run as tests)
from kalepy.plot import nbshow

import warnings
# warnings.simplefilter("error")   # WARNING: this is breaking jupyter at the moment (2021-02-14)

In [None]:
def func(x, y):
    rv = (x + y)**2 + x**2 + y**2
    rv = 3 * np.exp(-rv / 2.0) + 1.0
    return rv

xx = np.linspace(-4, 4, 22)
yy = np.linspace(-4, 4, 23)

data = func(xx[:, np.newaxis], yy[np.newaxis, :]) 

rr = xx[:, np.newaxis]**2 + yy[np.newaxis, :]**2 + 1.0
ww = np.sqrt(rr)

extreme = 10**(data - 1) - 1

edges = [xx, yy]
data_edge = extreme
scalar_edge = ww

scalar_out, data_out = kale.sample.sample_grid(np.concatenate(edges), data_edge, 1000, scalar=scalar_edge)

In [None]:
np.random.seed(1234567)

In [None]:
def draw_example(xx, yy, data):
    fig, axes = plt.subplots(figsize=[12, 4], ncols=2)

    for ii, ax in enumerate(axes):
        interp = bool(ii)
        vals = kale.sample.sample_grid([xx, yy], data, 2000, interpolate=interp)
        ax.set_title(f"{interp=}")
        ax.scatter(*vals, color='pink', s=10, zorder=20, alpha=0.3)
        ax.hist2d(*vals, bins=10)

    return fig, vals

# Test 1

In [None]:
xx = [0.0, 1.0, 2.0]
yy = [100.0, 200.0, 300.0]

edge = [
    [0.0, 10.0, 0.0],
    [10.0, 20.0, 10.0],
    [0.0, 10.0, 0.0]
]

fig, vals = draw_example(xx, yy, edge)

nbshow()

## Compare 1/4ths

In [None]:
# ---- Make sure quadrants are evenly sampled
quads, *_ = np.histogram2d(*vals, bins=(xx, yy))

tot = quads.sum()
# fraction in each quadrant
frac = quads / tot
print("frac = ", frac.flatten())
# expected poisson error
err = 1.0 / np.sqrt(tot / 4.0)
err *= 3.0   # padding
extr = 0.25 * np.array([1.0 - err, 1.0 + err])
print("exp = ", 1.0/4.0, " +- ", err, "==>", extr)
# Make sure each quadrant is roughly equal
if np.any((frac < extr[0]) | (frac > extr[1])):
    raise ValueError("Uneven distribution of points!")

## Compare 1/16ths

In [None]:
edges = [np.sort(np.concatenate([kale.utils.midpoints(ee), ee])) for ee in [xx, yy]]
octs, *_ = np.histogram2d(*vals, bins=edges)

tot = octs.sum()
frac = octs / tot
ave = tot / octs.size
err = np.sqrt(ave)
extr = [ave - err, ave + err]
extr = np.array(extr) / tot
print(f"{ave=}, {err=}, {extr=}")

print("frac=\n", frac)
corners = np.ix_([0, 3], [0, 3])
corners = frac[corners].flatten()
print("corners=\n", corners)

centers = np.ix_([1, 2], [1, 2])
centers = frac[centers].flatten()
print("centers=\n", centers)

diags_h = np.ix_([1, 2], [0, 3])
diags_h = frac[diags_h]
diags_v = np.ix_([0, 3], [1, 2])
diags_v = frac[diags_v]
diags = np.concatenate([diags_h, diags_v]).flatten()
print("diags=\n", diags)

if not np.all(diags[:, np.newaxis] < centers[np.newaxis, :]):
    err = "diagonal values are not all smaller than center values!"
    raise ValueError(err)
    
if not np.all(corners[:, np.newaxis] < diags[np.newaxis, :]):
    err = "corner values are not all smaller than diagonal values!"
    raise ValueError(err)
    
if np.any((diags > 0.07) | (diags < 0.03)):
    raise ValueError("diagonals are out of bounds!")
    
if np.any((centers > 0.16) | (0.11 > centers)):
    raise ValueError("centers are out of bounds!")
    
if np.any((corners > 0.025) | (corners < 0.008)):
    raise ValueError("corners are out of bounds!")

# Test 2

In [None]:
xx = [0.0, 1.0, 2.0]
yy = [100.0, 200.0, 300.0]

edge = [
    [10.0, 0.0, 10.0],
    [ 0.0, 0.0,  0.0],
    [10.0, 0.0, 10.0]
]

fig, vals = draw_example(xx, yy, edge)
nbshow()

## Compare 1/4ths

In [None]:
# ---- Make sure quadrants are evenly sampled
quads, *_ = np.histogram2d(*vals, bins=(xx, yy))

tot = quads.sum()
# fraction in each quadrant
frac = quads / tot
print("frac = ", frac.flatten())
# expected poisson error
err = 1.0 / np.sqrt(tot / 4.0)
err *= 2.0   # padding
extr = 0.25 * np.array([1.0 - err, 1.0 + err])
print("exp = ", 1.0/4.0, " +- ", err, "==>", extr)
# Make sure each quadrant is roughly equal
if np.any((frac < extr[0]) | (frac > extr[1])):
    raise ValueError("Uneven distribution of points!")

## Compare 1/16ths

In [None]:
edges = [np.sort(np.concatenate([kale.utils.midpoints(ee), ee])) for ee in [xx, yy]]
octs, *_ = np.histogram2d(*vals, bins=edges)

tot = octs.sum()
frac = octs / tot
ave = tot / octs.size
err = np.sqrt(ave)
extr = [ave - err, ave + err]
extr = np.array(extr) / tot
print(f"{ave=}, {err=}, {extr=}")

print("frac=\n", frac)
corners = np.ix_([0, 3], [0, 3])
corners = frac[corners].flatten()
print("corners=\n", corners)

centers = np.ix_([1, 2], [1, 2])
centers = frac[centers].flatten()
print("centers=\n", centers)

diags_h = np.ix_([1, 2], [0, 3])
diags_h = frac[diags_h]
diags_v = np.ix_([0, 3], [1, 2])
diags_v = frac[diags_v]
diags = np.concatenate([diags_h, diags_v]).flatten()
print("diags=\n", diags)

if not np.all(corners[:, np.newaxis] > diags[np.newaxis, :]):
    err = "corner values are not all larger than diagonal values!"
    raise ValueError(err)
    
if not np.all(diags[:, np.newaxis] > centers[np.newaxis, :]):
    err = "diagonal values are not all larger than center values!"
    raise ValueError(err)
    
if np.any((diags > 0.06) | (diags < 0.03)):
    raise ValueError("diagonals are out of bounds!")
    
test = (corners > 0.16) | (corners < 0.11)
if np.any(test):
    raise ValueError("corners are out of bounds!")
    
if np.any((centers > 0.025) | (centers < 0.008)):
    raise ValueError("centers are out of bounds!")

# Test 3

In [None]:
xx = [0.0, 1.0, 2.0, 3.0]
yy = [100.0, 200.0, 300.0]

edge = [
    [10.0, 0.0, 10.0],
    [10.0, 0.0, 10.0],
    [10.0, 0.0, 10.0],
    [10.0, 0.0, 10.0],
]

fig, vals = draw_example(xx, yy, edge)
nbshow()

In [None]:
hist, *_ = np.histogram2d(*vals, bins=(3, 3))
tot = hist.sum()
frac = hist / tot
ave = tot / hist.size
err = np.sqrt(ave)
print(f"{ave=} {err=}")
print("frac=\n", frac)

rows = np.mean(frac, axis=0)
print("rows=", rows)

df = np.fabs(rows[-1] - rows[0])
print("df=", df)
if df > 0.015:
    raise ValueError("Edge rows do not match!")

if (rows[0] < 0.14) or (rows[0] > 0.16):
    raise ValueError("edge rows are out of bounds!")
    
if (rows[1] < 0.03) or (rows[1] > 0.05):
    raise ValueError("middle row is out of bounds!")

    
cols = np.mean(frac, axis=1)
print("cols=", cols)
df = np.fabs(cols[:, np.newaxis] - cols[np.newaxis, :]).max()
print("max diff=", df)
if df > 0.02:
    raise ValueError("columns are inconsistent!")

if (cols[0] < 0.095) or (cols[0] > 0.125):
    raise ValueError("columns are out of bounds!")

# Visual : test 3D function

In [None]:
def func(x, y, z):
    rv = (x + y - 2*z)**2 + x**2 + y**2 + z**2
    rv = 3 * np.exp(-rv / 2.0)
    return rv

xx = np.linspace(-4, 4, 200)
yy = np.linspace(-4, 4, 300)
zz = np.linspace(-1, 1, 5)

data = func(xx[:, np.newaxis, np.newaxis], yy[np.newaxis, :, np.newaxis], zz[np.newaxis, np.newaxis, :]) 
vals = kale.sample.sample_grid([xx, yy, zz], data, 2000)

# smap = zplot.smap(data, scale='linear')

# for ii, _ in enumerate(zz):
#     fig, ax = plt.subplots()
#     aa = data[:, :, ii]
#     plt.pcolormesh(xx, yy, aa.T, cmap=smap.cmap, norm=smap.norm)
    
# nbshow()
    

In [None]:
# smap = zplot.smap(data_edge, scale='lin')
extr = kale.utils.minmax(data)
norm = mpl.colors.Normalize(*extr)

# add an offset so that the last 'bin' is also matched
off = np.diff(zz)[0]*0.5
keys = np.searchsorted(zz, vals[2] + off) - 1
for ii in range(zz.size):
    idx = (keys == ii)
        
    fig, ax = plt.subplots()
    plt.pcolormesh(xx, yy, data[:, :, ii].T, norm=norm, shading='auto')
    uu = vals[0][idx]
    vv = vals[1][idx]
    plt.scatter(uu, vv, facecolor='pink', edgecolor='0.25', alpha=0.7, s=20, zorder=100)

nbshow()

# Visual : test 3D random data

In [None]:
data = kale.utils._random_data_3d_01()
data = data[:, ::10]

In [None]:
kde = kale.KDE(data)
edges, values = kde.density()
print([len(ee) for ee in edges], np.shape(values))

In [None]:
nsamp = 1000
samples = kale.sample.sample_grid(edges, values, nsamp)

In [None]:
corner = kale.Corner(kde)
corner.plot_kde()
corner.plot_data(samples, color='r')

nbshow()

# Scalar Values

In [None]:
def func(x, y):
    rv = (x + y)**2 + x**2 + y**2
    rv = 3 * np.exp(-rv / 2.0) + 1.0
    return rv

xx = np.linspace(-4, 4, 22)
yy = np.linspace(-4, 4, 23)

data = func(xx[:, np.newaxis], yy[np.newaxis, :]) 
vals, scal = kale.sample.sample_grid([xx, yy], data, 20000)

In [None]:
fig, axes = plt.subplots(figsize=[15, 6], ncols=2)

ax = axes[0]
ax.pcolormesh(xx, yy, data.T, shading='auto')

ax = axes[1]
hist, *_ = np.histogram2d(*vals, bins=(xx, yy))
print(hist.shape)
ax.pcolormesh(xx, yy, hist.T, shading='auto')

plt.show()

In [None]:
# data = func(xx[:, np.newaxis], yy[np.newaxis, :])
rr = xx[:, np.newaxis]**2 + yy[np.newaxis, :]**2 + 1.0
# ww = 4 * rr
# ww = zmath.rescale(rr)
ww = np.sqrt(rr)
# vals, scal = kale.sample.sample_grid([xx, yy], ww, 200000, scalar=data/ww, interpolate=True)
# zmath.stats_str(scal)

vals, scal = kale.sample.sample_grid_proportional([xx, yy], data, ww, 200000)

In [None]:
fig, axes = plt.subplots(figsize=[15, 6], ncols=3)

ax = axes[0]
ax.pcolormesh(xx, yy, data.T, shading='auto')

ax = axes[1]
hist, *_ = np.histogram2d(*vals, bins=(xx, yy))
ax.pcolormesh(xx, yy, hist.T, shading='auto')

ax = axes[2]
hist, *_ = np.histogram2d(*vals, bins=(xx, yy), weights=scal)
ax.pcolormesh(xx, yy, hist.T, shading='auto')

plt.show()

In [None]:
extreme = 10**(data - 1) - 1
extreme

In [None]:
xx.size, yy.size

In [None]:
edges = [xx, yy]
data_edge = extreme
scalar_edge = ww

scalar_out, data_out = kale.sample.sample_grid(np.concatenate(edges), data_edge, 1000, scalar=scalar_edge)

In [None]:
true_scalar_cent = kale.utils.midpoints(scalar_edge, axis=None)