<a href="https://colab.research.google.com/github/AmandinChyba/Project1-Fields-2022/blob/main/Laplace_On_Disc/LaplaceOnDisc.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax.random as random
import jax.numpy as jnp
import numpy as np
import time

In [45]:
'''
solve_at : Solves the Dirichlet problem for Laplace's equation on a 2D disc at
polar coordinates (r, theta). Also takes the boundary condition as a
parameter.

Parameters
--------------------
g : the boundary condition as a vectorized function
k : the kernel function of the integral equation as a vectorized function
r : the radius of the polar coordinates
theta : the angle of the polar coordinates
n_samples : the number of samples of E[g(X)] where X ~ P(theta - phi)
t : the number of discrete time steps of the Metropolis-Hastings algorithm to
sample from X

Returns : the solution to the Dirichlet problem for Laplace's equation on a 2D
disc at the given point with the given boundary condition

Notes
--------------------
Currently, we want to use the Poisson kernel P(theta), so set k as 
poisson_kernel(r)

P(theta) = (1 - r^2) / (1 - 2*r*cos(theta) + r^2)
'''
def solve_at(g, k, r, theta, n_samples, t, key):
  # generate keys
  keys = random.split(key, 2 * t + 1)

  # sample from X ~ P(theta - phi) with Metropolis-Hastings
  # Q(x'|x) is a uniform distribution over the domain
  current_samples = random.uniform(keys[-1], shape=(1, n_samples)) * 2 * np.pi
  samples = jnp.tile(current_samples, (1, 1))
  for i in range(0, 2*t, 2):
    proposals = random.uniform(keys[i], shape=(1, n_samples)) * 2 * np.pi
    A = (jnp.ones(n_samples) * 
         k(r, theta - proposals) / 
         k(r, theta - current_samples))
    A = jnp.minimum(1, A)
    U = random.uniform(keys[i+1], shape=(1, n_samples))
    new_samples = jnp.where(U <= A, proposals, current_samples)
    samples = jnp.append(samples, new_samples, axis=0)
    current_samples = new_samples

  # compute E(g(X))
  G = g(samples)
  E = jnp.mean(G, axis=0)
  return samples, jnp.mean(E)

In [26]:
def poisson_kernel(r, theta):
  return ((1 - jnp.power(r, 2)) / 
          (1 - 2 * r * jnp.cos(theta) + jnp.power(r, 2)))

In [46]:
n_samples = 100
t = 100
r = 0.5
theta = np.pi / 2
key = random.PRNGKey(int(time.time()))
samples, ans = solve_at(jnp.sin, poisson_kernel, r, theta, n_samples, t, key)

[[5.1333156  2.5560143  0.11147802 2.668045   5.207624   0.20570773
  3.3682234  4.9701333  4.6107073  5.3702397  3.0017009  2.9769654
  0.7965667  6.130176   5.71475    3.4504435  0.9514434  5.812756
  2.2578983  1.5008167  4.4531507  2.0621383  4.801404   3.4766836
  1.0502309  5.972744   1.7906402  3.9385939  2.1320858  3.7552533
  0.7986797  2.8966973  2.5063105  2.609651   2.8496084  6.07546
  5.7478185  0.97852546 2.9343212  1.5890925  1.5653391  6.1017046
  1.530053   4.5041356  3.0206337  2.8588054  6.1402316  6.1408
  2.1223037  4.5607624  4.8716445  5.7648997  3.7192197  4.9769287
  1.6392405  2.030574   2.0835714  5.837361   0.06792509 2.4325385
  1.3801026  3.2077465  4.186382   2.108259   4.7120175  2.0781245
  3.1608589  0.02452047 1.6766028  0.0858857  6.216814   0.56227213
  3.323985   0.90205336 1.2831922  6.2089806  4.6898556  5.3774643
  4.301125   4.4870667  3.2397997  3.4565763  2.5095372  5.8809566
  1.1818783  4.950688   3.2623339  3.5546095  2.9742906  6.2688665

In [48]:
ans

DeviceArray(0.49303088, dtype=float32)