- 0317: update the net structure
- 0318: update the bandwidth setting
- 0322: update the tuning params
- 0325: introduce a small network to simultaniously update the tuning parameter
- 0427: add a new cov term in the loss
- 0430: remove cov term and add a new regularization term on the weights
- 0510: increase the batch size
- 0512: new x regularity term
- 0512v2: new x regularity term + weight regularity term
- 0515: scaling and standardization
- 0515: do not scale

In [1]:
import os
import itertools
from sklearn.preprocessing import StandardScaler

In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sb
import pandas as pd
from tqdm import tqdm

from datagen import *
from network import *
from scipy.stats import ks_2samp

from joblib import Parallel, delayed

In [3]:
def expand_grid(data_dict):
    """Create a dataframe from every combination of given values."""
    rows = itertools.product(*data_dict.values())
    return pd.DataFrame.from_records(rows, columns=data_dict.keys())

In [4]:
param_dict = {
    'data_path': ['./save/simu_data/scenario1/'],
    'batch_size': [512],
    'lr':[1e-3],
    'pa':[0.1],
    'val_metric':['KS'],
    'x_lambda':[0.01,0.1,1],
    'wt_lambda':[0,0.1,1],
    'num_nodes': [512],
    'num_iters': [5000],
    'num_init_iters':[500],
    'patience':[10],
    'kernel_params':[{'kernel':'rbf',
                     'gamma':1,
                     'degree':2,
                     'c':1},
                    {'kernel':'rbf',
                     'gamma':5,
                     'degree':2,
                     'c':1},
                    {'kernel':'rbf',
                     'gamma':10,
                     'degree':2,
                     'c':1}],
    'random_state': [0]
}

In [5]:
param_df = expand_grid(param_dict)

In [6]:
param_df

Unnamed: 0,data_path,batch_size,lr,pa,val_metric,x_lambda,wt_lambda,num_nodes,num_iters,num_init_iters,patience,kernel_params,random_state
0,./save/simu_data/scenario1/,512,0.001,0.1,KS,0.01,0.0,512,5000,500,10,"{'kernel': 'rbf', 'gamma': 1, 'degree': 2, 'c'...",0
1,./save/simu_data/scenario1/,512,0.001,0.1,KS,0.01,0.0,512,5000,500,10,"{'kernel': 'rbf', 'gamma': 5, 'degree': 2, 'c'...",0
2,./save/simu_data/scenario1/,512,0.001,0.1,KS,0.01,0.0,512,5000,500,10,"{'kernel': 'rbf', 'gamma': 10, 'degree': 2, 'c...",0
3,./save/simu_data/scenario1/,512,0.001,0.1,KS,0.01,0.1,512,5000,500,10,"{'kernel': 'rbf', 'gamma': 1, 'degree': 2, 'c'...",0
4,./save/simu_data/scenario1/,512,0.001,0.1,KS,0.01,0.1,512,5000,500,10,"{'kernel': 'rbf', 'gamma': 5, 'degree': 2, 'c'...",0
5,./save/simu_data/scenario1/,512,0.001,0.1,KS,0.01,0.1,512,5000,500,10,"{'kernel': 'rbf', 'gamma': 10, 'degree': 2, 'c...",0
6,./save/simu_data/scenario1/,512,0.001,0.1,KS,0.01,1.0,512,5000,500,10,"{'kernel': 'rbf', 'gamma': 1, 'degree': 2, 'c'...",0
7,./save/simu_data/scenario1/,512,0.001,0.1,KS,0.01,1.0,512,5000,500,10,"{'kernel': 'rbf', 'gamma': 5, 'degree': 2, 'c'...",0
8,./save/simu_data/scenario1/,512,0.001,0.1,KS,0.01,1.0,512,5000,500,10,"{'kernel': 'rbf', 'gamma': 10, 'degree': 2, 'c...",0
9,./save/simu_data/scenario1/,512,0.001,0.1,KS,0.1,0.0,512,5000,500,10,"{'kernel': 'rbf', 'gamma': 1, 'degree': 2, 'c'...",0


In [7]:
def allocation_test(mdist_obs,mdist_array):
    return (np.sum(mdist_array>=mdist_obs)+1)/(mdist_array.shape[0]+1)

