In [4]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

import torch
import pyro
import pyro.distributions as dist
from pyro.infer.autoguide import AutoNormal
from chirho.indexed.handlers import IndexPlatesMessenger
from chirho.observational.handlers.cut import SingleStageCut
from pyro.infer import Predictive

pyro.settings.set(module_local_params=True)

from cities.utils.data_grabber import DataGrabber

sns.set_style("white")

pyro.set_rng_seed(321) # for reproducibility

In [27]:
# Let's load the data from the author's github
DATA_URL = "https://raw.githubusercontent.com/synth-inference/synthdid/master/data/california_prop99.csv"
data = pd.read_csv(DATA_URL, sep=";")

# Assign unique integer ids for each state and each time
data["unit_index"] = pd.factorize(data["State"].values)[0]
data["time_index"] = pd.factorize(data["Year"].values)[0]

# Model below assumes the response is coded as "y"
data["y"] = data["PacksPerCapita"].values.copy()

# Assign each unit to treatment or control group
data["in_treatment_group"] = 0
treated_units = data[data["treated"] == 1]["State"].unique()
data.loc[data["State"].isin(treated_units), "in_treatment_group"] = 1

data.head()

data.dtypes

State                  object
Year                    int64
PacksPerCapita        float64
treated                 int64
unit_index              int64
time_index              int64
y                     float64
in_treatment_group      int64
dtype: object

In [53]:
dg = DataGrabber()

dg.get_features_long(["gdp", 'population'])  #consider stdized

dg.long['gdp'].head()
analysis_data = pd.merge(dg.long['gdp'], dg.long['population'], on=['GeoFIPS', 'GeoName', 'Year'])
#print(analysis_data.columns[3:4])
analysis_data.rename(columns={'Value_x': 'gdp', 'Value_y': 'population'}, inplace=True)

analysis_data['unit_index']= pd.factorize(analysis_data['GeoFIPS'].values)[0]
analysis_data['time_index']= pd.factorize(analysis_data['Year'].values)[0]
analysis_data['y'] = analysis_data['gdp'].values.copy()


# some absurd pseudo-intervention for now
# so that 31 % states are "treated"
states_above_50000 = analysis_data[analysis_data['Year'] == 2015]['GeoFIPS'][analysis_data['population'] > 50000]
print(len(states_above_50000.unique())/len(analysis_data['GeoFIPS'].unique()))
analysis_data["in_treatment_group"] = 0
analysis_data.loc[analysis_data['GeoFIPS'].isin(states_above_50000) & (analysis_data['Year'] >= 2015), 'in_treatment_group'] = 1


display(analysis_data.head())


0.3173575129533679


Unnamed: 0,GeoFIPS,GeoName,Year,gdp,population,unit_index,time_index,y,in_treatment_group
0,1001,"Autauga, AL",2001,59.839,44889.0,0,0,59.839,0
1,1003,"Baldwin, AL",2001,73.853,144875.0,1,0,73.853,0
2,1005,"Barbour, AL",2001,113.864,28863.0,2,0,113.864,0
3,1007,"Bibb, AL",2001,80.443,21028.0,3,0,80.443,0
4,1009,"Blount, AL",2001,92.104,51845.0,4,0,92.104,0
