In [1]:
NAME = "0806"

In [2]:
import jax, jax.nn
from jax import random
import jax.numpy as jnp
from jax.experimental import optimizers

import sys, os
sys.path.append("../../")
	
from Seismic_wave_inversion_PINN.data_utils import *
from Seismic_wave_inversion_PINN.jax_model import *

from collections import namedtuple

In [3]:
key = random.PRNGKey(0)
key, subkey = random.split(key, 2)

layers = [3] + [512]*8 + [2] # (x, z, t) -> (p)
c0 = 6.0
w0 = 10.0
w1 = 1.0
lambda_0 = 1e-12
direct_params = init_siren_params(subkey, layers, c0, w0, w1)

domain = jnp.array([[-1, -1, 0], [1., 1., 1.]])

@jax.jit
def direct_model(params, xzt):
	for w, b in params[:-1]:
		xzt = jnp.sin(jnp.dot(xzt, w) + b)
	return jnp.dot(xzt, params[-1][0]) + params[-1][1]

jacobian = jacrev_fn(direct_model)
hessian = hessian_fn(direct_model)

from jax import lax

@jax.jit
def scalar_inverse_model(x, z): # scaled to [-1, 1]
	return lax.cond(z[0] >= 0.0, z, lambda z: z*0+2.5, z, lambda z: z*0+1.5)

inverse_model = jax.jit(jax.vmap(scalar_inverse_model, in_axes = (0, 0)))

In [4]:
metaloss = mae

@jax.jit
def loss_fn_(params, batch):
	collocation, dirichlet = batch["collocation"], batch["dirichlet"]
	direct_params = params
    
	c = inverse_model(collocation.x, collocation.z)
	
	# hessian[i] = [
    #				[[dp/dxx, dp/dxz, dp/dxt],
	#                [dp/dxz, dp/dzz, dp/dzt],
	#                [dp/dxt, dp/dzt, dp/dtt]],
	#               ]
	dp_dxxzztt = hessian(direct_params, jnp.hstack([collocation.x, collocation.z, collocation.t]))
	dp_dxx_c, dp_dzz_c, dp_dtt_c = dp_dxxzztt[:, 0, 0, 0], dp_dxxzztt[:, 0, 1, 1], dp_dxxzztt[:, 0, 2, 2]
	
	p_d = direct_model(direct_params, jnp.hstack([dirichlet.x, dirichlet.z, dirichlet.t]))
	
	loss_c = mse(dp_dtt_c - c**2*(dp_dxx_c + dp_dzz_c), 0)
	loss_d = mse(p_d, dirichlet.p)
	return loss_c, loss_d

@jax.jit
def loss_fn(params, batch):
	
	loss_c, loss_d = loss_fn_(params, batch)
	return batch["weights"]["c"]*loss_c + batch["weights"]["d"]*loss_d + l2_regularization(params[0], lambda_0)

@jax.jit
def step(i, opt_state, batch):
	params = get_params(opt_state)
	grad = jax.grad(loss_fn, 0)(params, batch)
	return opt_update(i, grad, opt_state)

@jax.jit
def evaluate(params, batch):
	loss_c, loss_d = loss_fn_(params, batch)
	return batch["weights"]["c"]*loss_c + batch["weights"]["d"]*loss_d, loss_c, loss_d

In [5]:
import pickle 
with open("dataset_single_source.pkl", "rb") as file:
	x, z, t, p = pickle.load(file)
	
tzx = tensor_grid([t, z, x])
x, z, t = tzx[:, 2:3], tzx[:, 1:2], tzx[:, 0:1]
p = p.reshape((-1, 1))

dataset_Dirichlet = namedtuple("dataset_Dirichlet", ["x", "z", "t", "p"])
dirichlet = dataset_Dirichlet(*map(lambda x: jnp.array(x), [x, z, t, p]))

In [6]:
lr = 1e-3
start_iteration = 0
iterations = 100000
print_every = 100
save_every = 10000
batch_size = {"collocation": 1000, "dirichlet": 1000}
weights = {"c": 1e-4, "d": 1.0}

key = random.PRNGKey(1)
Dirichlet = Batch_Generator(key, dirichlet, batch_size["dirichlet"])
params = direct_params

