# Multi-Action Synthetic Control Example

This Jupyter notebook is designed to be a simple, "user-friendly" tool to demonstrate the Multi-Action Synthetic Control (MA-SC) algorithm. 

The MS-SC algorithm is implented in the $\textbf{fill_tensor}$ method below. 

In Sections 1 and 2, using artificially generated data, we illustrate how to use the $\textbf{fill_tensor}$ method to generate counterfactuals for $\textit{each unit}$ under $\textit{each intervention}$ of interest (i.e., personalized interventions). 

We hope you find the method useful for your problems of interest. We have found MA-SC to product accurate counterfactual estimates across a wide vareity of fields including: econometric policy evaluation, web-scale A/B testing, sports, genetics. We hope you find it to be of use too for your problems of interest.

In [1]:
from scontrol2 import random_rct, diagnostic, fill_tensor
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

## Section 1 - Generating Artificial Data from a Randomized Control Trial

### Explanation of Terms $N, I, T, T_0, r, \sigma$ 

We begin with generating artificial data for the purposes of the demonstration through the function random_rct. All the data can be captured through a 3-dimensional tensor, $\mathcal{M} \in \mathbb{R}^{N \times T \times I}$.

$N$ denotes the number of units we perform the experiments on. 

$I$ denotes the total number of intervention. Each unit $n \in N$ will recieve exactly one intervention, $i \in I$.

$T$ is the total number of time periods (i.e., total number of measurements) we perform the experiment for. 

$T_0$ is the number of pre-intervention periods. Note $1 < T_0 < T$.

$r$ denotes the "model complexity", i.e., the rank of the tensor $\mathcal{M}$. 

$\sigma$ is the level of noise added to each measurement, i.e., the variance parameter of mean zero Gaussian noise.

In [2]:
# Generate Artifical Data
rct_data = random_rct(N = 100, I = 4, T = 100, T0 = 40, rank = 2, sigma = 0)

### Pre-Intervention & Post-Intervention Data (pre_df, post_df)

The rct_data object returned by calling the function $\textbf{random_rct}$ is comprised of two dataframes: pre_df and post_df.

pre_df is a 2-dimensional matrix, $\mathcal{M}^{\text{pre}} \in \mathbb{R}^{N \times T_0}$. It is measurements of all units before any experiments are performed.

post_df is a 2-dimensional matrix, $\mathcal{M}^{\text{post}} \in \mathbb{R}^{N \times (T-T_0)}$. It is the intervention that each unit $n \in N$ experienced (actually observed in reality) in the post-intevention phase. 

(Note not each unit in pre_df has to have experienced an intervention. Further, a unit can experience multiple interventions. The function $\textbf{fill_tensor}$ (the MA-SC algorithm) will work as is for both. For simplicity, we illustrate on artificial data, the case where each unit in the pre-intervention phase $n \in N$ receives exactly one intervention in the post-intervention phase.)

In [3]:
# Pre- and Post- Intervention Data
pre_df, post_df = rct_data

In [4]:
pre_df.head()

Unnamed: 0,unit,intervention,t_0,t_1,t_2,t_3,t_4,t_5,t_6,t_7,...,t_30,t_31,t_32,t_33,t_34,t_35,t_36,t_37,t_38,t_39
0,id_0,inter_0,2.042665,-1.588248,1.394762,-1.418817,0.265346,-0.985812,1.074614,0.626755,...,1.42363,0.09,-0.498738,-1.532418,-1.113478,0.537767,-1.300796,0.21957,0.303231,-1.467718
1,id_1,inter_0,-1.815103,0.488189,-0.470921,1.168816,-0.099183,0.409964,-1.91698,-0.455953,...,-1.499546,0.553372,0.34982,1.198506,0.61985,0.385676,0.919076,0.517958,0.630305,1.145077
2,id_2,inter_0,2.593647,-1.087507,0.997505,-1.708985,0.199426,-0.782654,2.332842,0.694176,...,2.043681,-0.523205,-0.5393,-1.781508,-1.041828,-0.18635,-1.413316,-0.438928,-0.520606,-1.703445
3,id_3,inter_0,-2.303934,1.456753,-1.294586,1.566963,-0.249766,0.942964,-1.560828,-0.670315,...,-1.690734,0.128083,0.528687,1.669264,1.121921,-0.29351,1.38133,0.01084,-0.015846,1.597761
4,id_4,inter_0,-0.767379,0.312287,-0.287245,0.504692,-0.057602,0.226781,-0.700087,-0.204349,...,-0.607068,0.161299,0.158604,0.525418,0.304452,0.063996,0.415726,0.137182,0.163264,0.502363


