# Explanation of the Function Representation

In this notebook, we showcase how the function representation of a set of values
pre-calculated on a grid is used in `lcm`, and how it works. Before we dive into the
details, let us consider what it does on a high level.

## Motivation

Consider the last period of a finite dynamic programming problem. The value function
array for this period corresponds to the solution of a classic static utility
maximization problem. That is, it is the maximum of concurrent utility in each state,
where the maximum is taken over actions.

If the state-space is discretized into states $(x_1, \ldots, x_p)$, the value function
array (in the last period) $V^\text{arr}_T$ is a $p$-dimensional array, where the $i$-th
entry $V^\text{arr}_{T, i} = V^\text{arr}_T(x_i)$ is the maximal utility the agent can
achieve in state $x_i$.

Consider now the Bellman equation for the second-to last period:

$$
V_{T-1}(x) =
    \max_{a} \left\{u(x, a) + \mathbb{E}_{T-1}\left[V_T(x') \mid x, a\right] \right\},
$$

where $a$ denotes the action, and $x', x$ denote the next and current state,
respectively.

For most solution algorithms, we will need to evaluate the function $V_T$ at a different
set of points than the pre-calculated grid points in $V^\text{arr}_T$.

Ideally, we would like to have a function in the code that we can treat like $V_T$ is
written in the equation above: a function that can be evaluated at any valid state $x$,
ignoring the discretization in $V^\text{arr}_T$. This is precisely what the function
representation does.


### General Steps
                                                        
To get a function representation of pre-calculated values on a grid (i.e. an array)
we need to take care of the following things:

1. The function will be called with named arguments. Hence, we need to know which
   argument name corresponds to which array dimension.                   

2. The function will be called with values of each dimension (e.g., health taking on a
   value of 3). However, array elements are retrieved through indexing (maybe an index
   of 1 corresponds to a value of 3 for health).  Hence. we require a mapping from
   levels to indices.

3. Continuous variables will take on values that do not occur in the grid. This requires
   interpolation of the function values found on that grid.

Combining the above allows us to create a function representation of pre-calculated
values on a grid, which behaves like an analytical function.
                                                        

### Example

As an example, consider a stripped-down version of the deterministic model from
Iskhakov et al. (2017), which removes the absorbing retirement constraint and the lagged
retirement state compared to the original model (this version can be found in the
`tests/test_models/deterministic.py` module). Here we also use a coarser grid to
showcase the behavior of the function representation.

In [None]:
from dataclasses import dataclass

import jax.numpy as jnp

from lcm import DiscreteGrid, LinspaceGrid, Model


@dataclass
class RetirementStatus:
    working: int = 0
    retired: int = 1


def utility(consumption, working, disutility_of_work):
    return jnp.log(consumption) - disutility_of_work * working


def labor_income(working, wage):
    return working * wage


def working(retirement):
    return 1 - retirement


def wage(age):
    return 1 + 0.1 * age


def age(_period):
    return _period + 18


def next_wealth(wealth, consumption, labor_income, interest_rate):
    return (1 + interest_rate) * (wealth - consumption) + labor_income


def consumption_constraint(consumption, wealth):
    return consumption <= wealth


model = Model(
    description=(
        "Starts from Iskhakov et al. (2017), removes the absorbing retirement "
        "constraint and the lagged_retirement state, and adds a wage function that "
        "depends on age."
    ),
    n_periods=2,
    functions={
        "utility": utility,
        "next_wealth": next_wealth,
        "consumption_constraint": consumption_constraint,
        "labor_income": labor_income,
        "working": working,
        "wage": wage,
        "age": age,
    },
    actions={
        "retirement": DiscreteGrid(RetirementStatus),
        "consumption": LinspaceGrid(start=1, stop=400, n_points=20),
    },
    states={
        "wealth": LinspaceGrid(start=1, stop=400, n_points=10),
    },
)


params = {
    "beta": 0.95,
    "utility": {"disutility_of_work": 0.25},
    "next_wealth": {
        "interest_rate": 0.05,
    },
}

To generate the correct JAX grids from this model specification, we process the
model manually. This is normally done under the hood in `lcm`.

In [None]:
from lcm.input_processing import process_model

processed_model = process_model(model)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


#### Last period value function array

To compute the value function array in the last period, we first generate the utility
and feasibility function that depends only on state and action variables, and then
compute the maximum over all feasible actions.

In [None]:
from lcm.model_functions import get_current_u_and_f

u_and_f = get_current_u_and_f(processed_model)

u_and_f.__signature__

<Signature (consumption, params, retirement, wealth)>

We can then evaluate `u_and_f` on scalar values. Notice that in the below example, the
action is not feasible since the consumption constraint forbids a consumption level that
is larger than wealth.

In [None]:
_u, _f = u_and_f(
    consumption=100,
    retirement=0,
    wealth=50,
    params=params,
)

print(f"Utility: {_u}, feasible: {_f}")

Utility: 4.355170249938965, feasible: False


To evaluate `u_and_f` on the whole state-action-space we need to use `lcm.productmap`,
which allows us to pass in grids for each variable.

In [None]:
processed_model.grids.keys()

dict_keys(['retirement', 'wealth', 'consumption'])

In [None]:
from lcm.dispatchers import productmap

u_and_f_mapped = productmap(u_and_f, variables=["wealth", "consumption", "retirement"])

u, f = u_and_f_mapped(**processed_model.grids, params=params)

print(f"Length of (wealth, consumption, retirement) grids: {u.shape}")

Length of (wealth, consumption, retirement) grids: (10, 20, 2)


Now we can compute the value function array by taking the maximal utility over all
feasible actions (axis 1 and 2 correspond to the consumption and retirement dimensions
respectively).

In [None]:
V_arr = jnp.max(u, axis=(1, 2), where=f, initial=-jnp.inf)
V_arr.shape

(10,)

In [None]:
wealth_grid = processed_model.grids["wealth"]

In [None]:
import plotly.graph_objects as go
import plotly.io as pio

pio.renderers.default = "notebook_connected"

blue, orange = "#4C78A8", "#F58518"


fig = go.Figure(
    data=[
        go.Scatter(
            x=wealth_grid,
            y=V_arr,
            mode="markers",
            marker={"color": blue, "size": 10},
            name="Pre-calculated values",
        ),
    ],
)

fig.update_layout(
    xaxis_title="Wealth (x)",
    yaxis_title="V(x)",
    template="plotly_white",
)

fig.show()

#### Interpolation

What happens now if we want to know the value of $V_T$ at 25 or 75? We need to perform
some kind of interpolation. This is where the function representation comes into play, which
returns pre-calculated values if evaluated on a grid point, and linearly interpolated
values otherwise.

To optimally utilize the structure of the grids when interpolating, the function
representation requires information on the state space.

In [None]:
from lcm.state_space import create_state_action_space

# the space info object contains information on the grid structure etc.
space_info = create_state_action_space(
    model=processed_model,
    is_last_period=True,
)[1]

#### Setting up the function representation

The first step is to generate a function that can interpolate on the value function
array. The resulting function can be called with scalar arguments (here this means we
can only pass scalar levels of wealth and no grids). It also requires the data on
which it interpolates as an argument. The name of this argument can be set using the
`name_of_values_on_grid` argument. Below we use `name_of_values_on_grid="V_arr"`, which
implies that the resulting function gets an additional argument `V_arr` that can be
used to pass in the pre-calculated value function array.

In [None]:
from lcm.function_representation import get_function_representation

scalar_value_function = get_function_representation(
    space_info=space_info,
    name_of_values_on_grid="V_arr",
)
scalar_value_function.__signature__

<Signature (V_arr, wealth)>

We then apply the productmap decorator, which allows us to evaluate the function on a
grid of state variables (in this case, just wealth).

In [None]:
value_function = productmap(scalar_value_function, variables=["wealth"])

#### Visualizing the results

Besides the pre-calculated values at the grid points, we will now add the values
generated by evaluating the value function on the original grid points, and on
additional points computed by the value function generated by the function representation.
We expect the values on the grid points to coincide, and the values on the additional
points to be interpolated.

In [None]:
wealth_grid = processed_model.grids["wealth"]
wealth_points_new = jnp.array([10, 25, 75, 210, 300])

wealth_grid_concatenated = jnp.concatenate([wealth_grid, wealth_points_new])

V_via_func = value_function(wealth=wealth_grid_concatenated, next_V_arr=V_arr)

In [None]:
import plotly.graph_objects as go

fig = go.Figure(
    data=[
        go.Scatter(
            x=wealth_grid,
            y=V_arr,
            mode="markers",
            marker={"color": blue, "size": 10},
            name="Pre-calculated values",
        ),
        go.Scatter(
            x=wealth_grid_concatenated,
            y=V_via_func,
            mode="markers",
            marker={"color": orange, "size": 7},
            name="Evaluated Points",
        ),
    ],
)

fig.update_layout(
    xaxis_title="Wealth (x)",
    yaxis_title="V(x)",
    template="plotly_white",
)

fig.show()

If we now connect the pre-calculated values at the grid points using a line, that is,
we perform a linear interpolation on the value function array. We see that the values
generated by the function representation lie on that linear interpolation line.

That means, the function representation can simply be thought of as a function that behaves
like an analytical function corresponding to this linear interpolation.

In [None]:
fig = go.Figure(
    data=[
        go.Scatter(
            x=wealth_grid,
            y=V_arr,
            mode="markers",
            marker={"color": blue, "size": 10},
            name="Pre-calculated values",
        ),
        go.Scatter(
            x=wealth_grid,
            y=V_arr,
            mode="lines",
            line={"color": blue},
            name="Linear interpolation",
        ),
        go.Scatter(
            x=wealth_grid_concatenated,
            y=V_via_func,
            mode="markers",
            marker={"color": orange, "size": 7},
            name="Evaluated Points",
        ),
    ],
)

fig.update_layout(
    xaxis_title="Wealth (x)",
    yaxis_title="V(x)",
    template="plotly_white",
)

fig.show()

## Technical Details

In the following, we will discuss the building blocks that are used to implement the
function representation.

### Label Translator

The label translator is used to map the labels of dense discrete grids to their
corresponding index in the grid. Currently, `lcm` works under the assumption that
internal discrete grids always correspond to their indices. That is, a grid like [2, 3]
is not allowed, but would have to be represented as [0, 1] to be valid.
       
(Once issue [#82](https://github.com/OpenSourceEconomics/lcm/issues/82) is tackled, )
LCM converts discrete grids into an internal grid that is directly usable as an index.
Thus, the label translator simply is the identity function.

In [None]:
from lcm.function_representation import _get_label_translator

translator = _get_label_translator(in_name="health")
translator.__signature__

<Signature (health)>

In [None]:
translator(health=3)

3

### Lookup Function

The lookup function emulates indexing into an array via named axes.

> Note. These helper functions are important because we use `dags.concatenate_functions`
> to combine all auxiliary functions to get the final function representation.

In [None]:
# We want a function that allows us to perform a lookup like this:
V_arr[jnp.array([0, 2, 5])]

Array([0.       , 4.4426513, 5.351858 ], dtype=float32)

In [None]:
from lcm.function_representation import _get_lookup_function

lookup = _get_lookup_function(array_name="V_arr", axis_names=["wealth_index"])
lookup.__signature__

<Signature (wealth_index, V_arr)>

In [None]:
lookup(wealth_index=jnp.array([0, 2, 5]), next_V_arr=V_arr)

Array([0.       , 4.4426513, 5.351858 ], dtype=float32)

### Coordinate Finder

For continuous grids (linearly and logarithmically spaced), the coordinate finder
returns the *general* index corresponding to the given value. As an example, consider a
linearly spaced grid [1, 2, 3]. The general coordinate value given the value 1.5 is, in
this case, 0.5, because 1.5 is exactly in the middle between 1 (index = 0) and 2 (index =
1). The output of the coordinate finder can then be used by
`jax.scipy.ndimage.map_coordinates` for the interpolation.

In [None]:
wealth_grid = LinspaceGrid(start=1, stop=400, n_points=10)

wealth_grid.to_jax()

Array([  1.      ,  45.333336,  89.66667 , 134.00002 , 178.33334 ,
       222.66667 , 267.00003 , 311.33334 , 355.6667  , 400.      ],      dtype=float32)

In [None]:
from lcm.function_representation import _get_coordinate_finder

wealth_coordinate_finder = _get_coordinate_finder(
    in_name="wealth",
    grid=wealth_grid,
)
wealth_coordinate_finder.__signature__

<Signature (wealth)>

To showcase the behavior of the coordinate finder, and how the *general* indices work,
consider the following wealth values:

- **1:** This value is the first value in the original grid, therefore the index must
  correspond to 0
- **(1 + 45.333336) / 2:** This value is exactly in the middle between the first and second
  value in grid, therefore the general index corresponds (0 + 1) / 2 = 0.5
- **395:** This value is very close to the last index in the original grid, so the general
  index will be very close to 9.

In [None]:
wealth_values = jnp.array([1, (1 + 45.333336) / 2, 390])

wealth_coordinate_finder(wealth=wealth_values)

Array([0.        , 0.50000006, 8.774436  ], dtype=float32)

### Interpolator

In [None]:
from lcm.function_representation import _get_interpolator

value_function_interpolator = _get_interpolator(
    name_of_values_on_grid="V_arr",
    axis_names=["wealth_index"],
)

value_function_interpolator.__signature__

<Signature (V_arr, wealth_index)>

In [None]:
wealth_indices = wealth_coordinate_finder(wealth=wealth_values)

V_interpolations = value_function_interpolator(
    wealth_index=wealth_indices,
    next_V_arr=V_arr,
)
V_interpolations

Array([0.       , 1.8806003, 5.952807 ], dtype=float32)

In [None]:
fig = go.Figure(
    data=[
        go.Scatter(
            x=wealth_grid.to_jax(),
            y=V_arr,
            mode="markers",
            marker={"color": blue, "size": 10},
            name="Pre-calculated values",
        ),
        go.Scatter(
            x=wealth_values,
            y=V_interpolations,
            mode="markers",
            marker={"color": orange, "size": 7},
            name="Evaluated Points",
        ),
    ],
)

fig.update_layout(
    xaxis_title="Wealth (x)",
    yaxis_title="V(x)",
    template="plotly_white",
)

fig.show()

## Re-implementation of the function representation given the example model

Next, we will outline and implement the steps to re-implement the function representation for
the example model specified above. This is intended to help with understanding how the
internals of the function representation work.

### The Steps

We start by listing the required steps. The general idea is to generate functions for
the array lookup, interpolation, and so on, with the correct signature signaling their
dependence structure. These can then be combined into a single function that performs
all necessary steps using `dags.concatenate_functions`.


1. Add functions to look up positions of discrete state variables given their labels

   In the above example model there are no discrete state variables, so we can skip this
   step. If there are discrete variables, the lookup functions will coincide with the
   identity function, as the variables themselves are indices.


1. Create the lookup function for the discrete part

   In this step, a function is generated that allows one to index into the
   pre-calculated value function array using the labels of the discrete
   state variables. In the above example model, there are no discrete state variables,
   so this function returns the value function array untouched.


1. Create interpolation functions for the continuous state variables

   If the model contains (dense) continuous state variables, interpolation functions
   are required.

   1. Add a coordinate finder for each continuous state variable
      
      This allows us to map values of the continuous variable into their corresponding
      (general) indices, as required by the interpolator.

   1. Add an interpolator

      The interpolator uses the general indices from the last step, to interpolate on
      the values of the state variable at the corresponding grid points.


1. Throwing everything into dags

  The last step is to throw everything into `dags.concatenate_functions`. The resulting
  function is a value function that behaves like an analytical function.

### The Implementation

In [None]:
# Create the functions dictionary that will be passed to `dags.concatenate_functions`
funcs = {}


# Step 1: Since there are no discrete state variables the lookup info is empty, and we
# do not require any label translator
space_info.lookup_info

{}

In [None]:
# Step 2: Since there are no discrete state variables in the model, the discrete
# lookup coincides with the identity function. Since there are continuous state
# variables in the model, we must interpolate and the data that is returned here is
# used as interpolation data.


def discrete_lookup(V_arr):
    return V_arr


# if there was no interpolation, the entry in the funcs dictionary would have to be
# '__fval__'.
funcs["__interpolation_data__"] = discrete_lookup

In [None]:
# Step 3: (1) First we need to add a coordinate finder for the wealth state variable
from lcm.grid_helpers import get_linspace_coordinate


def wealth_coordinate_finder(wealth):
    return get_linspace_coordinate(
        wealth,
        start=1,
        stop=400,
        n_points=10,
    )


funcs["__wealth_coord__"] = wealth_coordinate_finder

In [None]:
# Step 3: (2) And second, we need to add an interpolator for the value function that
# uses the wealth coordinate finder as an input.

from lcm.ndimage import map_coordinates


def interpolator(__interpolation_data__, __wealth_coord__):
    coordinates = jnp.array([__wealth_coord__])
    return map_coordinates(
        __interpolation_data__,
        coordinates=coordinates,
    )


funcs["__fval__"] = interpolator

In [None]:
# Step 4: Throwing everything into dags
from dags import concatenate_functions

value_function = concatenate_functions(
    functions=funcs,
    targets="__fval__",
)
value_function.__signature__

<Signature (V_arr, wealth)>

In [None]:
V_evaluated = value_function(wealth=wealth_grid.to_jax(), next_V_arr=V_arr)

In [None]:
fig = go.Figure(
    data=[
        go.Scatter(
            x=wealth_grid.to_jax(),
            y=V_arr,
            mode="markers",
            marker={"color": blue, "size": 10},
            name="Pre-calculated values",
        ),
        go.Scatter(
            x=wealth_grid.to_jax(),
            y=V_evaluated,
            mode="markers",
            marker={"color": orange, "size": 7},
            name="Evaluated Points",
        ),
    ],
)

fig.update_layout(
    xaxis_title="Wealth (x)",
    yaxis_title="V(x)",
    template="plotly_white",
)

fig.show()