Task 1: Estimating $\pi$



In [1]:
# Estimating PI
from random import random

import jax
import jax.numpy as jnp


import numpy as np

def pi_python():
  npoints = 100_000
  xy = [(random(), random()) for _ in range(npoints)]

  inside = 0
  for x, y in xy:
    if x**2 + y**2 < 1:
      inside += 1.0

  return inside * 4.0 / npoints

def pi_np():
  npoints = 100_000
  xy = np.random.uniform(size=(npoints, 2))

  inside = np.sum(np.sum(xy**2, axis=1) < 1)

  return inside * 4.0 / npoints

def pi_jax():
  npoints = 100_000
  rng = jax.random.key(42)
  x_key, y_key = jax.random.split(rng)

  x = jax.random.uniform(x_key, shape=npoints, minval=0, maxval=1)
  y = jax.random.uniform(y_key, shape=npoints, minval=0, maxval=1)

  return jax.sum(x**2 + y**2 < 1) * (4.0 / npoints)

print(f"pi ~= {pi_python()}")
print(f"pi ~= {pi_np()}")
print(f"pi ~= {pi_jax()}")

print("\nPure Python:")
%timeit -n 10 pi_python()

print("\nNumpy Python:")
%timeit -n 10 pi_np()

print("\nJax Python:")
%timeit -n 10 pi_jax()

pi ~= 3.14032
pi ~= 3.14292

Pure Python:
63.3 ms ± 10.5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Numpy Python:
7.45 ms ± 4.19 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
# Solution