<a href="https://colab.research.google.com/github/AmandinChyba/Project1-Fields-2022/blob/main/Walk_on_Boundary_and_BIEM/WalkOnBoundary.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
from jax import lax
from jax import random

import jax.numpy as jnp
import numpy as np

import time

import matplotlib.pyplot as plt

from functools import partial

In [2]:
def unit_circle_boundary(x, y):
  return jnp.sqrt(jnp.power(x, 2) + jnp.power(y, 2)) - 1

In [3]:
def getMiddle(p0, p1):
  return jnp.array([(p1[0]-p0[0])/2 + p0[0], (p1[1]-p0[1])/2 + p0[1]])

In [21]:
def binaryStep(epoch, bounds, distanceF):
  mid = getMiddle(bounds[0],bounds[1])
  dist = jnp.where(distanceF(mid[0],mid[1]) > 0, 1, 0)
  return jnp.array([mid*(1-dist) + bounds[0]*dist, mid*dist + bounds[1]*(1-dist)])

In [22]:
def binaryRootSearch(point, angle, binaryFunc, epochs, dir):
  epsilon = 0.01
  diameter = 2 + epsilon

  lower = point
  upper = jnp.array([point[0] + dir * diameter * jnp.cos(angle), 
                     point[1] + dir * diameter * jnp.sin(angle)])
  
  bounds = jnp.array([lower,upper])
  bounds = lax.fori_loop(0, epochs, binaryFunc, bounds)

  root = getMiddle(bounds[0],bounds[1])
  return root

In [7]:
def getDistance(p0, p1):
  return jnp.array([jnp.power(p1[0]-p0[0],2) + jnp.power(p1[1]-p0[1],2)]) # distance squared

In [13]:
def wob(g, distanceF, x0, y0, t, epochs, angles):
  binaryFunc = partial(binaryStep, distanceF=distanceF)

  runningEst = 0
  runningSign = 1

  pInit = jnp.array([x0,y0])
  pNext = binaryRootSearch(pInit, angles[0], binaryFunc, epochs, 1) 

  for angle in angles[1:]:
    runningEst += runningSign * g(pNext)
    runningSign *= -1

    p0 = binaryRootSearch(pNext, angle, binaryFunc, epochs, -1)
    p1 = binaryRootSearch(pNext, angle, binaryFunc, epochs, 1)
    pNext = jnp.where(getDistance(pNext,p0) - getDistance(pNext,p1) > 0, p0, p1)

  runningEst *= 2
  runningEst += runningSign * g(pNext)

  return runningEst

In [23]:
%%time
t = 15 # number of jumps on the boundary
epochs = 25 # binary search time steps
batches = 100000

g = lambda p : p[1] # p is a point s.t. p[0]=x and p[1]=y
distanceF = unit_circle_boundary

x0 = np.array([0.8]*batches)
y0 = np.array([0.5]*batches)

key = random.PRNGKey(0)
angles = random.uniform(key, shape=(t, batches)) * 2 * jnp.pi

ans = wob(g, distanceF, x0, y0, t, epochs, angles)

mean = jnp.mean(ans)
print(mean)

0.50634384
CPU times: user 1.73 s, sys: 52.2 ms, total: 1.78 s
Wall time: 1.45 s
