# Grids

Grids discretize continuous state and action spaces for dynamic programming. The
choice of grid affects both solution accuracy and computation speed — finer grids
improve accuracy but increase cost.

This page covers:

- **Continuous grids**: `LinSpacedGrid`, `LogSpacedGrid`, `IrregSpacedGrid`
- **Piecewise grids**: `PiecewiseLinSpacedGrid`, `PiecewiseLogSpacedGrid`
- **Discrete grids**: `DiscreteGrid` and the `@categorical` decorator
- **Discrete Markov grids**: `DiscreteMarkovGrid` for stochastic transitions
- **State transitions**: how grids connect periods via the `transition` parameter
- **Shock grids**: pointer to the [shocks page](shocks.md)
- **Grid selection guide**: summary table

In [None]:
import jax.numpy as jnp

from lcm import (
    DiscreteGrid,
    DiscreteMarkovGrid,
    IrregSpacedGrid,
    LinSpacedGrid,
    LogSpacedGrid,
    Piece,
    PiecewiseLinSpacedGrid,
    PiecewiseLogSpacedGrid,
    categorical,
)
from lcm.typing import ContinuousState, DiscreteState, FloatND

## Continuous Grids

Continuous grids represent variables that can take any value within a range (e.g.,
wealth, consumption). All continuous grids share three parameters:

- `start`: lower bound
- `stop`: upper bound
- `n_points`: number of grid points

The grids differ in how they space points between `start` and `stop`.

### `LinSpacedGrid`

Points are equally spaced. Use this as the default for variables with roughly
uniform density across the range.

In [None]:
lin = LinSpacedGrid(start=1, stop=100, n_points=5)
print("LinSpacedGrid:", lin.to_jax())

### `LogSpacedGrid`

Points are logarithmically spaced — denser at lower values, sparser at higher
values. Use this for variables like wealth where the value function has high
curvature near zero.

In [None]:
log = LogSpacedGrid(start=1, stop=100, n_points=5)
print("LogSpacedGrid:", log.to_jax())

Notice how the log grid places three of its five points below 10, while the linear
grid spaces them evenly at 1, 25.75, 50.5, 75.25, 100.

### `IrregSpacedGrid`

Points are placed at user-specified locations. Use this when you need full control
over point placement — for example, Gauss-Hermite quadrature nodes.

In [None]:
irreg = IrregSpacedGrid(points=[-1.73, -0.58, 0.58, 1.73])
print("IrregSpacedGrid:", irreg.to_jax())

You can also defer the point specification to runtime by providing only `n_points`.
The actual points are then supplied via the model parameters.

In [None]:
irreg_deferred = IrregSpacedGrid(n_points=4)
print("n_points:", irreg_deferred.n_points)

## Piecewise Grids

Piecewise grids combine multiple segments with different densities. They are useful
when you need a breakpoint at a specific value — for example, an eligibility
threshold for a means-tested program.

Each segment is defined by a `Piece` with an interval and a number of points.

### `PiecewiseLinSpacedGrid`

Each segment is linearly spaced. Adjacent pieces must share a boundary, and the
boundary notation controls whether endpoints are included (`[closed`) or excluded
(`open)`).

In [None]:
pw_lin = PiecewiseLinSpacedGrid(
    pieces=(
        Piece(interval="[0, 50)", n_points=30),
        Piece(interval="[50, 500]", n_points=20),
    )
)
points = pw_lin.to_jax()
print(f"Total points: {pw_lin.n_points}")
print(f"First 5:  {points[:5]}")
print(f"Around 50: {points[28:32]}")
print(f"Last 5:   {points[-5:]}")

The breakpoint at 50 is guaranteed to be a grid point. This prevents interpolation
across a potential discontinuity in the value function.

### `PiecewiseLogSpacedGrid`

Each segment is logarithmically spaced. Use this for wealth-like variables where you
want dense coverage at low values and a specific breakpoint.

In [None]:
pw_log = PiecewiseLogSpacedGrid(
    pieces=(
        Piece(interval="[0.1, 10)", n_points=50),
        Piece(interval="[10, 1000]", n_points=30),
    )
)
points = pw_log.to_jax()
print(f"Total points: {pw_log.n_points}")
print(f"First 5:  {points[:5]}")
print(f"Last 5:   {points[-5:]}")

## Discrete Grids

Discrete grids represent categorical variables with a finite set of values (e.g.,
employment status, education level). The `@categorical` decorator creates a class
that maps labels to consecutive integer codes starting from 0.

### The `@categorical` decorator

In [None]:
@categorical
class EducationLevel:
    high_school: int
    college: int
    graduate: int


print("high_school:", EducationLevel.high_school)
print("college:    ", EducationLevel.college)
print("graduate:   ", EducationLevel.graduate)

### `DiscreteGrid`

Wraps a categorical class into a grid. Use this for deterministic discrete
variables — either fixed states or states with deterministic transitions.

In [None]:
edu_grid = DiscreteGrid(EducationLevel)

print("Categories:", edu_grid.categories)
print("Codes:     ", edu_grid.codes)
print("JAX array: ", edu_grid.to_jax())

## Discrete Markov Grids

