<a href="https://colab.research.google.com/github/AmandinChyba/Project1-Fields-2022/blob/main/Laplace_On_Disc/LaplaceOnDisc.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [194]:
from jax import random
from jax import lax

import jax.numpy as jnp
import numpy as np

import time
import matplotlib.pyplot as plt

from functools import partial

In [522]:
def poisson_kernel(r, theta):
  return (1 - jnp.power(r, 2)) / (1 - 2 * r * jnp.cos(theta) + jnp.power(r, 2))

In [674]:
def metroStep(cur, prop, kernel, r, theta):
  k1 = kernel(r, theta-prop[0])
  k2 = kernel(r, theta-cur)
  A = jnp.minimum(1, jnp.divide(k1,k2)) - prop[1]
  Y = jnp.ceil(A)
  out = jnp.multiply(cur, 1-Y) + jnp.multiply(prop[0], Y)
  return out, (out, jnp.array([0.0]))


def solve_at_point(g, k, r, theta, batches, t, key):
  keys = random.split(key, 2)

  proposals = random.uniform(keys[0], shape=(t, batches, len(r), len(theta))) * 2 * jnp.pi
  U = random.uniform(keys[1], shape=(t-1, batches, len(r), len(theta)))
  
  r = jnp.array(r)
  r = jnp.reshape(r, (1,len(r),1))
  theta = jnp.array(theta)
  theta = jnp.reshape(theta, (1,1,len(theta)))

  metroFunc = partial(metroStep, kernel=poisson_kernel, r=r, theta=theta)
  final, result = lax.scan(metroFunc, proposals[0], (proposals[1:], U))
  t_mean = jnp.mean(g(result[0]), axis=0)
  batch_mean = jnp.mean(t_mean, axis=0)
  return batch_mean

In [685]:
%%time
batches = 50
t = 10000
r = jnp.linspace(0,1,num=10)
theta = jnp.linspace(0,2*jnp.pi,num=10)
key = random.PRNGKey(0)
ans = solve_at_point(jnp.sin, poisson_kernel, r, theta, batches, t, key)
print(ans)

[[ 1.26276608e-03  6.99264696e-04 -1.33131235e-03 -8.41918401e-04
   8.13355437e-04  3.92701477e-04 -7.18974450e-04  5.72815479e-04
  -3.31015508e-05 -1.45094236e-04]
 [-1.21265685e-03  7.17817172e-02  1.09971732e-01  9.55314040e-02
   3.83657813e-02 -3.68416347e-02 -9.45786238e-02 -1.08703248e-01
  -7.11498559e-02  1.29645958e-03]
 [ 3.59527272e-04  1.43908113e-01  2.20039099e-01  1.92351416e-01
   7.57715479e-02 -7.56451637e-02 -1.90681770e-01 -2.18267187e-01
  -1.44130796e-01 -1.09241321e-03]
 [ 1.64026511e-03  2.10879251e-01  3.27797383e-01  2.85483956e-01
   1.13822594e-01 -1.11515522e-01 -2.91202456e-01 -3.26811552e-01
  -2.14217409e-01  1.29600533e-03]
 [ 8.63145804e-04  2.88767725e-01  4.38491195e-01  3.84863079e-01
   1.52301848e-01 -1.53538346e-01 -3.83761406e-01 -4.38224465e-01
  -2.87137836e-01  3.59896978e-04]
 [ 1.37113529e-04  3.59089971e-01  5.49339950e-01  4.80830282e-01
   1.89452201e-01 -1.89636633e-01 -4.82046127e-01 -5.44010103e-01
  -3.58024359e-01  1.20455246e-04