# Physics-informed neural network for 3D inverse modeling of natural-state geothermal systems

In [None]:
# Published on February 2025

In [None]:
# The program was implemented under the following versions:
# Tensorflow version 2.9.1
# Tensorflow probability version 0.17.0
# Numpy version 1.21.1
# pandas version 1.4.3
# matplotlib version 3.5.0
# scipy version 1.4.1

In [None]:
import sys
sys.path.append("./")
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import pandas as pd
import sys
from time import time
from enum import Enum
from scipy import interpolate
from scipy.interpolate import Rbf
import scipy.optimize
import os
import shutil
from typing import List
from dataclasses import dataclass
from optimizers.lbfgs import LBfgsOptimizer
from utils import rearrange_coordinate_container as rcc
# 
script_dir = os.path.abspath('') + "/utils/"
sys.path.append(script_dir)
ipt_dir = os.path.abspath('') + "/utils/"
sys.path.append(script_dir)
import set_wells_train_val as swtv
import get_boundary_indexlist as gb
import get_tkp_on_new_coords as gtpk

In [None]:
# GPU
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

gpu_id = 0
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    # Restrict TensorFlow to only use the first GPU
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        tf.config.experimental.set_visible_devices(gpus[gpu_id], 'GPU')
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU")
    except RuntimeError as e:
        # Visible devices must be set before GPUs have been initialized
        print(e)

In [None]:
class NeumannBcApproximater:
    """
    A class of approximate functions for Neumann boundary conditions on 3-dimensional rectangular domains.

    Parameters
    ----------
    X : Tensor
        x (x1) coordinate of the surface on which the Neumann boundary condition is imposed.
        shape = (number_of_points, 1)
    Y : Tensor
        y (x2) coordinate of the surface on which the Neumann boundary condition is imposed.
        shape = (number_of_points, 1)
    var : Tensor
        Physical quantities of the surface imposing Neumann boundary conditions
        shape = (number_of_points, 1)
    order_interpolate : int
        The order of interpolation of the boundary conditions
    regularization_interpolate : float
        Regularization weight for interpolation of boundary conditions

    Attributes
    -----------
    XY : Tensor
        Reshaped tensor for XY coordinates
        shape = (1, number_of_points, 2)
    var  : Tensor
        Reshaped tensor for a physical quantity
        shape = (1, number_of_points, 1)
    """
    def __init__(self, X, Y, var, order_interpolate, regularization_interpolate):
        size = X.numpy().shape[0]

        self.XY = tf.reshape(tf.concat([X, Y], axis = 1), [1, size, 2])
        self.var = tf.reshape(var, [1, size, 1])
        self.order = order_interpolate
        self.reg = regularization_interpolate

    def function(self, x, y):
        """
        Approximated function for a quantity on the Neumann boundaries

        Input
        -----
        x : tf.Tensor
        Tensor representing the x-coordinate of the position at which the interpolation amount is calculated
        shape = (number_of_points, 1)
        y : tf.Tensor
        Tensor representing the y-coordinate of the position at which the interpolation amount is calculated
        shape = (number_of_points, 1)
        
        Output
        ------
        tf.Tensor
        Returns the physical quantity calculated by interpolation.
        shape = (1, number_of_points, 1)
        """
        size = x.shape[0]
        x_reshaped = tf.reshape(x, [1, size, 1])
        y_reshaped = tf.reshape(y, [1, size, 1])
        xy = tf.reshape(tf.concat([x_reshaped, y_reshaped], axis = 1), [1, size, 2])
        var = tfa.image.interpolate_spline(self.XY, self.var, xy, self.order, regularization_weight = self.reg)
        return tf.reshape(var, x.shape)

In [None]:
@dataclass
class GridSpec:
    """
    Data class that holds information about grid specifications

    Attributes
    ----------
    N_x1: int
      Number_of_points in the first axis
    N_x2: int
      Number_of_points in the second axis
    N_x3: int
      Number_of_points in the third axis
    up_indices: np.ndarray
      An array of indices that refers to points on the upper surface boundary.
      shape = (num_upper_bound_points, ndim=3)
    low_indices: np.ndarray
      An array of indices that refers to points on the bottom boundary.
      shape = (num_lower_bound_points, ndim=3)
    side_indices: np.ndarray
      An array of indices that refers to points on the side boundaries.
      shape = (num_side_bound_points, ndim=3)
    """
    N_x1: int
    N_x2: int
    N_x3: int

    up_indices: np.ndarray
    low_indices: np.ndarray
    side_indices: np.ndarray

In [None]:
# Define model architecture
class PINN_NeuralNet3(tf.keras.Model):
    """ Set basic architecture of the PINN model."""

    def __init__(self, lb, ub,
            output_dim=1,
            num_hidden_layers=4,
            num_neurons_per_layer=50,
            activation='tanh',
            kernel_initializer='glorot_normal',
            **kwargs):
        super().__init__(**kwargs)

        self.num_hidden_layers = num_hidden_layers
        self.output_dim = output_dim
        self.lb = lb
        self.ub = ub

        # Define NN architecture
        self.scale1 = tf.keras.layers.Lambda(
            lambda x: 2.0*(x - lb)/(ub - lb) - 1.0)
        self.hidden1 = [tf.keras.layers.Dense(num_neurons_per_layer,
                             activation=tf.keras.activations.get(activation),
                             kernel_initializer=kernel_initializer)
                           for _ in range(self.num_hidden_layers)]
        self.out1 = tf.keras.layers.Dense(output_dim)

        self.scale2 = tf.keras.layers.Lambda(
            lambda x: 2.0*(x - lb)/(ub - lb) - 1.0)
        self.hidden2 = [tf.keras.layers.Dense(num_neurons_per_layer,
                             activation=tf.keras.activations.get(activation),
                             kernel_initializer=kernel_initializer)
                           for _ in range(self.num_hidden_layers)]
        self.out2 = tf.keras.layers.Dense(output_dim)

        self.scale3 = tf.keras.layers.Lambda(
            lambda x: 2.0*(x - lb)/(ub - lb) - 1.0)
        self.hidden3 = [tf.keras.layers.Dense(num_neurons_per_layer,
                             activation=tf.keras.activations.get(activation),
                             kernel_initializer=kernel_initializer)
                           for _ in range(self.num_hidden_layers)]
        self.out3 = tf.keras.layers.Dense(output_dim)

    def call(self, X):
        """Forward-pass through neural network."""
        Z1 = self.scale1(X)
        Z2 = self.scale2(X)
        Z3 = self.scale3(X)
        for i in range(self.num_hidden_layers):
            Z1 = self.hidden1[i](Z1)
            Z2 = self.hidden2[i](Z2)
            Z3 = self.hidden3[i](Z3)
        return self.out1(Z1),self.out2(Z2),self.out3(Z3)

In [None]:
class VariableType(Enum):
    """
    Enumeration type that represents the type of physical quantity (T, p, k, etc.)
    """
    T = 1
    p = 2
    k = 3
    others = 5

In [None]:
class DataSet(Enum):
    """
     Enumeration type that represents the type of datasets (training , validation, all) 
    """
    TRAIN = 0
    VAL = 1
    ALL = 2