Unnamed: 0,unit,intervention,t_0,t_1,t_2,t_3,t_4,t_5,t_6,t_7,...,t_30,t_31,t_32,t_33,t_34,t_35,t_36,t_37,t_38,t_39
0,id_0,inter_0,2.042665,-1.588248,1.394762,-1.418817,0.265346,-0.985812,1.074614,0.626755,...,1.42363,0.09,-0.498738,-1.532418,-1.113478,0.537767,-1.300796,0.21957,0.303231,-1.467718
1,id_1,inter_0,-1.815103,0.488189,-0.470921,1.168816,-0.099183,0.409964,-1.91698,-0.455953,...,-1.499546,0.553372,0.34982,1.198506,0.61985,0.385676,0.919076,0.517958,0.630305,1.145077
2,id_2,inter_0,2.593647,-1.087507,0.997505,-1.708985,0.199426,-0.782654,2.332842,0.694176,...,2.043681,-0.523205,-0.5393,-1.781508,-1.041828,-0.18635,-1.413316,-0.438928,-0.520606,-1.703445
3,id_3,inter_0,-2.303934,1.456753,-1.294586,1.566963,-0.249766,0.942964,-1.560828,-0.670315,...,-1.690734,0.128083,0.528687,1.669264,1.121921,-0.29351,1.38133,0.01084,-0.015846,1.597761
4,id_4,inter_0,-0.767379,0.312287,-0.287245,0.504692,-0.057602,0.226781,-0.700087,-0.204349,...,-0.607068,0.161299,0.158604,0.525418,0.304452,0.063996,0.415726,0.137182,0.163264,0.502363


In [5]:
post_df.head()

Unnamed: 0,unit,intervention,t_40,t_41,t_42,t_43,t_44,t_45,t_46,t_47,...,t_90,t_91,t_92,t_93,t_94,t_95,t_96,t_97,t_98,t_99
0,id_0,inter_2,0.306549,-1.328981,-0.893788,2.411053,0.282049,0.354917,-0.359026,-2.412839,...,0.154893,-1.740342,0.904971,2.043332,-1.438101,-0.965589,3.087692,0.524963,-2.005381,2.792084
1,id_1,inter_1,-0.383018,2.082014,1.400048,-3.733998,-0.556799,-0.922937,0.634809,3.574497,...,-0.635356,2.816473,-1.572656,-2.987209,2.133706,1.718418,-4.821048,-0.845178,3.152589,-4.309403
2,id_2,inter_3,-0.702458,1.622166,1.091585,-3.088875,0.043787,0.805623,0.193949,3.639034,...,1.136824,1.820367,-0.581594,-3.216401,2.158021,0.484078,-3.823585,-0.563939,2.410956,-3.626649
3,id_3,inter_3,-1.715337,0.168983,0.116809,-1.05171,1.945744,6.280949,-1.201762,3.850175,...,6.750899,-1.330614,2.555718,-3.948165,2.239034,-3.423802,-0.672033,0.325611,0.066915,-1.471318
4,id_4,inter_0,0.174016,0.225794,0.151428,-0.309138,-0.31519,-0.913531,0.229243,-0.067972,...,-0.939501,0.504997,-0.513977,0.150304,-0.032993,0.642399,-0.486912,-0.142111,0.366081,-0.323815


