# Deterministic Linear Transport equation
---
Consider the equation
$$
\left\{
\begin{aligned}
&\frac{\partial r}{\partial t} + v\frac{\partial j}{\partial x} = \frac{\sigma(x, z)}{\epsilon^2}(\hat{r} - r), \\
&\frac{\partial j}{\partial t} + \frac{v}{\epsilon^2}\frac{\partial r}{\partial x} = - \frac{\sigma(x, z)}{\epsilon^2}j, 
\end{aligned}
\right.
$$
where $\epsilon$ is a small number, $v \in [0, 1]$, and $\hat{r} = \int_{0}^{1} rdv$. 

It seems to me that $v$ follows a uniform distribution (as Gauss-Legendre quadrature is used to compute the integral).

We let $\sigma(x, z) \equiv 1$.

The initial data are
$$
\begin{aligned}
&r = 1, \ x < 0.5 \\
&r = 0, x \ge 0.5 \\
&j = 0.
\end{aligned}
$$
$$

BC: for all $v, x, t$,
$$
\sigma j = -v r_{x}, 
$$
and for all $v$, 
$$
r - \frac{\epsilon}{\sigma}vr_x{\Large \bracevert}_{x = 0} = 1, \quad r + \frac{\epsilon}{\sigma}vr_x{\Large \bracevert}_{x = 1} = 0.
$$

The spatiotemporal domain is 
$$
(x, t, v) \in [0, 1]\times [0, 0.01] \times [0, 1].
$$

In [1]:
NAME = "2_1_new_ic"

In [2]:
import jax, jax.nn
from jax import random
import jax.numpy as jnp
from jax.experimental import optimizers
from jax.ops import index, index_add, index_update


import sys, os
sys.path.append("../../")
	
from Seismic_wave_inversion_PINN.data_utils import *
from Seismic_wave_inversion_PINN.jax_model import *

from collections import namedtuple
# from jax.config import config; config.update("jax_enable_x64", True)
# dtype = jnp.float64
dtype = jnp.float32

In [3]:
key = random.PRNGKey(1)
key, subkey = random.split(key, 2)

layers = [3] + [32]*4 + [2] # (x, t, v) -> (r, j)
c0 = 1.0
w0 = jnp.array([[1.0, 1.0, 1.0]]).T
w1 = jnp.array([[1.0, 1.0]]) # (w_r, w_j)
direct_params = init_siren_params(subkey, layers, c0, w0, w1, dtype)

domain = jnp.array([[0., 0., 0.0], [1., 0.01, 1.0]])

sigma = 1.0
epsilon = 1.0

@jax.jit
def model(params, xtv): # for predictions
	# linear scaling
	xtv = (2*xtv - (domain[0, :]+domain[1, :]))/(domain[1, :] - domain[0, :])
	for w, b in params[:-1]:
		xtv = jnp.sin(jnp.dot(xtv, w) + b)
	return jnp.dot(xtv, params[-1][0]) + params[-1][1]

# @jax.jit
# def model_(params, xt): # for derivatives
# 	for w, b in params[:-1]:
# 		xt = jnp.sin(jnp.dot(xt, w) + b)
# 	return jnp.dot(xt, params[-1][0]) + params[-1][1]

jacobian = jacrev_fn(model)
hessian = hessian_fn(model)

In [4]:
metaloss = mae

static_jit = lambda i: jax.partial(jax.jit, static_argnums = i)

@jax.jit
def quadrature(params, x, t, v, w):
	xt_ = jnp.repeat(jnp.hstack([x, t]), w.shape[0], axis = 0)
	v_ = jnp.tile(v, (x.shape[0], 1))
	rj = model(params, jnp.hstack([xt_, v_]))
	r = rj[:, 0:1].reshape((x.shape[0], w.shape[0]))
	return jnp.dot(r, w)

# jacobian[i] = [[dr/dx, dr/dt, dr/dv],
#                [dj/dx, dj/dt, dj/dv]]
# i: the i^th input

# hessian[i] = [
#				[[du/dxx, du/dxy],
#                [du/dxy, du/dyy]],
#               [[dv/dxx, dv/dxy],
#                [dv/dxy, dv/dyy]]
#              ]
@jax.jit
def loss_fn_(params, batch):
	collocation, dirichlet, bc, quad = batch["collocation"], batch["dirichlet"], batch["bc"], batch["quad"]
	direct_params = params
	
	if collocation[0] is not None:
		rj_c = model(direct_params, jnp.hstack([collocation.x, collocation.t, collocation.v]))
		r_c, j_c = rj_c[:, 0:1], rj_c[:, 1:2]
		drj_dxtv_c = jacobian(direct_params, jnp.hstack([collocation.x, collocation.t, collocation.v]))
		dr_dt_c, dj_dt_c = drj_dxtv_c[:, 0:1, 1], drj_dxtv_c[:, 1:2, 1]
		dr_dx_c, dj_dx_c = drj_dxtv_c[:, 0:1, 0], drj_dxtv_c[:, 1:2, 0]
		
		# quad.w: [q, 1]
		# quad.v: [q, 1]
		r_hat_c = quadrature(direct_params, collocation.x, collocation.t, quad.v, quad.w)
		
		loss_c1 = metaloss(epsilon**2*(dr_dt_c + collocation.v*dj_dx_c), sigma*(r_hat_c - r_c))
		loss_c2 = metaloss(epsilon**2*dj_dt_c + collocation.v*dr_dx_c, -sigma*j_c)
	else:
		loss_c1 = loss_c2 = 0
        
	if dirichlet[0] is not None:
		rj_d = model(direct_params, jnp.hstack([dirichlet.x, dirichlet.t, dirichlet.v]))
		r_d, j_d = rj_d[:, 0:1], rj_d[:, 1:2]
		loss_d1 = metaloss(r_d, dirichlet.r)
		loss_d2 = metaloss(j_d, dirichlet.j)
		loss_d = loss_d1 + loss_d2
	else:
		loss_d = 0.0
		
	if bc[0] is not None:
		rj_bl = model(direct_params, jnp.hstack([bc.l, bc.t, bc.v]))
		rj_br = model(direct_params, jnp.hstack([bc.r, bc.t, bc.v]))
		r_bl, j_bl = rj_bl[:, 0:1], rj_bl[:, 1:2]
		r_br, j_br = rj_br[:, 0:1], rj_br[:, 1:2]
		drj_dxtv_bl = jacobian(direct_params, jnp.hstack([bc.l, bc.t, bc.v]))
		drj_dxtv_br = jacobian(direct_params, jnp.hstack([bc.r, bc.t, bc.v]))
		dr_dx_bl, dr_dx_br = drj_dxtv_bl[:, 0:1, 0], drj_dxtv_br[:, 0:1, 0]
		
		loss_b1 = metaloss(sigma*j_bl, -bc.v*dr_dx_bl) + metaloss(sigma*j_br, -bc.v*dr_dx_br)
		loss_b2 = metaloss(r_bl - epsilon/sigma*bc.v*dr_dx_bl, 1.0) + metaloss(r_br + epsilon/sigma*bc.v*dr_dx_br, 0.0)

	return loss_c1, loss_c2, loss_d, loss_b1, loss_b2

@jax.jit
def loss_fn(params, batch):
	w = batch["weights"]
	loss_c1, loss_c2, loss_d, loss_b1, loss_b2 = loss_fn_(params, batch)
	return w["c1"]*loss_c1 + w["c2"]*loss_c2 + w["d"]*loss_d + w["b1"]*loss_b1 + w["b2"]*loss_b2 + \
			l1_regularization(params, w["l1"]) + l2_regularization(params, w["l2"])

@jax.jit
def step(i, opt_state, batch):
	params = get_params(opt_state)
	grad = jax.grad(loss_fn, 0)(params, batch)
	return opt_update(i, grad, opt_state)

@jax.jit
def evaluate(params, batch):
	w = batch["weights"]
	loss_c1, loss_c2, loss_d, loss_b1, loss_b2 = loss_fn_(params, batch)
	l1 = l1_regularization(params, 1.0)
	l2 = l2_regularization(params, 1.0)
	return w["c1"]*loss_c1 + w["c2"]*loss_c2 + w["d"]*loss_d + w["b1"]*loss_b1 + w["b2"]*loss_b2, \
			loss_c1, loss_c2, loss_d, loss_b1, loss_b2, l1, l2

In [5]:
# r0_fn = lambda x, t, v: jnp.zeros_like(x)
# r0_fn = lambda x, t, v: jnp.select([jnp.isclose(x, domain[0, 0]), x > 0], [1.0, 0.0])
r0_fn = lambda x, t, v: jnp.select([x < 0.5, x >= 0.5], [1.0, 0.0])
j0_fn = lambda x, t, v: jnp.zeros_like(x)

# r0_fn_ = lambda xtv: jnp.zeros_like(xtv[0])
# j0_fn_ = lambda xtv: jnp.zeros_like(xtv[0])

# dr0_dt_fn = lambda xtv: jax.vmap(jax.jacfwd(r0_fn_), in_axes = 0)(xtv)[:, 1:2]
# dj0_dt_fn = lambda xtv: jax.vmap(jax.jacfwd(j0_fn_), in_axes = 0)(xtv)[:, 1:2]

# dr0_dx_fn = lambda xtv: jax.vmap(jax.jacfwd(r0_fn_), in_axes = 0)(xtv)[:, 0:1]
# dj0_dx_fn = lambda xtv: jax.vmap(jax.jacfwd(j0_fn_), in_axes = 0)(xtv)[:, 0:1]

key, *subkeys = random.split(key, 3)

n_quad = 16
v_quad, w_quad = np.polynomial.legendre.leggauss(n_quad)
v_quad = jnp.array(0.5*(v_quad+1), dtype = jnp.float32).reshape((-1, 1))
w_quad = jnp.array(0.5*w_quad, dtype = jnp.float32).reshape((-1, 1))

n_i = 200
x_i = jnp.linspace(*domain[:, 0], n_i)
# v_i = jnp.linspace(*domain[:, 2], n_i)
v_i = v_quad
xv_i = tensor_grid([x_i, v_i])
x_i, v_i = xv_i[:, 0:1], xv_i[:, 1:2]
t_i = jnp.zeros_like(x_i)
r_i = r0_fn(x_i, t_i, v_i)
j_i = j0_fn(x_i, t_i, v_i)

n_b = 100
t_b = jnp.linspace(*domain[:, 1], n_b)
# v_b = jnp.linspace(*domain[:, 2], n_b)
v_b = v_quad
tv_b = tensor_grid([t_b, v_b])
t_b, v_b = tv_b[:, 0:1], tv_b[:, 1:2]
x_bl = jnp.ones_like(t_b)*domain[0, 0]
x_br = jnp.ones_like(t_b)*domain[1, 0]

n_cx = 201
n_ct = 100
x_c = jnp.linspace(*domain[:, 0], n_cx).reshape((-1, 1))
t_c = jnp.linspace(*domain[:, 1], n_ct).reshape((-1, 1))
# v_c = jnp.linspace(*domain[:, 2], n_cv).reshape((-1, 1))
v_c = v_quad
xtv_c = tensor_grid([x_c, t_c, v_c])

dataset_Dirichlet = namedtuple("dataset_Dirichlet", ["x", "t", "v", "r", "j"])
dataset_Collocation = namedtuple("dataset_Collocation", ["x", "t", "v"])
dataset_Quad = namedtuple("dataset_Quad", ["w", "v"])
dataset_BC = namedtuple("dataset_BC", ["l", "r", "t", "v"])

dirichlet = dataset_Dirichlet(x_i, t_i, v_i, r_i, j_i)
collocation = dataset_Collocation(xtv_c[:, 0:1], xtv_c[:, 1:2], xtv_c[:, 2:3])
quad = dataset_Quad(w_quad, v_quad)
bc = dataset_BC(x_bl, x_br, t_b, v_b)

In [6]:
# class Time_Marching_Generator:
# 	def __init__(self, key, spatial_points, temporal_domain, batch_size, iterations, update_every, count1 = 0):
# 		self.key = key
# 		self.spatial_points = spatial_points
# 		self.domain = temporal_domain
# 		self.batch_size = batch_size
# 		self.iterations = iterations
# 		self._count1 = count1
# 		self._count2 = update_every
# 		if count1 < iterations:
# 			self._update(self.domain[0])
# 		else:
# 			self._update(self.domain[1])
# 		self.update_every = update_every
		
# 	def _update(self, tmax):
# 		self.key, subkey = random.split(self.key)
# 		self._t = random.uniform(key, (self.batch_size, 1), jnp.float32, self.domain[0], tmax)
		
# 	def __iter__(self):
# 		return self
	
# 	def __next__(self):
# 		if self._count2 == self.update_every:
# 			self._count1 = max(self.iterations, self._count1 + 1)
# 			tmax = self.domain[0] + (self.domain[1]-self.domain[0])*self._count1/self.iterations
# 			self._update(tmax)
# 			self._count2 = 0
# 		else:
# 			self._count2 += 1
# 		return self.spatial_points, self._t

In [7]:
lr = 1e-3
params = direct_params
opt_init, opt_update, get_params = optimizers.adam(lr)
opt_state = opt_init(params)
hist = {"iter": [], "loss": []}

batch_size = {"dirichlet": 3200, "collocation": 20100, "bc": 1600}
key, *subkeys = random.split(key, 5)
Dirichlet = Batch_Generator(subkeys[0], dirichlet, batch_size["dirichlet"])
Collocation = Batch_Generator(subkeys[1], collocation, batch_size["collocation"])
BC = Batch_Generator(subkeys[2], bc, batch_size["bc"])

start_iteration = 0
iterations = 100000
print_every = 200
save_every = 10000
weights = {"c1": 1.0, "c2": 1, "d": 1, "b1": 1, "b2": 1, "l1": 1e-8, "l2": 1e-8}

for iteration in range(start_iteration, start_iteration+iterations+1):
	d = next(Dirichlet)
	b = next(BC)
	c = next(Collocation)
	batch = {
		"dirichlet": dataset_Dirichlet(*d),
		"bc": dataset_BC(*b),
		"collocation": dataset_Collocation(jnp.vstack([d[0], c[0], b[0], b[1]]), jnp.vstack([d[1], c[1], b[2], b[2]]), jnp.vstack([d[2], c[2], b[3], b[3]])),
		"quad": quad,
		"weights": weights,
	}
	opt_state = step(iteration, opt_state, batch)
	if (iteration-start_iteration) % print_every == 0:
		names = ["Loss", "c1", "c2", "d", "b1", "b2", "l1_reg", "l2_reg"]
		params_ = get_params(opt_state)
		batch = {
			"dirichlet": dataset_Dirichlet(*Dirichlet.dataset),
			"bc": dataset_BC(*BC.dataset),
			"collocation": batch["collocation"],
			"quad": quad,
			"weights": weights
		}
		losses = evaluate(params_, batch)
		print("{}, Iteration: {}, Train".format(get_time(), iteration) + \
			  ','.join([" {}: {:.4e}".format(name, loss) for name, loss in zip(names, losses)]))
		hist["iter"].append(iteration)
		hist["loss"].append(losses)
	if (iteration-start_iteration) % save_every == 0:
		params_ = np.asarray(get_params(opt_state), dtype = object)
		save_path = "models/{}/iteration_{}/params.npy".format(NAME, iteration)
		if not os.path.exists(os.path.dirname(save_path)):
			os.makedirs(os.path.dirname(save_path))
		np.save(save_path, params_)

2020/08/29, 17:22:32, Iteration: 0, Train Loss: 1.4738e+01, c1: 1.1173e+00, c2: 1.1860e+01, d: 5.6257e-01, b1: 2.1710e-01, b2: 9.8149e-01, l1_reg: 3.0210e+02, l2_reg: 4.2696e+01
2020/08/29, 17:22:36, Iteration: 200, Train Loss: 1.4294e+00, c1: 1.3255e-01, c2: 1.5845e-01, d: 4.7444e-01, b1: 1.4908e-01, b2: 5.1484e-01, l1_reg: 3.0043e+02, l2_reg: 4.2867e+01
2020/08/29, 17:22:40, Iteration: 400, Train Loss: 1.0946e+00, c1: 1.3146e-01, c2: 6.5188e-02, d: 4.5306e-01, b1: 2.0744e-01, b2: 2.3741e-01, l1_reg: 3.0209e+02, l2_reg: 4.4335e+01
2020/08/29, 17:22:44, Iteration: 600, Train Loss: 9.5176e-01, c1: 1.0088e-01, c2: 4.3405e-02, d: 4.5657e-01, b1: 1.4800e-01, b2: 2.0290e-01, l1_reg: 3.0374e+02, l2_reg: 4.5808e+01
2020/08/29, 17:22:48, Iteration: 800, Train Loss: 8.5050e-01, c1: 1.6289e-01, c2: 3.8510e-02, d: 4.4862e-01, b1: 5.4904e-02, b2: 1.4558e-01, l1_reg: 3.0660e+02, l2_reg: 4.7978e+01
2020/08/29, 17:22:52, Iteration: 1000, Train Loss: 5.8266e-01, c1: 7.9407e-02, c2: 4.8963e-02, d: 3.70

2020/08/29, 17:25:37, Iteration: 9200, Train Loss: 4.9848e-02, c1: 1.2087e-02, c2: 6.2957e-03, d: 2.8062e-02, b1: 1.9943e-03, b2: 1.4091e-03, l1_reg: 3.3804e+02, l2_reg: 7.9553e+01
2020/08/29, 17:25:41, Iteration: 9400, Train Loss: 5.6919e-02, c1: 1.2401e-02, c2: 7.3299e-03, d: 2.8708e-02, b1: 2.6528e-03, b2: 5.8276e-03, l1_reg: 3.3844e+02, l2_reg: 8.0150e+01
2020/08/29, 17:25:44, Iteration: 9600, Train Loss: 5.5338e-02, c1: 1.0793e-02, c2: 8.7882e-03, d: 2.7847e-02, b1: 3.8493e-03, b2: 4.0599e-03, l1_reg: 3.3886e+02, l2_reg: 8.0730e+01
2020/08/29, 17:25:49, Iteration: 9800, Train Loss: 5.0904e-02, c1: 1.4896e-02, c2: 6.7982e-03, d: 2.5131e-02, b1: 3.1429e-03, b2: 9.3599e-04, l1_reg: 3.3923e+02, l2_reg: 8.1247e+01
2020/08/29, 17:25:53, Iteration: 10000, Train Loss: 5.5574e-02, c1: 1.3683e-02, c2: 8.8416e-03, d: 2.5548e-02, b1: 1.7110e-03, b2: 5.7902e-03, l1_reg: 3.3960e+02, l2_reg: 8.1836e+01
2020/08/29, 17:25:57, Iteration: 10200, Train Loss: 4.6861e-02, c1: 1.3658e-02, c2: 6.2000e-03

2020/08/29, 17:28:43, Iteration: 18400, Train Loss: 3.3714e-02, c1: 1.3458e-02, c2: 4.6043e-03, d: 1.3142e-02, b1: 7.5621e-04, b2: 1.7531e-03, l1_reg: 3.4955e+02, l2_reg: 1.0330e+02
2020/08/29, 17:28:47, Iteration: 18600, Train Loss: 3.3975e-02, c1: 1.2694e-02, c2: 4.6518e-03, d: 1.3353e-02, b1: 1.9526e-03, b2: 1.3238e-03, l1_reg: 3.4986e+02, l2_reg: 1.0379e+02
2020/08/29, 17:28:51, Iteration: 18800, Train Loss: 3.0677e-02, c1: 9.2749e-03, c2: 4.9492e-03, d: 1.3129e-02, b1: 1.7850e-03, b2: 1.5381e-03, l1_reg: 3.5007e+02, l2_reg: 1.0420e+02
2020/08/29, 17:28:55, Iteration: 19000, Train Loss: 2.8293e-02, c1: 7.6673e-03, c2: 4.7815e-03, d: 1.2988e-02, b1: 1.7607e-03, b2: 1.0953e-03, l1_reg: 3.5037e+02, l2_reg: 1.0472e+02
2020/08/29, 17:28:59, Iteration: 19200, Train Loss: 2.7583e-02, c1: 8.4985e-03, c2: 4.1629e-03, d: 1.2741e-02, b1: 8.7991e-04, b2: 1.3014e-03, l1_reg: 3.5060e+02, l2_reg: 1.0523e+02
2020/08/29, 17:29:03, Iteration: 19400, Train Loss: 4.6425e-02, c1: 2.4192e-02, c2: 5.6525

2020/08/29, 17:31:49, Iteration: 27600, Train Loss: 3.0148e-02, c1: 1.3590e-02, c2: 4.8490e-03, d: 9.2247e-03, b1: 1.8612e-03, b2: 6.2259e-04, l1_reg: 3.5956e+02, l2_reg: 1.2448e+02
2020/08/29, 17:31:53, Iteration: 27800, Train Loss: 2.9526e-02, c1: 9.4069e-03, c2: 4.6861e-03, d: 1.0592e-02, b1: 1.6079e-03, b2: 3.2333e-03, l1_reg: 3.5972e+02, l2_reg: 1.2492e+02
2020/08/29, 17:31:57, Iteration: 28000, Train Loss: 3.0122e-02, c1: 1.3845e-02, c2: 4.7344e-03, d: 9.2177e-03, b1: 1.4487e-03, b2: 8.7660e-04, l1_reg: 3.5989e+02, l2_reg: 1.2551e+02
2020/08/29, 17:32:01, Iteration: 28200, Train Loss: 2.3223e-02, c1: 7.9166e-03, c2: 3.4679e-03, d: 9.0582e-03, b1: 1.3021e-03, b2: 1.4779e-03, l1_reg: 3.5998e+02, l2_reg: 1.2591e+02
2020/08/29, 17:32:05, Iteration: 28400, Train Loss: 2.2125e-02, c1: 8.1569e-03, c2: 3.7124e-03, d: 8.6806e-03, b1: 6.7503e-04, b2: 9.0038e-04, l1_reg: 3.6027e+02, l2_reg: 1.2639e+02
2020/08/29, 17:32:09, Iteration: 28600, Train Loss: 2.2298e-02, c1: 5.7113e-03, c2: 3.6957

2020/08/29, 17:34:54, Iteration: 36800, Train Loss: 2.1903e-02, c1: 7.5439e-03, c2: 4.2122e-03, d: 7.7446e-03, b1: 1.9481e-03, b2: 4.5379e-04, l1_reg: 3.5861e+02, l2_reg: 1.3953e+02
2020/08/29, 17:34:58, Iteration: 37000, Train Loss: 2.8747e-02, c1: 1.2481e-02, c2: 4.4091e-03, d: 8.2721e-03, b1: 1.9339e-03, b2: 1.6500e-03, l1_reg: 3.5865e+02, l2_reg: 1.3998e+02
2020/08/29, 17:35:02, Iteration: 37200, Train Loss: 1.5194e-02, c1: 2.6254e-03, c2: 2.8630e-03, d: 7.5526e-03, b1: 5.1410e-04, b2: 1.6386e-03, l1_reg: 3.5852e+02, l2_reg: 1.4034e+02
2020/08/29, 17:35:06, Iteration: 37400, Train Loss: 1.9662e-02, c1: 6.9322e-03, c2: 2.8933e-03, d: 7.6529e-03, b1: 7.1237e-04, b2: 1.4716e-03, l1_reg: 3.5837e+02, l2_reg: 1.4071e+02
2020/08/29, 17:35:10, Iteration: 37600, Train Loss: 2.2956e-02, c1: 7.6314e-03, c2: 4.1729e-03, d: 7.9433e-03, b1: 1.6957e-03, b2: 1.5131e-03, l1_reg: 3.5843e+02, l2_reg: 1.4115e+02
2020/08/29, 17:35:14, Iteration: 37800, Train Loss: 2.2069e-02, c1: 5.1852e-03, c2: 4.3057

2020/08/29, 17:38:00, Iteration: 46000, Train Loss: 1.8260e-02, c1: 6.1820e-03, c2: 3.1224e-03, d: 6.3207e-03, b1: 1.4241e-03, b2: 1.2104e-03, l1_reg: 3.5581e+02, l2_reg: 1.5882e+02
2020/08/29, 17:38:04, Iteration: 46200, Train Loss: 1.8110e-02, c1: 3.4849e-03, c2: 5.1121e-03, d: 6.4622e-03, b1: 1.3293e-03, b2: 1.7216e-03, l1_reg: 3.5530e+02, l2_reg: 1.5906e+02
2020/08/29, 17:38:08, Iteration: 46400, Train Loss: 2.0366e-02, c1: 4.9926e-03, c2: 5.9206e-03, d: 6.4799e-03, b1: 1.5919e-03, b2: 1.3808e-03, l1_reg: 3.5543e+02, l2_reg: 1.5946e+02
2020/08/29, 17:38:12, Iteration: 46600, Train Loss: 1.8439e-02, c1: 4.5165e-03, c2: 5.5061e-03, d: 6.0609e-03, b1: 8.9025e-04, b2: 1.4653e-03, l1_reg: 3.5543e+02, l2_reg: 1.5977e+02
2020/08/29, 17:38:16, Iteration: 46800, Train Loss: 1.2562e-02, c1: 3.9280e-03, c2: 1.9404e-03, d: 5.5728e-03, b1: 8.5904e-04, b2: 2.6115e-04, l1_reg: 3.5531e+02, l2_reg: 1.6001e+02
2020/08/29, 17:38:20, Iteration: 47000, Train Loss: 1.2647e-02, c1: 4.4661e-03, c2: 2.0192

2020/08/29, 17:41:05, Iteration: 55200, Train Loss: 1.5867e-02, c1: 7.0374e-03, c2: 2.2393e-03, d: 4.9751e-03, b1: 1.2471e-03, b2: 3.6774e-04, l1_reg: 3.5137e+02, l2_reg: 1.7042e+02
2020/08/29, 17:41:09, Iteration: 55400, Train Loss: 1.2423e-02, c1: 3.1106e-03, c2: 2.3391e-03, d: 5.0366e-03, b1: 5.0260e-04, b2: 1.4344e-03, l1_reg: 3.5144e+02, l2_reg: 1.7068e+02
2020/08/29, 17:41:13, Iteration: 55600, Train Loss: 1.2021e-02, c1: 3.8454e-03, c2: 2.2701e-03, d: 4.7093e-03, b1: 5.9810e-04, b2: 5.9766e-04, l1_reg: 3.5146e+02, l2_reg: 1.7089e+02
2020/08/29, 17:41:17, Iteration: 55800, Train Loss: 2.3981e-02, c1: 1.1352e-02, c2: 4.5054e-03, d: 5.3480e-03, b1: 1.2348e-03, b2: 1.5416e-03, l1_reg: 3.5144e+02, l2_reg: 1.7119e+02
2020/08/29, 17:41:21, Iteration: 56000, Train Loss: 1.3439e-02, c1: 3.7779e-03, c2: 2.5333e-03, d: 4.9881e-03, b1: 1.0935e-03, b2: 1.0462e-03, l1_reg: 3.5107e+02, l2_reg: 1.7130e+02
2020/08/29, 17:41:25, Iteration: 56200, Train Loss: 1.6498e-02, c1: 4.7802e-03, c2: 4.9901

2020/08/29, 17:44:10, Iteration: 64400, Train Loss: 1.5613e-02, c1: 5.5400e-03, c2: 2.6690e-03, d: 4.6210e-03, b1: 1.4079e-03, b2: 1.3754e-03, l1_reg: 3.4426e+02, l2_reg: 1.8137e+02
2020/08/29, 17:44:14, Iteration: 64600, Train Loss: 1.1698e-02, c1: 2.9624e-03, c2: 2.0601e-03, d: 4.5501e-03, b1: 6.4938e-04, b2: 1.4756e-03, l1_reg: 3.4405e+02, l2_reg: 1.8155e+02
2020/08/29, 17:44:18, Iteration: 64800, Train Loss: 1.3512e-02, c1: 2.9111e-03, c2: 2.7164e-03, d: 5.0696e-03, b1: 1.1884e-03, b2: 1.6269e-03, l1_reg: 3.4402e+02, l2_reg: 1.8180e+02
2020/08/29, 17:44:22, Iteration: 65000, Train Loss: 1.1532e-02, c1: 3.6821e-03, c2: 2.2266e-03, d: 4.2431e-03, b1: 5.9722e-04, b2: 7.8331e-04, l1_reg: 3.4385e+02, l2_reg: 1.8206e+02
2020/08/29, 17:44:26, Iteration: 65200, Train Loss: 1.2458e-02, c1: 3.9530e-03, c2: 2.7750e-03, d: 4.1738e-03, b1: 5.4559e-04, b2: 1.0106e-03, l1_reg: 3.4365e+02, l2_reg: 1.8213e+02
2020/08/29, 17:44:30, Iteration: 65400, Train Loss: 2.0322e-02, c1: 8.8606e-03, c2: 4.3850

2020/08/29, 17:47:15, Iteration: 73600, Train Loss: 1.5599e-02, c1: 5.7902e-03, c2: 2.0782e-03, d: 4.6088e-03, b1: 1.8468e-03, b2: 1.2753e-03, l1_reg: 3.3752e+02, l2_reg: 1.9082e+02
2020/08/29, 17:47:19, Iteration: 73800, Train Loss: 8.2391e-03, c1: 2.3214e-03, c2: 1.6431e-03, d: 3.5515e-03, b1: 3.9798e-04, b2: 3.2515e-04, l1_reg: 3.3737e+02, l2_reg: 1.9086e+02
2020/08/29, 17:47:23, Iteration: 74000, Train Loss: 7.7528e-03, c1: 1.3141e-03, c2: 2.3405e-03, d: 3.4191e-03, b1: 3.5820e-04, b2: 3.2088e-04, l1_reg: 3.3758e+02, l2_reg: 1.9113e+02
2020/08/29, 17:47:27, Iteration: 74200, Train Loss: 1.0662e-02, c1: 3.4123e-03, c2: 2.4629e-03, d: 3.7244e-03, b1: 3.5686e-04, b2: 7.0549e-04, l1_reg: 3.3717e+02, l2_reg: 1.9122e+02
2020/08/29, 17:47:31, Iteration: 74400, Train Loss: 9.3502e-03, c1: 3.5024e-03, c2: 1.7163e-03, d: 3.4628e-03, b1: 2.5347e-04, b2: 4.1512e-04, l1_reg: 3.3676e+02, l2_reg: 1.9122e+02
2020/08/29, 17:47:35, Iteration: 74600, Train Loss: 1.0459e-02, c1: 3.8239e-03, c2: 2.1463

2020/08/29, 17:50:20, Iteration: 82800, Train Loss: 1.4528e-02, c1: 6.0727e-03, c2: 2.4011e-03, d: 3.9891e-03, b1: 1.8035e-03, b2: 2.6183e-04, l1_reg: 3.3347e+02, l2_reg: 1.9915e+02
2020/08/29, 17:50:24, Iteration: 83000, Train Loss: 9.6575e-03, c1: 3.1840e-03, c2: 1.7713e-03, d: 3.5198e-03, b1: 5.8480e-04, b2: 5.9760e-04, l1_reg: 3.3300e+02, l2_reg: 1.9918e+02
2020/08/29, 17:50:28, Iteration: 83200, Train Loss: 1.0563e-02, c1: 4.0849e-03, c2: 2.5335e-03, d: 3.2721e-03, b1: 3.0035e-04, b2: 3.7188e-04, l1_reg: 3.3289e+02, l2_reg: 1.9935e+02
2020/08/29, 17:50:32, Iteration: 83400, Train Loss: 1.7498e-02, c1: 9.0646e-03, c2: 3.1700e-03, d: 3.7679e-03, b1: 6.7798e-04, b2: 8.1711e-04, l1_reg: 3.3308e+02, l2_reg: 1.9953e+02
2020/08/29, 17:50:36, Iteration: 83600, Train Loss: 7.8777e-03, c1: 1.7775e-03, c2: 1.2910e-03, d: 3.5831e-03, b1: 4.6251e-04, b2: 7.6355e-04, l1_reg: 3.3274e+02, l2_reg: 1.9972e+02
2020/08/29, 17:50:40, Iteration: 83800, Train Loss: 1.3547e-02, c1: 4.2730e-03, c2: 2.9328

2020/08/29, 17:53:24, Iteration: 92000, Train Loss: 1.1200e-02, c1: 3.9954e-03, c2: 2.3692e-03, d: 3.4548e-03, b1: 6.9135e-04, b2: 6.8919e-04, l1_reg: 3.3153e+02, l2_reg: 2.0807e+02
2020/08/29, 17:53:28, Iteration: 92200, Train Loss: 1.1045e-02, c1: 4.2067e-03, c2: 2.4185e-03, d: 3.3236e-03, b1: 2.2474e-04, b2: 8.7180e-04, l1_reg: 3.3128e+02, l2_reg: 2.0825e+02
2020/08/29, 17:53:33, Iteration: 92400, Train Loss: 9.7492e-03, c1: 3.5945e-03, c2: 1.8470e-03, d: 3.3474e-03, b1: 5.4722e-04, b2: 4.1319e-04, l1_reg: 3.3182e+02, l2_reg: 2.0858e+02
2020/08/29, 17:53:37, Iteration: 92600, Train Loss: 8.7428e-03, c1: 2.0711e-03, c2: 1.4389e-03, d: 3.5018e-03, b1: 9.6722e-04, b2: 7.6384e-04, l1_reg: 3.3140e+02, l2_reg: 2.0878e+02
2020/08/29, 17:53:41, Iteration: 92800, Train Loss: 1.1681e-02, c1: 4.3242e-03, c2: 2.2229e-03, d: 3.4923e-03, b1: 5.9986e-04, b2: 1.0415e-03, l1_reg: 3.3152e+02, l2_reg: 2.0900e+02
2020/08/29, 17:53:45, Iteration: 93000, Train Loss: 9.5748e-03, c1: 3.3094e-03, c2: 1.6199

In [8]:
from scipy.io import loadmat
params_ = get_params(opt_state)

data_true = loadmat("4_snapshots_epsilon_1e-8.mat")
x_test = data_true["x"]
t_test = data_true["times"][0]
r_hat_trues = data_true["rhos"].T

# x_test = jnp.linspace(*domain[:, 0], 200)
# t_test = jnp.linspace(*domain[:, 1], 2000)
xt_tests = [tensor_grid([x_test, ti]) for ti in t_test]

r_hat_preds = [quadrature(params_, xt[:, 0:1], xt[:, 1:2], quad.v, quad.w) for xt in xt_tests]
# uv_preds = [model(params_, xt_test) for xt_test in xt_tests]
# u_preds, v_preds = [uv_pred[:, 0:1] for uv_pred in uv_preds], [uv_pred[:, 1:2] for uv_pred in uv_preds]
# duv_dxt_preds = [jacobian(params_, xt_test) for xt_test in xt_tests]
# du_dx_preds, dv_dx_preds = [duv_dxt_pred[:, 0:1, 0] for duv_dxt_pred in duv_dxt_preds], [duv_dxt_pred[:, 1:2, 0] for duv_dxt_pred in duv_dxt_preds]
# du_dt_preds, dv_dt_preds = [duv_dxt_pred[:, 0:1, 1] for duv_dxt_pred in duv_dxt_preds], [duv_dxt_pred[:, 1:2, 1] for duv_dxt_pred in duv_dxt_preds]

In [10]:
from matplotlib import animation
%matplotlib notebook

fig, ax = plt.subplots(1, 1, figsize = (5, 5))
line1, = ax.plot([], [], lw = 1.5, label = "pred")
line2, = ax.plot([], [], lw = 1.5, label = "true")
lines = [line1, line2]
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
ax.legend()
ax.grid()
    
def init():
	for line in lines:
		line.set_data([], [])
	return lines

def animate(i):
	r_hat_pred = r_hat_preds[i]
	r_hat_true = r_hat_trues[i]
	lines[0].set_data(x_test, r_hat_pred)
	lines[1].set_data(x_test, r_hat_true)
	ax.set_title("r, t = {:.4f}".format(t_test[i]))
	return lines

anim = animation.FuncAnimation(fig, animate, frames = len(t_test), interval = 100, blit = True)
plt.show()

<IPython.core.display.Javascript object>

In [None]:
from scipy.io import loadmat
data = loadmat("epsilon_1e-16.mat")
x_true_, u_true = data["x"], data["u"]
x_true = np.zeros_like(u_true)
x_true[0] = 0; x_true[-1] = 1
x_true[1:-1] = x_true_

f, ax = plt.subplots(1, 2, figsize = (10, 5))
ax[0].plot(x_test, r_hat_preds[-1], label = "pred")
ax[1].plot(x_true, u_true, label = "true")
for i in range(2):
	ax[i].legend()
	ax[i].grid()
plt.show()

In [None]:
# from matplotlib import animation
# %matplotlib notebook

# fig, ax = plt.subplots(2, 3, figsize = (15, 10))
# lines = []
# for i in range(2):
# 	for j in range(3):
# 		line, = ax[i][j].plot([], [], lw = 1.5, label = "pred")
# 		lines.append(line)
# 		ax[i][j].set_xlim([-1, 1])
# 		ax[i][j].set_ylim([-5, 5])
# 		ax[i][j].legend()
# 		ax[i][j].grid()
# ax[0][0].set_ylim([0.9, 2.1])
# ax[1][0].set_ylim([-0.1, 1.0])
    
# def init():
# 	for line in lines:
# 		line.set_data([], [])
# 	return lines

# def animate(i):
# 	u_pred, v_pred = u_preds[i], v_preds[i]
# 	du_dx_pred, dv_dx_pred = du_dx_preds[i], dv_dx_preds[i]
# 	du_dt_pred, dv_dt_pred = du_dt_preds[i], dv_dt_preds[i]
# # 	u_true, v_true = np.real(uv_true[i, :]), np.imag(uv_true[i, :])
	
# 	lines[0].set_data(x_test, u_pred)
# 	ax[0][0].set_title("u, t = {:.4f}".format(t_test[i]))
# 	lines[1].set_data(x_test, du_dx_pred)
# 	ax[0][1].set_title("du_dx, t = {:.4f}".format(t_test[i]))
# 	lines[2].set_data(x_test, du_dt_pred)
# 	ax[0][2].set_title("du_dt, t = {:.4f}".format(t_test[i]))
    
# 	lines[3].set_data(x_test, v_pred)
# 	ax[1][0].set_title("v, t = {:.4f}".format(t_test[i]))
# 	lines[4].set_data(x_test, dv_dx_pred)
# 	ax[1][1].set_title("dv_dx, t = {:.4f}".format(t_test[i]))
# 	lines[5].set_data(x_test, dv_dt_pred)
# 	ax[1][2].set_title("dv_dt, t = {:.4f}".format(t_test[i]))

# 	return lines

# anim = animation.FuncAnimation(fig, animate, frames = len(t_test), interval = 50, blit = True)
# plt.show()

In [None]:
# from scipy.io import loadmat
# data = loadmat("epsilon_0.49.mat")
# x_true, u_true, v_true = data["x"], data["u"], data["v"]

# f, ax = plt.subplots(1, 2, figsize = (12, 5))
# ax[0].plot(x_test, u_preds[-1], label = "pred")
# ax[0].plot(x_true, u_true, label = "true")
# ax[0].set_title("u, t = {:.2f}".format(t_test[-1]))
# ax[1].plot(x_test, v_preds[-1], label = "pred")
# ax[1].plot(x_true, v_true, label = "true")
# ax[1].set_title("v, t = {:.2f}".format(t_test[-1]))
# for i in range(2):
# 	ax[i].legend()
# 	ax[i].grid()
# plt.show()

In [None]:
# x = jnp.linspace(*domain[:, 0], 10000).reshape((-1, 1))
# t = jnp.zeros_like(x)
# xt = jnp.hstack([x, t])


# # direct_params_ = direct_params
# direct_params_ = get_params(opt_state)
# duv_dxt = jacobian(direct_params_, xt)
# du_dt, dv_dt = duv_dxt[:, 0:1, 1], duv_dxt[:, 1:2, 1]
# du_dx, dv_dx = duv_dxt[:, 0:1, 0], duv_dxt[:, 1:2, 0]
# duv_dxxtt = hessian(direct_params_, xt)
# du_dxx, dv_dxx = duv_dxxtt[:, 0:1, 0, 0], duv_dxxtt[:, 1:2, 0, 0] 
# uv = model(direct_params_, xt)
# u, v = uv[:, 0:1], uv[:, 1:2]
# loss_c1 = epsilon*du_dt + 0.5*epsilon**2*dv_dxx - V*v
# loss_c2 = epsilon*dv_dt - 0.5*epsilon**2*du_dxx + V*u

# du0_dx, dv0_dx = du0_dx_fn(xt), dv0_dx_fn(xt)
# du0_dxx, dv0_dxx = du0_dxx_fn(xt), dv0_dxx_fn(xt)
# u0, v0 = u0_fn(xt[:, 0:1], xt[:, 1:2]), v0_fn(xt[:, 0:1], xt[:, 1:2])
# du0_dt = 1.0/epsilon*(V*v0 - epsilon**2/2*dv0_dxx)
# dv0_dt = 1.0/epsilon*(epsilon**2/2*du0_dxx - V*u0)
# loss_c10 = epsilon*du0_dt + 0.5*epsilon**2*dv0_dxx - V*v0
# loss_c20 = epsilon*dv0_dt - 0.5*epsilon**2*du0_dxx + V*u0

# %matplotlib inline
# plt.rcParams.update(plt.rcParamsDefault)
# plt.rcParams["text.usetex"] = True

# f, ax = plt.subplots(2, 5, figsize = (20, 10))
# i, j = 0, 0
# ax[i][j].plot(x, du_dt, label = "pred")
# ax[i][j].plot(x, du0_dt, label = "true")
# ax[i][j].set_title(r"$\frac{\partial u}{\partial t}$")
# i = 1
# ax[i][j].plot(x, dv_dt, label = "pred")
# ax[i][j].plot(x, dv0_dt, label = "true")
# ax[i][j].set_title(r"$\frac{\partial v}{\partial t}$")
# i, j = 0, j+1
# ax[i][j].plot(x, du_dx, label = "pred")
# ax[i][j].plot(x, du0_dx, label = "true")
# ax[i][j].set_title(r"$\frac{\partial u}{\partial x}$")
# i = 1
# ax[i][j].plot(x, dv_dx, label = "pred")
# ax[i][j].plot(x, dv0_dx, label = "true")
# ax[i][j].set_title(r"$\frac{\partial v}{\partial x}$")
# i, j = 0, j+1
# ax[i][j].plot(x, du_dxx, label = "pred")
# ax[i][j].plot(x, du0_dxx, label = "true")
# ax[i][j].set_title(r"$\frac{\partial^2 u}{\partial x^2}$")
# i = 1
# ax[i][j].plot(x, dv_dxx, label = "pred")
# ax[i][j].plot(x, dv0_dxx, label = "true")
# ax[i][j].set_title(r"$\frac{\partial^2 v}{\partial x^2}$")
# i, j = 0, j+1
# ax[i][j].plot(x, u, label = "pred")
# ax[i][j].plot(x, u0, label = "true")
# ax[i][j].set_title(r"$u$")
# i = 1
# ax[i][j].plot(x, v, label = "pred")
# ax[i][j].plot(x, v0, label = "true")
# ax[i][j].set_title(r"$v$")
# i, j = 0, j+1
# ax[i][j].plot(x, loss_c1, label = "pred")
# ax[i][j].plot(x, loss_c10, label = "true")
# ax[i][j].set_title(r"loss c1")
# i = 1
# ax[i][j].plot(x, loss_c2, label = "pred")
# ax[i][j].plot(x, loss_c20, label = "true")
# ax[i][j].set_title(r"loss c2")

# for i in range(2):
# 	for j in range(5):
# 		ax[i][j].legend()
# 		ax[i][j].grid()
# plt.show()

In [None]:
get_params(opt_state)

In [None]:
du_dxx

In [None]:
du0_dx_fn(jnp.array([[-0.25, 0.0], [0.25, 0.0]]))