<a href="https://colab.research.google.com/github/AmandinChyba/Project1-Fields-2022/blob/main/Walk_on_Boundary_and_BIEM/WalkOnBoundary.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
from jax import lax
from jax import random

import jax.numpy as jnp
import numpy as np

import time

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d

from functools import partial

In [2]:
def unit_circle_boundary(x, y):
  return jnp.sqrt(jnp.power(x, 2) + jnp.power(y, 2)) - 1

In [3]:
def initial_point(x_0, y_0, angle, boundary, tolerance):
  x = x_0
  y = y_0
  error = boundary(x, y)
  while jnp.abs(error) > tolerance:
    x += error * jnp.cos(angle)
    y += error * jnp.sin(angle)
    error = boundary(x, y)
  return x[0], y[0]

In [4]:
def get_point(x_0, y_0, angle, boundary, tolerance):
  x = x_0
  y = y_0
  error = tolerance * -2
  theta = angle
  if boundary(x + error * jnp.cos(theta), y + error * jnp.sin(theta)) > 0:
    theta -= jnp.pi
  while jnp.abs(error) > tolerance:
    x += error * jnp.cos(theta)
    y += error * jnp.sin(theta)
    error = boundary(x, y)
  return x[0], y[0]

In [5]:
def wob(g, boundary, x, y, batches, t, tolerance, key):
  # generate keys
  keys = random.split(key, 2)

  # compute angles
  angles = random.uniform(keys[0], shape=(t, 1)) * 2 * jnp.pi

  # get initial point on boundary
  cur_x, cur_y = initial_point(x, y, angles[0], boundary, tolerance)

  points = jnp.array([[cur_x, cur_y]])
  # run markov chain
  for angle in angles[1:]:
    next_x, next_y = get_point(cur_x, cur_y, angle, boundary, tolerance)
    points = jnp.append(points, jnp.array([[next_x, next_y]]), axis=0)
    cur_x = next_x
    cur_y = next_y

  x_values = points[:,0]
  y_values = points[:,1]
  g_values = g(x_values, y_values)

  pos_g = jnp.sum(g_values[0:t-1:2]) * 2
  neg_g = jnp.sum(g_values[1:t-1:2]) * -2
  last = g_values[-1]

  return pos_g + neg_g + last

In [6]:
%%time
g = lambda x, y : y
boundary = unit_circle_boundary
x = 0
y = 0.5
batches = 1000
t = 10
tolerance = 0.0001
key = random.PRNGKey(int(time.time()))
keys = random.split(key, batches)

ans = jnp.array([])

for i in range(batches):
  value = wob(g, boundary, x, y, batches, t, tolerance, keys[i])
  ans = jnp.append(ans, jnp.array([value]))

mean = jnp.mean(ans)
print(mean)



0.49623543
CPU times: user 2min 45s, sys: 727 ms, total: 2min 46s
Wall time: 2min 47s