In [None]:
class PINNSolver_3D():

    """_summary_
    Class that performs PINN optimization

    Parameters
    ----------
        model: class
                 Class of the physics-informed neural network
        model_input_grid_spec: Grid specification
        if_low_dirichlet: bool
                Bool type indicating whether or not to impose the Dirichlet condition on the bottom surface.
        PI_indices: Tensor
                The index of the collocation points used to calculate the losses for mass and energy balances (PI1loss and PI2loss)
        X_uTpbD: Tensor
                Coordinates of the upper surface boundary used for Dirichlet boundary condition
                shape = (Num_surface_bound, 3)
        X_lTpbD: Tensor
                Coordinates of the bottom boundary used for Dirichlet boundary condition
                shape = (Num_bottom_bound, 3)
        nvec_p: Tensor
                Normal vector for side boundaries
                shape = (Num_side_bounds, 3)
        X_TbN: Tensor
                Corrdinates for the Neumann boundary condition at the bottom
                shape = (Num_bottom_bound, 3)
        nvec_T: Tensor
                Normal vector at the bottom
                shape = (Num_bottom_bound, 3)
        Y_uTpbD_T: Tensor
                Temperatures specified at the upper surface used as the Dirichlet boundary condition
        Y_uTpbD_p: Tensor
                Pressures specified at the upper surface used as the Dirichlet boundary condition
                shape = (Num_surface_bound, 1)
        Y_lTpbD_T: Tensor
                Temperatures specified at the bottom used as the Dirichlet boundary condition
                shape = (Num_bottom_bound, 1)
        Y_lTpbD_p: Tensor
                Pressures specified at the bottom used as the Dirichlet boundary condition
                shape = (Num_bottom_bound, 1)
        cof_PI1: float
                Coefficient for the loss of mass balance
        cof_PI2: float
                Coefficient for the loss of energy balance
        cof_pbN: float
                Coefficient for the Neumann boundary at the sides
        cof_TbN: float
                Coefficient for the Neumann boundary at the bottom
        T_true: nd.array
                Reference temperature
                shape = (Num_all_grid, 1)
        p_true: nd.array
                Reference pressure
                shape = (Num_all_grid, 1)
        k_true: nd.array
                Reference permeability
                shape = (Num_all_grid, 1)
        lmode: int
                The calculation mode specified
        dtype: Data type
    """
    def __init__(self, model, model_input_grid_spec: GridSpec,
                if_low_dirichlet,
                PI_indices,
                X_uTpbD, X_lTpbD,
                X_pbN, nvec_p, X_TbN, nvec_T,
                X_all,
                Y_uTpbD_T, Y_uTpbD_p, Y_lTpbD_T, Y_lTpbD_p,
                cof_PI1AD, cof_PI2AD, cof_pbN, cof_TbN,
                T_true, p_true, k_true,
                T_for_normalize, p_for_normalize, k_for_normalize,
                lmode,dtype,
                order_interpolate = 1,
                regularization_interpolate = 1e-4):
        
        self.model = model

        # Information about the input grid
        self._model_input_grid_spec: GridSpec = model_input_grid_spec

        # Generate the grid coordinates that is input when making a model prediction
        self._model_input_grid = tf.transpose(
            tf.reshape(X_all,
                       shape=(model_input_grid_spec.N_x3,
                              model_input_grid_spec.N_x2,
                              model_input_grid_spec.N_x1,
                              3)
            ),
            perm=[2, 1, 0, 3]
        )[tf.newaxis, ...]

        # Generate indices for extracting collocation points from the input grid
        self._grid_collocation_point_indices = (
            tf.gather(
                tf.transpose(
                    tf.unravel_index(PI_indices,
                                    dims=(model_input_grid_spec.N_x3,
                                          model_input_grid_spec.N_x2,
                                          model_input_grid_spec.N_x1))),
                indices=[2, 1, 0],
                axis=1)
        )

        # Bool type for the bottom boundary condition
        self.if_low_dirichlet = if_low_dirichlet

        # Store boundary points
        self.xuTpbD = X_uTpbD
        self.xlTpbD = X_lTpbD

        self.xpbN1 = X_pbN[:,0:1]
        self.xpbN2 = X_pbN[:,1:2]
        self.xpbN3 = X_pbN[:,2:3]

        self.nvecp = nvec_p
        self.xTbN1 = X_TbN[:,0:1]
        self.xTbN2 = X_TbN[:,1:2]
        self.xTbN3 = X_TbN[:,2:3]
        self.nvecT = nvec_T
        self.xall = X_all

        self.Y_uT = Y_uTpbD_T
        self.Y_up = Y_uTpbD_p
        self.Y_lT = Y_lTpbD_T
        self.Y_lp = Y_lTpbD_p

        self.cof_PI1AD = cof_PI1AD
        self.cof_PI2AD = cof_PI2AD
        self.cof_pbN = cof_pbN
        self.cof_TbN = cof_TbN
        if (self.if_low_dirichlet):
            self.cof_TbN = 0

        # Initialize history of losses and global iteration counter
        self.train_hist = []
        self.val_hist = []
        self.Tloss_hist = []
        self.ploss_hist = []
        self.kloss_hist = []
        self.PI1lossAD_hist = []
        self.PI2lossAD_hist = []
        self.TpbDloss_hist = []
        self.pbNloss_hist = []
        self.TbNloss_hist = []
        self.PI1lossAD_all_hist = []
        self.PI2lossAD_all_hist = []
        self.wPI1_hist = []
        self.wPI2_hist = []
        self.wTbN_hist = []
        self.iter = 0

        # Reference (true) data
        self.T_true = T_true
        self.p_true = p_true
        self.k_true = k_true
        
        # Data for normalization
        self.T_norm = T_for_normalize
        self.p_norm = p_for_normalize
        self.k_norm = k_for_normalize

        # loss mode
        self.lmode = lmode

        # Parameters for interpolation
        self.order = order_interpolate
        self.reg = regularization_interpolate

        # The type of floating-point numbers
        self.dtype = dtype
        
        # Weighting parameters for each loss function
        self.wPI1 = 1.0
        self.wPI2 = 1.0
        self.wTbN = 1.0
        self.gTPK = []
        self.gPI1 = []
        self.gPI2 = []
        self.gTbN = []

    def Coef_IF97_Eq15_Region1(self,i):
        if i == 1:
            Ii = 0; Ji = -2; ni = 0.14632971213167
        elif i == 2:
            Ii = 0; Ji = -1; ni = -0.84548187169114
        elif i == 3:
            Ii = 0; Ji = 0; ni = -0.37563603672040*10
        elif i == 4:
            Ii = 0; Ji = 1; ni = 0.33855169168385*10
        elif i == 5:
            Ii = 0; Ji = 2; ni = -0.95791963387872
        elif i == 6:
            Ii = 0; Ji = 3; ni = 0.15772038513228
        elif i == 7:
            Ii = 0; Ji = 4; ni = -0.16616417199501*10**(-1)
        elif i == 8:
            Ii = 0; Ji = 5; ni = 0.81214629983568*10**(-3)
        elif i == 9:
            Ii = 1; Ji = -9; ni = 0.28319080123804*10**(-3)
        elif i == 10:
            Ii = 1; Ji = -7; ni = -0.60706301565874*10**(-3)
        elif i == 11:
            Ii = 1; Ji = -1; ni = -0.18990068218419*10**(-1)
        elif i == 12:
            Ii = 1; Ji = 0; ni = -0.32529748770505*10**(-1)
        elif i == 13:
            Ii = 1; Ji = 1; ni = -0.21841717175414*10**(-1)
        elif i == 14:
            Ii = 1; Ji = 3; ni = -0.52838357969930*10**(-4)
        elif i == 15:
             Ii = 2; Ji = -3; ni = -0.47184321073267*10**(-3)
        elif i == 16:
            Ii = 2; Ji = 0; ni = -0.30001780793026*10**(-3)
        elif i == 17:
            Ii = 2; Ji = 1; ni = 0.47661393906987*10**(-4)
        elif i == 18:
            Ii = 2; Ji = 3; ni = -0.44141845330846*10**(-5)
        elif i == 19:
            Ii = 2; Ji = 17; ni = -0.72694996297594*10**(-15)
        elif i == 20:
            Ii = 3; Ji = -4; ni = -0.31679644845054*10**(-4)
        elif i == 21:
            Ii = 3; Ji = 0; ni = -0.28270797985312*10**(-5)
        elif i == 22:
            Ii = 3; Ji = 6; ni = -0.85205128120103*10**(-9)
        elif i == 23:
            Ii = 4; Ji = -5; ni = -0.22425281908000*10**(-5)
        elif i == 24:
            Ii = 4; Ji = -2; ni = -0.65171222895601*10**(-6)
        elif i == 25:
            Ii = 4; Ji = 10; ni = -0.14341729937924*10**(-12)
        elif i == 26:
            Ii = 5; Ji = -8; ni = -0.40516996860117*10**(-6)
        elif i == 27:
            Ii = 8; Ji = -11; ni = -0.12734301741641*10**(-8)
        elif i == 28:
            Ii = 8; Ji = -6; ni = -0.17424871230634*10**(-9)
        elif i == 29:
            Ii = 21; Ji = -29; ni = -0.68762131295531*10**(-18)
        elif i == 30:
            Ii = 23; Ji = -31; ni = 0.14478307828521*10**(-19)
        elif i == 31:
            Ii = 29; Ji = -38; ni = 0.26335781662795*10**(-22)
        elif i == 32:
            Ii = 30; Ji = -39; ni = -0.11947622640071*10**(-22)
        elif i == 33:
            Ii = 31; Ji = -40; ni = 0.18228094581404*10**(-23)
        elif i == 34:
            Ii = 32; Ji = -41; ni = -0.93537087292458*10**(-25)
        return Ii, Ji, ni

    # Specific volume at Region 1
    def IF97_SpecificVol_Region1(self,Tdeg,pMPa):
        gamma_pi = 0.0
        R = 0.461526*10**(-3) # Specific gas constant [kJ/(g*K)]
        TK = Tdeg + 273
        tau = 1386/TK
        ppi = pMPa/16.53
        for i in range(1,35):
            Ii, Ji, ni = self.Coef_IF97_Eq15_Region1(i)
            gamma_pi = gamma_pi - ni*Ii*((7.1-ppi)**(Ii-1))*((tau-1.222)**Ji)
        svol = ppi*gamma_pi*R*TK/pMPa
        return svol
    
    # Region 1 Specific_isobaric_heatcapacity
    def IF97_SpecificCp_Region1(self,Tdeg,pMPa):
        gamma_tautau = 0.0
        R = 0.461526 # Specific gas constant [kJ/(kg*K)]
        TK = Tdeg + 273
        tau = 1386/TK
        ppi = pMPa/16.53
        for i in range(1,35):
            Ii, Ji, ni = self.Coef_IF97_Eq15_Region1(i)
            gamma_tautau = gamma_tautau + ni*((7.1-ppi)**Ii)*Ji*(Ji-1)*(tau-1.222)**(Ji-2)
        scp = -R*(tau**2)*gamma_tautau
        return scp*10**3 # [J/(kg*K)]
    
    # Calculate density of water
    def Densw_pred(self,Tdeg,pPa):
        pMPa = pPa/(10**6)
        Densw = 1.0/self.IF97_SpecificVol_Region1(Tdeg,pMPa)
        return Densw
    
    # Calculate heat capacity
    def HCw_pred(self,Tdeg,pPa):
        pMPa = pPa/(10**6)
        HCw_calc = self.IF97_SpecificCp_Region1(Tdeg,pMPa)
        return HCw_calc
    
    def Viscow_pred(self,Tdeg,dens):
        TK = Tdeg + 273
        That = TK / 647.096
        cof0 = 1.67752 / That**0
        cof1 = 2.20462 / That**1
        cof2 = 0.6366564 / That**2
        cof3 = -0.241605 / That**3
        cof = cof0 + cof1 + cof2 + cof3
        myu0 = 100.0 * tf.math.sqrt(That) / cof
        H0 = [5.20094*10**(-1), 2.22531*10**(-1), -2.81378*10**(-1), 1.61913*10**(-1), -3.25372*10**(-2), 0.0, 0.0]
        H1 = [8.50895*10**(-2), 9.99115*10**(-1), -9.06851*10**(-1), 2.57399*10**(-1), 0.0, 0.0, 0.0]
        H2 = [-1.08374, 1.88797, -7.72479*10**(-1), 0.0, 0.0, 0.0, 0.0]
        H3 = [-2.89555*10**(-1), 1.26613, -4.89837*10**(-1), 0.0, 6.98452*10**(-2), 0.0, -4.35673*10**(-3)]
        H4 = [0.0, 0.0, -2.57040*10**(-1), 0.0, 0.0, 8.72102*10**(-3), 0.0]
        H5 = [0.0, 1.20573*10**(-1), 0.0, 0.0, 0.0, 0.0, -5.93264*10**(-4)]
        H = [H0, H1, H2, H3, H4, H5]
        rouhat = dens / 322.0
        myu1tp = 0.0
        for ii in range(0,6):
            Hsum = 0.0
            for jj in range(0,7):
                Hsum += H[ii][jj] * (rouhat-1)**jj
            myu1tp += ((1/That-1)**ii) * Hsum
        myu1 = tf.math.exp(rouhat*myu1tp)
        Visw_calc = myu0 * myu1 * 10**(-6)
        return Visw_calc

    def denormalize(self,xx,xxdata):
        return xx*(np.max(xxdata)-np.min(xxdata)) + np.min(xxdata)

    def normalize(self,xx,xxdata):
        datamax = tf.reduce_max(xxdata)
        datamin = tf.reduce_min(xxdata)
        return (xx - datamin) / (datamax - datamin)

    def reshape_and_rescale(
            self, var_grid: tf.Tensor,
            var: str) -> tf.Tensor:

        # The 3D grid data are reordered into a single column.
        var_ansatz = rcc.rearrange_grid_to_list(var_grid)

        # # Denormalize the quantities
        if var=='k': # permeability
            var_denm = self.denormalize(var_ansatz, self.k_norm)
        elif (var == 'T'): # temperature
            var_denm = self.denormalize(var_ansatz, self.T_norm)
        elif (var == 'p'): # pressure
            var_denm = self.denormalize(var_ansatz, self.p_norm)

        return var_denm

    def extract_points(self,
            in_tensor: tf.Tensor, grid_shape: tuple, var: Enum,
            model_input_grid: tf.Tensor,
            point_indices: tf.Tensor) -> tf.Tensor:

        # Re-organize the data in a 3D grid format.
        arr_grid = rcc.rearrange_list_to_grid(arr_list=in_tensor, grid_shape=grid_shape)
        # The data at the collocation-points are extracted and reordered into a single column.
        arr_extracted = tf.gather_nd(tf.squeeze(arr_grid, axis=0), point_indices)

        return arr_extracted
    
    def get_mass_loss(self):
        
        with tf.GradientTape(persistent=True) as tape1:
            # Watch variables with this GradientTape
            tape1.watch(self._model_input_grid)
            with tf.GradientTape(persistent=True) as tape2:
                # Watch variables with this GradientTape
                tape2.watch(self._model_input_grid)
            
                # Compute quantities by the PINN 
                # input shape = (1, N_x1, N_x2, N_x3, 3)
                # each output shape = (1, N_x1, N_x2, N_x3, 1)
                T_grid, p_grid, k_grid = self.model(self._model_input_grid)

                # gathered shape = (N_f, 1)
                pdenm = self.denormalize(p_grid, p_star)

            # Derivatives are taken from each variable in the entire input grid.
            p_x_grid = tape2.gradient(pdenm, self._model_input_grid)  # shape = (1, N_x1, N_x2, N_x3, 3)
            p_x1 = p_x_grid[:,:,:,:,0:1]
            p_x2 = p_x_grid[:,:,:,:,1:2]
            p_x3 = p_x_grid[:,:,:,:,2:3]
            Tdenm = self.denormalize(T_grid, T_star)
            kdenm = self.denormalize(k_grid, k_star)
            dens = self.Densw_pred(Tdenm, pdenm)
            visc = self.Viscow_pred(Tdenm, dens)
            densg = dens * 9.8
            ddv = tf.math.divide(dens,visc)
            pf_x1 = tf.math.multiply(ddv, p_x1)
            pf_x3 = tf.math.multiply(ddv, p_x3)
            pf_x2tp = tf.math.subtract(p_x2,densg)
            pf_x2 = tf.math.multiply(ddv, pf_x2tp)
            g1 = tf.math.multiply(10**kdenm, pf_x1)
            g2 = tf.math.multiply(10**kdenm, pf_x2)
            g3 = tf.math.multiply(10**kdenm, pf_x3)
        
        f1tp = tape1.gradient(g1, self._model_input_grid)  # shape = (1, N_x1, N_x2, N_x3, 1)
        f2tp = tape1.gradient(g2, self._model_input_grid)
        f3tp = tape1.gradient(g3, self._model_input_grid)
        posf1 = tf.math.is_finite(f1tp)
        posf2 = tf.math.is_finite(f2tp)
        posf3 = tf.math.is_finite(f3tp)
        f1 = tf.where(posf1,f1tp,[10**5])
        f2 = tf.where(posf2,f2tp,[10**5])
        f3 = tf.where(posf3,f3tp,[10**5])
        
        del tape2
        del tape1
        
        return self.func_ms(f1, f2, f3)

    def get_energy_loss(self):
        
        with tf.GradientTape(persistent=True) as tape1:
            tape1.watch(self._model_input_grid)
            with tf.GradientTape(persistent=True) as tape2:
                tape2.watch(self._model_input_grid)

                # Compute quantities by the PINN 
                # input shape = (1, N_x1, N_x2, N_x3, 3)
                # each output shape = (1, N_x1, N_x2, N_x3, 1)
                T_grid, p_grid, k_grid = self.model(self._model_input_grid)
                pdenm = self.denormalize(p_grid, p_star)
                Tdenm = self.denormalize(T_grid, T_star)

            p_x_grid = tape2.gradient(pdenm, self._model_input_grid)  # shape = (1, N_x1, N_x2, N_x3, 1)
            p_x1 = p_x_grid[:,:,:,:,0:1]
            p_x2 = p_x_grid[:,:,:,:,1:2]
            p_x3 = p_x_grid[:,:,:,:,2:3]    
            T_x_grid = tape2.gradient(Tdenm, self._model_input_grid)  # shape = (1, N_x1, N_x2, N_x3, 1)
            T_x1 = T_x_grid[:,:,:,:,0:1]
            T_x2 = T_x_grid[:,:,:,:,1:2]
            T_x3 = T_x_grid[:,:,:,:,2:3]

            kdenm = self.denormalize(k_grid, k_star)
            dens = self.Densw_pred(Tdenm, pdenm)
            visc = self.Viscow_pred(Tdenm, dens)
            hcw = self.HCw_pred(Tdenm, pdenm)
            densg = dens * 9.8
            ddw_tp1 = tf.math.divide(dens,visc)
            ddw_tp2 = tf.math.multiply(hcw,ddw_tp1)
            ddw = tf.math.multiply(Tdenm,ddw_tp2)
            pf_x1 = tf.math.multiply(ddw, p_x1)
            pf_x3 = tf.math.multiply(ddw, p_x3)
            pf_x2tp = tf.math.subtract(p_x2,densg)
            pf_x2 = tf.math.multiply(ddw, pf_x2tp)
            g1 = tf.math.multiply(10**kdenm, pf_x1)
            g2 = tf.math.multiply(10**kdenm, pf_x2)
            g3 = tf.math.multiply(10**kdenm, pf_x3)
        
        f_x1tp_grid = tape1.gradient(g1, self._model_input_grid)
        f_x2tp_grid = tape1.gradient(g2, self._model_input_grid)
        f_x3tp_grid = tape1.gradient(g3, self._model_input_grid)
        posf1 = tf.math.is_finite(f_x1tp_grid)
        posf2 = tf.math.is_finite(f_x2tp_grid)
        posf3 = tf.math.is_finite(f_x3tp_grid)
        f_x1x1 = tf.where(posf1,f_x1tp_grid,[10**5])
        f_x2x2 = tf.where(posf2,f_x2tp_grid,[10**5])
        f_x3x3 = tf.where(posf3,f_x3tp_grid,[10**5])
                
        T_x1x1 = tape1.gradient(T_x1, self._model_input_grid)
        T_x2x2 = tape1.gradient(T_x2, self._model_input_grid)
        T_x3x3 = tape1.gradient(T_x3, self._model_input_grid)
        
        del tape2
        del tape1
        
        return self.func_en(f_x1x1, f_x2x2, f_x3x3, T_x1x1, T_x2x2, T_x3x3)
    
    
    def func_ms(self, fl1, fl2, fl3):
        """Residual of the PDE"""
        return fl1 + fl2 + fl3

    def func_en(self, fl_x1x1, fl_x2x2, fl_x3x3,
               T_x1x1, T_x2x2, T_x3x3):
        """Residual of the PDE"""
        return (fl_x1x1 + fl_x2x2 + fl_x3x3
                + Lambda * (T_x1x1 + T_x2x2 + T_x3x3))
    
    def get_TpbD(self):

        T_grid, p_grid, _ = self.model(self._model_input_grid)

        # The data stored in a three-dimensional grid is rearranged into a single list and converted into physical quantities.
        Tdenm = self.reshape_and_rescale(T_grid, var='T')
        pdenm = self.reshape_and_rescale(p_grid, var='p')
        
        # Extract data at the upper surface boundary 
        # gathered shape = (N_up, 1)
        grid_shape = T_grid.shape
        utt_denm = self.extract_points(
            in_tensor=Tdenm, grid_shape=grid_shape, var=VariableType.T,
            model_input_grid=self._model_input_grid,
            point_indices=self._model_input_grid_spec.up_indices)
        upp_denm = self.extract_points(
            in_tensor=pdenm, grid_shape=grid_shape, var=VariableType.p,
            model_input_grid=self._model_input_grid,
            point_indices=self._model_input_grid_spec.up_indices)

        # Extract data at the bottom boundary
        grid_shape = T_grid.shape
        ltt_denm = self.extract_points(
            in_tensor=Tdenm, grid_shape=grid_shape, var=VariableType.T,
            model_input_grid=self._model_input_grid,
            point_indices=self._model_input_grid_spec.low_indices)
        lpp_denm = self.extract_points(
            in_tensor=pdenm, grid_shape=grid_shape, var=VariableType.p,
            model_input_grid=self._model_input_grid,
            point_indices=self._model_input_grid_spec.low_indices)

        # Normalization
        N_up = self._model_input_grid_spec.up_indices.shape[0]
        utt = tf.reshape(self.normalize(utt_denm, self.T_norm), [1, N_up, 1])
        upp = tf.reshape(self.normalize(upp_denm, self.p_norm), [1, N_up, 1])

        N_low = self._model_input_grid_spec.low_indices.shape[0]
        ltt = tf.reshape(self.normalize(ltt_denm, self.T_norm), [1, N_low, 1])
        lpp = tf.reshape(self.normalize(lpp_denm, self.p_norm), [1, N_low, 1])

        return utt, upp, ltt, lpp

    def get_TbN(self):

        with tf.GradientTape() as tape11, \
            tf.GradientTape() as tape12, \
                tf.GradientTape() as tape13:
            tape11.watch(self._model_input_grid)
            tape12.watch(self._model_input_grid)
            tape13.watch(self._model_input_grid)

            # Prediction by the PINN
            # input shape = (1, N_x1, N_x2, N_x3, 3)
            # each output shape = (1, N_x1, N_x2, N_x3, 1)
            new_grid = self._model_input_grid
            T_grid, p_grid, _ = self.model(new_grid)

            # The data stored in a three-dimensional grid is rearranged into a single list and converted into physical quantities.
            Tdenm = self.reshape_and_rescale(T_grid, var='T')
            pdenm = self.reshape_and_rescale(p_grid, var='p')

            grid_shape = T_grid.shape
            Tdenm = self.extract_points(
                in_tensor=Tdenm, grid_shape=grid_shape, var=VariableType.T,
                model_input_grid=self._model_input_grid,
                point_indices=self._model_input_grid_spec.low_indices)

        gT_1_grid = tape11.gradient(Tdenm, self._model_input_grid)
        gT_1 = tf.gather_nd(tf.squeeze(gT_1_grid, axis=0), self._model_input_grid_spec.low_indices)[:, 0:1]
        gT_2_grid = tape12.gradient(Tdenm, self._model_input_grid)
        gT_2 = tf.gather_nd(tf.squeeze(gT_2_grid, axis=0), self._model_input_grid_spec.low_indices)[:, 1:2]
        gT_3_grid = tape13.gradient(Tdenm, self._model_input_grid)
        gT_3 = tf.gather_nd(tf.squeeze(gT_3_grid, axis=0), self._model_input_grid_spec.low_indices)[:, 2:3]

        gT = gT_1*self.nvecT[:,0:1] + gT_2*self.nvecT[:,1:2]+ gT_3*self.nvecT[:,2:3]

        return gT

    def calculate_tpk_loss(self, X,
                            X_wells, T_wells, p_wells, k_wells,
                            well_neighbor_indices):

        # Calculate loss function for T, p, k.
        N = X.shape[1]
        x = tf.reshape(X, [1, N, 3])
        T_grid, p_grid, k_grid = self.model(self._model_input_grid)

        # The data stored in a three-dimensional grid is rearranged into a single list and converted into physical quantities.
        Tdenm = self.reshape_and_rescale(T_grid, var='T')
        pdenm = self.reshape_and_rescale(p_grid, var='p')
        kdenm = self.reshape_and_rescale(k_grid, var='k')
        
        # Extracts data from points near the wells
        grid_well_neighbor_indices = (
            tf.gather(
                tf.transpose(
                    tf.unravel_index(tf.squeeze(well_neighbor_indices, axis=0),
                                     dims=(self._model_input_grid_spec.N_x3,
                                           self._model_input_grid_spec.N_x2,
                                           self._model_input_grid_spec.N_x1))),
                indices=[2, 1, 0],
                axis=1)
        )

        grid_shape = T_grid.shape
        Tdenm = self.extract_points(
            in_tensor=Tdenm, grid_shape=grid_shape, var=VariableType.T,
            model_input_grid=self._model_input_grid,
            point_indices=grid_well_neighbor_indices)
        pdenm = self.extract_points(
            in_tensor=pdenm, grid_shape=grid_shape, var=VariableType.p,
            model_input_grid=self._model_input_grid,
            point_indices=grid_well_neighbor_indices)
        kdenm = self.extract_points(
            in_tensor=kdenm, grid_shape=grid_shape, var=VariableType.k,
            model_input_grid=self._model_input_grid,
            point_indices=grid_well_neighbor_indices)

        # Normalization
        T_pred = tf.reshape(self.normalize(Tdenm, self.T_norm), [1, N, 1])
        p_pred = tf.reshape(self.normalize(pdenm, self.p_norm), [1, N, 1])
        k_pred = tf.reshape(self.normalize(kdenm, self.k_norm), [1, N, 1])

        # Obtain T, P, and K at wells
        T_interp = tfa.image.interpolate_spline(x, T_pred, X_wells, self.order, regularization_weight=self.reg)
        p_interp = tfa.image.interpolate_spline(x, p_pred, X_wells, self.order, regularization_weight=self.reg)
        k_interp = tfa.image.interpolate_spline(x, k_pred, X_wells, self.order, regularization_weight=self.reg)

        Tloss = tf.reduce_mean(tf.square(T_interp - T_wells))
        ploss = tf.reduce_mean(tf.square(p_interp - p_wells))
        kloss = tf.reduce_mean(tf.square(k_interp - k_wells))

        return Tloss, ploss, kloss

    def loss_func(self, X,
                Y_pbN, Y_TbN,
                X_wells, T_wells, p_wells, k_wells,
                well_neighbor_indices, dstype):

        if self.lmode == 1 or self.lmode == 0:

            # Compute phi_r (physics-informed constraint)          
            lm_AD = self.get_mass_loss()
            le_AD = self.get_energy_loss()
            PI1loss_AD = tf.reduce_mean(tf.square(lm_AD))
            PI2loss_AD = tf.reduce_mean(tf.square(le_AD))

            # Compute boundary condition
            uTpbD_T_pred, uTpbD_p_pred, lTpbD_T_pred, lTpbD_p_pred= self.get_TpbD()

            TbN_pred = self.get_TbN()

            # Loss of boudary condition
            uTpbDloss_T = tf.reduce_mean(tf.square(self.Y_uT - uTpbD_T_pred))
            uTpbDloss_p = tf.reduce_mean(tf.square(self.Y_up - uTpbD_p_pred))
            lTpbDloss_T = tf.reduce_mean(tf.square(self.Y_lT - lTpbD_T_pred))
            lTpbDloss_p = tf.reduce_mean(tf.square(self.Y_lp - lTpbD_p_pred))
            uTpbDloss = tf.add(uTpbDloss_T, uTpbDloss_p)
            lTpbDloss = tf.add(lTpbDloss_T, lTpbDloss_p)
            if (self.if_low_dirichlet):
                TpbDloss = tf.add(uTpbDloss, lTpbDloss)
                TbNloss =  tf.constant(0.0, self.dtype)
            else:
                TpbDloss = uTpbDloss
                TbNloss = tf.reduce_mean(tf.square(Y_TbN - TbN_pred))

        else:

            PI1loss_FD = tf.constant(0.0, self.dtype)
            PI2loss_FD = tf.constant(0.0, self.dtype)
            PI1loss_AD = tf.constant(0.0, self.dtype)
            PI2loss_AD = tf.constant(0.0, self.dtype)
            TpbDloss = tf.constant(0.0, self.dtype)
            TbNloss =  tf.constant(0.0, self.dtype)

        # If the data set is training or validation, the loss of T, p, and k is calculated.
        if (dstype != DataSet.ALL):
            Tloss, ploss, kloss \
                = self.calculate_tpk_loss(X,
                                            X_wells, T_wells, p_wells, k_wells,
                                            well_neighbor_indices)
        else:
            Tloss = tf.constant(0, dtype = self.dtype)
            ploss = tf.constant(0, dtype = self.dtype)
            kloss = tf.constant(0, dtype = self.dtype)

        # Sum up
        if self.lmode <= 0:
            loss = Tloss + ploss + kloss        
        else:
            loss = Tloss + ploss + kloss + self.wPI1*PI1loss_AD*self.cof_PI1AD + self.wPI2*PI2loss_AD*self.cof_PI2AD + TpbDloss + self.wTbN*TbNloss*self.cof_TbN

        return loss, Tloss, ploss, kloss, PI1loss_AD, PI2loss_AD, TpbDloss, TbNloss

    def get_grad(self, X,
                Y_pbN, Y_TbN,
                X_wells, T_wells, p_wells, k_wells,
                well_neighbor_indices, dstype, nan_to_zero=True):

        with tf.GradientTape(persistent=True) as tape, \
            tf.GradientTape(persistent=True) as tapeT, \
                tf.GradientTape(persistent=True) as tapeP, \
             tf.GradientTape(persistent=True) as tapeK, \
            tf.GradientTape(persistent=True) as tapePI1, \
            tf.GradientTape(persistent=True) as tapePI2, \
            tf.GradientTape(persistent=True) as tapeTbN:
            #
            tape.watch(self.model.trainable_variables)
            tapeT.watch(self.model.trainable_variables)
            tapeP.watch(self.model.trainable_variables)
            tapeK.watch(self.model.trainable_variables)
            tapePI1.watch(self.model.trainable_variables)
            tapePI2.watch(self.model.trainable_variables)
            tapeTbN.watch(self.model.trainable_variables)
            loss,Tloss,ploss,kloss,PI1loss_AD,PI2loss_AD,TpbDloss,TbNloss\
                = self.loss_func(X,
                            Y_pbN, Y_TbN,
                            X_wells, T_wells, p_wells, k_wells,
                            well_neighbor_indices, dstype)

        g = tape.gradient(loss, self.model.trainable_variables)
        if self.lmode < 0:
            self.wPI1 = 1.0
            self.wPI2 = 1.0
            self.wTbN = 1.0
            gradT = tapeT.gradient(Tloss, self.model.trainable_variables)
            gradP = tapeP.gradient(ploss, self.model.trainable_variables)
            gradK = tapeK.gradient(kloss, self.model.trainable_variables)
            gradTPK_max_list = tf.constant([0.0], self.dtype)
            for gradT_elem,gradP_elem,gradK_elem in zip(gradT,gradP,gradK):
                if gradT_elem is None:
                    is_gradT_nan = True
                else:
                    is_gradT_nan = tf.reduce_any(tf.math.is_nan(gradT_elem))
                if gradP_elem is None:
                    is_gradP_nan = True
                else:
                    is_gradP_nan = tf.reduce_any(tf.math.is_nan(gradP_elem))
                if gradK_elem is None:
                    is_gradK_nan = True
                else:
                    is_gradK_nan = tf.reduce_any(tf.math.is_nan(gradK_elem))
                if not is_gradT_nan:
                    gradTPK_max_list = tf.concat([gradTPK_max_list, [tf.math.reduce_mean(tf.abs(gradT_elem))]], axis=0)
                if not is_gradP_nan:
                    gradTPK_max_list = tf.concat([gradTPK_max_list, [tf.math.reduce_mean(tf.abs(gradP_elem))]], axis=0)
                if not is_gradK_nan:
                    gradTPK_max_list = tf.concat([gradTPK_max_list, [tf.math.reduce_mean(tf.abs(gradK_elem))]], axis=0)
            gradTPK_max = tf.math.reduce_max(gradTPK_max_list)
        elif self.iter % 100 == 0:
            gradT = tapeT.gradient(Tloss, self.model.trainable_variables)
            gradP = tapeP.gradient(ploss, self.model.trainable_variables)
            gradK = tapeK.gradient(kloss, self.model.trainable_variables)
            gradPI1 = tapePI1.gradient(PI1loss_AD, self.model.trainable_variables)
            gradPI2 = tapePI2.gradient(PI2loss_AD, self.model.trainable_variables)
            gradTbN = tapeTbN.gradient(TbNloss, self.model.trainable_variables)
            gTPK_max_list = tf.constant([0.0], self.dtype)
            gPI1_mean_list = tf.constant([], self.dtype)
            gPI2_mean_list = tf.constant([], self.dtype)
            gTbN_mean_list = tf.constant([], self.dtype)
            for gT_el,gP_el,gK_el,gPI1_el,gPI2_el,gTbN_el in zip(gradT,gradP,gradK,gradPI1,gradPI2,gradTbN):
                if gT_el is None:
                    is_gradT_nan = True
                else:
                    is_gradT_nan = tf.reduce_any(tf.math.is_nan(gT_el))
                if gP_el is None:
                    is_gradP_nan = True
                else:
                    is_gradP_nan = tf.reduce_any(tf.math.is_nan(gP_el))
                if gK_el is None:
                    is_gradK_nan = True
                else:
                    is_gradK_nan = tf.reduce_any(tf.math.is_nan(gK_el))
                if gPI1_el is None:
                    is_gradPI1_nan = True
                else:
                    is_gradPI1_nan = tf.reduce_any(tf.math.is_nan(gPI1_el))
                if gPI2_el is None:
                    is_gradPI2_nan = True
                else:
                    is_gradPI2_nan = tf.reduce_any(tf.math.is_nan(gPI2_el))
                if gTbN_el is None:
                    is_gradTbN_nan = True
                else:
                    is_gradTbN_nan = tf.reduce_any(tf.math.is_nan(gTbN_el))
                if not is_gradT_nan:
                    gTPK_max_list = tf.concat([gTPK_max_list, [tf.math.reduce_mean(tf.abs(gT_el))]], axis=0)
                if not is_gradP_nan:
                    gTPK_max_list = tf.concat([gTPK_max_list, [tf.math.reduce_mean(tf.abs(gP_el))]], axis=0)
                if not is_gradK_nan:
                    gTPK_max_list = tf.concat([gTPK_max_list, [tf.math.reduce_mean(tf.abs(gK_el))]], axis=0)
                if not is_gradPI1_nan:
                    gPI1_mean_list = tf.concat([gPI1_mean_list, [tf.math.reduce_max(tf.abs(gPI1_el))]], axis=0)
                if not is_gradPI2_nan:
                    gPI2_mean_list = tf.concat([gPI2_mean_list, [tf.math.reduce_max(tf.abs(gPI2_el))]], axis=0)
                if not is_gradTbN_nan:
                    gTbN_mean_list = tf.concat([gTbN_mean_list, [tf.math.reduce_max(tf.abs(gTbN_el))]], axis=0)
            gTPK_max = tf.math.reduce_mean(gTPK_max_list)
            gPI1_mean = self.cof_PI1AD*tf.math.reduce_max(gPI1_mean_list)
            gPI2_mean = self.cof_PI2AD*tf.math.reduce_max(gPI2_mean_list)
            gTbN_mean = self.cof_TbN*tf.math.reduce_max(gTbN_mean_list)
            if not gTPK_max == 0.0:
                if gPI1_mean is not None and not tf.reduce_any(tf.math.is_nan(gPI1_mean)):
                    wPI1_hat = gTPK_max / gPI1_mean
                    if wPI1_hat is not None and not tf.reduce_any(tf.math.is_nan(wPI1_hat)):
                        self.wPI1 = 0.6 * self.wPI1 + 0.4 * wPI1_hat
                if gPI2_mean is not None and not tf.reduce_any(tf.math.is_nan(gPI2_mean)):
                    wPI2_hat = gTPK_max / gPI2_mean
                    if wPI2_hat is not None and not tf.reduce_any(tf.math.is_nan(wPI2_hat)):
                        self.wPI2 = 0.6 * self.wPI2 + 0.4 * wPI2_hat
                if gTbN_mean is not None and not tf.reduce_any(tf.math.is_nan(gTbN_mean)):
                    wTbN_hat = gTPK_max / gTbN_mean
                    if wTbN_hat is not None and not tf.reduce_any(tf.math.is_nan(wTbN_hat)):
                        self.wTbN = 0.6 * self.wTbN + 0.4 * wTbN_hat
                        
            if self.iter % 2500 == 0 and self.iter != 0:
                for gT_el,gP_el,gK_el,gPI1_el,gPI2_el,gTbN_el in zip(gradT,gradP,gradK,gradPI1,gradPI2,gradTbN):
                    if gT_el is None:
                        is_gradT_nan = True
                    else:
                        is_gradT_nan = tf.reduce_any(tf.math.is_nan(gT_el))
                    if gP_el is None:
                        is_gradP_nan = True
                    else:
                        is_gradP_nan = tf.reduce_any(tf.math.is_nan(gP_el))
                    if gK_el is None:
                        is_gradK_nan = True
                    else:
                        is_gradK_nan = tf.reduce_any(tf.math.is_nan(gK_el))
                    if gPI1_el is None:
                        is_gradPI1_nan = True
                    else:
                        is_gradPI1_nan = tf.reduce_any(tf.math.is_nan(gPI1_el))
                    if gPI2_el is None:
                        is_gradPI2_nan = True
                    else:
                        is_gradPI2_nan = tf.reduce_any(tf.math.is_nan(gPI2_el))
                    if gTbN_el is None:
                        is_gradTbN_nan = True
                    else:
                        is_gradTbN_nan = tf.reduce_any(tf.math.is_nan(gTbN_el))
                    if not is_gradT_nan:
                        gTel_flatten = self.flatten_list(tf.reshape(tf.squeeze(gT_el), [-1]).numpy().tolist())
                        self.gTPK.append(gTel_flatten)
                    if not is_gradP_nan:
                        gPel_flatten = self.flatten_list(tf.reshape(tf.squeeze(gP_el), [-1]).numpy().tolist())
                        self.gTPK.append(gPel_flatten)
                    if not is_gradK_nan:
                        gKel_flatten = self.flatten_list(tf.reshape(tf.squeeze(gK_el), [-1]).numpy().tolist())
                        self.gTPK.append(gKel_flatten)
                    if not is_gradPI1_nan:
                        gPI1el_flatten = self.flatten_list(tf.reshape(tf.squeeze(gPI1_el), [-1]).numpy().tolist())
                        self.gPI1.append(gPI1el_flatten)
                    if not is_gradPI2_nan:
                        gPI2el_flatten = self.flatten_list(tf.reshape(tf.squeeze(gPI2_el), [-1]).numpy().tolist())
                        self.gPI2.append(gPI2el_flatten)
                    if not is_gradTbN_nan:
                        gTbNel_flatten = self.flatten_list(tf.reshape(tf.squeeze(gTbN_el), [-1]).numpy().tolist())
                        self.gTbN.append(gTbNel_flatten)
                
        if nan_to_zero:
            ind = 0
            for var, grad in zip(self.model.trainable_variables, g):
                if grad is not None:
                    max_grad = tf.reduce_max(tf.abs(grad))
                    is_grad_nan = tf.reduce_any(tf.math.is_nan(grad))
                    if is_grad_nan:
                        g[ind] = tf.zeros_like(grad)
                ind += 1
                
        del tape,tapeT,tapeP,tapeK,tapePI1,tapePI2,tapeTbN

        return loss, Tloss, ploss, kloss, PI1loss_AD, PI2loss_AD, TpbDloss, TbNloss, g

    def solve_with_TFoptimizer(self, optimizer, train_ds, val_ds, all_ds, N=1001):
        """This method performs a gradient descent type optimization."""

        self.max_iteration = N

        #@tf.function
        def train_step():
            loss,Tloss,ploss,kloss,PI1loss_AD,PI2loss_AD,TpbDloss,TbNloss,grad_theta \
                = self.get_grad(X,
                                Y_pbN, Y_TbN,
                                X_wells, T_wells, p_wells, k_wells,
                                well_neighbor_indices, dstype)

            # Perform gradient descent step
            optimizer.apply_gradients(zip(grad_theta, self.model.trainable_variables))
            return loss, Tloss, ploss, kloss, PI1loss_AD, PI2loss_AD, TpbDloss, TbNloss

        #@tf.function
        def valds_step():
            loss,Tloss,ploss,kloss,PI1loss_AD,PI2loss_AD,TpbDloss,TbNloss,grad_theta \
                = self.get_grad(X,
                                Y_pbN, Y_TbN,
                                X_wells, T_wells, p_wells, k_wells,
                                well_neighbor_indices, dstype)
            return loss, Tloss, ploss, kloss, PI1loss_AD, PI2loss_AD, TpbDloss, TbNloss

        #@tf.function
        def allds_step():
            loss,Tloss,ploss,kloss,PI1loss_AD,PI2loss_AD,TpbDloss,TbNloss,grad_theta \
                = self.get_grad(X,
                                Y_pbN, Y_TbN,
                                X_wells, T_wells, p_wells, k_wells,
                                well_neighbor_indices, dstype)
            return loss, Tloss, ploss, kloss, PI1loss_AD, PI2loss_AD, TpbDloss, TbNloss

        for epoch in range(N):
            for (X, Y_pbN,Y_TbN,\
                X_wells, T_wells, p_wells, k_wells,
                well_neighbor_indices) in train_ds.shuffle(1000).batch(25).prefetch(1):
                dstype = DataSet.TRAIN # Dataset type: training
                loss,Tloss,ploss,kloss,PI1loss_AD,PI2loss_AD,TpbDloss,TbNloss = train_step()
                self.current_loss = loss.numpy()
                self.Tloss = Tloss.numpy()
                self.ploss = ploss.numpy()
                self.kloss = kloss.numpy()
                self.PI1loss_AD = PI1loss_AD.numpy()
                self.PI2loss_AD = PI2loss_AD.numpy()
                self.TpbDloss = TpbDloss.numpy()
                self.TbNloss = TbNloss.numpy()

            for (X, Y_pbN,Y_TbN,\
                X_wells, T_wells, p_wells, k_wells,
                well_neighbor_indices) in val_ds.shuffle(1000).batch(50).prefetch(1):
                dstype = DataSet.VAL # Dataset type: validation
                valloss,Tloss,ploss,kloss,PI1loss_AD,PI2loss_AD,TpbDloss,TbNloss = valds_step()
                self.val_loss = valloss.numpy()

            for (X, Y_pbN,Y_TbN,\
                X_wells, T_wells, p_wells, k_wells,
                well_neighbor_indices) in all_ds.shuffle(1000).batch(50).prefetch(1):
                dstype = DataSet.ALL # Dataset type: all
                allloss,Tloss,ploss,kloss,PI1loss_AD_all,PI2loss_AD_all,TpbDloss,TbNloss = allds_step()
                self.PI1all_AD_loss = PI1loss_AD_all.numpy()
                self.PI2all_AD_loss = PI2loss_AD_all.numpy()

            self.callback()

    def solve_with_tfp_lbfgs(self, train_ds, val_ds, all_ds, max_iterations, debug_print=False):

        self.max_iteration = max_iterations

        def get_loss_for_dataset(dataset:tf.data.Dataset, dstype: DataSet):
            """
            Get loss and loss components:
            (loss, Tloss, ploss, kloss, PI1loss, PI2loss, TpbDloss, pbNloss, TbNloss)
            """
            data, = tuple(dataset.batch(1).as_numpy_iterator())
            
            (X, Y_pbN, Y_TbN, X_wells, T_wells, p_wells, k_wells,
             well_neighbor_indices) = data
            *losses, _ = self.get_grad(tf.convert_to_tensor(X),
                                       Y_pbN, Y_TbN,
                                       X_wells, T_wells, p_wells, k_wells,
                                       well_neighbor_indices, dstype)
            return losses

        def loss_gradient_function(X,
                                   Y_pbN, Y_TbN,
                                   X_wells, T_wells, p_wells, k_wells,
                                   well_neighbor_indices, dstype):
            """A funtion that returns loss and gradient for L-BFGS"""
            losses = self.get_grad(X,
                                   Y_pbN, Y_TbN,
                                   X_wells, T_wells, p_wells, k_wells,
                                   well_neighbor_indices, dstype)

            (loss, Tloss, ploss, kloss, PI1loss_AD, PI2loss_AD,
             TpbDloss, TbNloss, gradient) = losses

            if debug_print:
                print(f'[L-BFGS] objective evaluated: loss value = {float(loss)}')
            return loss, gradient

        # Get training data
        train_data, = tuple(train_ds.batch(1).as_numpy_iterator())
        (X, Y_pbN, Y_TbN, X_wells, T_wells, p_wells, k_wells,
         well_neighbor_indices) = train_data

        # Initialize L-BFGS optimizer
        lbfgs_optimizer = LBfgsOptimizer(
            model=self.model,
            loss_function=loss_gradient_function,
            loss_function_kwargs=dict(
                X=tf.convert_to_tensor(X),
                Y_pbN=Y_pbN, Y_TbN=Y_TbN,
                X_wells=X_wells, T_wells=T_wells, p_wells=p_wells, k_wells=k_wells,
                well_neighbor_indices=well_neighbor_indices, dstype=DataSet.TRAIN),
            lbfgs_kwargs=dict(  # Set the arguments to be passed to tfp.optimize.lbfgs_minimize()
                tolerance=1e-12,
                max_line_search_iterations=100,
            )
        )

        # Execute L-BFGS optimizer
        current_iterations = 0
        result_state = None
        callback_interval = 1
        while current_iterations < max_iterations:

            # Execute a single interation of L-BFGS optimization
            result_state = lbfgs_optimizer.optimize(
                max_iterations=min(current_iterations+callback_interval, max_iterations)
            )

            # Update the number of epoch
            current_iterations = result_state.num_total_iterations
            if debug_print:
                print(f'[L-BFGS] current result state: {result_state}')

            # Calculate loss components using the temporaily trained PINN
            loss, Tloss, ploss, kloss, PI1loss_AD, PI2loss_AD, TpbDloss, TbNloss = get_loss_for_dataset(train_ds, DataSet.TRAIN)
            self.current_loss = loss.numpy()
            self.Tloss = Tloss.numpy()
            self.ploss = ploss.numpy()
            self.kloss = kloss.numpy()
            self.PI1loss_AD = PI1loss_AD.numpy()
            self.PI2loss_AD = PI2loss_AD.numpy()
            self.TpbDloss = TpbDloss.numpy()
            self.TbNloss = TbNloss.numpy()

            val_loss, *_ = get_loss_for_dataset(val_ds, DataSet.VAL)
            self.val_loss = val_loss.numpy()

            all_loss, _, _, _, PI1loss_all_AD, PI2loss_all_AD, _, _ = get_loss_for_dataset(all_ds, DataSet.ALL)
            self.PI1all_AD_loss = PI1loss_all_AD.numpy()
            self.PI2all_AD_loss = PI2loss_all_AD.numpy()

            # Finish iteration once L-BFGS is converged
            if result_state.converged:
                print(f'[L-BFGS] converged: state={result_state}.')
                # callback
                self.callback(force_all_output=True)
                return result_state

            # Finish the iteration when L-BFGS is failed to converge
            # L-BFGS of TensorFlow Probablity failed under the following condition:
            # "a line search step failed to find a suitable step size satisfying Wolfe conditions"
            # (https://www.tensorflow.org/probability/api_docs/python/tfp/optimizer/lbfgs_minimize)
            if result_state.failed:
                print(f'[L-BFGS] line search step failed: state={result_state}.')
                # Execute callback
                self.callback(force_all_output=True)
                return result_state

            # callback
            self.callback()

        print(f'[L-BFGS] reached max iteration: state={result_state}.')
        return result_state

    def retrieve_all_point_indices(self, model_input_grid):
        
        x_size = model_input_grid.shape[1]
        y_size = model_input_grid.shape[2]
        z_size = model_input_grid.shape[3]
        all_point_indices_np = np.zeros(shape=(x_size * y_size * z_size, 3), dtype=int)

        for iz in range(z_size):
            for iy in range(y_size):
                for ix in range(x_size):
                    all_point_indices_np[ix + iy * x_size + iz * x_size * y_size, 0] = ix
                    all_point_indices_np[ix + iy * x_size + iz * x_size * y_size, 1] = iy
                    all_point_indices_np[ix + iy * x_size + iz * x_size * y_size, 2] = iz

        all_point_indices = tf.constant(all_point_indices_np)

        return all_point_indices

    def flatten_list(self,nested_list):
        flat_list = []
        for item in nested_list:
            if isinstance(item, list):
                flat_list.extend(self.flatten_list(item))
            else:
                flat_list.append(item)
        return flat_list
    
    def callback(self, xr=None, force_all_output=False):

        if self.iter % 100 == 0:
            print('#----------------------#')
            print('It {:05d}:'.format(self.iter))
            print('trainloss,valloss = {:10.8e},{:10.8e}'.format(self.current_loss,self.val_loss))
            print('Tloss,Ploss,Kloss {:10.8e},{:10.8e},{:10.8e}'.format(self.Tloss,self.ploss,self.kloss))
            print('Mass_loss,Energy_loss {:10.8e},{:10.8e}'.format(self.PI1loss_AD,self.PI2loss_AD))
            print('TpbDloss,TpNloss {:10.8e},{:10.8e}'.format(self.TpbDloss,self.TbNloss))

        if self.iter % 5000 == 0:
            checkpoint = tf.train.Checkpoint(model=self.model.variables)
            manager = tf.train.CheckpointManager(checkpoint, directory='./save_checkpoints', checkpoint_name='weights.ckpt', max_to_keep=500)
            path = manager.save(checkpoint_number=self.iter)
            print("weights saved to %s" % path)

        self.train_hist.append(self.current_loss)
        self.val_hist.append(self.val_loss)
        self.Tloss_hist.append(self.Tloss)
        self.ploss_hist.append(self.ploss)
        self.kloss_hist.append(self.kloss)
        self.PI1lossAD_hist.append(self.PI1loss_AD)
        self.PI2lossAD_hist.append(self.PI2loss_AD)
        self.TpbDloss_hist.append(self.TpbDloss)
        self.TbNloss_hist.append(self.TbNloss)
        self.PI1lossAD_all_hist.append(self.PI1all_AD_loss)
        self.PI2lossAD_all_hist.append(self.PI2all_AD_loss)

        self.gTPK = []
        self.gPI1 = []
        self.gPI2 = []
        self.iter+=1

        if self.iter % 100 == 0 or self.iter == 1 or force_all_output:
            if not os.path.exists('./save_predict'):
                os.makedirs('./save_predict')
                
            T_grid, p_grid, k_grid = self.model(self._model_input_grid)

            Tdenm = self.reshape_and_rescale(T_grid, var='T')
            pdenm = self.reshape_and_rescale(p_grid, var='p')
            kdenm = self.reshape_and_rescale(k_grid, var='k')

            all_point_indices = self.retrieve_all_point_indices(self._model_input_grid)
            grid_shape = T_grid.shape
            Tdenm = self.extract_points(
                in_tensor=Tdenm, grid_shape=grid_shape, var=VariableType.T,
                model_input_grid=self._model_input_grid,
                point_indices=all_point_indices)
            pdenm = self.extract_points(
                in_tensor=pdenm, grid_shape=grid_shape, var=VariableType.p,
                model_input_grid=self._model_input_grid,
                point_indices=all_point_indices)

            Tdenm_df = pd.DataFrame(Tdenm.numpy())
            pdenm_df = pd.DataFrame(pdenm.numpy())
            kdenm_df = pd.DataFrame(kdenm.numpy())
            Tpkdenm_df = pd.concat([Tdenm_df,pdenm_df,kdenm_df],axis=1)
            Tpkdenm_df.columns=['T_pred','p_pred','k_pred']
            Tpkdenm_df.to_csv('./save_predict/predicted_'+str(self.iter)+'.csv',index=False)
            # Calculate L2 error
            N = self.xall.shape[0]
            t = Tdenm.numpy().reshape(N)
            t_true = self.T_true.reshape(N)
            p = pdenm.numpy().reshape(N)
            p_true = self.p_true.reshape(N)
            k = kdenm.numpy().reshape(N)
            k_true = self.k_true.reshape(N)

            err_T = np.sqrt(np.dot(t - t_true, t - t_true)/np.dot(t_true, t_true))
            err_p = np.sqrt(np.dot(p - p_true, p - p_true)/np.dot(p_true, p_true))
            err_k = np.sqrt(np.dot(k - k_true, k - k_true)/np.dot(k_true, k_true))
            print(self.iter, err_T, err_p, err_k)
            filename = "./save_checkpoints/error_" + "lmode_" + str(self.lmode) + ".csv"
            header = ''
            if (self.iter == 1):
                header = 'Iterations, L2-error(T), L2-error (P), L2-error (log k)'

            with open(filename, 'a') as f:
                errorlist = np.array([self.iter, err_T, err_p, err_k]).reshape(1,4)
                np.savetxt(f, errorlist, delimiter = ',', header = header, comments = '')
                
        if self.iter % 1000 == 0 or force_all_output:
            train_nd = pd.DataFrame([self.train_hist])
            val_nd = pd.DataFrame([self.val_hist])
            Tnet_nd = pd.DataFrame([self.Tloss_hist])
            Pnet_nd = pd.DataFrame([self.ploss_hist])
            Knet_nd = pd.DataFrame([self.kloss_hist])
            PI1_AD_nd = pd.DataFrame([self.PI1lossAD_hist])
            PI2_AD_nd = pd.DataFrame([self.PI2lossAD_hist])
            TpbD_nd = pd.DataFrame([self.TpbDloss_hist])
            TbN_nd = pd.DataFrame([self.TbNloss_hist])
            PI1_AD_all_nd = pd.DataFrame([self.PI1lossAD_all_hist])
            PI2_AD_all_nd = pd.DataFrame([self.PI2lossAD_all_hist])
            wPI1_nd = pd.DataFrame([self.wPI1_hist])
            wPI2_nd = pd.DataFrame([self.wPI2_hist])
            wTbN_nd = pd.DataFrame([self.wTbN_hist])
            histories_df = pd.concat([train_nd.T,val_nd.T,Tnet_nd.T,Pnet_nd.T,Knet_nd.T,PI1_AD_nd.T,PI2_AD_nd.T,TpbD_nd.T,TbN_nd.T,PI1_AD_all_nd.T,PI2_AD_all_nd.T,wPI1_nd.T,wPI2_nd.T,wTbN_nd.T],axis=1)
            histories_df.columns = ['train_loss','val_loss','Tnet_loss','Pnet_loss','Knet_loss','PI1_loss_AD','PI2_loss_AD','TpbDnet_loss','TbN_loss','PI1_allgrid','PI2_allgrid','Weight_PI1loss','Weight_PI2loss','Weight_TbNloss']
            histories_df.to_csv('./save_checkpoints/histories_loss.csv',index=False)

            # plot
            N = self.xall.shape[0]
            x1 = tf.reshape(self.xall[:, 0], (N, 1))
            x2 = tf.reshape(self.xall[:, 1], (N, 1))
            x3 = tf.reshape(self.xall[:, 2], (N, 1))

            N = self.xall.shape[0]
            t = Tdenm.numpy().reshape(N)
            t_true = self.T_true.reshape(N)
            p = pdenm.numpy().reshape(N)
            p_true = self.p_true.reshape(N)
            k = kdenm.numpy().reshape(N)
            k_true = self.k_true.reshape(N)

            err_p = np.abs(p - p_true)/p_true

            filename = "./save_checkpoints/T_" + str(self.iter) + "lmode_" + str(self.lmode) + ".png"
            fig = plt.figure()
            ax = fig.add_subplot(projection='3d')
            cm = plt.cm.get_cmap('hot')
            pl = ax.scatter(self.xall[:, 0], self.xall[:, 1],self.xall[:, 2], vmin = 15, vmax = 350,c = Tdenm, cmap=cm)
            plt.xlabel("x")
            plt.ylabel("z")
            plt.title("T")
            plt.colorbar(pl)
            plt.savefig(filename)
            plt.clf()

            filename = "./save_checkpoints/p_" + str(self.iter) + "lmode_" + str(self.lmode) + ".png"
            cm = plt.cm.get_cmap('cool')
            fig = plt.figure()
            ax = fig.add_subplot(projection='3d')
            pl = ax.scatter(self.xall[:, 0], self.xall[:, 1], self.xall[:, 2],c = pdenm,vmin = 1e5, vmax = 2.5e7, cmap=cm)
            plt.xlabel("x")
            plt.ylabel("z")
            plt.title("p")
            plt.colorbar(pl)
            plt.savefig(filename)
            plt.clf()

            filename = "./save_checkpoints/k_" + str(self.iter) + "lmode_" + str(self.lmode) + ".png"
            fig = plt.figure()
            ax = fig.add_subplot(projection='3d')
            cm = plt.cm.get_cmap('jet')
            pl = ax.scatter(self.xall[:, 0], self.xall[:, 1], self.xall[:,2],c = kdenm,vmin = -17, vmax = -12, cmap=cm)
            plt.xlabel("x")
            plt.ylabel("z")
            plt.title("k")

            plt.colorbar(pl)
            plt.savefig(filename)
            plt.clf()

            plt.close('all')
                
    def restore_checkpoint(self):
        checkpoint = tf.train.Checkpoint(model=self.model.variables)
        checkpoint.restore(tf.train.latest_checkpoint('./save_checkpoints'))

