# 1D Schrodinger equation
---
(Example From PINNs)

$$
\begin{aligned}
&ih_t + 0.5h_{xx} + |h|^2 h = 0, \ (x, t) \in [-5, 5]\times [0, \pi/2], \\ 
&h(0, x) = 2 \sech (x), \\
&h(t, -5) = h(t, 5), \\
&h_x(t, -5) = h_x(t, 5).
\end{aligned}
$$

Consider 
$$
h(t, x) = u(t, x) + iv(t, x), 
$$
where $u, v$ are real-valued functions. Then the equation can be given by

$$
\begin{aligned}
&u_t + 0.5v_{xx} + (u^2+v^2)v = 0, \\
&-v_t + 0.5u_{xx} + (u^2+v^2)u = 0, \\
&u(0, x) = 2\sech (x), \ v(0, x) = 0, \\
&u(t, -5) = u(t, 5), \ v(t, -5) = v(t, 5), \\
&u_x(t, -5) = u_x(t, 5), \ v_x(t, -5) = v_x(t, 5).
\end{aligned}
$$

In [1]:
NAME = "siren"

In [2]:
import jax, jax.nn
from jax import random
import jax.numpy as jnp
from jax.experimental import optimizers

import sys, os
sys.path.append("../../")
	
from Seismic_wave_inversion_PINN.data_utils import *

from collections import namedtuple

In [3]:
def siren_layer_params(key, scale, m, n):
	w_key, b_key = random.split(key)
	return random.uniform(w_key, (m, n), jnp.float32, minval = -scale, maxval = scale), jnp.zeros((n, ), jnp.float32)

def init_siren_params(key, layers, c0, w0):
	keys = random.split(key, len(layers))
	return [siren_layer_params(keys[0], w0*jnp.sqrt(c0/layers[0]), layers[0], layers[1])] + \
			[siren_layer_params(k, jnp.sqrt(c0/m), m, n) for m, n, k in zip(layers[1:-1], layers[2:], keys[1:])]

layers = [2, 128, 128, 128, 128, 128, 2] # (x, t) -> (u, v)
c0 = 1.0
w0 = 10.0
lambda_0 = 1e-10
direct_params = init_siren_params(random.PRNGKey(0), layers, c0, w0)

@jax.jit
def scalar_u_model(params, x, t):
	x_ = jnp.hstack([x, t])
	for w, b in params[:-1]:
		x_ = jnp.sin(jnp.dot(x_, w) + b)
	return jnp.sum(jnp.dot(x_, params[-1][0][:, 0:1]) + params[-1][1][0])

@jax.jit
def scalar_v_model(params, x, t):
	x_ = jnp.hstack([x, t])
	for w, b in params[:-1]:
		x_ = jnp.sin(jnp.dot(x_, w) + b)
	return jnp.sum(jnp.dot(x_, params[-1][0][:, 1:2]) + params[-1][1][1])

u_model = jax.jit(jax.vmap(scalar_u_model, in_axes = (None, 0, 0)))
v_model = jax.jit(jax.vmap(scalar_v_model, in_axes = (None, 0, 0)))



In [4]:
@jax.jit
def mse(pred, true):
	return jnp.mean(jnp.square(pred - true))

@jax.jit
def l2_regularization(params, lambda_0):
	res = 0
	for p in params:
		res += jnp.sum(jnp.square(p[0]))
	return res*lambda_0

@jax.jit
def scalar_du_dx(params, x, t):
    return jnp.sum(jax.grad(scalar_u_model, 1)(params, x, t))

@jax.jit
def scalar_du_dt(params, x, t):
    return jnp.sum(jax.grad(scalar_u_model, 2)(params, x, t))

du_dx = jax.jit(jax.vmap(scalar_du_dx, in_axes = (None, 0, 0)))
du_dt = jax.jit(jax.vmap(scalar_du_dt, in_axes = (None, 0, 0)))

@jax.jit
def du_dxx(params, x, t):
    return jax.grad(scalar_du_dx, 1)(params, x, t)

@jax.jit
def du_dtt(params, x, t):
    return jax.grad(scalar_du_dt, 2)(params, x, t)

@jax.jit
def scalar_dv_dx(params, x, t):
    return jnp.sum(jax.grad(scalar_v_model, 1)(params, x, t))

@jax.jit
def scalar_dv_dt(params, x, t):
    return jnp.sum(jax.grad(scalar_v_model, 2)(params, x, t))

dv_dx = jax.jit(jax.vmap(scalar_dv_dx, in_axes = (None, 0, 0)))
dv_dt = jax.jit(jax.vmap(scalar_dv_dt, in_axes = (None, 0, 0)))

