# Jax Version of Stenosis2D

## Notes
* It runs fine, but the error is the same throughout all runs ==> My thought is that it might has something to do with the learning rate? Or how we propagates the data?


Suggestions
* Look at normalization
* Look at updating Reynolds and Peclet number

## Import of libraries

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
%cd /content/drive/MyDrive/ENM531\ Project/
!pwd

/content/drive/.shortcut-targets-by-id/1rxnV06iwKeFXatMhja60OvGThJkZnQHG/ENM531 Project
/content/drive/.shortcut-targets-by-id/1rxnV06iwKeFXatMhja60OvGThJkZnQHG/ENM531 Project


In [None]:
import jax.numpy as np
import numpy as onp
from jax import random, jit, vmap, grad, device_put, jacrev, hessian
from jax.experimental.optimizers import optimizer, make_schedule, exponential_decay
import scipy.io

import itertools
from functools import partial
from tqdm import trange
from torch.utils import data
import matplotlib.pyplot as plt

# from james_utilities import MLP_pde, Navier_Stokes_2D, Strain_Rate_2D, mean_squared_error, relative_error, fwd_gradients
from JAX_utilities import MLP_pde, Navier_Stokes_2D, Strain_Rate_2D, mean_squared_error, relative_error, fwd_gradients

## Initialize optimizer, data generator, and HFM class

In [None]:
@optimizer
def sgd(step_size):
    step_size = make_schedule(step_size)
    def init(x0):
        return x0
    def update(i, g, x):
        return x - step_size(i) * g
    def get_params(x):
        return x
    return init, update, get_params

@optimizer
def adam(step_size, b1=0.9, b2=0.999, eps=1e-8):
    step_size = make_schedule(step_size)
    def init(x0):
        m0 = np.zeros_like(x0)
        v0 = np.zeros_like(x0)
        return x0, m0, v0
    def update(i, g, state):
        x, m, v = state
        m = (1 - b1) * g + b1 * m  # First  moment estimate.
        v = (1 - b2) * np.square(g) + b2 * v  # Second moment estimate.
        mhat = m / (1 - np.asarray(b1, m.dtype) ** (i + 1))  # Bias correction.
        vhat = v / (1 - np.asarray(b2, m.dtype) ** (i + 1))
        x = x - step_size(i) * mhat / (np.sqrt(vhat) + eps)
        return x, m, v
    def get_params(state):
        x, _, _ = state
        return x
    return init, update, get_params

