In [1]:
import torch
import econml
import pandas as pd
import numpy as np 
from functools import partial
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICSCORE_LIB_OK"]="TRUE"

from causal_cocycle.model import cocycle_model,flow_model,cocycle_outcome_model
from causal_cocycle.optimise import *
from causal_cocycle.loss_functions import Loss
from causal_cocycle.conditioners import Empty_Conditioner,Constant_Conditioner,Lin_Conditioner,NN_RELU_Conditioner
from causal_cocycle.transformers import Transformer,Shift_layer,Scale_layer,RQS_layer,Inverse_layer
from causal_cocycle.helper_functions import likelihood_loss,mmd,propensity_score
from causal_cocycle.kernels import *
from causal_cocycle.kde import *

In [2]:
# Importing data
psid = pd.concat( 
    map(partial(pd.read_csv,sep = "  ", header = None), ['psid_controls.txt',
                                                         'psid2_controls.txt']), ignore_index=True)
nsw = pd.concat( 
    map(partial(pd.read_csv,sep = "  ", header = None), ['nsw_control.txt',
                                                         'nsw_treated.txt']), ignore_index=True)

psid_columns = ["treat","age","educ","black","hispan","married","nodegree","re74","re75","re78"]
nsw_columns = ["treat","age","educ","black","hispan","married","nodegree","re75","re78"]

psid.columns = psid_columns
nsw.columns = nsw_columns

  psid = pd.concat(
  psid = pd.concat(
  nsw = pd.concat(
  nsw = pd.concat(


In [3]:
# Getting empirical ATE
X,Y = torch.tensor(nsw['treat'].values),torch.tensor(nsw['re78'].values)
ATE = Y[X==1].mean()-Y[X==0].mean()
Prob_increase = (Y[X==1][:,None] >= Y[X==0][None]).float().mean()

In [4]:
# Constructing data to train on
N = len(psid)+len(nsw)
D = len(nsw.T[:-1])
Xtrain = torch.zeros((N,len(nsw_columns)-1))
Ytrain = torch.row_stack((torch.tensor(psid['re78'].values)[:,None],torch.tensor(nsw['re78'].values)[:,None])).float()
for i in range(len(nsw_columns)-1):
    Xtrain[:,i] = torch.row_stack((torch.tensor(psid[nsw_columns[i]].values)[:,None],torch.tensor(nsw[nsw_columns[i]].values)[:,None])).T.float()

# Shuffling data
shuffled_inds = torch.randperm(Xtrain.size()[0])
Xtrain = Xtrain[shuffled_inds]
Ytrain = Ytrain[shuffled_inds]

In [5]:
# Method + opt set up
cocycle_loss = "CMMD_U"
batch_size =64
validation_method = "CV"
choose_best_model = "overall"
layers = 2
width = 64
train_val_split = 0.8
learn_rate = [1e-3]
scheduler = True
maxiter = 10000
miniter = 10000
weight_decay = 1e-3
RQS_bins = 8
val_batch_size = N

# Setting training optimiser args
opt_args = ["learn_rate",
            "scheduler",
            "batch_size",
            "maxiter",
            "miniter",
            "weight_decay",
            "print_",
            "val_batch_size"]
opt_argvals = [learn_rate,
              scheduler,
              batch_size,
             maxiter,
              miniter,
              weight_decay,
              True,
              val_batch_size]

hyper = []
hyper_val = []

#Shorthand function calls
def NN(i,o=1,width=128,layers=2):
    return NN_RELU_Conditioner(width = width,
                                     layers = layers, 
                                     input_dims =  i, 
                                     output_dims = o,
                                     bias = True)


In [6]:
# Specifying models for cross-validation
conditioners_list = [[Lin_Conditioner(D,1)]]#,
                     #[NN(D,1,width,layers)]]
transformers_list = [Transformer([Shift_layer()])]#,
                     #Transformer([Shift_layer()])]
models_validation = []
for m in range(len(conditioners_list)):
    models_validation.append(cocycle_model(conditioners_list[m],transformers_list[m]))
hyper_args = [hyper]*len(conditioners_list)
hyper_argvals = [hyper_val]*len(conditioners_list)

In [7]:
# Scaling data
Ymu,Ysc =  Ytrain.mean(),Ytrain.var()**0.5
Xmu,Xsc =  Xtrain.mean(0),Xtrain.var(0)**0.5
Yscale = (Ytrain - Ymu)/Ysc
Xscale = (Xtrain - Xmu)/Xsc

# Getting loss functon (using CMMD_V as scalable for validation)
loss_fn =  Loss(loss_fn = cocycle_loss,kernel = [gaussian_kernel(torch.ones(1),1)]*2)
loss_fn_val =  Loss(loss_fn = "CMMD_V",kernel = [gaussian_kernel(torch.ones(1),1)]*2)
loss_fn.median_heuristic(Xscale,Yscale,subsamples = 10**3)
loss_fn_val.median_heuristic(Xscale,Yscale,subsamples = 10**3)

  batch_inds = torch.tensor([np.random.choice(ind_list,subsamples)]).long().view(subsamples,)


In [8]:
# Cross-validation
final_models,val_losses = validate(models_validation,
                                     loss_fn,
                                     Xscale,
                                     Yscale,
                                     loss_fn_val,
                                     validation_method,
                                     train_val_split,
                                     opt_args,
                                     opt_argvals,
                                     hyper_args,
                                     hyper_argvals,
                                     choose_best_model)

Training loss last 10 avg is : tensor(-0.6778)
99.9  % completion
Currently optimising model  0 , for fold  4


In [12]:
final_models[0].conditioner[0].state_dict()

OrderedDict([('stack.0.weight',
              tensor([[-0.0118, -0.0275,  0.1165, -0.0555, -0.0024,  0.0742,  0.0048,  0.6574]])),
             ('stack.0.bias', tensor([4.6041e-08]))])

In [10]:
X1,X0 = Xtrain*1, Xtrain*1
X1[:,0],X0[:,0] = 1,0
X1scale = (X1 - Xmu)/Xsc
X0scale = (X0 - Xmu)/Xsc
effect = final_models[0].cocycle(X1scale,X0scale,Ytrain*0)-Ytrain*0
print((effect*Ysc).mean())

tensor(-651.1846, grad_fn=<MeanBackward0>)


In [45]:
Dists = (Yscale[:100]- Yscale[:100].T)[...,None]
Dists_pred = Yscale[:100,None,:]

In [63]:
kernel = gaussian_kernel(torch.ones(1),1)
K1 = kernel.get_gram(Dists,Dists).mean()
K2 = -2*kernel.get_gram(Dists_pred,Dists).mean()
print(K1+K2)

batchsize = max(1,min(100,int(10**5/100**2)))
print(batchsize)
nbatch = int(100/batchsize)
K = 0
for i in range(nbatch):
    K += kernel.get_gram(Dists[i*batchsize:(i+1)*batchsize],Dists[i*batchsize:(i+1)*batchsize]).sum()/100**3
    K += -2*kernel.get_gram(Dists_pred[i*batchsize:(i+1)*batchsize],Dists[i*batchsize:(i+1)*batchsize]).sum()/100**2
print(K)

tensor(-0.8380)
10
tensor(-0.8380)