In [None]:
DTYPE='float32'
tf.keras.backend.set_floatx(DTYPE)

def Normalize_input(val,idx_pos, val_star):
    return (val[idx_pos]-np.min(val_star))/(np.max(val_star)-np.min(val_star))

def Normalize_input_all(val, val_star):
    return (val - np.min(val_star))/(np.max(val_star)-np.min(val_star))

In [None]:
"""
Cell block to read well and reference data
"""

# The number of datapoints in x (x1), y (x2), and z (x3) direction of the reference model
N_x1 = 18
N_x2 = 11
N_x3 = 18

# The total number of datapoints
N = N_x1 * N_x2 * N_x3

# Read reference data
dataset = pd.read_csv('./Reference_model.csv')
X = np.array((dataset['X_Easting'],dataset['Y_Northing'], dataset['Elevation'])).T
print("shape of X = ", X.shape)
T_true = np.array(dataset['T_degC']).reshape(X.shape[0] ,1)
k_true = np.array(dataset['log10PER']).reshape(X.shape[0] ,1)
p_true = np.array(dataset['P_Pa']).reshape(X.shape[0] ,1)

# Read well coordinates & temperatures, pore pressures, permeabilities at wells
dataset_wells = pd.read_csv('./Welldata_30wells.csv')
X_star = np.array((dataset_wells['X_Easting'],dataset_wells['Y_Northing'], dataset_wells['Elevation'])).T
T_star = np.array(dataset_wells['T_degC']).reshape(X_star.shape[0] ,1)
k_star = np.array(dataset_wells['log10PER']).reshape(X_star.shape[0] ,1)
p_star = np.array(dataset_wells['P_Pa']).reshape(X_star.shape[0] ,1)
T_star_tf = tf.constant(T_star, DTYPE)
p_star_tf = tf.constant(p_star, DTYPE)
k_star_tf = tf.constant(k_star, DTYPE)

