# 2D Acoustic Wave
---
- Model Setup: [this link](https://github.com/devitocodes/devito/blob/master/examples/seismic/tutorials/01_modelling.ipynb)

- target: $c(x, z)^2$.

- rescaling: $x' = x/1000, z' = z/1000, t' = t/1000$.

---

# Direct Problem only,  without Source

In [1]:
NAME = "0722_direct_problem_without_source"

In [2]:
import jax, jax.nn
from jax import random
import jax.numpy as jnp
from jax.experimental import optimizers

import sys, os
sys.path.append("../../")
	
from Seismic_wave_inversion_PINN.data_utils import *

from collections import namedtuple

In [3]:
def siren_layer_params(key, scale, m, n):
	w_key, b_key = random.split(key)
	return random.uniform(w_key, (m, n), jnp.float32, minval = -scale, maxval = scale), jnp.zeros((n, ), jnp.float32)

def init_siren_params(key, layers, c, w0):
	keys = random.split(key, len(layers))
	return [siren_layer_params(keys[0], w0*jnp.sqrt(c/layers[0]), layers[0], layers[1])] + \
			[siren_layer_params(k, jnp.sqrt(c/m), m, n) for m, n, k in zip(layers[1:-1], layers[2:], keys[1:])]

layers = [3, 512, 512, 512, 512, 512, 1] # (x, z, t) -> p
c = 6.0
w0 = 30.0
lambda_0 = 1e-5
direct_params = init_siren_params(random.PRNGKey(0), layers, c, w0)

inverse_NAME = "0722_pretrain_inverse_problem"
inverse_iteration = 1000000
inverse_params = np.load("models/{}/inverse_model/iteration_{}/params.npy".format(inverse_NAME, inverse_iteration), allow_pickle=True)
inverse_params = [[jnp.asarray(arr) for arr in Arr] for Arr in inverse_params]

@jax.jit
def scalar_direct_model(params, x, z, t):
	x_ = jnp.hstack([x, z, t])
	for w, b in params[:-1]:
		x_ = jnp.sin(jnp.dot(x_, w) + b)
	return jnp.sum(jnp.dot(x_, params[-1][0]) + params[-1][1])

@jax.jit
def scalar_inverse_model(params, x, z):
	x_ = jnp.hstack([x, z])
	for w, b in params[:-1]:
		x_ = jnp.sin(jnp.dot(x_, w) + b)
	return jnp.sum(jnp.dot(x_, params[-1][0]) + params[-1][1])

direct_model = jax.jit(jax.vmap(scalar_direct_model, in_axes = (None, 0, 0, 0)))
inverse_model = jax.jit(jax.vmap(scalar_inverse_model, in_axes = (None, 0, 0)))

In [4]:
@jax.jit
def mse(pred, true):
	return jnp.mean(jnp.square(pred - true))

@jax.jit
def l2_regularization(params, lambda_0):
	res = 0
	for p in params:
		res += jnp.sum(jnp.square(p[0]))
	return res*lambda_0

@jax.jit
def scalar_dp_dx(params, x, z, t):
    return jnp.sum(jax.grad(scalar_direct_model, 1)(params, x, z, t))

@jax.jit
def scalar_dp_dz(params, x, z, t):
    return jnp.sum(jax.grad(scalar_direct_model, 2)(params, x, z, t))

@jax.jit
def scalar_dp_dt(params, x, z, t):
    return jnp.sum(jax.grad(scalar_direct_model, 3)(params, x, z, t))

@jax.jit
def dp_dxx(params, x, z, t):
    return jax.grad(scalar_dp_dx, 1)(params, x, z, t)

@jax.jit
def dp_dzz(params, x, z, t):
    return jax.grad(scalar_dp_dz, 2)(params, x, z, t)

@jax.jit
def dp_dtt(params, x, z, t):
    return jax.grad(scalar_dp_dt, 3)(params, x, z, t)

@jax.jit
def loss_fn_(params, batch):
# 	direct_params, inverse_params = params
# 	collocation, dirichlet = batch["collocation"], batch["dirichlet"]
	direct_params = params
	dirichlet = batch["dirichlet"]
    
# 	c = inverse_model(inverse_params, collocation.x, collocation.z)
# 	dp_dtt_ = dp_dtt(direct_params, collocation.x, collocation.z, collocation.t)
# 	dp_dxx_ = dp_dxx(direct_params, collocation.x, collocation.z, collocation.t)
# 	dp_dzz_ = dp_dzz(direct_params, collocation.x, collocation.z, collocation.t)
	p_pred = direct_model(direct_params, dirichlet.x, dirichlet.z, dirichlet.t).reshape((-1, 1))
	
# 	loss_c = mse(dp_dtt_ - c**2*(dp_dxx_ + dp_dzz_), 0)
	loss_d = mse(p_pred, dirichlet.p)
# 	return loss_c, loss_d
	return loss_d

@jax.jit
def loss_fn(params, batch):
# 	loss_c, loss_d = loss_fn_(params, batch)
# 	return w_c*loss_c + w_d*loss_d + l2_regularization(params[0], lambda_0) + l2_regularization(params[1], lambda_0)
	loss_d = loss_fn_(params, batch)
	return w_d*loss_d + l2_regularization(params, lambda_0)

@jax.jit
def step(i, opt_state, batch):
	params = get_params(opt_state)
	grad = jax.grad(loss_fn, 0)(params, batch)
	return opt_update(i, grad, opt_state)

@jax.jit
def evaluate(params, batch):
# 	loss_c, loss_d = loss_fn_(params, batch)
# 	return w_c*loss_c + w_d*loss_d, loss_c, loss_d
	return loss_fn_(params, batch)

In [5]:
dataset_Dirichlet = namedtuple("dataset_Dirichlet", ["x", "z", "t", "p"])
dataset_Collocation = namedtuple("dataset_Collocation", ["x", "z", "t"])

x0 = z0 = t0 = 1000

domain = np.array([0.0, 1000.0]) / x0
T_max = 1000.0 / t0
x_s = 500.0 / x0
z_s = 20.0 / z0

import pickle
with open("../9_2020-07-14-devito/dataset_single_source.pkl", "rb") as file:
	[x, t, p, _, _] = pickle.load(file)
t_index = (t >= 200)
t_ = t[t_index] / t0
x /= x0
	
txz_d = tensor_grid([t_, x, [z_s]])
t_d, x_d, z_d = txz_d[:, 0:1], txz_d[:, 1:2], txz_d[:, 2:3]
p_d = p[t_index, :].reshape((-1, 1))

n_cx = n_cz = n_ct = 100000
x_c, z_c, t_c = np.linspace(*domain, n_cx).reshape((-1, 1)), np.linspace(*domain, n_cz).reshape((-1, 1)), np.linspace(0, T_max, n_ct).reshape((-1, 1))

collocation = dataset_Collocation(*(map(lambda x: jnp.array(x), [np.vstack([x_c, x_d]), np.vstack([z_c, z_d]), np.vstack([t_c, t_d])])))
dirichlet = dataset_Dirichlet(*map(lambda x: jnp.array(x), [x_d, z_d, t_d, p_d]))

class Batch_Generator:
	def __init__(self, dataset, batch_size):
		self.dataset = dataset
		self.batch_size = batch_size
		self.index = np.arange(dataset[0].shape[0])
		np.random.shuffle(self.index)
		self.pointer = 0
		
	def __iter__(self):
		return self
	
	def __next__(self):
		if self.pointer >= len(self.index):
			np.random.shuffle(self.index)
			self.pointer = 0
		self.pointer += self.batch_size
		return [d[self.pointer-self.batch_size:self.pointer, :] for d in self.dataset]

In [None]:
lr = 1e-4
iterations = 10000
print_every = 1000
save_every = 10000
batch_size_collocation = 10000
batch_size_dirichlet = 10000
w_c = 1.0
w_d = 1.0

Collocation = Batch_Generator(collocation, batch_size_collocation)
Dirichlet = Batch_Generator(dirichlet, batch_size_dirichlet)
# params = [direct_params, inverse_params]
params = direct_params

opt_init, opt_update, get_params = optimizers.adam(lr)
opt_state = opt_init(params)
for iteration in range(1, iterations+1):
	batch = {
		"dirichlet": dataset_Dirichlet(*next(Dirichlet)),
		"collocation": dataset_Collocation(*next(Collocation))
	}
	opt_state = step(iteration, opt_state, batch)
	if iteration % print_every == 0:
# 		names = ["Loss", "collocation", "dirichlet"]
		names = ["Loss"]
		params_ = get_params(opt_state)
		losses = [evaluate(params_, batch)]
		print("{}, Iteration: {}, Train".format(get_time(), iteration) + \
			  ','.join([" {}: {:.4e}".format(name, loss) for name, loss in zip(names, losses)]))
	if iteration % save_every == 0:
		params_ = np.asarray(get_params(opt_state), dtype = object)
		save_path = "models/{}/iteration_{}/params.npy".format(NAME, iteration)
		if not os.path.exists(os.path.dirname(save_path)):
			os.makedirs(os.path.dirname(save_path))
		np.save(save_path, params_)

2020/07/24, 13:41:01, Iteration: 1000, Train Loss: 2.1913e-04
2020/07/24, 13:41:08, Iteration: 2000, Train Loss: 4.7827e-04
2020/07/24, 13:41:15, Iteration: 3000, Train Loss: 3.6159e-04
2020/07/24, 13:41:21, Iteration: 4000, Train Loss: 5.5528e-04
2020/07/24, 13:41:28, Iteration: 5000, Train Loss: 1.4201e-04
2020/07/24, 13:41:35, Iteration: 6000, Train Loss: 2.9216e-04
2020/07/24, 13:41:42, Iteration: 7000, Train Loss: 3.5833e-04


In [None]:
p_pred = direct_model(get_params(opt_state), dirichlet.x, dirichlet.z, dirichlet.t).reshape((len(t_), len(x)))
p_true = dirichlet.p.reshape((len(t_), len(x)))

In [None]:
import matplotlib as mpl
from matplotlib.cm import cool
from matplotlib.colors import Normalize

from mpl_toolkits.axes_grid1 import make_axes_locatable

from matplotlib import pyplot as plt
plt.rcParams.update(plt.rcParamsDefault)

cmap = cool
norm = Normalize(vmin=-5.0, vmax=8.0)

X, T = np.meshgrid(x, t_)
fig, ax = plt.subplots(1, 3, figsize = (15, 5))
fig.subplots_adjust(right = 1.0)

im0 = ax[0].contourf(X, T, p_true, cmap = cmap, norm = norm, levels = 1000)
ax[0].set_title("true")
divider = make_axes_locatable(ax[0])
cax = divider.append_axes('right', size='5%', pad=0.05)
mpl.colorbar.ColorbarBase(cax, cmap = cmap, norm = norm, orientation='vertical')

# norm = mpl.colors.Normalize(vmin=1.0, vmax=3.0)
im1 = ax[1].contourf(X, T, p_pred, cmap = cmap, norm = norm, levels = 1000)
ax[1].set_title("pred")
ax[1].set_xlabel("")
divider = make_axes_locatable(ax[1])
cax = divider.append_axes('right', size='5%', pad=0.05)
mpl.colorbar.ColorbarBase(cax, cmap = cmap, norm = norm, orientation='vertical')

norm = mpl.colors.Normalize(vmin=-1e-1, vmax=1e-1)
im2 = ax[2].contourf(X, T, p_true - p_pred, cmap = cmap, norm = norm, levels = 1000)
ax[2].set_title("MSE: {}".format(np.mean(np.square(p_true - p_pred))))
divider = make_axes_locatable(ax[2])
cax = divider.append_axes('right', size='5%', pad=0.05)
mpl.colorbar.ColorbarBase(cax, cmap = cmap, norm = norm, orientation='vertical')

plt.show()

In [None]:
# domain = np.array([[0.0, 0.0], [1.0, 1.0]])
# def c_fn(x, z):
# 	return np.piecewise(z, [z >= 0.5, z < 0.5], [2.5, 1.5])

# x_test = np.linspace(domain[0, 0], domain[1, 0], 100).reshape((-1, 1))
# z_test = np.linspace(domain[0, 1], domain[1, 1], 100).reshape((-1, 1))
# xz_test = tensor_grid([x_test, z_test])
# c_test = c_fn(xz_test[:, 0:1], xz_test[:, 1:2])



In [None]:
# c_pred = inverse_model(get_params(opt_state)[1], xz_test[:, 0:1], xz_test[:, 1:2]).reshape((-1, 1))

import matplotlib as mpl
from matplotlib.cm import cool
from matplotlib.colors import Normalize

from mpl_toolkits.axes_grid1 import make_axes_locatable

from matplotlib import pyplot as plt
plt.rcParams.update(plt.rcParamsDefault)

cmap = cool
norm = Normalize(vmin=1.0, vmax=3.0)

X, Z = np.meshgrid(x_test, z_test)
fig, ax = plt.subplots(1, 3, figsize = (15, 5))
fig.subplots_adjust(right = 1.0)

im0 = ax[0].contourf(X, Z, c_test.reshape((len(z_test), len(x_test))), cmap = cmap, norm = norm, levels = 1000)
ax[0].set_title("true")
divider = make_axes_locatable(ax[0])
cax = divider.append_axes('right', size='5%', pad=0.05)
mpl.colorbar.ColorbarBase(cax, cmap = cmap, norm = norm, orientation='vertical')

norm = mpl.colors.Normalize(vmin=1.0, vmax=3.0)
im1 = ax[1].contourf(X, Z, c_pred.reshape((len(z_test), len(x_test))), cmap = cmap, norm = norm, levels = 1000)
ax[1].set_title("pred")
ax[1].set_xlabel("")
divider = make_axes_locatable(ax[1])
cax = divider.append_axes('right', size='5%', pad=0.05)
mpl.colorbar.ColorbarBase(cax, cmap = cmap, norm = norm, orientation='vertical')

norm = mpl.colors.Normalize(vmin=-1e-1, vmax=1e-1)
im2 = ax[2].contourf(X, Z, c_test.reshape((len(z_test), len(x_test)))-c_pred.reshape((len(z_test), len(x_test))), cmap = cmap, norm = norm, levels = 1000)
ax[2].set_title("MSE: {}".format(np.mean(np.square(c_test - c_pred))))
divider = make_axes_locatable(ax[2])
cax = divider.append_axes('right', size='5%', pad=0.05)
mpl.colorbar.ColorbarBase(cax, cmap = cmap, norm = norm, orientation='vertical')

plt.show()