## Projecting Cohort Fertility Rates

This notebook uses PyMC to model and predict fertility rates in the United States, using preprocessed data from the US Census.

Background article on fertility rates: "[Why the total fertility rate doesn't necessarily tell us the number of births women eventually have](https://ourworldindata.org/total-fertility-rate-births-per-woman)" 

**Note:** Run `process_cps.ipynb` first to generate the preprocessed data file.

[Click here to run this notebook on Colab](https://colab.research.google.com/github/AllenDowney/BayesFertility/blob/main/fertility_cps.ipynb)

In [None]:
from os.path import basename, exists


def download(url):
    filename = basename(url)
    if not exists(filename):
        from urllib.request import urlretrieve

        local, _ = urlretrieve(url, filename)
        print("Downloaded " + local)


download("https://github.com/AllenDowney/BayesFertility/raw/main/utils.py")

In [None]:
import warnings

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import arviz as az
import pymc as pm

from utils import decorate, value_counts, save_baseline_results, load_baseline_results
import os

## Load Preprocessed Data

Load the preprocessed data from the preprocessing notebook.

In [None]:
filename = "../data/fertility_cps_preprocessed.h5"

# Load the aggregated data
sum_df = pd.read_hdf(filename, key='sum_df')
count_df = pd.read_hdf(filename, key='count_df')
cfr_cps = pd.read_hdf(filename, key='cfr_cps')

# Load labels
age_labels = pd.read_hdf(filename, key='age_labels').values
cohort_labels = pd.read_hdf(filename, key='cohort_labels').values

# Load metadata
metadata = pd.read_hdf(filename, key='metadata')
cutoff_year = int(metadata['cutoff_year'])

print(f"Loaded preprocessed data from {filename}")
print(f"Cutoff year: {cutoff_year}")
print(f"Shape: {sum_df.shape}")

## The Model

For the most recent generations, we have limited data.
To project what future fertility rates will look like, we'll use a model to estimate cohort and age effects, then use the model to generate predictions.

The following PyMC model is based on a log-linear model of age cohort effects -- that is, each cohort has a latent value, $\alpha$, that indicates their overall proclivity to have children, and each age group has a latent value, $\beta$, that indicates the tendency of people to have children at that age.

The average number of children borne by a particular cohort during a given 3-year span is $\lambda = \exp \alpha + \beta$.
So the average parity of a cohort at a given age is the cumulative sum of these rates up to their current age.

Both $\alpha$ and $\beta$ are modeled with a Gaussian random walk, which indicates that we expect the difference between successive cohorts, and between successive age groups, to follow a Gaussian distribution.
As a result, each value is estimated relative to its predecessor, so in the absence of enough data to infer a change, each value is presumed to be unchanged.
We'll see the consequences of this structure in the results.

The values of $\alpha$ are constrained so their mean is 0.
This is useful because the location of the $\alpha$ and $\beta$ coefficients is arbitrary; this constraint establishes a zero point.

Finally, if we are given the size and expected parity of a cohort-age group, we expect the observed data to follow a Poisson distribution with the given mean.

In [None]:
def make_model(sum_array, count_array, random_walk_sigma=0.1):
    with pm.Model() as model:
        n_cohorts, n_ages = sum_array.shape

        # Random walk prior for cohort effects with mean constraint
        sigma_alpha = pm.HalfNormal("sigma_alpha", sigma=random_walk_sigma)
        alpha = pm.GaussianRandomWalk(
            "alpha",
            sigma=sigma_alpha,
            shape=n_cohorts,
            init_dist=pm.Normal.dist(mu=0, sigma=0.4),
        )

        # Soft constraint to enforce mean zero without hard subtraction
        pm.Potential(
            "zero_mean_constraint", pm.logp(pm.Normal.dist(0, 0.001), 
                                            pm.math.mean(alpha))
        )

        # Random walk prior for age effects
        sigma_beta = pm.HalfNormal("sigma_beta", sigma=0.1)
        beta = pm.GaussianRandomWalk(
            "beta",
            sigma=sigma_beta, 
            shape=n_ages,
            init_dist=pm.Normal.dist(mu=0, sigma=0.4)
        )

        # Log-linear model for the ASBR
        log_lambda = alpha[:, None] + beta[None, :]
        lambda_ = pm.Deterministic("lambda", pm.math.exp(log_lambda))

        # Observed parity depends on the cumulative sum of ASBRs
        cumulative_lambda = pm.math.cumsum(lambda_, axis=1)

        # Likelihood, ignoring unobserved cohort-age pairs
        mask = count_array != 0
        y_obs = pm.Poisson(
            "y_obs", mu=(count_array * cumulative_lambda)[mask], observed=sum_array[mask]
        )
        
        return model

In [None]:
count_array = count_df.to_numpy()
sum_array = sum_df.to_numpy()

To see whether the prior distributions make sense, we'll look at the prior predictive distribution of $\lambda$.

In [None]:
# TODO: Make the model estimate this
random_walk_sigma = 0.05

In [None]:
model = make_model(sum_array, count_array, random_walk_sigma)

In [None]:
with model:
    prior_predictive = pm.sample_prior_predictive(1000)

lambda_prior_samples = (
    prior_predictive.prior["lambda"].stack(sample=("chain", "draw")).values
)

In [None]:
pd.Series(lambda_prior_samples.flatten()).describe()

Now let's sample the posterior distribution.
In general, the model samples well.

In [None]:
from utils import load_idata_or_sample

idata_filename = f"nc/fertility_cps_idata_{cutoff_year}_{random_walk_sigma}.nc"

idata = load_idata_or_sample(
    model, idata_filename, nuts_sampler="nutpie", force_run=True, random_seed=17
)

In [None]:
pm.summary(idata, var_names=["sigma_alpha", "sigma_beta"])

### Cohort effects

Here is a summary of the cohort effects.

In [None]:
alpha_summary = pm.summary(idata, var_names=["alpha"])
alpha_summary

In [None]:
def forest_plot(summary, labels, **options):
    means = summary["mean"].to_numpy()
    hdi_lower = summary["hdi_3%"].to_numpy()
    hdi_upper = summary["hdi_97%"].to_numpy()

    n_cohorts = len(means)
    x_positions = np.arange(n_cohorts)
    plt.xticks(x_positions, labels, **options)

    plt.errorbar(
        x_positions,
        means,
        yerr=[means - hdi_lower, hdi_upper - means],
        fmt="o",
        markersize=4,
        capsize=2,
        color="C0",
    )

And here's what they look like:

In [None]:
forest_plot(alpha_summary, cohort_labels, rotation=45)
plt.ylabel("Effect Size")
plt.xlabel("Cohorts")
plt.title("Cohort Effects")
plt.tight_layout()

Qualitatively, the trend is what we expect to see.
The proclivity to have children is highest in the 1937 and 1940 cohorts, who had their babies at the tail end of the baby boom.
It's mostly unchanged from 1943 to 1982, declines slowly until 1991, and then declines more quickly.

The error bars are wider in the most recent cohorts, where we have less data.
In the most recent cohorts, we see a useful property of the random walk -- if there is not enough data to be confident that there is a change, it assumes that there is no change.
For example, the central estimates for the 2006 and 2009 cohorts are about the same, because there is not enough data in the 2009 cohort to provide strong evidence of a difference.

This is different from what we would see in a hierarchical model, where in the absence of sufficient data, estimates are centered on the overall mean, not the most recent mean.

The model is conservative in the sense that it does not extrapolate trends -- although it looks like this proclivity has declined consistently over the last 30 years, the model does not assume that this trend will continue.

### Age effects

Here is a summary of the age effects.

In [None]:
beta_summary = pm.summary(idata, var_names=["beta"])
beta_summary

And here's what they look like.

In [None]:
forest_plot(beta_summary, age_labels)

plt.ylabel("Effect Size")
plt.xlabel("Ages")
plt.title("Age Effects")
plt.tight_layout()

Qualitatively, the trends here are what we expect: women are most likely to have children when they are 21-33 years old.

In the age groups where we have less data, the error bars are wider.
And again, the model does not extrapolate trends, so the estimates level off in the oldest groups.

## Prediction

Now we can use the model to generate predictive distributions for each cohort-age group, including retrodictions for the groups where we have data, and predictions for the groups where we have little or none.

In [None]:
with model:
    # Generate posterior predictive samples
    posterior_predictive = pm.sample_posterior_predictive(
        idata, var_names=["lambda"], random_seed=42
    )

    # Compute expected cumulative births for each draw
    lambda_samples = posterior_predictive.posterior_predictive[
        "lambda"
    ].values  # Shape: (chains, draws, n_cohorts, n_ages)

    cumulative_lambda_pred = np.cumsum(
        lambda_samples, axis=-1
    )  # Compute cumulative sum along age axis

The result we want is the `cumulative_lambda_pred` array, which contains the expected parity for each cohort-age group.
We can compute the mean by combining the first two axes from the results, averaging over the chains and the draws from each chain.

In [None]:
mean = cumulative_lambda_pred.mean(axis=(0, 1))
hdi = pm.hdi(cumulative_lambda_pred)

I'll put the results in a `DataFrame` so we can see the cohort and age labels.
Reading across the bottom line, we can see the predictions for women born in 2009.

In [None]:
mean_cumulative_rate = pd.DataFrame(mean, index=cohort_labels, columns=age_labels)
mean_cumulative_rate.tail()

Now let's see what the error bounds look like.

In [None]:
low = pd.DataFrame(hdi[:, :, 0], index=cohort_labels, columns=age_labels)
high = pd.DataFrame(hdi[:, :, 1], index=cohort_labels, columns=age_labels)

The following figure shows Completed Cohort Fertility Rate (CFR), which is the average number of children born to women from a specific birth cohort by the end of their reproductive years -- taken to be 42 to be consistent with the midpoint of the range used by CPS, 40-44.

In [None]:
# Define the age for CFR calculation
cfr_age = 42

We'll use age `cfr_age` for the CFR calculations.

In [None]:
cfr_cohort = mean_cumulative_rate[cfr_age]
plt.fill_between(low.index, low[cfr_age], high[cfr_age], alpha=0.1)
cfr_cohort.plot(label="predicted CFR")
decorate(xlabel="Cohort", ylabel="Completed cohort fertility rate", ylim=[0, 3.7])

In [None]:
cfr_shifted = pd.Series(cfr_cohort.values, cfr_cohort.index + cfr_age, copy=True)
plt.fill_between(cfr_shifted.index, low[cfr_age], high[cfr_age], alpha=0.1)
cfr_shifted.plot(alpha=0.6, label=f"Predicted CFR (age {cfr_age})")

cfr_cps.plot(alpha=0.8, label="Actual CFR (ages 40-44)")
decorate(ylabel="Completed cohort fertility rate (CFR)", ylim=[0, 3.5])

The model predicts that CFR will decline steeply, starting with women born around 1980, and reaching very low levels for the most recent cohorts.

The error bounds show 94% credible intervals that take into account the uncertainty of the estimated coefficients, but we should not take them too seriously because they don't take into account the far larger source of uncertainty -- any number of things that could happen in the future that would affect these outcomes.

I suggest we think of these results as projections rather than predictions -- that is, they show us what we would expect in the future if the structure of the model is appropriate, the estimated parameters are accurate, and nothing changes in the future that substantially affects the outcome.

Considering the structure of the model, the most obvious omission is the possibility of a "rebound" effect, where a cohort that has fewer children when they are young goes on to have more children when they are older, with the effect of closing the CFR gap with previous generations.

Of course that could happen, and possible extensions to the model could either add an interaction term that estimates the rebound in previous generations, or a "what if" parameter that lets us explore the effect of different levels of rebound on future CFR.

We can come back to this question, but first let's consider whether the model has adequately captured the structure of the data.

The results show that, unless there is a substantial shift toward higher fertility, starting soon, we should expect a large decline in CFR over the next 30 years, comparable in speed and magnitude to the previous decline at the end of the baby boom.

Of course, predictions 30 years into the future are unlikely to be precise, but if we take the results of the model at face value, it is plausible that CFR will drop below 1.0 between 2040 and 2050, comparable to levels in South Korea now.
At that level, barring large-scale immigration, the population of the United States would decline quickly.

## Compare model to data

The following function plots observed average parity in each cohort-age group along with the estimates from the model.
Where we have data, we can see if the retrodictions fit it.
And where we are missing data, we can see if the predictions seem plausible.

In [None]:
def plot_cohorts(mean_parity_df, start, end):
    cohorts = mean_parity_df.index[start:end]
    palette = list(sns.color_palette("nipy_spectral", len(cohorts)))

    for i, cohort in enumerate(cohorts):
        mean_parity_df.loc[cohort].plot(
            style="o", color=palette[i], alpha=0.8, label=cohort
        )
        mean_cumulative_rate.loc[cohort].plot(
            style=":", color=palette[i], alpha=0.6, label=""
        )

    decorate(xlabel="Age")

Here are the results for women born in the 1990s and 2000s.

In [None]:
mean_parity_df = sum_df / count_df
plot_cohorts(mean_parity_df, -7, None)
decorate(ylabel="Number of live births", title="1990s and 2000s Cohorts")

Looking at the top line, we can see that the retrodictions for the 1991 cohort fit the data well.

In [None]:
# Example: Check parity at age 33 and predicted CFR
# mean_parity_df.loc[1991, 33], mean_cumulative_rate.loc[1991, cfr_age]

As expected, predicted CFRs for successive cohorts are lower.
The projections for the last two cohorts overlap because the estimates for their cohort effects are almost the same.

Here are the results for women born in the 1970s and 1980s.

In [None]:
plot_cohorts(mean_parity_df, -14, -7)
decorate(ylim=[0, 3], ylabel="Number of live births", title="1970s and 1980s Cohorts")

The model seems to capture the structure of the data well.
Here are the results for women born in the 1950s and 1960s.

In [None]:
plot_cohorts(mean_parity_df, -20, -14)
decorate(ylim=[0, 3], ylabel="Number of live births", title="1950s and 1960s Cohorts")

The estimated CFRs for these cohorts are not very different.

Finally here are the results for women born in the 1940s and the end of the 1930s.

In [None]:
plot_cohorts(mean_parity_df, -27, -20)
decorate(ylabel="Number of live births", title="1930s and 1940s Cohorts")

Here we see the decline in total fertility that characterized the end of the baby boom.

Overall, the model seems to capture the structure of the data, and its projections are plausible in the sense that they are based on the assumption that the future will be like the past.

## Regression testing

In [None]:
# Save baseline results for regression testing
# Use the same cfr_age as defined above (42, to match CPS range 40-44)
# cfr_age is already defined earlier in the notebook

# Prepare CFR data with HDI bounds
cfr_df = pd.DataFrame({
    'cohort': cohort_labels,
    'mean': mean_cumulative_rate[cfr_age].values,
    'low': low[cfr_age].values,
    'high': high[cfr_age].values
})

# Save all three outputs as v2.0
model_version = "v2.0"
save_baseline_results(
    version=model_version,
    alpha_summary=alpha_summary,
    beta_summary=beta_summary,
    cfr_df=cfr_df,
    cohort_labels=cohort_labels,
    age_labels=age_labels
)

In [None]:
# Load v1.0 results and compare with v2.0
baseline = load_baseline_results("v1.0")

In [None]:
# Compare cohort effects
plt.figure(figsize=(8, 5))
plt.plot(baseline['alpha']['cohort'], baseline['alpha']['mean'], 'o-', label='v1.0', markersize=4)
plt.plot(cohort_labels, alpha_summary['mean'], 's-', label='v2.0', markersize=4)
plt.legend()
plt.xlabel('Cohort')
plt.ylabel('Cohort Effect')
plt.title('Cohort Effects Comparison')
plt.grid(True, alpha=0.3)
decorate()

In [None]:
# Compare age effects
plt.figure(figsize=(8, 5))
plt.plot(baseline['beta']['age'], baseline['beta']['mean'], 'o-', label='v1.0', markersize=4)
plt.plot(age_labels, beta_summary['mean'], 's-', label='v2.0', markersize=4)
plt.legend()
plt.xlabel('Age')
plt.ylabel('Age Effect')
plt.title('Age Effects Comparison')
plt.grid(True, alpha=0.3)
decorate()

In [None]:
# Compare CFR predictions
plt.figure(figsize=(8, 5))
plt.plot(baseline['cfr_df']['cohort'], baseline['cfr_df']['mean'], 'o-', label='v1.0', markersize=4)
plt.plot(cohort_labels, mean_cumulative_rate[cfr_age], 's-', label='v2.0', markersize=4)
plt.legend()
plt.xlabel('Cohort')
plt.ylabel(f'CFR at Age {cfr_age}')
plt.title('CFR Predictions Comparison')
plt.grid(True, alpha=0.3)
decorate()

In [None]:
from utils import beep

beep()