Unnamed: 0,unit,intervention,t_40,t_41,t_42,t_43,t_44,t_45,t_46,t_47,...,t_90,t_91,t_92,t_93,t_94,t_95,t_96,t_97,t_98,t_99
0,id_0,inter_2,0.306549,-1.328981,-0.893788,2.411053,0.282049,0.354917,-0.359026,-2.412839,...,0.154893,-1.740342,0.904971,2.043332,-1.438101,-0.965589,3.087692,0.524963,-2.005381,2.792084
1,id_1,inter_1,-0.383018,2.082014,1.400048,-3.733998,-0.556799,-0.922937,0.634809,3.574497,...,-0.635356,2.816473,-1.572656,-2.987209,2.133706,1.718418,-4.821048,-0.845178,3.152589,-4.309403
2,id_2,inter_3,-0.702458,1.622166,1.091585,-3.088875,0.043787,0.805623,0.193949,3.639034,...,1.136824,1.820367,-0.581594,-3.216401,2.158021,0.484078,-3.823585,-0.563939,2.410956,-3.626649
3,id_3,inter_3,-1.715337,0.168983,0.116809,-1.05171,1.945744,6.280949,-1.201762,3.850175,...,6.750899,-1.330614,2.555718,-3.948165,2.239034,-3.423802,-0.672033,0.325611,0.066915,-1.471318
4,id_4,inter_0,0.174016,0.225794,0.151428,-0.309138,-0.31519,-0.913531,0.229243,-0.067972,...,-0.939501,0.504997,-0.513977,0.150304,-0.032993,0.642399,-0.486912,-0.142111,0.366081,-0.323815


## Section 2 - Diagnostic: Which Interventions can be reliably produce counterfactuals for?

In this section we show how to use our diagnostic tool method, termed $\textbf{diagnostic}$. 

$\textbf{diagnostic}$ is a function to assess if the counterfactual estimates produced are reliable. Recall, in reality, we do not get access to the counterfactual estimates. Hence, we need a test to see if any relationship we learn in the pre-intervention phase, will continue to reliably hold in the post-intervention phase. 

$\textbf{diagnostic}$ is an implementation of the "rank preservation" test, i.e., we check if the singualr values are not siginifacntly preturbed pre- and post- intervention. If they are not, then we can  safely use the counterfactual estimates produced for that intervention. 

The input to $\textbf{diagnostic}$ are the two pre- and post- intervention dataframes (pre_df, post_df). There is an optional parameter, $\textit{cum_energy} \in [0, 1]$ which denotes the threshold at which we cutoff the spectrum (i.e., choose the effective rank). This parameter is going to be application dependent. 

In [6]:
diagnostic(pre_df, post_df, cum_energy=0.90)

Unnamed: 0,intervention,Pre Intervention Rank (90.0%),Post Intervention Rank (90.0%),Valid (90.0%)
0,inter_0,1.0,2.0,False
1,inter_1,1.0,1.0,True
2,inter_2,1.0,1.0,True
3,inter_3,1.0,2.0,False


Unnamed: 0,intervention,Pre Intervention Rank (90.0%),Post Intervention Rank (90.0%),Valid (90.0%)
0,inter_0,1.0,2.0,False
1,inter_1,1.0,1.0,True
2,inter_2,1.0,1.0,True
3,inter_3,1.0,2.0,False


## Section 3 - Producing Counterfactual Estimates: For Each Unit Under Each Intervention

In this section, we show how to use the $\textbf{fill_tensor}$ method to produce personalized interventions for each unit. 

The input to $\textbf{fill_tensor}$ are the two pre- and post- intervention dataframes. 

The key parameter to the method is: $\textit{cumulative_energy} \in [0, 100]$, which decides the number of prinicpal components to retain when performing Principal Component Regression, when we learn a linear coefficient. In essence, we find the minimum of principal components required such that the percentage of the spectral energy retained is above the given parameter. 

