In [1]:
import os
import logging


import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
import dill

import torch
import pyro
import pyro.distributions as dist
from pyro.infer.autoguide import AutoNormal, AutoMultivariateNormal, AutoDelta
from pyro.infer import SVI, Trace_ELBO, Predictive
from pyro.optim import Adam
from chirho.indexed.handlers import IndexPlatesMessenger
from chirho.counterfactual.handlers import MultiWorldCounterfactual
from chirho.indexed.ops import IndexSet, gather, indices_of
from chirho.interventional.handlers import do
from chirho.observational.handlers import condition


from cities.utils.data_grabber import (DataGrabber, list_available_features, list_tensed_features, list_interventions, list_outcomes)
from cities.utils.cleaning_utils import check_if_tensed, find_repo_root
from cities.modeling.model_interactions import cities_model_interactions as model
from cities.modeling.model_interactions import InteractionsModel
from cities.modeling.modeling_utils import (prep_data_for_interaction_inference, train_interactions_model)
from cities.utils.cleaning_utils import find_repo_root


In [10]:


outcome_dataset='industry_transportation_warehousing_total'
intervention_dataset='spending_transportation'
intervention_variable = dg.std_long[intervention_dataset].columns[-1]
forward_shift = 1


dg = DataGrabber()
dg.get_features_std_long([intervention_dataset, outcome_dataset])

loaded_data = {}
loaded_data[forward_shift] =  prep_data_for_interaction_inference(
        outcome_dataset=outcome_dataset,
intervention_dataset=intervention_dataset,
intervention_variable=intervention_variable,
forward_shift=forward_shift, 
    )

model_args= (
        loaded_data[forward_shift]['N_t'],
        loaded_data[forward_shift]['N_cov'],
        loaded_data[forward_shift]['N_s'],
        loaded_data[forward_shift]['N_u'],
        loaded_data[forward_shift]['N_obs'],
        loaded_data[forward_shift]['state_index_sparse'],
        loaded_data[forward_shift]['state_index'],
        loaded_data[forward_shift]['time_index'],
        loaded_data[forward_shift]['unit_index'],
    )



In [22]:
def model_testing(
    N_t,
    N_cov,
    N_s,
    N_u,
    N_obs,
    state_index_sparse,
    state_index,
    time_index,
    unit_index,
    leeway=0.2,
):
    Y_bias = pyro.sample("Y_bias", dist.Normal(0, leeway))
    T_bias = pyro.sample("T_bias", dist.Normal(0, leeway))

    weight_TY = pyro.sample("weight_TY", dist.Normal(0, leeway))

    sigma_T = pyro.sample("sigma_T", dist.Exponential(1))
    sigma_Y = pyro.sample("sigma_Y", dist.Exponential(1))
 
    counties_plate = pyro.plate("counties_plate", N_u, dim=-1)
    states_plate = pyro.plate("states_plate", N_s, dim=-2)
    covariates_plate = pyro.plate("covariates_plate", N_cov, dim=-3)
    time_plate = pyro.plate("time_plate", N_t, dim=-4)
    print("Nt", N_t)

    with covariates_plate:
        X_bias = pyro.sample("X_bias", dist.Normal(0, leeway))
        sigma_X = pyro.sample("sigma_X", dist.Exponential(1))
        print("sigma_X", sigma_X.shape)
        weight_XT = pyro.sample("weight_XT", dist.Normal(0, leeway))
        weight_XY = pyro.sample("weight_XY", dist.Normal(0, leeway))

    with states_plate:
        weight_UsT = pyro.sample("weight_UsT", dist.Normal(0, leeway))
        weight_UsY = pyro.sample("weight_UsY", dist.Normal(0, leeway))

        with covariates_plate:
            weight_UsX = pyro.sample("weight_UsX", dist.Normal(0, leeway))

    with time_plate:
        weight_UtTfull = pyro.sample("weight_UtTfull", dist.Normal(0, leeway))    
        print(weight_UtTfull.shape)
        weight_UtT = pyro.sample("weight_UtT", dist.Normal(0, leeway))
        print(" model wUtT", weight_UtT.shape)
        weight_UtY = pyro.sample("weight_UtY", dist.Normal(0, leeway))

    with states_plate:
        with covariates_plate:
            UsX_weight_selected = weight_UsX[...,state_index_sparse,:,:]
            print("X_bias", X_bias.shape)
            print("UsX_weight_selected", UsX_weight_selected.shape)
            X_means = torch.einsum("...cddd,...cudd->...cud", X_bias, UsX_weight_selected).unsqueeze(-1)
            print("X_means", X_means.shape)
            X = pyro.sample("X", dist.Normal(X_means, sigma_X)) # cudd
    print("X", X.shape)
    XT_weighted = torch.einsum("cudd, cddd -> du", X, weight_XT).unsqeeze(-2)
    XY_weighted = torch.einsum("cudd, cddd -> du", X, weight_XY).unsqueeze(-2) # dudd

    with observations_plate:
        T_mean = (
            T_bias
            + weight_UtT[time_index]
            + weight_UsT[state_index]
            + XT_weighted[unit_index]
        )

        T = pyro.sample("T", dist.Normal(T_mean, sigma_T))

        Y_mean = (
            Y_bias
            + weight_UtY[time_index]
            + weight_UsY[state_index]
            + weight_TY * T
            + XY_weighted[unit_index]
        )

        Y = pyro.sample("Y", dist.Normal(Y_mean, sigma_Y))

    return Y 

with pyro.poutine.trace() as tr:
    model_testing(*model_args)
  


Nt 11
sigma_X torch.Size([30, 1, 1, 1])
torch.Size([11, 1, 1, 1, 1])
 model wUtT torch.Size([11, 1, 1, 1, 1])
X_bias torch.Size([30, 1, 1, 1])
UsX_weight_selected torch.Size([30, 3065, 1, 1])
X_means torch.Size([30, 3065, 1, 1])


ValueError: Shape mismatch inside plate('states_plate') at site X dim -3, 51 vs 3065

In [None]:
 data = prep_data_for_interaction_inference(
            outcome_dataset=self.outcome_dataset,
            intervention_dataset=self.intervention_dataset,
            intervention_variable=self.intervention_variable,
            forward_shift=2,
        )