<a href="https://colab.research.google.com/github/AmandinChyba/Project1-Fields-2022/blob/main/laplace_wob_exterior.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [37]:
import jax
from jax import lax
from jax import random as jrandom

import jax.numpy as jnp
import numpy as np

import time

import matplotlib.pyplot as plt
import pandas as pd

# Boundary Functions

In [38]:
'''
signed distance function of the domain
'''
#@jax.jit
def signDistance(p, axis=1):
  return jnp.linalg.norm(p,ord=2,axis=axis,keepdims=True) - 1 # unit sphere

In [39]:
'''
Dirchlet boundary condition
'''
#@jax.jit
def boundaryCond(p, k, axis=1):
  # exterior
  return trueSoln(p, k, axis)

  # interior
  #pointCharge = jnp.array([0,0,2])
  #pointCharge = jnp.tile(pointCharge, (jnp.size(p,0), 1)) # [samples, dimensions]
  #pDiff = jnp.linalg.norm(p - pointCharge, ord=2, axis=1, keepdims=False)
  #final = jnp.exp(-k * pDiff) / (4 * jnp.pi * pDiff)
  #return final

In [40]:
'''
the true solution
'''
#@jax.jit
def trueSoln(p, k, axis=1):
  # exterior
  p_norm = jnp.linalg.norm(p, ord=2, axis=axis, keepdims=False)
  return jnp.exp(-k * p_norm) / (4 * jnp.pi * p_norm)

# WOB algorithm

In [41]:
'''
binary search to find boundary intersection
'''
#@jax.jit
def bisection(p0, p1, root_steps, dir=1):
  current = jnp.zeros((jnp.size(p0, 0), 1)) # [dimensions, samples]

  def bisectionStep(i, current):
    next = current + jnp.power(0.5,i+1)
    dist1 = signDistance(p0 + dir * next * (p1 - p0))
    dist2 = signDistance(p0 + dir * current * (p1 - p0))
    return jnp.where(dist1 * dist2 > 0, next, current)

  current = lax.fori_loop(0, root_steps, bisectionStep, current)
  return p0 + dir * current * (p1 - p0), current # return [roots, t-values]

In [42]:
'''
run the wob and calculate the estimator
'''
#@jax.jit
def wob(p0, p1, t, root_steps, k):
  markov_chain = np.zeros((p1.shape)) # [jumps-1, samples, dimensions]
  pNext = p0

  for i in range(len(p1)):
    p_angle = p1[i] - pNext # re-center p1 at pNext
    
    pBack, tBack = bisection(pNext, p_angle, root_steps, -1)
    pFor, tFor = bisection(pNext, p_angle, root_steps, 1)
    pNext = jnp.where(tBack - tFor > 0, pBack, pFor)
    markov_chain[i] = pNext

  return markov_chain

In [43]:
#@partial(jax.jit, static_argnames=['ray'])
def rootIsolate(p0, p1, root_steps):
  # create grid
  lb = jnp.zeros((len(p0), 1))
  ub = jnp.ones((len(p0), 1))
  grid = jnp.linspace(lb, ub, 100, axis=0)

  # find the roots
  v = signDistance(p0 + grid * (p1 - p0), axis=2)
  v = jnp.where(v > 0, 1, -1)

  # find the roots
  roots = v[:-1] + v[1:]

  # check how many roots
  n_zeros = jnp.count_nonzero(roots == 0, axis=0)

  # if n_zeros < 2 then act like the line didn't hit the boundary
  n_zeros = jnp.where(n_zeros < 2, 0, 2)
  roots = jnp.where(n_zeros == 0, 99, roots)
  
  # calculate how many samples have roots so we can calculate d for Q0
  d = jnp.count_nonzero(n_zeros != 0)
  print(d, '/', len(n_zeros))

  # find the bounds for the 2 roots on each sample
  roots = jnp.where(roots != 0, 1, 0)
  all_root_indices = jnp.argsort(roots, axis=0)

  p_pos_indices = all_root_indices[0].T
  p_neg_indices = all_root_indices[1].T
  grid = grid.squeeze(axis=2)
  t1 = jnp.take_along_axis(grid, p_pos_indices, axis=0).T
  t2 = jnp.take_along_axis(grid, p_pos_indices+1, axis=0).T
  t3 = jnp.take_along_axis(grid, p_neg_indices, axis=0).T
  t4 = jnp.take_along_axis(grid, p_neg_indices+1, axis=0).T

  # use bisection method to find the roots
  p_pos, _ = bisection(p0 + t1 * (p1 - p0),
                   p0 + t2 * (p1 - p0),
                   root_steps)
  
  p_neg, _ = bisection(p0 + t3 * (p1 - p0),
                   p0 + t4 * (p1 - p0),
                   root_steps)
  
  return p_pos, p_neg, d, n_zeros

