<a href="https://colab.research.google.com/github/AmandinChyba/Project1-Fields-2022/blob/main/Walk_on_Boundary_and_BIEM/WalkOnBoundary.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
import jax
from jax import lax
from jax import random

import jax.numpy as jnp
import numpy as np

import time

import matplotlib.pyplot as plt

from functools import partial

In [15]:
@jax.jit
def unit_circle_distance(p): # signed distance function for 2d unit circle
  return jnp.linalg.norm(p,ord=2,axis=0) - 1

In [491]:
@jax.jit
def boundaryCond(p):
  return p[1]

In [488]:
@jax.jit
def binaryRootSearch(p0, p1, epochs, dir):
  p1 = p0 + p1 # need to center p1 at p0

  current = jnp.zeros((jnp.size(p0, 1), jnp.size(p0, 2)))

  def binaryStep(i, current):
    next = current + jnp.power(0.5,i+1)
    dist = unit_circle_distance(p0 + dir * next * (p1 - p0)) # need to specify signed distance function
                                                             # cause its not a JAX-type
    return jnp.where(dist > 0, current, next)

  current = lax.fori_loop(0, epochs, binaryStep, current)
  
  return p0 + dir * current * (p1 - p0), current # return the roots, t-values

In [489]:
@jax.jit
def wob(p0, p1, t, epochs):
  runningEst = jnp.zeros((jnp.size(p0,1),jnp.size(p0,2)))
  runningSign = 1

  pNext, tFirst = binaryRootSearch(p0, p1[0], epochs, 1)

  for p in p1[1:]:
    runningEst += runningSign * boundaryCond(pNext)
    runningSign *= -1
    pBack, tBack = binaryRootSearch(pNext, p, epochs, -1)
    pFor, tFor = binaryRootSearch(pNext, p, epochs, 1)
    pNext = jnp.where(tBack - tFor > 0, pBack, pFor)

  runningEst *= 2
  runningEst += runningSign * boundaryCond(pNext)
  return runningEst

In [480]:
def computeSoln(t, epochs, batches, p, key):
  p0 = jnp.tile(p, (1,1,batches))

  # generate random directions in n-dimensions
  diameter = 2
  p1 = random.normal(key, shape=(t, jnp.size(p, 0), jnp.size(p, 1), batches))
  normalize = jnp.linalg.norm(p1, ord=2, axis=1, keepdims=True)
  p1 = diameter * jnp.divide(p1, normalize)

  ans = wob(p0, p1, t, epochs)
  return jnp.mean(ans, 1)

In [503]:
key = random.PRNGKey(int(time.time()))
x = jnp.array([[0.5,0,0.2],[0,0.15,0.7]]).transpose()
#x = jnp.array([[0.5,0]]).transpose()
x = jnp.reshape(x, (jnp.size(x,0), jnp.size(x,1), 1))

In [508]:
%%time
print(computeSoln(2, 15, 100000, x, key))

[-0.00573833  0.16755617]
CPU times: user 99.6 ms, sys: 9.36 ms, total: 109 ms
Wall time: 85.9 ms