# Elevation threshold to define the boundary of training and validation data
training_validation_boundary = -1700

# 領域の境界のインデックスを取得します。
idx_low, idx_up, idx_west, idx_east, idx_south, idx_north \
            = gb.get_boundary_indexlist(N_x1, N_x2, N_x3, X)
print('idx_low.shape = ', len(idx_low))
print('idx_up.shape = ', len(idx_up))
print('idx_west.shape = ', len(idx_west))
print('idx_east.shape = ', len(idx_east))
print('idx_south.shape = ', len(idx_south))
print('idx_north.shape = ', len(idx_north))

ndx_ew = len(idx_east)
ndx_ns = len(idx_south)

# Type of the bottom boundary condition
if_low_dirichlet = True  # Dirichlet boundary (Temperature & Pressure boundary)
#if_low_dirichlet = False  # Neumann boundary (Heat flow boundary)

# Heat flow at the bottom boundary
qN = 0.6

# Representative thermal conductivity
Lambda = 2.0

# Thermal gradient at the bottom
dTdn = qN / Lambda

# Number of wells used for interpolation
pointnumber_of_interpolate = 1

# Interpolation parameters used to create approximate functions for Dirichlet boundary conditions on the top and bottom surfaces
order_interpolate_dirichlet = 1
regularization_interpolate_dirichlet = 1e-4

