In [1]:
import os

import dill

import matplotlib.pyplot as plt
import torch
import pyro
import copy

import torch
import pyro
import pyro.distributions as dist
import matplotlib.pyplot as plt
import numpy as np

import torch
import time

import pandas as pd
from torch.utils.data import DataLoader

from chirho.indexed.ops import IndexSet, gather
import seaborn as sns


import pyro
from pyro.infer import Predictive

from chirho.counterfactual.handlers import MultiWorldCounterfactual

# from cities.modeling.zoning_models.units_causal_model import UnitsCausalModel
from cities.modeling.zoning_models.distance_causal_model import DistanceCausalModel
from cities.modeling.svi_inference import run_svi_inference
from cities.utils.data_loader import select_from_data


from pyro.infer.autoguide import AutoDiagonalNormal

from cities.modeling.evaluation import (
    prep_data_for_test,
    test_performance,
)

from cities.modeling.zoning_models.tracts_model import TractsModel

from cities.modeling.svi_inference import run_svi_inference
from pyro.infer import Predictive
from chirho.observational.handlers.predictive import PredictiveModel
from chirho.interventional.handlers import do


smoke_test = "CI" in os.environ

# use when testing model health
# smoke_test = True

n_steps = 10 if smoke_test else 1500
num_samples = 10 if smoke_test else 1000

from cities.utils.data_grabber import find_repo_root

root = find_repo_root()

In [2]:

census_tracts_data_path = os.path.join(root, "data/minneapolis/processed/census_tracts_dataset.pt")


ct_dataset_read = torch.load(census_tracts_data_path)

ct_loader = DataLoader(
    ct_dataset_read, batch_size=len(ct_dataset_read), shuffle=True
)

data = next(iter(ct_loader))

print(data["continuous"].keys())
print(data['categorical'].keys())   

kwargs = {
    "categorical": [
        "year_original",
        "year",
        "census_tract"
    ],
    "continuous": {
      'housing_units',
      'total_value',
      'median_value',
      'mean_limit_original',
      'median_distance',
      'income',
      'segregation_original',
      'white_original',
    },
    'outcome': 'housing_units'
}

subset = select_from_data(data, kwargs)
print(subset["continuous"].keys())
print(subset['continuous']['housing_units'].shape)
print(subset['categorical']['census_tract'].shape)

tracts_model = TractsModel(
    **subset, categorical_levels=ct_dataset_read.categorical_levels
)


subset_for_preds = copy.deepcopy(subset)
subset_for_preds['continuous']['housing_units'] = None  

dict_keys(['housing_units', 'housing_units_original', 'income', 'income_original', 'mean_distance', 'mean_distance_original', 'mean_limit', 'mean_limit_original', 'median_distance', 'median_distance_original', 'median_value', 'median_value_original', 'segregation', 'segregation_original', 'total_value', 'total_value_original', 'white', 'white_original'])
dict_keys(['year_original', 'year', 'census_tract'])
dict_keys(['housing_units', 'income', 'mean_limit_original', 'median_distance', 'median_value', 'segregation_original', 'total_value', 'white_original'])
torch.Size([816])
torch.Size([816])


In [3]:
pyro.clear_param_store()

guide_path = os.path.join(root, "data/minneapolis/guides", "tracts_model_guide.pkl")

param_path = os.path.join(root, "data/minneapolis/guides", "tracts_model_params.pth")

with open(guide_path, "rb") as file:
    guide = dill.load(file)

pyro.get_param_store().load(param_path)

predictive = Predictive(
    model=tracts_model, guide=guide, num_samples=num_samples, parallel=True
)

In [4]:
# RUN ONCE TO GENERATE BACKGROUND DATA USED FOR INTERVENTION CONSTRUCTION

# data  = copy.deepcopy(values[['limit_con', 'downtown_yn', 'year', 'distance_to_transit',
#                              'parcel', 'census_tract']])

# data.to_csv(os.path.join(root, "data/minneapolis/processed/census_tract_intervention_required.csv"))

