<a href="https://colab.research.google.com/github/ZINZINBIN/JAX_example/blob/main/EstimatingPI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Task 1: Estimating $\pi$



In [1]:
# Estimating PI
from random import random

import jax
import jax.numpy as jnp

def pi_python():
  npoints = 100_000
  xy = [(random(), random()) for _ in range(npoints)]

  inside = 0
  for x, y in xy:
    if x**2 + y**2 < 1: # jax doesn't cover this logic condition normally (vectorizaiton needed)
      inside += 1.0

  return inside * 4.0 / npoints


print(f"pi ~= {pi_python()}")

print("\nPure Python:")
%timeit -n 10 pi_python()

pi ~= 3.15408

Pure Python:
59.8 ms ± 6.74 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [20]:
# Solution
import jax
import jax.numpy as jnp
from typing import List, Tuple

def compute_inside(xy:List[Tuple]):
    dist_l2 = jnp.sum(jnp.array(xy) ** 2, axis = 1)
    inside = jnp.sum(dist_l2 < 1)
    return inside

def pi_python_jit():
    rng = jax.random.key(42)
    npoints = 100_000
    xy = jax.random.uniform(rng, shape = (2,npoints)) # random should be also replaced!

    # my method
    # inside = compute_inside(xy)

    # solution
    inside = jnp.sum((xy[0,...] ** 2 + xy[1,...] ** 2) < 1)

    return inside * 4.0 / npoints

print(f"pi ~= {pi_python_jit()}")

print("\nPython with JIT (with compile-time):")
%timeit -n 1 pi_python_jit().block_until_ready()

print("\nPython with JIT (without compile-time):")
%timeit -n 10 pi_python_jit().block_until_ready()

pi ~= 3.142359972000122

Python with JIT (with compile-time):
3.41 ms ± 179 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

Python with JIT (without compile-time):
3.41 ms ± 166 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