In [None]:
# Input grid
up_indices = (
    np.transpose(
        np.unravel_index(
            idx_up,
            shape=(N_x3, N_x2, N_x1)))[:, [2, 1, 0]]
)
low_indices = (
    np.transpose(
        np.unravel_index(
            idx_low,
            shape=(N_x3, N_x2, N_x1)))[:, [2, 1, 0]]
)
side_indices = (
    np.transpose(
        np.unravel_index(
            np.concatenate([idx_west, idx_east, idx_south, idx_north]),
            shape=(N_x3, N_x2, N_x1)))[:, [2, 1, 0]]
)
grid_spec = GridSpec(
    N_x1=N_x1, N_x2=N_x2, N_x3=N_x3,
    up_indices=up_indices,
    low_indices=low_indices,
    side_indices=side_indices
)

print(grid_spec.up_indices.shape)

In [None]:
"""
Cell block to specify boundary conditions
"""

# Specify input domain bounds
lb, ub = X.min(0), X.max(0)
# Lower bounds
lb = tf.constant(X.min(0), dtype=DTYPE)
# Upper bounds
ub = tf.constant(X.max(0), dtype=DTYPE)
#
idx_f = np.random.choice(N,N, replace=False)
#------------------------------------------------------------------

# Dirichlet condition at the upper surface boundary
X_uTpbD1 = tf.constant(X[idx_up, :], dtype=DTYPE)
X_uTpbD = [X_uTpbD1,X_uTpbD1]
Y_uTpbD_T = tf.constant(Normalize_input(T_true, idx_up, T_star), dtype=DTYPE)
Y_uTpbD_p = tf.constant(Normalize_input(p_true, idx_up, p_star), dtype=DTYPE)

