In [1]:
%load_ext autoreload
%autoreload 2
%autosave 2

Autosaving every 2 seconds


In [2]:
import sys
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import scipy.stats as stats
from math import sqrt
import pickle

sys.path.insert(1, '../../data/')
from datasets import tanh_v1, tanh_v2, tanh_v3
sys.path.insert(1, '../../methods/')
from TriangularEstimators import ScaledCostOT
from metrics import mmd2, empirical_wasserstein_distance

np.random.seed(42069)

## Evaluate generate samples with random trials

In [None]:
Ntrain = int(5000) ### fixed 
Ntest = int(2000) ### fixed

n_methods = ['ICNN','OTT','NN']
print(n_methods)
n_trials = 10 

### distribution and other parameters
datasets = ['tanh_v1','tanh_v2','tanh_v3']
dim_x = dim_y = 1
betaz = [0.06]
y_loc_mc = 50

## parameters for methods
n_iters_icnn    = 5000
n_inters_ott    = 5000
hidden_dim_icnn = 128
batch_size_icnn = 256
beta_scaling    = 5

for dataset in datasets:
    print(dataset)
    
    # set target distribution
    if dataset == 'tanh_v1':
        pi = tanh_v1()
    elif dataset == 'tanh_v2':
        pi = tanh_v2()
    elif dataset == 'tanh_v3':
        pi = tanh_v3()
    
    # define dictionary to save results
    W2_dict  = {method: np.zeros((n_trials, y_loc_mc, len(betaz))) for method in n_methods}
    MMD_dict = {method: np.zeros((n_trials, y_loc_mc, len(betaz))) for method in n_methods}
    
    # sample ylocs for each target
    y_locs = pi.sample_joint(y_loc_mc)[:,0]
    
    for k,beta in enumerate(betaz):
        
        # set epsilon
        eps_ott = beta/beta_scaling
        
        for i in range(n_trials):
            print('beta: '+str(beta)+', trial: '+str(i))
            
            X = pi.sample_joint(Ntrain)
            X_source = np.hstack((X[:,0][:,None], np.random.randn(Ntrain,1)))
            X_target = pi.sample_joint(Ntrain)

            ## ICNN model
            if 'ICNN' in n_methods:
                ot_icnn = ScaledCostOT(dx1=dim_x,dx2=dim_y,beta=beta,estimator='ICNN', 
                                       n_iters=n_iters_icnn, hidden_dim=hidden_dim_icnn, 
                                       batch_size=batch_size_icnn)
                ot_icnn.fit(source=X_source, target=X_target)
                print('Done ICNN')

            ## Nearest-Neighbor estimator
            if 'NN' in n_methods: 
                ot_nn = ScaledCostOT(dx1=dim_x,dx2=dim_y,beta=beta,estimator='NN')
                ot_nn.fit(source=X_source, target=X_target)
                print('Done NN')

            ## Entropic map estimator
            if 'OTT' in n_methods:
                ot_ott = ScaledCostOT(dx1=dim_x,dx2=dim_y,beta=beta,estimator='OTT',eps=eps_ott)
                ot_ott.fit(source=X_source, target=X_target, max_iter=n_inters_ott)
                print('Done OTT')

            for j,y in enumerate(y_locs):
                data_cond_y = jnp.hstack((jnp.ones((Ntest,1))*y, jnp.array(np.random.randn(Ntest,1))))
                gen_samples_pi = pi.sample_conditional(y,Ntest)[:,None]

                if 'ICNN' in n_methods:
                    YX_transp_icnn = np.array(ot_icnn.evaluate(data_cond_y.copy()))
                    W2_dict['ICNN'][i,j,k] = empirical_wasserstein_distance(YX_transp_icnn[:,1][:,None],gen_samples_pi) ### change this to regular OT distance
                    MMD_dict['ICNN'][i,j,k],_ = mmd2(YX_transp_icnn[:,1][:,None],gen_samples_pi)
                    
                if 'NN' in n_methods:
                    YX_transp_nn = np.array(ot_nn.evaluate(data_cond_y.copy()))
                    W2_dict['NN'][i,j,k] = empirical_wasserstein_distance(YX_transp_nn[:,1][:,None],gen_samples_pi)
                    MMD_dict['NN'][i,j,k],_ = mmd2(YX_transp_nn[:,1][:,None],gen_samples_pi)
                    
                if 'OTT' in n_methods: 
                    YX_transp_ott = np.array(ot_ott.evaluate(data_cond_y.copy()))
                    W2_dict['OTT'][i,j,k] = empirical_wasserstein_distance(YX_transp_ott[:,1][:,None],gen_samples_pi)
                    MMD_dict['OTT'][i,j,k],_ = mmd2(YX_transp_ott[:,1][:,None],gen_samples_pi)

            print('Done Evaluations')
                    
    # save results
    all_info={'dataset':dataset,'n_iters_icnn':n_iters_icnn,
              'n_iters_ott':n_inters_ott,
              'W2_dict': W2_dict,'mmd_dict':MMD_dict,
              'ntrain':Ntrain,'ntest':Ntest,'yloc':y_locs}
    with open(dataset+'.pkl', 'wb') as handle:
        pickle.dump(all_info, handle, protocol=pickle.HIGHEST_PROTOCOL)