# census_ids = pd.DataFrame({'census_tract':subset['categorical']['census_tract'].numpy(), 
#                            'year':subset['categorical']['year_original'].numpy()})
# census_ids.to_csv(os.path.join(root, "data/minneapolis/processed/census_ids.csv"))

In [5]:
# data = pd.read_csv(os.path.join(root,
#              "data/minneapolis/processed/census_tract_intervention_required.csv"))

# census_ids = pd.read_csv(os.path.join(root, "data/minneapolis/processed/census_ids.csv"))

# display(data.head())
# display(census_ids.head())

In [6]:

def values_intervention(radius_blue, limit_blue, 
                radius_yellow, limit_yellow, reform_year = 2015):
    

    # don't want to load large data multiple times
    
    if not hasattr(values_intervention, "global_census_ids"):   
        values_intervention.global_census_ids = pd.read_csv(os.path.join(root, "data/minneapolis/processed/census_ids.csv"))

        values_intervention.global_data = pd.read_csv(os.path.join(root,
            "data/minneapolis/processed/census_tract_intervention_required.csv"))

        data = values_intervention.global_data
        census_ids = values_intervention.global_census_ids    
        values_intervention.global_data = data[
            (data['census_tract'].isin(census_ids['census_tract'])) & 
            (data['year'].isin(census_ids['year']))]
        
    data = values_intervention.global_data.copy()

    intervention = copy.deepcopy(values_intervention.global_data['limit_con'])
    downtown = data['downtown_yn']
    new_blue = (~downtown) & (data['year'] >= reform_year) & (data["distance_to_transit"] <= radius_blue)
    new_yellow = (
            (~downtown)
            & (data["year"] >= reform_year)
            & (data["distance_to_transit"] > radius_blue)
            & (data["distance_to_transit"] <= radius_yellow)
        )
    new_other = (
        (~downtown) & (data['year'] > reform_year) & (data["distance_to_transit"] > radius_yellow)
    )

    intervention[downtown] = 0.0
    intervention[new_blue] = limit_blue
    intervention[new_yellow] = limit_yellow
    intervention[new_other] = 1.0

    data['intervention'] = intervention 

    return data

#note subsequent runs are much faster
start = time.time()
simple_intervention = values_intervention(300, .5, 700, .7, reform_year = 2015)
end = time.time()
print("Time to run values_intervention 1: ", end - start)
start2 = time.time()
simple_intervention2 = values_intervention(400, .5, 800, .6, reform_year = 2013)
end2 = time.time()

print("Time to run values_intervention 2: ", end2 - start2)


start3 = time.time()
simple_intervention3 = values_intervention(200, .4, 1000, .65, reform_year = 2013)
end3 = time.time()


print("Time to run values_intervention 3: ", end3 - start3)


Time to run values_intervention 1:  0.7421717643737793
Time to run values_intervention 2:  0.056528329849243164
Time to run values_intervention 3:  0.054033756256103516


In [7]:
def tracts_intervention (radius_blue, limit_blue, 
                radius_yellow, limit_yellow, reform_year = 2015):
    
    parcel_interventions = values_intervention(radius_blue, limit_blue,
                        radius_yellow, limit_yellow, reform_year = reform_year)

    aggregate = parcel_interventions[['census_tract', 'year', 'intervention']].groupby(['census_tract', 'year']).mean().reset_index()

    if not hasattr(tracts_intervention, "global_census_ids"):

        tracts_intervention.global_valid_pairs = set(zip(values_intervention.global_census_ids['census_tract'],
                                     values_intervention.global_census_ids['year']))


    subaggregate = aggregate[aggregate[['census_tract', 'year']].apply(tuple, axis=1).isin(tracts_intervention.global_valid_pairs)].copy()

    return torch.tensor(list(subaggregate['intervention']))

    # plt.hist(aggregate['intervention'])
    # plt.show()
    #return torch.tensor(subaggregate['intervention'])

start = time.time()
t_intervention = tracts_intervention(300, .5, 700, .7, reform_year = 2015)
end = time.time()
print("Time to run tracts_intervention 1: ", end - start)

start2 = time.time()
t_intervention2 = tracts_intervention(400, .5, 800, .6, reform_year = 2013)
end2 = time.time()
print("Time to run tracts_intervention 2: ", end2 - start2)


