# Guide to using CAST_panel

Here is a Jupyter Notebook that simultaneously provides a guide to using the CAST_panel package, as well as reproducing the results in the paper: [FILL ME HERE]

In [1]:
import pandas as pd
import numpy as np
import CAST

## Uninsurance rates

In [2]:
# loads the treatment
treat = pd.read_csv("sample_data/expansion.csv")
adopt_year = treat["ADOPTION"]

uninsurance_rates = pd.read_csv("sample_data/uninsurance_rates.csv", index_col = 0)

In [3]:
treat.head()

Unnamed: 0,STATE,ADOPTION,DATE
0,AL,2025,1/1/2025
1,AK,2016,9/1/2015
2,AZ,2014,1/1/2014
3,AR,2014,1/1/2014
4,CA,2014,1/1/2014


In [4]:
uninsurance_rates.head() # our observation matrix

Unnamed: 0,2008,2009,2010,2011,2012,2013,2014,2015,2016,2017,2018,2019,2021,2022
Alabama,0.16,0.16,0.17,0.16,0.15,0.16,0.14,0.12,0.11,0.11,0.122,0.116,0.12,0.104
Alaska,0.22,0.22,0.19,0.22,0.22,0.2,0.19,0.16,0.16,0.16,0.135,0.129,0.121,0.123
Arizona,0.21,0.2,0.2,0.2,0.2,0.2,0.16,0.13,0.12,0.12,0.127,0.134,0.128,0.126
Arkansas,0.21,0.19,0.2,0.2,0.19,0.19,0.14,0.11,0.09,0.1,0.099,0.109,0.11,0.102
California,0.2,0.2,0.21,0.2,0.2,0.19,0.14,0.1,0.08,0.08,0.082,0.09,0.081,0.075


In order to run the method, we need to provide a vector that indicates the "index of treatment", i.e. from which column in the observation matrix did treatment begin. 

The following piece of code constucts the treatment index: it constructs a pandas Series that is consistent with the rows of the observation matrix, and each entry indicates the column for which treatment began in the matrix.

In [5]:
treat_index = adopt_year - 2008
treat_index.head()

0    17
1     8
2     6
3     6
4     6
Name: ADOPTION, dtype: int64

Alabama corresponds with row 0, and 17 is larger than the number of columns indicating that it has never adopted the treatment. Alaska corresponds to row 1, and its value 8 indicates that it adopted treatment in 2008 + 8 = 2016 (more precisely, September of 2015 which we rounded to 2016).

However our observation matrix is missing 2020 as an observation, so we have to hack around it to ensure that states that enacted the expansion after it has the correct index. We also need to ensure that the treatment value is equal to the number of columns in the matrix if a unit never undergoes treatment

In [6]:
treat_index[treat_index >= 13] -= 1 
treat_index = treat_index.to_list()
# makes it such that everything that began in 2020 is pushed further by a year

The treat_index variable needs to be a list where the i^th entry indicates from which column in the observation matrix did treatment for the i^th row begin.

In [7]:
treat_index

[16,
 8,
 6,
 6,
 6,
 6,
 6,
 6,
 16,
 16,
 6,
 12,
 6,
 7,
 6,
 16,
 6,
 9,
 11,
 6,
 6,
 6,
 6,
 16,
 13,
 8,
 12,
 6,
 7,
 6,
 6,
 6,
 6,
 6,
 6,
 13,
 6,
 7,
 6,
 16,
 15,
 16,
 16,
 12,
 6,
 11,
 6,
 6,
 16,
 16]

We have implemented a helper method that computes the singular values of the control matrix to aid in rank selection. It takes in both the observation matrix and the treatment index as arguments. As a whole, our library deals with numpy arrays.

In [8]:
print(CAST.rank_selection(uninsurance_rates.values, treat_index))

{'singular_values': array([2.05367718, 0.08315907, 0.04581451, 0.03229862, 0.02982912,
       0.02389201, 0.01775155, 0.01329282]), 'broken_stick': 1}


Based on the singular values, we pick to use a rank one matrix for our method. 

