# 2d poissson equation
---
$$
\begin{aligned}
&\frac{d^2 u}{dx^2} + \frac{d^2 u}{dy^2} = -2\sin(x)\sin(y), \ x, y \in [-\pi, \pi]^2, \\
&u(x, -\pi) = u(x, \pi) = u(-\pi, y) = u(\pi, y) = 0.
\end{aligned}
$$

Solution:
$$
u(x) = \sin(x)\sin(y).
$$

In [1]:
import flax, flax.nn
from flax import jax_utils, optim
from flax.training import lr_schedule

import jax, jax.nn
from jax import random
import jax.numpy as jnp
from jax.experimental import optimizers

import sys
sys.path.append("../../")
	
from Seismic_wave_inversion_PINN.tf_model_utils import *
from Seismic_wave_inversion_PINN.data_utils import *

In [2]:
def random_layer_params(m, n, key, scale=1e-2):
	w_key, b_key = random.split(key)
	return jax.nn.initializers.glorot_uniform()(w_key, (m, n)), jax.nn.initializers.zeros(b_key, (n, ))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(key, layers):
	keys = random.split(key, len(layers))
	return [random_layer_params(m, n, k) for m, n, k in zip(layers[:-1], layers[1:], keys)]

layers = [2, 128, 128, 128, 128, 1]
params = init_network_params(random.PRNGKey(0), layers)

@jax.jit
def scalar_model(params, x, y):
	x_ = jnp.hstack([x, y])
	for w, b in params[:-1]:
		x_ = jnp.tanh(jnp.dot(x_, w) + b)
	return jnp.sum(jnp.dot(x_, params[-1][0]) + params[-1][1])

model = jax.jit(jax.vmap(scalar_model, in_axes = (None, 0, 0)))

In [3]:
@jax.jit
def mse(pred, true):
	return jnp.mean(jnp.square(pred - true))

@jax.jit
def scalar_du_dx(params, x, y):
	return jnp.sum(jax.grad(scalar_model, 1)(params, x, y))

@jax.jit
def scalar_du_dy(params, x, y):
	return jnp.sum(jax.grad(scalar_model, 2)(params, x, y))

@jax.jit
def du_dxx(params, x, y):
	return jax.grad(scalar_du_dx, 1)(params, x, y)

@jax.jit
def du_dyy(params, x, y):
	return jax.grad(scalar_du_dy, 2)(params, x, y)

@jax.jit
def loss_fn_(params, batch):
	u_b = model(params, batch["b"]["x"], batch["b"]["y"])
	loss_c = mse(du_dxx(params, batch["c"]["x"], batch["c"]["y"]).reshape((-1, 1)) + du_dyy(params, batch["c"]["x"], batch["c"]["y"]).reshape((-1, 1)), -2*jnp.sin(batch["c"]["x"])*jnp.sin(batch["c"]["y"]))
	loss_br = mse(du_dxx(params, batch["b"]["x"], batch["b"]["y"]).reshape((-1, 1)) + du_dyy(params, batch["b"]["x"], batch["b"]["y"]).reshape((-1, 1)), -2*jnp.sin(batch["b"]["x"])*jnp.sin(batch["b"]["y"]))
	loss_bv = mse(u_b, batch["b"]["u"])
	return loss_c, loss_br, loss_bv

@jax.jit
def loss_fn(params, batch):
	loss_c, loss_br, loss_bv = loss_fn_(params, batch)
	return loss_c + loss_br + loss_bv

def step(i, opt_state, opt_update, get_params, batch):
	params = get_params(opt_state)
	grad = jax.grad(loss_fn, 0)(params, batch)
	return opt_update(i, grad, opt_state)

step_size = 1e-4
@jax.jit
def update(params, batch):
	grads = jax.grad(loss_fn, 0)(params, batch)
	return [(w-step_size*dw, b-step_size*db) for (w, b), (dw, db) in zip(params, grads)]

@jax.jit
def evaluate(params, batch):
	loss_c, loss_br, loss_bv = loss_fn_(params, batch)
	return loss_c + loss_br + loss_bv, loss_c, loss_br, loss_bv

In [4]:
key = random.PRNGKey(0)
key, *subkeys = random.split(key, 3)
n_c = 10000
x_c = random.uniform(subkeys[0], (n_c, 1), minval = -jnp.pi, maxval = jnp.pi)
y_c = random.uniform(subkeys[1], (n_c, 1), minval = -jnp.pi, maxval = jnp.pi)

key, *subkeys = random.split(key, 5)
n_b = 100
x_b = jnp.vstack([random.uniform(subkeys[0], (n_b, 1), minval = -jnp.pi, maxval = jnp.pi),
				  jnp.ones((n_b, 1))*jnp.pi,
				  random.uniform(subkeys[1], (n_b, 1), minval = -jnp.pi, maxval = jnp.pi),
				  jnp.ones((n_b, 1))*-jnp.pi])
y_b = jnp.vstack([jnp.ones((n_b, 1))*-jnp.pi,
				  random.uniform(subkeys[2], (n_b, 1), minval = -jnp.pi, maxval = jnp.pi),
				  jnp.ones((n_b, 1))*jnp.pi,
				  random.uniform(subkeys[3], (n_b, 1), minval = -jnp.pi, maxval = jnp.pi)])
f = lambda x, y: jnp.sin(x)*jnp.sin(y)
u_b = f(x_b, y_b)
dataset = {"c": {"x": x_c, "y": y_c},
		   "b": {"x": x_b, "y": y_b, "u": u_b}}

