In [45]:
import numpy as np
import scipy.io
import matplotlib
import matplotlib.pyplot as plt
import time

matplotlib.rcParams.update({'font.size': 8})

import jax

import jax.numpy as jnp
from jax import value_and_grad, jit, random, vmap, grad, jacrev
import optax
from pyDOE import lhs

from tqdm.contrib import tzip

KEY = random.PRNGKey(1)

In [46]:
# Load Data
data = scipy.io.loadmat("NLS.mat")

noise = 0.0        
    
# Doman bounds
lb = np.array([-5.0, 0.0])
ub = np.array([5.0, np.pi/2])

N0 = 50
N_b = 50
N_f = 20000
layers = [2, 100, 100, 100, 100, 2]

data = scipy.io.loadmat('NLS.mat')

t = data['tt'].flatten()[:,None]
x = data['x'].flatten()[:,None]
Exact = data['uu']
Exact_u = np.real(Exact)
Exact_v = np.imag(Exact)
Exact_h = np.sqrt(Exact_u**2 + Exact_v**2)

X, T = np.meshgrid(x,t)

X_star = np.hstack((X.flatten()[:,None], T.flatten()[:,None]))
u_star = Exact_u.T.flatten()[:,None]
v_star = Exact_v.T.flatten()[:,None]
h_star = Exact_h.T.flatten()[:,None]

###########################

idx_x = np.random.choice(x.shape[0], N0, replace=False)
x0 = x[idx_x,:]
u0 = Exact_u[idx_x,0:1]
v0 = Exact_v[idx_x,0:1]

idx_t = np.random.choice(t.shape[0], N_b, replace=False)
tb = t[idx_t,:]

X_f = lb + (ub-lb)*lhs(2, N_f)

In [47]:
class PINN:
    def __init__(self, x0, u0, v0, tb, X_f, layers, lb, ub):
        X0 = np.concatenate((x0, 0*x0), 1) # (x0, 0)
        X_lb = np.concatenate((0*tb + lb[0], tb), 1) # (lb[0], tb)
        X_ub = np.concatenate((0*tb + ub[0], tb), 1) # (ub[0], tb)

        self.lb = jnp.array(lb)
        self.ub = jnp.array(ub)

        self.x0 = jnp.array(X0[:,0:1])
        self.t0 = jnp.array(X0[:,1:2])

        self.x_lb = jnp.array(X_lb[:,0:1])
        self.t_lb = jnp.array(X_lb[:,1:2])

        self.x_ub = jnp.array(X_ub[:,0:1])
        self.t_ub = jnp.array(X_ub[:,1:2])

        self.x_f = jnp.array(X_f[:,0:1])
        self.t_f = jnp.array(X_f[:,1:2])

        self.u0 = jnp.array(u0)
        self.v0 = jnp.array(v0)


        # Initialize NNs
        self.layers = layers
        self.params = self.init_params(layers)
        
        
    def init_params(self, layers, key=KEY):
        '''
        Initialize parameters in the MLP. Weights are initialized
        using Xavier initialization, while biases are zero initialized.

        Returns
        - params: the initialized parameters
        '''
        def xavier_init(input_dim, output_dim, key=key):
            '''Use Xavier initialization for weights of a single layer'''
            std_dev = jnp.sqrt(2/(input_dim + output_dim)) # compute standard deviation for xavier init
            w = std_dev * random.normal(key, (input_dim, output_dim)) # initialize the weights
            return w

        params = []

        for l in range(len(layers) - 1):
            w = xavier_init(layers[l], layers[l+1]) # xavier initialize the weight
            b = jnp.zeros(layers[l+1]) # zero initialize the bias
            params.append((w, b)) # append weight and bias for this layer to params

        return params
    
    def neural_net(self, X, params):
        num_layers = len(params) + 1
        weights, biases = zip(*params)
        
        H = 2.0*(X - self.lb)/(self.ub - self.lb) - 1.0
        for l in range(0,num_layers-2):
            W = weights[l]
            b = biases[l]
            H = jnp.tanh(jnp.add(jnp.matmul(H, W), b))