In [9]:
uninsurance_CAST = CAST.method(uninsurance_rates.values, treat_index, rank = 1)
uninsurance_sigeffects = uninsurance_CAST.get_significant_effects()
uninsurance_sigeffects.index = uninsurance_rates.columns[6:]
uninsurance_sigeffects

Unnamed: 0,Positive effects,Negative effects,Null effects,Number treatment
2014,1.0,23.0,2.0,26.0
2015,1.0,28.0,0.0,29.0
2016,2.0,28.0,1.0,31.0
2017,2.0,28.0,2.0,32.0
2018,1.0,30.0,1.0,32.0
2019,0.0,29.0,5.0,34.0
2021,1.0,36.0,0.0,37.0
2022,0.0,36.0,3.0,39.0


To get the treatment effects, we can use the following method:

In [10]:
uninsurance_effects = uninsurance_CAST.get_treatment_effects()
pd.DataFrame(uninsurance_effects[0], index = uninsurance_rates.index, 
                 columns = uninsurance_rates.columns).head()

Unnamed: 0,2008,2009,2010,2011,2012,2013,2014,2015,2016,2017,2018,2019,2021,2022
Alabama,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Alaska,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.014439,0.007597,-0.019201,-0.029877,-0.033441,-0.018964
Arizona,0.0,0.0,0.0,0.0,0.0,0.0,-0.010076,-0.017087,-0.017507,-0.023963,-0.018662,-0.015915,-0.017735,-0.007782
Arkansas,0.0,0.0,0.0,0.0,0.0,0.0,-0.025916,-0.03349,-0.044144,-0.040443,-0.0431,-0.037249,-0.032171,-0.02851
California,0.0,0.0,0.0,0.0,0.0,0.0,-0.028744,-0.045936,-0.05643,-0.062837,-0.062522,-0.058742,-0.063595,-0.057735


In [11]:
# this is a matrix indicating which entries are controls
pd.DataFrame(uninsurance_effects[1], index = uninsurance_rates.index, 
                 columns = uninsurance_rates.columns).head()

Unnamed: 0,2008,2009,2010,2011,2012,2013,2014,2015,2016,2017,2018,2019,2021,2022
Alabama,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
Alaska,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
Arizona,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Arkansas,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
California,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [12]:
# we can also report the standard errors per entry
pd.DataFrame(uninsurance_CAST.get_std_errors(), index = uninsurance_rates.index, 
                 columns = uninsurance_rates.columns).head()

Unnamed: 0,2008,2009,2010,2011,2012,2013,2014,2015,2016,2017,2018,2019,2021,2022
Alabama,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Alaska,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.003414,0.003666,0.00338,0.003653,0.003532,0.003942
Arizona,0.0,0.0,0.0,0.0,0.0,0.0,0.002083,0.001727,0.001887,0.002122,0.001591,0.001944,0.001861,0.002716
Arkansas,0.0,0.0,0.0,0.0,0.0,0.0,0.002443,0.002053,0.002143,0.002367,0.001938,0.002242,0.002156,0.002857
California,0.0,0.0,0.0,0.0,0.0,0.0,0.001543,0.001233,0.001507,0.001754,0.001051,0.001499,0.001422,0.002468


We also provide functionality for computing the average treatment effect on the treated units for each year, as well as estimates for their standard errors.

In [13]:
# first is the average values, and the second is the standard error estimates
ATET, se_ATET = uninsurance_CAST.get_average_effect()

ATET_values = {"ATET" : ATET, "se" : se_ATET}
ATET_df = pd.DataFrame(ATET_values, index = uninsurance_rates.columns[6:])
ATET_df

Unnamed: 0,ATET,se
2014,-0.0173,0.000922
2015,-0.022789,0.000688
2016,-0.024466,0.000984
2017,-0.026775,0.001188
2018,-0.028769,0.000525
2019,-0.025996,0.000944
2021,-0.028613,0.000885
2022,-0.022443,0.001797


We also allow the average effects argument to take in weights to compute any sort of weighted treatment effects, along with their standard errors.

