# Schrodinger equation
---

$$
\begin{aligned}
\varepsilon u_t - i\frac{\epsilon^2}{2}\Delta u + i V(x)u = 0, \ t \in \mathbb{R}, x\in \mathbb{R}^d, \\
u(x, t = 0) = u_0(x),
\end{aligned}
$$
where $V$ is a given electrostatic potential, $0 < \varepsilon \ll 1$.

---

Example 2 in Shi Jin's paper:
$$
u(x, 0) = \sqrt{n_0(x)}e^{iS_0(x)/\varepsilon},
$$
(3.9-3.10)
$$
n_0(x) = (e^{-25(x-0.5)^2})^2, \ S_0(x) = 0.2(x^2 - x).
$$

Periodic BC.

---

Consider the real and imag part of $u$, i.e., 
$$
u(x, t) = p(x, t) + iq(x, t),
$$
then
$$
\begin{aligned}
&\varepsilon p_t+\frac{\varepsilon^2}{2}q_{xx} - V(x)q = 0, \\
&\varepsilon q_t-\frac{\varepsilon^2}{2}p_{xx} + V(x)p = 0,
\end{aligned}
$$
with ic
$$
p(x, 0) = e^{-25(x-0.5)^2}\cos(\frac{0.2(x^2-x)}{\varepsilon}), \ q(x, 0) = e^{-25(x-0.5)^2}\sin(\frac{0.2(x^2-x)}{\varepsilon}),
$$
and bc
$$
p(0, t) = p(1, t), q(0, t) = q(1, t), p_x(0, t) = p_x(1, t), q_x(0, t) = q_x(1, t).
$$

In [2]:
NAME = "1_siren"

In [3]:
import jax, jax.nn
from jax import random
import jax.numpy as jnp
from jax.experimental import optimizers

import sys, os
sys.path.append("../../")
	
from Seismic_wave_inversion_PINN.data_utils import *
from Seismic_wave_inversion_PINN.jax_model import *

from collections import namedtuple

In [4]:
key = random.PRNGKey(0)
key, subkey = random.split(key, 2)

layers = [2] + [128]*4 + [2] # (x, t) -> (u, v)
c0 = 1.0
w0 = 10.0
w1 = 1.0
lambda_0 = 1e-8
direct_params = init_siren_params(subkey, layers, c0, w0, w1)

domain = jnp.array([[0., 0.], [1., 1.]])
epsilon = 1.0
V = 1.0

@jax.jit
def model(params, xt):
	# first, normalize to [-1, 1]
	xt = 2.0*(xt - domain[0, :])/(domain[1, :]-domain[0, :]) - 1.0
	for w, b in params[:-1]:
		xt = jnp.sin(jnp.dot(xt, w) + b)
	return jnp.dot(xt, params[-1][0]) + params[-1][1]

jacobian = jacrev_fn(model)
hessian = hessian_fn(model)

In [8]:
metaloss = mse

jit_conservation = lambda i: jax.partial(jax.jit, static_argnums = i)

