<a href="https://colab.research.google.com/github/AmandinChyba/Project1-Fields-2022/blob/main/Walk_on_Boundary_and_BIEM/WalkOnBoundary.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [127]:
import jax
from jax import lax
from jax import random

import jax.numpy as jnp
import numpy as np

import time

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d

from functools import partial

In [128]:
def unit_circle_boundary(x, y):
  return jnp.sqrt(jnp.power(x, 2) + jnp.power(y, 2)) - 1

In [129]:
def initial_point(x_0, y_0, angle, boundary, tolerance):
  x = x_0
  y = y_0
  error = boundary(x, y)
  while jnp.abs(error) > tolerance:
    x += error * jnp.cos(angle)
    y += error * jnp.sin(angle)
    error = boundary(x, y)
  return x[0], y[0]

In [130]:
def get_point(x_0, y_0, angle, boundary, tolerance):
  x = x_0
  y = y_0
  error = tolerance * -2
  theta = angle
  if boundary(x + error * jnp.cos(theta), y + error * jnp.sin(theta)) > 0:
    theta -= jnp.pi
  while jnp.abs(error) > tolerance:
    x += error * jnp.cos(theta)
    y += error * jnp.sin(theta)
    error = boundary(x, y)
  return x[0], y[0]

In [131]:
def wob(g, boundary, x, y, batches, t, tolerance, key):
  # generate keys
  keys = random.split(key, 2)

  # compute angles
  angles = random.uniform(keys[0], shape=(t, batches)) * 2 * jnp.pi

  # get initial point on boundary
  cur_x, cur_y = initial_point(x, y, angles[0], boundary, tolerance)

  points = jnp.array([[cur_x, cur_y]])
  # run markov chain
  for angle in angles[1:]:
    next_x, next_y = get_point(cur_x, cur_y, angle, boundary, tolerance)
    points = jnp.append(points, jnp.array([[next_x, next_y]]), axis=0)
    cur_x = next_x
    cur_y = next_y

  return points, angles

In [135]:
g = lambda x, y : y
boundary = unit_circle_boundary
x = 0
y = 0.5
batches = 1
t = 100
tolerance = 0.001
key = random.PRNGKey(1)

points, angles = wob(g, boundary, x, y, batches, t, tolerance, key)
points

DeviceArray([[ 0.99956834, -0.02280234],
             [ 0.39300376, -0.9190035 ],
             [-0.26050088,  0.96545607],
             [-0.25851825,  0.9651931 ],
             [-0.99711436,  0.06331345],
             [ 0.729059  , -0.6834972 ],
             [-0.83860207, -0.54431856],
             [-0.19697239,  0.9799177 ],
             [-0.37238237, -0.9276621 ],
             [ 0.93510115, -0.35196978],
             [ 0.93436795, -0.35383055],
             [-0.7827891 ,  0.6219497 ],
             [-0.7812288 ,  0.6232008 ],
             [ 0.47903126, -0.8777743 ],
             [ 0.88311046,  0.46786618],
             [-0.976038  ,  0.21725668],
             [-0.97555524,  0.21919754],
             [ 0.3600103 , -0.9325249 ],
             [-0.7753906 ,  0.6311333 ],
             [ 0.98587155,  0.1629362 ],
             [-0.7560247 ,  0.65387434],
             [-0.7651382 , -0.6429936 ],
             [ 0.82002956, -0.5706142 ],
             [ 0.5909743 ,  0.8058046 ],
             [ 0