opt_init, opt_update, get_params = optimizers.adam(lr)
opt_state = opt_init(params)

for iteration in range(start_iteration+1, start_iteration+iterations+1):
	diri = dataset_Dirichlet(*next(Dirichlet))
	collo = diri
	batch = {
		"dirichlet": diri,
		"collocation": collo,
		"weights": weights,
	}
	opt_state = step(iteration, opt_state, batch)
	if (iteration-start_iteration) % print_every == 0:
		names = ["Loss", "c", "d"]
		params_ = get_params(opt_state)
		losses = evaluate(params_, batch)
		print("{}, Iteration: {}, Train".format(get_time(), iteration) + \
			  ','.join([" {}: {:.4e}".format(name, loss) for name, loss in zip(names, losses)]))
	if (iteration-start_iteration) % save_every == 0:
		params_ = np.asarray(get_params(opt_state), dtype = object)
		save_path = "models/{}/iteration_{}/params.npy".format(NAME, iteration)
		if not os.path.exists(os.path.dirname(save_path)):
			os.makedirs(os.path.dirname(save_path))
		np.save(save_path, params_)

2020/08/06, 15:24:05, Iteration: 100, Train Loss: 3.4821e+01, c: 3.3642e+05, d: 1.1786e+00
2020/08/06, 15:24:07, Iteration: 200, Train Loss: 1.0939e+01, c: 9.4121e+04, d: 1.5268e+00
2020/08/06, 15:24:09, Iteration: 300, Train Loss: 6.0872e+00, c: 4.6581e+04, d: 1.4291e+00
2020/08/06, 15:24:11, Iteration: 400, Train Loss: 4.3323e+00, c: 2.9573e+04, d: 1.3750e+00
2020/08/06, 15:24:13, Iteration: 500, Train Loss: 3.6118e+00, c: 1.8055e+04, d: 1.8062e+00
2020/08/06, 15:24:16, Iteration: 600, Train Loss: 2.8147e+00, c: 1.3564e+04, d: 1.4584e+00
2020/08/06, 15:24:19, Iteration: 700, Train Loss: 2.3297e+00, c: 9.9801e+03, d: 1.3317e+00
2020/08/06, 15:24:21, Iteration: 800, Train Loss: 2.1995e+00, c: 8.1102e+03, d: 1.3885e+00
2020/08/06, 15:24:24, Iteration: 900, Train Loss: 2.0210e+00, c: 7.1084e+03, d: 1.3101e+00
2020/08/06, 15:24:27, Iteration: 1000, Train Loss: 1.7730e+00, c: 5.8831e+03, d: 1.1847e+00
2020/08/06, 15:24:30, Iteration: 1100, Train Loss: 1.8542e+00, c: 4.3545e+03, d: 1.4187e+

2020/08/06, 15:27:42, Iteration: 9100, Train Loss: 7.2006e-01, c: 4.8382e+01, d: 7.1522e-01
2020/08/06, 15:27:44, Iteration: 9200, Train Loss: 7.0487e-01, c: 6.3244e+01, d: 6.9854e-01
2020/08/06, 15:27:46, Iteration: 9300, Train Loss: 7.2537e-01, c: 4.5353e+01, d: 7.2083e-01
2020/08/06, 15:27:48, Iteration: 9400, Train Loss: 8.2255e-01, c: 5.6630e+01, d: 8.1689e-01
2020/08/06, 15:27:50, Iteration: 9500, Train Loss: 6.2183e-01, c: 5.6195e+01, d: 6.1621e-01
2020/08/06, 15:27:53, Iteration: 9600, Train Loss: 6.8525e-01, c: 6.9922e+01, d: 6.7826e-01
2020/08/06, 15:27:55, Iteration: 9700, Train Loss: 7.2744e-01, c: 4.4128e+01, d: 7.2303e-01
2020/08/06, 15:27:57, Iteration: 9800, Train Loss: 7.0512e-01, c: 4.0025e+01, d: 7.0112e-01
2020/08/06, 15:27:59, Iteration: 9900, Train Loss: 7.9326e-01, c: 6.2555e+01, d: 7.8701e-01
2020/08/06, 15:28:01, Iteration: 10000, Train Loss: 5.8221e-01, c: 4.8914e+01, d: 5.7732e-01
2020/08/06, 15:28:04, Iteration: 10100, Train Loss: 5.1642e-01, c: 5.1601e+01, 

