# Function approximation
---
$$
y = \sin(x), \ x \in [-\pi, \pi].
$$

In [1]:
import flax, flax.nn
from flax import jax_utils, optim
from flax.training import lr_schedule

import jax, jax.nn
from jax import random
import jax.numpy as jnp

import sys
sys.path.append("../../")
	
from Seismic_wave_inversion_PINN.tf_model_utils import *
from Seismic_wave_inversion_PINN.data_utils import *

In [2]:
class MLP(flax.nn.Module):
	def apply(self, x, layers, activation_fn):
		for l in layers[:-1]:
			x = flax.nn.Dense(x, features = l, 
							 kernel_init = jax.nn.initializers.glorot_uniform(),
							 bias_init = lambda key, shape: jnp.zeros(shape),)
			x = activation_fn(x)
		x = flax.nn.Dense(x, features = layers[-1], 
						 kernel_init = jax.nn.initializers.glorot_uniform(),
						 bias_init = lambda key, shape: jnp.zeros(shape),)
		return x
	
def create_model(key, layers):
	module = MLP.partial(layers = layers, activation_fn = jax.nn.relu)
	_, initial_params = module.init_by_shape(key, [(1,)])
	print(initial_params)
	model = flax.nn.Model(module, initial_params)
	return model

layers = [128, 128, 128, 128, 1]
model = create_model(random.PRNGKey(0), layers)

{'Dense_0': {'bias': DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32), 'kernel': DeviceArray([[ 0.1228532 , -0.00718319,  0.14009224, -0.02008048,
               0.03648998,  0.07539438, -0.15733759,  0.19290626,
               0.12540536, -0.08498622, -0.05346658, -0.06495079,
               0.01131031, -0.14955327,  0.04657786, -0.20040599,
               0.01196754,  0.14329325, -0.06426512, -0.183

In [3]:
@jax.jit
def step(optimizer, batch):
	def loss_fn(model):
		y_pred = model(batch["x"])
		loss = jnp.mean(jnp.square(y_pred-batch["y"]))
		return loss
	grad = jax.grad(loss_fn)(optimizer.target)
	optimizer = optimizer.apply_gradient(grad)
	return optimizer

@jax.jit
def evaluate(model, dataset):
	y_pred = model(dataset["x"])
	loss = jnp.mean(jnp.square(y_pred-dataset["y"]))
	return jax.device_get(loss)

In [4]:
x = random.uniform(random.PRNGKey(0), (100, 1), minval = -jnp.pi, maxval = jnp.pi)
f = lambda x: jnp.sin(x)
y = f(x)
dataset = {"x": x, "y": y}

In [5]:
optimizer = flax.optim.Adam(learning_rate = 1e-4).create(model)
for iteration in range(1, 10001):
	optimizer = step(optimizer, dataset)
	if iteration % 100 == 0:
		print("{}, Iteration: {}, Loss: {:.4e}".format(get_time(), iteration, evaluate(optimizer.target, dataset)))

2020/07/20, 20:48:06, Iteration: 100, Loss: 1.6199e-01
2020/07/20, 20:48:07, Iteration: 200, Loss: 1.1874e-01
2020/07/20, 20:48:07, Iteration: 300, Loss: 4.4711e-02
2020/07/20, 20:48:07, Iteration: 400, Loss: 5.7958e-03
2020/07/20, 20:48:07, Iteration: 500, Loss: 8.8134e-04
2020/07/20, 20:48:07, Iteration: 600, Loss: 2.8121e-04
2020/07/20, 20:48:07, Iteration: 700, Loss: 1.2749e-04
2020/07/20, 20:48:07, Iteration: 800, Loss: 6.5499e-05
2020/07/20, 20:48:07, Iteration: 900, Loss: 3.5606e-05
2020/07/20, 20:48:07, Iteration: 1000, Loss: 2.1028e-05
2020/07/20, 20:48:07, Iteration: 1100, Loss: 1.3973e-05
2020/07/20, 20:48:07, Iteration: 1200, Loss: 1.0250e-05
2020/07/20, 20:48:07, Iteration: 1300, Loss: 8.4616e-06
2020/07/20, 20:48:07, Iteration: 1400, Loss: 5.8546e-06
2020/07/20, 20:48:07, Iteration: 1500, Loss: 4.5110e-06
2020/07/20, 20:48:07, Iteration: 1600, Loss: 3.3622e-06
2020/07/20, 20:48:07, Iteration: 1700, Loss: 2.5738e-06
2020/07/20, 20:48:08, Iteration: 1800, Loss: 2.0278e-06
2