In [1]:
import pandas as pd

import torch
import pyro
import pyro.distributions as dist
from pyro.infer.autoguide import AutoNormal
from chirho.indexed.handlers import IndexPlatesMessenger
from chirho.observational.handlers.cut import SingleStageCut

from pyro.infer import Predictive

In [2]:
from cities.utils.data_grabber import DataGrabber

dg = DataGrabber()

outcome_dataset = 'unemployment_rate'
intervention_dataset = 'spending_HHS'
intervention_variable = 'total_obligated_amount'

dg.get_features_std_long([outcome_dataset, intervention_dataset]) 
dg.get_features_std_wide([outcome_dataset, intervention_dataset]) 

year_min = max(dg.std_long[intervention_dataset]['Year'].min(), dg.std_long[outcome_dataset]['Year'].min())
year_max = min(dg.std_long[intervention_dataset]['Year'].max(), dg.std_long[outcome_dataset]['Year'].max())

outcome = dg.std_long[outcome_dataset][(dg.std_long[outcome_dataset]['Year'] >= year_min) & (dg.std_long[outcome_dataset]['Year'] <= year_max)]
intervention = dg.std_long[intervention_dataset][(dg.std_long[intervention_dataset]['Year'] >= year_min) & (dg.std_long[intervention_dataset]['Year'] <= year_max)]

# available time units for both intervention and outcome
T_outcome = year_max - year_min + 1

#TODO add covariates
data = pd.merge(outcome, intervention, on=['GeoFIPS', 'Year'])

if 'GeoName_x' in data.columns:
    data.rename(columns={'GeoName_x': "GeoName"}, inplace=True)    
    columns_to_drop = data.filter(regex=r'^GeoName_[a-zA-Z]$')
    data.drop(columns=columns_to_drop.columns, inplace=True)


data.rename(columns={'Value': outcome_dataset}, inplace=True)

data['state'] = [code // 1000 for code in data['GeoFIPS']]

N_states = len(data['state'].unique())

data['unit_index']= pd.factorize(data['GeoFIPS'].values)[0]
data['state_index']= pd.factorize(data['state'].values)[0]
data['time_index']= pd.factorize(data['Year'].values)[0]
data['y'] = data[outcome_dataset].values.copy()

display(data.head())

y = torch.tensor(data['y'])
unit_index = torch.tensor(data['unit_index'])
state_index = torch.tensor(data['state_index'])
time_index = torch.tensor(data['time_index'])
intervention = torch.tensor(data[intervention_variable])

print(data['time_index'].unique())

beta = pyro.sample('beta', dist.Normal(0, 1).expand([T_outcome]))

print(beta[time_index])

Unnamed: 0,GeoFIPS,GeoName,Year,unemployment_rate,total_obligated_amount,state,unit_index,state_index,time_index,y
0,1003,"Baldwin County, AL",2010,0.025,-0.984055,1,0,0,0,0.025
1,1015,"Calhoun County, AL",2010,0.09,-0.993019,1,1,0,0,0.09
2,1031,"Coffee County, AL",2010,-0.109589,-0.999981,1,2,0,0,-0.109589
3,1035,"Conecuh County, AL",2010,0.375,-0.993819,1,3,0,0,0.375
4,1039,"Covington County, AL",2010,0.04,-0.999552,1,4,0,0,0.04


[ 0  1  2  3  4  5  6  7  8  9 10 11]
tensor([0.2104, 0.2104, 0.2104,  ..., 1.2072, 1.2072, 1.2072])


In [8]:
def cities_model(T_outcome, N_states, state_index, time_index, intervention,
                  y = None):
        # Intercept 
        mu = pyro.sample("mu", dist.Normal(0, 1))

        # time effects
        beta = pyro.sample("beta", dist.Normal(0, 1).expand((T_outcome,)).to_event(1))

        print("beta", len(beta))
        print("time_index", len(time_index), time_index.unique())
        # state latent confounders
        alpha = pyro.sample( "alpha", dist.Normal(0, 1).expand((N_states,)).to_event(1))

        #TODO consider adding unit level effects?

        #treatment effect
        tau = pyro.sample("tau", dist.Normal(0, 1))

        # sd 
        sigma = pyro.sample("sigma", dist.Exponential(1))
        
        #with pyro.plate("data", len(y)):
        
        #        y_mean = mu  + beta[time_index]  #+ alpha[state_index] + tau * intervention
        #         pyro.sample("y", dist.Normal(y_mean, sigma))

        return  #y_mean, mu, beta, tau, alpha
print(
cities_model(T_outcome, N_states, y, state_index, time_index, intervention)
)

beta 12
time_index 21833 tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50])
None