2020/08/06, 15:30:56, Iteration: 18000, Train Loss: 2.8444e+07, c: 2.8444e+11, d: 1.2768e+00
2020/08/06, 15:30:58, Iteration: 18100, Train Loss: 1.6960e+07, c: 1.6960e+11, d: 1.4276e+00
2020/08/06, 15:31:00, Iteration: 18200, Train Loss: 3.3186e+06, c: 3.3186e+10, d: 1.3720e+00
2020/08/06, 15:31:02, Iteration: 18300, Train Loss: 1.1694e+06, c: 1.1694e+10, d: 1.4802e+00
2020/08/06, 15:31:04, Iteration: 18400, Train Loss: 7.2906e+05, c: 7.2905e+09, d: 1.3813e+00
2020/08/06, 15:31:06, Iteration: 18500, Train Loss: 4.2242e+05, c: 4.2242e+09, d: 1.2540e+00
2020/08/06, 15:31:09, Iteration: 18600, Train Loss: 5.1013e+05, c: 5.1013e+09, d: 1.5007e+00
2020/08/06, 15:31:11, Iteration: 18700, Train Loss: 2.8997e+05, c: 2.8997e+09, d: 1.5035e+00
2020/08/06, 15:31:13, Iteration: 18800, Train Loss: 9.7683e+04, c: 9.7681e+08, d: 1.4297e+00
2020/08/06, 15:31:15, Iteration: 18900, Train Loss: 8.5455e+04, c: 8.5453e+08, d: 1.2595e+00
2020/08/06, 15:31:17, Iteration: 19000, Train Loss: 1.4460e+05, c: 1.4

2020/08/06, 15:34:06, Iteration: 26900, Train Loss: 1.3778e+14, c: 1.3778e+18, d: 1.3895e+00
2020/08/06, 15:34:08, Iteration: 27000, Train Loss: 1.4559e+14, c: 1.4559e+18, d: 1.3890e+00
2020/08/06, 15:34:10, Iteration: 27100, Train Loss: 3.6615e+14, c: 3.6615e+18, d: 1.1344e+00
2020/08/06, 15:34:12, Iteration: 27200, Train Loss: 3.7673e+14, c: 3.7673e+18, d: 1.3272e+00
2020/08/06, 15:34:14, Iteration: 27300, Train Loss: 1.0507e+15, c: 1.0507e+19, d: 1.3269e+00
2020/08/06, 15:34:16, Iteration: 27400, Train Loss: 6.4204e+14, c: 6.4204e+18, d: 1.3618e+00
2020/08/06, 15:34:18, Iteration: 27500, Train Loss: 1.5227e+15, c: 1.5227e+19, d: 1.3910e+00
2020/08/06, 15:34:21, Iteration: 27600, Train Loss: 7.3925e+14, c: 7.3925e+18, d: 1.4952e+00
2020/08/06, 15:34:23, Iteration: 27700, Train Loss: 2.1993e+14, c: 2.1993e+18, d: 1.3466e+00
2020/08/06, 15:34:25, Iteration: 27800, Train Loss: 2.9316e+14, c: 2.9316e+18, d: 1.4416e+00
2020/08/06, 15:34:27, Iteration: 27900, Train Loss: 1.0626e+14, c: 1.0

