# 2d poissson equation
---
$$
\begin{aligned}
&\frac{d^2 u}{dx^2} + \frac{d^2 u}{dy^2} = -2\sin(x)\sin(y), \ x, y \in [-\pi, \pi]^2, \\
&u(x, -\pi) = u(x, \pi) = u(-\pi, y) = u(\pi, y) = 0.
\end{aligned}
$$

Solution:
$$
u(x) = \sin(x)\sin(y).
$$

In [1]:
import flax, flax.nn
from flax import jax_utils, optim
from flax.training import lr_schedule

import jax, jax.nn
from jax import random
import jax.numpy as jnp

import sys
sys.path.append("../../")
	
from Seismic_wave_inversion_PINN.tf_model_utils import *
from Seismic_wave_inversion_PINN.data_utils import *

In [5]:
@jax.jit
def scalar_fn(x, y):
    W = jnp.array([[1.0], [1.0]])
    x_ = jnp.hstack([x, y])
    return jnp.sum(jnp.matmul(x_, W))

@jax.jit
def scalar_du_dx(x, y):
	return jnp.sum(jax.grad(scalar_fn, 0)(x, y))

@jax.jit
def scalar_du_dy(x, y):
	return jnp.sum(jax.grad(scalar_fn, 1)(x, y))

@jax.jit
def du_dxx(x, y):
	return jax.grad(scalar_du_dx, 0)(x, y)

@jax.jit
def du_dyy(x, y):
	return jax.grad(scalar_du_dy, 1)(x, y)

key, *subkeys = random.split(random.PRNGKey(0), 3)
x = random.uniform(subkeys[0], (100, 1))
y = random.uniform(subkeys[1], (100, 1))

In [6]:
fn = jax.vmap(scalar_fn, in_axes = (0, 0))
%timeit fn(x, y).block_until_ready()

du_dx = jax.vmap(scalar_du_dx, in_axes = (0, 0))
%timeit du_dx(x, y).block_until_ready()

%timeit du_dxx(x, y).block_until_ready()

183 µs ± 188 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
343 µs ± 22.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


AttributeError: 'numpy.ndarray' object has no attribute 'block_until_ready'

In [None]:
@jax.jit
def scalar_fn(x, y):
    W = jnp.array([[1.0], [1.0]])
    x_ = jnp.hstack([x, y])
    return jnp.sum(jnp.matmul(x_, W))

@jax.jit
def du_dx(x, y):
	return jax.vmap(jax.grad(scalar_fn, 0))(x, y)

@jax.jit
def du_dy(x, y):
	return jax.vmap(jax.grad(scalar_fn, 1))(x, y)

@jax.jit
def scalar_du_dx(x, y):
	return jnp.sum(jax.vmap(jax.grad(scalar_fn, 0))(x, y))

@jax.jit
def scalar_du_dy(x, y):
	return jnp.sum(jax.vmap(jax.grad(scalar_fn, 1))(x, y))

@jax.jit
def du_dxx(x, y):
	return jax.grad(scalar_du_dx, 0)(x, y)

@jax.jit
def du_dyy(x, y):
	return jax.grad(scalar_du_dy, 1)(x, y)

In [None]:
%timeit fn(x, y).block_until_ready()

%timeit du_dx(x, y).block_until_ready()

%timeit du_dxx(x, y).block_until_ready()