Skip to content

Commit d501a58

Browse files
authored
ENH: include nvidia-smi across lectures (#60)
* ENH: include nvidia-smi across lectures * reduce words
1 parent c391cbc commit d501a58

File tree

11 files changed

+59
-9
lines changed

11 files changed

+59
-9
lines changed

lectures/aiyagari_jax.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ import jax
5757
import jax.numpy as jnp
5858
```
5959

60+
Let's check the GPU we are running
61+
62+
```{code-cell} ipython3
63+
!nvidia-smi
64+
```
65+
6066
We will use 64 bit floats with JAX in order to increase the precision.
6167

6268
```{code-cell} ipython3

lectures/arellano.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,16 @@ import jax
7979
import jax.numpy as jnp
8080
```
8181

82+
Let's check the GPU we are running
83+
84+
```{code-cell} ipython3
85+
!nvidia-smi
86+
```
87+
88+
We will use 64 bit floats with JAX in order to increase the precision.
89+
8290
```{code-cell} ipython3
83-
jax.config.update('jax_enable_x64', True)
91+
jax.config.update("jax_enable_x64", True)
8492
```
8593

8694
## Structure

lectures/ifp_egm.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ from numba import njit, float64
4848
from numba.experimental import jitclass
4949
```
5050

51+
Let's check the GPU we are running
52+
53+
```{code-cell} ipython3
54+
!nvidia-smi
55+
```
5156

5257
We use 64 bit floating point numbers for extra precision.
5358

lectures/inventory_dynamics.md

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,10 @@ from jax import random, lax
4848
from collections import namedtuple
4949
```
5050

51-
Lets check the backend used by JAX and the devices available
51+
Let's check the GPU we are running
5252

5353
```{code-cell} ipython3
54-
# Check if JAX is using GPU
55-
print(f"JAX backend: {jax.devices()[0].platform}")
56-
57-
# Check the devices available for JAX
58-
print(jax.devices())
54+
!nvidia-smi
5955
```
6056

6157
## Sample paths

lectures/kesten_processes.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ import jax.numpy as jnp
6262
from jax import random
6363
```
6464

65+
Let's check the GPU we are running
66+
67+
```{code-cell} ipython3
68+
!nvidia-smi
69+
```
70+
6571
## Kesten processes
6672

6773
```{index} single: Kesten processes; heavy tails

lectures/newtons_method.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ import jax.numpy as jnp
3535
from scipy.optimize import root
3636
```
3737

38+
Let's check the GPU we are running
39+
40+
```{code-cell} ipython3
41+
!nvidia-smi
42+
```
43+
3844
## The Equilibrium Problem
3945

4046
In this section we describe the market equilibrium problem we will solve with

lectures/opt_invest.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ We require the following library to be installed.
2525
!pip install --upgrade quantecon
2626
```
2727

28-
2928
A monopolist faces inverse demand
3029
curve
3130

@@ -66,6 +65,12 @@ import jax.numpy as jnp
6665
import matplotlib.pyplot as plt
6766
```
6867

68+
Let's check the GPU we are running
69+
70+
```{code-cell} ipython3
71+
!nvidia-smi
72+
```
73+
6974
We will use 64 bit floats with JAX in order to increase the precision.
7075

7176
```{code-cell} ipython3

lectures/opt_savings.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ import matplotlib.pyplot as plt
3333
import time
3434
```
3535

36+
Let's check the GPU we are running
37+
38+
```{code-cell} ipython3
39+
!nvidia-smi
40+
```
41+
3642
Use 64 bit floats with JAX in order to match NumPy code
3743
- By default, JAX uses 32-bit datatypes.
3844
- By default, NumPy uses 64-bit datatypes.

lectures/short_path.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ import jax.numpy as jnp
3030
import jax
3131
```
3232

33+
Let's check the GPU we are running
34+
35+
```{code-cell} ipython3
36+
!nvidia-smi
37+
```
38+
3339
## Solving for Minimum Cost-to-Go
3440

3541
Let $J(v)$ denote the minimum cost-to-go from node $v$,

lectures/status.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,10 @@ You can check the backend used by JAX using:
2828
import jax
2929
# Check if JAX is using GPU
3030
print(f"JAX backend: {jax.devices()[0].platform}")
31+
```
32+
33+
and the hardware we are running on:
34+
35+
```{code-cell} ipython3
36+
!nvidia-smi
3137
```

0 commit comments

Comments
 (0)