# The Transition System

pylcm models involve two independent layers of transitions that compose during
backward induction:

1. **State transitions** — how a state variable arrived at its current value
   (attached to grids)
2. **Regime transitions** — which regime the agent enters next period (attached to
   the regime itself)

The two layers have opposite orientations:

- **Regime transitions are forward-looking.** The `transition` field on a `Regime`
  answers "where does the agent go next?" It lives on the *source* regime and
  points toward the future.
- **State transitions are backward-looking.** The `transition` parameter on a grid
  answers "how did this state variable reach its current value?" It lives on the
  grid that *receives* the value. In multi-regime models, per-boundary mappings
  are placed on the *target* regime's grid.

This notebook explains how each layer works, how per-boundary transitions resolve
cross-regime state mismatches, and how everything composes in the Bellman equation.

In [None]:
import jax
import jax.numpy as jnp

from lcm import (
    DiscreteGrid,
    DiscreteMarkovGrid,
    LinSpacedGrid,
    Regime,
    RegimeTransition,
    categorical,
)
from lcm.typing import (
    ContinuousState,
    DiscreteState,
    FloatND,
    ScalarInt,
)

## State Transition Mechanics

State transitions are attached directly to grid objects via the `transition`
parameter. There are four cases:

| Grid configuration | Behavior |
|---|---|
| `transition=some_func` | Deterministic: $s' = f(s, a, \ldots)$ |
| `transition=None` | Fixed: $s' = s$ (identity auto-generated) |
| `DiscreteMarkovGrid(transition=func)` | Stochastic: probability-weighted expectation |
| Shock grids (`lcm.shocks.*`) | Intrinsic transitions with interpolated weights |

### Deterministic state transitions

A state transition function defines how the current state value was determined from
last period's states, actions, and parameters. The function's argument names are
resolved from the regime's namespace.

In [None]:
def next_wealth(
    wealth: ContinuousState,
    consumption: ContinuousState,
    interest_rate: float,
) -> ContinuousState:
    return (1 + interest_rate) * (wealth - consumption)


# Attached to the grid:
wealth_grid = LinSpacedGrid(start=0, stop=100, n_points=50, transition=next_wealth)

### Fixed states and identity transitions

When `transition=None`, pylcm auto-generates an `_IdentityTransition` internally.
This shows up in `regime.get_all_functions()` under the key `"next_<state_name>"`.

In [None]:
@categorical
class EducationLevel:
    low: int
    high: int


@categorical
class RegimeId:
    working: int
    retired: int


def utility(wealth: ContinuousState) -> FloatND:
    return jnp.log(wealth + 1.0)


def next_regime() -> ScalarInt:
    return RegimeId.retired


regime = Regime(
    transition=RegimeTransition(next_regime),
    states={
        "education": DiscreteGrid(EducationLevel, transition=None),
        "wealth": LinSpacedGrid(start=0, stop=50, n_points=10, transition=None),
    },
    functions={"utility": utility},
)

all_funcs = regime.get_all_functions()
print("Function keys:", list(all_funcs.keys()))
print("next_education type:", type(all_funcs["next_education"]).__name__)
print("next_wealth type:   ", type(all_funcs["next_wealth"]).__name__)

Both fixed states produce `_IdentityTransition` objects. These are marked with
`_is_auto_identity = True` so that validation can distinguish them from
user-provided transitions.

### Stochastic state transitions (DiscreteMarkovGrid)

For `DiscreteMarkovGrid`, the transition function returns a probability array over
the categories. During the solve step, pylcm computes a probability-weighted
expectation over next-period states:

