<a href="https://colab.research.google.com/github/AmandinChyba/Project1-Fields-2022/blob/main/Walk_on_Boundary_and_BIEM/WalkOnBoundary.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [664]:
import jax
from jax import lax
from jax import random

import jax.numpy as jnp
import numpy as np

import time

import matplotlib.pyplot as plt
import pandas as pd

In [665]:
'''
signed distance function of the domain
'''
@jax.jit
def signDistance(p):
  return jnp.linalg.norm(p,ord=2,axis=1,keepdims=True) - 1 # unit sphere

In [666]:
'''
bounding box diameter
'''
@jax.jit
def box():
  return 2

In [667]:
'''
Dirchlet boundary condition
'''
@jax.jit
def boundaryCond(p):
  #return 1.0*(p[1] > 0.0)
  k = 1
  pointCharge = jnp.array([[0,0,1.1]])
  pointCharge = jnp.tile(pointCharge, (1, jnp.size(p,0)))
  pointCharge = pointCharge.reshape(jnp.size(p,0), jnp.size(p,1), 1) # [solutions, dimensions, samples]
  pDiff = jnp.linalg.norm(p - pointCharge, ord=2, axis=1)
  return jnp.exp(-k * pDiff) / (4 * jnp.pi * pDiff)

In [668]:
'''
the true solution
'''
def trueSoln(p):
  #return 0.5 + np.arctan(2*p[1]/(1-np.power(p[0],2)-np.power(p[1],2)))/np.pi
  k = 1
  pointCharge = jnp.array(jnp.size(p,0)*[0,0,1.1]).reshape(jnp.size(p,0),jnp.size(p,1)) # [solutions, dimensions]
  pDiff = jnp.linalg.norm(p - pointCharge, ord=2, axis=1)
  return jnp.exp(-k * pDiff) / (4 * jnp.pi * pDiff)

In [669]:
'''
binary search to find boundary intersection
'''
@jax.jit
def binaryRootSearch(p0, p1, rootSteps, dir):
  p1 = p0 + p1 # need to center p1 at p0

  current = jnp.zeros((jnp.size(p0, 0), 1, jnp.size(p0, 2))) # [solutions, dimensions, samples]

  def binaryStep(i, current):
    next = current + jnp.power(0.5,i+1)
    dist = signDistance(p0 + dir * next * (p1 - p0))
    return jnp.where(dist > 0, current, next)

  current = lax.fori_loop(0, rootSteps, binaryStep, current)
  
  
  return p0 + dir * current * (p1 - p0), current # return the roots, t-values

In [670]:
'''
run the wob and calculate the estimator
'''
@jax.jit
def wob(p0, p1, t, rootSteps):
  k = 1
  runningEst = jnp.zeros((jnp.size(p0,0),jnp.size(p0,2))) # [solutions, samples]
  runningSign = 1
  
  pNext, tFirst = binaryRootSearch(p0, p1[0], rootSteps, 1)
  pDiff = jnp.linalg.norm(p0 - pNext, ord=2, axis=1)
  qNext = jnp.exp(-k * pDiff) + k * pDiff
  pBefore = pNext

  for p in p1[1:]:
    runningEst += runningSign * boundaryCond(pNext) * qNext
    runningSign *= -1
    
    pBack, tBack = binaryRootSearch(pNext, p, rootSteps, -1)
    pFor, tFor = binaryRootSearch(pNext, p, rootSteps, 1)
    pNext = jnp.where(tBack - tFor > 0, pBack, pFor)
    
    pDiff = jnp.linalg.norm(pBefore - pNext, ord=2, axis=1)
    qNext = qNext * jnp.exp(-k * pDiff) + k * pDiff
    pBefore = pNext

  runningEst *= 2
  runningEst += runningSign * boundaryCond(pNext) * qNext
  return runningEst

In [671]:
'''
setup the wob
'''
def computeSoln(t, rootSteps, samples, p, key):
  p = jnp.reshape(p, (jnp.size(p,0), jnp.size(p,1), 1))
  p0 = jnp.tile(p, (1,1,samples)) # [solutions, dimensions, samples]

  # generate random directions in n-dimensions
  p1 = random.normal(key, shape=(t, jnp.size(p, 0), jnp.size(p, 1), samples))
  normalize = jnp.linalg.norm(p1, ord=2, axis=1, keepdims=True)
  p1 = box() * jnp.divide(p1, normalize) # [jumps, solutions, dimensions, samples]

  ans = wob(p0, p1, t, rootSteps)
  return jnp.mean(ans, 1, keepdims=True), jnp.std(ans, 1, keepdims=True)

In [685]:
%%time
'''
testing code
'''
key = random.PRNGKey(int(time.time()))
x = jnp.array([[0,0,0.5], [0,0,0.3]])
#x = jnp.array([[0,0,0.5]])

y, sd = computeSoln(5, 15, 100000, x, key)
print('solution:', y)
print('exact:',trueSoln(x))

solution: [[0.07236434]
 [0.04121904]]
exact: [0.07278839 0.04469558]
CPU times: user 295 ms, sys: 17.1 ms, total: 312 ms
Wall time: 196 ms


In [None]:
'''
generate dataset with throw-out algorithm
WARNING: this algorithm gets worse in higher dimensions
'''
def generateData():
  key = random.PRNGKey(0)
  dimensions = 2
  samples = 10000
  x = random.uniform(key, minval=-box()/2, maxval=box()-box()/2, shape=(samples,dimensions))
  indices = jnp.where(signDistance(x.transpose()) < 0)
  print('kept', len(x[indices]), 'out of', len(x))
  x = x[indices]

  batchSize = 100
  y, std = computeSoln(10, 15, 100000, x[:batchSize], key)
  for i in range(batchSize, len(x), batchSize):
    mean, std = computeSoln(5, 10, 100000, x[i:i+batchSize], key)
    y = np.concatenate((y, mean), axis=0)
  
  return x, y

x, y = generateData()

In [None]:
'''
save the dataset
'''
def saveData(x,y):
  np.savez('wob_piecwise_data', x=x, y=y)
  #data = jnp.concatenate((y,x), axis=1)
  #df = pd.DataFrame(data)
  #print(df)
  #df.to_pickle('wobDataset.pkl')

saveData(x,y)

In [None]:
'''
download the dataset
'''
def downloadData():
  data = np.load('wob_piecwise_data.npz')
  #df = pd.read_pickle('wobDataset.pkl')
  #y = df.iloc[:,0].to_numpy() # [u0, u1, ..., un]
  #x = df.iloc[:,1:].to_numpy() # [p0, p1, ..., pn]
  return data['x'], data['y']

In [None]:
'''
plot the data
'''
t = 10
rootSteps = 25
rootSteps = np.arange(25)
#batches = np.power(10,np.arange(1,7))
batches = 100000
y = jnp.array([[0.5,0.3]])

key = random.PRNGKey(int(time.time()))
keys = random.split(key, len(rootSteps))

values = np.zeros(len(rootSteps))
sdValues = np.zeros(len(rootSteps))
for i in range(len(rootSteps)):
  values[i], sdValues[i] = computeSoln(t, rootSteps[i], batches, y, keys[i])


upperC = values + 1.96 * sdValues / np.sqrt(batches)
lowerC = values - 1.96 * sdValues / np.sqrt(batches)
#error = np.abs(values - 0.3)
plt.plot(rootSteps, values-trueSoln(y[0]))
#plt.xscale('log')
#plt.yscale('log')
plt.fill_between(rootSteps, lowerC-trueSoln(y[0]), upperC-trueSoln(y[0]), color='b', alpha=.1)