@jax.jit
def dv_dxx(params, x, t):
    return jax.grad(scalar_dv_dx, 1)(params, x, t)

@jax.jit
def dv_dtt(params, x, t):
    return jax.grad(scalar_dv_dt, 2)(params, x, t)

@jax.jit
def loss_fn_(params, batch):
	collocation, dirichlet, periodic_bc = batch["collocation"], batch["dirichlet"], batch["periodic_bc"]
	direct_params = params
	
	u_c = u_model(direct_params, collocation.x, collocation.t)
	v_c = v_model(direct_params, collocation.x, collocation.t)
	du_dt_c = du_dt(direct_params, collocation.x, collocation.t)
	dv_dt_c = dv_dt(direct_params, collocation.x, collocation.t)
	du_dxx_c = du_dxx(direct_params, collocation.x, collocation.t)
	dv_dxx_c = dv_dxx(direct_params, collocation.x, collocation.t)
	
	u_l = u_model(direct_params, periodic_bc.l, periodic_bc.t)
	u_r = u_model(direct_params, periodic_bc.r, periodic_bc.t)
	v_l = v_model(direct_params, periodic_bc.l, periodic_bc.t)
	v_r = v_model(direct_params, periodic_bc.r, periodic_bc.t)
	du_dx_l = du_dx(direct_params, periodic_bc.l, periodic_bc.t)
	du_dx_r = du_dx(direct_params, periodic_bc.r, periodic_bc.t)
	dv_dx_l = dv_dx(direct_params, periodic_bc.l, periodic_bc.t)
	dv_dx_r = dv_dx(direct_params, periodic_bc.r, periodic_bc.t)	
		
	u_d = u_model(direct_params, dirichlet.x, dirichlet.t).reshape((-1, 1))
	v_d = v_model(direct_params, dirichlet.x, dirichlet.t).reshape((-1, 1))
	
	loss_c1 = mse(du_dt_c + 0.5*dv_dxx_c + (u_c**2 + v_c**2)*v_c, 0)
	loss_c2 = mse(-dv_dt_c + 0.5*du_dxx_c + (u_c**2 + v_c**2)*u_c, 0)
	loss_c = loss_c1 + loss_c2
	
	loss_d1 = mse(u_d, dirichlet.u)
	loss_d2 = mse(v_d, dirichlet.v)
	loss_d = loss_d1 + loss_d2
	
	loss_pbc_d1 = mse(u_l, u_r)
	loss_pbc_d2 = mse(v_l, v_r)
	loss_pbc_n1 = mse(du_dx_l, du_dx_r)
	loss_pbc_n2 = mse(dv_dx_l, dv_dx_r)
	loss_pbc_d = loss_pbc_d1 + loss_pbc_d2
	loss_pbc_n = loss_pbc_n1 + loss_pbc_n2
	
	return loss_c, loss_d, loss_pbc_d, loss_pbc_n

@jax.jit
def loss_fn(params, batch):
	w = batch["weights"]
	loss_c, loss_d, loss_pbc_d, loss_pbc_n = loss_fn_(params, batch)
	return w["c"]*loss_c + w["d"]*loss_d + w["pbc_d"]*loss_pbc_d + w["pbc_n"]*loss_pbc_n + l2_regularization(params, lambda_0)

@jax.jit
def step(i, opt_state, batch):
	params = get_params(opt_state)
	grad = jax.grad(loss_fn, 0)(params, batch)
	return opt_update(i, grad, opt_state)

@jax.jit
def evaluate(params, batch):
	w = batch["weights"]
	loss_c, loss_d, loss_pbc_d, loss_pbc_n = loss_fn_(params, batch)
	return w["c"]*loss_c + w["d"]*loss_d + w["pbc_d"]*loss_pbc_d + w["pbc_n"]*loss_pbc_n, loss_c, loss_d, loss_pbc_d, loss_pbc_n

In [5]:
domain = np.array([[-5, 5], [0, np.pi/2]])
key = random.PRNGKey(1)
key, *subkeys = random.split(key, 4)

from scipy.io import loadmat
data = loadmat("NLS.mat")
x, t, h = data["x"].reshape((-1, 1)), data["tt"].reshape((-1, 1)), data["uu"].T
u, v = np.real(h), np.imag(h)

# ic
n_i = 50
ix_i = random.choice(subkeys[0], jnp.arange(len(x)), shape = (n_i, ), replace = False)
x_i = x[ix_i, :]
t_i = np.zeros_like(x_i)
u_i = u[0, ix_i].reshape((-1, 1))
v_i = v[0, ix_i].reshape((-1, 1))