$$
\mathbb{E}[V(s')] = \sum_{s'} P(s' \mid s) \, V(s')
$$

In [None]:
@categorical
class Health:
    bad: int
    good: int


def health_transition(health: DiscreteState) -> FloatND:
    return jnp.where(
        health == Health.good,
        jnp.array([0.1, 0.9]),  # good → 90% stay good
        jnp.array([0.6, 0.4]),  # bad  → 40% recover
    )


health_grid = DiscreteMarkovGrid(Health, transition=health_transition)

# Inspect the transition probabilities
for state_name, code in [("bad", Health.bad), ("good", Health.good)]:
    probs = health_transition(jnp.array(code))
    print(f"P(next | {state_name}) = {probs}")

### Shock grids

Shock grids (from `lcm.shocks.iid` and `lcm.shocks.ar1`) have intrinsic
transitions computed from the distribution. For IID shocks, the transition
probabilities are the same regardless of the current value. For AR(1) shocks,
probabilities depend on the current state.

Shock grids do not accept a `transition` parameter — their transitions are
built-in.

In [None]:
import lcm.shocks.iid

shock = lcm.shocks.iid.Normal(
    n_points=5, gauss_hermite=False, mu=0.0, sigma=1.0, n_std=2.5
)
print("Grid points:", shock.to_jax())

## Per-Boundary State Transitions

When a discrete state has different categories across regimes, a simple callable
transition is not enough — you need to map from one category set to another at
the regime boundary.

The solution: a **mapping transition** keyed by `(source_regime, target_regime)`
pairs, placed on the **target** regime's grid.

### Example: different health categories

Suppose working life has three health states (disabled, bad, good) but retirement
only has two (bad, good). The transition from working to retired needs an explicit
mapping.

In [None]:
@categorical
class HealthWorking:
    disabled: int
    bad: int
    good: int


@categorical
class HealthRetired:
    bad: int
    good: int


def map_working_to_retired(health: DiscreteState) -> DiscreteState:
    """Map 3-category working health to 2-category retired health."""
    return jnp.where(
        health == HealthWorking.good,
        HealthRetired.good,
        HealthRetired.bad,
    )


# Verify the mapping
for name, code in [("disabled", 0), ("bad", 1), ("good", 2)]:
    result = map_working_to_retired(jnp.array(code))
    print(f"working {name} ({code}) → retired code {int(result)}")

The mapping is placed on the target regime's grid:

In [None]:
health_retired_grid = DiscreteGrid(
    HealthRetired,
    transition={
        ("working", "retired"): map_working_to_retired,
    },
)

### Resolution priority

When resolving which transition function to use at a regime boundary `(source,
target)`, pylcm checks (in order):

1. Target grid mapping for `(source, target)`
2. Source grid mapping for `(source, target)`
3. Source grid's callable transition
4. Target grid's callable transition
5. Auto-generated identity (if categories match)

If the categories differ across regimes and no explicit mapping is found,
`ModelInitializationError` is raised.

### Parameterized per-boundary transitions

Per-boundary mapping functions can take **parameters** beyond the state variable
itself. A common use case is a continuous state whose transition law differs
across regime boundaries — for example, wealth that grows at a rate specific to
the target regime.

When pylcm resolves a per-boundary transition from the **target** grid's mapping
(priority 1 above), any parameters in that function are looked up in the
**target** regime's parameter template. This means the user specifies the
parameter value under the target regime in the `params` dict, and pylcm
automatically routes it to the transition function at the boundary.

The rule is simple: **whoever owns the mapping owns the parameters**. Since
per-boundary mappings live on the target regime's grid, their parameters come
from the target regime.

### Example: regime-specific growth rate

Consider a two-regime model (phase 1 → phase 2) where wealth grows at a rate
that is specific to phase 2. The transition function on phase 2's wealth grid
takes a `growth_rate` parameter:

In [None]:
def next_wealth_at_boundary(
    wealth: ContinuousState,
    growth_rate: float,
) -> ContinuousState:
    """Wealth transition at the phase1 → phase2 boundary.

    The growth_rate parameter is resolved from phase2's params template,
    because the mapping lives on phase2's grid.
    """
    return (1 + growth_rate) * wealth


# Phase 2's wealth grid declares the per-boundary mapping:
phase2_wealth_grid = LinSpacedGrid(
    start=0,
    stop=100,
    n_points=20,
    transition={
        ("phase1", "phase2"): next_wealth_at_boundary,
    },
)

Because the mapping `{("phase1", "phase2"): next_wealth_at_boundary}` lives on
phase 2's grid, the `growth_rate` parameter appears in phase 2's parameter
template. The user supplies it under `"phase2"` in the params dict:

```python
params = {
    "phase1": {...},
    "phase2": {
        "next_wealth": {"growth_rate": 0.05},
        ...
    },
}
```

Internally, pylcm detects that the transition was resolved from the target
grid's mapping and renames the parameter to a cross-boundary qualified name
(e.g., `phase2__next_wealth__growth_rate`). At solve and simulation time, the
value is looked up from `internal_params["phase2"]` — not from `"phase1"` —
even though the transition is evaluated as part of phase 1's backward induction
step.

## Regime Transition Mechanics

Regime transitions determine which regime the agent enters next period. Internally,
both deterministic and stochastic transitions are converted to a uniform probability
array format.

### Deterministic transitions → one-hot encoding

A `RegimeTransition` wraps a function that returns an integer regime ID.
Internally, `_wrap_deterministic_regime_transition` converts this to a one-hot
probability array using `jax.nn.one_hot`:

In [None]:
@categorical
class RegimeIdExample:
    working: int
    retired: int
    dead: int


# Deterministic: retire at age 65
def next_regime_det(age: float, retirement_age: float) -> ScalarInt:
    return jnp.where(
        age >= retirement_age, RegimeIdExample.retired, RegimeIdExample.working
    )


# What pylcm does internally:
regime_idx = next_regime_det(age=50.0, retirement_age=65.0)
one_hot = jax.nn.one_hot(regime_idx, num_classes=3)
print(f"Regime index: {int(regime_idx)}")
print(f"One-hot:      {one_hot}  (= [P(working), P(retired), P(dead)])")

### Stochastic transitions → probability array

A `MarkovRegimeTransition` wraps a function that directly returns a probability
array. No conversion is needed — the array is used as-is.

In [None]:
def next_regime_stoch(survival_prob: float) -> FloatND:
    """Alive → [P(working), P(retired), P(dead)]."""
    return jnp.array([survival_prob, 0.0, 1 - survival_prob])


probs = next_regime_stoch(survival_prob=0.98)
print(f"Probabilities: {probs}  (= [P(working), P(retired), P(dead)])")

After wrapping, the probability array is further converted to a dictionary keyed by
regime name (via `_wrap_regime_transition_probs`), giving a uniform internal
representation regardless of whether the original transition was deterministic or
stochastic.

## How Transitions Compose in the Bellman Equation

The value function computation depends on the regime type:

### Terminal regimes

No continuation value. The value function equals the utility directly:

$$
V_T(s) = U(s)
$$

### Non-terminal with deterministic regime transition

The continuation value comes from a single next-period regime:

$$
V_t(s) = \max_a \left\{ U(s, a) + \beta \, V_{t+1}^{r'}(s') \right\}
$$

where $r'$ is the deterministically chosen next regime and $s' = g(s, a)$ is the
next-period state.

### Non-terminal with stochastic regime transition

The continuation value is an expectation over possible next regimes:

$$
V_t(s) = \max_a \left\{ U(s, a) + \beta \sum_r p_r \, V_{t+1}^{r}(s') \right\}
$$

where $p_r$ is the probability of transitioning to regime $r$.

### Adding stochastic state transitions

When a state has a Markov transition (`DiscreteMarkovGrid`) or shock grid, an
additional layer of expectation is added inside the max:

$$
V_t(s) = \max_a \left\{ U(s, a) + \beta \sum_r p_r
    \sum_{s'} P(s' \mid s) \, V_{t+1}^{r}(s') \right\}
$$

The inner sum handles the stochastic state transition; the outer sum handles the
stochastic regime transition. When either is deterministic, its corresponding sum
collapses to a single term.

### Summary

| Component | Deterministic | Stochastic |
|---|---|---|
| State transition | $s' = g(s, a)$ | $\sum_{s'} P(s' \mid s) \, V(s')$ |
| Regime transition | One-hot $\rightarrow$ single $V^{r'}$ | $\sum_r p_r \, V^r$ |
| Internal format | Both converted to probability arrays | — |

The uniform probability format means the backward induction algorithm treats all
transitions the same way — deterministic transitions are just the special case
where one probability is 1 and the rest are 0.