In [44]:
import numpy as np
import scipy.io
import matplotlib
import matplotlib.pyplot as plt
import time

matplotlib.rcParams.update({'font.size': 8})

import jax.numpy as jnp
from jax import value_and_grad, jit, random, vmap, grad, jacfwd
import optax
from pyDOE import lhs

from tqdm.contrib import tzip

KEY = random.PRNGKey(1)

In [11]:
# Load Data
data = scipy.io.loadmat("NLS.mat")

noise = 0.0        
    
# Doman bounds
lb = np.array([-5.0, 0.0])
ub = np.array([5.0, np.pi/2])

N0 = 50
N_b = 50
N_f = 20000
layers = [2, 100, 100, 100, 100, 2]

data = scipy.io.loadmat('NLS.mat')

t = data['tt'].flatten()[:,None]
x = data['x'].flatten()[:,None]
Exact = data['uu']
Exact_u = np.real(Exact)
Exact_v = np.imag(Exact)
Exact_h = np.sqrt(Exact_u**2 + Exact_v**2)

X, T = np.meshgrid(x,t)

X_star = np.hstack((X.flatten()[:,None], T.flatten()[:,None]))
u_star = Exact_u.T.flatten()[:,None]
v_star = Exact_v.T.flatten()[:,None]
h_star = Exact_h.T.flatten()[:,None]

###########################

idx_x = np.random.choice(x.shape[0], N0, replace=False)
x0 = x[idx_x,:]
u0 = Exact_u[idx_x,0:1]
v0 = Exact_v[idx_x,0:1]

idx_t = np.random.choice(t.shape[0], N_b, replace=False)
tb = t[idx_t,:]

X_f = lb + (ub-lb)*lhs(2, N_f)

In [67]:
class PINN:
    def __init__(self, x0, u0, v0, tb, X_f, layers, lb, ub):
        X0 = np.concatenate((x0, 0*x0), 1) # (x0, 0)
        X_lb = np.concatenate((0*tb + lb[0], tb), 1) # (lb[0], tb)
        X_ub = np.concatenate((0*tb + ub[0], tb), 1) # (ub[0], tb)

        self.lb = jnp.array(lb)
        self.ub = jnp.array(ub)

        self.x0 = jnp.array(X0[:,0:1])
        self.t0 = jnp.array(X0[:,1:2])

        self.x_lb = jnp.array(X_lb[:,0:1])
        self.t_lb = jnp.array(X_lb[:,1:2])

        self.x_ub = jnp.array(X_ub[:,0:1])
        self.t_ub = jnp.array(X_ub[:,1:2])

        self.x_f = jnp.array(X_f[:,0:1])
        self.t_f = jnp.array(X_f[:,1:2])

        self.u0 = jnp.array(u0)
        self.v0 = jnp.array(v0)


        # Initialize NNs
        self.layers = layers
        self.params = self.init_params(layers)
        
        
    def init_params(self, layers, key=KEY):
        '''
        Initialize parameters in the MLP. Weights are initialized
        using Xavier initialization, while biases are zero initialized.

        Returns
        - params: the initialized parameters
        '''
        def xavier_init(input_dim, output_dim, key=key):
            '''Use Xavier initialization for weights of a single layer'''
            std_dev = jnp.sqrt(2/(input_dim + output_dim)) # compute standard deviation for xavier init
            w = std_dev * random.normal(key, (input_dim, output_dim)) # initialize the weights
            return w

        params = []

        for l in range(len(layers) - 1):
            w = xavier_init(layers[l], layers[l+1]) # xavier initialize the weight
            b = jnp.zeros(layers[l+1]) # zero initialize the bias
            params.append((w, b)) # append weight and bias for this layer to params

        return params
    
    def neural_net(self, X, params):
        num_layers = len(params) + 1
        weights, biases = zip(*params)
        
        H = 2.0*(X - self.lb)/(self.ub - self.lb) - 1.0
        for l in range(0,num_layers-2):
            W = weights[l]
            b = biases[l]
            H = jnp.tanh(jnp.add(jnp.matmul(H, W), b))
