## Tuning parameter of CWGAN for Test Problem 1

In [1]:
%matplotlib inline

### Load required package

In [2]:
import numpy as np
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
import matplotlib.pyplot as plt
import time
from scipy import stats
from scipy.stats import wasserstein_distance
import random
import torch
import sys

In [3]:
import wgan  # Load the wgan python file in the current directory (recommended, as it is more convenient, 
             # facilitates reproducibility, and avoids potential issues that may arise during installation), 
             # or install the package if needed.

In [4]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
# Prevent kernel suddenly interrupted caused by some unknown problem caused by CWGAN

In [5]:
print(f"PyTorch version: {torch.__version__}")
print(f"NumPy version: {np.__version__}")
print(f"Python version: {sys.version}")

PyTorch version: 1.11.0+cu113
NumPy version: 1.21.6
Python version: 3.7.3 (default, Apr 24 2019, 15:29:51) [MSC v.1915 64 bit (AMD64)]


### Prepare for experiments

In [6]:
d = 3 # dimensions of covariates
n = 10000 # number of training data
m = 300 # number of quantile levels in QRGMM
le=1/m
ue=1-le
quantiles = np.linspace(le, ue, m-1) # quantile levels in QRGMM

x1lb=0;
x1ub=10;
x2lb=-5;
x2ub=5;
x3lb=0;
x3ub=5;
# range of covariates

a0=5;
a1=1;
a2=2;
a3=0.5;
r0=1;
r1=0.1;
r2=0.2;
r3=0.05;
# example coefficients


In [7]:
def g1(x0_1,x0_2,x0_3):
    g1=a0+a1*x0_1+a2*x0_2+a3*x0_3
    return g1 
def g2(x0_1,x0_2,x0_3):
    g2=r0+r1*x0_1+r2*x0_2+r3*x0_3
    return g2 

### Tuning parameter: max_epochs 

In [8]:
run=10

In [9]:
maxepochs=np.arange(500,5500,500)

In [10]:
maxepochs

array([ 500, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000])

In [11]:
WassersteinDistance_CWGANGP=np.zeros((run,10))

In [12]:
for maxepochi in np.arange(0,10):
    for runi in np.arange(0,run):
        
        random.seed(runi)
        np.random.seed(runi)
        torch.manual_seed(runi)
        torch.cuda.manual_seed(runi)
        torch.cuda.manual_seed_all(runi)  # if you are using multi-GPU.
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        ############################### generate data ###############################
    
        u1=np.random.rand(n)
        x1=x1lb+(x1ub-x1lb)*u1
        u2=np.random.rand(n)
        x2=x2lb+(x2ub-x2lb)*u2
        u3=np.random.rand(n)
        x3=x3lb+(x3ub-x3lb)*u3
        g1=a0+a1*x1+a2*x2+a3*x3
        g2=r0+r1*x1+r2*x2+r3*x3
        
        F=np.zeros((n,4))
        for i in np.arange(0,n):
            F[i,0]=x1[i]
            F[i,1]=x2[i]
            F[i,2]=x3[i]
            F[i,3]=np.random.normal(g1[i],g2[i])
        df = pd.DataFrame(F, columns=list('A''B''C''F')) # training data
              
        ################################ CWGAN-GP ################################ 
    
        # Y | X
        continuous_vars = ["F"]
        categorical_vars = []
        context_vars = ["A", "B", "C"]
    
        # Initialize objects
        data_wrapper = wgan.DataWrapper(df, continuous_vars, categorical_vars, 
                                      context_vars)
        spec = wgan.Specifications(data_wrapper, batch_size=4096, max_epochs=maxepochs[maxepochi], critic_lr=1e-3, generator_lr=1e-3,
                                     print_every=100, device = "cuda")
        generator = wgan.Generator(spec)
        critic = wgan.Critic(spec)
    
        # train Y | X
        y, context = data_wrapper.preprocess(df)
        wgan.train(generator, critic, y, context, spec)

        # generate data with conditional WGANs
        testdf_generated0 = data_wrapper.apply_generator(generator, df)

        # calculate wasserstein distance 
        WassersteinDistance_CWGANGP[runi,maxepochi]=wasserstein_distance(testdf_generated0["F"], df["F"])       

