<a href="https://colab.research.google.com/github/AmandinChyba/Project1-Fields-2022/blob/main/Laplace_On_Disc/LaplaceOnDisc.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax.random as random
import jax.numpy as jnp
import numpy as np
import time

In [11]:
'''
solve_at : Solves the Dirichlet problem for Laplace's equation on a 2D disc at
polar coordinates (r, theta). Also takes the boundary condition as a
parameter.

Parameters
--------------------
g : the boundary condition as a vectorized function
k : the kernel function of the integral equation as a vectorized function
r : the radius of the polar coordinates
theta : the angle of the polar coordinates
n_samples : the number of samples of E[g(X)] where X ~ P(theta - phi)
t : the number of discrete time steps of the Metropolis-Hastings algorithm to
sample from X

Returns : the solution to the Dirichlet problem for Laplace's equation on a 2D
disc at the given point with the given boundary condition

Notes
--------------------
Currently, we want to use the Poisson kernel P(theta), so set k as 
poisson_kernel(r)

P(theta) = (1 - r^2) / (1 - 2*r*cos(theta) + r^2)
'''
def solve_at(g, k, r, theta, n_samples, t, key):
  # sample from X ~ P(theta - phi) with Metropolis-Hastings
  # Q(x'|x) is a uniform distribution over the domain
  current_samples = random.uniform(key, shape=(1, n_samples)) * 2 * np.pi
  samples = jnp.tile(current_samples, (1, 1))
  for i in range(t):
    proposals = random.uniform(key, shape=(1, n_samples)) * 2 * np.pi
    A = (jnp.ones(n_samples) * 
         k(theta - proposals) / 
         k(theta - current_samples))
    A = jnp.minimum(1, A)
    U = random.uniform(key, shape=(1, n_samples))
    new_samples = jnp.where(U <= A, proposals, A)
    samples = jnp.append(samples, new_samples, axis=1)
    current_samples = new_samples

  # compute E(g(X))
  G = g(samples)
  E = jnp.mean(G, axis=0)
  return samples, jnp.mean(E)

In [4]:
def poisson_kernel(r):
  def poisson_kernel_at_r(theta):
    return ((1 - jnp.power(r, 2)) / 
            (1 - 2 * r * jnp.cos(theta) + jnp.power(r, 2)))
    
  return poisson_kernel_at_r

In [None]:
n_samples = 100
t = 1000
r = 0.5
key = random.PRNGKey(int(time.time()))
theta = np.pi / 2
samples, ans = solve_at(jnp.sin, poisson_kernel(r), r, theta, n_samples, t, key)

DeviceArray(-0.00484148, dtype=float32)