@jax.jit
def loss_fn_(params, batch):
	collocation, dirichlet, periodic_bc = batch["collocation"], batch["dirichlet"], batch["periodic_bc"]
	direct_params = params
	
	uv_c = model(direct_params, jnp.hstack([collocation.x, collocation.t]))
	u_c, v_c = uv_c[:, 0:1], uv_c[:, 1:2]
	
	# jacobian[i] = [[du/dx, du/dt],
	#                [dv/dx, dv/dt]]
	# i: the i^th input
	duv_dxt_c = jacobian(direct_params, jnp.hstack([collocation.x, collocation.t]))
	du_dt_c, dv_dt_c = duv_dxt_c[:, 0:1, 1], duv_dxt_c[:, 1:2, 1]
	
	# hessian[i] = [
    #				[[du/dxx, du/dxy],
	#                [du/dxy, du/dyy]],
	#               [[dv/dxx, dv/dxy],
	#                [dv/dxy, dv/dyy]]
	#              ]
	duv_dxxtt_c = hessian(direct_params, jnp.hstack([collocation.x, collocation.t]))
	du_dxx_c, dv_dxx_c = duv_dxxtt_c[:, 0:1, 0, 0], duv_dxxtt_c[:, 1:2, 0, 0] 
	
	uv_l = model(direct_params, jnp.hstack([periodic_bc.l, periodic_bc.t]))
	uv_r = model(direct_params, jnp.hstack([periodic_bc.r, periodic_bc.t]))
	u_l, v_l = uv_l[:, 0:1], uv_l[:, 1:2]
	u_r, v_r = uv_r[:, 0:1], uv_r[:, 1:2]
	
	duv_dxt_l = jacobian(direct_params, jnp.hstack([periodic_bc.l, periodic_bc.t]))
	duv_dxt_r = jacobian(direct_params, jnp.hstack([periodic_bc.r, periodic_bc.t]))
	du_dx_l, dv_dx_l = duv_dxt_l[:, 0:1, 0], duv_dxt_l[:, 1:2, 0]
	du_dx_r, dv_dx_r = duv_dxt_r[:, 0:1, 0], duv_dxt_r[:, 1:2, 0]
		
	uv_d = model(direct_params, jnp.hstack([dirichlet.x, dirichlet.t]))
	u_d, v_d = uv_d[:, 0:1], uv_d[:, 1:2]

	
	loss_c1 = metaloss(du_dt_c + 0.5*dv_dxx_c + (u_c**2 + v_c**2)*v_c, 0)
	loss_c2 = metaloss(-dv_dt_c + 0.5*du_dxx_c + (u_c**2 + v_c**2)*u_c, 0)
	loss_c = loss_c1 + loss_c2
	
	loss_d1 = metaloss(u_d, dirichlet.u)
	loss_d2 = metaloss(v_d, dirichlet.v)
	loss_d = loss_d1 + loss_d2
	
	loss_pbc_d1 = metaloss(u_l, u_r)
	loss_pbc_d2 = metaloss(v_l, v_r)
	loss_pbc_n1 = metaloss(du_dx_l, du_dx_r)
	loss_pbc_n2 = metaloss(dv_dx_l, dv_dx_r)
	loss_pbc_d = loss_pbc_d1 + loss_pbc_d2
	loss_pbc_n = loss_pbc_n1 + loss_pbc_n2
	
	return loss_c, loss_d, loss_pbc_d, loss_pbc_n


@jax.jit
def loss_fn(params, batch):
	w = batch["weights"]
	loss_c, loss_d, loss_pbc_d, loss_pbc_n = loss_fn_(params, batch)
	return w["c"]*loss_c + w["d"]*loss_d + w["pbc_d"]*loss_pbc_d + w["pbc_n"]*loss_pbc_n\
			+ l2_regularization(params, lambda_0)

@jax.jit
def step(i, opt_state, batch):
	params = get_params(opt_state)
	grad = jax.grad(loss_fn, 0)(params, batch)
	return opt_update(i, grad, opt_state)

@jax.jit
def evaluate(params, batch):
	w = batch["weights"]
	loss_c, loss_d, loss_pbc_d, loss_pbc_n = loss_fn_(params, batch)
	return w["c"]*loss_c + w["d"]*loss_d + w["pbc_d"]*loss_pbc_d + w["pbc_n"]*loss_pbc_n, \
			loss_c, loss_d, loss_pbc_d, loss_pbc_n

In [9]:
key = random.PRNGKey(1)
key, *subkeys = random.split(key, 4)

from scipy.io import loadmat
data = loadmat("../19_2020-08-01_Schrodinger/NLS.mat")
x, t, h = data["x"].reshape((-1, 1)), data["tt"].reshape((-1, 1)), data["uu"].T
u, v = np.real(h), np.imag(h)

import pickle
with open("../19_2020-08-01_Schrodinger/dataset_train.pkl", "rb") as file:
	[ix_i, ix_b, xt_c] = pickle.load(file)
	
# # ic
# n_i = 50
# ix_i = random.choice(subkeys[0], jnp.arange(len(x)), shape = (n_i, ), replace = False)
x_i = x[ix_i, :]
t_i = np.zeros_like(x_i)
u_i = u[0, ix_i].reshape((-1, 1))
v_i = v[0, ix_i].reshape((-1, 1))