['ICNN', 'OTT', 'NN']
tanh_v1
beta: 0.06, trial: 0


  self.log_det = 2*np.log(np.diag(self.cho_cov


0
50
100
150
200
250
300
350
400
450
500
550
600
650
700
750
800
850
900
950
1000
1050
1100
1150
1200
1250
1300
1350
1400
1450
1500
1550
1600
1650
1700
1750
1800
1850
1900
1950
2000
2050
2100
2150
2200
2250
2300
2350
2400
2450
2500
2550
2600
2650
2700
2750
2800
2850
2900
2950
3000
3050
3100
3150
3200
3250
3300
3350
3400
3450
3500
3550
3600
3650
3700
3750
3800
3850
3900
3950
4000
4050
4100
4150
4200
4250
4300
4350
4400
4450
4500
4550
4600
4650
4700
4750
4800
4850
4900
4950
Done ICNN
Done NN
Done OTT
Done Evaluations
beta: 0.06, trial: 1


  self.log_det = 2*np.log(np.diag(self.cho_cov


0
50
100
150
200
250
300
350
400
450
500
550
600
650
700
750
800
850
900
950
1000
1050
1100
1150
1200
1250
1300
1350
1400
1450
1500
1550
1600
1650
1700
1750
1800
1850
1900
1950
2000
2050
2100
2150
2200
2250
2300
2350
2400
2450
2500
2550
2600
2650
2700
2750
2800
2850
2900
2950
3000
3050
3100
3150
3200
3250
3300
3350
3400
3450
3500
3550
3600
3650
3700
3750
3800
3850
3900
3950
4000
4050
4100
4150
4200
4250
4300
4350
4400
4450
4500
4550
4600
4650
4700
4750
4800
4850
4900
4950
Done ICNN
Done NN
Done OTT
Done Evaluations
beta: 0.06, trial: 2


  self.log_det = 2*np.log(np.diag(self.cho_cov


0
50
100
150
200
250
300
350
400
450
500
550
600
650
700
750
800
850
900
950
1000
1050
1100
1150
1200
1250
1300
1350
1400
1450
1500
1550
1600
1650
1700
1750
1800
1850
1900
1950
2000
2050
2100
2150
2200
2250
2300
2350
2400
2450
2500
2550
2600
2650
2700
2750
2800
2850
2900
2950
3000
3050
3100
3150
3200
3250
3300
3350
3400
3450
3500
3550
3600
3650
3700
3750
3800
3850
3900
3950
4000
4050
4100
4150
4200
4250
4300
4350
4400
4450
4500
4550
4600
4650
4700
4750
4800
4850
4900
4950
Done ICNN
Done NN
Done OTT
Done Evaluations
beta: 0.06, trial: 3


  self.log_det = 2*np.log(np.diag(self.cho_cov


0
50
100
150
200
250
300
350
400
450
500
550
600
650
700
750
800
850
900
950
1000
1050
1100
1150
1200
1250
1300
1350
1400
1450
1500
1550
1600
1650
1700
1750
1800
1850
1900
1950
2000
2050
2100
2150
2200
2250
2300
2350
2400
2450
2500
2550
2600
2650
2700
2750
2800
2850
2900
2950
3000
3050
3100
3150
3200
3250
3300
3350
3400
3450
3500
3550
3600
3650
3700
3750
3800
3850
3900
3950
4000
4050
4100
4150
4200
4250
4300
4350
4400
4450
4500
4550
4600
4650
4700
4750
4800
4850
4900
4950
Done ICNN
Done NN
Done OTT
Done Evaluations
beta: 0.06, trial: 4


  self.log_det = 2*np.log(np.diag(self.cho_cov


0
50
100
150
200
250
300
350
400
450
500
550
600
650
700
750
800
850
900
950
1000
1050
1100
1150
1200
1250
1300
1350
1400
1450
1500
1550
1600
1650
1700
1750
1800
1850
1900
1950
2000
2050
2100
2150
2200
2250
2300
2350
2400
2450
2500
2550
2600
2650
2700
2750
2800
2850
2900
2950
3000
3050
3100
3150
3200
3250
3300
3350
3400
3450
3500
3550
3600
3650
3700
3750
3800
3850
3900
3950
4000
4050
4100
4150
4200
4250
4300
4350
4400
4450
4500
4550
4600
4650
4700
4750
4800
4850
4900
4950
Done ICNN
Done NN
Done OTT
Done Evaluations
beta: 0.06, trial: 5


  self.log_det = 2*np.log(np.diag(self.cho_cov


0
50
100
150
200
250
300
350
400
450
500
550
600
650
700
750
800
850
900
950
1000
1050
1100
1150
1200
1250
1300
1350
1400
1450
1500
1550
1600
1650
1700
1750
1800
1850
1900
1950
2000
2050
2100
2150
2200
2250
2300
2350
2400
2450
2500
2550
2600
2650
2700
2750
2800
2850
2900
2950
3000
3050
3100
3150
3200
3250
3300
3350
3400
3450
3500
3550
3600
3650
3700
3750
3800
3850
3900
3950
4000
4050
4100
4150
4200
4250
4300
4350
4400
4450
4500
4550
4600
4650
4700
4750
4800
4850
4900
4950
Done ICNN
Done NN
Done OTT
Done Evaluations
beta: 0.06, trial: 6


  self.log_det = 2*np.log(np.diag(self.cho_cov


0
50
100
150
200
250
300
350
400
450
500
550
600
650
700
750
800
850
900
950
1000
1050
1100
1150
1200
1250
1300
1350
1400
1450
1500
1550
1600
1650
1700
1750
1800
1850
1900
1950
2000
2050
2100
2150
2200
2250
2300
2350
2400
2450
2500
2550
2600
2650
2700
2750
2800
2850
2900
2950
3000
3050
3100
3150
3200
3250
3300
3350
3400
3450
3500
3550
3600
3650
3700
3750
3800
3850
3900
3950
4000
4050
4100
4150
4200
4250
4300
4350
4400
4450
4500
4550
4600
4650
4700
4750
4800
4850
4900
4950
Done ICNN
Done NN
Done OTT
Done Evaluations
beta: 0.06, trial: 7


  self.log_det = 2*np.log(np.diag(self.cho_cov


0
50
100
150
200
250
300
350
400
450
500
550
600
650
700
750
800
850
900
950
1000
1050
1100
1150
1200
1250
1300
1350
1400
1450
1500
1550
1600
1650
1700
1750
1800
1850
1900
1950
2000
2050
2100
2150
2200
2250
2300
2350
2400
2450
2500
2550
2600
2650
2700
2750
2800
2850
2900
2950
3000
3050
3100
3150
3200
3250
3300
3350
3400
3450
3500
3550
3600
3650
3700
3750
3800
3850
3900
3950
4000
4050
4100
4150
4200
4250
4300
4350
4400
4450
4500
4550
4600
4650
4700
4750
4800
4850
4900
4950
Done ICNN
Done NN
Done OTT
Done Evaluations
beta: 0.06, trial: 8


  self.log_det = 2*np.log(np.diag(self.cho_cov


0
50
100
150
200
250
300
350
400
450
500
550
600
650
700
750
800
850
900
950
1000
1050
1100
1150
1200
1250
1300
1350
1400
1450
1500
1550
1600
1650
1700
1750
1800
1850
1900
1950
2000
2050
2100
2150
2200
2250
2300
2350
2400
2450
2500
2550
2600
2650
2700
2750
2800
2850
2900
2950
3000
3050
3100
3150
3200
3250
3300
3350
3400
3450
3500
3550
3600
3650
3700
3750
3800
3850
3900
3950
4000
4050
4100
4150
4200
4250
4300
4350
4400
4450
4500
4550
4600
4650
4700
4750
4800
4850
4900
4950
Done ICNN
Done NN
Done OTT
Done Evaluations
beta: 0.06, trial: 9


  self.log_det = 2*np.log(np.diag(self.cho_cov


0
50
100
150
200
250
300
350
400
450
500
550
600
650
700
750
800
850
900
950
1000
1050
1100
1150
1200
1250
1300
1350
1400
1450
1500
1550
1600
1650
1700
1750
1800
1850
1900
1950
2000
2050
2100
2150
2200
2250
2300
2350
2400
2450
2500
2550
2600
2650
2700
2750
2800
2850
2900
2950
3000
3050
3100
3150
3200
3250
3300
3350
3400
3450
3500
3550
3600
3650
3700
3750
3800
3850
3900
3950
4000
4050
4100
4150
4200
4250
4300
4350
4400
4450
4500
4550
4600
4650
4700
4750
4800
4850
4900
4950
Done ICNN
Done NN
Done OTT
Done Evaluations
tanh_v2
beta: 0.06, trial: 0


  self.log_det = 2*np.log(np.diag(self.cho_cov


0
50
100
150
200
250
300
350
400
450
500
550
600
650
700
750
800
850
900
950
1000
1050
1100
1150
1200
1250
1300
1350
1400
1450
1500
1550
1600
1650
1700
1750
1800
1850
1900
1950
2000
2050
2100
2150
2200
2250
2300
2350
2400
2450
2500
2550
2600
2650
2700
2750
2800
2850
2900
2950
3000
3050
3100
3150
3200
3250
3300
3350
3400
3450
3500
3550
3600
3650
3700
3750
3800
3850
3900
3950
4000
4050
4100
4150
4200
4250
4300
4350
4400
4450
4500
4550
4600
4650
4700
4750
4800
4850
4900
4950
Done ICNN
Done NN
Done OTT
Done Evaluations
beta: 0.06, trial: 1


  self.log_det = 2*np.log(np.diag(self.cho_cov


0
50
100
150
200
250
300
350
400
450
500
550
600
650
700
750
800
850
900
950
1000
1050
1100
1150
1200
1250
1300
1350
1400
1450
1500
1550
1600
1650
1700
1750
1800
1850
1900
1950
2000
2050
2100
2150
2200
2250
2300
2350
2400
2450
2500
2550
2600
2650
2700
2750
2800
2850
2900
2950
3000
3050
3100
3150
3200
3250
3300
3350
3400
3450
3500
3550
3600
3650
3700
3750
3800
3850
3900
3950
4000
4050
4100
4150
4200
4250
4300
4350
4400
4450
4500
4550
4600
4650
4700
4750
4800
4850
4900
4950
Done ICNN
Done NN
Done OTT
Done Evaluations
beta: 0.06, trial: 2


  self.log_det = 2*np.log(np.diag(self.cho_cov


0
50
100
150
200
250
300
350
400
450
500
550
600
650
700
750
800
850
900
950
1000
1050
1100
1150
1200
1250
1300
1350
1400
1450
1500
1550
1600
1650
1700
1750
1800
1850
1900
1950
2000
2050
2100
2150
2200
2250
2300
2350
2400
2450
2500
2550
2600
2650
2700
2750
2800
2850
2900
2950
3000
3050
3100
3150
3200
3250
3300
3350
3400
3450
3500
3550
3600
3650
3700
3750
3800
3850
3900
3950
4000
4050
4100
4150
4200
4250
4300
4350
4400
4450
4500
4550
4600
4650
4700
4750
4800
4850
4900
4950
Done ICNN
Done NN
Done OTT
Done Evaluations
beta: 0.06, trial: 3


  self.log_det = 2*np.log(np.diag(self.cho_cov


0
50
100
150
200
250
300
350
400
450
500
550
600
650
700
750
800
850
900
950
1000
1050
1100
1150
1200
1250
1300
1350
1400
1450
1500
1550
1600
1650
1700
1750
1800
1850
1900
1950
2000
2050
2100
2150
2200
2250
2300
2350
2400
2450
2500
2550
2600
2650
2700
2750
2800
2850
2900
2950
3000
3050
3100
3150
3200
3250
3300
3350
3400
3450
3500
3550
3600
3650
3700
3750
3800
3850
3900
3950
4000
4050
4100
4150
4200
4250
4300
4350
4400
4450


## Compute averages of the methods

In [6]:
datasets = ['tanh_v1','tanh_v2','tanh_v3']
n_methods = ['ICNN','OTT','NN']
betaz = [0.06]

### For W2
for dataset in datasets:
    print(dataset+str(' (W2)'))

    # load data
    with open(dataset+'.pkl', 'rb') as handle:
        saved_data = pickle.load(handle)
        
    # compute data averages
    for method in n_methods:
        w2_vals = saved_data['W2_dict'][method]

        for k,beta in enumerate(betaz):
            w2_vals_ylocmean = np.mean(w2_vals[:,:,k],axis=1)
            print(method + ' beta: ' +str(beta)+ ' = ' + str(np.mean(w2_vals_ylocmean)) + ' with std ' + str(np.std(w2_vals_ylocmean)) )

### For MMD
for dataset in datasets:
    print(dataset+str(' (MMD)'))
    
    # load data
    with open(dataset+'.pkl', 'rb') as handle:
        saved_data = pickle.load(handle)
        
    # compute data averages
    for method in n_methods:
        w2_vals = saved_data['mmd_dict'][method]

        for k,beta in enumerate(betaz):
            w2_vals_ylocmean = np.mean(w2_vals[:,:,k],axis=1)
            print(method + ' beta: ' +str(beta)+ ' = ' + str(np.mean(w2_vals_ylocmean)) + ' with std ' + str(np.std(w2_vals_ylocmean)) )

tanh_v1 (W2)
ICNN beta: 0.06 = 0.009220261444736182 with std 0.0023568022316583126
OTT beta: 0.06 = 0.00444927872913199 with std 0.0009819901982683647
NN beta: 0.06 = 0.006048213204406266 with std 0.0009819725409280534
tanh_v2 (W2)
ICNN beta: 0.06 = 0.20518553971959202 with std 0.07437512221019801
OTT beta: 0.06 = 0.000996972180550402 with std 0.00034864531721306817
NN beta: 0.06 = 0.0012726141393230998 with std 0.0004164064795025003
tanh_v3 (W2)
ICNN beta: 0.06 = 0.010406350360353369 with std 0.001649834798887412
OTT beta: 0.06 = 0.0036517988276162537 with std 0.00036546967083527247
NN beta: 0.06 = 0.0034295095532638174 with std 0.00044815891234992075
tanh_v1 (MMD)
ICNN beta: 0.06 = 0.019649992727241793 with std 0.006162869600226513
OTT beta: 0.06 = 0.016524141867409976 with std 0.002608574034961806
NN beta: 0.06 = 0.013689045488685847 with std 0.00502898534510591
tanh_v2 (MMD)
ICNN beta: 0.06 = 0.4656833156589421 with std 0.038706327424482254
OTT beta: 0.06 = 0.30197849571666924 with