# Explanations of dispatchers

In this notebook we showcase how the internal functions `vmap_1d`, `productmap` and
`spacemap` can and are being used by `lcm` internally.

In [1]:
import jax.numpy as jnp
from jax import vmap
from lcm.dispatchers import productmap, spacemap, vmap_1d

# `vmap_1d`

Let's try to vectorize function `f` over axis `a`.

In [2]:
def f(a, b):
    return a + b

In [3]:
a = jnp.linspace(0, 1, 10)


# in_axes = (0, None) means that the first argument is mapped over, and the second
# argument is kept constant
f_vmapped = vmap(f, in_axes=(0, None))

f_vmapped(a, 1)

Array([1.       , 1.1111112, 1.2222222, 1.3333334, 1.4444444, 1.5555556,
       1.6666667, 1.7777778, 1.8888888, 2.       ], dtype=float32)

However, notice that we can call `f` with keyword arguments, but not `f_vmapped`:

In [4]:
f(a=a, b=1)

Array([1.       , 1.1111112, 1.2222222, 1.3333334, 1.4444444, 1.5555556,
       1.6666667, 1.7777778, 1.8888888, 2.       ], dtype=float32)

In [5]:
f_vmapped(a=a, b=1)

ValueError: vmap in_axes must be an int, None, or a tuple of entries corresponding to the positional arguments passed to the function, but got len(in_axes)=2, len(args)=0

For this, `lcm` provides the function `vmap_1d`.

In [6]:
f_vmapped_1d = vmap_1d(f, variables=["a"])

In [7]:
f_vmapped_1d(a=a, b=1)

Array([1.       , 1.1111112, 1.2222222, 1.3333334, 1.4444444, 1.5555556,
       1.6666667, 1.7777778, 1.8888888, 2.       ], dtype=float32)

# `productmap`

Let's try to vectorize the function `g` over a Cartesian product of its variables.
For this, `lcm` provides the `productmap` function.

In [8]:
def g(a, b, c, d):
    return a + b + c + d

In [9]:
a = jnp.arange(2)
b = jnp.arange(3)
c = jnp.arange(4)
d = -1

In [10]:
g_mapped = productmap(g, variables=["a", "b", "c"])

In [11]:
res = g_mapped(a=a, b=b, c=c, d=d)
res

Array([[[-1,  0,  1,  2],
        [ 0,  1,  2,  3],
        [ 1,  2,  3,  4]],

       [[ 0,  1,  2,  3],
        [ 1,  2,  3,  4],
        [ 2,  3,  4,  5]]], dtype=int32)

In [12]:
res.shape

(2, 3, 4)

# `spacemap`

In `lcm` we often have to evaluate functions not only on a well-defined Cartesian
product of grids, but on a mix of that with an awkward array. This is due to the fact
that the state-choice space can be filtered for some variables, in which case only
specific combinations of choices and states are in the state-choice space.

If the valied values of one variable depend on some other variable, the variable is
called _sparse_.

The `spacemap` function provides a way to map a function over an entire state-choice
space. To illustrate this, let us first show an example of such a state-choice space.


Consider a simplified version of our deterministic test model, with

- **Choice variables:**

  - `retirement` in {0, 1}

  - `consumption` in [1, 2]

- **State variables:**

  - `lagged_retirement` in {0, 1}

  - `wealth` in [1, 2, 3, 4]

And an absorbing retirement _filter_: That is, if `lagged_retirement` is 1, `retirement`
can never be 0, which leads to an awkward state-choice space.

We use different brackets {} and [] do denote discrete and continuous variables. For
this example we choose a very coarse continuous grid, but in practice this could be
much finer.

In [13]:
def utility(consumption, retirement, lagged_retirement, wealth):  # noqa: ARG001
    working = 1 - retirement
    bequest_utility = wealth
    return jnp.log(consumption) - 0.5 * working + bequest_utility


def absorbing_retirement_filter(retirement, lagged_retirement):
    return jnp.logical_or(retirement == 1, lagged_retirement == 0)