In [8]:
df_output = fill_tensor(pre_df, post_df, cum_energy=0.90)
df_output.head()

Unnamed: 0,unit,intervention,t_40,t_41,t_42,t_43,t_44,t_45,t_46,t_47,...,t_90,t_91,t_92,t_93,t_94,t_95,t_96,t_97,t_98,t_99
0,id_0,inter_0,0.156236,-1.126372,-0.757331,1.997437,0.361488,0.69168,-0.381365,-1.826055,...,0.549616,-1.570907,0.932025,1.503923,-1.091811,-1.037514,2.599695,0.469173,-1.711273,2.297446
1,id_0,inter_1,0.116987,-2.368963,-1.592405,4.106913,1.010412,2.253268,-0.959541,-3.393244,...,2.010593,-3.499797,2.297349,2.697443,-2.036724,-2.629764,5.432356,1.036284,-3.622856,4.691033
2,id_0,inter_2,0.084366,-1.208484,-0.812384,2.106042,0.486261,1.056302,-0.471122,-1.783186,...,0.925959,-1.762505,1.13262,1.43037,-1.069279,-1.289298,2.775335,0.522864,-1.845366,2.409487
3,id_0,inter_3,0.145907,0.869879,0.584372,-1.424102,-0.594277,-1.540115,0.49288,0.846784,...,-1.501086,1.459963,-1.144483,-0.574954,0.516226,1.365215,-1.963271,-0.424727,1.351496,-1.596774
4,id_1,inter_0,-0.120697,0.87015,0.585057,-1.54307,-0.279258,-0.53434,0.294614,1.410673,...,-0.424592,1.213565,-0.720012,-1.161818,0.843451,0.801506,-2.008329,-0.362448,1.322001,-1.774834


Unnamed: 0,unit,intervention,t_40,t_41,t_42,t_43,t_44,t_45,t_46,t_47,...,t_90,t_91,t_92,t_93,t_94,t_95,t_96,t_97,t_98,t_99
0,id_0,inter_0,0.156236,-1.126372,-0.757331,1.997437,0.361488,0.69168,-0.381365,-1.826055,...,0.549616,-1.570907,0.932025,1.503923,-1.091811,-1.037514,2.599695,0.469173,-1.711273,2.297446
1,id_0,inter_1,0.116987,-2.368963,-1.592405,4.106913,1.010412,2.253268,-0.959541,-3.393244,...,2.010593,-3.499797,2.297349,2.697443,-2.036724,-2.629764,5.432356,1.036284,-3.622856,4.691033
2,id_0,inter_2,0.084366,-1.208484,-0.812384,2.106042,0.486261,1.056302,-0.471122,-1.783186,...,0.925959,-1.762505,1.13262,1.43037,-1.069279,-1.289298,2.775335,0.522864,-1.845366,2.409487
3,id_0,inter_3,0.145907,0.869879,0.584372,-1.424102,-0.594277,-1.540115,0.49288,0.846784,...,-1.501086,1.459963,-1.144483,-0.574954,0.516226,1.365215,-1.963271,-0.424727,1.351496,-1.596774
4,id_1,inter_0,-0.120697,0.87015,0.585057,-1.54307,-0.279258,-0.53434,0.294614,1.410673,...,-0.424592,1.213565,-0.720012,-1.161818,0.843451,0.801506,-2.008329,-0.362448,1.322001,-1.774834


In [None]:
unit = 'id_1'
inter_obs = 'inter_3'

y = post_df.loc[(post_df.unit==unit) & (post_df.intervention==inter_obs)].drop(columns=['unit', 'intervention']).values
y_hat = df_output.loc[(df_output.unit==unit) & (df_output.intervention==inter_obs)].drop(columns=['unit', 'intervention']).values


In [None]:
plt.figure()
plt.plot(y.flatten(), label='obs')
plt.plot(y_hat.flatten(), label='pred')
plt.legend(loc='best')
plt.show()

In [None]:
pre_df[filter_inter]