#             H = tf.tanh(tf.add(tf.matmul(H, W), b))
        W = weights[-1]
        b = biases[-1]
        Y = jnp.add(jnp.matmul(H, W), b)
        return Y
    
    def net_u_forward(self, params, x, t):
        X = jnp.concatenate([x,t],1)
        
        uv = self.neural_net(X, self.params)
        u = uv[:,0:1]
        return u
    
    def net_v_forward(self, params, x, t):
        X = jnp.concatenate([x,t],1)
        
        uv = self.neural_net(X, self.params)
        v = uv[:,1:2]
        return v
    
    def net_uv_forward(self, params, x, t):
        X = jnp.concatenate([x,t],1)
        
        uv = self.neural_net(X, self.params)
        u = uv[:,0:1]
        v = uv[:,1:2]
        return u, v
    
    def net_f_uv_forward(self, params, x, t):
        u = self.net_u_forward(params, x, t)
        u_x, u_t = jacrev(self.net_u_forward, argnums=(0, 1))(params, x, t)
        
        v = self.net_v_forward(params, x, t)
        v_x, v_t = jacrev(self.net_v_forward, argnums=(0, 1))(params, x, t)
        
        u_xx = jacrev(jacrev(self.net_u_forward, argnums=1), argnums=1)(params, x, t)
        v_xx = jacrev(jacrev(self.net_v_forward, argnums=1), argnums=1)(params, x, t)
        
        f_u = u_t + 0.5*v_xx + (u**2 + v**2)*v
        f_v = v_t - 0.5*u_xx - (u**2 + v**2)*u   
        
        return f_u, f_v
    
    def loss(self, params, x0, t0, x_lb, t_lb, x_ub, t_ub, x_f, t_f):
        u0_pred = self.net_u_forward(params, x0, t0)
        v0_pred = self.net_v_forward(params, x0, t0)
        
        u_lb_pred = self.net_u_forward(params, x_lb, t_lb)
        u_x_lb_pred, _ = jacrev(self.net_u_forward, argnums=(0,1))(params, x_lb, t_lb)
        
        v_lb_pred = self.net_v_forward(params, x_lb, t_lb)
        v_x_lb_pred, _ = jacrev(self.net_v_forward, argnums=(0,1))(params, x_lb, t_lb)
        

        u_ub_pred = self.net_u_forward(params, x_ub, t_ub)
        u_x_ub_pred, _ = jacrev(self.net_u_forward, argnums=(0,1))(params, x_ub, t_ub)
        
        v_ub_pred = self.net_v_forward(params, x_ub, t_ub)
        v_x_ub_pred, _ = jacrev(self.net_v_forward, argnums=(0,1))(params, x_ub, t_ub)
        
        
        f_u_pred, f_v_pred = self.net_f_uv_forward(params, x_f, t_f)
        
        loss = jnp.mean(jnp.square(self.u0 - u0_pred)) + \
               jnp.mean(jnp.square(self.v0 - v0_pred)) + \
               jnp.mean(jnp.square(u_lb_pred - u_ub_pred)) + \
               jnp.mean(jnp.square(v_lb_pred - v_ub_pred)) + \
               jnp.mean(jnp.square(jnp.subtract(jnp.array(u_x_lb_pred), jnp.array(u_x_ub_pred)))) + \
               jnp.mean(jnp.square(jnp.subtract(v_x_lb_pred, v_x_ub_pred))) 
#                jnp.mean(jnp.square(f_u_pred)) + \
#                jnp.mean(jnp.square(f_v_pred))
        return loss
 

    def train(self, nIter, optimizer):
        opt_state = optimizer.init(self.params)
        
        @jit
        def step(paramms, opt_state):
            loss_value, grads = value_and_grad(self.loss)(paramms, self.x0, self.t0, self.x_lb, self.t_lb,
                                                          self.x_ub, self.t_ub, self.x_f, self.t_f)

            updates, opt_state = optimizer.update(grads, opt_state, paramms)
            paramms = optax.apply_updates(paramms, updates)
            return paramms, opt_state, loss_value
        
        train_losses = [0]
        for it in range(nIter):
            self.params, opt_state, loss_value = step(self.params, opt_state)
            train_losses.append(loss_value)
            if it % 1000 == 0:
                print(f'Iteration {it}, average loss: {jnp.mean(loss_value)}')
        return train_losses
    
    #def predict(self, Xstar):
        
                
        
    
    

In [48]:
# Train and Test
model = PINN(x0, u0, v0, tb, X_f, layers, lb, ub)
optimizer = optax.adam(learning_rate=5e-4)
             
start_time = time.time() 
# model.net_f_uv_forward1()
model.train(5000, optimizer)
elapsed = time.time() - start_time                
print('Training time: %.4f' % (elapsed))


# u_pred, v_pred, f_u_pred, f_v_pred = model.predict(X_star)
# h_pred = np.sqrt(u_pred**2 + v_pred**2)

ValueError: All input arrays must have the same shape.