MODEL_CONFIG = {
    "functions": {
        "utility": utility,
        "next_lagged_retirement": lambda retirement: retirement,
        "next_wealth": lambda wealth, consumption: wealth - consumption,
        "absorbing_retirement_filter": absorbing_retirement_filter,
    },
    "choices": {
        "retirement": {"options": [0, 1]},
        "consumption": {
            "grid_type": "linspace",
            "start": 1,
            "stop": 2,
            "n_points": 2,
        },
    },
    "states": {
        "lagged_retirement": {"options": [0, 1]},
        "wealth": {
            "grid_type": "linspace",
            "start": 1,
            "stop": 4,
            "n_points": 4,
        },
    },
    "n_periods": 1,
}

In [14]:
from lcm.process_model import process_model
from lcm.state_space import create_state_choice_space

model = process_model(user_model=MODEL_CONFIG)

sc_space, space_info, state_indexer, segments = create_state_choice_space(
    model,
    period=2,
    is_last_period=False,
    jit_filter=False,
)

Now, the state-choice space contains all sparse and dense states and choices, except
for the dense continuous choices, since these are handled differently in `lcm`.

So, we expect the state-choice space to contain the dense state variable `wealth` and
some representation of the sparse combination of `retirement` and `lagged_retirement`. 

In [15]:
sc_space.dense_vars

{'wealth': Array([1., 2., 3., 4.], dtype=float32)}

In [16]:
sc_space.sparse_vars

{'lagged_retirement': Array([0, 0, 1], dtype=int32),
 'retirement': Array([0, 1, 1], dtype=int32)}

In [17]:
import pandas as pd

pd.DataFrame(sc_space.sparse_vars)

Unnamed: 0,lagged_retirement,retirement
0,0,0
1,0,1
2,1,1


Notice that for the dense variables, the state-choice space contains the whole grid of
possible values. For the sparse variables, however, the state-choice space contains
one dimensional arrays that can be thought of as columns in a dataframe such that each
row in that dataframe represents a valid combination.

In the beginning we said that we disallow `lagged_retirement` being 1 and `retirement`
being 0, and this is exactly the combination that is missing from this dataframe.

It is also worth noting the connection between the sparse variable representation
and the `segments`. 

In [18]:
segments

{'segment_ids': Array([0, 0, 1], dtype=int32), 'num_segments': 2}

These choice segments divide the rows in the above dataframe into segments for which
choices have to be made.

In our example this means that the first choice segment is made
out of the first two rows, meaning that if `lagged_retirement` is 0, the choice of
`retirement` can be either 0 or 1. However, for
the case of `lagged_retirement` being 1, the choice segment contains only the single
choice `retirement` equal to 1. 

Now, we can map a function over the entire state-choice space using the `spacemap`
function.

In [19]:
spacemapped = spacemap(
    func=utility,
    dense_vars=list(sc_space.dense_vars),
    sparse_vars=list(sc_space.sparse_vars),
    put_dense_first=False,
)

In [20]:
res = spacemapped(
    **sc_space.dense_vars,
    **sc_space.sparse_vars,
    consumption=1,
)
res

Array([[0.5, 1.5, 2.5, 3.5],
       [1. , 2. , 3. , 4. ],
       [1. , 2. , 3. , 4. ]], dtype=float32)

In [21]:
res.shape

(3, 4)

Notice, that we do not vectorize over `consumption` yet, and since `put_dense_first` is
False, the leading axis corresponds to the sparse variables. The remaining axes
correspond to the dense variables. In this case `wealth`, which attains values on
a grid of length 4.

To also vectorize over `consumption` we would have to call an additional `productmap`.

In [22]:
mapped = productmap(spacemapped, variables=["consumption"])

In [23]:
res = mapped(
    **sc_space.dense_vars,
    **sc_space.sparse_vars,
    consumption=jnp.linspace(1, 400, 2),
)
res

Array([[[0.5      , 1.5      , 2.5      , 3.5      ],
        [1.       , 2.       , 3.       , 4.       ],
        [1.       , 2.       , 3.       , 4.       ]],

       [[6.4914646, 7.4914646, 8.491465 , 9.491465 ],
        [6.9914646, 7.9914646, 8.991465 , 9.991465 ],
        [6.9914646, 7.9914646, 8.991465 , 9.991465 ]]], dtype=float32)

In [24]:
res.shape

(2, 3, 4)