# Helpers

> Collection of helper functions

In [None]:
#| default_exp Utils/Helpers

In [None]:
#| export
import jax
import numpy as np
import jax.numpy as jnp
from jax import jit

In [None]:
#| export
def make_variable_vector(variable,  # can be iterable or float or int
                         length:int  # length of the vector
                        ):  # vector
    "Turn a `variable` into a vector or check that `length` is consistent."
    if hasattr(variable, '__iter__'):
        assert len(variable) == length, 'Wrong number given'
        return jnp.array(variable)
    else:
        return jnp.repeat(variable, length)

For example, when providing a discount factor of 0.9 to all 5 agents, we can simply write 

In [None]:
make_variable_vector(0.9, 5)

Array([0.9, 0.9, 0.9, 0.9, 0.9], dtype=float32, weak_type=True)

In [None]:
#| export
@jit
def compute_stationarydistribution(Tkk:jnp.ndarray):  # Transition matrix
    """Compute stationary distribution for transition matrix `Tkk`."""
    # eigenvectors
    oeival, oeivec = jnp.linalg.eig(Tkk.T)
    oeival = oeival.real
    oeivec = oeivec.real
    
    get_mask = lambda tol: jnp.abs(oeival - 1) < tol
  
    tolerances = jax.lax.map(lambda x: 0.1**x, jnp.arange(1,16,1))
    masks = jax.lax.map(get_mask, tolerances)
    ix = jnp.max(jnp.where(masks.sum(-1)>=1, jnp.arange(len(masks)), -1))
    mask = masks[ix]
    tol = tolerances[ix]
    
    # obtain stationary distribution
    meivec = jnp.where(mask, oeivec, -42)
    
    dist = meivec / meivec.sum(axis=0, keepdims=True)
    dist = jnp.where(dist < tol, 0, dist)
    dist = dist / dist.sum(axis=0, keepdims=True)
    
    return jnp.where(meivec==-42, -10, dist)

For example, let's create a random transition matrix with dimension 4:

In [None]:
Tkk = np.random.rand(4,4)

A transition matrix contains probabilities, which need to sum up to 1.

In [None]:
Tkk = Tkk / Tkk.sum(-1, keepdims=True) 

`compute_stationarydistribution` should return a 4 by 4 matrix with the stationary distribution in the first column, and the rest filled with a dummy value of -10. This was done to make it work with jax just-in-time-compilation.

In [None]:
compute_stationarydistribution(Tkk).round(1)

Array([[  0.2, -10. , -10. , -10. ],
       [  0.2, -10. , -10. , -10. ],
       [  0.1, -10. , -10. , -10. ],
       [  0.5, -10. , -10. , -10. ]], dtype=float32)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()