In [1]:
import os
import random
import cProfile
import timeit

from cities.queries.causal_insight import CausalInsight
from cities.utils.data_grabber import (DataGrabber, list_interventions,
                                       list_outcomes)

smoke_test = "CI" in os.environ
num_samples = 10 if smoke_test else 1000

In [2]:
interventions = list_interventions()
outcomes = list_outcomes()
shifts = [1, 2, 3]

outcome = "unemployment_rate"
intervention = "spending_commerce"
shift = 2
intervened_value = 0.7
fips = 1005

data = DataGrabber()
data.get_features_wide(["gdp"])
gdp = data.wide["gdp"]
values = [round(i * 0.1, 1) for i in range(1, 10)]
fips = gdp["GeoFIPS"][5] #1011


In [3]:
def basic_run():
    ci = CausalInsight(
        outcome_dataset=outcome,
        intervention_dataset=intervention,
        num_samples=num_samples,
    )

    ci.load_guide(forward_shift=shift)
    ci.generate_tensed_samples()
    ci.get_fips_predictions(intervened_value=intervened_value, fips=fips)
    ci.plot_predictions(range_multiplier=1)



In [4]:
profiler_basic = cProfile.Profile()

profiler_basic.enable()
basic_run()
profiler_basic.disable()

# confirming that the time consuming moves involve the model
# gettign samples, etc.
profiler_basic.print_stats(sort='cumulative')

# note: restricting sites to tau drops time from 8s to 5s. 


         946660 function calls (916815 primitive calls) in 7.875 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    0.000    0.000    7.877    3.938 interactiveshell.py:3514(run_code)
    792/2    0.002    0.000    7.877    3.938 {built-in method builtins.exec}
        1    0.000    0.000    7.876    7.876 433866942.py:1(<module>)
        1    0.000    0.000    7.876    7.876 2494475649.py:1(basic_run)
        1    0.000    0.000    7.264    7.264 causal_insight.py:103(generate_tensed_samples)
        3    0.002    0.001    7.227    2.409 causal_insight.py:62(generate_samples)
     12/3    0.000    0.000    6.328    2.109 module.py:1514(_wrapped_call_impl)
     12/3    0.001    0.000    6.328    2.109 module.py:1520(_call_impl)
        3    0.192    0.064    6.327    2.109 predictive.py:246(forward)
        6    0.002    0.000    6.135    1.022 predictive.py:67(_predictive)
       18    0.000    0.000    6.128  

In [5]:
# # this is prep pipeline

# ci = CausalInsight(
#         outcome_dataset=outcome,
#         intervention_dataset=intervention,
#         num_samples=num_samples,
#     )

# ci.load_guide(forward_shift=shift)
# ci.generate_tensed_samples()
# # ci.get_fips_predictions(intervened_value=intervened_value, fips=fips)
# # ci.plot_predictions(range_multiplier=1)

In [7]:
# this is the slim execution 
def slim_run():
    ci = CausalInsight(
        outcome_dataset=outcome,
        intervention_dataset=intervention,
        num_samples=num_samples,
    )

    ci.get_tau_samples()
    ci.get_fips_predictions(intervened_value=intervened_value, fips=fips)
    ci.plot_predictions(range_multiplier=1)

profiler_slim = cProfile.Profile()

profiler_slim.enable()
slim_run()
profiler_slim.disable()


profiler_slim.print_stats(sort='cumulative')


         249257 function calls (235951 primitive calls) in 0.509 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    0.000    0.000    0.509    0.255 interactiveshell.py:3514(run_code)
        2    0.000    0.000    0.509    0.255 {built-in method builtins.exec}
        1    0.001    0.001    0.509    0.509 207891617.py:2(slim_run)
       50    0.001    0.000    0.416    0.008 readers.py:848(read_csv)
       50    0.002    0.000    0.415    0.008 readers.py:574(_read)
        1    0.001    0.001    0.408    0.408 causal_insight.py:127(get_fips_predictions)
       50    0.001    0.000    0.367    0.007 readers.py:1732(read)
        1    0.002    0.002    0.309    0.309 modeling_utils.py:20(prep_wide_data_for_inference)
       50    0.001    0.000    0.301    0.006 c_parser_wrapper.py:222(read)
       50    0.282    0.006    0.285    0.006 {method 'read_low_memory' of 'pandas._libs.parsers.TextReader' objects}
    