In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import scipy.stats as stats
from math import sqrt

sys.path.insert(1, '../../data/')
from datasets import tanh_v1, tanh_v2, tanh_v3
sys.path.insert(1, '../../methods/')
from TriangularEstimators import ScaledCostOT
from metrics import mmd2

np.random.seed(42069)

## Evaluate generate samples with random trials

In [3]:
Ntrain = int(3000) ### fixed 
Ntest = int(2000) ### fixed

n_methods = ['ICNN','OTT','NN']
want_nn = True
if Ntrain <= 10000 and want_nn:
    n_methods.append('NN')
print(n_methods)
n_trials = 5 

### distribution and other parameters
dataset_options = ['tanh_v1','tanh_v2','tanh_v3']
if target == 'tanh_v1':
    pi = tanh_v1()
elif target == 'tanh_v2':
    pi = tanh_v2()
elif target == 'tanh_v3':
    pi = tanh_v3()

dim_x = dim_y = 1
betaz = [0.1, 0.08, 0.06, 0.04, 0.02, 0.01]
y_loc_mc = 50
y_locs = []

W1_dict = {method: np.zeros((n_trials, y_loc_mc, len(betaz))) for method in n_methods}
MMD_dict = {method: np.zeros((n_trials, y_loc_mc, len(betaz))) for method in n_methods}

## for NN

## parameters for methods
n_inters_ott = 500

## for FM method
for k,beta in enumerate(betaz):
    eps_ott = beta/5
    for i in range(n_trials):
        X = pi.sample_joint(Ntrain)
        X_source = np.hstack((X[:,0][:,None], np.random.randn(Ntrain,1)))
        X_target = pi.sample_joint(Ntrain)

        ## ICNN model
        if 'ICNN' in n_methods:
            ot_icnn = ScaledCostOT(dx1=dim_x,dx2=dim_y,beta=beta,estimator='ICNN', n_iters=n_iters_icnn) ## move this as a **args
            ot_icnn.fit(source=X_source, target=X_target)
            print('Done ICNN')

        ## Nearest-Neighbor estimator
        if 'NN' in n_methods: 
            ot_nn = ScaledCostOT(dx1=dim_x,dx2=dim_y,beta=beta,estimator='NN')
            ot_nn.fit(source=X_source, target=X_target)
            print('Done NN')

        ## Entropic map estimator
        if 'OTT' in n_methods:
            ot_ott = ScaledCostOT(dx1=dim_x,dx2=dim_y,beta=beta,estimator='OTT',eps=eps_ott)
            ot_ott.fit(source=X_source, target=X_target, max_iter=n_inters_ott)
            print('Done OTT')

        for j,y in enumerate(y_locs):
            data_cond_y = jnp.hstack((jnp.ones((Ntest,1))*y, jnp.array(np.random.randn(Ntest,1))))
            gen_samples_pi = pi.sample_conditional(y,Ntest)

            if 'ICNN' in n_methods:
                YX_transp_icnn = np.array(ot_icnn.evaluate(data_cond_y.copy()))
                W1_dict['ICNN'][i,j,k] = stats.wasserstein_distance(YX_transp_icnn[:,1],gen_samples_pi) ### change this to regular OT distance
                MMD_dict['ICNN'][i,j,k],_ = mmd2(YX_transp_icnn[:,1].reshape(Ntest,dim_x),gen_samples_pi.reshape(Ntest,dim_x))

            if 'NN' in n_methods:
                YX_transp_nn = np.array(ot_nn.evaluate(data_cond_y.copy()))
                W1_dict['NN'][i,j,k] = stats.wasserstein_distance(YX_transp_nn[:,1],gen_samples_pi)
                MMD_dict['NN'][i,j,k],_ = mmd2(YX_transp_nn[:,1].reshape(Ntest,dim_x),gen_samples_pi.reshape(Ntest,dim_x))

            if 'OTT' in n_methods: 
                YX_transp_ott = np.array(ot_ott.evaluate(data_cond_y.copy()))
                W1_dict['OTT'][i,j,k] = stats.wasserstein_distance(YX_transp_ott[:,1],gen_samples_pi)
                MMD_dict['OTT'][i,j,k],_ = mmd2(YX_transp_ott[:,1].reshape(Ntest,dim_x),gen_samples_pi.reshape(Ntest,dim_x))

all_info={'dataset':target,'n_iters_icnn':n_iters_icnn,
          'n_iters_ott':n_inters_ott,
          'w1_dict': W1_dict,'mmd_dict':MMD_dict,
          'ntrain':Ntrain,'ntest':Ntest}




['ICNN', 'OTT', 'SB', 'NN']


  self.log_det = 2*np.log(np.diag(self.cho_cov


  0%|          | 0/500 [00:00<?, ?it/s]

0
50
100
150
200
250
300
350
400
450
Done ICNN
Done NN
Done OTT


  self.log_det = 2*np.log(np.diag(self.cho_cov


  0%|          | 0/500 [00:00<?, ?it/s]

0
50
100
150
200
250
300
350
400
450
Done ICNN
Done NN
Done OTT


  self.log_det = 2*np.log(np.diag(self.cho_cov


  0%|          | 0/500 [00:00<?, ?it/s]

0
50
100
150
200
250
300
350
400
450
Done ICNN
Done NN
Done OTT


  self.log_det = 2*np.log(np.diag(self.cho_cov


  0%|          | 0/500 [00:00<?, ?it/s]

0
50
100
150
200
250
300
350
400
450
Done ICNN
Done NN
Done OTT


## Compute averages of the methods

In [4]:
### For W2
for method in n_methods:
    w1_vals = W1_dict[method]
    for j,y in enumerate(y_locs):
        w1_y = w1_vals[:,j]
        print(method + ' at ' + str(y) + ': ' + str(np.mean(w1_y)) + ' with std ' + str(np.std(w1_y)) )
    #print('OTT: ' + list(w2_vals.mean(axis=0)) +' with std ' + list(w2_vals.std(axis=0)) )

ICNN at -1.2: 0.4382549883988135 with std 0.07587373109793258
ICNN at 0: 0.15555593575183013 with std 0.040302177158286996
ICNN at 1.2: 0.4435870023600625 with std 0.14324935475986975
OTT at -1.2: 0.02355024022158647 with std 0.004054221137290739
OTT at 0: 0.03393047508282316 with std 0.0128942503238217
OTT at 1.2: 0.027500369528400192 with std 0.004023416661860557
SB at -1.2: 0.010819006840209179 with std 0.004410672771946524
SB at 0: 0.044644962909513114 with std 0.015677731379956338
SB at 1.2: 0.013020467500129381 with std 0.008928124783686027
NN at -1.2: 0.014045936579101944 with std 0.005820603425915202
NN at 0: 0.059680289025055296 with std 0.017489489182995675
NN at 1.2: 0.017204360536070353 with std 0.010403876517099992


In [5]:
for method in n_methods:
    mmd_vals = MMD_dict[method]
    for j,y in enumerate(y_locs):
        mmd_y = sqrt(mmd_vals[:,j])
        print(method + ' at ' + str(y) + ': ' + str(np.mean(mmd_y)) + ' with std ' + str(np.std(mmd_y)) )


TypeError: only size-1 arrays can be converted to Python scalars