<a href="https://colab.research.google.com/github/AmandinChyba/Project1-Fields-2022/blob/main/Walk_on_Boundary_and_BIEM/adamstepWOB.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
from jax import lax
from jax import random

import jax.numpy as jnp
import numpy as np

import time

import matplotlib.pyplot as plt
import pandas as pd

In [2]:
'''
signed distance function of the domain
'''
@jax.jit
def signDistance(p):
  return jnp.linalg.norm(p,ord=2,axis=0) - 1 # unit sphere

In [3]:
'''
bounding box diameter
'''
@jax.jit
def box():
  return 2

In [4]:
'''
Dirchlet boundary condition
'''
@jax.jit
def boundaryCond(p):
  return 1.0*(p[1] > 0.0)

In [5]:
'''
the true solution
'''
def trueSoln(p):
  return 0.5 + np.arctan(2*p[1]/(1-np.power(p[0],2)-np.power(p[1],2)))/np.pi

In [6]:
'''
binary search to find boundary intersection
'''
#@jax.jit
def binaryRootSearch(p0, p1, rootSteps, dir):
  p1 = p0 + p1 # need to center p1 at p0

  t = (jnp.zeros((jnp.size(p0, 1), jnp.size(p0, 2))), 
       jnp.ones((jnp.size(p0, 1), jnp.size(p0, 2))))
  #print("t: ", t)
  next = t[0] + 0.5
  #print("next: ", p0 + dir * next * (p1 - p0))
  dist = signDistance(p0 + dir * next * (p1 - p0))
  #print("dist: ", dist)
  t = (jnp.where(dist > 0, t[0], next), 
       jnp.where(dist < 0, t[1], next))

  for i in range(0, rootSteps):
    #print("t: ", t)
    v0 = signDistance(p0 + dir * t[0] * (p1 - p0))
    v1 = signDistance(p0 + dir * t[1] * (p1 - p0))
    next = (t[0] * v1 - t[1] * v0) / (v1 - v0)
    #print("next: ", p0 + dir * next * (p1 - p0))
    dist = signDistance(p0 + dir * next * (p1 - p0))
    #print("dist: ", dist)
    old_t0 = t[0]
    old_t1 = t[1]
    new_t0 = jnp.where(dist > 0, t[0], next)
    new_t0 = jnp.where(jnp.isnan(new_t0), old_t0, new_t0)
    new_t1 = jnp.where(dist < 0, t[1], next)
    new_t1 = jnp.where(jnp.isnan(new_t1), old_t1, new_t1)
    t = (new_t0, new_t1)
    #return t

  #t = lax.fori_loop(1, rootSteps, binaryStep, t)

  #print("p0: ", p0)
  #print("p1: ", p1)
  #print("p: ", p0 + dir * t[0] * (p1 - p0))
  
  return p0 + dir * t[0] * (p1 - p0), t[0] # return the roots, t-values

In [7]:
'''
run the wob and calculate the estimator
'''
#@jax.jit
def wob(p0, p1, t, rootSteps):
  '''
  pFirst, tFirst = binaryRootSearch(p0, p1[0], rootSteps, 1)
  def wobStep(i, next):
    est = next[1] + 2 * jnp.power(-1,i+1) * boundaryCond(next[0])

    pBack, tBack = binaryRootSearch(next[0], p1[i], rootSteps, -1)
    pFor, tFor = binaryRootSearch(next[0], p1[i], rootSteps, 1)
    pNext = jnp.where(tBack > tFor, pBack, pFor)
    return (pNext, est)

  est = jnp.zeros((jnp.size(p1,2),jnp.size(p1, 3)))
  pLast, est = lax.fori_loop(1, t, wobStep, (pFirst, est))
  return est + jnp.power(-1, t+1) * boundaryCond(pLast)
  '''
  runningEst = jnp.zeros((jnp.size(p0,1),jnp.size(p0,2)))
  runningSign = 1

  pNext, tFirst = binaryRootSearch(p0, p1[0], rootSteps, 1)
  bad_ps = 0

  for p in p1[1:]:
    runningEst += runningSign * boundaryCond(pNext)
    runningSign *= -1
    pBack, tBack = binaryRootSearch(pNext, p, rootSteps, -1)
    pFor, tFor = binaryRootSearch(pNext, p, rootSteps, 1)
    pNext = jnp.where(tBack - tFor > 0, pBack, pFor)
    bad_ps += jnp.sum(jnp.where(tBack == tFor, 1, 0))

  print(bad_ps)

  runningEst *= 2
  runningEst += runningSign * boundaryCond(pNext)
  print(pNext.shape)
  return runningEst

In [8]:
'''
setup the wob
'''
def computeSoln(t, rootSteps, batches, p, key):
  p = p.transpose()
  p = jnp.reshape(p, (jnp.size(p,0), jnp.size(p,1), 1))
  p0 = jnp.tile(p, (1,1,batches))

  # generate random directions in n-dimensions
  p1 = random.normal(key, shape=(t, jnp.size(p, 0), jnp.size(p, 1), batches))
  normalize = jnp.linalg.norm(p1, ord=2, axis=1, keepdims=True)
  p1 = box() * jnp.divide(p1, normalize)

  ans = wob(p0, p1, t, rootSteps)
  return jnp.mean(ans, 1, keepdims=True)

In [11]:
%%time
'''
testing code
'''
key = random.PRNGKey(int(time.time()))
x = jnp.array([[0.5,0.3]])

y = computeSoln(10, 15, 100000, x, key)
print('solution:', y)
print('exact:', trueSoln(x[0]))

231795
(2, 1, 100000)
solution: [[0.70818996]]
exact: 0.7348538347994082
CPU times: user 907 ms, sys: 387 ms, total: 1.29 s
Wall time: 887 ms
