In [3]:
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import dblquad
from scipy.interpolate import interp2d
from scipy.stats import multivariate_normal

plt.rcParams['figure.figsize'] = (10, 10)

# Original distribution

In [4]:
d0 = multivariate_normal(mean=[5, 22], cov=[[3,2],[2,2]])
d1 = multivariate_normal(mean=[11, 31], cov=[[3,-2],[-2,2]])

d0x = multivariate_normal(mean=[5], cov=3)
d1x = multivariate_normal(mean=[11], cov=3)
d0y = multivariate_normal(mean=[22], cov=2)
d1y = multivariate_normal(mean=[31], cov=2)

In [5]:
def comb(pos):
    return (d0.pdf(pos) + d1.pdf(pos))/2

def combx(pos):
    return (d0x.pdf(pos) + d1x.pdf(pos)) / 2

def comby(pos):
    return (d0y.pdf(pos) + d1y.pdf(pos)) / 2

In [6]:
knots_x = np.linspace(-1, 20, 200)
knots_y = np.linspace(15, 36, 200)

x, y = np.meshgrid(knots_x, knots_y)
pos = np.dstack((x, y))
values = comb(pos)
values.shape

(200, 200)

In [5]:
fig = plt.figure()

gs = fig.add_gridspec(2, 2, hspace=0, wspace=0, width_ratios=[1., 0.5], height_ratios=[1., 0.5])
axs = gs.subplots(sharex='col', sharey='row')
axs[1,1].remove()

axs[0,0].contourf(knots_x, knots_y, values)
axs[0,1].plot(comby(knots_y), knots_y)
axs[1,0].plot(knots_x, combx(knots_x))

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [6]:
knots_x

array([-1.        , -0.89447236, -0.78894472, -0.68341709, -0.57788945,
       -0.47236181, -0.36683417, -0.26130653, -0.15577889, -0.05025126,
        0.05527638,  0.16080402,  0.26633166,  0.3718593 ,  0.47738693,
        0.58291457,  0.68844221,  0.79396985,  0.89949749,  1.00502513,
        1.11055276,  1.2160804 ,  1.32160804,  1.42713568,  1.53266332,
        1.63819095,  1.74371859,  1.84924623,  1.95477387,  2.06030151,
        2.16582915,  2.27135678,  2.37688442,  2.48241206,  2.5879397 ,
        2.69346734,  2.79899497,  2.90452261,  3.01005025,  3.11557789,
        3.22110553,  3.32663317,  3.4321608 ,  3.53768844,  3.64321608,
        3.74874372,  3.85427136,  3.95979899,  4.06532663,  4.17085427,
        4.27638191,  4.38190955,  4.48743719,  4.59296482,  4.69849246,
        4.8040201 ,  4.90954774,  5.01507538,  5.12060302,  5.22613065,
        5.33165829,  5.43718593,  5.54271357,  5.64824121,  5.75376884,
        5.85929648,  5.96482412,  6.07035176,  6.1758794 ,  6.28

In [7]:
knots_y

array([15.        , 15.10552764, 15.21105528, 15.31658291, 15.42211055,
       15.52763819, 15.63316583, 15.73869347, 15.84422111, 15.94974874,
       16.05527638, 16.16080402, 16.26633166, 16.3718593 , 16.47738693,
       16.58291457, 16.68844221, 16.79396985, 16.89949749, 17.00502513,
       17.11055276, 17.2160804 , 17.32160804, 17.42713568, 17.53266332,
       17.63819095, 17.74371859, 17.84924623, 17.95477387, 18.06030151,
       18.16582915, 18.27135678, 18.37688442, 18.48241206, 18.5879397 ,
       18.69346734, 18.79899497, 18.90452261, 19.01005025, 19.11557789,
       19.22110553, 19.32663317, 19.4321608 , 19.53768844, 19.64321608,
       19.74874372, 19.85427136, 19.95979899, 20.06532663, 20.17085427,
       20.27638191, 20.38190955, 20.48743719, 20.59296482, 20.69849246,
       20.8040201 , 20.90954774, 21.01507538, 21.12060302, 21.22613065,
       21.33165829, 21.43718593, 21.54271357, 21.64824121, 21.75376884,
       21.85929648, 21.96482412, 22.07035176, 22.1758794 , 22.28

In [8]:
#values.ravel()

In [9]:
values.shape

(200, 200)

# Compare 2D

In [7]:
test2_str = open('/tmp/sample2d.txt').read()
test2_str = test2_str[test2_str.index('>')+1:]
test2 = np.fromstring(test2_str, sep=',').reshape(-1, 2)
len(test2)

20000

In [8]:
centers_x = (knots_x[:-1]+knots_x[1:])/2
centers_y = (knots_y[:-1]+knots_y[1:])/2

In [16]:
fig = plt.figure(figsize=(12, 12))
gs = fig.add_gridspec(2, 2, hspace=0, wspace=0, width_ratios=[1., 0.5], height_ratios=[1., 0.5])
axs = gs.subplots(sharex='col', sharey='row')
axs[1,1].remove()

axs[0,0].hist2d(test2[:,1], test2[:,0], density=True, bins=100)
axs[0,0].contour(knots_x, knots_y, values, cmap='hot')

axs[0,1].hist(test2[:,0], orientation='horizontal', density=True, bins=100, histtype='step')
axs[0,1].plot(comby(np.linspace(knots_y.min(), knots_y.max())), np.linspace(knots_y.min(), knots_y.max()))
axs[0,1].set_yticks(knots_y[::10])

axs[1,0].hist(test2[:,1], density=True, bins=100, histtype='step')
axs[1,0].plot(np.linspace(knots_x.min(), knots_x.max()), combx(np.linspace(knots_x.min(), knots_x.max())), label='True distribution')
axs[1,0].set_xticks(knots_x[::10])
axs[1,0].legend()

plt.tight_layout()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [17]:
np.unravel_index(values.argmax(), values.shape)

(152, 113)

In [18]:
test2[:,0].mean()

26.488938745000002

In [19]:
knots_y[17]

16.79396984924623