# # bc
# n_b = 50
# ix_b = random.choice(subkeys[1], jnp.arange(len(t)), shape = (n_b, ), replace = False)
t_b = t[ix_b, :]
x_lb = np.ones_like(t_b)*domain[0, 0]
x_rb = np.ones_like(t_b)*domain[1, 0]

# n_c = 20000
# from pyDOE import lhs
# xt_c = lhs(2, n_c)
# x_c = transform(xt_c[:, 0:1], *domain[:, 0])
# t_c = transform(xt_c[:, 1:2], *domain[:, 1])
x_c, t_c = xt_c[:, 0:1], xt_c[:, 1:2]

dataset_Dirichlet = namedtuple("dataset_Dirichlet", ["x", "t", "u", "v"])
dataset_Collocation = namedtuple("dataset_Collocation", ["x", "t"])
dataset_BC = namedtuple("dataset_BC", ["l", "r", "t"])

map_to_jnp_array = lambda x: map(lambda arr: jnp.array(arr), x)
dirichlet = dataset_Dirichlet(*map_to_jnp_array([x_i, t_i, u_i, v_i]))
periodic_bc = dataset_BC(*map_to_jnp_array([x_lb, x_rb, t_b]))
collocation = dataset_Collocation(*map_to_jnp_array([jnp.vstack([dirichlet.x, periodic_bc.l, periodic_bc.r, x_c]),
													jnp.vstack([dirichlet.t, periodic_bc.t, periodic_bc.t, t_c])]))

In [None]:
lr = 1e-3
start_iteration = 0
iterations = 50000
print_every = 100
save_every = 50000
batch_size = {"dirichlet": 50, "bc": 50, "collocation": 20150}
weights = {"c": 1.0, "d": 1.0, "pbc_d": 1.0, "pbc_n": 1.0}

key, *subkeys = random.split(key, 4)
Dirichlet = Batch_Generator(subkeys[0], dirichlet, batch_size["dirichlet"])
Collocation = Batch_Generator(subkeys[1], collocation, batch_size["collocation"])
BC = Batch_Generator(subkeys[2], periodic_bc, batch_size["bc"])
params = direct_params

opt_init, opt_update, get_params = optimizers.adam(lr)
opt_state = opt_init(params)
hist = {"iter": [], "loss": []}

for iteration in range(start_iteration, start_iteration+iterations+1):
	batch = {
		"dirichlet": dataset_Dirichlet(*next(Dirichlet)),
		"collocation": dataset_Collocation(*next(Collocation)),
		"periodic_bc": dataset_BC(*next(BC)),
		"weights": weights
	}
	opt_state = step(iteration, opt_state, batch)
	if (iteration-start_iteration) % print_every == 0:
		names = ["Loss", "c", "d"]
		params_ = get_params(opt_state)
		losses = evaluate(params_, batch)
		print("{}, Iteration: {}, Train".format(get_time(), iteration) + \
			  ','.join([" {}: {:.4e}".format(name, loss) for name, loss in zip(names, losses)]))
		hist["iter"].append(iteration)
		hist["loss"].append(losses[0])
	if (iteration-start_iteration) % save_every == 0:
		params_ = np.asarray(get_params(opt_state), dtype = object)
		save_path = "models/{}/iteration_{}/params.npy".format(NAME, iteration)
		if not os.path.exists(os.path.dirname(save_path)):
			os.makedirs(os.path.dirname(save_path))
		np.save(save_path, params_)