2020/08/06, 15:37:13, Iteration: 35800, Train Loss: 5.3099e+15, c: 5.3099e+19, d: 1.3925e+00
2020/08/06, 15:37:15, Iteration: 35900, Train Loss: 3.6499e+15, c: 3.6499e+19, d: 1.1079e+00
2020/08/06, 15:37:17, Iteration: 36000, Train Loss: 4.4623e+15, c: 4.4623e+19, d: 1.2823e+00
2020/08/06, 15:37:20, Iteration: 36100, Train Loss: 1.9463e+15, c: 1.9463e+19, d: 1.2707e+00
2020/08/06, 15:37:22, Iteration: 36200, Train Loss: 2.9481e+15, c: 2.9481e+19, d: 1.2547e+00
2020/08/06, 15:37:24, Iteration: 36300, Train Loss: 2.0822e+15, c: 2.0822e+19, d: 1.5282e+00
2020/08/06, 15:37:26, Iteration: 36400, Train Loss: 3.6476e+15, c: 3.6476e+19, d: 1.4122e+00
2020/08/06, 15:37:28, Iteration: 36500, Train Loss: 3.8599e+16, c: 3.8599e+20, d: 1.5057e+00
2020/08/06, 15:37:30, Iteration: 36600, Train Loss: 1.7734e+16, c: 1.7734e+20, d: 1.3159e+00
2020/08/06, 15:37:32, Iteration: 36700, Train Loss: 1.5401e+16, c: 1.5401e+20, d: 1.2221e+00
2020/08/06, 15:37:34, Iteration: 36800, Train Loss: 2.9992e+16, c: 2.9

2020/08/06, 15:40:20, Iteration: 44700, Train Loss: 3.8290e+16, c: 3.8290e+20, d: 1.3948e+00
2020/08/06, 15:40:22, Iteration: 44800, Train Loss: 7.0226e+16, c: 7.0226e+20, d: 1.4878e+00
2020/08/06, 15:40:24, Iteration: 44900, Train Loss: 6.3977e+16, c: 6.3977e+20, d: 1.4717e+00
2020/08/06, 15:40:26, Iteration: 45000, Train Loss: 8.3193e+16, c: 8.3193e+20, d: 1.5979e+00
2020/08/06, 15:40:28, Iteration: 45100, Train Loss: 1.1406e+17, c: 1.1406e+21, d: 1.4104e+00
2020/08/06, 15:40:31, Iteration: 45200, Train Loss: 1.0390e+17, c: 1.0390e+21, d: 1.4251e+00
2020/08/06, 15:40:33, Iteration: 45300, Train Loss: 4.0636e+16, c: 4.0636e+20, d: 1.4472e+00
2020/08/06, 15:40:35, Iteration: 45400, Train Loss: 1.6458e+16, c: 1.6458e+20, d: 1.3510e+00
2020/08/06, 15:40:37, Iteration: 45500, Train Loss: 2.7643e+16, c: 2.7643e+20, d: 1.3912e+00
2020/08/06, 15:40:39, Iteration: 45600, Train Loss: 2.5709e+16, c: 2.5709e+20, d: 1.2235e+00
2020/08/06, 15:40:41, Iteration: 45700, Train Loss: 3.6644e+16, c: 3.6

2020/08/06, 15:43:27, Iteration: 53600, Train Loss: 5.4094e+17, c: 5.4094e+21, d: 1.7491e+00
2020/08/06, 15:43:29, Iteration: 53700, Train Loss: 2.0636e+17, c: 2.0636e+21, d: 1.6666e+00
2020/08/06, 15:43:31, Iteration: 53800, Train Loss: 2.1500e+17, c: 2.1500e+21, d: 1.3826e+00
2020/08/06, 15:43:33, Iteration: 53900, Train Loss: 4.4535e+17, c: 4.4535e+21, d: 1.6611e+00
2020/08/06, 15:43:35, Iteration: 54000, Train Loss: 2.8614e+17, c: 2.8614e+21, d: 1.3734e+00
2020/08/06, 15:43:37, Iteration: 54100, Train Loss: 2.4645e+17, c: 2.4645e+21, d: 1.6443e+00
2020/08/06, 15:43:39, Iteration: 54200, Train Loss: 4.1716e+17, c: 4.1716e+21, d: 1.4877e+00
2020/08/06, 15:43:41, Iteration: 54300, Train Loss: 2.2604e+17, c: 2.2604e+21, d: 1.4211e+00
2020/08/06, 15:43:44, Iteration: 54400, Train Loss: 2.4300e+17, c: 2.4300e+21, d: 1.3701e+00
2020/08/06, 15:43:46, Iteration: 54500, Train Loss: 4.2109e+17, c: 4.2109e+21, d: 1.4374e+00
2020/08/06, 15:43:48, Iteration: 54600, Train Loss: 3.6892e+17, c: 3.6