In [None]:
class HFM:
    # notational conventions
    # _tf: placeholders for input/output data and points used to regress the equations
    # _pred: output of neural network
    # _eqns: points used to regress the equations
    # _data: input-output data
    # _star: predictions

    def __init__(self, t_data, x_data, y_data, c_data,
                       t_eqns, x_eqns, y_eqns,
                       layers, rng_key=random.PRNGKey(19)):

        # MLP init and apply functions
        self.net_init, self.net_apply = MLP_pde(layers)
        params = self.net_init(rng_key)
        # self.params = params

        # Optimizer initialization and update functions
        lr = exponential_decay(1e-3, decay_steps=1000, decay_rate=0.999)
        self.opt_init, \
        self.opt_update, \
        self.get_params = adam(lr)
        self.opt_state = self.opt_init(params)

        # Logger
        self.itercount = itertools.count()
        self.loss_log = []

        # specs
        self.layers = layers
        
        # flow properties
        self.Pec = 15.0
        self.Rey = 5.0
                
        # data
        [self.t_data, self.x_data, self.y_data, self.c_data] = [t_data, x_data, y_data, c_data]
        [self.t_eqns, self.x_eqns, self.y_eqns] = [t_eqns, x_eqns, y_eqns]
        
        # inputs
        [self.t_data_input, self.x_data_input, self.y_data_input, self.c_data_input] = [np.empty((1,),dtype=np.float32) for _ in range(4)]
        [self.t_eqns_input, self.x_eqns_input, self.y_eqns_input] = [np.empty((1,),dtype=np.float32) for _ in range(3)]
        
        # physics "uninformed" neural networks
    def neural_net(self, params, t, x, y):
        inputs = np.stack([t, x, y])
        nn = self.net_apply(params, inputs.T)
        ##### Low-key not sure about this implementation
        return nn.T

    def loss(self, params, t_data, x_data, y_data, c_data):


        [self.c_data_pred,
        self.u_data_pred,
        self.v_data_pred,
        self.p_data_pred] = self.neural_net(params, self.t_data_input, self.x_data_input, self.y_data_input)

        [self.c_eqns_pred,
        self.u_eqns_pred,
        self.v_eqns_pred,
        self.p_eqns_pred] = self.neural_net(params, self.t_eqns_input,
                                                      self.x_eqns_input,
                                                      self.y_eqns_input)
        
        [self.e1_eqns_pred,
        self.e2_eqns_pred,
        self.e3_eqns_pred,
        self.e4_eqns_pred] = self.Navier_Stokes_2D(self.c_eqns_pred.reshape(-1,1),
                                              self.u_eqns_pred.reshape(-1,1),
                                              self.v_eqns_pred.reshape(-1,1),
                                              self.p_eqns_pred.reshape(-1,1),
                                              self.t_eqns_input.reshape(-1,1),
                                              self.x_eqns_input.reshape(-1,1),
                                              self.y_eqns_input.reshape(-1,1),
                                              self.Pec,
                                              self.Rey,
                                              params)
        
        # [self.eps11dot_eqns_pred,
        # self.eps12dot_eqns_pred,
        # self.eps22dot_eqns_pred] = Strain_Rate_2D(self.u_eqns_pred,
        #                                           self.v_eqns_pred,
        #                                           self.x_eqns_input,
        #                                           self.y_eqns_input,
        #                                           self.t_eqns_input,
        #                                           params)
        
        # loss ==> pulled from utilities function of mean_square_error
        loss = np.mean(np.square(np.subtract(self.c_data_pred, np.array(self.c_data_input)))) + \
                np.mean(np.square(np.subtract(self.e1_eqns_pred,np.array([0.0])))) + \
                np.mean(np.square(np.subtract(self.e2_eqns_pred,np.array([0.0])))) + \
                np.mean(np.square(np.subtract(self.e3_eqns_pred,np.array([0.0])))) + \
                np.mean(np.square(np.subtract(self.e4_eqns_pred,np.array([0.0]))))
        return loss
        
    def Navier_Stokes_2D(self, c, u, v, p, t, x, y, Pec, Rey, params):
    
        Y = np.concatenate([c, u, v, p], 1) # (1,4)
        batch = t, x, y
        
        #Y_t = fwd_gradients(params, batch, Y, 1)
        #Y_x = fwd_gradients(params, batch, Y, 2)
        #Y_y = fwd_gradients(params, batch, Y, 3)
        #Y_xx = fwd_gradients(params, batch, Y_x, 2)
        #Y_yy = fwd_gradients(params, batch, Y_y, 3)
        Y_t, Y_x, Y_y = jacrev(self.neural_net, (1, 2, 3))(params, t, x, y)
        Y_xx = hessian(self.neural_net, 2)(params, t, x, y)
        Y_yy = hessian(self.neural_net, 3)(params, t, x, y)

        # Outputs from hessian and jacrev were (4,1,1,1,1,1,1) or something like that
        # so we squeeze to make (4,) and reshape to get (1,4)
        Y_t = np.squeeze(Y_t).reshape(1,-1)
        Y_x = np.squeeze(Y_x).reshape(1,-1)
        Y_y = np.squeeze(Y_y).reshape(1,-1)
        Y_xx = np.squeeze(Y_xx).reshape(1,-1)
        Y_yy = np.squeeze(Y_yy).reshape(1,-1)
        
        c = Y[:,0:1]
        u = Y[:,1:2]
        v = Y[:,2:3]
        p = Y[:,3:4]
        
        c_t = Y_t[:,0:1]
        u_t = Y_t[:,1:2]
        v_t = Y_t[:,2:3]
        
        c_x = Y_x[:,0:1]
        u_x = Y_x[:,1:2]
        v_x = Y_x[:,2:3]
        p_x = Y_x[:,3:4]
        
        c_y = Y_y[:,0:1]
        u_y = Y_y[:,1:2]
        v_y = Y_y[:,2:3]
        p_y = Y_y[:,3:4]
     
        c_xx = Y_xx[:,0:1]
        u_xx = Y_xx[:,1:2]
        v_xx = Y_xx[:,2:3]
        
        c_yy = Y_yy[:,0:1]
        u_yy = Y_yy[:,1:2]
        v_yy = Y_yy[:,2:3]
        
        e1 = c_t + (u*c_x + v*c_y) - (1.0/Pec)*(c_xx + c_yy)
        e2 = u_t + (u*u_x + v*u_y) + p_x - (1.0/Rey)*(u_xx + u_yy) 
        e3 = v_t + (u*v_x + v*v_y) + p_y - (1.0/Rey)*(v_xx + v_yy)
        e4 = u_x + v_y
        
        return e1, e2, e3, e4

    def Gradient_Velocity_2D(self, u, v, x, y, t, params):
        
        Y = np.concatenate([u, v], 1)
        batch = t, x, y
        
        # Y_x = fwd_gradients(params, batch, Y, 1)
        # Y_y = fwd_gradients(params, batch, Y, 2)
        Y_x, Y_y = jacrev(self.neural_net, (2, 3))(params, t, x, y)
        
        u_x = Y_x[:,0:1]
        v_x = Y_x[:,1:2]
        
        u_y = Y_y[:,0:1]
        v_y = Y_y[:,1:2]
        
        return [u_x, v_x, u_y, v_y]

    def Strain_Rate_2D(self, u, v, x, y, t, params):
        
        [u_x, v_x, u_y, v_y] = self.Gradient_Velocity_2D(u, v, x, y, t, params)
        
        eps11dot = u_x
        eps12dot = 0.5*(v_x + u_y)
        eps22dot = v_y
        
        return [eps11dot, eps12dot, eps22dot]


    @partial(jit, static_argnums=(0,))
    def step(self, i, opt_state, batch, t_data, x_data, y_data, c_data):
        params = self.get_params(opt_state)
        gradients = grad(self.loss)(params, t_data, x_data, y_data, c_data)
        return self.opt_update(i, gradients, opt_state)

    def train(self, dataset, t_data, x_data, y_data, c_data, nIter = 10):
        data = iter(dataset)
        pbar = trange(nIter)
        # Main training loop
        for it in pbar:
            # Run one gradient descent update
            batch = np.squeeze(np.asarray(next(data))).T
            self.opt_state = self.step(next(self.itercount), self.opt_state, batch, t_data, x_data, y_data, c_data) 
            if it % 50 == 0:
                # Logger
                params = self.get_params(self.opt_state)
                loss = self.loss(params, t_data, x_data, y_data, c_data)
                self.loss_log.append(loss)
                pbar.set_postfix({'Loss': loss})

    def get_logs(self):
      return self.loss_log
    
    @partial(jit, static_argnums=(0,))
    def predict(self, params, t, x, y):
        inputs = np.stack([t, x, y])
        outputs = self.net_apply(params, inputs.T)
        return outputs
    
    # def predict_eps_dot(self, t_star, x_star, y_star):
                
    #     [self.eps11dot_eqns_pred,
    #      self.eps12dot_eqns_pred,
    #      self.eps22dot_eqns_pred] = self.Strain_Rate_2D(self.u_eqns_pred,
    #                                                self.v_eqns_pred,
    #                                                self.x_eqns_input,
    #                                                self.y_eqns_input)
    #     tf_dict = {self.t_eqns_tf: t_star, self.x_eqns_tf: x_star, self.y_eqns_tf: y_star}
        
    #     eps11dot_star = self.sess.run(self.eps11dot_eqns_pred, tf_dict)
    #     eps12dot_star = self.sess.run(self.eps12dot_eqns_pred, tf_dict)
    #     eps22dot_star = self.sess.run(self.eps22dot_eqns_pred, tf_dict)
        
    #     return eps11dot_star, eps12dot_star, eps22dot_star