# bc
n_b = 50
ix_b = random.choice(subkeys[1], jnp.arange(len(t)), shape = (n_b, ), replace = False)
t_b = t[ix_b, :]
x_lb = np.ones_like(t_b)*domain[0, 0]
x_rb = np.ones_like(t_b)*domain[0, 1]

n_c = 20000
from pyDOE import lhs
xt_c = lhs(2, n_c)
x_c = transform(xt_c[:, 0:1], *domain[0])
t_c = transform(xt_c[:, 1:2], *domain[1])

dataset_Dirichlet = namedtuple("dataset_Dirichlet", ["x", "t", "u", "v"])
dataset_Collocation = namedtuple("dataset_Collocation", ["x", "t"])
dataset_BC = namedtuple("dataset_BC", ["l", "r", "t"])

map_to_jnp_array = lambda x: map(lambda arr: jnp.array(arr), x)
dirichlet = dataset_Dirichlet(*map_to_jnp_array([x_i, t_i, u_i, v_i]))
periodic_bc = dataset_BC(*map_to_jnp_array([x_lb, x_rb, t_b]))
collocation = dataset_Collocation(*map_to_jnp_array([jnp.vstack([dirichlet.x, periodic_bc.l, periodic_bc.r, x_c]),
													jnp.vstack([dirichlet.t, periodic_bc.t, periodic_bc.t, t_c])]))

class Batch_Generator:
	def __init__(self, key, dataset, batch_size):
		self.key = key
		self.dataset = dataset
		self.batch_size = batch_size
		self.index = jnp.arange(dataset[0].shape[0])
		self.pointer = 0
		self._shuffle()
		
	def _shuffle(self):
		key, subkey = random.split(self.key)
		self.index = random.permutation(subkey, jnp.arange(self.dataset[0].shape[0]))
		self.key = key
		
	def __iter__(self):
		return self
	
	def __next__(self):
		if self.pointer >= len(self.index):
			self._shuffle()
			self.pointer = 0
		self.pointer += self.batch_size
		index_ = self.index[self.pointer-self.batch_size:self.pointer]
		return [d[index_, :] for d in self.dataset]

In [None]:
lr = 1e-3
start_iteration = 0
iterations = 50000
print_every = 100
save_every = 50000
batch_size = {"dirichlet": 50, "bc": 50, "collocation": 20000}
weights = {"c": 1.0, "d": 1.0, "pbc_d": 1.0, "pbc_n": 1.0}

key, *subkeys = random.split(key, 4)
Dirichlet = Batch_Generator(subkeys[0], dirichlet, batch_size["dirichlet"])
Collocation = Batch_Generator(subkeys[1], collocation, batch_size["collocation"])
BC = Batch_Generator(subkeys[2], periodic_bc, batch_size["bc"])
params = direct_params

opt_init, opt_update, get_params = optimizers.adam(lr)
opt_state = opt_init(params)
hist = {"iter": [], "loss": []}

for iteration in range(start_iteration, start_iteration+iterations+1):
	batch = {
		"dirichlet": dataset_Dirichlet(*next(Dirichlet)),
		"collocation": dataset_Collocation(*next(Collocation)),
		"periodic_bc": dataset_BC(*next(BC)),
		"weights": weights
	}
	opt_state = step(iteration, opt_state, batch)
	if (iteration-start_iteration) % print_every == 0:
		names = ["Loss", "c", "d"]
		params_ = get_params(opt_state)
		losses = evaluate(params_, batch)
		print("{}, Iteration: {}, Train".format(get_time(), iteration) + \
			  ','.join([" {}: {:.4e}".format(name, loss) for name, loss in zip(names, losses)]))
		hist["iter"].append(iteration)
		hist["loss"].append(losses[0])
	if (iteration-start_iteration) % save_every == 0:
		params_ = np.asarray(get_params(opt_state), dtype = object)
		save_path = "models/{}/iteration_{}/params.npy".format(NAME, iteration)
		if not os.path.exists(os.path.dirname(save_path)):
			os.makedirs(os.path.dirname(save_path))
		np.save(save_path, params_)

