# Interpolation and Extrapolation

In dynamic programming, the value function $V(x)$ is computed on a discrete grid but
must be evaluated at arbitrary points during the solution process. State transitions
$x' = g(x, a)$ can produce next-period states that fall between grid points (requiring
interpolation) or outside the grid (requiring extrapolation).

pylcm does not use `RegularGridInterpolator` or similar pre-built tools because it
needs to support grids whose points are only known at runtime — for example, shock
grids whose locations depend on distributional parameters supplied via `params`.

This notebook explains how pylcm handles interpolation and extrapolation, using a CRRA
utility function $u(w) = \frac{w^{1-\gamma}}{1-\gamma}$ on a coarse wealth grid as a
running example.

## pylcm's two-step design

pylcm evaluates functions on arbitrary points in two steps:

1. **Coordinate finder**: Convert a physical value (e.g., wealth = 150) to a
   *generalized coordinate* — a fractional index into the grid. Values inside the grid
   produce coordinates in $[0, n-1]$; values outside produce coordinates outside this
   range.

2. **`map_coordinates`**: Take the generalized coordinates and the pre-computed array
   of function values. Perform linear interpolation (for coordinates inside $[0, n-1]$)
   or linear extrapolation (for coordinates outside).

Each grid type provides its own coordinate finder, optimized for its spacing pattern.
The `map_coordinates` function is the same for all grid types.

## `LinSpacedGrid`

For a linearly spaced grid with `start`, `stop`, and `n_points`, the coordinate finder
uses the O(1) formula:

$$
\text{coordinate} = \frac{\text{value} - \text{start}}{\text{step\_length}},
\quad \text{step\_length} = \frac{\text{stop} - \text{start}}{n_\text{points} - 1}
$$

Values inside the grid produce coordinates in $[0, n_\text{points}-1]$. Values below
`start` produce negative coordinates; values above `stop` produce coordinates above
$n_\text{points}-1$. Both cases lead to linear extrapolation in `map_coordinates`.

In [None]:
import jax.numpy as jnp
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from lcm import LinSpacedGrid, LogSpacedGrid, Piece, PiecewiseLinSpacedGrid
from lcm.ndimage import map_coordinates

blue, orange, green = "#4C78A8", "#F58518", "#54A24B"


def crra(wealth, gamma=1.5):
    """CRRA utility: u(w) = w^(1-gamma) / (1-gamma)."""
    return wealth ** (1 - gamma) / (1 - gamma)


# Coarse linearly spaced wealth grid (10 points)
lin_grid = LinSpacedGrid(start=1, stop=400, n_points=10)
grid_points = lin_grid.to_jax()

# CRRA values on the grid
V = crra(grid_points)

print("Grid points:", grid_points)
print("V on grid:  ", V)

In [None]:
# Query points: some inside the grid, some outside
query_inside = jnp.array([25.0, 100.0, 250.0])
query_outside = jnp.array([0.5, 450.0])

coords_inside = lin_grid.get_coordinate(query_inside)
coords_outside = lin_grid.get_coordinate(query_outside)

print("Inside grid:")
for w, c in zip(query_inside, coords_inside, strict=True):
    print(f"  wealth = {w:.1f}  →  coordinate = {c:.4f}  (in [0, 9])")

print("\nOutside grid:")
for w, c in zip(query_outside, coords_outside, strict=True):
    in_range = "< 0" if c < 0 else "> 9"
    print(f"  wealth = {w:.1f}  →  coordinate = {c:.4f}  ({in_range})")

In [None]:
# Interpolate and extrapolate using map_coordinates
query_all = jnp.concatenate([query_inside, query_outside])
coords_all = lin_grid.get_coordinate(query_all)

V_approx = map_coordinates(input=V, coordinates=[coords_all])
V_true = crra(query_all)

print(
    f"{'Wealth':>8}  {'Coordinate':>11}  "
    f"{'Approximated':>12}  {'True':>12}  {'Error':>10}"
)
print("-" * 60)
for w, c, va, vt in zip(query_all, coords_all, V_approx, V_true, strict=True):
    print(f"{w:8.1f}  {c:11.4f}  {va:12.6f}  {vt:12.6f}  {va - vt:10.6f}")