2020/08/05, 21:33:35, Iteration: 0, Train Loss: 7.7047e+00, c: 6.5884e+00, d: 9.1159e-01
2020/08/05, 21:33:38, Iteration: 100, Train Loss: 4.8824e-01, c: 5.2886e-02, d: 4.3439e-01
2020/08/05, 21:33:41, Iteration: 200, Train Loss: 4.0056e-01, c: 8.5346e-02, d: 3.1340e-01
2020/08/05, 21:33:44, Iteration: 300, Train Loss: 3.3547e-01, c: 1.0476e-01, d: 2.2678e-01
2020/08/05, 21:33:47, Iteration: 400, Train Loss: 2.8152e-01, c: 1.0403e-01, d: 1.7393e-01
2020/08/05, 21:33:50, Iteration: 500, Train Loss: 2.2077e-01, c: 8.8546e-02, d: 1.2978e-01
2020/08/05, 21:33:53, Iteration: 600, Train Loss: 1.8433e-01, c: 7.5391e-02, d: 1.0669e-01
2020/08/05, 21:33:56, Iteration: 700, Train Loss: 1.7345e-01, c: 7.7588e-02, d: 9.2238e-02
2020/08/05, 21:33:59, Iteration: 800, Train Loss: 1.4465e-01, c: 6.1974e-02, d: 8.0547e-02
2020/08/05, 21:34:02, Iteration: 900, Train Loss: 1.2834e-01, c: 5.5304e-02, d: 7.1116e-02
2020/08/05, 21:34:05, Iteration: 1000, Train Loss: 1.1908e-01, c: 5.4208e-02, d: 6.2036e-02


2020/08/05, 21:38:19, Iteration: 9000, Train Loss: 4.1784e-02, c: 3.0615e-02, d: 7.1857e-03
2020/08/05, 21:38:22, Iteration: 9100, Train Loss: 1.8888e-02, c: 1.0308e-02, d: 6.8724e-03
2020/08/05, 21:38:26, Iteration: 9200, Train Loss: 1.7482e-02, c: 9.2861e-03, d: 6.5467e-03
2020/08/05, 21:38:29, Iteration: 9300, Train Loss: 1.9045e-02, c: 1.0876e-02, d: 6.5625e-03
2020/08/05, 21:38:32, Iteration: 9400, Train Loss: 1.6744e-02, c: 9.0249e-03, d: 6.1469e-03
2020/08/05, 21:38:35, Iteration: 9500, Train Loss: 1.6277e-02, c: 8.8272e-03, d: 5.8897e-03
2020/08/05, 21:38:38, Iteration: 9600, Train Loss: 1.9430e-02, c: 1.2092e-02, d: 5.7015e-03
2020/08/05, 21:38:41, Iteration: 9700, Train Loss: 2.2622e-02, c: 1.3817e-02, d: 5.7620e-03
2020/08/05, 21:38:45, Iteration: 9800, Train Loss: 1.5240e-02, c: 8.4172e-03, d: 5.3892e-03
2020/08/05, 21:38:48, Iteration: 9900, Train Loss: 2.5133e-02, c: 1.8097e-02, d: 5.3900e-03
2020/08/05, 21:38:51, Iteration: 10000, Train Loss: 1.7306e-02, c: 1.0606e-02, d

2020/08/05, 21:43:00, Iteration: 17900, Train Loss: 3.4001e-03, c: 2.0007e-03, d: 1.1799e-03
2020/08/05, 21:43:03, Iteration: 18000, Train Loss: 4.6434e-03, c: 3.2186e-03, d: 1.1554e-03
2020/08/05, 21:43:07, Iteration: 18100, Train Loss: 3.2011e-03, c: 1.8680e-03, d: 1.1284e-03
2020/08/05, 21:43:10, Iteration: 18200, Train Loss: 4.9940e-03, c: 3.6156e-03, d: 1.1414e-03
2020/08/05, 21:43:13, Iteration: 18300, Train Loss: 3.1080e-03, c: 1.8052e-03, d: 1.1026e-03
2020/08/05, 21:43:16, Iteration: 18400, Train Loss: 5.2856e-03, c: 3.8284e-03, d: 1.0890e-03
2020/08/05, 21:43:19, Iteration: 18500, Train Loss: 4.7053e-03, c: 3.3578e-03, d: 1.0745e-03
2020/08/05, 21:43:22, Iteration: 18600, Train Loss: 3.1938e-03, c: 1.9284e-03, d: 1.0572e-03
2020/08/05, 21:43:25, Iteration: 18700, Train Loss: 3.1854e-03, c: 1.9496e-03, d: 1.0401e-03
2020/08/05, 21:43:28, Iteration: 18800, Train Loss: 3.4986e-03, c: 2.2628e-03, d: 1.0320e-03
2020/08/05, 21:43:32, Iteration: 18900, Train Loss: 5.9522e-03, c: 4.6