In [14]:
# loads the pouplation data
pop_df = pd.read_csv("sample_data/population.csv", index_col = 0)
# making sure its aligned with the uninsurance_rates observations
pop_df = pop_df[uninsurance_rates.columns].reindex(uninsurance_rates.index)
pop_df.head()

Unnamed: 0,2008,2009,2010,2011,2012,2013,2014,2015,2016,2017,2018,2019,2021,2022
Alabama,4718206.0,4757938.0,4785298,4799642,4816632,4831586,4843737,4854803,4866824,4877989,4891628,4907965,5031864,5050380
Alaska,687455.0,698895.0,713985,722349,730810,737626,737075,738430,742575,740983,736624,733603,732964,734923
Arizona,6280362.0,6343154.0,6413737,6473416,6556344,6634690,6732873,6832810,6944767,7048088,7164228,7291843,7186683,7272487
Arkansas,2874554.0,2896843.0,2921606,2941038,2952876,2960459,2968759,2979732,2991815,3003855,3012161,3020985,3014348,3028443
California,36604337.0,36961229.0,37349363,37636311,37944551,38253768,38586706,38904296,39149186,39337785,39437463,39437610,39503200,39145060


In [15]:
# the weights must be a numpy array, the num_treated_renorm argument indicates to
# normalize all the weights such that they sum up to one for the treated units,
# by default it is true
pop_ATET, se_pop_ATET = uninsurance_CAST.get_average_effect(weights = pop_df.values, 
                                                                num_treated_renorm = False)

pop_ATET_values = {"ATET" : pop_ATET, "se" : se_pop_ATET}
pop_ATET_df = pd.DataFrame(pop_ATET_values, index = uninsurance_rates.columns[6:])
pop_ATET_df

Unnamed: 0,ATET,se
2014,-3080816.0,174132.574724
2015,-4954703.0,144598.100664
2016,-5849858.0,205123.75382
2017,-6633449.0,254322.147545
2018,-6848940.0,110989.119241
2019,-6651022.0,212763.572543
2021,-7504618.0,207468.916971
2022,-6493026.0,438264.196987


# Health expenditures

In [18]:
expenditures = pd.read_csv("sample_data/stateexp_percapita.csv", index_col = 0)
expenditures.head()

Unnamed: 0_level_0,2000,2001,2002,2003,2004,2005,2006,2007,2008,2009,...,2013,2014,2015,2016,2017,2018,2019,2020,2021,2022
STATE,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Alabama,2347.617669,2377.656719,2191.630568,2183.21742,2223.792242,4801.517789,5040.202152,5660.369283,2871.642315,2625.296925,...,3114.298286,3227.260274,3254.920951,3313.86547,3443.017194,3578.972072,3715.185418,3953.037388,3834.960563,4247.601171
Alaska,0.0,0.0,5749.318504,6441.872014,6898.371875,8721.845547,10047.35659,12550.345436,13589.253115,13700.19817,...,12788.323622,11794.864837,14208.029628,9097.128236,8055.515444,9063.104107,10095.514877,9742.080371,9300.593208,10605.192653
Arizona,2290.534447,2389.505065,2590.500264,2283.651316,2630.935439,2650.418893,2834.068734,2610.300371,2774.840049,2678.320596,...,2344.088119,3330.3762,3824.341669,3922.23382,2913.272365,3094.54138,3149.272413,7771.705828,7034.687477,8254.228052
Arkansas,2831.173738,2972.613392,3181.564395,3242.061002,3432.912319,3613.681939,3792.312673,4043.669809,4206.91349,4407.556778,...,5194.87012,5477.069712,5607.752644,5555.011009,5761.926115,5837.373352,5944.094977,6030.314249,6211.956947,6532.399652
California,2496.882942,2795.345565,2845.275485,3028.920047,2929.70463,3003.075002,3320.599907,3585.293379,3771.820809,3311.199419,...,3685.937553,3701.015578,4120.23906,4092.268994,4216.81175,4496.129885,5153.531439,5285.754184,5735.976225,6915.15826