In [8]:
def parallel_unit(i,batch_size,
                  lr,pa,
                  x_lambda,
                  wt_lambda,
                  num_nodes,num_iters,
                  num_init_iters,
                  kernel_params,val_metric,
                  patience,random_state,
                  data_path):
  
    print('------------- Data:',i,'------------- ')

    # load the data
    data_full_path = data_path + 'd' + str(i) + '.npy'
    dat = np.load(data_full_path,allow_pickle=True)

    # unzip the data
    x = dat.item()['x']
    z = dat.item()['z']
    y1 = dat.item()['y1']
    y2 = dat.item()['y2']
    y3 = dat.item()['y3']
    
    # scale the covariate data
    sc = StandardScaler()
    x = sc.fit_transform(x)
    
    nt = int(z.sum())

    # only use the mean difference loss
    save_folder = './save/0515qrer_tuning/param_grid='+str(i_param)+'/'+str(i)+'/'

    if not os.path.exists(save_folder):
        print('Creat the folder.')
        os.makedirs(save_folder)

    if not os.path.exists(save_folder+'final_checkpoint.pt'):
        print('Train the model from scratch.')

        estimator = QRWG(lr=lr,
                          batch_size=batch_size,
                          patience=patience,
                          num_iters=num_iters,
                          num_init_iters=num_init_iters,
                          pa=pa,
                          x_lambda=x_lambda,
                          wt_lambda=wt_lambda,
                          num_nodes=num_nodes,
                          val_metric=val_metric,
                          save_folder=save_folder,
                          kernel_params=kernel_params,
                          verbose=False,
                          random_state=random_state)

        # train the model from scratch
        estimator.fit(x,z)

    else:
        print('Skip! The model has been trained.')
        estimator = QRWG(lr=lr,
                          batch_size=batch_size,
                          patience=patience,
                          num_iters=num_iters,
                          num_init_iters=1,
                          pa=pa,
                          x_lambda=x_lambda,
                          wt_lambda=wt_lambda,
                          num_nodes=num_nodes,
                          val_metric=val_metric,
                          save_folder=save_folder,
                          kernel_params=kernel_params,
                          verbose=False,
                          random_state=random_state)
        estimator.w = z
        estimator.nwts = int(estimator.w.shape[0])
        estimator.nt = int(z.sum())
        estimator.nc = int((1-z).sum())
        estimator._init_network()
        estimator.netG.load_state_dict(torch.load(save_folder+'final_checkpoint.pt'))
    
    # generate acceptable weights
#     wts_mat_net_fea = wts_mat_net[test_array>0.05]
#     while wts_mat_net_fea.shape[0]<1000:
#         wts_mat_net_tmp = estimator.predict(1000).numpy()
#         test_array_tmp = np.array([allocation_test(maha_dist(x,z,wts_mat_net_tmp[i]).item(),mdist_array) for i in range(1000)])
#         wts_mat_net_fea = np.concatenate([wts_mat_net_fea,wts_mat_net_tmp[test_array_tmp>0.05]],axis=0)
#     wts_mat_net_fea = wts_mat_net_fea[:1000]
    np.random.seed(i)
    torch.manual_seed(i)

    wts_mat_net = estimator.predict().numpy()

    z_rer_mat = np.array([ReR(pa,torch.Tensor(x),np.sum(z))[0].numpy() for i in range(1000)])
    z_rer = ReR(pa,torch.Tensor(x),np.sum(z))[0].numpy()
    mdist_array = np.array([ReR(pa,torch.Tensor(x),np.sum(z))[1].item() for i in range(1000)])

    wts_mat_net = estimator.predict(1000).numpy()

    mdiff_mat_net = np.array([cov_mdiff(x,z,wts_mat_net[i]) for i in range(1000)])
    mdiff_mat_rer = np.array([cov_mdiff(x,z_rer_mat[i]) for i in range(1000)])
    xmdiff_ks, xmdiff_pval = np.array([ks_2samp(mdiff_mat_net[:,i],
                                                mdiff_mat_rer[:,i]) for i in range(mdiff_mat_net.shape[1])]).mean(axis=0)

    test_array = np.array([allocation_test(maha_dist(x,z,wts_mat_net[i]).item(),mdist_array) for i in range(1000)])
    accept_ratio = np.mean(test_array>0.15)
    
    
    if not os.path.exists(save_folder+'qrwg_est.csv'):
        
        # compare different strategies
        # strategy 1:
        avg_wts = wts_mat_net.mean(axis=0)

        est1_s1 = tau_diff(y1,z,avg_wts)
        est2_s1 = tau_diff(y2,z,avg_wts)
        est3_s1 = tau_diff(y3,z,avg_wts)

        ci1_s1 = ri_ci(y1,z,est1_s1,z_rer_mat)
        ci2_s1 = ri_ci(y2,z,est2_s1,z_rer_mat)
        ci3_s1 = ri_ci(y3,z,est3_s1,z_rer_mat)

        df_est_s1 = pd.DataFrame({
          'tauhat': [est1_s1,est2_s1,est3_s1],
          "95CI_lb": [ci1_s1[0],ci2_s1[0],ci3_s1[0]],
          "95CI_ub": [ci1_s1[1],ci2_s1[1],ci3_s1[1]],
          'type': 'S1'
        })

        # strategy 2:
        wts = wts_mat_net[0]

        est1_s2 = tau_diff(y1,z,wts)
        est2_s2 = tau_diff(y2,z,wts)
        est3_s2 = tau_diff(y3,z,wts)

        ci1_s2 = ri_ci(y1,z,est1_s2,z_rer_mat)
        ci2_s2 = ri_ci(y2,z,est2_s2,z_rer_mat)
        ci3_s2 = ri_ci(y3,z,est3_s2,z_rer_mat)

        df_est_s2 = pd.DataFrame({
          'tauhat': [est1_s2,est2_s2,est3_s2],
          "95CI_lb": [ci1_s2[0],ci2_s2[0],ci3_s2[0]],
          "95CI_ub": [ci1_s2[1],ci2_s2[1],ci3_s2[1]],
          'type': 'S2'
        })
        
        # strategy 3:
        # wts = wts_mat_net[np.cumsum(test_array>0.05)==1][0]