In [None]:
lr = 1e-4
start_iteration += iterations
iterations = 10000
print_every = 200
save_every = 5000
batch_size = {"dirichlet": 10000, "bc": 1000, "collocation": 10000}
weights = {"c": 1e-3, "d": 100.0, "pbc_d": 1.0, "pbc_n": 1.0, "conservation": 1e-6, "energy": 0}

# key, *subkeys = random.split(key, 4)
# Dirichlet = Batch_Generator(subkeys[0], dirichlet, batch_size["dirichlet"])
# Collocation = Batch_Generator(subkeys[1], collocation, batch_size["collocation"])
# BC = Batch_Generator(subkeys[2], periodic_bc, batch_size["bc"])
# params = direct_params

# opt_init, opt_update, get_params = optimizers.adam(lr)
# opt_state = opt_init(params)
# hist = {"iter": [], "loss": []}

for iteration in range(start_iteration, start_iteration+iterations+1):
	batch = {
		"dirichlet": dataset_Dirichlet(*next(Dirichlet)),
		"collocation": dataset_Collocation(*next(Collocation)),
		"periodic_bc": dataset_BC(*next(BC)),
		"weights": weights
	}
	opt_state = step(iteration, opt_state, batch, conservation)
	if (iteration-start_iteration) % print_every == 0:
		names = ["Loss", "c", "d", "pbc_d", "pbc_n", "conservation", "energy"]
		params_ = get_params(opt_state)
		losses = evaluate(params_, batch, conservation)
		print("{}, Iteration: {}, Train".format(get_time(), iteration) + \
			  ','.join([" {}: {:.4e}".format(name, loss) for name, loss in zip(names, losses)]))
		hist["iter"].append(iteration)
		hist["loss"].append(losses)
	if (iteration-start_iteration) % save_every == 0:
		params_ = np.asarray(get_params(opt_state), dtype = object)
		save_path = "models/{}/iteration_{}/params.npy".format(NAME, iteration)
		if not os.path.exists(os.path.dirname(save_path)):
			os.makedirs(os.path.dirname(save_path))
		np.save(save_path, params_)

In [None]:
lr = 1e-4
start_iteration += iterations
iterations = 10000
print_every = 200
save_every = 5000
batch_size = {"dirichlet": 10000, "bc": 1000, "collocation": 10000}
weights = {"c": 1e-3, "d": 100.0, "pbc_d": 1.0, "pbc_n": 1.0, "conservation": 1e-6, "energy": 1e-8}

# key, *subkeys = random.split(key, 4)
# Dirichlet = Batch_Generator(subkeys[0], dirichlet, batch_size["dirichlet"])
# Collocation = Batch_Generator(subkeys[1], collocation, batch_size["collocation"])
# BC = Batch_Generator(subkeys[2], periodic_bc, batch_size["bc"])
# params = direct_params

# opt_init, opt_update, get_params = optimizers.adam(lr)
# opt_state = opt_init(params)
# hist = {"iter": [], "loss": []}

for iteration in range(start_iteration, start_iteration+iterations+1):
	batch = {
		"dirichlet": dataset_Dirichlet(*next(Dirichlet)),
		"collocation": dataset_Collocation(*next(Collocation)),
		"periodic_bc": dataset_BC(*next(BC)),
		"weights": weights
	}
	opt_state = step(iteration, opt_state, batch, conservation)
	if (iteration-start_iteration) % print_every == 0:
		names = ["Loss", "c", "d", "pbc_d", "pbc_n", "conservation", "energy"]
		params_ = get_params(opt_state)
		losses = evaluate(params_, batch, conservation)
		print("{}, Iteration: {}, Train".format(get_time(), iteration) + \
			  ','.join([" {}: {:.4e}".format(name, loss) for name, loss in zip(names, losses)]))
		hist["iter"].append(iteration)
		hist["loss"].append(losses)
	if (iteration-start_iteration) % save_every == 0:
		params_ = np.asarray(get_params(opt_state), dtype = object)
		save_path = "models/{}/iteration_{}/params.npy".format(NAME, iteration)
		if not os.path.exists(os.path.dirname(save_path)):
			os.makedirs(os.path.dirname(save_path))
		np.save(save_path, params_)

