# 2D burgers
---

https://github.com/cics-nd/ar-pde-cnn/tree/master/2D-Burgers-SWAG

https://arxiv.org/abs/1906.05747

In [1]:
NAME="0726_"

In [2]:
import jax, jax.nn
from jax import random
import jax.numpy as jnp
from jax.experimental import optimizers

import sys, os
sys.path.append("../../")
	
from Seismic_wave_inversion_PINN.data_utils import *

from collections import namedtuple

In [None]:
def siren_layer_params(key, scale, m, n):
	w_key, b_key = random.split(key)
	return random.uniform(w_key, (m, n), jnp.float32, minval = -scale, maxval = scale), jnp.zeros((n, ), jnp.float32)

def init_siren_params(key, layers, c, w0):
	keys = random.split(key, len(layers))
	return [siren_layer_params(keys[0], w0*jnp.sqrt(c/layers[0]), layers[0], layers[1])] + \
			[siren_layer_params(k, jnp.sqrt(c/m), m, n) for m, n, k in zip(layers[1:-1], layers[2:], keys[1:])]

def 

layers = [3, 512, 512, 512, 512, 512, 1] # (x, z, t) -> p
c = 6.0
w0 = 30.0
lambda_0 = 1e-5
direct_params = init_siren_params(random.PRNGKey(0), layers, c, w0)

inverse_NAME = "0722_pretrain_inverse_problem"
inverse_iteration = 1000000
inverse_params = np.load("models/{}/inverse_model/iteration_{}/params.npy".format(inverse_NAME, inverse_iteration), allow_pickle=True)
inverse_params = [[jnp.asarray(arr) for arr in Arr] for Arr in inverse_params]

@jax.jit
def scalar_direct_model(params, x, z, t):
	x_ = jnp.hstack([x, z, t])
	for w, b in params[:-1]:
		x_ = jnp.sin(jnp.dot(x_, w) + b)
	return jnp.sum(jnp.dot(x_, params[-1][0]) + params[-1][1])

@jax.jit
def scalar_inverse_model(params, x, z):
	x_ = jnp.hstack([x, z])
	for w, b in params[:-1]:
		x_ = jnp.sin(jnp.dot(x_, w) + b)
	return jnp.sum(jnp.dot(x_, params[-1][0]) + params[-1][1])

direct_model = jax.jit(jax.vmap(scalar_direct_model, in_axes = (None, 0, 0, 0)))
inverse_model = jax.jit(jax.vmap(scalar_inverse_model, in_axes = (None, 0, 0)))