# 2d poissson equation
---
$$
\begin{aligned}
&\frac{d^2 u}{dx^2} + \frac{d^2 u}{dy^2} = -2\sin(x)\sin(y), \ x, y \in [-\pi, \pi]^2, \\
&u(x, -\pi) = u(x, \pi) = u(-\pi, y) = u(\pi, y) = 0.
\end{aligned}
$$

Solution:
$$
u(x) = \sin(x)\sin(y).
$$

In [1]:
import flax, flax.nn
from flax import jax_utils, optim
from flax.training import lr_schedule

import jax, jax.nn
from jax import random
import jax.numpy as jnp
from jax.experimental import optimizers

import sys
sys.path.append("../../")
	
from Seismic_wave_inversion_PINN.tf_model_utils import *
from Seismic_wave_inversion_PINN.data_utils import *

In [2]:
def random_layer_params(m, n, key, scale=1e-2):
	w_key, b_key = random.split(key)
	return jax.nn.initializers.glorot_uniform()(w_key, (n, m)), jax.nn.initializers.zeros(b_key, (n, ))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(key, layers):
	keys = random.split(key, len(layers))
	return [random_layer_params(m, n, k) for m, n, k in zip(layers[:-1], layers[1:], keys)]

layers = [2, 128, 128, 128, 128, 1]
params = init_network_params(random.PRNGKey(0), layers)

@jax.jit
def scalar_model(params, x, y):
	x_ = jnp.hstack([x, y])
	for w, b in params[:-1]:
		x_ = jnp.tanh(jnp.dot(w, x_) + b)
	return jnp.sum(jnp.dot(params[-1][0], x_) + params[-1][1])

model = jax.jit(jax.vmap(scalar_model, in_axes = (None, 0, 0)))

In [3]:
@jax.jit
def mse(pred, true):
	return jnp.mean(jnp.square(pred - true))

@jax.jit
def loss_fn(params, batch):

	@jax.jit
	def scalar_du_dx(x, y):
		return jnp.sum(jax.grad(scalar_model, 1)(params, x, y))

	@jax.jit
	def scalar_du_dy(x, y):
		return jnp.sum(jax.grad(scalar_model, 2)(params, x, y))

	@jax.jit
	def scalar_du_dxx(x, y):
		return jnp.sum(jax.grad(scalar_du_dx, 0)(x, y))

	@jax.jit
	def scalar_du_dyy(x, y):
		return jnp.sum(jax.grad(scalar_du_dy, 1)(x, y))

	du_dxx = jax.vmap(scalar_du_dxx, in_axes = (0, 0))
	du_dyy = jax.vmap(scalar_du_dyy, in_axes = (0, 0))

	u_b = model(params, batch["b"]["x"], batch["b"]["y"])
	loss_c = mse(du_dxx(batch["c"]["x"], batch["c"]["y"]).reshape((-1, 1)) + du_dyy(batch["c"]["x"], batch["c"]["y"]).reshape((-1, 1)), -2*jnp.sin(batch["c"]["x"])*jnp.sin(batch["c"]["y"]))
	loss_br = mse(du_dxx(batch["b"]["x"], batch["b"]["y"]).reshape((-1, 1)) + du_dyy(batch["b"]["x"], batch["b"]["y"]).reshape((-1, 1)), -2*jnp.sin(batch["b"]["x"])*jnp.sin(batch["b"]["y"]))
	loss_bv = mse(u_b, batch["b"]["u"])
	loss = loss_c + loss_br + loss_bv
	return loss

# @jax.jit
# def step(i, opt_state, opt_update, get_params, batch):
# 	params = get_params(opt_state)
# 	grad = jax.grad(loss_fn, 0)(params, batch)
# 	return opt_update(i, grad, opt_state)

step_size = 1e-4
@jax.jit
def update(params, batch):
	grads = jax.grad(loss_fn, 0)(params, batch)
	return [(w-step_size*dw, b-step_size*db) for (w, b), (dw, db) in zip(params, grads)]

