<a href="https://colab.research.google.com/github/AmandinChyba/Project1-Fields-2022/blob/main/Walk_on_Boundary_and_BIEM/WalkOnBoundary.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [125]:
import jax
from jax import lax
from jax import random

import jax.numpy as jnp
import numpy as np

import time

import matplotlib.pyplot as plt

from functools import partial

In [236]:
def unit_circle_boundary(p): # signed distance function for 2d unit circle
  return jnp.linalg.norm(p,ord=2,axis=0) - 1

In [247]:
'''
def binaryStep(i, bounds, distanceF):
  mid = getMiddle(bounds[0],bounds[1])
  dist = jnp.where(distanceF(mid[0],mid[1]) > 0, 1, 0)
  return jnp.array([mid*(1-dist) + bounds[0]*dist, mid*dist + bounds[1]*(1-dist)])
'''

'\ndef binaryStep(i, bounds, distanceF):\n  mid = getMiddle(bounds[0],bounds[1])\n  dist = jnp.where(distanceF(mid[0],mid[1]) > 0, 1, 0)\n  return jnp.array([mid*(1-dist) + bounds[0]*dist, mid*dist + bounds[1]*(1-dist)])\n'

In [248]:
def binaryRootSearch(p0, p1, epochs, dir):
  
  # need to center p1 at p0
  p1 = p0 + p1

  current = jnp.zeros((1, jnp.size(p0, axis=1)))

  for i in range(epochs):
    next = current + jnp.power(0.5,i+1)
    dist = unit_circle_boundary(p0 + dir * next * (p1 - p0))
    current = jnp.where(dist > 0, current, next)
  
  return p0 + dir * current * (p1 - p0), current # return the roots, t-values
  '''
  lower = p
  upper = jnp.array([p[0] + dir * diameter * jnp.cos(angle), 
                     p[1] + dir * diameter * jnp.sin(angle)])
  
  bounds = jnp.array([lower,upper])
  bounds = lax.fori_loop(0, epochs, binaryFunc, bounds)

  root = getMiddle(bounds[0],bounds[1])
  return root
  '''

In [8]:
def getDistance(p0, p1):
  return jnp.array([jnp.power(p1[0]-p0[0],2) + jnp.power(p1[1]-p0[1],2)]) # distance squared

In [259]:
def wob(g, distanceF, p0, p1, t, epochs):

  runningEst = 0
  runningSign = 1

  pNext, tFirst = binaryRootSearch(p0, p1[0], epochs, 1) 

  for p in p1[1:]:
    runningEst += runningSign * g(pNext)
    runningSign *= -1

    pBack, tBack = binaryRootSearch(pNext, p, epochs, -1)
    pFor, tFor = binaryRootSearch(pNext, p, epochs, 1)
    pNext = jnp.where(tBack - tFor > 0, pBack, pFor)

  runningEst *= 2
  runningEst += runningSign * g(pNext)

  return runningEst

In [218]:
def computeSoln(t, epochs, batches, p, key):
  g = lambda p : p[1] # p is a point s.t. p[0]=x and p[1]=y
  distanceF = unit_circle_boundary

  p0 = jnp.full((len(p), batches),p)

  diameter = 2
  p1 = random.normal(key, shape=(t, len(p), batches))
  normalize = jnp.linalg.norm(p1, ord=2, axis=1).reshape((t, 1, batches))
  p1 = diameter * jnp.divide(p1, normalize)

  ans = wob(g, distanceF, p0, p1, t, epochs)

  return jnp.mean(ans)

In [264]:
%%time
key = random.PRNGKey(1)
x = jnp.array([[0.3,0.8]]).transpose()
print(computeSoln(3, 10, 1000000, x, key))

0.79831785
CPU times: user 915 ms, sys: 60.2 ms, total: 976 ms
Wall time: 2.09 s


In [189]:
%%time
t = 2 # number of jumps on the boundary
epochs = 10 # binary search time steps
batches = np.power(10,np.arange(1,7))
x = 0.5
y = 0.3

key = random.PRNGKey(int(time.time()))
keys = random.split(key, len(batches))

values = np.zeros(len(batches))
for i in range(len(batches)):
  values[i] = computeSoln(t, epochs, batches[i], x, y, keys[i])

error = np.abs(values - y)
plt.plot(batches, error)
plt.xscale('log')
plt.yscale('log')
plt.xlabel('number of batches')
plt.ylabel('absolute error')
plt.title('Error over batches for (0.5, 0.3)')

TypeError: ignored