# Extract the LaLonde dataset 

We want to extract the LaLonde data set that has been used to generate all the other realcause datasets in the `realcause_datasets` folder. This is the same base data set that we will use to train the Credence model. 

In [1]:
import pandas as pd
import numpy as np
from numpy.testing import assert_approx_equal
from pathlib import Path
import time

from loading import load_from_folder

from data.lalonde import load_lalonde
from data.twins import load_twins
from consts import REALCAUSE_DATASETS_FOLDER, N_SAMPLE_SEEDS, N_AGG_SEEDS

In [4]:
psid_gen_model, args = load_from_folder(dataset='lalonde_psid1')
psid_w, psid_t, psid_y = load_lalonde(obs_version='psid', data_format='pandas')

Namespace(activation='ReLU', atoms=[0.0], batch_size=25000, comet=True, data='lalonde', dataroot='/home/mila/r/raghupas/causal-benchmark/datasets', dim_h=4, dist='SigmoidFlow', dist_args=['ndim=2', 'base_distribution=uniform'], early_stop=True, eval=True, grad_norm=inf, ignore_w=False, kernel_t='RBFKernel', kernel_y='RBFKernel', lr=0.01, model_type='tarnet', n_hidden_layers=1, num_epochs=10000, num_tasks=32, num_univariate_tests=100, overwrite_reload='', patience=None, saveroot='save', seed=123, test_prop=0.4, test_size=None, train=True, train_prop=0.5, val_prop=0.1, var_dist='MeanFieldVariationalDistribution', w_transform='Normalize', y_transform='Normalize')
2024-05-30 20:04:36.186790 / Namespace(activation='ReLU', atoms=[0.0], batch_size=25000, comet=False, data='lalonde', dataroot='./datasets', dim_h=4, dist='SigmoidFlow', dist_args=['ndim=2', 'base_distribution=uniform'], early_stop=True, eval=False, grad_norm=inf, ignore_w=False, kernel_t='RBFKernel', kernel_y='RBFKernel', lr=0.0

In [5]:
psid_w

Unnamed: 0,age,education,black,hispanic,married,nodegree,re74,re75
0,37.0,11.0,1.0,0.0,1.0,1.0,0.000000,0.000000
1,22.0,9.0,0.0,1.0,0.0,1.0,0.000000,0.000000
2,30.0,12.0,1.0,0.0,0.0,0.0,0.000000,0.000000
3,27.0,11.0,1.0,0.0,0.0,1.0,0.000000,0.000000
4,33.0,8.0,1.0,0.0,0.0,1.0,0.000000,0.000000
...,...,...,...,...,...,...,...,...
2485,47.0,8.0,0.0,0.0,1.0,1.0,44667.363281,33837.097656
2486,32.0,8.0,0.0,0.0,1.0,1.0,47022.402344,67137.093750
2487,47.0,10.0,0.0,0.0,1.0,1.0,48197.964844,47968.113281
2488,54.0,0.0,0.0,1.0,1.0,1.0,49228.539062,44220.968750


In [6]:
psid_gen_model.ate(noisy=True)

-13195.036332417525

In [7]:
psid_gen_model.ate(noisy=False)



0.0

In [8]:
psid_gen_model

<models.tarnet.TarNet at 0x7f55d53c1fd0>

This is the dataset that has been used to generate the `lalonde_psid` datasets with realcause. We will now store and use this dataset for Credence.

In [13]:
# Store the lalonde psid data set here. 

# Join the psid_w, psid_t, psid_y dataframes into a single dataframe
psid = pd.concat([psid_w, psid_t, psid_y], axis=1)
psid.to_csv('pba_data/psid.csv', index=False)

In [7]:
# Testing the realcause package for basic functionality
from loading import load_realcause_dataset

df = load_realcause_dataset('lalonde_psid', 1)

In [8]:
df

Unnamed: 0,age,education,black,hispanic,married,nodegree,re74,re75,t,y,y0,y1,ite
0,37.0,11.0,1.0,0.0,1.0,1.0,0.000,0.000,0.0,5532.554,5532.554,0.0000,-5532.554
1,22.0,9.0,0.0,1.0,0.0,1.0,0.000,0.000,0.0,0.000,0.000,4797.9730,4797.973
2,30.0,12.0,1.0,0.0,0.0,0.0,0.000,0.000,0.0,0.000,0.000,0.0000,0.000
3,27.0,11.0,1.0,0.0,0.0,1.0,0.000,0.000,0.0,1667.217,1667.217,0.0000,-1667.217
4,33.0,8.0,1.0,0.0,0.0,1.0,0.000,0.000,0.0,0.000,0.000,3085.3630,3085.363
...,...,...,...,...,...,...,...,...,...,...,...,...,...
2670,47.0,8.0,0.0,0.0,1.0,1.0,44667.363,33837.098,0.0,43386.914,43386.914,6671.3926,-36715.523
2671,32.0,8.0,0.0,0.0,1.0,1.0,47022.402,67137.090,0.0,71678.370,71678.370,0.0000,-71678.370
2672,47.0,10.0,0.0,0.0,1.0,1.0,48197.965,47968.113,0.0,50136.547,50136.547,7171.6323,-42964.914
2673,54.0,0.0,0.0,1.0,1.0,1.0,49228.540,44220.970,0.0,38690.840,38690.840,14053.9480,-24636.890


In [9]:
df.shape

(2675, 13)

In [10]:
df.columns

Index(['age', 'education', 'black', 'hispanic', 'married', 'nodegree', 're74',
       're75', 't', 'y', 'y0', 'y1', 'ite'],
      dtype='object')