#         try:
#             wts = wts_mat_net[np.cumsum(test_array>0.05)==1][0]
#         except:
#             wts_mat_net_tmp = estimator.predict(1000).numpy()
#             test_array_tmp = np.array([allocation_test(maha_dist(x,z,wts_mat_net_tmp[i]).item(),mdist_array) for i in range(1000)])
#             while np.mean(test_array_tmp>0.05)==0:
#                 wts_mat_net_tmp = estimator.predict(1000).numpy()
#                 test_array_tmp = np.array([allocation_test(maha_dist(x,z,wts_mat_net_tmp[i]).item(),mdist_array) for i in range(1000)])
#             wts = wts_mat_net_tmp[np.cumsum(test_array_tmp>0.05)==1][0]
        
#         est1_s3 = tau_diff(y1,z,wts)
#         est2_s3 = tau_diff(y2,z,wts)
#         est3_s3 = tau_diff(y3,z,wts)

#         ci1_s3 = ri_ci(y1,z,est1_s3,z_rer_mat)
#         ci2_s3 = ri_ci(y2,z,est2_s3,z_rer_mat)
#         ci3_s3 = ri_ci(y3,z,est3_s3,z_rer_mat)

#         df_est_s3 = pd.DataFrame({
#           'tauhat': [est1_s3,est2_s3,est3_s3],
#           "95CI_lb": [ci1_s3[0],ci2_s3[0],ci3_s3[0]],
#           "95CI_ub": [ci1_s3[1],ci2_s3[1],ci3_s3[1]],
#           'type': 'S3'
#         })
        
        
#        df_est = pd.concat([df_est_s1,df_est_s2,df_est_s3],axis=0)
        df_est = pd.concat([df_est_s1,df_est_s2],axis=0)
        df_est.to_csv(save_folder+"qrwg_est.csv",index=False)
    else:
        print('Skip! QRWG has been considered')

    return pd.read_csv(save_folder+"qrwg_est.csv").values,xmdiff_ks, xmdiff_pval, accept_ratio

In [9]:
n_kernel = 20
n_data = 5
tau = 1

In [10]:
result_df_list = []

In [11]:
for i_param in range(param_df.shape[0]):
#for i_param in range(2):
    kwargs = dict(param_df.iloc[i_param,:])
    print('----------------- [%d/%d] -----------------\n'%(i_param+1,param_df.shape[0]))
    results = Parallel(n_jobs=n_kernel)(delayed(parallel_unit)(i=i,**kwargs) for i in tqdm(range(n_data)))
    
    dat_array = np.array([results[i][0] for i in range(n_data)])
    bias = dat_array[:,:,0].mean(axis=0)-tau
    rmse = np.sqrt(np.mean((dat_array[:,:,0]-tau)**2,axis=0).astype(float))
    covarage = ((dat_array[:,:,1]<=tau)*(dat_array[:,:,2]>=tau)).mean(axis=0)
    width = (dat_array[:,:,2] - dat_array[:,:,1]).mean(axis=0)

    ks_array = np.array([results[i][1] for i in range(n_data)])
    kspval_array = np.array([results[i][2] for i in range(n_data)])
    accept_array = np.array([results[i][3] for i in range(n_data)])

    result_dict = {#'val_metric':[param_df.iloc[i_param,4]],
                  #'patience':[param_df.iloc[i_param,:]['patience']],
                  'gamma':[param_df.iloc[i_param,:]['kernel_params']['gamma']],
                  #'batch_size':[param_df.iloc[i_param,:]['batch_size']],
                  #'num_init_iters':[param_df.iloc[i_param,:]['num_init_iters']],
                  'x_lambda':[param_df.iloc[i_param,:]['x_lambda']],
                  'wt_lambda':[param_df.iloc[i_param,:]['wt_lambda']],
                  'bias':[bias.mean()],
                  'rmse':[rmse.mean()],
                  'covarage':[covarage.mean()],
                  'width':[width.mean()],
                  'ks':[np.median(ks_array)],#[ks_array.mean()],
                  'pval':[np.median(kspval_array)],#[kspval_array.mean()],
                  'accept':[accept_array.mean()]}
    result_df = pd.DataFrame(result_dict)
    result_df_list.append(result_df)