In [None]:
lr = 1e-4
start_iteration += iterations
iterations = 20000
print_every = 200
save_every = 5000
batch_size = {"dirichlet": 10000, "bc": 1000, "collocation": 10000}
weights = {"c": 1e-3, "d": 100.0, "pbc_d": 1.0, "pbc_n": 1.0, "conservation": 1e-5, "energy": 1e-8}

# key, *subkeys = random.split(key, 4)
# Dirichlet = Batch_Generator(subkeys[0], dirichlet, batch_size["dirichlet"])
# Collocation = Batch_Generator(subkeys[1], collocation, batch_size["collocation"])
# BC = Batch_Generator(subkeys[2], periodic_bc, batch_size["bc"])
# params = direct_params

# opt_init, opt_update, get_params = optimizers.adam(lr)
# opt_state = opt_init(params)
# hist = {"iter": [], "loss": []}

for iteration in range(start_iteration, start_iteration+iterations+1):
	batch = {
		"dirichlet": dataset_Dirichlet(*next(Dirichlet)),
		"collocation": dataset_Collocation(*next(Collocation)),
		"periodic_bc": dataset_BC(*next(BC)),
		"weights": weights
	}
	opt_state = step(iteration, opt_state, batch, conservation)
	if (iteration-start_iteration) % print_every == 0:
		names = ["Loss", "c", "d", "pbc_d", "pbc_n", "conservation", "energy"]
		params_ = get_params(opt_state)
		losses = evaluate(params_, batch, conservation)
		print("{}, Iteration: {}, Train".format(get_time(), iteration) + \
			  ','.join([" {}: {:.4e}".format(name, loss) for name, loss in zip(names, losses)]))
		hist["iter"].append(iteration)
		hist["loss"].append(losses)
	if (iteration-start_iteration) % save_every == 0:
		params_ = np.asarray(get_params(opt_state), dtype = object)
		save_path = "models/{}/iteration_{}/params.npy".format(NAME, iteration)
		if not os.path.exists(os.path.dirname(save_path)):
			os.makedirs(os.path.dirname(save_path))
		np.save(save_path, params_)

In [None]:
lr = 1e-4
start_iteration += iterations
iterations = 20000
print_every = 200
save_every = 5000
batch_size = {"dirichlet": 10000, "bc": 1000, "collocation": 10000}
weights = {"c": 1e-3, "d": 100.0, "pbc_d": 1.0, "pbc_n": 1.0, "conservation": 1e-3, "energy": 1e-7}

# key, *subkeys = random.split(key, 4)
# Dirichlet = Batch_Generator(subkeys[0], dirichlet, batch_size["dirichlet"])
# Collocation = Batch_Generator(subkeys[1], collocation, batch_size["collocation"])
# BC = Batch_Generator(subkeys[2], periodic_bc, batch_size["bc"])
# params = direct_params

# opt_init, opt_update, get_params = optimizers.adam(lr)
# opt_state = opt_init(params)
# hist = {"iter": [], "loss": []}

for iteration in range(start_iteration, start_iteration+iterations+1):
	batch = {
		"dirichlet": dataset_Dirichlet(*next(Dirichlet)),
		"collocation": dataset_Collocation(*next(Collocation)),
		"periodic_bc": dataset_BC(*next(BC)),
		"weights": weights
	}
	opt_state = step(iteration, opt_state, batch, conservation)
	if (iteration-start_iteration) % print_every == 0:
		names = ["Loss", "c", "d", "pbc_d", "pbc_n", "conservation", "energy"]
		params_ = get_params(opt_state)
		losses = evaluate(params_, batch, conservation)
		print("{}, Iteration: {}, Train".format(get_time(), iteration) + \
			  ','.join([" {}: {:.4e}".format(name, loss) for name, loss in zip(names, losses)]))
		hist["iter"].append(iteration)
		hist["loss"].append(losses)
	if (iteration-start_iteration) % save_every == 0:
		params_ = np.asarray(get_params(opt_state), dtype = object)
		save_path = "models/{}/iteration_{}/params.npy".format(NAME, iteration)
		if not os.path.exists(os.path.dirname(save_path)):
			os.makedirs(os.path.dirname(save_path))
		np.save(save_path, params_)