def exterior_jump(p0, p1, root_steps, k):
  p1 = p1 - p0 # re-center the angles
  # make it long based on how far away we are from boundary (use signDistance and diameter?)

  # find both roots
  p_pos, p_neg, d, n_zeros = rootIsolate(p0, p1, root_steps)
  
  # calculate weights
  Q0 = 2 * d / len(p1)

  return p_pos, p_neg, Q0, n_zeros

In [44]:
'''
setup the wob
'''
def computeSoln(key, p, t=5, rootSteps=10, samples=1_000, diameter=1, k=1):
  p0 = jnp.tile(p, (samples,1)) # [samples, dimensions]

  # generate random directions in n-dimensions
  p1 = jrandom.normal(key, shape=(t, samples, jnp.size(p0, 1)))
  normalize = jnp.linalg.norm(p1, ord=2, axis=2, keepdims=True)
  p1 = diameter * jnp.divide(p1, normalize) # [jumps, dimensions, samples]

  # do the first exterior jump
  p_pos, p_neg, Q0, n_zeros = exterior_jump(p0, p1[0], rootSteps, k)

  # create markov chain at p0
  mc_pos = wob(p_pos, p1[1:], t, rootSteps, k)
  markov_chain_pos = np.concatenate((p_pos[np.newaxis,:], mc_pos), axis=0)

  # create markov chain at second p0 (is it okay to use the same angles?)
  mc_neg = wob(p_neg, p1[1:], t, rootSteps, k)
  markov_chain_neg = np.concatenate((p_neg[np.newaxis,:], mc_neg), axis=0)

  # remove from markov chain using n_zeros (should delete earlier?)
  markov_chain_pos = np.delete(markov_chain_pos, n_zeros, axis=1)
  markov_chain_neg = np.delete(markov_chain_neg, n_zeros, axis=1)

  # calculate alpha
  alpha = np.mean(boundaryCond(markov_chain_pos, k, axis=2), axis=0) / np.mean(np.linalg.norm(markov_chain_pos,ord=2,axis=2), axis=0)
  
  # apply estimator
  est1 = alpha * np.sum(1 / np.linalg.norm(markov_chain_pos,ord=2,axis=2) - 1 / np.linalg.norm(markov_chain_neg,ord=2,axis=2), axis=0)
  est2 = np.sum(boundaryCond(markov_chain_pos, k, axis=2) - boundaryCond(markov_chain_neg, k, axis=2), axis=0)
  ans = Q0 * (est1 - est2)

  print(Q0)
  
  return jnp.mean(ans), jnp.std(ans)

# Testing

In [45]:
%%time
'''
testing code
'''
key = jrandom.PRNGKey(1)
x = jnp.array([0.0,1.1,0.0])

y, sd = computeSoln(key, x, t=10, rootSteps=10, samples=5000, diameter=3, k=0)
print('solution:', y)
print('exterior exact:', trueSoln(jnp.array([x]), k=0))

3199 / 5000
1.2796
solution: 0.46481478
exterior exact: [0.07234315]
CPU times: user 4.52 s, sys: 102 ms, total: 4.62 s
Wall time: 4.94 s