In [None]:
class DataGenerator(data.Dataset):
    def __init__(self, T, X, Y, 
                 norm_const=((0.0, 1.0), (0.0, 1.0), (0.0, 1.0)), 
                 batch_size=128, 
                 rng_key=random.PRNGKey(1234)):
        'Initialization'
        self.T = T
        self.X = X
        self.Y = Y
        self.N = Y.shape[0]
        self.norm_const = norm_const
        self.batch_size = batch_size
        self.key = rng_key

    @partial(jit, static_argnums=(0,))
    def __data_generation(self, key, T, X, Y):
        'Generates data containing batch_size samples'
        (mu_T, sigma_T), (mu_X, sigma_X), (mu_Y, sigma_Y) = self.norm_const
        idx = random.choice(key, self.N, (self.batch_size,), replace=False)
        T_selected = T[idx,:]
        X_selected = X[idx,:]
        Y_selected = Y[idx,:]
        T_norm = (T_selected - mu_T)/sigma_T
        X_norm = (X_selected - mu_X)/sigma_X
        Y_norm = (Y_selected - mu_Y)/sigma_Y
        return T_norm, X_norm, Y_norm

    def __getitem__(self, index):
        'Generate one batch of data'
        self.key, subkey = random.split(self.key)
        T_norm, X_norm, Y_norm = self.__data_generation(self.key, self.T, self.X, self.Y)
        return T_norm, X_norm, Y_norm

