In [1]:
#!/usr/bin/env python
# coding: utf-8

import sys
import pandas as pd 
import numpy as np
import torch
import time

In [2]:
import sys
sys.path.append("/home/kim2712/Desktop/research/generative_quantile/generative_qunatile")
from _nets import wgan2
from _data.gaussian_conjugate import forward_sampler

In [3]:
nABC = 1000#00
N_y=2
theta_dim=2

HPARAM = {"nu":25, "sigma0_sq":1, "mu0":0,"kappa":2}
theta_seq, X_seq = forward_sampler(n = N_y, 
                                batch_size=nABC,
                                device="cuda",
                                h_param=HPARAM)

In [4]:
theta_names=["theta"+str(i) for i in range(1,theta_dim+1)]
X_names=["X"+str(i) for i in range(1,N_y+1)]

df=pd.DataFrame(data=np.concatenate((theta_seq,X_seq),axis=-1),
               columns=(theta_names+X_names))

In [5]:
checkpoint_path = None #"gauss_"+str(n_test)+"_nu/cp.ckpt"
checkpoint_dir = "./"
save_checkpoint = F"{checkpoint_dir}/checkpoint.ck"

In [6]:
import pathlib
pathlib.Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)


In [8]:
data_wrapper= wgan2.DataWrapper(df, continuous_vars=theta_names, context_vars=X_names)

spec=wgan2.Specifications(data_wrapper, batch_size=1280, max_epochs=2020, 
                          critic_lr=1e-3, generator_lr=1e-3,
                         print_every=10,device = "cuda",
                        #load_checkpoint=None,#checkpoint_dir,
                          save_checkpoint = save_checkpoint,
                        save_every=10)

generator=wgan2.Generator(spec)
critic=wgan2.Critic(spec)
thetas, Xs = data_wrapper.preprocess(df)

settings: {'optimizer': <class 'torch.optim.adam.Adam'>, 'activation': 'relu', 'critic_d_hidden': [128, 128, 128], 'critic_dropout': 0, 'critic_steps': 15, 'critic_lr': 0.001, 'critic_gp_factor': 5, 'discriminator_d_hidden': [128, 128, 128], 'discriminator_dropout': 0.1, 'discriminator_steps': 1, 'discriminator_lr': 0.0001, 'generator_d_hidden': [128, 128, 128], 'generator_dropout': 0.1, 'generator_lr': 0.001, 'generator_d_noise': 2, 'generator_optimizer': 'optimizer', 'big_Z': False, 'max_epochs': 2020, 'batch_size': 1280, 'test_set_size': 16, 'load_checkpoint': None, 'save_checkpoint': './/checkpoint.ck', 'save_every': 10, 'print_every': 10, 'history_path': None, 'device': 'cuda'}


In [9]:
wgan2.train(generator, critic, thetas, Xs, spec)

epoch 0 | step 2 | WD_test 0.06 | WD_train 0.01 | sec passed 3 |
epoch 10 | step 12 | WD_test 0.62 | WD_train 0.42 | sec passed 0 |
epoch 20 | step 22 | WD_test 1.04 | WD_train 0.9 | sec passed 0 |
epoch 30 | step 32 | WD_test 1.06 | WD_train 0.95 | sec passed 0 |
epoch 40 | step 42 | WD_test 1.11 | WD_train 1.03 | sec passed 0 |
epoch 50 | step 52 | WD_test 1.13 | WD_train 1.04 | sec passed 0 |
epoch 60 | step 62 | WD_test 1.12 | WD_train 1.06 | sec passed 0 |
epoch 70 | step 72 | WD_test 1.15 | WD_train 1.07 | sec passed 0 |
epoch 80 | step 82 | WD_test 1.12 | WD_train 1.06 | sec passed 0 |
epoch 90 | step 92 | WD_test 1.08 | WD_train 1.03 | sec passed 0 |
epoch 100 | step 102 | WD_test 1.11 | WD_train 1.04 | sec passed 0 |
epoch 110 | step 112 | WD_test 1.07 | WD_train 1.02 | sec passed 0 |
epoch 120 | step 122 | WD_test 1.03 | WD_train 0.97 | sec passed 0 |
epoch 130 | step 132 | WD_test 0.97 | WD_train 0.98 | sec passed 0 |
epoch 140 | step 142 | WD_test 1.0 | WD_train 0.93 | sec 

epoch 1200 | step 1202 | WD_test 0.5 | WD_train 0.16 | sec passed 0 |
epoch 1210 | step 1212 | WD_test 0.42 | WD_train 0.12 | sec passed 0 |
epoch 1220 | step 1222 | WD_test 0.39 | WD_train 0.08 | sec passed 0 |
epoch 1230 | step 1232 | WD_test 0.46 | WD_train 0.14 | sec passed 0 |
epoch 1240 | step 1242 | WD_test 0.33 | WD_train 0.17 | sec passed 0 |
epoch 1250 | step 1252 | WD_test 0.35 | WD_train 0.18 | sec passed 0 |
epoch 1260 | step 1262 | WD_test 0.2 | WD_train 0.2 | sec passed 0 |
epoch 1270 | step 1272 | WD_test 0.22 | WD_train 0.18 | sec passed 0 |
epoch 1280 | step 1282 | WD_test 0.41 | WD_train 0.16 | sec passed 0 |
epoch 1290 | step 1292 | WD_test 0.05 | WD_train 0.14 | sec passed 0 |
epoch 1300 | step 1302 | WD_test 0.49 | WD_train 0.17 | sec passed 0 |
epoch 1310 | step 1312 | WD_test 0.07 | WD_train 0.12 | sec passed 0 |
epoch 1320 | step 1322 | WD_test 0.28 | WD_train 0.12 | sec passed 0 |
epoch 1330 | step 1332 | WD_test 0.07 | WD_train 0.15 | sec passed 0 |
epoch 134

In [10]:
ck = torch.load(save_checkpoint)
print(ck['epoch'])
generator.load_state_dict(ck["generator_state_dict"])
#critic.load_state_dict(ck["critic_state_dict"])

2010


<All keys matched successfully>

In [11]:
n_repeat=50000
M=2.2
X0=np.array([M,M]).reshape(1,-1)
X0_tile=np.repeat(X0, repeats=[n_repeat],axis=0)
Z0=np.random.normal(size=(n_repeat, 2))

df0=pd.DataFrame(data=np.concatenate((Z0,X0_tile),axis=-1),columns=(theta_names+X_names))

In [12]:
theta0_hat=data_wrapper.apply_generator(generator, df0)
newtheta0_seq=theta0_hat[theta_names]
newtheta0_seq

Unnamed: 0,theta1,theta2
0,1.078532,1.192933
1,1.018022,1.076689
2,1.192513,0.748294
3,1.230789,1.207675
4,0.616087,1.492202
...,...,...
49995,0.789278,1.243984
49996,0.248442,1.364525
49997,0.842347,1.146536
49998,0.565942,0.767835
