In [1]:
import jax
import numpy as np

In [5]:
def func_a(x, y):
    return x + y

jax.vmap(func_a, in_axes=(0, None))(np.array([2]), 3)

Array([5], dtype=int32)

In [6]:
isinstance(np.array(2), np.ndarray)

True

In [10]:
# Check if array is a scalar
np.array(2).dtype == int

True

In [8]:
from jax import vmap, jit
import pickle
import pandas as pd
import io
import yaml
from functools import partial
import jax.numpy as jnp
import numpy as np
from tests.utils.markov_simulator import markov_simulator

ModuleNotFoundError: No module named 'tests'

In [4]:
def f(x, y):
    return x**2 + y


def f_aux(x, y):
    x_squ = x * +2
    return x_squ + y, x_squ


def g(func, x, y):
    func_val = func(x, y)
    if isinstance(func_val, tuple):
        if len(func_val) == 2:
            if isinstance(func_val[1], dict):
                return func_val
            else:
                raise ValueError(
                    "The second output of budget equation must be a dictionary."
                )
    else:
        return func_val

In [11]:
test_a = jnp.array(1.0)
test_b = jnp.array(2.0)

In [6]:
len(jnp.array([1.0, 2.0]).shape)

1

In [8]:
jnp.array([1.0]).ndim

1

In [12]:
jit_f = jit(f)
jit_f(test_a, test_b)

Array(3., dtype=float32, weak_type=True)

In [13]:
%timeit jit_f(test_a, test_b)

9.77 μs ± 1.02 μs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [14]:
jit_g = jit(lambda x, y: g(f, x, y))
jit_g_aux = jit(lambda x, y: g(f_aux, x, y))
jit_g(test_a, test_b)
jit_g_aux(test_a, test_b)

TypeError: len() of unsized object

In [23]:
isinstance(f_aux(test_a, test_b), tuple)

True

In [9]:
%timeit jit_g(test_a, test_b)

14.2 μs ± 3.65 μs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [10]:
%timeit jit_g_aux(test_a, test_b)

9.58 μs ± 3.92 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [34]:
n_periods = 10
init_dist = np.array([0.5, 0.5])
trans_mat = np.array([[0.8, 0.2], [0.1, 0.9]])

markov_simulator(n_periods, init_dist, trans_mat)

array([[0.5       , 0.5       ],
       [0.45      , 0.55      ],
       [0.415     , 0.585     ],
       [0.3905    , 0.6095    ],
       [0.37335   , 0.62665   ],
       [0.361345  , 0.638655  ],
       [0.3529415 , 0.6470585 ],
       [0.34705905, 0.65294095],
       [0.34294134, 0.65705866],
       [0.34005893, 0.65994107]])

In [47]:
n_agents = 100_000
current_agents_in_states = (np.ones(2) * n_agents / 2).astype(int)
for period in range(n_periods):
    print(current_agents_in_states / n_agents)
    next_period_agents_states = np.zeros(2, dtype=int)
    for state in range(2):
        agents_in_state = current_agents_in_states[state]
        transition_draws = np.random.choice(
            a=[0, 1], size=agents_in_state, p=trans_mat[state, :]
        )
        next_period_agents_states[1] += transition_draws.sum()
        next_period_agents_states[0] += agents_in_state - transition_draws.sum()
    current_agents_in_states = next_period_agents_states

[0.5 0.5]
[0.4502 0.5498]
[0.4189 0.5811]
[0.39164 0.60836]
[0.37405 0.62595]
[0.35994 0.64006]
[0.35166 0.64834]
[0.34544 0.65456]
[0.34263 0.65737]
[0.34015 0.65985]


In [29]:
trans_mat[0, :]

Array([0.8, 0.2], dtype=float32)

In [2]:
def squared_num(a):
    return a**2


def square_and_select(ind, array):
    return array[ind] ** 2


def ex_ante_select(indices, array_to_select):
    selected_array = jnp.take(array_to_select, indices)
    return vmap(squared_num, in_axes=(0))(selected_array)


def select_on_the_go(indices, array_to_select):
    return vmap(square_and_select, in_axes=(0, None))(indices, array_to_select)


def select_dict_with_loop(input_array):
    out = {}
    for i in range(input_array.shape[0]):
        out[i] = jnp.sum(input_array[i, :])
    return out


def append_array(input_array):
    out_array = jnp.array([])
    for i in range(input_array.shape[0]):
        out_array = jnp.append(out_array, jnp.sum(input_array[i, :]))
    return out_array

In [9]:
np.random.gumbel(size=100000).mean()

0.5794452868115461

In [7]:
np.euler_gamma

0.5772156649015329

In [3]:
size_ind = 100_000_000
to_select_test = jnp.array([2, 4])
indices_test = jnp.append(jnp.zeros(size_ind), jnp.ones(size_ind)).astype(int)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [4]:
jit_exante = jit(ex_ante_select)
jit_expost = jit(select_on_the_go)

In [5]:
jit_exante(indices_test, to_select_test)
%timeit jit_exante(indices_test, to_select_test)

224 ms ± 10.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
jit_expost(indices_test, to_select_test)
%timeit jit_expost(indices_test, to_select_test)

221 ms ± 14.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
size_loop = 50
size_sum = 10_000
multi_dim = jnp.empty((size_loop, size_sum))

jit_dict_select = jit(select_dict_with_loop)
jit_append = jit(append_array)

In [8]:
jit_dict_select(multi_dim)
%timeit jit_dict_select(multi_dim)

275 µs ± 4.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [9]:
jit_append(multi_dim)
%timeit jit_append(multi_dim)

215 µs ± 7.21 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [14]:
jnp.take(multi_dim, jnp.array([10, 20, -99]), axis=0)

Array([[ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [nan, nan, nan, ..., nan, nan, nan]], dtype=float32)