## EMD calculation with neural nets

In [1]:
# add modules to Python's search path
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

# import our EMD calculation module(s)
from modules import krd_nn as emd

# import rest of the helper modules
import numpy as np
import tensorflow as tf

###  Create samplers for our test distributions

In [2]:
# make a convenient wrapper for producing samples in form of a tensor
def gaussian_sampler(mean, cov, size):
    samples = np.random.multivariate_normal(mean, cov, size)
    return tf.convert_to_tensor(samples, dtype=tf.float64)

# set up parameters for our two test distributions
dimension = 3
mean_1 = np.zeros(dimension)
mean_2 = mean_1 + 100.0 * np.ones(dimension)
cov_1 = np.identity(dimension)
cov_2 = cov_1

# finally create the samplers our test distributions
sampler_1 = lambda size: gaussian_sampler(mean_1, cov_1, size)
sampler_2 = lambda size: gaussian_sampler(mean_2, cov_2, size)

# test our samplers
print("samples from distribution #1:\n{}".format(sampler_1(3)))
print("samples from distribution #2:\n{}".format(sampler_2(3)))

samples from distribution #1:
[[-1.88306773  0.25573661 -0.13129384]
 [ 0.75592547 -0.13778148 -0.31176679]
 [-1.42285067  0.02273891 -0.4270862 ]]
samples from distribution #2:
[[100.2778101   98.86926193 101.15997845]
 [102.39656984  99.71910325  99.05671751]
 [ 97.88241582 101.52172509  98.74307171]]


### Calculate approximate EMD with the help of Kantorovich-Rubinstein duality

In [3]:
emd_calc= emd.W1L1Calc(sampler_1, sampler_2, clip_value=0.02)
emd_calc.calculate(epochs=500)

epoch = 1, EMD = 0.5211375675981307
epoch = 2, EMD = 0.0016709639248090108
epoch = 3, EMD = 0.004697678896276827
epoch = 4, EMD = 0.0069159184740328495
epoch = 5, EMD = 0.008619708889793637
epoch = 6, EMD = 0.010028984214473829
epoch = 7, EMD = 0.011196561856479355
epoch = 8, EMD = 0.012221130374969209
epoch = 9, EMD = 0.013100004965809955
epoch = 10, EMD = 0.013859836553318053
epoch = 11, EMD = 0.014539758675918484
epoch = 12, EMD = 0.015153520630995412
epoch = 13, EMD = 0.015747827761611145
epoch = 14, EMD = 0.016293718264077592
epoch = 15, EMD = 0.016740294554506922
epoch = 16, EMD = 0.01721189365612173
epoch = 17, EMD = 0.017687926084106985
epoch = 18, EMD = 0.018152513020873607
epoch = 19, EMD = 0.0186213968903894
epoch = 20, EMD = 0.019132710544976006
epoch = 21, EMD = 0.019663990849365116
epoch = 22, EMD = 0.02024383282434204
epoch = 23, EMD = 0.020885269100390496
epoch = 24, EMD = 0.02161952778680342
epoch = 25, EMD = 0.02237775703305087
epoch = 26, EMD = 0.023233555853829656
e

epoch = 219, EMD = 0.6836458184862495
epoch = 220, EMD = 0.6841576179399371
epoch = 221, EMD = 0.6816058229560218
epoch = 222, EMD = 0.685457803978855
epoch = 223, EMD = 0.6847501458264595
epoch = 224, EMD = 0.6849191796768225
epoch = 225, EMD = 0.6850102605684235
epoch = 226, EMD = 0.6844341075515649
epoch = 227, EMD = 0.6854303521011714
epoch = 228, EMD = 0.6838729382323214
epoch = 229, EMD = 0.681310922682216
epoch = 230, EMD = 0.6832514227289554
epoch = 231, EMD = 0.6835282816752827
epoch = 232, EMD = 0.683540364409126
epoch = 233, EMD = 0.6834226624772941
epoch = 234, EMD = 0.686864081298175
epoch = 235, EMD = 0.6853571304140279
epoch = 236, EMD = 0.6840697980632765
epoch = 237, EMD = 0.6867915155007316
epoch = 238, EMD = 0.6848018244716093
epoch = 239, EMD = 0.6868801931345917
epoch = 240, EMD = 0.6847432047630503
epoch = 241, EMD = 0.6865253853395095
epoch = 242, EMD = 0.684564481432699
epoch = 243, EMD = 0.6853659535852793
epoch = 244, EMD = 0.6842642485386313
epoch = 245, EMD 

epoch = 439, EMD = 0.6857537276368584
epoch = 440, EMD = 0.6854222003086767
epoch = 441, EMD = 0.6844832866420991
epoch = 442, EMD = 0.6856595537457917
epoch = 443, EMD = 0.6838780320501764
epoch = 444, EMD = 0.6838059114306437
epoch = 445, EMD = 0.6863726084335855
epoch = 446, EMD = 0.6854869285291285
epoch = 447, EMD = 0.6833893685962948
epoch = 448, EMD = 0.6882423107198855
epoch = 449, EMD = 0.6822842559594442
epoch = 450, EMD = 0.6860982390567942
epoch = 451, EMD = 0.6834989533258575
epoch = 452, EMD = 0.6861639546171903
epoch = 453, EMD = 0.6847934976512042
epoch = 454, EMD = 0.6829775121342658
epoch = 455, EMD = 0.6861439286951674
epoch = 456, EMD = 0.6834674016694007
epoch = 457, EMD = 0.685752988045804
epoch = 458, EMD = 0.6861444077659055
epoch = 459, EMD = 0.6834085590807593
epoch = 460, EMD = 0.6841900086645257
epoch = 461, EMD = 0.6847248301564537
epoch = 462, EMD = 0.6842763781258445
epoch = 463, EMD = 0.6852094001609915
epoch = 464, EMD = 0.6840344976265098
epoch = 465, 

<tf.Tensor: shape=(), dtype=float64, numpy=0.6840065476389268>