<a href="https://colab.research.google.com/github/AmandinChyba/Project1-Fields-2022/blob/main/Walk_on_Boundary_and_BIEM/WalkOnBoundary_proto.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
from jax import lax
from jax import random

import jax.numpy as jnp
import numpy as np

import time

import matplotlib.pyplot as plt

from functools import partial

In [43]:
def unit_circle_boundary(p): # p is a point s.t. p[0]=x and p[1]=y
  return jnp.sqrt(jnp.power(p[0], 2) + jnp.power(p[1], 2)) - 1

In [11]:
def getMiddle(p0, p1):
  return jnp.array([(p1[0]-p0[0])/2 + p0[0], (p1[1]-p0[1])/2 + p0[1]])

In [91]:
def binaryStep(i, bounds, distanceF):
  mid = getMiddle(bounds[0],bounds[1])
  dist = jnp.where(distanceF(mid) > 0, 1, 0)
  return jnp.array([mid*(1-dist) + bounds[0]*dist, mid*dist + bounds[1]*(1-dist)])

In [92]:
def binaryRootSearch(point, angle, binaryFunc, iterations, dir):
  epsilon = 0.01
  diameter = 2 + epsilon

  lower = point
  upper = jnp.array([point[0] + dir * diameter * jnp.cos(angle), 
                     point[1] + dir * diameter * jnp.sin(angle)]) # optimize this
  
  bounds = jnp.array([lower,upper])

  bounds = lax.fori_loop(0, iterations, binaryFunc, bounds)

  root = getMiddle(bounds[0],bounds[1])
  return root

In [12]:
def getDistance(p0, p1):
  return jnp.array([jnp.power(p1[0]-p0[0],2) + jnp.power(p1[1]-p0[1],2)]) # distance squared

In [231]:
def wobStep(pNext, angle, binaryFunc, rootSteps):
    p0 = binaryRootSearch(pNext, angle, binaryFunc, rootSteps, 1)
    p1 = binaryRootSearch(pNext, angle, binaryFunc, rootSteps, -1)
    pNext = jnp.where(getDistance(pNext,p0) - getDistance(pNext,p1) > 0, p0, p1)
    return pNext, pNext

In [218]:
def wob(g, distanceF, x0, y0, t, rootSteps, angles):
  binaryFunc = partial(binaryStep, distanceF=distanceF)
  wobFunc = partial(wobStep, binaryFunc=binaryFunc, rootSteps=rootSteps)

  pInit = jnp.array([x0,y0])
  pNext = binaryRootSearch(pInit, angles[0], binaryFunc, rootSteps, 1) 
  final, result = lax.scan(wobFunc, pNext, angles[1:])
  result = jnp.concatenate((jnp.array([pNext]), result), axis=0)

  est = np.resize(np.array([1,-1]), (t,1))
  est[:-1] = 2*est[:-1]
  est = jnp.tile(jnp.array(est), (1,batches))
  
  gVal = jnp.apply_along_axis(g, 1, result)
  
  est = jnp.sum(est * gVal, axis=0)
  return jnp.mean(est)

In [None]:
%%time
t = 50 # number of jumps on the boundary
rootSteps = 25 # binary search search steps
batches = 100000

g = lambda p : p[1] # p is a point s.t. p[0]=x and p[1]=y
distanceF = unit_circle_boundary

x0 = np.array([0]*batches)
y0 = np.array([-0.3]*batches)

key = random.PRNGKey(1)
angles = random.uniform(key, shape=(t, batches)) * 2 * jnp.pi

ans = wob(g, distanceF, x0, y0, t, rootSteps, angles)

mean = jnp.mean(ans)
print(mean)