In [48]:
treat_index = adopt_year - 2000
treat_index = treat_index.to_list()

print(CAST.rank_selection(expenditures.values, treat_index))

{'singular_values': array([59929.07160213,  5525.18226946,  3986.05229358,  1691.5044982 ,
        1267.62880648,  1175.57179881,   879.08639218,   608.08689983,
         525.57974705]), 'broken_stick': 1}


In [111]:
expend_CAST = CAST.method(expenditures.values, treat_index, rank = 3)
expend_sigeffects = expend_CAST.get_significant_effects()
expend_sigeffects.index = expenditures.columns[-9:]
expend_sigeffects

Unnamed: 0,Positive effects,Negative effects,Null effects,Number treatment
2014,10.0,6.0,10.0,26.0
2015,6.0,4.0,19.0,29.0
2016,7.0,7.0,17.0,31.0
2017,6.0,5.0,21.0,32.0
2018,9.0,6.0,17.0,32.0
2019,9.0,10.0,15.0,34.0
2020,14.0,7.0,15.0,36.0
2021,19.0,8.0,10.0,37.0
2022,16.0,11.0,12.0,39.0


In [112]:
# first is the average values, and the second is the standard error estimates
ATET, se_ATET = expend_CAST.get_average_effect()

ATET_values = {"ATET" : ATET, "se" : se_ATET}
ATET_df = pd.DataFrame(ATET_values, index = expenditures.columns[-9:])
ATET_df

Unnamed: 0,ATET,se
2014,98.921934,53.88728
2015,139.471423,55.165686
2016,78.783908,57.178851
2017,-92.445694,67.341482
2018,29.156311,66.383097
2019,-120.132579,89.748709
2020,175.251414,65.171699
2021,264.683816,48.63049
2022,128.028141,60.76892


# Infant mortality rates

In [115]:
treat = adopt_year - 2001
treat_index = treat.to_list()

mortality_rates_infant = pd.read_csv("sample_data/mortality_rates_infant_wind3.csv", index_col = 0)

In [116]:
inf_mort_CAST = CAST.method(mortality_rates_infant, treat_index, rank = 3)

print(inf_mort_CAST.get_significant_effects())

InvalidIndexError: (array([ 0, 42, 41, 40, 39, 35, 26, 48, 23, 15, 24, 49,  8,  9, 43, 11, 18,
       45, 17, 25,  1, 37, 13, 28, 38,  4,  3, 16, 36, 12, 44, 46, 47,  2,
        5, 32, 33, 14, 31, 10, 29, 27,  6,  7, 22, 21, 20, 19, 34, 30]), slice(None, None, None))

In [36]:
# loads the treatment
treat = pd.read_csv("sample_data/expansion.csv")
adopt_year = treat["ADOPTION"]

mortality_rates_infant = np.genfromtxt("sample_data/mortality_rates_infant_wind3.csv", delimiter=",")
mortality_rates_infant = mortality_rates_infant[1:,1:]


pop_df = pd.read_csv("sample_data/births_1999-2020_wind3.csv")

if pop_df.shape[0] == 51:
    pop_df = pop_df.drop([8])

years = np.arange(2001, 2021)
treat = adopt_year - 2001
treat_index = treat.to_list()



print(CAST.rank_selection(mortality_rates_infant, treat_index))

inf_mort_CAST = CAST.method(mortality_rates_infant, treat_index, rank = 3)


print(inf_mort_CAST.get_significant_effects())


{'singular_values': array([9275.94659706,  269.30981323,  166.71897289,  115.1514331 ,
         62.48948766,   51.08276807,   32.08624455]), 'broken_stick': 1}
   Positive effects  Negative effects  Null effects  Number treatment
0               5.0              12.0           9.0              26.0
1               6.0              17.0           6.0              29.0
2               5.0              15.0          11.0              31.0
3               5.0              16.0          11.0              32.0
4               4.0              18.0          10.0              32.0
5               3.0              22.0           9.0              34.0
6               6.0              20.0          10.0              36.0