In [None]:
# Dense points for smooth curves (extending beyond the grid for extrapolation)
x_dense = jnp.linspace(0.5, 450, 500)
coords_dense = lin_grid.get_coordinate(x_dense)
V_interp_dense = map_coordinates(input=V, coordinates=[coords_dense])
V_true_dense = crra(x_dense)

fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=x_dense,
        y=V_true_dense,
        mode="lines",
        line={"color": "gray", "width": 1},
        name="True CRRA",
    )
)
fig.add_trace(
    go.Scatter(
        x=x_dense,
        y=V_interp_dense,
        mode="lines",
        line={"color": orange, "width": 2},
        name="Interpolation / Extrapolation",
    )
)
fig.add_trace(
    go.Scatter(
        x=grid_points,
        y=V,
        mode="markers",
        marker={"color": blue, "size": 8},
        name="Grid points",
    )
)
fig.add_vline(x=1, line={"color": "gray", "dash": "dot", "width": 1})
fig.add_vline(x=400, line={"color": "gray", "dash": "dot", "width": 1})
fig.update_layout(
    title="LinSpacedGrid: Interpolation and Extrapolation",
    xaxis_title="Wealth",
    yaxis_title="u(w)",
)
fig.show()

## `LogSpacedGrid`

For a logarithmically spaced grid, points are denser at lower values — ideal for
functions with high curvature near zero (like CRRA utility).

The coordinate finder `get_logspace_coordinate` works by:

1. Transforming `value`, `start`, `stop` to log space
2. Finding the bounding grid points (via their ranks in log space)
3. Linearly interpolating between the ranks in the original (physical) space

This gives a coordinate that accounts for the non-uniform spacing. With the same
number of grid points, a log-spaced grid captures the curvature of CRRA much better
than a linearly spaced one.

In [None]:
log_grid = LogSpacedGrid(start=1, stop=400, n_points=10)
log_points = log_grid.to_jax()
V_log = crra(log_points)

print("LinSpaced grid:", jnp.round(grid_points, 1))
print("LogSpaced grid:", jnp.round(log_points, 1))

In [None]:
x_eval = jnp.linspace(1, 400, 500)
V_true_eval = crra(x_eval)

# LinSpaced interpolation
lin_coords = lin_grid.get_coordinate(x_eval)
V_lin_interp = map_coordinates(input=V, coordinates=[lin_coords])

# LogSpaced interpolation
log_coords = log_grid.get_coordinate(x_eval)
V_log_interp = map_coordinates(input=V_log, coordinates=[log_coords])

fig = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=(
        "Interpolated value functions",
        "Interpolation error",
    ),
)

# Left: Interpolated curves
fig.add_trace(
    go.Scatter(
        x=x_eval,
        y=V_true_eval,
        mode="lines",
        line={"color": "gray", "width": 1},
        name="True CRRA",
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        x=x_eval,
        y=V_lin_interp,
        mode="lines",
        line={"color": blue, "width": 2},
        name="LinSpaced (10 pts)",
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        x=x_eval,
        y=V_log_interp,
        mode="lines",
        line={"color": orange, "width": 2},
        name="LogSpaced (10 pts)",
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        x=grid_points,
        y=V,
        mode="markers",
        marker={"color": blue, "size": 6},
        name="LinSpaced grid",
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        x=log_points,
        y=V_log,
        mode="markers",
        marker={"color": orange, "size": 6},
        name="LogSpaced grid",
    ),
    row=1,
    col=1,
)