@jax.jit
def evaluate(params, batch):
	@jax.jit
	def scalar_du_dx(x, y):
		return jnp.sum(jax.grad(scalar_model, 1)(params, x, y))

	@jax.jit
	def scalar_du_dy(x, y):
		return jnp.sum(jax.grad(scalar_model, 2)(params, x, y))

	@jax.jit
	def scalar_du_dxx(x, y):
		return jnp.sum(jax.grad(scalar_du_dx, 0)(x, y))

	@jax.jit
	def scalar_du_dyy(x, y):
		return jnp.sum(jax.grad(scalar_du_dy, 1)(x, y))

	du_dxx = jax.vmap(scalar_du_dxx, in_axes = (0, 0))
	du_dyy = jax.vmap(scalar_du_dyy, in_axes = (0, 0))

	u_b = model(params, batch["b"]["x"], batch["b"]["y"])
	loss_c = mse(du_dxx(batch["c"]["x"], batch["c"]["y"]).reshape((-1, 1)) + du_dyy(batch["c"]["x"], batch["c"]["y"]).reshape((-1, 1)), -2*jnp.sin(batch["c"]["x"])*jnp.sin(batch["c"]["y"]))
	loss_br = mse(du_dxx(batch["b"]["x"], batch["b"]["y"]).reshape((-1, 1)) + du_dyy(batch["b"]["x"], batch["b"]["y"]).reshape((-1, 1)), -2*jnp.sin(batch["b"]["x"])*jnp.sin(batch["b"]["y"]))
	loss_bv = mse(u_b, batch["b"]["u"])
	loss = loss_c + loss_br + loss_bv
	return loss, loss_c, loss_br, loss_bv

In [4]:
key = random.PRNGKey(0)
key, *subkeys = random.split(key, 3)
n_c = 10000
x_c = random.uniform(subkeys[0], (n_c, 1), minval = -jnp.pi, maxval = jnp.pi)
y_c = random.uniform(subkeys[1], (n_c, 1), minval = -jnp.pi, maxval = jnp.pi)

key, *subkeys = random.split(key, 5)
n_b = 100
x_b = jnp.vstack([random.uniform(subkeys[0], (n_b, 1), minval = -jnp.pi, maxval = jnp.pi),
				  jnp.ones((n_b, 1))*jnp.pi,
				  random.uniform(subkeys[1], (n_b, 1), minval = -jnp.pi, maxval = jnp.pi),
				  jnp.ones((n_b, 1))*-jnp.pi])
y_b = jnp.vstack([jnp.ones((n_b, 1))*-jnp.pi,
				  random.uniform(subkeys[2], (n_b, 1), minval = -jnp.pi, maxval = jnp.pi),
				  jnp.ones((n_b, 1))*jnp.pi,
				  random.uniform(subkeys[3], (n_b, 1), minval = -jnp.pi, maxval = jnp.pi)])
f = lambda x, y: jnp.sin(x)*jnp.sin(y)
u_b = f(x_b, y_b)
dataset = {"c": {"x": x_c, "y": y_c},
		   "b": {"x": x_b, "y": y_b, "u": u_b}}

In [5]:
# opt_init, opt_update, get_params = optimizers.adam(1e-4)
# opt_state = opt_init(initial_params)
for iteration in range(1, 10001):
# 	params = get_params(opt_state)
# 	grad = jax.grad(loss_fn, 0)(params, dataset)
# 	opt_state = opt_update(iteration, grad, opt_state)
	params = update(params, dataset)
	if iteration % 1000 == 0:
		names = ("Loss", "c", "br", "bv")
# 		params = get_params(opt_state)
		print("{}, Iteration: {},".format(get_time(), iteration) + \
			  ','.join([" {}: {:.4e}".format(name, loss) for name, loss in zip(names, evaluate(params, dataset))]))

