In [1]:
import jax.numpy as jnp 
from jax import grad, jit, vmap, hessian
from jax import random
import jax.config  as config 
import pickle as pkl
import matplotlib.pyplot as plt
import jax 
from utils import gt
import os 
import jaxopt
from jax.flatten_util import ravel_pytree

In [2]:
config.update("jax_enable_x64", True)
key = random.PRNGKey(0)

In [3]:
path = './results/t1'
if not os.path.exists(path):
    os.makedirs(path)
layers = [3,20,20,1]
max_iter = 2000
val_int = 10
bw = 100 
iw = 100

In [4]:
initializer = jax.nn.initializers.glorot_uniform()
def random_layer_params(key,m,n,scale=1):
    return scale * initializer(key, (n,m),dtype=jnp.float64), jnp.zeros((n,))

def init_network_params(key,sizes):
    return [random_layer_params(key,x,y) for x,y in zip(sizes[:-1],sizes[1:])]

#params: [[[L1,L2],[L1,]],[[L2,L3],[L2,]],...,[[Ln-1,Ln],[Ln-1,]]]
#        list -> tuple -> array
#        layers -> weight, bias -> coefficients

def NN(activation):
    def model(params,x):
        output = x 
        for w,b in params[:-1]:
            #print(output.shape,jnp.dot(output,w.T).shape,b.shape)
            linear = jnp.dot(output,w.T) + b
            output = activation(linear)
        w,b = params[-1]
        #print(output.shape,jnp.dot(output,w.T).shape,b.shape)
        output = jnp.reshape(jnp.dot(output,w.T) + b,())
        return output
    return model 

params = init_network_params(key,layers)
u = NN(jnp.tanh)

In [5]:
def laplace(func):

    hess = hessian(func,0)

    lap = lambda x,t: jnp.trace(hess(x,t))

    return lap 

def parabolic(func):
    lap = laplace(func)
    time_diff = grad(func,1)
    par = lambda x: time_diff(x[:-1],x[-1])+lap(x[:-1],x[-1])

    return par 

def bdry(func):
    b = lambda x: func(x)
    return b 

def init(func):
    i = lambda x:func(x) 
    return i 

In [6]:
''' 
loading data
'''

with open("dataset/2000pts",'rb') as pfile:
    data = pkl.load(pfile)

d_c = data['domain']
b_c = data['bdry']
i_c = data['init']


In [7]:
f = lambda x: (1-2*jnp.pi**2*x[2])*jnp.sin(jnp.pi*x[0])*jnp.sin(jnp.pi*x[1])
f_d = vmap(f)(d_c) 

b_d = jnp.zeros([len(b_c),1])
i_d = jnp.zeros([len(i_c),1]) + 0.02*jax.random.normal(key,[len(i_c),1])

In [8]:
parab = lambda param: parabolic(lambda x,t: u(param,jnp.hstack([x,t])))
bdry_func = lambda param: bdry(lambda x: u(param,x))
init_func = lambda param: init(lambda x: u(param,x))

pred_d = lambda param,x : parab(param)(x)
pred_b = lambda param,x : bdry_func(param)(x)
pred_i = lambda param,x : init_func(param)(x)

v_pred_d = jit(vmap(pred_d,(None,0)))
v_pred_b = jit(vmap(pred_b,(None,0)))
v_pred_i = jit(vmap(pred_i,(None,0)))

In [9]:
@jit 
def loss(param):
    l_d = jnp.mean((v_pred_d(param,d_c) - f_d)**2)
    l_b = jnp.mean((v_pred_b(param,b_c) - b_d)**2)
    l_i = jnp.mean((v_pred_i(param,i_c) - i_d)**2)

    return l_d + bw * l_b + iw * l_i

In [10]:
y = lambda x: jnp.sin(jnp.pi*x[0])*jnp.sin(jnp.pi*x[1])*x[2]

res = lambda param,x : (u(param,x)-y(x))**2 
v_res = vmap(res,(None,0))


flat_param, unravel = ravel_pytree(params)
flat_loss = lambda flat_param: loss(unravel(flat_param))

LBFGS = jaxopt.LBFGS(fun=flat_loss,value_and_grad=False,linesearch='hager-zhang')
state = LBFGS.init_state(flat_param) 

for iter in jnp.arange(0,max_iter):
    flat_param,state = LBFGS.update(flat_param,state)
    if iter % val_int == 0:
        params = unravel(flat_param)
        error = jnp.sqrt(jnp.mean(v_res(params,d_c)))
        print("Iter {} | Error {}".format(iter,error))

Iter 0 | Error 0.5733184934599507
Iter 10 | Error 0.5715436295550633
Iter 20 | Error 0.26020823329845144
Iter 30 | Error 0.17649193408349287
Iter 40 | Error 0.19420168480772135
Iter 50 | Error 0.13825586041393503
Iter 60 | Error 0.08043049583386293
Iter 70 | Error 0.07632401381894778
Iter 80 | Error 0.06685910874356864
Iter 90 | Error 0.06286070710762598
Iter 100 | Error 0.05520329395818939
Iter 110 | Error 0.04651274479855597
Iter 120 | Error 0.04353244797828913
Iter 130 | Error 0.050420364343653275
Iter 140 | Error 0.04841416700522314
Iter 150 | Error 0.05125786502559185
Iter 160 | Error 0.04559126286478243
Iter 170 | Error 0.048203376378351925
Iter 180 | Error 0.04466500995281347
Iter 190 | Error 0.046325493715804786
Iter 200 | Error 0.042965565338742094
Iter 210 | Error 0.038907981240841986
Iter 220 | Error 0.03861442370026865
Iter 230 | Error 0.0399321935099246
Iter 240 | Error 0.038378958498781104
Iter 250 | Error 0.038382454036806464
Iter 260 | Error 0.03501270349788849
Iter 270

Iter 790 | Error 0.015287799549072835
