# Explanations of Dispatchers

In this notebook, we showcase how the internal functions `vmap_1d`, `productmap` and
`spacemap` are used by `lcm`.

In [None]:
import jax.numpy as jnp
import pytest
from jax import vmap

from lcm.dispatchers import productmap, spacemap, vmap_1d

# `vmap_1d`

Let's start by vectorizing the function `f` over axis `a` using Jax' `vmap` function.

In [None]:
def f(a, b):
    return a + b

In [None]:
a = jnp.linspace(0, 1, 5)


# in_axes = (0, None) means that the first argument is mapped over, and the second
# argument is kept constant
f_vmapped = vmap(f, in_axes=(0, None))

f_vmapped(a, 1)

Array([1.  , 1.25, 1.5 , 1.75, 2.  ], dtype=float32)

However, note that we can call `f` with keyword arguments, but not `f_vmapped`:

In [None]:
f(a=a, b=1)

Array([1.  , 1.25, 1.5 , 1.75, 2.  ], dtype=float32)

In [None]:
with pytest.raises(
    ValueError,
    match="vmap in_axes must be an int, None, or a tuple of entries corresponding to",
):
    f_vmapped(a=a, b=1)

In order to allow for the flexibility and safety in calling vmapped functions with keyword arguments, `lcm` provides the function `vmap_1d`.

In [None]:
f_vmapped_1d = vmap_1d(f, variables=["a"])

In [None]:
f_vmapped_1d(a=a, b=1)

Array([1.  , 1.25, 1.5 , 1.75, 2.  ], dtype=float32)

# `productmap`

Let's vectorize the function `g` over a Cartesian product of its variables.
For this, `lcm` provides the `productmap` function.

In [None]:
def g(a, b, c, d):
    return a + b + c + d

In [None]:
a = jnp.arange(2)
b = jnp.arange(3)
c = jnp.arange(4)
d = -1

In [None]:
g_mapped = productmap(g, variables=["a", "b", "c"])

In [None]:
res = g_mapped(a=a, b=b, c=c, d=d)
res

Array([[[-1,  0,  1,  2],
        [ 0,  1,  2,  3],
        [ 1,  2,  3,  4]],

       [[ 0,  1,  2,  3],
        [ 1,  2,  3,  4],
        [ 2,  3,  4,  5]]], dtype=int32)

In [None]:
res.shape

(2, 3, 4)

# `spacemap`

The `spacemap` function combines `productmap` and `vmap_1d` in a way that is often
needed in `lcm`.

If the valid values of a variable in a state-choice space depend on another variable, that variable is termed a _sparse_ variable; otherwise, it is a _dense_ variable. To dispatch a function across an entire state-choice space, we must vectorize over both dense and sparse variables. Since, by definition, all values of dense variables are valid, we can simply perform a `productmap` over the Cartesian grid of their values. The valid combinations of sparse variables are stored as a collection of 1D arrays (see below for an example). For these, we can perform a call to `vmap_1d`.

Consider a simplified version of our deterministic test model. Curly brackets {} denote discrete variables; square brackets [] represent continuous variables.

- **Choice variables:**

  - _retirement_ $\in \{0, 1\}$

  - _consumption_ $\in [1, 2]$

- **State variables:**

  - _lagged_retirement_ $\in \{0, 1\}$

  - _wealth_ $\in [1, 2, 3, 4]$

- **Filter:**
  - Absorbing retirement filter: If _lagged_retirement_ is 1, then the choice
    _retirement_ can never be 0.

In [None]:
from lcm import DiscreteGrid, LinspaceGrid, Model


def utility(consumption, retirement, lagged_retirement, wealth):
    working = 1 - retirement
    retirement_habit = lagged_retirement * wealth
    return jnp.log(consumption) - 0.5 * working + retirement_habit


def absorbing_retirement_filter(retirement, lagged_retirement):
    return jnp.logical_or(retirement == 1, lagged_retirement == 0)


model = Model(
    functions={
        "utility": utility,
        "next_lagged_retirement": lambda retirement: retirement,
        "next_wealth": lambda wealth, consumption: wealth - consumption,
        "absorbing_retirement_filter": absorbing_retirement_filter,
    },
    n_periods=1,
    choices={
        "retirement": DiscreteGrid([0, 1]),
        "consumption": LinspaceGrid(start=1, stop=2, n_points=2),
    },
    states={
        "lagged_retirement": DiscreteGrid([0, 1]),
        "wealth": LinspaceGrid(start=1, stop=4, n_points=4),
    },
)

In [None]:
from lcm.process_model import process_model
from lcm.state_space import create_state_choice_space

processed_model = process_model(model)

sc_space, space_info, state_indexer, segments = create_state_choice_space(
    processed_model,
    period=2,
    is_last_period=False,
    jit_filter=False,
)

Now, the state-choice space includes all sparse and dense states and choices, except for the dense continuous choices, as these are managed differently in `lcm`.

Therefore, we anticipate the state-choice space to encompass the dense state variable _wealth_ and a representation of the sparse combination of _retirement_ and _lagged_retirement_.

In [None]:
sc_space.dense_vars

{'wealth': Array([1., 2., 3., 4.], dtype=float32)}

In [None]:
sc_space.sparse_vars

{'lagged_retirement': Array([0, 0, 1], dtype=int32),
 'retirement': Array([0, 1, 1], dtype=int32)}

In [None]:
import pandas as pd

pd.DataFrame(sc_space.sparse_vars)

Unnamed: 0,lagged_retirement,retirement
0,0,0
1,0,1
2,1,1