In [38]:
mortality_rates_infant

array([[ 986.86666667,  947.96666667,  910.        ,  894.76666667,
         908.66666667,  923.43333333,  971.2       ,  972.26666667,
         938.23333333,  888.86666667,  842.36666667,  851.5       ,
         846.96666667,  871.1       ,  867.46666667,  888.9       ,
         841.43333333,  794.96666667,  750.8       ,  740.4       ],
       [ 721.96666667,  706.        ,  704.63333333,  660.8       ,
         674.36666667,  681.6       ,  676.03333333,  679.86666667,
         663.6       ,  570.6       ,  492.33333333,  431.63333333,
         483.83333333,  579.66666667,  637.03333333,  635.3       ,
         588.43333333,  547.56666667,  525.1       ,  516.16666667],
       [ 715.3       ,  699.13333333,  679.5       ,  682.2       ,
         701.43333333,  704.73333333,  715.53333333,  698.26666667,
         682.2       ,  634.5       ,  607.16666667,  587.3       ,
         565.63333333,  578.76666667,  565.2       ,  561.43333333,
         536.8       ,  535.86666667,  537.1  

In [37]:
treat_index

[24,
 15,
 13,
 13,
 13,
 13,
 13,
 13,
 24,
 24,
 13,
 19,
 13,
 14,
 13,
 24,
 13,
 16,
 18,
 13,
 13,
 13,
 13,
 24,
 21,
 15,
 20,
 13,
 14,
 13,
 13,
 13,
 13,
 13,
 13,
 21,
 13,
 14,
 13,
 24,
 23,
 24,
 24,
 19,
 13,
 18,
 13,
 13,
 24,
 24]

In [17]:
mortality_rates_infant

array([[ 986.86666667,  947.96666667,  910.        ,  894.76666667,
         908.66666667,  923.43333333,  971.2       ,  972.26666667,
         938.23333333,  888.86666667,  842.36666667,  851.5       ,
         846.96666667,  871.1       ,  867.46666667,  888.9       ,
         841.43333333,  794.96666667,  750.8       ,  740.4       ],
       [ 721.96666667,  706.        ,  704.63333333,  660.8       ,
         674.36666667,  681.6       ,  676.03333333,  679.86666667,
         663.6       ,  570.6       ,  492.33333333,  431.63333333,
         483.83333333,  579.66666667,  637.03333333,  635.3       ,
         588.43333333,  547.56666667,  525.1       ,  516.16666667],
       [ 715.3       ,  699.13333333,  679.5       ,  682.2       ,
         701.43333333,  704.73333333,  715.53333333,  698.26666667,
         682.2       ,  634.5       ,  607.16666667,  587.3       ,
         565.63333333,  578.76666667,  565.2       ,  561.43333333,
         536.8       ,  535.86666667,  537.1  

In [15]:
treat_index

[24,
 15,
 13,
 13,
 13,
 13,
 13,
 13,
 24,
 24,
 13,
 19,
 13,
 14,
 13,
 24,
 13,
 16,
 18,
 13,
 13,
 13,
 13,
 24,
 21,
 15,
 20,
 13,
 14,
 13,
 13,
 13,
 13,
 13,
 13,
 21,
 13,
 14,
 13,
 24,
 23,
 24,
 24,
 19,
 13,
 18,
 13,
 13,
 24,
 24]

In [14]:
adopt_year

0     2025
1     2016
2     2014
3     2014
4     2014
5     2014
6     2014
7     2014
8     2025
9     2025
10    2014
11    2020
12    2014
13    2015
14    2014
15    2025
16    2014
17    2017
18    2019
19    2014
20    2014
21    2014
22    2014
23    2025
24    2022
25    2016
26    2021
27    2014
28    2015
29    2014
30    2014
31    2014
32    2014
33    2014
34    2014
35    2022
36    2014
37    2015
38    2014
39    2025
40    2024
41    2025
42    2025
43    2020
44    2014
45    2019
46    2014
47    2014
48    2025
49    2025
Name: ADOPTION, dtype: int64