# A Tiny Example

A three-period consumption-savings model with two regimes:

- **Working life** (ages 25 and 45): The agent chooses whether to work and how much
  to consume. A simple tax-and-transfer system guarantees a consumption floor.
  Savings earn interest.
- **Retirement** (age 65): Terminal regime. The agent consumes out of remaining
  wealth.

## Model

An agent lives for three periods (ages 25, 45, and 65). In the first two periods
(working life), the agent chooses whether to work $d_t \in \{0, 1\}$ and how
much to consume $c_t$. In the final period (retirement), the agent consumes out
of remaining wealth.

**Working life** (ages 25 and 45):

$$
V_t(w_t) = \max_{d_t,\, c_t} \left\{
    \frac{c_t^{1-\sigma}}{1-\sigma} - \phi \, d_t
    + \beta \, V_{t+1}(w_{t+1})
\right\}
$$

subject to

$$
\begin{align}
e_t &= d_t \cdot \bar{w} \\[4pt]
\tau(e_t, w_t) &= \begin{cases}
    \theta\,(e_t - \underline{c})
        & \text{if } e_t \geq \underline{c} \\
    \min(0,\; w_t + e_t - \underline{c})
        & \text{otherwise}
\end{cases} \\[4pt]
a_t &= w_t + e_t - \tau(e_t, w_t) - c_t \\[4pt]
w_{t+1} &= (1 + r)\, a_t \\[4pt]
a_t &\geq 0
\end{align}
$$

where $w_t$ is wealth, $e_t$ earnings, $\bar{w}$ the wage, $\underline{c}$ a
consumption floor guaranteed by transfers, $\theta$ the tax rate, and $a_t$
end-of-period wealth. The transfer only kicks in when the agent's resources
($w_t + e_t$) fall below the consumption floor.

**Retirement** (age 65, terminal):

$$
V_2(w_2) = \max_{c_2 \leq w_2}
    \frac{c_2^{1-\sigma}}{1-\sigma}
$$

In [None]:
from pprint import pprint

import jax.numpy as jnp
import pandas as pd
import plotly.express as px

from lcm import (
    AgeGrid,
    DiscreteGrid,
    LinSpacedGrid,
    LogSpacedGrid,
    Model,
    Regime,
    RegimeTransition,
    categorical,
)
from lcm.typing import (
    BoolND,
    ContinuousAction,
    ContinuousState,
    DiscreteAction,
    FloatND,
    ScalarInt,
)

## Categorical Variables

In [None]:
@categorical
class Work:
    no: int
    yes: int


@categorical
class RegimeId:
    working_life: int
    retirement: int

## Model Functions

In [None]:
# Utility


def utility(
    consumption: ContinuousAction,
    work: DiscreteAction,
    disutility_of_work: float,
    risk_aversion: float,
) -> FloatND:
    return consumption ** (1 - risk_aversion) / (
        1 - risk_aversion
    ) - disutility_of_work * (work == Work.yes)


def utility_retirement(wealth: ContinuousState, risk_aversion: float) -> FloatND:
    return wealth ** (1 - risk_aversion) / (1 - risk_aversion)


# Auxiliary functions


def earnings(work: DiscreteAction, wage: float) -> FloatND:
    return jnp.where(work == Work.yes, wage, 0.0)


def taxes_transfers(
    earnings: FloatND,
    wealth: ContinuousState,
    consumption_floor: float,
    tax_rate: float,
) -> FloatND:
    return jnp.where(
        earnings >= consumption_floor,
        tax_rate * (earnings - consumption_floor),
        jnp.minimum(0.0, wealth + earnings - consumption_floor),
    )


def end_of_period_wealth(
    wealth: ContinuousState,
    earnings: FloatND,
    taxes_transfers: FloatND,
    consumption: ContinuousAction,
) -> FloatND:
    return wealth + earnings - taxes_transfers - consumption


# State transition


def next_wealth(end_of_period_wealth: FloatND, interest_rate: float) -> ContinuousState:
    return (1 + interest_rate) * end_of_period_wealth


# Constraints


def borrowing_constraint_working(end_of_period_wealth: FloatND) -> BoolND:
    return end_of_period_wealth >= 0


# Regime transition


def next_regime(age: float, last_working_age: float) -> ScalarInt:
    return jnp.where(
        age >= last_working_age, RegimeId.retirement, RegimeId.working_life
    )

## Regimes and Model

In [None]:
age_grid = AgeGrid(start=25, stop=65, step="20Y")
retirement_age = age_grid.precise_values[-1]

