In [1]:
import os
import logging

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
import dill
import plotly.graph_objects as go
import random

import torch
import pyro
import pyro.distributions as dist
from pyro.infer.autoguide import AutoNormal, AutoMultivariateNormal, AutoDelta
from pyro.infer import SVI, Trace_ELBO, Predictive
from pyro.optim import Adam
from chirho.indexed.handlers import IndexPlatesMessenger
from chirho.counterfactual.handlers import MultiWorldCounterfactual
from chirho.indexed.ops import IndexSet, gather, indices_of
from chirho.interventional.handlers import do
from chirho.observational.handlers import condition


from cities.utils.data_grabber import (DataGrabber, list_available_features, list_tensed_features, list_interventions, list_outcomes)
from cities.utils.cleaning_utils import check_if_tensed, find_repo_root
from cities.queries.causal_insight import CausalInsight
from cities.modeling.model_interactions import model_cities_interaction
from cities.modeling.model_interactions import InteractionsModel
from cities.modeling.modeling_utils import (prep_wide_data_for_inference, train_interactions_model)
from cities.utils.cleaning_utils import find_repo_root


In [2]:
interventions = list_interventions()
outcomes = list_outcomes()
shifts = [1,2,3]

outcome = 'unemployment_rate'
intervention='spending_commerce'
shift = 2
intervened_value = 0.7
#fips = 1005

data = DataGrabber()
data.get_features_wide(["gdp"])
gdp = data.wide["gdp"]
#values = [round(i * 0.1, 1) for i in range(1, 10)]
fips = random.choice(gdp['GeoFIPS'])
#intervened_value = random.choice(values)

In [None]:
ci = CausalInsight(
    outcome_dataset = outcome,
    intervention_dataset= intervention,
    num_samples=1000,
    )

ci.generate_tensed_samples()

In [None]:
print(outcome)
print(intervened_value)
print("intervention", intervention)

ci.get_fips_predictions(intervened_value=intervened_value, fips = fips)
display(ci.predictions)

unemployment_rate
0.7
intervention spending_commerce
2018    2018
2019    2019
2020    2020
2021    2021
Name: year, dtype: object


Unnamed: 0,year,observed,mean,low,high
2018,2018,-0.080362,-0.080362,-0.080362,-0.080362
2019,2019,-0.0864,-0.122012,-0.135381,-0.107826
2020,2020,0.112762,0.081241,0.067418,0.095687
2021,2021,0.187243,0.155595,0.140547,0.170648


In [None]:
display(ci.predictions_original)

Unnamed: 0,observed,mean,low,high,year
0,3.4,3.4,3.4,3.4,2018
1,3.2,2.9,2.8,3.0,2019
2,8.3,7.9,7.7,8.1,2020
3,6.6,6.3,6.1,6.4,2021


In [None]:
ci.plot_predictions(scaling = "transformed")

In [None]:
ci.plot_predictions(scaling = "original")