# Dirichlet condition at the bottom boundary
X_lTpbD1 = tf.constant(X[idx_low, :], dtype=DTYPE)
X_lTpbD = [X_lTpbD1,X_lTpbD1]
Y_lTpbD_T = tf.constant(Normalize_input(T_true, idx_low, T_star), dtype=DTYPE)
Y_lTpbD_p = tf.constant(Normalize_input(p_true, idx_low, p_star), dtype=DTYPE)

# Neumann condition at the side boundary
x_pbN1 = tf.constant(X[idx_west, :], dtype=DTYPE)
x_pbN2 = tf.constant(X[idx_east, :], dtype=DTYPE)
x_pbN3 = tf.constant(X[idx_south, :], dtype=DTYPE)
x_pbN4 = tf.constant(X[idx_north, :], dtype=DTYPE)
X_pbN = tf.concat([x_pbN1, x_pbN2, x_pbN3, x_pbN4], axis = 0)
Y_pbN = tf.zeros((X_pbN.shape[0], 1), dtype=DTYPE)
n_west_vec = tf.constant([[-1.0, 0.0, 0.0]], dtype=DTYPE)
n_east_vec = tf.constant([[1.0, 0.0, 0.0]], dtype=DTYPE)
n_south_vec = tf.constant([[0.0, -1.0, 0.0]], dtype=DTYPE)
n_north_vec = tf.constant([[0.0, 1.0, 0.0]], dtype=DTYPE)
n_west_size = tf.constant([N_x2 * N_x3,1], tf.int32)
n_east_size = tf.constant([N_x2 * N_x3,1], tf.int32)
n_south_size = tf.constant([N_x1 * N_x3,1], tf.int32)
n_north_size = tf.constant([N_x1 * N_x3,1], tf.int32)
n_west = tf.tile(n_west_vec,n_west_size)
n_east = tf.tile(n_east_vec,n_east_size)
n_south = tf.tile(n_south_vec,n_south_size)
n_north = tf.tile(n_north_vec,n_north_size)
nvec_p = tf.concat([n_west, n_east, n_south, n_north], axis = 0)