## Initial data and layers setup

In [None]:
# layers = [3] + 10*[4*50] + [4]
# layers = [3] + 5*[4*10] + [4] # Instead of 10 layer DNN with 50 neurons per layer
# Inputs are (t,x,y) and outputs are (c,u,v,p)
layers = [3, 40, 40, 40, 40, 40, 4]

# Load Data
data = scipy.io.loadmat('./HFM/Data/Stenosis2D.mat')
      
t_star = data['t_star'] # T x 1
x_star = data['x_star'] # N x 1
y_star = data['y_star'] # N x 1

T = t_star.shape[0]
N = x_star.shape[0]

U_star = data['U_star'] # N x T
V_star = data['V_star'] # N x T
P_star = data['P_star'] # N x T
C_star = data['C_star'] # N x T    
        
# Rearrange Data 
T_star = np.tile(t_star, (1,N)).T # N x T
X_star = np.tile(x_star, (1,T)) # N x T
Y_star = np.tile(y_star, (1,T)) # N x T 

In [None]:
print(data['C_star'].shape)
print(data['P_star'].shape)
print(data['U_star'].shape)
print(data['V_star'].shape)
print(data['t_star'].shape)
print(data['x_star'].shape)
print(data['y_star'].shape)

(52737, 201)
(52737, 201)
(52737, 201)
(52737, 201)
(201, 1)
(52737, 1)
(52737, 1)


## Noiseless Data

In [None]:
T_data = T # int(sys.argv[1])
N_data = N # int(sys.argv[2])
key = random.PRNGKey(19)
idx_t = np.concatenate([np.array([0]), random.choice(key, T-2, (T_data-2,), replace=False)+1, np.array([T-1])])
key, subkey = random.split(key)
idx_x = random.choice(key, N, (N_data,), replace=False)
t_data = np.ravel(T_star[:, idx_t][idx_x,:])[:, None]       
x_data = np.ravel(X_star[:, idx_t][idx_x,:])[:, None] 
y_data = np.ravel(Y_star[:, idx_t][idx_x,:])[:, None] 
c_data = np.ravel(C_star[:, idx_t][idx_x,:])[:, None] 
    