In [5]:
# opt_init, opt_update, get_params = optimizers.adam(1e-4)
# opt_state = opt_init(params)
for iteration in range(1, 10001):
# 	params = get_params(opt_state)
	grad = jax.grad(loss_fn, 0)(params, dataset)
# 	opt_state = opt_update(iteration, grad, opt_state)
# 	opt_state = step(iteration, opt_state, opt_update, get_params, dataset)
	params = update(params, dataset)
	if iteration % 1000 == 0:
		names = ("Loss", "c", "br", "bv")
# 		params = get_params(opt_state)
		print("{}, Iteration: {},".format(get_time(), iteration) + \
			  ','.join([" {}: {:.4e}".format(name, loss) for name, loss in zip(names, evaluate(params, dataset))]))

2020/07/22, 00:55:05, Iteration: 1000, Loss: 9.9058e-01, c: 9.8558e-01, br: 1.2665e-03, bv: 3.7312e-03
2020/07/22, 00:55:23, Iteration: 2000, Loss: 9.8671e-01, c: 9.8277e-01, br: 9.9113e-04, bv: 2.9429e-03
2020/07/22, 00:55:39, Iteration: 3000, Loss: 9.8332e-01, c: 9.8011e-01, br: 7.9855e-04, bv: 2.4036e-03
2020/07/22, 00:55:56, Iteration: 4000, Loss: 9.8020e-01, c: 9.7749e-01, br: 6.6332e-04, bv: 2.0448e-03
2020/07/22, 00:56:12, Iteration: 5000, Loss: 9.7722e-01, c: 9.7482e-01, br: 5.6909e-04, bv: 1.8230e-03
2020/07/22, 00:56:30, Iteration: 6000, Loss: 9.7425e-01, c: 9.7203e-01, br: 5.0541e-04, bv: 1.7112e-03
2020/07/22, 00:56:47, Iteration: 7000, Loss: 9.7120e-01, c: 9.6904e-01, br: 4.6562e-04, bv: 1.6934e-03
2020/07/22, 00:57:05, Iteration: 8000, Loss: 9.6798e-01, c: 9.6577e-01, br: 4.4566e-04, bv: 1.7621e-03
2020/07/22, 00:57:23, Iteration: 9000, Loss: 9.6450e-01, c: 9.6214e-01, br: 4.4340e-04, bv: 1.9167e-03
2020/07/22, 00:57:40, Iteration: 10000, Loss: 9.6067e-01, c: 9.5805e-01, 

In [6]:
%timeit model(params, x_b, y_b).block_until_ready()
%timeit model(params, x_c, y_c).block_until_ready()

du_dx = jax.vmap(scalar_du_dx, in_axes = (None, 0, 0))
%timeit du_dx(params, x_b, y_b).block_until_ready()
%timeit du_dxx(params, x_b, y_b).block_until_ready()

%timeit du_dx(params, x_c, y_c).block_until_ready()
%timeit du_dxx(params, x_c, y_c).block_until_ready()

198 µs ± 393 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
348 µs ± 1.54 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
373 µs ± 664 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
689 µs ± 102 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
657 µs ± 253 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.43 ms ± 142 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
x_test = jnp.linspace(-jnp.pi, jnp.pi, 100).reshape((-1, 1))
y_test = x_test
xy_test = tensor_grid([x_test, y_test])

u_pred = optimizer.target(xy_test)
u_test = f(xy_test[:, 0:1], xy_test[:, 1:2])


import matplotlib as mpl
from matplotlib.cm import cool
from matplotlib.colors import Normalize

from mpl_toolkits.axes_grid1 import make_axes_locatable

from matplotlib import pyplot as plt
plt.rcParams.update(plt.rcParamsDefault)

cmap = cool
norm = Normalize(vmin=-1.0, vmax=1.0)

X, Y = np.meshgrid(x_test, y_test)
fig, ax = plt.subplots(1, 3, figsize = (15, 5))
fig.subplots_adjust(right = 1.0)

im0 = ax[0].contourf(X, Y, u_test.reshape((len(y_test), len(x_test))), cmap = cmap, norm = norm)
ax[0].set_title("true")
divider = make_axes_locatable(ax[0])
cax = divider.append_axes('right', size='5%', pad=0.05)
mpl.colorbar.ColorbarBase(cax, cmap = cmap, norm = norm, orientation='vertical')

norm = mpl.colors.Normalize(vmin=-1.0, vmax=1.0)
im1 = ax[1].contourf(X, Y, u_pred.reshape((len(y_test), len(x_test))), cmap = cmap, norm = norm)
ax[1].set_title("pred")
ax[1].set_xlabel("")
divider = make_axes_locatable(ax[1])
cax = divider.append_axes('right', size='5%', pad=0.05)
mpl.colorbar.ColorbarBase(cax, cmap = cmap, norm = norm, orientation='vertical')

norm = mpl.colors.Normalize(vmin=-1e-2, vmax=1e-2)
im2 = ax[2].contourf(X, Y, u_test.reshape((len(y_test), len(x_test)))-u_pred.reshape((len(y_test), len(x_test))), cmap = cmap, norm = norm)
ax[2].set_title("MSE: {}".format(np.mean(np.square(u_test - u_pred))))
divider = make_axes_locatable(ax[2])
cax = divider.append_axes('right', size='5%', pad=0.05)
mpl.colorbar.ColorbarBase(cax, cmap = cmap, norm = norm, orientation='vertical')

plt.show()

NameError: name 'optimizer' is not defined