#             H = tf.tanh(tf.add(tf.matmul(H, W), b))
        W = weights[-1]
        b = biases[-1]
        Y = jnp.add(jnp.matmul(H, W), b)
        return Y
    
    def net_u_forward(self, x, t):
        X = jnp.concatenate([x,t],1)
        
        uv = self.neural_net(X, self.params)
        u = uv[:,0:1]
        return u
    
    def net_v_forward(self, x, t):
        X = jnp.concatenate([x,t],1)
        
        uv = self.neural_net(X, self.params)
        v = uv[:,1:2]
        return v
    
    def net_uv_forward(self, x, t):
        X = jnp.concatenate([x,t],1)
        
        uv = self.neural_net(X, self.params)
        u = uv[:,0:1]
        v = uv[:,1:2]
        return u, v
    
    def net_f_uv_forward(self, x, t):
        u, grad_f = vjp(self.net_u_forward, x, t)
        u_x, u_t = grad_f((jnp.ones(u.shape)))
        
        v, grad_f = vjp(self.net_v_forward, x, t)
        v_x, v_t = grad_f((jnp.ones(v.shape)))
        
        
        
        primals, grad_f = vjp(self.net_uv_forward, x, t)
        u, v = primals
        u_x, v_x = grad_f((x, t))
        u_x = u_x[0]
        
        u_t = grad(u, t)
        u_xx = grad
        
        v_t = grad(v, t)
        v_xx = grad(v_x, x)
        
        f_u = u_t + 0.5*v_xx + (u**2 + v**2)*v
        f_v = v_t - 0.5*u_xx - (u**2 + v**2)*u   
        
        return f_u, f_v
    
    def loss(self, x, y):
        x0, t0, x_lb, t_lb, x_ub, t_ub, x_f, t_f
    
    def loss(self, x0, t0, x_lb, t_lb, x_ub, t_ub, x_f, t_f):
        def convert_f(func, fixed_pos, fixed_arg):
            if fixed_pos == 1:
                return lambda x: func(x, fixed_arg)
            elif fixed_pos == 0:
                return lambda x: func(fixed_arg, x)
        def jac_fix_one(func, fixed_pos, fixed_arg, arg):
            fixed_fun = convert_f(func, fixed_pos, fixed_arg)
            jac = jacfwd(fixed_fun)(arg)
            return jac
            
            
        u0_pred = self.net_u_forward(x0, t0)
        v0_pred = self.net_v_forward(x0, t0)
        
        
        u_lb_pred = self.net_u_forward(x_lb, t_lb)
        jac = jac_fix_one(self.net_u_forward, 1, t_lb, x_lb)
        print("!!!!!")
        print(u_lb_pred)
        print(jac.shape)
        for i in jac:
            for j in i:
                for k in j:
                    for l in k:
                        print(l)
        print("!!!!!")
        
        
        u_lb_pred, grad_f = vjp(self.net_u_forward, x_lb, t_lb)
        u_x_lb_pred, _ = grad_f((jnp.ones(u_lb_pred.shape)))
        print(u_x_lb_pred)
        for i in u_x_lb_pred:
            for j in i:
                print(i)
        
        v_lb_pred, grad_f = vjp(self.net_v_forward, x_lb, t_lb)
        v_x_lb_pred, _ = grad_f((jnp.ones(v_lb_pred.shape)))
        
#         u_lb_pred = self.net_u_forward(x_lb, t_lb)
#         v_lb_pred = self.net_v_forward(x_lb, t_lb)
#         u_x_lb_pred = grad(self.net_u_forward)(x_lb, t_lb) # argnums=0: grad wrt 0th argument, which is x_lb
#         v_x_lb_pred = grad(self.net_v_forward, argnums=0)(x_lb, t_lb)


        u_ub_pred, grad_f = vjp(self.net_u_forward, x_ub, t_ub)
        u_x_ub_pred, _ = grad_f((jnp.ones(u_ub_pred.shape)))
        
        v_ub_pred, grad_f = vjp(self.net_v_forward, x_ub, t_ub)
        v_x_ub_pred, _ = grad_f((jnp.ones(v_ub_pred.shape)))
        
#         u_ub_pred = self.net_u_forward(x_ub, t_ub)
#         v_ub_pred = self.net_v_forward(x_ub, t_ub)
#         u_x_ub_pred = grad(u_ub_pred, x_ub)
#         v_x_ub_pred = grad(v_ub_pred, x_ub)
        
        
        f_u_pred, f_v_pred = self.net_f_uv_forward(x_f, t_f)
        
        loss = jnp.mean(jnp.square(self.u0 - u0_pred)) + \
               jnp.mean(jnp.square(self.v0 - v0_pred)) + \
               jnp.mean(jnp.square(u_lb_pred - u_ub_pred)) + \
               jnp.mean(jnp.square(v_lb_pred - v_ub_pred)) + \
               jnp.mean(jnp.square(u_x_lb_pred - u_x_ub_pred)) + \
               jnp.mean(jnp.square(v_x_lb_pred - v_x_ub_pred)) + \
               jnp.mean(jnp.square(f_u_pred)) + \
               jnp.mean(jnp.square(f_v_pred))
        return loss
    
    def train(self, nIter, optimizer):
        opt_state = optimizer.init(self.params)
        #@jit
        def step(params, opt_state):
            loss_value, grads = value_and_grad(self.loss)(self.x0, self.t0, self.x_lb, self.t_lb,
                                                          self.x_ub, self.t_ub, self.x_f, self.t_f)
            updates, opt_state = optimizer.update(grads, opt_state, params)
            params = optax.apply_updates(params, updates)
            return params, opt_state, loss_value
        
        train_losses = [0]
        for it in range(nIter):
            self.params, opt_state, loss_value = step(self.params, opt_state)
            train_losses.append(loss_value)
            if it % 1000 == 0:
                print(f'Iteration {it}, average loss: {jnp.mean(loss_value)}')
        return train_losses
    
    #def predict(self, Xstar):
        
                
        
    
    