2020/08/06, 15:46:33, Iteration: 62500, Train Loss: 9.3793e+17, c: 9.3793e+21, d: 1.2931e+00
2020/08/06, 15:46:35, Iteration: 62600, Train Loss: 8.0646e+17, c: 8.0646e+21, d: 1.5336e+00
2020/08/06, 15:46:38, Iteration: 62700, Train Loss: 6.5091e+17, c: 6.5091e+21, d: 1.3755e+00
2020/08/06, 15:46:40, Iteration: 62800, Train Loss: 1.0502e+18, c: 1.0502e+22, d: 1.3898e+00
2020/08/06, 15:46:42, Iteration: 62900, Train Loss: 6.4976e+17, c: 6.4976e+21, d: 1.3637e+00
2020/08/06, 15:46:44, Iteration: 63000, Train Loss: 6.4560e+17, c: 6.4560e+21, d: 1.3857e+00
2020/08/06, 15:46:46, Iteration: 63100, Train Loss: 5.4213e+17, c: 5.4213e+21, d: 1.6032e+00
2020/08/06, 15:46:48, Iteration: 63200, Train Loss: 1.1442e+18, c: 1.1442e+22, d: 1.2220e+00
2020/08/06, 15:46:50, Iteration: 63300, Train Loss: 1.0770e+18, c: 1.0770e+22, d: 1.3765e+00
2020/08/06, 15:46:52, Iteration: 63400, Train Loss: 1.9335e+18, c: 1.9335e+22, d: 1.5115e+00
2020/08/06, 15:46:54, Iteration: 63500, Train Loss: 1.7632e+18, c: 1.7

2020/08/06, 15:49:40, Iteration: 71400, Train Loss: 6.4832e+17, c: 6.4832e+21, d: 1.7763e+00
2020/08/06, 15:49:42, Iteration: 71500, Train Loss: 4.6523e+17, c: 4.6523e+21, d: 1.6029e+00
2020/08/06, 15:49:44, Iteration: 71600, Train Loss: 1.0809e+18, c: 1.0809e+22, d: 1.3620e+00
2020/08/06, 15:49:46, Iteration: 71700, Train Loss: 9.2040e+17, c: 9.2040e+21, d: 1.3362e+00
2020/08/06, 15:49:48, Iteration: 71800, Train Loss: 1.4536e+18, c: 1.4536e+22, d: 1.5914e+00
2020/08/06, 15:49:50, Iteration: 71900, Train Loss: 6.9279e+17, c: 6.9279e+21, d: 1.5720e+00
2020/08/06, 15:49:52, Iteration: 72000, Train Loss: 1.3423e+18, c: 1.3423e+22, d: 1.4296e+00
2020/08/06, 15:49:55, Iteration: 72100, Train Loss: 5.6774e+17, c: 5.6774e+21, d: 1.3771e+00
2020/08/06, 15:49:57, Iteration: 72200, Train Loss: 9.1756e+17, c: 9.1756e+21, d: 1.3997e+00
2020/08/06, 15:49:59, Iteration: 72300, Train Loss: 1.3166e+18, c: 1.3166e+22, d: 1.4706e+00
2020/08/06, 15:50:01, Iteration: 72400, Train Loss: 1.7782e+18, c: 1.7

