# Regimes

A **regime** is a phase of life with its own utility function, states, actions, and
constraints. Models have at least one non-terminal regime and one terminal regime.

Example framing:

- **Working life**: the agent chooses labor supply and consumption
- **Retirement**: the agent consumes out of savings (terminal)

This page covers:

- Regime anatomy — what each field does
- Terminal vs non-terminal regimes
- Regime transitions — deterministic and stochastic
- Building a model from regimes
- A complete worked example

In [None]:
from pprint import pprint

import jax.numpy as jnp

from lcm import (
    AgeGrid,
    DiscreteGrid,
    LinSpacedGrid,
    LogSpacedGrid,
    MarkovRegimeTransition,
    Model,
    Regime,
    RegimeTransition,
    categorical,
)
from lcm.typing import (
    BoolND,
    ContinuousAction,
    ContinuousState,
    DiscreteAction,
    FloatND,
    ScalarInt,
)

## Regime Anatomy

A `Regime` is defined by these fields:

| Field | Type | Purpose |
|---|---|---|
| `transition` | `RegimeTransition`, `MarkovRegimeTransition`, or `None` | Next-regime transition. `None` marks a terminal regime. |
| `active` | `Callable[[float], bool]` | Age-based predicate — when the regime is active |
| `states` | `dict[str, Grid]` | State variables with grids (each grid has a `transition`) |
| `actions` | `dict[str, Grid]` | Choice variables with grids (no transitions) |
| `functions` | `dict[str, Callable]` | Must include `"utility"`; can include auxiliary functions |
| `constraints` | `dict[str, Callable]` | Feasibility constraints on state-action combinations |

Note the two different uses of "transition" here:

- The regime's `transition` is **forward-looking**: it determines which regime the
  agent enters *next* period. It lives on the source regime.
- A grid's `transition` is **backward-looking**: it defines how a state variable
  *arrived* at its current value. In multi-regime models, per-boundary mappings
  live on the *target* regime's grid. See the [grids page](grids.ipynb) for details.

### Building a regime step by step

Let's build a working-life regime for a consumption-savings model.

**Step 1: Define categorical variables.**

In [None]:
@categorical
class Work:
    no: int
    yes: int


@categorical
class RegimeId:
    working: int
    retired: int

**Step 2: Define functions.** The `"utility"` key is required. Auxiliary functions
(like `earnings`) can be referenced by other functions through their argument names.

In [None]:
def utility(
    consumption: ContinuousAction,
    work: DiscreteAction,
    disutility_of_work: float,
    risk_aversion: float,
) -> FloatND:
    return consumption ** (1 - risk_aversion) / (
        1 - risk_aversion
    ) - disutility_of_work * (work == Work.yes)


def earnings(work: DiscreteAction, wage: float) -> FloatND:
    return jnp.where(work == Work.yes, wage, 0.0)

**Step 3: Define state transitions.** Transition functions are attached directly to
grids via the `transition` parameter.

In [None]:
def next_wealth(
    wealth: ContinuousState,
    earnings: FloatND,
    consumption: ContinuousAction,
    interest_rate: float,
) -> ContinuousState:
    return (1 + interest_rate) * (wealth + earnings - consumption)

**Step 4: Define constraints.** Constraints filter infeasible state-action
combinations. They return boolean arrays.

In [None]:
def borrowing_constraint(
    wealth: ContinuousState,
    earnings: FloatND,
    consumption: ContinuousAction,
) -> BoolND:
    return wealth + earnings - consumption >= 0

**Step 5: Define the regime transition.** This determines which regime the agent
enters next period.

In [None]:
def next_regime(age: float, retirement_age: float) -> ScalarInt:
    return jnp.where(age >= retirement_age, RegimeId.retired, RegimeId.working)

**Step 6: Assemble the regime.**

In [None]:
RETIREMENT_AGE = 65