T_eqns = T
N_eqns = N
key, subkey = random.split(key)
idx_t = np.concatenate([np.array([0]), random.choice(key, T-2, (T_eqns-2,), replace=False)+1, np.array([T-1])])
key, subkey = random.split(key)
idx_x = random.choice(key, N, (N_eqns,), replace=False)
t_eqns = np.ravel(T_star[:, idx_t][idx_x,:])[:, None] 
x_eqns = np.ravel(X_star[:, idx_t][idx_x,:])[:, None] 
y_eqns = np.ravel(Y_star[:, idx_t][idx_x,:])[:, None] 

norm_const = ((np.mean(t_eqns), np.std(t_eqns)), (np.mean(x_eqns), np.std(x_eqns)), (np.mean(y_eqns), np.std(y_eqns)))
dataset = DataGenerator(t_eqns, x_eqns, y_eqns, norm_const=norm_const, batch_size=1000)

## Define and train model

In [None]:
# Define model
model = HFM(t_data, x_data, y_data, c_data,
            t_eqns, x_eqns, y_eqns,
            layers, rng_key=random.PRNGKey(19))

In [None]:
model.train(dataset, t_data, x_data, y_data, c_data, 50000) # train model

 82%|████████▏ | 40948/50000 [5:15:43<1:09:06,  2.18it/s, Loss=2.801248e-13]

## Shear stress calculation and saving shear stress data

In [None]:
# Calculate shear stress
# Shear = np.zeros((300,t_star.shape[0]))

# for snap in range(0,t_star.shape[0]):
    
#     x1_shear = np.linspace(15,25,100)[:,None]
#     x2_shear = np.linspace(25,35,100)[:,None]
#     x3_shear = np.linspace(35,55,100)[:,None]

#     x_shear = np.concatenate([x1_shear,x2_shear,x3_shear], axis=0)

#     y1_shear = 0.0*x1_shear
#     y2_shear = np.sqrt(25.0 - (x2_shear - 30.0)**2)
#     y3_shear = 0.0*x3_shear

#     y_shear = np.concatenate([y1_shear,y2_shear,y3_shear], axis=0)
        
#     t_shear = T_star[0,snap] + 0.0*x_shear
    
#     eps11_dot_shear, eps12_dot_shear, eps22_dot_shear = model.predict_eps_dot(t_shear, x_shear, y_shear)
    
#     nx1_shear = 0.0*x1_shear
#     nx2_shear = 6.0 - x2_shear/5.0
#     nx3_shear = 0.0*x3_shear
    
#     nx_shear = np.concatenate([nx1_shear,nx2_shear,nx3_shear], axis=0)
    
#     ny1_shear = -1.0 + 0.0*y1_shear
#     ny2_shear = -y2_shear/5.0
#     ny3_shear = -1.0 + 0.0*y3_shear
    
#     ny_shear = np.concatenate([ny1_shear,ny2_shear,ny3_shear], axis=0)
    
#     shear_x = 2.0*(1.0/5.0)*(eps11_dot_shear*nx_shear + eps12_dot_shear*ny_shear)
#     shear_y = 2.0*(1.0/5.0)*(eps12_dot_shear*nx_shear + eps22_dot_shear*ny_shear)
    
#     shear = np.sqrt(shear_x**2 + shear_y**2)
    
#     Shear[:,snap] = np.ravel(shear)

In [None]:
# !ls

In [None]:
# Save shear data
# scipy.io.savemat('./Results/Stenosis2D_Pec_Re_shear_results_%s.mat' %(time.strftime('%d_%m_%Y')),
#                   {'Shear':Shear, 'x_shear':x_shear})

## Error calculation and saving prediction results

In [None]:
# Test Data
snap = np.array([55])
t_test = T_star[:,snap]
x_test = X_star[:,snap]
y_test = Y_star[:,snap]

