Skip to content

Commit fe23f83

Browse files
committed
revert kesten processes
1 parent bbbd6b0 commit fe23f83

File tree

1 file changed

+11
-160
lines changed

1 file changed

+11
-160
lines changed

lectures/kesten_processes.md

Lines changed: 11 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,6 @@ kernelspec:
1919

2020
# Kesten Processes and Firm Dynamics
2121

22-
```{admonition} GPU
23-
:class: warning
24-
25-
This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and JAX for GPU programming.
26-
27-
Free GPUs are available on Google Colab. To use this option, please click on the play icon top right, select Colab, and set the runtime environment to include a GPU.
28-
29-
Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax) for installing JAX with GPU support. If you would like to install jax running on the `cpu` only you can use `pip install jax[cpu]`
30-
```
31-
3222
```{index} single: Linear State Space Models
3323
```
3424

@@ -683,22 +673,15 @@ s_init = 1.0 # initial condition for each firm
683673
:class: dropdown
684674
```
685675

686-
Here's one solution in [JAX](https://python-programming.quantecon.org/jax_intro.html).
687-
688-
First let's import the necessary modules and check the backend for JAX
676+
Here's one solution.
677+
First we generate the observations:
689678

690679
```{code-cell} ipython3
691-
import jax
692-
import jax.numpy as jnp
693-
from jax import random
694-
695-
# Check if JAX is using GPU
696-
print(f"jax backend: {jax.devices()[0].platform}")
697-
```
680+
from numba import njit, prange
681+
from numpy.random import randn
698682
699-
Now we can generate the observations:
700683
701-
```{code-cell} ipython3
684+
@njit(parallel=True)
702685
def generate_draws(μ_a=-0.5,
703686
σ_a=0.1,
704687
μ_b=0.0,
@@ -708,139 +691,7 @@ def generate_draws(μ_a=-0.5,
708691
s_bar=1.0,
709692
T=500,
710693
M=1_000_000,
711-
s_init=1.0,
712-
seed=123):
713-
714-
key = random.PRNGKey(seed)
715-
keys = random.split(key, 3)
716-
717-
# Initialize the array of s values with the initial value
718-
s = jnp.full((M, ), s_init)
719-
720-
@jax.jit
721-
def update_s(s, keys):
722-
a_random = μ_a + σ_a * random.normal(keys[0], (M, ))
723-
b_random = μ_b + σ_b * random.normal(keys[1], (M, ))
724-
e_random = μ_e + σ_e * random.normal(keys[2], (M, ))
725-
726-
exp_a = jnp.exp(a_random)
727-
exp_b = jnp.exp(b_random)
728-
exp_e = jnp.exp(e_random)
729-
730-
s = jnp.where(s < s_bar,
731-
exp_e,
732-
exp_a * s + exp_b)
733-
734-
return s, keys[-1]
735-
736-
# Perform updates on s for time t
737-
for t in range(T):
738-
s, key = update_s(s, keys)
739-
keys = random.split(key, 3)
740-
741-
return s
742-
743-
%time data = generate_draws().block_until_ready()
744-
```
745-
746-
As JIT-compiled `for` loops will lead to very slow compilation, we used `jax.jit` on the function `update_s` instead of the whole function.
747-
748-
Let's produce the rank-size plot and check the distribution:
749-
750-
```{code-cell} ipython3
751-
fig, ax = plt.subplots()
752-
753-
rank_data, size_data = qe.rank_size(data, c=0.01)
754-
ax.loglog(rank_data, size_data, 'o', markersize=3.0, alpha=0.5)
755-
ax.set_xlabel("log rank")
756-
ax.set_ylabel("log size")
757-
758-
plt.show()
759-
```
760-
761-
The plot produces a straight line, consistent with a Pareto tail.
762-
763-
It is possible to further speed up our code by replacing the `for` loop with [`lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html)
764-
to reduce the loop overhead in the compilation of the jitted function
765-
766-
```{code-cell} ipython3
767-
from jax import lax
768-
769-
@jax.jit
770-
def generate_draws_lax(μ_a=-0.5,
771-
σ_a=0.1,
772-
μ_b=0.0,
773-
σ_b=0.5,
774-
μ_e=0.0,
775-
σ_e=0.5,
776-
s_bar=1.0,
777-
T=500,
778-
M=1_000_000,
779-
s_init=1.0,
780-
seed=123):
781-
782-
key = random.PRNGKey(seed)
783-
keys = random.split(key, T)
784-
785-
# Generate random draws and initial values
786-
a_random = μ_a + σ_a * random.normal(keys[0], (T, M))
787-
b_random = μ_b + σ_b * random.normal(keys[1], (T, M))
788-
e_random = μ_e + σ_e * random.normal(keys[2], (T, M))
789-
s = jnp.full((M, ), s_init)
790-
791-
# Define the function for each update
792-
def update_s(s, a_b_e_draws):
793-
a, b, e = a_b_e_draws
794-
s = jnp.where(s < s_bar,
795-
jnp.exp(e),
796-
jnp.exp(a) * s + jnp.exp(b))
797-
return s, None
798-
799-
# Use lax.scan to perform the calculations on all states
800-
s_final, _ = lax.scan(update_s, s, (a_random, b_random, e_random))
801-
return s_final
802-
803-
%time data = generate_draws_lax().block_until_ready()
804-
```
805-
806-
Since we used `jax.jit` on the entire function, the compiled function is even faster
807-
808-
```{code-cell} ipython3
809-
%time data = generate_draws_lax().block_until_ready()
810-
```
811-
812-
Here we produce the same rank-size plot:
813-
814-
```{code-cell} ipython3
815-
fig, ax = plt.subplots()
816-
817-
rank_data, size_data = qe.rank_size(data, c=0.01)
818-
ax.loglog(rank_data, size_data, 'o', markersize=3.0, alpha=0.5)
819-
ax.set_xlabel("log rank")
820-
ax.set_ylabel("log size")
821-
822-
plt.show()
823-
```
824-
825-
We can also use Numba with `for` loops to generate the observations (replicating the results we obtained with JAX).
826-
827-
The results will be slightly different since the pseudo random number generation is implemented [differently in JAX](https://www.kaggle.com/code/aakashnain/tf-jax-tutorials-part-6-prng-in-jax/notebook)
828-
829-
```{code-cell} ipython3
830-
from numba import njit, prange
831-
from numpy.random import randn
832-
833-
@njit(parallel=True)
834-
def generate_draws_numba(μ_a=-0.5,
835-
σ_a=0.1,
836-
μ_b=0.0,
837-
σ_b=0.5,
838-
μ_e=0.0,
839-
σ_e=0.5,
840-
s_bar=1.0,
841-
T=500,
842-
M=1_000_000,
843-
s_init=1.0):
694+
s_init=1.0):
844695
845696
draws = np.empty(M)
846697
for m in prange(M):
@@ -857,12 +708,10 @@ def generate_draws_numba(μ_a=-0.5,
857708
858709
return draws
859710
860-
%time data = generate_draws_numba()
711+
data = generate_draws()
861712
```
862713

863-
We can see that JAX and vectorization of the code have sped up the computation significantly compared to the Numba version.
864-
865-
We produce the rank-size plot again using the data, and it shows the same pattern we saw before:
714+
Now we produce the rank-size plot:
866715

867716
```{code-cell} ipython3
868717
fig, ax = plt.subplots()
@@ -875,5 +724,7 @@ ax.set_ylabel("log size")
875724
plt.show()
876725
```
877726

727+
The plot produces a straight line, consistent with a Pareto tail.
728+
878729
```{solution-end}
879-
```
730+
```

0 commit comments

Comments
 (0)