2020/07/21, 23:31:34, Iteration: 1000, Loss: 9.8784e-01, c: 9.8472e-01, br: 8.9890e-04, bv: 2.2270e-03
2020/07/21, 23:31:41, Iteration: 2000, Loss: 9.8421e-01, c: 9.8184e-01, br: 6.8618e-04, bv: 1.6816e-03
2020/07/21, 23:31:47, Iteration: 3000, Loss: 9.8094e-01, c: 9.7904e-01, br: 5.4876e-04, bv: 1.3502e-03
2020/07/21, 23:31:54, Iteration: 4000, Loss: 9.7783e-01, c: 9.7619e-01, br: 4.6265e-04, bv: 1.1752e-03
2020/07/21, 23:32:00, Iteration: 5000, Loss: 9.7472e-01, c: 9.7318e-01, br: 4.1353e-04, bv: 1.1231e-03
2020/07/21, 23:32:07, Iteration: 6000, Loss: 9.7149e-01, c: 9.6992e-01, br: 3.9297e-04, bv: 1.1762e-03
2020/07/21, 23:32:14, Iteration: 7000, Loss: 9.6803e-01, c: 9.6630e-01, br: 3.9655e-04, bv: 1.3294e-03
2020/07/21, 23:32:20, Iteration: 8000, Loss: 9.6421e-01, c: 9.6220e-01, br: 4.2281e-04, bv: 1.5878e-03
2020/07/21, 23:32:27, Iteration: 9000, Loss: 9.5992e-01, c: 9.5748e-01, br: 4.7267e-04, bv: 1.9667e-03
2020/07/21, 23:32:34, Iteration: 10000, Loss: 9.5500e-01, c: 9.5196e-01, 

In [6]:
x_test = jnp.linspace(-jnp.pi, jnp.pi, 100).reshape((-1, 1))
y_test = x_test
xy_test = tensor_grid([x_test, y_test])

u_pred = optimizer.target(xy_test)
u_test = f(xy_test[:, 0:1], xy_test[:, 1:2])


import matplotlib as mpl
from matplotlib.cm import cool
from matplotlib.colors import Normalize

from mpl_toolkits.axes_grid1 import make_axes_locatable

from matplotlib import pyplot as plt
plt.rcParams.update(plt.rcParamsDefault)

cmap = cool
norm = Normalize(vmin=-1.0, vmax=1.0)

X, Y = np.meshgrid(x_test, y_test)
fig, ax = plt.subplots(1, 3, figsize = (15, 5))
fig.subplots_adjust(right = 1.0)

im0 = ax[0].contourf(X, Y, u_test.reshape((len(y_test), len(x_test))), cmap = cmap, norm = norm)
ax[0].set_title("true")
divider = make_axes_locatable(ax[0])
cax = divider.append_axes('right', size='5%', pad=0.05)
mpl.colorbar.ColorbarBase(cax, cmap = cmap, norm = norm, orientation='vertical')

norm = mpl.colors.Normalize(vmin=-1.0, vmax=1.0)
im1 = ax[1].contourf(X, Y, u_pred.reshape((len(y_test), len(x_test))), cmap = cmap, norm = norm)
ax[1].set_title("pred")
ax[1].set_xlabel("")
divider = make_axes_locatable(ax[1])
cax = divider.append_axes('right', size='5%', pad=0.05)
mpl.colorbar.ColorbarBase(cax, cmap = cmap, norm = norm, orientation='vertical')

norm = mpl.colors.Normalize(vmin=-1e-2, vmax=1e-2)
im2 = ax[2].contourf(X, Y, u_test.reshape((len(y_test), len(x_test)))-u_pred.reshape((len(y_test), len(x_test))), cmap = cmap, norm = norm)
ax[2].set_title("MSE: {}".format(np.mean(np.square(u_test - u_pred))))
divider = make_axes_locatable(ax[2])
cax = divider.append_axes('right', size='5%', pad=0.05)
mpl.colorbar.ColorbarBase(cax, cmap = cmap, norm = norm, orientation='vertical')

plt.show()

NameError: name 'optimizer' is not defined