In [None]:
lr = 1e-4
start_iteration += iterations
iterations = 50000
print_every = 200
save_every = 5000
batch_size = {"dirichlet": 10000, "bc": 1000, "collocation": 10000}
weights = {"c": 1e-1, "d": 100.0, "pbc_d": 1.0, "pbc_n": 1.0, "conservation": 1.0, "energy": 1e-5}

# key, *subkeys = random.split(key, 4)
# Dirichlet = Batch_Generator(subkeys[0], dirichlet, batch_size["dirichlet"])
# Collocation = Batch_Generator(subkeys[1], collocation, batch_size["collocation"])
# BC = Batch_Generator(subkeys[2], periodic_bc, batch_size["bc"])
# params = direct_params

# opt_init, opt_update, get_params = optimizers.adam(lr)
# opt_state = opt_init(params)
# hist = {"iter": [], "loss": []}

for iteration in range(start_iteration, start_iteration+iterations+1):
	batch = {
		"dirichlet": dataset_Dirichlet(*next(Dirichlet)),
		"collocation": dataset_Collocation(*next(Collocation)),
		"periodic_bc": dataset_BC(*next(BC)),
		"weights": weights
	}
	opt_state = step(iteration, opt_state, batch, conservation)
	if (iteration-start_iteration) % print_every == 0:
		names = ["Loss", "c", "d", "pbc_d", "pbc_n", "conservation", "energy"]
		params_ = get_params(opt_state)
		losses = evaluate(params_, batch, conservation)
		print("{}, Iteration: {}, Train".format(get_time(), iteration) + \
			  ','.join([" {}: {:.4e}".format(name, loss) for name, loss in zip(names, losses)]))
		hist["iter"].append(iteration)
		hist["loss"].append(losses)
	if (iteration-start_iteration) % save_every == 0:
		params_ = np.asarray(get_params(opt_state), dtype = object)
		save_path = "models/{}/iteration_{}/params.npy".format(NAME, iteration)
		if not os.path.exists(os.path.dirname(save_path)):
			os.makedirs(os.path.dirname(save_path))
		np.save(save_path, params_)

In [None]:
lr = 1e-4
start_iteration += iterations
iterations = 50000
print_every = 200
save_every = 5000
batch_size = {"dirichlet": 10000, "bc": 1000, "collocation": 10000}
weights = {"c": 1.0, "d": 1e4, "pbc_d": 1e2, "pbc_n": 1e3, "conservation": 1e2, "energy": 1e-3}

# key, *subkeys = random.split(key, 4)
# Dirichlet = Batch_Generator(subkeys[0], dirichlet, batch_size["dirichlet"])
# Collocation = Batch_Generator(subkeys[1], collocation, batch_size["collocation"])
# BC = Batch_Generator(subkeys[2], periodic_bc, batch_size["bc"])
# params = direct_params

# opt_init, opt_update, get_params = optimizers.adam(lr)
# opt_state = opt_init(params)
# hist = {"iter": [], "loss": []}

for iteration in range(start_iteration, start_iteration+iterations+1):
	batch = {
		"dirichlet": dataset_Dirichlet(*next(Dirichlet)),
		"collocation": dataset_Collocation(*next(Collocation)),
		"periodic_bc": dataset_BC(*next(BC)),
		"weights": weights
	}
	opt_state = step(iteration, opt_state, batch, conservation)
	if (iteration-start_iteration) % print_every == 0:
		names = ["Loss", "c", "d", "pbc_d", "pbc_n", "conservation", "energy"]
		params_ = get_params(opt_state)
		losses = evaluate(params_, batch, conservation)
		print("{}, Iteration: {}, Train".format(get_time(), iteration) + \
			  ','.join([" {}: {:.4e}".format(name, loss) for name, loss in zip(names, losses)]))
		hist["iter"].append(iteration)
		hist["loss"].append(losses)
	if (iteration-start_iteration) % save_every == 0:
		params_ = np.asarray(get_params(opt_state), dtype = object)
		save_path = "models/{}/iteration_{}/params.npy".format(NAME, iteration)
		if not os.path.exists(os.path.dirname(save_path)):
			os.makedirs(os.path.dirname(save_path))
		np.save(save_path, params_)