working = Regime(
    transition=RegimeTransition(next_regime),
    active=lambda age: age < RETIREMENT_AGE,
    states={
        "wealth": LinSpacedGrid(start=0, stop=50, n_points=25, transition=next_wealth),
    },
    actions={
        "work": DiscreteGrid(Work),
        "consumption": LogSpacedGrid(start=0.5, stop=50, n_points=50),
    },
    functions={
        "utility": utility,
        "earnings": earnings,
    },
    constraints={
        "borrowing_constraint": borrowing_constraint,
    },
)

## Terminal vs Non-Terminal Regimes

- **Terminal regime**: `transition=None`. The value function equals the utility
  function directly — there is no continuation value.
- **Non-terminal regime**: `transition` wraps a function. pylcm auto-injects an
  aggregation function `H` that combines utility with the discounted continuation
  value:

$$
H(u, V', \beta) = u + \beta \, V'
$$

In [None]:
def utility_retired(wealth: ContinuousState, risk_aversion: float) -> FloatND:
    return wealth ** (1 - risk_aversion) / (1 - risk_aversion)


retired = Regime(
    transition=None,
    active=lambda age: age >= RETIREMENT_AGE,
    states={
        "wealth": LinSpacedGrid(start=0, stop=50, n_points=25, transition=None),
    },
    functions={"utility": utility_retired},
)

print("Terminal?", retired.terminal)

## Regime Transitions

The regime transition function determines which regime an agent enters in the next
period. There are two kinds:

### Deterministic: `RegimeTransition`

The function returns an integer regime ID (from the `@categorical` `RegimeId`
class). Use this for transitions that depend deterministically on state — for
example, mandatory retirement at a certain age. The `next_regime` function defined
above is wrapped in `RegimeTransition`:

In [None]:
det_transition = RegimeTransition(next_regime)

### Stochastic: `MarkovRegimeTransition`

The function returns a probability array over all regimes. Use this when the regime
transition is uncertain — for example, a mortality risk that determines whether the
agent survives to the next period.

In [None]:
@categorical
class RegimeIdMortality:
    alive: int
    dead: int


def survival_transition(survival_prob: float) -> FloatND:
    """Return [P(alive), P(dead)]."""
    return jnp.array([survival_prob, 1 - survival_prob])


stoch_transition = MarkovRegimeTransition(survival_transition)

Internally, deterministic transitions are converted to one-hot probability arrays,
so both types end up in the same format during the solve step.

## Building a Model

A `Model` assembles regimes into a solvable life-cycle problem. It requires:

- `regimes`: dict mapping names to `Regime` instances
- `ages`: an `AgeGrid` defining the lifecycle
- `regime_id_class`: a `@categorical` class whose fields match the regime names

In [None]:
age_grid = AgeGrid(start=25, stop=65, step="20Y")
print("Ages:", age_grid.values)
print("Periods:", age_grid.n_periods)

In [None]:
model = Model(
    regimes={
        "working": working,
        "retired": retired,
    },
    ages=age_grid,
    regime_id_class=RegimeId,
)

The model validates that:

- There is at least one terminal and one non-terminal regime
- The `regime_id_class` fields match the regime names
- All state grids have explicit `transition` parameters

### Parameters template

After construction, `model.params_template` shows what parameters the model
expects. Parameters shared across regimes (like `risk_aversion`) appear at the
top level.

In [None]:
pprint(dict(model.params_template))

## Complete Example

A three-period consumption-savings model. Ages 25 and 45 are working life; age 65 is
retirement.

In [None]:
params = {
    "discount_factor": 0.95,
    "risk_aversion": 1.5,
    "interest_rate": 0.03,
    "working": {
        "utility": {"disutility_of_work": 1.0},
        "earnings": {"wage": 20.0},
        "next_regime": {"retirement_age": age_grid.precise_values[-2]},
    },
}

In [None]:
result = model.solve_and_simulate(
    params=params,
    initial_regimes=["working"] * 50,
    initial_states={
        "age": jnp.full(50, age_grid.values[0]),
        "wealth": jnp.linspace(1, 40, 50),
    },
)

df = result.to_dataframe(additional_targets="all")
df.head(10)