# Right: Absolute errors
fig.add_trace(
    go.Scatter(
        x=x_eval,
        y=jnp.abs(V_lin_interp - V_true_eval),
        mode="lines",
        line={"color": blue},
        name="LinSpaced error",
        showlegend=False,
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Scatter(
        x=x_eval,
        y=jnp.abs(V_log_interp - V_true_eval),
        mode="lines",
        line={"color": orange},
        name="LogSpaced error",
        showlegend=False,
    ),
    row=1,
    col=2,
)

fig.update_xaxes(title_text="Wealth", row=1, col=1)
fig.update_xaxes(title_text="Wealth", row=1, col=2)
fig.update_yaxes(title_text="u(w)", row=1, col=1)
fig.update_yaxes(title_text="|Error|", row=1, col=2)
fig.update_layout(height=450, width=900)
fig.show()

## `ShockGrid` (Normal)

ShockGrids discretize continuous distributions. Their grid points depend on
distributional parameters (e.g., `mu`, `sigma` for a Normal shock) which may only be
known at runtime (when supplied via `params`). Because the points are determined
dynamically, ShockGrids cannot use the O(1) linspace formula — instead, they use
`get_irreg_coordinate` internally.

`get_irreg_coordinate` handles arbitrary point sequences:

1. Use `jnp.searchsorted` to find the bounding grid points (O(log n))
2. Linearly interpolate between the bounding points to get the fractional coordinate
3. For values outside the grid, extrapolate using the slope of the nearest segment

In [None]:
import lcm.shocks.iid

shock = lcm.shocks.iid.Normal(mu=0.0, sigma=1.0, n_std=2.5, n_points=7)
shock_points = shock.to_jax()

print("Shock grid points:", shock_points)

# CRRA of (base wealth + shock)
base_wealth = 100.0
V_shock = crra(base_wealth + shock_points)

# Query points inside and outside the shock grid
shock_query = jnp.array([-3.0, -1.0, 0.5, 2.0, 3.0])
shock_coords = shock.get_coordinate(shock_query)

n_last = len(shock_points) - 1

print("\nShock query points and coordinates:")
for q, c in zip(shock_query, shock_coords, strict=True):
    in_out = "inside" if 0 <= c <= n_last else "outside"
    print(f"  ε = {q:5.1f}  →  coordinate = {c:6.3f}  ({in_out})")

In [None]:
eps_dense = jnp.linspace(-3.5, 3.5, 300)
shock_coords_dense = shock.get_coordinate(eps_dense)
V_shock_interp = map_coordinates(input=V_shock, coordinates=[shock_coords_dense])
V_shock_true = crra(base_wealth + eps_dense)

fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=eps_dense,
        y=V_shock_true,
        mode="lines",
        line={"color": "gray", "width": 1},
        name="True CRRA(100 + ε)",
    )
)
fig.add_trace(
    go.Scatter(
        x=eps_dense,
        y=V_shock_interp,
        mode="lines",
        line={"color": orange, "width": 2},
        name="Interpolation / Extrapolation",
    )
)
fig.add_trace(
    go.Scatter(
        x=shock_points,
        y=V_shock,
        mode="markers",
        marker={"color": blue, "size": 8},
        name="Grid points",
    )
)
fig.add_vline(
    x=float(shock_points[0]),
    line={"color": "gray", "dash": "dot", "width": 1},
)
fig.add_vline(
    x=float(shock_points[-1]),
    line={"color": "gray", "dash": "dot", "width": 1},
)
fig.update_layout(
    title="Normal ShockGrid: Interpolation and Extrapolation",
    xaxis_title="Shock (ε)",
    yaxis_title="u(100 + ε)",
)
fig.show()

```{note}
This example is not very illustrative because `Normal` produces linearly spaced
grid points — making `get_irreg_coordinate` equivalent to `get_linspace_coordinate`
here. The difference will become visible once ShockGrids support Gauss-Hermite
quadrature points ([#248](https://github.com/OpenSourceEconomics/pylcm/issues/248)),
which are genuinely irregularly spaced.
```

## `PiecewiseLinSpacedGrid`

Some models feature eligibility thresholds — e.g., a means-tested program that applies
only below a wealth cutoff. The value function may jump at the threshold (eligible
households receive a transfer, ineligible ones do not). To ensure the threshold is a
grid point (avoiding interpolation across the discontinuity),
`PiecewiseLinSpacedGrid` lets you place a breakpoint at the threshold.

For a grid with pieces $[a, b)$ and $[b, c]$:

1. **Piece selection**: `jnp.searchsorted` on breakpoints determines which piece a
   value belongs to
2. **Local coordinate**: `get_linspace_coordinate` within the piece
3. **Global coordinate**: offset by the cumulative number of points in preceding pieces

The breakpoint $b$ is the first point of the second piece, guaranteeing it is a grid
point. Because `map_coordinates` interpolates linearly between adjacent grid points,
and the two points straddling the breakpoint ($b$ and the last point of the previous
piece) are adjacent in the array, the interpolation never crosses the discontinuity —
this is guaranteed by the implementation.