2020/08/02, 00:48:42, Iteration: 0, Train Loss: 9.8935e-01, c: 1.3233e-01, d: 8.0831e-01
2020/08/02, 00:51:10, Iteration: 100, Train Loss: 3.9466e-01, c: 2.1591e-02, d: 3.7161e-01
2020/08/02, 00:53:11, Iteration: 200, Train Loss: 1.1961e-01, c: 6.3821e-02, d: 5.4533e-02
2020/08/02, 00:55:14, Iteration: 300, Train Loss: 9.4450e-02, c: 5.8385e-02, d: 3.4860e-02
2020/08/02, 00:57:19, Iteration: 400, Train Loss: 8.7343e-02, c: 5.1154e-02, d: 3.5661e-02
2020/08/02, 00:59:23, Iteration: 500, Train Loss: 8.5922e-02, c: 4.7202e-02, d: 3.7869e-02
2020/08/02, 01:01:27, Iteration: 600, Train Loss: 8.2136e-02, c: 4.4397e-02, d: 3.7399e-02
2020/08/02, 01:03:32, Iteration: 700, Train Loss: 8.5936e-02, c: 4.7181e-02, d: 3.8100e-02
2020/08/02, 01:05:38, Iteration: 800, Train Loss: 9.0064e-02, c: 4.4636e-02, d: 4.4830e-02
2020/08/02, 01:07:44, Iteration: 900, Train Loss: 8.5356e-02, c: 4.7248e-02, d: 3.7415e-02
2020/08/02, 01:09:50, Iteration: 1000, Train Loss: 8.2809e-02, c: 4.6874e-02, d: 3.5266e-02


2020/08/02, 04:06:13, Iteration: 9000, Train Loss: 7.8276e-02, c: 4.2923e-02, d: 3.5144e-02
2020/08/02, 04:08:29, Iteration: 9100, Train Loss: 7.8285e-02, c: 4.2412e-02, d: 3.5702e-02
2020/08/02, 04:10:43, Iteration: 9200, Train Loss: 7.8736e-02, c: 4.4240e-02, d: 3.4446e-02
2020/08/02, 04:12:58, Iteration: 9300, Train Loss: 7.7439e-02, c: 4.4068e-02, d: 3.3280e-02
2020/08/02, 04:15:14, Iteration: 9400, Train Loss: 7.9443e-02, c: 4.3303e-02, d: 3.6057e-02
2020/08/02, 04:17:30, Iteration: 9500, Train Loss: 7.7425e-02, c: 4.3602e-02, d: 3.3725e-02
2020/08/02, 04:19:45, Iteration: 9600, Train Loss: 7.8830e-02, c: 4.1537e-02, d: 3.7173e-02
2020/08/02, 04:22:01, Iteration: 9700, Train Loss: 7.9193e-02, c: 4.7042e-02, d: 3.2004e-02
2020/08/02, 04:24:16, Iteration: 9800, Train Loss: 7.9074e-02, c: 4.4664e-02, d: 3.4373e-02
2020/08/02, 04:26:32, Iteration: 9900, Train Loss: 7.7897e-02, c: 4.3435e-02, d: 3.4400e-02
2020/08/02, 04:28:48, Iteration: 10000, Train Loss: 8.1617e-02, c: 4.5222e-02, d

2020/08/02, 07:26:13, Iteration: 17900, Train Loss: 7.7413e-02, c: 4.3996e-02, d: 3.3374e-02
2020/08/02, 07:28:28, Iteration: 18000, Train Loss: 7.8071e-02, c: 4.1859e-02, d: 3.6102e-02
2020/08/02, 07:30:42, Iteration: 18100, Train Loss: 7.8712e-02, c: 4.1251e-02, d: 3.7420e-02
2020/08/02, 07:32:57, Iteration: 18200, Train Loss: 7.8809e-02, c: 4.1750e-02, d: 3.6982e-02
2020/08/02, 07:35:10, Iteration: 18300, Train Loss: 7.8255e-02, c: 4.1018e-02, d: 3.7212e-02
2020/08/02, 07:37:25, Iteration: 18400, Train Loss: 7.7918e-02, c: 4.2372e-02, d: 3.5506e-02
2020/08/02, 07:39:39, Iteration: 18500, Train Loss: 7.8267e-02, c: 4.0952e-02, d: 3.7284e-02
2020/08/02, 07:41:54, Iteration: 18600, Train Loss: 7.7420e-02, c: 4.3765e-02, d: 3.3630e-02
2020/08/02, 07:44:08, Iteration: 18700, Train Loss: 7.8157e-02, c: 4.5003e-02, d: 3.3140e-02
2020/08/02, 07:46:23, Iteration: 18800, Train Loss: 7.7613e-02, c: 4.0997e-02, d: 3.6552e-02
2020/08/02, 07:48:37, Iteration: 18900, Train Loss: 7.7535e-02, c: 4.3