<a href="https://colab.research.google.com/github/AmandinChyba/Project1-Fields-2022/blob/main/Walk_on_Boundary_and_BIEM/WalkOnBoundary.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import jax
from jax import lax
from jax import random

import jax.numpy as jnp
import numpy as np

import time

import matplotlib.pyplot as plt

from functools import partial

In [62]:
@jax.jit
def unit_circle_distance(p): # signed distance function for 2d unit circle
  return jnp.linalg.norm(p,ord=2,axis=0) - 1

In [64]:
@jax.jit
def boundaryCond(p):
  return p[1]

In [63]:
@jax.jit
def binaryRootSearch(p0, p1, epochs, dir):
  p1 = p0 + p1 # need to center p1 at p0

  current = jnp.zeros((1, jnp.size(p0, axis=1)))

  def binaryStep(i, current):
    next = current + jnp.power(0.5,i+1)
    dist = unit_circle_distance(p0 + dir * next * (p1 - p0)) # need to specify signed distance function
                                                             # cause its not a JAX-type
    return jnp.where(dist > 0, current, next)
  
  current = lax.fori_loop(0, epochs, binaryStep, current)
  
  return p0 + dir * current * (p1 - p0), current # return the roots, t-values

In [144]:
@jax.jit
def wob(p0, p1, t, epochs):
  runningEst = 0
  runningSign = 1

  pNext, tFirst = binaryRootSearch(p0, p1[0], epochs, 1)
  '''
  def wobStep(pNext, p):
    pBack, tBack = binaryRootSearch(pNext, p, epochs, -1)
    pFor, tFor = binaryRootSearch(pNext, p, epochs, 1)
    pNext = jnp.where(tBack - tFor > 0, pBack, pFor)
    return pNext, pNext
  final, result = lax.scan(wobStep, pNext, p1[1:])
  '''
  
  for p in p1[1:]:
    runningEst += runningSign * boundaryCond(pNext)
    runningSign *= -1 
    pBack, tBack = binaryRootSearch(pNext, p, epochs, -1)
    pFor, tFor = binaryRootSearch(pNext, p, epochs, 1)
    pNext = jnp.where(tBack - tFor > 0, pBack, pFor)

  runningEst *= 2
  runningEst += runningSign * boundaryCond(pNext)
  return runningEst
  #est = np.resize(np.array([1,-1]), (t,1))
  #est[:-1] = 2*est[:-1]
  #est = jnp.tile(jnp.array(est), (1,batches))
  #print(est.shape)
  #gVal = jnp.apply_along_axis(boundaryCond, 1, result)
  #est = jnp.sum(est * gVal, axis=0)
  #return jnp.mean(est)

In [154]:
def computeSoln(t, epochs, batches, p, key):
  p0 = jnp.full((len(p), batches),p)

  # generate random directions in n-dimensions
  diameter = 2
  p1 = random.normal(key, shape=(t, len(p), batches))
  normalize = jnp.linalg.norm(p1, ord=2, axis=1).reshape((t, 1, batches))
  p1 = diameter * jnp.divide(p1, normalize)

  ans = wob(p0, p1, t, epochs)
  return jnp.mean(ans)

In [155]:
%%time
key = random.PRNGKey(int(time.time()))
x = jnp.array([[0.3,0.8]]).transpose()
print(computeSoln(5, 15, 100000, x, key))

TypeError: ignored

In [66]:
%%time
t = 2 # number of jumps on the boundary
epochs = 10 # binary search time steps
batches = np.power(10,np.arange(1,7))
x = 0.5
y = 0.3

key = random.PRNGKey(int(time.time()))
keys = random.split(key, len(batches))

values = np.zeros(len(batches))
for i in range(len(batches)):
  values[i] = computeSoln(t, epochs, batches[i], x, y, keys[i])

error = np.abs(values - y)
plt.plot(batches, error)
plt.xscale('log')
plt.yscale('log')
plt.xlabel('number of batches')
plt.ylabel('absolute error')
plt.title('Error over batches for (0.5, 0.3)')

TypeError: ignored