Notice that for the dense variables, the state-choice space contains the whole grid of
possible values. For the sparse variables, however, the state-choice space contains
one dimensional arrays that can be thought of as columns in a dataframe such that each
row in that dataframe represents a valid combination.

Initially, we mentioned that combinations of _lagged_retirement_ being 1 and _retirement_ being 0 are disallowed. This specific combination is absent from the dataframe.

---
### Remark on memory usage and computational efficiency

**Dense variables**

- Require a 1D array with grid values for each variable (low memory usage)

- Apply function on product (high computational load)

- Store results (high memory usage)

$\Rightarrow$ Computational load and memory usage depend on product of dimensions of
  dense variables


**Sparse variables**

- Need to store one row for each valid state-choice combination (high memory usage)

- Apply function along first axis (low computational load, unless many rows)

- Store results (lower memory usage, unless many rows)

$\Rightarrow$ Computational load and memory usage depend on number of valid state-choice
  combinations

---

It is also worth noting the connection between the sparse variable representation
and the `segments`. 

In [None]:
segments

{'segment_ids': Array([0, 0, 1], dtype=int32), 'num_segments': 2}

These choice segments divide the rows in the above dataframe into segments for which
choices have to be made.

In our example this means that the first choice segment is made
out of the first two rows, meaning that if _lagged_retirement_ is 0, the choice of
_retirement_ can be either 0 or 1. However, for
the case of _lagged_retirement_ being 1, the choice segment contains only the single
choice _retirement_ equal to 1. 

Now, we can map a function over the entire state-choice space using the `spacemap`
function.

In [None]:
spacemapped = spacemap(
    func=utility,
    dense_vars=list(sc_space.dense_vars),
    sparse_vars=list(sc_space.sparse_vars),
    put_dense_first=False,
)

In [None]:
sc_space.dense_vars

{'wealth': Array([1., 2., 3., 4.], dtype=float32)}

In [None]:
sc_space.sparse_vars

{'lagged_retirement': Array([0, 0, 1], dtype=int32),
 'retirement': Array([0, 1, 1], dtype=int32)}

In [None]:
res = spacemapped(
    **sc_space.dense_vars,
    **sc_space.sparse_vars,
    consumption=1,
)
res

Array([[-0.5, -0.5, -0.5, -0.5],
       [ 0. ,  0. ,  0. ,  0. ],
       [ 1. ,  2. ,  3. ,  4. ]], dtype=float32)

In [None]:
res.shape

(3, 4)

Let's try to get this result via looping over the grids and calling `utility` directly

In [None]:
_res = jnp.empty((3, 4))

# loop over valid combinations of sparse variables (first axis)
for i, (lagged_retirement, retirement) in enumerate(
    zip(
        sc_space.sparse_vars["lagged_retirement"],
        sc_space.sparse_vars["retirement"],
        strict=False,
    ),
):
    # loop over product of dense variables
    for j, wealth in enumerate(sc_space.dense_vars["wealth"]):
        u = utility(
            wealth=wealth,
            retirement=retirement,
            lagged_retirement=lagged_retirement,
            consumption=1,
        )
        _res = _res.at[i, j].set(u)  # JAX arrays are immutable

_res

Array([[-0.5, -0.5, -0.5, -0.5],
       [ 0. ,  0. ,  0. ,  0. ],
       [ 1. ,  2. ,  3. ,  4. ]], dtype=float32)

If `put_dense_first` was False, the order of the loops need to be switched, leading to an output shape of (4, 3).

---

### Explanation of Results

The outputs align with the utility function: The rows represent sparse combinations of _lagged_retirement_ and _retirement_, while the columns represent values of _wealth_. For column $j$, the value of _wealth_ corresponds to $j$.

Consider the first row, corresponding to _lagged_retirement_ being 0 and _retirement_ also being 0. In this scenario, the agent is working, incurring a cost of -0.5. As the agent is not retired, there is no utility from a retirement habit, resulting in a utility of $log(1) - 0.5 = -0.5$ for all _wealth_ values.

The second row corresponds to _lagged_retirement_ being 0 and _retirement_ being 1. Here, the agent is retired, thus avoiding work-related costs. Being newly retired, the agent receives no utility through a retirement habit in this model, leading to a utility of $log(1) = 0$ across all _wealth_ values.

The final row represents _lagged_retirement_ being 1 and _retirement_ also being 1. In this case, the agent, already retired, incurs no work-related costs. Additionally, having been retired for one period, the agent gains utility from a retirement habit that increases linearly with wealth, making the utility $log(1) + wealth = wealth$.

---

The variable _consumption_ belongs to the special class of `continuous` and `dense` choice variables. When computing the maximum of the value function over the agents' choices, we solve the continuous problem for each combination of state and discrete/sparse choices. Therefore, the vectorization over continuous-dense choices is performed independently from the vectorization over the rest of the state-choice space.

To vectorize over _consumption_, we must use an additional `productmap`.

In [None]:
mapped = productmap(spacemapped, variables=["consumption"])

In [None]:
res = mapped(
    **sc_space.dense_vars,
    **sc_space.sparse_vars,
    consumption=jnp.linspace(1, 400, 2),
)
res

Array([[[-0.5      , -0.5      , -0.5      , -0.5      ],
        [ 0.       ,  0.       ,  0.       ,  0.       ],
        [ 1.       ,  2.       ,  3.       ,  4.       ]],

       [[ 5.4914646,  5.4914646,  5.4914646,  5.4914646],
        [ 5.9914646,  5.9914646,  5.9914646,  5.9914646],
        [ 6.9914646,  7.9914646,  8.991465 ,  9.991465 ]]], dtype=float32)

In [None]:
res.shape

(2, 3, 4)