working_life = Regime(
    transition=RegimeTransition(next_regime),
    active=lambda age: age < retirement_age,
    states={
        "wealth": LinSpacedGrid(start=0, stop=50, n_points=25, transition=next_wealth),
    },
    actions={
        "work": DiscreteGrid(Work),
        "consumption": LogSpacedGrid(start=4, stop=50, n_points=100),
    },
    functions={
        "utility": utility,
        "earnings": earnings,
        "taxes_transfers": taxes_transfers,
        "end_of_period_wealth": end_of_period_wealth,
    },
    constraints={
        "borrowing_constraint_working": borrowing_constraint_working,
    },
)

retirement = Regime(
    transition=None,
    active=lambda age: age >= retirement_age,
    states={
        "wealth": LinSpacedGrid(start=0, stop=50, n_points=25, transition=None),
    },
    functions={"utility": utility_retirement},
)

model = Model(
    regimes={
        "working_life": working_life,
        "retirement": retirement,
    },
    ages=age_grid,
    regime_id_class=RegimeId,
    description="A tiny three-period consumption-savings model.",
)

## Parameters

Use `model.params_template` to see what parameters the model expects, organized
by regime and function.

In [None]:
pprint(dict(model.params_template))

Parameters shared across regimes (`risk_aversion`, `discount_factor`,
`interest_rate`) can be specified at the model level. Parameters unique to one
regime go under the regime name.

In [None]:
params = {
    "discount_factor": 0.95,
    "risk_aversion": 1.5,
    "interest_rate": 0.03,
    "working_life": {
        "utility": {"disutility_of_work": 1.0},
        "earnings": {"wage": 20.0},
        "taxes_transfers": {"consumption_floor": 2.0, "tax_rate": 0.2},
        "next_regime": {"last_working_age": age_grid.precise_values[-2]},
    },
}

## Solve and Simulate

In [None]:
n_agents = 100

result = model.solve_and_simulate(
    params=params,
    initial_regimes=["working_life"] * n_agents,
    initial_states={"wealth": jnp.linspace(1, 40, n_agents)},
)

In [None]:
df = result.to_dataframe(additional_targets="all")
df["age"] = df["age"].astype(int)
df.loc[df["age"] == retirement_age, "consumption"] = df.loc[
    df["age"] == retirement_age, "wealth"
]
columns = [
    "regime",
    "work",
    "consumption",
    "wealth",
    "earnings",
    "taxes_transfers",
    "end_of_period_wealth",
    "value",
]
df.set_index(["subject_id", "age"])[columns].head(20).style.format(
    precision=1,
    na_rep="",
)

In [None]:
# Classify agents by work pattern across the two working-life periods
first_working_age = age_grid.precise_values[0]
last_working_age = age_grid.precise_values[-2]

df_working = df[df["regime"] == "working_life"]
work_by_age = df_working.pivot_table(
    index="subject_id",
    columns="age",
    values="work",
    aggfunc="first",
)
work_pattern = (
    work_by_age[first_working_age].astype(str)
    + ", "
    + work_by_age[last_working_age].astype(str)
)
assert "yes, yes" not in work_pattern.to_numpy(), (
    "Plotting assumes that no agent works in both periods of working life."
)

label_map = {
    "yes, no": "low",  # work early, not later
    "no, yes": "medium",  # coast early, work later
    "no, no": "high",  # never work
}
groups = work_pattern.map(label_map).rename("initial_wealth")

# Combined descriptives and work decisions table
initial_wealth = df[df["age"] == first_working_age].set_index("subject_id")["wealth"]
group_desc = initial_wealth.groupby(groups).agg(["min", "max"]).round(1)

df_groups = df.copy()
df_groups["initial_wealth"] = df_groups["subject_id"].map(groups)
df_mean = df_groups.groupby(["initial_wealth", "age"], as_index=False).mean(
    numeric_only=True,
)
work_table = df_mean[df_mean["age"] < retirement_age].pivot_table(
    index="initial_wealth",
    columns="age",
    values="earnings",
)
work_table = (work_table > 0).astype(int)
work_table.columns = [f"works {c}" for c in work_table.columns]

summary = pd.concat([group_desc, work_table], axis=1)
summary.index.name = "initial_wealth"
summary.loc[["low", "medium", "high"]].style.format(precision=1, na_rep="")

In [None]:
fig = px.line(
    df_mean,
    x="age",
    y="consumption",
    color="initial_wealth",
    title="Consumption by Age",
    template="plotly_dark",
)
fig.show()

In [None]:
fig = px.line(
    df_mean,
    x="age",
    y="wealth",
    color="initial_wealth",
    title="Wealth by Age",
    template="plotly_dark",
)
fig.show()