Time to run tracts_intervention 1:  0.14096927642822266
Time to run tracts_intervention 2:  0.1297438144683838


In [8]:
print(t_intervention.shape)


with MultiWorldCounterfactual() as mwc:
    with do(actions={"limit": t_intervention}):
        samples = predictive(**subset_for_preds)


print(samples["limit"].shape)
print(samples["housing_units"].shape)


torch.Size([816])
torch.Size([1000, 2, 1, 1, 1, 816])
torch.Size([1000, 2, 1, 1, 1, 816])


In [63]:

def generate_intervention_settings(
    distance_interval=50, distance_max=2500, limit_interval=0.2
):
    
    dist = np.arange(0, distance_max, distance_interval)
    dist = dist.round()  # to avoid small numerical issues

    lim = np.arange(0, 1.01, limit_interval)
    lim = lim.round(decimals=1)
    n_lim = len(lim)

    counter = 0            
    interventions_settings_dict = {}
    for d_blue in dist:
        for l_blue in lim:
            for d_yellow in dist[dist > d_blue]:  # note the built in restriction
                for l_yellow in lim:
                    
                    interventions_settings_dict[(d_blue, l_blue, d_yellow, l_yellow)] = (
                        counter
                    )
                    counter += 1

    return interventions_settings_dict

interventions_settings_dict = generate_intervention_settings(distance_interval=600, distance_max=2000, limit_interval=0.2)

print(len(interventions_settings_dict.keys()))

print(interventions_settings_dict)