In [None]:
pw_grid = PiecewiseLinSpacedGrid(
    pieces=(
        Piece(interval="[1, 50)", n_points=5),
        Piece(interval="[50, 400]", n_points=7),
    )
)
pw_points = pw_grid.to_jax()

# Value function with a jump: means-tested transfer of 0.5 for wealth below threshold
transfer = 0.5
threshold = 50.0

V_pw = crra(pw_points) + transfer * jnp.where(pw_points < threshold, 1.0, 0.0)

print(f"Total grid points: {pw_grid.n_points}")
print(f"Grid: {jnp.round(pw_points, 1)}")
print(f"\nBreakpoint at wealth = 50 is at index 5: grid[5] = {pw_points[5]:.1f}")
print(f"\nV just below threshold (grid[4]): {V_pw[4]:.4f}")
print(f"V at threshold        (grid[5]): {V_pw[5]:.4f}")

In [None]:
x_dense_pw = jnp.linspace(1, 400, 500)
pw_coords_dense = pw_grid.get_coordinate(x_dense_pw)
V_pw_interp = map_coordinates(input=V_pw, coordinates=[pw_coords_dense])

# Split traces at the threshold to avoid connecting lines across the jump
mask_below = x_dense_pw < threshold
x_below = x_dense_pw[mask_below]
x_above = x_dense_pw[~mask_below]

fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=x_below,
        y=crra(x_below) + transfer,
        mode="lines",
        line={"color": "gray", "width": 1},
        name="True function",
    )
)
fig.add_trace(
    go.Scatter(
        x=x_above,
        y=crra(x_above),
        mode="lines",
        line={"color": "gray", "width": 1},
        name="True function",
        showlegend=False,
    )
)
fig.add_trace(
    go.Scatter(
        x=x_below,
        y=V_pw_interp[mask_below],
        mode="lines",
        line={"color": orange, "width": 2},
        name="Piecewise interpolation",
    )
)
fig.add_trace(
    go.Scatter(
        x=x_above,
        y=V_pw_interp[~mask_below],
        mode="lines",
        line={"color": orange, "width": 2},
        name="Piecewise interpolation",
        showlegend=False,
    )
)
fig.add_trace(
    go.Scatter(
        x=pw_points,
        y=V_pw,
        mode="markers",
        marker={"color": blue, "size": 8},
        name="Grid points",
    )
)
fig.update_layout(
    title="PiecewiseLinSpacedGrid: capturing a discontinuity",
    xaxis_title="Wealth",
    yaxis_title="V(w)",
)
fig.show()

## `map_coordinates` internals

pylcm's `map_coordinates` (`src/lcm/ndimage.py`) is a modified version of JAX's
`jax.scipy.ndimage.map_coordinates`. The key difference is in how it handles values
outside the grid.

The function `_compute_indices_and_weights` is where the magic happens:

```python
lower_index = jnp.clip(jnp.floor(coordinate), 0, input_size - 2)
upper_weight = coordinate - lower_index
lower_weight = 1 - upper_weight
```

- The lower index is clipped to $[0, n-2]$, ensuring valid array access
- But the **weight is not clipped** — it is simply `coordinate - lower_index`

For coordinates inside $[0, n-1]$, the weight falls in $[0, 1]$, giving standard
linear interpolation. For coordinates outside this range:

| Coordinate | Lower index | Weight | Result |
|-----------|-------------|--------|--------|
| $c \in [0, n-1]$ | $\lfloor c \rfloor$ | $c - \lfloor c \rfloor \in [0,1]$ | Linear interpolation |
| $c < 0$ | $0$ | $c < 0$ | Linear extrapolation using first segment |
| $c > n-1$ | $n-2$ | $c - (n-2) > 1$ | Linear extrapolation using last segment |

JAX's version, by contrast, clips or fills values outside the grid (depending on
`mode`), which does not give linear extrapolation.

In [None]:
from lcm.ndimage import _compute_indices_and_weights

# A simple grid with 5 points
n = 5
V_demo = jnp.array([10.0, 20.0, 30.0, 40.0, 50.0])

# Three cases: below grid, inside grid, above grid
cases = [
    (-0.5, "below grid"),
    (1.7, "inside grid"),
    (4.5, "above grid"),
]