`DiscreteMarkovGrid` extends `DiscreteGrid` with stochastic (Markov) transitions.
Instead of a deterministic transition function that returns the next state, the
transition function returns a **probability array** over all categories.

Use this for states like health status where the next-period value is drawn from a
probability distribution that depends on the current state.

In [None]:
@categorical
class HealthStatus:
    bad: int
    good: int


def health_transition(health: DiscreteState) -> FloatND:
    """Markov transition for health.

    Returns probability array [P(bad), P(good)] given current health.
    """
    return jnp.where(
        health == HealthStatus.good,
        jnp.array([0.1, 0.9]),  # good → 90% stay good
        jnp.array([0.6, 0.4]),  # bad → 40% recover
    )


health_grid = DiscreteMarkovGrid(HealthStatus, transition=health_transition)
print("Categories:", health_grid.categories)
print("Codes:     ", health_grid.codes)

During the solve step, pylcm computes the probability-weighted expectation over
next-period states. During simulation, it draws from the transition probabilities.

## State Transitions on Grids

Every state grid in a regime must specify a `transition` parameter that defines how
the state variable arrived at its current value — given the previous period's states,
actions, and parameters. The `transition` parameter takes one of three forms:

1. **Callable**: a function that computes the state from last period's variables
2. **`None`**: the state is fixed (time-invariant)
3. **Mapping**: per-boundary transitions for multi-regime models

State transitions are **backward-looking**: they live on the grid that *receives*
the state value. In multi-regime models, a per-boundary mapping is placed on the
**target** regime's grid to describe how to arrive from a different regime. This
contrasts with the regime's own `transition`, which is forward-looking (see the
[regimes page](regimes.ipynb)).

Action grids never have transitions — they are choice variables, not state
variables.

### Callable transition

The most common form. The function computes the current state value from last
period's variables. Its argument names are resolved from the model's namespace
(states, actions, auxiliary functions, parameters).

In [None]:
def next_wealth(
    wealth: ContinuousState,
    consumption: ContinuousState,
    interest_rate: float,
) -> ContinuousState:
    return (1 + interest_rate) * (wealth - consumption)


wealth_grid = LinSpacedGrid(start=0, stop=100, n_points=50, transition=next_wealth)

### Fixed state (`None`)

For states that don't change over time (e.g., a fixed education level). An identity
transition is auto-generated internally.

In [None]:
edu_fixed = DiscreteGrid(EducationLevel, transition=None)

### Per-boundary mapping

When discrete categories differ across regimes, or a state needs a custom mapping at
a regime boundary, use a mapping keyed by `(source_regime, target_regime)` pairs.
This is placed on the **target** regime's grid.

In [None]:
# Working regime has 3 health categories
@categorical
class HealthWorking:
    disabled: int
    bad: int
    good: int


# Retirement regime has 2 health categories (no disability)
@categorical
class HealthRetired:
    bad: int
    good: int


def map_working_to_retired(health: DiscreteState) -> DiscreteState:
    """Map 3-category working health to 2-category retired health."""
    # disabled and bad both map to bad in retirement
    return jnp.where(
        health == HealthWorking.good, HealthRetired.good, HealthRetired.bad
    )


# On the target (retired) regime's grid:
health_retired_grid = DiscreteGrid(
    HealthRetired,
    transition={
        ("working", "retired"): map_working_to_retired,
    },
)

## Shock Grids

Shock grids discretize continuous stochastic processes. They have **intrinsic
transitions** (probability weights computed from the distribution) and do not
accept a `transition` parameter.

pylcm provides:

- **IID shocks**: `lcm.shocks.iid.Normal`, `lcm.shocks.iid.Uniform`,
  `lcm.shocks.iid.LogNormal`
- **AR(1) shocks**: `lcm.shocks.ar1.Rouwenhorst`, `lcm.shocks.ar1.Tauchen`

Import them as modules:

In [None]:
import lcm.shocks.ar1
import lcm.shocks.iid

income_shock = lcm.shocks.iid.Normal(
    n_points=5, gauss_hermite=False, mu=0.0, sigma=0.1, n_std=2.5
)
print("Normal shock grid:", income_shock.to_jax())

ar1_shock = lcm.shocks.ar1.Rouwenhorst(n_points=5, rho=0.9, sigma=0.1, mu=0.0)
print("Rouwenhorst grid: ", ar1_shock.to_jax())

See the [shocks page](shocks.md) for details on each shock type and its parameters.

## Grid Selection Guide

| Variable type | Recommended grid | Example |
|---|---|---|
| Continuous, uniform density | `LinSpacedGrid` | Age-independent income |
| Continuous, high curvature at low values | `LogSpacedGrid` | Wealth with CRRA utility |
| Continuous, specific breakpoint needed | `PiecewiseLinSpacedGrid` | Wealth with eligibility threshold |
| Continuous, custom point placement | `IrregSpacedGrid` | Quadrature nodes |
| Categorical, deterministic | `DiscreteGrid` | Education level, employment status |
| Categorical, stochastic | `DiscreteMarkovGrid` | Health status with transition probs |
| IID continuous shock | `lcm.shocks.iid.Normal` / `.Uniform` | Income shock |
| Persistent AR(1) shock | `lcm.shocks.ar1.Rouwenhorst` / `.Tauchen` | Productivity process |