In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns #for plotting
from sklearn.ensemble import RandomForestClassifier #for the model
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz #plot tree
from sklearn.metrics import roc_curve, auc #for model evaluation
from sklearn.metrics import classification_report #for model evaluation
from sklearn.metrics import confusion_matrix #for model evaluation
from sklearn.model_selection import train_test_split #for data splitting
np.random.seed(1) #ensure reproducibility

pd.options.mode.chained_assignment = None  #hide any pandas warnings

In [2]:
import pandas as pd
import numpy as np
import dame_flame
import random
import matplotlib.pyplot as plt

In [3]:
clean_ridership_path = "/Users/albertsun/Projects/data/flame_boardings.csv"
df_clean_ridership = pd.read_csv(clean_ridership_path)
df_clean_ridership = df_clean_ridership.iloc[:,1:]
df_clean_ridership.iloc[matched[matched.Agency==2].index]['SAP'].value_counts()

NameError: name 'matched' is not defined

In [None]:
df_clean_ridership.iloc[matched[matched.Agency==1].index]['SAP'].value_counts()

In [None]:
model_flame_AMT = dame_flame.matching.FLAME(repeats=False, 
                                            verbose=3, 
                                            early_stop_iterations=30, 
                                            stop_unmatched_t = True, 
                                            adaptive_weights='decisiontree', 
                                            want_pe=True)


In [None]:
model_flame_AMT.fit(holdout_data=False, 
                    treatment_column_name='SAP', 
                    outcome_column_name='total_boardings')
result_flame_AMT = model_flame_AMT.predict(df_clean_ridership)

In [None]:
matched = model_flame_AMT.df_units_and_covars_matched
matched[matched.Agency==2].index

In [None]:
def compute_cates(column_name: str) -> dict:
    """
    Return dictionary where
    - the dictionary keys are a value in a column (i.e. Asian) and 
    - the dictionary value are the CATEs for all the units that have that value in the column (i.e. race=Asian) 
    """
    cates = {}
    cates_avg = {}
    for val in set(matched[column_name].tolist()):
        cates[val] = []
        print(f'Calculating cates for column: {column_name}, for value: {val}')
        for i in matched[(matched[column_name] == val)].index:
            cates[val].append(dame_flame.utils.post_processing.CATE(model_flame_AMT, i))
        cates_avg[val] = sum(cates[val]) / len(cates[val])
    return cates_avg

def get_cates_and_counts(column_name: str) -> pd.DataFrame:
    """
    Return pd DataFrame with CATE values and Count Values to plot both
    
    - 'CATE' column has CATE value corresponding to all the units with that value in the column
    - 'Counts' column has the number of all units in the data that are matched to a group that 
    has that specific value, i.e. 1791 individuals/units with matched in a group such that race=White
    """
    cates_avg = compute_cates(column_name)
    cates_avg_series = pd.Series(cates_avg)
    
    num_lvls = matched[column_name].value_counts()
    frame = {'CATE': cates_avg_series, 
             'Counts': num_lvls}
    data = pd.DataFrame(frame)
    data.index = data.index.map(str)
    data = data.sort_index() 
    return data

# non-response columns: 
cols = ['RaceDesc', 'OverallJobAccess_D', 'lowwagelaborforce_D', 'Access30Transit_D', 'Access30Transit_ts_D', 'spatialmismatch_D', 'Agency', 'Language', 'stage', 'Age_bin']

df_dic = {}

def export_cate_df(cols: list): 
    """
    Create dictionary where
    - keys are column (i.e. White)
    - values are pandas dataframe (i.e. pandas dataframe created using get_cates_and_counts() that shows CATE
    values and counts)
    """
    for col in cols:
        df_dic[col] = get_cates_and_counts(col)

export_cate_df(cols)


In [None]:
artifacts_filepath = '/Users/albertsun/Projects/artifacts/'
response_var = 'ridership'

import seaborn as sns
sns.set_theme(style="white")
sns.set(font_scale=2)

In [None]:
plt.rcParams["figure.figsize"] = (8,11)

In [None]:
def plot_column_CATE(column, title=None, xaxistitle=None, rename_dic=None) -> None:
    """
    Generalized function that plots the CATE and counts for everything in a particular column, i.e. 
    "race". Must run previous helper functions for this to work. 
    
    Takes:
    - column - column name
    - title - plot title
    - xaxistitle - x axis title 
    
    Returns:
    - None
    
    """
    
    if title ==None: 
        title = column + " CATE on Ridership"
        
    df = df_dic[column]
    
    df = df.rename(index={'*': 'Unmatched'})
    
    if rename_dic != None: 
        df = df.rename(index=rename_dic)
    
    df = df.sort_values('CATE', ascending=False)
    df.plot.bar(rot=0, subplots=True, fontsize = 20)
    sns.set_theme(style="white")
    
    plt.xlabel(xaxistitle)
    
    #plt.subplots_adjust(top=0.85)
    #plt.gcf().subplots_adjust(bottom=0.20)

    plt.suptitle(title, fontweight='bold')
    plt.xticks(rotation = 90) # Rotates X-Axis Ticks by 45-degrees
    plt.savefig(artifacts_filepath + 'Ridership_CATE_' + response_var + '_' + column + '.png')
    

    return df

In [None]:
dic_race_rename = {'0.0': 'White',
            '1.0': 'Black',
            '2.0': 'Asian',
            '3.0': 'Hispanic',
            '4.0': 'Not Specified',
            '5.0': 'Other',
            '6.0': 'American Indian',
            '7.0': 'Multi-Racial',
            '8.0': 'Pacific Islander',
            '*': 'Umatched'}

dic_language_rename = {"0.0":"English", 
                       "1.0":"Chinese", 
                       "2.0":"Spanish", 
                       "3.0":"Dari", 
                       "4.0":"Vietnamese", 
                       "5.0":'Other'}

dic_agency_rename = {"0.0":"King County Public Health", 
     "1.0":"DSHS - ORCA LIFT (remote enrollment)", 
     "2.0":"DSHS - subsidized annual pass (remote enrollment)", 
     "3.0":"CCS", 
     "4.0":"KCMCCS", 
     "5.0":'Other'}



In [None]:
plot_column_CATE(column='RaceDesc', rename_dic=dic_race_rename)

In [None]:
plot_column_CATE(column='Access30Transit_D')

In [None]:
plot_column_CATE(column='OverallJobAccess_D')

In [None]:
plot_column_CATE(column='Agency', rename_dic=dic_agency_rename)

In [None]:
plot_column_CATE(column='Language', rename_dic=dic_language_rename)



In [None]:
plot_column_CATE(column='stage')



In [None]:
plot_column_CATE(column='Age_bin')

In [None]:
plot_column_CATE(column='Access30Transit_ts_D')



In [None]:
plot_column_CATE(column='spatialmismatch_D')