In [None]:
lr = 1e-4
start_iteration += iterations
iterations = 50000
print_every = 200
save_every = 5000
batch_size = {"dirichlet": 10000, "bc": 1000, "collocation": 10000}
weights = {"c": 1.0, "d": 1e4, "pbc_d": 1e2, "pbc_n": 1e3, "conservation": 1e2, "energy": 1e-2}

# key, *subkeys = random.split(key, 4)
# Dirichlet = Batch_Generator(subkeys[0], dirichlet, batch_size["dirichlet"])
# Collocation = Batch_Generator(subkeys[1], collocation, batch_size["collocation"])
# BC = Batch_Generator(subkeys[2], periodic_bc, batch_size["bc"])
# params = direct_params

# opt_init, opt_update, get_params = optimizers.adam(lr)
# opt_state = opt_init(params)
# hist = {"iter": [], "loss": []}

for iteration in range(start_iteration, start_iteration+iterations+1):
	batch = {
		"dirichlet": dataset_Dirichlet(*next(Dirichlet)),
		"collocation": dataset_Collocation(*next(Collocation)),
		"periodic_bc": dataset_BC(*next(BC)),
		"weights": weights
	}
	opt_state = step(iteration, opt_state, batch, conservation)
	if (iteration-start_iteration) % print_every == 0:
		names = ["Loss", "c", "d", "pbc_d", "pbc_n", "conservation", "energy"]
		params_ = get_params(opt_state)
		losses = evaluate(params_, batch, conservation)
		print("{}, Iteration: {}, Train".format(get_time(), iteration) + \
			  ','.join([" {}: {:.4e}".format(name, loss) for name, loss in zip(names, losses)]))
		hist["iter"].append(iteration)
		hist["loss"].append(losses)
	if (iteration-start_iteration) % save_every == 0:
		params_ = np.asarray(get_params(opt_state), dtype = object)
		save_path = "models/{}/iteration_{}/params.npy".format(NAME, iteration)
		if not os.path.exists(os.path.dirname(save_path)):
			os.makedirs(os.path.dirname(save_path))
		np.save(save_path, params_)

In [None]:
from scipy.io import loadmat

uv_true = loadmat("epsilon_1.0.mat")["u"].T

from matplotlib import animation
%matplotlib notebook

fig, ax = plt.subplots(1, 3, figsize = (15, 5))
lines = []
for i in range(3):
    line1, = ax[i].plot([], [], lw = 1.5, label = "true")
    line2, = ax[i].plot([], [], lw = 1.5, label = "pred")
    lines.extend([line1, line2])
    ax[i].set_xlim([0, 1])
    ax[i].set_ylim([-1, 1])
    ax[i].legend()
    ax[i].grid()
    
def init():
	for line in lines:
		line.set_data([], [])
	return lines

params_ = get_params(opt_state)

x_test = jnp.linspace(*domain[:, 0], 64)
t_test = jnp.linspace(*domain[:, 1], 101)
xt_tests = [tensor_grid([x_test, ti]) for ti in t_test]
uv_preds = [model(params_, xt_test) for xt_test in xt_tests]
u_preds, v_preds = [uv_pred[:, 0:1] for uv_pred in uv_preds], [uv_pred[:, 1:2] for uv_pred in uv_preds]

def animate(i):
	u_pred, v_pred = u_preds[i], v_preds[i]
	u_true, v_true = np.real(uv_true[i, :]), np.imag(uv_true[i, :])
	
	lines[0].set_data(x_test, u_true)
	lines[1].set_data(x_test, u_pred)
	ax[0].set_title("u, t = {:.2f}".format(t_test[i]))
    
	lines[2].set_data(x_test, v_true)
	lines[3].set_data(x_test, v_pred)
	ax[1].set_title("v, t = {:.2f}".format(t_test[i]))

	lines[4].set_data(x_test, np.sqrt(u_true**2+v_true**2))
	lines[5].set_data(x_test, np.sqrt(u_pred**2+v_pred**2))
	ax[2].set_title("|h|, t = {:.2f}".format(t_test[i]))

	return lines

anim = animation.FuncAnimation(fig, animate, frames = len(t_test), interval = 1000, blit = True)
plt.show()

In [None]:
uv_true.shape