# Synthetic Control in PyMC

In this notebook we provide an example of how to implement a synthetic control problem in PyMC to answer a "what if this had happened?" type of question in the context of causal inference. We reproduce the results of the example provided in the great book [Causal Inference for The Brave and True](https://matheusfacure.github.io/python-causality-handbook/landing-page.html) by [Matheus Facure](https://matheusfacure.github.io/). Specifically, we look into the problem of estimating the *effect of cigarette taxation on its consumption* presented in Chapter [15 - Synthetic Control](https://matheusfacure.github.io/python-causality-handbook/15-Synthetic-Control.html).

## Prepare Notebook

In [None]:
import aesara.tensor as at
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import pymc.sampling_jax
from pymc.distributions.continuous import Exponential
from sklearn.preprocessing import StandardScaler
import seaborn as sns


plt.style.use("bmh")
plt.rcParams["figure.figsize"] = [10, 6]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"

%load_ext rich
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"

## Read Data

In [None]:
data_path = "https://raw.githubusercontent.com/matheusfacure/python-causality-handbook/master/causal-inference-for-the-brave-and-true/data/smoking.csv"

raw_data_df = pd.read_csv(data_path)

raw_data_df.head()

In [None]:
df = raw_data_df.copy().drop(columns=["lnincome", "beer", "age15to24"]).assign()

df.info()

In [None]:
fig, ax = plt.subplots()

(
    df.groupby(["year", "california"], as_index=False)
    .agg({"cigsale": np.mean})
    .pipe(
        (sns.lineplot, "data"),
        x="year",
        y="cigsale",
        hue="california",
        marker="o",
        ax=ax,
    )
)
ax.axvline(
    x=1988,
    linestyle=":",
    lw=2,
    color="C2",
    label="Proposition 99",
)
ax.legend(loc="upper right")
ax.set(
    title="Gap in per-capita cigarette sales (in packs)", ylabel="Cigarette Sales Trend"
);


In [None]:
features = ["cigsale", "retprice"]

pre_df = (
    df
    .query("~ after_treatment")
    .pivot(index='state', columns="year", values=features)
    .T
)

post_df = (
    df
    .query("after_treatment")
    .pivot(index='state', columns="year", values=features)
    .T
)

In [None]:
idx = 3

y_pre = pre_df[idx].to_numpy()
x_pre = pre_df.drop(columns=idx).to_numpy()
pre_years = pre_df.reset_index(inplace=False).year.unique()
n_pre = pre_years.size

y_post = post_df[idx].to_numpy()
x_post = post_df.drop(columns=idx).to_numpy()
post_years = post_df.reset_index(inplace=False).year.unique()
n_post = post_years.size

k = pre_df.shape[0]

In [None]:
with pm.Model() as model:
    x = pm.MutableData(name="x", value=x_pre)
    y = pm.MutableData(name="y", value=y_pre)

    beta = pm.Dirichlet(name="beta", a=(1 / k) * np.ones(k))
    sigma = pm.HalfNormal(name="sigma", sigma=5)
    mu = pm.Deterministic(name="mu", var=pm.math.dot(x, beta))
    likelihood = pm.Normal(name="likelihood", mu=mu, sigma=sigma, observed=y)

pm.model_to_graphviz(model)

In [None]:
with model:
    idata = pm.sampling_jax.sample_numpyro_nuts(draws=4000, chains=4)
    posterior_predictive_pre = pm.sample_posterior_predictive(trace=idata)

In [None]:
az.summary(data=idata)

In [None]:
az.plot_forest(data=idata, combined=True, var_names=["beta"]);

In [None]:
np.unique(
    idata
    .posterior["beta"]
    .stack(samples=("chain", "draw"))
    .sum(axis=0)
    .to_numpy()
    - 1
)

In [None]:
with model:
    pm.set_data(new_data={"x": x_post, "y": y_post})
    posterior_predictive_post = pm.sample_posterior_predictive(
        trace=idata, var_names=["likelihood"]
    )


In [None]:
pre_posterior_mean = (
    posterior_predictive_pre
    .posterior_predictive
    ["likelihood"]
    [:, :, : n_pre]
    .stack(samples=("chain", "draw"))
    .mean(axis=1)
)

post_posterior_mean = (
    posterior_predictive_post
    .posterior_predictive
    ["likelihood"]
    [:, :, : n_post]
    .stack(samples=("chain", "draw"))
    .mean(axis=1)
)


fig, ax = plt.subplots()

(
    df.groupby(["year", "california"], as_index=False)
    .agg({"cigsale": np.mean})
    .assign(california=lambda x: x.california.map(
        {True: "is_california", False: "is_not_california"})
    )
    .pipe(
        (sns.lineplot, "data"),
        x="year",
        y="cigsale",
        hue="california",
        alpha=0.5,
        ax=ax,
    )
)
ax.axvline(
    x=1988,
    linestyle=":",
    lw=2,
    color="C2",
    label="Proposition 99",
)
sns.lineplot(
    x=pre_years,
    y=pre_posterior_mean,
    color="C1",
    marker="o",
    label="pre-treatment posterior predictive mean",
    ax=ax
)
sns.lineplot(
    x=post_years,
    y=post_posterior_mean,
    color="C2",
    marker="o",
    label="post-treatment posterior predictive mean",
    ax=ax
)
az.plot_hdi(
    x=pre_years,
    y=posterior_predictive_pre.posterior_predictive["likelihood"][:, :, : n_pre],
    smooth=True,
    color="C1",
    fill_kwargs={"label": "pre-treatment posterior predictive (94% HDI)"},
    ax=ax
)
az.plot_hdi(
    x=post_years,
    y=posterior_predictive_post.posterior_predictive["likelihood"][:, :, : n_post],
    smooth=True,
    color="C2",
    fill_kwargs={"label": "post-treatment posterior predictive (94% HDI)"},
    ax=ax
)
ax.legend(loc="lower left")
ax.set(
    title="Gap in per-capita cigarette sales (in packs)", ylabel="Cigarette Sales Trend"
);


In [None]:
fig, ax = plt.subplots()

ax.axvline(
    x=1988,
    linestyle=":",
    lw=2,
    color="C2",
    label="Proposition 99",
)
sns.lineplot(
    x=pre_years,
    y=y_pre[:n_pre] - pre_posterior_mean,
    color="C1",
    marker="o",
    label="pre-treatment posterior predictive effect mean",
    ax=ax
)
sns.lineplot(
    x=post_years,
    y=y_post[:n_post] - post_posterior_mean,
    color="C2",
    marker="o",
    label="post-treatment posterior predictive effect mean",
    ax=ax
)
az.plot_hdi(
    x=pre_years,
    y=y_pre[:n_pre] - posterior_predictive_pre.posterior_predictive["likelihood"][:, :, : n_pre],
    smooth=True,
    color="C1",
    fill_kwargs={"label": "pre-treatment posterior predictive effect (94% HDI)"},
    ax=ax
)
az.plot_hdi(
    x=post_years,
    y=y_post[:n_post] - posterior_predictive_post.posterior_predictive["likelihood"][:, :, : n_post],
    smooth=True,
    color="C2",
    fill_kwargs={"label": "post-treatment posterior predictive effect (94% HDI)"},
    ax=ax
)
ax.axhline(y=0.0, color="black", linestyle="--", label="zero")
ax.legend(loc="lower left")
ax.set(
    title="Gap in per-capita cigarette sales (in packs) - Effect",
    ylabel="Cigarette Sales Trend"
);

In [None]:
g = (
    (
        y_post[:n_post] - posterior_predictive_post.posterior_predictive["likelihood"][:, :, : n_post]
    )
    [:, :, -1]
    .stack(samples=("chain", "draw"))
    .pipe((sns.displot, "data"), kde=True)
)
g.set(title="reduced the sales in cigarettes in 2000");

---

In [None]:
def run_synthetic_control(pre_df: pd.DataFrame, post_df: pd.DataFrame, idx: int) -> tuple[]:

    y_pre = pre_df[idx].to_numpy()
    x_pre = pre_df.drop(columns=idx).to_numpy()
    pre_years = pre_df.reset_index(inplace=False).year.unique()
    n_pre = pre_years.size

    y_post = post_df[idx].to_numpy()
    x_post = post_df.drop(columns=idx).to_numpy()
    post_years = post_df.reset_index(inplace=False).year.unique()
    n_post = post_years.size

    k = pre_df.shape[0]

    with pm.Model() as model:
        x = pm.MutableData(name="x", value=x_pre)
        y = pm.MutableData(name="y", value=y_pre)

        beta = pm.Dirichlet(name="beta", a=(1 / k) * np.ones(k))
        sigma = pm.HalfNormal(name="sigma", sigma=5)
        mu = pm.Deterministic(name="mu", var=pm.math.dot(x, beta))
        likelihood = pm.Normal(name="likelihood", mu=mu, sigma=sigma, observed=y)

        idata = pm.sampling_jax.sample_numpyro_nuts(draws=4000, chains=4)
        posterior_predictive_pre = pm.sample_posterior_predictive(trace=idata)

        pm.set_data(new_data={"x": x_post, "y": y_post})
        posterior_predictive_post = pm.sample_posterior_predictive(
            trace=idata, var_names=["likelihood"]
        )

        error_pre = (
            y_pre[:n_pre] - posterior_predictive_pre.posterior_predictive["likelihood"][:, :, : n_pre]
        )
        error_post = (
            y_post[:n_post] - posterior_predictive_post.posterior_predictive["likelihood"][:, :, : n_post]
        )

    return error_pre, error_post

In [None]:
results = {
    idx: run_synthetic_control(pre_df=pre_df, post_df=pre_df, idx=idx)
    ßfor idx in df["state"].unique()[: 3]
}