2020/08/06, 15:52:46, Iteration: 80300, Train Loss: 1.4070e+18, c: 1.4070e+22, d: 1.4820e+00
2020/08/06, 15:52:48, Iteration: 80400, Train Loss: 1.5399e+18, c: 1.5399e+22, d: 1.3968e+00
2020/08/06, 15:52:51, Iteration: 80500, Train Loss: 2.0685e+17, c: 2.0685e+21, d: 1.2283e+00
2020/08/06, 15:52:53, Iteration: 80600, Train Loss: 1.4493e+18, c: 1.4493e+22, d: 1.3941e+00
2020/08/06, 15:52:55, Iteration: 80700, Train Loss: 1.8781e+18, c: 1.8781e+22, d: 1.7069e+00
2020/08/06, 15:52:57, Iteration: 80800, Train Loss: 1.4198e+18, c: 1.4198e+22, d: 1.3641e+00
2020/08/06, 15:52:59, Iteration: 80900, Train Loss: 1.0016e+18, c: 1.0016e+22, d: 1.3619e+00
2020/08/06, 15:53:01, Iteration: 81000, Train Loss: 1.0097e+18, c: 1.0097e+22, d: 1.2431e+00
2020/08/06, 15:53:03, Iteration: 81100, Train Loss: 1.4498e+18, c: 1.4498e+22, d: 1.4162e+00
2020/08/06, 15:53:05, Iteration: 81200, Train Loss: 1.1667e+18, c: 1.1667e+22, d: 1.2276e+00
2020/08/06, 15:53:07, Iteration: 81300, Train Loss: 1.2408e+18, c: 1.2

2020/08/06, 15:55:53, Iteration: 89200, Train Loss: 1.5835e+18, c: 1.5835e+22, d: 1.3575e+00
2020/08/06, 15:55:55, Iteration: 89300, Train Loss: 1.5737e+18, c: 1.5737e+22, d: 1.2398e+00
2020/08/06, 15:55:57, Iteration: 89400, Train Loss: 3.6317e+18, c: 3.6317e+22, d: 1.4522e+00
2020/08/06, 15:55:59, Iteration: 89500, Train Loss: 1.0209e+18, c: 1.0209e+22, d: 1.5786e+00
2020/08/06, 15:56:01, Iteration: 89600, Train Loss: 1.1944e+18, c: 1.1944e+22, d: 1.5069e+00
2020/08/06, 15:56:03, Iteration: 89700, Train Loss: 2.2509e+18, c: 2.2509e+22, d: 1.3478e+00
2020/08/06, 15:56:05, Iteration: 89800, Train Loss: 4.0135e+18, c: 4.0135e+22, d: 1.6831e+00
2020/08/06, 15:56:07, Iteration: 89900, Train Loss: 2.2048e+18, c: 2.2048e+22, d: 1.3890e+00
2020/08/06, 15:56:10, Iteration: 90000, Train Loss: 3.8351e+18, c: 3.8351e+22, d: 1.2893e+00
2020/08/06, 15:56:12, Iteration: 90100, Train Loss: 3.1575e+18, c: 3.1575e+22, d: 1.4589e+00
2020/08/06, 15:56:14, Iteration: 90200, Train Loss: 1.5133e+18, c: 1.5

2020/08/06, 15:58:59, Iteration: 98100, Train Loss: 4.5115e+18, c: 4.5115e+22, d: 1.3574e+00
2020/08/06, 15:59:01, Iteration: 98200, Train Loss: 3.7089e+18, c: 3.7089e+22, d: 1.2738e+00
2020/08/06, 15:59:03, Iteration: 98300, Train Loss: 3.5218e+18, c: 3.5218e+22, d: 1.4617e+00
2020/08/06, 15:59:06, Iteration: 98400, Train Loss: 5.3014e+18, c: 5.3014e+22, d: 1.2753e+00
2020/08/06, 15:59:08, Iteration: 98500, Train Loss: 2.4368e+18, c: 2.4368e+22, d: 1.6868e+00
2020/08/06, 15:59:10, Iteration: 98600, Train Loss: 2.0137e+18, c: 2.0137e+22, d: 1.5161e+00
2020/08/06, 15:59:12, Iteration: 98700, Train Loss: 2.8768e+18, c: 2.8768e+22, d: 1.2942e+00
2020/08/06, 15:59:14, Iteration: 98800, Train Loss: 1.3046e+18, c: 1.3046e+22, d: 1.1059e+00
2020/08/06, 15:59:16, Iteration: 98900, Train Loss: 4.9534e+18, c: 4.9534e+22, d: 1.5916e+00
2020/08/06, 15:59:18, Iteration: 99000, Train Loss: 2.2387e+18, c: 2.2387e+22, d: 1.5031e+00
2020/08/06, 15:59:20, Iteration: 99100, Train Loss: 2.8591e+18, c: 2.8