In [1]:
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import dblquad
from scipy.interpolate import interp2d
from scipy.stats import multivariate_normal
from mpl_toolkits import mplot3d
from matplotlib import colors, cm, animation

plt.rcParams['figure.figsize'] = (10, 10)

In [2]:
animation.writers.list()

['pillow', 'ffmpeg', 'ffmpeg_file', 'imagemagick', 'imagemagick_file', 'html']

In [3]:
Writer = animation.writers['pillow']
writer = Writer(fps=5)

# Original distribution

In [4]:
d0 = multivariate_normal(mean=[5, 22, 54], cov=[[8,2,1],[2,2,1],[1,2,4]])
d1 = multivariate_normal(mean=[11, 31, 108], cov=[[1,-2,1],[-1,2,-1],[1,-1,2]])

d0xs = [
    multivariate_normal(mean=5, cov=8),
    multivariate_normal(mean=22, cov=2),
    multivariate_normal(mean=54, cov=4),
]
d1xs = [
    multivariate_normal(mean=11, cov=1),
    multivariate_normal(mean=31, cov=2),
    multivariate_normal(mean=108, cov=2),
]

In [5]:
def comb(pos):
    return (d0.pdf(pos) + d1.pdf(pos))/2

In [92]:
knots_x0 = np.linspace(-1, 20, 10, endpoint=True)
knots_x1 = np.linspace(15, 36, 8, endpoint=True)
knots_x2 = np.linspace(44, 118, 9, endpoint=True)

x0, x1, x2 = np.meshgrid(knots_x0, knots_x1, knots_x2, indexing='ij')
pos = np.stack((x0, x1, x2), axis=-1)
values = comb(pos)
values.shape

(10, 8, 9)

In [93]:
def weighted_avg_and_std(values, weights):
    """
    Return the weighted average and standard deviation.

    values, weights -- Numpy ndarrays with the same shape.
    """
    average = np.average(values, weights=weights)
    # Fast and numerically precise:
    variance = np.average((values-average)**2, weights=weights)
    return (average, np.sqrt(variance))

In [94]:
m0 = values.sum(axis=-1).sum(axis=-1)
print(weighted_avg_and_std(knots_x0, weights=m0))
m1 = values.sum(axis=-1).sum(axis=0)
print(weighted_avg_and_std(knots_x1, weights=m1))
m2 = values.sum(axis=0).sum(axis=0)
print(weighted_avg_and_std(knots_x2, weights=m2))

(7.3790898193707015, 3.882799178472711)
(25.792886155799668, 5.0025838581389195)
(79.43038557938095, 27.705095527138454)


In [95]:
subset = values[::3,::3,::3]

In [96]:
norm = colors.SymLogNorm(1e-6, vmin=subset.min(), vmax=subset.max())
normed = norm(subset.ravel()).reshape(subset.shape)
c = cm.get_cmap('jet')(normed)
c[:,:,:,3] = normed

In [97]:
fig = plt.figure()
ax = plt.axes(projection='3d')

cuts = np.linspace(1., 0., 5)
for a, b in zip(cuts[:-1], cuts[1:]):
    mask = ((normed > b) & (normed <= a)).astype(np.float32)
    ax.voxels(mask, facecolors=c)

ax.set_xticks(np.arange(len(knots_x0[::3]), step=2))
ax.set_xticklabels([f'{k:.2f}' for k in knots_x0[::6]])
ax.set_yticks(np.arange(len(knots_x1[::3]), step=2))
ax.set_yticklabels([f'{k:.2f}' for k in knots_x1[::6]])
ax.set_zticks(np.arange(len(knots_x2[::3]), step=2))
ax.set_zticklabels([f'{k:.2f}' for k in knots_x2[::6]])

plt.show()

#def rotate(angle):
#     ax.view_init(azim=angle)
#
#ani = animation.FuncAnimation(fig, rotate, frames=np.arange(0, 360, 10), interval=50)
#ani.save('/home/aalvarez/Downloads/truth_3d.gif', writer=writer, dpi=72)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [98]:
knots_x0.min(), np.diff(knots_x0)[0], knots_x0.shape

(-1.0, 2.3333333333333335, (10,))

In [106]:
(20- -1)/9

2.3333333333333335

In [107]:
knots_x1.min(), np.diff(knots_x1)[0], knots_x1.shape

(15.0, 3.0, (8,))

In [109]:
(36-15)/7

3.0

In [110]:
knots_x2.min(), np.diff(knots_x2)[0], knots_x2.shape

(44.0, 9.25, (9,))

In [112]:
(118-44)/8

9.25

In [113]:
np.save('/home/aalvarez/Work/Projects/Alexandria/Alexandria/MathUtils/tests/src/PDF/pdf3d.npy', values)

In [114]:
values.shape

(10, 8, 9)

# Compare 3D

In [133]:
test3_str = open('/tmp/sample3d.txt').read()
test3_str = test3_str[test3_str.index('>')+1:]
test3 = np.fromstring(test3_str, sep=',').reshape(-1, 3)
len(test3)

100000

In [134]:
centers_x0 = (knots_x0[:-1]+knots_x0[1:])/2
centers_x1 = (knots_x1[:-1]+knots_x1[1:])/2
centers_x2 = (knots_x2[:-1]+knots_x2[1:])/2

In [135]:
H, (e0, e1, e2) = np.histogramdd(test3, bins=40, density=True)

In [136]:
norm = colors.SymLogNorm(1e-3, vmin=H[H > 0].min(), vmax=H.max())
normed = norm(H.ravel()).reshape(H.shape)
c = cm.get_cmap('jet')(normed)
c[:,:,:,3] = normed

In [137]:
np.mean(test3[:,2])

79.361777432

In [138]:
fig, axes = plt.subplots(ncols=3, figsize=(12, 4))
for i, ax in enumerate(axes):
    _, edges, _ = ax.hist(test3[:,i], density=True, bins=150)
    xs = np.linspace(edges[0], edges[-1])
    ax.plot(xs, d0xs[i].pdf(xs)/2, label='Population 1')
    ax.plot(xs, d1xs[i].pdf(xs)/2, label='Population 2')
    ax.plot(xs, d0xs[i].pdf(xs)/2+d1xs[i].pdf(xs)/2, color='red', label='Total', linestyle=':')
    ax.set_title(f'Axis {i}')
axes[0].legend()
plt.tight_layout()    
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [139]:
a, b = knots_x0[1], knots_x0[5]
mask = (test3[:,0] >= a) & (test3[:,0]<b)
plt.figure()
count, edges = np.histogram(test3[mask,0], bins=200)
centers = (edges[1:]+edges[:-1])/2
w = np.diff(centers)[0]
plt.bar(centers, (count/len(test3))/w, width=w)
xs = np.linspace(a, b)
plt.plot(xs, d0xs[0].pdf(xs)/2+d1xs[0].pdf(xs)/2, linestyle='--', color='red')
plt.vlines(knots_x0[(knots_x0>=a) & (knots_x0<b)], 0, 0.08, linestyle=':', color='black', alpha=0.8)
plt.plot()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[]