# Schrodinger equation
---

$$
\begin{aligned}
\varepsilon u_t - i\frac{\epsilon^2}{2}\Delta u + i V(x)u = 0, \ t \in \mathbb{R}, x\in \mathbb{R}^d, \\
u(x, t = 0) = u_0(x),
\end{aligned}
$$
where $V$ is a given electrostatic potential, $0 < \varepsilon \ll 1$.

---

Example 1 in Shi Jin's paper (2008-jy-phase.pdf):
$$
u(x, 0) = \sqrt{n_0(x)}e^{iS_0(x)/\varepsilon},
$$
(3.9-3.10)
$$
n_0(x) = e^{-25x^2}, \ S_0(x) = -0.2\log(2\cosh(5x)).
$$

Periodic BC.

Domain: $[-0.25, 0.25]\times [0, 0.5]$.

---

Consider the real and imag part of $u$, i.e., 
$$
u(x, t) = p(x, t) + iq(x, t),
$$
then
$$
\begin{aligned}
&\varepsilon p_t+\frac{\varepsilon^2}{2}q_{xx} - V(x)q = 0, \\
&\varepsilon q_t-\frac{\varepsilon^2}{2}p_{xx} + V(x)p = 0,
\end{aligned}
$$
with ic
$$
p(x, 0) = e^{-25x^2}\cos(-0.2\log(2\cosh(5x))/\varepsilon), \ q(x, 0) = e^{-25x^2}\sin(-0.2\log(2\cosh(5x))/\varepsilon),
$$
and bc
$$
p(x_0, t) = p(x_1, t), q(x_0, t) = q(x_1, t), p_x(x_0, t) = p_x(x_1, t), q_x(x_0, t) = q_x(x_1, t).
$$

In [1]:
NAME = "1_2"

In [2]:
import jax, jax.nn
from jax import random
import jax.numpy as jnp
from jax.experimental import optimizers
from jax.ops import index, index_add, index_update


import sys, os
sys.path.append("../../")
	
from Seismic_wave_inversion_PINN.data_utils import *
from Seismic_wave_inversion_PINN.jax_model import *

from collections import namedtuple

In [8]:
key = random.PRNGKey(1)
key, subkey = random.split(key, 2)

params = jnp.ones((3, 2))

@jax.jit
def model_(params, xt):
# 	xt = jnp.sin(2.0*jnp.pi*(xt - domain[0, :])/(domain[1, :]-domain[0, :]) - jnp.pi)
	xt1 = jnp.dot(xt, jnp.array([[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]]))
	xt2 = index_update(xt1, index[:, 0], xt1[:, 0]+1)
	xt_ = index_update(xt2, index[:, 1], 2*xt2[:, 1]**2+2*xt2[:, 1])
	return xt_

@jax.jit
def model_2(params, xt):
# 	xt = jnp.sin(2.0*jnp.pi*(xt - domain[0, :])/(domain[1, :]-domain[0, :]) - jnp.pi)
	xt1 = jnp.dot(xt, jnp.array([[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]]))
	xt2 = index_update(xt1, index[0], xt1[0]+1)
	xt_ = index_update(xt2, index[1], 2*xt2[1]**2+2*xt2[1])
	return xt_

model = jax.vmap(model_, (None, 0))

In [9]:
xt = jnp.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]])
model_(params, xt)

DeviceArray([[ 2.,  4.,  2.],
             [ 3., 12.,  3.],
             [ 4., 24.,  4.]], dtype=float32)

In [10]:
def jacfwd_fn(model):
	def jac_(params, inputs):
		return jax.jit(jax.vmap(jax.jacfwd(model, 1), in_axes = (None, 0)))(params, inputs)
	return jac_

jacobian = jacfwd_fn(model_)

In [11]:
jax.vmap(jax.jacfwd(model_2, 1), (None, 0))(params, xt)

DeviceArray([[[ 1.,  0.],
              [ 6.,  0.],
              [ 0.,  1.]],

             [[ 1.,  0.],
              [10.,  0.],
              [ 0.,  1.]],

             [[ 1.,  0.],
              [14.,  0.],
              [ 0.,  1.]]], dtype=float32)