# Deterministic Linear Transport equation
---
Consider the equation
$$
\left\{
\begin{aligned}
&\frac{\partial r}{\partial t} + v\frac{\partial j}{\partial x} = \frac{\sigma(x, z)}{\epsilon^2}(\hat{r} - r), \\
&\frac{\partial j}{\partial t} + \frac{v}{\epsilon^2}\frac{\partial r}{\partial x} = - \frac{\sigma(x, z)}{\epsilon^2}j, 
\end{aligned}
\right.
$$
where $\epsilon$ is a small number, $v \in [-1, 1]$, and $\hat{r} = \int_{0}^{1} rdv$. 

It seems to me that $v$ follows a uniform distribution (as Gauss-Legendre quadrature is used to compute the integral).

We let $\sigma(x, z) \equiv 1$.

The initial data are
$$
r(x, v, z, t = 0) = j(x, v, z, t = 0) = 0.
$$

BC: for all $v, x, t$,
$$
\sigma j = -v r_{x}, 
$$
and for all $v$, 

---
# (3.12)
$$
r + \epsilon j{\Large \bracevert}_{x = 0} = 1, \quad r - \epsilon j{\Large \bracevert}_{x = 1} = 0.
$$

The spatiotemporal domain is 
$$
(x, t, v) \in [0, 1]\times [0, 0.01] \times [0, 1].
$$

In [1]:
NAME = "2_1e-8_2"

In [2]:
import jax, jax.nn
from jax import random
import jax.numpy as jnp
from jax.experimental import optimizers
from jax.ops import index, index_add, index_update


import sys, os
sys.path.append("../../")
	
from Seismic_wave_inversion_PINN.data_utils import *
from Seismic_wave_inversion_PINN.jax_model import *

from collections import namedtuple
# from jax.config import config; config.update("jax_enable_x64", True)
# dtype = jnp.float64
dtype = jnp.float32

In [3]:
key = random.PRNGKey(1)
key, subkey = random.split(key, 2)

layers = [3] + [32]*4 + [2] # (x, t, v) -> (r, j)
c0 = 1.0
w0 = jnp.array([[1.0, 1.0, 1.0]]).T
w1 = jnp.array([[1.0, 1.0]]) # (w_r, w_j)
direct_params = init_siren_params(subkey, layers, c0, w0, w1, dtype)

domain = jnp.array([[0., 0., 0.], [1., 0.01, 1.0]])

sigma = 1.0
epsilon = 1.0

@jax.jit
def model(params, xtv): # for predictions
	# linear scaling
	xtv = (2*xtv - (domain[0, :]+domain[1, :]))/(domain[1, :] - domain[0, :])
	for w, b in params[:-1]:
		xtv = jnp.sin(jnp.dot(xtv, w) + b)
	return jnp.dot(xtv, params[-1][0]) + params[-1][1]

# @jax.jit
# def model_(params, xt): # for derivatives
# 	for w, b in params[:-1]:
# 		xt = jnp.sin(jnp.dot(xt, w) + b)
# 	return jnp.dot(xt, params[-1][0]) + params[-1][1]

jacobian = jacrev_fn(model)
hessian = hessian_fn(model)

In [4]:
metaloss = mae

static_jit = lambda i: jax.partial(jax.jit, static_argnums = i)

@jax.jit
def quadrature(params, x, t, v, w):
	xt_ = jnp.repeat(jnp.hstack([x, t]), w.shape[0], axis = 0)
	v_ = jnp.tile(v, (x.shape[0], 1))
	rj = model(params, jnp.hstack([xt_, v_]))
	r = rj[:, 0:1].reshape((x.shape[0], w.shape[0]))
	return jnp.dot(r, w)

# jacobian[i] = [[dr/dx, dr/dt, dr/dv],
#                [dj/dx, dj/dt, dj/dv]]
# i: the i^th input

# hessian[i] = [
#				[[du/dxx, du/dxy],
#                [du/dxy, du/dyy]],
#               [[dv/dxx, dv/dxy],
#                [dv/dxy, dv/dyy]]
#              ]
@jax.jit
def loss_fn_(params, batch):
	collocation, dirichlet, bc, quad = batch["collocation"], batch["dirichlet"], batch["bc"], batch["quad"]
	direct_params = params
	
	if collocation[0] is not None:
		rj_c = model(direct_params, jnp.hstack([collocation.x, collocation.t, collocation.v]))
		r_c, j_c = rj_c[:, 0:1], rj_c[:, 1:2]
		drj_dxtv_c = jacobian(direct_params, jnp.hstack([collocation.x, collocation.t, collocation.v]))
		dr_dt_c, dj_dt_c = drj_dxtv_c[:, 0:1, 1], drj_dxtv_c[:, 1:2, 1]
		dr_dx_c, dj_dx_c = drj_dxtv_c[:, 0:1, 0], drj_dxtv_c[:, 1:2, 0]
		
		# quad.w: [q, 1]
		# quad.v: [q, 1]
		r_hat_c = quadrature(direct_params, collocation.x, collocation.t, quad.v, quad.w)
		
		loss_c1 = metaloss(epsilon**2*(dr_dt_c + collocation.v*dj_dx_c), sigma*(r_hat_c - r_c))
		loss_c2 = metaloss(epsilon**2*dj_dt_c + collocation.v*dr_dx_c, -sigma*j_c)
	else:
		loss_c1 = loss_c2 = 0
        
	if dirichlet[0] is not None:
		rj_d = model(direct_params, jnp.hstack([dirichlet.x, dirichlet.t, dirichlet.v]))
		r_d, j_d = rj_d[:, 0:1], rj_d[:, 1:2]
		loss_d1 = metaloss(r_d, dirichlet.r)
		loss_d2 = metaloss(j_d, dirichlet.j)
		loss_d = loss_d1 + loss_d2
	else:
		loss_d = 0.0
		
	if bc[0] is not None:
		rj_bl = model(direct_params, jnp.hstack([bc.l, bc.t, bc.v]))
		rj_br = model(direct_params, jnp.hstack([bc.r, bc.t, bc.v]))
		r_bl, j_bl = rj_bl[:, 0:1], rj_bl[:, 1:2]
		r_br, j_br = rj_br[:, 0:1], rj_br[:, 1:2]
		loss_b1 = 0
		loss_b2 = metaloss(r_bl + epsilon*j_bl, 1) + metaloss(r_br + epsilon*j_br, 0)

	return loss_c1, loss_c2, loss_d, loss_b1, loss_b2

@jax.jit
def loss_fn(params, batch):
	w = batch["weights"]
	loss_c1, loss_c2, loss_d, loss_b1, loss_b2 = loss_fn_(params, batch)
	return w["c1"]*loss_c1 + w["c2"]*loss_c2 + w["d"]*loss_d + w["b1"]*loss_b1 + w["b2"]*loss_b2 + \
			l1_regularization(params, w["l1"]) + l2_regularization(params, w["l2"])

@jax.jit
def step(i, opt_state, batch):
	params = get_params(opt_state)
	grad = jax.grad(loss_fn, 0)(params, batch)
	return opt_update(i, grad, opt_state)

@jax.jit
def evaluate(params, batch):
	w = batch["weights"]
	loss_c1, loss_c2, loss_d, loss_b1, loss_b2 = loss_fn_(params, batch)
	l1 = l1_regularization(params, 1.0)
	l2 = l2_regularization(params, 1.0)
	return w["c1"]*loss_c1 + w["c2"]*loss_c2 + w["d"]*loss_d + w["b1"]*loss_b1 + w["b2"]*loss_b2, \
			loss_c1, loss_c2, loss_d, loss_b1, loss_b2, l1, l2

In [5]:
# r0_fn = lambda x, t, v: jnp.zeros_like(x)
r0_fn = lambda x, t, v: jnp.select([jnp.isclose(x, domain[0, 0]), x > 0], [1.0, 0.0])
j0_fn = lambda x, t, v: jnp.zeros_like(x)

# r0_fn_ = lambda xtv: jnp.zeros_like(xtv[0])
# j0_fn_ = lambda xtv: jnp.zeros_like(xtv[0])

# dr0_dt_fn = lambda xtv: jax.vmap(jax.jacfwd(r0_fn_), in_axes = 0)(xtv)[:, 1:2]
# dj0_dt_fn = lambda xtv: jax.vmap(jax.jacfwd(j0_fn_), in_axes = 0)(xtv)[:, 1:2]

# dr0_dx_fn = lambda xtv: jax.vmap(jax.jacfwd(r0_fn_), in_axes = 0)(xtv)[:, 0:1]
dj0_dx_fn = lambda xtv: jax.vmap(jax.jacfwd(j0_fn_), in_axes = 0)(xtv)[:, 0:1]

key, *subkeys = random.split(key, 3)

n_quad = 16
v_quad, w_quad = np.polynomial.legendre.leggauss(n_quad)
v_quad = jnp.array(0.5*(v_quad+1), dtype = jnp.float32).reshape((-1, 1))
w_quad = jnp.array(0.5*w_quad, dtype = jnp.float32).reshape((-1, 1))

n_i = 200
x_i = jnp.linspace(*domain[:, 0], n_i)
# x_i = jnp.linspace(domain[0, 0]+(domain[1, 0]-domain[0, 0])/n_i, domain[1, 0]-(domain[1, 0]-domain[0, 0])/n_i, n_i)
# v_i = jnp.linspace(*domain[:, 2], n_i)
v_i = v_quad
xv_i = tensor_grid([x_i, v_i])
x_i, v_i = xv_i[:, 0:1], xv_i[:, 1:2]
t_i = jnp.zeros_like(x_i)
r_i = r0_fn(x_i, t_i, v_i)
j_i = j0_fn(x_i, t_i, v_i)

n_b = 100
t_b = jnp.linspace(*domain[:, 1], n_b)
# t_b = jnp.linspace(0.001, domain[1, 1], n_b)
# v_b = jnp.linspace(*domain[:, 2], n_b)
v_b = v_quad
tv_b = tensor_grid([t_b, v_b])
t_b, v_b = tv_b[:, 0:1], tv_b[:, 1:2]
x_bl = jnp.ones_like(t_b)*domain[0, 0]
x_br = jnp.ones_like(t_b)*domain[1, 0]

n_cx = 201
n_ct = 100
# n_cv = 20
x_c = jnp.linspace(*domain[:, 0], n_cx).reshape((-1, 1))
t_c = jnp.linspace(*domain[:, 1], n_ct).reshape((-1, 1))
# v_c = jnp.linspace(*domain[:, 2], n_cv).reshape((-1, 1))
v_c = v_quad
xtv_c = tensor_grid([x_c, t_c, v_c])

dataset_Dirichlet = namedtuple("dataset_Dirichlet", ["x", "t", "v", "r", "j"])
dataset_Collocation = namedtuple("dataset_Collocation", ["x", "t", "v"])
dataset_Quad = namedtuple("dataset_Quad", ["w", "v"])
dataset_BC = namedtuple("dataset_BC", ["l", "r", "t", "v"])

dirichlet = dataset_Dirichlet(x_i, t_i, v_i, r_i, j_i)
collocation = dataset_Collocation(xtv_c[:, 0:1], xtv_c[:, 1:2], xtv_c[:, 2:3])
quad = dataset_Quad(w_quad, v_quad)
bc = dataset_BC(x_bl, x_br, t_b, v_b)

In [6]:
# class Time_Marching_Generator:
# 	def __init__(self, key, spatial_points, temporal_domain, batch_size, iterations, update_every, count1 = 0):
# 		self.key = key
# 		self.spatial_points = spatial_points
# 		self.domain = temporal_domain
# 		self.batch_size = batch_size
# 		self.iterations = iterations
# 		self._count1 = count1
# 		self._count2 = update_every
# 		if count1 < iterations:
# 			self._update(self.domain[0])
# 		else:
# 			self._update(self.domain[1])
# 		self.update_every = update_every
		
# 	def _update(self, tmax):
# 		self.key, subkey = random.split(self.key)
# 		self._t = random.uniform(key, (self.batch_size, 1), jnp.float32, self.domain[0], tmax)
		
# 	def __iter__(self):
# 		return self
	
# 	def __next__(self):
# 		if self._count2 == self.update_every:
# 			self._count1 = max(self.iterations, self._count1 + 1)
# 			tmax = self.domain[0] + (self.domain[1]-self.domain[0])*self._count1/self.iterations
# 			self._update(tmax)
# 			self._count2 = 0
# 		else:
# 			self._count2 += 1
# 		return self.spatial_points, self._t

In [7]:
lr = 1e-3
params = direct_params
opt_init, opt_update, get_params = optimizers.adam(lr)
opt_state = opt_init(params)
hist = {"iter": [], "loss": []}

batch_size = {"dirichlet": 3200, "collocation": 20100, "bc": 1600}
key, *subkeys = random.split(key, 5)
Dirichlet = Batch_Generator(subkeys[0], dirichlet, batch_size["dirichlet"])
Collocation = Batch_Generator(subkeys[1], collocation, batch_size["collocation"])
BC = Batch_Generator(subkeys[2], bc, batch_size["bc"])

start_iteration = 0
iterations = 10000
print_every = 200
save_every = 10000
weights = {"c1": 1.0, "c2": 10, "d": 100, "b1": 100, "b2": 100, "l1": 1e-8, "l2": 1e-8}

for iteration in range(start_iteration, start_iteration+iterations+1):
	d = next(Dirichlet)
	b = next(BC)
	c = next(Collocation)
	batch = {
		"dirichlet": dataset_Dirichlet(*d),
		"bc": dataset_BC(*b),
		"collocation": dataset_Collocation(jnp.vstack([d[0], c[0], b[0], b[1]]), jnp.vstack([d[1], c[1], b[2], b[2]]), jnp.vstack([d[2], c[2], b[3], b[3]])),
		"quad": quad,
		"weights": weights,
	}
	opt_state = step(iteration, opt_state, batch)
	if (iteration-start_iteration) % print_every == 0:
		names = ["Loss", "c1", "c2", "d", "b1", "b2", "l1_reg", "l2_reg"]
		params_ = get_params(opt_state)
		batch = {
			"dirichlet": dataset_Dirichlet(*Dirichlet.dataset),
			"bc": dataset_BC(*BC.dataset),
			"collocation": batch["collocation"],
			"quad": quad,
			"weights": weights
		}
		losses = evaluate(params_, batch)
		print("{}, Iteration: {}, Train".format(get_time(), iteration) + \
			  ','.join([" {}: {:.4e}".format(name, loss) for name, loss in zip(names, losses)]))
		hist["iter"].append(iteration)
		hist["loss"].append(losses)
	if (iteration-start_iteration) % save_every == 0:
		params_ = np.asarray(get_params(opt_state), dtype = object)
		save_path = "models/{}/iteration_{}/params.npy".format(NAME, iteration)
		if not os.path.exists(os.path.dirname(save_path)):
			os.makedirs(os.path.dirname(save_path))
		np.save(save_path, params_)

2020/08/27, 17:06:39, Iteration: 0, Train Loss: 2.1554e+02, c1: 4.0824e+00, c2: 1.0070e+01, d: 1.2824e-01, b1: 0.0000e+00, b2: 9.7936e-01, l1_reg: 3.0213e+02, l2_reg: 4.2694e+01
2020/08/27, 17:06:43, Iteration: 200, Train Loss: 5.4635e+01, c1: 1.6626e+00, c2: 2.6109e-01, d: 4.8070e-01, b1: 0.0000e+00, b2: 2.2919e-02, l1_reg: 3.0382e+02, l2_reg: 4.3103e+01
2020/08/27, 17:06:47, Iteration: 400, Train Loss: 4.9350e+01, c1: 4.2662e-01, c2: 6.8214e-02, d: 4.6977e-01, b1: 0.0000e+00, b2: 1.2638e-02, l1_reg: 3.0407e+02, l2_reg: 4.3300e+01
2020/08/27, 17:06:51, Iteration: 600, Train Loss: 3.6225e+01, c1: 1.1627e+00, c2: 1.5698e-01, d: 3.1773e-01, b1: 0.0000e+00, b2: 1.7199e-02, l1_reg: 3.0789e+02, l2_reg: 4.4610e+01
2020/08/27, 17:06:56, Iteration: 800, Train Loss: 2.8898e+01, c1: 9.2891e-01, c2: 7.2902e-02, d: 2.6177e-01, b1: 0.0000e+00, b2: 1.0624e-02, l1_reg: 3.0897e+02, l2_reg: 4.4744e+01
2020/08/27, 17:07:00, Iteration: 1000, Train Loss: 2.6374e+01, c1: 5.3438e-01, c2: 1.1299e-01, d: 2.39

2020/08/27, 17:09:53, Iteration: 9200, Train Loss: 7.5235e+00, c1: 5.7138e-01, c2: 6.0765e-02, d: 5.1461e-02, b1: 0.0000e+00, b2: 1.1984e-02, l1_reg: 3.4182e+02, l2_reg: 5.9407e+01
2020/08/27, 17:09:57, Iteration: 9400, Train Loss: 6.8162e+00, c1: 5.7445e-01, c2: 6.6108e-02, d: 4.9939e-02, b1: 0.0000e+00, b2: 5.8679e-03, l1_reg: 3.4233e+02, l2_reg: 5.9644e+01
2020/08/27, 17:10:02, Iteration: 9600, Train Loss: 7.2018e+00, c1: 5.9922e-01, c2: 5.7615e-02, d: 4.9067e-02, b1: 0.0000e+00, b2: 1.1197e-02, l1_reg: 3.4285e+02, l2_reg: 5.9891e+01
2020/08/27, 17:10:06, Iteration: 9800, Train Loss: 9.1683e+00, c1: 6.0804e-01, c2: 5.7112e-02, d: 5.7089e-02, b1: 0.0000e+00, b2: 2.2802e-02, l1_reg: 3.4326e+02, l2_reg: 6.0130e+01
2020/08/27, 17:10:10, Iteration: 10000, Train Loss: 7.4246e+00, c1: 5.7446e-01, c2: 5.7611e-02, d: 4.9444e-02, b1: 0.0000e+00, b2: 1.3296e-02, l1_reg: 3.4366e+02, l2_reg: 6.0366e+01


In [11]:
start_iteration = 10000
iterations = 20000
print_every = 200
save_every = 10000
weights = {"c1": 10.0, "c2": 10, "d": 100, "b1": 100, "b2": 100, "l1": 1e-8, "l2": 1e-8}

for iteration in range(start_iteration, start_iteration+iterations+1):
	d = next(Dirichlet)
	b = next(BC)
	c = next(Collocation)
	batch = {
		"dirichlet": dataset_Dirichlet(*d),
		"bc": dataset_BC(*b),
		"collocation": dataset_Collocation(jnp.vstack([d[0], c[0], b[0], b[1]]), jnp.vstack([d[1], c[1], b[2], b[2]]), jnp.vstack([d[2], c[2], b[3], b[3]])),
		"quad": quad,
		"weights": weights,
	}
	opt_state = step(iteration, opt_state, batch)
	if (iteration-start_iteration) % print_every == 0:
		names = ["Loss", "c1", "c2", "d", "b1", "b2", "l1_reg", "l2_reg"]
		params_ = get_params(opt_state)
		batch = {
			"dirichlet": dataset_Dirichlet(*Dirichlet.dataset),
			"bc": dataset_BC(*BC.dataset),
			"collocation": batch["collocation"],
			"quad": quad,
			"weights": weights
		}
		losses = evaluate(params_, batch)
		print("{}, Iteration: {}, Train".format(get_time(), iteration) + \
			  ','.join([" {}: {:.4e}".format(name, loss) for name, loss in zip(names, losses)]))
		hist["iter"].append(iteration)
		hist["loss"].append(losses)
	if (iteration-start_iteration) % save_every == 0:
		params_ = np.asarray(get_params(opt_state), dtype = object)
		save_path = "models/{}/iteration_{}/params.npy".format(NAME, iteration)
		if not os.path.exists(os.path.dirname(save_path)):
			os.makedirs(os.path.dirname(save_path))
		np.save(save_path, params_)

2020/08/27, 17:18:28, Iteration: 10000, Train Loss: 1.3587e+01, c1: 6.6099e-01, c2: 7.2962e-02, d: 5.0553e-02, b1: 0.0000e+00, b2: 1.1919e-02, l1_reg: 3.4365e+02, l2_reg: 6.0365e+01
2020/08/27, 17:18:32, Iteration: 10200, Train Loss: 9.5100e+00, c1: 2.1481e-01, c2: 5.2352e-02, d: 4.7535e-02, b1: 0.0000e+00, b2: 2.0850e-02, l1_reg: 3.4275e+02, l2_reg: 5.9993e+01
2020/08/27, 17:18:35, Iteration: 10400, Train Loss: 9.3099e+00, c1: 2.1003e-01, c2: 6.6598e-02, d: 4.6277e-02, b1: 0.0000e+00, b2: 1.9160e-02, l1_reg: 3.4258e+02, l2_reg: 5.9938e+01
2020/08/27, 17:18:39, Iteration: 10600, Train Loss: 8.6057e+00, c1: 1.5398e-01, c2: 5.8733e-02, d: 4.5366e-02, b1: 0.0000e+00, b2: 1.9420e-02, l1_reg: 3.4250e+02, l2_reg: 5.9936e+01
2020/08/27, 17:18:43, Iteration: 10800, Train Loss: 8.1970e+00, c1: 1.2197e-01, c2: 5.2177e-02, d: 4.4211e-02, b1: 0.0000e+00, b2: 2.0344e-02, l1_reg: 3.4245e+02, l2_reg: 5.9965e+01
2020/08/27, 17:18:47, Iteration: 11000, Train Loss: 7.9695e+00, c1: 9.9778e-02, c2: 4.8099

2020/08/27, 17:21:38, Iteration: 19200, Train Loss: 3.8003e+00, c1: 1.4776e-02, c2: 1.4256e-02, d: 3.1727e-02, b1: 0.0000e+00, b2: 3.3726e-03, l1_reg: 3.5123e+02, l2_reg: 6.8289e+01
2020/08/27, 17:21:42, Iteration: 19400, Train Loss: 3.7924e+00, c1: 1.5886e-02, c2: 1.3713e-02, d: 3.0878e-02, b1: 0.0000e+00, b2: 4.0868e-03, l1_reg: 3.5200e+02, l2_reg: 6.8765e+01
2020/08/27, 17:21:46, Iteration: 19600, Train Loss: 3.9806e+00, c1: 2.5240e-02, c2: 2.9195e-02, d: 3.0942e-02, b1: 0.0000e+00, b2: 3.4197e-03, l1_reg: 3.5280e+02, l2_reg: 6.9263e+01
2020/08/27, 17:21:51, Iteration: 19800, Train Loss: 3.9300e+00, c1: 2.5666e-02, c2: 2.6289e-02, d: 2.9793e-02, b1: 0.0000e+00, b2: 4.3121e-03, l1_reg: 3.5356e+02, l2_reg: 6.9752e+01
2020/08/27, 17:21:55, Iteration: 20000, Train Loss: 3.7991e+00, c1: 1.8971e-02, c2: 2.0880e-02, d: 2.8497e-02, b1: 0.0000e+00, b2: 5.5091e-03, l1_reg: 3.5410e+02, l2_reg: 7.0182e+01
2020/08/27, 17:21:59, Iteration: 20200, Train Loss: 3.7377e+00, c1: 2.1572e-02, c2: 2.3267

2020/08/27, 17:24:52, Iteration: 28400, Train Loss: 3.0183e+00, c1: 3.1485e-02, c2: 3.1165e-02, d: 1.8209e-02, b1: 0.0000e+00, b2: 5.7085e-03, l1_reg: 3.6853e+02, l2_reg: 8.4022e+01
2020/08/27, 17:24:56, Iteration: 28600, Train Loss: 2.8341e+00, c1: 2.2839e-02, c2: 2.2710e-02, d: 1.7935e-02, b1: 0.0000e+00, b2: 5.8505e-03, l1_reg: 3.6867e+02, l2_reg: 8.4314e+01
2020/08/27, 17:25:00, Iteration: 28800, Train Loss: 2.9198e+00, c1: 2.5386e-02, c2: 2.4371e-02, d: 1.8014e-02, b1: 0.0000e+00, b2: 6.2086e-03, l1_reg: 3.6870e+02, l2_reg: 8.4595e+01
2020/08/27, 17:25:04, Iteration: 29000, Train Loss: 2.4940e+00, c1: 1.1328e-02, c2: 9.0876e-03, d: 1.7489e-02, b1: 0.0000e+00, b2: 5.4098e-03, l1_reg: 3.6882e+02, l2_reg: 8.4917e+01
2020/08/27, 17:25:09, Iteration: 29200, Train Loss: 2.7973e+00, c1: 1.5095e-02, c2: 1.4357e-02, d: 1.7877e-02, b1: 0.0000e+00, b2: 7.1508e-03, l1_reg: 3.6909e+02, l2_reg: 8.5283e+01
2020/08/27, 17:25:13, Iteration: 29400, Train Loss: 2.4283e+00, c1: 6.3940e-03, c2: 7.4715

In [15]:
params_ = get_params(opt_state)

x_test = jnp.linspace(*domain[:, 0], 200)
t_test = jnp.linspace(*domain[:, 1], 100)
xt_tests = [tensor_grid([x_test, ti]) for ti in t_test]

r_hat_preds = [quadrature(params_, xt[:, 0:1], xt[:, 1:2], quad.v, quad.w) for xt in xt_tests]
# uv_preds = [model(params_, xt_test) for xt_test in xt_tests]
# u_preds, v_preds = [uv_pred[:, 0:1] for uv_pred in uv_preds], [uv_pred[:, 1:2] for uv_pred in uv_preds]
# duv_dxt_preds = [jacobian(params_, xt_test) for xt_test in xt_tests]
# du_dx_preds, dv_dx_preds = [duv_dxt_pred[:, 0:1, 0] for duv_dxt_pred in duv_dxt_preds], [duv_dxt_pred[:, 1:2, 0] for duv_dxt_pred in duv_dxt_preds]
# du_dt_preds, dv_dt_preds = [duv_dxt_pred[:, 0:1, 1] for duv_dxt_pred in duv_dxt_preds], [duv_dxt_pred[:, 1:2, 1] for duv_dxt_pred in duv_dxt_preds]

ERROR:tornado.application:Exception in callback <bound method TimerBase._on_timer of <matplotlib.backends.backend_webagg_core.TimerTornado object at 0x7ff5e85a59d0>>
Traceback (most recent call last):
  File "/usr/lib/python3.8/site-packages/tornado/ioloop.py", line 907, in _run
    return self.callback()
  File "/usr/lib/python3.8/site-packages/matplotlib/backend_bases.py", line 1194, in _on_timer
    ret = func(*args, **kwargs)
  File "/usr/lib/python3.8/site-packages/matplotlib/animation.py", line 1420, in _step
    self._init_draw()
  File "/usr/lib/python3.8/site-packages/matplotlib/animation.py", line 1695, in _init_draw
    self._draw_frame(next(self.new_frame_seq()))
  File "/usr/lib/python3.8/site-packages/matplotlib/animation.py", line 1718, in _draw_frame
    self._drawn_artists = self._func(framedata, *self._args)
  File "<ipython-input-13-d85ab7517393>", line 21, in animate
    ax.set_title("r, t = {:.4f}".format(t_test[i]))
AttributeError: 'numpy.ndarray' object has no at

In [18]:
from matplotlib import animation
%matplotlib notebook

fig, ax = plt.subplots(1, 1, figsize = (5, 5))
lines = []
line, = ax.plot([], [], lw = 1.5, label = "pred")
lines.append(line)
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
ax.legend()
ax.grid()
    
def init():
	for line in lines:
		line.set_data([], [])
	return lines

def animate(i):
	r_hat_pred = r_hat_preds[i]
	lines[0].set_data(x_test, r_hat_pred)
	ax.set_title("r, t = {:.4f}".format(t_test[i]))
	return lines

anim = animation.FuncAnimation(fig, animate, frames = len(t_test), interval = 100, blit = True)
plt.show()

<IPython.core.display.Javascript object>

In [17]:
from scipy.io import loadmat
data = loadmat("epsilon_1e-16.mat")
x_true_, u_true = data["x"], data["u"]
x_true = np.zeros_like(u_true)
x_true[0] = 0; x_true[-1] = 1
x_true[1:-1] = x_true_

f, ax = plt.subplots(1, 2, figsize = (10, 5))
ax[0].plot(x_test, r_hat_preds[-1], label = "pred")
ax[1].plot(x_true, u_true, label = "true")
for i in range(2):
	ax[i].legend()
	ax[i].grid()
plt.show()

<IPython.core.display.Javascript object>

In [None]:
# from matplotlib import animation
# %matplotlib notebook

# fig, ax = plt.subplots(2, 3, figsize = (15, 10))
# lines = []
# for i in range(2):
# 	for j in range(3):
# 		line, = ax[i][j].plot([], [], lw = 1.5, label = "pred")
# 		lines.append(line)
# 		ax[i][j].set_xlim([-1, 1])
# 		ax[i][j].set_ylim([-5, 5])
# 		ax[i][j].legend()
# 		ax[i][j].grid()
# ax[0][0].set_ylim([0.9, 2.1])
# ax[1][0].set_ylim([-0.1, 1.0])
    
# def init():
# 	for line in lines:
# 		line.set_data([], [])
# 	return lines

# def animate(i):
# 	u_pred, v_pred = u_preds[i], v_preds[i]
# 	du_dx_pred, dv_dx_pred = du_dx_preds[i], dv_dx_preds[i]
# 	du_dt_pred, dv_dt_pred = du_dt_preds[i], dv_dt_preds[i]
# # 	u_true, v_true = np.real(uv_true[i, :]), np.imag(uv_true[i, :])
	
# 	lines[0].set_data(x_test, u_pred)
# 	ax[0][0].set_title("u, t = {:.4f}".format(t_test[i]))
# 	lines[1].set_data(x_test, du_dx_pred)
# 	ax[0][1].set_title("du_dx, t = {:.4f}".format(t_test[i]))
# 	lines[2].set_data(x_test, du_dt_pred)
# 	ax[0][2].set_title("du_dt, t = {:.4f}".format(t_test[i]))
    
# 	lines[3].set_data(x_test, v_pred)
# 	ax[1][0].set_title("v, t = {:.4f}".format(t_test[i]))
# 	lines[4].set_data(x_test, dv_dx_pred)
# 	ax[1][1].set_title("dv_dx, t = {:.4f}".format(t_test[i]))
# 	lines[5].set_data(x_test, dv_dt_pred)
# 	ax[1][2].set_title("dv_dt, t = {:.4f}".format(t_test[i]))

# 	return lines

# anim = animation.FuncAnimation(fig, animate, frames = len(t_test), interval = 50, blit = True)
# plt.show()

In [None]:
# from scipy.io import loadmat
# data = loadmat("epsilon_0.49.mat")
# x_true, u_true, v_true = data["x"], data["u"], data["v"]

# f, ax = plt.subplots(1, 2, figsize = (12, 5))
# ax[0].plot(x_test, u_preds[-1], label = "pred")
# ax[0].plot(x_true, u_true, label = "true")
# ax[0].set_title("u, t = {:.2f}".format(t_test[-1]))
# ax[1].plot(x_test, v_preds[-1], label = "pred")
# ax[1].plot(x_true, v_true, label = "true")
# ax[1].set_title("v, t = {:.2f}".format(t_test[-1]))
# for i in range(2):
# 	ax[i].legend()
# 	ax[i].grid()
# plt.show()

In [None]:
# x = jnp.linspace(*domain[:, 0], 10000).reshape((-1, 1))
# t = jnp.zeros_like(x)
# xt = jnp.hstack([x, t])


# # direct_params_ = direct_params
# direct_params_ = get_params(opt_state)
# duv_dxt = jacobian(direct_params_, xt)
# du_dt, dv_dt = duv_dxt[:, 0:1, 1], duv_dxt[:, 1:2, 1]
# du_dx, dv_dx = duv_dxt[:, 0:1, 0], duv_dxt[:, 1:2, 0]
# duv_dxxtt = hessian(direct_params_, xt)
# du_dxx, dv_dxx = duv_dxxtt[:, 0:1, 0, 0], duv_dxxtt[:, 1:2, 0, 0] 
# uv = model(direct_params_, xt)
# u, v = uv[:, 0:1], uv[:, 1:2]
# loss_c1 = epsilon*du_dt + 0.5*epsilon**2*dv_dxx - V*v
# loss_c2 = epsilon*dv_dt - 0.5*epsilon**2*du_dxx + V*u

# du0_dx, dv0_dx = du0_dx_fn(xt), dv0_dx_fn(xt)
# du0_dxx, dv0_dxx = du0_dxx_fn(xt), dv0_dxx_fn(xt)
# u0, v0 = u0_fn(xt[:, 0:1], xt[:, 1:2]), v0_fn(xt[:, 0:1], xt[:, 1:2])
# du0_dt = 1.0/epsilon*(V*v0 - epsilon**2/2*dv0_dxx)
# dv0_dt = 1.0/epsilon*(epsilon**2/2*du0_dxx - V*u0)
# loss_c10 = epsilon*du0_dt + 0.5*epsilon**2*dv0_dxx - V*v0
# loss_c20 = epsilon*dv0_dt - 0.5*epsilon**2*du0_dxx + V*u0

# %matplotlib inline
# plt.rcParams.update(plt.rcParamsDefault)
# plt.rcParams["text.usetex"] = True

# f, ax = plt.subplots(2, 5, figsize = (20, 10))
# i, j = 0, 0
# ax[i][j].plot(x, du_dt, label = "pred")
# ax[i][j].plot(x, du0_dt, label = "true")
# ax[i][j].set_title(r"$\frac{\partial u}{\partial t}$")
# i = 1
# ax[i][j].plot(x, dv_dt, label = "pred")
# ax[i][j].plot(x, dv0_dt, label = "true")
# ax[i][j].set_title(r"$\frac{\partial v}{\partial t}$")
# i, j = 0, j+1
# ax[i][j].plot(x, du_dx, label = "pred")
# ax[i][j].plot(x, du0_dx, label = "true")
# ax[i][j].set_title(r"$\frac{\partial u}{\partial x}$")
# i = 1
# ax[i][j].plot(x, dv_dx, label = "pred")
# ax[i][j].plot(x, dv0_dx, label = "true")
# ax[i][j].set_title(r"$\frac{\partial v}{\partial x}$")
# i, j = 0, j+1
# ax[i][j].plot(x, du_dxx, label = "pred")
# ax[i][j].plot(x, du0_dxx, label = "true")
# ax[i][j].set_title(r"$\frac{\partial^2 u}{\partial x^2}$")
# i = 1
# ax[i][j].plot(x, dv_dxx, label = "pred")
# ax[i][j].plot(x, dv0_dxx, label = "true")
# ax[i][j].set_title(r"$\frac{\partial^2 v}{\partial x^2}$")
# i, j = 0, j+1
# ax[i][j].plot(x, u, label = "pred")
# ax[i][j].plot(x, u0, label = "true")
# ax[i][j].set_title(r"$u$")
# i = 1
# ax[i][j].plot(x, v, label = "pred")
# ax[i][j].plot(x, v0, label = "true")
# ax[i][j].set_title(r"$v$")
# i, j = 0, j+1
# ax[i][j].plot(x, loss_c1, label = "pred")
# ax[i][j].plot(x, loss_c10, label = "true")
# ax[i][j].set_title(r"loss c1")
# i = 1
# ax[i][j].plot(x, loss_c2, label = "pred")
# ax[i][j].plot(x, loss_c20, label = "true")
# ax[i][j].set_title(r"loss c2")

# for i in range(2):
# 	for j in range(5):
# 		ax[i][j].legend()
# 		ax[i][j].grid()
# plt.show()

In [None]:
get_params(opt_state)

In [None]:
du_dxx

In [None]:
du0_dx_fn(jnp.array([[-0.25, 0.0], [0.25, 0.0]]))