# Neumann condition at the bottom boundary
X_TbN = tf.constant(X[idx_low, :], dtype=DTYPE)
Y_TbN = dTdn * tf.ones((X_TbN.shape[0], 1), dtype=DTYPE)
nTvec = tf.constant([[0.0, 0.0, -1.0]], dtype=DTYPE)
nTsize = tf.constant([N_x1 * N_x2, 1], tf.int32)
nvec_T = np.tile(nTvec, nTsize)

# Coordinates in all analyzed domain
x1_all = tf.constant(X[:,0:1], dtype=DTYPE)
x2_all = tf.constant(X[:,1:2], dtype=DTYPE)
x3_all = tf.constant(X[:,2:3], dtype=DTYPE)
X_all = tf.concat([x1_all, x2_all, x3_all], axis=1)
#------------------------------------------------------------------

In [None]:
"""
Cell block that obtains grid points near the wells.
"""

# Get indices for datapoints near wells
idx_tpk_train, idx_tpk_val = swtv.get_wells_neighbor_indecies(
                                                             X_star, X,
                                                             training_validation_boundary,
                                                             pointnumber_of_interpolate)

# Divide well data into training and validation data
[X_wells_train, T_wells_train_denm, p_wells_train_denm, k_wells_train_denm,
X_wells_val, T_wells_val_denm, p_wells_val_denm, k_wells_val_denm]\
= swtv.get_training_and_validation_data(X_star, T_star, p_star,
                                        k_star, training_validation_boundary)

# Finish in case of no training points
if idx_tpk_train.shape[0] == 0:
    print("There are no training points for the loss function. Exit.")
    sys.exit()
# Warning in case of no validation point 
if idx_tpk_val.shape[0] == 0:
    print("Warning： There is no validation point. The loss function for the validation data will be NaN.")

N_wells = X_star.shape[0]
N_wells_train = X_wells_train.shape[0]
N_wells_val = X_wells_val.shape[0]

# Training and validation data
X_wells_train = tf.constant(X_wells_train, dtype = DTYPE)
X_wells_val = tf.constant(X_wells_val, dtype = DTYPE)

X_train = []
T_train = []
p_train = []
k_train = []

X_val = []
T_val = []
p_val = []
k_val = []

N_train = idx_tpk_train.shape[0]
N_val = idx_tpk_val.shape[0]

X_train = tf.constant(X[idx_tpk_train, :], dtype = DTYPE)
T_train = tf.constant(Normalize_input(T_true, idx_tpk_train, T_star), dtype = DTYPE)
p_train = tf.constant(Normalize_input(p_true, idx_tpk_train, p_star), dtype = DTYPE)
k_train = tf.constant(Normalize_input(k_true, idx_tpk_train, k_star), dtype = DTYPE)

X_val = tf.constant(X[idx_tpk_val, :], dtype = DTYPE)
T_val = tf.constant(Normalize_input(T_true, idx_tpk_val, T_star), dtype = DTYPE)
p_val = tf.constant(Normalize_input(p_true, idx_tpk_val, p_star), dtype = DTYPE)
k_val = tf.constant(Normalize_input(k_true, idx_tpk_val, k_star), dtype = DTYPE)

# Normalize well data
T_wells_train = tf.constant(Normalize_input_all(T_wells_train_denm, T_star), dtype = DTYPE)
p_wells_train = tf.constant(Normalize_input_all(p_wells_train_denm, p_star), dtype = DTYPE)
k_wells_train = tf.constant(Normalize_input_all(k_wells_train_denm, k_star), dtype = DTYPE)

T_wells_val = tf.constant(Normalize_input_all(T_wells_val_denm, T_star), dtype = DTYPE)
p_wells_val = tf.constant(Normalize_input_all(p_wells_val_denm, p_star), dtype = DTYPE)
k_wells_val = tf.constant(Normalize_input_all(k_wells_val_denm, k_star), dtype = DTYPE)

In [None]:
"""
Cell block to create tf.data.Dataset for training and validation data
"""
#---------------------------------------------------------
train_ds = tf.data.Dataset.from_tensors((X_train,
                                        Y_pbN,Y_TbN,
                                        X_wells_train,
                                        T_wells_train,
                                        p_wells_train,
                                        k_wells_train,
                                        idx_tpk_train
                                        ))

val_ds = tf.data.Dataset.from_tensors((X_val,
                                        Y_pbN,
                                        Y_TbN,
                                        X_wells_val,
                                        T_wells_val,
                                        p_wells_val,
                                        k_wells_val,
                                        idx_tpk_val
                                    ))

#--------------------------------- All grid data
idx_all = np.arange(0,N,1)
x1_T_all = tf.constant(X[idx_all,0:1], dtype=DTYPE)
x2_T_all = tf.constant(X[idx_all,1:2], dtype=DTYPE)
x3_T_all = tf.constant(X[idx_all,2:3], dtype=DTYPE)

X_T_all = tf.concat([x1_T_all, x2_T_all, x3_T_all], axis=1)
Y_T_all = tf.constant(Normalize_input(T_true,idx_all, T_star), dtype=DTYPE)

x1_p_all = tf.constant(X[idx_all,0:1], dtype=DTYPE)
x2_p_all = tf.constant(X[idx_all,1:2], dtype=DTYPE)
x3_p_all = tf.constant(X[idx_all,2:3], dtype=DTYPE)
X_p_all = tf.concat([x1_p_all, x2_p_all, x3_p_all], axis=1)
Y_p_all = tf.constant(Normalize_input(p_true,idx_all, p_star), dtype=DTYPE)

x1_k_all = tf.constant(X[idx_tpk_val,0:1], dtype=DTYPE)
x2_k_all = tf.constant(X[idx_tpk_val,1:2], dtype=DTYPE)
x3_k_all = tf.constant(X[idx_tpk_val,2:3], dtype=DTYPE)
X_k_all = tf.concat([x1_k_all, x2_k_all, x3_k_all], axis=1)
Y_k_all = tf.constant(Normalize_input(k_true,idx_all, k_star), dtype=DTYPE)
#---------------------------------------------------------
X_alldata = X_p_all
YT_alldata = Y_T_all
Yp_alldata = Y_p_all
Yk_alldata = Y_k_all