In [68]:
# Train and Test
model = PINN(x0, u0, v0, tb, X_f, layers, lb, ub)
optimizer = optax.adam(learning_rate=5e-4)
             
start_time = time.time()                
model.train(1000, optimizer)
elapsed = time.time() - start_time                
print('Training time: %.4f' % (elapsed))


# u_pred, v_pred, f_u_pred, f_v_pred = model.predict(X_star)
# h_pred = np.sqrt(u_pred**2 + v_pred**2)

!!!!!
[[0.37043977]
 [0.3534008 ]
 [0.35568053]
 [0.37756115]
 [0.37297603]
 [0.36426994]
 [0.3732231 ]
 [0.37605402]
 [0.37950486]
 [0.3636548 ]
 [0.36399677]
 [0.37946206]
 [0.35766512]
 [0.3591708 ]
 [0.354164  ]
 [0.37550807]
 [0.3748383 ]
 [0.36754543]
 [0.3586638 ]
 [0.37430423]
 [0.3781013 ]
 [0.36972207]
 [0.37594512]
 [0.37873495]
 [0.36554196]
 [0.3794974 ]
 [0.36467442]
 [0.3581723 ]
 [0.35755396]
 [0.36632708]
 [0.36249372]
 [0.3633107 ]
 [0.37891835]
 [0.37926745]
 [0.37294286]
 [0.3507114 ]
 [0.37710276]
 [0.37943545]
 [0.35939702]
 [0.37009928]
 [0.37572983]
 [0.36110267]
 [0.36296406]
 [0.372725  ]
 [0.37417105]
 [0.37635684]
 [0.3772746 ]
 [0.35530257]
 [0.3601241 ]
 [0.36830503]]
(50, 1, 50, 1)
-0.06399963
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
-0.05947131
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.

[[-0.06399959]
 [-0.0594713 ]
 [-0.05982935]
 [-0.06323368]
 [-0.06244617]
 [-0.06374401]
 [-0.06248511]
 [-0.06295283]
 [-0.06381892]
 [-0.06104102]
 [-0.0610921 ]
 [-0.0639001 ]
 [-0.06340361]
 [-0.06348567]
 [-0.05959189]
 [-0.06411766]
 [-0.06274594]
 [-0.06388967]
 [-0.06028892]
 [-0.06410269]
 [-0.06334682]
 [-0.06194696]
 [-0.06411988]
 [-0.06404393]
 [-0.06380281]
 [-0.06386908]
 [-0.0611931 ]
 [-0.06343152]
 [-0.06011903]
 [-0.06143914]
 [-0.06365786]
 [-0.06098956]
 [-0.06402504]
 [-0.06396887]
 [-0.06407406]
 [-0.05903993]
 [-0.06314404]
 [-0.06391486]
 [-0.06040052]
 [-0.06398784]
 [-0.06411906]
 [-0.06358738]
 [-0.06093777]
 [-0.06240685]
 [-0.06263668]
 [-0.06411991]
 [-0.06411067]
 [-0.05977044]
 [-0.06051074]
 [-0.06392051]]
[-0.06399959]
[-0.0594713]
[-0.05982935]
[-0.06323368]
[-0.06244617]
[-0.06374401]
[-0.06248511]
[-0.06295283]
[-0.06381892]
[-0.06104102]
[-0.0610921]
[-0.0639001]
[-0.06340361]
[-0.06348567]
[-0.05959189]
[-0.06411766]
[-0.06274594]
[-0.06388967]


TypeError: Expected a callable value, got [[-0.1645398 ]
 [ 0.35295618]
 [ 0.09166733]
 ...
 [-0.3130107 ]
 [-0.2992484 ]
 [ 0.20865972]]