100%|██████████| 5/5 [00:00<00:00, 455.34it/s]

----------------- [1/27] -----------------




100%|██████████| 5/5 [00:00<00:00, 3499.92it/s]

----------------- [2/27] -----------------




100%|██████████| 5/5 [00:00<00:00, 2464.05it/s]

----------------- [3/27] -----------------




100%|██████████| 5/5 [00:00<00:00, 3169.34it/s]

----------------- [4/27] -----------------




100%|██████████| 5/5 [00:00<00:00, 3205.18it/s]

----------------- [5/27] -----------------




100%|██████████| 5/5 [00:00<00:00, 2724.28it/s]

----------------- [6/27] -----------------




100%|██████████| 5/5 [00:00<00:00, 2228.41it/s]

----------------- [7/27] -----------------




100%|██████████| 5/5 [00:00<00:00, 2085.06it/s]

----------------- [8/27] -----------------




100%|██████████| 5/5 [00:00<00:00, 3281.41it/s]

----------------- [9/27] -----------------




100%|██████████| 5/5 [00:00<00:00, 3029.69it/s]

----------------- [10/27] -----------------




100%|██████████| 5/5 [00:00<00:00, 3368.38it/s]

----------------- [11/27] -----------------




100%|██████████| 5/5 [00:00<00:00, 3815.78it/s]

----------------- [12/27] -----------------




100%|██████████| 5/5 [00:00<00:00, 3632.69it/s]

----------------- [13/27] -----------------




100%|██████████| 5/5 [00:00<00:00, 3220.44it/s]

----------------- [14/27] -----------------




100%|██████████| 5/5 [00:00<00:00, 3694.77it/s]

----------------- [15/27] -----------------




100%|██████████| 5/5 [00:00<00:00, 5500.01it/s]

----------------- [16/27] -----------------




100%|██████████| 5/5 [00:00<00:00, 3610.18it/s]

----------------- [17/27] -----------------




100%|██████████| 5/5 [00:00<00:00, 4401.16it/s]

----------------- [18/27] -----------------




100%|██████████| 5/5 [00:00<00:00, 3344.74it/s]

----------------- [19/27] -----------------




100%|██████████| 5/5 [00:00<00:00, 3640.89it/s]

----------------- [20/27] -----------------




100%|██████████| 5/5 [00:00<00:00, 3230.86it/s]

----------------- [21/27] -----------------




100%|██████████| 5/5 [00:00<00:00, 3750.27it/s]

----------------- [22/27] -----------------




100%|██████████| 5/5 [00:00<00:00, 3249.89it/s]

----------------- [23/27] -----------------




100%|██████████| 5/5 [00:00<00:00, 3298.45it/s]

----------------- [24/27] -----------------




100%|██████████| 5/5 [00:00<00:00, 3293.27it/s]

----------------- [25/27] -----------------




100%|██████████| 5/5 [00:00<00:00, 3629.55it/s]

----------------- [26/27] -----------------




100%|██████████| 5/5 [00:00<00:00, 3242.85it/s]

----------------- [27/27] -----------------






In [12]:
sum_mat = pd.concat(result_df_list,axis=0)
sum_mat

Unnamed: 0,gamma,x_lambda,wt_lambda,bias,rmse,covarage,width,ks,pval,accept
0,1,0.01,0.0,-0.303748,0.601719,0.933333,2.821667,0.103125,0.02430152,0.6998
0,5,0.01,0.0,-0.289439,0.54023,1.0,2.818667,0.089375,0.1380911,0.6022
0,10,0.01,0.0,-0.056413,0.482371,1.0,2.816333,0.089625,0.1086702,0.589
0,1,0.01,0.1,-0.220064,0.590878,0.933333,2.819,0.101,0.02092593,0.7256
0,5,0.01,0.1,-0.213526,0.491302,1.0,2.817667,0.0955,0.07111794,0.594
0,10,0.01,0.1,-0.169284,0.541586,0.966667,2.818333,0.088625,0.07527584,0.5696
0,1,0.01,1.0,-0.087251,0.536448,1.0,2.815333,0.1285,0.02068785,0.8064
0,5,0.01,1.0,-0.085987,0.547295,0.966667,2.817333,0.092375,0.1131856,0.6124
0,10,0.01,1.0,-0.017788,0.535872,1.0,2.816,0.08,0.1871091,0.5948
0,1,0.1,0.0,-0.269067,0.605245,0.933333,2.817667,0.108125,0.02722999,0.7776


In [13]:
sum_mat.to_csv('./save/tuning0515.csv')