# Dummy data are used for all data except X_alldata, as the all data is rarely used.
all_ds = tf.data.Dataset.from_tensors((X_alldata,
                                        Y_pbN,Y_TbN,
                                        X_wells_val,
                                        T_wells_val,
                                        p_wells_val,
                                        k_wells_val,
                                        idx_all
                                    ))

In [None]:
#==========================================================================
##### Run PINN solver
# Initialize model

model = PINN_NeuralNet3(lb,ub)
model.build(input_shape=(None, N_x1, N_x2, N_x3, 3))

# Initilize PINN solver
lmode = -1
cof_PI1_AD = 1
cof_PI2_AD = 1
cof_pbN = 1
cof_TbN = 1
solver = PINNSolver_3D(
    model=model,
    model_input_grid_spec=grid_spec,
    if_low_dirichlet=if_low_dirichlet,
    PI_indices=idx_f,
    X_uTpbD=X_uTpbD,X_lTpbD=X_lTpbD,
    X_pbN=X_pbN,nvec_p=nvec_p, X_TbN=X_TbN, nvec_T=nvec_T,
    X_all=X_all,
    Y_uTpbD_T=Y_uTpbD_T, Y_uTpbD_p=Y_uTpbD_p,
    Y_lTpbD_T=Y_lTpbD_T, Y_lTpbD_p=Y_lTpbD_p,
    cof_PI1AD=cof_PI1_AD, cof_PI2AD=cof_PI2_AD, cof_pbN=cof_pbN, cof_TbN=cof_TbN,
    T_true=T_true, p_true=p_true, k_true=k_true,
    T_for_normalize=T_star_tf, p_for_normalize=p_star_tf, k_for_normalize=k_star_tf,
    lmode=lmode,
    dtype=DTYPE)

# Decide which optimizer should be used
mode = 'TFoptimizer'

if mode == 'TFoptimizer':
    print('#================= First Adam training without calculate physics and bound terms')
    TF_maxstep = 12001
    lr = tf.keras.optimizers.schedules.PiecewiseConstantDecay([3000,10000],[5e-3,2.5e-3,5e-4])
    optim = tf.keras.optimizers.Adam(learning_rate=lr)
    solver.solve_with_TFoptimizer(optim, train_ds, val_ds, all_ds, N=TF_maxstep)
elif mode == 'L-BFGS':
    print('#================= L-BFGS training')
    max_iter = 5000
    lbfgs_result_state = solver.solve_with_tfp_lbfgs(
        train_ds=train_ds, val_ds=val_ds, all_ds=all_ds,
        max_iterations=max_iter, debug_print=False)
    TF_maxstep = lbfgs_result_state.num_total_iterations

os.rename(src='./save_checkpoints/histories_loss.csv', dst='./save_checkpoints/histories_loss_run0.csv')

##### Run PINN solver 1
cof_PI1_AD = 1
cof_PI2_AD = 1
cof_pbN = 1
cof_TbN = 1

# Initialize model
lmode = 0
solver2 = PINNSolver_3D(
    model=model,
    model_input_grid_spec=grid_spec,
    if_low_dirichlet=if_low_dirichlet,
    PI_indices=idx_f,
    X_uTpbD=X_uTpbD,X_lTpbD=X_lTpbD,
    X_pbN=X_pbN,nvec_p=nvec_p, X_TbN=X_TbN, nvec_T=nvec_T,
    X_all=X_all,
    Y_uTpbD_T=Y_uTpbD_T, Y_uTpbD_p=Y_uTpbD_p,
    Y_lTpbD_T=Y_lTpbD_T, Y_lTpbD_p=Y_lTpbD_p,
    cof_PI1AD=cof_PI1_AD, cof_PI2AD=cof_PI2_AD, cof_pbN=cof_pbN, cof_TbN=cof_TbN,
    T_true=T_true, p_true=p_true, k_true=k_true,
    T_for_normalize=T_star_tf, p_for_normalize=p_star_tf, k_for_normalize=k_star_tf,
    lmode=lmode,
    dtype=DTYPE)

solver2.restore_checkpoint

mode = 'TFoptimizer'

if mode == 'TFoptimizer':
    print('#================= First Adam training')
    TF_maxstep = 2001
    lr = tf.keras.optimizers.schedules.PiecewiseConstantDecay([3000,10000],[5e-4,1e-4,5e-5])
    optim = tf.keras.optimizers.Adam(learning_rate=lr)
    solver2.solve_with_TFoptimizer(optim, train_ds, val_ds, all_ds, N=TF_maxstep)
elif mode == 'L-BFGS':
    print('#================= L-BFGS training')
    max_iter = 20000
    solver2.solve_with_tfp_lbfgs(
        train_ds=train_ds, val_ds=val_ds, all_ds=all_ds,
        max_iterations=max_iter, debug_print=True)

os.rename(src='./save_checkpoints/histories_loss.csv', dst='./save_checkpoints/histories_loss_run1.csv')

##### Run PINN solver 2
# read loss history file and calc weighting coefs for PI1loss & PI2loss
lhistries_df = pd.read_csv('./save_checkpoints/histories_loss_run1.csv')
Tloss = lhistries_df.Tnet_loss
Ploss = lhistries_df.Pnet_loss
Kloss = lhistries_df.Knet_loss
TbNloss = lhistries_df.TbN_loss
Tloss_last = Tloss[TF_maxstep-2]
Ploss_last = Ploss[TF_maxstep-2]
Kloss_last = Kloss[TF_maxstep-2]
TbNloss_last = TbNloss[TF_maxstep-2]
PI1loss_AD = lhistries_df.PI1_loss_AD
PI2loss_AD = lhistries_df.PI2_loss_AD
PI1loss_AD_last = PI1loss_AD[TF_maxstep-2]
PI2loss_AD_last = PI2loss_AD[TF_maxstep-2]
aveloss_last = (Tloss_last+Ploss_last+Kloss_last)/3
cof_PI1_AD = aveloss_last / PI1loss_AD_last
cof_PI2_AD = aveloss_last / PI2loss_AD_last

cof_pbN = 0.0
if TbNloss_last == 0.0:
    cof_TbN = 0.0
else:
    cof_TbN = aveloss_last / TbNloss_last

# Initialize model
lmode = 1
solver2 = PINNSolver_3D(
    model=model,
    model_input_grid_spec=grid_spec,
    if_low_dirichlet=if_low_dirichlet,
    PI_indices=idx_f,
    X_uTpbD=X_uTpbD,X_lTpbD=X_lTpbD,
    X_pbN=X_pbN,nvec_p=nvec_p, X_TbN=X_TbN, nvec_T=nvec_T,
    X_all=X_all,
    Y_uTpbD_T=Y_uTpbD_T, Y_uTpbD_p=Y_uTpbD_p,
    Y_lTpbD_T=Y_lTpbD_T, Y_lTpbD_p=Y_lTpbD_p,
    cof_PI1AD=cof_PI1_AD, cof_PI2AD=cof_PI2_AD, cof_pbN=cof_pbN, cof_TbN=cof_TbN,
    T_true=T_true, p_true=p_true, k_true=k_true,
    T_for_normalize=T_star_tf, p_for_normalize=p_star_tf, k_for_normalize=k_star_tf,
    lmode=lmode,
    dtype=DTYPE)

solver2.restore_checkpoint

mode = 'TFoptimizer'

if mode == 'TFoptimizer':
    print('#================= Second Adam training')
    TF_maxstep = 30001
    lr = tf.keras.optimizers.schedules.PiecewiseConstantDecay([3000,10000],[5e-4,1e-4,5e-5])
    optim = tf.keras.optimizers.Adam(learning_rate=lr)
    solver2.solve_with_TFoptimizer(optim, train_ds, val_ds, all_ds, N=TF_maxstep)
elif mode == 'L-BFGS':
    print('#================= L-BFGS training')
    max_iter = 2000
    solver2.solve_with_tfp_lbfgs(
        train_ds=train_ds, val_ds=val_ds, all_ds=all_ds,
        max_iterations=max_iter, debug_print=True)

os.rename(src='./save_checkpoints/histories_loss.csv', dst='./save_checkpoints/histories_loss_run2.csv')


##### Run PINN solver 3
# Read loss history file and calc weighting coefs for PI1loss & PI2loss
lhistries_df = pd.read_csv('./save_checkpoints/histories_loss_run2.csv')
Tloss = lhistries_df.Tnet_loss
Ploss = lhistries_df.Pnet_loss
Kloss = lhistries_df.Knet_loss
TbNloss = lhistries_df.TbN_loss
Tloss_last = Tloss[TF_maxstep-2]
Ploss_last = Ploss[TF_maxstep-2]
Kloss_last = Kloss[TF_maxstep-2]
TbNloss_last = TbNloss[TF_maxstep-2]
PI1loss_AD = lhistries_df.PI1_loss_AD
PI2loss_AD = lhistries_df.PI2_loss_AD
PI1loss_AD_last = PI1loss_AD[TF_maxstep-2]
PI2loss_AD_last = PI2loss_AD[TF_maxstep-2]
aveloss_last = (Tloss_last+Ploss_last+Kloss_last)/3
cof_PI1_AD = aveloss_last / PI1loss_AD_last
cof_PI2_AD = aveloss_last / PI2loss_AD_last

cof_pbN = 0.0
if TbNloss_last == 0.0:
    cof_TbN = 0.0
else:
    cof_TbN = aveloss_last / TbNloss_last

# Initialize model
lmode = 2
solver3 = PINNSolver_3D(
    model=model,
    model_input_grid_spec=grid_spec,
    if_low_dirichlet=if_low_dirichlet,
    PI_indices=idx_f,
    X_uTpbD=X_uTpbD,X_lTpbD=X_lTpbD,
    X_pbN=X_pbN,nvec_p=nvec_p, X_TbN=X_TbN, nvec_T=nvec_T,
    X_all=X_all,
    Y_uTpbD_T=Y_uTpbD_T, Y_uTpbD_p=Y_uTpbD_p,
    Y_lTpbD_T=Y_lTpbD_T, Y_lTpbD_p=Y_lTpbD_p,
    cof_PI1AD=cof_PI1_AD, cof_PI2AD=cof_PI2_AD, cof_pbN=cof_pbN, cof_TbN=cof_TbN,
    T_true=T_true, p_true=p_true, k_true=k_true,
    T_for_normalize=T_star_tf, p_for_normalize=p_star_tf, k_for_normalize=k_star_tf,
    lmode=lmode,
    dtype=DTYPE)

solver3.restore_checkpoint

mode = 'L-BFGS'

if mode == 'TFoptimizer':
    print('#================= Third Adam training')
    lr = tf.keras.optimizers.schedules.PiecewiseConstantDecay([3000,10000],[5e-4,1e-4,5e-5])
    optim = tf.keras.optimizers.Adam(learning_rate=lr)
    solver3.solve_with_TFoptimizer(optim, train_ds, val_ds, all_ds, N=10001)
elif mode == 'L-BFGS':
    print('#================= L-BFGS training')
    max_iter = 10000
    solver3.solve_with_tfp_lbfgs(
        train_ds=train_ds, val_ds=val_ds, all_ds=all_ds,
        max_iterations=max_iter, debug_print=True)

os.rename(src='./save_checkpoints/histories_loss.csv', dst='./save_checkpoints/histories_loss_run3_LBFGS.csv')
