In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
import sys
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import scipy.stats as stats
from math import sqrt
import pickle

sys.path.insert(1, '../../data/')
from datasets import Banana, tanh_v1, tanh_v2, tanh_v3
sys.path.insert(1, '../../methods/')
from TriangularEstimators import ScaledCostOT
sys.path.insert(1,'../../.') 
from metrics import *

np.random.seed(42069)

## Evaluate generate samples with random trials

In [3]:
Ntrain = int(2000) ### fixed 
Ntest = int(2000) ### fixed

n_methods = ['OTT'] ##'ICNN', 'OTT', 'SB'
want_nn = True
if Ntrain <= 10000 and want_nn:
    n_methods.append('NN')
print(n_methods)
n_trials = 10 ### change to 5-10 later

### distribution and other parameters
dataset_options = ['Banana','tanh_v1','tanh_v2']

dim_x = dim_y = 1
betaz = [0.1, 0.08, 0.06, 0.04, 0.02, 0.01]
y_loc_mc = 50

W1_dict = {method: np.zeros((n_trials, y_loc_mc, len(betaz))) for method in n_methods}
MMD_dict = {method: np.zeros((n_trials, y_loc_mc, len(betaz))) for method in n_methods}

## for NN
n_iters_icnn = 5000
hidden_dim_icnn = 128
batch_size_icnn = 256
## for OTT method
n_inters_ott = 5000
beta_scaling = 5
## for SB
# tau_sb = 0.99
# Nsteps_sb = 50

for dset in dataset_options:
    if dset == 'tanh_v1':
        pi = tanh_v1()
    elif dset == 'tanh_v2':
        pi = tanh_v2()
    elif dset == 'tanh_v3':
        pi = tanh_v3()
    elif dset == 'Banana':
        pi = Banana(reverse=False)

    y_locz = pi.sample_joint(y_loc_mc)[:,1]
    for k,beta in enumerate(betaz):
        eps_ott = beta/beta_scaling
        for i in range(n_trials):
            X = pi.sample_joint(Ntrain)
            X_source = np.hstack((X[:,0][:,None], np.random.randn(Ntrain,1)))
            X_target = pi.sample_joint(Ntrain)

            ## ICNN model
            if 'ICNN' in n_methods:
                ot_icnn = ScaledCostOT(dx1=dim_x,dx2=dim_y,beta=beta,estimator='ICNN', 
                                       n_iters=n_iters_icnn,hidden_dim=hidden_dim_icnn,
                                      batch_size=batch_size_icnn) ## move this as a **args
                ot_icnn.fit(source=X_source, target=X_target)
                print('Done ICNN')
            ## Nearest-Neighbor estimator
            if 'NN' in n_methods: 
                ot_nn = ScaledCostOT(dx1=dim_x,dx2=dim_y,beta=beta,estimator='NN')
                ot_nn.fit(source=X_source, target=X_target)
                print('Done NN')
            ## Entropic map estimator
            if 'OTT' in n_methods:
                ot_ott = ScaledCostOT(dx1=dim_x,dx2=dim_y,beta=beta,estimator='OTT',eps=eps_ott)
                ot_ott.fit(source=X_source, target=X_target, max_iter=n_inters_ott)
                print('Done OTT')
    
            for j,y in enumerate(y_locz):
                data_cond_y = jnp.hstack((jnp.ones((Ntest,1))*y, jnp.array(np.random.randn(Ntest,1))))
                gen_samples_pi = pi.sample_conditional(y,Ntest)

                if 'ICNN' in n_methods:
                    YX_transp_icnn = np.array(ot_icnn.evaluate(data_cond_y.copy()))
                    W1_dict['ICNN'][i,j,k] = stats.wasserstein_distance(YX_transp_icnn[:,1],gen_samples_pi)
                    MMD_dict['ICNN'][i,j,k],_ = mmd2(YX_transp_icnn[:,1].reshape(Ntest,dim_x),gen_samples_pi.reshape(Ntest,dim_x))
    
                if 'NN' in n_methods:
                    YX_transp_nn = np.array(ot_nn.evaluate(data_cond_y.copy()))
                    W1_dict['NN'][i,j,k] = stats.wasserstein_distance(YX_transp_nn[:,1],gen_samples_pi)
                    MMD_dict['NN'][i,j,k],_ = mmd2(YX_transp_nn[:,1].reshape(Ntest,dim_x),gen_samples_pi.reshape(Ntest,dim_x))
    
                if 'OTT' in n_methods: 
                    YX_transp_ott = np.array(ot_ott.evaluate(data_cond_y.copy()))
                    W1_dict['OTT'][i,j,k] = stats.wasserstein_distance(YX_transp_ott[:,1],gen_samples_pi)
                    MMD_dict['OTT'][i,j,k],_ = mmd2(YX_transp_ott[:,1].reshape(Ntest,dim_x),gen_samples_pi.reshape(Ntest,dim_x))
    
                if 'SB' in n_methods: 
                    YX_transp_sb = np.array(ot_ott.evaluate_bridge(data_cond_y.copy(),tau=tau_sb,Nsteps=Nsteps_sb))
                    W1_dict['SB'][i,j,k] = stats.wasserstein_distance(YX_transp_sb[:,1],gen_samples_pi)
                    MMD_dict['SB'][i,j,k],_ = mmd2(YX_transp_sb[:,1].reshape(Ntest,dim_x),gen_samples_pi.reshape(Ntest,dim_x))
    
    all_info={'dataset':dset,'n_iters_ott':n_inters_ott,'tau_sb':tau_sb,
              'Nsteps_sb':Nsteps_sb,'w1_dict': W1_dict, 'betaz':betaz,
              'mmd_dict':MMD_dict,'ntrain':Ntrain,'ntest':Ntest,
             'y_locs':y_locz, 'beta_scaling': beta_scaling}


    with open('all_info_{}.pkl'.format(dset), 'wb') as handle:
        pickle.dump(all_info, handle, protocol=pickle.HIGHEST_PROTOCOL)

['ICNN']


  self.log_det = 2*np.log(np.diag(self.cho_cov


  0%|          | 0/5000 [00:00<?, ?it/s]


KeyboardInterrupt



In [None]:
# for dset in dataset_options:
#     with open('all_info_{}.pkl'.format(dset), 'rb') as f:
#         kms = pickle.load(f)
#     print(kms)