Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions lectures/jax_intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ kernelspec:

This lecture provides a short introduction to [Google JAX](https://github.com/jax-ml/jax).

JAX is a high-performance scientific computing library that provides
JAX is a high-performance scientific computing library that provides

* a NumPy-like interface that can automatically parallize across CPUs and GPUs,
* a [NumPy](https://en.wikipedia.org/wiki/NumPy)-like interface that can automatically parallelize across CPUs and GPUs,
* a just-in-time compiler for accelerating a large range of numerical
operations, and
* automatic differentiation.
* [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation).

Increasingly, JAX also maintains and provides more specialized scientific
computing routines, such as those originally found in SciPy.
Increasingly, JAX also maintains and provides [more specialized scientific
computing routines](https://docs.jax.dev/en/latest/jax.scipy.html), such as those originally found in [SciPy](https://en.wikipedia.org/wiki/SciPy).

In addition to what's in Anaconda, this lecture will need the following libraries:

Expand All @@ -36,7 +36,7 @@ In addition to what's in Anaconda, this lecture will need the following librarie
```{admonition} GPU
:class: warning

This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and target JAX for GPU programming.
This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and targets JAX for GPU programming.

Free GPUs are available on Google Colab.
To use this option, please click on the play icon top right, select Colab, and set the runtime environment to include a GPU.
Expand All @@ -50,7 +50,7 @@ If you would like to install JAX running on the `cpu` only you can use `pip inst
One of the attractive features of JAX is that, whenever possible, its array
processing operations conform to the NumPy API.

This means that, in many cases, we can use JAX is as a drop-in NumPy replacement.
This means that, in many cases, we can use JAX as a drop-in NumPy replacement.

Let's look at the similarities and differences between JAX and NumPy.

Expand Down Expand Up @@ -199,7 +199,7 @@ a, a_new
```

The designers of JAX chose to make arrays immutable because JAX uses a
*functional programming style*.
[functional programming](https://en.wikipedia.org/wiki/Functional_programming) style.

This design choice has important implications, which we explore next!

Expand Down Expand Up @@ -241,19 +241,19 @@ In other words, JAX assumes a functional programming style.

The major implication is that JAX functions should be pure.

**Pure functions** have the following characteristics:
[Pure functions](https://en.wikipedia.org/wiki/Pure_function) have the following characteristics:

1. *Deterministic*
2. *No side effects*

**Deterministic** means
[Deterministic](https://en.wikipedia.org/wiki/Deterministic_algorithm) means

* Same input $\implies$ same output
* Outputs do not depend on global state

In particular, pure functions will always return the same result if invoked with the same inputs.

**No side effects** means that the function
[No side effects](https://en.wikipedia.org/wiki/Side_effect_(computer_science)) means that the function

* Won't change global state
* Won't modify data passed to the function (immutable data)
Expand Down Expand Up @@ -307,7 +307,7 @@ At first you might find the syntax rather verbose.
But you will soon realize that the syntax and semantics are necessary in order
to maintain the functional programming style we just discussed.

Moreover, full control of random state
Moreover, full control of random state is
essential for parallel programming, such as when we want to run independent experiments along multiple threads.


Expand Down Expand Up @@ -793,8 +793,8 @@ We defer further exploration of automatic differentiation with JAX until {doc}`j
:label: jax_intro_ex2
```

In the Exercise section of {doc}`our lecture on Numba <numba>`, we used Monte
Carlo to price a European call option.
In the Exercise section of {doc}`our lecture on Numba <numba>`, we {ref}`used Monte
Carlo to price a European call option <numba_ex4>`.

The code was accelerated by Numba-based multithreading.

Expand Down
Loading