c_test = C_star[:,snap]
u_test = U_star[:,snap]
v_test = V_star[:,snap]
p_test = P_star[:,snap]

# Prediction
opt_params = model.get_params(model.opt_state)
pred = model.predict(opt_params, t_test, x_test, y_test)
c_pred, u_pred, v_pred, p_pred = pred.T
# Error
# print(p_test)
error_c = relative_error(c_pred, c_test)
error_c = np.sqrt(np.mean(np.square(c_pred - c_test))/np.mean(np.square(c_pred - np.mean(c_test))))
error_u = relative_error(u_pred, u_test)
error_u = np.sqrt(np.mean(np.square(u_pred - u_test))/np.mean(np.square(u_pred - np.mean(u_test))))
error_v = relative_error(v_pred, v_test)
error_v = np.sqrt(np.mean(np.square(v_pred - v_test))/np.mean(np.square(v_pred - np.mean(v_test))))
error_p = relative_error(p_pred - np.mean(p_pred), p_test - np.mean(p_test))
error_p = np.sqrt(np.mean(np.square(p_pred - np.mean(p_pred) - p_test - np.mean(p_test)))/np.mean(np.square(p_pred - np.mean(p_pred) - np.mean(p_test - np.mean(p_test)))))

# print(error_c)
print('Error c: %e' % (error_c))
print('Error u: %e' % (error_u))
print('Error v: %e' % (error_v))
print('Error p: %e' % (error_p))

NameError: ignored

In [None]:
import time

In [None]:
C_pred = 0*C_star
U_pred = 0*U_star
V_pred = 0*V_star
P_pred = 0*P_star
for snap in range(0,t_star.shape[0]):
    t_test = T_star[:,snap:snap+1]
    x_test = X_star[:,snap:snap+1]
    y_test = Y_star[:,snap:snap+1]
    
    c_test = C_star[:,snap:snap+1]
    u_test = U_star[:,snap:snap+1]
    v_test = V_star[:,snap:snap+1]
    p_test = P_star[:,snap:snap+1]

    # Prediction
    opt_params = model.get_params(model.opt_state)
    pred = model.predict(opt_params, t_test, x_test, y_test)
    c_pred, u_pred, v_pred, p_pred = pred.T
    
    C_pred[:,snap:snap+1] = c_pred
    U_pred[:,snap:snap+1] = u_pred
    V_pred[:,snap:snap+1] = v_pred
    P_pred[:,snap:snap+1] = p_pred

    # Error
    # print(p_test)
    error_c = relative_error(c_pred, c_test)
    error_c = np.sqrt(np.mean(np.square(c_pred - c_test))/np.mean(np.square(c_pred - np.mean(c_test))))
    error_u = relative_error(u_pred, u_test)
    error_u = np.sqrt(np.mean(np.square(u_pred - u_test))/np.mean(np.square(u_pred - np.mean(u_test))))
    error_v = relative_error(v_pred, v_test)
    error_v = np.sqrt(np.mean(np.square(v_pred - v_test))/np.mean(np.square(v_pred - np.mean(v_test))))
    error_p = relative_error(p_pred - np.mean(p_pred), p_test - np.mean(p_test))
    error_p = np.sqrt(np.mean(np.square(p_pred - np.mean(p_pred) - p_test - np.mean(p_test)))/np.mean(np.square(p_pred - np.mean(p_pred) - np.mean(p_test - np.mean(p_test)))))

    print('==============')
    print('Error c: %e' % (error_c))
    print('Error u: %e' % (error_u))
    print('Error v: %e' % (error_v))
    print('Error p: %e' % (error_p))
    print('==============')

scipy.io.savemat('./Results/JAX_Stenosis2D_Pec_Re_results_test%s.mat' %(time.strftime('%d_%m_%Y')),
                  {'C_pred':C_pred, 'U_pred':U_pred, 'V_pred':V_pred, 'P_pred':P_pred, 'Pec': model.Pec, 'Rey': model.Rey})