for coord, label in cases:
    coord_arr = jnp.array(coord)
    [(lo_idx, lo_wt), (hi_idx, hi_wt)] = _compute_indices_and_weights(coord_arr, n)
    result = float(lo_wt * V_demo[lo_idx] + hi_wt * V_demo[hi_idx])
    print(f"coordinate = {coord:5.1f} ({label})")
    print(f"  lower_index = {int(lo_idx)}, upper_index = {int(hi_idx)}")
    print(f"  lower_weight = {float(lo_wt):.2f}, upper_weight = {float(hi_wt):.2f}")
    print(
        f"  result = {float(lo_wt):.2f} \u00d7 V[{int(lo_idx)}]"
        f" + {float(hi_wt):.2f} \u00d7 V[{int(hi_idx)}]"
    )
    print(
        f"         = {float(lo_wt):.2f} \u00d7 {float(V_demo[lo_idx]):.0f}"
        f" + {float(hi_wt):.2f} \u00d7 {float(V_demo[hi_idx]):.0f}"
    )
    print(f"         = {result:.1f}\n")

## Multi-dimensional interpolation

When the value function depends on multiple continuous states, each dimension gets its
own coordinate finder. The `map_coordinates` function then performs multi-linear
interpolation (or extrapolation) by combining the per-dimension coordinates.

For a 2D case (e.g., wealth $\times$ income shock), this is bilinear interpolation: the
function value at a query point is a weighted average of the 4 nearest grid points,
with weights determined by the per-dimension fractional coordinates.

In [None]:
# 2D value function: V(wealth, shock) = CRRA(wealth + shock)
wealth_grid = LinSpacedGrid(start=10, stop=400, n_points=8)
w_points = wealth_grid.to_jax()

shock_2d = lcm.shocks.iid.Normal(mu=0.0, sigma=1.0, n_std=2.0, n_points=5)
s_points = shock_2d.to_jax()

# Evaluate on the 2D grid (shape: 8 x 5)
W, S = jnp.meshgrid(w_points, s_points, indexing="ij")
V_2d = crra(W + S)

print(f"Wealth grid: {jnp.round(w_points, 1)}")
print(f"Shock grid:  {jnp.round(s_points, 2)}")
print(f"V shape:     {V_2d.shape} (wealth \u00d7 shock)")

# Query point
w_query = jnp.array(150.0)
s_query = jnp.array(0.3)

# Per-dimension coordinates
w_coord = wealth_grid.get_coordinate(w_query)
s_coord = shock_2d.get_coordinate(s_query)

# 2D interpolation
V_interp_2d = map_coordinates(input=V_2d, coordinates=[w_coord, s_coord])
V_true_2d = crra(w_query + s_query)

print(f"\nQuery: wealth = {float(w_query)}, shock = {float(s_query)}")
print(f"Wealth coordinate: {float(w_coord):.4f}")
print(f"Shock coordinate:  {float(s_coord):.4f}")
print(f"Interpolated V:    {float(V_interp_2d):.6f}")
print(f"True V:            {float(V_true_2d):.6f}")
print(f"Error:             {float(V_interp_2d - V_true_2d):.6f}")

## Summary

| Grid type | Coordinate finder | Complexity | When to use |
|-----------|------------------|------------|-------------|
| `LinSpacedGrid` | `get_linspace_coordinate` | O(1) | Uniformly spaced state variables |
| `LogSpacedGrid` | `get_logspace_coordinate` | O(1) | States with high curvature at low values (e.g., wealth with CRRA) |
| `PiecewiseLinSpacedGrid` | searchsorted + `get_linspace_coordinate` | O(log k) + O(1) | States with breakpoints (e.g., eligibility thresholds) |
| `IrregSpacedGrid` | `get_irreg_coordinate` | O(log n) | Arbitrary point placement |
| ShockGrids | `get_irreg_coordinate` | O(log n) | Stochastic shocks with runtime-determined points |

All coordinate finders produce generalized coordinates that `map_coordinates` uses for
linear interpolation (inside the grid) or linear extrapolation (outside the grid). The
two-step design keeps the interpolation logic generic while letting each grid type
optimize its coordinate mapping.