settings: {'optimizer': <class 'torch.optim.adam.Adam'>, 'critic_d_hidden': [128, 128, 128], 'critic_dropout': 0, 'critic_steps': 15, 'critic_lr': 0.001, 'critic_gp_factor': 5, 'generator_d_hidden': [128, 128, 128], 'generator_dropout': 0.1, 'generator_lr': 0.001, 'generator_d_noise': 1, 'generator_optimizer': 'optimizer', 'max_epochs': 500, 'batch_size': 4096, 'test_set_size': 16, 'load_checkpoint': None, 'save_checkpoint': None, 'save_every': 100, 'print_every': 100, 'device': 'cuda'}
epoch 0 | step 4 | WD_test 0.05 | WD_train 0.02 | sec passed 3 |
epoch 100 | step 304 | WD_test 0.28 | WD_train 0.28 | sec passed 14 |
epoch 200 | step 604 | WD_test 0.09 | WD_train 0.09 | sec passed 14 |
epoch 300 | step 904 | WD_test 0.0 | WD_train -0.06 | sec passed 14 |
epoch 400 | step 1204 | WD_test 0.24 | WD_train 0.22 | sec passed 14 |
settings: {'optimizer': <class 'torch.optim.adam.Adam'>, 'critic_d_hidden': [128, 128, 128], 'critic_dropout': 0, 'critic_steps': 15, 'critic_lr': 0.001, 'critic_

epoch 300 | step 904 | WD_test 0.43 | WD_train 0.35 | sec passed 15 |
epoch 400 | step 1204 | WD_test -0.01 | WD_train -0.07 | sec passed 15 |
settings: {'optimizer': <class 'torch.optim.adam.Adam'>, 'critic_d_hidden': [128, 128, 128], 'critic_dropout': 0, 'critic_steps': 15, 'critic_lr': 0.001, 'critic_gp_factor': 5, 'generator_d_hidden': [128, 128, 128], 'generator_dropout': 0.1, 'generator_lr': 0.001, 'generator_d_noise': 1, 'generator_optimizer': 'optimizer', 'max_epochs': 1000, 'batch_size': 4096, 'test_set_size': 16, 'load_checkpoint': None, 'save_checkpoint': None, 'save_every': 100, 'print_every': 100, 'device': 'cuda'}
epoch 0 | step 4 | WD_test 0.05 | WD_train 0.02 | sec passed 0 |
epoch 100 | step 304 | WD_test 0.28 | WD_train 0.28 | sec passed 14 |
epoch 200 | step 604 | WD_test 0.09 | WD_train 0.09 | sec passed 14 |
epoch 300 | step 904 | WD_test 0.0 | WD_train -0.06 | sec passed 14 |
epoch 400 | step 1204 | WD_test 0.24 | WD_train 0.22 | sec passed 14 |
epoch 500 | step 1

epoch 600 | step 1804 | WD_test 0.06 | WD_train 0.01 | sec passed 14 |
epoch 700 | step 2104 | WD_test 0.19 | WD_train 0.15 | sec passed 14 |
epoch 800 | step 2404 | WD_test 0.22 | WD_train 0.17 | sec passed 14 |
epoch 900 | step 2704 | WD_test 0.11 | WD_train 0.06 | sec passed 14 |
settings: {'optimizer': <class 'torch.optim.adam.Adam'>, 'critic_d_hidden': [128, 128, 128], 'critic_dropout': 0, 'critic_steps': 15, 'critic_lr': 0.001, 'critic_gp_factor': 5, 'generator_d_hidden': [128, 128, 128], 'generator_dropout': 0.1, 'generator_lr': 0.001, 'generator_d_noise': 1, 'generator_optimizer': 'optimizer', 'max_epochs': 1000, 'batch_size': 4096, 'test_set_size': 16, 'load_checkpoint': None, 'save_checkpoint': None, 'save_every': 100, 'print_every': 100, 'device': 'cuda'}
epoch 0 | step 4 | WD_test 0.11 | WD_train 0.03 | sec passed 0 |
epoch 100 | step 304 | WD_test 0.22 | WD_train 0.25 | sec passed 14 |
epoch 200 | step 604 | WD_test 0.16 | WD_train 0.13 | sec passed 14 |
epoch 300 | step 9

epoch 1100 | step 3304 | WD_test 0.24 | WD_train 0.23 | sec passed 14 |
epoch 1200 | step 3604 | WD_test -0.0 | WD_train -0.03 | sec passed 14 |
epoch 1300 | step 3904 | WD_test 0.11 | WD_train 0.07 | sec passed 14 |
epoch 1400 | step 4204 | WD_test 0.11 | WD_train 0.05 | sec passed 14 |
settings: {'optimizer': <class 'torch.optim.adam.Adam'>, 'critic_d_hidden': [128, 128, 128], 'critic_dropout': 0, 'critic_steps': 15, 'critic_lr': 0.001, 'critic_gp_factor': 5, 'generator_d_hidden': [128, 128, 128], 'generator_dropout': 0.1, 'generator_lr': 0.001, 'generator_d_noise': 1, 'generator_optimizer': 'optimizer', 'max_epochs': 1500, 'batch_size': 4096, 'test_set_size': 16, 'load_checkpoint': None, 'save_checkpoint': None, 'save_every': 100, 'print_every': 100, 'device': 'cuda'}
epoch 0 | step 4 | WD_test 0.06 | WD_train 0.02 | sec passed 0 |
epoch 100 | step 304 | WD_test 0.25 | WD_train 0.27 | sec passed 14 |
epoch 200 | step 604 | WD_test 0.11 | WD_train 0.1 | sec passed 14 |
epoch 300 | st

epoch 100 | step 304 | WD_test 0.43 | WD_train 0.26 | sec passed 14 |
epoch 200 | step 604 | WD_test 0.27 | WD_train 0.15 | sec passed 14 |
epoch 300 | step 904 | WD_test 0.05 | WD_train 0.11 | sec passed 15 |
epoch 400 | step 1204 | WD_test 0.34 | WD_train 0.14 | sec passed 15 |
epoch 500 | step 1504 | WD_test 0.26 | WD_train 0.27 | sec passed 15 |
epoch 600 | step 1804 | WD_test 0.15 | WD_train 0.21 | sec passed 15 |
epoch 700 | step 2104 | WD_test 0.18 | WD_train 0.23 | sec passed 14 |
epoch 800 | step 2404 | WD_test 0.02 | WD_train -0.01 | sec passed 14 |
epoch 900 | step 2704 | WD_test 0.17 | WD_train 0.3 | sec passed 15 |
epoch 1000 | step 3004 | WD_test 0.52 | WD_train 0.28 | sec passed 14 |
epoch 1100 | step 3304 | WD_test 0.21 | WD_train 0.33 | sec passed 15 |
epoch 1200 | step 3604 | WD_test 0.63 | WD_train 0.39 | sec passed 14 |
epoch 1300 | step 3904 | WD_test 0.22 | WD_train 0.3 | sec passed 14 |
epoch 1400 | step 4204 | WD_test -0.09 | WD_train 0.11 | sec passed 14 |
sett

epoch 1900 | step 5704 | WD_test -0.03 | WD_train -0.13 | sec passed 14 |
settings: {'optimizer': <class 'torch.optim.adam.Adam'>, 'critic_d_hidden': [128, 128, 128], 'critic_dropout': 0, 'critic_steps': 15, 'critic_lr': 0.001, 'critic_gp_factor': 5, 'generator_d_hidden': [128, 128, 128], 'generator_dropout': 0.1, 'generator_lr': 0.001, 'generator_d_noise': 1, 'generator_optimizer': 'optimizer', 'max_epochs': 2000, 'batch_size': 4096, 'test_set_size': 16, 'load_checkpoint': None, 'save_checkpoint': None, 'save_every': 100, 'print_every': 100, 'device': 'cuda'}
epoch 0 | step 4 | WD_test 0.06 | WD_train 0.02 | sec passed 0 |
epoch 100 | step 304 | WD_test 0.25 | WD_train 0.27 | sec passed 14 |
epoch 200 | step 604 | WD_test 0.11 | WD_train 0.1 | sec passed 14 |
epoch 300 | step 904 | WD_test 0.27 | WD_train 0.2 | sec passed 14 |
epoch 400 | step 1204 | WD_test 0.23 | WD_train 0.08 | sec passed 14 |
epoch 500 | step 1504 | WD_test 0.33 | WD_train 0.33 | sec passed 14 |
epoch 600 | step 1

epoch 0 | step 4 | WD_test 0.11 | WD_train 0.03 | sec passed 0 |
epoch 100 | step 304 | WD_test 0.22 | WD_train 0.25 | sec passed 14 |
epoch 200 | step 604 | WD_test 0.16 | WD_train 0.13 | sec passed 14 |
epoch 300 | step 904 | WD_test 0.44 | WD_train 0.31 | sec passed 14 |
epoch 400 | step 1204 | WD_test 0.23 | WD_train 0.24 | sec passed 14 |
epoch 500 | step 1504 | WD_test 0.36 | WD_train 0.16 | sec passed 14 |
epoch 600 | step 1804 | WD_test 0.08 | WD_train 0.08 | sec passed 14 |
epoch 700 | step 2104 | WD_test -0.02 | WD_train -0.01 | sec passed 14 |
epoch 800 | step 2404 | WD_test 0.39 | WD_train 0.26 | sec passed 14 |
epoch 900 | step 2704 | WD_test -0.07 | WD_train -0.08 | sec passed 14 |
epoch 1000 | step 3004 | WD_test 0.21 | WD_train 0.22 | sec passed 14 |
epoch 1100 | step 3304 | WD_test 0.33 | WD_train 0.18 | sec passed 14 |
epoch 1200 | step 3604 | WD_test 0.23 | WD_train 0.21 | sec passed 14 |
epoch 1300 | step 3904 | WD_test 0.19 | WD_train 0.09 | sec passed 14 |
epoch 1

epoch 300 | step 904 | WD_test -0.02 | WD_train 0.01 | sec passed 14 |
epoch 400 | step 1204 | WD_test -0.09 | WD_train 0.0 | sec passed 14 |
epoch 500 | step 1504 | WD_test 0.15 | WD_train 0.18 | sec passed 14 |
epoch 600 | step 1804 | WD_test 0.26 | WD_train 0.13 | sec passed 14 |
epoch 700 | step 2104 | WD_test -0.19 | WD_train -0.06 | sec passed 14 |
epoch 800 | step 2404 | WD_test 0.22 | WD_train 0.12 | sec passed 14 |
epoch 900 | step 2704 | WD_test -0.12 | WD_train -0.05 | sec passed 14 |
epoch 1000 | step 3004 | WD_test 0.11 | WD_train 0.22 | sec passed 14 |
epoch 1100 | step 3304 | WD_test 0.12 | WD_train 0.14 | sec passed 14 |
epoch 1200 | step 3604 | WD_test 0.3 | WD_train 0.3 | sec passed 14 |
epoch 1300 | step 3904 | WD_test -0.01 | WD_train 0.02 | sec passed 14 |
epoch 1400 | step 4204 | WD_test 0.41 | WD_train 0.28 | sec passed 14 |
epoch 1500 | step 4504 | WD_test 0.17 | WD_train 0.21 | sec passed 14 |
epoch 1600 | step 4804 | WD_test 0.33 | WD_train 0.23 | sec passed 1

epoch 2300 | step 6904 | WD_test 0.27 | WD_train 0.34 | sec passed 14 |
epoch 2400 | step 7204 | WD_test 0.38 | WD_train 0.32 | sec passed 14 |
settings: {'optimizer': <class 'torch.optim.adam.Adam'>, 'critic_d_hidden': [128, 128, 128], 'critic_dropout': 0, 'critic_steps': 15, 'critic_lr': 0.001, 'critic_gp_factor': 5, 'generator_d_hidden': [128, 128, 128], 'generator_dropout': 0.1, 'generator_lr': 0.001, 'generator_d_noise': 1, 'generator_optimizer': 'optimizer', 'max_epochs': 2500, 'batch_size': 4096, 'test_set_size': 16, 'load_checkpoint': None, 'save_checkpoint': None, 'save_every': 100, 'print_every': 100, 'device': 'cuda'}
epoch 0 | step 4 | WD_test 0.07 | WD_train 0.02 | sec passed 0 |
epoch 100 | step 304 | WD_test 0.32 | WD_train 0.27 | sec passed 14 |
epoch 200 | step 604 | WD_test 0.15 | WD_train 0.13 | sec passed 15 |
epoch 300 | step 904 | WD_test 0.21 | WD_train 0.16 | sec passed 14 |
epoch 400 | step 1204 | WD_test 0.12 | WD_train 0.07 | sec passed 15 |
epoch 500 | step 

epoch 1100 | step 3304 | WD_test 0.21 | WD_train 0.33 | sec passed 14 |
epoch 1200 | step 3604 | WD_test 0.63 | WD_train 0.39 | sec passed 14 |
epoch 1300 | step 3904 | WD_test 0.22 | WD_train 0.3 | sec passed 15 |
epoch 1400 | step 4204 | WD_test -0.09 | WD_train 0.11 | sec passed 15 |
epoch 1500 | step 4504 | WD_test 0.13 | WD_train 0.23 | sec passed 14 |
epoch 1600 | step 4804 | WD_test 0.07 | WD_train 0.19 | sec passed 15 |
epoch 1700 | step 5104 | WD_test 0.05 | WD_train 0.19 | sec passed 15 |
epoch 1800 | step 5404 | WD_test 0.27 | WD_train 0.23 | sec passed 14 |
epoch 1900 | step 5704 | WD_test 0.04 | WD_train 0.05 | sec passed 14 |
epoch 2000 | step 6004 | WD_test -0.04 | WD_train 0.12 | sec passed 14 |
epoch 2100 | step 6304 | WD_test 0.11 | WD_train -0.01 | sec passed 14 |
epoch 2200 | step 6604 | WD_test 0.1 | WD_train 0.04 | sec passed 15 |
epoch 2300 | step 6904 | WD_test 0.18 | WD_train 0.26 | sec passed 14 |
epoch 2400 | step 7204 | WD_test -0.06 | WD_train 0.06 | sec pa

epoch 2500 | step 7504 | WD_test 0.34 | WD_train 0.19 | sec passed 14 |
epoch 2600 | step 7804 | WD_test 0.28 | WD_train 0.31 | sec passed 15 |
epoch 2700 | step 8104 | WD_test 0.3 | WD_train 0.23 | sec passed 14 |
epoch 2800 | step 8404 | WD_test 0.0 | WD_train -0.04 | sec passed 14 |
epoch 2900 | step 8704 | WD_test 0.06 | WD_train 0.11 | sec passed 14 |
settings: {'optimizer': <class 'torch.optim.adam.Adam'>, 'critic_d_hidden': [128, 128, 128], 'critic_dropout': 0, 'critic_steps': 15, 'critic_lr': 0.001, 'critic_gp_factor': 5, 'generator_d_hidden': [128, 128, 128], 'generator_dropout': 0.1, 'generator_lr': 0.001, 'generator_d_noise': 1, 'generator_optimizer': 'optimizer', 'max_epochs': 3000, 'batch_size': 4096, 'test_set_size': 16, 'load_checkpoint': None, 'save_checkpoint': None, 'save_every': 100, 'print_every': 100, 'device': 'cuda'}
epoch 0 | step 4 | WD_test 0.07 | WD_train 0.02 | sec passed 0 |
epoch 100 | step 304 | WD_test 0.38 | WD_train 0.29 | sec passed 14 |
epoch 200 | s

epoch 2900 | step 8704 | WD_test 0.19 | WD_train 0.21 | sec passed 14 |
settings: {'optimizer': <class 'torch.optim.adam.Adam'>, 'critic_d_hidden': [128, 128, 128], 'critic_dropout': 0, 'critic_steps': 15, 'critic_lr': 0.001, 'critic_gp_factor': 5, 'generator_d_hidden': [128, 128, 128], 'generator_dropout': 0.1, 'generator_lr': 0.001, 'generator_d_noise': 1, 'generator_optimizer': 'optimizer', 'max_epochs': 3000, 'batch_size': 4096, 'test_set_size': 16, 'load_checkpoint': None, 'save_checkpoint': None, 'save_every': 100, 'print_every': 100, 'device': 'cuda'}
epoch 0 | step 4 | WD_test 0.07 | WD_train 0.02 | sec passed 0 |
epoch 100 | step 304 | WD_test 0.32 | WD_train 0.27 | sec passed 14 |
epoch 200 | step 604 | WD_test 0.15 | WD_train 0.13 | sec passed 14 |
epoch 300 | step 904 | WD_test 0.21 | WD_train 0.16 | sec passed 14 |
epoch 400 | step 1204 | WD_test 0.12 | WD_train 0.07 | sec passed 14 |
epoch 500 | step 1504 | WD_test 0.23 | WD_train 0.22 | sec passed 14 |
epoch 600 | step 1

epoch 0 | step 4 | WD_test 0.08 | WD_train 0.03 | sec passed 0 |
epoch 100 | step 304 | WD_test 0.43 | WD_train 0.26 | sec passed 14 |
epoch 200 | step 604 | WD_test 0.27 | WD_train 0.15 | sec passed 14 |
epoch 300 | step 904 | WD_test 0.05 | WD_train 0.11 | sec passed 14 |
epoch 400 | step 1204 | WD_test 0.34 | WD_train 0.14 | sec passed 14 |
epoch 500 | step 1504 | WD_test 0.26 | WD_train 0.27 | sec passed 14 |
epoch 600 | step 1804 | WD_test 0.15 | WD_train 0.21 | sec passed 14 |
epoch 700 | step 2104 | WD_test 0.18 | WD_train 0.23 | sec passed 14 |
epoch 800 | step 2404 | WD_test 0.02 | WD_train -0.01 | sec passed 14 |
epoch 900 | step 2704 | WD_test 0.17 | WD_train 0.3 | sec passed 14 |
epoch 1000 | step 3004 | WD_test 0.52 | WD_train 0.28 | sec passed 14 |
epoch 1100 | step 3304 | WD_test 0.21 | WD_train 0.33 | sec passed 14 |
epoch 1200 | step 3604 | WD_test 0.63 | WD_train 0.39 | sec passed 14 |
epoch 1300 | step 3904 | WD_test 0.22 | WD_train 0.3 | sec passed 14 |
epoch 1400 |

epoch 0 | step 4 | WD_test 0.07 | WD_train 0.02 | sec passed 0 |
epoch 100 | step 304 | WD_test 0.19 | WD_train 0.27 | sec passed 14 |
epoch 200 | step 604 | WD_test 0.14 | WD_train 0.18 | sec passed 14 |
epoch 300 | step 904 | WD_test -0.02 | WD_train 0.01 | sec passed 14 |
epoch 400 | step 1204 | WD_test -0.09 | WD_train 0.0 | sec passed 14 |
epoch 500 | step 1504 | WD_test 0.15 | WD_train 0.18 | sec passed 14 |
epoch 600 | step 1804 | WD_test 0.26 | WD_train 0.13 | sec passed 14 |
epoch 700 | step 2104 | WD_test -0.19 | WD_train -0.06 | sec passed 14 |
epoch 800 | step 2404 | WD_test 0.22 | WD_train 0.12 | sec passed 14 |
epoch 900 | step 2704 | WD_test -0.12 | WD_train -0.05 | sec passed 14 |
epoch 1000 | step 3004 | WD_test 0.11 | WD_train 0.22 | sec passed 14 |
epoch 1100 | step 3304 | WD_test 0.12 | WD_train 0.14 | sec passed 14 |
epoch 1200 | step 3604 | WD_test 0.3 | WD_train 0.3 | sec passed 14 |
epoch 1300 | step 3904 | WD_test -0.01 | WD_train 0.02 | sec passed 14 |
epoch 1

epoch 3100 | step 9304 | WD_test 0.25 | WD_train 0.18 | sec passed 14 |
epoch 3200 | step 9604 | WD_test -0.03 | WD_train 0.23 | sec passed 14 |
epoch 3300 | step 9904 | WD_test 0.34 | WD_train 0.23 | sec passed 14 |
epoch 3400 | step 10204 | WD_test 0.12 | WD_train 0.2 | sec passed 14 |
settings: {'optimizer': <class 'torch.optim.adam.Adam'>, 'critic_d_hidden': [128, 128, 128], 'critic_dropout': 0, 'critic_steps': 15, 'critic_lr': 0.001, 'critic_gp_factor': 5, 'generator_d_hidden': [128, 128, 128], 'generator_dropout': 0.1, 'generator_lr': 0.001, 'generator_d_noise': 1, 'generator_optimizer': 'optimizer', 'max_epochs': 3500, 'batch_size': 4096, 'test_set_size': 16, 'load_checkpoint': None, 'save_checkpoint': None, 'save_every': 100, 'print_every': 100, 'device': 'cuda'}
epoch 0 | step 4 | WD_test 0.06 | WD_train 0.02 | sec passed 0 |
epoch 100 | step 304 | WD_test 0.22 | WD_train 0.3 | sec passed 14 |
epoch 200 | step 604 | WD_test 0.02 | WD_train 0.1 | sec passed 14 |
epoch 300 | ste

epoch 2100 | step 6304 | WD_test 0.03 | WD_train 0.01 | sec passed 15 |
epoch 2200 | step 6604 | WD_test 0.24 | WD_train 0.26 | sec passed 14 |
epoch 2300 | step 6904 | WD_test -0.06 | WD_train -0.09 | sec passed 15 |
epoch 2400 | step 7204 | WD_test -0.04 | WD_train -0.04 | sec passed 14 |
epoch 2500 | step 7504 | WD_test -0.05 | WD_train -0.02 | sec passed 15 |
epoch 2600 | step 7804 | WD_test 0.15 | WD_train 0.09 | sec passed 15 |
epoch 2700 | step 8104 | WD_test 0.2 | WD_train 0.13 | sec passed 15 |
epoch 2800 | step 8404 | WD_test -0.19 | WD_train -0.09 | sec passed 14 |
epoch 2900 | step 8704 | WD_test 0.04 | WD_train 0.04 | sec passed 14 |
epoch 3000 | step 9004 | WD_test 0.12 | WD_train 0.11 | sec passed 15 |
epoch 3100 | step 9304 | WD_test 0.27 | WD_train 0.18 | sec passed 15 |
epoch 3200 | step 9604 | WD_test 0.29 | WD_train 0.25 | sec passed 15 |
epoch 3300 | step 9904 | WD_test 0.28 | WD_train 0.27 | sec passed 15 |
epoch 3400 | step 10204 | WD_test 0.18 | WD_train 0.23 | 

epoch 1000 | step 3004 | WD_test 0.36 | WD_train 0.26 | sec passed 14 |
epoch 1100 | step 3304 | WD_test 0.25 | WD_train 0.15 | sec passed 14 |
epoch 1200 | step 3604 | WD_test 0.19 | WD_train 0.04 | sec passed 14 |
epoch 1300 | step 3904 | WD_test -0.01 | WD_train 0.2 | sec passed 15 |
epoch 1400 | step 4204 | WD_test 0.35 | WD_train 0.2 | sec passed 14 |
epoch 1500 | step 4504 | WD_test 0.03 | WD_train 0.11 | sec passed 14 |
epoch 1600 | step 4804 | WD_test 0.36 | WD_train 0.16 | sec passed 14 |
epoch 1700 | step 5104 | WD_test -0.06 | WD_train 0.02 | sec passed 14 |
epoch 1800 | step 5404 | WD_test 0.0 | WD_train 0.09 | sec passed 14 |
epoch 1900 | step 5704 | WD_test 0.22 | WD_train 0.13 | sec passed 14 |
epoch 2000 | step 6004 | WD_test -0.08 | WD_train -0.06 | sec passed 14 |
epoch 2100 | step 6304 | WD_test 0.18 | WD_train 0.17 | sec passed 15 |
epoch 2200 | step 6604 | WD_test 0.31 | WD_train 0.2 | sec passed 15 |
epoch 2300 | step 6904 | WD_test 0.04 | WD_train 0.09 | sec pass

epoch 3600 | step 10804 | WD_test 0.24 | WD_train 0.23 | sec passed 14 |
epoch 3700 | step 11104 | WD_test 0.49 | WD_train 0.39 | sec passed 14 |
epoch 3800 | step 11404 | WD_test 0.09 | WD_train -0.03 | sec passed 14 |
epoch 3900 | step 11704 | WD_test 0.21 | WD_train 0.31 | sec passed 14 |
settings: {'optimizer': <class 'torch.optim.adam.Adam'>, 'critic_d_hidden': [128, 128, 128], 'critic_dropout': 0, 'critic_steps': 15, 'critic_lr': 0.001, 'critic_gp_factor': 5, 'generator_d_hidden': [128, 128, 128], 'generator_dropout': 0.1, 'generator_lr': 0.001, 'generator_d_noise': 1, 'generator_optimizer': 'optimizer', 'max_epochs': 4000, 'batch_size': 4096, 'test_set_size': 16, 'load_checkpoint': None, 'save_checkpoint': None, 'save_every': 100, 'print_every': 100, 'device': 'cuda'}
epoch 0 | step 4 | WD_test 0.07 | WD_train 0.02 | sec passed 0 |
epoch 100 | step 304 | WD_test 0.38 | WD_train 0.29 | sec passed 14 |
epoch 200 | step 604 | WD_test 0.17 | WD_train 0.16 | sec passed 14 |
epoch 300

epoch 1000 | step 3004 | WD_test 0.18 | WD_train 0.16 | sec passed 14 |
epoch 1100 | step 3304 | WD_test 0.32 | WD_train 0.23 | sec passed 14 |
epoch 1200 | step 3604 | WD_test 0.2 | WD_train 0.27 | sec passed 14 |
epoch 1300 | step 3904 | WD_test 0.37 | WD_train 0.31 | sec passed 14 |
epoch 1400 | step 4204 | WD_test 0.23 | WD_train 0.27 | sec passed 14 |
epoch 1500 | step 4504 | WD_test 0.25 | WD_train 0.24 | sec passed 14 |
epoch 1600 | step 4804 | WD_test 0.28 | WD_train 0.35 | sec passed 14 |
epoch 1700 | step 5104 | WD_test 0.37 | WD_train 0.29 | sec passed 14 |
epoch 1800 | step 5404 | WD_test 0.26 | WD_train 0.28 | sec passed 14 |
epoch 1900 | step 5704 | WD_test 0.31 | WD_train 0.22 | sec passed 14 |
epoch 2000 | step 6004 | WD_test 0.12 | WD_train 0.1 | sec passed 14 |
epoch 2100 | step 6304 | WD_test 0.01 | WD_train 0.07 | sec passed 14 |
epoch 2200 | step 6604 | WD_test 0.32 | WD_train 0.22 | sec passed 14 |
epoch 2300 | step 6904 | WD_test 0.27 | WD_train 0.34 | sec passed

epoch 3100 | step 9304 | WD_test 0.27 | WD_train 0.18 | sec passed 14 |
epoch 3200 | step 9604 | WD_test 0.29 | WD_train 0.25 | sec passed 14 |
epoch 3300 | step 9904 | WD_test 0.28 | WD_train 0.27 | sec passed 14 |
epoch 3400 | step 10204 | WD_test 0.18 | WD_train 0.23 | sec passed 14 |
epoch 3500 | step 10504 | WD_test 0.22 | WD_train 0.21 | sec passed 14 |
epoch 3600 | step 10804 | WD_test 0.15 | WD_train 0.17 | sec passed 14 |
epoch 3700 | step 11104 | WD_test 0.28 | WD_train 0.26 | sec passed 14 |
epoch 3800 | step 11404 | WD_test 0.35 | WD_train 0.34 | sec passed 14 |
epoch 3900 | step 11704 | WD_test 0.18 | WD_train 0.21 | sec passed 14 |
settings: {'optimizer': <class 'torch.optim.adam.Adam'>, 'critic_d_hidden': [128, 128, 128], 'critic_dropout': 0, 'critic_steps': 15, 'critic_lr': 0.001, 'critic_gp_factor': 5, 'generator_d_hidden': [128, 128, 128], 'generator_dropout': 0.1, 'generator_lr': 0.001, 'generator_d_noise': 1, 'generator_optimizer': 'optimizer', 'max_epochs': 4000, '

epoch 500 | step 1504 | WD_test 0.32 | WD_train 0.36 | sec passed 14 |
epoch 600 | step 1804 | WD_test 0.43 | WD_train 0.28 | sec passed 14 |
epoch 700 | step 2104 | WD_test 0.08 | WD_train 0.13 | sec passed 14 |
epoch 800 | step 2404 | WD_test 0.24 | WD_train 0.15 | sec passed 14 |
epoch 900 | step 2704 | WD_test 0.21 | WD_train 0.31 | sec passed 14 |
epoch 1000 | step 3004 | WD_test 0.36 | WD_train 0.26 | sec passed 14 |
epoch 1100 | step 3304 | WD_test 0.25 | WD_train 0.15 | sec passed 14 |
epoch 1200 | step 3604 | WD_test 0.19 | WD_train 0.04 | sec passed 14 |
epoch 1300 | step 3904 | WD_test -0.01 | WD_train 0.2 | sec passed 14 |
epoch 1400 | step 4204 | WD_test 0.35 | WD_train 0.2 | sec passed 14 |
epoch 1500 | step 4504 | WD_test 0.03 | WD_train 0.11 | sec passed 14 |
epoch 1600 | step 4804 | WD_test 0.36 | WD_train 0.16 | sec passed 14 |
epoch 1700 | step 5104 | WD_test -0.06 | WD_train 0.02 | sec passed 14 |
epoch 1800 | step 5404 | WD_test 0.0 | WD_train 0.09 | sec passed 14 

epoch 2100 | step 6304 | WD_test 0.07 | WD_train 0.01 | sec passed 14 |
epoch 2200 | step 6604 | WD_test 0.09 | WD_train 0.1 | sec passed 14 |
epoch 2300 | step 6904 | WD_test 0.35 | WD_train 0.19 | sec passed 14 |
epoch 2400 | step 7204 | WD_test 0.03 | WD_train 0.14 | sec passed 14 |
epoch 2500 | step 7504 | WD_test 0.34 | WD_train 0.19 | sec passed 14 |
epoch 2600 | step 7804 | WD_test 0.28 | WD_train 0.31 | sec passed 14 |
epoch 2700 | step 8104 | WD_test 0.3 | WD_train 0.23 | sec passed 14 |
epoch 2800 | step 8404 | WD_test 0.0 | WD_train -0.04 | sec passed 14 |
epoch 2900 | step 8704 | WD_test 0.06 | WD_train 0.11 | sec passed 14 |
epoch 3000 | step 9004 | WD_test 0.26 | WD_train 0.26 | sec passed 14 |
epoch 3100 | step 9304 | WD_test 0.11 | WD_train 0.18 | sec passed 14 |
epoch 3200 | step 9604 | WD_test 0.36 | WD_train 0.29 | sec passed 14 |
epoch 3300 | step 9904 | WD_test 0.24 | WD_train 0.32 | sec passed 14 |
epoch 3400 | step 10204 | WD_test 0.14 | WD_train 0.15 | sec passe

epoch 3200 | step 9604 | WD_test -0.03 | WD_train 0.23 | sec passed 14 |
epoch 3300 | step 9904 | WD_test 0.34 | WD_train 0.23 | sec passed 14 |
epoch 3400 | step 10204 | WD_test 0.12 | WD_train 0.2 | sec passed 14 |
epoch 3500 | step 10504 | WD_test 0.29 | WD_train 0.25 | sec passed 14 |
epoch 3600 | step 10804 | WD_test 0.22 | WD_train 0.26 | sec passed 14 |
epoch 3700 | step 11104 | WD_test 0.24 | WD_train 0.24 | sec passed 14 |
epoch 3800 | step 11404 | WD_test -0.02 | WD_train -0.03 | sec passed 14 |
epoch 3900 | step 11704 | WD_test 0.0 | WD_train 0.02 | sec passed 14 |
epoch 4000 | step 12004 | WD_test 0.16 | WD_train 0.19 | sec passed 14 |
epoch 4100 | step 12304 | WD_test 0.29 | WD_train 0.3 | sec passed 14 |
epoch 4200 | step 12604 | WD_test 0.14 | WD_train 0.19 | sec passed 14 |
epoch 4300 | step 12904 | WD_test 0.03 | WD_train 0.02 | sec passed 14 |
epoch 4400 | step 13204 | WD_test -0.04 | WD_train 0.03 | sec passed 14 |
settings: {'optimizer': <class 'torch.optim.adam.Ada

epoch 4300 | step 12904 | WD_test 0.2 | WD_train 0.14 | sec passed 14 |
epoch 4400 | step 13204 | WD_test 0.1 | WD_train 0.07 | sec passed 14 |
settings: {'optimizer': <class 'torch.optim.adam.Adam'>, 'critic_d_hidden': [128, 128, 128], 'critic_dropout': 0, 'critic_steps': 15, 'critic_lr': 0.001, 'critic_gp_factor': 5, 'generator_d_hidden': [128, 128, 128], 'generator_dropout': 0.1, 'generator_lr': 0.001, 'generator_d_noise': 1, 'generator_optimizer': 'optimizer', 'max_epochs': 4500, 'batch_size': 4096, 'test_set_size': 16, 'load_checkpoint': None, 'save_checkpoint': None, 'save_every': 100, 'print_every': 100, 'device': 'cuda'}
epoch 0 | step 4 | WD_test 0.06 | WD_train 0.02 | sec passed 0 |
epoch 100 | step 304 | WD_test 0.14 | WD_train 0.22 | sec passed 14 |
epoch 200 | step 604 | WD_test 0.3 | WD_train 0.21 | sec passed 14 |
epoch 300 | step 904 | WD_test 0.23 | WD_train 0.13 | sec passed 14 |
epoch 400 | step 1204 | WD_test 0.03 | WD_train 0.04 | sec passed 14 |
epoch 500 | step 1

epoch 200 | step 604 | WD_test 0.27 | WD_train 0.15 | sec passed 14 |
epoch 300 | step 904 | WD_test 0.05 | WD_train 0.11 | sec passed 14 |
epoch 400 | step 1204 | WD_test 0.34 | WD_train 0.14 | sec passed 14 |
epoch 500 | step 1504 | WD_test 0.26 | WD_train 0.27 | sec passed 14 |
epoch 600 | step 1804 | WD_test 0.15 | WD_train 0.21 | sec passed 14 |
epoch 700 | step 2104 | WD_test 0.18 | WD_train 0.23 | sec passed 14 |
epoch 800 | step 2404 | WD_test 0.02 | WD_train -0.01 | sec passed 14 |
epoch 900 | step 2704 | WD_test 0.17 | WD_train 0.3 | sec passed 14 |
epoch 1000 | step 3004 | WD_test 0.52 | WD_train 0.28 | sec passed 14 |
epoch 1100 | step 3304 | WD_test 0.21 | WD_train 0.33 | sec passed 14 |
epoch 1200 | step 3604 | WD_test 0.63 | WD_train 0.39 | sec passed 14 |
epoch 1300 | step 3904 | WD_test 0.22 | WD_train 0.3 | sec passed 14 |
epoch 1400 | step 4204 | WD_test -0.09 | WD_train 0.11 | sec passed 14 |
epoch 1500 | step 4504 | WD_test 0.13 | WD_train 0.23 | sec passed 14 |
ep

epoch 1300 | step 3904 | WD_test 0.09 | WD_train 0.09 | sec passed 14 |
epoch 1400 | step 4204 | WD_test 0.16 | WD_train 0.15 | sec passed 14 |
epoch 1500 | step 4504 | WD_test 0.13 | WD_train 0.21 | sec passed 14 |
epoch 1600 | step 4804 | WD_test 0.26 | WD_train 0.28 | sec passed 14 |
epoch 1700 | step 5104 | WD_test 0.2 | WD_train 0.23 | sec passed 14 |
epoch 1800 | step 5404 | WD_test 0.05 | WD_train 0.04 | sec passed 14 |
epoch 1900 | step 5704 | WD_test 0.01 | WD_train -0.03 | sec passed 14 |
epoch 2000 | step 6004 | WD_test 0.1 | WD_train 0.14 | sec passed 14 |
epoch 2100 | step 6304 | WD_test 0.26 | WD_train 0.26 | sec passed 14 |
epoch 2200 | step 6604 | WD_test 0.26 | WD_train 0.36 | sec passed 14 |
epoch 2300 | step 6904 | WD_test 0.36 | WD_train 0.37 | sec passed 14 |
epoch 2400 | step 7204 | WD_test 0.27 | WD_train 0.24 | sec passed 14 |
epoch 2500 | step 7504 | WD_test -0.07 | WD_train 0.01 | sec passed 14 |
epoch 2600 | step 7804 | WD_test 0.03 | WD_train 0.03 | sec pass

epoch 1400 | step 4204 | WD_test 0.11 | WD_train 0.05 | sec passed 14 |
epoch 1500 | step 4504 | WD_test 0.19 | WD_train 0.22 | sec passed 14 |
epoch 1600 | step 4804 | WD_test 0.34 | WD_train 0.25 | sec passed 14 |
epoch 1700 | step 5104 | WD_test 0.29 | WD_train 0.29 | sec passed 14 |
epoch 1800 | step 5404 | WD_test 0.37 | WD_train 0.29 | sec passed 14 |
epoch 1900 | step 5704 | WD_test -0.03 | WD_train -0.13 | sec passed 14 |
epoch 2000 | step 6004 | WD_test 0.11 | WD_train 0.07 | sec passed 14 |
epoch 2100 | step 6304 | WD_test 0.02 | WD_train -0.03 | sec passed 14 |
epoch 2200 | step 6604 | WD_test 0.02 | WD_train 0.05 | sec passed 14 |
epoch 2300 | step 6904 | WD_test 0.32 | WD_train 0.2 | sec passed 14 |
epoch 2400 | step 7204 | WD_test 0.22 | WD_train 0.25 | sec passed 14 |
epoch 2500 | step 7504 | WD_test 0.29 | WD_train 0.23 | sec passed 14 |
epoch 2600 | step 7804 | WD_test 0.25 | WD_train 0.25 | sec passed 14 |
epoch 2700 | step 8104 | WD_test -0.08 | WD_train -0.07 | sec 

epoch 1400 | step 4204 | WD_test 0.23 | WD_train 0.27 | sec passed 14 |
epoch 1500 | step 4504 | WD_test 0.25 | WD_train 0.24 | sec passed 14 |
epoch 1600 | step 4804 | WD_test 0.28 | WD_train 0.35 | sec passed 14 |
epoch 1700 | step 5104 | WD_test 0.37 | WD_train 0.29 | sec passed 14 |
epoch 1800 | step 5404 | WD_test 0.26 | WD_train 0.28 | sec passed 14 |
epoch 1900 | step 5704 | WD_test 0.31 | WD_train 0.22 | sec passed 14 |
epoch 2000 | step 6004 | WD_test 0.12 | WD_train 0.1 | sec passed 14 |
epoch 2100 | step 6304 | WD_test 0.01 | WD_train 0.07 | sec passed 14 |
epoch 2200 | step 6604 | WD_test 0.32 | WD_train 0.22 | sec passed 14 |
epoch 2300 | step 6904 | WD_test 0.27 | WD_train 0.34 | sec passed 14 |
epoch 2400 | step 7204 | WD_test 0.38 | WD_train 0.32 | sec passed 14 |
epoch 2500 | step 7504 | WD_test 0.16 | WD_train 0.27 | sec passed 14 |
epoch 2600 | step 7804 | WD_test 0.21 | WD_train 0.22 | sec passed 14 |
epoch 2700 | step 8104 | WD_test 0.23 | WD_train 0.28 | sec passe

epoch 1500 | step 4504 | WD_test 0.13 | WD_train 0.09 | sec passed 13 |
epoch 1600 | step 4804 | WD_test 0.29 | WD_train 0.28 | sec passed 14 |
epoch 1700 | step 5104 | WD_test 0.45 | WD_train 0.43 | sec passed 13 |
epoch 1800 | step 5404 | WD_test 0.24 | WD_train 0.25 | sec passed 13 |
epoch 1900 | step 5704 | WD_test 0.17 | WD_train 0.15 | sec passed 13 |
epoch 2000 | step 6004 | WD_test -0.05 | WD_train -0.04 | sec passed 14 |
epoch 2100 | step 6304 | WD_test 0.03 | WD_train 0.01 | sec passed 13 |
epoch 2200 | step 6604 | WD_test 0.24 | WD_train 0.26 | sec passed 14 |
epoch 2300 | step 6904 | WD_test -0.06 | WD_train -0.09 | sec passed 13 |
epoch 2400 | step 7204 | WD_test -0.04 | WD_train -0.04 | sec passed 13 |
epoch 2500 | step 7504 | WD_test -0.05 | WD_train -0.02 | sec passed 14 |
epoch 2600 | step 7804 | WD_test 0.15 | WD_train 0.09 | sec passed 13 |
epoch 2700 | step 8104 | WD_test 0.2 | WD_train 0.13 | sec passed 13 |
epoch 2800 | step 8404 | WD_test -0.19 | WD_train -0.09 |

epoch 1500 | step 4504 | WD_test 0.13 | WD_train 0.23 | sec passed 13 |
epoch 1600 | step 4804 | WD_test 0.07 | WD_train 0.19 | sec passed 13 |
epoch 1700 | step 5104 | WD_test 0.05 | WD_train 0.19 | sec passed 13 |
epoch 1800 | step 5404 | WD_test 0.27 | WD_train 0.23 | sec passed 13 |
epoch 1900 | step 5704 | WD_test 0.04 | WD_train 0.05 | sec passed 13 |
epoch 2000 | step 6004 | WD_test -0.04 | WD_train 0.12 | sec passed 13 |
epoch 2100 | step 6304 | WD_test 0.11 | WD_train -0.01 | sec passed 13 |
epoch 2200 | step 6604 | WD_test 0.1 | WD_train 0.04 | sec passed 13 |
epoch 2300 | step 6904 | WD_test 0.18 | WD_train 0.26 | sec passed 14 |
epoch 2400 | step 7204 | WD_test -0.06 | WD_train 0.06 | sec passed 13 |
epoch 2500 | step 7504 | WD_test 0.33 | WD_train 0.11 | sec passed 13 |
epoch 2600 | step 7804 | WD_test 0.18 | WD_train 0.33 | sec passed 13 |
epoch 2700 | step 8104 | WD_test -0.05 | WD_train 0.21 | sec passed 13 |
epoch 2800 | step 8404 | WD_test 0.09 | WD_train -0.02 | sec 

In [13]:
WassersteinDistance_CWGANGP = pd.DataFrame(WassersteinDistance_CWGANGP)
WassersteinDistance_CWGANGP.columns=maxepochs

In [14]:
WassersteinDistance_CWGANGP.to_csv("./data/WD_cwgan_TuningParameter_withseed.csv",index=0) 

In [15]:
WassersteinDistance_CWGANGP

Unnamed: 0,500,1000,1500,2000,2500,3000,3500,4000,4500,5000
0,1.224198,2.439083,1.430486,1.964579,0.275304,2.145307,1.507189,1.448433,1.297706,0.974502
1,1.741514,1.585201,1.721675,2.716508,1.538448,2.008328,1.925015,1.469689,1.498391,2.735931
2,1.833043,2.042612,1.714863,2.236112,1.622907,1.150506,1.193637,1.573209,2.181145,1.344073
3,2.325753,2.82479,1.79517,1.026342,1.50667,2.512768,1.819695,1.515152,0.776181,1.37044
4,1.206534,1.628821,2.103635,1.944554,1.887524,1.808949,0.361087,1.377977,1.211224,1.839382
5,1.754818,2.271663,1.056011,1.406072,2.548562,2.889568,1.656299,1.979775,1.549695,2.091786
6,1.582954,2.395302,1.327418,1.434709,1.849436,1.048748,1.418538,1.531695,0.765725,2.021665
7,1.069857,2.207497,2.011763,1.732724,1.877773,1.119309,1.594584,0.850435,0.590467,1.558846
8,2.056705,2.164168,2.064521,1.513348,0.79109,1.666136,1.250796,1.742591,0.991607,2.22546
9,2.695049,2.11742,0.893147,1.053456,1.752615,1.681097,1.540248,1.394782,2.024744,1.185604


In [16]:
WassersteinDistance_CWGANGP.mean(axis=0)

500     1.749042
1000    2.167656
1500    1.611869
2000    1.702840
2500    1.565033
3000    1.803072
3500    1.426709
4000    1.488374
4500    1.288689
5000    1.734769
dtype: float64

In [17]:
WassersteinDistance_CWGANGP.std(axis=0)

500     0.515103
1000    0.367800
1500    0.421934
2000    0.528620
2500    0.628496
3000    0.609138
3500    0.437657
4000    0.287856
4500    0.534204
5000    0.543623
dtype: float64