216
{(0, 0.0, 600, 0.0): 0, (0, 0.0, 600, 0.2): 1, (0, 0.0, 600, 0.4): 2, (0, 0.0, 600, 0.6): 3, (0, 0.0, 600, 0.8): 4, (0, 0.0, 600, 1.0): 5, (0, 0.0, 1200, 0.0): 6, (0, 0.0, 1200, 0.2): 7, (0, 0.0, 1200, 0.4): 8, (0, 0.0, 1200, 0.6): 9, (0, 0.0, 1200, 0.8): 10, (0, 0.0, 1200, 1.0): 11, (0, 0.0, 1800, 0.0): 12, (0, 0.0, 1800, 0.2): 13, (0, 0.0, 1800, 0.4): 14, (0, 0.0, 1800, 0.6): 15, (0, 0.0, 1800, 0.8): 16, (0, 0.0, 1800, 1.0): 17, (0, 0.2, 600, 0.0): 18, (0, 0.2, 600, 0.2): 19, (0, 0.2, 600, 0.4): 20, (0, 0.2, 600, 0.6): 21, (0, 0.2, 600, 0.8): 22, (0, 0.2, 600, 1.0): 23, (0, 0.2, 1200, 0.0): 24, (0, 0.2, 1200, 0.2): 25, (0, 0.2, 1200, 0.4): 26, (0, 0.2, 1200, 0.6): 27, (0, 0.2, 1200, 0.8): 28, (0, 0.2, 1200, 1.0): 29, (0, 0.2, 1800, 0.0): 30, (0, 0.2, 1800, 0.2): 31, (0, 0.2, 1800, 0.4): 32, (0, 0.2, 1800, 0.6): 33, (0, 0.2, 1800, 0.8): 34, (0, 0.2, 1800, 1.0): 35, (0, 0.4, 600, 0.0): 36, (0, 0.4, 600, 0.2): 37, (0, 0.4, 600, 0.4): 38, (0, 0.4, 600, 0.6): 39, (0, 0.4, 600, 0.8): 4

In [57]:
def generate_interventions(interventions_settings_dict):
   

    interventions_path = os.path.join(root, "data/minneapolis/processed/tract_interventions_tuple.pkl")


    start = time.time()
    if os.path.exists(interventions_path):
            with open(interventions_path, "rb") as file:
                intervention_vectors_dict = dill.load(file)
            print(f"Loaded existing dictionary with {len(intervention_vectors_dict)} entries")
    else:
        print("No existing dictionary found, creating new one")
        intervention_vectors_dict = {}

    counter = 0 
    for key in interventions_settings_dict.keys():
        

        if key not in intervention_vectors_dict:
            intervention_vectors_dict[key] = tracts_intervention(
                radius_blue=key[0],
                limit_blue=key[1],
                radius_yellow=key[2],
                limit_yellow=key[3],
            )

            counter += 1
            if counter % 100 == 0:
                    print(f"Saving {key} at step {counter}")
                    with open(interventions_path, "wb") as file:
                        dill.dump(intervention_vectors_dict, file)
                    print(f"{len(intervention_vectors_dict)} exist out of {len(interventions_settings_dict)} desired (ratio: {len(intervention_vectors_dict)/len(interventions_settings_dict)})")


    with open(interventions_path, "wb") as file:
        dill.dump(intervention_vectors_dict, file)
    
    end = time.time()
    print("Time of last run: ", end - start, ", ", counter, "vectors computed.")
    print(f"{len(intervention_vectors_dict)} exist out of {len(interventions_settings_dict)} desired (ratio: {len(intervention_vectors_dict)/len(interventions_settings_dict)})")

    return intervention_vectors_dict

intervention_vectors_dict = generate_interventions(interventions_settings_dict)

Loaded existing dictionary with 216 entries
Time of last run:  0.06448507308959961 ,  0 vectors computed.
216 exist out of 216 desired (ratio: 1.0)


In [112]:
def generate_intervened_preds(intervention_vectors_dict, 
                              subset_for_preds, predictive, batch_size=100):

    tracts_prediction_path = os.path.join(root, "data/minneapolis/processed/tract_intervened_predictions.pkl")

    if os.path.exists(tracts_prediction_path):
        with open(tracts_prediction_path, "rb") as file:
            all_preds = dill.load(file)
        print(f"Loaded existing dictionary with {len(all_preds)} entries")

    else:
        print("No existing dictionary found, creating new one")
        all_preds = {}


    keys = list(intervention_vectors_dict.keys())
    total_batches = (len(keys) // batch_size) + 1

    
    batched_keys = {}
    batched_samples = {}
    mwcs = {}


    for batch_idx in range(total_batches):
        start_idx = batch_idx * batch_size
        end_idx = min((batch_idx + 1) * batch_size, len(keys))

        batched_keys[batch_idx] = keys[start_idx:end_idx]
        
        if all(key in all_preds for key in batched_keys[batch_idx]):
            print(f"Skipping batch {batch_idx + 1} out of {total_batches} as computed.")
            continue

        
        interventions_tuple = tuple(intervention_vectors_dict[key] for key in batched_keys[batch_idx])
        print(len(interventions_tuple))

        with MultiWorldCounterfactual() as mwc:
            with do(actions={"limit": interventions_tuple}):
                batched_samples[batch_idx] = predictive(**subset_for_preds)
        mwcs[batch_idx] = mwc

        value = batched_samples[batch_idx]['housing_units']

        with mwcs[batch_idx]:
            if batch_idx == 0:
                all_preds['factual_preds'] = (
                    gather(value, IndexSet(**{"limit": {0}}), event_dims=0)
                    .squeeze()
                    .detach()
                    .mean(axis=0)
                    .numpy()
                )


            for idx, key in enumerate(batched_keys[batch_idx]):

                all_preds[key] = gather(value,
                    IndexSet(**{"limit": {idx+1}}), event_dims=0).squeeze().detach().mean(axis=0).numpy()
            
        print(f"Batch {batch_idx + 1} out of {total_batches} done, saving progress")
        with open(tracts_prediction_path, "wb") as file:
            dill.dump(all_preds, file)



    return all_preds

                    
all_preds = generate_intervened_preds(intervention_vectors_dict, subset_for_preds, predictive, batch_size=100)


assert len(all_preds) == len(intervention_vectors_dict) + 1
assert 'factual_preds' in all_preds
assert all(all_preds[key].shape == (816,) for key in intervention_vectors_dict.keys()), \
    "Not all entries in all_preds have the shape (816,)"


Loaded existing dictionary with 217 entries
Skipping batch 1 out of 3 as computed.
Skipping batch 2 out of 3 as computed.
Skipping batch 3 out of 3 as computed.
