# Cohort Revenue Retention Analysis: A Bayesian Approach

In this notebook we extend the cohort retention model presented in the post [Cohort Retention Analysis with BART](https://juanitorduz.github.io/retention_bart/) so that we just model retention **and**  per cohort simultaneously (we recommend to read the referenced post before this one).  The idea is to keep modeling the retention using a Bayesian Additive Regression Tree (BART) model (see [`pymc-bart`](https://www.pymc.io/projects/bart/en/latest/)) and model the revenue per cohort linearly using a Gamma distribution. We couple the retention and revenue components in a similar way as presented in notebook [Introduction to Bayesian A/B Testing](https://www.pymc.io/projects/examples/en/latest/case_studies/bayesian_ab_testing_introduction.html). For this simulated example we use synthetic data set, see the blog post [A Simple Cohort Retention Analysis in PyMC](https://juanitorduz.github.io/retention/) For more details. [Here](https://github.com/juanitorduz/website_projects/blob/master/data/retention_data.csv) you can find the data to reproduce the results.

## Prepare Notebook

In [None]:
import arviz as az
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import numpy as np
import pandas as pd
import pymc as pm
import pymc_bart as pmb
import pytensor.tensor as pt
import seaborn as sns
from scipy.special import logit
from sklearn.preprocessing import MaxAbsScaler, LabelEncoder


plt.style.use("bmh")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"

%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"

In [None]:
seed: int = sum(map(ord, "retention"))
rng: np.random.Generator = np.random.default_rng(seed=seed)
random_seed_int: int = rng.integers(low=0, high=100, size=1).item()

## Read Data

We start by reading the data from previous posts (see [here](https://github.com/juanitorduz/website_projects/blob/master/Python/retantion_data.py) for the code to generate the data).

In [None]:
data_df = pd.read_csv("../data/retention_data.csv", parse_dates=["cohort", "period"])

data_df.head()

The new component that we have is `revenue` which represents the revenue per cohort.

In [None]:
data_df["revenue"].describe()

## Data Preprocessing

In order to understand the user vs revenue relation, let's compute the revenue per users and per *active* users. The former represent the overall cohort contribution and the latter the contribution of the active users.

In [None]:
data_df["revenue_per_users"] = data_df["revenue"] / data_df["n_users"]
data_df["revenue_per_active_users"] = data_df["revenue"] / data_df["n_active_users"]

Observe that we have certain `periods` where we do not have active users. This induces `NaN` values in the `revenue_per_active_users`.

In [None]:
data_df.query("revenue_per_active_users.isna()")

We fill these missing values with zero.

In [None]:
data_df.fillna(value={"revenue_per_active_users": 0.0}, inplace=True)

We make a data train-test split.

In [None]:
period_train_test_split = "2022-11-01"

train_data_df = data_df.query("period <= @period_train_test_split")
test_data_df = data_df.query("period > @period_train_test_split")
test_data_df = test_data_df[
    test_data_df["cohort"].isin(train_data_df["cohort"].unique())
]

## EDA

For a detailed EDA of the data, please refer to the previous posts ([A Simple Cohort Retention Analysis in PyMC](https://juanitorduz.github.io/retention/) and [Cohort Retention Analysis with BART](https://juanitorduz.github.io/retention_bart/)). Here we want to focus in the retention and revenue relation.  First, let's recall how the retention matrix looks like.

In [None]:
fig, ax = plt.subplots(figsize=(18, 7))

fmt = lambda y, _: f"{y :0.0%}"

(
    train_data_df.assign(
        cohort=lambda df: df["cohort"].dt.strftime("%Y-%m"),
        period=lambda df: df["period"].dt.strftime("%Y-%m"),
    )
    .query("cohort_age != 0")
    .filter(["cohort", "period", "retention"])
    .pivot(index="cohort", columns="period", values="retention")
    .pipe(
        (sns.heatmap, "data"),
        cmap="viridis_r",
        linewidths=0.2,
        linecolor="black",
        annot=True,
        fmt="0.0%",
        cbar_kws={"format": mtick.FuncFormatter(fmt)},
        ax=ax,
    )
)

ax.set_title("Retention by Cohort and Period");

The key observation is that the retention matrix has a clear seasonality pattern (in the `period`, i.e. observation variable) and seems to be changing as a function of the `cohort` (i.e. the group variable). This motivates using is a  BART model to capture complex interactions between the `period`, `cohort` and seasonal variables. In the next figure we plot the  retention rate by cohort over time (period) to future illustrate the seasonality pattern.

In [None]:
fig, ax = plt.subplots(figsize=(12, 7))
sns.lineplot(
    x="period",
    y="retention",
    hue="cohort",
    palette="viridis_r",
    alpha=0.8,
    data=train_data_df.query("cohort_age > 0").assign(
        cohort=lambda df: df["cohort"].dt.strftime("%Y-%m")
    ),
    ax=ax,
)
ax.legend(title="cohort", loc="center left", bbox_to_anchor=(1, 0.5), fontsize=7.5)
ax.set(title="Retention by Cohort and Period");

Note that the retention rate by itself hides how *big* is the cohort. At the very end, one os not just interested in the retention rate but in the value of the cohort. We can start by looking into absolute number of active users per cohort.

In [None]:
fig, ax = plt.subplots(figsize=(18, 7))

fmt = lambda y, _: f"{y :0.0f}"

(
    train_data_df.assign(
        cohort=lambda df: df["cohort"].dt.strftime("%Y-%m"),
        period=lambda df: df["period"].dt.strftime("%Y-%m"),
    )
    .query("cohort_age != 0")
    .filter(["cohort", "period", "n_active_users"])
    .pivot(index="cohort", columns="period", values="n_active_users")
    .pipe(
        (sns.heatmap, "data"),
        cmap="viridis_r",
        linewidths=0.2,
        linecolor="black",
        annot=True,
        annot_kws={"fontsize":8},
        fmt="0.0f",
        cbar_kws={"format": mtick.FuncFormatter(fmt)},
        ax=ax,
    )
)

ax.set_title("Active Users by Cohort and Period");

The younger cohorts are much smaller than the older cohorts. Next, we plot the revenue absolute values.

In [None]:
fig, ax = plt.subplots(figsize=(18, 7))

fmt = lambda y, _: f"{y :0.0f}"

(
    train_data_df.assign(
        cohort=lambda df: df["cohort"].dt.strftime("%Y-%m"),
        period=lambda df: df["period"].dt.strftime("%Y-%m"),
    )
    .query("cohort_age != 0")
    .filter(["cohort", "period", "revenue"])
    .pivot(index="cohort", columns="period", values="revenue")
    .pipe(
        (sns.heatmap, "data"),
        cmap="viridis_r",
        linewidths=0.2,
        linecolor="black",
        annot=True,
        annot_kws={"fontsize": 6},
        fmt="0.0f",
        cbar_kws={"format": mtick.FuncFormatter(fmt)},
        ax=ax,
    )
)

ax.set_title("Revenue by Cohort and Period");

The pattern looks very similar as the number of active users. Hence, we expect the quotient `revenue_per_active_users` to be relatively stable across cohorts.

In [None]:
fig, ax = plt.subplots(figsize=(18, 7))

fmt = lambda y, _: f"{y :0.0f}"

(
    train_data_df.assign(
        cohort=lambda df: df["cohort"].dt.strftime("%Y-%m"),
        period=lambda df: df["period"].dt.strftime("%Y-%m"),
    )
    .query("cohort_age != 0")
    .filter(["cohort", "period", "revenue_per_active_users"])
    .pivot(index="cohort", columns="period", values="revenue_per_active_users")
    .pipe(
        (sns.heatmap, "data"),
        cmap="viridis_r",
        linewidths=0.2,
        linecolor="black",
        annot=True,
        annot_kws={"fontsize": 9},
        fmt="0.0f",
        cbar_kws={"format": mtick.FuncFormatter(fmt)},
        ax=ax,
    )
)

ax.set_title("Revenue by Cohort and Period");

Observe that this ratio does not show a strong seasonality pattern. This suggest that for revenue the model we can simply add the seasonality pattern into the retention rate component. In addition, note that the `revenue_per_active_users` seems to decrease with the `cohort_age` linearly. In a similar manner, it seems to increase with the `age` of the cohort linearly as well.

Finally, we plot the `revenue_per_users` which includes also users which are not active.

In [None]:
fig, ax = plt.subplots(figsize=(18, 7))

fmt = lambda y, _: f"{y :0.0f}"

(
    train_data_df.assign(
        cohort=lambda df: df["cohort"].dt.strftime("%Y-%m"),
        period=lambda df: df["period"].dt.strftime("%Y-%m"),
    )
    .query("cohort_age != 0")
    .filter(["cohort", "period", "revenue_per_users"])
    .pivot(index="cohort", columns="period", values="revenue_per_users")
    .pipe(
        (sns.heatmap, "data"),
        cmap="viridis_r",
        linewidths=0.2,
        linecolor="black",
        annot=True,
        annot_kws={"fontsize": 9},
        fmt="0.0f",
        cbar_kws={"format": mtick.FuncFormatter(fmt)},
        ax=ax,
    )
)

ax.set_title("Revenue by Cohort and Period");

It is no surprise that we observe the seasonal component again (as the cohort size is fixed).

## Model

Motivates by the analysis above we suggest the following retention-revenue model.

\begin{align*}
\text{Revenue} & \sim \text{Gamma}(N_{\text{active}}, \lambda) \\
\log(\lambda) = (& \text{intercept} \\
    & + \beta_{\text{cohort age}} \text{cohort age} \\
    & + \beta_{\text{age}} \text{age} \\
    & + \beta_{\text{cohort age} \times \text{age}} \text{cohort age} \times \text{age} ) \\
N_{\text{active}} & \sim \text{Binomial}(N_{\text{total}}, p) \\
\textrm{logit}(p) & = \text{BART}(\text{cohort age}, \text{age}, \text{month})
\end{align*}

### Data Transformations

We do similar transformations as in the previous posts.

In [None]:
eps = np.finfo(float).eps
train_data_red_df = train_data_df.query("cohort_age > 0").reset_index(drop=True)
train_obs_idx = train_data_red_df.index.to_numpy()
train_n_users = train_data_red_df["n_users"].to_numpy()
train_n_active_users = train_data_red_df["n_active_users"].to_numpy()
train_retention = train_data_red_df["retention"].to_numpy()
train_retention_logit = logit(train_retention + eps)
train_data_red_df["month"] = train_data_red_df["period"].dt.strftime("%m").astype(int)
train_data_red_df["cohort_month"] = (
    train_data_red_df["cohort"].dt.strftime("%m").astype(int)
)
train_data_red_df["period_month"] = (
    train_data_red_df["period"].dt.strftime("%m").astype(int)
)
train_revenue = train_data_red_df["revenue"].to_numpy() + eps
train_revenue_per_user = train_revenue / (train_n_active_users + eps)

train_cohort = train_data_red_df["cohort"].to_numpy()
train_cohort_encoder = LabelEncoder()
train_cohort_idx = train_cohort_encoder.fit_transform(train_cohort).flatten()
train_period = train_data_red_df["period"].to_numpy()
train_period_encoder = LabelEncoder()
train_period_idx = train_period_encoder.fit_transform(train_period).flatten()

features: list[str] = ["age", "cohort_age", "month"]
x_train = train_data_red_df[features]

train_age = train_data_red_df["age"].to_numpy()
train_age_scaler = MaxAbsScaler()
train_age_scaled = train_age_scaler.fit_transform(train_age.reshape(-1, 1)).flatten()
train_cohort_age = train_data_red_df["cohort_age"].to_numpy()
train_cohort_age_scaler = MaxAbsScaler()
train_cohort_age_scaled = train_cohort_age_scaler.fit_transform(
    train_cohort_age.reshape(-1, 1)
).flatten()

### Model Specification

Now we are ready to specify the model in PyMC.
- For the retention component please see the details presented in the post [Cohort Retention Analysis with BART](https://juanitorduz.github.io/retention_bart/).
- The retention-revenue coupling is motivates by the model presented in the example notebook the post [Introduction to Bayesian A/B Testing](https://www.pymc.io/projects/examples/en/latest/case_studies/bayesian_ab_testing_introduction.html).

In [None]:
with pm.Model(coords={"feature": features}) as model:

    # --- Data ---
    model.add_coord(name="obs", values=train_obs_idx, mutable=True)
    age_scaled = pm.MutableData(name="age_scaled", value=train_age_scaled, dims="obs")
    cohort_age_scaled = pm.MutableData(
        name="cohort_age_scaled", value=train_cohort_age_scaled, dims="obs"
    )
    x = pm.MutableData(name="x", value=x_train, dims=("obs", "feature"))
    n_users = pm.MutableData(name="n_users", value=train_n_users, dims="obs")
    n_active_users = pm.MutableData(
        name="n_active_users", value=train_n_active_users, dims="obs"
    )
    revenue = pm.MutableData(name="revenue", value=train_revenue, dims="obs")

    # --- Priors ---
    intercept = pm.Normal(name="intercept", mu=0, sigma=1)
    b_age_scaled = pm.Normal(name="b_age_scaled", mu=0, sigma=1)
    b_cohort_age_scaled = pm.Normal(name="b_cohort_age_scaled", mu=0, sigma=1)
    b_age_cohort_age_interaction = pm.Normal(
        name="b_age_cohort_age_interaction", mu=0, sigma=1
    )

    # --- Parametrization ---
    # The BART component models the image of the retention rate under the
    # logit transform so that the range is not constrained to [0, 1].
    mu = pmb.BART(name="mu", X=x, Y=train_retention_logit, m=50, dims="obs")
    # We use the inverse logit transform to get the retention rate back into [0, 1].
    p = pm.Deterministic(name="p", var=pm.math.invlogit(mu), dims="obs")
    # We add a small epsilon to avoid numerical issues.
    p = pt.switch(pt.eq(p, 0), eps, p)
    p = pt.switch(pt.eq(p, 1), 1 - eps, p)

    # For the revenue component we use a Gamma distribution where we combine the number
    # of estimated active users with the average revenue per user.
    lam_log = pm.Deterministic(
        name="lam_log",
        var=intercept
        + b_age_scaled * age_scaled
        + b_cohort_age_scaled * cohort_age_scaled
        + b_age_cohort_age_interaction * age_scaled * cohort_age_scaled,
        dims="obs",
    )

    lam = pm.Deterministic(name="lam", var=pm.math.exp(lam_log), dims="obs")

    # --- Likelihood ---
    n_active_users_estimated = pm.Binomial(
        name="n_active_users_estimated",
        n=n_users,
        p=p,
        observed=n_active_users,
        dims="obs",
    )

    x = pm.Gamma(
        name="revenue_estimated",
        alpha=n_active_users_estimated + eps,
        beta=lam,
        observed=revenue,
        dims="obs",
    )

    mean_revenue_per_user = pm.Deterministic(
        name="mean_revenue_per_user", var=(1 / lam), dims="obs"
    )
    pm.Deterministic(
        name="mean_revenue_per_active_user", var=p * mean_revenue_per_user, dims="obs"
    )

pm.model_to_graphviz(model=model)

### Model Fitting

Now we are ready to fit the model.

In [None]:
with model:
    idata = pm.sample(draws=2_000, chains=4, random_seed=rng)
    posterior_predictive = pm.sample_posterior_predictive(trace=idata, random_seed=rng)


### Model Diagnostics

We look into the posterior predictive check:

In [None]:
ax = az.plot_ppc(
    data=posterior_predictive,
    kind="cumulative",
    observed_rug=True,
    random_seed=random_seed_int,
)
ax[0].set(
    title="Posterior Predictive Check (Retention)",
    xscale="log",
    xlabel="likelihood (n_active_users) - log scale",
)
ax[1].set(
    title="Posterior Predictive Check (Revenue)",
    xscale="log",
    xlabel="likelihood (revenue) - log scale",
    xlim=(1, None), # to avoid plotting the clipped value `eps`.
);

In [None]:
idata.sample_stats["diverging"].sum().item()

In [None]:
_ = az.plot_trace(
    data=idata,
    var_names=[
        "intercept",
        "b_age_scaled",
        "b_cohort_age_scaled",
        "b_age_cohort_age_interaction",
    ],
    compact=True,
    kind="rank_bars",
    backend_kwargs={"figsize": (12, 10), "layout": "constrained"},
)
plt.gcf().suptitle("Model Trace", fontsize=16);

The model seems to be doing a good job 🙂 ! 

### Retention Rate In-Sample Predictions

Let's see how the model performs in-sample. We plot the retention rate posterior mean predictions for the training data:

In [None]:
train_posterior_retention = (
    posterior_predictive.posterior_predictive["n_active_users_estimated"]
    / train_n_users[np.newaxis, None]
)
train_posterior_retention_mean = az.extract(
    data=train_posterior_retention, var_names=["n_active_users_estimated"]
).mean("sample")

fig, ax = plt.subplots(figsize=(10, 9))
sns.scatterplot(
    x="retention",
    y="posterior_retention_mean",
    data=train_data_red_df.assign(
        posterior_retention_mean=train_posterior_retention_mean
    ),
    hue="age",
    palette="viridis_r",
    size="n_users",
    ax=ax,
)
ax.axline(xy1=(0.3, 0.3), slope=1, color="black", linestyle="--", label="diagonal")
ax.legend()
ax.set(title="Posterior Predictive - Retention Mean");

In [None]:
train_posterior_revenue_mean = az.extract(
    data=posterior_predictive,
    group="posterior_predictive",
    var_names=["revenue_estimated"],
).mean("sample")

fig, ax = plt.subplots(figsize=(10, 9))
sns.scatterplot(
    x="revenue",
    y="posterior_revenue_mean",
    data=train_data_red_df.assign(posterior_revenue_mean=train_posterior_revenue_mean),
    hue="age",
    palette="viridis_r",
    size="n_users",
    ax=ax,
)
ax.axline(xy1=(1e5, 1e5), slope=1, color="black", linestyle="--", label="diagonal")
ax.legend()
ax.set(
    title="Posterior Predictive - Revenue Mean",
    xscale="log",
    yscale="log",
    xlabel="revenue (log)",
    ylabel="posterior_revenue_mean (log)",
);

Next, we look into the uncertainty estimates for a subset of individual cohorts:

In [None]:
train_retention_hdi = az.hdi(ary=train_posterior_retention)["n_active_users_estimated"]


def plot_train_retention_hdi_cohort(cohort_index: int, ax: plt.Axes) -> plt.Axes:

    mask = train_cohort_idx == cohort_index

    ax.fill_between(
        x=train_period[train_period_idx[mask]],
        y1=train_retention_hdi[mask, :][:, 0],
        y2=train_retention_hdi[mask, :][:, 1],
        alpha=0.3,
        color="C0",
        label="94% HDI (train)",
    )
    sns.lineplot(
        x=train_period[train_period_idx[mask]],
        y=train_retention[mask],
        color="C0",
        marker="o",
        label="observed retention (train)",
        ax=ax,
    )
    cohort_name = (
        pd.to_datetime(train_cohort_encoder.classes_[cohort_index]).date().isoformat()
    )
    ax.legend(loc="upper left")
    ax.set(title=f"Retention HDI - Cohort {cohort_name}")
    return ax


cohort_index_to_plot = [0, 1, 5, 10, 15, 20, 25, 30]

fig, axes = plt.subplots(
    nrows=np.ceil(len(cohort_index_to_plot) / 2).astype(int),
    ncols=2,
    figsize=(15, 10),
    sharex=True,
    sharey=True,
    layout="constrained",
)

for cohort_index, ax in zip(cohort_index_to_plot, axes.flatten()):
    plot_train_retention_hdi_cohort(cohort_index=cohort_index, ax=ax)

As in the linear model case, we are capturing the retention rate development over time. The uncertainty estimates are also quite similar to the linear model.

In [None]:
train_revenue_hdi = az.hdi(ary=posterior_predictive.posterior_predictive)["revenue_estimated"]


def plot_train_revenue_hdi_cohort(cohort_index: int, ax: plt.Axes) -> plt.Axes:

    mask = train_cohort_idx == cohort_index

    ax.fill_between(
        x=train_period[train_period_idx[mask]],
        y1=train_revenue_hdi[mask, :][:, 0],
        y2=train_revenue_hdi[mask, :][:, 1],
        alpha=0.3,
        color="C0",
        label="94% HDI (train)",
    )
    sns.lineplot(
        x=train_period[train_period_idx[mask]],
        y=train_revenue[mask],
        color="C0",
        marker="o",
        label="observed revenue (train)",
        ax=ax,
    )
    cohort_name = (
        pd.to_datetime(train_cohort_encoder.classes_[cohort_index]).date().isoformat()
    )
    ax.legend(loc="upper left")
    ax.set(title=f"revenue HDI - Cohort {cohort_name}")
    return ax


cohort_index_to_plot = [0, 1, 5, 10, 15, 20, 25, 30]

fig, axes = plt.subplots(
    nrows=np.ceil(len(cohort_index_to_plot) / 2).astype(int),
    ncols=2,
    figsize=(15, 10),
    sharex=True,
    sharey=False,
    layout="constrained",
)

for cohort_index, ax in zip(cohort_index_to_plot, axes.flatten()):
    plot_train_revenue_hdi_cohort(cohort_index=cohort_index, ax=ax)

fig.suptitle("Revenue Predictions", y=1.03, fontsize=16);

## Predictions

Now we transform the test data to the same format as the training data and use the model to predict the retention rates. Note that we are using the scalers and encoders from the training data.

### Data Transformations

In [None]:
test_data_red_df = test_data_df.query("cohort_age > 0")
test_data_red_df = test_data_red_df[
    test_data_red_df["cohort"].isin(train_data_red_df["cohort"].unique())
].reset_index(drop=True)
test_obs_idx = test_data_red_df.index.to_numpy()
test_n_users = test_data_red_df["n_users"].to_numpy()
test_n_active_users = test_data_red_df["n_active_users"].to_numpy()
test_retention = test_data_red_df["retention"].to_numpy()
test_revenue = test_data_red_df["revenue"].to_numpy()

test_cohort = test_data_red_df["cohort"].to_numpy()
test_cohort_idx = train_cohort_encoder.transform(test_cohort).flatten()

test_data_red_df["month"] = test_data_red_df["period"].dt.strftime("%m").astype(int)
test_data_red_df["cohort_month"] = test_data_red_df["cohort"].dt.strftime("%m").astype(int)
test_data_red_df["period_month"] = test_data_red_df["period"].dt.strftime("%m").astype(int)
x_test = test_data_red_df[features]

test_age = test_data_red_df["age"].to_numpy()
test_age_scaled = train_age_scaler.transform(test_age.reshape(-1, 1)).flatten()
test_cohort_age = test_data_red_df["cohort_age"].to_numpy()
test_cohort_age_scaled = train_cohort_age_scaler.transform(
    test_cohort_age.reshape(-1, 1)
).flatten()

### Out-of-Sample Posterior Predictions

Now we want to see out-of-sample predictions from this model. To begin, we need to compute the posterior predictive distribution for the test data.

In [None]:
with model:
    pm.set_data(
        new_data={
            "age_scaled": test_age_scaled,
            "cohort_age_scaled": test_cohort_age_scaled,
            "x": x_test,
            "n_users": test_n_users,
            "n_active_users": np.ones_like(
                test_n_active_users
            ),  # Dummy data to make coords work! We are not using this at prediction time!
            "revenue": np.ones_like(
                test_revenue
            ),  # Dummy data to make coords work! We are not using this at prediction time!
        },
        coords={"obs": test_obs_idx},
    )
    idata.extend(
        pm.sample_posterior_predictive(
            trace=idata,
            var_names=[
                "p",
                "mu",
                "n_active_users_estimated",
                "revenue_estimated",
                "mean_revenue_per_user",
                "mean_revenue_per_active_user",
            ],
            idata_kwargs={"coords": {"obs": test_obs_idx}},
            random_seed=rng,
        )
    )

### Retention Rate Out-of-Sample Predictions

Finally we compute the posterior retention rate distributions for the test data and visualize the results.

In [None]:
test_posterior_retention = (
    idata.posterior_predictive["n_active_users_estimated"] / test_n_users[np.newaxis, None]
)

test_retention_hdi = az.hdi(ary=test_posterior_retention)["n_active_users_estimated"]
test_revenue_hdi = az.hdi(ary=idata.posterior_predictive)["revenue_estimated"]

In [None]:
def plot_test_retention_hdi_cohort(cohort_index: int, ax: plt.Axes) -> plt.Axes:
    mask = test_cohort_idx == cohort_index

    test_period_range = test_data_red_df.query(
        f"cohort == '{train_cohort_encoder.classes_[cohort_index]}'"
    )["period"]

    ax.fill_between(
        x=test_period_range,
        y1=test_retention_hdi[mask, :][:, 0],
        y2=test_retention_hdi[mask, :][:, 1],
        alpha=0.3,
        color="C1",
        label="94% HDI (test)",
    )
    sns.lineplot(
        x=test_period_range,
        y=test_retention[mask],
        color="C1",
        marker="o",
        label="observed retention (test)",
        ax=ax,
    )
    return ax

In [None]:
cohort_index_to_plot = [0, 1, 5, 10, 15, 20, 25, 30]

fig, axes = plt.subplots(
    nrows=len(cohort_index_to_plot),
    ncols=1,
    figsize=(12, 15),
    sharex=True,
    sharey=True,
    layout="constrained",
)

for cohort_index, ax in zip(cohort_index_to_plot, axes.flatten()):
    plot_train_retention_hdi_cohort(cohort_index=cohort_index, ax=ax)
    plot_test_retention_hdi_cohort(cohort_index=cohort_index, ax=ax)
    ax.axvline(
        x=pd.to_datetime(period_train_test_split),
        color="black",
        linestyle="--",
        label="train/test split",
    )
    ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
fig.suptitle("Retention Predictions", y=1.03, fontsize=16);

In [None]:
def plot_test_revenue_hdi_cohort(cohort_index: int, ax: plt.Axes) -> plt.Axes:
    mask = test_cohort_idx == cohort_index

    test_period_range = test_data_red_df.query(
        f"cohort == '{train_cohort_encoder.classes_[cohort_index]}'"
    )["period"]

    ax.fill_between(
        x=test_period_range,
        y1=test_revenue_hdi[mask, :][:, 0],
        y2=test_revenue_hdi[mask, :][:, 1],
        alpha=0.3,
        color="C1",
        label="94% HDI (test)",
    )
    sns.lineplot(
        x=test_period_range,
        y=test_revenue[mask],
        color="C1",
        marker="o",
        label="observed revenue (test)",
        ax=ax,
    )
    return ax

In [None]:
cohort_index_to_plot = [0, 1, 5, 10, 15, 20, 25, 30]

fig, axes = plt.subplots(
    nrows=len(cohort_index_to_plot),
    ncols=1,
    figsize=(12, 15),
    sharex=True,
    sharey=False,
    layout="constrained",
)

for cohort_index, ax in zip(cohort_index_to_plot, axes.flatten()):
    plot_train_revenue_hdi_cohort(cohort_index=cohort_index, ax=ax)
    plot_test_revenue_hdi_cohort(cohort_index=cohort_index, ax=ax)
    ax.axvline(
        x=pd.to_datetime(period_train_test_split),
        color="black",
        linestyle="--",
        label="train/pred split",
    )
    ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
fig.suptitle("revenue Predictions", y=1.03, fontsize=16);