In [1]:
import os

In [2]:
os.chdir("/content/drive/MyDrive/Github/QRWG")
print(os.getcwd())

FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/Github/QRWG'

In [3]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sb
import pandas as pd

from datagen import *
from network import *

In [4]:
# data path
data_path = './save/simu_data/'

# network training 
batch_size = 256
lr = 1e-3
pa = 0.05#np.inf
init_method = 'uniform'
num_nodes = 128
num_iters = 1500
patience = 5
random_state = 0

In [5]:
i = 0

In [6]:
# load the data
data_full_path = data_path + 'd' + str(i) + '.npy'
dat = np.load(data_full_path,allow_pickle=True)

In [7]:
# unzip the data
x = dat.item()['x']
z = dat.item()['z']
y1 = dat.item()['y1']
y2 = dat.item()['y2']
y3 = dat.item()['y3']

In [8]:
save_folder = './save/qrwg_pa='+str(pa)+'/'+str(i)+'/'

In [9]:
if not os.path.exists(save_folder):
    print('Creat the folder.')
    os.makedirs(save_folder)
    
if not os.path.exists(save_folder+'final_checkpoint.pt'):
    print('Train the model from scratch.')
    
    estimator = QRWG(lr=lr,
                     batch_size=batch_size,
                     patience=patience,
                     num_iters=num_iters,
                     pa=pa,
                     init_method=init_method,
                     num_nodes=num_nodes,
                     save_folder=save_folder,
                     verbose=True,
                     random_state=random_state)
    
    # train the model from scratch
    estimator.fit(x,z)
    
    # save the loss and other plots for diagnosis

    # loss
    plt.figure()
    plt.plot(estimator.losses[100:],label='Total Loss')
    plt.plot(estimator.qqlosses[100:],label='QQLoss')
    plt.plot(estimator.mdifflosses[100:],label='MdiffLoss')
    plt.legend()
    plt.savefig(save_folder+"loss_trace.png")
    plt.show()

    # mean difference
    plt.figure()
    plt.plot(estimator.x_mdiff_list)
    plt.savefig(save_folder+"mdiff_trace.png")
    plt.show()

    # ks difference
    plt.figure()
    plt.plot(estimator.ks_list)
    plt.savefig(save_folder+"ks_trace.png")
    plt.show()

    # val_ks 
    plt.figure()
    plt.plot(estimator.val_ks_list)
    plt.savefig(save_folder+"val_ks_trace.png")
    plt.show()
    
    # save covariance balance
    wts_mat_net = estimator.predict().numpy()
    z_mat_rer = np.array([ReR(pa,torch.Tensor(x))[0].numpy() for i in range(1000)])

    mdiff_mat_net = np.array([cov_mdiff(x,z,wts_mat_net[i]) for i in range(wts_mat_net.shape[0])])
    mdiff_mat_rer = np.array([cov_mdiff(x,z_mat_rer[i]) for i in range(z_mat_rer.shape[0])])

    df_mdiff_net = pd.DataFrame(mdiff_mat_net,columns=['X'+str(i+1) for i in range(mdiff_mat_net.shape[1])])
    df_mdiff_net['Method'] = 'QRWG'
    df_mdiff_rer = pd.DataFrame(mdiff_mat_rer,columns=['X'+str(i+1) for i in range(mdiff_mat_rer.shape[1])])
    df_mdiff_rer['Method'] = 'ReR'

    df_mdiff = pd.concat([df_mdiff_net,df_mdiff_rer],axis=0)
    df_mdiff = pd.melt(df_mdiff,id_vars=['Method'],var_name=['Variable'],value_name='Value')

    plt.figure(figsize=(10,6))
    sb.boxplot(x='Variable',y='Value',
               hue="Method",data=df_mdiff)
    plt.ylabel('$\\bar{x}_T-\\bar{x}_C$',fontsize=12)
    plt.xlabel('Covariates',fontsize=12)
    plt.savefig(save_folder+"covbalance.pdf")
    plt.show()
    
    # save inferences results
    tau1_vec_net = np.array([tau_diff(y1,z,wts_mat_net[i]) for i in range(wts_mat_net.shape[0])])
    tau2_vec_net = np.array([tau_diff(y2,z,wts_mat_net[i]) for i in range(wts_mat_net.shape[0])])
    tau3_vec_net = np.array([tau_diff(y3,z,wts_mat_net[i]) for i in range(wts_mat_net.shape[0])])

    df_est = pd.DataFrame({
        'tauhat': [tau1_vec_net.mean(),tau2_vec_net.mean(),tau3_vec_net.mean()],
        '95CI_lb': [np.quantile(tau1_vec_net,0.025),np.quantile(tau2_vec_net,0.025),np.quantile(tau3_vec_net,0.025)],
        '95CI_ub': [np.quantile(tau1_vec_net,0.975),np.quantile(tau2_vec_net,0.975),np.quantile(tau3_vec_net,0.975)]
    })

    df_est.to_csv(save_folder+"tau_est.csv",index=False)
    
else:
    
    print('Skip! The model has been trained.')
    
#         # load from existing model
#         estimator.self._init_data()
#         estimator.self._init_network()
#         estimator.netG.load_state_dict(torch.load(save_folder+'best_checkpoint.pt'))

Train the model from scratch.


KeyboardInterrupt: 